diff --git a/handlers/jwt.go b/handlers/jwt.go index f7a421d..048ed5a 100644 --- a/handlers/jwt.go +++ b/handlers/jwt.go @@ -96,6 +96,7 @@ func GenerateTokens(email, userAgent, ipAddress string) (string, string, error) userID = helper.UUIDGenerator() } + log.Print("userID:", userID) roleID, err := services.GetRoleIDsFromEmail(email) if err != nil { return "", "", fmt.Errorf("error checking role in database: %w", err) @@ -125,7 +126,7 @@ func GenerateTokens(email, userAgent, ipAddress string) (string, string, error) session := models.JWTSession{ ID: sessionID, - UserID: userID, + UsersID: userID, RefreshTokenHash: refreshTokenHash, UserAgent: userAgent, IPAddress: ipAddress, @@ -166,7 +167,7 @@ func GenerateTokens(email, userAgent, ipAddress string) (string, string, error) } } - accessToken, err := generateAccessToken(email, sessionID, userID, roleIDsStr) + accessToken, err := generateAccessToken(email, sessionID, userID, roleID) if err != nil { return "", "", fmt.Errorf(errFormatWithContext, errMsgFailedToGenerateAccessToken, err) } @@ -175,7 +176,7 @@ func GenerateTokens(email, userAgent, ipAddress string) (string, string, error) return accessToken, refreshToken, nil } -func generateAccessToken(email, sessionID, userID, roleID string) (string, error) { +func generateAccessToken(email, sessionID, userID string, roleID []int) (string, error) { AccessTokenExpiration := os.Getenv("ACCESS_TOKEN_EXPIRATION_MINUTES") if AccessTokenExpiration == "" { log.Println("AccessTokenExpiration not set (in minutes), defaulting to 45 minutes") @@ -186,7 +187,7 @@ func generateAccessToken(email, sessionID, userID, roleID string) (string, error claims := &models.AccessToken{ Email: email, - UserID: userID, + UsersID: userID, RoleID: roleID, SessionID: sessionID, Exp: expirationTime, @@ -251,7 +252,7 @@ func RefreshAccessToken(refreshTokenString, userAgent, ipAddress string) (string WHERE refresh_token_hash = ? AND is_revoked = false `, refreshTokenHash).Scan( &session.ID, - &session.UserID, + &session.UsersID, &session.UserAgent, &session.IPAddress, &session.CreatedAt, @@ -265,8 +266,8 @@ func RefreshAccessToken(refreshTokenString, userAgent, ipAddress string) (string return "", fmt.Errorf("invalid refresh token: %w", err) } - helper.LogInfo(fmt.Sprintf("Session found in DB - ID: %s, UserID: %s, Created: %s, Expires: %s", - session.ID, session.UserID, session.CreatedAt.Format(timeFormatDateTime), session.ExpiresAt.Format(timeFormatDateTime))) + helper.LogInfo(fmt.Sprintf("Session found in DB - ID: %s, UsersID: %s, Created: %s, Expires: %s", + session.ID, session.UsersID, session.CreatedAt.Format(timeFormatDateTime), session.ExpiresAt.Format(timeFormatDateTime))) sessionTTL := int(time.Until(session.ExpiresAt).Seconds()) if sessionTTL > 0 { @@ -275,8 +276,8 @@ func RefreshAccessToken(refreshTokenString, userAgent, ipAddress string) (string } } } else { - helper.LogInfo(fmt.Sprintf("Session found in Redis cache - ID: %s, UserID: %s, Expires: %s", - session.ID, session.UserID, session.ExpiresAt.Format(timeFormatDateTime))) + helper.LogInfo(fmt.Sprintf("Session found in Redis cache - ID: %s, UsersID: %s, Expires: %s", + session.ID, session.UsersID, session.ExpiresAt.Format(timeFormatDateTime))) } if session.IsRevoked { @@ -307,18 +308,18 @@ func RefreshAccessToken(refreshTokenString, userAgent, ipAddress string) (string } // Get user email from user ID (with caching) - email, err := getUserEmailFromIDCached(session.UserID) + email, err := getUserEmailFromIDCached(session.UsersID) if err != nil { - helper.LogError(err, fmt.Sprintf("Failed to get email for user %s", session.UserID)) + helper.LogError(err, fmt.Sprintf("Failed to get email for user %s", session.UsersID)) // For registrants or users not yet in the main tables, we still want to allow refresh // but we need to get the email from somewhere else. Since we don't store email in session, // we'll need to handle this gracefully by allowing the refresh to continue with a placeholder // The email will be properly resolved when they complete registration - helper.LogWarn(fmt.Sprintf("User ID %s not found in database tables, but allowing refresh for potential registrant", session.UserID)) + helper.LogWarn(fmt.Sprintf("User ID %s not found in database tables, but allowing refresh for potential registrant", session.UsersID)) // For now, we'll use a placeholder email pattern and let the access token generation handle it // The system should work as long as the session is valid - email = fmt.Sprintf("registrant_%s@pending.local", session.UserID) + email = fmt.Sprintf("registrant_%s@pending.local", session.UsersID) } helper.LogInfo(fmt.Sprintf("Generating new access token for email: %s, session: %s", email, session.ID)) @@ -326,24 +327,15 @@ func RefreshAccessToken(refreshTokenString, userAgent, ipAddress string) (string userID, err := helper.FetchUserIDFromDB(email) if err != nil { helper.LogError(err, fmt.Sprintf("Failed to fetch user ID for email %s during refresh", email)) - userID = session.UserID // Fallback to session's user ID + userID = session.UsersID // Fallback to session's user ID } roleIDs, err := services.GetRoleIDsFromEmail(email) - var roleIDsStr string if err != nil { helper.LogError(err, fmt.Sprintf("Failed to fetch role ID for email %s during refresh", email)) - roleIDsStr = "" - } else { - for i, r := range roleIDs { - if i > 0 { - roleIDsStr += "," - } - roleIDsStr += fmt.Sprintf("%d", r) - } + roleIDs = []int{} } - - accessToken, err := generateAccessToken(email, session.ID, userID, roleIDsStr) + accessToken, err := generateAccessToken(email, session.ID, userID, roleIDs) if err != nil { helper.LogError(err, "Failed to generate access token during refresh") return "", fmt.Errorf("failed to generate access token: %w", err) @@ -411,7 +403,7 @@ func RefreshAccessTokenWithEmailFallback(refreshTokenString, userAgent, ipAddres WHERE refresh_token_hash = ? AND is_revoked = false `, refreshTokenHash).Scan( &session.ID, - &session.UserID, + &session.UsersID, &session.UserAgent, &session.IPAddress, &session.CreatedAt, @@ -425,8 +417,8 @@ func RefreshAccessTokenWithEmailFallback(refreshTokenString, userAgent, ipAddres return "", fmt.Errorf("invalid refresh token: %w", err) } - helper.LogInfo(fmt.Sprintf("Session found in DB - ID: %s, UserID: %s, Created: %s, Expires: %s", - session.ID, session.UserID, session.CreatedAt.Format(timeFormatDateTime), session.ExpiresAt.Format(timeFormatDateTime))) + helper.LogInfo(fmt.Sprintf("Session found in DB - ID: %s, UsersID: %s, Created: %s, Expires: %s", + session.ID, session.UsersID, session.CreatedAt.Format(timeFormatDateTime), session.ExpiresAt.Format(timeFormatDateTime))) sessionTTL := int(time.Until(session.ExpiresAt).Seconds()) if sessionTTL > 0 { @@ -435,8 +427,8 @@ func RefreshAccessTokenWithEmailFallback(refreshTokenString, userAgent, ipAddres } } } else { - helper.LogInfo(fmt.Sprintf("Session found in Redis cache - ID: %s, UserID: %s, Expires: %s", - session.ID, session.UserID, session.ExpiresAt.Format(timeFormatDateTime))) + helper.LogInfo(fmt.Sprintf("Session found in Redis cache - ID: %s, UsersID: %s, Expires: %s", + session.ID, session.UsersID, session.ExpiresAt.Format(timeFormatDateTime))) } if session.IsRevoked { @@ -467,15 +459,15 @@ func RefreshAccessTokenWithEmailFallback(refreshTokenString, userAgent, ipAddres } // Get user email from user ID (with caching), with fallback to provided email - email, err := getUserEmailFromIDCached(session.UserID) + email, err := getUserEmailFromIDCached(session.UsersID) if err != nil { - helper.LogError(err, fmt.Sprintf("Failed to get email for user %s", session.UserID)) + helper.LogError(err, fmt.Sprintf("Failed to get email for user %s", session.UsersID)) if emailFallback != "" { - helper.LogInfo(fmt.Sprintf("Using email fallback for user ID %s: %s", session.UserID, emailFallback)) + helper.LogInfo(fmt.Sprintf("Using email fallback for user ID %s: %s", session.UsersID, emailFallback)) email = emailFallback } else { - helper.LogWarn(fmt.Sprintf("User ID %s not found in database tables and no email fallback provided", session.UserID)) + helper.LogWarn(fmt.Sprintf("User ID %s not found in database tables and no email fallback provided", session.UsersID)) return "", fmt.Errorf("failed to get user email: %w", err) } } @@ -485,24 +477,15 @@ func RefreshAccessTokenWithEmailFallback(refreshTokenString, userAgent, ipAddres userID, err := helper.FetchUserIDFromDB(email) if err != nil { helper.LogError(err, fmt.Sprintf("Failed to fetch user ID for email %s during refresh", email)) - userID = session.UserID // Fallback to session's user ID + userID = session.UsersID // Fallback to session's user ID } roleIDs, err := services.GetRoleIDsFromEmail(email) - var roleIDsStr string if err != nil { helper.LogError(err, fmt.Sprintf("Failed to fetch role ID for email %s during refresh", email)) - roleIDsStr = "" - } else { - for i, r := range roleIDs { - if i > 0 { - roleIDsStr += "," - } - roleIDsStr += fmt.Sprintf("%d", r) - } + roleIDs = []int{} } - - accessToken, err := generateAccessToken(email, session.ID, userID, roleIDsStr) + accessToken, err := generateAccessToken(email, session.ID, userID, roleIDs) if err != nil { helper.LogError(err, "Failed to generate access token during refresh") return "", fmt.Errorf("failed to generate access token: %w", err) @@ -623,7 +606,7 @@ func ValidateSession(sessionID string) (*models.JWTSession, error) { WHERE id = ? `, sessionID).Scan( &session.ID, - &session.UserID, + &session.UsersID, &session.RefreshTokenHash, &session.UserAgent, &session.IPAddress, @@ -716,12 +699,12 @@ func GetSessionUserInfo(sessionID string) (string, string, error) { return "", "", fmt.Errorf("%s", errMsgSessionNotFoundInCache) } - email, err := getUserEmailFromIDCached(session.UserID) + email, err := getUserEmailFromIDCached(session.UsersID) if err != nil { return "", "", fmt.Errorf("failed to get user email: %w", err) } - return session.UserID, email, nil + return session.UsersID, email, nil } func InvalidateUserSessionsInCache(userID string) error { @@ -767,11 +750,11 @@ func CleanupExpiredSessions() error { userIDsToCleanup := make(map[string]bool) for rows.Next() { var session models.ExpiredSession - if err := rows.Scan(&session.ID, &session.UserID, &session.RefreshTokenHash); err != nil { + if err := rows.Scan(&session.ID, &session.UsersID, &session.RefreshTokenHash); err != nil { continue } expiredSessions = append(expiredSessions, session) - userIDsToCleanup[session.UserID] = true + userIDsToCleanup[session.UsersID] = true } _, err = db.DB.Exec("DELETE FROM jwt_sessions WHERE expires_at < ?", time.Now()) @@ -809,7 +792,7 @@ func GetUserSessions(userID string) ([]models.JWTSession, error) { var session models.JWTSession err := rows.Scan( &session.ID, - &session.UserID, + &session.UsersID, &session.RefreshTokenHash, &session.UserAgent, &session.IPAddress, diff --git a/middleware/jwt.go b/middleware/jwt.go index cb584ee..9d720ff 100644 --- a/middleware/jwt.go +++ b/middleware/jwt.go @@ -186,7 +186,7 @@ func validateSessionFromDB(sessionID string) (*models.JWTSession, error) { WHERE id = ? AND is_revoked = false `, sessionID).Scan( &session.ID, - &session.UserID, + &session.UsersID, &session.RefreshTokenHash, &session.UserAgent, &session.IPAddress, diff --git a/models/jwt.go b/models/jwt.go index 8c0d0d8..5c07549 100644 --- a/models/jwt.go +++ b/models/jwt.go @@ -8,8 +8,8 @@ import ( type AccessToken struct { Email string `json:"email"` - UserID string `json:"user_id"` - RoleID string `json:"role_id"` + UsersID string `json:"users_id"` + RoleID []int `json:"role_id"` SessionID string `json:"session_id"` Exp int64 `json:"exp"` jwt.RegisteredClaims @@ -17,7 +17,7 @@ type AccessToken struct { type JWTSession struct { ID string `json:"id" db:"id"` - UserID string `json:"user_id" db:"user_id"` + UsersID string `json:"users_id" db:"users_id"` RefreshTokenHash string `json:"refresh_token_hash" db:"refresh_token_hash"` UserAgent string `json:"user_agent" db:"user_agent"` IPAddress string `json:"ip_address" db:"ip_address"` @@ -29,6 +29,6 @@ type JWTSession struct { type ExpiredSession struct { ID string - UserID string + UsersID string RefreshTokenHash string } diff --git a/models/jwt_test.go b/models/jwt_test.go index 8aaac21..a6c198d 100644 --- a/models/jwt_test.go +++ b/models/jwt_test.go @@ -68,7 +68,7 @@ func TestJWTSessionCreation(t *testing.T) { now := time.Now() session := &JWTSession{ ID: "session-id-123", - UserID: "user-456", + UsersID: "user-456", RefreshTokenHash: "hash123", UserAgent: "Mozilla/5.0", IPAddress: "192.168.1.1", @@ -82,8 +82,8 @@ func TestJWTSessionCreation(t *testing.T) { t.Errorf("Expected session ID 'session-id-123', got '%s'", session.ID) } - if session.UserID != "user-456" { - t.Errorf("Expected user ID 'user-456', got '%s'", session.UserID) + if session.UsersID != "user-456" { + t.Errorf("Expected user ID 'user-456', got '%s'", session.UsersID) } if session.RefreshTokenHash != "hash123" { @@ -182,7 +182,7 @@ func TestJWTSessionRevokedStatus(t *testing.T) { func TestExpiredSessionCreation(t *testing.T) { expiredSession := ExpiredSession{ ID: "expired-id-123", - UserID: "user-789", + UsersID: "user-789", RefreshTokenHash: "expired-hash", } @@ -190,8 +190,8 @@ func TestExpiredSessionCreation(t *testing.T) { t.Errorf("Expected ID 'expired-id-123', got '%s'", expiredSession.ID) } - if expiredSession.UserID != "user-789" { - t.Errorf("Expected UserID 'user-789', got '%s'", expiredSession.UserID) + if expiredSession.UsersID != "user-789" { + t.Errorf("Expected UsersID 'user-789', got '%s'", expiredSession.UsersID) } if expiredSession.RefreshTokenHash != "expired-hash" { diff --git a/services/users.go b/services/users.go index e671bf4..f82fc02 100644 --- a/services/users.go +++ b/services/users.go @@ -29,9 +29,9 @@ func CheckEmailInDB(email string) (bool, error) { func GetUserIDFromEmail(email string) (string, error) { log.Print(email) - query := `SELECT user_id + query := `SELECT users_id FROM ( - SELECT user_id, 1 AS priority + SELECT users_id, 1 AS priority FROM users WHERE email_address = ? AND is_deleted = 0