diff --git a/handlers/jwt.go b/handlers/jwt.go index 32f306d..0bc0592 100644 --- a/handlers/jwt.go +++ b/handlers/jwt.go @@ -15,6 +15,7 @@ import ( "fmt" "log" "os" + "sort" "time" "github.com/golang-jwt/jwt/v5" @@ -97,9 +98,20 @@ func GenerateTokens(email, userAgent, ipAddress string) (string, string, error) } log.Print("userID:", userID) - roleID, err := services.GetRoleIDsFromEmail(email) + + tokenEmail := email + roleID := make([]int, 0) + user, err := services.FetchUserByEmail(email) if err != nil { - return "", "", fmt.Errorf("error checking role in database: %w", err) + helper.LogWarn(fmt.Sprintf("Failed to fetch user profile for JWT role mapping (%s): %v", email, err)) + } else { + if user.UserID != "" { + userID = user.UserID + } + if user.EmailAddress != "" { + tokenEmail = user.EmailAddress + } + roleID = buildJWTClaimRoleIDs(user) } sessionID := helper.UUIDGenerator() @@ -156,18 +168,7 @@ func GenerateTokens(email, userAgent, ipAddress string) (string, string, error) } } - // Convert roleIDs slice to a comma-separated string for the token claim - var roleIDsStr string - if len(roleID) > 0 { - for i, r := range roleID { - if i > 0 { - roleIDsStr += "," - } - roleIDsStr += fmt.Sprintf("%d", r) - } - } - - accessToken, err := generateAccessToken(email, sessionID, userID, roleID) + accessToken, err := generateAccessToken(tokenEmail, sessionID, userID, roleID) if err != nil { return "", "", fmt.Errorf(errFormatWithContext, errMsgFailedToGenerateAccessToken, err) } @@ -176,6 +177,64 @@ func GenerateTokens(email, userAgent, ipAddress string) (string, string, error) return accessToken, refreshToken, nil } +func buildJWTClaimRoleIDs(user models.User) []int { + roleMap := make(map[int]struct{}) + orderedRoles := make([]int, 0) + roleOneProjects := make([]int, 0) + + if user.RoleID != nil { + helper.LogInfo(fmt.Sprintf("JWT role source - base role_id: %d", *user.RoleID)) + } + + addUnique := func(role int) { + if _, exists := roleMap[role]; exists { + return + } + roleMap[role] = struct{}{} + orderedRoles = append(orderedRoles, role) + } + + if user.RoleID != nil { + addUnique(*user.RoleID) + } + + if user.Projects != nil { + for _, project := range *user.Projects { + helper.LogInfo(fmt.Sprintf("JWT role source - project %d: role_id=%v", project.ProjectID, project.RoleID)) + for _, role := range project.RoleID { + if role == 1 { + roleOneProjects = append(roleOneProjects, project.ProjectID) + } + addUnique(role) + } + } + } + + if len(roleOneProjects) > 0 { + helper.LogInfo(fmt.Sprintf("JWT role trace - additional role_id=1 found in project_id(s): %v", roleOneProjects)) + } else { + helper.LogInfo("JWT role trace - additional role_id=1 not found in project roles") + } + + if len(orderedRoles) <= 1 { + if len(orderedRoles) == 1 { + helper.LogInfo(fmt.Sprintf("JWT role claims - primary role_id=%d, additional_role_id=[]", orderedRoles[0])) + } else { + helper.LogInfo("JWT role claims - primary role_id=nil, additional_role_id=[]") + } + return orderedRoles + } + + primaryRole := orderedRoles[0] + remainingRoles := append([]int(nil), orderedRoles[1:]...) + sort.Ints(remainingRoles) + + finalRoles := append([]int{primaryRole}, remainingRoles...) + helper.LogInfo(fmt.Sprintf("JWT role claims - primary role_id=%d, additional_role_id=%v", primaryRole, remainingRoles)) + + return finalRoles +} + func generateAccessToken(email, sessionID, userID string, roleID []int) (string, error) { AccessTokenExpiration := os.Getenv("ACCESS_TOKEN_EXPIRATION_MINUTES") if AccessTokenExpiration == "" { @@ -188,9 +247,12 @@ func generateAccessToken(email, sessionID, userID string, roleID []int) (string, } var primaryRoleID *int + additionalRoleIDs := make([]int, 0) if len(roleID) > 0 { - value := roleID[0] - primaryRoleID = &value + primaryRoleID = &roleID[0] + if len(roleID) > 1 { + additionalRoleIDs = append(additionalRoleIDs, roleID[1:]...) + } } expirationTime := time.Now().Add(24 * time.Hour).Unix() @@ -199,7 +261,7 @@ func generateAccessToken(email, sessionID, userID string, roleID []int) (string, Email: email, UsersID: userID, RoleID: primaryRoleID, - AdditionalRoleID: roleID, + AdditionalRoleID: additionalRoleIDs, SessionID: sessionID, Exp: expirationTime, RegisteredClaims: jwt.RegisteredClaims{ @@ -341,10 +403,18 @@ func RefreshAccessToken(refreshTokenString, userAgent, ipAddress string) (string userID = session.UsersID // Fallback to session's user ID } - roleIDs, err := services.GetRoleIDsFromEmail(email) + roleIDs := make([]int, 0) + user, err := services.FetchUserByEmail(email) if err != nil { - helper.LogError(err, fmt.Sprintf("Failed to fetch role ID for email %s during refresh", email)) - roleIDs = []int{} + helper.LogError(err, fmt.Sprintf("Failed to fetch user profile for role mapping during refresh for email %s", email)) + } else { + if user.UserID != "" { + userID = user.UserID + } + if user.EmailAddress != "" { + email = user.EmailAddress + } + roleIDs = buildJWTClaimRoleIDs(user) } accessToken, err := generateAccessToken(email, session.ID, userID, roleIDs) if err != nil { @@ -491,10 +561,18 @@ func RefreshAccessTokenWithEmailFallback(refreshTokenString, userAgent, ipAddres userID = session.UsersID // Fallback to session's user ID } - roleIDs, err := services.GetRoleIDsFromEmail(email) + roleIDs := make([]int, 0) + user, err := services.FetchUserByEmail(email) if err != nil { - helper.LogError(err, fmt.Sprintf("Failed to fetch role ID for email %s during refresh", email)) - roleIDs = []int{} + helper.LogError(err, fmt.Sprintf("Failed to fetch user profile for role mapping during refresh for email %s", email)) + } else { + if user.UserID != "" { + userID = user.UserID + } + if user.EmailAddress != "" { + email = user.EmailAddress + } + roleIDs = buildJWTClaimRoleIDs(user) } accessToken, err := generateAccessToken(email, session.ID, userID, roleIDs) if err != nil {