fixed jwt parsing from HMAC to RSA

This commit is contained in:
2026-01-05 14:03:36 +08:00
parent acdc53ec24
commit 6fe17327d8
+75 -15
View File
@@ -5,8 +5,12 @@ import (
"authorization/models"
"authorization/redisclient"
"context"
"crypto/rsa"
"crypto/x509"
"encoding/json"
"encoding/pem"
"fmt"
"log"
"net/http"
"os"
"sync"
@@ -23,10 +27,10 @@ const (
)
var (
// Cache JWT secret to avoid repeated os.Getenv calls
jwtSecretOnce sync.Once
jwtSecretCached []byte
jwtSecretError error
// Cache RSA public key to avoid repeated file reads
rsaPublicKeyOnce sync.Once
rsaPublicKeyCached *rsa.PublicKey
rsaPublicKeyError error
// Pre-allocate error messages to avoid repeated allocations
errExpiredToken = "Invalid or expired token" // #nosec G101
@@ -35,17 +39,61 @@ var (
redisTokenPrefix = "jwt:token:"
)
// Initialize JWT secret once
func getJWTSecret() ([]byte, error) {
jwtSecretOnce.Do(func() {
secret := os.Getenv("JWT_KEY")
if secret == "" {
jwtSecretError = fmt.Errorf("JWT_KEY not set in environment")
func getRSAPublicKey() (*rsa.PublicKey, error) {
rsaPublicKeyOnce.Do(func() {
log.Print("Loading RSA public key from PEM certificate file")
// Read PEM file
pemData, err := os.ReadFile("rsa/ServerCertificate.pem")
if err != nil {
rsaPublicKeyError = fmt.Errorf("failed to read PEM file: %w", err)
log.Printf("Error reading PEM file: %v", rsaPublicKeyError)
return
}
jwtSecretCached = []byte(secret)
log.Print("PEM file successfully read")
// Parse PEM blocks to find the certificate
var certBlock *pem.Block
for {
block, rest := pem.Decode(pemData)
if block == nil {
break
}
if block.Type == "CERTIFICATE" {
certBlock = block
log.Print("Certificate block found in PEM file")
break
}
pemData = rest
}
if certBlock == nil {
rsaPublicKeyError = fmt.Errorf("no certificate block found in PEM file")
log.Printf("Error: %v", rsaPublicKeyError)
return
}
// Parse certificate
cert, err := x509.ParseCertificate(certBlock.Bytes)
if err != nil {
rsaPublicKeyError = fmt.Errorf("failed to parse certificate: %w", err)
log.Printf("Error parsing certificate: %v", rsaPublicKeyError)
return
}
log.Print("Certificate successfully parsed")
// Extract RSA public key
publicKey, ok := cert.PublicKey.(*rsa.PublicKey)
if !ok {
rsaPublicKeyError = fmt.Errorf("certificate does not contain RSA public key")
log.Printf("Error: %v", rsaPublicKeyError)
return
}
rsaPublicKeyCached = publicKey
log.Printf("RSA public key successfully loaded and cached (key size: %d bits)", publicKey.N.BitLen())
})
return jwtSecretCached, jwtSecretError
return rsaPublicKeyCached, rsaPublicKeyError
}
// extractBearerToken extracts token from Authorization header
@@ -86,28 +134,38 @@ func checkTokenCache(tokenString string) (*models.Claims, bool) {
// delete(tokenCache, tokenString)
// }
// parseAndValidateToken parses JWT token and validates it
// parseAndValidateToken parses JWT token and validates it using RSA
func parseAndValidateToken(tokenString string) (*models.Claims, error) {
log.Print("Starting JWT token parsing and verification with RSA")
token, err := jwt.ParseWithClaims(tokenString, &models.Claims{}, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
log.Printf("Token verification failed: unexpected signing method: %v", token.Header["alg"])
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return getJWTSecret()
log.Print("Token signing method verified: RSA")
return getRSAPublicKey()
})
if err != nil {
log.Printf("Token parsing failed: %v", err)
return nil, err
}
if !token.Valid {
log.Print("Token validation failed: token is invalid")
return nil, fmt.Errorf("invalid token")
}
log.Print("Token successfully parsed and validated")
claims, ok := token.Claims.(*models.Claims)
if !ok {
log.Print("Token validation failed: invalid claims structure")
return nil, fmt.Errorf("invalid claims")
}
log.Printf("Token verified successfully for user: %s (UserID: %s)", claims.Username, claims.UserID)
return claims, nil
}
@@ -157,6 +215,7 @@ func JWTAuth(next http.HandlerFunc) http.HandlerFunc {
return
}
log.Print("1")
// Parse and validate token
claims, err := parseAndValidateToken(tokenString)
if err != nil {
@@ -164,6 +223,7 @@ func JWTAuth(next http.HandlerFunc) http.HandlerFunc {
return
}
log.Print("2")
// Cache the validated token
cacheToken(tokenString, claims)