feat: implement horizontal scaling optimizations for authz service
- Add /health and /ready endpoints for load balancer health checks - Replace in-memory JWT token cache with Redis for multi-replica support - Reduce DB connection pool from 100 to 25 connections per replica - Add distributed rate limiting (100 req/min + 20 burst) using Redis - Implement circuit breakers for DB and Redis to prevent cascading failures This enables the service to scale horizontally with multiple replicas behind a load balancer without exhausting database connections or maintaining separate token caches per instance.
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
package db
|
package db
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"authorization/helper"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
@@ -13,6 +14,9 @@ import (
|
|||||||
// DB is the global database connection pool
|
// DB is the global database connection pool
|
||||||
var DB *sql.DB
|
var DB *sql.DB
|
||||||
|
|
||||||
|
// DBCircuitBreaker protects database operations
|
||||||
|
var DBCircuitBreaker *helper.CircuitBreaker
|
||||||
|
|
||||||
func InitDB() (*sql.DB, error) {
|
func InitDB() (*sql.DB, error) {
|
||||||
// Get connection details from environment variables (loaded in main)
|
// Get connection details from environment variables (loaded in main)
|
||||||
connStr := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?parseTime=true",
|
connStr := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?parseTime=true",
|
||||||
@@ -29,13 +33,20 @@ func InitDB() (*sql.DB, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error opening database: %v", err)
|
return nil, fmt.Errorf("error opening database: %v", err)
|
||||||
}
|
}
|
||||||
// Set connection pool parameters
|
// Initialize circuit breaker
|
||||||
DB.SetMaxOpenConns(100) // Maximum number of open connections to the database
|
DBCircuitBreaker = helper.NewCircuitBreaker("database", 5, 2*time.Second)
|
||||||
DB.SetMaxIdleConns(100) // Maximum number of connections in the idle connection pool
|
|
||||||
|
// Set connection pool parameters optimized for horizontal scaling
|
||||||
|
// Lower per-replica to allow more replicas without exhausting DB connections
|
||||||
|
DB.SetMaxOpenConns(25) // Maximum number of open connections to the database
|
||||||
|
DB.SetMaxIdleConns(10) // Maximum number of connections in the idle connection pool
|
||||||
DB.SetConnMaxLifetime(5 * time.Minute) // Maximum amount of time a connection may be reused
|
DB.SetConnMaxLifetime(5 * time.Minute) // Maximum amount of time a connection may be reused
|
||||||
|
|
||||||
// Check if the database connection is working
|
// Check if the database connection is working with circuit breaker
|
||||||
if err := DB.Ping(); err != nil {
|
err = DBCircuitBreaker.Call(func() error {
|
||||||
|
return DB.Ping()
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
log.Printf("Database connection lost: %v. Reconnecting...", err)
|
log.Printf("Database connection lost: %v. Reconnecting...", err)
|
||||||
DB, err = InitDB()
|
DB, err = InitDB()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -0,0 +1,84 @@
|
|||||||
|
package handlers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"authorization/db"
|
||||||
|
"authorization/models"
|
||||||
|
"authorization/redisclient"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HealthHandler provides a basic liveness check
|
||||||
|
// @Summary Health check endpoint
|
||||||
|
// @Description Returns service health status for load balancer health checks
|
||||||
|
// @Tags health
|
||||||
|
// @Produce json
|
||||||
|
// @Success 200 {object} HealthResponse
|
||||||
|
// @Router /health [get]
|
||||||
|
func HealthHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
json.NewEncoder(w).Encode(models.HealthResponse{
|
||||||
|
Status: "ok",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadyHandler checks if the service is ready to handle requests
|
||||||
|
// @Summary Readiness check endpoint
|
||||||
|
// @Description Returns readiness status including database and Redis connectivity
|
||||||
|
// @Tags health
|
||||||
|
// @Produce json
|
||||||
|
// @Success 200 {object} HealthResponse
|
||||||
|
// @Failure 503 {object} HealthResponse
|
||||||
|
// @Router /ready [get]
|
||||||
|
func ReadyHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
services := make(map[string]string)
|
||||||
|
allHealthy := true
|
||||||
|
|
||||||
|
// Check database
|
||||||
|
if db.DB != nil {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if err := db.DB.PingContext(ctx); err != nil {
|
||||||
|
services["database"] = "unhealthy"
|
||||||
|
allHealthy = false
|
||||||
|
} else {
|
||||||
|
services["database"] = "healthy"
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
services["database"] = "not_initialized"
|
||||||
|
allHealthy = false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check Redis
|
||||||
|
if redisclient.RDB != nil {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if _, err := redisclient.RDB.Ping(ctx).Result(); err != nil {
|
||||||
|
services["redis"] = "unhealthy"
|
||||||
|
allHealthy = false
|
||||||
|
} else {
|
||||||
|
services["redis"] = "healthy"
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
services["redis"] = "not_initialized"
|
||||||
|
allHealthy = false
|
||||||
|
}
|
||||||
|
|
||||||
|
status := "ready"
|
||||||
|
statusCode := http.StatusOK
|
||||||
|
if !allHealthy {
|
||||||
|
status = "not_ready"
|
||||||
|
statusCode = http.StatusServiceUnavailable
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteHeader(statusCode)
|
||||||
|
json.NewEncoder(w).Encode(models.HealthResponse{
|
||||||
|
Status: status,
|
||||||
|
Services: services,
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,125 @@
|
|||||||
|
package helper
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CircuitState represents the state of a circuit breaker
|
||||||
|
type CircuitState int
|
||||||
|
|
||||||
|
const (
|
||||||
|
StateClosed CircuitState = iota
|
||||||
|
StateOpen
|
||||||
|
StateHalfOpen
|
||||||
|
)
|
||||||
|
|
||||||
|
// CircuitBreaker implements the circuit breaker pattern
|
||||||
|
type CircuitBreaker struct {
|
||||||
|
name string
|
||||||
|
maxFailures int
|
||||||
|
timeout time.Duration
|
||||||
|
resetTimeout time.Duration
|
||||||
|
failures int
|
||||||
|
lastFailureTime time.Time
|
||||||
|
state CircuitState
|
||||||
|
mutex sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCircuitBreaker creates a new circuit breaker
|
||||||
|
func NewCircuitBreaker(name string, maxFailures int, timeout time.Duration) *CircuitBreaker {
|
||||||
|
return &CircuitBreaker{
|
||||||
|
name: name,
|
||||||
|
maxFailures: maxFailures,
|
||||||
|
timeout: timeout,
|
||||||
|
resetTimeout: 30 * time.Second,
|
||||||
|
state: StateClosed,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call executes the given function with circuit breaker protection
|
||||||
|
func (cb *CircuitBreaker) Call(fn func() error) error {
|
||||||
|
cb.mutex.Lock()
|
||||||
|
|
||||||
|
// Check if circuit should transition from Open to HalfOpen
|
||||||
|
if cb.state == StateOpen {
|
||||||
|
if time.Since(cb.lastFailureTime) > cb.resetTimeout {
|
||||||
|
cb.state = StateHalfOpen
|
||||||
|
cb.failures = 0
|
||||||
|
} else {
|
||||||
|
cb.mutex.Unlock()
|
||||||
|
return &CircuitBreakerError{
|
||||||
|
Name: cb.name,
|
||||||
|
State: "open",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
currentState := cb.state
|
||||||
|
cb.mutex.Unlock()
|
||||||
|
|
||||||
|
// Execute the function
|
||||||
|
err := fn()
|
||||||
|
|
||||||
|
cb.mutex.Lock()
|
||||||
|
defer cb.mutex.Unlock()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
cb.failures++
|
||||||
|
cb.lastFailureTime = time.Now()
|
||||||
|
|
||||||
|
if currentState == StateHalfOpen {
|
||||||
|
// If it fails in HalfOpen, go back to Open
|
||||||
|
cb.state = StateOpen
|
||||||
|
} else if cb.failures >= cb.maxFailures {
|
||||||
|
// If too many failures, open the circuit
|
||||||
|
cb.state = StateOpen
|
||||||
|
LogError(err, cb.name+" circuit breaker opened")
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Success - reset if in HalfOpen, or reset failure count
|
||||||
|
if cb.state == StateHalfOpen {
|
||||||
|
cb.state = StateClosed
|
||||||
|
cb.failures = 0
|
||||||
|
LogInfo(cb.name + " circuit breaker closed")
|
||||||
|
} else if cb.state == StateClosed && cb.failures > 0 {
|
||||||
|
// Gradually reduce failure count on success
|
||||||
|
cb.failures--
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetState returns the current state of the circuit breaker
|
||||||
|
func (cb *CircuitBreaker) GetState() CircuitState {
|
||||||
|
cb.mutex.RLock()
|
||||||
|
defer cb.mutex.RUnlock()
|
||||||
|
return cb.state
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset manually resets the circuit breaker
|
||||||
|
func (cb *CircuitBreaker) Reset() {
|
||||||
|
cb.mutex.Lock()
|
||||||
|
defer cb.mutex.Unlock()
|
||||||
|
cb.state = StateClosed
|
||||||
|
cb.failures = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// CircuitBreakerError represents a circuit breaker error
|
||||||
|
type CircuitBreakerError struct {
|
||||||
|
Name string
|
||||||
|
State string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *CircuitBreakerError) Error() string {
|
||||||
|
return "circuit breaker '" + e.Name + "' is " + e.State
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsCircuitBreakerError checks if an error is a circuit breaker error
|
||||||
|
func IsCircuitBreakerError(err error) bool {
|
||||||
|
_, ok := err.(*CircuitBreakerError)
|
||||||
|
return ok
|
||||||
|
}
|
||||||
+40
-57
@@ -3,7 +3,9 @@ package middleware
|
|||||||
import (
|
import (
|
||||||
"authorization/helper"
|
"authorization/helper"
|
||||||
"authorization/models"
|
"authorization/models"
|
||||||
|
"authorization/redisclient"
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
@@ -21,18 +23,16 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// Token cache for high-frequency requests
|
|
||||||
tokenCache = make(map[string]*models.CacheEntry)
|
|
||||||
tokenCacheMutex sync.RWMutex
|
|
||||||
|
|
||||||
// Cache JWT secret to avoid repeated os.Getenv calls
|
// Cache JWT secret to avoid repeated os.Getenv calls
|
||||||
jwtSecretOnce sync.Once
|
jwtSecretOnce sync.Once
|
||||||
jwtSecretCached []byte
|
jwtSecretCached []byte
|
||||||
jwtSecretError error
|
jwtSecretError error
|
||||||
|
|
||||||
// Pre-allocate error messages to avoid repeated allocations
|
// Pre-allocate error messages to avoid repeated allocations
|
||||||
|
|
||||||
errExpiredToken = "Invalid or expired token" // #nosec G101
|
errExpiredToken = "Invalid or expired token" // #nosec G101
|
||||||
|
|
||||||
|
// Redis key prefix for token cache
|
||||||
|
redisTokenPrefix = "jwt:token:"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Initialize JWT secret once
|
// Initialize JWT secret once
|
||||||
@@ -48,29 +48,6 @@ func getJWTSecret() ([]byte, error) {
|
|||||||
return jwtSecretCached, jwtSecretError
|
return jwtSecretCached, jwtSecretError
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clean expired cache entries periodically
|
|
||||||
func init() {
|
|
||||||
go func() {
|
|
||||||
ticker := time.NewTicker(5 * time.Minute)
|
|
||||||
defer ticker.Stop()
|
|
||||||
for range ticker.C {
|
|
||||||
cleanExpiredTokens()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
func cleanExpiredTokens() {
|
|
||||||
tokenCacheMutex.Lock()
|
|
||||||
defer tokenCacheMutex.Unlock()
|
|
||||||
|
|
||||||
now := time.Now()
|
|
||||||
for token, entry := range tokenCache {
|
|
||||||
if now.After(entry.ExpiresAt) {
|
|
||||||
delete(tokenCache, token)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// extractBearerToken extracts token from Authorization header
|
// extractBearerToken extracts token from Authorization header
|
||||||
func extractBearerToken(authHeader string) (string, bool) {
|
func extractBearerToken(authHeader string) (string, bool) {
|
||||||
if authHeader == "" || len(authHeader) < 8 || authHeader[:7] != "Bearer " {
|
if authHeader == "" || len(authHeader) < 8 || authHeader[:7] != "Bearer " {
|
||||||
@@ -79,22 +56,27 @@ func extractBearerToken(authHeader string) (string, bool) {
|
|||||||
return authHeader[7:], true
|
return authHeader[7:], true
|
||||||
}
|
}
|
||||||
|
|
||||||
// checkTokenCache retrieves token from cache if valid
|
// checkTokenCache retrieves token from Redis cache if valid
|
||||||
func checkTokenCache(tokenString string) (*models.Claims, bool) {
|
func checkTokenCache(tokenString string) (*models.Claims, bool) {
|
||||||
tokenCacheMutex.RLock()
|
if redisclient.RDB == nil {
|
||||||
defer tokenCacheMutex.RUnlock()
|
|
||||||
|
|
||||||
cached, exists := tokenCache[tokenString]
|
|
||||||
if !exists {
|
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
if time.Now().Before(cached.ExpiresAt) {
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||||
return cached.Claims, true
|
defer cancel()
|
||||||
|
|
||||||
|
key := redisTokenPrefix + tokenString
|
||||||
|
val, err := redisclient.RDB.Get(ctx, key).Result()
|
||||||
|
if err != nil {
|
||||||
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Token expired, will be cleaned up later
|
var claims models.Claims
|
||||||
return nil, false
|
if err := json.Unmarshal([]byte(val), &claims); err != nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return &claims, true
|
||||||
}
|
}
|
||||||
|
|
||||||
// removeExpiredCacheEntry removes a single expired token from cache
|
// removeExpiredCacheEntry removes a single expired token from cache
|
||||||
@@ -129,32 +111,33 @@ func parseAndValidateToken(tokenString string) (*models.Claims, error) {
|
|||||||
return claims, nil
|
return claims, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// cacheToken stores validated token in cache
|
// cacheToken stores validated token in Redis cache
|
||||||
func cacheToken(tokenString string, claims *models.Claims) {
|
func cacheToken(tokenString string, claims *models.Claims) {
|
||||||
expiresAt := time.Now().Add(5 * time.Minute)
|
if redisclient.RDB == nil {
|
||||||
if claims.ExpiresAt != nil && claims.ExpiresAt.Time.Before(expiresAt) {
|
return
|
||||||
expiresAt = claims.ExpiresAt.Time
|
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenCacheMutex.Lock()
|
// Calculate TTL
|
||||||
defer tokenCacheMutex.Unlock()
|
ttl := 5 * time.Minute
|
||||||
|
if claims.ExpiresAt != nil {
|
||||||
// Limit cache size
|
timeUntilExpiry := time.Until(claims.ExpiresAt.Time)
|
||||||
if len(tokenCache) > 10000000 {
|
if timeUntilExpiry > 0 && timeUntilExpiry < ttl {
|
||||||
count := 0
|
ttl = timeUntilExpiry
|
||||||
for k := range tokenCache {
|
|
||||||
delete(tokenCache, k)
|
|
||||||
count++
|
|
||||||
if count >= 1000000 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenCache[tokenString] = &models.CacheEntry{
|
// Serialize claims to JSON
|
||||||
Claims: claims,
|
claimsJSON, err := json.Marshal(claims)
|
||||||
ExpiresAt: expiresAt,
|
if err != nil {
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Store in Redis with TTL
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
key := redisTokenPrefix + tokenString
|
||||||
|
redisclient.RDB.Set(ctx, key, claimsJSON, ttl)
|
||||||
}
|
}
|
||||||
|
|
||||||
// JWTAuth is a middleware that validates JWT tokens with caching for high-frequency requests
|
// JWTAuth is a middleware that validates JWT tokens with caching for high-frequency requests
|
||||||
|
|||||||
@@ -0,0 +1,98 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"authorization/helper"
|
||||||
|
"authorization/models"
|
||||||
|
"authorization/redisclient"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DefaultRateLimitConfig returns default rate limiting settings
|
||||||
|
func DefaultRateLimitConfig() models.RateLimitConfig {
|
||||||
|
return models.RateLimitConfig{
|
||||||
|
RequestsPerMinute: 100,
|
||||||
|
BurstSize: 20,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RateLimiterMiddleware implements distributed rate limiting using Redis
|
||||||
|
func RateLimiterMiddleware(config models.RateLimitConfig) func(http.HandlerFunc) http.HandlerFunc {
|
||||||
|
return func(next http.HandlerFunc) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Skip rate limiting if Redis is not available
|
||||||
|
if redisclient.RDB == nil {
|
||||||
|
helper.RespondWithError(w, http.StatusServiceUnavailable, "Redis not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract user identifier (prefer user_id from JWT, fallback to IP)
|
||||||
|
var identifier string
|
||||||
|
if userID, ok := GetUserID(r); ok {
|
||||||
|
identifier = "user:" + userID
|
||||||
|
} else {
|
||||||
|
identifier = "ip:" + getClientIP(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check rate limit
|
||||||
|
allowed, err := checkRateLimit(identifier, config)
|
||||||
|
if err != nil {
|
||||||
|
// On error, fail open (allow request) but log the error
|
||||||
|
helper.LogError(err, "rate limiter error")
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !allowed {
|
||||||
|
helper.RespondWithError(w, http.StatusTooManyRequests, "Rate limit exceeded")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkRateLimit uses Redis INCR with sliding window
|
||||||
|
func checkRateLimit(identifier string, config models.RateLimitConfig) (bool, error) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
key := fmt.Sprintf("ratelimit:%s", identifier)
|
||||||
|
|
||||||
|
// Use Redis pipeline for atomic operations
|
||||||
|
pipe := redisclient.RDB.Pipeline()
|
||||||
|
|
||||||
|
incrCmd := pipe.Incr(ctx, key)
|
||||||
|
pipe.Expire(ctx, key, time.Minute)
|
||||||
|
|
||||||
|
_, err := pipe.Exec(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
count := incrCmd.Val()
|
||||||
|
|
||||||
|
// Allow burst + requests per minute
|
||||||
|
return count <= int64(config.RequestsPerMinute+config.BurstSize), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getClientIP extracts the client IP from the request
|
||||||
|
func getClientIP(r *http.Request) string {
|
||||||
|
// Check X-Forwarded-For header first (for proxies/load balancers)
|
||||||
|
forwarded := r.Header.Get("X-Forwarded-For")
|
||||||
|
if forwarded != "" {
|
||||||
|
return forwarded
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check X-Real-IP header
|
||||||
|
realIP := r.Header.Get("X-Real-IP")
|
||||||
|
if realIP != "" {
|
||||||
|
return realIP
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to RemoteAddr
|
||||||
|
return r.RemoteAddr
|
||||||
|
}
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
package models
|
||||||
|
|
||||||
|
// HealthResponse represents the health check response
|
||||||
|
type HealthResponse struct {
|
||||||
|
Status string `json:"status"`
|
||||||
|
Services map[string]string `json:"services,omitempty"`
|
||||||
|
}
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
package models
|
||||||
|
|
||||||
|
// RateLimitConfig holds rate limiting configuration
|
||||||
|
type RateLimitConfig struct {
|
||||||
|
RequestsPerMinute int
|
||||||
|
BurstSize int
|
||||||
|
}
|
||||||
+14
-4
@@ -1,17 +1,20 @@
|
|||||||
// pkg/redisclient/redis.go
|
|
||||||
|
|
||||||
package redisclient
|
package redisclient
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"authorization/helper"
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/redis/go-redis/v9"
|
"github.com/redis/go-redis/v9"
|
||||||
)
|
)
|
||||||
|
|
||||||
var RDB *redis.Client
|
var RDB *redis.Client
|
||||||
|
|
||||||
|
// RedisCircuitBreaker protects Redis operations
|
||||||
|
var RedisCircuitBreaker *helper.CircuitBreaker
|
||||||
|
|
||||||
func Init() {
|
func Init() {
|
||||||
redisHost := os.Getenv("REDIS_HOST")
|
redisHost := os.Getenv("REDIS_HOST")
|
||||||
if redisHost == "" {
|
if redisHost == "" {
|
||||||
@@ -39,9 +42,16 @@ func Init() {
|
|||||||
|
|
||||||
RDB = redis.NewClient(opts)
|
RDB = redis.NewClient(opts)
|
||||||
|
|
||||||
// Test connection with authentication
|
// Initialize circuit breaker
|
||||||
|
RedisCircuitBreaker = helper.NewCircuitBreaker("redis", 5, 2*time.Second)
|
||||||
|
|
||||||
|
// Test connection with authentication using circuit breaker
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
if _, err := RDB.Ping(ctx).Result(); err != nil {
|
err := RedisCircuitBreaker.Call(func() error {
|
||||||
|
_, err := RDB.Ping(ctx).Result()
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
panic(fmt.Sprintf("Could not connect to Redis: %v", err))
|
panic(fmt.Sprintf("Could not connect to Redis: %v", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
+9
-1
@@ -10,8 +10,16 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func SetupRoutes(router *mux.Router, db *sql.DB) {
|
func SetupRoutes(router *mux.Router, db *sql.DB) {
|
||||||
|
// Health check endpoints (no auth required)
|
||||||
|
router.HandleFunc("/health", handlers.HealthHandler).Methods("GET")
|
||||||
|
router.HandleFunc("/ready", handlers.ReadyHandler).Methods("GET")
|
||||||
|
|
||||||
|
// Rate limit configuration
|
||||||
|
rateLimitConfig := middleware.DefaultRateLimitConfig()
|
||||||
|
rateLimiter := middleware.RateLimiterMiddleware(rateLimitConfig)
|
||||||
|
|
||||||
authRoutes := router.PathPrefix("/v1/auth").Subrouter()
|
authRoutes := router.PathPrefix("/v1/auth").Subrouter()
|
||||||
authRoutes.HandleFunc("/check", middleware.JWTAuth(handlers.AuthorizeHandler)).Methods("POST")
|
authRoutes.HandleFunc("/check", rateLimiter(middleware.JWTAuth(handlers.AuthorizeHandler))).Methods("POST")
|
||||||
|
|
||||||
router.PathPrefix("/swagger/").Handler(httpSwagger.WrapHandler)
|
router.PathPrefix("/swagger/").Handler(httpSwagger.WrapHandler)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user