From 20bd509bba46d7a104b21da3d8351084cd2af780 Mon Sep 17 00:00:00 2001 From: F04C Date: Fri, 27 Feb 2026 14:03:29 +0800 Subject: [PATCH] added additional_role_id when checking the roles --- handlers/authorize.go | 7 +++++++ handlers/authorize_test.go | 16 ++++++++++++++++ middleware/jwt.go | 20 ++++++++++++++++++-- middleware/jwt_test.go | 22 ++++++++++++++++++++++ models/authorize.go | 1 + 5 files changed, 64 insertions(+), 2 deletions(-) diff --git a/handlers/authorize.go b/handlers/authorize.go index b06fa4c..452decf 100644 --- a/handlers/authorize.go +++ b/handlers/authorize.go @@ -147,6 +147,13 @@ func collectClaimRoles(claims *models.Claims) []int { } } + for _, role := range claims.AdditionalRoleID { + if _, exists := unique[role]; !exists { + unique[role] = struct{}{} + roles = append(roles, role) + } + } + for _, project := range claims.Projects { for _, role := range project.RoleID { if _, exists := unique[role]; !exists { diff --git a/handlers/authorize_test.go b/handlers/authorize_test.go index 8675183..80f6897 100644 --- a/handlers/authorize_test.go +++ b/handlers/authorize_test.go @@ -424,3 +424,19 @@ func TestCollectRequestedRolesFromArray(t *testing.T) { t.Fatalf("unexpected requested roles: %v", result) } } + +func TestCollectClaimRolesIncludesAdditionalRoles(t *testing.T) { + claims := &models.Claims{ + RoleID: models.RoleIDs{30}, + AdditionalRoleID: models.RoleIDs{4, 5, 30}, + } + + roles := collectClaimRoles(claims) + if len(roles) != 3 { + t.Fatalf("expected 3 unique roles, got %d (%v)", len(roles), roles) + } + + if roles[0] != 30 || roles[1] != 4 || roles[2] != 5 { + t.Fatalf("unexpected role order/content: %v", roles) + } +} diff --git a/middleware/jwt.go b/middleware/jwt.go index ceb998f..b799f94 100644 --- a/middleware/jwt.go +++ b/middleware/jwt.go @@ -240,8 +240,24 @@ func JWTAuth(next http.HandlerFunc) http.HandlerFunc { func buildContext(parent context.Context, claims *models.Claims) context.Context { ctx := context.WithValue(parent, claimsKey, claims) ctx = context.WithValue(ctx, userIDKey, claims.UsersID) - // Store plain []int in context for roles to keep middleware interfaces simple - ctx = context.WithValue(ctx, roleIDKey, []int(claims.RoleID)) + roles := make([]int, 0, len(claims.RoleID)+len(claims.AdditionalRoleID)) + unique := make(map[int]struct{}) + + for _, role := range claims.RoleID { + if _, exists := unique[role]; !exists { + unique[role] = struct{}{} + roles = append(roles, role) + } + } + + for _, role := range claims.AdditionalRoleID { + if _, exists := unique[role]; !exists { + unique[role] = struct{}{} + roles = append(roles, role) + } + } + + ctx = context.WithValue(ctx, roleIDKey, roles) return ctx } diff --git a/middleware/jwt_test.go b/middleware/jwt_test.go index ff497c0..564c6fa 100644 --- a/middleware/jwt_test.go +++ b/middleware/jwt_test.go @@ -190,6 +190,28 @@ func TestBuildContext(t *testing.T) { } } +func TestBuildContextIncludesAdditionalRoles(t *testing.T) { + claims := &models.Claims{ + UsersID: "user123", + RoleID: models.RoleIDs{30}, + AdditionalRoleID: models.RoleIDs{4, 5, 30}, + } + + ctx := buildContext(context.Background(), claims) + val, ok := ctx.Value(roleIDKey).([]int) + if !ok { + t.Fatal("Role not properly set in context") + } + + if len(val) != 3 { + t.Fatalf("expected 3 unique roles, got %d (%v)", len(val), val) + } + + if val[0] != 30 || val[1] != 4 || val[2] != 5 { + t.Fatalf("unexpected roles in context: %v", val) + } +} + func TestGetClaims(t *testing.T) { claims := &models.Claims{ UsersID: "user123", diff --git a/models/authorize.go b/models/authorize.go index 4ac6b5d..a187ea7 100644 --- a/models/authorize.go +++ b/models/authorize.go @@ -85,6 +85,7 @@ type Claims struct { UsersID string `json:"users_id,omitempty"` EmailAddress string `json:"email_address,omitempty"` RoleID RoleIDs `json:"role_id"` + AdditionalRoleID RoleIDs `json:"additional_role_id,omitempty"` Projects []ProjectClaim `json:"projects,omitempty"` jwt.RegisteredClaims }