fixed authorization (now checks the role inside of the project)

This commit is contained in:
2026-03-02 13:46:14 +08:00
parent e32a4a2779
commit 8ca995d490
6 changed files with 253 additions and 16 deletions
+45 -5
View File
@@ -87,15 +87,24 @@ func AuthorizeHandler(w http.ResponseWriter, r *http.Request) {
} }
claimRoles := collectClaimRoles(claims) claimRoles := collectClaimRoles(claims)
projectRoleCount := 0
for _, project := range claims.Projects {
projectRoleCount += len(project.RoleID)
log.Printf("[Handler] Project role claim - project_id=%d, alias=%s, roles=%v", project.ProjectID, project.Alias, project.RoleID)
}
log.Printf("[Handler] Claim roles parsed - base=%v, additional=%v, projects=%d, projectRoleEntries=%d, combined=%v",
claims.RoleID,
claims.AdditionalRoleID,
len(claims.Projects),
projectRoleCount,
claimRoles,
)
if len(claimRoles) == 0 { if len(claimRoles) == 0 {
log.Printf("ERROR: No roles found in JWT claims for user=%s", claims.UsersID) log.Printf("ERROR: No roles found in JWT claims for user=%s", claims.UsersID)
} }
requestedRoles := collectRequestedRoles(&ctx) requestedRoles := collectRequestedRoles(&ctx)
if len(requestedRoles) == 0 { validRoles := buildRoleCandidates(requestedRoles, claimRoles)
requestedRoles = claimRoles log.Printf("[Handler] Role candidate resolution - requested=%v, finalCandidates=%v", requestedRoles, validRoles)
}
validRoles := intersectRoles(requestedRoles, claimRoles)
if len(validRoles) == 0 { if len(validRoles) == 0 {
log.Printf("ERROR: Role mismatch for user=%s - requestedRoles=%v, claimRoles=%v", ctx.UsersID, requestedRoles, claimRoles) log.Printf("ERROR: Role mismatch for user=%s - requestedRoles=%v, claimRoles=%v", ctx.UsersID, requestedRoles, claimRoles)
helper.RespondWithError(w, http.StatusForbidden, "Role ID mismatch") helper.RespondWithError(w, http.StatusForbidden, "Role ID mismatch")
@@ -197,3 +206,34 @@ func intersectRoles(requested, available []int) []int {
return result return result
} }
func buildRoleCandidates(requested, claimRoles []int) []int {
if len(claimRoles) == 0 {
return nil
}
if len(requested) == 0 {
return append([]int(nil), claimRoles...)
}
primary := intersectRoles(requested, claimRoles)
if len(primary) == 0 {
return nil
}
seen := make(map[int]struct{}, len(primary))
for _, role := range primary {
seen[role] = struct{}{}
}
result := append([]int(nil), primary...)
for _, role := range claimRoles {
if _, exists := seen[role]; exists {
continue
}
seen[role] = struct{}{}
result = append(result, role)
}
return result
}
+24
View File
@@ -440,3 +440,27 @@ func TestCollectClaimRolesIncludesAdditionalRoles(t *testing.T) {
t.Fatalf("unexpected role order/content: %v", roles) t.Fatalf("unexpected role order/content: %v", roles)
} }
} }
func TestBuildRoleCandidates_PrioritizesRequestedThenFallsBackToClaims(t *testing.T) {
requested := []int{30}
claimRoles := []int{30, 44, 52}
result := buildRoleCandidates(requested, claimRoles)
if len(result) != 3 {
t.Fatalf("expected 3 candidate roles, got %d (%v)", len(result), result)
}
if result[0] != 30 || result[1] != 44 || result[2] != 52 {
t.Fatalf("unexpected candidate role order/content: %v", result)
}
}
func TestBuildRoleCandidates_ReturnsNilWhenRequestedNotInClaims(t *testing.T) {
requested := []int{999}
claimRoles := []int{30, 44}
result := buildRoleCandidates(requested, claimRoles)
if len(result) != 0 {
t.Fatalf("expected no candidates for mismatched requested role, got %v", result)
}
}
+10 -1
View File
@@ -35,7 +35,7 @@ var (
errExpiredToken = "Invalid or expired token" // #nosec G101 errExpiredToken = "Invalid or expired token" // #nosec G101
// Redis key prefix for token cache // Redis key prefix for token cache
redisTokenPrefix = "jwt:token:" redisTokenPrefix = "jwt:v2:token:"
) )
func getRSAPublicKey() (*rsa.PublicKey, error) { func getRSAPublicKey() (*rsa.PublicKey, error) {
@@ -257,6 +257,15 @@ func buildContext(parent context.Context, claims *models.Claims) context.Context
} }
} }
for _, project := range claims.Projects {
for _, role := range project.RoleID {
if _, exists := unique[role]; !exists {
unique[role] = struct{}{}
roles = append(roles, role)
}
}
}
ctx = context.WithValue(ctx, roleIDKey, roles) ctx = context.WithValue(ctx, roleIDKey, roles)
return ctx return ctx
} }
+26
View File
@@ -212,6 +212,32 @@ func TestBuildContextIncludesAdditionalRoles(t *testing.T) {
} }
} }
func TestBuildContextIncludesProjectRoles(t *testing.T) {
claims := &models.Claims{
UsersID: "user123",
RoleID: models.RoleIDs{30},
AdditionalRoleID: models.RoleIDs{4},
Projects: []models.ProjectClaim{
{ProjectID: 10, RoleID: models.RoleIDs{44, 52}},
{ProjectID: 11, RoleID: models.RoleIDs{30, 52, 61}},
},
}
ctx := buildContext(context.Background(), claims)
val, ok := ctx.Value(roleIDKey).([]int)
if !ok {
t.Fatal("Role not properly set in context")
}
if len(val) != 5 {
t.Fatalf("expected 5 unique roles, got %d (%v)", len(val), val)
}
if val[0] != 30 || val[1] != 4 || val[2] != 44 || val[3] != 52 || val[4] != 61 {
t.Fatalf("unexpected roles in context: %v", val)
}
}
func TestGetClaims(t *testing.T) { func TestGetClaims(t *testing.T) {
claims := &models.Claims{ claims := &models.Claims{
UsersID: "user123", UsersID: "user123",
+80 -10
View File
@@ -28,6 +28,26 @@ type ProjectClaim struct {
OfficeID int `json:"office_id,omitempty"` OfficeID int `json:"office_id,omitempty"`
} }
func (p *ProjectClaim) UnmarshalJSON(data []byte) error {
type Alias ProjectClaim
aux := &struct {
RoleIDs RoleIDs `json:"role_ids"`
*Alias
}{
Alias: (*Alias)(p),
}
if err := json.Unmarshal(data, &aux); err != nil {
return err
}
if len(p.RoleID) == 0 && len(aux.RoleIDs) > 0 {
p.RoleID = aux.RoleIDs
}
return nil
}
// RoleIDs represents one or more role IDs. // RoleIDs represents one or more role IDs.
// It is defined as a custom type so we can implement flexible JSON unmarshalling // It is defined as a custom type so we can implement flexible JSON unmarshalling
// that accepts a single string ("1"), a single number (1), or an array ([1,2,...]). // that accepts a single string ("1"), a single number (1), or an array ([1,2,...]).
@@ -92,28 +112,78 @@ type Claims struct {
// UnmarshalJSON handles both "user_id" and "users_id" field names in JWT claims // UnmarshalJSON handles both "user_id" and "users_id" field names in JWT claims
func (c *Claims) UnmarshalJSON(data []byte) error { func (c *Claims) UnmarshalJSON(data []byte) error {
type Alias Claims aux := struct {
aux := &struct { UsersID string `json:"users_id"`
UserID string `json:"user_id"` UserID string `json:"user_id"`
Email string `json:"email"` EmailAddress string `json:"email_address"`
*Alias Email string `json:"email"`
}{ RoleID RoleIDs `json:"role_id"`
Alias: (*Alias)(c), AdditionalRoleID RoleIDs `json:"additional_role_id"`
} Projects json.RawMessage `json:"projects"`
ProjectMetadata json.RawMessage `json:"project_metadata"`
ProjectsMetadata json.RawMessage `json:"projects_metadata"`
jwt.RegisteredClaims
}{}
if err := json.Unmarshal(data, &aux); err != nil { if err := json.Unmarshal(data, &aux); err != nil {
return err return err
} }
// If UsersID is empty but UserID is set, copy UserID to UsersID
c.UsersID = aux.UsersID
c.EmailAddress = aux.EmailAddress
c.RoleID = aux.RoleID
c.AdditionalRoleID = aux.AdditionalRoleID
c.RegisteredClaims = aux.RegisteredClaims
if c.UsersID == "" && aux.UserID != "" { if c.UsersID == "" && aux.UserID != "" {
c.UsersID = aux.UserID c.UsersID = aux.UserID
} }
// If EmailAddress is empty but Email is set, copy Email to EmailAddress
if c.EmailAddress == "" && aux.Email != "" { if c.EmailAddress == "" && aux.Email != "" {
c.EmailAddress = aux.Email c.EmailAddress = aux.Email
} }
projects := make([]ProjectClaim, 0)
rawProjectFields := []json.RawMessage{aux.Projects, aux.ProjectMetadata, aux.ProjectsMetadata}
for _, raw := range rawProjectFields {
parsedProjects, err := parseProjectClaims(raw)
if err != nil {
return err
}
if len(parsedProjects) > 0 {
projects = append(projects, parsedProjects...)
}
}
c.Projects = projects
return nil return nil
} }
func parseProjectClaims(raw json.RawMessage) ([]ProjectClaim, error) {
trimmed := bytes.TrimSpace(raw)
if len(trimmed) == 0 || bytes.Equal(trimmed, []byte("null")) {
return nil, nil
}
switch trimmed[0] {
case '[':
var projects []ProjectClaim
if err := json.Unmarshal(trimmed, &projects); err != nil {
return nil, err
}
return projects, nil
case '{':
var project ProjectClaim
if err := json.Unmarshal(trimmed, &project); err != nil {
return nil, err
}
return []ProjectClaim{project}, nil
default:
return nil, fmt.Errorf("unsupported JSON for projects claim: %s", string(trimmed))
}
}
// ContextKey is a custom type for context keys to avoid collisions // ContextKey is a custom type for context keys to avoid collisions
type ContextKey string type ContextKey string
+68
View File
@@ -0,0 +1,68 @@
package models
import (
"encoding/json"
"testing"
)
func TestClaimsUnmarshal_ProjectMetadataRoleIDs(t *testing.T) {
payload := `{
"users_id":"U0000000003",
"email_address":"user@example.com",
"role_id":[30],
"project_metadata":[
{"project_id":101,"alias":"proj-a","role_ids":[44,52]}
]
}`
var claims Claims
if err := json.Unmarshal([]byte(payload), &claims); err != nil {
t.Fatalf("expected no error, got %v", err)
}
if len(claims.RoleID) != 1 || claims.RoleID[0] != 30 {
t.Fatalf("unexpected base role: %v", claims.RoleID)
}
if len(claims.Projects) != 1 {
t.Fatalf("expected 1 project, got %d", len(claims.Projects))
}
if claims.Projects[0].ProjectID != 101 {
t.Fatalf("expected project_id=101, got %d", claims.Projects[0].ProjectID)
}
if len(claims.Projects[0].RoleID) != 2 || claims.Projects[0].RoleID[0] != 44 || claims.Projects[0].RoleID[1] != 52 {
t.Fatalf("unexpected project role ids: %v", claims.Projects[0].RoleID)
}
}
func TestClaimsUnmarshal_ProjectsMetadataSingleObject(t *testing.T) {
payload := `{
"user_id":"U0000000003",
"email":"user@example.com",
"role_id":"30",
"projects_metadata":{"project_id":202,"alias":"proj-b","role_id":"61"}
}`
var claims Claims
if err := json.Unmarshal([]byte(payload), &claims); err != nil {
t.Fatalf("expected no error, got %v", err)
}
if claims.UsersID != "U0000000003" {
t.Fatalf("expected users_id fallback from user_id, got %s", claims.UsersID)
}
if claims.EmailAddress != "user@example.com" {
t.Fatalf("expected email fallback from email, got %s", claims.EmailAddress)
}
if len(claims.Projects) != 1 {
t.Fatalf("expected 1 project from projects_metadata object, got %d", len(claims.Projects))
}
if len(claims.Projects[0].RoleID) != 1 || claims.Projects[0].RoleID[0] != 61 {
t.Fatalf("unexpected project role ids: %v", claims.Projects[0].RoleID)
}
}