Merge branch 'rj' into 'main'

Refactored refresh token endpoint

See merge request psa/uess/authn!6
This commit is contained in:
2026-03-13 16:57:47 +08:00
2 changed files with 29 additions and 16 deletions
+27 -14
View File
@@ -473,14 +473,16 @@ func HandleTokenRefresh(w http.ResponseWriter, r *http.Request) {
helper.LogInfo("Refresh token handler called") helper.LogInfo("Refresh token handler called")
authHeader := r.Header.Get("Authorization") authHeader := r.Header.Get("Authorization")
helper.LogInfo("Authorization header: " + authHeader) helper.LogInfo("Authorization header: " + authHeader)
if authHeader != "" && strings.HasPrefix(authHeader, bearerPrefix) { if authHeader != "" && strings.HasPrefix(authHeader, bearerPrefix) {
accessToken := strings.TrimSpace(strings.TrimPrefix(authHeader, bearerPrefix)) accessToken := strings.TrimSpace(strings.TrimPrefix(authHeader, bearerPrefix))
helper.LogInfo("Access token from header: " + accessToken) helper.LogInfo("Access token from header: " + accessToken)
token, err := jwt.ParseWithClaims(accessToken, &models.AccessToken{}, func(token *jwt.Token) (interface{}, error) { token, err := jwt.ParseWithClaims(accessToken, &models.AccessToken{}, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
} }
return []byte(os.Getenv("JWT_SECRET_KEY")), nil return &rsaPrivateKey.PublicKey, nil
}) })
helper.LogInfo("Parsed token: " + fmt.Sprintf("%v", token)) helper.LogInfo("Parsed token: " + fmt.Sprintf("%v", token))
@@ -517,7 +519,7 @@ func HandleTokenRefresh(w http.ResponseWriter, r *http.Request) {
if err != nil && !strings.Contains(err.Error(), "expired") && !strings.Contains(err.Error(), "used before issued") { if err != nil && !strings.Contains(err.Error(), "expired") && !strings.Contains(err.Error(), "used before issued") {
helper.LogError(err, "Invalid access token format") helper.LogError(err, "Invalid access token format")
helper.RespondWithError(w, http.StatusBadRequest, "Invalid access token format") helper.RespondWithError(w, http.StatusBadRequest, err.Error())
return return
} }
helper.LogInfo("Access token is expired or invalid, proceeding with refresh") helper.LogInfo("Access token is expired or invalid, proceeding with refresh")
@@ -530,22 +532,33 @@ func HandleTokenRefresh(w http.ResponseWriter, r *http.Request) {
i, cookie.Name, len(cookie.Value), cookie.Domain, cookie.Path)) i, cookie.Name, len(cookie.Value), cookie.Domain, cookie.Path))
} }
var refreshToken string
cookie, err := r.Cookie("refresh_token") cookie, err := r.Cookie("refresh_token")
helper.LogInfo("TRACE: Cookie retrieval - error: " + fmt.Sprintf("%v", err)) helper.LogInfo("TRACE: Cookie retrieval - error: " + fmt.Sprintf("%v", err))
if err != nil {
helper.LogError(err, "Refresh token cookie not found") if err == nil && cookie.Value != "" {
refreshToken = cookie.Value
helper.LogInfo("TRACE: Refresh token retrieved from cookie")
} else {
var body struct {
RefreshToken string `json:"refresh_token"`
}
decodeErr := json.NewDecoder(r.Body).Decode(&body)
if decodeErr == nil && body.RefreshToken != "" {
refreshToken = body.RefreshToken
helper.LogInfo("TRACE: Refresh token retrieved from request body")
}
}
if refreshToken == "" {
helper.LogError(errors.New("refresh token not found"), "Refresh token not found in cookie or body")
helper.RespondWithError(w, http.StatusUnauthorized, "Refresh token not found") helper.RespondWithError(w, http.StatusUnauthorized, "Refresh token not found")
return return
} }
helper.LogInfo("TRACE: Refresh token length: " + fmt.Sprintf("%d", len(refreshToken)))
refreshToken := cookie.Value
helper.LogInfo("TRACE: Refresh token from cookie - length: " + fmt.Sprintf("%d", len(refreshToken)))
if refreshToken == "" {
helper.LogError(errors.New("refresh token cookie is empty"), "refresh token cookie is empty")
helper.RespondWithError(w, http.StatusUnauthorized, "refresh token is empty")
return
}
// Get client info for security validation // Get client info for security validation
userAgent := r.Header.Get("User-Agent") userAgent := r.Header.Get("User-Agent")
ipAddress := getIPAddress(r) ipAddress := getIPAddress(r)
+2 -2
View File
@@ -625,7 +625,7 @@ func RefreshAccessTokenWithEmailFallback(refreshTokenString, userAgent, ipAddres
} }
go func() { go func() {
_, err := db.DB.Exec("UPDATE jwt_sessions SET updated_at = ? WHERE id = ?", session.UpdatedAt, session.ID) _, err := db.DB.Exec("UPDATE jwt_sessions SET updated_at = ? WHERE jwt_sessions_id = ?", session.UpdatedAt, session.ID)
if err != nil { if err != nil {
helper.LogError(err, "Failed to update session activity in DB") helper.LogError(err, "Failed to update session activity in DB")
} }
@@ -949,7 +949,7 @@ func UpdateSessionLastActivity(sessionID string) error {
func getUserEmailFromID(userID string) (string, error) { func getUserEmailFromID(userID string) (string, error) {
var email string var email string
err := db.DB.QueryRow("SELECT email_address FROM users WHERE id = ?", userID).Scan(&email) err := db.DB.QueryRow("SELECT email_address FROM users WHERE users_id = ?", userID).Scan(&email)
if err == nil { if err == nil {
return email, nil return email, nil
} }