From f0bc603a5fb86b5980f35de24de7b5d877069630 Mon Sep 17 00:00:00 2001 From: F04C Date: Tue, 24 Mar 2026 16:38:41 +0800 Subject: [PATCH] fixed region fetching in user_attributes --- repository/permission_repository.go | 64 +++++++++++++----------- repository/permission_repository_test.go | 2 +- services/policy_evaluator.go | 33 ++++++++++++ 3 files changed, 70 insertions(+), 29 deletions(-) diff --git a/repository/permission_repository.go b/repository/permission_repository.go index 66aba23..bb9d1a2 100644 --- a/repository/permission_repository.go +++ b/repository/permission_repository.go @@ -13,13 +13,13 @@ func GetPermissionByResourceActionAndRole(resource, action string, roleID int) ( resource, action, roleID) query := ` - SELECT p.permissions_id, p.permission_name, p.description, p.resource, p.action - FROM uess_user_management.permissions p - INNER JOIN uess_user_management.role_permissions rp - ON p.permissions_id = rp.permission_id - WHERE p.resource = ? AND p.action = ? AND rp.role_id = ? AND rp.is_deleted = 0 - LIMIT 1 - ` + SELECT p.permissions_id, p.permission_name, p.description, p.resource, p.action + FROM uess_user_management.permissions p + INNER JOIN uess_user_management.role_permissions rp + ON p.permissions_id = rp.permission_id + WHERE p.resource = ? AND p.action = ? AND rp.role_id = ? AND rp.is_deleted = 0 + LIMIT 1 + ` var perm models.Permission err := db.DB.QueryRow(query, resource, action, roleID).Scan( @@ -47,10 +47,10 @@ func GetPermissionByResourceActionAndRole(resource, action string, roleID int) ( // GetPolicyAttributesByPermission retrieves all policy attributes for a permission func GetPolicyAttributesByPermission(permissionID int) ([]models.PolicyAttribute, error) { query := ` - SELECT id, attribute_name, attribute_type, comparison, attribute_value, permission_id - FROM uess_user_management.policy_attributes - WHERE permission_id = ? - ` + SELECT id, attribute_name, attribute_type, comparison, attribute_value, permission_id + FROM uess_user_management.policy_attributes + WHERE permission_id = ? + ` rows, err := db.DB.Query(query, permissionID) if err != nil { @@ -83,12 +83,20 @@ func GetUserAttributes(userID string) (map[string]string, error) { log.Printf("[Repository] GetUserAttributes - userID=%s", userID) query := ` - SELECT attribute_name, attribute_value - FROM uess_user_management.user_attributes - WHERE users_id = ? - ` + SELECT attribute_name, attribute_value + FROM uess_user_management.user_attributes + WHERE users_id = ? - rows, err := db.DB.Query(query, userID) + UNION + + SELECT 'region' AS attribute_name, o.reg AS attribute_value + FROM uess_reference.office o + LEFT JOIN uess_user_management.users u + ON o.id = u.office_id + WHERE u.users_id = ? + ` + + rows, err := db.DB.Query(query, userID, userID) if err != nil { log.Printf("[Repository] ✗ Database error querying user attributes: %v", err) return nil, fmt.Errorf("error querying user attributes: %w", err) @@ -115,10 +123,10 @@ func GetUserByID(userID string) (*models.User, error) { log.Printf("[Repository] GetUserByID - userID=%s", userID) query := ` - SELECT users_id, email_address - FROM uess_user_management.users - WHERE users_id = ? - ` + SELECT users_id, email_address + FROM uess_user_management.users + WHERE users_id = ? + ` var user models.User err := db.DB.QueryRow(query, userID).Scan(&user.UsersID, &user.EmailAddress) @@ -138,10 +146,10 @@ func GetUserByID(userID string) (*models.User, error) { // GetAllPermissions retrieves all permissions (for caching) func GetAllPermissions() ([]models.Permission, error) { query := ` - SELECT id, permission_name, description, resource, action - FROM uess_user_management.permissions - ORDER BY id - ` + SELECT id, permission_name, description, resource, action + FROM uess_user_management.permissions + ORDER BY id + ` rows, err := db.DB.Query(query) if err != nil { @@ -171,10 +179,10 @@ func GetAllPermissions() ([]models.Permission, error) { // GetAllPolicyAttributes retrieves all policy attributes (for caching) func GetAllPolicyAttributes() (map[int][]models.PolicyAttribute, error) { query := ` - SELECT policy_attributes_id, attribute_name, attribute_type, comparison, attribute_value, permission_id - FROM uess_user_management.policy_attributes - ORDER BY permission_id, policy_attributes_id - ` + SELECT policy_attributes_id, attribute_name, attribute_type, comparison, attribute_value, permission_id + FROM uess_user_management.policy_attributes + ORDER BY permission_id, policy_attributes_id + ` rows, err := db.DB.Query(query) if err != nil { diff --git a/repository/permission_repository_test.go b/repository/permission_repository_test.go index a644c9a..3165730 100644 --- a/repository/permission_repository_test.go +++ b/repository/permission_repository_test.go @@ -35,7 +35,7 @@ func TestGetPolicyAttributesByPermissionSuccess(t *testing.T) { AddRow(1, "department", "user", "=", "engineering", 1). AddRow(2, "level", "user", ">=", "5", 1) - mock.ExpectQuery("SELECT id, attribute_name, attribute_type, comparison, attribute_value, permission_id FROM uess_user_management.policy_attributes WHERE permission_id = \\?"). + mock.ExpectQuery("SELECT id, attribute_name, attribute_type, comparison, attribute_value, permission_id FROM uess_user_management.policy_attributes WHERE permission_id = ?"). WithArgs(1). WillReturnRows(rows) diff --git a/services/policy_evaluator.go b/services/policy_evaluator.go index 1a362af..f2f6746 100644 --- a/services/policy_evaluator.go +++ b/services/policy_evaluator.go @@ -60,9 +60,17 @@ func compare(actual, expected, operator string) bool { switch operator { case "=": + // Special logic for region: allow '1' and '01' to match + if isRegionComparison(actual, expected) { + return normalizeRegion(actual) == normalizeRegion(expected) + } return actual == expected // case-sensitive case "!=": + // Special logic for region: allow '1' and '01' to match + if isRegionComparison(actual, expected) { + return normalizeRegion(actual) != normalizeRegion(expected) + } return actual != expected // case-sensitive case ">": @@ -94,6 +102,31 @@ func compare(actual, expected, operator string) bool { } } +// Checks if the comparison is for region attribute +func isRegionComparison(actual, expected string) bool { + // Only trigger for region values that are numeric or zero-padded numeric + // This is a heuristic: if both are digits or zero-padded digits, treat as region + return isRegionValue(actual) && isRegionValue(expected) +} + +func isRegionValue(val string) bool { + val = strings.TrimLeft(val, "0") + return len(val) > 0 && isDigits(val) +} + +func isDigits(val string) bool { + for _, r := range val { + if r < '0' || r > '9' { + return false + } + } + return true +} + +func normalizeRegion(val string) string { + return strings.TrimLeft(val, "0") +} + func numericCompare(actual, expected string, compareFn func(float64, float64) bool) bool { actualNum, err1 := strconv.ParseFloat(actual, 64) expectedNum, err2 := strconv.ParseFloat(expected, 64)