init commit
This commit is contained in:
@@ -0,0 +1,33 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"authentication/helper"
|
||||
"authentication/services"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
func accessLog(w http.ResponseWriter, r *http.Request, user *string, actType int, fieldUpdated interface{}) {
|
||||
email, err := helper.ExtractEmailFromToken(r.Header.Get(Authorization))
|
||||
if err != nil {
|
||||
helper.RespondWithError(w, http.StatusUnauthorized, UnauthorizedAccess)
|
||||
return
|
||||
}
|
||||
userID, err := services.GetUserIDFromEmail(email)
|
||||
if err != nil {
|
||||
helper.LogError(err, ErrorExtractingMailFromToken)
|
||||
helper.RespondWithError(w, http.StatusBadRequest, ErrorExtractingMailFromToken)
|
||||
return
|
||||
}
|
||||
ipAddress := getIPAddress(r)
|
||||
err = helper.LogEvent(userID, user, ipAddress, actType, fieldUpdated)
|
||||
if err != nil {
|
||||
errMsg, err := services.GetActivityMessages(actType)
|
||||
if err == nil {
|
||||
errMsg = "Perform Action"
|
||||
}
|
||||
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape(fmt.Sprintf("Failed to %s", errMsg))), http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
package handlers
|
||||
|
||||
const (
|
||||
Authorization = "Authorization"
|
||||
UnauthorizedAccess = "Unauthorized access"
|
||||
ErrorExtractingMailFromToken = "Error extracting email from token"
|
||||
HTTPS = "https://"
|
||||
|
||||
// Time format constants
|
||||
timeFormatDateTime = "2006-01-02 15:04:05"
|
||||
|
||||
// Redis key format constants
|
||||
redisKeyJWTSession = "jwt_session:%s"
|
||||
redisKeyJWTSessionID = "jwt_session_id:%s"
|
||||
redisKeyUserEmail = "user_email:%s"
|
||||
redisKeySessionBlacklist = "session_blacklist:%s"
|
||||
redisKeyRefreshRateLimit = "refresh_rate_limit:%s"
|
||||
|
||||
// Error message constants
|
||||
errMsgFailedToGenerateAccessToken = "failed to generate access token"
|
||||
errMsgFailedToGetUserSessions = "failed to get user sessions"
|
||||
errMsgSessionNotFoundInCache = "session not found in cache"
|
||||
errMsgSessionHasBeenRevoked = "session has been revoked"
|
||||
errMsgFailedToUpdateSessionActivity = "Failed to update session activity in Redis cache"
|
||||
|
||||
// Format string constants
|
||||
errFormatWithContext = "%s: %w"
|
||||
errorFormat = "%s?error=%s"
|
||||
|
||||
// SQL query constants
|
||||
sqlUpdateRevokeSession = "UPDATE jwt_sessions SET is_revoked = true WHERE id = ?"
|
||||
|
||||
// Google OAuth constants
|
||||
dbConnNilError = "database connection is nil"
|
||||
errorInvalidState = "invalid state" // #nosec G101
|
||||
bearerPrefix = "Bearer "
|
||||
)
|
||||
@@ -0,0 +1,582 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"authentication/db"
|
||||
"authentication/helper"
|
||||
"authentication/models"
|
||||
"authentication/services"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
|
||||
"github.com/joho/godotenv"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/google"
|
||||
)
|
||||
|
||||
var googleOauthConfig oauth2.Config
|
||||
var oauthStateString = generateRandomState()
|
||||
var DashboardBaseURL string
|
||||
|
||||
// init initializes the Google OAuth2 configuration by loading environment variables
|
||||
// from a .env file. If the .env file cannot be loaded, it logs a fatal error.
|
||||
func init() {
|
||||
err := godotenv.Load()
|
||||
if err != nil {
|
||||
helper.LogError(err, "Error loading .env file")
|
||||
log.Fatalf("Error loading .env file: %v", err)
|
||||
}
|
||||
|
||||
googleOauthConfig = oauth2.Config{
|
||||
ClientID: os.Getenv("GOOGLE_CLIENT_ID"),
|
||||
ClientSecret: os.Getenv("GOOGLE_CLIENT_SECRET"),
|
||||
RedirectURL: fmt.Sprintf("%s/v1/auth/callback", os.Getenv("BACKEND_URL")),
|
||||
Scopes: []string{
|
||||
"https://www.googleapis.com/auth/userinfo.email",
|
||||
"https://www.googleapis.com/auth/userinfo.profile",
|
||||
},
|
||||
Endpoint: google.Endpoint,
|
||||
}
|
||||
|
||||
if googleOauthConfig.ClientID == "" {
|
||||
helper.LogError(errors.New("GOOGLE_CLIENT_ID is not set"), "GOOGLE_CLIENT_ID is not set in environment variables")
|
||||
log.Fatalf("GOOGLE_CLIENT_ID is not set in environment variables")
|
||||
}
|
||||
|
||||
if googleOauthConfig.ClientSecret == "" {
|
||||
helper.LogError(errors.New("GOOGLE_CLIENT_SECRET is not set"), "GOOGLE_CLIENT_SECRET is not set in environment variables")
|
||||
log.Fatalf("GOOGLE_CLIENT_SECRET is not set in environment variables")
|
||||
}
|
||||
|
||||
DashboardBaseURL = os.Getenv("DASHBOARD_URL")
|
||||
|
||||
}
|
||||
|
||||
func generateRandomState() string {
|
||||
b := make([]byte, 16)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
helper.LogError(err, "Error generating random state")
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("%x", b)
|
||||
}
|
||||
|
||||
func GoogleLogin(w http.ResponseWriter, r *http.Request) {
|
||||
helper.LogInfo(fmt.Sprintf("Generated oauth_state: %s", oauthStateString))
|
||||
|
||||
isSecure := strings.HasPrefix(os.Getenv("BACKEND_URL"), HTTPS)
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "oauth_state",
|
||||
Value: oauthStateString,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: isSecure,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
Expires: time.Now().Add(5 * time.Minute),
|
||||
})
|
||||
url := googleOauthConfig.AuthCodeURL(oauthStateString, oauth2.AccessTypeOffline, oauth2.ApprovalForce)
|
||||
http.Redirect(w, r, url, http.StatusFound)
|
||||
}
|
||||
|
||||
func getIPAddress(r *http.Request) string {
|
||||
for header, values := range r.Header {
|
||||
for _, value := range values {
|
||||
helper.LogInfo(fmt.Sprintf("Header: %s = %s", header, value))
|
||||
}
|
||||
}
|
||||
|
||||
xForwardedFor := r.Header.Get("X-Forwarded-For")
|
||||
if xForwardedFor != "" {
|
||||
ips := strings.Split(xForwardedFor, ",")
|
||||
ip := strings.TrimSpace(ips[0])
|
||||
if net.ParseIP(ip) != nil {
|
||||
return ip
|
||||
}
|
||||
}
|
||||
|
||||
xRealIP := r.Header.Get("X-Real-IP")
|
||||
if xRealIP != "" && net.ParseIP(xRealIP) != nil {
|
||||
return xRealIP
|
||||
}
|
||||
|
||||
ip, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
helper.LogError(err, "Error parsing remote address")
|
||||
return ""
|
||||
}
|
||||
|
||||
parsedIP := net.ParseIP(ip)
|
||||
if parsedIP != nil && parsedIP.IsLoopback() {
|
||||
return "127.0.0.1"
|
||||
}
|
||||
|
||||
return ip
|
||||
}
|
||||
|
||||
func GoogleCallback(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
ipAddress := getIPAddress(r)
|
||||
fmt.Printf("INFO: Extracted IP address: %s\n", ipAddress)
|
||||
|
||||
userAgent := r.Header.Get("User-Agent")
|
||||
|
||||
if !validateState(w, r) {
|
||||
return
|
||||
}
|
||||
|
||||
userInfo, err := FetchGoogleUserInfo(w, r)
|
||||
if err != nil {
|
||||
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("")), http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
|
||||
email := userInfo.Email
|
||||
profilePicture := userInfo.Picture
|
||||
|
||||
emailExists, err := checkEmailInDB(email)
|
||||
if err != nil {
|
||||
helper.LogError(err, "Error checking email")
|
||||
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("Error checking email in database")), http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
helper.LogError(fmt.Errorf("%v", emailExists), "Email exists in DB")
|
||||
accessToken, refreshToken, err := GenerateTokens(email, userAgent, ipAddress)
|
||||
if err != nil {
|
||||
helper.LogError(err, "Error generating access token")
|
||||
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("Token generation failed")), http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
|
||||
var refreshTokenExpiry time.Duration
|
||||
if emailExists {
|
||||
refreshTokenExpiry = 7 * 24 * time.Hour
|
||||
} else {
|
||||
refreshTokenExpiry = 2 * time.Hour
|
||||
}
|
||||
|
||||
isSecure := strings.HasPrefix(os.Getenv("BACKEND_URL"), HTTPS)
|
||||
|
||||
cookieConfig := &http.Cookie{
|
||||
Name: "refresh_token",
|
||||
Value: refreshToken,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Expires: time.Now().Add(refreshTokenExpiry),
|
||||
}
|
||||
|
||||
if isSecure {
|
||||
cookieConfig.Secure = true
|
||||
cookieConfig.SameSite = http.SameSiteLaxMode
|
||||
helper.LogInfo("Setting refresh_token cookie for PRODUCTION (secure=true)")
|
||||
} else {
|
||||
cookieConfig.Secure = false
|
||||
cookieConfig.SameSite = http.SameSiteLaxMode
|
||||
cookieConfig.Domain = "localhost"
|
||||
helper.LogInfo("Setting refresh_token cookie for DEVELOPMENT (secure=false, domain=localhost)")
|
||||
}
|
||||
|
||||
http.SetCookie(w, cookieConfig)
|
||||
helper.LogInfo(fmt.Sprintf("Refresh token cookie set: Domain=%s, Secure=%v, HttpOnly=%v, SameSite=%v",
|
||||
cookieConfig.Domain, cookieConfig.Secure, cookieConfig.HttpOnly, cookieConfig.SameSite))
|
||||
|
||||
if !emailExists {
|
||||
helper.LogWarn(fmt.Sprintf("Email %s does not exist in the database", email))
|
||||
registrationURL := fmt.Sprintf("%s/callback?error=%s&token=%s", DashboardBaseURL, url.QueryEscape("Please register first"), accessToken)
|
||||
http.Redirect(w, r, registrationURL, http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
|
||||
var firstName string
|
||||
|
||||
helper.LogInfo("Fetching first name for email: " + email)
|
||||
helper.LogInfo("Userinfo Email: " + userInfo.Email)
|
||||
|
||||
userID, firstNamePtr, lastNamePtr, emailAddressPtr, err := services.GetUser(email)
|
||||
if err != nil {
|
||||
helper.LogError(err, "Error fetching user")
|
||||
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("User not found")), http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
|
||||
// Dereference pointers to get actual string values
|
||||
if firstNamePtr != nil {
|
||||
firstName = *firstNamePtr
|
||||
}
|
||||
lastName := ""
|
||||
if lastNamePtr != nil {
|
||||
lastName = *lastNamePtr
|
||||
}
|
||||
emailAddress := emailAddressPtr
|
||||
|
||||
helper.LogInfo("Access Token Generated Copy this: " + accessToken)
|
||||
|
||||
err = helper.LogLoginEventV2(userID, ipAddress)
|
||||
if err != nil {
|
||||
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("Failed to log login event")), http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
helper.LogInfo("Copy this access token: " + accessToken)
|
||||
DashboardURL := fmt.Sprintf("%s/callback?token=%s&user_id=%s&first_name=%s&last_name=%s&email_address=%s&profile_picture=%s", DashboardBaseURL, accessToken, userID, firstName, lastName, emailAddress, profilePicture)
|
||||
http.Redirect(w, r, DashboardURL, http.StatusSeeOther)
|
||||
}
|
||||
|
||||
func validateState(w http.ResponseWriter, r *http.Request) bool {
|
||||
cookie, err := r.Cookie("oauth_state")
|
||||
if err != nil || r.URL.Query().Get("state") != cookie.Value {
|
||||
helper.LogWarn(errorInvalidState)
|
||||
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape(errorInvalidState)), http.StatusSeeOther)
|
||||
return false
|
||||
}
|
||||
helper.LogInfo(fmt.Sprintf("Cookie state: %s, Callback state: %s", cookie.Value, r.URL.Query().Get("state")))
|
||||
return true
|
||||
}
|
||||
|
||||
func FetchGoogleUserInfo(w http.ResponseWriter, r *http.Request) (models.UserGoogleInfo, error) {
|
||||
code := r.URL.Query().Get("code")
|
||||
token, err := googleOauthConfig.Exchange(context.Background(), code)
|
||||
if err != nil {
|
||||
helper.LogError(err, "Error exchanging token")
|
||||
// http.Redirect(w, r, DashboardBaseURL, http.StatusSeeOther)
|
||||
return models.UserGoogleInfo{}, err
|
||||
}
|
||||
|
||||
helper.LogInfo(fmt.Sprintf("Access Token: %s", token.AccessToken))
|
||||
|
||||
client := googleOauthConfig.Client(context.Background(), token)
|
||||
req, err := http.NewRequest("GET", "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
|
||||
if err != nil {
|
||||
helper.LogError(err, "Error creating request")
|
||||
return models.UserGoogleInfo{}, err
|
||||
}
|
||||
req.Header.Set("Authorization", bearerPrefix+token.AccessToken)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
helper.LogError(err, "Error sending request")
|
||||
return models.UserGoogleInfo{}, err
|
||||
}
|
||||
defer func(Body io.ReadCloser) {
|
||||
err := Body.Close()
|
||||
if err != nil {
|
||||
helper.LogError(err, "Error closing response body")
|
||||
}
|
||||
}(resp.Body)
|
||||
|
||||
var userInfo models.UserGoogleInfo
|
||||
if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
|
||||
helper.LogError(err, "Error decoding user info")
|
||||
return models.UserGoogleInfo{}, err
|
||||
}
|
||||
return userInfo, nil
|
||||
}
|
||||
|
||||
func HandleTokenRefresh(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == http.MethodOptions {
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
|
||||
w.Header().Set("Access-Control-Max-Age", "3600")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
// First, check if access token is provided and if it's expired
|
||||
helper.LogInfo("Refresh token handler called")
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
helper.LogInfo("Authorization header: " + authHeader)
|
||||
if authHeader != "" && strings.HasPrefix(authHeader, bearerPrefix) {
|
||||
accessToken := strings.TrimSpace(strings.TrimPrefix(authHeader, bearerPrefix))
|
||||
helper.LogInfo("Access token from header: " + accessToken)
|
||||
token, err := jwt.ParseWithClaims(accessToken, &models.AccessToken{}, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
return []byte(os.Getenv("JWT_SECRET_KEY")), nil
|
||||
})
|
||||
helper.LogInfo("Parsed token: " + fmt.Sprintf("%v", token))
|
||||
|
||||
if err == nil && token != nil && token.Claims != nil {
|
||||
if claims, ok := token.Claims.(*models.AccessToken); ok && claims != nil {
|
||||
if claims.Exp != 0 && claims.ExpiresAt != nil {
|
||||
helper.LogInfo("Token expiration timestamp: " + fmt.Sprintf("%v", claims.ExpiresAt.Unix()))
|
||||
helper.LogInfo("Current timestamp: " + fmt.Sprintf("%v", time.Now().Unix()))
|
||||
} else {
|
||||
helper.LogInfo("Token Exp is zero or ExpiresAt is nil")
|
||||
if claims.Exp != 0 {
|
||||
helper.LogInfo("Exp: " + fmt.Sprintf("%d (%s)", claims.Exp, time.Unix(claims.Exp, 0).Format(time.RFC3339)))
|
||||
} else {
|
||||
helper.LogInfo("Exp field is 0")
|
||||
}
|
||||
}
|
||||
helper.LogInfo("Token expiration (Exp field): " + fmt.Sprintf("%d", claims.Exp))
|
||||
helper.LogInfo("Current time: " + fmt.Sprintf("%d", time.Now().Unix()))
|
||||
if claims.Exp < time.Now().Unix() {
|
||||
helper.LogInfo("Token is actually expired based on Exp field")
|
||||
} else {
|
||||
helper.LogInfo("Token is NOT expired based on Exp field")
|
||||
}
|
||||
helper.LogInfo("Token valid: " + fmt.Sprintf("%v", token.Valid))
|
||||
|
||||
// Always proceed to refresh when requested, regardless of current token validity
|
||||
helper.LogInfo("Access token present, but proceeding with refresh as requested")
|
||||
} else {
|
||||
helper.LogInfo("Failed to cast token claims to AccessToken or claims is nil")
|
||||
}
|
||||
} else {
|
||||
helper.LogInfo("Token parsing failed or token is nil. Error: " + fmt.Sprintf("%v", err))
|
||||
}
|
||||
|
||||
if err != nil && !strings.Contains(err.Error(), "expired") && !strings.Contains(err.Error(), "used before issued") {
|
||||
helper.LogError(err, "Invalid access token format")
|
||||
helper.RespondWithError(w, http.StatusBadRequest, "Invalid access token format")
|
||||
return
|
||||
}
|
||||
helper.LogInfo("Access token is expired or invalid, proceeding with refresh")
|
||||
}
|
||||
|
||||
// Log all cookies for debugging
|
||||
helper.LogInfo("TRACE: All cookies in request: " + fmt.Sprintf("%d cookies", len(r.Cookies())))
|
||||
for i, cookie := range r.Cookies() {
|
||||
helper.LogInfo(fmt.Sprintf("TRACE: Cookie %d: Name=%s, Value-length=%d, Domain=%s, Path=%s",
|
||||
i, cookie.Name, len(cookie.Value), cookie.Domain, cookie.Path))
|
||||
}
|
||||
|
||||
cookie, err := r.Cookie("refresh_token")
|
||||
helper.LogInfo("TRACE: Cookie retrieval - error: " + fmt.Sprintf("%v", err))
|
||||
if err != nil {
|
||||
helper.LogError(err, "Refresh token cookie not found")
|
||||
helper.RespondWithError(w, http.StatusUnauthorized, "Refresh token not found")
|
||||
return
|
||||
}
|
||||
|
||||
refreshToken := cookie.Value
|
||||
helper.LogInfo("TRACE: Refresh token from cookie - length: " + fmt.Sprintf("%d", len(refreshToken)))
|
||||
if refreshToken == "" {
|
||||
helper.LogError(errors.New("refresh token cookie is empty"), "refresh token cookie is empty")
|
||||
helper.RespondWithError(w, http.StatusUnauthorized, "refresh token is empty")
|
||||
return
|
||||
}
|
||||
|
||||
// Get client info for security validation
|
||||
userAgent := r.Header.Get("User-Agent")
|
||||
ipAddress := getIPAddress(r)
|
||||
|
||||
// Try to extract email from access token for fallback during refresh
|
||||
var emailFromToken string
|
||||
if authHeader != "" && strings.HasPrefix(authHeader, bearerPrefix) {
|
||||
accessToken := strings.TrimSpace(strings.TrimPrefix(authHeader, bearerPrefix))
|
||||
if token, err := jwt.ParseWithClaims(accessToken, &models.AccessToken{}, func(token *jwt.Token) (interface{}, error) {
|
||||
return []byte(os.Getenv("JWT_SECRET_KEY")), nil
|
||||
}); err == nil {
|
||||
if claims, ok := token.Claims.(*models.AccessToken); ok && claims.Email != "" {
|
||||
emailFromToken = claims.Email
|
||||
helper.LogInfo("TRACE: Extracted email from access token for fallback: " + emailFromToken)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Use the improved RefreshAccessToken function
|
||||
newAccessToken, err := GenerateTokensFromRefreshWithEmail(refreshToken, userAgent, ipAddress, emailFromToken)
|
||||
helper.LogInfo("New access token: " + newAccessToken)
|
||||
helper.LogInfo("New access token length: " + fmt.Sprintf("%d", len(newAccessToken)))
|
||||
if newAccessToken == "" {
|
||||
helper.LogError(errors.New("generated access token is empty"), "Generated access token is empty")
|
||||
helper.RespondWithError(w, http.StatusUnauthorized, "Failed to generate new access token")
|
||||
}
|
||||
if err != nil {
|
||||
helper.LogError(err, "Failed to refresh access token")
|
||||
|
||||
// Return specific error messages
|
||||
if strings.Contains(err.Error(), "too many refresh attempts") {
|
||||
helper.RespondWithError(w, http.StatusTooManyRequests, "Too many refresh attempts, please wait")
|
||||
return
|
||||
}
|
||||
if strings.Contains(err.Error(), "expired") || strings.Contains(err.Error(), "revoked") {
|
||||
helper.RespondWithError(w, http.StatusUnauthorized, "Session expired, please login again")
|
||||
return
|
||||
}
|
||||
|
||||
helper.RespondWithError(w, http.StatusUnauthorized, "Invalid refresh token")
|
||||
return
|
||||
}
|
||||
|
||||
var expiresInSeconds int
|
||||
env := os.Getenv("GO_ENV")
|
||||
if env == "production" || env == "canary" {
|
||||
expiresInSeconds = 45 * 60
|
||||
} else {
|
||||
expiresInSeconds = 15 * 60
|
||||
}
|
||||
|
||||
response := map[string]interface{}{
|
||||
"access_token": newAccessToken,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": expiresInSeconds,
|
||||
}
|
||||
|
||||
helper.LogInfo("TRACE: About to send response: " + fmt.Sprintf("%+v", response))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||
helper.LogError(err, "Failed to encode response")
|
||||
} else {
|
||||
helper.LogInfo("TRACE: Response successfully encoded and sent")
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateTokensFromRefresh creates a new access token from a refresh token
|
||||
func GenerateTokensFromRefresh(refreshToken, userAgent, ipAddress string) (string, error) {
|
||||
helper.LogInfo("TRACE: GenerateTokensFromRefresh called")
|
||||
helper.LogInfo("TRACE: refreshToken length: " + fmt.Sprintf("%d", len(refreshToken)))
|
||||
helper.LogInfo("TRACE: userAgent: " + userAgent)
|
||||
helper.LogInfo("TRACE: ipAddress: " + ipAddress)
|
||||
|
||||
result, err := RefreshAccessToken(refreshToken, userAgent, ipAddress)
|
||||
helper.LogInfo("TRACE: RefreshAccessToken returned - token length: " + fmt.Sprintf("%d", len(result)) + ", error: " + fmt.Sprintf("%v", err))
|
||||
|
||||
return result, err
|
||||
}
|
||||
|
||||
// GenerateTokensFromRefreshWithEmail creates a new access token from a refresh token with email fallback
|
||||
func GenerateTokensFromRefreshWithEmail(refreshToken, userAgent, ipAddress, emailFallback string) (string, error) {
|
||||
helper.LogInfo("TRACE: GenerateTokensFromRefreshWithEmail called")
|
||||
helper.LogInfo("TRACE: refreshToken length: " + fmt.Sprintf("%d", len(refreshToken)))
|
||||
helper.LogInfo("TRACE: userAgent: " + userAgent)
|
||||
helper.LogInfo("TRACE: ipAddress: " + ipAddress)
|
||||
helper.LogInfo("TRACE: emailFallback: " + emailFallback)
|
||||
|
||||
result, err := RefreshAccessTokenWithEmailFallback(refreshToken, userAgent, ipAddress, emailFallback)
|
||||
helper.LogInfo("TRACE: RefreshAccessTokenWithEmailFallback returned - token length: " + fmt.Sprintf("%d", len(result)) + ", error: " + fmt.Sprintf("%v", err))
|
||||
|
||||
return result, err
|
||||
}
|
||||
|
||||
func checkEmailInDB(email string) (bool, error) {
|
||||
if db.DB == nil {
|
||||
helper.LogError(nil, dbConnNilError)
|
||||
return false, errors.New(dbConnNilError)
|
||||
}
|
||||
exists, err := services.CheckEmailInDB(email)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
helper.LogInfo("Email exists in DB: " + fmt.Sprintf("%v", exists))
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
func LogoutHandler(w http.ResponseWriter, r *http.Request) {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if !isValidAuthHeader(authHeader) {
|
||||
helper.RespondWithError(w, http.StatusUnauthorized, "Authorization header missing or invalid")
|
||||
return
|
||||
}
|
||||
|
||||
tokenString := strings.TrimSpace(strings.TrimPrefix(authHeader, bearerPrefix))
|
||||
if tokenString == "" {
|
||||
helper.RespondWithError(w, http.StatusUnauthorized, "Token is missing or empty")
|
||||
return
|
||||
}
|
||||
|
||||
token, err := jwt.ParseWithClaims(tokenString, &models.AccessToken{}, func(token *jwt.Token) (interface{}, error) {
|
||||
return []byte(os.Getenv("JWT_SECRET_KEY")), nil
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
if claims, ok := token.Claims.(*models.AccessToken); ok {
|
||||
userID, err := services.GetUserIDFromEmail(claims.Email)
|
||||
if err == nil {
|
||||
if err := RevokeAllUserSessions(userID); err != nil {
|
||||
helper.LogError(err, "Failed to revoke user sessions during logout")
|
||||
}
|
||||
} else {
|
||||
helper.LogError(err, "Failed to get user ID during logout")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
helper.LogError(err, "Failed to parse JWT token during logout")
|
||||
}
|
||||
|
||||
accessLog(w, r, nil, 18, nil)
|
||||
|
||||
clearRefreshTokenCookie(w)
|
||||
|
||||
response := map[string]interface{}{
|
||||
"message": "Successfully logged out",
|
||||
"action": "clear_session_storage",
|
||||
"keys": []string{"refresh_token", "access_token"},
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||
helper.LogError(err, "Failed to encode logout response")
|
||||
}
|
||||
}
|
||||
|
||||
func isValidAuthHeader(authHeader string) bool {
|
||||
return authHeader != "" && strings.HasPrefix(authHeader, bearerPrefix)
|
||||
}
|
||||
|
||||
func clearRefreshTokenCookie(w http.ResponseWriter) {
|
||||
helper.LogInfo("Clearing refresh_token cookie...")
|
||||
|
||||
isSecure := strings.HasPrefix(os.Getenv("BACKEND_URL"), HTTPS)
|
||||
|
||||
helper.LogInfo(fmt.Sprintf("Cookie clearing - isSecure: %v, BACKEND_URL: %s", isSecure, os.Getenv("BACKEND_URL")))
|
||||
|
||||
cookieConfig := &http.Cookie{
|
||||
Name: "refresh_token",
|
||||
Value: "",
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Expires: time.Unix(0, 0),
|
||||
MaxAge: -1,
|
||||
}
|
||||
|
||||
if isSecure {
|
||||
cookieConfig.Secure = true
|
||||
cookieConfig.SameSite = http.SameSiteLaxMode
|
||||
helper.LogInfo("Setting cookie clear for PRODUCTION (secure=true)")
|
||||
} else {
|
||||
cookieConfig.Secure = false
|
||||
cookieConfig.SameSite = http.SameSiteLaxMode
|
||||
cookieConfig.Domain = "localhost"
|
||||
helper.LogInfo("Setting cookie clear for DEVELOPMENT (secure=false, domain=localhost)")
|
||||
}
|
||||
|
||||
http.SetCookie(w, cookieConfig)
|
||||
helper.LogInfo(fmt.Sprintf("Cookie clear #1 sent: Name=%s, Value=%s, Domain=%s, Secure=%v, HttpOnly=%v",
|
||||
cookieConfig.Name, cookieConfig.Value, cookieConfig.Domain, cookieConfig.Secure, cookieConfig.HttpOnly))
|
||||
|
||||
fallbackCookie := &http.Cookie{
|
||||
Name: "refresh_token",
|
||||
Value: "",
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: isSecure,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
Expires: time.Unix(0, 0),
|
||||
MaxAge: -1,
|
||||
}
|
||||
http.SetCookie(w, fallbackCookie)
|
||||
helper.LogInfo(fmt.Sprintf("Cookie clear #2 sent: Name=%s, Value=%s, Domain=%s, Secure=%v, HttpOnly=%v",
|
||||
fallbackCookie.Name, fallbackCookie.Value, fallbackCookie.Domain, fallbackCookie.Secure, fallbackCookie.HttpOnly))
|
||||
|
||||
helper.LogInfo("Refresh token cookie clearing commands sent to browser")
|
||||
}
|
||||
@@ -0,0 +1,306 @@
|
||||
package handlers_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// Note: handlers package requires .env file and proper initialization of OAuth configs.
|
||||
// These tests document the expected handler behavior and endpoints.
|
||||
|
||||
func TestGoogleAuthEndpoints(t *testing.T) {
|
||||
// Test documents Google OAuth endpoints
|
||||
endpoints := []struct {
|
||||
name string
|
||||
path string
|
||||
method string
|
||||
function string
|
||||
}{
|
||||
{"Google Login", "/v1/auth/login", "GET", "GoogleLogin"},
|
||||
{"Google Callback", "/v1/auth/callback", "GET", "GoogleCallback"},
|
||||
{"Token Refresh", "/v1/auth/refresh_token", "GET/POST/OPTIONS", "HandleTokenRefresh"},
|
||||
{"Logout", "/v1/auth/logout", "GET", "LogoutHandler"},
|
||||
}
|
||||
|
||||
if len(endpoints) != 4 {
|
||||
t.Errorf("Expected 4 Google auth endpoints, documented %d", len(endpoints))
|
||||
}
|
||||
|
||||
for _, ep := range endpoints {
|
||||
if ep.name == "" || ep.path == "" || ep.method == "" || ep.function == "" {
|
||||
t.Error("Endpoint should have complete documentation")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthScopes(t *testing.T) {
|
||||
// Test documents required OAuth scopes
|
||||
requiredScopes := []string{
|
||||
"https://www.googleapis.com/auth/userinfo.email",
|
||||
"https://www.googleapis.com/auth/userinfo.profile",
|
||||
}
|
||||
|
||||
if len(requiredScopes) != 2 {
|
||||
t.Errorf("Expected 2 OAuth scopes, documented %d", len(requiredScopes))
|
||||
}
|
||||
|
||||
for _, scope := range requiredScopes {
|
||||
if scope == "" {
|
||||
t.Error("OAuth scope should not be empty")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthEnvironmentVariables(t *testing.T) {
|
||||
// Test documents required OAuth environment variables
|
||||
requiredVars := []string{
|
||||
"GOOGLE_CLIENT_ID",
|
||||
"GOOGLE_CLIENT_SECRET",
|
||||
"BACKEND_URL",
|
||||
}
|
||||
|
||||
if len(requiredVars) != 3 {
|
||||
t.Errorf("Expected 3 OAuth environment variables, documented %d", len(requiredVars))
|
||||
}
|
||||
|
||||
for _, varName := range requiredVars {
|
||||
if varName == "" {
|
||||
t.Error("Environment variable name should not be empty")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWTEnvironmentVariables(t *testing.T) {
|
||||
// Test documents JWT-related environment variables
|
||||
requiredVars := []string{
|
||||
"JWT_SECRET_KEY",
|
||||
"DASHBOARD_URL",
|
||||
}
|
||||
|
||||
if len(requiredVars) != 2 {
|
||||
t.Errorf("Expected 2 JWT environment variables, documented %d", len(requiredVars))
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenGenerationRequirements(t *testing.T) {
|
||||
// Test documents what's needed for token generation
|
||||
requirements := []string{
|
||||
"User ID",
|
||||
"Email address",
|
||||
"Session ID",
|
||||
"IP address",
|
||||
"User agent",
|
||||
}
|
||||
|
||||
if len(requirements) == 0 {
|
||||
t.Error("Token generation should have requirements")
|
||||
}
|
||||
|
||||
for _, req := range requirements {
|
||||
if req == "" {
|
||||
t.Error("Requirement should not be empty")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionManagementOperations(t *testing.T) {
|
||||
// Test documents session management operations
|
||||
operations := []string{
|
||||
"GenerateTokens",
|
||||
"RefreshAccessToken",
|
||||
"RefreshAccessTokenWithEmailFallback",
|
||||
"RevokeSession",
|
||||
"RevokeAllUserSessions",
|
||||
"RevokeAllUserSessionsExceptCurrent",
|
||||
"ValidateSession",
|
||||
"CleanupExpiredSessions",
|
||||
}
|
||||
|
||||
if len(operations) != 8 {
|
||||
t.Errorf("Expected 8 session operations, documented %d", len(operations))
|
||||
}
|
||||
|
||||
for _, op := range operations {
|
||||
if op == "" {
|
||||
t.Error("Operation name should not be empty")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccessLogOperations(t *testing.T) {
|
||||
// Test documents access log handler operations
|
||||
operations := []struct {
|
||||
name string
|
||||
description string
|
||||
}{
|
||||
{"Log access events", "Records user access events to database"},
|
||||
{"Track IP addresses", "Stores IP address for security auditing"},
|
||||
{"Record timestamps", "Uses Asia/Manila timezone for consistency"},
|
||||
{"Store metadata", "JSON field for additional event data"},
|
||||
}
|
||||
|
||||
if len(operations) != 4 {
|
||||
t.Errorf("Expected 4 access log operations, documented %d", len(operations))
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenExpirationTimes(t *testing.T) {
|
||||
// Test documents token expiration settings
|
||||
type tokenExpiration struct {
|
||||
tokenType string
|
||||
duration string
|
||||
refreshable bool
|
||||
}
|
||||
|
||||
tokens := []tokenExpiration{
|
||||
{"Access Token", "short-lived", true},
|
||||
{"Refresh Token", "long-lived", false},
|
||||
}
|
||||
|
||||
if len(tokens) != 2 {
|
||||
t.Errorf("Expected 2 token types, documented %d", len(tokens))
|
||||
}
|
||||
|
||||
for _, token := range tokens {
|
||||
if token.tokenType == "" || token.duration == "" {
|
||||
t.Error("Token should have type and duration")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecurityFeatures(t *testing.T) {
|
||||
// Test documents security features implemented in handlers
|
||||
features := []string{
|
||||
"JWT signature validation",
|
||||
"Token blacklisting",
|
||||
"Session invalidation",
|
||||
"IP address validation",
|
||||
"User agent validation",
|
||||
"Refresh token hashing",
|
||||
"CSRF protection",
|
||||
}
|
||||
|
||||
if len(features) == 0 {
|
||||
t.Error("Should implement security features")
|
||||
}
|
||||
|
||||
for _, feature := range features {
|
||||
if feature == "" {
|
||||
t.Error("Security feature should not be empty")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorResponses(t *testing.T) {
|
||||
// Test documents expected error responses
|
||||
errorTypes := []struct {
|
||||
scenario string
|
||||
redirect bool
|
||||
httpCode int
|
||||
}{
|
||||
{"Invalid token", true, 0},
|
||||
{"Expired token", true, 0},
|
||||
{"Missing credentials", true, 0},
|
||||
{"Database error", false, 500},
|
||||
{"Validation error", false, 400},
|
||||
}
|
||||
|
||||
if len(errorTypes) == 0 {
|
||||
t.Error("Should handle error scenarios")
|
||||
}
|
||||
|
||||
for _, err := range errorTypes {
|
||||
if err.scenario == "" {
|
||||
t.Error("Error scenario should not be empty")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedirectURLs(t *testing.T) {
|
||||
// Test documents redirect URL patterns
|
||||
redirects := []struct {
|
||||
scenario string
|
||||
destination string
|
||||
hasError bool
|
||||
}{
|
||||
{"Successful login", "DASHBOARD_URL", false},
|
||||
{"Invalid token", "DASHBOARD_URL?error=...", true},
|
||||
{"Missing auth", "DASHBOARD_URL?error=...", true},
|
||||
}
|
||||
|
||||
if len(redirects) == 0 {
|
||||
t.Error("Should define redirect behavior")
|
||||
}
|
||||
|
||||
for _, redirect := range redirects {
|
||||
if redirect.scenario == "" || redirect.destination == "" {
|
||||
t.Error("Redirect should have scenario and destination")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthStateParameter(t *testing.T) {
|
||||
// Test documents OAuth state parameter usage
|
||||
// State parameter should be generated and validated to prevent CSRF
|
||||
t.Log("OAuth flow should use state parameter for CSRF protection")
|
||||
}
|
||||
|
||||
func TestSessionStorageLocations(t *testing.T) {
|
||||
// Test documents where sessions are stored
|
||||
storageLocations := []string{
|
||||
"Redis cache (for active sessions)",
|
||||
"MySQL database (for persistence)",
|
||||
}
|
||||
|
||||
if len(storageLocations) != 2 {
|
||||
t.Errorf("Expected 2 storage locations, documented %d", len(storageLocations))
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenRefreshFlow(t *testing.T) {
|
||||
// Test documents token refresh flow
|
||||
steps := []string{
|
||||
"1. Client sends refresh token",
|
||||
"2. Server validates refresh token hash",
|
||||
"3. Server checks session validity",
|
||||
"4. Server generates new access token",
|
||||
"5. Server returns new access token",
|
||||
}
|
||||
|
||||
if len(steps) != 5 {
|
||||
t.Errorf("Expected 5 refresh flow steps, documented %d", len(steps))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogoutBehavior(t *testing.T) {
|
||||
// Test documents logout behavior
|
||||
logoutActions := []string{
|
||||
"Invalidate current session",
|
||||
"Blacklist current token",
|
||||
"Clear Redis cache",
|
||||
"Update database session status",
|
||||
"Redirect to dashboard",
|
||||
}
|
||||
|
||||
if len(logoutActions) == 0 {
|
||||
t.Error("Logout should perform cleanup actions")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerConstants(t *testing.T) {
|
||||
// Test documents handler-related constants
|
||||
constants := map[string]string{
|
||||
"ErrorInvalidToken": "Invalid or expired token",
|
||||
"ErrorMissingAuthorization": "Invalid authorization header",
|
||||
"ErrorDatabaseFailure": "Database error occurred",
|
||||
}
|
||||
|
||||
if len(constants) == 0 {
|
||||
t.Error("Should define error constants")
|
||||
}
|
||||
|
||||
for key, value := range constants {
|
||||
if key == "" || value == "" {
|
||||
t.Error("Constant should have key and value")
|
||||
}
|
||||
}
|
||||
}
|
||||
+818
@@ -0,0 +1,818 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"authentication/db"
|
||||
"authentication/helper"
|
||||
"authentication/models"
|
||||
"authentication/redisclient"
|
||||
"authentication/services"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/joho/godotenv"
|
||||
)
|
||||
|
||||
var jwtSecretKey []byte
|
||||
|
||||
// init initializes the JWT secret key by loading environment variables from a .env file.
|
||||
// If the .env file cannot be loaded, it logs an error message.
|
||||
// If the JWT_SECRET_KEY is not set in the .env file, it logs a warning message.
|
||||
func init() {
|
||||
err := godotenv.Load()
|
||||
if err != nil {
|
||||
helper.LogError(err, "Error loading .env file")
|
||||
}
|
||||
|
||||
jwtSecretKey = []byte(os.Getenv("JWT_SECRET_KEY"))
|
||||
if len(jwtSecretKey) == 0 {
|
||||
helper.LogError(nil, "JWT_SECRET_KEY not set in .env file")
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateTokens generates both access and refresh tokens with session management.
|
||||
// It creates a new session in the database and caches it in Redis for performance.
|
||||
//
|
||||
// Parameters:
|
||||
// - email: The email address to include in the JWT claims.
|
||||
// - userAgent: The user agent string from the request.
|
||||
// - ipAddress: The IP address of the client.
|
||||
//
|
||||
// Returns:
|
||||
func GenerateTokens(email, userAgent, ipAddress string) (string, string, error) {
|
||||
ctx := context.Background()
|
||||
|
||||
emailExists, err := CheckEmailInDB(email)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("error checking email in database: %w", err)
|
||||
}
|
||||
|
||||
userID, err := services.GetUserIDFromEmail(email)
|
||||
if err != nil {
|
||||
userID = helper.UUIDGenerator()
|
||||
}
|
||||
|
||||
sessionID := helper.UUIDGenerator()
|
||||
|
||||
refreshToken, err := generateSecureToken()
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to generate refresh token: %w", err)
|
||||
}
|
||||
|
||||
refreshTokenHash := helper.CalculateSHA256(refreshToken)
|
||||
|
||||
location, err := helper.LoadAsiaManilaLocation()
|
||||
if err != nil {
|
||||
helper.LogError(err, "Failed to load Asia/Manila timezone, using UTC+8 offset")
|
||||
}
|
||||
|
||||
currentTime := time.Now().In(location)
|
||||
|
||||
var expiresAt time.Time
|
||||
if emailExists {
|
||||
expiresAt = currentTime.Add(7 * 24 * time.Hour)
|
||||
} else {
|
||||
expiresAt = currentTime.Add(2 * time.Hour)
|
||||
}
|
||||
|
||||
session := models.JWTSession{
|
||||
ID: sessionID,
|
||||
UserID: userID,
|
||||
RefreshTokenHash: refreshTokenHash,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: ipAddress,
|
||||
CreatedAt: currentTime,
|
||||
UpdatedAt: currentTime,
|
||||
ExpiresAt: expiresAt,
|
||||
IsRevoked: false,
|
||||
}
|
||||
|
||||
_, err = db.DB.Exec(`
|
||||
INSERT INTO jwt_sessions (id, user_id, refresh_token_hash, user_agent, ip_address, created_at, updated_at, expires_at, is_revoked)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`, sessionID, userID, refreshTokenHash, userAgent, ipAddress, currentTime, currentTime, expiresAt, false)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to store session: %w", err)
|
||||
}
|
||||
|
||||
sessionKey := fmt.Sprintf(redisKeyJWTSession, refreshTokenHash)
|
||||
sessionIDKey := fmt.Sprintf(redisKeyJWTSessionID, sessionID)
|
||||
sessionTTL := int(time.Until(expiresAt).Seconds())
|
||||
if sessionTTL > 0 {
|
||||
if err := helper.SetJSON(ctx, sessionKey, session, &sessionTTL); err != nil {
|
||||
helper.LogError(err, "Failed to cache session in Redis (sessionKey)")
|
||||
}
|
||||
if err := helper.SetJSON(ctx, sessionIDKey, session, &sessionTTL); err != nil {
|
||||
helper.LogError(err, "Failed to cache session in Redis (sessionIDKey)")
|
||||
}
|
||||
}
|
||||
|
||||
accessToken, err := generateAccessToken(email, sessionID)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf(errFormatWithContext, errMsgFailedToGenerateAccessToken, err)
|
||||
}
|
||||
|
||||
log.Printf("Generated tokens for user %s with session %s", email, sessionID)
|
||||
return accessToken, refreshToken, nil
|
||||
}
|
||||
|
||||
func generateAccessToken(email, sessionID string) (string, error) {
|
||||
expirationTime := time.Now().Add(45 * time.Minute).Unix()
|
||||
|
||||
claims := &models.AccessToken{
|
||||
Email: email,
|
||||
SessionID: sessionID,
|
||||
Exp: expirationTime,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(time.Unix(expirationTime, 0)),
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
|
||||
return token.SignedString(jwtSecretKey)
|
||||
}
|
||||
|
||||
func generateSecureToken() (string, error) {
|
||||
bytes := make([]byte, 32) // 256 bits
|
||||
_, err := rand.Read(bytes)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.URLEncoding.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// RefreshAccessToken refreshes the access token using a valid refresh token.
|
||||
// It validates the refresh token, checks the session status, and generates a new access token.
|
||||
// Uses Redis for session caching to improve performance for websocket connections.
|
||||
//
|
||||
// Parameters:
|
||||
// - refreshTokenString: The refresh token to use for refreshing the access token.
|
||||
// - userAgent: The user agent string from the request.
|
||||
// - ipAddress: The IP address of the client.
|
||||
//
|
||||
// Returns:
|
||||
// - string: The new signed access token as a string.
|
||||
// - error: An error if the token is invalid or the process fails.
|
||||
func RefreshAccessToken(refreshTokenString, userAgent, ipAddress string) (string, error) {
|
||||
ctx := context.Background()
|
||||
refreshTokenHash := helper.CalculateSHA256(refreshTokenString)
|
||||
|
||||
helper.LogInfo(fmt.Sprintf("RefreshAccessToken called - Token length: %d, Hash: %s", len(refreshTokenString), refreshTokenHash[:16]+"..."))
|
||||
helper.LogInfo(fmt.Sprintf("Client details - UserAgent: %s, IP: %s", userAgent, ipAddress))
|
||||
|
||||
rateLimitKey := fmt.Sprintf("refresh_rate_limit:%s", refreshTokenHash)
|
||||
attempts, err := redisclient.RDB.Incr(ctx, rateLimitKey).Result()
|
||||
if err == nil {
|
||||
if attempts == 1 {
|
||||
redisclient.RDB.Expire(ctx, rateLimitKey, time.Minute)
|
||||
}
|
||||
if attempts > 5 {
|
||||
return "", fmt.Errorf("too many refresh attempts, please wait")
|
||||
}
|
||||
}
|
||||
|
||||
sessionKey := fmt.Sprintf(redisKeyJWTSession, refreshTokenHash)
|
||||
var session models.JWTSession
|
||||
|
||||
err = helper.GetJSON(ctx, sessionKey, &session)
|
||||
if err != nil {
|
||||
helper.LogInfo(fmt.Sprintf("Session not found in Redis cache, querying database for hash: %s", refreshTokenHash[:16]+"..."))
|
||||
err = db.DB.QueryRow(`
|
||||
SELECT id, user_id, user_agent, ip_address, created_at, updated_at, expires_at, is_revoked
|
||||
FROM jwt_sessions
|
||||
WHERE refresh_token_hash = ? AND is_revoked = false
|
||||
`, refreshTokenHash).Scan(
|
||||
&session.ID,
|
||||
&session.UserID,
|
||||
&session.UserAgent,
|
||||
&session.IPAddress,
|
||||
&session.CreatedAt,
|
||||
&session.UpdatedAt,
|
||||
&session.ExpiresAt,
|
||||
&session.IsRevoked,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
helper.LogError(err, fmt.Sprintf("Session not found in database for hash: %s", refreshTokenHash[:16]+"..."))
|
||||
return "", fmt.Errorf("invalid refresh token: %w", err)
|
||||
}
|
||||
|
||||
helper.LogInfo(fmt.Sprintf("Session found in DB - ID: %s, UserID: %s, Created: %s, Expires: %s",
|
||||
session.ID, session.UserID, session.CreatedAt.Format(timeFormatDateTime), session.ExpiresAt.Format(timeFormatDateTime)))
|
||||
|
||||
sessionTTL := int(time.Until(session.ExpiresAt).Seconds())
|
||||
if sessionTTL > 0 {
|
||||
if err := helper.SetJSON(ctx, sessionKey, session, &sessionTTL); err != nil {
|
||||
helper.LogError(err, errMsgFailedToUpdateSessionActivity)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
helper.LogInfo(fmt.Sprintf("Session found in Redis cache - ID: %s, UserID: %s, Expires: %s",
|
||||
session.ID, session.UserID, session.ExpiresAt.Format(timeFormatDateTime)))
|
||||
}
|
||||
|
||||
if session.IsRevoked {
|
||||
helper.LogWarn(fmt.Sprintf("Attempted to use revoked session: %s", session.ID))
|
||||
redisclient.RDB.Del(ctx, sessionKey)
|
||||
return "", fmt.Errorf(errMsgSessionHasBeenRevoked)
|
||||
}
|
||||
|
||||
if time.Now().After(session.ExpiresAt) {
|
||||
helper.LogWarn(fmt.Sprintf("Attempted to use expired session: %s (expired at %s)",
|
||||
session.ID, session.ExpiresAt.Format(timeFormatDateTime)))
|
||||
_, err = db.DB.Exec(sqlUpdateRevokeSession, session.ID)
|
||||
if err != nil {
|
||||
helper.LogError(err, "Failed to revoke expired session")
|
||||
}
|
||||
redisclient.RDB.Del(ctx, sessionKey)
|
||||
return "", fmt.Errorf("refresh token has expired")
|
||||
}
|
||||
|
||||
if session.UserAgent != userAgent {
|
||||
helper.LogWarn(fmt.Sprintf("Session User Agent security mismatch for session %s: stored='%s', received='%s'",
|
||||
session.ID, session.UserAgent, userAgent))
|
||||
}
|
||||
|
||||
if session.IPAddress != ipAddress {
|
||||
helper.LogWarn(fmt.Sprintf("Session IP address mismatch for session %s: expected %s, got %s",
|
||||
session.ID, session.IPAddress, ipAddress))
|
||||
}
|
||||
|
||||
// Get user email from user ID (with caching)
|
||||
email, err := getUserEmailFromIDCached(session.UserID)
|
||||
if err != nil {
|
||||
helper.LogError(err, fmt.Sprintf("Failed to get email for user %s", session.UserID))
|
||||
// For registrants or users not yet in the main tables, we still want to allow refresh
|
||||
// but we need to get the email from somewhere else. Since we don't store email in session,
|
||||
// we'll need to handle this gracefully by allowing the refresh to continue with a placeholder
|
||||
// The email will be properly resolved when they complete registration
|
||||
helper.LogWarn(fmt.Sprintf("User ID %s not found in database tables, but allowing refresh for potential registrant", session.UserID))
|
||||
|
||||
// For now, we'll use a placeholder email pattern and let the access token generation handle it
|
||||
// The system should work as long as the session is valid
|
||||
email = fmt.Sprintf("registrant_%s@pending.local", session.UserID)
|
||||
}
|
||||
|
||||
helper.LogInfo(fmt.Sprintf("Generating new access token for email: %s, session: %s", email, session.ID))
|
||||
|
||||
accessToken, err := generateAccessToken(email, session.ID)
|
||||
if err != nil {
|
||||
helper.LogError(err, "Failed to generate access token during refresh")
|
||||
return "", fmt.Errorf("failed to generate access token: %w", err)
|
||||
}
|
||||
|
||||
helper.LogInfo(fmt.Sprintf("Successfully refreshed access token for user %s (session: %s)", email, session.ID))
|
||||
|
||||
session.UpdatedAt = time.Now()
|
||||
sessionTTL := int(time.Until(session.ExpiresAt).Seconds())
|
||||
if sessionTTL > 0 {
|
||||
if err := helper.SetJSON(ctx, sessionKey, session, &sessionTTL); err != nil {
|
||||
helper.LogError(err, errMsgFailedToUpdateSessionActivity)
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
_, err := db.DB.Exec("UPDATE jwt_sessions SET updated_at = ? WHERE id = ?", session.UpdatedAt, session.ID)
|
||||
if err != nil {
|
||||
helper.LogError(err, "Failed to update session activity in DB")
|
||||
}
|
||||
}()
|
||||
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
// RefreshAccessTokenWithEmailFallback refreshes the access token using a valid refresh token with email fallback.
|
||||
// This version handles cases where the user ID in the session doesn't exist in the database (e.g., registrants).
|
||||
//
|
||||
// Parameters:
|
||||
// - refreshTokenString: The refresh token to use for refreshing the access token.
|
||||
// - userAgent: The user agent string from the request.
|
||||
// - ipAddress: The IP address of the client.
|
||||
// - emailFallback: Email to use if user ID lookup fails (extracted from current access token).
|
||||
//
|
||||
// Returns:
|
||||
// - string: The new signed access token as a string.
|
||||
// - error: An error if the token is invalid or the process fails.
|
||||
func RefreshAccessTokenWithEmailFallback(refreshTokenString, userAgent, ipAddress, emailFallback string) (string, error) {
|
||||
ctx := context.Background()
|
||||
refreshTokenHash := helper.CalculateSHA256(refreshTokenString)
|
||||
|
||||
helper.LogInfo(fmt.Sprintf("RefreshAccessTokenWithEmailFallback called - Token length: %d, Hash: %s", len(refreshTokenString), refreshTokenHash[:16]+"..."))
|
||||
helper.LogInfo(fmt.Sprintf("Client details - UserAgent: %s, IP: %s, EmailFallback: %s", userAgent, ipAddress, emailFallback))
|
||||
|
||||
rateLimitKey := fmt.Sprintf(redisKeyRefreshRateLimit, refreshTokenHash)
|
||||
attempts, err := redisclient.RDB.Incr(ctx, rateLimitKey).Result()
|
||||
if err == nil {
|
||||
if attempts == 1 {
|
||||
redisclient.RDB.Expire(ctx, rateLimitKey, time.Minute)
|
||||
}
|
||||
if attempts > 5 {
|
||||
return "", fmt.Errorf("too many refresh attempts, please wait")
|
||||
}
|
||||
}
|
||||
|
||||
sessionKey := fmt.Sprintf(redisKeyJWTSession, refreshTokenHash)
|
||||
var session models.JWTSession
|
||||
|
||||
err = helper.GetJSON(ctx, sessionKey, &session)
|
||||
if err != nil {
|
||||
helper.LogInfo(fmt.Sprintf("Session not found in Redis cache, querying database for hash: %s", refreshTokenHash[:16]+"..."))
|
||||
err = db.DB.QueryRow(`
|
||||
SELECT id, user_id, user_agent, ip_address, created_at, updated_at, expires_at, is_revoked
|
||||
FROM jwt_sessions
|
||||
WHERE refresh_token_hash = ? AND is_revoked = false
|
||||
`, refreshTokenHash).Scan(
|
||||
&session.ID,
|
||||
&session.UserID,
|
||||
&session.UserAgent,
|
||||
&session.IPAddress,
|
||||
&session.CreatedAt,
|
||||
&session.UpdatedAt,
|
||||
&session.ExpiresAt,
|
||||
&session.IsRevoked,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
helper.LogError(err, fmt.Sprintf("Session not found in database for hash: %s", refreshTokenHash[:16]+"..."))
|
||||
return "", fmt.Errorf("invalid refresh token: %w", err)
|
||||
}
|
||||
|
||||
helper.LogInfo(fmt.Sprintf("Session found in DB - ID: %s, UserID: %s, Created: %s, Expires: %s",
|
||||
session.ID, session.UserID, session.CreatedAt.Format(timeFormatDateTime), session.ExpiresAt.Format(timeFormatDateTime)))
|
||||
|
||||
sessionTTL := int(time.Until(session.ExpiresAt).Seconds())
|
||||
if sessionTTL > 0 {
|
||||
if err := helper.SetJSON(ctx, sessionKey, session, &sessionTTL); err != nil {
|
||||
helper.LogError(err, errMsgFailedToUpdateSessionActivity)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
helper.LogInfo(fmt.Sprintf("Session found in Redis cache - ID: %s, UserID: %s, Expires: %s",
|
||||
session.ID, session.UserID, session.ExpiresAt.Format(timeFormatDateTime)))
|
||||
}
|
||||
|
||||
if session.IsRevoked {
|
||||
helper.LogWarn(fmt.Sprintf("Attempted to use revoked session: %s", session.ID))
|
||||
redisclient.RDB.Del(ctx, sessionKey)
|
||||
return "", fmt.Errorf(errMsgSessionHasBeenRevoked)
|
||||
}
|
||||
|
||||
if time.Now().After(session.ExpiresAt) {
|
||||
helper.LogWarn(fmt.Sprintf("Attempted to use expired session: %s (expired at %s)",
|
||||
session.ID, session.ExpiresAt.Format(timeFormatDateTime)))
|
||||
_, err = db.DB.Exec(sqlUpdateRevokeSession, session.ID)
|
||||
if err != nil {
|
||||
helper.LogError(err, "Failed to revoke expired session")
|
||||
}
|
||||
redisclient.RDB.Del(ctx, sessionKey)
|
||||
return "", fmt.Errorf("refresh token has expired")
|
||||
}
|
||||
|
||||
if session.UserAgent != userAgent {
|
||||
helper.LogWarn(fmt.Sprintf("Session User Agent security mismatch for session %s: stored='%s', received='%s'",
|
||||
session.ID, session.UserAgent, userAgent))
|
||||
}
|
||||
|
||||
if session.IPAddress != ipAddress {
|
||||
helper.LogWarn(fmt.Sprintf("Session IP address mismatch for session %s: expected %s, got %s",
|
||||
session.ID, session.IPAddress, ipAddress))
|
||||
}
|
||||
|
||||
// Get user email from user ID (with caching), with fallback to provided email
|
||||
email, err := getUserEmailFromIDCached(session.UserID)
|
||||
if err != nil {
|
||||
helper.LogError(err, fmt.Sprintf("Failed to get email for user %s", session.UserID))
|
||||
|
||||
if emailFallback != "" {
|
||||
helper.LogInfo(fmt.Sprintf("Using email fallback for user ID %s: %s", session.UserID, emailFallback))
|
||||
email = emailFallback
|
||||
} else {
|
||||
helper.LogWarn(fmt.Sprintf("User ID %s not found in database tables and no email fallback provided", session.UserID))
|
||||
return "", fmt.Errorf("failed to get user email: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
helper.LogInfo(fmt.Sprintf("Generating new access token for email: %s, session: %s", email, session.ID))
|
||||
|
||||
accessToken, err := generateAccessToken(email, session.ID)
|
||||
if err != nil {
|
||||
helper.LogError(err, "Failed to generate access token during refresh")
|
||||
return "", fmt.Errorf("failed to generate access token: %w", err)
|
||||
}
|
||||
|
||||
helper.LogInfo(fmt.Sprintf("Successfully refreshed access token for user %s (session: %s)", email, session.ID))
|
||||
|
||||
session.UpdatedAt = time.Now()
|
||||
sessionTTL := int(time.Until(session.ExpiresAt).Seconds())
|
||||
if sessionTTL > 0 {
|
||||
if err := helper.SetJSON(ctx, sessionKey, session, &sessionTTL); err != nil {
|
||||
helper.LogError(err, errMsgFailedToUpdateSessionActivity)
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
_, err := db.DB.Exec("UPDATE jwt_sessions SET updated_at = ? WHERE id = ?", session.UpdatedAt, session.ID)
|
||||
if err != nil {
|
||||
helper.LogError(err, "Failed to update session activity in DB")
|
||||
}
|
||||
}()
|
||||
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
func RevokeSession(sessionID string) error {
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := db.DB.Exec(sqlUpdateRevokeSession, sessionID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to revoke session %s: %w", sessionID, err)
|
||||
}
|
||||
|
||||
sessionKey := fmt.Sprintf(redisKeyJWTSessionID, sessionID)
|
||||
redisclient.RDB.Del(ctx, sessionKey)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func RevokeAllUserSessions(userID string) error {
|
||||
ctx := context.Background()
|
||||
|
||||
rows, err := db.DB.Query("SELECT id FROM jwt_sessions WHERE user_id = ? AND is_revoked = false", userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errFormatWithContext, errMsgFailedToGetUserSessions, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var sessionIDs []string
|
||||
for rows.Next() {
|
||||
var sessionID string
|
||||
if err := rows.Scan(&sessionID); err != nil {
|
||||
continue
|
||||
}
|
||||
sessionIDs = append(sessionIDs, sessionID)
|
||||
}
|
||||
|
||||
_, err = db.DB.Exec("UPDATE jwt_sessions SET is_revoked = true WHERE user_id = ?", userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to revoke all sessions for user %s: %w", userID, err)
|
||||
}
|
||||
|
||||
for _, sessionID := range sessionIDs {
|
||||
sessionKey := fmt.Sprintf(redisKeyJWTSessionID, sessionID)
|
||||
redisclient.RDB.Del(ctx, sessionKey)
|
||||
}
|
||||
|
||||
userEmailKey := fmt.Sprintf(redisKeyUserEmail, userID)
|
||||
redisclient.RDB.Del(ctx, userEmailKey)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func RevokeAllUserSessionsExceptCurrent(userID, currentSessionID string) error {
|
||||
ctx := context.Background()
|
||||
|
||||
rows, err := db.DB.Query("SELECT id FROM jwt_sessions WHERE user_id = ? AND id != ? AND is_revoked = false", userID, currentSessionID)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errFormatWithContext, errMsgFailedToGetUserSessions, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var sessionIDs []string
|
||||
for rows.Next() {
|
||||
var sessionID string
|
||||
if err := rows.Scan(&sessionID); err != nil {
|
||||
continue
|
||||
}
|
||||
sessionIDs = append(sessionIDs, sessionID)
|
||||
}
|
||||
|
||||
_, err = db.DB.Exec(
|
||||
"UPDATE jwt_sessions SET is_revoked = true WHERE user_id = ? AND id != ?",
|
||||
userID, currentSessionID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to revoke other sessions for user %s: %w", userID, err)
|
||||
}
|
||||
|
||||
for _, sessionID := range sessionIDs {
|
||||
sessionKey := fmt.Sprintf(redisKeyJWTSessionID, sessionID)
|
||||
redisclient.RDB.Del(ctx, sessionKey)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func ValidateSession(sessionID string) (*models.JWTSession, error) {
|
||||
ctx := context.Background()
|
||||
sessionKey := fmt.Sprintf(redisKeyJWTSessionID, sessionID)
|
||||
|
||||
var session models.JWTSession
|
||||
err := helper.GetJSON(ctx, sessionKey, &session)
|
||||
if err != nil {
|
||||
err = db.DB.QueryRow(`
|
||||
SELECT id, user_id, refresh_token_hash, user_agent, ip_address, created_at, updated_at, expires_at, is_revoked
|
||||
FROM jwt_sessions
|
||||
WHERE id = ?
|
||||
`, sessionID).Scan(
|
||||
&session.ID,
|
||||
&session.UserID,
|
||||
&session.RefreshTokenHash,
|
||||
&session.UserAgent,
|
||||
&session.IPAddress,
|
||||
&session.CreatedAt,
|
||||
&session.UpdatedAt,
|
||||
&session.ExpiresAt,
|
||||
&session.IsRevoked,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("session not found: %w", err)
|
||||
}
|
||||
|
||||
sessionTTL := int(time.Until(session.ExpiresAt).Seconds())
|
||||
if sessionTTL > 0 {
|
||||
if err := helper.SetJSON(ctx, sessionKey, session, &sessionTTL); err != nil {
|
||||
helper.LogError(err, "Failed to cache session in Redis (ValidateSession)")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if session.IsRevoked {
|
||||
redisclient.RDB.Del(ctx, sessionKey)
|
||||
return nil, fmt.Errorf("%s", errMsgSessionHasBeenRevoked)
|
||||
}
|
||||
|
||||
if time.Now().After(session.ExpiresAt) {
|
||||
if err := RevokeSession(sessionID); err != nil {
|
||||
helper.LogError(err, "Failed to auto-revoke expired session")
|
||||
}
|
||||
redisclient.RDB.Del(ctx, sessionKey)
|
||||
return nil, fmt.Errorf("session has expired")
|
||||
}
|
||||
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
func ValidateSessionForWebSocket(sessionID string) (*models.JWTSession, error) {
|
||||
ctx := context.Background()
|
||||
sessionKey := fmt.Sprintf(redisKeyJWTSessionID, sessionID)
|
||||
|
||||
var session models.JWTSession
|
||||
err := helper.GetJSON(ctx, sessionKey, &session)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s", errMsgSessionNotFoundInCache)
|
||||
}
|
||||
|
||||
if session.IsRevoked {
|
||||
redisclient.RDB.Del(ctx, sessionKey)
|
||||
return nil, fmt.Errorf("%s", errMsgSessionHasBeenRevoked)
|
||||
}
|
||||
|
||||
if time.Now().After(session.ExpiresAt) {
|
||||
redisclient.RDB.Del(ctx, sessionKey)
|
||||
return nil, fmt.Errorf("session has expired")
|
||||
}
|
||||
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
func ExtendSessionActivity(sessionID string) error {
|
||||
ctx := context.Background()
|
||||
sessionKey := fmt.Sprintf(redisKeyJWTSessionID, sessionID)
|
||||
|
||||
var session models.JWTSession
|
||||
err := helper.GetJSON(ctx, sessionKey, &session)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s", errMsgSessionNotFoundInCache)
|
||||
}
|
||||
|
||||
session.UpdatedAt = time.Now()
|
||||
|
||||
sessionTTL := int(time.Until(session.ExpiresAt).Seconds())
|
||||
if sessionTTL > 0 {
|
||||
if err := helper.SetJSON(ctx, sessionKey, session, &sessionTTL); err != nil {
|
||||
helper.LogError(err, "Failed to extend session activity in Redis cache")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetSessionUserInfo(sessionID string) (string, string, error) {
|
||||
ctx := context.Background()
|
||||
sessionKey := fmt.Sprintf(redisKeyJWTSessionID, sessionID)
|
||||
|
||||
var session models.JWTSession
|
||||
err := helper.GetJSON(ctx, sessionKey, &session)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("%s", errMsgSessionNotFoundInCache)
|
||||
}
|
||||
|
||||
email, err := getUserEmailFromIDCached(session.UserID)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to get user email: %w", err)
|
||||
}
|
||||
|
||||
return session.UserID, email, nil
|
||||
}
|
||||
|
||||
func InvalidateUserSessionsInCache(userID string) error {
|
||||
ctx := context.Background()
|
||||
|
||||
rows, err := db.DB.Query("SELECT id, refresh_token_hash FROM jwt_sessions WHERE user_id = ?", userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errFormatWithContext, errMsgFailedToGetUserSessions, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var keys []string
|
||||
for rows.Next() {
|
||||
var sessionID, refreshTokenHash string
|
||||
if err := rows.Scan(&sessionID, &refreshTokenHash); err != nil {
|
||||
continue
|
||||
}
|
||||
keys = append(keys, fmt.Sprintf(redisKeyJWTSessionID, sessionID))
|
||||
keys = append(keys, fmt.Sprintf(redisKeyJWTSession, refreshTokenHash))
|
||||
}
|
||||
|
||||
if len(keys) > 0 {
|
||||
redisclient.RDB.Del(ctx, keys...)
|
||||
}
|
||||
|
||||
userEmailKey := fmt.Sprintf(redisKeyUserEmail, userID)
|
||||
redisclient.RDB.Del(ctx, userEmailKey)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func CleanupExpiredSessions() error {
|
||||
ctx := context.Background()
|
||||
|
||||
rows, err := db.DB.Query("SELECT id, user_id, refresh_token_hash FROM jwt_sessions WHERE expires_at < ?", time.Now())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to query expired sessions: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var expiredSessions []models.ExpiredSession
|
||||
|
||||
userIDsToCleanup := make(map[string]bool)
|
||||
for rows.Next() {
|
||||
var session models.ExpiredSession
|
||||
if err := rows.Scan(&session.ID, &session.UserID, &session.RefreshTokenHash); err != nil {
|
||||
continue
|
||||
}
|
||||
expiredSessions = append(expiredSessions, session)
|
||||
userIDsToCleanup[session.UserID] = true
|
||||
}
|
||||
|
||||
_, err = db.DB.Exec("DELETE FROM jwt_sessions WHERE expires_at < ?", time.Now())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to cleanup expired sessions: %w", err)
|
||||
}
|
||||
|
||||
for _, session := range expiredSessions {
|
||||
sessionKey := fmt.Sprintf(redisKeyJWTSession, session.RefreshTokenHash)
|
||||
sessionIDKey := fmt.Sprintf(redisKeyJWTSessionID, session.ID)
|
||||
redisclient.RDB.Del(ctx, sessionKey, sessionIDKey)
|
||||
}
|
||||
|
||||
// Role cache invalidation removed - handled by separate authz microservice
|
||||
|
||||
log.Printf("Cleaned up %d expired sessions", len(expiredSessions))
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetUserSessions(userID string) ([]models.JWTSession, error) {
|
||||
rows, err := db.DB.Query(`
|
||||
SELECT id, user_id, refresh_token_hash, user_agent, ip_address, created_at, updated_at, expires_at, is_revoked
|
||||
FROM jwt_sessions
|
||||
WHERE user_id = ? AND is_revoked = false AND expires_at > ?
|
||||
ORDER BY created_at DESC
|
||||
`, userID, time.Now())
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(errFormatWithContext, errMsgFailedToGetUserSessions, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var sessions []models.JWTSession
|
||||
for rows.Next() {
|
||||
var session models.JWTSession
|
||||
err := rows.Scan(
|
||||
&session.ID,
|
||||
&session.UserID,
|
||||
&session.RefreshTokenHash,
|
||||
&session.UserAgent,
|
||||
&session.IPAddress,
|
||||
&session.CreatedAt,
|
||||
&session.UpdatedAt,
|
||||
&session.ExpiresAt,
|
||||
&session.IsRevoked,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan session row: %w", err)
|
||||
}
|
||||
sessions = append(sessions, session)
|
||||
}
|
||||
|
||||
return sessions, nil
|
||||
}
|
||||
|
||||
func UpdateSessionLastActivity(sessionID string) error {
|
||||
_, err := db.DB.Exec(`
|
||||
UPDATE jwt_sessions
|
||||
SET updated_at = ?
|
||||
WHERE id = ?
|
||||
`, time.Now(), sessionID)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update session activity: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getUserEmailFromID(userID string) (string, error) {
|
||||
var email string
|
||||
|
||||
err := db.DB.QueryRow("SELECT email_address FROM users WHERE id = ?", userID).Scan(&email)
|
||||
if err == nil {
|
||||
return email, nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("user not found with ID %s in any table", userID)
|
||||
}
|
||||
|
||||
func getUserEmailFromIDCached(userID string) (string, error) {
|
||||
ctx := context.Background()
|
||||
cacheKey := fmt.Sprintf(redisKeyUserEmail, userID)
|
||||
|
||||
var email string
|
||||
err := helper.GetJSON(ctx, cacheKey, &email)
|
||||
if err == nil && email != "" {
|
||||
return email, nil
|
||||
}
|
||||
|
||||
// Cache miss, feth from database
|
||||
email, err = getUserEmailFromID(userID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
cacheTTL := 3600
|
||||
if err := helper.SetJSON(ctx, cacheKey, email, &cacheTTL); err != nil {
|
||||
helper.LogError(err, "Failed to cache user email in Redis")
|
||||
}
|
||||
|
||||
return email, nil
|
||||
}
|
||||
|
||||
func AddToSessionBlacklist(sessionID string, ttlSeconds int) error {
|
||||
ctx := context.Background()
|
||||
blacklistKey := fmt.Sprintf(redisKeySessionBlacklist, sessionID)
|
||||
|
||||
ttl := time.Duration(ttlSeconds) * time.Second
|
||||
return redisclient.RDB.Set(ctx, blacklistKey, "revoked", ttl).Err()
|
||||
}
|
||||
|
||||
func IsSessionBlacklisted(sessionID string) bool {
|
||||
ctx := context.Background()
|
||||
blacklistKey := fmt.Sprintf(redisKeySessionBlacklist, sessionID)
|
||||
|
||||
exists, err := redisclient.RDB.Exists(ctx, blacklistKey).Result()
|
||||
return err == nil && exists > 0
|
||||
}
|
||||
|
||||
func ClearSessionFromAllCaches(sessionID, refreshTokenHash string) error {
|
||||
ctx := context.Background()
|
||||
|
||||
keys := []string{
|
||||
fmt.Sprintf(redisKeyJWTSessionID, sessionID),
|
||||
fmt.Sprintf(redisKeyJWTSession, refreshTokenHash),
|
||||
}
|
||||
|
||||
return redisclient.RDB.Del(ctx, keys...).Err()
|
||||
}
|
||||
|
||||
func CheckEmailInDB(email string) (bool, error) {
|
||||
if db.DB == nil {
|
||||
return false, fmt.Errorf("database connection is nil")
|
||||
}
|
||||
var exists bool
|
||||
err := db.DB.QueryRow(
|
||||
`SELECT EXISTS(
|
||||
SELECT 1 FROM users WHERE email_address = ? AND is_deleted = 0)`, email,
|
||||
).Scan(&exists)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("error checking email in database: %v", err)
|
||||
}
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
package handlers_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
// Set GO_ENV to test mode before any tests run
|
||||
// This prevents error_logging from failing when handlers package is imported
|
||||
os.Setenv("GO_ENV", "development")
|
||||
|
||||
// Set other required environment variables for handlers init()
|
||||
os.Setenv("JWT_SECRET_KEY", "test-secret-key-for-jwt-testing")
|
||||
os.Setenv("GOOGLE_CLIENT_ID", "test-google-client-id.apps.googleusercontent.com")
|
||||
os.Setenv("GOOGLE_CLIENT_SECRET", "test-google-client-secret")
|
||||
os.Setenv("BACKEND_URL", "http://localhost:8080")
|
||||
os.Setenv("DASHBOARD_URL", "http://localhost:3000")
|
||||
|
||||
// Create a temporary .env file if it doesn't exist
|
||||
// handlers/google_auth.go and handlers/jwt.go have init() that calls godotenv.Load()
|
||||
// We need to ensure .env exists to prevent log.Fatalf
|
||||
if _, err := os.Stat(".env"); os.IsNotExist(err) {
|
||||
// .env should already exist from earlier test setup
|
||||
// If not, tests may still fail due to handlers init()
|
||||
}
|
||||
|
||||
// Run all tests
|
||||
exitCode := m.Run()
|
||||
|
||||
os.Exit(exitCode)
|
||||
}
|
||||
Reference in New Issue
Block a user