diff --git a/handlers/authorize.go b/handlers/authorize.go index 452decf..0dad462 100644 --- a/handlers/authorize.go +++ b/handlers/authorize.go @@ -87,15 +87,24 @@ func AuthorizeHandler(w http.ResponseWriter, r *http.Request) { } claimRoles := collectClaimRoles(claims) + projectRoleCount := 0 + for _, project := range claims.Projects { + projectRoleCount += len(project.RoleID) + log.Printf("[Handler] Project role claim - project_id=%d, alias=%s, roles=%v", project.ProjectID, project.Alias, project.RoleID) + } + log.Printf("[Handler] Claim roles parsed - base=%v, additional=%v, projects=%d, projectRoleEntries=%d, combined=%v", + claims.RoleID, + claims.AdditionalRoleID, + len(claims.Projects), + projectRoleCount, + claimRoles, + ) if len(claimRoles) == 0 { log.Printf("ERROR: No roles found in JWT claims for user=%s", claims.UsersID) } requestedRoles := collectRequestedRoles(&ctx) - if len(requestedRoles) == 0 { - requestedRoles = claimRoles - } - - validRoles := intersectRoles(requestedRoles, claimRoles) + validRoles := buildRoleCandidates(requestedRoles, claimRoles) + log.Printf("[Handler] Role candidate resolution - requested=%v, finalCandidates=%v", requestedRoles, validRoles) if len(validRoles) == 0 { log.Printf("ERROR: Role mismatch for user=%s - requestedRoles=%v, claimRoles=%v", ctx.UsersID, requestedRoles, claimRoles) helper.RespondWithError(w, http.StatusForbidden, "Role ID mismatch") @@ -197,3 +206,34 @@ func intersectRoles(requested, available []int) []int { return result } + +func buildRoleCandidates(requested, claimRoles []int) []int { + if len(claimRoles) == 0 { + return nil + } + + if len(requested) == 0 { + return append([]int(nil), claimRoles...) + } + + primary := intersectRoles(requested, claimRoles) + if len(primary) == 0 { + return nil + } + + seen := make(map[int]struct{}, len(primary)) + for _, role := range primary { + seen[role] = struct{}{} + } + + result := append([]int(nil), primary...) + for _, role := range claimRoles { + if _, exists := seen[role]; exists { + continue + } + seen[role] = struct{}{} + result = append(result, role) + } + + return result +} diff --git a/handlers/authorize_test.go b/handlers/authorize_test.go index 80f6897..cbf885a 100644 --- a/handlers/authorize_test.go +++ b/handlers/authorize_test.go @@ -440,3 +440,27 @@ func TestCollectClaimRolesIncludesAdditionalRoles(t *testing.T) { t.Fatalf("unexpected role order/content: %v", roles) } } + +func TestBuildRoleCandidates_PrioritizesRequestedThenFallsBackToClaims(t *testing.T) { + requested := []int{30} + claimRoles := []int{30, 44, 52} + + result := buildRoleCandidates(requested, claimRoles) + if len(result) != 3 { + t.Fatalf("expected 3 candidate roles, got %d (%v)", len(result), result) + } + + if result[0] != 30 || result[1] != 44 || result[2] != 52 { + t.Fatalf("unexpected candidate role order/content: %v", result) + } +} + +func TestBuildRoleCandidates_ReturnsNilWhenRequestedNotInClaims(t *testing.T) { + requested := []int{999} + claimRoles := []int{30, 44} + + result := buildRoleCandidates(requested, claimRoles) + if len(result) != 0 { + t.Fatalf("expected no candidates for mismatched requested role, got %v", result) + } +} diff --git a/middleware/jwt.go b/middleware/jwt.go index b799f94..fd24efa 100644 --- a/middleware/jwt.go +++ b/middleware/jwt.go @@ -35,7 +35,7 @@ var ( errExpiredToken = "Invalid or expired token" // #nosec G101 // Redis key prefix for token cache - redisTokenPrefix = "jwt:token:" + redisTokenPrefix = "jwt:v2:token:" ) func getRSAPublicKey() (*rsa.PublicKey, error) { @@ -257,6 +257,15 @@ func buildContext(parent context.Context, claims *models.Claims) context.Context } } + for _, project := range claims.Projects { + for _, role := range project.RoleID { + if _, exists := unique[role]; !exists { + unique[role] = struct{}{} + roles = append(roles, role) + } + } + } + ctx = context.WithValue(ctx, roleIDKey, roles) return ctx } diff --git a/middleware/jwt_test.go b/middleware/jwt_test.go index 1e3b9f2..88ace22 100644 --- a/middleware/jwt_test.go +++ b/middleware/jwt_test.go @@ -212,6 +212,32 @@ func TestBuildContextIncludesAdditionalRoles(t *testing.T) { } } +func TestBuildContextIncludesProjectRoles(t *testing.T) { + claims := &models.Claims{ + UsersID: "user123", + RoleID: models.RoleIDs{30}, + AdditionalRoleID: models.RoleIDs{4}, + Projects: []models.ProjectClaim{ + {ProjectID: 10, RoleID: models.RoleIDs{44, 52}}, + {ProjectID: 11, RoleID: models.RoleIDs{30, 52, 61}}, + }, + } + + ctx := buildContext(context.Background(), claims) + val, ok := ctx.Value(roleIDKey).([]int) + if !ok { + t.Fatal("Role not properly set in context") + } + + if len(val) != 5 { + t.Fatalf("expected 5 unique roles, got %d (%v)", len(val), val) + } + + if val[0] != 30 || val[1] != 4 || val[2] != 44 || val[3] != 52 || val[4] != 61 { + t.Fatalf("unexpected roles in context: %v", val) + } +} + func TestGetClaims(t *testing.T) { claims := &models.Claims{ UsersID: "user123", diff --git a/models/authorize.go b/models/authorize.go index 64d6381..f7bfdbb 100644 --- a/models/authorize.go +++ b/models/authorize.go @@ -28,6 +28,26 @@ type ProjectClaim struct { OfficeID int `json:"office_id,omitempty"` } +func (p *ProjectClaim) UnmarshalJSON(data []byte) error { + type Alias ProjectClaim + aux := &struct { + RoleIDs RoleIDs `json:"role_ids"` + *Alias + }{ + Alias: (*Alias)(p), + } + + if err := json.Unmarshal(data, &aux); err != nil { + return err + } + + if len(p.RoleID) == 0 && len(aux.RoleIDs) > 0 { + p.RoleID = aux.RoleIDs + } + + return nil +} + // RoleIDs represents one or more role IDs. // It is defined as a custom type so we can implement flexible JSON unmarshalling // that accepts a single string ("1"), a single number (1), or an array ([1,2,...]). @@ -92,28 +112,78 @@ type Claims struct { // UnmarshalJSON handles both "user_id" and "users_id" field names in JWT claims func (c *Claims) UnmarshalJSON(data []byte) error { - type Alias Claims - aux := &struct { - UserID string `json:"user_id"` - Email string `json:"email"` - *Alias - }{ - Alias: (*Alias)(c), - } + aux := struct { + UsersID string `json:"users_id"` + UserID string `json:"user_id"` + EmailAddress string `json:"email_address"` + Email string `json:"email"` + RoleID RoleIDs `json:"role_id"` + AdditionalRoleID RoleIDs `json:"additional_role_id"` + Projects json.RawMessage `json:"projects"` + ProjectMetadata json.RawMessage `json:"project_metadata"` + ProjectsMetadata json.RawMessage `json:"projects_metadata"` + jwt.RegisteredClaims + }{} + if err := json.Unmarshal(data, &aux); err != nil { return err } - // If UsersID is empty but UserID is set, copy UserID to UsersID + + c.UsersID = aux.UsersID + c.EmailAddress = aux.EmailAddress + c.RoleID = aux.RoleID + c.AdditionalRoleID = aux.AdditionalRoleID + c.RegisteredClaims = aux.RegisteredClaims + if c.UsersID == "" && aux.UserID != "" { c.UsersID = aux.UserID } - // If EmailAddress is empty but Email is set, copy Email to EmailAddress + if c.EmailAddress == "" && aux.Email != "" { c.EmailAddress = aux.Email } + + projects := make([]ProjectClaim, 0) + rawProjectFields := []json.RawMessage{aux.Projects, aux.ProjectMetadata, aux.ProjectsMetadata} + for _, raw := range rawProjectFields { + parsedProjects, err := parseProjectClaims(raw) + if err != nil { + return err + } + if len(parsedProjects) > 0 { + projects = append(projects, parsedProjects...) + } + } + + c.Projects = projects + return nil } +func parseProjectClaims(raw json.RawMessage) ([]ProjectClaim, error) { + trimmed := bytes.TrimSpace(raw) + if len(trimmed) == 0 || bytes.Equal(trimmed, []byte("null")) { + return nil, nil + } + + switch trimmed[0] { + case '[': + var projects []ProjectClaim + if err := json.Unmarshal(trimmed, &projects); err != nil { + return nil, err + } + return projects, nil + case '{': + var project ProjectClaim + if err := json.Unmarshal(trimmed, &project); err != nil { + return nil, err + } + return []ProjectClaim{project}, nil + default: + return nil, fmt.Errorf("unsupported JSON for projects claim: %s", string(trimmed)) + } +} + // ContextKey is a custom type for context keys to avoid collisions type ContextKey string diff --git a/models/authorize_test.go b/models/authorize_test.go new file mode 100644 index 0000000..8f38386 --- /dev/null +++ b/models/authorize_test.go @@ -0,0 +1,68 @@ +package models + +import ( + "encoding/json" + "testing" +) + +func TestClaimsUnmarshal_ProjectMetadataRoleIDs(t *testing.T) { + payload := `{ + "users_id":"U0000000003", + "email_address":"user@example.com", + "role_id":[30], + "project_metadata":[ + {"project_id":101,"alias":"proj-a","role_ids":[44,52]} + ] + }` + + var claims Claims + if err := json.Unmarshal([]byte(payload), &claims); err != nil { + t.Fatalf("expected no error, got %v", err) + } + + if len(claims.RoleID) != 1 || claims.RoleID[0] != 30 { + t.Fatalf("unexpected base role: %v", claims.RoleID) + } + + if len(claims.Projects) != 1 { + t.Fatalf("expected 1 project, got %d", len(claims.Projects)) + } + + if claims.Projects[0].ProjectID != 101 { + t.Fatalf("expected project_id=101, got %d", claims.Projects[0].ProjectID) + } + + if len(claims.Projects[0].RoleID) != 2 || claims.Projects[0].RoleID[0] != 44 || claims.Projects[0].RoleID[1] != 52 { + t.Fatalf("unexpected project role ids: %v", claims.Projects[0].RoleID) + } +} + +func TestClaimsUnmarshal_ProjectsMetadataSingleObject(t *testing.T) { + payload := `{ + "user_id":"U0000000003", + "email":"user@example.com", + "role_id":"30", + "projects_metadata":{"project_id":202,"alias":"proj-b","role_id":"61"} + }` + + var claims Claims + if err := json.Unmarshal([]byte(payload), &claims); err != nil { + t.Fatalf("expected no error, got %v", err) + } + + if claims.UsersID != "U0000000003" { + t.Fatalf("expected users_id fallback from user_id, got %s", claims.UsersID) + } + + if claims.EmailAddress != "user@example.com" { + t.Fatalf("expected email fallback from email, got %s", claims.EmailAddress) + } + + if len(claims.Projects) != 1 { + t.Fatalf("expected 1 project from projects_metadata object, got %d", len(claims.Projects)) + } + + if len(claims.Projects[0].RoleID) != 1 || claims.Projects[0].RoleID[0] != 61 { + t.Fatalf("unexpected project role ids: %v", claims.Projects[0].RoleID) + } +}