Files
Authorization/middleware/rate_limiter_test.go
T
2025-12-16 10:57:26 +08:00

327 lines
8.0 KiB
Go

package middleware
import (
"authorization/models"
"authorization/redisclient"
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/go-redis/redismock/v9"
)
func TestDefaultRateLimitConfig(t *testing.T) {
config := DefaultRateLimitConfig()
if config.RequestsPerMinute != 100 {
t.Errorf("RequestsPerMinute = %v, want 100", config.RequestsPerMinute)
}
if config.BurstSize != 20 {
t.Errorf("BurstSize = %v, want 20", config.BurstSize)
}
}
func TestGetClientIP(t *testing.T) {
tests := []struct {
name string
xForwardedFor string
xRealIP string
remoteAddr string
expectedIP string
}{
{
name: "uses X-Forwarded-For when present",
xForwardedFor: "192.168.1.1",
xRealIP: "192.168.1.2",
remoteAddr: "192.168.1.3:1234",
expectedIP: "192.168.1.1",
},
{
name: "uses X-Real-IP when X-Forwarded-For absent",
xForwardedFor: "",
xRealIP: "192.168.1.2",
remoteAddr: "192.168.1.3:1234",
expectedIP: "192.168.1.2",
},
{
name: "uses RemoteAddr when both headers absent",
xForwardedFor: "",
xRealIP: "",
remoteAddr: "192.168.1.3:1234",
expectedIP: "192.168.1.3:1234",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.RemoteAddr = tt.remoteAddr
if tt.xForwardedFor != "" {
req.Header.Set("X-Forwarded-For", tt.xForwardedFor)
}
if tt.xRealIP != "" {
req.Header.Set("X-Real-IP", tt.xRealIP)
}
got := getClientIP(req)
if got != tt.expectedIP {
t.Errorf("getClientIP() = %v, want %v", got, tt.expectedIP)
}
})
}
}
func TestCheckRateLimit_AllowedRequests(t *testing.T) {
db, mock := redismock.NewClientMock()
originalRedis := redisclient.RDB
redisclient.RDB = db
defer func() { redisclient.RDB = originalRedis }()
config := models.RateLimitConfig{
RequestsPerMinute: 100,
BurstSize: 20,
}
identifier := "user:test123"
key := "ratelimit:user:test123"
// Mock Redis INCR returning 10 (within limit)
mock.ExpectIncr(key).SetVal(10)
mock.ExpectExpire(key, time.Minute).SetVal(true)
allowed, err := checkRateLimit(identifier, config)
if err != nil {
t.Fatalf("checkRateLimit() error = %v", err)
}
if !allowed {
t.Error("checkRateLimit() should allow request")
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unfulfilled mock expectations: %v", err)
}
}
func TestCheckRateLimit_ExceedsLimit(t *testing.T) {
db, mock := redismock.NewClientMock()
originalRedis := redisclient.RDB
redisclient.RDB = db
defer func() { redisclient.RDB = originalRedis }()
config := models.RateLimitConfig{
RequestsPerMinute: 100,
BurstSize: 20,
}
identifier := "user:test123"
key := "ratelimit:user:test123"
// Mock Redis INCR returning 121 (exceeds limit of 120)
mock.ExpectIncr(key).SetVal(121)
mock.ExpectExpire(key, time.Minute).SetVal(true)
allowed, err := checkRateLimit(identifier, config)
if err != nil {
t.Fatalf("checkRateLimit() error = %v", err)
}
if allowed {
t.Error("checkRateLimit() should block request when limit exceeded")
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unfulfilled mock expectations: %v", err)
}
}
func TestCheckRateLimit_RedisError(t *testing.T) {
db, mock := redismock.NewClientMock()
originalRedis := redisclient.RDB
redisclient.RDB = db
defer func() { redisclient.RDB = originalRedis }()
config := models.RateLimitConfig{
RequestsPerMinute: 100,
BurstSize: 20,
}
identifier := "user:test123"
key := "ratelimit:user:test123"
// Mock Redis error
mock.ExpectIncr(key).SetErr(context.DeadlineExceeded)
mock.ExpectExpire(key, time.Minute).SetVal(true)
allowed, err := checkRateLimit(identifier, config)
if err == nil {
t.Error("checkRateLimit() should return error when Redis fails")
}
if allowed {
t.Error("checkRateLimit() should not allow when error occurs")
}
}
func TestRateLimiterMiddleware_RedisNotAvailable(t *testing.T) {
originalRedis := redisclient.RDB
redisclient.RDB = nil
defer func() { redisclient.RDB = originalRedis }()
config := DefaultRateLimitConfig()
middleware := RateLimiterMiddleware(config)
handlerCalled := false
handler := middleware(func(w http.ResponseWriter, r *http.Request) {
handlerCalled = true
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
w := httptest.NewRecorder()
handler(w, req)
resp := w.Result()
defer resp.Body.Close()
if resp.StatusCode != http.StatusServiceUnavailable {
t.Errorf("status = %v, want %v", resp.StatusCode, http.StatusServiceUnavailable)
}
if handlerCalled {
t.Error("handler should not be called when Redis is not available")
}
}
func TestRateLimiterMiddleware_AllowsRequest(t *testing.T) {
db, mock := redismock.NewClientMock()
originalRedis := redisclient.RDB
redisclient.RDB = db
defer func() { redisclient.RDB = originalRedis }()
config := models.RateLimitConfig{
RequestsPerMinute: 100,
BurstSize: 20,
}
// Mock Redis response for allowed request
mock.MatchExpectationsInOrder(false)
mock.ExpectIncr("ratelimit:ip:192.168.1.1:1234").SetVal(5)
mock.ExpectExpire("ratelimit:ip:192.168.1.1:1234", time.Minute).SetVal(true)
middleware := RateLimiterMiddleware(config)
handlerCalled := false
handler := middleware(func(w http.ResponseWriter, r *http.Request) {
handlerCalled = true
w.WriteHeader(http.StatusOK)
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.RemoteAddr = "192.168.1.1:1234"
w := httptest.NewRecorder()
handler(w, req)
resp := w.Result()
defer resp.Body.Close()
if !handlerCalled {
t.Error("handler should be called when rate limit not exceeded")
}
if resp.StatusCode != http.StatusOK {
t.Errorf("status = %v, want %v", resp.StatusCode, http.StatusOK)
}
}
func TestRateLimiterMiddleware_BlocksRequest(t *testing.T) {
db, mock := redismock.NewClientMock()
originalRedis := redisclient.RDB
redisclient.RDB = db
defer func() { redisclient.RDB = originalRedis }()
config := models.RateLimitConfig{
RequestsPerMinute: 100,
BurstSize: 20,
}
// Mock Redis response for blocked request
mock.MatchExpectationsInOrder(false)
mock.ExpectIncr("ratelimit:ip:192.168.1.1:1234").SetVal(121)
mock.ExpectExpire("ratelimit:ip:192.168.1.1:1234", time.Minute).SetVal(true)
middleware := RateLimiterMiddleware(config)
handlerCalled := false
handler := middleware(func(w http.ResponseWriter, r *http.Request) {
handlerCalled = true
w.WriteHeader(http.StatusOK)
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.RemoteAddr = "192.168.1.1:1234"
w := httptest.NewRecorder()
handler(w, req)
resp := w.Result()
defer resp.Body.Close()
if handlerCalled {
t.Error("handler should not be called when rate limit exceeded")
}
if resp.StatusCode != http.StatusTooManyRequests {
t.Errorf("status = %v, want %v", resp.StatusCode, http.StatusTooManyRequests)
}
}
func TestRateLimiterMiddleware_FailsOpenOnError(t *testing.T) {
db, mock := redismock.NewClientMock()
originalRedis := redisclient.RDB
redisclient.RDB = db
defer func() { redisclient.RDB = originalRedis }()
config := models.RateLimitConfig{
RequestsPerMinute: 100,
BurstSize: 20,
}
// Mock Redis error
mock.MatchExpectationsInOrder(false)
mock.ExpectIncr("ratelimit:ip:192.168.1.1:1234").SetErr(context.DeadlineExceeded)
mock.ExpectExpire("ratelimit:ip:192.168.1.1:1234", time.Minute).SetVal(true)
middleware := RateLimiterMiddleware(config)
handlerCalled := false
handler := middleware(func(w http.ResponseWriter, r *http.Request) {
handlerCalled = true
w.WriteHeader(http.StatusOK)
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.RemoteAddr = "192.168.1.1:1234"
w := httptest.NewRecorder()
handler(w, req)
resp := w.Result()
defer resp.Body.Close()
if !handlerCalled {
t.Error("handler should be called when Redis errors (fail open)")
}
if resp.StatusCode != http.StatusOK {
t.Errorf("status = %v, want %v", resp.StatusCode, http.StatusOK)
}
}