diff --git a/middleware/frontend_origin.go b/middleware/frontend_origin.go deleted file mode 100644 index e3a8691..0000000 --- a/middleware/frontend_origin.go +++ /dev/null @@ -1,44 +0,0 @@ -package middleware - -import ( - "authentication/helper" - "net/http" - "os" - "strings" -) - -const defaultFrontendOrigin = "http://localhost:5173" - -func allowedFrontendOrigins() map[string]struct{} { - raw := os.Getenv("FRONTEND_ORIGINS") - if strings.TrimSpace(raw) == "" { - raw = defaultFrontendOrigin - } - - allowed := make(map[string]struct{}) - for _, origin := range strings.Split(raw, ",") { - trimmed := strings.TrimSpace(origin) - if trimmed != "" { - allowed[trimmed] = struct{}{} - } - } - - return allowed -} - -func FrontendOriginWhitelist(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - origin := strings.TrimSpace(r.Header.Get("Origin")) - if origin == "" { - helper.RespondWithError(w, http.StatusBadRequest, "missing origin header") - return - } - - if _, ok := allowedFrontendOrigins()[origin]; !ok { - helper.RespondWithError(w, http.StatusForbidden, "forbidden origin") - return - } - - next.ServeHTTP(w, r) - }) -} diff --git a/middleware/frontend_origin_test.go b/middleware/frontend_origin_test.go deleted file mode 100644 index 2aa63c2..0000000 --- a/middleware/frontend_origin_test.go +++ /dev/null @@ -1,94 +0,0 @@ -package middleware - -import ( - "net/http" - "net/http/httptest" - "os" - "testing" -) - -func TestFrontendOriginWhitelist_DefaultAllowedOrigin(t *testing.T) { - os.Unsetenv("FRONTEND_ORIGINS") - - called := false - next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - called = true - w.WriteHeader(http.StatusOK) - }) - - handler := FrontendOriginWhitelist(next) - req := httptest.NewRequest(http.MethodGet, "/v1/auth/forgot-password", nil) - req.Header.Set("Origin", defaultFrontendOrigin) - recorder := httptest.NewRecorder() - - handler.ServeHTTP(recorder, req) - - if recorder.Code != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, recorder.Code) - } - if !called { - t.Fatal("expected next handler to be called") - } -} - -func TestFrontendOriginWhitelist_RejectsMissingOrigin(t *testing.T) { - os.Unsetenv("FRONTEND_ORIGINS") - - next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - }) - - handler := FrontendOriginWhitelist(next) - req := httptest.NewRequest(http.MethodGet, "/v1/auth/forgot-password", nil) - recorder := httptest.NewRecorder() - - handler.ServeHTTP(recorder, req) - - if recorder.Code != http.StatusForbidden { - t.Fatalf("expected status %d, got %d", http.StatusForbidden, recorder.Code) - } -} - -func TestFrontendOriginWhitelist_RejectsNonWhitelistedOrigin(t *testing.T) { - os.Unsetenv("FRONTEND_ORIGINS") - - next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - }) - - handler := FrontendOriginWhitelist(next) - req := httptest.NewRequest(http.MethodGet, "/v1/auth/forgot-password", nil) - req.Header.Set("Origin", "http://malicious-site.example") - recorder := httptest.NewRecorder() - - handler.ServeHTTP(recorder, req) - - if recorder.Code != http.StatusForbidden { - t.Fatalf("expected status %d, got %d", http.StatusForbidden, recorder.Code) - } -} - -func TestFrontendOriginWhitelist_UsesConfiguredOrigins(t *testing.T) { - os.Setenv("FRONTEND_ORIGINS", "http://localhost:4173, http://localhost:5173") - defer os.Unsetenv("FRONTEND_ORIGINS") - - called := false - next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - called = true - w.WriteHeader(http.StatusOK) - }) - - handler := FrontendOriginWhitelist(next) - req := httptest.NewRequest(http.MethodGet, "/v1/auth/forgot-password", nil) - req.Header.Set("Origin", "http://localhost:4173") - recorder := httptest.NewRecorder() - - handler.ServeHTTP(recorder, req) - - if recorder.Code != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, recorder.Code) - } - if !called { - t.Fatal("expected next handler to be called") - } -} diff --git a/routes/routes.go b/routes/routes.go index 62f8082..30172e0 100644 --- a/routes/routes.go +++ b/routes/routes.go @@ -15,11 +15,9 @@ func SetupRoutes(router *mux.Router, db *sql.DB) { router.HandleFunc("/ready", handlers.ReadyHandler).Methods("GET") authRoutes := router.PathPrefix("/v1/auth").Subrouter() - frontendOnly := authRoutes.NewRoute().Subrouter() - frontendOnly.Use(middleware.FrontendOriginWhitelist) - frontendOnly.HandleFunc("/login", handlers.GoogleLogin).Methods("GET") - frontendOnly.HandleFunc("/forgot-password", handlers.ForgotPassword).Methods("GET") - frontendOnly.HandleFunc("/callback", handlers.GoogleCallback).Methods("GET") + authRoutes.HandleFunc("/login", handlers.GoogleLogin).Methods("GET") + authRoutes.HandleFunc("/forgot-password", handlers.ForgotPassword).Methods("GET") + authRoutes.HandleFunc("/callback", handlers.GoogleCallback).Methods("GET") csrfProtected := authRoutes.NewRoute().Subrouter() csrfProtected.Use(middleware.CSRFMiddleware)