Files
Authorization/middleware/rate_limiter.go
T

101 lines
2.5 KiB
Go

package middleware
import (
"authorization/models"
"authorization/redisclient"
"context"
"fmt"
"net/http"
"time"
sabat "github.com/cespares/response"
)
// DefaultRateLimitConfig returns default rate limiting settings
func DefaultRateLimitConfig() models.RateLimitConfig {
return models.RateLimitConfig{
RequestsPerMinute: 100,
BurstSize: 20,
}
}
// RateLimiterMiddleware implements distributed rate limiting using Redis
func RateLimiterMiddleware(config models.RateLimitConfig) func(http.HandlerFunc) http.HandlerFunc {
return func(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Fail-open: Skip rate limiting if Redis is not available (prevents full outage)
if redisclient.RDB == nil {
sabat.LogError(nil, "Rate limiter: Redis not available, allowing request (fail-open)")
next.ServeHTTP(w, r)
return
}
// Extract user identifier (prefer users_id from JWT, fallback to IP)
var identifier string
if userID, ok := GetUserID(r); ok {
identifier = "user:" + userID
} else {
identifier = "ip:" + getClientIP(r)
}
// Check rate limit
allowed, err := checkRateLimit(identifier, config)
if err != nil {
// On error, fail open (allow request) but log the error
sabat.LogError(err, "rate limiter error")
next.ServeHTTP(w, r)
return
}
if !allowed {
sabat.RespondWithError(w, http.StatusTooManyRequests, "Rate limit exceeded")
return
}
next.ServeHTTP(w, r)
}
}
}
// checkRateLimit uses Redis INCR with sliding window
func checkRateLimit(identifier string, config models.RateLimitConfig) (bool, error) {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
key := fmt.Sprintf("ratelimit:%s", identifier)
// Use Redis pipeline for atomic operations
pipe := redisclient.RDB.Pipeline()
incrCmd := pipe.Incr(ctx, key)
pipe.Expire(ctx, key, time.Minute)
_, err := pipe.Exec(ctx)
if err != nil {
return false, err
}
count := incrCmd.Val()
// Allow burst + requests per minute
return count <= int64(config.RequestsPerMinute+config.BurstSize), nil
}
// getClientIP extracts the client IP from the request
func getClientIP(r *http.Request) string {
// Check X-Forwarded-For header first (for proxies/load balancers)
forwarded := r.Header.Get("X-Forwarded-For")
if forwarded != "" {
return forwarded
}
// Check X-Real-IP header
realIP := r.Header.Get("X-Real-IP")
if realIP != "" {
return realIP
}
// Fallback to RemoteAddr
return r.RemoteAddr
}