From 30c91cf5c8bd9e79de13778590ec1d4ee2b00224 Mon Sep 17 00:00:00 2001 From: F04C Date: Thu, 5 Mar 2026 10:09:12 +0800 Subject: [PATCH] added fetching of origin in auth login --- handlers/google_auth.go | 120 ++++++++++++++------------ handlers/google_auth_security_test.go | 97 +++++++++++++++++++++ middleware/frontend_origin.go | 44 ++++++++++ middleware/frontend_origin_test.go | 94 ++++++++++++++++++++ routes/routes.go | 8 +- services/users_test.go | 32 +++---- 6 files changed, 319 insertions(+), 76 deletions(-) create mode 100644 handlers/google_auth_security_test.go create mode 100644 middleware/frontend_origin.go create mode 100644 middleware/frontend_origin_test.go diff --git a/handlers/google_auth.go b/handlers/google_auth.go index 4d02254..6fd7d0d 100644 --- a/handlers/google_auth.go +++ b/handlers/google_auth.go @@ -27,9 +27,12 @@ import ( ) var googleOauthConfig oauth2.Config -var oauthStateString = generateRandomState() var AuthorizationURL string -var FetchedRedirectURI *string + +const ( + oauthStateCookieName = "oauth_state" + oauthRedirectURICookieName = "oauth_redirect_uri" +) func isTestEnvironment() bool { return flag.Lookup("test.v") != nil || strings.Contains(os.Args[0], ".test") @@ -106,29 +109,41 @@ func generateRandomState() string { } func GoogleLogin(w http.ResponseWriter, r *http.Request) { - - helper.LogInfo(fmt.Sprintf("Generated oauth_state: %s", oauthStateString)) - isSecure := strings.HasPrefix(os.Getenv("BACKEND_URL"), HTTPS) + state := generateRandomState() + helper.LogInfo(fmt.Sprintf("Generated oauth_state: %s", state)) + + redirectURI := strings.TrimSpace(r.URL.Query().Get("redirect_uri")) + if redirectURI == "" { + helper.RespondWithError(w, http.StatusBadRequest, "redirect_uri is required") + return + } + + if !IsAllowedRedirectURI(redirectURI) { + helper.RespondWithError(w, http.StatusUnauthorized, "Unauthorized RedirectURI") + return + } http.SetCookie(w, &http.Cookie{ - Name: "oauth_state", - Value: oauthStateString, + Name: oauthStateCookieName, + Value: state, + Path: "/", + HttpOnly: true, + Secure: isSecure, + SameSite: http.SameSiteLaxMode, + Expires: time.Now().Add(5 * time.Minute), + }) + http.SetCookie(w, &http.Cookie{ + Name: oauthRedirectURICookieName, + Value: redirectURI, Path: "/", HttpOnly: true, Secure: isSecure, SameSite: http.SameSiteLaxMode, Expires: time.Now().Add(5 * time.Minute), }) - redirectURI := r.URL.Query().Get("redirect_uri") - if redirectURI != "" { - FetchedRedirectURI = &redirectURI - log.Print("FetchedRedirectURI set to: ", *FetchedRedirectURI) - } else { - FetchedRedirectURI = nil - } - url := googleOauthConfig.AuthCodeURL(oauthStateString, oauth2.AccessTypeOffline, oauth2.ApprovalForce) + url := googleOauthConfig.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.ApprovalForce) http.Redirect(w, r, url, http.StatusFound) } @@ -184,6 +199,11 @@ func GoogleCallback(w http.ResponseWriter, r *http.Request) { } helper.LogInfo(fmt.Sprintf("[oauth-debug] state validation ok duration_ms=%d", time.Since(stateStart).Milliseconds())) + redirectURI, ok := callbackRedirectURI(w, r) + if !ok { + return + } + googleUserInfoStart := time.Now() userInfo, err := FetchGoogleUserInfo(w, r) if err != nil { @@ -232,24 +252,8 @@ func GoogleCallback(w http.ResponseWriter, r *http.Request) { if !emailExists { helper.LogError(errors.New("unregistered email"), "Google login attempt with unregistered email: "+email) - if FetchedRedirectURI != nil && *FetchedRedirectURI != "" { - RedirectURI := *FetchedRedirectURI - log.Print("RedirectURI from query param: ", RedirectURI) - if !IsAllowedRedirectURI(RedirectURI) { - helper.LogError(errors.New("unauthorized redirect uri"), "Blocked redirect URI for unregistered email: "+RedirectURI) - helper.RespondWithError(w, http.StatusUnauthorized, "Unauthorized RedirectURI") - log.Print("Unauthorized RedirectURI: ", RedirectURI) - return - } - log.Print("Valid redirect_uri: ", RedirectURI) - RedirectURL := fmt.Sprintf("%s/callback?error=%s=", RedirectURI, "unregistered_email") - http.Redirect(w, r, RedirectURL, http.StatusSeeOther) - return - } - - log.Print("No redirect_uri provided, returning JSON response") - // No redirect_uri provided, return JSON response - helper.RespondWithError(w, http.StatusUnauthorized, "Your email is not registered in the system. Please contact your administrator to request access.") + RedirectURL := fmt.Sprintf("%s/callback?error=%s=", redirectURI, "unregistered_email") + http.Redirect(w, r, RedirectURL, http.StatusSeeOther) return } @@ -322,33 +326,12 @@ func GoogleCallback(w http.ResponseWriter, r *http.Request) { helper.LogInfo("Copy this access token: " + accessToken) - if FetchedRedirectURI != nil && *FetchedRedirectURI != "" { - RedirectURI := *FetchedRedirectURI - log.Print("RedirectURI from query param: ", RedirectURI) - if !IsAllowedRedirectURI(RedirectURI) { - helper.LogError(errors.New("unauthorized redirect uri"), "Blocked redirect URI after successful auth: "+RedirectURI) - helper.RespondWithError(w, http.StatusUnauthorized, "Unauthorized RedirectURI") - log.Print("Unauthorized RedirectURI: ", RedirectURI) - return - } - log.Print("Valid redirect_uri: ", RedirectURI) - RedirectURL := fmt.Sprintf("%s/callback?token=%s&user_id=%s", RedirectURI, accessToken, userID) - helper.LogInfo(fmt.Sprintf("[oauth-debug] callback complete redirect=true total_ms=%d", time.Since(callbackStart).Milliseconds())) - http.Redirect(w, r, RedirectURL, http.StatusSeeOther) - return - } - - log.Print("No redirect_uri provided, returning JSON response") - // No redirect_uri provided, return JSON response - helper.LogInfo(fmt.Sprintf("[oauth-debug] callback complete redirect=false total_ms=%d", time.Since(callbackStart).Milliseconds())) - helper.RespondWithJSON(w, http.StatusOK, map[string]string{ - "message": "Authentication successful", - "access_token": accessToken, - }) + RedirectURL := fmt.Sprintf("%s/callback?token=%s&user_id=%s", redirectURI, accessToken, userID) + http.Redirect(w, r, RedirectURL, http.StatusSeeOther) } func validateState(w http.ResponseWriter, r *http.Request) bool { - cookie, err := r.Cookie("oauth_state") + cookie, err := r.Cookie(oauthStateCookieName) callbackState := r.URL.Query().Get("state") if err != nil { helper.LogError(err, "oauth_state cookie missing or unreadable during callback") @@ -357,6 +340,12 @@ func validateState(w http.ResponseWriter, r *http.Request) bool { return false } + if strings.TrimSpace(callbackState) == "" { + helper.LogWarn(errorInvalidState) + helper.RespondWithError(w, http.StatusUnauthorized, errorInvalidState) + return false + } + if callbackState != cookie.Value { helper.LogError(errors.New("oauth state mismatch"), fmt.Sprintf("OAuth state mismatch. cookie_state=%s callback_state=%s", cookie.Value, callbackState)) helper.LogWarn(errorInvalidState) @@ -367,6 +356,23 @@ func validateState(w http.ResponseWriter, r *http.Request) bool { return true } +func callbackRedirectURI(w http.ResponseWriter, r *http.Request) (string, bool) { + cookie, err := r.Cookie(oauthRedirectURICookieName) + if err != nil { + helper.LogError(err, "oauth redirect_uri cookie missing or unreadable during callback") + helper.RespondWithError(w, http.StatusUnauthorized, "Unauthorized RedirectURI") + return "", false + } + + redirectURI := strings.TrimSpace(cookie.Value) + if redirectURI == "" || !IsAllowedRedirectURI(redirectURI) { + helper.RespondWithError(w, http.StatusUnauthorized, "Unauthorized RedirectURI") + return "", false + } + + return redirectURI, true +} + func FetchGoogleUserInfo(w http.ResponseWriter, r *http.Request) (models.UserGoogleInfo, error) { fetchStart := time.Now() code := r.URL.Query().Get("code") diff --git a/handlers/google_auth_security_test.go b/handlers/google_auth_security_test.go new file mode 100644 index 0000000..6e01dbe --- /dev/null +++ b/handlers/google_auth_security_test.go @@ -0,0 +1,97 @@ +package handlers + +import ( + "net/http" + "net/http/httptest" + "os" + "testing" +) + +func TestGoogleLogin_RequiresRedirectURI(t *testing.T) { + original := os.Getenv("ALLOWED_REDIRECT_URIS") + os.Setenv("ALLOWED_REDIRECT_URIS", "http://localhost:5173") + defer os.Setenv("ALLOWED_REDIRECT_URIS", original) + + req := httptest.NewRequest(http.MethodGet, "/v1/auth/login", nil) + recorder := httptest.NewRecorder() + + GoogleLogin(recorder, req) + + if recorder.Code != http.StatusBadRequest { + t.Fatalf("expected status %d, got %d", http.StatusBadRequest, recorder.Code) + } +} + +func TestGoogleLogin_RejectsUnauthorizedRedirectURI(t *testing.T) { + original := os.Getenv("ALLOWED_REDIRECT_URIS") + os.Setenv("ALLOWED_REDIRECT_URIS", "http://localhost:5173") + defer os.Setenv("ALLOWED_REDIRECT_URIS", original) + + req := httptest.NewRequest(http.MethodGet, "/v1/auth/login?redirect_uri=http://malicious.example", nil) + recorder := httptest.NewRecorder() + + GoogleLogin(recorder, req) + + if recorder.Code != http.StatusUnauthorized { + t.Fatalf("expected status %d, got %d", http.StatusUnauthorized, recorder.Code) + } +} + +func TestValidateState_MissingCookie(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/v1/auth/callback?state=test-state", nil) + recorder := httptest.NewRecorder() + + ok := validateState(recorder, req) + if ok { + t.Fatal("expected validateState to return false when oauth_state cookie is missing") + } + if recorder.Code != http.StatusUnauthorized { + t.Fatalf("expected status %d, got %d", http.StatusUnauthorized, recorder.Code) + } +} + +func TestValidateState_Success(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/v1/auth/callback?state=test-state", nil) + req.AddCookie(&http.Cookie{Name: oauthStateCookieName, Value: "test-state"}) + recorder := httptest.NewRecorder() + + ok := validateState(recorder, req) + if !ok { + t.Fatal("expected validateState to return true for matching state") + } +} + +func TestCallbackRedirectURI_MissingCookie(t *testing.T) { + original := os.Getenv("ALLOWED_REDIRECT_URIS") + os.Setenv("ALLOWED_REDIRECT_URIS", "http://localhost:5173") + defer os.Setenv("ALLOWED_REDIRECT_URIS", original) + + req := httptest.NewRequest(http.MethodGet, "/v1/auth/callback?state=test-state", nil) + recorder := httptest.NewRecorder() + + _, ok := callbackRedirectURI(recorder, req) + if ok { + t.Fatal("expected callbackRedirectURI to return false when redirect cookie is missing") + } + if recorder.Code != http.StatusUnauthorized { + t.Fatalf("expected status %d, got %d", http.StatusUnauthorized, recorder.Code) + } +} + +func TestCallbackRedirectURI_Success(t *testing.T) { + original := os.Getenv("ALLOWED_REDIRECT_URIS") + os.Setenv("ALLOWED_REDIRECT_URIS", "http://localhost:5173") + defer os.Setenv("ALLOWED_REDIRECT_URIS", original) + + req := httptest.NewRequest(http.MethodGet, "/v1/auth/callback?state=test-state", nil) + req.AddCookie(&http.Cookie{Name: oauthRedirectURICookieName, Value: "http://localhost:5173"}) + recorder := httptest.NewRecorder() + + uri, ok := callbackRedirectURI(recorder, req) + if !ok { + t.Fatal("expected callbackRedirectURI to return true for allowed redirect URI") + } + if uri != "http://localhost:5173" { + t.Fatalf("expected redirect URI %q, got %q", "http://localhost:5173", uri) + } +} diff --git a/middleware/frontend_origin.go b/middleware/frontend_origin.go new file mode 100644 index 0000000..e3a8691 --- /dev/null +++ b/middleware/frontend_origin.go @@ -0,0 +1,44 @@ +package middleware + +import ( + "authentication/helper" + "net/http" + "os" + "strings" +) + +const defaultFrontendOrigin = "http://localhost:5173" + +func allowedFrontendOrigins() map[string]struct{} { + raw := os.Getenv("FRONTEND_ORIGINS") + if strings.TrimSpace(raw) == "" { + raw = defaultFrontendOrigin + } + + allowed := make(map[string]struct{}) + for _, origin := range strings.Split(raw, ",") { + trimmed := strings.TrimSpace(origin) + if trimmed != "" { + allowed[trimmed] = struct{}{} + } + } + + return allowed +} + +func FrontendOriginWhitelist(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + origin := strings.TrimSpace(r.Header.Get("Origin")) + if origin == "" { + helper.RespondWithError(w, http.StatusBadRequest, "missing origin header") + return + } + + if _, ok := allowedFrontendOrigins()[origin]; !ok { + helper.RespondWithError(w, http.StatusForbidden, "forbidden origin") + return + } + + next.ServeHTTP(w, r) + }) +} diff --git a/middleware/frontend_origin_test.go b/middleware/frontend_origin_test.go new file mode 100644 index 0000000..2aa63c2 --- /dev/null +++ b/middleware/frontend_origin_test.go @@ -0,0 +1,94 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "os" + "testing" +) + +func TestFrontendOriginWhitelist_DefaultAllowedOrigin(t *testing.T) { + os.Unsetenv("FRONTEND_ORIGINS") + + called := false + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + w.WriteHeader(http.StatusOK) + }) + + handler := FrontendOriginWhitelist(next) + req := httptest.NewRequest(http.MethodGet, "/v1/auth/forgot-password", nil) + req.Header.Set("Origin", defaultFrontendOrigin) + recorder := httptest.NewRecorder() + + handler.ServeHTTP(recorder, req) + + if recorder.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, recorder.Code) + } + if !called { + t.Fatal("expected next handler to be called") + } +} + +func TestFrontendOriginWhitelist_RejectsMissingOrigin(t *testing.T) { + os.Unsetenv("FRONTEND_ORIGINS") + + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + handler := FrontendOriginWhitelist(next) + req := httptest.NewRequest(http.MethodGet, "/v1/auth/forgot-password", nil) + recorder := httptest.NewRecorder() + + handler.ServeHTTP(recorder, req) + + if recorder.Code != http.StatusForbidden { + t.Fatalf("expected status %d, got %d", http.StatusForbidden, recorder.Code) + } +} + +func TestFrontendOriginWhitelist_RejectsNonWhitelistedOrigin(t *testing.T) { + os.Unsetenv("FRONTEND_ORIGINS") + + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + handler := FrontendOriginWhitelist(next) + req := httptest.NewRequest(http.MethodGet, "/v1/auth/forgot-password", nil) + req.Header.Set("Origin", "http://malicious-site.example") + recorder := httptest.NewRecorder() + + handler.ServeHTTP(recorder, req) + + if recorder.Code != http.StatusForbidden { + t.Fatalf("expected status %d, got %d", http.StatusForbidden, recorder.Code) + } +} + +func TestFrontendOriginWhitelist_UsesConfiguredOrigins(t *testing.T) { + os.Setenv("FRONTEND_ORIGINS", "http://localhost:4173, http://localhost:5173") + defer os.Unsetenv("FRONTEND_ORIGINS") + + called := false + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + w.WriteHeader(http.StatusOK) + }) + + handler := FrontendOriginWhitelist(next) + req := httptest.NewRequest(http.MethodGet, "/v1/auth/forgot-password", nil) + req.Header.Set("Origin", "http://localhost:4173") + recorder := httptest.NewRecorder() + + handler.ServeHTTP(recorder, req) + + if recorder.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, recorder.Code) + } + if !called { + t.Fatal("expected next handler to be called") + } +} diff --git a/routes/routes.go b/routes/routes.go index 2813e6e..62f8082 100644 --- a/routes/routes.go +++ b/routes/routes.go @@ -15,9 +15,11 @@ func SetupRoutes(router *mux.Router, db *sql.DB) { router.HandleFunc("/ready", handlers.ReadyHandler).Methods("GET") authRoutes := router.PathPrefix("/v1/auth").Subrouter() - authRoutes.HandleFunc("/login", handlers.GoogleLogin).Methods("GET") - authRoutes.HandleFunc("/callback", handlers.GoogleCallback).Methods("GET") - authRoutes.HandleFunc("/forgot-password", handlers.ForgotPassword).Methods("GET") + frontendOnly := authRoutes.NewRoute().Subrouter() + frontendOnly.Use(middleware.FrontendOriginWhitelist) + frontendOnly.HandleFunc("/login", handlers.GoogleLogin).Methods("GET") + frontendOnly.HandleFunc("/forgot-password", handlers.ForgotPassword).Methods("GET") + frontendOnly.HandleFunc("/callback", handlers.GoogleCallback).Methods("GET") csrfProtected := authRoutes.NewRoute().Subrouter() csrfProtected.Use(middleware.CSRFMiddleware) diff --git a/services/users_test.go b/services/users_test.go index 5c06c33..18a2897 100644 --- a/services/users_test.go +++ b/services/users_test.go @@ -34,10 +34,10 @@ func TestGetUser(t *testing.T) { email := "test@example.com" expectedID := "user123" - rows := sqlmock.NewRows([]string{"user_id"}). + rows := sqlmock.NewRows([]string{"users_id"}). AddRow(expectedID) - mock.ExpectQuery(`SELECT user_id FROM users WHERE email_address = \? AND is_deleted = 0 LIMIT 1;`). + mock.ExpectQuery(`SELECT users_id FROM uess_user_management\.users WHERE email_address = \? AND is_deleted = 0 LIMIT 1;`). WithArgs(email). WillReturnRows(rows) @@ -49,7 +49,7 @@ func TestGetUserNotFound(t *testing.T) { email := "nonexistent@example.com" - mock.ExpectQuery(`SELECT user_id FROM users WHERE email_address = \? AND is_deleted = 0 LIMIT 1;`). + mock.ExpectQuery(`SELECT users_id FROM uess_user_management\.users WHERE email_address = \? AND is_deleted = 0 LIMIT 1;`). WithArgs(email). WillReturnError(sql.ErrNoRows) @@ -71,10 +71,10 @@ func TestGetUserNullNames(t *testing.T) { email := "test@example.com" expectedID := "user456" - rows := sqlmock.NewRows([]string{"user_id"}). + rows := sqlmock.NewRows([]string{"users_id"}). AddRow(expectedID) - mock.ExpectQuery(`SELECT user_id FROM users WHERE email_address = \? AND is_deleted = 0 LIMIT 1;`). + mock.ExpectQuery(`SELECT users_id FROM uess_user_management\.users WHERE email_address = \? AND is_deleted = 0 LIMIT 1;`). WithArgs(email). WillReturnRows(rows) @@ -97,10 +97,10 @@ func TestGetUserID(t *testing.T) { email := "test@example.com" expectedID := "user789" - rows := sqlmock.NewRows([]string{"user_id"}). + rows := sqlmock.NewRows([]string{"users_id"}). AddRow(expectedID) - mock.ExpectQuery(`SELECT user_id FROM users WHERE email_address = \? AND is_deleted = 0 LIMIT 1;`). + mock.ExpectQuery(`SELECT users_id FROM uess_user_management\.users WHERE email_address = \? AND is_deleted = 0 LIMIT 1;`). WithArgs(email). WillReturnRows(rows) @@ -124,7 +124,7 @@ func TestCheckEmailInDB(t *testing.T) { rows := sqlmock.NewRows([]string{"exists"}). AddRow(true) - mock.ExpectQuery(`SELECT EXISTS \( SELECT 1 FROM users WHERE email_address = \? AND is_deleted = 0\)`). + mock.ExpectQuery(`SELECT EXISTS \(\s*SELECT 1 FROM uess_user_management\.users WHERE email_address = \? AND is_deleted = 0\)`). WithArgs(email). WillReturnRows(rows) @@ -152,7 +152,7 @@ func TestCheckEmailInDBNotExists(t *testing.T) { rows := sqlmock.NewRows([]string{"exists"}). AddRow(false) - mock.ExpectQuery(`SELECT EXISTS \( SELECT 1 FROM users WHERE email_address = \? AND is_deleted = 0\)`). + mock.ExpectQuery(`SELECT EXISTS \(\s*SELECT 1 FROM uess_user_management\.users WHERE email_address = \? AND is_deleted = 0\)`). WithArgs(email). WillReturnRows(rows) @@ -173,7 +173,7 @@ func TestCheckEmailInDBError(t *testing.T) { email := "error@example.com" - mock.ExpectQuery(`SELECT EXISTS \( SELECT 1 FROM users WHERE email_address = \? AND is_deleted = 0\)`). + mock.ExpectQuery(`SELECT EXISTS \(\s*SELECT 1 FROM uess_user_management\.users WHERE email_address = \? AND is_deleted = 0\)`). WithArgs(email). WillReturnError(sql.ErrConnDone) @@ -198,7 +198,7 @@ func TestGetUserIDFromEmail(t *testing.T) { rows := sqlmock.NewRows([]string{"id"}). AddRow(expectedID) - mock.ExpectQuery(`SELECT user_id FROM \( SELECT user_id, 1 AS priority FROM users WHERE email_address = \? AND is_deleted = 0 \) t ORDER BY priority ASC LIMIT 1;`). + mock.ExpectQuery(`SELECT users_id\s+FROM \(\s*SELECT users_id, 1 AS priority\s+FROM users\s+WHERE email_address = \?\s+AND is_deleted = 0\s*\) t\s+ORDER BY priority ASC\s+LIMIT 1;`). WithArgs(email). WillReturnRows(rows) @@ -223,7 +223,7 @@ func TestGetUserIDFromEmailNotFound(t *testing.T) { email := "notfound@example.com" - mock.ExpectQuery(`SELECT user_id FROM \( SELECT user_id, 1 AS priority FROM users WHERE email_address = \? AND is_deleted = 0 \) t ORDER BY priority ASC LIMIT 1;`). + mock.ExpectQuery(`SELECT users_id\s+FROM \(\s*SELECT users_id, 1 AS priority\s+FROM users\s+WHERE email_address = \?\s+AND is_deleted = 0\s*\) t\s+ORDER BY priority ASC\s+LIMIT 1;`). WithArgs(email). WillReturnError(sql.ErrNoRows) @@ -244,7 +244,7 @@ func TestGetUserIDFromEmailDBError(t *testing.T) { email := "error@example.com" - mock.ExpectQuery(`SELECT user_id FROM \( SELECT user_id, 1 AS priority FROM users WHERE email_address = \? AND is_deleted = 0 \) t ORDER BY priority ASC LIMIT 1;`). + mock.ExpectQuery(`SELECT users_id\s+FROM \(\s*SELECT users_id, 1 AS priority\s+FROM users\s+WHERE email_address = \?\s+AND is_deleted = 0\s*\) t\s+ORDER BY priority ASC\s+LIMIT 1;`). WithArgs(email). WillReturnError(sql.ErrConnDone) @@ -279,10 +279,10 @@ func TestGetUserMultipleEmails(t *testing.T) { for _, tc := range testCases { t.Run(tc.email, func(t *testing.T) { - rows := sqlmock.NewRows([]string{"user_id"}). + rows := sqlmock.NewRows([]string{"users_id"}). AddRow(tc.userID) - mock.ExpectQuery(`SELECT user_id FROM users WHERE email_address = \? AND is_deleted = 0 LIMIT 1;`). + mock.ExpectQuery(`SELECT users_id FROM uess_user_management\.users WHERE email_address = \? AND is_deleted = 0 LIMIT 1;`). WithArgs(tc.email). WillReturnRows(rows) @@ -316,7 +316,7 @@ func TestCheckEmailInDBVariousEmails(t *testing.T) { rows := sqlmock.NewRows([]string{"exists"}). AddRow(exists) - mock.ExpectQuery(`SELECT EXISTS \( SELECT 1 FROM users WHERE email_address = \? AND is_deleted = 0\)`). + mock.ExpectQuery(`SELECT EXISTS \(\s*SELECT 1 FROM uess_user_management\.users WHERE email_address = \? AND is_deleted = 0\)`). WithArgs(email). WillReturnRows(rows)