package handlers import ( "authentication/db" "authentication/helper" "authentication/models" "authentication/services" "context" "crypto/rand" "encoding/json" "errors" "flag" "fmt" "io" "log" "net" "net/http" "os" "strings" "time" "github.com/golang-jwt/jwt/v5" "github.com/joho/godotenv" "golang.org/x/oauth2" "golang.org/x/oauth2/google" ) var googleOauthConfig oauth2.Config var AuthorizationURL string const ( oauthStateCookieName = "oauth_state" oauthRedirectURICookieName = "oauth_redirect_uri" ) func isTestEnvironment() bool { return flag.Lookup("test.v") != nil || strings.Contains(os.Args[0], ".test") } // init initializes the Google OAuth2 configuration by loading environment variables // from a .env file. If the .env file cannot be loaded, it logs a fatal error. // Note: This init runs AFTER .env is loaded in main() init // But we need to load .env here too since init order is package-based func init() { cwd, _ := os.Getwd() log.Printf("[google_auth.init] Current working directory: %s", cwd) err := godotenv.Load() if err != nil { log.Printf("[google_auth.init] Failed to load .env: %v, trying .env explicitly", err) err = godotenv.Load(".env") if err != nil { log.Printf("[google_auth.init] Failed to load .env explicitly: %v", err) } } clientID := os.Getenv("GOOGLE_CLIENT_ID") clientSecret := os.Getenv("GOOGLE_CLIENT_SECRET") backendURL := os.Getenv("BACKEND_URL") if (clientID == "" || clientSecret == "" || backendURL == "") && isTestEnvironment() { if clientID == "" { clientID = "test-google-client-id" } if clientSecret == "" { clientSecret = "test-google-client-secret" } if backendURL == "" { backendURL = "http://localhost:8080" } log.Print("[google_auth.init] Using test fallback values for Google OAuth configuration") } log.Printf("[google_auth.init] GOOGLE_CLIENT_ID: '%s' (length: %d)", clientID, len(clientID)) log.Printf("[google_auth.init] GOOGLE_CLIENT_SECRET: '%s' (length: %d)", clientSecret, len(clientSecret)) log.Printf("[google_auth.init] BACKEND_URL: '%s'", backendURL) googleOauthConfig = oauth2.Config{ ClientID: clientID, ClientSecret: clientSecret, RedirectURL: fmt.Sprintf("%s/v1/auth/callback", backendURL), Scopes: []string{ "https://www.googleapis.com/auth/userinfo.email", "https://www.googleapis.com/auth/userinfo.profile", }, Endpoint: google.Endpoint, } log.Print("Redirect URL set to: ", googleOauthConfig.RedirectURL) if googleOauthConfig.ClientID == "" { log.Fatal("GOOGLE_CLIENT_ID is not set in environment variables") } if googleOauthConfig.ClientSecret == "" { log.Fatal("GOOGLE_CLIENT_SECRET is not set in environment variables") } AuthorizationURL = os.Getenv("AUTHORIZATION_URL") } func generateRandomState() string { b := make([]byte, 16) if _, err := rand.Read(b); err != nil { helper.LogError(err, "Error generating random state") return "" } return fmt.Sprintf("%x", b) } func GoogleLogin(w http.ResponseWriter, r *http.Request) { isSecure := strings.HasPrefix(os.Getenv("BACKEND_URL"), HTTPS) state := generateRandomState() helper.LogInfo(fmt.Sprintf("Generated oauth_state: %s", state)) redirectURI := strings.TrimSpace(r.URL.Query().Get("redirect_uri")) if redirectURI == "" { helper.RespondWithError(w, http.StatusBadRequest, "redirect_uri is required") return } if !IsAllowedRedirectURI(redirectURI) { helper.RespondWithError(w, http.StatusUnauthorized, "Unauthorized RedirectURI") return } http.SetCookie(w, &http.Cookie{ Name: oauthStateCookieName, Value: state, Path: "/", HttpOnly: true, Secure: isSecure, SameSite: http.SameSiteLaxMode, Expires: time.Now().Add(5 * time.Minute), }) http.SetCookie(w, &http.Cookie{ Name: oauthRedirectURICookieName, Value: redirectURI, Path: "/", HttpOnly: true, Secure: isSecure, SameSite: http.SameSiteLaxMode, Expires: time.Now().Add(5 * time.Minute), }) url := googleOauthConfig.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.ApprovalForce) http.Redirect(w, r, url, http.StatusFound) } func getIPAddress(r *http.Request) string { for header, values := range r.Header { for _, value := range values { helper.LogInfo(fmt.Sprintf("Header: %s = %s", header, value)) } } xForwardedFor := r.Header.Get("X-Forwarded-For") if xForwardedFor != "" { ips := strings.Split(xForwardedFor, ",") ip := strings.TrimSpace(ips[0]) if net.ParseIP(ip) != nil { return ip } } xRealIP := r.Header.Get("X-Real-IP") if xRealIP != "" && net.ParseIP(xRealIP) != nil { return xRealIP } ip, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { helper.LogError(err, "Error parsing remote address") return "" } parsedIP := net.ParseIP(ip) if parsedIP != nil && parsedIP.IsLoopback() { return "127.0.0.1" } return ip } func GoogleCallback(w http.ResponseWriter, r *http.Request) { callbackStart := time.Now() helper.LogInfo(fmt.Sprintf("[oauth-debug] callback start path=%s query=%s", r.URL.Path, r.URL.RawQuery)) ipAddress := getIPAddress(r) fmt.Printf("INFO: Extracted IP address: %s\n", ipAddress) helper.LogInfo(fmt.Sprintf("[oauth-debug] ip extraction done duration_ms=%d", time.Since(callbackStart).Milliseconds())) userAgent := r.Header.Get("User-Agent") stateStart := time.Now() if !validateState(w, r) { helper.LogInfo(fmt.Sprintf("[oauth-debug] state validation failed duration_ms=%d total_ms=%d", time.Since(stateStart).Milliseconds(), time.Since(callbackStart).Milliseconds())) return } helper.LogInfo(fmt.Sprintf("[oauth-debug] state validation ok duration_ms=%d", time.Since(stateStart).Milliseconds())) redirectURI, ok := callbackRedirectURI(w, r) if !ok { return } googleUserInfoStart := time.Now() userInfo, err := FetchGoogleUserInfo(w, r) if err != nil { helper.LogInfo(fmt.Sprintf("[oauth-debug] fetch google userinfo failed duration_ms=%d total_ms=%d", time.Since(googleUserInfoStart).Milliseconds(), time.Since(callbackStart).Milliseconds())) errMsg := err.Error() helper.LogError(err, "Failed to fetch Google user info") // Provide user-friendly error messages for different scenarios if strings.Contains(errMsg, "TLS handshake timeout") { helper.RespondWithError(w, http.StatusGatewayTimeout, "Connection to Google failed due to network issues. Please try again in a moment.") return } if strings.Contains(errMsg, "timeout") { helper.RespondWithError(w, http.StatusGatewayTimeout, "Request to Google took too long. Please try again.") return } if strings.Contains(errMsg, "connection refused") || strings.Contains(errMsg, "no such host") { helper.RespondWithError(w, http.StatusServiceUnavailable, "Unable to reach Google authentication servers. Please check your internet connection and try again.") return } if strings.Contains(errMsg, "status 401") { helper.RespondWithError(w, http.StatusUnauthorized, "Invalid authorization code. Please start the login process again.") return } if strings.Contains(errMsg, "status 403") { helper.RespondWithError(w, http.StatusForbidden, "Access to Google authentication was denied. Please try again later.") return } helper.RespondWithError(w, http.StatusBadGateway, "Failed to fetch user information from Google. Please try again.") return } helper.LogInfo(fmt.Sprintf("[oauth-debug] fetch google userinfo ok duration_ms=%d", time.Since(googleUserInfoStart).Milliseconds())) email := userInfo.Email emailCheckStart := time.Now() emailExists, err := checkEmailInDB(email) if err != nil { helper.LogInfo(fmt.Sprintf("[oauth-debug] email check failed duration_ms=%d total_ms=%d", time.Since(emailCheckStart).Milliseconds(), time.Since(callbackStart).Milliseconds())) helper.LogError(err, "Error checking email") helper.RespondWithError(w, http.StatusBadGateway, "Error checking email in database") return } helper.LogInfo(fmt.Sprintf("[oauth-debug] email check ok duration_ms=%d", time.Since(emailCheckStart).Milliseconds())) if !emailExists { helper.LogError(errors.New("unregistered email"), "Google login attempt with unregistered email: "+email) RedirectURL := fmt.Sprintf("%s/callback?error=%s=", redirectURI, "unregistered_email") http.Redirect(w, r, RedirectURL, http.StatusSeeOther) return } accessToken, refreshToken, err := GenerateTokens(email, userAgent, ipAddress) if err != nil { helper.LogInfo(fmt.Sprintf("[oauth-debug] token generation failed total_ms=%d", time.Since(callbackStart).Milliseconds())) helper.LogError(err, "Error generating access token") helper.RespondWithError(w, http.StatusInternalServerError, "Token generation failed") return } helper.LogInfo(fmt.Sprintf("[oauth-debug] token generation ok elapsed_ms=%d", time.Since(callbackStart).Milliseconds())) var refreshTokenExpiry time.Duration if emailExists { refreshTokenExpiry = 7 * 24 * time.Hour } else { refreshTokenExpiry = 2 * time.Hour } isSecure := strings.HasPrefix(os.Getenv("BACKEND_URL"), HTTPS) cookieConfig := &http.Cookie{ Name: "refresh_token", Value: refreshToken, Path: "/", HttpOnly: true, Expires: time.Now().Add(refreshTokenExpiry), } if isSecure { cookieConfig.Secure = true cookieConfig.SameSite = http.SameSiteLaxMode helper.LogInfo("Setting refresh_token cookie for PRODUCTION (secure=true)") } else { cookieConfig.Secure = false cookieConfig.SameSite = http.SameSiteLaxMode cookieConfig.Domain = "localhost" helper.LogInfo("Setting refresh_token cookie for DEVELOPMENT (secure=false, domain=localhost)") } http.SetCookie(w, cookieConfig) helper.LogInfo(fmt.Sprintf("Refresh token cookie set: Domain=%s, Secure=%v, HttpOnly=%v, SameSite=%v", cookieConfig.Domain, cookieConfig.Secure, cookieConfig.HttpOnly, cookieConfig.SameSite)) helper.LogInfo("Fetching first name for email: " + email) helper.LogInfo("Userinfo Email: " + userInfo.Email) userID, err := services.GetUserID(email) if err != nil { helper.LogInfo(fmt.Sprintf("[oauth-debug] get user id failed total_ms=%d", time.Since(callbackStart).Milliseconds())) helper.LogError(err, "Error fetching user") helper.RespondWithError(w, http.StatusBadGateway, "Failed to fetch user information") return } helper.LogInfo(fmt.Sprintf("[oauth-debug] get user id ok total_ms=%d", time.Since(callbackStart).Milliseconds())) // Dereference pointers to get actual string values helper.LogInfo("Access Token Generated Copy this: " + accessToken) loginLogStart := time.Now() err = helper.LogLoginEventV2(userID, ipAddress) if err != nil { helper.LogInfo(fmt.Sprintf("[oauth-debug] login event log failed duration_ms=%d total_ms=%d", time.Since(loginLogStart).Milliseconds(), time.Since(callbackStart).Milliseconds())) helper.LogError(err, fmt.Sprintf("Failed to log login event. user_id=%s ip=%s", userID, ipAddress)) helper.RespondWithError(w, http.StatusBadGateway, "Failed to Log Login Event") return } helper.LogInfo(fmt.Sprintf("[oauth-debug] login event log ok duration_ms=%d", time.Since(loginLogStart).Milliseconds())) helper.LogInfo("Copy this access token: " + accessToken) var RedirectURL string if strings.Contains(redirectURI, "com.") { //Request from mobile; append refresh_token to query params RedirectURL = fmt.Sprintf("%s/callback?token=%s&refresh_token=%s&user_id=%s", redirectURI, accessToken, refreshToken, userID) } RedirectURL = fmt.Sprintf("%s/callback?token=%s&user_id=%s", redirectURI, accessToken, userID) http.Redirect(w, r, RedirectURL, http.StatusSeeOther) } func validateState(w http.ResponseWriter, r *http.Request) bool { cookie, err := r.Cookie(oauthStateCookieName) callbackState := r.URL.Query().Get("state") if err != nil { helper.LogError(err, "oauth_state cookie missing or unreadable during callback") helper.LogWarn(errorInvalidState) helper.RespondWithError(w, http.StatusUnauthorized, errorInvalidState) return false } if strings.TrimSpace(callbackState) == "" { helper.LogWarn(errorInvalidState) helper.RespondWithError(w, http.StatusUnauthorized, errorInvalidState) return false } if callbackState != cookie.Value { helper.LogError(errors.New("oauth state mismatch"), fmt.Sprintf("OAuth state mismatch. cookie_state=%s callback_state=%s", cookie.Value, callbackState)) helper.LogWarn(errorInvalidState) helper.RespondWithError(w, http.StatusUnauthorized, errorInvalidState) return false } helper.LogInfo(fmt.Sprintf("Cookie state: %s, Callback state: %s", cookie.Value, callbackState)) return true } func callbackRedirectURI(w http.ResponseWriter, r *http.Request) (string, bool) { cookie, err := r.Cookie(oauthRedirectURICookieName) if err != nil { helper.LogError(err, "oauth redirect_uri cookie missing or unreadable during callback") helper.RespondWithError(w, http.StatusUnauthorized, "Unauthorized RedirectURI") return "", false } redirectURI := strings.TrimSpace(cookie.Value) if redirectURI == "" || !IsAllowedRedirectURI(redirectURI) { helper.RespondWithError(w, http.StatusUnauthorized, "Unauthorized RedirectURI") return "", false } return redirectURI, true } func FetchGoogleUserInfo(w http.ResponseWriter, r *http.Request) (models.UserGoogleInfo, error) { fetchStart := time.Now() code := r.URL.Query().Get("code") log.Print("Authorization code received: ", code) exchangeStart := time.Now() exchangeCtx, exchangeCancel := context.WithTimeout(context.Background(), 10*time.Second) defer exchangeCancel() token, err := googleOauthConfig.Exchange(exchangeCtx, code) if err != nil { helper.LogInfo(fmt.Sprintf("[oauth-debug] google exchange failed duration_ms=%d total_ms=%d", time.Since(exchangeStart).Milliseconds(), time.Since(fetchStart).Milliseconds())) helper.LogError(err, "Error exchanging authorization code for token") return models.UserGoogleInfo{}, fmt.Errorf("failed to exchange authorization code: %w", err) } helper.LogInfo(fmt.Sprintf("[oauth-debug] google exchange ok duration_ms=%d", time.Since(exchangeStart).Milliseconds())) helper.LogInfo(fmt.Sprintf("Access Token: %s", token.AccessToken)) // Create a context with a 30-second timeout for the userinfo request ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() client := googleOauthConfig.Client(ctx, token) req, err := http.NewRequest("GET", "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil) if err != nil { helper.LogError(err, "Error creating request to fetch user info") return models.UserGoogleInfo{}, fmt.Errorf("failed to create userinfo request: %w", err) } req.Header.Set("Authorization", bearerPrefix+token.AccessToken) req = req.WithContext(ctx) resp, err := client.Do(req) if err != nil { helper.LogInfo(fmt.Sprintf("[oauth-debug] google userinfo request failed duration_ms=%d total_ms=%d", time.Since(exchangeStart).Milliseconds(), time.Since(fetchStart).Milliseconds())) errMsg := fmt.Sprintf("Failed to fetch user info from Google: %v", err) helper.LogError(err, errMsg) // Provide more specific error messages for common issues if os.IsTimeout(err) { return models.UserGoogleInfo{}, fmt.Errorf("request timed out: Google userinfo endpoint took too long to respond (timeout: 30s)") } if strings.Contains(err.Error(), "net/http: TLS handshake timeout") { return models.UserGoogleInfo{}, fmt.Errorf("TLS handshake timeout: Unable to establish secure connection to Google") } if strings.Contains(err.Error(), "context deadline exceeded") { return models.UserGoogleInfo{}, fmt.Errorf("request deadline exceeded: Connection attempt exceeded 30 second timeout") } if strings.Contains(err.Error(), "connection refused") { return models.UserGoogleInfo{}, fmt.Errorf("connection refused: Cannot reach Google servers") } if strings.Contains(err.Error(), "no such host") { return models.UserGoogleInfo{}, fmt.Errorf("DNS resolution failed: Cannot resolve googleapis.com") } return models.UserGoogleInfo{}, fmt.Errorf("network error while fetching user info: %w", err) } defer func(Body io.ReadCloser) { err := Body.Close() if err != nil { helper.LogError(err, "Error closing response body") } }(resp.Body) // Check HTTP status code if resp.StatusCode != http.StatusOK { helper.LogInfo(fmt.Sprintf("[oauth-debug] google userinfo non-200 status=%d total_ms=%d", resp.StatusCode, time.Since(fetchStart).Milliseconds())) bodyBytes, _ := io.ReadAll(resp.Body) errMsg := fmt.Sprintf("Google API returned status %d: %s", resp.StatusCode, string(bodyBytes)) helper.LogError(nil, errMsg) return models.UserGoogleInfo{}, fmt.Errorf("google api error (status %d): %s", resp.StatusCode, string(bodyBytes)) } var userInfo models.UserGoogleInfo if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil { helper.LogInfo(fmt.Sprintf("[oauth-debug] google userinfo decode failed total_ms=%d", time.Since(fetchStart).Milliseconds())) helper.LogError(err, "Error decoding user info from Google response") return models.UserGoogleInfo{}, fmt.Errorf("failed to parse user info response: %w", err) } helper.LogInfo(fmt.Sprintf("[oauth-debug] fetch google userinfo complete total_ms=%d", time.Since(fetchStart).Milliseconds())) return userInfo, nil } func HandleTokenRefresh(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodOptions { w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") w.Header().Set("Access-Control-Max-Age", "3600") w.WriteHeader(http.StatusOK) return } // First, check if access token is provided and if it's expired helper.LogInfo("Refresh token handler called") authHeader := r.Header.Get("Authorization") helper.LogInfo("Authorization header: " + authHeader) if authHeader != "" && strings.HasPrefix(authHeader, bearerPrefix) { accessToken := strings.TrimSpace(strings.TrimPrefix(authHeader, bearerPrefix)) helper.LogInfo("Access token from header: " + accessToken) token, err := jwt.ParseWithClaims(accessToken, &models.AccessToken{}, func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } return &rsaPrivateKey.PublicKey, nil }) helper.LogInfo("Parsed token: " + fmt.Sprintf("%v", token)) if err == nil && token != nil && token.Claims != nil { if claims, ok := token.Claims.(*models.AccessToken); ok && claims != nil { if claims.Exp != 0 && claims.ExpiresAt != nil { helper.LogInfo("Token expiration timestamp: " + fmt.Sprintf("%v", claims.ExpiresAt.Unix())) helper.LogInfo("Current timestamp: " + fmt.Sprintf("%v", time.Now().Unix())) } else { helper.LogInfo("Token Exp is zero or ExpiresAt is nil") if claims.Exp != 0 { helper.LogInfo("Exp: " + fmt.Sprintf("%d (%s)", claims.Exp, time.Unix(claims.Exp, 0).Format(time.RFC3339))) } else { helper.LogInfo("Exp field is 0") } } helper.LogInfo("Token expiration (Exp field): " + fmt.Sprintf("%d", claims.Exp)) helper.LogInfo("Current time: " + fmt.Sprintf("%d", time.Now().Unix())) if claims.Exp < time.Now().Unix() { helper.LogInfo("Token is actually expired based on Exp field") } else { helper.LogInfo("Token is NOT expired based on Exp field") } helper.LogInfo("Token valid: " + fmt.Sprintf("%v", token.Valid)) // Always proceed to refresh when requested, regardless of current token validity helper.LogInfo("Access token present, but proceeding with refresh as requested") } else { helper.LogInfo("Failed to cast token claims to AccessToken or claims is nil") } } else { helper.LogInfo("Token parsing failed or token is nil. Error: " + fmt.Sprintf("%v", err)) } if err != nil && !strings.Contains(err.Error(), "expired") && !strings.Contains(err.Error(), "used before issued") { helper.LogError(err, "Invalid access token format") helper.RespondWithError(w, http.StatusBadRequest, err.Error()) return } helper.LogInfo("Access token is expired or invalid, proceeding with refresh") } // Log all cookies for debugging helper.LogInfo("TRACE: All cookies in request: " + fmt.Sprintf("%d cookies", len(r.Cookies()))) for i, cookie := range r.Cookies() { helper.LogInfo(fmt.Sprintf("TRACE: Cookie %d: Name=%s, Value-length=%d, Domain=%s, Path=%s", i, cookie.Name, len(cookie.Value), cookie.Domain, cookie.Path)) } var refreshToken string cookie, err := r.Cookie("refresh_token") helper.LogInfo("TRACE: Cookie retrieval - error: " + fmt.Sprintf("%v", err)) if err == nil && cookie.Value != "" { refreshToken = cookie.Value helper.LogInfo("TRACE: Refresh token retrieved from cookie") } else { var body struct { RefreshToken string `json:"refresh_token"` } decodeErr := json.NewDecoder(r.Body).Decode(&body) if decodeErr == nil && body.RefreshToken != "" { refreshToken = body.RefreshToken helper.LogInfo("TRACE: Refresh token retrieved from request body") } } if refreshToken == "" { helper.LogError(errors.New("refresh token not found"), "Refresh token not found in cookie or body") helper.RespondWithError(w, http.StatusUnauthorized, "Refresh token not found") return } helper.LogInfo("TRACE: Refresh token length: " + fmt.Sprintf("%d", len(refreshToken))) // Get client info for security validation userAgent := r.Header.Get("User-Agent") ipAddress := getIPAddress(r) // Try to extract email from access token for fallback during refresh var emailFromToken string if authHeader != "" && strings.HasPrefix(authHeader, bearerPrefix) { accessToken := strings.TrimSpace(strings.TrimPrefix(authHeader, bearerPrefix)) if token, err := jwt.ParseWithClaims(accessToken, &models.AccessToken{}, func(token *jwt.Token) (interface{}, error) { return []byte(os.Getenv("JWT_SECRET_KEY")), nil }); err == nil { if claims, ok := token.Claims.(*models.AccessToken); ok && claims.Email != "" { emailFromToken = claims.Email helper.LogInfo("TRACE: Extracted email from access token for fallback: " + emailFromToken) } } } // Use the improved RefreshAccessToken function newAccessToken, err := GenerateTokensFromRefreshWithEmail(refreshToken, userAgent, ipAddress, emailFromToken) helper.LogInfo("New access token: " + newAccessToken) helper.LogInfo("New access token length: " + fmt.Sprintf("%d", len(newAccessToken))) if newAccessToken == "" { helper.LogError(errors.New("generated access token is empty"), "Generated access token is empty") helper.RespondWithError(w, http.StatusUnauthorized, "Failed to generate new access token") } if err != nil { helper.LogError(err, "Failed to refresh access token") // Return specific error messages if strings.Contains(err.Error(), "too many refresh attempts") { helper.RespondWithError(w, http.StatusTooManyRequests, "Too many refresh attempts, please wait") return } if strings.Contains(err.Error(), "expired") || strings.Contains(err.Error(), "revoked") { helper.RespondWithError(w, http.StatusUnauthorized, "Session expired, please login again") return } helper.RespondWithError(w, http.StatusUnauthorized, "Invalid refresh token") return } var expiresInSeconds int env := os.Getenv("GO_ENV") if env == "production" || env == "canary" { expiresInSeconds = 45 * 60 } else { expiresInSeconds = 15 * 60 } response := map[string]interface{}{ "access_token": newAccessToken, "token_type": "Bearer", "expires_in": expiresInSeconds, } helper.LogInfo("TRACE: About to send response: " + fmt.Sprintf("%+v", response)) w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) if err := json.NewEncoder(w).Encode(response); err != nil { helper.LogError(err, "Failed to encode response") } else { helper.LogInfo("TRACE: Response successfully encoded and sent") } } // GenerateTokensFromRefresh creates a new access token from a refresh token func GenerateTokensFromRefresh(refreshToken, userAgent, ipAddress string) (string, error) { helper.LogInfo("TRACE: GenerateTokensFromRefresh called") helper.LogInfo("TRACE: refreshToken length: " + fmt.Sprintf("%d", len(refreshToken))) helper.LogInfo("TRACE: userAgent: " + userAgent) helper.LogInfo("TRACE: ipAddress: " + ipAddress) result, err := RefreshAccessToken(refreshToken, userAgent, ipAddress) helper.LogInfo("TRACE: RefreshAccessToken returned - token length: " + fmt.Sprintf("%d", len(result)) + ", error: " + fmt.Sprintf("%v", err)) return result, err } // GenerateTokensFromRefreshWithEmail creates a new access token from a refresh token with email fallback func GenerateTokensFromRefreshWithEmail(refreshToken, userAgent, ipAddress, emailFallback string) (string, error) { helper.LogInfo("TRACE: GenerateTokensFromRefreshWithEmail called") helper.LogInfo("TRACE: refreshToken length: " + fmt.Sprintf("%d", len(refreshToken))) helper.LogInfo("TRACE: userAgent: " + userAgent) helper.LogInfo("TRACE: ipAddress: " + ipAddress) helper.LogInfo("TRACE: emailFallback: " + emailFallback) result, err := RefreshAccessTokenWithEmailFallback(refreshToken, userAgent, ipAddress, emailFallback) helper.LogInfo("TRACE: RefreshAccessTokenWithEmailFallback returned - token length: " + fmt.Sprintf("%d", len(result)) + ", error: " + fmt.Sprintf("%v", err)) return result, err } func checkEmailInDB(email string) (bool, error) { if db.DB == nil { helper.LogError(nil, dbConnNilError) return false, errors.New(dbConnNilError) } exists, err := services.CheckEmailInDB(email) if err != nil { return false, err } helper.LogInfo("Email exists in DB: " + fmt.Sprintf("%v", exists)) return exists, nil } func LogoutHandler(w http.ResponseWriter, r *http.Request) { authHeader := r.Header.Get("Authorization") clearRefreshTokenCookie(w) clearCSRFCookie(w) if isValidAuthHeader(authHeader) { tokenString := strings.TrimSpace(strings.TrimPrefix(authHeader, bearerPrefix)) if tokenString != "" { token, err := jwt.ParseWithClaims(tokenString, &models.AccessToken{}, func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } if rsaPrivateKey == nil { return nil, errors.New("RSA private key is not initialized") } return &rsaPrivateKey.PublicKey, nil }) if err == nil { if claims, ok := token.Claims.(*models.AccessToken); ok { userID, err := services.GetUserIDFromEmail(claims.Email) if err == nil { if err := RevokeAllUserSessions(userID); err != nil { helper.LogError(err, "Failed to revoke user sessions during logout") } } else { helper.LogError(err, "Failed to get user ID during logout") } } } else { helper.LogError(err, "Failed to parse JWT token during logout") } } else { helper.LogWarn("Authorization header contains empty bearer token during logout") } } else { helper.LogWarn("Authorization header missing or invalid during logout; proceeding with cookie clear only") } if err := accessLog(r, nil, 18, nil); err != nil { helper.LogError(err, "Failed to write access log during logout") } response := map[string]interface{}{ "message": "Successfully logged out", "action": "clear_session_storage", "keys": []string{"refresh_token", "access_token"}, } w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) if err := json.NewEncoder(w).Encode(response); err != nil { helper.LogError(err, "Failed to encode logout response") } } func isValidAuthHeader(authHeader string) bool { return authHeader != "" && strings.HasPrefix(authHeader, bearerPrefix) } func clearRefreshTokenCookie(w http.ResponseWriter) { helper.LogInfo("Clearing refresh_token cookie...") isSecure := strings.HasPrefix(os.Getenv("BACKEND_URL"), HTTPS) helper.LogInfo(fmt.Sprintf("Cookie clearing - isSecure: %v, BACKEND_URL: %s", isSecure, os.Getenv("BACKEND_URL"))) cookieConfig := &http.Cookie{ Name: "refresh_token", Value: "", Path: "/", HttpOnly: true, Expires: time.Unix(0, 0), MaxAge: -1, } if isSecure { cookieConfig.Secure = true cookieConfig.SameSite = http.SameSiteLaxMode helper.LogInfo("Setting cookie clear for PRODUCTION (secure=true)") } else { cookieConfig.Secure = false cookieConfig.SameSite = http.SameSiteLaxMode cookieConfig.Domain = "localhost" helper.LogInfo("Setting cookie clear for DEVELOPMENT (secure=false, domain=localhost)") } http.SetCookie(w, cookieConfig) helper.LogInfo(fmt.Sprintf("Cookie clear #1 sent: Name=%s, Value=%s, Domain=%s, Secure=%v, HttpOnly=%v", cookieConfig.Name, cookieConfig.Value, cookieConfig.Domain, cookieConfig.Secure, cookieConfig.HttpOnly)) fallbackCookie := &http.Cookie{ Name: "refresh_token", Value: "", Path: "/", HttpOnly: true, Secure: isSecure, SameSite: http.SameSiteLaxMode, Expires: time.Unix(0, 0), MaxAge: -1, } http.SetCookie(w, fallbackCookie) helper.LogInfo(fmt.Sprintf("Cookie clear #2 sent: Name=%s, Value=%s, Domain=%s, Secure=%v, HttpOnly=%v", fallbackCookie.Name, fallbackCookie.Value, fallbackCookie.Domain, fallbackCookie.Secure, fallbackCookie.HttpOnly)) helper.LogInfo("Refresh token cookie clearing commands sent to browser") } func clearCSRFCookie(w http.ResponseWriter) { helper.LogInfo("Clearing csrf_token cookie...") isSecure := strings.HasPrefix(os.Getenv("BACKEND_URL"), HTTPS) // Match middleware cookie characteristics first (host-only, SameSiteStrict) primaryCookie := &http.Cookie{ Name: "csrf_token", Value: "", Path: "/", HttpOnly: true, Secure: true, SameSite: http.SameSiteStrictMode, Expires: time.Unix(0, 0), MaxAge: -1, } http.SetCookie(w, primaryCookie) helper.LogInfo(fmt.Sprintf("CSRF cookie clear #1 sent: Name=%s, Domain=%s, Secure=%v, SameSite=%v", primaryCookie.Name, primaryCookie.Domain, primaryCookie.Secure, primaryCookie.SameSite)) // Fallback for local/dev browser behavior where secure or samesite attributes differ fallbackCookie := &http.Cookie{ Name: "csrf_token", Value: "", Path: "/", HttpOnly: true, Secure: isSecure, SameSite: http.SameSiteLaxMode, Expires: time.Unix(0, 0), MaxAge: -1, } http.SetCookie(w, fallbackCookie) helper.LogInfo(fmt.Sprintf("CSRF cookie clear #2 sent: Name=%s, Domain=%s, Secure=%v, SameSite=%v", fallbackCookie.Name, fallbackCookie.Domain, fallbackCookie.Secure, fallbackCookie.SameSite)) if !isSecure { localhostCookie := &http.Cookie{ Name: "csrf_token", Value: "", Path: "/", Domain: "localhost", HttpOnly: true, Secure: false, SameSite: http.SameSiteLaxMode, Expires: time.Unix(0, 0), MaxAge: -1, } http.SetCookie(w, localhostCookie) helper.LogInfo(fmt.Sprintf("CSRF cookie clear #3 sent: Name=%s, Domain=%s, Secure=%v, SameSite=%v", localhostCookie.Name, localhostCookie.Domain, localhostCookie.Secure, localhostCookie.SameSite)) } helper.LogInfo("CSRF token cookie clearing commands sent to browser") }