package handlers import ( "authentication/db" "authentication/helper" "authentication/models" "authentication/redisclient" "authentication/services" "context" "crypto/rand" "crypto/rsa" "crypto/x509" "encoding/base64" "encoding/pem" "fmt" "log" "os" "time" "github.com/golang-jwt/jwt/v5" ) var rsaPrivateKey *rsa.PrivateKey func parseRSAPrivateKey(keyData []byte) (*rsa.PrivateKey, error) { block, _ := pem.Decode(keyData) if block == nil { return nil, fmt.Errorf("failed to decode PEM block containing private key") } key, err := x509.ParsePKCS8PrivateKey(block.Bytes) if err != nil { key, err = x509.ParsePKCS1PrivateKey(block.Bytes) if err != nil { return nil, fmt.Errorf("failed to parse RSA private key: %w", err) } } rsaKey, ok := key.(*rsa.PrivateKey) if !ok { return nil, fmt.Errorf("not an RSA private key") } return rsaKey, nil } // Note: .env file is loaded in main() before this init runs func init() { keyPath := os.Getenv("JWT_PRIVATE_KEY_PATH") if keyPath == "" { keyPath = "rsa/psa.gov.ph.key" } keyData, err := os.ReadFile(keyPath) // #nosec G304 if err != nil { if isTestEnvironment() { log.Printf("Failed to read RSA private key file at %s, generating test key: %v", keyPath, err) generatedKey, genErr := rsa.GenerateKey(rand.Reader, 2048) if genErr != nil { log.Fatalf("Failed to generate test RSA private key: %v", genErr) } rsaPrivateKey = generatedKey return } log.Fatalf("Failed to read RSA private key file: %v", err) } parsedKey, err := parseRSAPrivateKey(keyData) if err != nil { log.Fatalf("%v", err) } rsaPrivateKey = parsedKey log.Println("RSA private key loaded successfully for JWT signing") } // GenerateTokens generates both access and refresh tokens with session management. // It creates a new session in the database and caches it in Redis for performance. // // Parameters: // - email: The email address to include in the JWT claims. // - userAgent: The user agent string from the request. // - ipAddress: The IP address of the client. // // Returns: func GenerateTokens(email, userAgent, ipAddress string) (string, string, error) { ctx := context.Background() emailExists, err := CheckEmailInDB(email) if err != nil { return "", "", fmt.Errorf("error checking email in database: %w", err) } userID, err := services.GetUserIDFromEmail(email) if err != nil { userID = helper.UUIDGenerator() } log.Print("userID:", userID) roleID, err := services.GetRoleIDsFromEmail(email) if err != nil { return "", "", fmt.Errorf("error checking role in database: %w", err) } sessionID := helper.UUIDGenerator() refreshToken, err := generateSecureToken() if err != nil { return "", "", fmt.Errorf("failed to generate refresh token: %w", err) } refreshTokenHash := helper.CalculateSHA256(refreshToken) location, err := helper.LoadAsiaManilaLocation() if err != nil { helper.LogError(err, "Failed to load Asia/Manila timezone, using UTC+8 offset") } currentTime := time.Now().In(location) var expiresAt time.Time if emailExists { expiresAt = currentTime.Add(7 * 24 * time.Hour) } else { expiresAt = currentTime.Add(2 * time.Hour) } session := models.JWTSession{ ID: sessionID, UsersID: userID, RefreshTokenHash: refreshTokenHash, UserAgent: userAgent, IPAddress: ipAddress, CreatedAt: currentTime, UpdatedAt: currentTime, ExpiresAt: expiresAt, IsRevoked: false, } _, err = db.DB.Exec(` INSERT INTO jwt_sessions (jwt_sessions_id, user_id, refresh_token_hash, user_agent, ip_address, created_at, updated_at, expires_at, is_revoked) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) `, sessionID, userID, refreshTokenHash, userAgent, ipAddress, currentTime, currentTime, expiresAt, false) if err != nil { return "", "", fmt.Errorf("failed to store session: %w", err) } sessionKey := fmt.Sprintf(redisKeyJWTSession, refreshTokenHash) sessionIDKey := fmt.Sprintf(redisKeyJWTSessionID, sessionID) sessionTTL := int(time.Until(expiresAt).Seconds()) if sessionTTL > 0 { if err := helper.SetJSON(ctx, sessionKey, session, &sessionTTL); err != nil { helper.LogError(err, "Failed to cache session in Redis (sessionKey)") } if err := helper.SetJSON(ctx, sessionIDKey, session, &sessionTTL); err != nil { helper.LogError(err, "Failed to cache session in Redis (sessionIDKey)") } } // Convert roleIDs slice to a comma-separated string for the token claim var roleIDsStr string if len(roleID) > 0 { for i, r := range roleID { if i > 0 { roleIDsStr += "," } roleIDsStr += fmt.Sprintf("%d", r) } } accessToken, err := generateAccessToken(email, sessionID, userID, roleID) if err != nil { return "", "", fmt.Errorf(errFormatWithContext, errMsgFailedToGenerateAccessToken, err) } log.Printf("Generated tokens for user %s with session %s", email, sessionID) return accessToken, refreshToken, nil } func generateAccessToken(email, sessionID, userID string, roleID []int) (string, error) { AccessTokenExpiration := os.Getenv("ACCESS_TOKEN_EXPIRATION_MINUTES") if AccessTokenExpiration == "" { log.Println("AccessTokenExpiration not set (in minutes), defaulting to 45 minutes") AccessTokenExpiration = "45" } if roleID == nil { roleID = []int{} } var primaryRoleID *int if len(roleID) > 0 { value := roleID[0] primaryRoleID = &value } expirationTime := time.Now().Add(24 * time.Hour).Unix() claims := &models.AccessToken{ Email: email, UsersID: userID, RoleID: primaryRoleID, AdditionalRoleID: roleID, SessionID: sessionID, Exp: expirationTime, RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(time.Unix(expirationTime, 0)), }, } token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) return token.SignedString(rsaPrivateKey) } func generateSecureToken() (string, error) { bytes := make([]byte, 32) // 256 bits _, err := rand.Read(bytes) if err != nil { return "", err } return base64.URLEncoding.EncodeToString(bytes), nil } // RefreshAccessToken refreshes the access token using a valid refresh token. // It validates the refresh token, checks the session status, and generates a new access token. // Uses Redis for session caching to improve performance for websocket connections. // // Parameters: // - refreshTokenString: The refresh token to use for refreshing the access token. // - userAgent: The user agent string from the request. // - ipAddress: The IP address of the client. // // Returns: // - string: The new signed access token as a string. // - error: An error if the token is invalid or the process fails. func RefreshAccessToken(refreshTokenString, userAgent, ipAddress string) (string, error) { ctx := context.Background() refreshTokenHash := helper.CalculateSHA256(refreshTokenString) helper.LogInfo(fmt.Sprintf("RefreshAccessToken called - Token length: %d, Hash: %s", len(refreshTokenString), refreshTokenHash[:16]+"...")) helper.LogInfo(fmt.Sprintf("Client details - UserAgent: %s, IP: %s", userAgent, ipAddress)) rateLimitKey := fmt.Sprintf("refresh_rate_limit:%s", refreshTokenHash) attempts, err := redisclient.RDB.Incr(ctx, rateLimitKey).Result() if err == nil { if attempts == 1 { redisclient.RDB.Expire(ctx, rateLimitKey, time.Minute) } if attempts > 5 { return "", fmt.Errorf("too many refresh attempts, please wait") } } sessionKey := fmt.Sprintf(redisKeyJWTSession, refreshTokenHash) var session models.JWTSession err = helper.GetJSON(ctx, sessionKey, &session) if err != nil { helper.LogInfo(fmt.Sprintf("Session not found in Redis cache, querying database for hash: %s", refreshTokenHash[:16]+"...")) err = db.DB.QueryRow(` SELECT id, user_id, user_agent, ip_address, created_at, updated_at, expires_at, is_revoked FROM jwt_sessions WHERE refresh_token_hash = ? AND is_revoked = false `, refreshTokenHash).Scan( &session.ID, &session.UsersID, &session.UserAgent, &session.IPAddress, &session.CreatedAt, &session.UpdatedAt, &session.ExpiresAt, &session.IsRevoked, ) if err != nil { helper.LogError(err, fmt.Sprintf("Session not found in database for hash: %s", refreshTokenHash[:16]+"...")) return "", fmt.Errorf("invalid refresh token: %w", err) } helper.LogInfo(fmt.Sprintf("Session found in DB - ID: %s, UsersID: %s, Created: %s, Expires: %s", session.ID, session.UsersID, session.CreatedAt.Format(timeFormatDateTime), session.ExpiresAt.Format(timeFormatDateTime))) sessionTTL := int(time.Until(session.ExpiresAt).Seconds()) if sessionTTL > 0 { if err := helper.SetJSON(ctx, sessionKey, session, &sessionTTL); err != nil { helper.LogError(err, errMsgFailedToUpdateSessionActivity) } } } else { helper.LogInfo(fmt.Sprintf("Session found in Redis cache - ID: %s, UsersID: %s, Expires: %s", session.ID, session.UsersID, session.ExpiresAt.Format(timeFormatDateTime))) } if session.IsRevoked { helper.LogWarn(fmt.Sprintf("Attempted to use revoked session: %s", session.ID)) redisclient.RDB.Del(ctx, sessionKey) return "", fmt.Errorf(errMsgSessionHasBeenRevoked) } if time.Now().After(session.ExpiresAt) { helper.LogWarn(fmt.Sprintf("Attempted to use expired session: %s (expired at %s)", session.ID, session.ExpiresAt.Format(timeFormatDateTime))) _, err = db.DB.Exec(sqlUpdateRevokeSession, session.ID) if err != nil { helper.LogError(err, "Failed to revoke expired session") } redisclient.RDB.Del(ctx, sessionKey) return "", fmt.Errorf("refresh token has expired") } if session.UserAgent != userAgent { helper.LogWarn(fmt.Sprintf("Session User Agent security mismatch for session %s: stored='%s', received='%s'", session.ID, session.UserAgent, userAgent)) } if session.IPAddress != ipAddress { helper.LogWarn(fmt.Sprintf("Session IP address mismatch for session %s: expected %s, got %s", session.ID, session.IPAddress, ipAddress)) } // Get user email from user ID (with caching) email, err := getUserEmailFromIDCached(session.UsersID) if err != nil { helper.LogError(err, fmt.Sprintf("Failed to get email for user %s", session.UsersID)) // For registrants or users not yet in the main tables, we still want to allow refresh // but we need to get the email from somewhere else. Since we don't store email in session, // we'll need to handle this gracefully by allowing the refresh to continue with a placeholder // The email will be properly resolved when they complete registration helper.LogWarn(fmt.Sprintf("User ID %s not found in database tables, but allowing refresh for potential registrant", session.UsersID)) // For now, we'll use a placeholder email pattern and let the access token generation handle it // The system should work as long as the session is valid email = fmt.Sprintf("registrant_%s@pending.local", session.UsersID) } helper.LogInfo(fmt.Sprintf("Generating new access token for email: %s, session: %s", email, session.ID)) userID, err := helper.FetchUserIDFromDB(email) if err != nil { helper.LogError(err, fmt.Sprintf("Failed to fetch user ID for email %s during refresh", email)) userID = session.UsersID // Fallback to session's user ID } roleIDs, err := services.GetRoleIDsFromEmail(email) if err != nil { helper.LogError(err, fmt.Sprintf("Failed to fetch role ID for email %s during refresh", email)) roleIDs = []int{} } accessToken, err := generateAccessToken(email, session.ID, userID, roleIDs) if err != nil { helper.LogError(err, "Failed to generate access token during refresh") return "", fmt.Errorf("failed to generate access token: %w", err) } helper.LogInfo(fmt.Sprintf("Successfully refreshed access token for user %s (session: %s)", email, session.ID)) session.UpdatedAt = time.Now() sessionTTL := int(time.Until(session.ExpiresAt).Seconds()) if sessionTTL > 0 { if err := helper.SetJSON(ctx, sessionKey, session, &sessionTTL); err != nil { helper.LogError(err, errMsgFailedToUpdateSessionActivity) } } go func() { _, err := db.DB.Exec("UPDATE jwt_sessions SET updated_at = ? WHERE id = ?", session.UpdatedAt, session.ID) if err != nil { helper.LogError(err, "Failed to update session activity in DB") } }() return accessToken, nil } // RefreshAccessTokenWithEmailFallback refreshes the access token using a valid refresh token with email fallback. // This version handles cases where the user ID in the session doesn't exist in the database (e.g., registrants). // // Parameters: // - refreshTokenString: The refresh token to use for refreshing the access token. // - userAgent: The user agent string from the request. // - ipAddress: The IP address of the client. // - emailFallback: Email to use if user ID lookup fails (extracted from current access token). // // Returns: // - string: The new signed access token as a string. // - error: An error if the token is invalid or the process fails. func RefreshAccessTokenWithEmailFallback(refreshTokenString, userAgent, ipAddress, emailFallback string) (string, error) { ctx := context.Background() refreshTokenHash := helper.CalculateSHA256(refreshTokenString) helper.LogInfo(fmt.Sprintf("RefreshAccessTokenWithEmailFallback called - Token length: %d, Hash: %s", len(refreshTokenString), refreshTokenHash[:16]+"...")) helper.LogInfo(fmt.Sprintf("Client details - UserAgent: %s, IP: %s, EmailFallback: %s", userAgent, ipAddress, emailFallback)) rateLimitKey := fmt.Sprintf(redisKeyRefreshRateLimit, refreshTokenHash) attempts, err := redisclient.RDB.Incr(ctx, rateLimitKey).Result() if err == nil { if attempts == 1 { redisclient.RDB.Expire(ctx, rateLimitKey, time.Minute) } if attempts > 5 { return "", fmt.Errorf("too many refresh attempts, please wait") } } sessionKey := fmt.Sprintf(redisKeyJWTSession, refreshTokenHash) var session models.JWTSession err = helper.GetJSON(ctx, sessionKey, &session) if err != nil { helper.LogInfo(fmt.Sprintf("Session not found in Redis cache, querying database for hash: %s", refreshTokenHash[:16]+"...")) err = db.DB.QueryRow(` SELECT id, user_id, user_agent, ip_address, created_at, updated_at, expires_at, is_revoked FROM jwt_sessions WHERE refresh_token_hash = ? AND is_revoked = false `, refreshTokenHash).Scan( &session.ID, &session.UsersID, &session.UserAgent, &session.IPAddress, &session.CreatedAt, &session.UpdatedAt, &session.ExpiresAt, &session.IsRevoked, ) if err != nil { helper.LogError(err, fmt.Sprintf("Session not found in database for hash: %s", refreshTokenHash[:16]+"...")) return "", fmt.Errorf("invalid refresh token: %w", err) } helper.LogInfo(fmt.Sprintf("Session found in DB - ID: %s, UsersID: %s, Created: %s, Expires: %s", session.ID, session.UsersID, session.CreatedAt.Format(timeFormatDateTime), session.ExpiresAt.Format(timeFormatDateTime))) sessionTTL := int(time.Until(session.ExpiresAt).Seconds()) if sessionTTL > 0 { if err := helper.SetJSON(ctx, sessionKey, session, &sessionTTL); err != nil { helper.LogError(err, errMsgFailedToUpdateSessionActivity) } } } else { helper.LogInfo(fmt.Sprintf("Session found in Redis cache - ID: %s, UsersID: %s, Expires: %s", session.ID, session.UsersID, session.ExpiresAt.Format(timeFormatDateTime))) } if session.IsRevoked { helper.LogWarn(fmt.Sprintf("Attempted to use revoked session: %s", session.ID)) redisclient.RDB.Del(ctx, sessionKey) return "", fmt.Errorf(errMsgSessionHasBeenRevoked) } if time.Now().After(session.ExpiresAt) { helper.LogWarn(fmt.Sprintf("Attempted to use expired session: %s (expired at %s)", session.ID, session.ExpiresAt.Format(timeFormatDateTime))) _, err = db.DB.Exec(sqlUpdateRevokeSession, session.ID) if err != nil { helper.LogError(err, "Failed to revoke expired session") } redisclient.RDB.Del(ctx, sessionKey) return "", fmt.Errorf("refresh token has expired") } if session.UserAgent != userAgent { helper.LogWarn(fmt.Sprintf("Session User Agent security mismatch for session %s: stored='%s', received='%s'", session.ID, session.UserAgent, userAgent)) } if session.IPAddress != ipAddress { helper.LogWarn(fmt.Sprintf("Session IP address mismatch for session %s: expected %s, got %s", session.ID, session.IPAddress, ipAddress)) } // Get user email from user ID (with caching), with fallback to provided email email, err := getUserEmailFromIDCached(session.UsersID) if err != nil { helper.LogError(err, fmt.Sprintf("Failed to get email for user %s", session.UsersID)) if emailFallback != "" { helper.LogInfo(fmt.Sprintf("Using email fallback for user ID %s: %s", session.UsersID, emailFallback)) email = emailFallback } else { helper.LogWarn(fmt.Sprintf("User ID %s not found in database tables and no email fallback provided", session.UsersID)) return "", fmt.Errorf("failed to get user email: %w", err) } } helper.LogInfo(fmt.Sprintf("Generating new access token for email: %s, session: %s", email, session.ID)) userID, err := helper.FetchUserIDFromDB(email) if err != nil { helper.LogError(err, fmt.Sprintf("Failed to fetch user ID for email %s during refresh", email)) userID = session.UsersID // Fallback to session's user ID } roleIDs, err := services.GetRoleIDsFromEmail(email) if err != nil { helper.LogError(err, fmt.Sprintf("Failed to fetch role ID for email %s during refresh", email)) roleIDs = []int{} } accessToken, err := generateAccessToken(email, session.ID, userID, roleIDs) if err != nil { helper.LogError(err, "Failed to generate access token during refresh") return "", fmt.Errorf("failed to generate access token: %w", err) } helper.LogInfo(fmt.Sprintf("Successfully refreshed access token for user %s (session: %s)", email, session.ID)) session.UpdatedAt = time.Now() sessionTTL := int(time.Until(session.ExpiresAt).Seconds()) if sessionTTL > 0 { if err := helper.SetJSON(ctx, sessionKey, session, &sessionTTL); err != nil { helper.LogError(err, errMsgFailedToUpdateSessionActivity) } } go func() { _, err := db.DB.Exec("UPDATE jwt_sessions SET updated_at = ? WHERE id = ?", session.UpdatedAt, session.ID) if err != nil { helper.LogError(err, "Failed to update session activity in DB") } }() return accessToken, nil } func RevokeSession(sessionID string) error { ctx := context.Background() _, err := db.DB.Exec(sqlUpdateRevokeSession, sessionID) if err != nil { return fmt.Errorf("failed to revoke session %s: %w", sessionID, err) } sessionKey := fmt.Sprintf(redisKeyJWTSessionID, sessionID) redisclient.RDB.Del(ctx, sessionKey) return nil } func RevokeAllUserSessions(userID string) error { ctx := context.Background() rows, err := db.DB.Query("SELECT jwt_sessions_id FROM jwt_sessions WHERE user_id = ? AND is_revoked = false", userID) if err != nil { return fmt.Errorf(errFormatWithContext, errMsgFailedToGetUserSessions, err) } defer rows.Close() var sessionIDs []string for rows.Next() { var sessionID string if err := rows.Scan(&sessionID); err != nil { continue } sessionIDs = append(sessionIDs, sessionID) } _, err = db.DB.Exec("UPDATE jwt_sessions SET is_revoked = true WHERE user_id = ?", userID) if err != nil { return fmt.Errorf("failed to revoke all sessions for user %s: %w", userID, err) } for _, sessionID := range sessionIDs { sessionKey := fmt.Sprintf(redisKeyJWTSessionID, sessionID) redisclient.RDB.Del(ctx, sessionKey) } userEmailKey := fmt.Sprintf(redisKeyUserEmail, userID) redisclient.RDB.Del(ctx, userEmailKey) return nil } func RevokeAllUserSessionsExceptCurrent(userID, currentSessionID string) error { ctx := context.Background() rows, err := db.DB.Query("SELECT id FROM jwt_sessions WHERE user_id = ? AND id != ? AND is_revoked = false", userID, currentSessionID) if err != nil { return fmt.Errorf(errFormatWithContext, errMsgFailedToGetUserSessions, err) } defer rows.Close() var sessionIDs []string for rows.Next() { var sessionID string if err := rows.Scan(&sessionID); err != nil { continue } sessionIDs = append(sessionIDs, sessionID) } _, err = db.DB.Exec( "UPDATE jwt_sessions SET is_revoked = true WHERE user_id = ? AND id != ?", userID, currentSessionID, ) if err != nil { return fmt.Errorf("failed to revoke other sessions for user %s: %w", userID, err) } for _, sessionID := range sessionIDs { sessionKey := fmt.Sprintf(redisKeyJWTSessionID, sessionID) redisclient.RDB.Del(ctx, sessionKey) } return nil } func ValidateSession(sessionID string) (*models.JWTSession, error) { ctx := context.Background() sessionKey := fmt.Sprintf(redisKeyJWTSessionID, sessionID) var session models.JWTSession err := helper.GetJSON(ctx, sessionKey, &session) if err != nil { err = db.DB.QueryRow(` SELECT id, user_id, refresh_token_hash, user_agent, ip_address, created_at, updated_at, expires_at, is_revoked FROM jwt_sessions WHERE id = ? `, sessionID).Scan( &session.ID, &session.UsersID, &session.RefreshTokenHash, &session.UserAgent, &session.IPAddress, &session.CreatedAt, &session.UpdatedAt, &session.ExpiresAt, &session.IsRevoked, ) if err != nil { return nil, fmt.Errorf("session not found: %w", err) } sessionTTL := int(time.Until(session.ExpiresAt).Seconds()) if sessionTTL > 0 { if err := helper.SetJSON(ctx, sessionKey, session, &sessionTTL); err != nil { helper.LogError(err, "Failed to cache session in Redis (ValidateSession)") } } } if session.IsRevoked { redisclient.RDB.Del(ctx, sessionKey) return nil, fmt.Errorf("%s", errMsgSessionHasBeenRevoked) } if time.Now().After(session.ExpiresAt) { if err := RevokeSession(sessionID); err != nil { helper.LogError(err, "Failed to auto-revoke expired session") } redisclient.RDB.Del(ctx, sessionKey) return nil, fmt.Errorf("session has expired") } return &session, nil } func ValidateSessionForWebSocket(sessionID string) (*models.JWTSession, error) { ctx := context.Background() sessionKey := fmt.Sprintf(redisKeyJWTSessionID, sessionID) var session models.JWTSession err := helper.GetJSON(ctx, sessionKey, &session) if err != nil { return nil, fmt.Errorf("%s", errMsgSessionNotFoundInCache) } if session.IsRevoked { redisclient.RDB.Del(ctx, sessionKey) return nil, fmt.Errorf("%s", errMsgSessionHasBeenRevoked) } if time.Now().After(session.ExpiresAt) { redisclient.RDB.Del(ctx, sessionKey) return nil, fmt.Errorf("session has expired") } return &session, nil } func ExtendSessionActivity(sessionID string) error { ctx := context.Background() sessionKey := fmt.Sprintf(redisKeyJWTSessionID, sessionID) var session models.JWTSession err := helper.GetJSON(ctx, sessionKey, &session) if err != nil { return fmt.Errorf("%s", errMsgSessionNotFoundInCache) } session.UpdatedAt = time.Now() sessionTTL := int(time.Until(session.ExpiresAt).Seconds()) if sessionTTL > 0 { if err := helper.SetJSON(ctx, sessionKey, session, &sessionTTL); err != nil { helper.LogError(err, "Failed to extend session activity in Redis cache") } } return nil } func GetSessionUserInfo(sessionID string) (string, string, error) { ctx := context.Background() sessionKey := fmt.Sprintf(redisKeyJWTSessionID, sessionID) var session models.JWTSession err := helper.GetJSON(ctx, sessionKey, &session) if err != nil { return "", "", fmt.Errorf("%s", errMsgSessionNotFoundInCache) } email, err := getUserEmailFromIDCached(session.UsersID) if err != nil { return "", "", fmt.Errorf("failed to get user email: %w", err) } return session.UsersID, email, nil } func InvalidateUserSessionsInCache(userID string) error { ctx := context.Background() rows, err := db.DB.Query("SELECT id, refresh_token_hash FROM jwt_sessions WHERE user_id = ?", userID) if err != nil { return fmt.Errorf(errFormatWithContext, errMsgFailedToGetUserSessions, err) } defer rows.Close() var keys []string for rows.Next() { var sessionID, refreshTokenHash string if err := rows.Scan(&sessionID, &refreshTokenHash); err != nil { continue } keys = append(keys, fmt.Sprintf(redisKeyJWTSessionID, sessionID)) keys = append(keys, fmt.Sprintf(redisKeyJWTSession, refreshTokenHash)) } if len(keys) > 0 { redisclient.RDB.Del(ctx, keys...) } userEmailKey := fmt.Sprintf(redisKeyUserEmail, userID) redisclient.RDB.Del(ctx, userEmailKey) return nil } func CleanupExpiredSessions() error { ctx := context.Background() rows, err := db.DB.Query("SELECT id, user_id, refresh_token_hash FROM jwt_sessions WHERE expires_at < ?", time.Now()) if err != nil { return fmt.Errorf("failed to query expired sessions: %w", err) } defer rows.Close() var expiredSessions []models.ExpiredSession userIDsToCleanup := make(map[string]bool) for rows.Next() { var session models.ExpiredSession if err := rows.Scan(&session.ID, &session.UsersID, &session.RefreshTokenHash); err != nil { continue } expiredSessions = append(expiredSessions, session) userIDsToCleanup[session.UsersID] = true } _, err = db.DB.Exec("DELETE FROM jwt_sessions WHERE expires_at < ?", time.Now()) if err != nil { return fmt.Errorf("failed to cleanup expired sessions: %w", err) } for _, session := range expiredSessions { sessionKey := fmt.Sprintf(redisKeyJWTSession, session.RefreshTokenHash) sessionIDKey := fmt.Sprintf(redisKeyJWTSessionID, session.ID) redisclient.RDB.Del(ctx, sessionKey, sessionIDKey) } // Role cache invalidation removed - handled by separate authz microservice log.Printf("Cleaned up %d expired sessions", len(expiredSessions)) return nil } func GetUserSessions(userID string) ([]models.JWTSession, error) { rows, err := db.DB.Query(` SELECT id, user_id, refresh_token_hash, user_agent, ip_address, created_at, updated_at, expires_at, is_revoked FROM jwt_sessions WHERE user_id = ? AND is_revoked = false AND expires_at > ? ORDER BY created_at DESC `, userID, time.Now()) if err != nil { return nil, fmt.Errorf(errFormatWithContext, errMsgFailedToGetUserSessions, err) } defer rows.Close() var sessions []models.JWTSession for rows.Next() { var session models.JWTSession err := rows.Scan( &session.ID, &session.UsersID, &session.RefreshTokenHash, &session.UserAgent, &session.IPAddress, &session.CreatedAt, &session.UpdatedAt, &session.ExpiresAt, &session.IsRevoked, ) if err != nil { return nil, fmt.Errorf("failed to scan session row: %w", err) } sessions = append(sessions, session) } return sessions, nil } func UpdateSessionLastActivity(sessionID string) error { _, err := db.DB.Exec(` UPDATE jwt_sessions SET updated_at = ? WHERE id = ? `, time.Now(), sessionID) if err != nil { return fmt.Errorf("failed to update session activity: %w", err) } return nil } func getUserEmailFromID(userID string) (string, error) { var email string err := db.DB.QueryRow("SELECT email_address FROM users WHERE id = ?", userID).Scan(&email) if err == nil { return email, nil } return "", fmt.Errorf("user not found with ID %s in any table", userID) } func getUserEmailFromIDCached(userID string) (string, error) { ctx := context.Background() cacheKey := fmt.Sprintf(redisKeyUserEmail, userID) var email string err := helper.GetJSON(ctx, cacheKey, &email) if err == nil && email != "" { return email, nil } // Cache miss, feth from database email, err = getUserEmailFromID(userID) if err != nil { return "", err } cacheTTL := 3600 if err := helper.SetJSON(ctx, cacheKey, email, &cacheTTL); err != nil { helper.LogError(err, "Failed to cache user email in Redis") } return email, nil } func AddToSessionBlacklist(sessionID string, ttlSeconds int) error { ctx := context.Background() blacklistKey := fmt.Sprintf(redisKeySessionBlacklist, sessionID) ttl := time.Duration(ttlSeconds) * time.Second return redisclient.RDB.Set(ctx, blacklistKey, "revoked", ttl).Err() } func IsSessionBlacklisted(sessionID string) bool { ctx := context.Background() blacklistKey := fmt.Sprintf(redisKeySessionBlacklist, sessionID) exists, err := redisclient.RDB.Exists(ctx, blacklistKey).Result() return err == nil && exists > 0 } func ClearSessionFromAllCaches(sessionID, refreshTokenHash string) error { ctx := context.Background() keys := []string{ fmt.Sprintf(redisKeyJWTSessionID, sessionID), fmt.Sprintf(redisKeyJWTSession, refreshTokenHash), } return redisclient.RDB.Del(ctx, keys...).Err() } func CheckEmailInDB(email string) (bool, error) { if db.DB == nil { return false, fmt.Errorf("database connection is nil") } var exists bool err := db.DB.QueryRow( `SELECT EXISTS( SELECT 1 FROM users WHERE email_address = ? AND is_deleted = 0)`, email, ).Scan(&exists) if err != nil { return false, fmt.Errorf("error checking email in database: %v", err) } return exists, nil }