added fetching of origin in auth login

This commit is contained in:
2026-03-05 10:09:12 +08:00
parent 8f51faeb12
commit 30c91cf5c8
6 changed files with 319 additions and 76 deletions
+63 -57
View File
@@ -27,9 +27,12 @@ import (
)
var googleOauthConfig oauth2.Config
var oauthStateString = generateRandomState()
var AuthorizationURL string
var FetchedRedirectURI *string
const (
oauthStateCookieName = "oauth_state"
oauthRedirectURICookieName = "oauth_redirect_uri"
)
func isTestEnvironment() bool {
return flag.Lookup("test.v") != nil || strings.Contains(os.Args[0], ".test")
@@ -106,29 +109,41 @@ func generateRandomState() string {
}
func GoogleLogin(w http.ResponseWriter, r *http.Request) {
helper.LogInfo(fmt.Sprintf("Generated oauth_state: %s", oauthStateString))
isSecure := strings.HasPrefix(os.Getenv("BACKEND_URL"), HTTPS)
state := generateRandomState()
helper.LogInfo(fmt.Sprintf("Generated oauth_state: %s", state))
redirectURI := strings.TrimSpace(r.URL.Query().Get("redirect_uri"))
if redirectURI == "" {
helper.RespondWithError(w, http.StatusBadRequest, "redirect_uri is required")
return
}
if !IsAllowedRedirectURI(redirectURI) {
helper.RespondWithError(w, http.StatusUnauthorized, "Unauthorized RedirectURI")
return
}
http.SetCookie(w, &http.Cookie{
Name: "oauth_state",
Value: oauthStateString,
Name: oauthStateCookieName,
Value: state,
Path: "/",
HttpOnly: true,
Secure: isSecure,
SameSite: http.SameSiteLaxMode,
Expires: time.Now().Add(5 * time.Minute),
})
http.SetCookie(w, &http.Cookie{
Name: oauthRedirectURICookieName,
Value: redirectURI,
Path: "/",
HttpOnly: true,
Secure: isSecure,
SameSite: http.SameSiteLaxMode,
Expires: time.Now().Add(5 * time.Minute),
})
redirectURI := r.URL.Query().Get("redirect_uri")
if redirectURI != "" {
FetchedRedirectURI = &redirectURI
log.Print("FetchedRedirectURI set to: ", *FetchedRedirectURI)
} else {
FetchedRedirectURI = nil
}
url := googleOauthConfig.AuthCodeURL(oauthStateString, oauth2.AccessTypeOffline, oauth2.ApprovalForce)
url := googleOauthConfig.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.ApprovalForce)
http.Redirect(w, r, url, http.StatusFound)
}
@@ -184,6 +199,11 @@ func GoogleCallback(w http.ResponseWriter, r *http.Request) {
}
helper.LogInfo(fmt.Sprintf("[oauth-debug] state validation ok duration_ms=%d", time.Since(stateStart).Milliseconds()))
redirectURI, ok := callbackRedirectURI(w, r)
if !ok {
return
}
googleUserInfoStart := time.Now()
userInfo, err := FetchGoogleUserInfo(w, r)
if err != nil {
@@ -232,24 +252,8 @@ func GoogleCallback(w http.ResponseWriter, r *http.Request) {
if !emailExists {
helper.LogError(errors.New("unregistered email"), "Google login attempt with unregistered email: "+email)
if FetchedRedirectURI != nil && *FetchedRedirectURI != "" {
RedirectURI := *FetchedRedirectURI
log.Print("RedirectURI from query param: ", RedirectURI)
if !IsAllowedRedirectURI(RedirectURI) {
helper.LogError(errors.New("unauthorized redirect uri"), "Blocked redirect URI for unregistered email: "+RedirectURI)
helper.RespondWithError(w, http.StatusUnauthorized, "Unauthorized RedirectURI")
log.Print("Unauthorized RedirectURI: ", RedirectURI)
return
}
log.Print("Valid redirect_uri: ", RedirectURI)
RedirectURL := fmt.Sprintf("%s/callback?error=%s=", RedirectURI, "unregistered_email")
http.Redirect(w, r, RedirectURL, http.StatusSeeOther)
return
}
log.Print("No redirect_uri provided, returning JSON response")
// No redirect_uri provided, return JSON response
helper.RespondWithError(w, http.StatusUnauthorized, "Your email is not registered in the system. Please contact your administrator to request access.")
RedirectURL := fmt.Sprintf("%s/callback?error=%s=", redirectURI, "unregistered_email")
http.Redirect(w, r, RedirectURL, http.StatusSeeOther)
return
}
@@ -322,33 +326,12 @@ func GoogleCallback(w http.ResponseWriter, r *http.Request) {
helper.LogInfo("Copy this access token: " + accessToken)
if FetchedRedirectURI != nil && *FetchedRedirectURI != "" {
RedirectURI := *FetchedRedirectURI
log.Print("RedirectURI from query param: ", RedirectURI)
if !IsAllowedRedirectURI(RedirectURI) {
helper.LogError(errors.New("unauthorized redirect uri"), "Blocked redirect URI after successful auth: "+RedirectURI)
helper.RespondWithError(w, http.StatusUnauthorized, "Unauthorized RedirectURI")
log.Print("Unauthorized RedirectURI: ", RedirectURI)
return
}
log.Print("Valid redirect_uri: ", RedirectURI)
RedirectURL := fmt.Sprintf("%s/callback?token=%s&user_id=%s", RedirectURI, accessToken, userID)
helper.LogInfo(fmt.Sprintf("[oauth-debug] callback complete redirect=true total_ms=%d", time.Since(callbackStart).Milliseconds()))
http.Redirect(w, r, RedirectURL, http.StatusSeeOther)
return
}
log.Print("No redirect_uri provided, returning JSON response")
// No redirect_uri provided, return JSON response
helper.LogInfo(fmt.Sprintf("[oauth-debug] callback complete redirect=false total_ms=%d", time.Since(callbackStart).Milliseconds()))
helper.RespondWithJSON(w, http.StatusOK, map[string]string{
"message": "Authentication successful",
"access_token": accessToken,
})
RedirectURL := fmt.Sprintf("%s/callback?token=%s&user_id=%s", redirectURI, accessToken, userID)
http.Redirect(w, r, RedirectURL, http.StatusSeeOther)
}
func validateState(w http.ResponseWriter, r *http.Request) bool {
cookie, err := r.Cookie("oauth_state")
cookie, err := r.Cookie(oauthStateCookieName)
callbackState := r.URL.Query().Get("state")
if err != nil {
helper.LogError(err, "oauth_state cookie missing or unreadable during callback")
@@ -357,6 +340,12 @@ func validateState(w http.ResponseWriter, r *http.Request) bool {
return false
}
if strings.TrimSpace(callbackState) == "" {
helper.LogWarn(errorInvalidState)
helper.RespondWithError(w, http.StatusUnauthorized, errorInvalidState)
return false
}
if callbackState != cookie.Value {
helper.LogError(errors.New("oauth state mismatch"), fmt.Sprintf("OAuth state mismatch. cookie_state=%s callback_state=%s", cookie.Value, callbackState))
helper.LogWarn(errorInvalidState)
@@ -367,6 +356,23 @@ func validateState(w http.ResponseWriter, r *http.Request) bool {
return true
}
func callbackRedirectURI(w http.ResponseWriter, r *http.Request) (string, bool) {
cookie, err := r.Cookie(oauthRedirectURICookieName)
if err != nil {
helper.LogError(err, "oauth redirect_uri cookie missing or unreadable during callback")
helper.RespondWithError(w, http.StatusUnauthorized, "Unauthorized RedirectURI")
return "", false
}
redirectURI := strings.TrimSpace(cookie.Value)
if redirectURI == "" || !IsAllowedRedirectURI(redirectURI) {
helper.RespondWithError(w, http.StatusUnauthorized, "Unauthorized RedirectURI")
return "", false
}
return redirectURI, true
}
func FetchGoogleUserInfo(w http.ResponseWriter, r *http.Request) (models.UserGoogleInfo, error) {
fetchStart := time.Now()
code := r.URL.Query().Get("code")
+97
View File
@@ -0,0 +1,97 @@
package handlers
import (
"net/http"
"net/http/httptest"
"os"
"testing"
)
func TestGoogleLogin_RequiresRedirectURI(t *testing.T) {
original := os.Getenv("ALLOWED_REDIRECT_URIS")
os.Setenv("ALLOWED_REDIRECT_URIS", "http://localhost:5173")
defer os.Setenv("ALLOWED_REDIRECT_URIS", original)
req := httptest.NewRequest(http.MethodGet, "/v1/auth/login", nil)
recorder := httptest.NewRecorder()
GoogleLogin(recorder, req)
if recorder.Code != http.StatusBadRequest {
t.Fatalf("expected status %d, got %d", http.StatusBadRequest, recorder.Code)
}
}
func TestGoogleLogin_RejectsUnauthorizedRedirectURI(t *testing.T) {
original := os.Getenv("ALLOWED_REDIRECT_URIS")
os.Setenv("ALLOWED_REDIRECT_URIS", "http://localhost:5173")
defer os.Setenv("ALLOWED_REDIRECT_URIS", original)
req := httptest.NewRequest(http.MethodGet, "/v1/auth/login?redirect_uri=http://malicious.example", nil)
recorder := httptest.NewRecorder()
GoogleLogin(recorder, req)
if recorder.Code != http.StatusUnauthorized {
t.Fatalf("expected status %d, got %d", http.StatusUnauthorized, recorder.Code)
}
}
func TestValidateState_MissingCookie(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/v1/auth/callback?state=test-state", nil)
recorder := httptest.NewRecorder()
ok := validateState(recorder, req)
if ok {
t.Fatal("expected validateState to return false when oauth_state cookie is missing")
}
if recorder.Code != http.StatusUnauthorized {
t.Fatalf("expected status %d, got %d", http.StatusUnauthorized, recorder.Code)
}
}
func TestValidateState_Success(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/v1/auth/callback?state=test-state", nil)
req.AddCookie(&http.Cookie{Name: oauthStateCookieName, Value: "test-state"})
recorder := httptest.NewRecorder()
ok := validateState(recorder, req)
if !ok {
t.Fatal("expected validateState to return true for matching state")
}
}
func TestCallbackRedirectURI_MissingCookie(t *testing.T) {
original := os.Getenv("ALLOWED_REDIRECT_URIS")
os.Setenv("ALLOWED_REDIRECT_URIS", "http://localhost:5173")
defer os.Setenv("ALLOWED_REDIRECT_URIS", original)
req := httptest.NewRequest(http.MethodGet, "/v1/auth/callback?state=test-state", nil)
recorder := httptest.NewRecorder()
_, ok := callbackRedirectURI(recorder, req)
if ok {
t.Fatal("expected callbackRedirectURI to return false when redirect cookie is missing")
}
if recorder.Code != http.StatusUnauthorized {
t.Fatalf("expected status %d, got %d", http.StatusUnauthorized, recorder.Code)
}
}
func TestCallbackRedirectURI_Success(t *testing.T) {
original := os.Getenv("ALLOWED_REDIRECT_URIS")
os.Setenv("ALLOWED_REDIRECT_URIS", "http://localhost:5173")
defer os.Setenv("ALLOWED_REDIRECT_URIS", original)
req := httptest.NewRequest(http.MethodGet, "/v1/auth/callback?state=test-state", nil)
req.AddCookie(&http.Cookie{Name: oauthRedirectURICookieName, Value: "http://localhost:5173"})
recorder := httptest.NewRecorder()
uri, ok := callbackRedirectURI(recorder, req)
if !ok {
t.Fatal("expected callbackRedirectURI to return true for allowed redirect URI")
}
if uri != "http://localhost:5173" {
t.Fatalf("expected redirect URI %q, got %q", "http://localhost:5173", uri)
}
}
+44
View File
@@ -0,0 +1,44 @@
package middleware
import (
"authentication/helper"
"net/http"
"os"
"strings"
)
const defaultFrontendOrigin = "http://localhost:5173"
func allowedFrontendOrigins() map[string]struct{} {
raw := os.Getenv("FRONTEND_ORIGINS")
if strings.TrimSpace(raw) == "" {
raw = defaultFrontendOrigin
}
allowed := make(map[string]struct{})
for _, origin := range strings.Split(raw, ",") {
trimmed := strings.TrimSpace(origin)
if trimmed != "" {
allowed[trimmed] = struct{}{}
}
}
return allowed
}
func FrontendOriginWhitelist(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
origin := strings.TrimSpace(r.Header.Get("Origin"))
if origin == "" {
helper.RespondWithError(w, http.StatusBadRequest, "missing origin header")
return
}
if _, ok := allowedFrontendOrigins()[origin]; !ok {
helper.RespondWithError(w, http.StatusForbidden, "forbidden origin")
return
}
next.ServeHTTP(w, r)
})
}
+94
View File
@@ -0,0 +1,94 @@
package middleware
import (
"net/http"
"net/http/httptest"
"os"
"testing"
)
func TestFrontendOriginWhitelist_DefaultAllowedOrigin(t *testing.T) {
os.Unsetenv("FRONTEND_ORIGINS")
called := false
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
w.WriteHeader(http.StatusOK)
})
handler := FrontendOriginWhitelist(next)
req := httptest.NewRequest(http.MethodGet, "/v1/auth/forgot-password", nil)
req.Header.Set("Origin", defaultFrontendOrigin)
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, req)
if recorder.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, recorder.Code)
}
if !called {
t.Fatal("expected next handler to be called")
}
}
func TestFrontendOriginWhitelist_RejectsMissingOrigin(t *testing.T) {
os.Unsetenv("FRONTEND_ORIGINS")
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
handler := FrontendOriginWhitelist(next)
req := httptest.NewRequest(http.MethodGet, "/v1/auth/forgot-password", nil)
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, req)
if recorder.Code != http.StatusForbidden {
t.Fatalf("expected status %d, got %d", http.StatusForbidden, recorder.Code)
}
}
func TestFrontendOriginWhitelist_RejectsNonWhitelistedOrigin(t *testing.T) {
os.Unsetenv("FRONTEND_ORIGINS")
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
handler := FrontendOriginWhitelist(next)
req := httptest.NewRequest(http.MethodGet, "/v1/auth/forgot-password", nil)
req.Header.Set("Origin", "http://malicious-site.example")
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, req)
if recorder.Code != http.StatusForbidden {
t.Fatalf("expected status %d, got %d", http.StatusForbidden, recorder.Code)
}
}
func TestFrontendOriginWhitelist_UsesConfiguredOrigins(t *testing.T) {
os.Setenv("FRONTEND_ORIGINS", "http://localhost:4173, http://localhost:5173")
defer os.Unsetenv("FRONTEND_ORIGINS")
called := false
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
w.WriteHeader(http.StatusOK)
})
handler := FrontendOriginWhitelist(next)
req := httptest.NewRequest(http.MethodGet, "/v1/auth/forgot-password", nil)
req.Header.Set("Origin", "http://localhost:4173")
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, req)
if recorder.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, recorder.Code)
}
if !called {
t.Fatal("expected next handler to be called")
}
}
+5 -3
View File
@@ -15,9 +15,11 @@ func SetupRoutes(router *mux.Router, db *sql.DB) {
router.HandleFunc("/ready", handlers.ReadyHandler).Methods("GET")
authRoutes := router.PathPrefix("/v1/auth").Subrouter()
authRoutes.HandleFunc("/login", handlers.GoogleLogin).Methods("GET")
authRoutes.HandleFunc("/callback", handlers.GoogleCallback).Methods("GET")
authRoutes.HandleFunc("/forgot-password", handlers.ForgotPassword).Methods("GET")
frontendOnly := authRoutes.NewRoute().Subrouter()
frontendOnly.Use(middleware.FrontendOriginWhitelist)
frontendOnly.HandleFunc("/login", handlers.GoogleLogin).Methods("GET")
frontendOnly.HandleFunc("/forgot-password", handlers.ForgotPassword).Methods("GET")
frontendOnly.HandleFunc("/callback", handlers.GoogleCallback).Methods("GET")
csrfProtected := authRoutes.NewRoute().Subrouter()
csrfProtected.Use(middleware.CSRFMiddleware)
+16 -16
View File
@@ -34,10 +34,10 @@ func TestGetUser(t *testing.T) {
email := "test@example.com"
expectedID := "user123"
rows := sqlmock.NewRows([]string{"user_id"}).
rows := sqlmock.NewRows([]string{"users_id"}).
AddRow(expectedID)
mock.ExpectQuery(`SELECT user_id FROM users WHERE email_address = \? AND is_deleted = 0 LIMIT 1;`).
mock.ExpectQuery(`SELECT users_id FROM uess_user_management\.users WHERE email_address = \? AND is_deleted = 0 LIMIT 1;`).
WithArgs(email).
WillReturnRows(rows)
@@ -49,7 +49,7 @@ func TestGetUserNotFound(t *testing.T) {
email := "nonexistent@example.com"
mock.ExpectQuery(`SELECT user_id FROM users WHERE email_address = \? AND is_deleted = 0 LIMIT 1;`).
mock.ExpectQuery(`SELECT users_id FROM uess_user_management\.users WHERE email_address = \? AND is_deleted = 0 LIMIT 1;`).
WithArgs(email).
WillReturnError(sql.ErrNoRows)
@@ -71,10 +71,10 @@ func TestGetUserNullNames(t *testing.T) {
email := "test@example.com"
expectedID := "user456"
rows := sqlmock.NewRows([]string{"user_id"}).
rows := sqlmock.NewRows([]string{"users_id"}).
AddRow(expectedID)
mock.ExpectQuery(`SELECT user_id FROM users WHERE email_address = \? AND is_deleted = 0 LIMIT 1;`).
mock.ExpectQuery(`SELECT users_id FROM uess_user_management\.users WHERE email_address = \? AND is_deleted = 0 LIMIT 1;`).
WithArgs(email).
WillReturnRows(rows)
@@ -97,10 +97,10 @@ func TestGetUserID(t *testing.T) {
email := "test@example.com"
expectedID := "user789"
rows := sqlmock.NewRows([]string{"user_id"}).
rows := sqlmock.NewRows([]string{"users_id"}).
AddRow(expectedID)
mock.ExpectQuery(`SELECT user_id FROM users WHERE email_address = \? AND is_deleted = 0 LIMIT 1;`).
mock.ExpectQuery(`SELECT users_id FROM uess_user_management\.users WHERE email_address = \? AND is_deleted = 0 LIMIT 1;`).
WithArgs(email).
WillReturnRows(rows)
@@ -124,7 +124,7 @@ func TestCheckEmailInDB(t *testing.T) {
rows := sqlmock.NewRows([]string{"exists"}).
AddRow(true)
mock.ExpectQuery(`SELECT EXISTS \( SELECT 1 FROM users WHERE email_address = \? AND is_deleted = 0\)`).
mock.ExpectQuery(`SELECT EXISTS \(\s*SELECT 1 FROM uess_user_management\.users WHERE email_address = \? AND is_deleted = 0\)`).
WithArgs(email).
WillReturnRows(rows)
@@ -152,7 +152,7 @@ func TestCheckEmailInDBNotExists(t *testing.T) {
rows := sqlmock.NewRows([]string{"exists"}).
AddRow(false)
mock.ExpectQuery(`SELECT EXISTS \( SELECT 1 FROM users WHERE email_address = \? AND is_deleted = 0\)`).
mock.ExpectQuery(`SELECT EXISTS \(\s*SELECT 1 FROM uess_user_management\.users WHERE email_address = \? AND is_deleted = 0\)`).
WithArgs(email).
WillReturnRows(rows)
@@ -173,7 +173,7 @@ func TestCheckEmailInDBError(t *testing.T) {
email := "error@example.com"
mock.ExpectQuery(`SELECT EXISTS \( SELECT 1 FROM users WHERE email_address = \? AND is_deleted = 0\)`).
mock.ExpectQuery(`SELECT EXISTS \(\s*SELECT 1 FROM uess_user_management\.users WHERE email_address = \? AND is_deleted = 0\)`).
WithArgs(email).
WillReturnError(sql.ErrConnDone)
@@ -198,7 +198,7 @@ func TestGetUserIDFromEmail(t *testing.T) {
rows := sqlmock.NewRows([]string{"id"}).
AddRow(expectedID)
mock.ExpectQuery(`SELECT user_id FROM \( SELECT user_id, 1 AS priority FROM users WHERE email_address = \? AND is_deleted = 0 \) t ORDER BY priority ASC LIMIT 1;`).
mock.ExpectQuery(`SELECT users_id\s+FROM \(\s*SELECT users_id, 1 AS priority\s+FROM users\s+WHERE email_address = \?\s+AND is_deleted = 0\s*\) t\s+ORDER BY priority ASC\s+LIMIT 1;`).
WithArgs(email).
WillReturnRows(rows)
@@ -223,7 +223,7 @@ func TestGetUserIDFromEmailNotFound(t *testing.T) {
email := "notfound@example.com"
mock.ExpectQuery(`SELECT user_id FROM \( SELECT user_id, 1 AS priority FROM users WHERE email_address = \? AND is_deleted = 0 \) t ORDER BY priority ASC LIMIT 1;`).
mock.ExpectQuery(`SELECT users_id\s+FROM \(\s*SELECT users_id, 1 AS priority\s+FROM users\s+WHERE email_address = \?\s+AND is_deleted = 0\s*\) t\s+ORDER BY priority ASC\s+LIMIT 1;`).
WithArgs(email).
WillReturnError(sql.ErrNoRows)
@@ -244,7 +244,7 @@ func TestGetUserIDFromEmailDBError(t *testing.T) {
email := "error@example.com"
mock.ExpectQuery(`SELECT user_id FROM \( SELECT user_id, 1 AS priority FROM users WHERE email_address = \? AND is_deleted = 0 \) t ORDER BY priority ASC LIMIT 1;`).
mock.ExpectQuery(`SELECT users_id\s+FROM \(\s*SELECT users_id, 1 AS priority\s+FROM users\s+WHERE email_address = \?\s+AND is_deleted = 0\s*\) t\s+ORDER BY priority ASC\s+LIMIT 1;`).
WithArgs(email).
WillReturnError(sql.ErrConnDone)
@@ -279,10 +279,10 @@ func TestGetUserMultipleEmails(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.email, func(t *testing.T) {
rows := sqlmock.NewRows([]string{"user_id"}).
rows := sqlmock.NewRows([]string{"users_id"}).
AddRow(tc.userID)
mock.ExpectQuery(`SELECT user_id FROM users WHERE email_address = \? AND is_deleted = 0 LIMIT 1;`).
mock.ExpectQuery(`SELECT users_id FROM uess_user_management\.users WHERE email_address = \? AND is_deleted = 0 LIMIT 1;`).
WithArgs(tc.email).
WillReturnRows(rows)
@@ -316,7 +316,7 @@ func TestCheckEmailInDBVariousEmails(t *testing.T) {
rows := sqlmock.NewRows([]string{"exists"}).
AddRow(exists)
mock.ExpectQuery(`SELECT EXISTS \( SELECT 1 FROM users WHERE email_address = \? AND is_deleted = 0\)`).
mock.ExpectQuery(`SELECT EXISTS \(\s*SELECT 1 FROM uess_user_management\.users WHERE email_address = \? AND is_deleted = 0\)`).
WithArgs(email).
WillReturnRows(rows)