Merge branch 'rj' into 'main'
Refactored refresh token endpoint See merge request psa/uess/authn!6
This commit is contained in:
+27
-14
@@ -473,14 +473,16 @@ func HandleTokenRefresh(w http.ResponseWriter, r *http.Request) {
|
||||
helper.LogInfo("Refresh token handler called")
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
helper.LogInfo("Authorization header: " + authHeader)
|
||||
|
||||
if authHeader != "" && strings.HasPrefix(authHeader, bearerPrefix) {
|
||||
accessToken := strings.TrimSpace(strings.TrimPrefix(authHeader, bearerPrefix))
|
||||
helper.LogInfo("Access token from header: " + accessToken)
|
||||
|
||||
token, err := jwt.ParseWithClaims(accessToken, &models.AccessToken{}, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
return []byte(os.Getenv("JWT_SECRET_KEY")), nil
|
||||
return &rsaPrivateKey.PublicKey, nil
|
||||
})
|
||||
helper.LogInfo("Parsed token: " + fmt.Sprintf("%v", token))
|
||||
|
||||
@@ -517,7 +519,7 @@ func HandleTokenRefresh(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
if err != nil && !strings.Contains(err.Error(), "expired") && !strings.Contains(err.Error(), "used before issued") {
|
||||
helper.LogError(err, "Invalid access token format")
|
||||
helper.RespondWithError(w, http.StatusBadRequest, "Invalid access token format")
|
||||
helper.RespondWithError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
helper.LogInfo("Access token is expired or invalid, proceeding with refresh")
|
||||
@@ -530,22 +532,33 @@ func HandleTokenRefresh(w http.ResponseWriter, r *http.Request) {
|
||||
i, cookie.Name, len(cookie.Value), cookie.Domain, cookie.Path))
|
||||
}
|
||||
|
||||
var refreshToken string
|
||||
|
||||
cookie, err := r.Cookie("refresh_token")
|
||||
helper.LogInfo("TRACE: Cookie retrieval - error: " + fmt.Sprintf("%v", err))
|
||||
if err != nil {
|
||||
helper.LogError(err, "Refresh token cookie not found")
|
||||
|
||||
if err == nil && cookie.Value != "" {
|
||||
refreshToken = cookie.Value
|
||||
helper.LogInfo("TRACE: Refresh token retrieved from cookie")
|
||||
} else {
|
||||
|
||||
var body struct {
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
}
|
||||
|
||||
decodeErr := json.NewDecoder(r.Body).Decode(&body)
|
||||
if decodeErr == nil && body.RefreshToken != "" {
|
||||
refreshToken = body.RefreshToken
|
||||
helper.LogInfo("TRACE: Refresh token retrieved from request body")
|
||||
}
|
||||
}
|
||||
|
||||
if refreshToken == "" {
|
||||
helper.LogError(errors.New("refresh token not found"), "Refresh token not found in cookie or body")
|
||||
helper.RespondWithError(w, http.StatusUnauthorized, "Refresh token not found")
|
||||
return
|
||||
}
|
||||
|
||||
refreshToken := cookie.Value
|
||||
helper.LogInfo("TRACE: Refresh token from cookie - length: " + fmt.Sprintf("%d", len(refreshToken)))
|
||||
if refreshToken == "" {
|
||||
helper.LogError(errors.New("refresh token cookie is empty"), "refresh token cookie is empty")
|
||||
helper.RespondWithError(w, http.StatusUnauthorized, "refresh token is empty")
|
||||
return
|
||||
}
|
||||
|
||||
helper.LogInfo("TRACE: Refresh token length: " + fmt.Sprintf("%d", len(refreshToken)))
|
||||
// Get client info for security validation
|
||||
userAgent := r.Header.Get("User-Agent")
|
||||
ipAddress := getIPAddress(r)
|
||||
|
||||
+2
-2
@@ -625,7 +625,7 @@ func RefreshAccessTokenWithEmailFallback(refreshTokenString, userAgent, ipAddres
|
||||
}
|
||||
|
||||
go func() {
|
||||
_, err := db.DB.Exec("UPDATE jwt_sessions SET updated_at = ? WHERE id = ?", session.UpdatedAt, session.ID)
|
||||
_, err := db.DB.Exec("UPDATE jwt_sessions SET updated_at = ? WHERE jwt_sessions_id = ?", session.UpdatedAt, session.ID)
|
||||
if err != nil {
|
||||
helper.LogError(err, "Failed to update session activity in DB")
|
||||
}
|
||||
@@ -949,7 +949,7 @@ func UpdateSessionLastActivity(sessionID string) error {
|
||||
func getUserEmailFromID(userID string) (string, error) {
|
||||
var email string
|
||||
|
||||
err := db.DB.QueryRow("SELECT email_address FROM users WHERE id = ?", userID).Scan(&email)
|
||||
err := db.DB.QueryRow("SELECT email_address FROM users WHERE users_id = ?", userID).Scan(&email)
|
||||
if err == nil {
|
||||
return email, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user