fixed authorization (now checks the role inside of the project)
This commit is contained in:
+45
-5
@@ -87,15 +87,24 @@ func AuthorizeHandler(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
claimRoles := collectClaimRoles(claims)
|
||||
projectRoleCount := 0
|
||||
for _, project := range claims.Projects {
|
||||
projectRoleCount += len(project.RoleID)
|
||||
log.Printf("[Handler] Project role claim - project_id=%d, alias=%s, roles=%v", project.ProjectID, project.Alias, project.RoleID)
|
||||
}
|
||||
log.Printf("[Handler] Claim roles parsed - base=%v, additional=%v, projects=%d, projectRoleEntries=%d, combined=%v",
|
||||
claims.RoleID,
|
||||
claims.AdditionalRoleID,
|
||||
len(claims.Projects),
|
||||
projectRoleCount,
|
||||
claimRoles,
|
||||
)
|
||||
if len(claimRoles) == 0 {
|
||||
log.Printf("ERROR: No roles found in JWT claims for user=%s", claims.UsersID)
|
||||
}
|
||||
requestedRoles := collectRequestedRoles(&ctx)
|
||||
if len(requestedRoles) == 0 {
|
||||
requestedRoles = claimRoles
|
||||
}
|
||||
|
||||
validRoles := intersectRoles(requestedRoles, claimRoles)
|
||||
validRoles := buildRoleCandidates(requestedRoles, claimRoles)
|
||||
log.Printf("[Handler] Role candidate resolution - requested=%v, finalCandidates=%v", requestedRoles, validRoles)
|
||||
if len(validRoles) == 0 {
|
||||
log.Printf("ERROR: Role mismatch for user=%s - requestedRoles=%v, claimRoles=%v", ctx.UsersID, requestedRoles, claimRoles)
|
||||
helper.RespondWithError(w, http.StatusForbidden, "Role ID mismatch")
|
||||
@@ -197,3 +206,34 @@ func intersectRoles(requested, available []int) []int {
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func buildRoleCandidates(requested, claimRoles []int) []int {
|
||||
if len(claimRoles) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(requested) == 0 {
|
||||
return append([]int(nil), claimRoles...)
|
||||
}
|
||||
|
||||
primary := intersectRoles(requested, claimRoles)
|
||||
if len(primary) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
seen := make(map[int]struct{}, len(primary))
|
||||
for _, role := range primary {
|
||||
seen[role] = struct{}{}
|
||||
}
|
||||
|
||||
result := append([]int(nil), primary...)
|
||||
for _, role := range claimRoles {
|
||||
if _, exists := seen[role]; exists {
|
||||
continue
|
||||
}
|
||||
seen[role] = struct{}{}
|
||||
result = append(result, role)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -440,3 +440,27 @@ func TestCollectClaimRolesIncludesAdditionalRoles(t *testing.T) {
|
||||
t.Fatalf("unexpected role order/content: %v", roles)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRoleCandidates_PrioritizesRequestedThenFallsBackToClaims(t *testing.T) {
|
||||
requested := []int{30}
|
||||
claimRoles := []int{30, 44, 52}
|
||||
|
||||
result := buildRoleCandidates(requested, claimRoles)
|
||||
if len(result) != 3 {
|
||||
t.Fatalf("expected 3 candidate roles, got %d (%v)", len(result), result)
|
||||
}
|
||||
|
||||
if result[0] != 30 || result[1] != 44 || result[2] != 52 {
|
||||
t.Fatalf("unexpected candidate role order/content: %v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRoleCandidates_ReturnsNilWhenRequestedNotInClaims(t *testing.T) {
|
||||
requested := []int{999}
|
||||
claimRoles := []int{30, 44}
|
||||
|
||||
result := buildRoleCandidates(requested, claimRoles)
|
||||
if len(result) != 0 {
|
||||
t.Fatalf("expected no candidates for mismatched requested role, got %v", result)
|
||||
}
|
||||
}
|
||||
|
||||
+10
-1
@@ -35,7 +35,7 @@ var (
|
||||
errExpiredToken = "Invalid or expired token" // #nosec G101
|
||||
|
||||
// Redis key prefix for token cache
|
||||
redisTokenPrefix = "jwt:token:"
|
||||
redisTokenPrefix = "jwt:v2:token:"
|
||||
)
|
||||
|
||||
func getRSAPublicKey() (*rsa.PublicKey, error) {
|
||||
@@ -257,6 +257,15 @@ func buildContext(parent context.Context, claims *models.Claims) context.Context
|
||||
}
|
||||
}
|
||||
|
||||
for _, project := range claims.Projects {
|
||||
for _, role := range project.RoleID {
|
||||
if _, exists := unique[role]; !exists {
|
||||
unique[role] = struct{}{}
|
||||
roles = append(roles, role)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ctx = context.WithValue(ctx, roleIDKey, roles)
|
||||
return ctx
|
||||
}
|
||||
|
||||
@@ -212,6 +212,32 @@ func TestBuildContextIncludesAdditionalRoles(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildContextIncludesProjectRoles(t *testing.T) {
|
||||
claims := &models.Claims{
|
||||
UsersID: "user123",
|
||||
RoleID: models.RoleIDs{30},
|
||||
AdditionalRoleID: models.RoleIDs{4},
|
||||
Projects: []models.ProjectClaim{
|
||||
{ProjectID: 10, RoleID: models.RoleIDs{44, 52}},
|
||||
{ProjectID: 11, RoleID: models.RoleIDs{30, 52, 61}},
|
||||
},
|
||||
}
|
||||
|
||||
ctx := buildContext(context.Background(), claims)
|
||||
val, ok := ctx.Value(roleIDKey).([]int)
|
||||
if !ok {
|
||||
t.Fatal("Role not properly set in context")
|
||||
}
|
||||
|
||||
if len(val) != 5 {
|
||||
t.Fatalf("expected 5 unique roles, got %d (%v)", len(val), val)
|
||||
}
|
||||
|
||||
if val[0] != 30 || val[1] != 4 || val[2] != 44 || val[3] != 52 || val[4] != 61 {
|
||||
t.Fatalf("unexpected roles in context: %v", val)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetClaims(t *testing.T) {
|
||||
claims := &models.Claims{
|
||||
UsersID: "user123",
|
||||
|
||||
+78
-8
@@ -28,6 +28,26 @@ type ProjectClaim struct {
|
||||
OfficeID int `json:"office_id,omitempty"`
|
||||
}
|
||||
|
||||
func (p *ProjectClaim) UnmarshalJSON(data []byte) error {
|
||||
type Alias ProjectClaim
|
||||
aux := &struct {
|
||||
RoleIDs RoleIDs `json:"role_ids"`
|
||||
*Alias
|
||||
}{
|
||||
Alias: (*Alias)(p),
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(data, &aux); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(p.RoleID) == 0 && len(aux.RoleIDs) > 0 {
|
||||
p.RoleID = aux.RoleIDs
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RoleIDs represents one or more role IDs.
|
||||
// It is defined as a custom type so we can implement flexible JSON unmarshalling
|
||||
// that accepts a single string ("1"), a single number (1), or an array ([1,2,...]).
|
||||
@@ -92,28 +112,78 @@ type Claims struct {
|
||||
|
||||
// UnmarshalJSON handles both "user_id" and "users_id" field names in JWT claims
|
||||
func (c *Claims) UnmarshalJSON(data []byte) error {
|
||||
type Alias Claims
|
||||
aux := &struct {
|
||||
aux := struct {
|
||||
UsersID string `json:"users_id"`
|
||||
UserID string `json:"user_id"`
|
||||
EmailAddress string `json:"email_address"`
|
||||
Email string `json:"email"`
|
||||
*Alias
|
||||
}{
|
||||
Alias: (*Alias)(c),
|
||||
}
|
||||
RoleID RoleIDs `json:"role_id"`
|
||||
AdditionalRoleID RoleIDs `json:"additional_role_id"`
|
||||
Projects json.RawMessage `json:"projects"`
|
||||
ProjectMetadata json.RawMessage `json:"project_metadata"`
|
||||
ProjectsMetadata json.RawMessage `json:"projects_metadata"`
|
||||
jwt.RegisteredClaims
|
||||
}{}
|
||||
|
||||
if err := json.Unmarshal(data, &aux); err != nil {
|
||||
return err
|
||||
}
|
||||
// If UsersID is empty but UserID is set, copy UserID to UsersID
|
||||
|
||||
c.UsersID = aux.UsersID
|
||||
c.EmailAddress = aux.EmailAddress
|
||||
c.RoleID = aux.RoleID
|
||||
c.AdditionalRoleID = aux.AdditionalRoleID
|
||||
c.RegisteredClaims = aux.RegisteredClaims
|
||||
|
||||
if c.UsersID == "" && aux.UserID != "" {
|
||||
c.UsersID = aux.UserID
|
||||
}
|
||||
// If EmailAddress is empty but Email is set, copy Email to EmailAddress
|
||||
|
||||
if c.EmailAddress == "" && aux.Email != "" {
|
||||
c.EmailAddress = aux.Email
|
||||
}
|
||||
|
||||
projects := make([]ProjectClaim, 0)
|
||||
rawProjectFields := []json.RawMessage{aux.Projects, aux.ProjectMetadata, aux.ProjectsMetadata}
|
||||
for _, raw := range rawProjectFields {
|
||||
parsedProjects, err := parseProjectClaims(raw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(parsedProjects) > 0 {
|
||||
projects = append(projects, parsedProjects...)
|
||||
}
|
||||
}
|
||||
|
||||
c.Projects = projects
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseProjectClaims(raw json.RawMessage) ([]ProjectClaim, error) {
|
||||
trimmed := bytes.TrimSpace(raw)
|
||||
if len(trimmed) == 0 || bytes.Equal(trimmed, []byte("null")) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
switch trimmed[0] {
|
||||
case '[':
|
||||
var projects []ProjectClaim
|
||||
if err := json.Unmarshal(trimmed, &projects); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return projects, nil
|
||||
case '{':
|
||||
var project ProjectClaim
|
||||
if err := json.Unmarshal(trimmed, &project); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []ProjectClaim{project}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported JSON for projects claim: %s", string(trimmed))
|
||||
}
|
||||
}
|
||||
|
||||
// ContextKey is a custom type for context keys to avoid collisions
|
||||
type ContextKey string
|
||||
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestClaimsUnmarshal_ProjectMetadataRoleIDs(t *testing.T) {
|
||||
payload := `{
|
||||
"users_id":"U0000000003",
|
||||
"email_address":"user@example.com",
|
||||
"role_id":[30],
|
||||
"project_metadata":[
|
||||
{"project_id":101,"alias":"proj-a","role_ids":[44,52]}
|
||||
]
|
||||
}`
|
||||
|
||||
var claims Claims
|
||||
if err := json.Unmarshal([]byte(payload), &claims); err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if len(claims.RoleID) != 1 || claims.RoleID[0] != 30 {
|
||||
t.Fatalf("unexpected base role: %v", claims.RoleID)
|
||||
}
|
||||
|
||||
if len(claims.Projects) != 1 {
|
||||
t.Fatalf("expected 1 project, got %d", len(claims.Projects))
|
||||
}
|
||||
|
||||
if claims.Projects[0].ProjectID != 101 {
|
||||
t.Fatalf("expected project_id=101, got %d", claims.Projects[0].ProjectID)
|
||||
}
|
||||
|
||||
if len(claims.Projects[0].RoleID) != 2 || claims.Projects[0].RoleID[0] != 44 || claims.Projects[0].RoleID[1] != 52 {
|
||||
t.Fatalf("unexpected project role ids: %v", claims.Projects[0].RoleID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaimsUnmarshal_ProjectsMetadataSingleObject(t *testing.T) {
|
||||
payload := `{
|
||||
"user_id":"U0000000003",
|
||||
"email":"user@example.com",
|
||||
"role_id":"30",
|
||||
"projects_metadata":{"project_id":202,"alias":"proj-b","role_id":"61"}
|
||||
}`
|
||||
|
||||
var claims Claims
|
||||
if err := json.Unmarshal([]byte(payload), &claims); err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if claims.UsersID != "U0000000003" {
|
||||
t.Fatalf("expected users_id fallback from user_id, got %s", claims.UsersID)
|
||||
}
|
||||
|
||||
if claims.EmailAddress != "user@example.com" {
|
||||
t.Fatalf("expected email fallback from email, got %s", claims.EmailAddress)
|
||||
}
|
||||
|
||||
if len(claims.Projects) != 1 {
|
||||
t.Fatalf("expected 1 project from projects_metadata object, got %d", len(claims.Projects))
|
||||
}
|
||||
|
||||
if len(claims.Projects[0].RoleID) != 1 || claims.Projects[0].RoleID[0] != 61 {
|
||||
t.Fatalf("unexpected project role ids: %v", claims.Projects[0].RoleID)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user