From b6ab8d31870673822d48f88dce62f5ccd7a6c2b0 Mon Sep 17 00:00:00 2001 From: F04C Date: Fri, 13 Mar 2026 16:51:10 +0800 Subject: [PATCH] Refactored refresh token endpoint --- handlers/google_auth.go | 41 +++++++++++++++++++++++++++-------------- handlers/jwt.go | 4 ++-- 2 files changed, 29 insertions(+), 16 deletions(-) diff --git a/handlers/google_auth.go b/handlers/google_auth.go index 5175a04..48be7eb 100644 --- a/handlers/google_auth.go +++ b/handlers/google_auth.go @@ -473,14 +473,16 @@ func HandleTokenRefresh(w http.ResponseWriter, r *http.Request) { helper.LogInfo("Refresh token handler called") authHeader := r.Header.Get("Authorization") helper.LogInfo("Authorization header: " + authHeader) + if authHeader != "" && strings.HasPrefix(authHeader, bearerPrefix) { accessToken := strings.TrimSpace(strings.TrimPrefix(authHeader, bearerPrefix)) helper.LogInfo("Access token from header: " + accessToken) + token, err := jwt.ParseWithClaims(accessToken, &models.AccessToken{}, func(token *jwt.Token) (interface{}, error) { - if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } - return []byte(os.Getenv("JWT_SECRET_KEY")), nil + return &rsaPrivateKey.PublicKey, nil }) helper.LogInfo("Parsed token: " + fmt.Sprintf("%v", token)) @@ -517,7 +519,7 @@ func HandleTokenRefresh(w http.ResponseWriter, r *http.Request) { if err != nil && !strings.Contains(err.Error(), "expired") && !strings.Contains(err.Error(), "used before issued") { helper.LogError(err, "Invalid access token format") - helper.RespondWithError(w, http.StatusBadRequest, "Invalid access token format") + helper.RespondWithError(w, http.StatusBadRequest, err.Error()) return } helper.LogInfo("Access token is expired or invalid, proceeding with refresh") @@ -530,22 +532,33 @@ func HandleTokenRefresh(w http.ResponseWriter, r *http.Request) { i, cookie.Name, len(cookie.Value), cookie.Domain, cookie.Path)) } + var refreshToken string + cookie, err := r.Cookie("refresh_token") helper.LogInfo("TRACE: Cookie retrieval - error: " + fmt.Sprintf("%v", err)) - if err != nil { - helper.LogError(err, "Refresh token cookie not found") + + if err == nil && cookie.Value != "" { + refreshToken = cookie.Value + helper.LogInfo("TRACE: Refresh token retrieved from cookie") + } else { + + var body struct { + RefreshToken string `json:"refresh_token"` + } + + decodeErr := json.NewDecoder(r.Body).Decode(&body) + if decodeErr == nil && body.RefreshToken != "" { + refreshToken = body.RefreshToken + helper.LogInfo("TRACE: Refresh token retrieved from request body") + } + } + + if refreshToken == "" { + helper.LogError(errors.New("refresh token not found"), "Refresh token not found in cookie or body") helper.RespondWithError(w, http.StatusUnauthorized, "Refresh token not found") return } - - refreshToken := cookie.Value - helper.LogInfo("TRACE: Refresh token from cookie - length: " + fmt.Sprintf("%d", len(refreshToken))) - if refreshToken == "" { - helper.LogError(errors.New("refresh token cookie is empty"), "refresh token cookie is empty") - helper.RespondWithError(w, http.StatusUnauthorized, "refresh token is empty") - return - } - + helper.LogInfo("TRACE: Refresh token length: " + fmt.Sprintf("%d", len(refreshToken))) // Get client info for security validation userAgent := r.Header.Get("User-Agent") ipAddress := getIPAddress(r) diff --git a/handlers/jwt.go b/handlers/jwt.go index b733688..89ffa05 100644 --- a/handlers/jwt.go +++ b/handlers/jwt.go @@ -625,7 +625,7 @@ func RefreshAccessTokenWithEmailFallback(refreshTokenString, userAgent, ipAddres } go func() { - _, err := db.DB.Exec("UPDATE jwt_sessions SET updated_at = ? WHERE id = ?", session.UpdatedAt, session.ID) + _, err := db.DB.Exec("UPDATE jwt_sessions SET updated_at = ? WHERE jwt_sessions_id = ?", session.UpdatedAt, session.ID) if err != nil { helper.LogError(err, "Failed to update session activity in DB") } @@ -949,7 +949,7 @@ func UpdateSessionLastActivity(sessionID string) error { func getUserEmailFromID(userID string) (string, error) { var email string - err := db.DB.QueryRow("SELECT email_address FROM users WHERE id = ?", userID).Scan(&email) + err := db.DB.QueryRow("SELECT email_address FROM users WHERE users_id = ?", userID).Scan(&email) if err == nil { return email, nil }