feat(auth): support multiple user roles in JWT and services
- Change JWT access token RoleID claim from int to []int to support multiple roles per user - Update all token generation and refresh logic to handle multiple role IDs as []int - Refactor services to return and process multiple role IDs from user_roles table - Fix OAuth state handling explanation and improve code comments - Clean up related function signatures and usages for consistency
This commit is contained in:
+35
-52
@@ -96,6 +96,7 @@ func GenerateTokens(email, userAgent, ipAddress string) (string, string, error)
|
||||
userID = helper.UUIDGenerator()
|
||||
}
|
||||
|
||||
log.Print("userID:", userID)
|
||||
roleID, err := services.GetRoleIDsFromEmail(email)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("error checking role in database: %w", err)
|
||||
@@ -125,7 +126,7 @@ func GenerateTokens(email, userAgent, ipAddress string) (string, string, error)
|
||||
|
||||
session := models.JWTSession{
|
||||
ID: sessionID,
|
||||
UserID: userID,
|
||||
UsersID: userID,
|
||||
RefreshTokenHash: refreshTokenHash,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: ipAddress,
|
||||
@@ -166,7 +167,7 @@ func GenerateTokens(email, userAgent, ipAddress string) (string, string, error)
|
||||
}
|
||||
}
|
||||
|
||||
accessToken, err := generateAccessToken(email, sessionID, userID, roleIDsStr)
|
||||
accessToken, err := generateAccessToken(email, sessionID, userID, roleID)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf(errFormatWithContext, errMsgFailedToGenerateAccessToken, err)
|
||||
}
|
||||
@@ -175,7 +176,7 @@ func GenerateTokens(email, userAgent, ipAddress string) (string, string, error)
|
||||
return accessToken, refreshToken, nil
|
||||
}
|
||||
|
||||
func generateAccessToken(email, sessionID, userID, roleID string) (string, error) {
|
||||
func generateAccessToken(email, sessionID, userID string, roleID []int) (string, error) {
|
||||
AccessTokenExpiration := os.Getenv("ACCESS_TOKEN_EXPIRATION_MINUTES")
|
||||
if AccessTokenExpiration == "" {
|
||||
log.Println("AccessTokenExpiration not set (in minutes), defaulting to 45 minutes")
|
||||
@@ -186,7 +187,7 @@ func generateAccessToken(email, sessionID, userID, roleID string) (string, error
|
||||
|
||||
claims := &models.AccessToken{
|
||||
Email: email,
|
||||
UserID: userID,
|
||||
UsersID: userID,
|
||||
RoleID: roleID,
|
||||
SessionID: sessionID,
|
||||
Exp: expirationTime,
|
||||
@@ -251,7 +252,7 @@ func RefreshAccessToken(refreshTokenString, userAgent, ipAddress string) (string
|
||||
WHERE refresh_token_hash = ? AND is_revoked = false
|
||||
`, refreshTokenHash).Scan(
|
||||
&session.ID,
|
||||
&session.UserID,
|
||||
&session.UsersID,
|
||||
&session.UserAgent,
|
||||
&session.IPAddress,
|
||||
&session.CreatedAt,
|
||||
@@ -265,8 +266,8 @@ func RefreshAccessToken(refreshTokenString, userAgent, ipAddress string) (string
|
||||
return "", fmt.Errorf("invalid refresh token: %w", err)
|
||||
}
|
||||
|
||||
helper.LogInfo(fmt.Sprintf("Session found in DB - ID: %s, UserID: %s, Created: %s, Expires: %s",
|
||||
session.ID, session.UserID, session.CreatedAt.Format(timeFormatDateTime), session.ExpiresAt.Format(timeFormatDateTime)))
|
||||
helper.LogInfo(fmt.Sprintf("Session found in DB - ID: %s, UsersID: %s, Created: %s, Expires: %s",
|
||||
session.ID, session.UsersID, session.CreatedAt.Format(timeFormatDateTime), session.ExpiresAt.Format(timeFormatDateTime)))
|
||||
|
||||
sessionTTL := int(time.Until(session.ExpiresAt).Seconds())
|
||||
if sessionTTL > 0 {
|
||||
@@ -275,8 +276,8 @@ func RefreshAccessToken(refreshTokenString, userAgent, ipAddress string) (string
|
||||
}
|
||||
}
|
||||
} else {
|
||||
helper.LogInfo(fmt.Sprintf("Session found in Redis cache - ID: %s, UserID: %s, Expires: %s",
|
||||
session.ID, session.UserID, session.ExpiresAt.Format(timeFormatDateTime)))
|
||||
helper.LogInfo(fmt.Sprintf("Session found in Redis cache - ID: %s, UsersID: %s, Expires: %s",
|
||||
session.ID, session.UsersID, session.ExpiresAt.Format(timeFormatDateTime)))
|
||||
}
|
||||
|
||||
if session.IsRevoked {
|
||||
@@ -307,18 +308,18 @@ func RefreshAccessToken(refreshTokenString, userAgent, ipAddress string) (string
|
||||
}
|
||||
|
||||
// Get user email from user ID (with caching)
|
||||
email, err := getUserEmailFromIDCached(session.UserID)
|
||||
email, err := getUserEmailFromIDCached(session.UsersID)
|
||||
if err != nil {
|
||||
helper.LogError(err, fmt.Sprintf("Failed to get email for user %s", session.UserID))
|
||||
helper.LogError(err, fmt.Sprintf("Failed to get email for user %s", session.UsersID))
|
||||
// For registrants or users not yet in the main tables, we still want to allow refresh
|
||||
// but we need to get the email from somewhere else. Since we don't store email in session,
|
||||
// we'll need to handle this gracefully by allowing the refresh to continue with a placeholder
|
||||
// The email will be properly resolved when they complete registration
|
||||
helper.LogWarn(fmt.Sprintf("User ID %s not found in database tables, but allowing refresh for potential registrant", session.UserID))
|
||||
helper.LogWarn(fmt.Sprintf("User ID %s not found in database tables, but allowing refresh for potential registrant", session.UsersID))
|
||||
|
||||
// For now, we'll use a placeholder email pattern and let the access token generation handle it
|
||||
// The system should work as long as the session is valid
|
||||
email = fmt.Sprintf("registrant_%s@pending.local", session.UserID)
|
||||
email = fmt.Sprintf("registrant_%s@pending.local", session.UsersID)
|
||||
}
|
||||
|
||||
helper.LogInfo(fmt.Sprintf("Generating new access token for email: %s, session: %s", email, session.ID))
|
||||
@@ -326,24 +327,15 @@ func RefreshAccessToken(refreshTokenString, userAgent, ipAddress string) (string
|
||||
userID, err := helper.FetchUserIDFromDB(email)
|
||||
if err != nil {
|
||||
helper.LogError(err, fmt.Sprintf("Failed to fetch user ID for email %s during refresh", email))
|
||||
userID = session.UserID // Fallback to session's user ID
|
||||
userID = session.UsersID // Fallback to session's user ID
|
||||
}
|
||||
|
||||
roleIDs, err := services.GetRoleIDsFromEmail(email)
|
||||
var roleIDsStr string
|
||||
if err != nil {
|
||||
helper.LogError(err, fmt.Sprintf("Failed to fetch role ID for email %s during refresh", email))
|
||||
roleIDsStr = ""
|
||||
} else {
|
||||
for i, r := range roleIDs {
|
||||
if i > 0 {
|
||||
roleIDsStr += ","
|
||||
}
|
||||
roleIDsStr += fmt.Sprintf("%d", r)
|
||||
}
|
||||
roleIDs = []int{}
|
||||
}
|
||||
|
||||
accessToken, err := generateAccessToken(email, session.ID, userID, roleIDsStr)
|
||||
accessToken, err := generateAccessToken(email, session.ID, userID, roleIDs)
|
||||
if err != nil {
|
||||
helper.LogError(err, "Failed to generate access token during refresh")
|
||||
return "", fmt.Errorf("failed to generate access token: %w", err)
|
||||
@@ -411,7 +403,7 @@ func RefreshAccessTokenWithEmailFallback(refreshTokenString, userAgent, ipAddres
|
||||
WHERE refresh_token_hash = ? AND is_revoked = false
|
||||
`, refreshTokenHash).Scan(
|
||||
&session.ID,
|
||||
&session.UserID,
|
||||
&session.UsersID,
|
||||
&session.UserAgent,
|
||||
&session.IPAddress,
|
||||
&session.CreatedAt,
|
||||
@@ -425,8 +417,8 @@ func RefreshAccessTokenWithEmailFallback(refreshTokenString, userAgent, ipAddres
|
||||
return "", fmt.Errorf("invalid refresh token: %w", err)
|
||||
}
|
||||
|
||||
helper.LogInfo(fmt.Sprintf("Session found in DB - ID: %s, UserID: %s, Created: %s, Expires: %s",
|
||||
session.ID, session.UserID, session.CreatedAt.Format(timeFormatDateTime), session.ExpiresAt.Format(timeFormatDateTime)))
|
||||
helper.LogInfo(fmt.Sprintf("Session found in DB - ID: %s, UsersID: %s, Created: %s, Expires: %s",
|
||||
session.ID, session.UsersID, session.CreatedAt.Format(timeFormatDateTime), session.ExpiresAt.Format(timeFormatDateTime)))
|
||||
|
||||
sessionTTL := int(time.Until(session.ExpiresAt).Seconds())
|
||||
if sessionTTL > 0 {
|
||||
@@ -435,8 +427,8 @@ func RefreshAccessTokenWithEmailFallback(refreshTokenString, userAgent, ipAddres
|
||||
}
|
||||
}
|
||||
} else {
|
||||
helper.LogInfo(fmt.Sprintf("Session found in Redis cache - ID: %s, UserID: %s, Expires: %s",
|
||||
session.ID, session.UserID, session.ExpiresAt.Format(timeFormatDateTime)))
|
||||
helper.LogInfo(fmt.Sprintf("Session found in Redis cache - ID: %s, UsersID: %s, Expires: %s",
|
||||
session.ID, session.UsersID, session.ExpiresAt.Format(timeFormatDateTime)))
|
||||
}
|
||||
|
||||
if session.IsRevoked {
|
||||
@@ -467,15 +459,15 @@ func RefreshAccessTokenWithEmailFallback(refreshTokenString, userAgent, ipAddres
|
||||
}
|
||||
|
||||
// Get user email from user ID (with caching), with fallback to provided email
|
||||
email, err := getUserEmailFromIDCached(session.UserID)
|
||||
email, err := getUserEmailFromIDCached(session.UsersID)
|
||||
if err != nil {
|
||||
helper.LogError(err, fmt.Sprintf("Failed to get email for user %s", session.UserID))
|
||||
helper.LogError(err, fmt.Sprintf("Failed to get email for user %s", session.UsersID))
|
||||
|
||||
if emailFallback != "" {
|
||||
helper.LogInfo(fmt.Sprintf("Using email fallback for user ID %s: %s", session.UserID, emailFallback))
|
||||
helper.LogInfo(fmt.Sprintf("Using email fallback for user ID %s: %s", session.UsersID, emailFallback))
|
||||
email = emailFallback
|
||||
} else {
|
||||
helper.LogWarn(fmt.Sprintf("User ID %s not found in database tables and no email fallback provided", session.UserID))
|
||||
helper.LogWarn(fmt.Sprintf("User ID %s not found in database tables and no email fallback provided", session.UsersID))
|
||||
return "", fmt.Errorf("failed to get user email: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -485,24 +477,15 @@ func RefreshAccessTokenWithEmailFallback(refreshTokenString, userAgent, ipAddres
|
||||
userID, err := helper.FetchUserIDFromDB(email)
|
||||
if err != nil {
|
||||
helper.LogError(err, fmt.Sprintf("Failed to fetch user ID for email %s during refresh", email))
|
||||
userID = session.UserID // Fallback to session's user ID
|
||||
userID = session.UsersID // Fallback to session's user ID
|
||||
}
|
||||
|
||||
roleIDs, err := services.GetRoleIDsFromEmail(email)
|
||||
var roleIDsStr string
|
||||
if err != nil {
|
||||
helper.LogError(err, fmt.Sprintf("Failed to fetch role ID for email %s during refresh", email))
|
||||
roleIDsStr = ""
|
||||
} else {
|
||||
for i, r := range roleIDs {
|
||||
if i > 0 {
|
||||
roleIDsStr += ","
|
||||
}
|
||||
roleIDsStr += fmt.Sprintf("%d", r)
|
||||
}
|
||||
roleIDs = []int{}
|
||||
}
|
||||
|
||||
accessToken, err := generateAccessToken(email, session.ID, userID, roleIDsStr)
|
||||
accessToken, err := generateAccessToken(email, session.ID, userID, roleIDs)
|
||||
if err != nil {
|
||||
helper.LogError(err, "Failed to generate access token during refresh")
|
||||
return "", fmt.Errorf("failed to generate access token: %w", err)
|
||||
@@ -623,7 +606,7 @@ func ValidateSession(sessionID string) (*models.JWTSession, error) {
|
||||
WHERE id = ?
|
||||
`, sessionID).Scan(
|
||||
&session.ID,
|
||||
&session.UserID,
|
||||
&session.UsersID,
|
||||
&session.RefreshTokenHash,
|
||||
&session.UserAgent,
|
||||
&session.IPAddress,
|
||||
@@ -716,12 +699,12 @@ func GetSessionUserInfo(sessionID string) (string, string, error) {
|
||||
return "", "", fmt.Errorf("%s", errMsgSessionNotFoundInCache)
|
||||
}
|
||||
|
||||
email, err := getUserEmailFromIDCached(session.UserID)
|
||||
email, err := getUserEmailFromIDCached(session.UsersID)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to get user email: %w", err)
|
||||
}
|
||||
|
||||
return session.UserID, email, nil
|
||||
return session.UsersID, email, nil
|
||||
}
|
||||
|
||||
func InvalidateUserSessionsInCache(userID string) error {
|
||||
@@ -767,11 +750,11 @@ func CleanupExpiredSessions() error {
|
||||
userIDsToCleanup := make(map[string]bool)
|
||||
for rows.Next() {
|
||||
var session models.ExpiredSession
|
||||
if err := rows.Scan(&session.ID, &session.UserID, &session.RefreshTokenHash); err != nil {
|
||||
if err := rows.Scan(&session.ID, &session.UsersID, &session.RefreshTokenHash); err != nil {
|
||||
continue
|
||||
}
|
||||
expiredSessions = append(expiredSessions, session)
|
||||
userIDsToCleanup[session.UserID] = true
|
||||
userIDsToCleanup[session.UsersID] = true
|
||||
}
|
||||
|
||||
_, err = db.DB.Exec("DELETE FROM jwt_sessions WHERE expires_at < ?", time.Now())
|
||||
@@ -809,7 +792,7 @@ func GetUserSessions(userID string) ([]models.JWTSession, error) {
|
||||
var session models.JWTSession
|
||||
err := rows.Scan(
|
||||
&session.ID,
|
||||
&session.UserID,
|
||||
&session.UsersID,
|
||||
&session.RefreshTokenHash,
|
||||
&session.UserAgent,
|
||||
&session.IPAddress,
|
||||
|
||||
+1
-1
@@ -186,7 +186,7 @@ func validateSessionFromDB(sessionID string) (*models.JWTSession, error) {
|
||||
WHERE id = ? AND is_revoked = false
|
||||
`, sessionID).Scan(
|
||||
&session.ID,
|
||||
&session.UserID,
|
||||
&session.UsersID,
|
||||
&session.RefreshTokenHash,
|
||||
&session.UserAgent,
|
||||
&session.IPAddress,
|
||||
|
||||
+4
-4
@@ -8,8 +8,8 @@ import (
|
||||
|
||||
type AccessToken struct {
|
||||
Email string `json:"email"`
|
||||
UserID string `json:"user_id"`
|
||||
RoleID string `json:"role_id"`
|
||||
UsersID string `json:"users_id"`
|
||||
RoleID []int `json:"role_id"`
|
||||
SessionID string `json:"session_id"`
|
||||
Exp int64 `json:"exp"`
|
||||
jwt.RegisteredClaims
|
||||
@@ -17,7 +17,7 @@ type AccessToken struct {
|
||||
|
||||
type JWTSession struct {
|
||||
ID string `json:"id" db:"id"`
|
||||
UserID string `json:"user_id" db:"user_id"`
|
||||
UsersID string `json:"users_id" db:"users_id"`
|
||||
RefreshTokenHash string `json:"refresh_token_hash" db:"refresh_token_hash"`
|
||||
UserAgent string `json:"user_agent" db:"user_agent"`
|
||||
IPAddress string `json:"ip_address" db:"ip_address"`
|
||||
@@ -29,6 +29,6 @@ type JWTSession struct {
|
||||
|
||||
type ExpiredSession struct {
|
||||
ID string
|
||||
UserID string
|
||||
UsersID string
|
||||
RefreshTokenHash string
|
||||
}
|
||||
|
||||
+6
-6
@@ -68,7 +68,7 @@ func TestJWTSessionCreation(t *testing.T) {
|
||||
now := time.Now()
|
||||
session := &JWTSession{
|
||||
ID: "session-id-123",
|
||||
UserID: "user-456",
|
||||
UsersID: "user-456",
|
||||
RefreshTokenHash: "hash123",
|
||||
UserAgent: "Mozilla/5.0",
|
||||
IPAddress: "192.168.1.1",
|
||||
@@ -82,8 +82,8 @@ func TestJWTSessionCreation(t *testing.T) {
|
||||
t.Errorf("Expected session ID 'session-id-123', got '%s'", session.ID)
|
||||
}
|
||||
|
||||
if session.UserID != "user-456" {
|
||||
t.Errorf("Expected user ID 'user-456', got '%s'", session.UserID)
|
||||
if session.UsersID != "user-456" {
|
||||
t.Errorf("Expected user ID 'user-456', got '%s'", session.UsersID)
|
||||
}
|
||||
|
||||
if session.RefreshTokenHash != "hash123" {
|
||||
@@ -182,7 +182,7 @@ func TestJWTSessionRevokedStatus(t *testing.T) {
|
||||
func TestExpiredSessionCreation(t *testing.T) {
|
||||
expiredSession := ExpiredSession{
|
||||
ID: "expired-id-123",
|
||||
UserID: "user-789",
|
||||
UsersID: "user-789",
|
||||
RefreshTokenHash: "expired-hash",
|
||||
}
|
||||
|
||||
@@ -190,8 +190,8 @@ func TestExpiredSessionCreation(t *testing.T) {
|
||||
t.Errorf("Expected ID 'expired-id-123', got '%s'", expiredSession.ID)
|
||||
}
|
||||
|
||||
if expiredSession.UserID != "user-789" {
|
||||
t.Errorf("Expected UserID 'user-789', got '%s'", expiredSession.UserID)
|
||||
if expiredSession.UsersID != "user-789" {
|
||||
t.Errorf("Expected UsersID 'user-789', got '%s'", expiredSession.UsersID)
|
||||
}
|
||||
|
||||
if expiredSession.RefreshTokenHash != "expired-hash" {
|
||||
|
||||
+2
-2
@@ -29,9 +29,9 @@ func CheckEmailInDB(email string) (bool, error) {
|
||||
|
||||
func GetUserIDFromEmail(email string) (string, error) {
|
||||
log.Print(email)
|
||||
query := `SELECT user_id
|
||||
query := `SELECT users_id
|
||||
FROM (
|
||||
SELECT user_id, 1 AS priority
|
||||
SELECT users_id, 1 AS priority
|
||||
FROM users
|
||||
WHERE email_address = ?
|
||||
AND is_deleted = 0
|
||||
|
||||
Reference in New Issue
Block a user