This commit is contained in:
2026-03-16 09:27:08 +08:00
parent 4cd58e4fed
commit c76e64f87c
2 changed files with 44 additions and 14 deletions
+29 -12
View File
@@ -2,6 +2,7 @@ package helper
import (
"authentication/models"
"encoding/pem"
"errors"
"log"
"os"
@@ -20,14 +21,22 @@ func ExtractEmailFromToken(tokenString string) (string, error) {
}
token, err := jwt.ParseWithClaims(tokenString, &models.AccessToken{}, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, errors.New("unexpected signing method")
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
return nil, errors.New("unexpected signing method: expected RSA")
}
secretKey := os.Getenv("JWT_SECRET_KEY")
if secretKey == "" {
return nil, errors.New("JWT secret key not set")
publicKeyPEM := os.Getenv("JWT_PUBLIC_KEY")
if publicKeyPEM == "" {
return nil, errors.New("JWT public key not set")
}
return []byte(secretKey), nil
block, _ := pem.Decode([]byte(publicKeyPEM))
if block == nil {
return nil, errors.New("failed to decode PEM block")
}
pubKey, err := jwt.ParseRSAPublicKeyFromPEM([]byte(publicKeyPEM))
if err != nil {
return nil, errors.New("failed to parse RSA public key")
}
return pubKey, nil
})
if err == nil && token.Valid {
@@ -42,14 +51,22 @@ func ExtractEmailFromToken(tokenString string) (string, error) {
// If AccessToken parsing failed, try MapClaims for backward compatibility
log.Printf("AccessToken parsing failed: %v, trying MapClaims fallback", err)
token, err = jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, errors.New("unexpected signing method")
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
return nil, errors.New("unexpected signing method: expected RSA")
}
secretKey := os.Getenv("JWT_SECRET_KEY")
if secretKey == "" {
return nil, errors.New("JWT secret key not set")
publicKeyPEM := os.Getenv("JWT_PUBLIC_KEY")
if publicKeyPEM == "" {
return nil, errors.New("JWT public key not set")
}
return []byte(secretKey), nil
block, _ := pem.Decode([]byte(publicKeyPEM))
if block == nil {
return nil, errors.New("failed to decode PEM block")
}
pubKey, err := jwt.ParseRSAPublicKeyFromPEM([]byte(publicKeyPEM))
if err != nil {
return nil, errors.New("failed to parse RSA public key")
}
return pubKey, nil
})
if err != nil {
+15 -2
View File
@@ -4,6 +4,7 @@ package middleware
import (
"context"
"database/sql"
"encoding/pem"
"fmt"
"net/http"
"net/url"
@@ -155,10 +156,22 @@ func isSessionBlacklisted(sessionID string) bool {
func parseToken(tokenString, secretKey string) (*jwt.Token, error) {
return jwt.ParseWithClaims(tokenString, jwt.MapClaims{}, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
if token.Method != jwt.SigningMethodRS256 {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return []byte(secretKey), nil
publicKeyPEM := os.Getenv("JWT_PUBLIC_KEY")
if publicKeyPEM == "" {
return nil, fmt.Errorf("JWT public key not set")
}
block, _ := pem.Decode([]byte(publicKeyPEM))
if block == nil {
return nil, fmt.Errorf("failed to decode PEM block")
}
pubKey, err := jwt.ParseRSAPublicKeyFromPEM([]byte(publicKeyPEM))
if err != nil {
return nil, fmt.Errorf("failed to parse RSA public key")
}
return pubKey, nil
})
}