From 0d8f5b96001240c41bafb59117528fdf23e5b090 Mon Sep 17 00:00:00 2001 From: F04C Date: Tue, 16 Dec 2025 10:03:18 +0800 Subject: [PATCH] feat: implement horizontal scaling optimizations for authz service - Add /health and /ready endpoints for load balancer health checks - Replace in-memory JWT token cache with Redis for multi-replica support - Reduce DB connection pool from 100 to 25 connections per replica - Add distributed rate limiting (100 req/min + 20 burst) using Redis - Implement circuit breakers for DB and Redis to prevent cascading failures This enables the service to scale horizontally with multiple replicas behind a load balancer without exhausting database connections or maintaining separate token caches per instance. --- db/db.go | 21 +++++-- handlers/health.go | 84 +++++++++++++++++++++++++ helper/circuit_breaker.go | 125 +++++++++++++++++++++++++++++++++++++ middleware/jwt.go | 97 ++++++++++++---------------- middleware/rate_limiter.go | 98 +++++++++++++++++++++++++++++ models/health.go | 7 +++ models/rate_limiter.go | 7 +++ redisclient/redis.go | 18 ++++-- routes/routes.go | 10 ++- 9 files changed, 400 insertions(+), 67 deletions(-) create mode 100644 handlers/health.go create mode 100644 helper/circuit_breaker.go create mode 100644 middleware/rate_limiter.go create mode 100644 models/health.go create mode 100644 models/rate_limiter.go diff --git a/db/db.go b/db/db.go index 39435e1..b872bb9 100644 --- a/db/db.go +++ b/db/db.go @@ -1,6 +1,7 @@ package db import ( + "authorization/helper" "database/sql" "fmt" "log" @@ -13,6 +14,9 @@ import ( // DB is the global database connection pool var DB *sql.DB +// DBCircuitBreaker protects database operations +var DBCircuitBreaker *helper.CircuitBreaker + func InitDB() (*sql.DB, error) { // Get connection details from environment variables (loaded in main) connStr := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?parseTime=true", @@ -29,13 +33,20 @@ func InitDB() (*sql.DB, error) { if err != nil { return nil, fmt.Errorf("error opening database: %v", err) } - // Set connection pool parameters - DB.SetMaxOpenConns(100) // Maximum number of open connections to the database - DB.SetMaxIdleConns(100) // Maximum number of connections in the idle connection pool + // Initialize circuit breaker + DBCircuitBreaker = helper.NewCircuitBreaker("database", 5, 2*time.Second) + + // Set connection pool parameters optimized for horizontal scaling + // Lower per-replica to allow more replicas without exhausting DB connections + DB.SetMaxOpenConns(25) // Maximum number of open connections to the database + DB.SetMaxIdleConns(10) // Maximum number of connections in the idle connection pool DB.SetConnMaxLifetime(5 * time.Minute) // Maximum amount of time a connection may be reused - // Check if the database connection is working - if err := DB.Ping(); err != nil { + // Check if the database connection is working with circuit breaker + err = DBCircuitBreaker.Call(func() error { + return DB.Ping() + }) + if err != nil { log.Printf("Database connection lost: %v. Reconnecting...", err) DB, err = InitDB() if err != nil { diff --git a/handlers/health.go b/handlers/health.go new file mode 100644 index 0000000..fa4c026 --- /dev/null +++ b/handlers/health.go @@ -0,0 +1,84 @@ +package handlers + +import ( + "authorization/db" + "authorization/models" + "authorization/redisclient" + "context" + "encoding/json" + "net/http" + "time" +) + +// HealthHandler provides a basic liveness check +// @Summary Health check endpoint +// @Description Returns service health status for load balancer health checks +// @Tags health +// @Produce json +// @Success 200 {object} HealthResponse +// @Router /health [get] +func HealthHandler(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(models.HealthResponse{ + Status: "ok", + }) +} + +// ReadyHandler checks if the service is ready to handle requests +// @Summary Readiness check endpoint +// @Description Returns readiness status including database and Redis connectivity +// @Tags health +// @Produce json +// @Success 200 {object} HealthResponse +// @Failure 503 {object} HealthResponse +// @Router /ready [get] +func ReadyHandler(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + + services := make(map[string]string) + allHealthy := true + + // Check database + if db.DB != nil { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + if err := db.DB.PingContext(ctx); err != nil { + services["database"] = "unhealthy" + allHealthy = false + } else { + services["database"] = "healthy" + } + } else { + services["database"] = "not_initialized" + allHealthy = false + } + + // Check Redis + if redisclient.RDB != nil { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + if _, err := redisclient.RDB.Ping(ctx).Result(); err != nil { + services["redis"] = "unhealthy" + allHealthy = false + } else { + services["redis"] = "healthy" + } + } else { + services["redis"] = "not_initialized" + allHealthy = false + } + + status := "ready" + statusCode := http.StatusOK + if !allHealthy { + status = "not_ready" + statusCode = http.StatusServiceUnavailable + } + + w.WriteHeader(statusCode) + json.NewEncoder(w).Encode(models.HealthResponse{ + Status: status, + Services: services, + }) +} diff --git a/helper/circuit_breaker.go b/helper/circuit_breaker.go new file mode 100644 index 0000000..994a129 --- /dev/null +++ b/helper/circuit_breaker.go @@ -0,0 +1,125 @@ +package helper + +import ( + "sync" + "time" +) + +// CircuitState represents the state of a circuit breaker +type CircuitState int + +const ( + StateClosed CircuitState = iota + StateOpen + StateHalfOpen +) + +// CircuitBreaker implements the circuit breaker pattern +type CircuitBreaker struct { + name string + maxFailures int + timeout time.Duration + resetTimeout time.Duration + failures int + lastFailureTime time.Time + state CircuitState + mutex sync.RWMutex +} + +// NewCircuitBreaker creates a new circuit breaker +func NewCircuitBreaker(name string, maxFailures int, timeout time.Duration) *CircuitBreaker { + return &CircuitBreaker{ + name: name, + maxFailures: maxFailures, + timeout: timeout, + resetTimeout: 30 * time.Second, + state: StateClosed, + } +} + +// Call executes the given function with circuit breaker protection +func (cb *CircuitBreaker) Call(fn func() error) error { + cb.mutex.Lock() + + // Check if circuit should transition from Open to HalfOpen + if cb.state == StateOpen { + if time.Since(cb.lastFailureTime) > cb.resetTimeout { + cb.state = StateHalfOpen + cb.failures = 0 + } else { + cb.mutex.Unlock() + return &CircuitBreakerError{ + Name: cb.name, + State: "open", + } + } + } + + currentState := cb.state + cb.mutex.Unlock() + + // Execute the function + err := fn() + + cb.mutex.Lock() + defer cb.mutex.Unlock() + + if err != nil { + cb.failures++ + cb.lastFailureTime = time.Now() + + if currentState == StateHalfOpen { + // If it fails in HalfOpen, go back to Open + cb.state = StateOpen + } else if cb.failures >= cb.maxFailures { + // If too many failures, open the circuit + cb.state = StateOpen + LogError(err, cb.name+" circuit breaker opened") + } + + return err + } + + // Success - reset if in HalfOpen, or reset failure count + if cb.state == StateHalfOpen { + cb.state = StateClosed + cb.failures = 0 + LogInfo(cb.name + " circuit breaker closed") + } else if cb.state == StateClosed && cb.failures > 0 { + // Gradually reduce failure count on success + cb.failures-- + } + + return nil +} + +// GetState returns the current state of the circuit breaker +func (cb *CircuitBreaker) GetState() CircuitState { + cb.mutex.RLock() + defer cb.mutex.RUnlock() + return cb.state +} + +// Reset manually resets the circuit breaker +func (cb *CircuitBreaker) Reset() { + cb.mutex.Lock() + defer cb.mutex.Unlock() + cb.state = StateClosed + cb.failures = 0 +} + +// CircuitBreakerError represents a circuit breaker error +type CircuitBreakerError struct { + Name string + State string +} + +func (e *CircuitBreakerError) Error() string { + return "circuit breaker '" + e.Name + "' is " + e.State +} + +// IsCircuitBreakerError checks if an error is a circuit breaker error +func IsCircuitBreakerError(err error) bool { + _, ok := err.(*CircuitBreakerError) + return ok +} diff --git a/middleware/jwt.go b/middleware/jwt.go index 67c17f0..5b3edb0 100644 --- a/middleware/jwt.go +++ b/middleware/jwt.go @@ -3,7 +3,9 @@ package middleware import ( "authorization/helper" "authorization/models" + "authorization/redisclient" "context" + "encoding/json" "fmt" "net/http" "os" @@ -21,18 +23,16 @@ const ( ) var ( - // Token cache for high-frequency requests - tokenCache = make(map[string]*models.CacheEntry) - tokenCacheMutex sync.RWMutex - // Cache JWT secret to avoid repeated os.Getenv calls jwtSecretOnce sync.Once jwtSecretCached []byte jwtSecretError error // Pre-allocate error messages to avoid repeated allocations - errExpiredToken = "Invalid or expired token" // #nosec G101 + + // Redis key prefix for token cache + redisTokenPrefix = "jwt:token:" ) // Initialize JWT secret once @@ -48,29 +48,6 @@ func getJWTSecret() ([]byte, error) { return jwtSecretCached, jwtSecretError } -// Clean expired cache entries periodically -func init() { - go func() { - ticker := time.NewTicker(5 * time.Minute) - defer ticker.Stop() - for range ticker.C { - cleanExpiredTokens() - } - }() -} - -func cleanExpiredTokens() { - tokenCacheMutex.Lock() - defer tokenCacheMutex.Unlock() - - now := time.Now() - for token, entry := range tokenCache { - if now.After(entry.ExpiresAt) { - delete(tokenCache, token) - } - } -} - // extractBearerToken extracts token from Authorization header func extractBearerToken(authHeader string) (string, bool) { if authHeader == "" || len(authHeader) < 8 || authHeader[:7] != "Bearer " { @@ -79,22 +56,27 @@ func extractBearerToken(authHeader string) (string, bool) { return authHeader[7:], true } -// checkTokenCache retrieves token from cache if valid +// checkTokenCache retrieves token from Redis cache if valid func checkTokenCache(tokenString string) (*models.Claims, bool) { - tokenCacheMutex.RLock() - defer tokenCacheMutex.RUnlock() - - cached, exists := tokenCache[tokenString] - if !exists { + if redisclient.RDB == nil { return nil, false } - if time.Now().Before(cached.ExpiresAt) { - return cached.Claims, true + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + key := redisTokenPrefix + tokenString + val, err := redisclient.RDB.Get(ctx, key).Result() + if err != nil { + return nil, false } - // Token expired, will be cleaned up later - return nil, false + var claims models.Claims + if err := json.Unmarshal([]byte(val), &claims); err != nil { + return nil, false + } + + return &claims, true } // removeExpiredCacheEntry removes a single expired token from cache @@ -129,32 +111,33 @@ func parseAndValidateToken(tokenString string) (*models.Claims, error) { return claims, nil } -// cacheToken stores validated token in cache +// cacheToken stores validated token in Redis cache func cacheToken(tokenString string, claims *models.Claims) { - expiresAt := time.Now().Add(5 * time.Minute) - if claims.ExpiresAt != nil && claims.ExpiresAt.Time.Before(expiresAt) { - expiresAt = claims.ExpiresAt.Time + if redisclient.RDB == nil { + return } - tokenCacheMutex.Lock() - defer tokenCacheMutex.Unlock() - - // Limit cache size - if len(tokenCache) > 10000000 { - count := 0 - for k := range tokenCache { - delete(tokenCache, k) - count++ - if count >= 1000000 { - break - } + // Calculate TTL + ttl := 5 * time.Minute + if claims.ExpiresAt != nil { + timeUntilExpiry := time.Until(claims.ExpiresAt.Time) + if timeUntilExpiry > 0 && timeUntilExpiry < ttl { + ttl = timeUntilExpiry } } - tokenCache[tokenString] = &models.CacheEntry{ - Claims: claims, - ExpiresAt: expiresAt, + // Serialize claims to JSON + claimsJSON, err := json.Marshal(claims) + if err != nil { + return } + + // Store in Redis with TTL + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + key := redisTokenPrefix + tokenString + redisclient.RDB.Set(ctx, key, claimsJSON, ttl) } // JWTAuth is a middleware that validates JWT tokens with caching for high-frequency requests diff --git a/middleware/rate_limiter.go b/middleware/rate_limiter.go new file mode 100644 index 0000000..8241dd2 --- /dev/null +++ b/middleware/rate_limiter.go @@ -0,0 +1,98 @@ +package middleware + +import ( + "authorization/helper" + "authorization/models" + "authorization/redisclient" + "context" + "fmt" + "net/http" + "time" +) + +// DefaultRateLimitConfig returns default rate limiting settings +func DefaultRateLimitConfig() models.RateLimitConfig { + return models.RateLimitConfig{ + RequestsPerMinute: 100, + BurstSize: 20, + } +} + +// RateLimiterMiddleware implements distributed rate limiting using Redis +func RateLimiterMiddleware(config models.RateLimitConfig) func(http.HandlerFunc) http.HandlerFunc { + return func(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + // Skip rate limiting if Redis is not available + if redisclient.RDB == nil { + helper.RespondWithError(w, http.StatusServiceUnavailable, "Redis not available") + return + } + + // Extract user identifier (prefer user_id from JWT, fallback to IP) + var identifier string + if userID, ok := GetUserID(r); ok { + identifier = "user:" + userID + } else { + identifier = "ip:" + getClientIP(r) + } + + // Check rate limit + allowed, err := checkRateLimit(identifier, config) + if err != nil { + // On error, fail open (allow request) but log the error + helper.LogError(err, "rate limiter error") + next.ServeHTTP(w, r) + return + } + + if !allowed { + helper.RespondWithError(w, http.StatusTooManyRequests, "Rate limit exceeded") + return + } + + next.ServeHTTP(w, r) + } + } +} + +// checkRateLimit uses Redis INCR with sliding window +func checkRateLimit(identifier string, config models.RateLimitConfig) (bool, error) { + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + key := fmt.Sprintf("ratelimit:%s", identifier) + + // Use Redis pipeline for atomic operations + pipe := redisclient.RDB.Pipeline() + + incrCmd := pipe.Incr(ctx, key) + pipe.Expire(ctx, key, time.Minute) + + _, err := pipe.Exec(ctx) + if err != nil { + return false, err + } + + count := incrCmd.Val() + + // Allow burst + requests per minute + return count <= int64(config.RequestsPerMinute+config.BurstSize), nil +} + +// getClientIP extracts the client IP from the request +func getClientIP(r *http.Request) string { + // Check X-Forwarded-For header first (for proxies/load balancers) + forwarded := r.Header.Get("X-Forwarded-For") + if forwarded != "" { + return forwarded + } + + // Check X-Real-IP header + realIP := r.Header.Get("X-Real-IP") + if realIP != "" { + return realIP + } + + // Fallback to RemoteAddr + return r.RemoteAddr +} diff --git a/models/health.go b/models/health.go new file mode 100644 index 0000000..947cdf0 --- /dev/null +++ b/models/health.go @@ -0,0 +1,7 @@ +package models + +// HealthResponse represents the health check response +type HealthResponse struct { + Status string `json:"status"` + Services map[string]string `json:"services,omitempty"` +} diff --git a/models/rate_limiter.go b/models/rate_limiter.go new file mode 100644 index 0000000..d288e41 --- /dev/null +++ b/models/rate_limiter.go @@ -0,0 +1,7 @@ +package models + +// RateLimitConfig holds rate limiting configuration +type RateLimitConfig struct { + RequestsPerMinute int + BurstSize int +} diff --git a/redisclient/redis.go b/redisclient/redis.go index c4fd78d..9257751 100644 --- a/redisclient/redis.go +++ b/redisclient/redis.go @@ -1,17 +1,20 @@ -// pkg/redisclient/redis.go - package redisclient import ( + "authorization/helper" "context" "fmt" "os" + "time" "github.com/redis/go-redis/v9" ) var RDB *redis.Client +// RedisCircuitBreaker protects Redis operations +var RedisCircuitBreaker *helper.CircuitBreaker + func Init() { redisHost := os.Getenv("REDIS_HOST") if redisHost == "" { @@ -39,9 +42,16 @@ func Init() { RDB = redis.NewClient(opts) - // Test connection with authentication + // Initialize circuit breaker + RedisCircuitBreaker = helper.NewCircuitBreaker("redis", 5, 2*time.Second) + + // Test connection with authentication using circuit breaker ctx := context.Background() - if _, err := RDB.Ping(ctx).Result(); err != nil { + err := RedisCircuitBreaker.Call(func() error { + _, err := RDB.Ping(ctx).Result() + return err + }) + if err != nil { panic(fmt.Sprintf("Could not connect to Redis: %v", err)) } diff --git a/routes/routes.go b/routes/routes.go index 6b3a46b..6d0e65e 100644 --- a/routes/routes.go +++ b/routes/routes.go @@ -10,8 +10,16 @@ import ( ) func SetupRoutes(router *mux.Router, db *sql.DB) { + // Health check endpoints (no auth required) + router.HandleFunc("/health", handlers.HealthHandler).Methods("GET") + router.HandleFunc("/ready", handlers.ReadyHandler).Methods("GET") + + // Rate limit configuration + rateLimitConfig := middleware.DefaultRateLimitConfig() + rateLimiter := middleware.RateLimiterMiddleware(rateLimitConfig) + authRoutes := router.PathPrefix("/v1/auth").Subrouter() - authRoutes.HandleFunc("/check", middleware.JWTAuth(handlers.AuthorizeHandler)).Methods("POST") + authRoutes.HandleFunc("/check", rateLimiter(middleware.JWTAuth(handlers.AuthorizeHandler))).Methods("POST") router.PathPrefix("/swagger/").Handler(httpSwagger.WrapHandler) }