264 lines
7.2 KiB
Go
264 lines
7.2 KiB
Go
package middleware
|
|
|
|
import (
|
|
"authorization/helper"
|
|
"authorization/models"
|
|
"authorization/redisclient"
|
|
"context"
|
|
"crypto/rsa"
|
|
"crypto/x509"
|
|
"encoding/json"
|
|
"encoding/pem"
|
|
"fmt"
|
|
"log"
|
|
"net/http"
|
|
"os"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/golang-jwt/jwt/v5"
|
|
)
|
|
|
|
const (
|
|
claimsKey models.ContextKey = "claims"
|
|
userIDKey models.ContextKey = "user_id"
|
|
roleIDKey models.ContextKey = "role_id"
|
|
)
|
|
|
|
var (
|
|
// Cache RSA public key to avoid repeated file reads
|
|
rsaPublicKeyOnce sync.Once
|
|
rsaPublicKeyCached *rsa.PublicKey
|
|
rsaPublicKeyError error
|
|
|
|
// Pre-allocate error messages to avoid repeated allocations
|
|
errExpiredToken = "Invalid or expired token" // #nosec G101
|
|
|
|
// Redis key prefix for token cache
|
|
redisTokenPrefix = "jwt:token:"
|
|
)
|
|
|
|
func getRSAPublicKey() (*rsa.PublicKey, error) {
|
|
rsaPublicKeyOnce.Do(func() {
|
|
log.Print("Loading RSA public key from PEM certificate file")
|
|
|
|
// Read PEM file - use path relative to executable or try both common paths
|
|
pemData, err := os.ReadFile("rsa/ServerCertificate.pem")
|
|
if err != nil {
|
|
// Try alternate path when running tests from subdirectory
|
|
pemData, err = os.ReadFile("../rsa/ServerCertificate.pem")
|
|
if err != nil {
|
|
rsaPublicKeyError = fmt.Errorf("failed to read PEM file: %w", err)
|
|
log.Printf("Error reading PEM file: %v", rsaPublicKeyError)
|
|
return
|
|
}
|
|
}
|
|
log.Print("PEM file successfully read")
|
|
|
|
// Parse PEM blocks to find the certificate
|
|
var certBlock *pem.Block
|
|
for {
|
|
block, rest := pem.Decode(pemData)
|
|
if block == nil {
|
|
break
|
|
}
|
|
if block.Type == "CERTIFICATE" {
|
|
certBlock = block
|
|
log.Print("Certificate block found in PEM file")
|
|
break
|
|
}
|
|
pemData = rest
|
|
}
|
|
|
|
if certBlock == nil {
|
|
rsaPublicKeyError = fmt.Errorf("no certificate block found in PEM file")
|
|
log.Printf("Error: %v", rsaPublicKeyError)
|
|
return
|
|
}
|
|
|
|
// Parse certificate
|
|
cert, err := x509.ParseCertificate(certBlock.Bytes)
|
|
if err != nil {
|
|
rsaPublicKeyError = fmt.Errorf("failed to parse certificate: %w", err)
|
|
log.Printf("Error parsing certificate: %v", rsaPublicKeyError)
|
|
return
|
|
}
|
|
log.Print("Certificate successfully parsed")
|
|
|
|
// Extract RSA public key
|
|
publicKey, ok := cert.PublicKey.(*rsa.PublicKey)
|
|
if !ok {
|
|
rsaPublicKeyError = fmt.Errorf("certificate does not contain RSA public key")
|
|
log.Printf("Error: %v", rsaPublicKeyError)
|
|
return
|
|
}
|
|
|
|
rsaPublicKeyCached = publicKey
|
|
log.Printf("RSA public key successfully loaded and cached (key size: %d bits)", publicKey.N.BitLen())
|
|
})
|
|
return rsaPublicKeyCached, rsaPublicKeyError
|
|
}
|
|
|
|
// extractBearerToken extracts token from Authorization header
|
|
func extractBearerToken(authHeader string) (string, bool) {
|
|
if authHeader == "" || len(authHeader) < 8 || authHeader[:7] != "Bearer " {
|
|
return "", false
|
|
}
|
|
return authHeader[7:], true
|
|
}
|
|
|
|
// checkTokenCache retrieves token from Redis cache if valid
|
|
func checkTokenCache(tokenString string) (*models.Claims, bool) {
|
|
if redisclient.RDB == nil {
|
|
return nil, false
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
|
defer cancel()
|
|
|
|
key := redisTokenPrefix + tokenString
|
|
val, err := redisclient.RDB.Get(ctx, key).Result()
|
|
if err != nil {
|
|
return nil, false
|
|
}
|
|
|
|
var claims models.Claims
|
|
if err := json.Unmarshal([]byte(val), &claims); err != nil {
|
|
return nil, false
|
|
}
|
|
|
|
return &claims, true
|
|
}
|
|
|
|
// removeExpiredCacheEntry removes a single expired token from cache
|
|
// func removeExpiredCacheEntry(tokenString string) {
|
|
// tokenCacheMutex.Lock()
|
|
// defer tokenCacheMutex.Unlock()
|
|
// delete(tokenCache, tokenString)
|
|
// }
|
|
|
|
// parseAndValidateToken parses JWT token and validates it using RSA
|
|
func parseAndValidateToken(tokenString string) (*models.Claims, error) {
|
|
log.Print("Starting JWT token parsing and verification with RSA")
|
|
|
|
token, err := jwt.ParseWithClaims(tokenString, &models.Claims{}, func(token *jwt.Token) (interface{}, error) {
|
|
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
|
|
log.Printf("Token verification failed: unexpected signing method: %v", token.Header["alg"])
|
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
|
}
|
|
log.Print("Token signing method verified: RSA")
|
|
return getRSAPublicKey()
|
|
})
|
|
|
|
if err != nil {
|
|
log.Printf("Token parsing failed: %v", err)
|
|
return nil, err
|
|
}
|
|
|
|
if !token.Valid {
|
|
log.Print("Token validation failed: token is invalid")
|
|
return nil, fmt.Errorf("invalid token")
|
|
}
|
|
|
|
log.Print("Token successfully parsed and validated")
|
|
|
|
claims, ok := token.Claims.(*models.Claims)
|
|
if !ok {
|
|
log.Print("Token validation failed: invalid claims structure")
|
|
return nil, fmt.Errorf("invalid claims")
|
|
}
|
|
|
|
log.Printf("Token verified successfully for user: (UserID: %s)", claims.UserID)
|
|
return claims, nil
|
|
}
|
|
|
|
// cacheToken stores validated token in Redis cache
|
|
func cacheToken(tokenString string, claims *models.Claims) {
|
|
if redisclient.RDB == nil {
|
|
return
|
|
}
|
|
|
|
// Calculate TTL
|
|
ttl := 5 * time.Minute
|
|
if claims.ExpiresAt != nil {
|
|
timeUntilExpiry := time.Until(claims.ExpiresAt.Time)
|
|
if timeUntilExpiry > 0 && timeUntilExpiry < ttl {
|
|
ttl = timeUntilExpiry
|
|
}
|
|
}
|
|
|
|
// Serialize claims to JSON
|
|
claimsJSON, err := json.Marshal(claims)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
// Store in Redis with TTL
|
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
|
defer cancel()
|
|
|
|
key := redisTokenPrefix + tokenString
|
|
redisclient.RDB.Set(ctx, key, claimsJSON, ttl)
|
|
}
|
|
|
|
// JWTAuth is a middleware that validates JWT tokens with caching for high-frequency requests
|
|
func JWTAuth(next http.HandlerFunc) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
// Extract token from header
|
|
tokenString, ok := extractBearerToken(r.Header.Get("Authorization"))
|
|
if !ok {
|
|
helper.RespondWithError(w, http.StatusUnauthorized, "Unauthorized")
|
|
return
|
|
}
|
|
|
|
// Check cache first
|
|
if claims, found := checkTokenCache(tokenString); found {
|
|
ctx := buildContext(r.Context(), claims)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
return
|
|
}
|
|
|
|
log.Print("1")
|
|
// Parse and validate token
|
|
claims, err := parseAndValidateToken(tokenString)
|
|
if err != nil {
|
|
helper.RespondWithError(w, http.StatusUnauthorized, errExpiredToken)
|
|
return
|
|
}
|
|
|
|
log.Print("2")
|
|
// Cache the validated token
|
|
cacheToken(tokenString, claims)
|
|
|
|
// Add claims to context and proceed
|
|
ctx := buildContext(r.Context(), claims)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
}
|
|
}
|
|
|
|
// buildContext efficiently builds context with claims (reduces allocations)
|
|
func buildContext(parent context.Context, claims *models.Claims) context.Context {
|
|
ctx := context.WithValue(parent, claimsKey, claims)
|
|
ctx = context.WithValue(ctx, userIDKey, claims.UserID)
|
|
ctx = context.WithValue(ctx, roleIDKey, claims.RoleID)
|
|
return ctx
|
|
}
|
|
|
|
// GetClaims retrieves the JWT claims from the request context
|
|
func GetClaims(r *http.Request) (*models.Claims, bool) {
|
|
claims, ok := r.Context().Value(claimsKey).(*models.Claims)
|
|
return claims, ok
|
|
}
|
|
|
|
// GetUserID retrieves the user ID from the request context
|
|
func GetUserID(r *http.Request) (string, bool) {
|
|
userID, ok := r.Context().Value(userIDKey).(string)
|
|
return userID, ok
|
|
}
|
|
|
|
// GetRole retrieves the role from the request context
|
|
func GetRole(r *http.Request) (string, bool) {
|
|
role, ok := r.Context().Value(roleIDKey).(string)
|
|
return role, ok
|
|
}
|