package middleware import ( "authorization/models" "authorization/redisclient" "context" "net/http" "net/http/httptest" "testing" "time" "github.com/go-redis/redismock/v9" ) func TestDefaultRateLimitConfig(t *testing.T) { config := DefaultRateLimitConfig() if config.RequestsPerMinute != 100 { t.Errorf("RequestsPerMinute = %v, want 100", config.RequestsPerMinute) } if config.BurstSize != 20 { t.Errorf("BurstSize = %v, want 20", config.BurstSize) } } func TestGetClientIP(t *testing.T) { tests := []struct { name string xForwardedFor string xRealIP string remoteAddr string expectedIP string }{ { name: "uses X-Forwarded-For when present", xForwardedFor: "192.168.1.1", xRealIP: "192.168.1.2", remoteAddr: "192.168.1.3:1234", expectedIP: "192.168.1.1", }, { name: "uses X-Real-IP when X-Forwarded-For absent", xForwardedFor: "", xRealIP: "192.168.1.2", remoteAddr: "192.168.1.3:1234", expectedIP: "192.168.1.2", }, { name: "uses RemoteAddr when both headers absent", xForwardedFor: "", xRealIP: "", remoteAddr: "192.168.1.3:1234", expectedIP: "192.168.1.3:1234", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/test", nil) req.RemoteAddr = tt.remoteAddr if tt.xForwardedFor != "" { req.Header.Set("X-Forwarded-For", tt.xForwardedFor) } if tt.xRealIP != "" { req.Header.Set("X-Real-IP", tt.xRealIP) } got := getClientIP(req) if got != tt.expectedIP { t.Errorf("getClientIP() = %v, want %v", got, tt.expectedIP) } }) } } func TestCheckRateLimitAllowedRequests(t *testing.T) { db, mock := redismock.NewClientMock() originalRedis := redisclient.RDB redisclient.RDB = db defer func() { redisclient.RDB = originalRedis }() config := models.RateLimitConfig{ RequestsPerMinute: 100, BurstSize: 20, } identifier := "user:test123" key := "ratelimit:user:test123" // Mock Redis INCR returning 10 (within limit) mock.ExpectIncr(key).SetVal(10) mock.ExpectExpire(key, time.Minute).SetVal(true) allowed, err := checkRateLimit(identifier, config) if err != nil { t.Fatalf("checkRateLimit() error = %v", err) } if !allowed { t.Error("checkRateLimit() should allow request") } if err := mock.ExpectationsWereMet(); err != nil { t.Errorf("unfulfilled mock expectations: %v", err) } } func TestCheckRateLimitExceedsLimit(t *testing.T) { db, mock := redismock.NewClientMock() originalRedis := redisclient.RDB redisclient.RDB = db defer func() { redisclient.RDB = originalRedis }() config := models.RateLimitConfig{ RequestsPerMinute: 100, BurstSize: 20, } identifier := "user:test123" key := "ratelimit:user:test123" // Mock Redis INCR returning 121 (exceeds limit of 120) mock.ExpectIncr(key).SetVal(121) mock.ExpectExpire(key, time.Minute).SetVal(true) allowed, err := checkRateLimit(identifier, config) if err != nil { t.Fatalf("checkRateLimit() error = %v", err) } if allowed { t.Error("checkRateLimit() should block request when limit exceeded") } if err := mock.ExpectationsWereMet(); err != nil { t.Errorf("unfulfilled mock expectations: %v", err) } } func TestCheckRateLimitRedisError(t *testing.T) { db, mock := redismock.NewClientMock() originalRedis := redisclient.RDB redisclient.RDB = db defer func() { redisclient.RDB = originalRedis }() config := models.RateLimitConfig{ RequestsPerMinute: 100, BurstSize: 20, } identifier := "user:test123" key := "ratelimit:user:test123" // Mock Redis error mock.ExpectIncr(key).SetErr(context.DeadlineExceeded) mock.ExpectExpire(key, time.Minute).SetVal(true) allowed, err := checkRateLimit(identifier, config) if err == nil { t.Error("checkRateLimit() should return error when Redis fails") } if allowed { t.Error("checkRateLimit() should not allow when error occurs") } } func TestRateLimiterMiddlewareRedisNotAvailable(t *testing.T) { originalRedis := redisclient.RDB redisclient.RDB = nil defer func() { redisclient.RDB = originalRedis }() config := DefaultRateLimitConfig() middleware := RateLimiterMiddleware(config) handlerCalled := false handler := middleware(func(w http.ResponseWriter, r *http.Request) { handlerCalled = true w.WriteHeader(http.StatusOK) }) req := httptest.NewRequest(http.MethodGet, "/test", nil) w := httptest.NewRecorder() handler(w, req) resp := w.Result() defer resp.Body.Close() // With fail-open behavior, requests should be allowed when Redis is unavailable if resp.StatusCode != http.StatusOK { t.Errorf("status = %v, want %v (fail-open behavior)", resp.StatusCode, http.StatusOK) } if !handlerCalled { t.Error("handler should be called when Redis is not available (fail-open)") } } func TestRateLimiterMiddlewareAllowsRequest(t *testing.T) { db, mock := redismock.NewClientMock() originalRedis := redisclient.RDB redisclient.RDB = db defer func() { redisclient.RDB = originalRedis }() config := models.RateLimitConfig{ RequestsPerMinute: 100, BurstSize: 20, } // Mock Redis response for allowed request mock.MatchExpectationsInOrder(false) mock.ExpectIncr("ratelimit:ip:192.168.1.1:1234").SetVal(5) mock.ExpectExpire("ratelimit:ip:192.168.1.1:1234", time.Minute).SetVal(true) middleware := RateLimiterMiddleware(config) handlerCalled := false handler := middleware(func(w http.ResponseWriter, r *http.Request) { handlerCalled = true w.WriteHeader(http.StatusOK) }) req := httptest.NewRequest(http.MethodGet, "/test", nil) req.RemoteAddr = "192.168.1.1:1234" w := httptest.NewRecorder() handler(w, req) resp := w.Result() defer resp.Body.Close() if !handlerCalled { t.Error("handler should be called when rate limit not exceeded") } if resp.StatusCode != http.StatusOK { t.Errorf("status = %v, want %v", resp.StatusCode, http.StatusOK) } } func TestRateLimiterMiddlewareBlocksRequest(t *testing.T) { db, mock := redismock.NewClientMock() originalRedis := redisclient.RDB redisclient.RDB = db defer func() { redisclient.RDB = originalRedis }() config := models.RateLimitConfig{ RequestsPerMinute: 100, BurstSize: 20, } // Mock Redis response for blocked request mock.MatchExpectationsInOrder(false) mock.ExpectIncr("ratelimit:ip:192.168.1.1:1234").SetVal(121) mock.ExpectExpire("ratelimit:ip:192.168.1.1:1234", time.Minute).SetVal(true) middleware := RateLimiterMiddleware(config) handlerCalled := false handler := middleware(func(w http.ResponseWriter, r *http.Request) { handlerCalled = true w.WriteHeader(http.StatusOK) }) req := httptest.NewRequest(http.MethodGet, "/test", nil) req.RemoteAddr = "192.168.1.1:1234" w := httptest.NewRecorder() handler(w, req) resp := w.Result() defer resp.Body.Close() if handlerCalled { t.Error("handler should not be called when rate limit exceeded") } if resp.StatusCode != http.StatusTooManyRequests { t.Errorf("status = %v, want %v", resp.StatusCode, http.StatusTooManyRequests) } } func TestRateLimiterMiddlewareFailsOpenOnError(t *testing.T) { db, mock := redismock.NewClientMock() originalRedis := redisclient.RDB redisclient.RDB = db defer func() { redisclient.RDB = originalRedis }() config := models.RateLimitConfig{ RequestsPerMinute: 100, BurstSize: 20, } // Mock Redis error mock.MatchExpectationsInOrder(false) mock.ExpectIncr("ratelimit:ip:192.168.1.1:1234").SetErr(context.DeadlineExceeded) mock.ExpectExpire("ratelimit:ip:192.168.1.1:1234", time.Minute).SetVal(true) middleware := RateLimiterMiddleware(config) handlerCalled := false handler := middleware(func(w http.ResponseWriter, r *http.Request) { handlerCalled = true w.WriteHeader(http.StatusOK) }) req := httptest.NewRequest(http.MethodGet, "/test", nil) req.RemoteAddr = "192.168.1.1:1234" w := httptest.NewRecorder() handler(w, req) resp := w.Result() defer resp.Body.Close() if !handlerCalled { t.Error("handler should be called when Redis errors (fail open)") } if resp.StatusCode != http.StatusOK { t.Errorf("status = %v, want %v", resp.StatusCode, http.StatusOK) } }