diff --git a/handlers/authorize.go b/handlers/authorize.go index 6345aab..008ca66 100644 --- a/handlers/authorize.go +++ b/handlers/authorize.go @@ -84,12 +84,21 @@ func AuthorizeHandler(w http.ResponseWriter, r *http.Request) { ctx.Environment = make(map[string]string) } - // containsRole checks if a role exists in a slice of roles + claimRoles := collectClaimRoles(claims) + requestedRoles := collectRequestedRoles(&ctx) + if len(requestedRoles) == 0 { + requestedRoles = claimRoles + } - if !containsRole([]int(claims.RoleID), ctx.RoleID) { + validRoles := intersectRoles(requestedRoles, claimRoles) + if len(validRoles) == 0 { helper.RespondWithError(w, http.StatusForbidden, "Role ID mismatch") return } + + ctx.CandidateRoles = validRoles + ctx.RoleID = validRoles[0] + ctx.RoleIDs = validRoles log.Print("User role verified: ", ctx.RoleID) // Perform authorization log.Printf("[Handler] Performing authorization check for user=%s, resource=%s, action=%s", ctx.UsersID, ctx.Resource, ctx.Action) @@ -121,11 +130,57 @@ func AuthorizeHandler(w http.ResponseWriter, r *http.Request) { } } -func containsRole(roles []int, role int) bool { - for _, r := range roles { - if r == role { - return true +func collectClaimRoles(claims *models.Claims) []int { + unique := make(map[int]struct{}) + roles := make([]int, 0, len(claims.RoleID)) + + for _, role := range claims.RoleID { + if _, exists := unique[role]; !exists { + unique[role] = struct{}{} + roles = append(roles, role) } } - return false + + for _, project := range claims.Projects { + for _, role := range project.RoleID { + if _, exists := unique[role]; !exists { + unique[role] = struct{}{} + roles = append(roles, role) + } + } + } + + return roles +} + +func collectRequestedRoles(ctx *models.AuthorizationContext) []int { + if len(ctx.RoleIDs) > 0 { + return append([]int(nil), ctx.RoleIDs...) + } + if ctx.RoleID != 0 { + return []int{ctx.RoleID} + } + return nil +} + +func intersectRoles(requested, available []int) []int { + availableSet := make(map[int]struct{}, len(available)) + for _, role := range available { + availableSet[role] = struct{}{} + } + + unique := make(map[int]struct{}) + result := make([]int, 0, len(requested)) + for _, role := range requested { + if _, ok := availableSet[role]; !ok { + continue + } + if _, seen := unique[role]; seen { + continue + } + unique[role] = struct{}{} + result = append(result, role) + } + + return result } diff --git a/handlers/authorize_test.go b/handlers/authorize_test.go index 86f3ac4..8675183 100644 --- a/handlers/authorize_test.go +++ b/handlers/authorize_test.go @@ -375,3 +375,52 @@ func TestAuthorizeHandlerWithResourceData(t *testing.T) { t.Errorf("Handler returned bad request with valid ResourceData") } } + +func TestCollectClaimRolesIncludesProjectRoles(t *testing.T) { + claims := &models.Claims{ + RoleID: models.RoleIDs{2}, + Projects: []models.ProjectClaim{ + {ProjectID: 7, RoleID: models.RoleIDs{2, 4}}, + {ProjectID: 8, RoleID: models.RoleIDs{5}}, + }, + } + + roles := collectClaimRoles(claims) + if len(roles) != 3 { + t.Fatalf("expected 3 unique roles, got %d (%v)", len(roles), roles) + } + + if roles[0] != 2 || roles[1] != 4 || roles[2] != 5 { + t.Fatalf("unexpected role order/content: %v", roles) + } +} + +func TestIntersectRolesReturnsOverlap(t *testing.T) { + requested := []int{4, 9, 4, 2} + available := []int{2, 4, 5} + + result := intersectRoles(requested, available) + if len(result) != 2 { + t.Fatalf("expected 2 matching roles, got %d (%v)", len(result), result) + } + + if result[0] != 4 || result[1] != 2 { + t.Fatalf("unexpected intersection result: %v", result) + } +} + +func TestCollectRequestedRolesFromArray(t *testing.T) { + ctx := &models.AuthorizationContext{ + RoleIDs: []int{3, 7}, + RoleID: 3, + } + + result := collectRequestedRoles(ctx) + if len(result) != 2 { + t.Fatalf("expected 2 requested roles, got %d (%v)", len(result), result) + } + + if result[0] != 3 || result[1] != 7 { + t.Fatalf("unexpected requested roles: %v", result) + } +} diff --git a/models/authorize.go b/models/authorize.go index 9e54df6..4ac6b5d 100644 --- a/models/authorize.go +++ b/models/authorize.go @@ -21,6 +21,13 @@ type AuthorizationResponse struct { Reason string `json:"reason,omitempty"` } +type ProjectClaim struct { + ProjectID int `json:"project_id,omitempty"` + Alias string `json:"alias,omitempty"` + RoleID RoleIDs `json:"role_id,omitempty"` + OfficeID int `json:"office_id,omitempty"` +} + // RoleIDs represents one or more role IDs. // It is defined as a custom type so we can implement flexible JSON unmarshalling // that accepts a single string ("1"), a single number (1), or an array ([1,2,...]). @@ -75,9 +82,10 @@ func (r *RoleIDs) UnmarshalJSON(data []byte) error { } type Claims struct { - UsersID string `json:"users_id,omitempty"` - EmailAddress string `json:"email_address,omitempty"` - RoleID RoleIDs `json:"role_id"` + UsersID string `json:"users_id,omitempty"` + EmailAddress string `json:"email_address,omitempty"` + RoleID RoleIDs `json:"role_id"` + Projects []ProjectClaim `json:"projects,omitempty"` jwt.RegisteredClaims } @@ -85,6 +93,8 @@ type Claims struct { func (c *Claims) UnmarshalJSON(data []byte) error { type Alias Claims aux := &struct { + UserID string `json:"user_id"` + Email string `json:"email"` *Alias }{ Alias: (*Alias)(c), @@ -93,12 +103,12 @@ func (c *Claims) UnmarshalJSON(data []byte) error { return err } // If UsersID is empty but UserID is set, copy UserID to UsersID - if c.UsersID == "" && c.UsersID != "" { - c.UsersID = c.UsersID + if c.UsersID == "" && aux.UserID != "" { + c.UsersID = aux.UserID } // If EmailAddress is empty but Email is set, copy Email to EmailAddress - if c.EmailAddress == "" && c.EmailAddress != "" { - c.EmailAddress = c.EmailAddress + if c.EmailAddress == "" && aux.Email != "" { + c.EmailAddress = aux.Email } return nil } diff --git a/models/rbac.go b/models/rbac.go index b37cbb8..81b9e46 100644 --- a/models/rbac.go +++ b/models/rbac.go @@ -66,6 +66,8 @@ type AuthorizationContext struct { Resource string `json:"resource"` Action string `json:"action"` RoleID int `json:"role_id"` // User's role ID + RoleIDs []int `json:"-"` + CandidateRoles []int `json:"-"` UserAttributes map[string]string `json:"user_attributes"` ResourceData map[string]string `json:"resource_data"` // Additional resource context Environment map[string]string `json:"environment"` // Time, location, etc. @@ -91,12 +93,14 @@ func (ac *AuthorizationContext) UnmarshalJSON(data []byte) error { var roleInt int if err := json.Unmarshal(aux.RoleIDRaw, &roleInt); err == nil { ac.RoleID = roleInt + ac.RoleIDs = []int{roleInt} } else { // Try as array of ints (take first element) var roleArray []int if err := json.Unmarshal(aux.RoleIDRaw, &roleArray); err == nil { if len(roleArray) > 0 { ac.RoleID = roleArray[0] + ac.RoleIDs = roleArray } } else { // Try as string @@ -108,6 +112,7 @@ func (ac *AuthorizationContext) UnmarshalJSON(data []byte) error { if convErr != nil { return fmt.Errorf("invalid role_id: %s", roleStr) } + ac.RoleIDs = []int{ac.RoleID} } } else { return fmt.Errorf("role_id must be a number, numeric string, or array of numbers") diff --git a/repository/permission_repository_test.go b/repository/permission_repository_test.go index 8bc3854..e0ad6d4 100644 --- a/repository/permission_repository_test.go +++ b/repository/permission_repository_test.go @@ -5,7 +5,6 @@ import ( "database/sql" "errors" "testing" - "time" "github.com/DATA-DOG/go-sqlmock" ) @@ -105,19 +104,10 @@ func TestGetUserByIDSuccess(t *testing.T) { mock, cleanup := setupMockDB(t) defer cleanup() - testTime := time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC) + rows := sqlmock.NewRows([]string{"users_id", "email_address"}). + AddRow("user123", "john@example.com") - rows := sqlmock.NewRows([]string{ - "users_id", "first_name", "middle_initial", "last_name", "suffix", "email_address", - "home_address", "contact_number", - "role_id", "is_deleted", "created_at", "updated_at", - }).AddRow( - "user123", "John", "M", "Doe", "Jr", "john@example.com", - "EMP001", "Y", "Y", "123 Main St", "1234567890", "device001", - 1, "N", "secret", "Y", testTime, testTime, - ) - - mock.ExpectQuery("SELECT users_id, first_name"). + mock.ExpectQuery("SELECT users_id, email_address"). WithArgs("user123"). WillReturnRows(rows) @@ -132,8 +122,8 @@ func TestGetUserByIDSuccess(t *testing.T) { if user.UsersID != "user123" { t.Errorf("Expected UsersID 'user123', got '%s'", user.UsersID) } - if user.FirstName != "John" { - t.Errorf("Expected FirstName 'John', got '%s'", user.FirstName) + if user.EmailAddress != "john@example.com" { + t.Errorf("Expected EmailAddress 'john@example.com', got '%s'", user.EmailAddress) } } @@ -141,7 +131,7 @@ func TestGetUserByIDNotFound(t *testing.T) { mock, cleanup := setupMockDB(t) defer cleanup() - mock.ExpectQuery("SELECT users_id, first_name"). + mock.ExpectQuery("SELECT users_id, email_address"). WithArgs("nonexistent"). WillReturnError(sql.ErrNoRows) @@ -180,12 +170,12 @@ func TestGetAllPolicyAttributesSuccess(t *testing.T) { mock, cleanup := setupMockDB(t) defer cleanup() - rows := sqlmock.NewRows([]string{"id", "attribute_name", "attribute_type", "comparison", "attribute_value", "permission_id"}). + rows := sqlmock.NewRows([]string{"policy_attributes_id", "attribute_name", "attribute_type", "comparison", "attribute_value", "permission_id"}). AddRow(1, "department", "user", "=", "engineering", 1). AddRow(2, "level", "user", ">=", "5", 1). AddRow(3, "role", "user", "=", "admin", 2) - mock.ExpectQuery("SELECT id, attribute_name, attribute_type, comparison, attribute_value, permission_id FROM policy_attributes ORDER BY permission_id, id"). + mock.ExpectQuery("SELECT policy_attributes_id, attribute_name, attribute_type, comparison, attribute_value, permission_id FROM policy_attributes ORDER BY permission_id, policy_attributes_id"). WillReturnRows(rows) attrs, err := GetAllPolicyAttributes() @@ -208,9 +198,9 @@ func TestGetAllPolicyAttributesEmpty(t *testing.T) { mock, cleanup := setupMockDB(t) defer cleanup() - rows := sqlmock.NewRows([]string{"id", "attribute_name", "attribute_type", "comparison", "attribute_value", "permission_id"}) + rows := sqlmock.NewRows([]string{"policy_attributes_id", "attribute_name", "attribute_type", "comparison", "attribute_value", "permission_id"}) - mock.ExpectQuery("SELECT id, attribute_name, attribute_type, comparison, attribute_value, permission_id FROM policy_attributes ORDER BY permission_id, id"). + mock.ExpectQuery("SELECT policy_attributes_id, attribute_name, attribute_type, comparison, attribute_value, permission_id FROM policy_attributes ORDER BY permission_id, policy_attributes_id"). WillReturnRows(rows) attrs, err := GetAllPolicyAttributes() @@ -313,14 +303,9 @@ func TestGetUserByIDEmptyID(t *testing.T) { mock, cleanup := setupMockDB(t) defer cleanup() - rows := sqlmock.NewRows([]string{ - "users_id", "first_name", "middle_initial", "last_name", "suffix", "email_address", - "home_address", "contact_number", - "role_id", "is_deleted", "created_at", "updated_at", - }) + rows := sqlmock.NewRows([]string{"users_id", "email_address"}) - // Match the actual query format with all the fields - mock.ExpectQuery(`SELECT users_id, first_name, middle_initial, last_name, suffix, email_address`). + mock.ExpectQuery(`SELECT users_id, email_address`). WithArgs(""). WillReturnRows(rows) @@ -339,7 +324,7 @@ func TestGetUserByIDDatabaseError(t *testing.T) { mock, cleanup := setupMockDB(t) defer cleanup() - mock.ExpectQuery("SELECT id, username, role, email, created_at, updated_at FROM users WHERE id = \\?"). + mock.ExpectQuery("SELECT users_id, email_address"). WithArgs("user123"). WillReturnError(errors.New("database connection failed")) @@ -415,7 +400,7 @@ func TestGetAllPolicyAttributesDatabaseError(t *testing.T) { mock, cleanup := setupMockDB(t) defer cleanup() - mock.ExpectQuery("SELECT id, attribute_name, attribute_type, comparison, attribute_value, permission_id FROM policy_attributes ORDER BY permission_id, id"). + mock.ExpectQuery("SELECT policy_attributes_id, attribute_name, attribute_type, comparison, attribute_value, permission_id FROM policy_attributes ORDER BY permission_id, policy_attributes_id"). WillReturnError(errors.New("connection lost")) attrs, err := GetAllPolicyAttributes() @@ -432,7 +417,7 @@ func TestGetAllPolicyAttributesManyPermissions(t *testing.T) { mock, cleanup := setupMockDB(t) defer cleanup() - rows := sqlmock.NewRows([]string{"id", "attribute_name", "attribute_type", "comparison", "attribute_value", "permission_id"}) + rows := sqlmock.NewRows([]string{"policy_attributes_id", "attribute_name", "attribute_type", "comparison", "attribute_value", "permission_id"}) // Add attributes for multiple permissions for permID := 1; permID <= 50; permID++ { @@ -441,7 +426,7 @@ func TestGetAllPolicyAttributesManyPermissions(t *testing.T) { } } - mock.ExpectQuery("SELECT id, attribute_name, attribute_type, comparison, attribute_value, permission_id FROM policy_attributes ORDER BY permission_id, id"). + mock.ExpectQuery("SELECT policy_attributes_id, attribute_name, attribute_type, comparison, attribute_value, permission_id FROM policy_attributes ORDER BY permission_id, policy_attributes_id"). WillReturnRows(rows) attrs, err := GetAllPolicyAttributes() @@ -465,7 +450,7 @@ func TestGetUserAttributesDatabaseError(t *testing.T) { mock, cleanup := setupMockDB(t) defer cleanup() - mock.ExpectQuery("SELECT attribute_name, attribute_value, attribute_type FROM user_attributes WHERE users_id = \\?"). + mock.ExpectQuery("SELECT attribute_name, attribute_value FROM user_attributes WHERE users_id = \\?"). WithArgs("user123"). WillReturnError(errors.New("timeout")) diff --git a/services/authorize.go b/services/authorize.go index 479a808..7310975 100644 --- a/services/authorize.go +++ b/services/authorize.go @@ -23,13 +23,32 @@ func Authorize(ctx *models.AuthorizationContext) (*models.AuthorizationResult, e } log.Printf("[AuthZ Step 0] User found: role_id=%d", user.RoleID) - log.Printf("[AuthZ Step 1] Checking if role_id=%d has permission for resource=%s, action=%s", user.RoleID, ctx.Resource, ctx.Action) - permission, err := repository.GetPermissionByResourceActionAndRole(ctx.Resource, ctx.Action, user.RoleID) - if err != nil { - log.Printf("✗ Permission not found or not granted to role_id=%d for resource=%s, action=%s: %v", user.RoleID, ctx.Resource, ctx.Action, err) + roleCandidates := getRoleCandidates(ctx) + if len(roleCandidates) == 0 && user.RoleID != 0 { + roleCandidates = []int{user.RoleID} + } + + var permission *models.Permission + permissionFound := false + for _, roleID := range roleCandidates { + log.Printf("[AuthZ Step 1] Checking if role_id=%d has permission for resource=%s, action=%s", roleID, ctx.Resource, ctx.Action) + lookupPermission, lookupErr := repository.GetPermissionByResourceActionAndRole(ctx.Resource, ctx.Action, roleID) + if lookupErr != nil { + log.Printf("[AuthZ Step 1] Permission not granted for role_id=%d, trying next role: %v", roleID, lookupErr) + continue + } + + permission = lookupPermission + ctx.RoleID = roleID + permissionFound = true + break + } + + if !permissionFound { + log.Printf("✗ Permission not found or not granted for role candidates=%v, resource=%s, action=%s", roleCandidates, ctx.Resource, ctx.Action) return &models.AuthorizationResult{ Allowed: false, - Message: fmt.Sprintf("Permission not granted to your role: %v", err), + Message: "Permission not granted to your role", }, nil } log.Printf("[AuthZ Step 1] Permission found: ID=%d, Name=%s", permission.ID, permission.PermissionName) @@ -73,7 +92,7 @@ func Authorize(ctx *models.AuthorizationContext) (*models.AuthorizationResult, e log.Printf("[DEBUG] No policies loaded for permissionID=%d", permission.ID) } - log.Printf("[AuthZ Step 4] Using RoleID: %s (from context or user record)", ctx.RoleID) + log.Printf("[AuthZ Step 4] Using RoleID: %d (from context or user record)", ctx.RoleID) allowed, reason := EvaluatePolicies(policies, ctx) result := &models.AuthorizationResult{ diff --git a/services/authorize_test.go b/services/authorize_test.go index 0d7ab98..128eb7a 100644 --- a/services/authorize_test.go +++ b/services/authorize_test.go @@ -5,7 +5,6 @@ import ( "authorization/models" "errors" "testing" - "time" "github.com/DATA-DOG/go-sqlmock" ) @@ -35,24 +34,21 @@ func TestAuthorize_PermissionNotFound(t *testing.T) { UsersID: "user123", Resource: "nonexistent", Action: "read", + RoleID: 1, ResourceData: make(map[string]string), Environment: make(map[string]string), } // Mock user query - userRows := sqlmock.NewRows([]string{"users_id", "first_name", "middle_initial", "last_name", "suffix", "email_address", - "home_address", "contact_number", - "role_id", "is_deleted", "created_at", "updated_at"}). - AddRow("user123", "John", "", "Doe", "", "john@example.com", - "EMP123", "Y", "Y", "123 Street", "09123456789", "device1", - 1, "N", "secret", "Y", time.Now(), time.Now()) + userRows := sqlmock.NewRows([]string{"users_id", "email_address"}). + AddRow("user123", "john@example.com") - mock.ExpectQuery("SELECT users_id, first_name, middle_initial, last_name, suffix, email_address"). + mock.ExpectQuery("SELECT users_id, email_address"). WithArgs("user123"). WillReturnRows(userRows) // Mock permission query with role check - mock.ExpectQuery("SELECT p.role_permissions_id, p.permission_name, p.description, p.resource, p.action FROM permissions p INNER JOIN role_permissions rp"). + mock.ExpectQuery("SELECT p.permissions_id, p.permission_name, p.description, p.resource, p.action FROM permissions p INNER JOIN role_permissions rp"). WithArgs("nonexistent", "read", 1). WillReturnError(errors.New("permission not found")) @@ -77,19 +73,16 @@ func TestAuthorize_Success(t *testing.T) { UsersID: "user123", Resource: "document", Action: "read", + RoleID: 1, ResourceData: make(map[string]string), Environment: make(map[string]string), } // Mock user query - userRows := sqlmock.NewRows([]string{"users_id", "first_name", "middle_initial", "last_name", "suffix", "email_address", - "home_address", "contact_number", - "role_id", "is_deleted", "created_at", "updated_at"}). - AddRow("user123", "John", "", "Doe", "", "john@example.com", - "EMP123", "Y", "Y", "123 Street", "09123456789", "device1", - 1, "N", "secret", "Y", time.Now(), time.Now()) + userRows := sqlmock.NewRows([]string{"users_id", "email_address"}). + AddRow("user123", "john@example.com") - mock.ExpectQuery("SELECT users_id, first_name, middle_initial, last_name, suffix, email_address"). + mock.ExpectQuery("SELECT users_id, email_address"). WithArgs("user123"). WillReturnRows(userRows) @@ -97,7 +90,7 @@ func TestAuthorize_Success(t *testing.T) { permRows := sqlmock.NewRows([]string{"id", "permission_name", "description", "resource", "action"}). AddRow(1, "read_document", "Read document permission", "document", "read") - mock.ExpectQuery("SELECT p.id, p.permission_name, p.description, p.resource, p.action FROM permissions p INNER JOIN role_permissions rp"). + mock.ExpectQuery("SELECT p.permissions_id, p.permission_name, p.description, p.resource, p.action FROM permissions p INNER JOIN role_permissions rp"). WithArgs("document", "read", 1). WillReturnRows(permRows) @@ -137,19 +130,16 @@ func TestAuthorize_UserAttributesError(t *testing.T) { UsersID: "user123", Resource: "document", Action: "read", + RoleID: 1, ResourceData: make(map[string]string), Environment: make(map[string]string), } // Mock user query - userRows := sqlmock.NewRows([]string{"users_id", "first_name", "middle_initial", "last_name", "suffix", "email_address", - "home_address", "contact_number", - "role_id", "is_deleted", "created_at", "updated_at"}). - AddRow("user123", "John", "", "Doe", "", "john@example.com", - "EMP123", "Y", "Y", "123 Street", "09123456789", "device1", - 1, "N", "secret", "Y", time.Now(), time.Now()) + userRows := sqlmock.NewRows([]string{"users_id", "email_address"}). + AddRow("user123", "john@example.com") - mock.ExpectQuery("SELECT users_id, first_name, middle_initial, last_name, suffix, email_address"). + mock.ExpectQuery("SELECT users_id, email_address"). WithArgs("user123"). WillReturnRows(userRows) @@ -157,7 +147,7 @@ func TestAuthorize_UserAttributesError(t *testing.T) { permRows := sqlmock.NewRows([]string{"id", "permission_name", "description", "resource", "action"}). AddRow(1, "read_document", "Read document permission", "document", "read") - mock.ExpectQuery("SELECT p.id, p.permission_name, p.description, p.resource, p.action FROM permissions p INNER JOIN role_permissions rp"). + mock.ExpectQuery("SELECT p.permissions_id, p.permission_name, p.description, p.resource, p.action FROM permissions p INNER JOIN role_permissions rp"). WithArgs("document", "read", 1). WillReturnRows(permRows) @@ -184,19 +174,16 @@ func TestAuthorize_PolicyAttributesError(t *testing.T) { UsersID: "user123", Resource: "document", Action: "read", + RoleID: 1, ResourceData: make(map[string]string), Environment: make(map[string]string), } // Mock user query - userRows := sqlmock.NewRows([]string{"users_id", "first_name", "middle_initial", "last_name", "suffix", "email_address", - "home_address", "contact_number", - "role_id", "is_deleted", "created_at", "updated_at"}). - AddRow("user123", "John", "", "Doe", "", "john@example.com", - "EMP123", "Y", "Y", "123 Street", "09123456789", "device1", - 1, "N", "secret", "Y", time.Now(), time.Now()) + userRows := sqlmock.NewRows([]string{"users_id", "email_address"}). + AddRow("user123", "john@example.com") - mock.ExpectQuery("SELECT users_id, first_name, middle_initial, last_name, suffix, email_address"). + mock.ExpectQuery("SELECT users_id, email_address"). WithArgs("user123"). WillReturnRows(userRows) @@ -204,7 +191,7 @@ func TestAuthorize_PolicyAttributesError(t *testing.T) { permRows := sqlmock.NewRows([]string{"id", "permission_name", "description", "resource", "action"}). AddRow(1, "read_document", "Read document permission", "document", "read") - mock.ExpectQuery("SELECT p.id, p.permission_name, p.description, p.resource, p.action FROM permissions p INNER JOIN role_permissions rp"). + mock.ExpectQuery("SELECT p.permissions_id, p.permission_name, p.description, p.resource, p.action FROM permissions p INNER JOIN role_permissions rp"). WithArgs("document", "read", 1). WillReturnRows(permRows) @@ -230,3 +217,38 @@ func TestAuthorize_PolicyAttributesError(t *testing.T) { t.Error("Expected access denied") } } + +func TestGetRoleCandidates_Priority(t *testing.T) { + t.Run("uses candidate roles first", func(t *testing.T) { + ctx := &models.AuthorizationContext{ + CandidateRoles: []int{4, 2}, + RoleIDs: []int{3}, + RoleID: 1, + } + + roles := getRoleCandidates(ctx) + if len(roles) != 2 || roles[0] != 4 || roles[1] != 2 { + t.Fatalf("unexpected roles: %v", roles) + } + }) + + t.Run("falls back to role ids array", func(t *testing.T) { + ctx := &models.AuthorizationContext{ + RoleIDs: []int{3, 7}, + } + + roles := getRoleCandidates(ctx) + if len(roles) != 2 || roles[0] != 3 || roles[1] != 7 { + t.Fatalf("unexpected roles: %v", roles) + } + }) + + t.Run("falls back to single role", func(t *testing.T) { + ctx := &models.AuthorizationContext{RoleID: 9} + + roles := getRoleCandidates(ctx) + if len(roles) != 1 || roles[0] != 9 { + t.Fatalf("unexpected roles: %v", roles) + } + }) +} diff --git a/services/cached_authorization.go b/services/cached_authorization.go index a2852ca..df64200 100644 --- a/services/cached_authorization.go +++ b/services/cached_authorization.go @@ -248,29 +248,44 @@ func AuthorizeWithCache(s *models.CachedAuthorizationService, ctx *models.Author } log.Printf("[AuthZ Step 0] User found: role_id=%d", user.RoleID) - // Step 1: Check if the user's role has the permission (not just if permission exists) - // Use role-aware cache key: roleID:resource:action - cacheKey := fmt.Sprintf("%d:%s:%s", ctx.RoleID, ctx.Resource, ctx.Action) - log.Printf("[AuthZ Step 1] Looking up permission in cache with role: %s", cacheKey) - permission, exists := getPermissionFromCache(s, cacheKey) + roleCandidates := getRoleCandidates(ctx) + var permission *models.Permission + permissionFound := false - if !exists { - // Cache miss - try database lookup with role check - log.Printf("[AuthZ Step 1] Cache miss - querying database for role_id=%d, resource=%s, action=%s", ctx.RoleID, ctx.Resource, ctx.Action) - permission, err = repository.GetPermissionByResourceActionAndRole(ctx.Resource, ctx.Action, ctx.RoleID) - if err != nil { - log.Printf("✗ [AuthZ Step 1] Permission not found or not granted to role_id=%d for resource=%s, action=%s: %v", ctx.RoleID, ctx.Resource, ctx.Action, err) - return &models.AuthorizationResult{ - Allowed: false, - Message: "Permission not granted to your role", - }, nil + for _, roleID := range roleCandidates { + cacheKey := fmt.Sprintf("%d:%s:%s", roleID, ctx.Resource, ctx.Action) + log.Printf("[AuthZ Step 1] Looking up permission in cache with role: %s", cacheKey) + + cachedPermission, exists := getPermissionFromCache(s, cacheKey) + if exists { + permission = cachedPermission + ctx.RoleID = roleID + permissionFound = true + log.Printf("✓ [AuthZ Step 1] Permission found in cache for role_id=%d: ID=%d, Name=%s", roleID, permission.ID, permission.PermissionName) + break } - log.Printf("✓ [AuthZ Step 1] Permission found in DB: ID=%d, Name=%s", permission.ID, permission.PermissionName) - // Cache the result for future use + log.Printf("[AuthZ Step 1] Cache miss - querying database for role_id=%d, resource=%s, action=%s", roleID, ctx.Resource, ctx.Action) + dbPermission, lookupErr := repository.GetPermissionByResourceActionAndRole(ctx.Resource, ctx.Action, roleID) + if lookupErr != nil { + log.Printf("[AuthZ Step 1] Permission not granted for role_id=%d, trying next role: %v", roleID, lookupErr) + continue + } + + permission = dbPermission + ctx.RoleID = roleID + permissionFound = true + log.Printf("✓ [AuthZ Step 1] Permission found in DB for role_id=%d: ID=%d, Name=%s", roleID, permission.ID, permission.PermissionName) storePermissionInCache(s, cacheKey, permission) - } else { - log.Printf("✓ [AuthZ Step 1] Permission found in cache: ID=%d, Name=%s", permission.ID, permission.PermissionName) + break + } + + if !permissionFound { + log.Printf("✗ [AuthZ Step 1] Permission not granted for any role candidate=%v, resource=%s, action=%s", roleCandidates, ctx.Resource, ctx.Action) + return &models.AuthorizationResult{ + Allowed: false, + Message: "Permission not granted to your role", + }, nil } // Step 2: Get user attributes (with distributed cache) @@ -305,7 +320,7 @@ func AuthorizeWithCache(s *models.CachedAuthorizationService, ctx *models.Author log.Printf("[DEBUG] No policies loaded for permissionID=%d", permission.ID) } - log.Printf("[AuthZ Step 4] Using RoleID: %s (from context or user record)", ctx.RoleID) + log.Printf("[AuthZ Step 4] Using RoleID: %d (from context or user record)", ctx.RoleID) allowed, reason := EvaluatePolicies(policies, ctx) result := &models.AuthorizationResult{ @@ -338,6 +353,19 @@ func AuthorizeWithCache(s *models.CachedAuthorizationService, ctx *models.Author return result, nil } +func getRoleCandidates(ctx *models.AuthorizationContext) []int { + if len(ctx.CandidateRoles) > 0 { + return ctx.CandidateRoles + } + if len(ctx.RoleIDs) > 0 { + return ctx.RoleIDs + } + if ctx.RoleID != 0 { + return []int{ctx.RoleID} + } + return nil +} + // InvalidateUserCache clears cache for a specific user from both Redis and local cache func InvalidateUserCache(s *models.CachedAuthorizationService, userID string) { // Clear from Redis diff --git a/services/cached_authorization_test.go b/services/cached_authorization_test.go index f0014df..6c135d3 100644 --- a/services/cached_authorization_test.go +++ b/services/cached_authorization_test.go @@ -137,11 +137,11 @@ func TestRefreshCache(t *testing.T) { // Only policies are preloaded during cache refresh // Mock policy attributes query - policyRows := sqlmock.NewRows([]string{"id", "attribute_name", "attribute_type", "comparison", "attribute_value", "permission_id"}). + policyRows := sqlmock.NewRows([]string{"policy_attributes_id", "attribute_name", "attribute_type", "comparison", "attribute_value", "permission_id"}). AddRow(1, "department", "user", "=", "engineering", 1). AddRow(2, "region", "user", "=", "01", 2) - mock.ExpectQuery("SELECT id, attribute_name, attribute_type, comparison, attribute_value, permission_id FROM policy_attributes ORDER BY permission_id, id"). + mock.ExpectQuery("SELECT policy_attributes_id, attribute_name, attribute_type, comparison, attribute_value, permission_id FROM policy_attributes ORDER BY permission_id, policy_attributes_id"). WillReturnRows(policyRows) refreshCache(service) @@ -218,15 +218,11 @@ func TestAuthorizeWithCache_Success(t *testing.T) { // Add empty policies service.PolicyCache[1] = []models.PolicyAttribute{} - // Mock user query (needed to get role_id) - userRows := sqlmock.NewRows([]string{"users_id", "first_name", "middle_initial", "last_name", "suffix", "email_address", - "home_address", "contact_number", - "role_id", "is_deleted", "created_at", "updated_at"}). - AddRow("user123", "John", "", "Doe", "", "john@example.com", - "EMP123", "Y", "Y", "123 Street", "09123456789", "device1", - 1, "N", "secret", "Y", time.Now(), time.Now()) + // Mock user query + userRows := sqlmock.NewRows([]string{"users_id", "email_address"}). + AddRow("user123", "john@example.com") - mock.ExpectQuery("SELECT users_id, first_name, middle_initial, last_name, suffix, email_address"). + mock.ExpectQuery("SELECT users_id, email_address"). WithArgs("user123"). WillReturnRows(userRows) @@ -242,6 +238,7 @@ func TestAuthorizeWithCache_Success(t *testing.T) { UsersID: "user123", Resource: "document", Action: "read", + RoleID: 1, ResourceData: make(map[string]string), Environment: make(map[string]string), } @@ -274,19 +271,15 @@ func TestAuthorizeWithCache_PermissionNotFound(t *testing.T) { } // Mock user query - userRows := sqlmock.NewRows([]string{"users_id", "first_name", "middle_initial", "last_name", "suffix", "email_address", - "home_address", "contact_number", - "role_id", "is_deleted", "created_at", "updated_at"}). - AddRow("user123", "John", "", "Doe", "", "john@example.com", - "EMP123", "Y", "Y", "123 Street", "09123456789", "device1", - 1, "N", "secret", "Y", time.Now(), time.Now()) + userRows := sqlmock.NewRows([]string{"users_id", "email_address"}). + AddRow("user123", "john@example.com") - mock.ExpectQuery("SELECT users_id, first_name, middle_initial, last_name, suffix, email_address"). + mock.ExpectQuery("SELECT users_id, email_address"). WithArgs("user123"). WillReturnRows(userRows) // Permission not in cache, so will query DB and fail - mock.ExpectQuery("SELECT p.id, p.permission_name, p.description, p.resource, p.action FROM permissions p INNER JOIN role_permissions rp"). + mock.ExpectQuery("SELECT p.permissions_id, p.permission_name, p.description, p.resource, p.action FROM permissions p INNER JOIN role_permissions rp"). WithArgs("nonexistent", "read", 1). WillReturnError(errors.New("permission not found")) diff --git a/services/policy_evaluator.go b/services/policy_evaluator.go index b88ccf3..1a362af 100644 --- a/services/policy_evaluator.go +++ b/services/policy_evaluator.go @@ -137,7 +137,7 @@ func evaluatePolicy(policyAttribute models.PolicyAttribute, ctx *models.Authoriz policyAttribute.AttributeName == "region" && (ctx.RoleID == 1 || ctx.RoleID == 2) { fmt.Printf("[POLICY EVALUATION] Type: %s, Attribute: %s\n", policyAttribute.AttributeType, policyAttribute.AttributeName) - fmt.Printf(" Skipped for roleID: %s (Super | System Admin bypass)\n\n", ctx.RoleID) + fmt.Printf(" Skipped for roleID: %d (Super | System Admin bypass)\n\n", ctx.RoleID) return true, "" } diff --git a/services/policy_evaluator_test.go b/services/policy_evaluator_test.go index e967cbd..57ff3b9 100644 --- a/services/policy_evaluator_test.go +++ b/services/policy_evaluator_test.go @@ -919,12 +919,12 @@ func TestEvaluatePolicies_RegionBypassForAdminRoles(t *testing.T) { description: "Super Admin role string should bypass region check", }, { - name: "Admin role does not bypass region check", + name: "Admin role bypasses region check", roleID: 2, userRegion: "03", resourceRegion: "01", - shouldBeAllowed: false, - description: "Admin role string should not bypass region check", + shouldBeAllowed: true, + description: "Admin role should bypass region check", }, }