package middleware import ( "authorization/models" "authorization/redisclient" "context" "fmt" "net/http" "time" sabat "github.com/cespares/response" ) // DefaultRateLimitConfig returns default rate limiting settings func DefaultRateLimitConfig() models.RateLimitConfig { return models.RateLimitConfig{ RequestsPerMinute: 100, BurstSize: 20, } } // RateLimiterMiddleware implements distributed rate limiting using Redis func RateLimiterMiddleware(config models.RateLimitConfig) func(http.HandlerFunc) http.HandlerFunc { return func(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // Fail-open: Skip rate limiting if Redis is not available (prevents full outage) if redisclient.RDB == nil { sabat.LogError(nil, "Rate limiter: Redis not available, allowing request (fail-open)") next.ServeHTTP(w, r) return } // Extract user identifier (prefer users_id from JWT, fallback to IP) var identifier string if userID, ok := GetUserID(r); ok { identifier = "user:" + userID } else { identifier = "ip:" + getClientIP(r) } // Check rate limit allowed, err := checkRateLimit(identifier, config) if err != nil { // On error, fail open (allow request) but log the error sabat.LogError(err, "rate limiter error") next.ServeHTTP(w, r) return } if !allowed { sabat.RespondWithError(w, http.StatusTooManyRequests, "Rate limit exceeded") return } next.ServeHTTP(w, r) } } } // checkRateLimit uses Redis INCR with sliding window func checkRateLimit(identifier string, config models.RateLimitConfig) (bool, error) { ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() key := fmt.Sprintf("ratelimit:%s", identifier) // Use Redis pipeline for atomic operations pipe := redisclient.RDB.Pipeline() incrCmd := pipe.Incr(ctx, key) pipe.Expire(ctx, key, time.Minute) _, err := pipe.Exec(ctx) if err != nil { return false, err } count := incrCmd.Val() // Allow burst + requests per minute return count <= int64(config.RequestsPerMinute+config.BurstSize), nil } // getClientIP extracts the client IP from the request func getClientIP(r *http.Request) string { // Check X-Forwarded-For header first (for proxies/load balancers) forwarded := r.Header.Get("X-Forwarded-For") if forwarded != "" { return forwarded } // Check X-Real-IP header realIP := r.Header.Get("X-Real-IP") if realIP != "" { return realIP } // Fallback to RemoteAddr return r.RemoteAddr }