fixed region fetching in user_attributes
This commit is contained in:
@@ -13,13 +13,13 @@ func GetPermissionByResourceActionAndRole(resource, action string, roleID int) (
|
|||||||
resource, action, roleID)
|
resource, action, roleID)
|
||||||
|
|
||||||
query := `
|
query := `
|
||||||
SELECT p.permissions_id, p.permission_name, p.description, p.resource, p.action
|
SELECT p.permissions_id, p.permission_name, p.description, p.resource, p.action
|
||||||
FROM uess_user_management.permissions p
|
FROM uess_user_management.permissions p
|
||||||
INNER JOIN uess_user_management.role_permissions rp
|
INNER JOIN uess_user_management.role_permissions rp
|
||||||
ON p.permissions_id = rp.permission_id
|
ON p.permissions_id = rp.permission_id
|
||||||
WHERE p.resource = ? AND p.action = ? AND rp.role_id = ? AND rp.is_deleted = 0
|
WHERE p.resource = ? AND p.action = ? AND rp.role_id = ? AND rp.is_deleted = 0
|
||||||
LIMIT 1
|
LIMIT 1
|
||||||
`
|
`
|
||||||
|
|
||||||
var perm models.Permission
|
var perm models.Permission
|
||||||
err := db.DB.QueryRow(query, resource, action, roleID).Scan(
|
err := db.DB.QueryRow(query, resource, action, roleID).Scan(
|
||||||
@@ -47,10 +47,10 @@ func GetPermissionByResourceActionAndRole(resource, action string, roleID int) (
|
|||||||
// GetPolicyAttributesByPermission retrieves all policy attributes for a permission
|
// GetPolicyAttributesByPermission retrieves all policy attributes for a permission
|
||||||
func GetPolicyAttributesByPermission(permissionID int) ([]models.PolicyAttribute, error) {
|
func GetPolicyAttributesByPermission(permissionID int) ([]models.PolicyAttribute, error) {
|
||||||
query := `
|
query := `
|
||||||
SELECT id, attribute_name, attribute_type, comparison, attribute_value, permission_id
|
SELECT id, attribute_name, attribute_type, comparison, attribute_value, permission_id
|
||||||
FROM uess_user_management.policy_attributes
|
FROM uess_user_management.policy_attributes
|
||||||
WHERE permission_id = ?
|
WHERE permission_id = ?
|
||||||
`
|
`
|
||||||
|
|
||||||
rows, err := db.DB.Query(query, permissionID)
|
rows, err := db.DB.Query(query, permissionID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -83,12 +83,20 @@ func GetUserAttributes(userID string) (map[string]string, error) {
|
|||||||
log.Printf("[Repository] GetUserAttributes - userID=%s", userID)
|
log.Printf("[Repository] GetUserAttributes - userID=%s", userID)
|
||||||
|
|
||||||
query := `
|
query := `
|
||||||
SELECT attribute_name, attribute_value
|
SELECT attribute_name, attribute_value
|
||||||
FROM uess_user_management.user_attributes
|
FROM uess_user_management.user_attributes
|
||||||
WHERE users_id = ?
|
WHERE users_id = ?
|
||||||
`
|
|
||||||
|
|
||||||
rows, err := db.DB.Query(query, userID)
|
UNION
|
||||||
|
|
||||||
|
SELECT 'region' AS attribute_name, o.reg AS attribute_value
|
||||||
|
FROM uess_reference.office o
|
||||||
|
LEFT JOIN uess_user_management.users u
|
||||||
|
ON o.id = u.office_id
|
||||||
|
WHERE u.users_id = ?
|
||||||
|
`
|
||||||
|
|
||||||
|
rows, err := db.DB.Query(query, userID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("[Repository] ✗ Database error querying user attributes: %v", err)
|
log.Printf("[Repository] ✗ Database error querying user attributes: %v", err)
|
||||||
return nil, fmt.Errorf("error querying user attributes: %w", err)
|
return nil, fmt.Errorf("error querying user attributes: %w", err)
|
||||||
@@ -115,10 +123,10 @@ func GetUserByID(userID string) (*models.User, error) {
|
|||||||
log.Printf("[Repository] GetUserByID - userID=%s", userID)
|
log.Printf("[Repository] GetUserByID - userID=%s", userID)
|
||||||
|
|
||||||
query := `
|
query := `
|
||||||
SELECT users_id, email_address
|
SELECT users_id, email_address
|
||||||
FROM uess_user_management.users
|
FROM uess_user_management.users
|
||||||
WHERE users_id = ?
|
WHERE users_id = ?
|
||||||
`
|
`
|
||||||
|
|
||||||
var user models.User
|
var user models.User
|
||||||
err := db.DB.QueryRow(query, userID).Scan(&user.UsersID, &user.EmailAddress)
|
err := db.DB.QueryRow(query, userID).Scan(&user.UsersID, &user.EmailAddress)
|
||||||
@@ -138,10 +146,10 @@ func GetUserByID(userID string) (*models.User, error) {
|
|||||||
// GetAllPermissions retrieves all permissions (for caching)
|
// GetAllPermissions retrieves all permissions (for caching)
|
||||||
func GetAllPermissions() ([]models.Permission, error) {
|
func GetAllPermissions() ([]models.Permission, error) {
|
||||||
query := `
|
query := `
|
||||||
SELECT id, permission_name, description, resource, action
|
SELECT id, permission_name, description, resource, action
|
||||||
FROM uess_user_management.permissions
|
FROM uess_user_management.permissions
|
||||||
ORDER BY id
|
ORDER BY id
|
||||||
`
|
`
|
||||||
|
|
||||||
rows, err := db.DB.Query(query)
|
rows, err := db.DB.Query(query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -171,10 +179,10 @@ func GetAllPermissions() ([]models.Permission, error) {
|
|||||||
// GetAllPolicyAttributes retrieves all policy attributes (for caching)
|
// GetAllPolicyAttributes retrieves all policy attributes (for caching)
|
||||||
func GetAllPolicyAttributes() (map[int][]models.PolicyAttribute, error) {
|
func GetAllPolicyAttributes() (map[int][]models.PolicyAttribute, error) {
|
||||||
query := `
|
query := `
|
||||||
SELECT policy_attributes_id, attribute_name, attribute_type, comparison, attribute_value, permission_id
|
SELECT policy_attributes_id, attribute_name, attribute_type, comparison, attribute_value, permission_id
|
||||||
FROM uess_user_management.policy_attributes
|
FROM uess_user_management.policy_attributes
|
||||||
ORDER BY permission_id, policy_attributes_id
|
ORDER BY permission_id, policy_attributes_id
|
||||||
`
|
`
|
||||||
|
|
||||||
rows, err := db.DB.Query(query)
|
rows, err := db.DB.Query(query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ func TestGetPolicyAttributesByPermissionSuccess(t *testing.T) {
|
|||||||
AddRow(1, "department", "user", "=", "engineering", 1).
|
AddRow(1, "department", "user", "=", "engineering", 1).
|
||||||
AddRow(2, "level", "user", ">=", "5", 1)
|
AddRow(2, "level", "user", ">=", "5", 1)
|
||||||
|
|
||||||
mock.ExpectQuery("SELECT id, attribute_name, attribute_type, comparison, attribute_value, permission_id FROM uess_user_management.policy_attributes WHERE permission_id = \\?").
|
mock.ExpectQuery("SELECT id, attribute_name, attribute_type, comparison, attribute_value, permission_id FROM uess_user_management.policy_attributes WHERE permission_id = ?").
|
||||||
WithArgs(1).
|
WithArgs(1).
|
||||||
WillReturnRows(rows)
|
WillReturnRows(rows)
|
||||||
|
|
||||||
|
|||||||
@@ -60,9 +60,17 @@ func compare(actual, expected, operator string) bool {
|
|||||||
|
|
||||||
switch operator {
|
switch operator {
|
||||||
case "=":
|
case "=":
|
||||||
|
// Special logic for region: allow '1' and '01' to match
|
||||||
|
if isRegionComparison(actual, expected) {
|
||||||
|
return normalizeRegion(actual) == normalizeRegion(expected)
|
||||||
|
}
|
||||||
return actual == expected // case-sensitive
|
return actual == expected // case-sensitive
|
||||||
|
|
||||||
case "!=":
|
case "!=":
|
||||||
|
// Special logic for region: allow '1' and '01' to match
|
||||||
|
if isRegionComparison(actual, expected) {
|
||||||
|
return normalizeRegion(actual) != normalizeRegion(expected)
|
||||||
|
}
|
||||||
return actual != expected // case-sensitive
|
return actual != expected // case-sensitive
|
||||||
|
|
||||||
case ">":
|
case ">":
|
||||||
@@ -94,6 +102,31 @@ func compare(actual, expected, operator string) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Checks if the comparison is for region attribute
|
||||||
|
func isRegionComparison(actual, expected string) bool {
|
||||||
|
// Only trigger for region values that are numeric or zero-padded numeric
|
||||||
|
// This is a heuristic: if both are digits or zero-padded digits, treat as region
|
||||||
|
return isRegionValue(actual) && isRegionValue(expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
func isRegionValue(val string) bool {
|
||||||
|
val = strings.TrimLeft(val, "0")
|
||||||
|
return len(val) > 0 && isDigits(val)
|
||||||
|
}
|
||||||
|
|
||||||
|
func isDigits(val string) bool {
|
||||||
|
for _, r := range val {
|
||||||
|
if r < '0' || r > '9' {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeRegion(val string) string {
|
||||||
|
return strings.TrimLeft(val, "0")
|
||||||
|
}
|
||||||
|
|
||||||
func numericCompare(actual, expected string, compareFn func(float64, float64) bool) bool {
|
func numericCompare(actual, expected string, compareFn func(float64, float64) bool) bool {
|
||||||
actualNum, err1 := strconv.ParseFloat(actual, 64)
|
actualNum, err1 := strconv.ParseFloat(actual, 64)
|
||||||
expectedNum, err2 := strconv.ParseFloat(expected, 64)
|
expectedNum, err2 := strconv.ParseFloat(expected, 64)
|
||||||
|
|||||||
Reference in New Issue
Block a user