Merge branch 'rsa' into 'main'
fix(jwt): make role claims consistent with /me and correct additional_role_id See merge request psa/uess/authn!2
This commit is contained in:
+101
-23
@@ -15,6 +15,7 @@ import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
@@ -97,9 +98,20 @@ func GenerateTokens(email, userAgent, ipAddress string) (string, string, error)
|
||||
}
|
||||
|
||||
log.Print("userID:", userID)
|
||||
roleID, err := services.GetRoleIDsFromEmail(email)
|
||||
|
||||
tokenEmail := email
|
||||
roleID := make([]int, 0)
|
||||
user, err := services.FetchUserByEmail(email)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("error checking role in database: %w", err)
|
||||
helper.LogWarn(fmt.Sprintf("Failed to fetch user profile for JWT role mapping (%s): %v", email, err))
|
||||
} else {
|
||||
if user.UserID != "" {
|
||||
userID = user.UserID
|
||||
}
|
||||
if user.EmailAddress != "" {
|
||||
tokenEmail = user.EmailAddress
|
||||
}
|
||||
roleID = buildJWTClaimRoleIDs(user)
|
||||
}
|
||||
sessionID := helper.UUIDGenerator()
|
||||
|
||||
@@ -156,18 +168,7 @@ func GenerateTokens(email, userAgent, ipAddress string) (string, string, error)
|
||||
}
|
||||
}
|
||||
|
||||
// Convert roleIDs slice to a comma-separated string for the token claim
|
||||
var roleIDsStr string
|
||||
if len(roleID) > 0 {
|
||||
for i, r := range roleID {
|
||||
if i > 0 {
|
||||
roleIDsStr += ","
|
||||
}
|
||||
roleIDsStr += fmt.Sprintf("%d", r)
|
||||
}
|
||||
}
|
||||
|
||||
accessToken, err := generateAccessToken(email, sessionID, userID, roleID)
|
||||
accessToken, err := generateAccessToken(tokenEmail, sessionID, userID, roleID)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf(errFormatWithContext, errMsgFailedToGenerateAccessToken, err)
|
||||
}
|
||||
@@ -176,6 +177,64 @@ func GenerateTokens(email, userAgent, ipAddress string) (string, string, error)
|
||||
return accessToken, refreshToken, nil
|
||||
}
|
||||
|
||||
func buildJWTClaimRoleIDs(user models.User) []int {
|
||||
roleMap := make(map[int]struct{})
|
||||
orderedRoles := make([]int, 0)
|
||||
roleOneProjects := make([]int, 0)
|
||||
|
||||
if user.RoleID != nil {
|
||||
helper.LogInfo(fmt.Sprintf("JWT role source - base role_id: %d", *user.RoleID))
|
||||
}
|
||||
|
||||
addUnique := func(role int) {
|
||||
if _, exists := roleMap[role]; exists {
|
||||
return
|
||||
}
|
||||
roleMap[role] = struct{}{}
|
||||
orderedRoles = append(orderedRoles, role)
|
||||
}
|
||||
|
||||
if user.RoleID != nil {
|
||||
addUnique(*user.RoleID)
|
||||
}
|
||||
|
||||
if user.Projects != nil {
|
||||
for _, project := range *user.Projects {
|
||||
helper.LogInfo(fmt.Sprintf("JWT role source - project %d: role_id=%v", project.ProjectID, project.RoleID))
|
||||
for _, role := range project.RoleID {
|
||||
if role == 1 {
|
||||
roleOneProjects = append(roleOneProjects, project.ProjectID)
|
||||
}
|
||||
addUnique(role)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(roleOneProjects) > 0 {
|
||||
helper.LogInfo(fmt.Sprintf("JWT role trace - additional role_id=1 found in project_id(s): %v", roleOneProjects))
|
||||
} else {
|
||||
helper.LogInfo("JWT role trace - additional role_id=1 not found in project roles")
|
||||
}
|
||||
|
||||
if len(orderedRoles) <= 1 {
|
||||
if len(orderedRoles) == 1 {
|
||||
helper.LogInfo(fmt.Sprintf("JWT role claims - primary role_id=%d, additional_role_id=[]", orderedRoles[0]))
|
||||
} else {
|
||||
helper.LogInfo("JWT role claims - primary role_id=nil, additional_role_id=[]")
|
||||
}
|
||||
return orderedRoles
|
||||
}
|
||||
|
||||
primaryRole := orderedRoles[0]
|
||||
remainingRoles := append([]int(nil), orderedRoles[1:]...)
|
||||
sort.Ints(remainingRoles)
|
||||
|
||||
finalRoles := append([]int{primaryRole}, remainingRoles...)
|
||||
helper.LogInfo(fmt.Sprintf("JWT role claims - primary role_id=%d, additional_role_id=%v", primaryRole, remainingRoles))
|
||||
|
||||
return finalRoles
|
||||
}
|
||||
|
||||
func generateAccessToken(email, sessionID, userID string, roleID []int) (string, error) {
|
||||
AccessTokenExpiration := os.Getenv("ACCESS_TOKEN_EXPIRATION_MINUTES")
|
||||
if AccessTokenExpiration == "" {
|
||||
@@ -188,9 +247,12 @@ func generateAccessToken(email, sessionID, userID string, roleID []int) (string,
|
||||
}
|
||||
|
||||
var primaryRoleID *int
|
||||
additionalRoleIDs := make([]int, 0)
|
||||
if len(roleID) > 0 {
|
||||
value := roleID[0]
|
||||
primaryRoleID = &value
|
||||
primaryRoleID = &roleID[0]
|
||||
if len(roleID) > 1 {
|
||||
additionalRoleIDs = append(additionalRoleIDs, roleID[1:]...)
|
||||
}
|
||||
}
|
||||
|
||||
expirationTime := time.Now().Add(24 * time.Hour).Unix()
|
||||
@@ -199,7 +261,7 @@ func generateAccessToken(email, sessionID, userID string, roleID []int) (string,
|
||||
Email: email,
|
||||
UsersID: userID,
|
||||
RoleID: primaryRoleID,
|
||||
AdditionalRoleID: roleID,
|
||||
AdditionalRoleID: additionalRoleIDs,
|
||||
SessionID: sessionID,
|
||||
Exp: expirationTime,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
@@ -341,10 +403,18 @@ func RefreshAccessToken(refreshTokenString, userAgent, ipAddress string) (string
|
||||
userID = session.UsersID // Fallback to session's user ID
|
||||
}
|
||||
|
||||
roleIDs, err := services.GetRoleIDsFromEmail(email)
|
||||
roleIDs := make([]int, 0)
|
||||
user, err := services.FetchUserByEmail(email)
|
||||
if err != nil {
|
||||
helper.LogError(err, fmt.Sprintf("Failed to fetch role ID for email %s during refresh", email))
|
||||
roleIDs = []int{}
|
||||
helper.LogError(err, fmt.Sprintf("Failed to fetch user profile for role mapping during refresh for email %s", email))
|
||||
} else {
|
||||
if user.UserID != "" {
|
||||
userID = user.UserID
|
||||
}
|
||||
if user.EmailAddress != "" {
|
||||
email = user.EmailAddress
|
||||
}
|
||||
roleIDs = buildJWTClaimRoleIDs(user)
|
||||
}
|
||||
accessToken, err := generateAccessToken(email, session.ID, userID, roleIDs)
|
||||
if err != nil {
|
||||
@@ -491,10 +561,18 @@ func RefreshAccessTokenWithEmailFallback(refreshTokenString, userAgent, ipAddres
|
||||
userID = session.UsersID // Fallback to session's user ID
|
||||
}
|
||||
|
||||
roleIDs, err := services.GetRoleIDsFromEmail(email)
|
||||
roleIDs := make([]int, 0)
|
||||
user, err := services.FetchUserByEmail(email)
|
||||
if err != nil {
|
||||
helper.LogError(err, fmt.Sprintf("Failed to fetch role ID for email %s during refresh", email))
|
||||
roleIDs = []int{}
|
||||
helper.LogError(err, fmt.Sprintf("Failed to fetch user profile for role mapping during refresh for email %s", email))
|
||||
} else {
|
||||
if user.UserID != "" {
|
||||
userID = user.UserID
|
||||
}
|
||||
if user.EmailAddress != "" {
|
||||
email = user.EmailAddress
|
||||
}
|
||||
roleIDs = buildJWTClaimRoleIDs(user)
|
||||
}
|
||||
accessToken, err := generateAccessToken(email, session.ID, userID, roleIDs)
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user