feat: implement horizontal scaling optimizations for authz service

- Add /health and /ready endpoints for load balancer health checks
- Replace in-memory JWT token cache with Redis for multi-replica support
- Reduce DB connection pool from 100 to 25 connections per replica
- Add distributed rate limiting (100 req/min + 20 burst) using Redis
- Implement circuit breakers for DB and Redis to prevent cascading failures

This enables the service to scale horizontally with multiple replicas
behind a load balancer without exhausting database connections or
maintaining separate token caches per instance.
This commit is contained in:
2025-12-16 10:03:18 +08:00
parent ee8079e65c
commit 0d8f5b9600
9 changed files with 400 additions and 67 deletions
+16 -5
View File
@@ -1,6 +1,7 @@
package db package db
import ( import (
"authorization/helper"
"database/sql" "database/sql"
"fmt" "fmt"
"log" "log"
@@ -13,6 +14,9 @@ import (
// DB is the global database connection pool // DB is the global database connection pool
var DB *sql.DB var DB *sql.DB
// DBCircuitBreaker protects database operations
var DBCircuitBreaker *helper.CircuitBreaker
func InitDB() (*sql.DB, error) { func InitDB() (*sql.DB, error) {
// Get connection details from environment variables (loaded in main) // Get connection details from environment variables (loaded in main)
connStr := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?parseTime=true", connStr := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?parseTime=true",
@@ -29,13 +33,20 @@ func InitDB() (*sql.DB, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("error opening database: %v", err) return nil, fmt.Errorf("error opening database: %v", err)
} }
// Set connection pool parameters // Initialize circuit breaker
DB.SetMaxOpenConns(100) // Maximum number of open connections to the database DBCircuitBreaker = helper.NewCircuitBreaker("database", 5, 2*time.Second)
DB.SetMaxIdleConns(100) // Maximum number of connections in the idle connection pool
// Set connection pool parameters optimized for horizontal scaling
// Lower per-replica to allow more replicas without exhausting DB connections
DB.SetMaxOpenConns(25) // Maximum number of open connections to the database
DB.SetMaxIdleConns(10) // Maximum number of connections in the idle connection pool
DB.SetConnMaxLifetime(5 * time.Minute) // Maximum amount of time a connection may be reused DB.SetConnMaxLifetime(5 * time.Minute) // Maximum amount of time a connection may be reused
// Check if the database connection is working // Check if the database connection is working with circuit breaker
if err := DB.Ping(); err != nil { err = DBCircuitBreaker.Call(func() error {
return DB.Ping()
})
if err != nil {
log.Printf("Database connection lost: %v. Reconnecting...", err) log.Printf("Database connection lost: %v. Reconnecting...", err)
DB, err = InitDB() DB, err = InitDB()
if err != nil { if err != nil {
+84
View File
@@ -0,0 +1,84 @@
package handlers
import (
"authorization/db"
"authorization/models"
"authorization/redisclient"
"context"
"encoding/json"
"net/http"
"time"
)
// HealthHandler provides a basic liveness check
// @Summary Health check endpoint
// @Description Returns service health status for load balancer health checks
// @Tags health
// @Produce json
// @Success 200 {object} HealthResponse
// @Router /health [get]
func HealthHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(models.HealthResponse{
Status: "ok",
})
}
// ReadyHandler checks if the service is ready to handle requests
// @Summary Readiness check endpoint
// @Description Returns readiness status including database and Redis connectivity
// @Tags health
// @Produce json
// @Success 200 {object} HealthResponse
// @Failure 503 {object} HealthResponse
// @Router /ready [get]
func ReadyHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
services := make(map[string]string)
allHealthy := true
// Check database
if db.DB != nil {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := db.DB.PingContext(ctx); err != nil {
services["database"] = "unhealthy"
allHealthy = false
} else {
services["database"] = "healthy"
}
} else {
services["database"] = "not_initialized"
allHealthy = false
}
// Check Redis
if redisclient.RDB != nil {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if _, err := redisclient.RDB.Ping(ctx).Result(); err != nil {
services["redis"] = "unhealthy"
allHealthy = false
} else {
services["redis"] = "healthy"
}
} else {
services["redis"] = "not_initialized"
allHealthy = false
}
status := "ready"
statusCode := http.StatusOK
if !allHealthy {
status = "not_ready"
statusCode = http.StatusServiceUnavailable
}
w.WriteHeader(statusCode)
json.NewEncoder(w).Encode(models.HealthResponse{
Status: status,
Services: services,
})
}
+125
View File
@@ -0,0 +1,125 @@
package helper
import (
"sync"
"time"
)
// CircuitState represents the state of a circuit breaker
type CircuitState int
const (
StateClosed CircuitState = iota
StateOpen
StateHalfOpen
)
// CircuitBreaker implements the circuit breaker pattern
type CircuitBreaker struct {
name string
maxFailures int
timeout time.Duration
resetTimeout time.Duration
failures int
lastFailureTime time.Time
state CircuitState
mutex sync.RWMutex
}
// NewCircuitBreaker creates a new circuit breaker
func NewCircuitBreaker(name string, maxFailures int, timeout time.Duration) *CircuitBreaker {
return &CircuitBreaker{
name: name,
maxFailures: maxFailures,
timeout: timeout,
resetTimeout: 30 * time.Second,
state: StateClosed,
}
}
// Call executes the given function with circuit breaker protection
func (cb *CircuitBreaker) Call(fn func() error) error {
cb.mutex.Lock()
// Check if circuit should transition from Open to HalfOpen
if cb.state == StateOpen {
if time.Since(cb.lastFailureTime) > cb.resetTimeout {
cb.state = StateHalfOpen
cb.failures = 0
} else {
cb.mutex.Unlock()
return &CircuitBreakerError{
Name: cb.name,
State: "open",
}
}
}
currentState := cb.state
cb.mutex.Unlock()
// Execute the function
err := fn()
cb.mutex.Lock()
defer cb.mutex.Unlock()
if err != nil {
cb.failures++
cb.lastFailureTime = time.Now()
if currentState == StateHalfOpen {
// If it fails in HalfOpen, go back to Open
cb.state = StateOpen
} else if cb.failures >= cb.maxFailures {
// If too many failures, open the circuit
cb.state = StateOpen
LogError(err, cb.name+" circuit breaker opened")
}
return err
}
// Success - reset if in HalfOpen, or reset failure count
if cb.state == StateHalfOpen {
cb.state = StateClosed
cb.failures = 0
LogInfo(cb.name + " circuit breaker closed")
} else if cb.state == StateClosed && cb.failures > 0 {
// Gradually reduce failure count on success
cb.failures--
}
return nil
}
// GetState returns the current state of the circuit breaker
func (cb *CircuitBreaker) GetState() CircuitState {
cb.mutex.RLock()
defer cb.mutex.RUnlock()
return cb.state
}
// Reset manually resets the circuit breaker
func (cb *CircuitBreaker) Reset() {
cb.mutex.Lock()
defer cb.mutex.Unlock()
cb.state = StateClosed
cb.failures = 0
}
// CircuitBreakerError represents a circuit breaker error
type CircuitBreakerError struct {
Name string
State string
}
func (e *CircuitBreakerError) Error() string {
return "circuit breaker '" + e.Name + "' is " + e.State
}
// IsCircuitBreakerError checks if an error is a circuit breaker error
func IsCircuitBreakerError(err error) bool {
_, ok := err.(*CircuitBreakerError)
return ok
}
+40 -57
View File
@@ -3,7 +3,9 @@ package middleware
import ( import (
"authorization/helper" "authorization/helper"
"authorization/models" "authorization/models"
"authorization/redisclient"
"context" "context"
"encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"os" "os"
@@ -21,18 +23,16 @@ const (
) )
var ( var (
// Token cache for high-frequency requests
tokenCache = make(map[string]*models.CacheEntry)
tokenCacheMutex sync.RWMutex
// Cache JWT secret to avoid repeated os.Getenv calls // Cache JWT secret to avoid repeated os.Getenv calls
jwtSecretOnce sync.Once jwtSecretOnce sync.Once
jwtSecretCached []byte jwtSecretCached []byte
jwtSecretError error jwtSecretError error
// Pre-allocate error messages to avoid repeated allocations // Pre-allocate error messages to avoid repeated allocations
errExpiredToken = "Invalid or expired token" // #nosec G101 errExpiredToken = "Invalid or expired token" // #nosec G101
// Redis key prefix for token cache
redisTokenPrefix = "jwt:token:"
) )
// Initialize JWT secret once // Initialize JWT secret once
@@ -48,29 +48,6 @@ func getJWTSecret() ([]byte, error) {
return jwtSecretCached, jwtSecretError return jwtSecretCached, jwtSecretError
} }
// Clean expired cache entries periodically
func init() {
go func() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for range ticker.C {
cleanExpiredTokens()
}
}()
}
func cleanExpiredTokens() {
tokenCacheMutex.Lock()
defer tokenCacheMutex.Unlock()
now := time.Now()
for token, entry := range tokenCache {
if now.After(entry.ExpiresAt) {
delete(tokenCache, token)
}
}
}
// extractBearerToken extracts token from Authorization header // extractBearerToken extracts token from Authorization header
func extractBearerToken(authHeader string) (string, bool) { func extractBearerToken(authHeader string) (string, bool) {
if authHeader == "" || len(authHeader) < 8 || authHeader[:7] != "Bearer " { if authHeader == "" || len(authHeader) < 8 || authHeader[:7] != "Bearer " {
@@ -79,22 +56,27 @@ func extractBearerToken(authHeader string) (string, bool) {
return authHeader[7:], true return authHeader[7:], true
} }
// checkTokenCache retrieves token from cache if valid // checkTokenCache retrieves token from Redis cache if valid
func checkTokenCache(tokenString string) (*models.Claims, bool) { func checkTokenCache(tokenString string) (*models.Claims, bool) {
tokenCacheMutex.RLock() if redisclient.RDB == nil {
defer tokenCacheMutex.RUnlock()
cached, exists := tokenCache[tokenString]
if !exists {
return nil, false return nil, false
} }
if time.Now().Before(cached.ExpiresAt) { ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
return cached.Claims, true defer cancel()
key := redisTokenPrefix + tokenString
val, err := redisclient.RDB.Get(ctx, key).Result()
if err != nil {
return nil, false
} }
// Token expired, will be cleaned up later var claims models.Claims
return nil, false if err := json.Unmarshal([]byte(val), &claims); err != nil {
return nil, false
}
return &claims, true
} }
// removeExpiredCacheEntry removes a single expired token from cache // removeExpiredCacheEntry removes a single expired token from cache
@@ -129,32 +111,33 @@ func parseAndValidateToken(tokenString string) (*models.Claims, error) {
return claims, nil return claims, nil
} }
// cacheToken stores validated token in cache // cacheToken stores validated token in Redis cache
func cacheToken(tokenString string, claims *models.Claims) { func cacheToken(tokenString string, claims *models.Claims) {
expiresAt := time.Now().Add(5 * time.Minute) if redisclient.RDB == nil {
if claims.ExpiresAt != nil && claims.ExpiresAt.Time.Before(expiresAt) { return
expiresAt = claims.ExpiresAt.Time
} }
tokenCacheMutex.Lock() // Calculate TTL
defer tokenCacheMutex.Unlock() ttl := 5 * time.Minute
if claims.ExpiresAt != nil {
// Limit cache size timeUntilExpiry := time.Until(claims.ExpiresAt.Time)
if len(tokenCache) > 10000000 { if timeUntilExpiry > 0 && timeUntilExpiry < ttl {
count := 0 ttl = timeUntilExpiry
for k := range tokenCache {
delete(tokenCache, k)
count++
if count >= 1000000 {
break
}
} }
} }
tokenCache[tokenString] = &models.CacheEntry{ // Serialize claims to JSON
Claims: claims, claimsJSON, err := json.Marshal(claims)
ExpiresAt: expiresAt, if err != nil {
return
} }
// Store in Redis with TTL
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
key := redisTokenPrefix + tokenString
redisclient.RDB.Set(ctx, key, claimsJSON, ttl)
} }
// JWTAuth is a middleware that validates JWT tokens with caching for high-frequency requests // JWTAuth is a middleware that validates JWT tokens with caching for high-frequency requests
+98
View File
@@ -0,0 +1,98 @@
package middleware
import (
"authorization/helper"
"authorization/models"
"authorization/redisclient"
"context"
"fmt"
"net/http"
"time"
)
// DefaultRateLimitConfig returns default rate limiting settings
func DefaultRateLimitConfig() models.RateLimitConfig {
return models.RateLimitConfig{
RequestsPerMinute: 100,
BurstSize: 20,
}
}
// RateLimiterMiddleware implements distributed rate limiting using Redis
func RateLimiterMiddleware(config models.RateLimitConfig) func(http.HandlerFunc) http.HandlerFunc {
return func(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Skip rate limiting if Redis is not available
if redisclient.RDB == nil {
helper.RespondWithError(w, http.StatusServiceUnavailable, "Redis not available")
return
}
// Extract user identifier (prefer user_id from JWT, fallback to IP)
var identifier string
if userID, ok := GetUserID(r); ok {
identifier = "user:" + userID
} else {
identifier = "ip:" + getClientIP(r)
}
// Check rate limit
allowed, err := checkRateLimit(identifier, config)
if err != nil {
// On error, fail open (allow request) but log the error
helper.LogError(err, "rate limiter error")
next.ServeHTTP(w, r)
return
}
if !allowed {
helper.RespondWithError(w, http.StatusTooManyRequests, "Rate limit exceeded")
return
}
next.ServeHTTP(w, r)
}
}
}
// checkRateLimit uses Redis INCR with sliding window
func checkRateLimit(identifier string, config models.RateLimitConfig) (bool, error) {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
key := fmt.Sprintf("ratelimit:%s", identifier)
// Use Redis pipeline for atomic operations
pipe := redisclient.RDB.Pipeline()
incrCmd := pipe.Incr(ctx, key)
pipe.Expire(ctx, key, time.Minute)
_, err := pipe.Exec(ctx)
if err != nil {
return false, err
}
count := incrCmd.Val()
// Allow burst + requests per minute
return count <= int64(config.RequestsPerMinute+config.BurstSize), nil
}
// getClientIP extracts the client IP from the request
func getClientIP(r *http.Request) string {
// Check X-Forwarded-For header first (for proxies/load balancers)
forwarded := r.Header.Get("X-Forwarded-For")
if forwarded != "" {
return forwarded
}
// Check X-Real-IP header
realIP := r.Header.Get("X-Real-IP")
if realIP != "" {
return realIP
}
// Fallback to RemoteAddr
return r.RemoteAddr
}
+7
View File
@@ -0,0 +1,7 @@
package models
// HealthResponse represents the health check response
type HealthResponse struct {
Status string `json:"status"`
Services map[string]string `json:"services,omitempty"`
}
+7
View File
@@ -0,0 +1,7 @@
package models
// RateLimitConfig holds rate limiting configuration
type RateLimitConfig struct {
RequestsPerMinute int
BurstSize int
}
+14 -4
View File
@@ -1,17 +1,20 @@
// pkg/redisclient/redis.go
package redisclient package redisclient
import ( import (
"authorization/helper"
"context" "context"
"fmt" "fmt"
"os" "os"
"time"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
) )
var RDB *redis.Client var RDB *redis.Client
// RedisCircuitBreaker protects Redis operations
var RedisCircuitBreaker *helper.CircuitBreaker
func Init() { func Init() {
redisHost := os.Getenv("REDIS_HOST") redisHost := os.Getenv("REDIS_HOST")
if redisHost == "" { if redisHost == "" {
@@ -39,9 +42,16 @@ func Init() {
RDB = redis.NewClient(opts) RDB = redis.NewClient(opts)
// Test connection with authentication // Initialize circuit breaker
RedisCircuitBreaker = helper.NewCircuitBreaker("redis", 5, 2*time.Second)
// Test connection with authentication using circuit breaker
ctx := context.Background() ctx := context.Background()
if _, err := RDB.Ping(ctx).Result(); err != nil { err := RedisCircuitBreaker.Call(func() error {
_, err := RDB.Ping(ctx).Result()
return err
})
if err != nil {
panic(fmt.Sprintf("Could not connect to Redis: %v", err)) panic(fmt.Sprintf("Could not connect to Redis: %v", err))
} }
+9 -1
View File
@@ -10,8 +10,16 @@ import (
) )
func SetupRoutes(router *mux.Router, db *sql.DB) { func SetupRoutes(router *mux.Router, db *sql.DB) {
// Health check endpoints (no auth required)
router.HandleFunc("/health", handlers.HealthHandler).Methods("GET")
router.HandleFunc("/ready", handlers.ReadyHandler).Methods("GET")
// Rate limit configuration
rateLimitConfig := middleware.DefaultRateLimitConfig()
rateLimiter := middleware.RateLimiterMiddleware(rateLimitConfig)
authRoutes := router.PathPrefix("/v1/auth").Subrouter() authRoutes := router.PathPrefix("/v1/auth").Subrouter()
authRoutes.HandleFunc("/check", middleware.JWTAuth(handlers.AuthorizeHandler)).Methods("POST") authRoutes.HandleFunc("/check", rateLimiter(middleware.JWTAuth(handlers.AuthorizeHandler))).Methods("POST")
router.PathPrefix("/swagger/").Handler(httpSwagger.WrapHandler) router.PathPrefix("/swagger/").Handler(httpSwagger.WrapHandler)
} }