feat: implement horizontal scaling optimizations for authz service

- Add /health and /ready endpoints for load balancer health checks
- Replace in-memory JWT token cache with Redis for multi-replica support
- Reduce DB connection pool from 100 to 25 connections per replica
- Add distributed rate limiting (100 req/min + 20 burst) using Redis
- Implement circuit breakers for DB and Redis to prevent cascading failures

This enables the service to scale horizontally with multiple replicas
behind a load balancer without exhausting database connections or
maintaining separate token caches per instance.
This commit is contained in:
2025-12-16 10:03:18 +08:00
parent ee8079e65c
commit 0d8f5b9600
9 changed files with 400 additions and 67 deletions
+40 -57
View File
@@ -3,7 +3,9 @@ package middleware
import (
"authorization/helper"
"authorization/models"
"authorization/redisclient"
"context"
"encoding/json"
"fmt"
"net/http"
"os"
@@ -21,18 +23,16 @@ const (
)
var (
// Token cache for high-frequency requests
tokenCache = make(map[string]*models.CacheEntry)
tokenCacheMutex sync.RWMutex
// Cache JWT secret to avoid repeated os.Getenv calls
jwtSecretOnce sync.Once
jwtSecretCached []byte
jwtSecretError error
// Pre-allocate error messages to avoid repeated allocations
errExpiredToken = "Invalid or expired token" // #nosec G101
// Redis key prefix for token cache
redisTokenPrefix = "jwt:token:"
)
// Initialize JWT secret once
@@ -48,29 +48,6 @@ func getJWTSecret() ([]byte, error) {
return jwtSecretCached, jwtSecretError
}
// Clean expired cache entries periodically
func init() {
go func() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for range ticker.C {
cleanExpiredTokens()
}
}()
}
func cleanExpiredTokens() {
tokenCacheMutex.Lock()
defer tokenCacheMutex.Unlock()
now := time.Now()
for token, entry := range tokenCache {
if now.After(entry.ExpiresAt) {
delete(tokenCache, token)
}
}
}
// extractBearerToken extracts token from Authorization header
func extractBearerToken(authHeader string) (string, bool) {
if authHeader == "" || len(authHeader) < 8 || authHeader[:7] != "Bearer " {
@@ -79,22 +56,27 @@ func extractBearerToken(authHeader string) (string, bool) {
return authHeader[7:], true
}
// checkTokenCache retrieves token from cache if valid
// checkTokenCache retrieves token from Redis cache if valid
func checkTokenCache(tokenString string) (*models.Claims, bool) {
tokenCacheMutex.RLock()
defer tokenCacheMutex.RUnlock()
cached, exists := tokenCache[tokenString]
if !exists {
if redisclient.RDB == nil {
return nil, false
}
if time.Now().Before(cached.ExpiresAt) {
return cached.Claims, true
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
key := redisTokenPrefix + tokenString
val, err := redisclient.RDB.Get(ctx, key).Result()
if err != nil {
return nil, false
}
// Token expired, will be cleaned up later
return nil, false
var claims models.Claims
if err := json.Unmarshal([]byte(val), &claims); err != nil {
return nil, false
}
return &claims, true
}
// removeExpiredCacheEntry removes a single expired token from cache
@@ -129,32 +111,33 @@ func parseAndValidateToken(tokenString string) (*models.Claims, error) {
return claims, nil
}
// cacheToken stores validated token in cache
// cacheToken stores validated token in Redis cache
func cacheToken(tokenString string, claims *models.Claims) {
expiresAt := time.Now().Add(5 * time.Minute)
if claims.ExpiresAt != nil && claims.ExpiresAt.Time.Before(expiresAt) {
expiresAt = claims.ExpiresAt.Time
if redisclient.RDB == nil {
return
}
tokenCacheMutex.Lock()
defer tokenCacheMutex.Unlock()
// Limit cache size
if len(tokenCache) > 10000000 {
count := 0
for k := range tokenCache {
delete(tokenCache, k)
count++
if count >= 1000000 {
break
}
// Calculate TTL
ttl := 5 * time.Minute
if claims.ExpiresAt != nil {
timeUntilExpiry := time.Until(claims.ExpiresAt.Time)
if timeUntilExpiry > 0 && timeUntilExpiry < ttl {
ttl = timeUntilExpiry
}
}
tokenCache[tokenString] = &models.CacheEntry{
Claims: claims,
ExpiresAt: expiresAt,
// Serialize claims to JSON
claimsJSON, err := json.Marshal(claims)
if err != nil {
return
}
// Store in Redis with TTL
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
key := redisTokenPrefix + tokenString
redisclient.RDB.Set(ctx, key, claimsJSON, ttl)
}
// JWTAuth is a middleware that validates JWT tokens with caching for high-frequency requests
+98
View File
@@ -0,0 +1,98 @@
package middleware
import (
"authorization/helper"
"authorization/models"
"authorization/redisclient"
"context"
"fmt"
"net/http"
"time"
)
// DefaultRateLimitConfig returns default rate limiting settings
func DefaultRateLimitConfig() models.RateLimitConfig {
return models.RateLimitConfig{
RequestsPerMinute: 100,
BurstSize: 20,
}
}
// RateLimiterMiddleware implements distributed rate limiting using Redis
func RateLimiterMiddleware(config models.RateLimitConfig) func(http.HandlerFunc) http.HandlerFunc {
return func(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Skip rate limiting if Redis is not available
if redisclient.RDB == nil {
helper.RespondWithError(w, http.StatusServiceUnavailable, "Redis not available")
return
}
// Extract user identifier (prefer user_id from JWT, fallback to IP)
var identifier string
if userID, ok := GetUserID(r); ok {
identifier = "user:" + userID
} else {
identifier = "ip:" + getClientIP(r)
}
// Check rate limit
allowed, err := checkRateLimit(identifier, config)
if err != nil {
// On error, fail open (allow request) but log the error
helper.LogError(err, "rate limiter error")
next.ServeHTTP(w, r)
return
}
if !allowed {
helper.RespondWithError(w, http.StatusTooManyRequests, "Rate limit exceeded")
return
}
next.ServeHTTP(w, r)
}
}
}
// checkRateLimit uses Redis INCR with sliding window
func checkRateLimit(identifier string, config models.RateLimitConfig) (bool, error) {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
key := fmt.Sprintf("ratelimit:%s", identifier)
// Use Redis pipeline for atomic operations
pipe := redisclient.RDB.Pipeline()
incrCmd := pipe.Incr(ctx, key)
pipe.Expire(ctx, key, time.Minute)
_, err := pipe.Exec(ctx)
if err != nil {
return false, err
}
count := incrCmd.Val()
// Allow burst + requests per minute
return count <= int64(config.RequestsPerMinute+config.BurstSize), nil
}
// getClientIP extracts the client IP from the request
func getClientIP(r *http.Request) string {
// Check X-Forwarded-For header first (for proxies/load balancers)
forwarded := r.Header.Get("X-Forwarded-For")
if forwarded != "" {
return forwarded
}
// Check X-Real-IP header
realIP := r.Header.Get("X-Real-IP")
if realIP != "" {
return realIP
}
// Fallback to RemoteAddr
return r.RemoteAddr
}