fixed jwt parsing from HMAC to RSA
This commit is contained in:
+75
-15
@@ -5,8 +5,12 @@ import (
|
||||
"authorization/models"
|
||||
"authorization/redisclient"
|
||||
"context"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"sync"
|
||||
@@ -23,10 +27,10 @@ const (
|
||||
)
|
||||
|
||||
var (
|
||||
// Cache JWT secret to avoid repeated os.Getenv calls
|
||||
jwtSecretOnce sync.Once
|
||||
jwtSecretCached []byte
|
||||
jwtSecretError error
|
||||
// Cache RSA public key to avoid repeated file reads
|
||||
rsaPublicKeyOnce sync.Once
|
||||
rsaPublicKeyCached *rsa.PublicKey
|
||||
rsaPublicKeyError error
|
||||
|
||||
// Pre-allocate error messages to avoid repeated allocations
|
||||
errExpiredToken = "Invalid or expired token" // #nosec G101
|
||||
@@ -35,17 +39,61 @@ var (
|
||||
redisTokenPrefix = "jwt:token:"
|
||||
)
|
||||
|
||||
// Initialize JWT secret once
|
||||
func getJWTSecret() ([]byte, error) {
|
||||
jwtSecretOnce.Do(func() {
|
||||
secret := os.Getenv("JWT_KEY")
|
||||
if secret == "" {
|
||||
jwtSecretError = fmt.Errorf("JWT_KEY not set in environment")
|
||||
func getRSAPublicKey() (*rsa.PublicKey, error) {
|
||||
rsaPublicKeyOnce.Do(func() {
|
||||
log.Print("Loading RSA public key from PEM certificate file")
|
||||
|
||||
// Read PEM file
|
||||
pemData, err := os.ReadFile("rsa/ServerCertificate.pem")
|
||||
if err != nil {
|
||||
rsaPublicKeyError = fmt.Errorf("failed to read PEM file: %w", err)
|
||||
log.Printf("Error reading PEM file: %v", rsaPublicKeyError)
|
||||
return
|
||||
}
|
||||
jwtSecretCached = []byte(secret)
|
||||
log.Print("PEM file successfully read")
|
||||
|
||||
// Parse PEM blocks to find the certificate
|
||||
var certBlock *pem.Block
|
||||
for {
|
||||
block, rest := pem.Decode(pemData)
|
||||
if block == nil {
|
||||
break
|
||||
}
|
||||
if block.Type == "CERTIFICATE" {
|
||||
certBlock = block
|
||||
log.Print("Certificate block found in PEM file")
|
||||
break
|
||||
}
|
||||
pemData = rest
|
||||
}
|
||||
|
||||
if certBlock == nil {
|
||||
rsaPublicKeyError = fmt.Errorf("no certificate block found in PEM file")
|
||||
log.Printf("Error: %v", rsaPublicKeyError)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse certificate
|
||||
cert, err := x509.ParseCertificate(certBlock.Bytes)
|
||||
if err != nil {
|
||||
rsaPublicKeyError = fmt.Errorf("failed to parse certificate: %w", err)
|
||||
log.Printf("Error parsing certificate: %v", rsaPublicKeyError)
|
||||
return
|
||||
}
|
||||
log.Print("Certificate successfully parsed")
|
||||
|
||||
// Extract RSA public key
|
||||
publicKey, ok := cert.PublicKey.(*rsa.PublicKey)
|
||||
if !ok {
|
||||
rsaPublicKeyError = fmt.Errorf("certificate does not contain RSA public key")
|
||||
log.Printf("Error: %v", rsaPublicKeyError)
|
||||
return
|
||||
}
|
||||
|
||||
rsaPublicKeyCached = publicKey
|
||||
log.Printf("RSA public key successfully loaded and cached (key size: %d bits)", publicKey.N.BitLen())
|
||||
})
|
||||
return jwtSecretCached, jwtSecretError
|
||||
return rsaPublicKeyCached, rsaPublicKeyError
|
||||
}
|
||||
|
||||
// extractBearerToken extracts token from Authorization header
|
||||
@@ -86,28 +134,38 @@ func checkTokenCache(tokenString string) (*models.Claims, bool) {
|
||||
// delete(tokenCache, tokenString)
|
||||
// }
|
||||
|
||||
// parseAndValidateToken parses JWT token and validates it
|
||||
// parseAndValidateToken parses JWT token and validates it using RSA
|
||||
func parseAndValidateToken(tokenString string) (*models.Claims, error) {
|
||||
log.Print("Starting JWT token parsing and verification with RSA")
|
||||
|
||||
token, err := jwt.ParseWithClaims(tokenString, &models.Claims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
|
||||
log.Printf("Token verification failed: unexpected signing method: %v", token.Header["alg"])
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
return getJWTSecret()
|
||||
log.Print("Token signing method verified: RSA")
|
||||
return getRSAPublicKey()
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Printf("Token parsing failed: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !token.Valid {
|
||||
log.Print("Token validation failed: token is invalid")
|
||||
return nil, fmt.Errorf("invalid token")
|
||||
}
|
||||
|
||||
log.Print("Token successfully parsed and validated")
|
||||
|
||||
claims, ok := token.Claims.(*models.Claims)
|
||||
if !ok {
|
||||
log.Print("Token validation failed: invalid claims structure")
|
||||
return nil, fmt.Errorf("invalid claims")
|
||||
}
|
||||
|
||||
log.Printf("Token verified successfully for user: %s (UserID: %s)", claims.Username, claims.UserID)
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
@@ -157,6 +215,7 @@ func JWTAuth(next http.HandlerFunc) http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
log.Print("1")
|
||||
// Parse and validate token
|
||||
claims, err := parseAndValidateToken(tokenString)
|
||||
if err != nil {
|
||||
@@ -164,6 +223,7 @@ func JWTAuth(next http.HandlerFunc) http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
log.Print("2")
|
||||
// Cache the validated token
|
||||
cacheToken(tokenString, claims)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user