diff --git a/handlers/jwt.go b/handlers/jwt.go index 29808c2..f7a421d 100644 --- a/handlers/jwt.go +++ b/handlers/jwt.go @@ -96,7 +96,7 @@ func GenerateTokens(email, userAgent, ipAddress string) (string, string, error) userID = helper.UUIDGenerator() } - roleID, err := services.GetRoleIDFromEmail(email) + roleID, err := services.GetRoleIDsFromEmail(email) if err != nil { return "", "", fmt.Errorf("error checking role in database: %w", err) } @@ -136,7 +136,7 @@ func GenerateTokens(email, userAgent, ipAddress string) (string, string, error) } _, err = db.DB.Exec(` - INSERT INTO jwt_sessions (id, user_id, refresh_token_hash, user_agent, ip_address, created_at, updated_at, expires_at, is_revoked) + INSERT INTO jwt_sessions (jwt_sessions_id, user_id, refresh_token_hash, user_agent, ip_address, created_at, updated_at, expires_at, is_revoked) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) `, sessionID, userID, refreshTokenHash, userAgent, ipAddress, currentTime, currentTime, expiresAt, false) if err != nil { @@ -155,7 +155,18 @@ func GenerateTokens(email, userAgent, ipAddress string) (string, string, error) } } - accessToken, err := generateAccessToken(email, sessionID, userID, roleID) + // Convert roleIDs slice to a comma-separated string for the token claim + var roleIDsStr string + if len(roleID) > 0 { + for i, r := range roleID { + if i > 0 { + roleIDsStr += "," + } + roleIDsStr += fmt.Sprintf("%d", r) + } + } + + accessToken, err := generateAccessToken(email, sessionID, userID, roleIDsStr) if err != nil { return "", "", fmt.Errorf(errFormatWithContext, errMsgFailedToGenerateAccessToken, err) } @@ -318,13 +329,21 @@ func RefreshAccessToken(refreshTokenString, userAgent, ipAddress string) (string userID = session.UserID // Fallback to session's user ID } - roleID, err := services.GetRoleIDFromEmail(email) + roleIDs, err := services.GetRoleIDsFromEmail(email) + var roleIDsStr string if err != nil { helper.LogError(err, fmt.Sprintf("Failed to fetch role ID for email %s during refresh", email)) - roleID = "" + roleIDsStr = "" + } else { + for i, r := range roleIDs { + if i > 0 { + roleIDsStr += "," + } + roleIDsStr += fmt.Sprintf("%d", r) + } } - accessToken, err := generateAccessToken(email, session.ID, userID, roleID) + accessToken, err := generateAccessToken(email, session.ID, userID, roleIDsStr) if err != nil { helper.LogError(err, "Failed to generate access token during refresh") return "", fmt.Errorf("failed to generate access token: %w", err) @@ -469,13 +488,21 @@ func RefreshAccessTokenWithEmailFallback(refreshTokenString, userAgent, ipAddres userID = session.UserID // Fallback to session's user ID } - roleID, err := services.GetRoleIDFromEmail(email) + roleIDs, err := services.GetRoleIDsFromEmail(email) + var roleIDsStr string if err != nil { helper.LogError(err, fmt.Sprintf("Failed to fetch role ID for email %s during refresh", email)) - roleID = "" + roleIDsStr = "" + } else { + for i, r := range roleIDs { + if i > 0 { + roleIDsStr += "," + } + roleIDsStr += fmt.Sprintf("%d", r) + } } - accessToken, err := generateAccessToken(email, session.ID, userID, roleID) + accessToken, err := generateAccessToken(email, session.ID, userID, roleIDsStr) if err != nil { helper.LogError(err, "Failed to generate access token during refresh") return "", fmt.Errorf("failed to generate access token: %w", err) diff --git a/services/users.go b/services/users.go index e84d436..e671bf4 100644 --- a/services/users.go +++ b/services/users.go @@ -7,7 +7,7 @@ import ( func GetUserID(email string) (string, error) { log.Print(email) - query := `SELECT user_id FROM users WHERE email_address = ? AND is_deleted = 0 LIMIT 1;` + query := `SELECT users_id FROM uess_user_management.users WHERE email_address = ? AND is_deleted = 0 LIMIT 1;` var id string err := db.DB.QueryRow(query, email).Scan(&id) if err != nil { @@ -51,23 +51,29 @@ func GetUserIDFromEmail(email string) (string, error) { return id, nil } -func GetRoleIDFromEmail(email string) (string, error) { +func GetRoleIDsFromEmail(email string) ([]int, error) { log.Print(email) - query := `SELECT role_id - FROM ( - SELECT r.id AS role_id, 1 AS priority - FROM roles r - JOIN users u ON u.role_id = r.id - WHERE u.email_address = ? - AND u.is_deleted = 0 - ) t - ORDER BY priority ASC - LIMIT 1; - ` - var roleID string - err := db.DB.QueryRow(query, email).Scan(&roleID) + query := `SELECT ur.role_id + FROM uess_user_management.user_roles ur + JOIN uess_user_management.users u ON ur.users_id = u.users_id + WHERE u.email_address = ? + AND u.is_deleted = 0` + rows, err := db.DB.Query(query, email) if err != nil { - return "", err + return nil, err } - return roleID, nil + defer rows.Close() + + var roleIDs []int + for rows.Next() { + var roleID int + if err := rows.Scan(&roleID); err != nil { + return nil, err + } + roleIDs = append(roleIDs, roleID) + } + if err := rows.Err(); err != nil { + return nil, err + } + return roleIDs, nil }