added fetching of origin in auth login
This commit is contained in:
+63
-57
@@ -27,9 +27,12 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var googleOauthConfig oauth2.Config
|
var googleOauthConfig oauth2.Config
|
||||||
var oauthStateString = generateRandomState()
|
|
||||||
var AuthorizationURL string
|
var AuthorizationURL string
|
||||||
var FetchedRedirectURI *string
|
|
||||||
|
const (
|
||||||
|
oauthStateCookieName = "oauth_state"
|
||||||
|
oauthRedirectURICookieName = "oauth_redirect_uri"
|
||||||
|
)
|
||||||
|
|
||||||
func isTestEnvironment() bool {
|
func isTestEnvironment() bool {
|
||||||
return flag.Lookup("test.v") != nil || strings.Contains(os.Args[0], ".test")
|
return flag.Lookup("test.v") != nil || strings.Contains(os.Args[0], ".test")
|
||||||
@@ -106,29 +109,41 @@ func generateRandomState() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GoogleLogin(w http.ResponseWriter, r *http.Request) {
|
func GoogleLogin(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
helper.LogInfo(fmt.Sprintf("Generated oauth_state: %s", oauthStateString))
|
|
||||||
|
|
||||||
isSecure := strings.HasPrefix(os.Getenv("BACKEND_URL"), HTTPS)
|
isSecure := strings.HasPrefix(os.Getenv("BACKEND_URL"), HTTPS)
|
||||||
|
state := generateRandomState()
|
||||||
|
helper.LogInfo(fmt.Sprintf("Generated oauth_state: %s", state))
|
||||||
|
|
||||||
|
redirectURI := strings.TrimSpace(r.URL.Query().Get("redirect_uri"))
|
||||||
|
if redirectURI == "" {
|
||||||
|
helper.RespondWithError(w, http.StatusBadRequest, "redirect_uri is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !IsAllowedRedirectURI(redirectURI) {
|
||||||
|
helper.RespondWithError(w, http.StatusUnauthorized, "Unauthorized RedirectURI")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
http.SetCookie(w, &http.Cookie{
|
http.SetCookie(w, &http.Cookie{
|
||||||
Name: "oauth_state",
|
Name: oauthStateCookieName,
|
||||||
Value: oauthStateString,
|
Value: state,
|
||||||
|
Path: "/",
|
||||||
|
HttpOnly: true,
|
||||||
|
Secure: isSecure,
|
||||||
|
SameSite: http.SameSiteLaxMode,
|
||||||
|
Expires: time.Now().Add(5 * time.Minute),
|
||||||
|
})
|
||||||
|
http.SetCookie(w, &http.Cookie{
|
||||||
|
Name: oauthRedirectURICookieName,
|
||||||
|
Value: redirectURI,
|
||||||
Path: "/",
|
Path: "/",
|
||||||
HttpOnly: true,
|
HttpOnly: true,
|
||||||
Secure: isSecure,
|
Secure: isSecure,
|
||||||
SameSite: http.SameSiteLaxMode,
|
SameSite: http.SameSiteLaxMode,
|
||||||
Expires: time.Now().Add(5 * time.Minute),
|
Expires: time.Now().Add(5 * time.Minute),
|
||||||
})
|
})
|
||||||
redirectURI := r.URL.Query().Get("redirect_uri")
|
|
||||||
if redirectURI != "" {
|
|
||||||
FetchedRedirectURI = &redirectURI
|
|
||||||
log.Print("FetchedRedirectURI set to: ", *FetchedRedirectURI)
|
|
||||||
} else {
|
|
||||||
FetchedRedirectURI = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
url := googleOauthConfig.AuthCodeURL(oauthStateString, oauth2.AccessTypeOffline, oauth2.ApprovalForce)
|
url := googleOauthConfig.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.ApprovalForce)
|
||||||
http.Redirect(w, r, url, http.StatusFound)
|
http.Redirect(w, r, url, http.StatusFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -184,6 +199,11 @@ func GoogleCallback(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
helper.LogInfo(fmt.Sprintf("[oauth-debug] state validation ok duration_ms=%d", time.Since(stateStart).Milliseconds()))
|
helper.LogInfo(fmt.Sprintf("[oauth-debug] state validation ok duration_ms=%d", time.Since(stateStart).Milliseconds()))
|
||||||
|
|
||||||
|
redirectURI, ok := callbackRedirectURI(w, r)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
googleUserInfoStart := time.Now()
|
googleUserInfoStart := time.Now()
|
||||||
userInfo, err := FetchGoogleUserInfo(w, r)
|
userInfo, err := FetchGoogleUserInfo(w, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -232,24 +252,8 @@ func GoogleCallback(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
if !emailExists {
|
if !emailExists {
|
||||||
helper.LogError(errors.New("unregistered email"), "Google login attempt with unregistered email: "+email)
|
helper.LogError(errors.New("unregistered email"), "Google login attempt with unregistered email: "+email)
|
||||||
if FetchedRedirectURI != nil && *FetchedRedirectURI != "" {
|
RedirectURL := fmt.Sprintf("%s/callback?error=%s=", redirectURI, "unregistered_email")
|
||||||
RedirectURI := *FetchedRedirectURI
|
http.Redirect(w, r, RedirectURL, http.StatusSeeOther)
|
||||||
log.Print("RedirectURI from query param: ", RedirectURI)
|
|
||||||
if !IsAllowedRedirectURI(RedirectURI) {
|
|
||||||
helper.LogError(errors.New("unauthorized redirect uri"), "Blocked redirect URI for unregistered email: "+RedirectURI)
|
|
||||||
helper.RespondWithError(w, http.StatusUnauthorized, "Unauthorized RedirectURI")
|
|
||||||
log.Print("Unauthorized RedirectURI: ", RedirectURI)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log.Print("Valid redirect_uri: ", RedirectURI)
|
|
||||||
RedirectURL := fmt.Sprintf("%s/callback?error=%s=", RedirectURI, "unregistered_email")
|
|
||||||
http.Redirect(w, r, RedirectURL, http.StatusSeeOther)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Print("No redirect_uri provided, returning JSON response")
|
|
||||||
// No redirect_uri provided, return JSON response
|
|
||||||
helper.RespondWithError(w, http.StatusUnauthorized, "Your email is not registered in the system. Please contact your administrator to request access.")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -322,33 +326,12 @@ func GoogleCallback(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
helper.LogInfo("Copy this access token: " + accessToken)
|
helper.LogInfo("Copy this access token: " + accessToken)
|
||||||
|
|
||||||
if FetchedRedirectURI != nil && *FetchedRedirectURI != "" {
|
RedirectURL := fmt.Sprintf("%s/callback?token=%s&user_id=%s", redirectURI, accessToken, userID)
|
||||||
RedirectURI := *FetchedRedirectURI
|
http.Redirect(w, r, RedirectURL, http.StatusSeeOther)
|
||||||
log.Print("RedirectURI from query param: ", RedirectURI)
|
|
||||||
if !IsAllowedRedirectURI(RedirectURI) {
|
|
||||||
helper.LogError(errors.New("unauthorized redirect uri"), "Blocked redirect URI after successful auth: "+RedirectURI)
|
|
||||||
helper.RespondWithError(w, http.StatusUnauthorized, "Unauthorized RedirectURI")
|
|
||||||
log.Print("Unauthorized RedirectURI: ", RedirectURI)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log.Print("Valid redirect_uri: ", RedirectURI)
|
|
||||||
RedirectURL := fmt.Sprintf("%s/callback?token=%s&user_id=%s", RedirectURI, accessToken, userID)
|
|
||||||
helper.LogInfo(fmt.Sprintf("[oauth-debug] callback complete redirect=true total_ms=%d", time.Since(callbackStart).Milliseconds()))
|
|
||||||
http.Redirect(w, r, RedirectURL, http.StatusSeeOther)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Print("No redirect_uri provided, returning JSON response")
|
|
||||||
// No redirect_uri provided, return JSON response
|
|
||||||
helper.LogInfo(fmt.Sprintf("[oauth-debug] callback complete redirect=false total_ms=%d", time.Since(callbackStart).Milliseconds()))
|
|
||||||
helper.RespondWithJSON(w, http.StatusOK, map[string]string{
|
|
||||||
"message": "Authentication successful",
|
|
||||||
"access_token": accessToken,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateState(w http.ResponseWriter, r *http.Request) bool {
|
func validateState(w http.ResponseWriter, r *http.Request) bool {
|
||||||
cookie, err := r.Cookie("oauth_state")
|
cookie, err := r.Cookie(oauthStateCookieName)
|
||||||
callbackState := r.URL.Query().Get("state")
|
callbackState := r.URL.Query().Get("state")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
helper.LogError(err, "oauth_state cookie missing or unreadable during callback")
|
helper.LogError(err, "oauth_state cookie missing or unreadable during callback")
|
||||||
@@ -357,6 +340,12 @@ func validateState(w http.ResponseWriter, r *http.Request) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if strings.TrimSpace(callbackState) == "" {
|
||||||
|
helper.LogWarn(errorInvalidState)
|
||||||
|
helper.RespondWithError(w, http.StatusUnauthorized, errorInvalidState)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
if callbackState != cookie.Value {
|
if callbackState != cookie.Value {
|
||||||
helper.LogError(errors.New("oauth state mismatch"), fmt.Sprintf("OAuth state mismatch. cookie_state=%s callback_state=%s", cookie.Value, callbackState))
|
helper.LogError(errors.New("oauth state mismatch"), fmt.Sprintf("OAuth state mismatch. cookie_state=%s callback_state=%s", cookie.Value, callbackState))
|
||||||
helper.LogWarn(errorInvalidState)
|
helper.LogWarn(errorInvalidState)
|
||||||
@@ -367,6 +356,23 @@ func validateState(w http.ResponseWriter, r *http.Request) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func callbackRedirectURI(w http.ResponseWriter, r *http.Request) (string, bool) {
|
||||||
|
cookie, err := r.Cookie(oauthRedirectURICookieName)
|
||||||
|
if err != nil {
|
||||||
|
helper.LogError(err, "oauth redirect_uri cookie missing or unreadable during callback")
|
||||||
|
helper.RespondWithError(w, http.StatusUnauthorized, "Unauthorized RedirectURI")
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
redirectURI := strings.TrimSpace(cookie.Value)
|
||||||
|
if redirectURI == "" || !IsAllowedRedirectURI(redirectURI) {
|
||||||
|
helper.RespondWithError(w, http.StatusUnauthorized, "Unauthorized RedirectURI")
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
return redirectURI, true
|
||||||
|
}
|
||||||
|
|
||||||
func FetchGoogleUserInfo(w http.ResponseWriter, r *http.Request) (models.UserGoogleInfo, error) {
|
func FetchGoogleUserInfo(w http.ResponseWriter, r *http.Request) (models.UserGoogleInfo, error) {
|
||||||
fetchStart := time.Now()
|
fetchStart := time.Now()
|
||||||
code := r.URL.Query().Get("code")
|
code := r.URL.Query().Get("code")
|
||||||
|
|||||||
@@ -0,0 +1,97 @@
|
|||||||
|
package handlers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGoogleLogin_RequiresRedirectURI(t *testing.T) {
|
||||||
|
original := os.Getenv("ALLOWED_REDIRECT_URIS")
|
||||||
|
os.Setenv("ALLOWED_REDIRECT_URIS", "http://localhost:5173")
|
||||||
|
defer os.Setenv("ALLOWED_REDIRECT_URIS", original)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/v1/auth/login", nil)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
|
GoogleLogin(recorder, req)
|
||||||
|
|
||||||
|
if recorder.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected status %d, got %d", http.StatusBadRequest, recorder.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGoogleLogin_RejectsUnauthorizedRedirectURI(t *testing.T) {
|
||||||
|
original := os.Getenv("ALLOWED_REDIRECT_URIS")
|
||||||
|
os.Setenv("ALLOWED_REDIRECT_URIS", "http://localhost:5173")
|
||||||
|
defer os.Setenv("ALLOWED_REDIRECT_URIS", original)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/v1/auth/login?redirect_uri=http://malicious.example", nil)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
|
GoogleLogin(recorder, req)
|
||||||
|
|
||||||
|
if recorder.Code != http.StatusUnauthorized {
|
||||||
|
t.Fatalf("expected status %d, got %d", http.StatusUnauthorized, recorder.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateState_MissingCookie(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/v1/auth/callback?state=test-state", nil)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
|
ok := validateState(recorder, req)
|
||||||
|
if ok {
|
||||||
|
t.Fatal("expected validateState to return false when oauth_state cookie is missing")
|
||||||
|
}
|
||||||
|
if recorder.Code != http.StatusUnauthorized {
|
||||||
|
t.Fatalf("expected status %d, got %d", http.StatusUnauthorized, recorder.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateState_Success(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/v1/auth/callback?state=test-state", nil)
|
||||||
|
req.AddCookie(&http.Cookie{Name: oauthStateCookieName, Value: "test-state"})
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
|
ok := validateState(recorder, req)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected validateState to return true for matching state")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCallbackRedirectURI_MissingCookie(t *testing.T) {
|
||||||
|
original := os.Getenv("ALLOWED_REDIRECT_URIS")
|
||||||
|
os.Setenv("ALLOWED_REDIRECT_URIS", "http://localhost:5173")
|
||||||
|
defer os.Setenv("ALLOWED_REDIRECT_URIS", original)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/v1/auth/callback?state=test-state", nil)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
|
_, ok := callbackRedirectURI(recorder, req)
|
||||||
|
if ok {
|
||||||
|
t.Fatal("expected callbackRedirectURI to return false when redirect cookie is missing")
|
||||||
|
}
|
||||||
|
if recorder.Code != http.StatusUnauthorized {
|
||||||
|
t.Fatalf("expected status %d, got %d", http.StatusUnauthorized, recorder.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCallbackRedirectURI_Success(t *testing.T) {
|
||||||
|
original := os.Getenv("ALLOWED_REDIRECT_URIS")
|
||||||
|
os.Setenv("ALLOWED_REDIRECT_URIS", "http://localhost:5173")
|
||||||
|
defer os.Setenv("ALLOWED_REDIRECT_URIS", original)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/v1/auth/callback?state=test-state", nil)
|
||||||
|
req.AddCookie(&http.Cookie{Name: oauthRedirectURICookieName, Value: "http://localhost:5173"})
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
|
uri, ok := callbackRedirectURI(recorder, req)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected callbackRedirectURI to return true for allowed redirect URI")
|
||||||
|
}
|
||||||
|
if uri != "http://localhost:5173" {
|
||||||
|
t.Fatalf("expected redirect URI %q, got %q", "http://localhost:5173", uri)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,44 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"authentication/helper"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
const defaultFrontendOrigin = "http://localhost:5173"
|
||||||
|
|
||||||
|
func allowedFrontendOrigins() map[string]struct{} {
|
||||||
|
raw := os.Getenv("FRONTEND_ORIGINS")
|
||||||
|
if strings.TrimSpace(raw) == "" {
|
||||||
|
raw = defaultFrontendOrigin
|
||||||
|
}
|
||||||
|
|
||||||
|
allowed := make(map[string]struct{})
|
||||||
|
for _, origin := range strings.Split(raw, ",") {
|
||||||
|
trimmed := strings.TrimSpace(origin)
|
||||||
|
if trimmed != "" {
|
||||||
|
allowed[trimmed] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return allowed
|
||||||
|
}
|
||||||
|
|
||||||
|
func FrontendOriginWhitelist(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
origin := strings.TrimSpace(r.Header.Get("Origin"))
|
||||||
|
if origin == "" {
|
||||||
|
helper.RespondWithError(w, http.StatusBadRequest, "missing origin header")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := allowedFrontendOrigins()[origin]; !ok {
|
||||||
|
helper.RespondWithError(w, http.StatusForbidden, "forbidden origin")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,94 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFrontendOriginWhitelist_DefaultAllowedOrigin(t *testing.T) {
|
||||||
|
os.Unsetenv("FRONTEND_ORIGINS")
|
||||||
|
|
||||||
|
called := false
|
||||||
|
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
called = true
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
handler := FrontendOriginWhitelist(next)
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/v1/auth/forgot-password", nil)
|
||||||
|
req.Header.Set("Origin", defaultFrontendOrigin)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
if recorder.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d", http.StatusOK, recorder.Code)
|
||||||
|
}
|
||||||
|
if !called {
|
||||||
|
t.Fatal("expected next handler to be called")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFrontendOriginWhitelist_RejectsMissingOrigin(t *testing.T) {
|
||||||
|
os.Unsetenv("FRONTEND_ORIGINS")
|
||||||
|
|
||||||
|
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
handler := FrontendOriginWhitelist(next)
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/v1/auth/forgot-password", nil)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
if recorder.Code != http.StatusForbidden {
|
||||||
|
t.Fatalf("expected status %d, got %d", http.StatusForbidden, recorder.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFrontendOriginWhitelist_RejectsNonWhitelistedOrigin(t *testing.T) {
|
||||||
|
os.Unsetenv("FRONTEND_ORIGINS")
|
||||||
|
|
||||||
|
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
handler := FrontendOriginWhitelist(next)
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/v1/auth/forgot-password", nil)
|
||||||
|
req.Header.Set("Origin", "http://malicious-site.example")
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
if recorder.Code != http.StatusForbidden {
|
||||||
|
t.Fatalf("expected status %d, got %d", http.StatusForbidden, recorder.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFrontendOriginWhitelist_UsesConfiguredOrigins(t *testing.T) {
|
||||||
|
os.Setenv("FRONTEND_ORIGINS", "http://localhost:4173, http://localhost:5173")
|
||||||
|
defer os.Unsetenv("FRONTEND_ORIGINS")
|
||||||
|
|
||||||
|
called := false
|
||||||
|
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
called = true
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
handler := FrontendOriginWhitelist(next)
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/v1/auth/forgot-password", nil)
|
||||||
|
req.Header.Set("Origin", "http://localhost:4173")
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
if recorder.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d", http.StatusOK, recorder.Code)
|
||||||
|
}
|
||||||
|
if !called {
|
||||||
|
t.Fatal("expected next handler to be called")
|
||||||
|
}
|
||||||
|
}
|
||||||
+5
-3
@@ -15,9 +15,11 @@ func SetupRoutes(router *mux.Router, db *sql.DB) {
|
|||||||
router.HandleFunc("/ready", handlers.ReadyHandler).Methods("GET")
|
router.HandleFunc("/ready", handlers.ReadyHandler).Methods("GET")
|
||||||
|
|
||||||
authRoutes := router.PathPrefix("/v1/auth").Subrouter()
|
authRoutes := router.PathPrefix("/v1/auth").Subrouter()
|
||||||
authRoutes.HandleFunc("/login", handlers.GoogleLogin).Methods("GET")
|
frontendOnly := authRoutes.NewRoute().Subrouter()
|
||||||
authRoutes.HandleFunc("/callback", handlers.GoogleCallback).Methods("GET")
|
frontendOnly.Use(middleware.FrontendOriginWhitelist)
|
||||||
authRoutes.HandleFunc("/forgot-password", handlers.ForgotPassword).Methods("GET")
|
frontendOnly.HandleFunc("/login", handlers.GoogleLogin).Methods("GET")
|
||||||
|
frontendOnly.HandleFunc("/forgot-password", handlers.ForgotPassword).Methods("GET")
|
||||||
|
frontendOnly.HandleFunc("/callback", handlers.GoogleCallback).Methods("GET")
|
||||||
|
|
||||||
csrfProtected := authRoutes.NewRoute().Subrouter()
|
csrfProtected := authRoutes.NewRoute().Subrouter()
|
||||||
csrfProtected.Use(middleware.CSRFMiddleware)
|
csrfProtected.Use(middleware.CSRFMiddleware)
|
||||||
|
|||||||
+16
-16
@@ -34,10 +34,10 @@ func TestGetUser(t *testing.T) {
|
|||||||
email := "test@example.com"
|
email := "test@example.com"
|
||||||
expectedID := "user123"
|
expectedID := "user123"
|
||||||
|
|
||||||
rows := sqlmock.NewRows([]string{"user_id"}).
|
rows := sqlmock.NewRows([]string{"users_id"}).
|
||||||
AddRow(expectedID)
|
AddRow(expectedID)
|
||||||
|
|
||||||
mock.ExpectQuery(`SELECT user_id FROM users WHERE email_address = \? AND is_deleted = 0 LIMIT 1;`).
|
mock.ExpectQuery(`SELECT users_id FROM uess_user_management\.users WHERE email_address = \? AND is_deleted = 0 LIMIT 1;`).
|
||||||
WithArgs(email).
|
WithArgs(email).
|
||||||
WillReturnRows(rows)
|
WillReturnRows(rows)
|
||||||
|
|
||||||
@@ -49,7 +49,7 @@ func TestGetUserNotFound(t *testing.T) {
|
|||||||
|
|
||||||
email := "nonexistent@example.com"
|
email := "nonexistent@example.com"
|
||||||
|
|
||||||
mock.ExpectQuery(`SELECT user_id FROM users WHERE email_address = \? AND is_deleted = 0 LIMIT 1;`).
|
mock.ExpectQuery(`SELECT users_id FROM uess_user_management\.users WHERE email_address = \? AND is_deleted = 0 LIMIT 1;`).
|
||||||
WithArgs(email).
|
WithArgs(email).
|
||||||
WillReturnError(sql.ErrNoRows)
|
WillReturnError(sql.ErrNoRows)
|
||||||
|
|
||||||
@@ -71,10 +71,10 @@ func TestGetUserNullNames(t *testing.T) {
|
|||||||
email := "test@example.com"
|
email := "test@example.com"
|
||||||
expectedID := "user456"
|
expectedID := "user456"
|
||||||
|
|
||||||
rows := sqlmock.NewRows([]string{"user_id"}).
|
rows := sqlmock.NewRows([]string{"users_id"}).
|
||||||
AddRow(expectedID)
|
AddRow(expectedID)
|
||||||
|
|
||||||
mock.ExpectQuery(`SELECT user_id FROM users WHERE email_address = \? AND is_deleted = 0 LIMIT 1;`).
|
mock.ExpectQuery(`SELECT users_id FROM uess_user_management\.users WHERE email_address = \? AND is_deleted = 0 LIMIT 1;`).
|
||||||
WithArgs(email).
|
WithArgs(email).
|
||||||
WillReturnRows(rows)
|
WillReturnRows(rows)
|
||||||
|
|
||||||
@@ -97,10 +97,10 @@ func TestGetUserID(t *testing.T) {
|
|||||||
email := "test@example.com"
|
email := "test@example.com"
|
||||||
expectedID := "user789"
|
expectedID := "user789"
|
||||||
|
|
||||||
rows := sqlmock.NewRows([]string{"user_id"}).
|
rows := sqlmock.NewRows([]string{"users_id"}).
|
||||||
AddRow(expectedID)
|
AddRow(expectedID)
|
||||||
|
|
||||||
mock.ExpectQuery(`SELECT user_id FROM users WHERE email_address = \? AND is_deleted = 0 LIMIT 1;`).
|
mock.ExpectQuery(`SELECT users_id FROM uess_user_management\.users WHERE email_address = \? AND is_deleted = 0 LIMIT 1;`).
|
||||||
WithArgs(email).
|
WithArgs(email).
|
||||||
WillReturnRows(rows)
|
WillReturnRows(rows)
|
||||||
|
|
||||||
@@ -124,7 +124,7 @@ func TestCheckEmailInDB(t *testing.T) {
|
|||||||
rows := sqlmock.NewRows([]string{"exists"}).
|
rows := sqlmock.NewRows([]string{"exists"}).
|
||||||
AddRow(true)
|
AddRow(true)
|
||||||
|
|
||||||
mock.ExpectQuery(`SELECT EXISTS \( SELECT 1 FROM users WHERE email_address = \? AND is_deleted = 0\)`).
|
mock.ExpectQuery(`SELECT EXISTS \(\s*SELECT 1 FROM uess_user_management\.users WHERE email_address = \? AND is_deleted = 0\)`).
|
||||||
WithArgs(email).
|
WithArgs(email).
|
||||||
WillReturnRows(rows)
|
WillReturnRows(rows)
|
||||||
|
|
||||||
@@ -152,7 +152,7 @@ func TestCheckEmailInDBNotExists(t *testing.T) {
|
|||||||
rows := sqlmock.NewRows([]string{"exists"}).
|
rows := sqlmock.NewRows([]string{"exists"}).
|
||||||
AddRow(false)
|
AddRow(false)
|
||||||
|
|
||||||
mock.ExpectQuery(`SELECT EXISTS \( SELECT 1 FROM users WHERE email_address = \? AND is_deleted = 0\)`).
|
mock.ExpectQuery(`SELECT EXISTS \(\s*SELECT 1 FROM uess_user_management\.users WHERE email_address = \? AND is_deleted = 0\)`).
|
||||||
WithArgs(email).
|
WithArgs(email).
|
||||||
WillReturnRows(rows)
|
WillReturnRows(rows)
|
||||||
|
|
||||||
@@ -173,7 +173,7 @@ func TestCheckEmailInDBError(t *testing.T) {
|
|||||||
|
|
||||||
email := "error@example.com"
|
email := "error@example.com"
|
||||||
|
|
||||||
mock.ExpectQuery(`SELECT EXISTS \( SELECT 1 FROM users WHERE email_address = \? AND is_deleted = 0\)`).
|
mock.ExpectQuery(`SELECT EXISTS \(\s*SELECT 1 FROM uess_user_management\.users WHERE email_address = \? AND is_deleted = 0\)`).
|
||||||
WithArgs(email).
|
WithArgs(email).
|
||||||
WillReturnError(sql.ErrConnDone)
|
WillReturnError(sql.ErrConnDone)
|
||||||
|
|
||||||
@@ -198,7 +198,7 @@ func TestGetUserIDFromEmail(t *testing.T) {
|
|||||||
rows := sqlmock.NewRows([]string{"id"}).
|
rows := sqlmock.NewRows([]string{"id"}).
|
||||||
AddRow(expectedID)
|
AddRow(expectedID)
|
||||||
|
|
||||||
mock.ExpectQuery(`SELECT user_id FROM \( SELECT user_id, 1 AS priority FROM users WHERE email_address = \? AND is_deleted = 0 \) t ORDER BY priority ASC LIMIT 1;`).
|
mock.ExpectQuery(`SELECT users_id\s+FROM \(\s*SELECT users_id, 1 AS priority\s+FROM users\s+WHERE email_address = \?\s+AND is_deleted = 0\s*\) t\s+ORDER BY priority ASC\s+LIMIT 1;`).
|
||||||
WithArgs(email).
|
WithArgs(email).
|
||||||
WillReturnRows(rows)
|
WillReturnRows(rows)
|
||||||
|
|
||||||
@@ -223,7 +223,7 @@ func TestGetUserIDFromEmailNotFound(t *testing.T) {
|
|||||||
|
|
||||||
email := "notfound@example.com"
|
email := "notfound@example.com"
|
||||||
|
|
||||||
mock.ExpectQuery(`SELECT user_id FROM \( SELECT user_id, 1 AS priority FROM users WHERE email_address = \? AND is_deleted = 0 \) t ORDER BY priority ASC LIMIT 1;`).
|
mock.ExpectQuery(`SELECT users_id\s+FROM \(\s*SELECT users_id, 1 AS priority\s+FROM users\s+WHERE email_address = \?\s+AND is_deleted = 0\s*\) t\s+ORDER BY priority ASC\s+LIMIT 1;`).
|
||||||
WithArgs(email).
|
WithArgs(email).
|
||||||
WillReturnError(sql.ErrNoRows)
|
WillReturnError(sql.ErrNoRows)
|
||||||
|
|
||||||
@@ -244,7 +244,7 @@ func TestGetUserIDFromEmailDBError(t *testing.T) {
|
|||||||
|
|
||||||
email := "error@example.com"
|
email := "error@example.com"
|
||||||
|
|
||||||
mock.ExpectQuery(`SELECT user_id FROM \( SELECT user_id, 1 AS priority FROM users WHERE email_address = \? AND is_deleted = 0 \) t ORDER BY priority ASC LIMIT 1;`).
|
mock.ExpectQuery(`SELECT users_id\s+FROM \(\s*SELECT users_id, 1 AS priority\s+FROM users\s+WHERE email_address = \?\s+AND is_deleted = 0\s*\) t\s+ORDER BY priority ASC\s+LIMIT 1;`).
|
||||||
WithArgs(email).
|
WithArgs(email).
|
||||||
WillReturnError(sql.ErrConnDone)
|
WillReturnError(sql.ErrConnDone)
|
||||||
|
|
||||||
@@ -279,10 +279,10 @@ func TestGetUserMultipleEmails(t *testing.T) {
|
|||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.email, func(t *testing.T) {
|
t.Run(tc.email, func(t *testing.T) {
|
||||||
rows := sqlmock.NewRows([]string{"user_id"}).
|
rows := sqlmock.NewRows([]string{"users_id"}).
|
||||||
AddRow(tc.userID)
|
AddRow(tc.userID)
|
||||||
|
|
||||||
mock.ExpectQuery(`SELECT user_id FROM users WHERE email_address = \? AND is_deleted = 0 LIMIT 1;`).
|
mock.ExpectQuery(`SELECT users_id FROM uess_user_management\.users WHERE email_address = \? AND is_deleted = 0 LIMIT 1;`).
|
||||||
WithArgs(tc.email).
|
WithArgs(tc.email).
|
||||||
WillReturnRows(rows)
|
WillReturnRows(rows)
|
||||||
|
|
||||||
@@ -316,7 +316,7 @@ func TestCheckEmailInDBVariousEmails(t *testing.T) {
|
|||||||
rows := sqlmock.NewRows([]string{"exists"}).
|
rows := sqlmock.NewRows([]string{"exists"}).
|
||||||
AddRow(exists)
|
AddRow(exists)
|
||||||
|
|
||||||
mock.ExpectQuery(`SELECT EXISTS \( SELECT 1 FROM users WHERE email_address = \? AND is_deleted = 0\)`).
|
mock.ExpectQuery(`SELECT EXISTS \(\s*SELECT 1 FROM uess_user_management\.users WHERE email_address = \? AND is_deleted = 0\)`).
|
||||||
WithArgs(email).
|
WithArgs(email).
|
||||||
WillReturnRows(rows)
|
WillReturnRows(rows)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user