Files

1025 lines
33 KiB
Go

package handlers
import (
"authentication/db"
"authentication/helper"
"authentication/models"
"authentication/redisclient"
"authentication/services"
"context"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"fmt"
"log"
"os"
"sort"
"strconv"
"time"
"github.com/golang-jwt/jwt/v5"
)
var rsaPrivateKey *rsa.PrivateKey
func parseRSAPrivateKey(keyData []byte) (*rsa.PrivateKey, error) {
block, _ := pem.Decode(keyData)
if block == nil {
return nil, fmt.Errorf("failed to decode PEM block containing private key")
}
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
key, err = x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to parse RSA private key: %w", err)
}
}
rsaKey, ok := key.(*rsa.PrivateKey)
if !ok {
return nil, fmt.Errorf("not an RSA private key")
}
return rsaKey, nil
}
// Note: .env file is loaded in main() before this init runs
func init() {
keyPath := os.Getenv("JWT_PRIVATE_KEY_PATH")
if keyPath == "" {
keyPath = "rsa/psa.gov.ph.key"
}
keyData, err := os.ReadFile(keyPath) // #nosec G304
if err != nil {
if isTestEnvironment() {
log.Printf("Failed to read RSA private key file at %s, generating test key: %v", keyPath, err)
generatedKey, genErr := rsa.GenerateKey(rand.Reader, 2048)
if genErr != nil {
log.Fatalf("Failed to generate test RSA private key: %v", genErr)
}
rsaPrivateKey = generatedKey
return
}
log.Fatalf("Failed to read RSA private key file: %v", err)
}
parsedKey, err := parseRSAPrivateKey(keyData)
if err != nil {
log.Fatalf("%v", err)
}
rsaPrivateKey = parsedKey
log.Println("RSA private key loaded successfully for JWT signing")
}
// GenerateTokens generates both access and refresh tokens with session management.
// It creates a new session in the database and caches it in Redis for performance.
//
// Parameters:
// - email: The email address to include in the JWT claims.
// - userAgent: The user agent string from the request.
// - ipAddress: The IP address of the client.
//
// Returns:
func GenerateTokens(email, userAgent, ipAddress string) (string, string, error) {
totalStart := time.Now()
helper.LogInfo(fmt.Sprintf("[oauth-debug] GenerateTokens start email=%s", email))
ctx := context.Background()
emailExistsStart := time.Now()
emailExists, err := CheckEmailInDB(email)
if err != nil {
helper.LogInfo(fmt.Sprintf("[oauth-debug] GenerateTokens CheckEmailInDB failed duration_ms=%d", time.Since(emailExistsStart).Milliseconds()))
return "", "", fmt.Errorf("error checking email in database: %w", err)
}
helper.LogInfo(fmt.Sprintf("[oauth-debug] GenerateTokens CheckEmailInDB ok duration_ms=%d", time.Since(emailExistsStart).Milliseconds()))
userIDLookupStart := time.Now()
userID, err := services.GetUserIDFromEmail(email)
if err != nil {
userID = helper.UUIDGenerator()
}
helper.LogInfo(fmt.Sprintf("[oauth-debug] GenerateTokens GetUserIDFromEmail done duration_ms=%d", time.Since(userIDLookupStart).Milliseconds()))
log.Print("userID:", userID)
tokenEmail := email
roleID := make([]int, 0)
fetchUserStart := time.Now()
user, err := services.FetchUserByEmail(email)
if err != nil {
helper.LogWarn(fmt.Sprintf("Failed to fetch user profile for JWT role mapping (%s): %v", email, err))
} else {
if user.UserID != "" {
userID = user.UserID
}
if user.EmailAddress != "" {
tokenEmail = user.EmailAddress
}
roleID = buildJWTClaimRoleIDs(user)
}
helper.LogInfo(fmt.Sprintf("[oauth-debug] GenerateTokens FetchUserByEmail done duration_ms=%d", time.Since(fetchUserStart).Milliseconds()))
sessionID := helper.UUIDGenerator()
refreshToken, err := generateSecureToken()
if err != nil {
return "", "", fmt.Errorf("failed to generate refresh token: %w", err)
}
refreshTokenHash := helper.CalculateSHA256(refreshToken)
location, err := helper.LoadAsiaManilaLocation()
if err != nil {
helper.LogError(err, "Failed to load Asia/Manila timezone, using UTC+8 offset")
}
currentTime := time.Now().In(location)
var expiresAt time.Time
if emailExists {
expiresAt = currentTime.Add(7 * 24 * time.Hour)
} else {
expiresAt = currentTime.Add(2 * time.Hour)
}
session := models.JWTSession{
ID: sessionID,
UsersID: userID,
RefreshTokenHash: refreshTokenHash,
UserAgent: userAgent,
IPAddress: ipAddress,
CreatedAt: currentTime,
UpdatedAt: currentTime,
ExpiresAt: expiresAt,
IsRevoked: false,
}
dbInsertStart := time.Now()
helper.LogInfo("[oauth-debug] GenerateTokens jwt_sessions insert start")
dbInsertCtx, dbInsertCancel := context.WithTimeout(ctx, 3*time.Second)
defer dbInsertCancel()
_, err = db.DB.ExecContext(dbInsertCtx, `
INSERT INTO jwt_sessions (jwt_sessions_id, user_id, refresh_token_hash, user_agent, ip_address, created_at, updated_at, expires_at, is_revoked)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
`, sessionID, userID, refreshTokenHash, userAgent, ipAddress, currentTime, currentTime, expiresAt, false)
if err != nil {
helper.LogInfo(fmt.Sprintf("[oauth-debug] GenerateTokens jwt_sessions insert failed duration_ms=%d", time.Since(dbInsertStart).Milliseconds()))
return "", "", fmt.Errorf("failed to store session: %w", err)
}
helper.LogInfo(fmt.Sprintf("[oauth-debug] GenerateTokens jwt_sessions insert ok duration_ms=%d", time.Since(dbInsertStart).Milliseconds()))
sessionKey := fmt.Sprintf(redisKeyJWTSession, refreshTokenHash)
sessionIDKey := fmt.Sprintf(redisKeyJWTSessionID, sessionID)
sessionTTL := int(time.Until(expiresAt).Seconds())
if sessionTTL > 0 {
redisStart := time.Now()
helper.LogInfo("[oauth-debug] GenerateTokens redis cache start")
redisCtx, redisCancel := context.WithTimeout(ctx, 2*time.Second)
defer redisCancel()
if err := helper.SetJSON(redisCtx, sessionKey, session, &sessionTTL); err != nil {
helper.LogError(err, "Failed to cache session in Redis (sessionKey)")
}
if err := helper.SetJSON(redisCtx, sessionIDKey, session, &sessionTTL); err != nil {
helper.LogError(err, "Failed to cache session in Redis (sessionIDKey)")
}
helper.LogInfo(fmt.Sprintf("[oauth-debug] GenerateTokens redis cache done duration_ms=%d", time.Since(redisStart).Milliseconds()))
}
signStart := time.Now()
helper.LogInfo("[oauth-debug] GenerateTokens access token sign start")
accessToken, err := generateAccessToken(tokenEmail, sessionID, userID, roleID)
if err != nil {
helper.LogInfo(fmt.Sprintf("[oauth-debug] GenerateTokens access token sign failed duration_ms=%d", time.Since(signStart).Milliseconds()))
return "", "", fmt.Errorf(errFormatWithContext, errMsgFailedToGenerateAccessToken, err)
}
helper.LogInfo(fmt.Sprintf("[oauth-debug] GenerateTokens access token sign ok duration_ms=%d", time.Since(signStart).Milliseconds()))
helper.LogInfo(fmt.Sprintf("[oauth-debug] GenerateTokens complete total_ms=%d", time.Since(totalStart).Milliseconds()))
log.Printf("Generated tokens for user %s with session %s", email, sessionID)
return accessToken, refreshToken, nil
}
func buildJWTClaimRoleIDs(user models.User) []int {
roleMap := make(map[int]struct{})
orderedRoles := make([]int, 0)
roleOneProjects := make([]int, 0)
if user.RoleID != nil {
helper.LogInfo(fmt.Sprintf("JWT role source - base role_id: %d", *user.RoleID))
}
addUnique := func(role int) {
if _, exists := roleMap[role]; exists {
return
}
roleMap[role] = struct{}{}
orderedRoles = append(orderedRoles, role)
}
if user.RoleID != nil {
addUnique(*user.RoleID)
}
if user.Projects != nil {
for _, project := range *user.Projects {
helper.LogInfo(fmt.Sprintf("JWT role source - project %d: role_id=%v", project.ProjectID, project.RoleID))
for _, role := range project.RoleID {
if role == 1 {
roleOneProjects = append(roleOneProjects, project.ProjectID)
}
addUnique(role)
}
}
}
if len(roleOneProjects) > 0 {
helper.LogInfo(fmt.Sprintf("JWT role trace - additional role_id=1 found in project_id(s): %v", roleOneProjects))
} else {
helper.LogInfo("JWT role trace - additional role_id=1 not found in project roles")
}
if len(orderedRoles) <= 1 {
if len(orderedRoles) == 1 {
helper.LogInfo(fmt.Sprintf("JWT role claims - primary role_id=%d, additional_role_id=[]", orderedRoles[0]))
} else {
helper.LogInfo("JWT role claims - primary role_id=nil, additional_role_id=[]")
}
return orderedRoles
}
primaryRole := orderedRoles[0]
remainingRoles := append([]int(nil), orderedRoles[1:]...)
sort.Ints(remainingRoles)
finalRoles := append([]int{primaryRole}, remainingRoles...)
helper.LogInfo(fmt.Sprintf("JWT role claims - primary role_id=%d, additional_role_id=%v", primaryRole, remainingRoles))
return finalRoles
}
func generateAccessToken(email, sessionID, userID string, roleID []int) (string, error) {
AccessTokenExpiration := os.Getenv("ACCESS_TOKEN_EXPIRATION_MINUTES")
if AccessTokenExpiration == "" {
log.Println("AccessTokenExpiration not set (in minutes), defaulting to 45 minutes")
AccessTokenExpiration = "45"
}
log.Print("AccessTokenExpiration (minutes):", AccessTokenExpiration)
if roleID == nil {
roleID = []int{}
}
var primaryRoleID *int
additionalRoleIDs := make([]int, 0)
if len(roleID) > 0 {
primaryRoleID = &roleID[0]
if len(roleID) > 1 {
additionalRoleIDs = append(additionalRoleIDs, roleID[1:]...)
}
}
expirationMinutes, err := strconv.Atoi(AccessTokenExpiration)
if err != nil {
log.Println("Invalid AccessTokenExpiration value, defaulting to 45 minutes")
expirationMinutes = 45
}
expirationTime := time.Now().Add(time.Duration(expirationMinutes) * time.Minute).Unix()
claims := &models.AccessToken{
Email: email,
UsersID: userID,
RoleID: primaryRoleID,
AdditionalRoleID: additionalRoleIDs,
SessionID: sessionID,
Exp: expirationTime,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Unix(expirationTime, 0)),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
return token.SignedString(rsaPrivateKey)
}
func generateSecureToken() (string, error) {
bytes := make([]byte, 32) // 256 bits
_, err := rand.Read(bytes)
if err != nil {
return "", err
}
return base64.URLEncoding.EncodeToString(bytes), nil
}
// RefreshAccessToken refreshes the access token using a valid refresh token.
// It validates the refresh token, checks the session status, and generates a new access token.
// Uses Redis for session caching to improve performance for websocket connections.
//
// Parameters:
// - refreshTokenString: The refresh token to use for refreshing the access token.
// - userAgent: The user agent string from the request.
// - ipAddress: The IP address of the client.
//
// Returns:
// - string: The new signed access token as a string.
// - error: An error if the token is invalid or the process fails.
func RefreshAccessToken(refreshTokenString, userAgent, ipAddress string) (string, error) {
ctx := context.Background()
refreshTokenHash := helper.CalculateSHA256(refreshTokenString)
helper.LogInfo(fmt.Sprintf("RefreshAccessToken called - Token length: %d, Hash: %s", len(refreshTokenString), refreshTokenHash[:16]+"..."))
helper.LogInfo(fmt.Sprintf("Client details - UserAgent: %s, IP: %s", userAgent, ipAddress))
rateLimitKey := fmt.Sprintf("refresh_rate_limit:%s", refreshTokenHash)
attempts, err := redisclient.RDB.Incr(ctx, rateLimitKey).Result()
if err == nil {
if attempts == 1 {
redisclient.RDB.Expire(ctx, rateLimitKey, time.Minute)
}
if attempts > 5 {
return "", fmt.Errorf("too many refresh attempts, please wait")
}
}
sessionKey := fmt.Sprintf(redisKeyJWTSession, refreshTokenHash)
var session models.JWTSession
err = helper.GetJSON(ctx, sessionKey, &session)
if err != nil {
helper.LogInfo(fmt.Sprintf("Session not found in Redis cache, querying database for hash: %s", refreshTokenHash[:16]+"..."))
err = db.DB.QueryRow(`
SELECT id, user_id, user_agent, ip_address, created_at, updated_at, expires_at, is_revoked
FROM jwt_sessions
WHERE refresh_token_hash = ? AND is_revoked = false
`, refreshTokenHash).Scan(
&session.ID,
&session.UsersID,
&session.UserAgent,
&session.IPAddress,
&session.CreatedAt,
&session.UpdatedAt,
&session.ExpiresAt,
&session.IsRevoked,
)
if err != nil {
helper.LogError(err, fmt.Sprintf("Session not found in database for hash: %s", refreshTokenHash[:16]+"..."))
return "", fmt.Errorf("invalid refresh token: %w", err)
}
helper.LogInfo(fmt.Sprintf("Session found in DB - ID: %s, UsersID: %s, Created: %s, Expires: %s",
session.ID, session.UsersID, session.CreatedAt.Format(timeFormatDateTime), session.ExpiresAt.Format(timeFormatDateTime)))
sessionTTL := int(time.Until(session.ExpiresAt).Seconds())
if sessionTTL > 0 {
if err := helper.SetJSON(ctx, sessionKey, session, &sessionTTL); err != nil {
helper.LogError(err, errMsgFailedToUpdateSessionActivity)
}
}
} else {
helper.LogInfo(fmt.Sprintf("Session found in Redis cache - ID: %s, UsersID: %s, Expires: %s",
session.ID, session.UsersID, session.ExpiresAt.Format(timeFormatDateTime)))
}
if session.IsRevoked {
helper.LogWarn(fmt.Sprintf("Attempted to use revoked session: %s", session.ID))
redisclient.RDB.Del(ctx, sessionKey)
return "", fmt.Errorf(errMsgSessionHasBeenRevoked)
}
if time.Now().After(session.ExpiresAt) {
helper.LogWarn(fmt.Sprintf("Attempted to use expired session: %s (expired at %s)",
session.ID, session.ExpiresAt.Format(timeFormatDateTime)))
_, err = db.DB.Exec(sqlUpdateRevokeSession, session.ID)
if err != nil {
helper.LogError(err, "Failed to revoke expired session")
}
redisclient.RDB.Del(ctx, sessionKey)
return "", fmt.Errorf("refresh token has expired")
}
if session.UserAgent != userAgent {
helper.LogWarn(fmt.Sprintf("Session User Agent security mismatch for session %s: stored='%s', received='%s'",
session.ID, session.UserAgent, userAgent))
}
if session.IPAddress != ipAddress {
helper.LogWarn(fmt.Sprintf("Session IP address mismatch for session %s: expected %s, got %s",
session.ID, session.IPAddress, ipAddress))
}
// Get user email from user ID (with caching)
email, err := getUserEmailFromIDCached(session.UsersID)
if err != nil {
helper.LogError(err, fmt.Sprintf("Failed to get email for user %s", session.UsersID))
// For registrants or users not yet in the main tables, we still want to allow refresh
// but we need to get the email from somewhere else. Since we don't store email in session,
// we'll need to handle this gracefully by allowing the refresh to continue with a placeholder
// The email will be properly resolved when they complete registration
helper.LogWarn(fmt.Sprintf("User ID %s not found in database tables, but allowing refresh for potential registrant", session.UsersID))
// For now, we'll use a placeholder email pattern and let the access token generation handle it
// The system should work as long as the session is valid
email = fmt.Sprintf("registrant_%s@pending.local", session.UsersID)
}
helper.LogInfo(fmt.Sprintf("Generating new access token for email: %s, session: %s", email, session.ID))
userID, err := helper.FetchUserIDFromDB(email)
if err != nil {
helper.LogError(err, fmt.Sprintf("Failed to fetch user ID for email %s during refresh", email))
userID = session.UsersID // Fallback to session's user ID
}
roleIDs := make([]int, 0)
user, err := services.FetchUserByEmail(email)
if err != nil {
helper.LogError(err, fmt.Sprintf("Failed to fetch user profile for role mapping during refresh for email %s", email))
} else {
if user.UserID != "" {
userID = user.UserID
}
if user.EmailAddress != "" {
email = user.EmailAddress
}
roleIDs = buildJWTClaimRoleIDs(user)
}
accessToken, err := generateAccessToken(email, session.ID, userID, roleIDs)
if err != nil {
helper.LogError(err, "Failed to generate access token during refresh")
return "", fmt.Errorf("failed to generate access token: %w", err)
}
helper.LogInfo(fmt.Sprintf("Successfully refreshed access token for user %s (session: %s)", email, session.ID))
session.UpdatedAt = time.Now()
sessionTTL := int(time.Until(session.ExpiresAt).Seconds())
if sessionTTL > 0 {
if err := helper.SetJSON(ctx, sessionKey, session, &sessionTTL); err != nil {
helper.LogError(err, errMsgFailedToUpdateSessionActivity)
}
}
go func() {
_, err := db.DB.Exec("UPDATE jwt_sessions SET updated_at = ? WHERE id = ?", session.UpdatedAt, session.ID)
if err != nil {
helper.LogError(err, "Failed to update session activity in DB")
}
}()
return accessToken, nil
}
// RefreshAccessTokenWithEmailFallback refreshes the access token using a valid refresh token with email fallback.
// This version handles cases where the user ID in the session doesn't exist in the database (e.g., registrants).
//
// Parameters:
// - refreshTokenString: The refresh token to use for refreshing the access token.
// - userAgent: The user agent string from the request.
// - ipAddress: The IP address of the client.
// - emailFallback: Email to use if user ID lookup fails (extracted from current access token).
//
// Returns:
// - string: The new signed access token as a string.
// - error: An error if the token is invalid or the process fails.
func RefreshAccessTokenWithEmailFallback(refreshTokenString, userAgent, ipAddress, emailFallback string) (string, error) {
ctx := context.Background()
refreshTokenHash := helper.CalculateSHA256(refreshTokenString)
helper.LogInfo(fmt.Sprintf("RefreshAccessTokenWithEmailFallback called - Token length: %d, Hash: %s", len(refreshTokenString), refreshTokenHash[:16]+"..."))
helper.LogInfo(fmt.Sprintf("Client details - UserAgent: %s, IP: %s, EmailFallback: %s", userAgent, ipAddress, emailFallback))
rateLimitKey := fmt.Sprintf(redisKeyRefreshRateLimit, refreshTokenHash)
attempts, err := redisclient.RDB.Incr(ctx, rateLimitKey).Result()
if err == nil {
if attempts == 1 {
redisclient.RDB.Expire(ctx, rateLimitKey, time.Minute)
}
if attempts > 5 {
return "", fmt.Errorf("too many refresh attempts, please wait")
}
}
sessionKey := fmt.Sprintf(redisKeyJWTSession, refreshTokenHash)
var session models.JWTSession
err = helper.GetJSON(ctx, sessionKey, &session)
if err != nil {
helper.LogInfo(fmt.Sprintf("Session not found in Redis cache, querying database for hash: %s", refreshTokenHash[:16]+"..."))
err = db.DB.QueryRow(`
SELECT id, user_id, user_agent, ip_address, created_at, updated_at, expires_at, is_revoked
FROM jwt_sessions
WHERE refresh_token_hash = ? AND is_revoked = false
`, refreshTokenHash).Scan(
&session.ID,
&session.UsersID,
&session.UserAgent,
&session.IPAddress,
&session.CreatedAt,
&session.UpdatedAt,
&session.ExpiresAt,
&session.IsRevoked,
)
if err != nil {
helper.LogError(err, fmt.Sprintf("Session not found in database for hash: %s", refreshTokenHash[:16]+"..."))
return "", fmt.Errorf("invalid refresh token: %w", err)
}
helper.LogInfo(fmt.Sprintf("Session found in DB - ID: %s, UsersID: %s, Created: %s, Expires: %s",
session.ID, session.UsersID, session.CreatedAt.Format(timeFormatDateTime), session.ExpiresAt.Format(timeFormatDateTime)))
sessionTTL := int(time.Until(session.ExpiresAt).Seconds())
if sessionTTL > 0 {
if err := helper.SetJSON(ctx, sessionKey, session, &sessionTTL); err != nil {
helper.LogError(err, errMsgFailedToUpdateSessionActivity)
}
}
} else {
helper.LogInfo(fmt.Sprintf("Session found in Redis cache - ID: %s, UsersID: %s, Expires: %s",
session.ID, session.UsersID, session.ExpiresAt.Format(timeFormatDateTime)))
}
if session.IsRevoked {
helper.LogWarn(fmt.Sprintf("Attempted to use revoked session: %s", session.ID))
redisclient.RDB.Del(ctx, sessionKey)
return "", fmt.Errorf(errMsgSessionHasBeenRevoked)
}
if time.Now().After(session.ExpiresAt) {
helper.LogWarn(fmt.Sprintf("Attempted to use expired session: %s (expired at %s)",
session.ID, session.ExpiresAt.Format(timeFormatDateTime)))
_, err = db.DB.Exec(sqlUpdateRevokeSession, session.ID)
if err != nil {
helper.LogError(err, "Failed to revoke expired session")
}
redisclient.RDB.Del(ctx, sessionKey)
return "", fmt.Errorf("refresh token has expired")
}
if session.UserAgent != userAgent {
helper.LogWarn(fmt.Sprintf("Session User Agent security mismatch for session %s: stored='%s', received='%s'",
session.ID, session.UserAgent, userAgent))
}
if session.IPAddress != ipAddress {
helper.LogWarn(fmt.Sprintf("Session IP address mismatch for session %s: expected %s, got %s",
session.ID, session.IPAddress, ipAddress))
}
// Get user email from user ID (with caching), with fallback to provided email
email, err := getUserEmailFromIDCached(session.UsersID)
if err != nil {
helper.LogError(err, fmt.Sprintf("Failed to get email for user %s", session.UsersID))
if emailFallback != "" {
helper.LogInfo(fmt.Sprintf("Using email fallback for user ID %s: %s", session.UsersID, emailFallback))
email = emailFallback
} else {
helper.LogWarn(fmt.Sprintf("User ID %s not found in database tables and no email fallback provided", session.UsersID))
return "", fmt.Errorf("failed to get user email: %w", err)
}
}
helper.LogInfo(fmt.Sprintf("Generating new access token for email: %s, session: %s", email, session.ID))
userID, err := helper.FetchUserIDFromDB(email)
if err != nil {
helper.LogError(err, fmt.Sprintf("Failed to fetch user ID for email %s during refresh", email))
userID = session.UsersID // Fallback to session's user ID
}
roleIDs := make([]int, 0)
user, err := services.FetchUserByEmail(email)
if err != nil {
helper.LogError(err, fmt.Sprintf("Failed to fetch user profile for role mapping during refresh for email %s", email))
} else {
if user.UserID != "" {
userID = user.UserID
}
if user.EmailAddress != "" {
email = user.EmailAddress
}
roleIDs = buildJWTClaimRoleIDs(user)
}
accessToken, err := generateAccessToken(email, session.ID, userID, roleIDs)
if err != nil {
helper.LogError(err, "Failed to generate access token during refresh")
return "", fmt.Errorf("failed to generate access token: %w", err)
}
helper.LogInfo(fmt.Sprintf("Successfully refreshed access token for user %s (session: %s)", email, session.ID))
session.UpdatedAt = time.Now()
sessionTTL := int(time.Until(session.ExpiresAt).Seconds())
if sessionTTL > 0 {
if err := helper.SetJSON(ctx, sessionKey, session, &sessionTTL); err != nil {
helper.LogError(err, errMsgFailedToUpdateSessionActivity)
}
}
go func() {
_, err := db.DB.Exec("UPDATE jwt_sessions SET updated_at = ? WHERE jwt_sessions_id = ?", session.UpdatedAt, session.ID)
if err != nil {
helper.LogError(err, "Failed to update session activity in DB")
}
}()
return accessToken, nil
}
func RevokeSession(sessionID string) error {
ctx := context.Background()
_, err := db.DB.Exec(sqlUpdateRevokeSession, sessionID)
if err != nil {
return fmt.Errorf("failed to revoke session %s: %w", sessionID, err)
}
sessionKey := fmt.Sprintf(redisKeyJWTSessionID, sessionID)
redisclient.RDB.Del(ctx, sessionKey)
return nil
}
func RevokeAllUserSessions(userID string) error {
ctx := context.Background()
rows, err := db.DB.Query("SELECT jwt_sessions_id FROM jwt_sessions WHERE user_id = ? AND is_revoked = false", userID)
if err != nil {
return fmt.Errorf(errFormatWithContext, errMsgFailedToGetUserSessions, err)
}
defer rows.Close()
var sessionIDs []string
for rows.Next() {
var sessionID string
if err := rows.Scan(&sessionID); err != nil {
continue
}
sessionIDs = append(sessionIDs, sessionID)
}
_, err = db.DB.Exec("UPDATE jwt_sessions SET is_revoked = true WHERE user_id = ?", userID)
if err != nil {
return fmt.Errorf("failed to revoke all sessions for user %s: %w", userID, err)
}
for _, sessionID := range sessionIDs {
sessionKey := fmt.Sprintf(redisKeyJWTSessionID, sessionID)
redisclient.RDB.Del(ctx, sessionKey)
}
userEmailKey := fmt.Sprintf(redisKeyUserEmail, userID)
redisclient.RDB.Del(ctx, userEmailKey)
return nil
}
func RevokeAllUserSessionsExceptCurrent(userID, currentSessionID string) error {
ctx := context.Background()
rows, err := db.DB.Query("SELECT id FROM jwt_sessions WHERE user_id = ? AND id != ? AND is_revoked = false", userID, currentSessionID)
if err != nil {
return fmt.Errorf(errFormatWithContext, errMsgFailedToGetUserSessions, err)
}
defer rows.Close()
var sessionIDs []string
for rows.Next() {
var sessionID string
if err := rows.Scan(&sessionID); err != nil {
continue
}
sessionIDs = append(sessionIDs, sessionID)
}
_, err = db.DB.Exec(
"UPDATE jwt_sessions SET is_revoked = true WHERE user_id = ? AND jwt_sessions_id != ?",
userID, currentSessionID,
)
if err != nil {
return fmt.Errorf("failed to revoke other sessions for user %s: %w", userID, err)
}
for _, sessionID := range sessionIDs {
sessionKey := fmt.Sprintf(redisKeyJWTSessionID, sessionID)
redisclient.RDB.Del(ctx, sessionKey)
}
return nil
}
func ValidateSession(sessionID string) (*models.JWTSession, error) {
ctx := context.Background()
sessionKey := fmt.Sprintf(redisKeyJWTSessionID, sessionID)
var session models.JWTSession
err := helper.GetJSON(ctx, sessionKey, &session)
if err != nil {
err = db.DB.QueryRow(`
SELECT jwt_sessions_id, user_id, refresh_token_hash, user_agent, ip_address, created_at, updated_at, expires_at, is_revoked
FROM jwt_sessions
WHERE jwt_sessions_id = ?
`, sessionID).Scan(
&session.ID,
&session.UsersID,
&session.RefreshTokenHash,
&session.UserAgent,
&session.IPAddress,
&session.CreatedAt,
&session.UpdatedAt,
&session.ExpiresAt,
&session.IsRevoked,
)
if err != nil {
return nil, fmt.Errorf("session not found: %w", err)
}
sessionTTL := int(time.Until(session.ExpiresAt).Seconds())
if sessionTTL > 0 {
if err := helper.SetJSON(ctx, sessionKey, session, &sessionTTL); err != nil {
helper.LogError(err, "Failed to cache session in Redis (ValidateSession)")
}
}
}
if session.IsRevoked {
redisclient.RDB.Del(ctx, sessionKey)
return nil, fmt.Errorf("%s", errMsgSessionHasBeenRevoked)
}
if time.Now().After(session.ExpiresAt) {
if err := RevokeSession(sessionID); err != nil {
helper.LogError(err, "Failed to auto-revoke expired session")
}
redisclient.RDB.Del(ctx, sessionKey)
return nil, fmt.Errorf("session has expired")
}
return &session, nil
}
func ValidateSessionForWebSocket(sessionID string) (*models.JWTSession, error) {
ctx := context.Background()
sessionKey := fmt.Sprintf(redisKeyJWTSessionID, sessionID)
var session models.JWTSession
err := helper.GetJSON(ctx, sessionKey, &session)
if err != nil {
return nil, fmt.Errorf("%s", errMsgSessionNotFoundInCache)
}
if session.IsRevoked {
redisclient.RDB.Del(ctx, sessionKey)
return nil, fmt.Errorf("%s", errMsgSessionHasBeenRevoked)
}
if time.Now().After(session.ExpiresAt) {
redisclient.RDB.Del(ctx, sessionKey)
return nil, fmt.Errorf("session has expired")
}
return &session, nil
}
func ExtendSessionActivity(sessionID string) error {
ctx := context.Background()
sessionKey := fmt.Sprintf(redisKeyJWTSessionID, sessionID)
var session models.JWTSession
err := helper.GetJSON(ctx, sessionKey, &session)
if err != nil {
return fmt.Errorf("%s", errMsgSessionNotFoundInCache)
}
session.UpdatedAt = time.Now()
sessionTTL := int(time.Until(session.ExpiresAt).Seconds())
if sessionTTL > 0 {
if err := helper.SetJSON(ctx, sessionKey, session, &sessionTTL); err != nil {
helper.LogError(err, "Failed to extend session activity in Redis cache")
}
}
return nil
}
func GetSessionUserInfo(sessionID string) (string, string, error) {
ctx := context.Background()
sessionKey := fmt.Sprintf(redisKeyJWTSessionID, sessionID)
var session models.JWTSession
err := helper.GetJSON(ctx, sessionKey, &session)
if err != nil {
return "", "", fmt.Errorf("%s", errMsgSessionNotFoundInCache)
}
email, err := getUserEmailFromIDCached(session.UsersID)
if err != nil {
return "", "", fmt.Errorf("failed to get user email: %w", err)
}
return session.UsersID, email, nil
}
func InvalidateUserSessionsInCache(userID string) error {
ctx := context.Background()
rows, err := db.DB.Query("SELECT jwt_sessions_id, refresh_token_hash FROM jwt_sessions WHERE user_id = ?", userID)
if err != nil {
return fmt.Errorf(errFormatWithContext, errMsgFailedToGetUserSessions, err)
}
defer rows.Close()
var keys []string
for rows.Next() {
var sessionID, refreshTokenHash string
if err := rows.Scan(&sessionID, &refreshTokenHash); err != nil {
continue
}
keys = append(keys, fmt.Sprintf(redisKeyJWTSessionID, sessionID))
keys = append(keys, fmt.Sprintf(redisKeyJWTSession, refreshTokenHash))
}
if len(keys) > 0 {
redisclient.RDB.Del(ctx, keys...)
}
userEmailKey := fmt.Sprintf(redisKeyUserEmail, userID)
redisclient.RDB.Del(ctx, userEmailKey)
return nil
}
func CleanupExpiredSessions() error {
ctx := context.Background()
rows, err := db.DB.Query("SELECT jwt_sessions_id, user_id, refresh_token_hash FROM jwt_sessions WHERE expires_at < ?", time.Now())
if err != nil {
return fmt.Errorf("failed to query expired sessions: %w", err)
}
defer rows.Close()
var expiredSessions []models.ExpiredSession
userIDsToCleanup := make(map[string]bool)
for rows.Next() {
var session models.ExpiredSession
if err := rows.Scan(&session.ID, &session.UsersID, &session.RefreshTokenHash); err != nil {
continue
}
expiredSessions = append(expiredSessions, session)
userIDsToCleanup[session.UsersID] = true
}
_, err = db.DB.Exec("DELETE FROM jwt_sessions WHERE expires_at < ?", time.Now())
if err != nil {
return fmt.Errorf("failed to cleanup expired sessions: %w", err)
}
for _, session := range expiredSessions {
sessionKey := fmt.Sprintf(redisKeyJWTSession, session.RefreshTokenHash)
sessionIDKey := fmt.Sprintf(redisKeyJWTSessionID, session.ID)
redisclient.RDB.Del(ctx, sessionKey, sessionIDKey)
}
// Role cache invalidation removed - handled by separate authz microservice
log.Printf("Cleaned up %d expired sessions", len(expiredSessions))
return nil
}
func GetUserSessions(userID string) ([]models.JWTSession, error) {
rows, err := db.DB.Query(`
SELECT jwt_sessions_id, user_id, refresh_token_hash, user_agent, ip_address, created_at, updated_at, expires_at, is_revoked
FROM jwt_sessions
WHERE user_id = ? AND is_revoked = false AND expires_at > ?
ORDER BY created_at DESC
`, userID, time.Now())
if err != nil {
return nil, fmt.Errorf(errFormatWithContext, errMsgFailedToGetUserSessions, err)
}
defer rows.Close()
var sessions []models.JWTSession
for rows.Next() {
var session models.JWTSession
err := rows.Scan(
&session.ID,
&session.UsersID,
&session.RefreshTokenHash,
&session.UserAgent,
&session.IPAddress,
&session.CreatedAt,
&session.UpdatedAt,
&session.ExpiresAt,
&session.IsRevoked,
)
if err != nil {
return nil, fmt.Errorf("failed to scan session row: %w", err)
}
sessions = append(sessions, session)
}
return sessions, nil
}
func UpdateSessionLastActivity(sessionID string) error {
_, err := db.DB.Exec(`
UPDATE jwt_sessions
SET updated_at = ?
WHERE jwt_sessions_id = ?
`, time.Now(), sessionID)
if err != nil {
return fmt.Errorf("failed to update session activity: %w", err)
}
return nil
}
func getUserEmailFromID(userID string) (string, error) {
var email string
err := db.DB.QueryRow("SELECT email_address FROM users WHERE users_id = ?", userID).Scan(&email)
if err == nil {
return email, nil
}
return "", fmt.Errorf("user not found with ID %s in any table", userID)
}
func getUserEmailFromIDCached(userID string) (string, error) {
ctx := context.Background()
cacheKey := fmt.Sprintf(redisKeyUserEmail, userID)
var email string
err := helper.GetJSON(ctx, cacheKey, &email)
if err == nil && email != "" {
return email, nil
}
// Cache miss, feth from database
email, err = getUserEmailFromID(userID)
if err != nil {
return "", err
}
cacheTTL := 3600
if err := helper.SetJSON(ctx, cacheKey, email, &cacheTTL); err != nil {
helper.LogError(err, "Failed to cache user email in Redis")
}
return email, nil
}
func AddToSessionBlacklist(sessionID string, ttlSeconds int) error {
ctx := context.Background()
blacklistKey := fmt.Sprintf(redisKeySessionBlacklist, sessionID)
ttl := time.Duration(ttlSeconds) * time.Second
return redisclient.RDB.Set(ctx, blacklistKey, "revoked", ttl).Err()
}
func IsSessionBlacklisted(sessionID string) bool {
ctx := context.Background()
blacklistKey := fmt.Sprintf(redisKeySessionBlacklist, sessionID)
exists, err := redisclient.RDB.Exists(ctx, blacklistKey).Result()
return err == nil && exists > 0
}
func ClearSessionFromAllCaches(sessionID, refreshTokenHash string) error {
ctx := context.Background()
keys := []string{
fmt.Sprintf(redisKeyJWTSessionID, sessionID),
fmt.Sprintf(redisKeyJWTSession, refreshTokenHash),
}
return redisclient.RDB.Del(ctx, keys...).Err()
}
func CheckEmailInDB(email string) (bool, error) {
if db.DB == nil {
return false, fmt.Errorf("database connection is nil")
}
var exists bool
err := db.DB.QueryRow(
`SELECT EXISTS(
SELECT 1 FROM users WHERE email_address = ? AND is_deleted = 0)`, email,
).Scan(&exists)
if err != nil {
return false, fmt.Errorf("error checking email in database: %v", err)
}
return exists, nil
}