feat(auth): support multiple user roles in JWT and services

- Change JWT access token RoleID claim from int to []int to support multiple roles per user
- Update all token generation and refresh logic to handle multiple role IDs as []int
- Refactor services to return and process multiple role IDs from user_roles table
- Fix OAuth state handling explanation and improve code comments
- Clean up related function signatures and usages for consistency
This commit is contained in:
2026-02-03 16:35:08 +08:00
parent f4b8651a5c
commit fee314870d
5 changed files with 48 additions and 65 deletions
+35 -52
View File
@@ -96,6 +96,7 @@ func GenerateTokens(email, userAgent, ipAddress string) (string, string, error)
userID = helper.UUIDGenerator()
}
log.Print("userID:", userID)
roleID, err := services.GetRoleIDsFromEmail(email)
if err != nil {
return "", "", fmt.Errorf("error checking role in database: %w", err)
@@ -125,7 +126,7 @@ func GenerateTokens(email, userAgent, ipAddress string) (string, string, error)
session := models.JWTSession{
ID: sessionID,
UserID: userID,
UsersID: userID,
RefreshTokenHash: refreshTokenHash,
UserAgent: userAgent,
IPAddress: ipAddress,
@@ -166,7 +167,7 @@ func GenerateTokens(email, userAgent, ipAddress string) (string, string, error)
}
}
accessToken, err := generateAccessToken(email, sessionID, userID, roleIDsStr)
accessToken, err := generateAccessToken(email, sessionID, userID, roleID)
if err != nil {
return "", "", fmt.Errorf(errFormatWithContext, errMsgFailedToGenerateAccessToken, err)
}
@@ -175,7 +176,7 @@ func GenerateTokens(email, userAgent, ipAddress string) (string, string, error)
return accessToken, refreshToken, nil
}
func generateAccessToken(email, sessionID, userID, roleID string) (string, error) {
func generateAccessToken(email, sessionID, userID string, roleID []int) (string, error) {
AccessTokenExpiration := os.Getenv("ACCESS_TOKEN_EXPIRATION_MINUTES")
if AccessTokenExpiration == "" {
log.Println("AccessTokenExpiration not set (in minutes), defaulting to 45 minutes")
@@ -186,7 +187,7 @@ func generateAccessToken(email, sessionID, userID, roleID string) (string, error
claims := &models.AccessToken{
Email: email,
UserID: userID,
UsersID: userID,
RoleID: roleID,
SessionID: sessionID,
Exp: expirationTime,
@@ -251,7 +252,7 @@ func RefreshAccessToken(refreshTokenString, userAgent, ipAddress string) (string
WHERE refresh_token_hash = ? AND is_revoked = false
`, refreshTokenHash).Scan(
&session.ID,
&session.UserID,
&session.UsersID,
&session.UserAgent,
&session.IPAddress,
&session.CreatedAt,
@@ -265,8 +266,8 @@ func RefreshAccessToken(refreshTokenString, userAgent, ipAddress string) (string
return "", fmt.Errorf("invalid refresh token: %w", err)
}
helper.LogInfo(fmt.Sprintf("Session found in DB - ID: %s, UserID: %s, Created: %s, Expires: %s",
session.ID, session.UserID, session.CreatedAt.Format(timeFormatDateTime), session.ExpiresAt.Format(timeFormatDateTime)))
helper.LogInfo(fmt.Sprintf("Session found in DB - ID: %s, UsersID: %s, Created: %s, Expires: %s",
session.ID, session.UsersID, session.CreatedAt.Format(timeFormatDateTime), session.ExpiresAt.Format(timeFormatDateTime)))
sessionTTL := int(time.Until(session.ExpiresAt).Seconds())
if sessionTTL > 0 {
@@ -275,8 +276,8 @@ func RefreshAccessToken(refreshTokenString, userAgent, ipAddress string) (string
}
}
} else {
helper.LogInfo(fmt.Sprintf("Session found in Redis cache - ID: %s, UserID: %s, Expires: %s",
session.ID, session.UserID, session.ExpiresAt.Format(timeFormatDateTime)))
helper.LogInfo(fmt.Sprintf("Session found in Redis cache - ID: %s, UsersID: %s, Expires: %s",
session.ID, session.UsersID, session.ExpiresAt.Format(timeFormatDateTime)))
}
if session.IsRevoked {
@@ -307,18 +308,18 @@ func RefreshAccessToken(refreshTokenString, userAgent, ipAddress string) (string
}
// Get user email from user ID (with caching)
email, err := getUserEmailFromIDCached(session.UserID)
email, err := getUserEmailFromIDCached(session.UsersID)
if err != nil {
helper.LogError(err, fmt.Sprintf("Failed to get email for user %s", session.UserID))
helper.LogError(err, fmt.Sprintf("Failed to get email for user %s", session.UsersID))
// For registrants or users not yet in the main tables, we still want to allow refresh
// but we need to get the email from somewhere else. Since we don't store email in session,
// we'll need to handle this gracefully by allowing the refresh to continue with a placeholder
// The email will be properly resolved when they complete registration
helper.LogWarn(fmt.Sprintf("User ID %s not found in database tables, but allowing refresh for potential registrant", session.UserID))
helper.LogWarn(fmt.Sprintf("User ID %s not found in database tables, but allowing refresh for potential registrant", session.UsersID))
// For now, we'll use a placeholder email pattern and let the access token generation handle it
// The system should work as long as the session is valid
email = fmt.Sprintf("registrant_%s@pending.local", session.UserID)
email = fmt.Sprintf("registrant_%s@pending.local", session.UsersID)
}
helper.LogInfo(fmt.Sprintf("Generating new access token for email: %s, session: %s", email, session.ID))
@@ -326,24 +327,15 @@ func RefreshAccessToken(refreshTokenString, userAgent, ipAddress string) (string
userID, err := helper.FetchUserIDFromDB(email)
if err != nil {
helper.LogError(err, fmt.Sprintf("Failed to fetch user ID for email %s during refresh", email))
userID = session.UserID // Fallback to session's user ID
userID = session.UsersID // Fallback to session's user ID
}
roleIDs, err := services.GetRoleIDsFromEmail(email)
var roleIDsStr string
if err != nil {
helper.LogError(err, fmt.Sprintf("Failed to fetch role ID for email %s during refresh", email))
roleIDsStr = ""
} else {
for i, r := range roleIDs {
if i > 0 {
roleIDsStr += ","
}
roleIDsStr += fmt.Sprintf("%d", r)
}
roleIDs = []int{}
}
accessToken, err := generateAccessToken(email, session.ID, userID, roleIDsStr)
accessToken, err := generateAccessToken(email, session.ID, userID, roleIDs)
if err != nil {
helper.LogError(err, "Failed to generate access token during refresh")
return "", fmt.Errorf("failed to generate access token: %w", err)
@@ -411,7 +403,7 @@ func RefreshAccessTokenWithEmailFallback(refreshTokenString, userAgent, ipAddres
WHERE refresh_token_hash = ? AND is_revoked = false
`, refreshTokenHash).Scan(
&session.ID,
&session.UserID,
&session.UsersID,
&session.UserAgent,
&session.IPAddress,
&session.CreatedAt,
@@ -425,8 +417,8 @@ func RefreshAccessTokenWithEmailFallback(refreshTokenString, userAgent, ipAddres
return "", fmt.Errorf("invalid refresh token: %w", err)
}
helper.LogInfo(fmt.Sprintf("Session found in DB - ID: %s, UserID: %s, Created: %s, Expires: %s",
session.ID, session.UserID, session.CreatedAt.Format(timeFormatDateTime), session.ExpiresAt.Format(timeFormatDateTime)))
helper.LogInfo(fmt.Sprintf("Session found in DB - ID: %s, UsersID: %s, Created: %s, Expires: %s",
session.ID, session.UsersID, session.CreatedAt.Format(timeFormatDateTime), session.ExpiresAt.Format(timeFormatDateTime)))
sessionTTL := int(time.Until(session.ExpiresAt).Seconds())
if sessionTTL > 0 {
@@ -435,8 +427,8 @@ func RefreshAccessTokenWithEmailFallback(refreshTokenString, userAgent, ipAddres
}
}
} else {
helper.LogInfo(fmt.Sprintf("Session found in Redis cache - ID: %s, UserID: %s, Expires: %s",
session.ID, session.UserID, session.ExpiresAt.Format(timeFormatDateTime)))
helper.LogInfo(fmt.Sprintf("Session found in Redis cache - ID: %s, UsersID: %s, Expires: %s",
session.ID, session.UsersID, session.ExpiresAt.Format(timeFormatDateTime)))
}
if session.IsRevoked {
@@ -467,15 +459,15 @@ func RefreshAccessTokenWithEmailFallback(refreshTokenString, userAgent, ipAddres
}
// Get user email from user ID (with caching), with fallback to provided email
email, err := getUserEmailFromIDCached(session.UserID)
email, err := getUserEmailFromIDCached(session.UsersID)
if err != nil {
helper.LogError(err, fmt.Sprintf("Failed to get email for user %s", session.UserID))
helper.LogError(err, fmt.Sprintf("Failed to get email for user %s", session.UsersID))
if emailFallback != "" {
helper.LogInfo(fmt.Sprintf("Using email fallback for user ID %s: %s", session.UserID, emailFallback))
helper.LogInfo(fmt.Sprintf("Using email fallback for user ID %s: %s", session.UsersID, emailFallback))
email = emailFallback
} else {
helper.LogWarn(fmt.Sprintf("User ID %s not found in database tables and no email fallback provided", session.UserID))
helper.LogWarn(fmt.Sprintf("User ID %s not found in database tables and no email fallback provided", session.UsersID))
return "", fmt.Errorf("failed to get user email: %w", err)
}
}
@@ -485,24 +477,15 @@ func RefreshAccessTokenWithEmailFallback(refreshTokenString, userAgent, ipAddres
userID, err := helper.FetchUserIDFromDB(email)
if err != nil {
helper.LogError(err, fmt.Sprintf("Failed to fetch user ID for email %s during refresh", email))
userID = session.UserID // Fallback to session's user ID
userID = session.UsersID // Fallback to session's user ID
}
roleIDs, err := services.GetRoleIDsFromEmail(email)
var roleIDsStr string
if err != nil {
helper.LogError(err, fmt.Sprintf("Failed to fetch role ID for email %s during refresh", email))
roleIDsStr = ""
} else {
for i, r := range roleIDs {
if i > 0 {
roleIDsStr += ","
}
roleIDsStr += fmt.Sprintf("%d", r)
}
roleIDs = []int{}
}
accessToken, err := generateAccessToken(email, session.ID, userID, roleIDsStr)
accessToken, err := generateAccessToken(email, session.ID, userID, roleIDs)
if err != nil {
helper.LogError(err, "Failed to generate access token during refresh")
return "", fmt.Errorf("failed to generate access token: %w", err)
@@ -623,7 +606,7 @@ func ValidateSession(sessionID string) (*models.JWTSession, error) {
WHERE id = ?
`, sessionID).Scan(
&session.ID,
&session.UserID,
&session.UsersID,
&session.RefreshTokenHash,
&session.UserAgent,
&session.IPAddress,
@@ -716,12 +699,12 @@ func GetSessionUserInfo(sessionID string) (string, string, error) {
return "", "", fmt.Errorf("%s", errMsgSessionNotFoundInCache)
}
email, err := getUserEmailFromIDCached(session.UserID)
email, err := getUserEmailFromIDCached(session.UsersID)
if err != nil {
return "", "", fmt.Errorf("failed to get user email: %w", err)
}
return session.UserID, email, nil
return session.UsersID, email, nil
}
func InvalidateUserSessionsInCache(userID string) error {
@@ -767,11 +750,11 @@ func CleanupExpiredSessions() error {
userIDsToCleanup := make(map[string]bool)
for rows.Next() {
var session models.ExpiredSession
if err := rows.Scan(&session.ID, &session.UserID, &session.RefreshTokenHash); err != nil {
if err := rows.Scan(&session.ID, &session.UsersID, &session.RefreshTokenHash); err != nil {
continue
}
expiredSessions = append(expiredSessions, session)
userIDsToCleanup[session.UserID] = true
userIDsToCleanup[session.UsersID] = true
}
_, err = db.DB.Exec("DELETE FROM jwt_sessions WHERE expires_at < ?", time.Now())
@@ -809,7 +792,7 @@ func GetUserSessions(userID string) ([]models.JWTSession, error) {
var session models.JWTSession
err := rows.Scan(
&session.ID,
&session.UserID,
&session.UsersID,
&session.RefreshTokenHash,
&session.UserAgent,
&session.IPAddress,
+1 -1
View File
@@ -186,7 +186,7 @@ func validateSessionFromDB(sessionID string) (*models.JWTSession, error) {
WHERE id = ? AND is_revoked = false
`, sessionID).Scan(
&session.ID,
&session.UserID,
&session.UsersID,
&session.RefreshTokenHash,
&session.UserAgent,
&session.IPAddress,
+4 -4
View File
@@ -8,8 +8,8 @@ import (
type AccessToken struct {
Email string `json:"email"`
UserID string `json:"user_id"`
RoleID string `json:"role_id"`
UsersID string `json:"users_id"`
RoleID []int `json:"role_id"`
SessionID string `json:"session_id"`
Exp int64 `json:"exp"`
jwt.RegisteredClaims
@@ -17,7 +17,7 @@ type AccessToken struct {
type JWTSession struct {
ID string `json:"id" db:"id"`
UserID string `json:"user_id" db:"user_id"`
UsersID string `json:"users_id" db:"users_id"`
RefreshTokenHash string `json:"refresh_token_hash" db:"refresh_token_hash"`
UserAgent string `json:"user_agent" db:"user_agent"`
IPAddress string `json:"ip_address" db:"ip_address"`
@@ -29,6 +29,6 @@ type JWTSession struct {
type ExpiredSession struct {
ID string
UserID string
UsersID string
RefreshTokenHash string
}
+6 -6
View File
@@ -68,7 +68,7 @@ func TestJWTSessionCreation(t *testing.T) {
now := time.Now()
session := &JWTSession{
ID: "session-id-123",
UserID: "user-456",
UsersID: "user-456",
RefreshTokenHash: "hash123",
UserAgent: "Mozilla/5.0",
IPAddress: "192.168.1.1",
@@ -82,8 +82,8 @@ func TestJWTSessionCreation(t *testing.T) {
t.Errorf("Expected session ID 'session-id-123', got '%s'", session.ID)
}
if session.UserID != "user-456" {
t.Errorf("Expected user ID 'user-456', got '%s'", session.UserID)
if session.UsersID != "user-456" {
t.Errorf("Expected user ID 'user-456', got '%s'", session.UsersID)
}
if session.RefreshTokenHash != "hash123" {
@@ -182,7 +182,7 @@ func TestJWTSessionRevokedStatus(t *testing.T) {
func TestExpiredSessionCreation(t *testing.T) {
expiredSession := ExpiredSession{
ID: "expired-id-123",
UserID: "user-789",
UsersID: "user-789",
RefreshTokenHash: "expired-hash",
}
@@ -190,8 +190,8 @@ func TestExpiredSessionCreation(t *testing.T) {
t.Errorf("Expected ID 'expired-id-123', got '%s'", expiredSession.ID)
}
if expiredSession.UserID != "user-789" {
t.Errorf("Expected UserID 'user-789', got '%s'", expiredSession.UserID)
if expiredSession.UsersID != "user-789" {
t.Errorf("Expected UsersID 'user-789', got '%s'", expiredSession.UsersID)
}
if expiredSession.RefreshTokenHash != "expired-hash" {
+2 -2
View File
@@ -29,9 +29,9 @@ func CheckEmailInDB(email string) (bool, error) {
func GetUserIDFromEmail(email string) (string, error) {
log.Print(email)
query := `SELECT user_id
query := `SELECT users_id
FROM (
SELECT user_id, 1 AS priority
SELECT users_id, 1 AS priority
FROM users
WHERE email_address = ?
AND is_deleted = 0