Files
Authentication/handlers/google_auth.go
T
2025-11-26 11:31:09 +08:00

597 lines
20 KiB
Go

package handlers
import (
"authentication/db"
"authentication/helper"
"authentication/models"
"authentication/services"
"context"
"crypto/rand"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net"
"net/http"
"net/url"
"os"
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/joho/godotenv"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
)
var googleOauthConfig oauth2.Config
var oauthStateString = generateRandomState()
var DashboardBaseURL string
// init initializes the Google OAuth2 configuration by loading environment variables
// from a .env file. If the .env file cannot be loaded, it logs a fatal error.
// Note: This init runs AFTER .env is loaded in main() init
// But we need to load .env here too since init order is package-based
func init() {
cwd, _ := os.Getwd()
log.Printf("[google_auth.init] Current working directory: %s", cwd)
err := godotenv.Load()
if err != nil {
log.Printf("[google_auth.init] Failed to load .env: %v, trying .env explicitly", err)
err = godotenv.Load(".env")
if err != nil {
log.Printf("[google_auth.init] Failed to load .env explicitly: %v", err)
}
}
clientID := os.Getenv("GOOGLE_CLIENT_ID")
clientSecret := os.Getenv("GOOGLE_CLIENT_SECRET")
backendURL := os.Getenv("BACKEND_URL")
log.Printf("[google_auth.init] GOOGLE_CLIENT_ID: '%s' (length: %d)", clientID, len(clientID))
log.Printf("[google_auth.init] GOOGLE_CLIENT_SECRET: '%s' (length: %d)", clientSecret, len(clientSecret))
log.Printf("[google_auth.init] BACKEND_URL: '%s'", backendURL)
googleOauthConfig = oauth2.Config{
ClientID: clientID,
ClientSecret: clientSecret,
RedirectURL: fmt.Sprintf("%s/v1/auth/callback", backendURL),
Scopes: []string{
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
},
Endpoint: google.Endpoint,
}
if googleOauthConfig.ClientID == "" {
log.Fatal("GOOGLE_CLIENT_ID is not set in environment variables")
}
if googleOauthConfig.ClientSecret == "" {
log.Fatal("GOOGLE_CLIENT_SECRET is not set in environment variables")
}
DashboardBaseURL = os.Getenv("DASHBOARD_URL")
}
func generateRandomState() string {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
helper.LogError(err, "Error generating random state")
return ""
}
return fmt.Sprintf("%x", b)
}
func GoogleLogin(w http.ResponseWriter, r *http.Request) {
helper.LogInfo(fmt.Sprintf("Generated oauth_state: %s", oauthStateString))
isSecure := strings.HasPrefix(os.Getenv("BACKEND_URL"), HTTPS)
http.SetCookie(w, &http.Cookie{
Name: "oauth_state",
Value: oauthStateString,
Path: "/",
HttpOnly: true,
Secure: isSecure,
SameSite: http.SameSiteLaxMode,
Expires: time.Now().Add(5 * time.Minute),
})
url := googleOauthConfig.AuthCodeURL(oauthStateString, oauth2.AccessTypeOffline, oauth2.ApprovalForce)
http.Redirect(w, r, url, http.StatusFound)
}
func getIPAddress(r *http.Request) string {
for header, values := range r.Header {
for _, value := range values {
helper.LogInfo(fmt.Sprintf("Header: %s = %s", header, value))
}
}
xForwardedFor := r.Header.Get("X-Forwarded-For")
if xForwardedFor != "" {
ips := strings.Split(xForwardedFor, ",")
ip := strings.TrimSpace(ips[0])
if net.ParseIP(ip) != nil {
return ip
}
}
xRealIP := r.Header.Get("X-Real-IP")
if xRealIP != "" && net.ParseIP(xRealIP) != nil {
return xRealIP
}
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
helper.LogError(err, "Error parsing remote address")
return ""
}
parsedIP := net.ParseIP(ip)
if parsedIP != nil && parsedIP.IsLoopback() {
return "127.0.0.1"
}
return ip
}
func GoogleCallback(w http.ResponseWriter, r *http.Request) {
ipAddress := getIPAddress(r)
fmt.Printf("INFO: Extracted IP address: %s\n", ipAddress)
userAgent := r.Header.Get("User-Agent")
if !validateState(w, r) {
return
}
userInfo, err := FetchGoogleUserInfo(w, r)
if err != nil {
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("")), http.StatusSeeOther)
return
}
email := userInfo.Email
profilePicture := userInfo.Picture
emailExists, err := checkEmailInDB(email)
if err != nil {
helper.LogError(err, "Error checking email")
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("Error checking email in database")), http.StatusSeeOther)
return
}
helper.LogError(fmt.Errorf("%v", emailExists), "Email exists in DB")
accessToken, refreshToken, err := GenerateTokens(email, userAgent, ipAddress)
if err != nil {
helper.LogError(err, "Error generating access token")
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("Token generation failed")), http.StatusSeeOther)
return
}
var refreshTokenExpiry time.Duration
if emailExists {
refreshTokenExpiry = 7 * 24 * time.Hour
} else {
refreshTokenExpiry = 2 * time.Hour
}
isSecure := strings.HasPrefix(os.Getenv("BACKEND_URL"), HTTPS)
cookieConfig := &http.Cookie{
Name: "refresh_token",
Value: refreshToken,
Path: "/",
HttpOnly: true,
Expires: time.Now().Add(refreshTokenExpiry),
}
if isSecure {
cookieConfig.Secure = true
cookieConfig.SameSite = http.SameSiteLaxMode
helper.LogInfo("Setting refresh_token cookie for PRODUCTION (secure=true)")
} else {
cookieConfig.Secure = false
cookieConfig.SameSite = http.SameSiteLaxMode
cookieConfig.Domain = "localhost"
helper.LogInfo("Setting refresh_token cookie for DEVELOPMENT (secure=false, domain=localhost)")
}
http.SetCookie(w, cookieConfig)
helper.LogInfo(fmt.Sprintf("Refresh token cookie set: Domain=%s, Secure=%v, HttpOnly=%v, SameSite=%v",
cookieConfig.Domain, cookieConfig.Secure, cookieConfig.HttpOnly, cookieConfig.SameSite))
if !emailExists {
helper.LogWarn(fmt.Sprintf("Email %s does not exist in the database", email))
registrationURL := fmt.Sprintf("%s/callback?error=%s&token=%s", DashboardBaseURL, url.QueryEscape("Please register first"), accessToken)
http.Redirect(w, r, registrationURL, http.StatusSeeOther)
return
}
var firstName string
helper.LogInfo("Fetching first name for email: " + email)
helper.LogInfo("Userinfo Email: " + userInfo.Email)
userID, firstNamePtr, lastNamePtr, emailAddressPtr, err := services.GetUser(email)
if err != nil {
helper.LogError(err, "Error fetching user")
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("User not found")), http.StatusSeeOther)
return
}
// Dereference pointers to get actual string values
if firstNamePtr != nil {
firstName = *firstNamePtr
}
lastName := ""
if lastNamePtr != nil {
lastName = *lastNamePtr
}
emailAddress := emailAddressPtr
helper.LogInfo("Access Token Generated Copy this: " + accessToken)
err = helper.LogLoginEventV2(userID, ipAddress)
if err != nil {
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("Failed to log login event")), http.StatusSeeOther)
return
}
helper.LogInfo("Copy this access token: " + accessToken)
DashboardURL := fmt.Sprintf("%s/callback?token=%s&user_id=%s&first_name=%s&last_name=%s&email_address=%s&profile_picture=%s", DashboardBaseURL, accessToken, userID, firstName, lastName, emailAddress, profilePicture)
http.Redirect(w, r, DashboardURL, http.StatusSeeOther)
}
func validateState(w http.ResponseWriter, r *http.Request) bool {
cookie, err := r.Cookie("oauth_state")
if err != nil || r.URL.Query().Get("state") != cookie.Value {
helper.LogWarn(errorInvalidState)
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape(errorInvalidState)), http.StatusSeeOther)
return false
}
helper.LogInfo(fmt.Sprintf("Cookie state: %s, Callback state: %s", cookie.Value, r.URL.Query().Get("state")))
return true
}
func FetchGoogleUserInfo(w http.ResponseWriter, r *http.Request) (models.UserGoogleInfo, error) {
code := r.URL.Query().Get("code")
token, err := googleOauthConfig.Exchange(context.Background(), code)
if err != nil {
helper.LogError(err, "Error exchanging token")
// http.Redirect(w, r, DashboardBaseURL, http.StatusSeeOther)
return models.UserGoogleInfo{}, err
}
helper.LogInfo(fmt.Sprintf("Access Token: %s", token.AccessToken))
client := googleOauthConfig.Client(context.Background(), token)
req, err := http.NewRequest("GET", "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
if err != nil {
helper.LogError(err, "Error creating request")
return models.UserGoogleInfo{}, err
}
req.Header.Set("Authorization", bearerPrefix+token.AccessToken)
resp, err := client.Do(req)
if err != nil {
helper.LogError(err, "Error sending request")
return models.UserGoogleInfo{}, err
}
defer func(Body io.ReadCloser) {
err := Body.Close()
if err != nil {
helper.LogError(err, "Error closing response body")
}
}(resp.Body)
var userInfo models.UserGoogleInfo
if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
helper.LogError(err, "Error decoding user info")
return models.UserGoogleInfo{}, err
}
return userInfo, nil
}
func HandleTokenRefresh(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodOptions {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
w.Header().Set("Access-Control-Max-Age", "3600")
w.WriteHeader(http.StatusOK)
return
}
// First, check if access token is provided and if it's expired
helper.LogInfo("Refresh token handler called")
authHeader := r.Header.Get("Authorization")
helper.LogInfo("Authorization header: " + authHeader)
if authHeader != "" && strings.HasPrefix(authHeader, bearerPrefix) {
accessToken := strings.TrimSpace(strings.TrimPrefix(authHeader, bearerPrefix))
helper.LogInfo("Access token from header: " + accessToken)
token, err := jwt.ParseWithClaims(accessToken, &models.AccessToken{}, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return []byte(os.Getenv("JWT_SECRET_KEY")), nil
})
helper.LogInfo("Parsed token: " + fmt.Sprintf("%v", token))
if err == nil && token != nil && token.Claims != nil {
if claims, ok := token.Claims.(*models.AccessToken); ok && claims != nil {
if claims.Exp != 0 && claims.ExpiresAt != nil {
helper.LogInfo("Token expiration timestamp: " + fmt.Sprintf("%v", claims.ExpiresAt.Unix()))
helper.LogInfo("Current timestamp: " + fmt.Sprintf("%v", time.Now().Unix()))
} else {
helper.LogInfo("Token Exp is zero or ExpiresAt is nil")
if claims.Exp != 0 {
helper.LogInfo("Exp: " + fmt.Sprintf("%d (%s)", claims.Exp, time.Unix(claims.Exp, 0).Format(time.RFC3339)))
} else {
helper.LogInfo("Exp field is 0")
}
}
helper.LogInfo("Token expiration (Exp field): " + fmt.Sprintf("%d", claims.Exp))
helper.LogInfo("Current time: " + fmt.Sprintf("%d", time.Now().Unix()))
if claims.Exp < time.Now().Unix() {
helper.LogInfo("Token is actually expired based on Exp field")
} else {
helper.LogInfo("Token is NOT expired based on Exp field")
}
helper.LogInfo("Token valid: " + fmt.Sprintf("%v", token.Valid))
// Always proceed to refresh when requested, regardless of current token validity
helper.LogInfo("Access token present, but proceeding with refresh as requested")
} else {
helper.LogInfo("Failed to cast token claims to AccessToken or claims is nil")
}
} else {
helper.LogInfo("Token parsing failed or token is nil. Error: " + fmt.Sprintf("%v", err))
}
if err != nil && !strings.Contains(err.Error(), "expired") && !strings.Contains(err.Error(), "used before issued") {
helper.LogError(err, "Invalid access token format")
helper.RespondWithError(w, http.StatusBadRequest, "Invalid access token format")
return
}
helper.LogInfo("Access token is expired or invalid, proceeding with refresh")
}
// Log all cookies for debugging
helper.LogInfo("TRACE: All cookies in request: " + fmt.Sprintf("%d cookies", len(r.Cookies())))
for i, cookie := range r.Cookies() {
helper.LogInfo(fmt.Sprintf("TRACE: Cookie %d: Name=%s, Value-length=%d, Domain=%s, Path=%s",
i, cookie.Name, len(cookie.Value), cookie.Domain, cookie.Path))
}
cookie, err := r.Cookie("refresh_token")
helper.LogInfo("TRACE: Cookie retrieval - error: " + fmt.Sprintf("%v", err))
if err != nil {
helper.LogError(err, "Refresh token cookie not found")
helper.RespondWithError(w, http.StatusUnauthorized, "Refresh token not found")
return
}
refreshToken := cookie.Value
helper.LogInfo("TRACE: Refresh token from cookie - length: " + fmt.Sprintf("%d", len(refreshToken)))
if refreshToken == "" {
helper.LogError(errors.New("refresh token cookie is empty"), "refresh token cookie is empty")
helper.RespondWithError(w, http.StatusUnauthorized, "refresh token is empty")
return
}
// Get client info for security validation
userAgent := r.Header.Get("User-Agent")
ipAddress := getIPAddress(r)
// Try to extract email from access token for fallback during refresh
var emailFromToken string
if authHeader != "" && strings.HasPrefix(authHeader, bearerPrefix) {
accessToken := strings.TrimSpace(strings.TrimPrefix(authHeader, bearerPrefix))
if token, err := jwt.ParseWithClaims(accessToken, &models.AccessToken{}, func(token *jwt.Token) (interface{}, error) {
return []byte(os.Getenv("JWT_SECRET_KEY")), nil
}); err == nil {
if claims, ok := token.Claims.(*models.AccessToken); ok && claims.Email != "" {
emailFromToken = claims.Email
helper.LogInfo("TRACE: Extracted email from access token for fallback: " + emailFromToken)
}
}
}
// Use the improved RefreshAccessToken function
newAccessToken, err := GenerateTokensFromRefreshWithEmail(refreshToken, userAgent, ipAddress, emailFromToken)
helper.LogInfo("New access token: " + newAccessToken)
helper.LogInfo("New access token length: " + fmt.Sprintf("%d", len(newAccessToken)))
if newAccessToken == "" {
helper.LogError(errors.New("generated access token is empty"), "Generated access token is empty")
helper.RespondWithError(w, http.StatusUnauthorized, "Failed to generate new access token")
}
if err != nil {
helper.LogError(err, "Failed to refresh access token")
// Return specific error messages
if strings.Contains(err.Error(), "too many refresh attempts") {
helper.RespondWithError(w, http.StatusTooManyRequests, "Too many refresh attempts, please wait")
return
}
if strings.Contains(err.Error(), "expired") || strings.Contains(err.Error(), "revoked") {
helper.RespondWithError(w, http.StatusUnauthorized, "Session expired, please login again")
return
}
helper.RespondWithError(w, http.StatusUnauthorized, "Invalid refresh token")
return
}
var expiresInSeconds int
env := os.Getenv("GO_ENV")
if env == "production" || env == "canary" {
expiresInSeconds = 45 * 60
} else {
expiresInSeconds = 15 * 60
}
response := map[string]interface{}{
"access_token": newAccessToken,
"token_type": "Bearer",
"expires_in": expiresInSeconds,
}
helper.LogInfo("TRACE: About to send response: " + fmt.Sprintf("%+v", response))
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(response); err != nil {
helper.LogError(err, "Failed to encode response")
} else {
helper.LogInfo("TRACE: Response successfully encoded and sent")
}
}
// GenerateTokensFromRefresh creates a new access token from a refresh token
func GenerateTokensFromRefresh(refreshToken, userAgent, ipAddress string) (string, error) {
helper.LogInfo("TRACE: GenerateTokensFromRefresh called")
helper.LogInfo("TRACE: refreshToken length: " + fmt.Sprintf("%d", len(refreshToken)))
helper.LogInfo("TRACE: userAgent: " + userAgent)
helper.LogInfo("TRACE: ipAddress: " + ipAddress)
result, err := RefreshAccessToken(refreshToken, userAgent, ipAddress)
helper.LogInfo("TRACE: RefreshAccessToken returned - token length: " + fmt.Sprintf("%d", len(result)) + ", error: " + fmt.Sprintf("%v", err))
return result, err
}
// GenerateTokensFromRefreshWithEmail creates a new access token from a refresh token with email fallback
func GenerateTokensFromRefreshWithEmail(refreshToken, userAgent, ipAddress, emailFallback string) (string, error) {
helper.LogInfo("TRACE: GenerateTokensFromRefreshWithEmail called")
helper.LogInfo("TRACE: refreshToken length: " + fmt.Sprintf("%d", len(refreshToken)))
helper.LogInfo("TRACE: userAgent: " + userAgent)
helper.LogInfo("TRACE: ipAddress: " + ipAddress)
helper.LogInfo("TRACE: emailFallback: " + emailFallback)
result, err := RefreshAccessTokenWithEmailFallback(refreshToken, userAgent, ipAddress, emailFallback)
helper.LogInfo("TRACE: RefreshAccessTokenWithEmailFallback returned - token length: " + fmt.Sprintf("%d", len(result)) + ", error: " + fmt.Sprintf("%v", err))
return result, err
}
func checkEmailInDB(email string) (bool, error) {
if db.DB == nil {
helper.LogError(nil, dbConnNilError)
return false, errors.New(dbConnNilError)
}
exists, err := services.CheckEmailInDB(email)
if err != nil {
return false, err
}
helper.LogInfo("Email exists in DB: " + fmt.Sprintf("%v", exists))
return exists, nil
}
func LogoutHandler(w http.ResponseWriter, r *http.Request) {
authHeader := r.Header.Get("Authorization")
if !isValidAuthHeader(authHeader) {
helper.RespondWithError(w, http.StatusUnauthorized, "Authorization header missing or invalid")
return
}
tokenString := strings.TrimSpace(strings.TrimPrefix(authHeader, bearerPrefix))
if tokenString == "" {
helper.RespondWithError(w, http.StatusUnauthorized, "Token is missing or empty")
return
}
token, err := jwt.ParseWithClaims(tokenString, &models.AccessToken{}, func(token *jwt.Token) (interface{}, error) {
return []byte(os.Getenv("JWT_SECRET_KEY")), nil
})
if err == nil {
if claims, ok := token.Claims.(*models.AccessToken); ok {
userID, err := services.GetUserIDFromEmail(claims.Email)
if err == nil {
if err := RevokeAllUserSessions(userID); err != nil {
helper.LogError(err, "Failed to revoke user sessions during logout")
}
} else {
helper.LogError(err, "Failed to get user ID during logout")
}
}
} else {
helper.LogError(err, "Failed to parse JWT token during logout")
}
accessLog(w, r, nil, 18, nil)
clearRefreshTokenCookie(w)
response := map[string]interface{}{
"message": "Successfully logged out",
"action": "clear_session_storage",
"keys": []string{"refresh_token", "access_token"},
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(response); err != nil {
helper.LogError(err, "Failed to encode logout response")
}
}
func isValidAuthHeader(authHeader string) bool {
return authHeader != "" && strings.HasPrefix(authHeader, bearerPrefix)
}
func clearRefreshTokenCookie(w http.ResponseWriter) {
helper.LogInfo("Clearing refresh_token cookie...")
isSecure := strings.HasPrefix(os.Getenv("BACKEND_URL"), HTTPS)
helper.LogInfo(fmt.Sprintf("Cookie clearing - isSecure: %v, BACKEND_URL: %s", isSecure, os.Getenv("BACKEND_URL")))
cookieConfig := &http.Cookie{
Name: "refresh_token",
Value: "",
Path: "/",
HttpOnly: true,
Expires: time.Unix(0, 0),
MaxAge: -1,
}
if isSecure {
cookieConfig.Secure = true
cookieConfig.SameSite = http.SameSiteLaxMode
helper.LogInfo("Setting cookie clear for PRODUCTION (secure=true)")
} else {
cookieConfig.Secure = false
cookieConfig.SameSite = http.SameSiteLaxMode
cookieConfig.Domain = "localhost"
helper.LogInfo("Setting cookie clear for DEVELOPMENT (secure=false, domain=localhost)")
}
http.SetCookie(w, cookieConfig)
helper.LogInfo(fmt.Sprintf("Cookie clear #1 sent: Name=%s, Value=%s, Domain=%s, Secure=%v, HttpOnly=%v",
cookieConfig.Name, cookieConfig.Value, cookieConfig.Domain, cookieConfig.Secure, cookieConfig.HttpOnly))
fallbackCookie := &http.Cookie{
Name: "refresh_token",
Value: "",
Path: "/",
HttpOnly: true,
Secure: isSecure,
SameSite: http.SameSiteLaxMode,
Expires: time.Unix(0, 0),
MaxAge: -1,
}
http.SetCookie(w, fallbackCookie)
helper.LogInfo(fmt.Sprintf("Cookie clear #2 sent: Name=%s, Value=%s, Domain=%s, Secure=%v, HttpOnly=%v",
fallbackCookie.Name, fallbackCookie.Value, fallbackCookie.Domain, fallbackCookie.Secure, fallbackCookie.HttpOnly))
helper.LogInfo("Refresh token cookie clearing commands sent to browser")
}