101 lines
2.5 KiB
Go
101 lines
2.5 KiB
Go
package middleware
|
|
|
|
import (
|
|
"authorization/models"
|
|
"authorization/redisclient"
|
|
"context"
|
|
"fmt"
|
|
"net/http"
|
|
"time"
|
|
|
|
sabat "github.com/cespares/response"
|
|
)
|
|
|
|
// DefaultRateLimitConfig returns default rate limiting settings
|
|
func DefaultRateLimitConfig() models.RateLimitConfig {
|
|
return models.RateLimitConfig{
|
|
RequestsPerMinute: 100,
|
|
BurstSize: 20,
|
|
}
|
|
}
|
|
|
|
// RateLimiterMiddleware implements distributed rate limiting using Redis
|
|
func RateLimiterMiddleware(config models.RateLimitConfig) func(http.HandlerFunc) http.HandlerFunc {
|
|
return func(next http.HandlerFunc) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
// Fail-open: Skip rate limiting if Redis is not available (prevents full outage)
|
|
if redisclient.RDB == nil {
|
|
sabat.LogError(nil, "Rate limiter: Redis not available, allowing request (fail-open)")
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
// Extract user identifier (prefer users_id from JWT, fallback to IP)
|
|
var identifier string
|
|
if userID, ok := GetUserID(r); ok {
|
|
identifier = "user:" + userID
|
|
} else {
|
|
identifier = "ip:" + getClientIP(r)
|
|
}
|
|
|
|
// Check rate limit
|
|
allowed, err := checkRateLimit(identifier, config)
|
|
if err != nil {
|
|
// On error, fail open (allow request) but log the error
|
|
sabat.LogError(err, "rate limiter error")
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
if !allowed {
|
|
sabat.RespondWithError(w, http.StatusTooManyRequests, "Rate limit exceeded")
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
}
|
|
}
|
|
}
|
|
|
|
// checkRateLimit uses Redis INCR with sliding window
|
|
func checkRateLimit(identifier string, config models.RateLimitConfig) (bool, error) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
|
defer cancel()
|
|
|
|
key := fmt.Sprintf("ratelimit:%s", identifier)
|
|
|
|
// Use Redis pipeline for atomic operations
|
|
pipe := redisclient.RDB.Pipeline()
|
|
|
|
incrCmd := pipe.Incr(ctx, key)
|
|
pipe.Expire(ctx, key, time.Minute)
|
|
|
|
_, err := pipe.Exec(ctx)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
count := incrCmd.Val()
|
|
|
|
// Allow burst + requests per minute
|
|
return count <= int64(config.RequestsPerMinute+config.BurstSize), nil
|
|
}
|
|
|
|
// getClientIP extracts the client IP from the request
|
|
func getClientIP(r *http.Request) string {
|
|
// Check X-Forwarded-For header first (for proxies/load balancers)
|
|
forwarded := r.Header.Get("X-Forwarded-For")
|
|
if forwarded != "" {
|
|
return forwarded
|
|
}
|
|
|
|
// Check X-Real-IP header
|
|
realIP := r.Header.Get("X-Real-IP")
|
|
if realIP != "" {
|
|
return realIP
|
|
}
|
|
|
|
// Fallback to RemoteAddr
|
|
return r.RemoteAddr
|
|
}
|