816 lines
30 KiB
Go
816 lines
30 KiB
Go
package handlers
|
|
|
|
import (
|
|
"authentication/db"
|
|
"authentication/helper"
|
|
"authentication/models"
|
|
"authentication/services"
|
|
"context"
|
|
"crypto/rand"
|
|
"encoding/json"
|
|
"errors"
|
|
"flag"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/golang-jwt/jwt/v5"
|
|
"github.com/joho/godotenv"
|
|
|
|
"golang.org/x/oauth2"
|
|
"golang.org/x/oauth2/google"
|
|
)
|
|
|
|
var googleOauthConfig oauth2.Config
|
|
var AuthorizationURL string
|
|
|
|
const (
|
|
oauthStateCookieName = "oauth_state"
|
|
oauthRedirectURICookieName = "oauth_redirect_uri"
|
|
)
|
|
|
|
func isTestEnvironment() bool {
|
|
return flag.Lookup("test.v") != nil || strings.Contains(os.Args[0], ".test")
|
|
}
|
|
|
|
// init initializes the Google OAuth2 configuration by loading environment variables
|
|
// from a .env file. If the .env file cannot be loaded, it logs a fatal error.
|
|
// Note: This init runs AFTER .env is loaded in main() init
|
|
// But we need to load .env here too since init order is package-based
|
|
func init() {
|
|
cwd, _ := os.Getwd()
|
|
log.Printf("[google_auth.init] Current working directory: %s", cwd)
|
|
|
|
err := godotenv.Load()
|
|
if err != nil {
|
|
log.Printf("[google_auth.init] Failed to load .env: %v, trying .env explicitly", err)
|
|
err = godotenv.Load(".env")
|
|
if err != nil {
|
|
log.Printf("[google_auth.init] Failed to load .env explicitly: %v", err)
|
|
}
|
|
}
|
|
|
|
clientID := os.Getenv("GOOGLE_CLIENT_ID")
|
|
clientSecret := os.Getenv("GOOGLE_CLIENT_SECRET")
|
|
backendURL := os.Getenv("BACKEND_URL")
|
|
|
|
if (clientID == "" || clientSecret == "" || backendURL == "") && isTestEnvironment() {
|
|
if clientID == "" {
|
|
clientID = "test-google-client-id"
|
|
}
|
|
if clientSecret == "" {
|
|
clientSecret = "test-google-client-secret"
|
|
}
|
|
if backendURL == "" {
|
|
backendURL = "http://localhost:8080"
|
|
}
|
|
log.Print("[google_auth.init] Using test fallback values for Google OAuth configuration")
|
|
}
|
|
|
|
log.Printf("[google_auth.init] GOOGLE_CLIENT_ID: '%s' (length: %d)", clientID, len(clientID))
|
|
log.Printf("[google_auth.init] GOOGLE_CLIENT_SECRET: '%s' (length: %d)", clientSecret, len(clientSecret))
|
|
log.Printf("[google_auth.init] BACKEND_URL: '%s'", backendURL)
|
|
|
|
googleOauthConfig = oauth2.Config{
|
|
ClientID: clientID,
|
|
ClientSecret: clientSecret,
|
|
RedirectURL: fmt.Sprintf("%s/v1/auth/callback", backendURL),
|
|
Scopes: []string{
|
|
"https://www.googleapis.com/auth/userinfo.email",
|
|
"https://www.googleapis.com/auth/userinfo.profile",
|
|
},
|
|
Endpoint: google.Endpoint,
|
|
}
|
|
log.Print("Redirect URL set to: ", googleOauthConfig.RedirectURL)
|
|
if googleOauthConfig.ClientID == "" {
|
|
log.Fatal("GOOGLE_CLIENT_ID is not set in environment variables")
|
|
}
|
|
|
|
if googleOauthConfig.ClientSecret == "" {
|
|
log.Fatal("GOOGLE_CLIENT_SECRET is not set in environment variables")
|
|
}
|
|
|
|
AuthorizationURL = os.Getenv("AUTHORIZATION_URL")
|
|
|
|
}
|
|
|
|
func generateRandomState() string {
|
|
b := make([]byte, 16)
|
|
if _, err := rand.Read(b); err != nil {
|
|
helper.LogError(err, "Error generating random state")
|
|
return ""
|
|
}
|
|
return fmt.Sprintf("%x", b)
|
|
}
|
|
|
|
func GoogleLogin(w http.ResponseWriter, r *http.Request) {
|
|
isSecure := strings.HasPrefix(os.Getenv("BACKEND_URL"), HTTPS)
|
|
state := generateRandomState()
|
|
helper.LogInfo(fmt.Sprintf("Generated oauth_state: %s", state))
|
|
|
|
redirectURI := strings.TrimSpace(r.URL.Query().Get("redirect_uri"))
|
|
if redirectURI == "" {
|
|
helper.RespondWithError(w, http.StatusBadRequest, "redirect_uri is required")
|
|
return
|
|
}
|
|
|
|
if !IsAllowedRedirectURI(redirectURI) {
|
|
helper.RespondWithError(w, http.StatusUnauthorized, "Unauthorized RedirectURI")
|
|
return
|
|
}
|
|
|
|
http.SetCookie(w, &http.Cookie{
|
|
Name: oauthStateCookieName,
|
|
Value: state,
|
|
Path: "/",
|
|
HttpOnly: true,
|
|
Secure: isSecure,
|
|
SameSite: http.SameSiteLaxMode,
|
|
Expires: time.Now().Add(5 * time.Minute),
|
|
})
|
|
http.SetCookie(w, &http.Cookie{
|
|
Name: oauthRedirectURICookieName,
|
|
Value: redirectURI,
|
|
Path: "/",
|
|
HttpOnly: true,
|
|
Secure: isSecure,
|
|
SameSite: http.SameSiteLaxMode,
|
|
Expires: time.Now().Add(5 * time.Minute),
|
|
})
|
|
|
|
url := googleOauthConfig.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.ApprovalForce)
|
|
http.Redirect(w, r, url, http.StatusFound)
|
|
}
|
|
|
|
func getIPAddress(r *http.Request) string {
|
|
for header, values := range r.Header {
|
|
for _, value := range values {
|
|
helper.LogInfo(fmt.Sprintf("Header: %s = %s", header, value))
|
|
}
|
|
}
|
|
|
|
xForwardedFor := r.Header.Get("X-Forwarded-For")
|
|
if xForwardedFor != "" {
|
|
ips := strings.Split(xForwardedFor, ",")
|
|
ip := strings.TrimSpace(ips[0])
|
|
if net.ParseIP(ip) != nil {
|
|
return ip
|
|
}
|
|
}
|
|
|
|
xRealIP := r.Header.Get("X-Real-IP")
|
|
if xRealIP != "" && net.ParseIP(xRealIP) != nil {
|
|
return xRealIP
|
|
}
|
|
|
|
ip, _, err := net.SplitHostPort(r.RemoteAddr)
|
|
if err != nil {
|
|
helper.LogError(err, "Error parsing remote address")
|
|
return ""
|
|
}
|
|
|
|
parsedIP := net.ParseIP(ip)
|
|
if parsedIP != nil && parsedIP.IsLoopback() {
|
|
return "127.0.0.1"
|
|
}
|
|
|
|
return ip
|
|
}
|
|
|
|
func GoogleCallback(w http.ResponseWriter, r *http.Request) {
|
|
callbackStart := time.Now()
|
|
helper.LogInfo(fmt.Sprintf("[oauth-debug] callback start path=%s query=%s", r.URL.Path, r.URL.RawQuery))
|
|
|
|
ipAddress := getIPAddress(r)
|
|
fmt.Printf("INFO: Extracted IP address: %s\n", ipAddress)
|
|
helper.LogInfo(fmt.Sprintf("[oauth-debug] ip extraction done duration_ms=%d", time.Since(callbackStart).Milliseconds()))
|
|
|
|
userAgent := r.Header.Get("User-Agent")
|
|
|
|
stateStart := time.Now()
|
|
if !validateState(w, r) {
|
|
helper.LogInfo(fmt.Sprintf("[oauth-debug] state validation failed duration_ms=%d total_ms=%d", time.Since(stateStart).Milliseconds(), time.Since(callbackStart).Milliseconds()))
|
|
return
|
|
}
|
|
helper.LogInfo(fmt.Sprintf("[oauth-debug] state validation ok duration_ms=%d", time.Since(stateStart).Milliseconds()))
|
|
|
|
redirectURI, ok := callbackRedirectURI(w, r)
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
googleUserInfoStart := time.Now()
|
|
userInfo, err := FetchGoogleUserInfo(w, r)
|
|
if err != nil {
|
|
helper.LogInfo(fmt.Sprintf("[oauth-debug] fetch google userinfo failed duration_ms=%d total_ms=%d", time.Since(googleUserInfoStart).Milliseconds(), time.Since(callbackStart).Milliseconds()))
|
|
errMsg := err.Error()
|
|
helper.LogError(err, "Failed to fetch Google user info")
|
|
|
|
// Provide user-friendly error messages for different scenarios
|
|
if strings.Contains(errMsg, "TLS handshake timeout") {
|
|
helper.RespondWithError(w, http.StatusGatewayTimeout, "Connection to Google failed due to network issues. Please try again in a moment.")
|
|
return
|
|
}
|
|
if strings.Contains(errMsg, "timeout") {
|
|
helper.RespondWithError(w, http.StatusGatewayTimeout, "Request to Google took too long. Please try again.")
|
|
return
|
|
}
|
|
if strings.Contains(errMsg, "connection refused") || strings.Contains(errMsg, "no such host") {
|
|
helper.RespondWithError(w, http.StatusServiceUnavailable, "Unable to reach Google authentication servers. Please check your internet connection and try again.")
|
|
return
|
|
}
|
|
if strings.Contains(errMsg, "status 401") {
|
|
helper.RespondWithError(w, http.StatusUnauthorized, "Invalid authorization code. Please start the login process again.")
|
|
return
|
|
}
|
|
if strings.Contains(errMsg, "status 403") {
|
|
helper.RespondWithError(w, http.StatusForbidden, "Access to Google authentication was denied. Please try again later.")
|
|
return
|
|
}
|
|
|
|
helper.RespondWithError(w, http.StatusBadGateway, "Failed to fetch user information from Google. Please try again.")
|
|
return
|
|
}
|
|
helper.LogInfo(fmt.Sprintf("[oauth-debug] fetch google userinfo ok duration_ms=%d", time.Since(googleUserInfoStart).Milliseconds()))
|
|
|
|
email := userInfo.Email
|
|
|
|
emailCheckStart := time.Now()
|
|
emailExists, err := checkEmailInDB(email)
|
|
if err != nil {
|
|
helper.LogInfo(fmt.Sprintf("[oauth-debug] email check failed duration_ms=%d total_ms=%d", time.Since(emailCheckStart).Milliseconds(), time.Since(callbackStart).Milliseconds()))
|
|
helper.LogError(err, "Error checking email")
|
|
helper.RespondWithError(w, http.StatusBadGateway, "Error checking email in database")
|
|
return
|
|
}
|
|
helper.LogInfo(fmt.Sprintf("[oauth-debug] email check ok duration_ms=%d", time.Since(emailCheckStart).Milliseconds()))
|
|
|
|
if !emailExists {
|
|
helper.LogError(errors.New("unregistered email"), "Google login attempt with unregistered email: "+email)
|
|
RedirectURL := fmt.Sprintf("%s/callback?error=%s=", redirectURI, "unregistered_email")
|
|
http.Redirect(w, r, RedirectURL, http.StatusSeeOther)
|
|
return
|
|
}
|
|
|
|
accessToken, refreshToken, err := GenerateTokens(email, userAgent, ipAddress)
|
|
if err != nil {
|
|
helper.LogInfo(fmt.Sprintf("[oauth-debug] token generation failed total_ms=%d", time.Since(callbackStart).Milliseconds()))
|
|
helper.LogError(err, "Error generating access token")
|
|
helper.RespondWithError(w, http.StatusInternalServerError, "Token generation failed")
|
|
return
|
|
}
|
|
helper.LogInfo(fmt.Sprintf("[oauth-debug] token generation ok elapsed_ms=%d", time.Since(callbackStart).Milliseconds()))
|
|
|
|
var refreshTokenExpiry time.Duration
|
|
if emailExists {
|
|
refreshTokenExpiry = 7 * 24 * time.Hour
|
|
} else {
|
|
refreshTokenExpiry = 2 * time.Hour
|
|
}
|
|
|
|
isSecure := strings.HasPrefix(os.Getenv("BACKEND_URL"), HTTPS)
|
|
|
|
cookieConfig := &http.Cookie{
|
|
Name: "refresh_token",
|
|
Value: refreshToken,
|
|
Path: "/",
|
|
HttpOnly: true,
|
|
Expires: time.Now().Add(refreshTokenExpiry),
|
|
}
|
|
|
|
if isSecure {
|
|
cookieConfig.Secure = true
|
|
cookieConfig.SameSite = http.SameSiteLaxMode
|
|
helper.LogInfo("Setting refresh_token cookie for PRODUCTION (secure=true)")
|
|
} else {
|
|
cookieConfig.Secure = false
|
|
cookieConfig.SameSite = http.SameSiteLaxMode
|
|
cookieConfig.Domain = "localhost"
|
|
helper.LogInfo("Setting refresh_token cookie for DEVELOPMENT (secure=false, domain=localhost)")
|
|
}
|
|
|
|
http.SetCookie(w, cookieConfig)
|
|
helper.LogInfo(fmt.Sprintf("Refresh token cookie set: Domain=%s, Secure=%v, HttpOnly=%v, SameSite=%v",
|
|
cookieConfig.Domain, cookieConfig.Secure, cookieConfig.HttpOnly, cookieConfig.SameSite))
|
|
|
|
helper.LogInfo("Fetching first name for email: " + email)
|
|
helper.LogInfo("Userinfo Email: " + userInfo.Email)
|
|
|
|
userID, err := services.GetUserID(email)
|
|
if err != nil {
|
|
helper.LogInfo(fmt.Sprintf("[oauth-debug] get user id failed total_ms=%d", time.Since(callbackStart).Milliseconds()))
|
|
helper.LogError(err, "Error fetching user")
|
|
helper.RespondWithError(w, http.StatusBadGateway, "Failed to fetch user information")
|
|
return
|
|
}
|
|
helper.LogInfo(fmt.Sprintf("[oauth-debug] get user id ok total_ms=%d", time.Since(callbackStart).Milliseconds()))
|
|
|
|
// Dereference pointers to get actual string values
|
|
|
|
helper.LogInfo("Access Token Generated Copy this: " + accessToken)
|
|
|
|
loginLogStart := time.Now()
|
|
err = helper.LogLoginEventV2(userID, ipAddress)
|
|
if err != nil {
|
|
helper.LogInfo(fmt.Sprintf("[oauth-debug] login event log failed duration_ms=%d total_ms=%d", time.Since(loginLogStart).Milliseconds(), time.Since(callbackStart).Milliseconds()))
|
|
helper.LogError(err, fmt.Sprintf("Failed to log login event. user_id=%s ip=%s", userID, ipAddress))
|
|
helper.RespondWithError(w, http.StatusBadGateway, "Failed to Log Login Event")
|
|
return
|
|
}
|
|
helper.LogInfo(fmt.Sprintf("[oauth-debug] login event log ok duration_ms=%d", time.Since(loginLogStart).Milliseconds()))
|
|
|
|
helper.LogInfo("Copy this access token: " + accessToken)
|
|
|
|
RedirectURL := fmt.Sprintf("%s/callback?token=%s&user_id=%s", redirectURI, accessToken, userID)
|
|
http.Redirect(w, r, RedirectURL, http.StatusSeeOther)
|
|
}
|
|
|
|
func validateState(w http.ResponseWriter, r *http.Request) bool {
|
|
cookie, err := r.Cookie(oauthStateCookieName)
|
|
callbackState := r.URL.Query().Get("state")
|
|
if err != nil {
|
|
helper.LogError(err, "oauth_state cookie missing or unreadable during callback")
|
|
helper.LogWarn(errorInvalidState)
|
|
helper.RespondWithError(w, http.StatusUnauthorized, errorInvalidState)
|
|
return false
|
|
}
|
|
|
|
if strings.TrimSpace(callbackState) == "" {
|
|
helper.LogWarn(errorInvalidState)
|
|
helper.RespondWithError(w, http.StatusUnauthorized, errorInvalidState)
|
|
return false
|
|
}
|
|
|
|
if callbackState != cookie.Value {
|
|
helper.LogError(errors.New("oauth state mismatch"), fmt.Sprintf("OAuth state mismatch. cookie_state=%s callback_state=%s", cookie.Value, callbackState))
|
|
helper.LogWarn(errorInvalidState)
|
|
helper.RespondWithError(w, http.StatusUnauthorized, errorInvalidState)
|
|
return false
|
|
}
|
|
helper.LogInfo(fmt.Sprintf("Cookie state: %s, Callback state: %s", cookie.Value, callbackState))
|
|
return true
|
|
}
|
|
|
|
func callbackRedirectURI(w http.ResponseWriter, r *http.Request) (string, bool) {
|
|
cookie, err := r.Cookie(oauthRedirectURICookieName)
|
|
if err != nil {
|
|
helper.LogError(err, "oauth redirect_uri cookie missing or unreadable during callback")
|
|
helper.RespondWithError(w, http.StatusUnauthorized, "Unauthorized RedirectURI")
|
|
return "", false
|
|
}
|
|
|
|
redirectURI := strings.TrimSpace(cookie.Value)
|
|
if redirectURI == "" || !IsAllowedRedirectURI(redirectURI) {
|
|
helper.RespondWithError(w, http.StatusUnauthorized, "Unauthorized RedirectURI")
|
|
return "", false
|
|
}
|
|
|
|
return redirectURI, true
|
|
}
|
|
|
|
func FetchGoogleUserInfo(w http.ResponseWriter, r *http.Request) (models.UserGoogleInfo, error) {
|
|
fetchStart := time.Now()
|
|
code := r.URL.Query().Get("code")
|
|
log.Print("Authorization code received: ", code)
|
|
exchangeStart := time.Now()
|
|
exchangeCtx, exchangeCancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
defer exchangeCancel()
|
|
token, err := googleOauthConfig.Exchange(exchangeCtx, code)
|
|
if err != nil {
|
|
helper.LogInfo(fmt.Sprintf("[oauth-debug] google exchange failed duration_ms=%d total_ms=%d", time.Since(exchangeStart).Milliseconds(), time.Since(fetchStart).Milliseconds()))
|
|
helper.LogError(err, "Error exchanging authorization code for token")
|
|
return models.UserGoogleInfo{}, fmt.Errorf("failed to exchange authorization code: %w", err)
|
|
}
|
|
helper.LogInfo(fmt.Sprintf("[oauth-debug] google exchange ok duration_ms=%d", time.Since(exchangeStart).Milliseconds()))
|
|
|
|
helper.LogInfo(fmt.Sprintf("Access Token: %s", token.AccessToken))
|
|
|
|
// Create a context with a 30-second timeout for the userinfo request
|
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
defer cancel()
|
|
|
|
client := googleOauthConfig.Client(ctx, token)
|
|
req, err := http.NewRequest("GET", "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
|
|
if err != nil {
|
|
helper.LogError(err, "Error creating request to fetch user info")
|
|
return models.UserGoogleInfo{}, fmt.Errorf("failed to create userinfo request: %w", err)
|
|
}
|
|
req.Header.Set("Authorization", bearerPrefix+token.AccessToken)
|
|
req = req.WithContext(ctx)
|
|
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
helper.LogInfo(fmt.Sprintf("[oauth-debug] google userinfo request failed duration_ms=%d total_ms=%d", time.Since(exchangeStart).Milliseconds(), time.Since(fetchStart).Milliseconds()))
|
|
errMsg := fmt.Sprintf("Failed to fetch user info from Google: %v", err)
|
|
helper.LogError(err, errMsg)
|
|
|
|
// Provide more specific error messages for common issues
|
|
if os.IsTimeout(err) {
|
|
return models.UserGoogleInfo{}, fmt.Errorf("request timed out: Google userinfo endpoint took too long to respond (timeout: 30s)")
|
|
}
|
|
if strings.Contains(err.Error(), "net/http: TLS handshake timeout") {
|
|
return models.UserGoogleInfo{}, fmt.Errorf("TLS handshake timeout: Unable to establish secure connection to Google")
|
|
}
|
|
if strings.Contains(err.Error(), "context deadline exceeded") {
|
|
return models.UserGoogleInfo{}, fmt.Errorf("request deadline exceeded: Connection attempt exceeded 30 second timeout")
|
|
}
|
|
if strings.Contains(err.Error(), "connection refused") {
|
|
return models.UserGoogleInfo{}, fmt.Errorf("connection refused: Cannot reach Google servers")
|
|
}
|
|
if strings.Contains(err.Error(), "no such host") {
|
|
return models.UserGoogleInfo{}, fmt.Errorf("DNS resolution failed: Cannot resolve googleapis.com")
|
|
}
|
|
|
|
return models.UserGoogleInfo{}, fmt.Errorf("network error while fetching user info: %w", err)
|
|
}
|
|
defer func(Body io.ReadCloser) {
|
|
err := Body.Close()
|
|
if err != nil {
|
|
helper.LogError(err, "Error closing response body")
|
|
}
|
|
}(resp.Body)
|
|
|
|
// Check HTTP status code
|
|
if resp.StatusCode != http.StatusOK {
|
|
helper.LogInfo(fmt.Sprintf("[oauth-debug] google userinfo non-200 status=%d total_ms=%d", resp.StatusCode, time.Since(fetchStart).Milliseconds()))
|
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
|
errMsg := fmt.Sprintf("Google API returned status %d: %s", resp.StatusCode, string(bodyBytes))
|
|
helper.LogError(nil, errMsg)
|
|
return models.UserGoogleInfo{}, fmt.Errorf("google api error (status %d): %s", resp.StatusCode, string(bodyBytes))
|
|
}
|
|
|
|
var userInfo models.UserGoogleInfo
|
|
if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
|
|
helper.LogInfo(fmt.Sprintf("[oauth-debug] google userinfo decode failed total_ms=%d", time.Since(fetchStart).Milliseconds()))
|
|
helper.LogError(err, "Error decoding user info from Google response")
|
|
return models.UserGoogleInfo{}, fmt.Errorf("failed to parse user info response: %w", err)
|
|
}
|
|
helper.LogInfo(fmt.Sprintf("[oauth-debug] fetch google userinfo complete total_ms=%d", time.Since(fetchStart).Milliseconds()))
|
|
return userInfo, nil
|
|
}
|
|
|
|
func HandleTokenRefresh(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method == http.MethodOptions {
|
|
w.Header().Set("Access-Control-Allow-Origin", "*")
|
|
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
|
|
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
|
|
w.Header().Set("Access-Control-Max-Age", "3600")
|
|
w.WriteHeader(http.StatusOK)
|
|
return
|
|
}
|
|
|
|
// First, check if access token is provided and if it's expired
|
|
helper.LogInfo("Refresh token handler called")
|
|
authHeader := r.Header.Get("Authorization")
|
|
helper.LogInfo("Authorization header: " + authHeader)
|
|
if authHeader != "" && strings.HasPrefix(authHeader, bearerPrefix) {
|
|
accessToken := strings.TrimSpace(strings.TrimPrefix(authHeader, bearerPrefix))
|
|
helper.LogInfo("Access token from header: " + accessToken)
|
|
token, err := jwt.ParseWithClaims(accessToken, &models.AccessToken{}, func(token *jwt.Token) (interface{}, error) {
|
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
|
}
|
|
return []byte(os.Getenv("JWT_SECRET_KEY")), nil
|
|
})
|
|
helper.LogInfo("Parsed token: " + fmt.Sprintf("%v", token))
|
|
|
|
if err == nil && token != nil && token.Claims != nil {
|
|
if claims, ok := token.Claims.(*models.AccessToken); ok && claims != nil {
|
|
if claims.Exp != 0 && claims.ExpiresAt != nil {
|
|
helper.LogInfo("Token expiration timestamp: " + fmt.Sprintf("%v", claims.ExpiresAt.Unix()))
|
|
helper.LogInfo("Current timestamp: " + fmt.Sprintf("%v", time.Now().Unix()))
|
|
} else {
|
|
helper.LogInfo("Token Exp is zero or ExpiresAt is nil")
|
|
if claims.Exp != 0 {
|
|
helper.LogInfo("Exp: " + fmt.Sprintf("%d (%s)", claims.Exp, time.Unix(claims.Exp, 0).Format(time.RFC3339)))
|
|
} else {
|
|
helper.LogInfo("Exp field is 0")
|
|
}
|
|
}
|
|
helper.LogInfo("Token expiration (Exp field): " + fmt.Sprintf("%d", claims.Exp))
|
|
helper.LogInfo("Current time: " + fmt.Sprintf("%d", time.Now().Unix()))
|
|
if claims.Exp < time.Now().Unix() {
|
|
helper.LogInfo("Token is actually expired based on Exp field")
|
|
} else {
|
|
helper.LogInfo("Token is NOT expired based on Exp field")
|
|
}
|
|
helper.LogInfo("Token valid: " + fmt.Sprintf("%v", token.Valid))
|
|
|
|
// Always proceed to refresh when requested, regardless of current token validity
|
|
helper.LogInfo("Access token present, but proceeding with refresh as requested")
|
|
} else {
|
|
helper.LogInfo("Failed to cast token claims to AccessToken or claims is nil")
|
|
}
|
|
} else {
|
|
helper.LogInfo("Token parsing failed or token is nil. Error: " + fmt.Sprintf("%v", err))
|
|
}
|
|
|
|
if err != nil && !strings.Contains(err.Error(), "expired") && !strings.Contains(err.Error(), "used before issued") {
|
|
helper.LogError(err, "Invalid access token format")
|
|
helper.RespondWithError(w, http.StatusBadRequest, "Invalid access token format")
|
|
return
|
|
}
|
|
helper.LogInfo("Access token is expired or invalid, proceeding with refresh")
|
|
}
|
|
|
|
// Log all cookies for debugging
|
|
helper.LogInfo("TRACE: All cookies in request: " + fmt.Sprintf("%d cookies", len(r.Cookies())))
|
|
for i, cookie := range r.Cookies() {
|
|
helper.LogInfo(fmt.Sprintf("TRACE: Cookie %d: Name=%s, Value-length=%d, Domain=%s, Path=%s",
|
|
i, cookie.Name, len(cookie.Value), cookie.Domain, cookie.Path))
|
|
}
|
|
|
|
cookie, err := r.Cookie("refresh_token")
|
|
helper.LogInfo("TRACE: Cookie retrieval - error: " + fmt.Sprintf("%v", err))
|
|
if err != nil {
|
|
helper.LogError(err, "Refresh token cookie not found")
|
|
helper.RespondWithError(w, http.StatusUnauthorized, "Refresh token not found")
|
|
return
|
|
}
|
|
|
|
refreshToken := cookie.Value
|
|
helper.LogInfo("TRACE: Refresh token from cookie - length: " + fmt.Sprintf("%d", len(refreshToken)))
|
|
if refreshToken == "" {
|
|
helper.LogError(errors.New("refresh token cookie is empty"), "refresh token cookie is empty")
|
|
helper.RespondWithError(w, http.StatusUnauthorized, "refresh token is empty")
|
|
return
|
|
}
|
|
|
|
// Get client info for security validation
|
|
userAgent := r.Header.Get("User-Agent")
|
|
ipAddress := getIPAddress(r)
|
|
|
|
// Try to extract email from access token for fallback during refresh
|
|
var emailFromToken string
|
|
if authHeader != "" && strings.HasPrefix(authHeader, bearerPrefix) {
|
|
accessToken := strings.TrimSpace(strings.TrimPrefix(authHeader, bearerPrefix))
|
|
if token, err := jwt.ParseWithClaims(accessToken, &models.AccessToken{}, func(token *jwt.Token) (interface{}, error) {
|
|
return []byte(os.Getenv("JWT_SECRET_KEY")), nil
|
|
}); err == nil {
|
|
if claims, ok := token.Claims.(*models.AccessToken); ok && claims.Email != "" {
|
|
emailFromToken = claims.Email
|
|
helper.LogInfo("TRACE: Extracted email from access token for fallback: " + emailFromToken)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Use the improved RefreshAccessToken function
|
|
newAccessToken, err := GenerateTokensFromRefreshWithEmail(refreshToken, userAgent, ipAddress, emailFromToken)
|
|
helper.LogInfo("New access token: " + newAccessToken)
|
|
helper.LogInfo("New access token length: " + fmt.Sprintf("%d", len(newAccessToken)))
|
|
if newAccessToken == "" {
|
|
helper.LogError(errors.New("generated access token is empty"), "Generated access token is empty")
|
|
helper.RespondWithError(w, http.StatusUnauthorized, "Failed to generate new access token")
|
|
}
|
|
if err != nil {
|
|
helper.LogError(err, "Failed to refresh access token")
|
|
|
|
// Return specific error messages
|
|
if strings.Contains(err.Error(), "too many refresh attempts") {
|
|
helper.RespondWithError(w, http.StatusTooManyRequests, "Too many refresh attempts, please wait")
|
|
return
|
|
}
|
|
if strings.Contains(err.Error(), "expired") || strings.Contains(err.Error(), "revoked") {
|
|
helper.RespondWithError(w, http.StatusUnauthorized, "Session expired, please login again")
|
|
return
|
|
}
|
|
|
|
helper.RespondWithError(w, http.StatusUnauthorized, "Invalid refresh token")
|
|
return
|
|
}
|
|
|
|
var expiresInSeconds int
|
|
env := os.Getenv("GO_ENV")
|
|
if env == "production" || env == "canary" {
|
|
expiresInSeconds = 45 * 60
|
|
} else {
|
|
expiresInSeconds = 15 * 60
|
|
}
|
|
|
|
response := map[string]interface{}{
|
|
"access_token": newAccessToken,
|
|
"token_type": "Bearer",
|
|
"expires_in": expiresInSeconds,
|
|
}
|
|
|
|
helper.LogInfo("TRACE: About to send response: " + fmt.Sprintf("%+v", response))
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
|
|
if err := json.NewEncoder(w).Encode(response); err != nil {
|
|
helper.LogError(err, "Failed to encode response")
|
|
} else {
|
|
helper.LogInfo("TRACE: Response successfully encoded and sent")
|
|
}
|
|
}
|
|
|
|
// GenerateTokensFromRefresh creates a new access token from a refresh token
|
|
func GenerateTokensFromRefresh(refreshToken, userAgent, ipAddress string) (string, error) {
|
|
helper.LogInfo("TRACE: GenerateTokensFromRefresh called")
|
|
helper.LogInfo("TRACE: refreshToken length: " + fmt.Sprintf("%d", len(refreshToken)))
|
|
helper.LogInfo("TRACE: userAgent: " + userAgent)
|
|
helper.LogInfo("TRACE: ipAddress: " + ipAddress)
|
|
|
|
result, err := RefreshAccessToken(refreshToken, userAgent, ipAddress)
|
|
helper.LogInfo("TRACE: RefreshAccessToken returned - token length: " + fmt.Sprintf("%d", len(result)) + ", error: " + fmt.Sprintf("%v", err))
|
|
|
|
return result, err
|
|
}
|
|
|
|
// GenerateTokensFromRefreshWithEmail creates a new access token from a refresh token with email fallback
|
|
func GenerateTokensFromRefreshWithEmail(refreshToken, userAgent, ipAddress, emailFallback string) (string, error) {
|
|
helper.LogInfo("TRACE: GenerateTokensFromRefreshWithEmail called")
|
|
helper.LogInfo("TRACE: refreshToken length: " + fmt.Sprintf("%d", len(refreshToken)))
|
|
helper.LogInfo("TRACE: userAgent: " + userAgent)
|
|
helper.LogInfo("TRACE: ipAddress: " + ipAddress)
|
|
helper.LogInfo("TRACE: emailFallback: " + emailFallback)
|
|
|
|
result, err := RefreshAccessTokenWithEmailFallback(refreshToken, userAgent, ipAddress, emailFallback)
|
|
helper.LogInfo("TRACE: RefreshAccessTokenWithEmailFallback returned - token length: " + fmt.Sprintf("%d", len(result)) + ", error: " + fmt.Sprintf("%v", err))
|
|
|
|
return result, err
|
|
}
|
|
|
|
func checkEmailInDB(email string) (bool, error) {
|
|
if db.DB == nil {
|
|
helper.LogError(nil, dbConnNilError)
|
|
return false, errors.New(dbConnNilError)
|
|
}
|
|
exists, err := services.CheckEmailInDB(email)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
helper.LogInfo("Email exists in DB: " + fmt.Sprintf("%v", exists))
|
|
return exists, nil
|
|
}
|
|
|
|
func LogoutHandler(w http.ResponseWriter, r *http.Request) {
|
|
authHeader := r.Header.Get("Authorization")
|
|
clearRefreshTokenCookie(w)
|
|
clearCSRFCookie(w)
|
|
|
|
if isValidAuthHeader(authHeader) {
|
|
tokenString := strings.TrimSpace(strings.TrimPrefix(authHeader, bearerPrefix))
|
|
if tokenString != "" {
|
|
token, err := jwt.ParseWithClaims(tokenString, &models.AccessToken{}, func(token *jwt.Token) (interface{}, error) {
|
|
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
|
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
|
}
|
|
if rsaPrivateKey == nil {
|
|
return nil, errors.New("RSA private key is not initialized")
|
|
}
|
|
return &rsaPrivateKey.PublicKey, nil
|
|
})
|
|
|
|
if err == nil {
|
|
if claims, ok := token.Claims.(*models.AccessToken); ok {
|
|
userID, err := services.GetUserIDFromEmail(claims.Email)
|
|
if err == nil {
|
|
if err := RevokeAllUserSessions(userID); err != nil {
|
|
helper.LogError(err, "Failed to revoke user sessions during logout")
|
|
}
|
|
} else {
|
|
helper.LogError(err, "Failed to get user ID during logout")
|
|
}
|
|
}
|
|
} else {
|
|
helper.LogError(err, "Failed to parse JWT token during logout")
|
|
}
|
|
} else {
|
|
helper.LogWarn("Authorization header contains empty bearer token during logout")
|
|
}
|
|
} else {
|
|
helper.LogWarn("Authorization header missing or invalid during logout; proceeding with cookie clear only")
|
|
}
|
|
|
|
if err := accessLog(r, nil, 18, nil); err != nil {
|
|
helper.LogError(err, "Failed to write access log during logout")
|
|
}
|
|
|
|
response := map[string]interface{}{
|
|
"message": "Successfully logged out",
|
|
"action": "clear_session_storage",
|
|
"keys": []string{"refresh_token", "access_token"},
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
if err := json.NewEncoder(w).Encode(response); err != nil {
|
|
helper.LogError(err, "Failed to encode logout response")
|
|
}
|
|
}
|
|
|
|
func isValidAuthHeader(authHeader string) bool {
|
|
return authHeader != "" && strings.HasPrefix(authHeader, bearerPrefix)
|
|
}
|
|
|
|
func clearRefreshTokenCookie(w http.ResponseWriter) {
|
|
helper.LogInfo("Clearing refresh_token cookie...")
|
|
|
|
isSecure := strings.HasPrefix(os.Getenv("BACKEND_URL"), HTTPS)
|
|
|
|
helper.LogInfo(fmt.Sprintf("Cookie clearing - isSecure: %v, BACKEND_URL: %s", isSecure, os.Getenv("BACKEND_URL")))
|
|
|
|
cookieConfig := &http.Cookie{
|
|
Name: "refresh_token",
|
|
Value: "",
|
|
Path: "/",
|
|
HttpOnly: true,
|
|
Expires: time.Unix(0, 0),
|
|
MaxAge: -1,
|
|
}
|
|
|
|
if isSecure {
|
|
cookieConfig.Secure = true
|
|
cookieConfig.SameSite = http.SameSiteLaxMode
|
|
helper.LogInfo("Setting cookie clear for PRODUCTION (secure=true)")
|
|
} else {
|
|
cookieConfig.Secure = false
|
|
cookieConfig.SameSite = http.SameSiteLaxMode
|
|
cookieConfig.Domain = "localhost"
|
|
helper.LogInfo("Setting cookie clear for DEVELOPMENT (secure=false, domain=localhost)")
|
|
}
|
|
|
|
http.SetCookie(w, cookieConfig)
|
|
helper.LogInfo(fmt.Sprintf("Cookie clear #1 sent: Name=%s, Value=%s, Domain=%s, Secure=%v, HttpOnly=%v",
|
|
cookieConfig.Name, cookieConfig.Value, cookieConfig.Domain, cookieConfig.Secure, cookieConfig.HttpOnly))
|
|
|
|
fallbackCookie := &http.Cookie{
|
|
Name: "refresh_token",
|
|
Value: "",
|
|
Path: "/",
|
|
HttpOnly: true,
|
|
Secure: isSecure,
|
|
SameSite: http.SameSiteLaxMode,
|
|
Expires: time.Unix(0, 0),
|
|
MaxAge: -1,
|
|
}
|
|
http.SetCookie(w, fallbackCookie)
|
|
helper.LogInfo(fmt.Sprintf("Cookie clear #2 sent: Name=%s, Value=%s, Domain=%s, Secure=%v, HttpOnly=%v",
|
|
fallbackCookie.Name, fallbackCookie.Value, fallbackCookie.Domain, fallbackCookie.Secure, fallbackCookie.HttpOnly))
|
|
|
|
helper.LogInfo("Refresh token cookie clearing commands sent to browser")
|
|
}
|
|
|
|
func clearCSRFCookie(w http.ResponseWriter) {
|
|
helper.LogInfo("Clearing csrf_token cookie...")
|
|
|
|
isSecure := strings.HasPrefix(os.Getenv("BACKEND_URL"), HTTPS)
|
|
|
|
// Match middleware cookie characteristics first (host-only, SameSiteStrict)
|
|
primaryCookie := &http.Cookie{
|
|
Name: "csrf_token",
|
|
Value: "",
|
|
Path: "/",
|
|
HttpOnly: true,
|
|
Secure: true,
|
|
SameSite: http.SameSiteStrictMode,
|
|
Expires: time.Unix(0, 0),
|
|
MaxAge: -1,
|
|
}
|
|
http.SetCookie(w, primaryCookie)
|
|
helper.LogInfo(fmt.Sprintf("CSRF cookie clear #1 sent: Name=%s, Domain=%s, Secure=%v, SameSite=%v",
|
|
primaryCookie.Name, primaryCookie.Domain, primaryCookie.Secure, primaryCookie.SameSite))
|
|
|
|
// Fallback for local/dev browser behavior where secure or samesite attributes differ
|
|
fallbackCookie := &http.Cookie{
|
|
Name: "csrf_token",
|
|
Value: "",
|
|
Path: "/",
|
|
HttpOnly: true,
|
|
Secure: isSecure,
|
|
SameSite: http.SameSiteLaxMode,
|
|
Expires: time.Unix(0, 0),
|
|
MaxAge: -1,
|
|
}
|
|
http.SetCookie(w, fallbackCookie)
|
|
helper.LogInfo(fmt.Sprintf("CSRF cookie clear #2 sent: Name=%s, Domain=%s, Secure=%v, SameSite=%v",
|
|
fallbackCookie.Name, fallbackCookie.Domain, fallbackCookie.Secure, fallbackCookie.SameSite))
|
|
|
|
if !isSecure {
|
|
localhostCookie := &http.Cookie{
|
|
Name: "csrf_token",
|
|
Value: "",
|
|
Path: "/",
|
|
Domain: "localhost",
|
|
HttpOnly: true,
|
|
Secure: false,
|
|
SameSite: http.SameSiteLaxMode,
|
|
Expires: time.Unix(0, 0),
|
|
MaxAge: -1,
|
|
}
|
|
http.SetCookie(w, localhostCookie)
|
|
helper.LogInfo(fmt.Sprintf("CSRF cookie clear #3 sent: Name=%s, Domain=%s, Secure=%v, SameSite=%v",
|
|
localhostCookie.Name, localhostCookie.Domain, localhostCookie.Secure, localhostCookie.SameSite))
|
|
}
|
|
|
|
helper.LogInfo("CSRF token cookie clearing commands sent to browser")
|
|
}
|