added fetching of origin in auth login
This commit is contained in:
+63
-57
@@ -27,9 +27,12 @@ import (
|
||||
)
|
||||
|
||||
var googleOauthConfig oauth2.Config
|
||||
var oauthStateString = generateRandomState()
|
||||
var AuthorizationURL string
|
||||
var FetchedRedirectURI *string
|
||||
|
||||
const (
|
||||
oauthStateCookieName = "oauth_state"
|
||||
oauthRedirectURICookieName = "oauth_redirect_uri"
|
||||
)
|
||||
|
||||
func isTestEnvironment() bool {
|
||||
return flag.Lookup("test.v") != nil || strings.Contains(os.Args[0], ".test")
|
||||
@@ -106,29 +109,41 @@ func generateRandomState() string {
|
||||
}
|
||||
|
||||
func GoogleLogin(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
helper.LogInfo(fmt.Sprintf("Generated oauth_state: %s", oauthStateString))
|
||||
|
||||
isSecure := strings.HasPrefix(os.Getenv("BACKEND_URL"), HTTPS)
|
||||
state := generateRandomState()
|
||||
helper.LogInfo(fmt.Sprintf("Generated oauth_state: %s", state))
|
||||
|
||||
redirectURI := strings.TrimSpace(r.URL.Query().Get("redirect_uri"))
|
||||
if redirectURI == "" {
|
||||
helper.RespondWithError(w, http.StatusBadRequest, "redirect_uri is required")
|
||||
return
|
||||
}
|
||||
|
||||
if !IsAllowedRedirectURI(redirectURI) {
|
||||
helper.RespondWithError(w, http.StatusUnauthorized, "Unauthorized RedirectURI")
|
||||
return
|
||||
}
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "oauth_state",
|
||||
Value: oauthStateString,
|
||||
Name: oauthStateCookieName,
|
||||
Value: state,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: isSecure,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
Expires: time.Now().Add(5 * time.Minute),
|
||||
})
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: oauthRedirectURICookieName,
|
||||
Value: redirectURI,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: isSecure,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
Expires: time.Now().Add(5 * time.Minute),
|
||||
})
|
||||
redirectURI := r.URL.Query().Get("redirect_uri")
|
||||
if redirectURI != "" {
|
||||
FetchedRedirectURI = &redirectURI
|
||||
log.Print("FetchedRedirectURI set to: ", *FetchedRedirectURI)
|
||||
} else {
|
||||
FetchedRedirectURI = nil
|
||||
}
|
||||
|
||||
url := googleOauthConfig.AuthCodeURL(oauthStateString, oauth2.AccessTypeOffline, oauth2.ApprovalForce)
|
||||
url := googleOauthConfig.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.ApprovalForce)
|
||||
http.Redirect(w, r, url, http.StatusFound)
|
||||
}
|
||||
|
||||
@@ -184,6 +199,11 @@ func GoogleCallback(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
helper.LogInfo(fmt.Sprintf("[oauth-debug] state validation ok duration_ms=%d", time.Since(stateStart).Milliseconds()))
|
||||
|
||||
redirectURI, ok := callbackRedirectURI(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
googleUserInfoStart := time.Now()
|
||||
userInfo, err := FetchGoogleUserInfo(w, r)
|
||||
if err != nil {
|
||||
@@ -232,24 +252,8 @@ func GoogleCallback(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
if !emailExists {
|
||||
helper.LogError(errors.New("unregistered email"), "Google login attempt with unregistered email: "+email)
|
||||
if FetchedRedirectURI != nil && *FetchedRedirectURI != "" {
|
||||
RedirectURI := *FetchedRedirectURI
|
||||
log.Print("RedirectURI from query param: ", RedirectURI)
|
||||
if !IsAllowedRedirectURI(RedirectURI) {
|
||||
helper.LogError(errors.New("unauthorized redirect uri"), "Blocked redirect URI for unregistered email: "+RedirectURI)
|
||||
helper.RespondWithError(w, http.StatusUnauthorized, "Unauthorized RedirectURI")
|
||||
log.Print("Unauthorized RedirectURI: ", RedirectURI)
|
||||
return
|
||||
}
|
||||
log.Print("Valid redirect_uri: ", RedirectURI)
|
||||
RedirectURL := fmt.Sprintf("%s/callback?error=%s=", RedirectURI, "unregistered_email")
|
||||
http.Redirect(w, r, RedirectURL, http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
|
||||
log.Print("No redirect_uri provided, returning JSON response")
|
||||
// No redirect_uri provided, return JSON response
|
||||
helper.RespondWithError(w, http.StatusUnauthorized, "Your email is not registered in the system. Please contact your administrator to request access.")
|
||||
RedirectURL := fmt.Sprintf("%s/callback?error=%s=", redirectURI, "unregistered_email")
|
||||
http.Redirect(w, r, RedirectURL, http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -322,33 +326,12 @@ func GoogleCallback(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
helper.LogInfo("Copy this access token: " + accessToken)
|
||||
|
||||
if FetchedRedirectURI != nil && *FetchedRedirectURI != "" {
|
||||
RedirectURI := *FetchedRedirectURI
|
||||
log.Print("RedirectURI from query param: ", RedirectURI)
|
||||
if !IsAllowedRedirectURI(RedirectURI) {
|
||||
helper.LogError(errors.New("unauthorized redirect uri"), "Blocked redirect URI after successful auth: "+RedirectURI)
|
||||
helper.RespondWithError(w, http.StatusUnauthorized, "Unauthorized RedirectURI")
|
||||
log.Print("Unauthorized RedirectURI: ", RedirectURI)
|
||||
return
|
||||
}
|
||||
log.Print("Valid redirect_uri: ", RedirectURI)
|
||||
RedirectURL := fmt.Sprintf("%s/callback?token=%s&user_id=%s", RedirectURI, accessToken, userID)
|
||||
helper.LogInfo(fmt.Sprintf("[oauth-debug] callback complete redirect=true total_ms=%d", time.Since(callbackStart).Milliseconds()))
|
||||
http.Redirect(w, r, RedirectURL, http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
|
||||
log.Print("No redirect_uri provided, returning JSON response")
|
||||
// No redirect_uri provided, return JSON response
|
||||
helper.LogInfo(fmt.Sprintf("[oauth-debug] callback complete redirect=false total_ms=%d", time.Since(callbackStart).Milliseconds()))
|
||||
helper.RespondWithJSON(w, http.StatusOK, map[string]string{
|
||||
"message": "Authentication successful",
|
||||
"access_token": accessToken,
|
||||
})
|
||||
RedirectURL := fmt.Sprintf("%s/callback?token=%s&user_id=%s", redirectURI, accessToken, userID)
|
||||
http.Redirect(w, r, RedirectURL, http.StatusSeeOther)
|
||||
}
|
||||
|
||||
func validateState(w http.ResponseWriter, r *http.Request) bool {
|
||||
cookie, err := r.Cookie("oauth_state")
|
||||
cookie, err := r.Cookie(oauthStateCookieName)
|
||||
callbackState := r.URL.Query().Get("state")
|
||||
if err != nil {
|
||||
helper.LogError(err, "oauth_state cookie missing or unreadable during callback")
|
||||
@@ -357,6 +340,12 @@ func validateState(w http.ResponseWriter, r *http.Request) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
if strings.TrimSpace(callbackState) == "" {
|
||||
helper.LogWarn(errorInvalidState)
|
||||
helper.RespondWithError(w, http.StatusUnauthorized, errorInvalidState)
|
||||
return false
|
||||
}
|
||||
|
||||
if callbackState != cookie.Value {
|
||||
helper.LogError(errors.New("oauth state mismatch"), fmt.Sprintf("OAuth state mismatch. cookie_state=%s callback_state=%s", cookie.Value, callbackState))
|
||||
helper.LogWarn(errorInvalidState)
|
||||
@@ -367,6 +356,23 @@ func validateState(w http.ResponseWriter, r *http.Request) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func callbackRedirectURI(w http.ResponseWriter, r *http.Request) (string, bool) {
|
||||
cookie, err := r.Cookie(oauthRedirectURICookieName)
|
||||
if err != nil {
|
||||
helper.LogError(err, "oauth redirect_uri cookie missing or unreadable during callback")
|
||||
helper.RespondWithError(w, http.StatusUnauthorized, "Unauthorized RedirectURI")
|
||||
return "", false
|
||||
}
|
||||
|
||||
redirectURI := strings.TrimSpace(cookie.Value)
|
||||
if redirectURI == "" || !IsAllowedRedirectURI(redirectURI) {
|
||||
helper.RespondWithError(w, http.StatusUnauthorized, "Unauthorized RedirectURI")
|
||||
return "", false
|
||||
}
|
||||
|
||||
return redirectURI, true
|
||||
}
|
||||
|
||||
func FetchGoogleUserInfo(w http.ResponseWriter, r *http.Request) (models.UserGoogleInfo, error) {
|
||||
fetchStart := time.Now()
|
||||
code := r.URL.Query().Get("code")
|
||||
|
||||
@@ -0,0 +1,97 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGoogleLogin_RequiresRedirectURI(t *testing.T) {
|
||||
original := os.Getenv("ALLOWED_REDIRECT_URIS")
|
||||
os.Setenv("ALLOWED_REDIRECT_URIS", "http://localhost:5173")
|
||||
defer os.Setenv("ALLOWED_REDIRECT_URIS", original)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1/auth/login", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
GoogleLogin(recorder, req)
|
||||
|
||||
if recorder.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusBadRequest, recorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGoogleLogin_RejectsUnauthorizedRedirectURI(t *testing.T) {
|
||||
original := os.Getenv("ALLOWED_REDIRECT_URIS")
|
||||
os.Setenv("ALLOWED_REDIRECT_URIS", "http://localhost:5173")
|
||||
defer os.Setenv("ALLOWED_REDIRECT_URIS", original)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1/auth/login?redirect_uri=http://malicious.example", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
GoogleLogin(recorder, req)
|
||||
|
||||
if recorder.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusUnauthorized, recorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateState_MissingCookie(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1/auth/callback?state=test-state", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
ok := validateState(recorder, req)
|
||||
if ok {
|
||||
t.Fatal("expected validateState to return false when oauth_state cookie is missing")
|
||||
}
|
||||
if recorder.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusUnauthorized, recorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateState_Success(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1/auth/callback?state=test-state", nil)
|
||||
req.AddCookie(&http.Cookie{Name: oauthStateCookieName, Value: "test-state"})
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
ok := validateState(recorder, req)
|
||||
if !ok {
|
||||
t.Fatal("expected validateState to return true for matching state")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallbackRedirectURI_MissingCookie(t *testing.T) {
|
||||
original := os.Getenv("ALLOWED_REDIRECT_URIS")
|
||||
os.Setenv("ALLOWED_REDIRECT_URIS", "http://localhost:5173")
|
||||
defer os.Setenv("ALLOWED_REDIRECT_URIS", original)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1/auth/callback?state=test-state", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
_, ok := callbackRedirectURI(recorder, req)
|
||||
if ok {
|
||||
t.Fatal("expected callbackRedirectURI to return false when redirect cookie is missing")
|
||||
}
|
||||
if recorder.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusUnauthorized, recorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallbackRedirectURI_Success(t *testing.T) {
|
||||
original := os.Getenv("ALLOWED_REDIRECT_URIS")
|
||||
os.Setenv("ALLOWED_REDIRECT_URIS", "http://localhost:5173")
|
||||
defer os.Setenv("ALLOWED_REDIRECT_URIS", original)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1/auth/callback?state=test-state", nil)
|
||||
req.AddCookie(&http.Cookie{Name: oauthRedirectURICookieName, Value: "http://localhost:5173"})
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
uri, ok := callbackRedirectURI(recorder, req)
|
||||
if !ok {
|
||||
t.Fatal("expected callbackRedirectURI to return true for allowed redirect URI")
|
||||
}
|
||||
if uri != "http://localhost:5173" {
|
||||
t.Fatalf("expected redirect URI %q, got %q", "http://localhost:5173", uri)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"authentication/helper"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const defaultFrontendOrigin = "http://localhost:5173"
|
||||
|
||||
func allowedFrontendOrigins() map[string]struct{} {
|
||||
raw := os.Getenv("FRONTEND_ORIGINS")
|
||||
if strings.TrimSpace(raw) == "" {
|
||||
raw = defaultFrontendOrigin
|
||||
}
|
||||
|
||||
allowed := make(map[string]struct{})
|
||||
for _, origin := range strings.Split(raw, ",") {
|
||||
trimmed := strings.TrimSpace(origin)
|
||||
if trimmed != "" {
|
||||
allowed[trimmed] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
return allowed
|
||||
}
|
||||
|
||||
func FrontendOriginWhitelist(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
origin := strings.TrimSpace(r.Header.Get("Origin"))
|
||||
if origin == "" {
|
||||
helper.RespondWithError(w, http.StatusBadRequest, "missing origin header")
|
||||
return
|
||||
}
|
||||
|
||||
if _, ok := allowedFrontendOrigins()[origin]; !ok {
|
||||
helper.RespondWithError(w, http.StatusForbidden, "forbidden origin")
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,94 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFrontendOriginWhitelist_DefaultAllowedOrigin(t *testing.T) {
|
||||
os.Unsetenv("FRONTEND_ORIGINS")
|
||||
|
||||
called := false
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
called = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
handler := FrontendOriginWhitelist(next)
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1/auth/forgot-password", nil)
|
||||
req.Header.Set("Origin", defaultFrontendOrigin)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(recorder, req)
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusOK, recorder.Code)
|
||||
}
|
||||
if !called {
|
||||
t.Fatal("expected next handler to be called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFrontendOriginWhitelist_RejectsMissingOrigin(t *testing.T) {
|
||||
os.Unsetenv("FRONTEND_ORIGINS")
|
||||
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
handler := FrontendOriginWhitelist(next)
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1/auth/forgot-password", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(recorder, req)
|
||||
|
||||
if recorder.Code != http.StatusForbidden {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusForbidden, recorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFrontendOriginWhitelist_RejectsNonWhitelistedOrigin(t *testing.T) {
|
||||
os.Unsetenv("FRONTEND_ORIGINS")
|
||||
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
handler := FrontendOriginWhitelist(next)
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1/auth/forgot-password", nil)
|
||||
req.Header.Set("Origin", "http://malicious-site.example")
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(recorder, req)
|
||||
|
||||
if recorder.Code != http.StatusForbidden {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusForbidden, recorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFrontendOriginWhitelist_UsesConfiguredOrigins(t *testing.T) {
|
||||
os.Setenv("FRONTEND_ORIGINS", "http://localhost:4173, http://localhost:5173")
|
||||
defer os.Unsetenv("FRONTEND_ORIGINS")
|
||||
|
||||
called := false
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
called = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
handler := FrontendOriginWhitelist(next)
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1/auth/forgot-password", nil)
|
||||
req.Header.Set("Origin", "http://localhost:4173")
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(recorder, req)
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusOK, recorder.Code)
|
||||
}
|
||||
if !called {
|
||||
t.Fatal("expected next handler to be called")
|
||||
}
|
||||
}
|
||||
+5
-3
@@ -15,9 +15,11 @@ func SetupRoutes(router *mux.Router, db *sql.DB) {
|
||||
router.HandleFunc("/ready", handlers.ReadyHandler).Methods("GET")
|
||||
|
||||
authRoutes := router.PathPrefix("/v1/auth").Subrouter()
|
||||
authRoutes.HandleFunc("/login", handlers.GoogleLogin).Methods("GET")
|
||||
authRoutes.HandleFunc("/callback", handlers.GoogleCallback).Methods("GET")
|
||||
authRoutes.HandleFunc("/forgot-password", handlers.ForgotPassword).Methods("GET")
|
||||
frontendOnly := authRoutes.NewRoute().Subrouter()
|
||||
frontendOnly.Use(middleware.FrontendOriginWhitelist)
|
||||
frontendOnly.HandleFunc("/login", handlers.GoogleLogin).Methods("GET")
|
||||
frontendOnly.HandleFunc("/forgot-password", handlers.ForgotPassword).Methods("GET")
|
||||
frontendOnly.HandleFunc("/callback", handlers.GoogleCallback).Methods("GET")
|
||||
|
||||
csrfProtected := authRoutes.NewRoute().Subrouter()
|
||||
csrfProtected.Use(middleware.CSRFMiddleware)
|
||||
|
||||
+16
-16
@@ -34,10 +34,10 @@ func TestGetUser(t *testing.T) {
|
||||
email := "test@example.com"
|
||||
expectedID := "user123"
|
||||
|
||||
rows := sqlmock.NewRows([]string{"user_id"}).
|
||||
rows := sqlmock.NewRows([]string{"users_id"}).
|
||||
AddRow(expectedID)
|
||||
|
||||
mock.ExpectQuery(`SELECT user_id FROM users WHERE email_address = \? AND is_deleted = 0 LIMIT 1;`).
|
||||
mock.ExpectQuery(`SELECT users_id FROM uess_user_management\.users WHERE email_address = \? AND is_deleted = 0 LIMIT 1;`).
|
||||
WithArgs(email).
|
||||
WillReturnRows(rows)
|
||||
|
||||
@@ -49,7 +49,7 @@ func TestGetUserNotFound(t *testing.T) {
|
||||
|
||||
email := "nonexistent@example.com"
|
||||
|
||||
mock.ExpectQuery(`SELECT user_id FROM users WHERE email_address = \? AND is_deleted = 0 LIMIT 1;`).
|
||||
mock.ExpectQuery(`SELECT users_id FROM uess_user_management\.users WHERE email_address = \? AND is_deleted = 0 LIMIT 1;`).
|
||||
WithArgs(email).
|
||||
WillReturnError(sql.ErrNoRows)
|
||||
|
||||
@@ -71,10 +71,10 @@ func TestGetUserNullNames(t *testing.T) {
|
||||
email := "test@example.com"
|
||||
expectedID := "user456"
|
||||
|
||||
rows := sqlmock.NewRows([]string{"user_id"}).
|
||||
rows := sqlmock.NewRows([]string{"users_id"}).
|
||||
AddRow(expectedID)
|
||||
|
||||
mock.ExpectQuery(`SELECT user_id FROM users WHERE email_address = \? AND is_deleted = 0 LIMIT 1;`).
|
||||
mock.ExpectQuery(`SELECT users_id FROM uess_user_management\.users WHERE email_address = \? AND is_deleted = 0 LIMIT 1;`).
|
||||
WithArgs(email).
|
||||
WillReturnRows(rows)
|
||||
|
||||
@@ -97,10 +97,10 @@ func TestGetUserID(t *testing.T) {
|
||||
email := "test@example.com"
|
||||
expectedID := "user789"
|
||||
|
||||
rows := sqlmock.NewRows([]string{"user_id"}).
|
||||
rows := sqlmock.NewRows([]string{"users_id"}).
|
||||
AddRow(expectedID)
|
||||
|
||||
mock.ExpectQuery(`SELECT user_id FROM users WHERE email_address = \? AND is_deleted = 0 LIMIT 1;`).
|
||||
mock.ExpectQuery(`SELECT users_id FROM uess_user_management\.users WHERE email_address = \? AND is_deleted = 0 LIMIT 1;`).
|
||||
WithArgs(email).
|
||||
WillReturnRows(rows)
|
||||
|
||||
@@ -124,7 +124,7 @@ func TestCheckEmailInDB(t *testing.T) {
|
||||
rows := sqlmock.NewRows([]string{"exists"}).
|
||||
AddRow(true)
|
||||
|
||||
mock.ExpectQuery(`SELECT EXISTS \( SELECT 1 FROM users WHERE email_address = \? AND is_deleted = 0\)`).
|
||||
mock.ExpectQuery(`SELECT EXISTS \(\s*SELECT 1 FROM uess_user_management\.users WHERE email_address = \? AND is_deleted = 0\)`).
|
||||
WithArgs(email).
|
||||
WillReturnRows(rows)
|
||||
|
||||
@@ -152,7 +152,7 @@ func TestCheckEmailInDBNotExists(t *testing.T) {
|
||||
rows := sqlmock.NewRows([]string{"exists"}).
|
||||
AddRow(false)
|
||||
|
||||
mock.ExpectQuery(`SELECT EXISTS \( SELECT 1 FROM users WHERE email_address = \? AND is_deleted = 0\)`).
|
||||
mock.ExpectQuery(`SELECT EXISTS \(\s*SELECT 1 FROM uess_user_management\.users WHERE email_address = \? AND is_deleted = 0\)`).
|
||||
WithArgs(email).
|
||||
WillReturnRows(rows)
|
||||
|
||||
@@ -173,7 +173,7 @@ func TestCheckEmailInDBError(t *testing.T) {
|
||||
|
||||
email := "error@example.com"
|
||||
|
||||
mock.ExpectQuery(`SELECT EXISTS \( SELECT 1 FROM users WHERE email_address = \? AND is_deleted = 0\)`).
|
||||
mock.ExpectQuery(`SELECT EXISTS \(\s*SELECT 1 FROM uess_user_management\.users WHERE email_address = \? AND is_deleted = 0\)`).
|
||||
WithArgs(email).
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
|
||||
@@ -198,7 +198,7 @@ func TestGetUserIDFromEmail(t *testing.T) {
|
||||
rows := sqlmock.NewRows([]string{"id"}).
|
||||
AddRow(expectedID)
|
||||
|
||||
mock.ExpectQuery(`SELECT user_id FROM \( SELECT user_id, 1 AS priority FROM users WHERE email_address = \? AND is_deleted = 0 \) t ORDER BY priority ASC LIMIT 1;`).
|
||||
mock.ExpectQuery(`SELECT users_id\s+FROM \(\s*SELECT users_id, 1 AS priority\s+FROM users\s+WHERE email_address = \?\s+AND is_deleted = 0\s*\) t\s+ORDER BY priority ASC\s+LIMIT 1;`).
|
||||
WithArgs(email).
|
||||
WillReturnRows(rows)
|
||||
|
||||
@@ -223,7 +223,7 @@ func TestGetUserIDFromEmailNotFound(t *testing.T) {
|
||||
|
||||
email := "notfound@example.com"
|
||||
|
||||
mock.ExpectQuery(`SELECT user_id FROM \( SELECT user_id, 1 AS priority FROM users WHERE email_address = \? AND is_deleted = 0 \) t ORDER BY priority ASC LIMIT 1;`).
|
||||
mock.ExpectQuery(`SELECT users_id\s+FROM \(\s*SELECT users_id, 1 AS priority\s+FROM users\s+WHERE email_address = \?\s+AND is_deleted = 0\s*\) t\s+ORDER BY priority ASC\s+LIMIT 1;`).
|
||||
WithArgs(email).
|
||||
WillReturnError(sql.ErrNoRows)
|
||||
|
||||
@@ -244,7 +244,7 @@ func TestGetUserIDFromEmailDBError(t *testing.T) {
|
||||
|
||||
email := "error@example.com"
|
||||
|
||||
mock.ExpectQuery(`SELECT user_id FROM \( SELECT user_id, 1 AS priority FROM users WHERE email_address = \? AND is_deleted = 0 \) t ORDER BY priority ASC LIMIT 1;`).
|
||||
mock.ExpectQuery(`SELECT users_id\s+FROM \(\s*SELECT users_id, 1 AS priority\s+FROM users\s+WHERE email_address = \?\s+AND is_deleted = 0\s*\) t\s+ORDER BY priority ASC\s+LIMIT 1;`).
|
||||
WithArgs(email).
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
|
||||
@@ -279,10 +279,10 @@ func TestGetUserMultipleEmails(t *testing.T) {
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.email, func(t *testing.T) {
|
||||
rows := sqlmock.NewRows([]string{"user_id"}).
|
||||
rows := sqlmock.NewRows([]string{"users_id"}).
|
||||
AddRow(tc.userID)
|
||||
|
||||
mock.ExpectQuery(`SELECT user_id FROM users WHERE email_address = \? AND is_deleted = 0 LIMIT 1;`).
|
||||
mock.ExpectQuery(`SELECT users_id FROM uess_user_management\.users WHERE email_address = \? AND is_deleted = 0 LIMIT 1;`).
|
||||
WithArgs(tc.email).
|
||||
WillReturnRows(rows)
|
||||
|
||||
@@ -316,7 +316,7 @@ func TestCheckEmailInDBVariousEmails(t *testing.T) {
|
||||
rows := sqlmock.NewRows([]string{"exists"}).
|
||||
AddRow(exists)
|
||||
|
||||
mock.ExpectQuery(`SELECT EXISTS \( SELECT 1 FROM users WHERE email_address = \? AND is_deleted = 0\)`).
|
||||
mock.ExpectQuery(`SELECT EXISTS \(\s*SELECT 1 FROM uess_user_management\.users WHERE email_address = \? AND is_deleted = 0\)`).
|
||||
WithArgs(email).
|
||||
WillReturnRows(rows)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user