913 lines
28 KiB
Go
913 lines
28 KiB
Go
package handlers
|
|
|
|
import (
|
|
"authentication/db"
|
|
"authentication/helper"
|
|
"authentication/models"
|
|
"authentication/redisclient"
|
|
"authentication/services"
|
|
"context"
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"crypto/x509"
|
|
"encoding/base64"
|
|
"encoding/pem"
|
|
"fmt"
|
|
"log"
|
|
"os"
|
|
"time"
|
|
|
|
"github.com/golang-jwt/jwt/v5"
|
|
)
|
|
|
|
var rsaPrivateKey *rsa.PrivateKey
|
|
|
|
func parseRSAPrivateKey(keyData []byte) (*rsa.PrivateKey, error) {
|
|
block, _ := pem.Decode(keyData)
|
|
if block == nil {
|
|
return nil, fmt.Errorf("failed to decode PEM block containing private key")
|
|
}
|
|
|
|
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
|
|
if err != nil {
|
|
key, err = x509.ParsePKCS1PrivateKey(block.Bytes)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to parse RSA private key: %w", err)
|
|
}
|
|
}
|
|
|
|
rsaKey, ok := key.(*rsa.PrivateKey)
|
|
if !ok {
|
|
return nil, fmt.Errorf("not an RSA private key")
|
|
}
|
|
|
|
return rsaKey, nil
|
|
}
|
|
|
|
// Note: .env file is loaded in main() before this init runs
|
|
func init() {
|
|
keyPath := os.Getenv("JWT_PRIVATE_KEY_PATH")
|
|
if keyPath == "" {
|
|
keyPath = "rsa/psa.gov.ph.key"
|
|
}
|
|
|
|
keyData, err := os.ReadFile(keyPath) // #nosec G304
|
|
if err != nil {
|
|
if isTestEnvironment() {
|
|
log.Printf("Failed to read RSA private key file at %s, generating test key: %v", keyPath, err)
|
|
generatedKey, genErr := rsa.GenerateKey(rand.Reader, 2048)
|
|
if genErr != nil {
|
|
log.Fatalf("Failed to generate test RSA private key: %v", genErr)
|
|
}
|
|
rsaPrivateKey = generatedKey
|
|
return
|
|
}
|
|
log.Fatalf("Failed to read RSA private key file: %v", err)
|
|
}
|
|
|
|
parsedKey, err := parseRSAPrivateKey(keyData)
|
|
if err != nil {
|
|
log.Fatalf("%v", err)
|
|
}
|
|
|
|
rsaPrivateKey = parsedKey
|
|
log.Println("RSA private key loaded successfully for JWT signing")
|
|
}
|
|
|
|
// GenerateTokens generates both access and refresh tokens with session management.
|
|
// It creates a new session in the database and caches it in Redis for performance.
|
|
//
|
|
// Parameters:
|
|
// - email: The email address to include in the JWT claims.
|
|
// - userAgent: The user agent string from the request.
|
|
// - ipAddress: The IP address of the client.
|
|
//
|
|
// Returns:
|
|
func GenerateTokens(email, userAgent, ipAddress string) (string, string, error) {
|
|
ctx := context.Background()
|
|
|
|
emailExists, err := CheckEmailInDB(email)
|
|
if err != nil {
|
|
return "", "", fmt.Errorf("error checking email in database: %w", err)
|
|
}
|
|
|
|
userID, err := services.GetUserIDFromEmail(email)
|
|
if err != nil {
|
|
userID = helper.UUIDGenerator()
|
|
}
|
|
|
|
log.Print("userID:", userID)
|
|
roleID, err := services.GetRoleIDsFromEmail(email)
|
|
if err != nil {
|
|
return "", "", fmt.Errorf("error checking role in database: %w", err)
|
|
}
|
|
sessionID := helper.UUIDGenerator()
|
|
|
|
refreshToken, err := generateSecureToken()
|
|
if err != nil {
|
|
return "", "", fmt.Errorf("failed to generate refresh token: %w", err)
|
|
}
|
|
|
|
refreshTokenHash := helper.CalculateSHA256(refreshToken)
|
|
|
|
location, err := helper.LoadAsiaManilaLocation()
|
|
if err != nil {
|
|
helper.LogError(err, "Failed to load Asia/Manila timezone, using UTC+8 offset")
|
|
}
|
|
|
|
currentTime := time.Now().In(location)
|
|
|
|
var expiresAt time.Time
|
|
if emailExists {
|
|
expiresAt = currentTime.Add(7 * 24 * time.Hour)
|
|
} else {
|
|
expiresAt = currentTime.Add(2 * time.Hour)
|
|
}
|
|
|
|
session := models.JWTSession{
|
|
ID: sessionID,
|
|
UsersID: userID,
|
|
RefreshTokenHash: refreshTokenHash,
|
|
UserAgent: userAgent,
|
|
IPAddress: ipAddress,
|
|
CreatedAt: currentTime,
|
|
UpdatedAt: currentTime,
|
|
ExpiresAt: expiresAt,
|
|
IsRevoked: false,
|
|
}
|
|
|
|
_, err = db.DB.Exec(`
|
|
INSERT INTO jwt_sessions (jwt_sessions_id, user_id, refresh_token_hash, user_agent, ip_address, created_at, updated_at, expires_at, is_revoked)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
`, sessionID, userID, refreshTokenHash, userAgent, ipAddress, currentTime, currentTime, expiresAt, false)
|
|
if err != nil {
|
|
return "", "", fmt.Errorf("failed to store session: %w", err)
|
|
}
|
|
|
|
sessionKey := fmt.Sprintf(redisKeyJWTSession, refreshTokenHash)
|
|
sessionIDKey := fmt.Sprintf(redisKeyJWTSessionID, sessionID)
|
|
sessionTTL := int(time.Until(expiresAt).Seconds())
|
|
if sessionTTL > 0 {
|
|
if err := helper.SetJSON(ctx, sessionKey, session, &sessionTTL); err != nil {
|
|
helper.LogError(err, "Failed to cache session in Redis (sessionKey)")
|
|
}
|
|
if err := helper.SetJSON(ctx, sessionIDKey, session, &sessionTTL); err != nil {
|
|
helper.LogError(err, "Failed to cache session in Redis (sessionIDKey)")
|
|
}
|
|
}
|
|
|
|
// Convert roleIDs slice to a comma-separated string for the token claim
|
|
var roleIDsStr string
|
|
if len(roleID) > 0 {
|
|
for i, r := range roleID {
|
|
if i > 0 {
|
|
roleIDsStr += ","
|
|
}
|
|
roleIDsStr += fmt.Sprintf("%d", r)
|
|
}
|
|
}
|
|
|
|
accessToken, err := generateAccessToken(email, sessionID, userID, roleID)
|
|
if err != nil {
|
|
return "", "", fmt.Errorf(errFormatWithContext, errMsgFailedToGenerateAccessToken, err)
|
|
}
|
|
|
|
log.Printf("Generated tokens for user %s with session %s", email, sessionID)
|
|
return accessToken, refreshToken, nil
|
|
}
|
|
|
|
func generateAccessToken(email, sessionID, userID string, roleID []int) (string, error) {
|
|
AccessTokenExpiration := os.Getenv("ACCESS_TOKEN_EXPIRATION_MINUTES")
|
|
if AccessTokenExpiration == "" {
|
|
log.Println("AccessTokenExpiration not set (in minutes), defaulting to 45 minutes")
|
|
AccessTokenExpiration = "45"
|
|
}
|
|
|
|
if roleID == nil {
|
|
roleID = []int{}
|
|
}
|
|
|
|
var primaryRoleID *int
|
|
if len(roleID) > 0 {
|
|
value := roleID[0]
|
|
primaryRoleID = &value
|
|
}
|
|
|
|
expirationTime := time.Now().Add(24 * time.Hour).Unix()
|
|
|
|
claims := &models.AccessToken{
|
|
Email: email,
|
|
UsersID: userID,
|
|
RoleID: primaryRoleID,
|
|
AdditionalRoleID: roleID,
|
|
SessionID: sessionID,
|
|
Exp: expirationTime,
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
ExpiresAt: jwt.NewNumericDate(time.Unix(expirationTime, 0)),
|
|
},
|
|
}
|
|
|
|
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
|
|
|
return token.SignedString(rsaPrivateKey)
|
|
}
|
|
|
|
func generateSecureToken() (string, error) {
|
|
bytes := make([]byte, 32) // 256 bits
|
|
_, err := rand.Read(bytes)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return base64.URLEncoding.EncodeToString(bytes), nil
|
|
}
|
|
|
|
// RefreshAccessToken refreshes the access token using a valid refresh token.
|
|
// It validates the refresh token, checks the session status, and generates a new access token.
|
|
// Uses Redis for session caching to improve performance for websocket connections.
|
|
//
|
|
// Parameters:
|
|
// - refreshTokenString: The refresh token to use for refreshing the access token.
|
|
// - userAgent: The user agent string from the request.
|
|
// - ipAddress: The IP address of the client.
|
|
//
|
|
// Returns:
|
|
// - string: The new signed access token as a string.
|
|
// - error: An error if the token is invalid or the process fails.
|
|
func RefreshAccessToken(refreshTokenString, userAgent, ipAddress string) (string, error) {
|
|
ctx := context.Background()
|
|
refreshTokenHash := helper.CalculateSHA256(refreshTokenString)
|
|
|
|
helper.LogInfo(fmt.Sprintf("RefreshAccessToken called - Token length: %d, Hash: %s", len(refreshTokenString), refreshTokenHash[:16]+"..."))
|
|
helper.LogInfo(fmt.Sprintf("Client details - UserAgent: %s, IP: %s", userAgent, ipAddress))
|
|
|
|
rateLimitKey := fmt.Sprintf("refresh_rate_limit:%s", refreshTokenHash)
|
|
attempts, err := redisclient.RDB.Incr(ctx, rateLimitKey).Result()
|
|
if err == nil {
|
|
if attempts == 1 {
|
|
redisclient.RDB.Expire(ctx, rateLimitKey, time.Minute)
|
|
}
|
|
if attempts > 5 {
|
|
return "", fmt.Errorf("too many refresh attempts, please wait")
|
|
}
|
|
}
|
|
|
|
sessionKey := fmt.Sprintf(redisKeyJWTSession, refreshTokenHash)
|
|
var session models.JWTSession
|
|
|
|
err = helper.GetJSON(ctx, sessionKey, &session)
|
|
if err != nil {
|
|
helper.LogInfo(fmt.Sprintf("Session not found in Redis cache, querying database for hash: %s", refreshTokenHash[:16]+"..."))
|
|
err = db.DB.QueryRow(`
|
|
SELECT id, user_id, user_agent, ip_address, created_at, updated_at, expires_at, is_revoked
|
|
FROM jwt_sessions
|
|
WHERE refresh_token_hash = ? AND is_revoked = false
|
|
`, refreshTokenHash).Scan(
|
|
&session.ID,
|
|
&session.UsersID,
|
|
&session.UserAgent,
|
|
&session.IPAddress,
|
|
&session.CreatedAt,
|
|
&session.UpdatedAt,
|
|
&session.ExpiresAt,
|
|
&session.IsRevoked,
|
|
)
|
|
|
|
if err != nil {
|
|
helper.LogError(err, fmt.Sprintf("Session not found in database for hash: %s", refreshTokenHash[:16]+"..."))
|
|
return "", fmt.Errorf("invalid refresh token: %w", err)
|
|
}
|
|
|
|
helper.LogInfo(fmt.Sprintf("Session found in DB - ID: %s, UsersID: %s, Created: %s, Expires: %s",
|
|
session.ID, session.UsersID, session.CreatedAt.Format(timeFormatDateTime), session.ExpiresAt.Format(timeFormatDateTime)))
|
|
|
|
sessionTTL := int(time.Until(session.ExpiresAt).Seconds())
|
|
if sessionTTL > 0 {
|
|
if err := helper.SetJSON(ctx, sessionKey, session, &sessionTTL); err != nil {
|
|
helper.LogError(err, errMsgFailedToUpdateSessionActivity)
|
|
}
|
|
}
|
|
} else {
|
|
helper.LogInfo(fmt.Sprintf("Session found in Redis cache - ID: %s, UsersID: %s, Expires: %s",
|
|
session.ID, session.UsersID, session.ExpiresAt.Format(timeFormatDateTime)))
|
|
}
|
|
|
|
if session.IsRevoked {
|
|
helper.LogWarn(fmt.Sprintf("Attempted to use revoked session: %s", session.ID))
|
|
redisclient.RDB.Del(ctx, sessionKey)
|
|
return "", fmt.Errorf(errMsgSessionHasBeenRevoked)
|
|
}
|
|
|
|
if time.Now().After(session.ExpiresAt) {
|
|
helper.LogWarn(fmt.Sprintf("Attempted to use expired session: %s (expired at %s)",
|
|
session.ID, session.ExpiresAt.Format(timeFormatDateTime)))
|
|
_, err = db.DB.Exec(sqlUpdateRevokeSession, session.ID)
|
|
if err != nil {
|
|
helper.LogError(err, "Failed to revoke expired session")
|
|
}
|
|
redisclient.RDB.Del(ctx, sessionKey)
|
|
return "", fmt.Errorf("refresh token has expired")
|
|
}
|
|
|
|
if session.UserAgent != userAgent {
|
|
helper.LogWarn(fmt.Sprintf("Session User Agent security mismatch for session %s: stored='%s', received='%s'",
|
|
session.ID, session.UserAgent, userAgent))
|
|
}
|
|
|
|
if session.IPAddress != ipAddress {
|
|
helper.LogWarn(fmt.Sprintf("Session IP address mismatch for session %s: expected %s, got %s",
|
|
session.ID, session.IPAddress, ipAddress))
|
|
}
|
|
|
|
// Get user email from user ID (with caching)
|
|
email, err := getUserEmailFromIDCached(session.UsersID)
|
|
if err != nil {
|
|
helper.LogError(err, fmt.Sprintf("Failed to get email for user %s", session.UsersID))
|
|
// For registrants or users not yet in the main tables, we still want to allow refresh
|
|
// but we need to get the email from somewhere else. Since we don't store email in session,
|
|
// we'll need to handle this gracefully by allowing the refresh to continue with a placeholder
|
|
// The email will be properly resolved when they complete registration
|
|
helper.LogWarn(fmt.Sprintf("User ID %s not found in database tables, but allowing refresh for potential registrant", session.UsersID))
|
|
|
|
// For now, we'll use a placeholder email pattern and let the access token generation handle it
|
|
// The system should work as long as the session is valid
|
|
email = fmt.Sprintf("registrant_%s@pending.local", session.UsersID)
|
|
}
|
|
|
|
helper.LogInfo(fmt.Sprintf("Generating new access token for email: %s, session: %s", email, session.ID))
|
|
|
|
userID, err := helper.FetchUserIDFromDB(email)
|
|
if err != nil {
|
|
helper.LogError(err, fmt.Sprintf("Failed to fetch user ID for email %s during refresh", email))
|
|
userID = session.UsersID // Fallback to session's user ID
|
|
}
|
|
|
|
roleIDs, err := services.GetRoleIDsFromEmail(email)
|
|
if err != nil {
|
|
helper.LogError(err, fmt.Sprintf("Failed to fetch role ID for email %s during refresh", email))
|
|
roleIDs = []int{}
|
|
}
|
|
accessToken, err := generateAccessToken(email, session.ID, userID, roleIDs)
|
|
if err != nil {
|
|
helper.LogError(err, "Failed to generate access token during refresh")
|
|
return "", fmt.Errorf("failed to generate access token: %w", err)
|
|
}
|
|
|
|
helper.LogInfo(fmt.Sprintf("Successfully refreshed access token for user %s (session: %s)", email, session.ID))
|
|
|
|
session.UpdatedAt = time.Now()
|
|
sessionTTL := int(time.Until(session.ExpiresAt).Seconds())
|
|
if sessionTTL > 0 {
|
|
if err := helper.SetJSON(ctx, sessionKey, session, &sessionTTL); err != nil {
|
|
helper.LogError(err, errMsgFailedToUpdateSessionActivity)
|
|
}
|
|
}
|
|
|
|
go func() {
|
|
_, err := db.DB.Exec("UPDATE jwt_sessions SET updated_at = ? WHERE id = ?", session.UpdatedAt, session.ID)
|
|
if err != nil {
|
|
helper.LogError(err, "Failed to update session activity in DB")
|
|
}
|
|
}()
|
|
|
|
return accessToken, nil
|
|
}
|
|
|
|
// RefreshAccessTokenWithEmailFallback refreshes the access token using a valid refresh token with email fallback.
|
|
// This version handles cases where the user ID in the session doesn't exist in the database (e.g., registrants).
|
|
//
|
|
// Parameters:
|
|
// - refreshTokenString: The refresh token to use for refreshing the access token.
|
|
// - userAgent: The user agent string from the request.
|
|
// - ipAddress: The IP address of the client.
|
|
// - emailFallback: Email to use if user ID lookup fails (extracted from current access token).
|
|
//
|
|
// Returns:
|
|
// - string: The new signed access token as a string.
|
|
// - error: An error if the token is invalid or the process fails.
|
|
func RefreshAccessTokenWithEmailFallback(refreshTokenString, userAgent, ipAddress, emailFallback string) (string, error) {
|
|
ctx := context.Background()
|
|
refreshTokenHash := helper.CalculateSHA256(refreshTokenString)
|
|
|
|
helper.LogInfo(fmt.Sprintf("RefreshAccessTokenWithEmailFallback called - Token length: %d, Hash: %s", len(refreshTokenString), refreshTokenHash[:16]+"..."))
|
|
helper.LogInfo(fmt.Sprintf("Client details - UserAgent: %s, IP: %s, EmailFallback: %s", userAgent, ipAddress, emailFallback))
|
|
|
|
rateLimitKey := fmt.Sprintf(redisKeyRefreshRateLimit, refreshTokenHash)
|
|
attempts, err := redisclient.RDB.Incr(ctx, rateLimitKey).Result()
|
|
if err == nil {
|
|
if attempts == 1 {
|
|
redisclient.RDB.Expire(ctx, rateLimitKey, time.Minute)
|
|
}
|
|
if attempts > 5 {
|
|
return "", fmt.Errorf("too many refresh attempts, please wait")
|
|
}
|
|
}
|
|
|
|
sessionKey := fmt.Sprintf(redisKeyJWTSession, refreshTokenHash)
|
|
var session models.JWTSession
|
|
|
|
err = helper.GetJSON(ctx, sessionKey, &session)
|
|
if err != nil {
|
|
helper.LogInfo(fmt.Sprintf("Session not found in Redis cache, querying database for hash: %s", refreshTokenHash[:16]+"..."))
|
|
err = db.DB.QueryRow(`
|
|
SELECT id, user_id, user_agent, ip_address, created_at, updated_at, expires_at, is_revoked
|
|
FROM jwt_sessions
|
|
WHERE refresh_token_hash = ? AND is_revoked = false
|
|
`, refreshTokenHash).Scan(
|
|
&session.ID,
|
|
&session.UsersID,
|
|
&session.UserAgent,
|
|
&session.IPAddress,
|
|
&session.CreatedAt,
|
|
&session.UpdatedAt,
|
|
&session.ExpiresAt,
|
|
&session.IsRevoked,
|
|
)
|
|
|
|
if err != nil {
|
|
helper.LogError(err, fmt.Sprintf("Session not found in database for hash: %s", refreshTokenHash[:16]+"..."))
|
|
return "", fmt.Errorf("invalid refresh token: %w", err)
|
|
}
|
|
|
|
helper.LogInfo(fmt.Sprintf("Session found in DB - ID: %s, UsersID: %s, Created: %s, Expires: %s",
|
|
session.ID, session.UsersID, session.CreatedAt.Format(timeFormatDateTime), session.ExpiresAt.Format(timeFormatDateTime)))
|
|
|
|
sessionTTL := int(time.Until(session.ExpiresAt).Seconds())
|
|
if sessionTTL > 0 {
|
|
if err := helper.SetJSON(ctx, sessionKey, session, &sessionTTL); err != nil {
|
|
helper.LogError(err, errMsgFailedToUpdateSessionActivity)
|
|
}
|
|
}
|
|
} else {
|
|
helper.LogInfo(fmt.Sprintf("Session found in Redis cache - ID: %s, UsersID: %s, Expires: %s",
|
|
session.ID, session.UsersID, session.ExpiresAt.Format(timeFormatDateTime)))
|
|
}
|
|
|
|
if session.IsRevoked {
|
|
helper.LogWarn(fmt.Sprintf("Attempted to use revoked session: %s", session.ID))
|
|
redisclient.RDB.Del(ctx, sessionKey)
|
|
return "", fmt.Errorf(errMsgSessionHasBeenRevoked)
|
|
}
|
|
|
|
if time.Now().After(session.ExpiresAt) {
|
|
helper.LogWarn(fmt.Sprintf("Attempted to use expired session: %s (expired at %s)",
|
|
session.ID, session.ExpiresAt.Format(timeFormatDateTime)))
|
|
_, err = db.DB.Exec(sqlUpdateRevokeSession, session.ID)
|
|
if err != nil {
|
|
helper.LogError(err, "Failed to revoke expired session")
|
|
}
|
|
redisclient.RDB.Del(ctx, sessionKey)
|
|
return "", fmt.Errorf("refresh token has expired")
|
|
}
|
|
|
|
if session.UserAgent != userAgent {
|
|
helper.LogWarn(fmt.Sprintf("Session User Agent security mismatch for session %s: stored='%s', received='%s'",
|
|
session.ID, session.UserAgent, userAgent))
|
|
}
|
|
|
|
if session.IPAddress != ipAddress {
|
|
helper.LogWarn(fmt.Sprintf("Session IP address mismatch for session %s: expected %s, got %s",
|
|
session.ID, session.IPAddress, ipAddress))
|
|
}
|
|
|
|
// Get user email from user ID (with caching), with fallback to provided email
|
|
email, err := getUserEmailFromIDCached(session.UsersID)
|
|
if err != nil {
|
|
helper.LogError(err, fmt.Sprintf("Failed to get email for user %s", session.UsersID))
|
|
|
|
if emailFallback != "" {
|
|
helper.LogInfo(fmt.Sprintf("Using email fallback for user ID %s: %s", session.UsersID, emailFallback))
|
|
email = emailFallback
|
|
} else {
|
|
helper.LogWarn(fmt.Sprintf("User ID %s not found in database tables and no email fallback provided", session.UsersID))
|
|
return "", fmt.Errorf("failed to get user email: %w", err)
|
|
}
|
|
}
|
|
|
|
helper.LogInfo(fmt.Sprintf("Generating new access token for email: %s, session: %s", email, session.ID))
|
|
|
|
userID, err := helper.FetchUserIDFromDB(email)
|
|
if err != nil {
|
|
helper.LogError(err, fmt.Sprintf("Failed to fetch user ID for email %s during refresh", email))
|
|
userID = session.UsersID // Fallback to session's user ID
|
|
}
|
|
|
|
roleIDs, err := services.GetRoleIDsFromEmail(email)
|
|
if err != nil {
|
|
helper.LogError(err, fmt.Sprintf("Failed to fetch role ID for email %s during refresh", email))
|
|
roleIDs = []int{}
|
|
}
|
|
accessToken, err := generateAccessToken(email, session.ID, userID, roleIDs)
|
|
if err != nil {
|
|
helper.LogError(err, "Failed to generate access token during refresh")
|
|
return "", fmt.Errorf("failed to generate access token: %w", err)
|
|
}
|
|
|
|
helper.LogInfo(fmt.Sprintf("Successfully refreshed access token for user %s (session: %s)", email, session.ID))
|
|
|
|
session.UpdatedAt = time.Now()
|
|
sessionTTL := int(time.Until(session.ExpiresAt).Seconds())
|
|
if sessionTTL > 0 {
|
|
if err := helper.SetJSON(ctx, sessionKey, session, &sessionTTL); err != nil {
|
|
helper.LogError(err, errMsgFailedToUpdateSessionActivity)
|
|
}
|
|
}
|
|
|
|
go func() {
|
|
_, err := db.DB.Exec("UPDATE jwt_sessions SET updated_at = ? WHERE id = ?", session.UpdatedAt, session.ID)
|
|
if err != nil {
|
|
helper.LogError(err, "Failed to update session activity in DB")
|
|
}
|
|
}()
|
|
|
|
return accessToken, nil
|
|
}
|
|
|
|
func RevokeSession(sessionID string) error {
|
|
ctx := context.Background()
|
|
|
|
_, err := db.DB.Exec(sqlUpdateRevokeSession, sessionID)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to revoke session %s: %w", sessionID, err)
|
|
}
|
|
|
|
sessionKey := fmt.Sprintf(redisKeyJWTSessionID, sessionID)
|
|
redisclient.RDB.Del(ctx, sessionKey)
|
|
|
|
return nil
|
|
}
|
|
|
|
func RevokeAllUserSessions(userID string) error {
|
|
ctx := context.Background()
|
|
|
|
rows, err := db.DB.Query("SELECT jwt_sessions_id FROM jwt_sessions WHERE user_id = ? AND is_revoked = false", userID)
|
|
if err != nil {
|
|
return fmt.Errorf(errFormatWithContext, errMsgFailedToGetUserSessions, err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var sessionIDs []string
|
|
for rows.Next() {
|
|
var sessionID string
|
|
if err := rows.Scan(&sessionID); err != nil {
|
|
continue
|
|
}
|
|
sessionIDs = append(sessionIDs, sessionID)
|
|
}
|
|
|
|
_, err = db.DB.Exec("UPDATE jwt_sessions SET is_revoked = true WHERE user_id = ?", userID)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to revoke all sessions for user %s: %w", userID, err)
|
|
}
|
|
|
|
for _, sessionID := range sessionIDs {
|
|
sessionKey := fmt.Sprintf(redisKeyJWTSessionID, sessionID)
|
|
redisclient.RDB.Del(ctx, sessionKey)
|
|
}
|
|
|
|
userEmailKey := fmt.Sprintf(redisKeyUserEmail, userID)
|
|
redisclient.RDB.Del(ctx, userEmailKey)
|
|
|
|
return nil
|
|
}
|
|
|
|
func RevokeAllUserSessionsExceptCurrent(userID, currentSessionID string) error {
|
|
ctx := context.Background()
|
|
|
|
rows, err := db.DB.Query("SELECT id FROM jwt_sessions WHERE user_id = ? AND id != ? AND is_revoked = false", userID, currentSessionID)
|
|
if err != nil {
|
|
return fmt.Errorf(errFormatWithContext, errMsgFailedToGetUserSessions, err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var sessionIDs []string
|
|
for rows.Next() {
|
|
var sessionID string
|
|
if err := rows.Scan(&sessionID); err != nil {
|
|
continue
|
|
}
|
|
sessionIDs = append(sessionIDs, sessionID)
|
|
}
|
|
|
|
_, err = db.DB.Exec(
|
|
"UPDATE jwt_sessions SET is_revoked = true WHERE user_id = ? AND id != ?",
|
|
userID, currentSessionID,
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to revoke other sessions for user %s: %w", userID, err)
|
|
}
|
|
|
|
for _, sessionID := range sessionIDs {
|
|
sessionKey := fmt.Sprintf(redisKeyJWTSessionID, sessionID)
|
|
redisclient.RDB.Del(ctx, sessionKey)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func ValidateSession(sessionID string) (*models.JWTSession, error) {
|
|
ctx := context.Background()
|
|
sessionKey := fmt.Sprintf(redisKeyJWTSessionID, sessionID)
|
|
|
|
var session models.JWTSession
|
|
err := helper.GetJSON(ctx, sessionKey, &session)
|
|
if err != nil {
|
|
err = db.DB.QueryRow(`
|
|
SELECT id, user_id, refresh_token_hash, user_agent, ip_address, created_at, updated_at, expires_at, is_revoked
|
|
FROM jwt_sessions
|
|
WHERE id = ?
|
|
`, sessionID).Scan(
|
|
&session.ID,
|
|
&session.UsersID,
|
|
&session.RefreshTokenHash,
|
|
&session.UserAgent,
|
|
&session.IPAddress,
|
|
&session.CreatedAt,
|
|
&session.UpdatedAt,
|
|
&session.ExpiresAt,
|
|
&session.IsRevoked,
|
|
)
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("session not found: %w", err)
|
|
}
|
|
|
|
sessionTTL := int(time.Until(session.ExpiresAt).Seconds())
|
|
if sessionTTL > 0 {
|
|
if err := helper.SetJSON(ctx, sessionKey, session, &sessionTTL); err != nil {
|
|
helper.LogError(err, "Failed to cache session in Redis (ValidateSession)")
|
|
}
|
|
}
|
|
}
|
|
|
|
if session.IsRevoked {
|
|
redisclient.RDB.Del(ctx, sessionKey)
|
|
return nil, fmt.Errorf("%s", errMsgSessionHasBeenRevoked)
|
|
}
|
|
|
|
if time.Now().After(session.ExpiresAt) {
|
|
if err := RevokeSession(sessionID); err != nil {
|
|
helper.LogError(err, "Failed to auto-revoke expired session")
|
|
}
|
|
redisclient.RDB.Del(ctx, sessionKey)
|
|
return nil, fmt.Errorf("session has expired")
|
|
}
|
|
|
|
return &session, nil
|
|
}
|
|
|
|
func ValidateSessionForWebSocket(sessionID string) (*models.JWTSession, error) {
|
|
ctx := context.Background()
|
|
sessionKey := fmt.Sprintf(redisKeyJWTSessionID, sessionID)
|
|
|
|
var session models.JWTSession
|
|
err := helper.GetJSON(ctx, sessionKey, &session)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%s", errMsgSessionNotFoundInCache)
|
|
}
|
|
|
|
if session.IsRevoked {
|
|
redisclient.RDB.Del(ctx, sessionKey)
|
|
return nil, fmt.Errorf("%s", errMsgSessionHasBeenRevoked)
|
|
}
|
|
|
|
if time.Now().After(session.ExpiresAt) {
|
|
redisclient.RDB.Del(ctx, sessionKey)
|
|
return nil, fmt.Errorf("session has expired")
|
|
}
|
|
|
|
return &session, nil
|
|
}
|
|
|
|
func ExtendSessionActivity(sessionID string) error {
|
|
ctx := context.Background()
|
|
sessionKey := fmt.Sprintf(redisKeyJWTSessionID, sessionID)
|
|
|
|
var session models.JWTSession
|
|
err := helper.GetJSON(ctx, sessionKey, &session)
|
|
if err != nil {
|
|
return fmt.Errorf("%s", errMsgSessionNotFoundInCache)
|
|
}
|
|
|
|
session.UpdatedAt = time.Now()
|
|
|
|
sessionTTL := int(time.Until(session.ExpiresAt).Seconds())
|
|
if sessionTTL > 0 {
|
|
if err := helper.SetJSON(ctx, sessionKey, session, &sessionTTL); err != nil {
|
|
helper.LogError(err, "Failed to extend session activity in Redis cache")
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func GetSessionUserInfo(sessionID string) (string, string, error) {
|
|
ctx := context.Background()
|
|
sessionKey := fmt.Sprintf(redisKeyJWTSessionID, sessionID)
|
|
|
|
var session models.JWTSession
|
|
err := helper.GetJSON(ctx, sessionKey, &session)
|
|
if err != nil {
|
|
return "", "", fmt.Errorf("%s", errMsgSessionNotFoundInCache)
|
|
}
|
|
|
|
email, err := getUserEmailFromIDCached(session.UsersID)
|
|
if err != nil {
|
|
return "", "", fmt.Errorf("failed to get user email: %w", err)
|
|
}
|
|
|
|
return session.UsersID, email, nil
|
|
}
|
|
|
|
func InvalidateUserSessionsInCache(userID string) error {
|
|
ctx := context.Background()
|
|
|
|
rows, err := db.DB.Query("SELECT id, refresh_token_hash FROM jwt_sessions WHERE user_id = ?", userID)
|
|
if err != nil {
|
|
return fmt.Errorf(errFormatWithContext, errMsgFailedToGetUserSessions, err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var keys []string
|
|
for rows.Next() {
|
|
var sessionID, refreshTokenHash string
|
|
if err := rows.Scan(&sessionID, &refreshTokenHash); err != nil {
|
|
continue
|
|
}
|
|
keys = append(keys, fmt.Sprintf(redisKeyJWTSessionID, sessionID))
|
|
keys = append(keys, fmt.Sprintf(redisKeyJWTSession, refreshTokenHash))
|
|
}
|
|
|
|
if len(keys) > 0 {
|
|
redisclient.RDB.Del(ctx, keys...)
|
|
}
|
|
|
|
userEmailKey := fmt.Sprintf(redisKeyUserEmail, userID)
|
|
redisclient.RDB.Del(ctx, userEmailKey)
|
|
|
|
return nil
|
|
}
|
|
|
|
func CleanupExpiredSessions() error {
|
|
ctx := context.Background()
|
|
|
|
rows, err := db.DB.Query("SELECT id, user_id, refresh_token_hash FROM jwt_sessions WHERE expires_at < ?", time.Now())
|
|
if err != nil {
|
|
return fmt.Errorf("failed to query expired sessions: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var expiredSessions []models.ExpiredSession
|
|
|
|
userIDsToCleanup := make(map[string]bool)
|
|
for rows.Next() {
|
|
var session models.ExpiredSession
|
|
if err := rows.Scan(&session.ID, &session.UsersID, &session.RefreshTokenHash); err != nil {
|
|
continue
|
|
}
|
|
expiredSessions = append(expiredSessions, session)
|
|
userIDsToCleanup[session.UsersID] = true
|
|
}
|
|
|
|
_, err = db.DB.Exec("DELETE FROM jwt_sessions WHERE expires_at < ?", time.Now())
|
|
if err != nil {
|
|
return fmt.Errorf("failed to cleanup expired sessions: %w", err)
|
|
}
|
|
|
|
for _, session := range expiredSessions {
|
|
sessionKey := fmt.Sprintf(redisKeyJWTSession, session.RefreshTokenHash)
|
|
sessionIDKey := fmt.Sprintf(redisKeyJWTSessionID, session.ID)
|
|
redisclient.RDB.Del(ctx, sessionKey, sessionIDKey)
|
|
}
|
|
|
|
// Role cache invalidation removed - handled by separate authz microservice
|
|
|
|
log.Printf("Cleaned up %d expired sessions", len(expiredSessions))
|
|
return nil
|
|
}
|
|
|
|
func GetUserSessions(userID string) ([]models.JWTSession, error) {
|
|
rows, err := db.DB.Query(`
|
|
SELECT id, user_id, refresh_token_hash, user_agent, ip_address, created_at, updated_at, expires_at, is_revoked
|
|
FROM jwt_sessions
|
|
WHERE user_id = ? AND is_revoked = false AND expires_at > ?
|
|
ORDER BY created_at DESC
|
|
`, userID, time.Now())
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf(errFormatWithContext, errMsgFailedToGetUserSessions, err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var sessions []models.JWTSession
|
|
for rows.Next() {
|
|
var session models.JWTSession
|
|
err := rows.Scan(
|
|
&session.ID,
|
|
&session.UsersID,
|
|
&session.RefreshTokenHash,
|
|
&session.UserAgent,
|
|
&session.IPAddress,
|
|
&session.CreatedAt,
|
|
&session.UpdatedAt,
|
|
&session.ExpiresAt,
|
|
&session.IsRevoked,
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to scan session row: %w", err)
|
|
}
|
|
sessions = append(sessions, session)
|
|
}
|
|
|
|
return sessions, nil
|
|
}
|
|
|
|
func UpdateSessionLastActivity(sessionID string) error {
|
|
_, err := db.DB.Exec(`
|
|
UPDATE jwt_sessions
|
|
SET updated_at = ?
|
|
WHERE id = ?
|
|
`, time.Now(), sessionID)
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("failed to update session activity: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func getUserEmailFromID(userID string) (string, error) {
|
|
var email string
|
|
|
|
err := db.DB.QueryRow("SELECT email_address FROM users WHERE id = ?", userID).Scan(&email)
|
|
if err == nil {
|
|
return email, nil
|
|
}
|
|
|
|
return "", fmt.Errorf("user not found with ID %s in any table", userID)
|
|
}
|
|
|
|
func getUserEmailFromIDCached(userID string) (string, error) {
|
|
ctx := context.Background()
|
|
cacheKey := fmt.Sprintf(redisKeyUserEmail, userID)
|
|
|
|
var email string
|
|
err := helper.GetJSON(ctx, cacheKey, &email)
|
|
if err == nil && email != "" {
|
|
return email, nil
|
|
}
|
|
|
|
// Cache miss, feth from database
|
|
email, err = getUserEmailFromID(userID)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
cacheTTL := 3600
|
|
if err := helper.SetJSON(ctx, cacheKey, email, &cacheTTL); err != nil {
|
|
helper.LogError(err, "Failed to cache user email in Redis")
|
|
}
|
|
|
|
return email, nil
|
|
}
|
|
|
|
func AddToSessionBlacklist(sessionID string, ttlSeconds int) error {
|
|
ctx := context.Background()
|
|
blacklistKey := fmt.Sprintf(redisKeySessionBlacklist, sessionID)
|
|
|
|
ttl := time.Duration(ttlSeconds) * time.Second
|
|
return redisclient.RDB.Set(ctx, blacklistKey, "revoked", ttl).Err()
|
|
}
|
|
|
|
func IsSessionBlacklisted(sessionID string) bool {
|
|
ctx := context.Background()
|
|
blacklistKey := fmt.Sprintf(redisKeySessionBlacklist, sessionID)
|
|
|
|
exists, err := redisclient.RDB.Exists(ctx, blacklistKey).Result()
|
|
return err == nil && exists > 0
|
|
}
|
|
|
|
func ClearSessionFromAllCaches(sessionID, refreshTokenHash string) error {
|
|
ctx := context.Background()
|
|
|
|
keys := []string{
|
|
fmt.Sprintf(redisKeyJWTSessionID, sessionID),
|
|
fmt.Sprintf(redisKeyJWTSession, refreshTokenHash),
|
|
}
|
|
|
|
return redisclient.RDB.Del(ctx, keys...).Err()
|
|
}
|
|
|
|
func CheckEmailInDB(email string) (bool, error) {
|
|
if db.DB == nil {
|
|
return false, fmt.Errorf("database connection is nil")
|
|
}
|
|
var exists bool
|
|
err := db.DB.QueryRow(
|
|
`SELECT EXISTS(
|
|
SELECT 1 FROM users WHERE email_address = ? AND is_deleted = 0)`, email,
|
|
).Scan(&exists)
|
|
if err != nil {
|
|
return false, fmt.Errorf("error checking email in database: %v", err)
|
|
}
|
|
return exists, nil
|
|
}
|