Merge branch 'rsa' into 'main'

fix(jwt): make role claims consistent with /me and correct additional_role_id

See merge request psa/uess/authn!2
This commit is contained in:
2026-02-26 10:45:56 +08:00
+101 -23
View File
@@ -15,6 +15,7 @@ import (
"fmt"
"log"
"os"
"sort"
"time"
"github.com/golang-jwt/jwt/v5"
@@ -97,9 +98,20 @@ func GenerateTokens(email, userAgent, ipAddress string) (string, string, error)
}
log.Print("userID:", userID)
roleID, err := services.GetRoleIDsFromEmail(email)
tokenEmail := email
roleID := make([]int, 0)
user, err := services.FetchUserByEmail(email)
if err != nil {
return "", "", fmt.Errorf("error checking role in database: %w", err)
helper.LogWarn(fmt.Sprintf("Failed to fetch user profile for JWT role mapping (%s): %v", email, err))
} else {
if user.UserID != "" {
userID = user.UserID
}
if user.EmailAddress != "" {
tokenEmail = user.EmailAddress
}
roleID = buildJWTClaimRoleIDs(user)
}
sessionID := helper.UUIDGenerator()
@@ -156,18 +168,7 @@ func GenerateTokens(email, userAgent, ipAddress string) (string, string, error)
}
}
// Convert roleIDs slice to a comma-separated string for the token claim
var roleIDsStr string
if len(roleID) > 0 {
for i, r := range roleID {
if i > 0 {
roleIDsStr += ","
}
roleIDsStr += fmt.Sprintf("%d", r)
}
}
accessToken, err := generateAccessToken(email, sessionID, userID, roleID)
accessToken, err := generateAccessToken(tokenEmail, sessionID, userID, roleID)
if err != nil {
return "", "", fmt.Errorf(errFormatWithContext, errMsgFailedToGenerateAccessToken, err)
}
@@ -176,6 +177,64 @@ func GenerateTokens(email, userAgent, ipAddress string) (string, string, error)
return accessToken, refreshToken, nil
}
func buildJWTClaimRoleIDs(user models.User) []int {
roleMap := make(map[int]struct{})
orderedRoles := make([]int, 0)
roleOneProjects := make([]int, 0)
if user.RoleID != nil {
helper.LogInfo(fmt.Sprintf("JWT role source - base role_id: %d", *user.RoleID))
}
addUnique := func(role int) {
if _, exists := roleMap[role]; exists {
return
}
roleMap[role] = struct{}{}
orderedRoles = append(orderedRoles, role)
}
if user.RoleID != nil {
addUnique(*user.RoleID)
}
if user.Projects != nil {
for _, project := range *user.Projects {
helper.LogInfo(fmt.Sprintf("JWT role source - project %d: role_id=%v", project.ProjectID, project.RoleID))
for _, role := range project.RoleID {
if role == 1 {
roleOneProjects = append(roleOneProjects, project.ProjectID)
}
addUnique(role)
}
}
}
if len(roleOneProjects) > 0 {
helper.LogInfo(fmt.Sprintf("JWT role trace - additional role_id=1 found in project_id(s): %v", roleOneProjects))
} else {
helper.LogInfo("JWT role trace - additional role_id=1 not found in project roles")
}
if len(orderedRoles) <= 1 {
if len(orderedRoles) == 1 {
helper.LogInfo(fmt.Sprintf("JWT role claims - primary role_id=%d, additional_role_id=[]", orderedRoles[0]))
} else {
helper.LogInfo("JWT role claims - primary role_id=nil, additional_role_id=[]")
}
return orderedRoles
}
primaryRole := orderedRoles[0]
remainingRoles := append([]int(nil), orderedRoles[1:]...)
sort.Ints(remainingRoles)
finalRoles := append([]int{primaryRole}, remainingRoles...)
helper.LogInfo(fmt.Sprintf("JWT role claims - primary role_id=%d, additional_role_id=%v", primaryRole, remainingRoles))
return finalRoles
}
func generateAccessToken(email, sessionID, userID string, roleID []int) (string, error) {
AccessTokenExpiration := os.Getenv("ACCESS_TOKEN_EXPIRATION_MINUTES")
if AccessTokenExpiration == "" {
@@ -188,9 +247,12 @@ func generateAccessToken(email, sessionID, userID string, roleID []int) (string,
}
var primaryRoleID *int
additionalRoleIDs := make([]int, 0)
if len(roleID) > 0 {
value := roleID[0]
primaryRoleID = &value
primaryRoleID = &roleID[0]
if len(roleID) > 1 {
additionalRoleIDs = append(additionalRoleIDs, roleID[1:]...)
}
}
expirationTime := time.Now().Add(24 * time.Hour).Unix()
@@ -199,7 +261,7 @@ func generateAccessToken(email, sessionID, userID string, roleID []int) (string,
Email: email,
UsersID: userID,
RoleID: primaryRoleID,
AdditionalRoleID: roleID,
AdditionalRoleID: additionalRoleIDs,
SessionID: sessionID,
Exp: expirationTime,
RegisteredClaims: jwt.RegisteredClaims{
@@ -341,10 +403,18 @@ func RefreshAccessToken(refreshTokenString, userAgent, ipAddress string) (string
userID = session.UsersID // Fallback to session's user ID
}
roleIDs, err := services.GetRoleIDsFromEmail(email)
roleIDs := make([]int, 0)
user, err := services.FetchUserByEmail(email)
if err != nil {
helper.LogError(err, fmt.Sprintf("Failed to fetch role ID for email %s during refresh", email))
roleIDs = []int{}
helper.LogError(err, fmt.Sprintf("Failed to fetch user profile for role mapping during refresh for email %s", email))
} else {
if user.UserID != "" {
userID = user.UserID
}
if user.EmailAddress != "" {
email = user.EmailAddress
}
roleIDs = buildJWTClaimRoleIDs(user)
}
accessToken, err := generateAccessToken(email, session.ID, userID, roleIDs)
if err != nil {
@@ -491,10 +561,18 @@ func RefreshAccessTokenWithEmailFallback(refreshTokenString, userAgent, ipAddres
userID = session.UsersID // Fallback to session's user ID
}
roleIDs, err := services.GetRoleIDsFromEmail(email)
roleIDs := make([]int, 0)
user, err := services.FetchUserByEmail(email)
if err != nil {
helper.LogError(err, fmt.Sprintf("Failed to fetch role ID for email %s during refresh", email))
roleIDs = []int{}
helper.LogError(err, fmt.Sprintf("Failed to fetch user profile for role mapping during refresh for email %s", email))
} else {
if user.UserID != "" {
userID = user.UserID
}
if user.EmailAddress != "" {
email = user.EmailAddress
}
roleIDs = buildJWTClaimRoleIDs(user)
}
accessToken, err := generateAccessToken(email, session.ID, userID, roleIDs)
if err != nil {