0d8f5b9600
- Add /health and /ready endpoints for load balancer health checks - Replace in-memory JWT token cache with Redis for multi-replica support - Reduce DB connection pool from 100 to 25 connections per replica - Add distributed rate limiting (100 req/min + 20 burst) using Redis - Implement circuit breakers for DB and Redis to prevent cascading failures This enables the service to scale horizontally with multiple replicas behind a load balancer without exhausting database connections or maintaining separate token caches per instance.
208 lines
5.4 KiB
Go
208 lines
5.4 KiB
Go
package middleware
|
|
|
|
import (
|
|
"authorization/helper"
|
|
"authorization/models"
|
|
"authorization/redisclient"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"os"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/golang-jwt/jwt/v5"
|
|
)
|
|
|
|
const (
|
|
claimsKey models.ContextKey = "claims"
|
|
userIDKey models.ContextKey = "user_id"
|
|
usernameKey models.ContextKey = "username"
|
|
roleKey models.ContextKey = "role"
|
|
)
|
|
|
|
var (
|
|
// Cache JWT secret to avoid repeated os.Getenv calls
|
|
jwtSecretOnce sync.Once
|
|
jwtSecretCached []byte
|
|
jwtSecretError error
|
|
|
|
// Pre-allocate error messages to avoid repeated allocations
|
|
errExpiredToken = "Invalid or expired token" // #nosec G101
|
|
|
|
// Redis key prefix for token cache
|
|
redisTokenPrefix = "jwt:token:"
|
|
)
|
|
|
|
// Initialize JWT secret once
|
|
func getJWTSecret() ([]byte, error) {
|
|
jwtSecretOnce.Do(func() {
|
|
secret := os.Getenv("JWT_KEY")
|
|
if secret == "" {
|
|
jwtSecretError = fmt.Errorf("JWT_KEY not set in environment")
|
|
return
|
|
}
|
|
jwtSecretCached = []byte(secret)
|
|
})
|
|
return jwtSecretCached, jwtSecretError
|
|
}
|
|
|
|
// extractBearerToken extracts token from Authorization header
|
|
func extractBearerToken(authHeader string) (string, bool) {
|
|
if authHeader == "" || len(authHeader) < 8 || authHeader[:7] != "Bearer " {
|
|
return "", false
|
|
}
|
|
return authHeader[7:], true
|
|
}
|
|
|
|
// checkTokenCache retrieves token from Redis cache if valid
|
|
func checkTokenCache(tokenString string) (*models.Claims, bool) {
|
|
if redisclient.RDB == nil {
|
|
return nil, false
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
|
defer cancel()
|
|
|
|
key := redisTokenPrefix + tokenString
|
|
val, err := redisclient.RDB.Get(ctx, key).Result()
|
|
if err != nil {
|
|
return nil, false
|
|
}
|
|
|
|
var claims models.Claims
|
|
if err := json.Unmarshal([]byte(val), &claims); err != nil {
|
|
return nil, false
|
|
}
|
|
|
|
return &claims, true
|
|
}
|
|
|
|
// removeExpiredCacheEntry removes a single expired token from cache
|
|
// func removeExpiredCacheEntry(tokenString string) {
|
|
// tokenCacheMutex.Lock()
|
|
// defer tokenCacheMutex.Unlock()
|
|
// delete(tokenCache, tokenString)
|
|
// }
|
|
|
|
// parseAndValidateToken parses JWT token and validates it
|
|
func parseAndValidateToken(tokenString string) (*models.Claims, error) {
|
|
token, err := jwt.ParseWithClaims(tokenString, &models.Claims{}, func(token *jwt.Token) (interface{}, error) {
|
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
|
}
|
|
return getJWTSecret()
|
|
})
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if !token.Valid {
|
|
return nil, fmt.Errorf("invalid token")
|
|
}
|
|
|
|
claims, ok := token.Claims.(*models.Claims)
|
|
if !ok {
|
|
return nil, fmt.Errorf("invalid claims")
|
|
}
|
|
|
|
return claims, nil
|
|
}
|
|
|
|
// cacheToken stores validated token in Redis cache
|
|
func cacheToken(tokenString string, claims *models.Claims) {
|
|
if redisclient.RDB == nil {
|
|
return
|
|
}
|
|
|
|
// Calculate TTL
|
|
ttl := 5 * time.Minute
|
|
if claims.ExpiresAt != nil {
|
|
timeUntilExpiry := time.Until(claims.ExpiresAt.Time)
|
|
if timeUntilExpiry > 0 && timeUntilExpiry < ttl {
|
|
ttl = timeUntilExpiry
|
|
}
|
|
}
|
|
|
|
// Serialize claims to JSON
|
|
claimsJSON, err := json.Marshal(claims)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
// Store in Redis with TTL
|
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
|
defer cancel()
|
|
|
|
key := redisTokenPrefix + tokenString
|
|
redisclient.RDB.Set(ctx, key, claimsJSON, ttl)
|
|
}
|
|
|
|
// JWTAuth is a middleware that validates JWT tokens with caching for high-frequency requests
|
|
func JWTAuth(next http.HandlerFunc) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
// Extract token from header
|
|
tokenString, ok := extractBearerToken(r.Header.Get("Authorization"))
|
|
if !ok {
|
|
helper.RespondWithError(w, http.StatusUnauthorized, "Unauthorized")
|
|
return
|
|
}
|
|
|
|
// Check cache first
|
|
if claims, found := checkTokenCache(tokenString); found {
|
|
ctx := buildContext(r.Context(), claims)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
return
|
|
}
|
|
|
|
// Parse and validate token
|
|
claims, err := parseAndValidateToken(tokenString)
|
|
if err != nil {
|
|
helper.RespondWithError(w, http.StatusUnauthorized, errExpiredToken)
|
|
return
|
|
}
|
|
|
|
// Cache the validated token
|
|
cacheToken(tokenString, claims)
|
|
|
|
// Add claims to context and proceed
|
|
ctx := buildContext(r.Context(), claims)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
}
|
|
}
|
|
|
|
// buildContext efficiently builds context with claims (reduces allocations)
|
|
func buildContext(parent context.Context, claims *models.Claims) context.Context {
|
|
ctx := context.WithValue(parent, claimsKey, claims)
|
|
ctx = context.WithValue(ctx, userIDKey, claims.UserID)
|
|
ctx = context.WithValue(ctx, usernameKey, claims.Username)
|
|
ctx = context.WithValue(ctx, roleKey, claims.Role)
|
|
return ctx
|
|
}
|
|
|
|
// GetClaims retrieves the JWT claims from the request context
|
|
func GetClaims(r *http.Request) (*models.Claims, bool) {
|
|
claims, ok := r.Context().Value(claimsKey).(*models.Claims)
|
|
return claims, ok
|
|
}
|
|
|
|
// GetUserID retrieves the user ID from the request context
|
|
func GetUserID(r *http.Request) (string, bool) {
|
|
userID, ok := r.Context().Value(userIDKey).(string)
|
|
return userID, ok
|
|
}
|
|
|
|
// GetUsername retrieves the username from the request context
|
|
func GetUsername(r *http.Request) (string, bool) {
|
|
username, ok := r.Context().Value(usernameKey).(string)
|
|
return username, ok
|
|
}
|
|
|
|
// GetRole retrieves the role from the request context
|
|
func GetRole(r *http.Request) (string, bool) {
|
|
role, ok := r.Context().Value(roleKey).(string)
|
|
return role, ok
|
|
}
|