329 lines
8.1 KiB
Go
329 lines
8.1 KiB
Go
package middleware
|
|
|
|
import (
|
|
"authorization/models"
|
|
"authorization/redisclient"
|
|
"context"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/go-redis/redismock/v9"
|
|
)
|
|
|
|
func TestDefaultRateLimitConfig(t *testing.T) {
|
|
config := DefaultRateLimitConfig()
|
|
|
|
if config.RequestsPerMinute != 100 {
|
|
t.Errorf("RequestsPerMinute = %v, want 100", config.RequestsPerMinute)
|
|
}
|
|
|
|
if config.BurstSize != 20 {
|
|
t.Errorf("BurstSize = %v, want 20", config.BurstSize)
|
|
}
|
|
}
|
|
|
|
func TestGetClientIP(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
xForwardedFor string
|
|
xRealIP string
|
|
remoteAddr string
|
|
expectedIP string
|
|
}{
|
|
{
|
|
name: "uses X-Forwarded-For when present",
|
|
xForwardedFor: "192.168.1.1",
|
|
xRealIP: "192.168.1.2",
|
|
remoteAddr: "192.168.1.3:1234",
|
|
expectedIP: "192.168.1.1",
|
|
},
|
|
{
|
|
name: "uses X-Real-IP when X-Forwarded-For absent",
|
|
xForwardedFor: "",
|
|
xRealIP: "192.168.1.2",
|
|
remoteAddr: "192.168.1.3:1234",
|
|
expectedIP: "192.168.1.2",
|
|
},
|
|
{
|
|
name: "uses RemoteAddr when both headers absent",
|
|
xForwardedFor: "",
|
|
xRealIP: "",
|
|
remoteAddr: "192.168.1.3:1234",
|
|
expectedIP: "192.168.1.3:1234",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
|
req.RemoteAddr = tt.remoteAddr
|
|
if tt.xForwardedFor != "" {
|
|
req.Header.Set("X-Forwarded-For", tt.xForwardedFor)
|
|
}
|
|
if tt.xRealIP != "" {
|
|
req.Header.Set("X-Real-IP", tt.xRealIP)
|
|
}
|
|
|
|
got := getClientIP(req)
|
|
if got != tt.expectedIP {
|
|
t.Errorf("getClientIP() = %v, want %v", got, tt.expectedIP)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestCheckRateLimitAllowedRequests(t *testing.T) {
|
|
db, mock := redismock.NewClientMock()
|
|
originalRedis := redisclient.RDB
|
|
redisclient.RDB = db
|
|
defer func() { redisclient.RDB = originalRedis }()
|
|
|
|
config := models.RateLimitConfig{
|
|
RequestsPerMinute: 100,
|
|
BurstSize: 20,
|
|
}
|
|
|
|
identifier := "user:test123"
|
|
key := "ratelimit:user:test123"
|
|
|
|
// Mock Redis INCR returning 10 (within limit)
|
|
mock.ExpectIncr(key).SetVal(10)
|
|
mock.ExpectExpire(key, time.Minute).SetVal(true)
|
|
|
|
allowed, err := checkRateLimit(identifier, config)
|
|
|
|
if err != nil {
|
|
t.Fatalf("checkRateLimit() error = %v", err)
|
|
}
|
|
|
|
if !allowed {
|
|
t.Error("checkRateLimit() should allow request")
|
|
}
|
|
|
|
if err := mock.ExpectationsWereMet(); err != nil {
|
|
t.Errorf("unfulfilled mock expectations: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestCheckRateLimitExceedsLimit(t *testing.T) {
|
|
db, mock := redismock.NewClientMock()
|
|
originalRedis := redisclient.RDB
|
|
redisclient.RDB = db
|
|
defer func() { redisclient.RDB = originalRedis }()
|
|
|
|
config := models.RateLimitConfig{
|
|
RequestsPerMinute: 100,
|
|
BurstSize: 20,
|
|
}
|
|
|
|
identifier := "user:test123"
|
|
key := "ratelimit:user:test123"
|
|
|
|
// Mock Redis INCR returning 121 (exceeds limit of 120)
|
|
mock.ExpectIncr(key).SetVal(121)
|
|
mock.ExpectExpire(key, time.Minute).SetVal(true)
|
|
|
|
allowed, err := checkRateLimit(identifier, config)
|
|
|
|
if err != nil {
|
|
t.Fatalf("checkRateLimit() error = %v", err)
|
|
}
|
|
|
|
if allowed {
|
|
t.Error("checkRateLimit() should block request when limit exceeded")
|
|
}
|
|
|
|
if err := mock.ExpectationsWereMet(); err != nil {
|
|
t.Errorf("unfulfilled mock expectations: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestCheckRateLimitRedisError(t *testing.T) {
|
|
db, mock := redismock.NewClientMock()
|
|
originalRedis := redisclient.RDB
|
|
redisclient.RDB = db
|
|
defer func() { redisclient.RDB = originalRedis }()
|
|
|
|
config := models.RateLimitConfig{
|
|
RequestsPerMinute: 100,
|
|
BurstSize: 20,
|
|
}
|
|
|
|
identifier := "user:test123"
|
|
key := "ratelimit:user:test123"
|
|
|
|
// Mock Redis error
|
|
mock.ExpectIncr(key).SetErr(context.DeadlineExceeded)
|
|
mock.ExpectExpire(key, time.Minute).SetVal(true)
|
|
|
|
allowed, err := checkRateLimit(identifier, config)
|
|
|
|
if err == nil {
|
|
t.Error("checkRateLimit() should return error when Redis fails")
|
|
}
|
|
|
|
if allowed {
|
|
t.Error("checkRateLimit() should not allow when error occurs")
|
|
}
|
|
}
|
|
|
|
func TestRateLimiterMiddlewareRedisNotAvailable(t *testing.T) {
|
|
originalRedis := redisclient.RDB
|
|
redisclient.RDB = nil
|
|
defer func() { redisclient.RDB = originalRedis }()
|
|
|
|
config := DefaultRateLimitConfig()
|
|
middleware := RateLimiterMiddleware(config)
|
|
|
|
handlerCalled := false
|
|
handler := middleware(func(w http.ResponseWriter, r *http.Request) {
|
|
handlerCalled = true
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
|
w := httptest.NewRecorder()
|
|
|
|
handler(w, req)
|
|
|
|
resp := w.Result()
|
|
defer resp.Body.Close()
|
|
|
|
// With fail-open behavior, requests should be allowed when Redis is unavailable
|
|
if resp.StatusCode != http.StatusOK {
|
|
t.Errorf("status = %v, want %v (fail-open behavior)", resp.StatusCode, http.StatusOK)
|
|
}
|
|
|
|
if !handlerCalled {
|
|
t.Error("handler should be called when Redis is not available (fail-open)")
|
|
}
|
|
}
|
|
|
|
func TestRateLimiterMiddlewareAllowsRequest(t *testing.T) {
|
|
db, mock := redismock.NewClientMock()
|
|
originalRedis := redisclient.RDB
|
|
redisclient.RDB = db
|
|
defer func() { redisclient.RDB = originalRedis }()
|
|
|
|
config := models.RateLimitConfig{
|
|
RequestsPerMinute: 100,
|
|
BurstSize: 20,
|
|
}
|
|
|
|
// Mock Redis response for allowed request
|
|
mock.MatchExpectationsInOrder(false)
|
|
mock.ExpectIncr("ratelimit:ip:192.168.1.1:1234").SetVal(5)
|
|
mock.ExpectExpire("ratelimit:ip:192.168.1.1:1234", time.Minute).SetVal(true)
|
|
|
|
middleware := RateLimiterMiddleware(config)
|
|
|
|
handlerCalled := false
|
|
handler := middleware(func(w http.ResponseWriter, r *http.Request) {
|
|
handlerCalled = true
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
|
req.RemoteAddr = "192.168.1.1:1234"
|
|
w := httptest.NewRecorder()
|
|
|
|
handler(w, req)
|
|
|
|
resp := w.Result()
|
|
defer resp.Body.Close()
|
|
|
|
if !handlerCalled {
|
|
t.Error("handler should be called when rate limit not exceeded")
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
t.Errorf("status = %v, want %v", resp.StatusCode, http.StatusOK)
|
|
}
|
|
}
|
|
|
|
func TestRateLimiterMiddlewareBlocksRequest(t *testing.T) {
|
|
db, mock := redismock.NewClientMock()
|
|
originalRedis := redisclient.RDB
|
|
redisclient.RDB = db
|
|
defer func() { redisclient.RDB = originalRedis }()
|
|
|
|
config := models.RateLimitConfig{
|
|
RequestsPerMinute: 100,
|
|
BurstSize: 20,
|
|
}
|
|
|
|
// Mock Redis response for blocked request
|
|
mock.MatchExpectationsInOrder(false)
|
|
mock.ExpectIncr("ratelimit:ip:192.168.1.1:1234").SetVal(121)
|
|
mock.ExpectExpire("ratelimit:ip:192.168.1.1:1234", time.Minute).SetVal(true)
|
|
|
|
middleware := RateLimiterMiddleware(config)
|
|
|
|
handlerCalled := false
|
|
handler := middleware(func(w http.ResponseWriter, r *http.Request) {
|
|
handlerCalled = true
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
|
req.RemoteAddr = "192.168.1.1:1234"
|
|
w := httptest.NewRecorder()
|
|
|
|
handler(w, req)
|
|
|
|
resp := w.Result()
|
|
defer resp.Body.Close()
|
|
|
|
if handlerCalled {
|
|
t.Error("handler should not be called when rate limit exceeded")
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusTooManyRequests {
|
|
t.Errorf("status = %v, want %v", resp.StatusCode, http.StatusTooManyRequests)
|
|
}
|
|
}
|
|
|
|
func TestRateLimiterMiddlewareFailsOpenOnError(t *testing.T) {
|
|
db, mock := redismock.NewClientMock()
|
|
originalRedis := redisclient.RDB
|
|
redisclient.RDB = db
|
|
defer func() { redisclient.RDB = originalRedis }()
|
|
|
|
config := models.RateLimitConfig{
|
|
RequestsPerMinute: 100,
|
|
BurstSize: 20,
|
|
}
|
|
|
|
// Mock Redis error
|
|
mock.MatchExpectationsInOrder(false)
|
|
mock.ExpectIncr("ratelimit:ip:192.168.1.1:1234").SetErr(context.DeadlineExceeded)
|
|
mock.ExpectExpire("ratelimit:ip:192.168.1.1:1234", time.Minute).SetVal(true)
|
|
|
|
middleware := RateLimiterMiddleware(config)
|
|
|
|
handlerCalled := false
|
|
handler := middleware(func(w http.ResponseWriter, r *http.Request) {
|
|
handlerCalled = true
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
|
req.RemoteAddr = "192.168.1.1:1234"
|
|
w := httptest.NewRecorder()
|
|
|
|
handler(w, req)
|
|
|
|
resp := w.Result()
|
|
defer resp.Body.Close()
|
|
|
|
if !handlerCalled {
|
|
t.Error("handler should be called when Redis errors (fail open)")
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
t.Errorf("status = %v, want %v", resp.StatusCode, http.StatusOK)
|
|
}
|
|
}
|