added unit testing
This commit is contained in:
@@ -0,0 +1,347 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"authorization/models"
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
func TestGetJWTSecret(t *testing.T) {
|
||||
// Save original and restore after test
|
||||
originalSecret := os.Getenv("JWT_KEY")
|
||||
defer func() {
|
||||
if originalSecret != "" {
|
||||
os.Setenv("JWT_KEY", originalSecret)
|
||||
} else {
|
||||
os.Unsetenv("JWT_KEY")
|
||||
}
|
||||
}()
|
||||
|
||||
t.Run("JWT_KEY not set", func(t *testing.T) {
|
||||
// Reset state for testing
|
||||
oldCached := jwtSecretCached
|
||||
oldError := jwtSecretError
|
||||
defer func() {
|
||||
jwtSecretCached = oldCached
|
||||
jwtSecretError = oldError
|
||||
}()
|
||||
|
||||
os.Unsetenv("JWT_KEY")
|
||||
jwtSecretOnce = sync.Once{}
|
||||
jwtSecretError = nil
|
||||
jwtSecretCached = nil
|
||||
|
||||
_, err := getJWTSecret()
|
||||
if err == nil {
|
||||
t.Error("Expected error when JWT_KEY is not set")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("JWT_KEY set", func(t *testing.T) {
|
||||
// Reset state for testing
|
||||
oldCached := jwtSecretCached
|
||||
oldError := jwtSecretError
|
||||
defer func() {
|
||||
jwtSecretCached = oldCached
|
||||
jwtSecretError = oldError
|
||||
}()
|
||||
|
||||
os.Setenv("JWT_KEY", "test-secret-key")
|
||||
jwtSecretOnce = sync.Once{}
|
||||
jwtSecretError = nil
|
||||
jwtSecretCached = nil
|
||||
|
||||
secret, err := getJWTSecret()
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
if string(secret) != "test-secret-key" {
|
||||
t.Errorf("Expected 'test-secret-key', got '%s'", string(secret))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestExtractBearerToken(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
authHeader string
|
||||
wantToken string
|
||||
wantOK bool
|
||||
}{
|
||||
{
|
||||
name: "Valid Bearer token",
|
||||
authHeader: "Bearer token123",
|
||||
wantToken: "token123",
|
||||
wantOK: true,
|
||||
},
|
||||
{
|
||||
name: "Empty header",
|
||||
authHeader: "",
|
||||
wantToken: "",
|
||||
wantOK: false,
|
||||
},
|
||||
{
|
||||
name: "Too short",
|
||||
authHeader: "Bearer",
|
||||
wantToken: "",
|
||||
wantOK: false,
|
||||
},
|
||||
{
|
||||
name: "Wrong prefix",
|
||||
authHeader: "Basic token123",
|
||||
wantToken: "",
|
||||
wantOK: false,
|
||||
},
|
||||
{
|
||||
name: "Missing space",
|
||||
authHeader: "Bearertoken123",
|
||||
wantToken: "",
|
||||
wantOK: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
token, ok := extractBearerToken(tt.authHeader)
|
||||
if token != tt.wantToken {
|
||||
t.Errorf("Expected token '%s', got '%s'", tt.wantToken, token)
|
||||
}
|
||||
if ok != tt.wantOK {
|
||||
t.Errorf("Expected ok %v, got %v", tt.wantOK, ok)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseAndValidateToken(t *testing.T) {
|
||||
// Setup
|
||||
os.Setenv("JWT_KEY", "test-secret-key")
|
||||
jwtSecretOnce = sync.Once{}
|
||||
jwtSecretError = nil
|
||||
defer os.Unsetenv("JWT_KEY")
|
||||
|
||||
t.Run("Valid token", func(t *testing.T) {
|
||||
// Create a valid token
|
||||
claims := &models.Claims{
|
||||
UserID: "user123",
|
||||
Username: "testuser",
|
||||
Role: "admin",
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenString, err := token.SignedString([]byte("test-secret-key"))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create token: %v", err)
|
||||
}
|
||||
|
||||
parsedClaims, err := parseAndValidateToken(tokenString)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
if parsedClaims.UserID != "user123" {
|
||||
t.Errorf("Expected UserID 'user123', got '%s'", parsedClaims.UserID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Invalid token", func(t *testing.T) {
|
||||
_, err := parseAndValidateToken("invalid.token.string")
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid token")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Expired token", func(t *testing.T) {
|
||||
claims := &models.Claims{
|
||||
UserID: "user123",
|
||||
Username: "testuser",
|
||||
Role: "admin",
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(-1 * time.Hour)),
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenString, _ := token.SignedString([]byte("test-secret-key"))
|
||||
|
||||
_, err := parseAndValidateToken(tokenString)
|
||||
if err == nil {
|
||||
t.Error("Expected error for expired token")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestBuildContext(t *testing.T) {
|
||||
claims := &models.Claims{
|
||||
UserID: "user123",
|
||||
Username: "testuser",
|
||||
Role: "admin",
|
||||
}
|
||||
|
||||
parent := context.Background()
|
||||
ctx := buildContext(parent, claims)
|
||||
|
||||
// Check claims
|
||||
if val, ok := ctx.Value(claimsKey).(*models.Claims); !ok || val.UserID != "user123" {
|
||||
t.Error("Claims not properly set in context")
|
||||
}
|
||||
|
||||
// Check userID
|
||||
if val, ok := ctx.Value(userIDKey).(string); !ok || val != "user123" {
|
||||
t.Error("UserID not properly set in context")
|
||||
}
|
||||
|
||||
// Check username
|
||||
if val, ok := ctx.Value(usernameKey).(string); !ok || val != "testuser" {
|
||||
t.Error("Username not properly set in context")
|
||||
}
|
||||
|
||||
// Check role
|
||||
if val, ok := ctx.Value(roleKey).(string); !ok || val != "admin" {
|
||||
t.Error("Role not properly set in context")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetClaims(t *testing.T) {
|
||||
claims := &models.Claims{
|
||||
UserID: "user123",
|
||||
Username: "testuser",
|
||||
Role: "admin",
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
ctx := context.WithValue(req.Context(), claimsKey, claims)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
retrievedClaims, ok := GetClaims(req)
|
||||
if !ok {
|
||||
t.Error("Expected claims to be found")
|
||||
}
|
||||
if retrievedClaims.UserID != "user123" {
|
||||
t.Errorf("Expected UserID 'user123', got '%s'", retrievedClaims.UserID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetUserID(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
ctx := context.WithValue(req.Context(), userIDKey, "user123")
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
userID, ok := GetUserID(req)
|
||||
if !ok {
|
||||
t.Error("Expected userID to be found")
|
||||
}
|
||||
if userID != "user123" {
|
||||
t.Errorf("Expected 'user123', got '%s'", userID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetUsername(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
ctx := context.WithValue(req.Context(), usernameKey, "testuser")
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
username, ok := GetUsername(req)
|
||||
if !ok {
|
||||
t.Error("Expected username to be found")
|
||||
}
|
||||
if username != "testuser" {
|
||||
t.Errorf("Expected 'testuser', got '%s'", username)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRole(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
ctx := context.WithValue(req.Context(), roleKey, "admin")
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
role, ok := GetRole(req)
|
||||
if !ok {
|
||||
t.Error("Expected role to be found")
|
||||
}
|
||||
if role != "admin" {
|
||||
t.Errorf("Expected 'admin', got '%s'", role)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWTAuth_NoAuthHeader(t *testing.T) {
|
||||
handler := func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
JWTAuth(handler)(w, req)
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWTAuth_InvalidToken(t *testing.T) {
|
||||
os.Setenv("JWT_KEY", "test-secret-key")
|
||||
jwtSecretOnce = sync.Once{}
|
||||
jwtSecretError = nil
|
||||
defer os.Unsetenv("JWT_KEY")
|
||||
|
||||
handler := func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.Header.Set("Authorization", "Bearer invalid.token.here")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
JWTAuth(handler)(w, req)
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWTAuth_ValidToken(t *testing.T) {
|
||||
os.Setenv("JWT_KEY", "test-secret-key")
|
||||
jwtSecretOnce = sync.Once{}
|
||||
jwtSecretError = nil
|
||||
defer os.Unsetenv("JWT_KEY")
|
||||
|
||||
claims := &models.Claims{
|
||||
UserID: "user123",
|
||||
Username: "testuser",
|
||||
Role: "admin",
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenString, _ := token.SignedString([]byte("test-secret-key"))
|
||||
|
||||
handler := func(w http.ResponseWriter, r *http.Request) {
|
||||
// Verify claims are in context
|
||||
if retrievedClaims, ok := GetClaims(r); !ok || retrievedClaims.UserID != "user123" {
|
||||
t.Error("Claims not found or incorrect in context")
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+tokenString)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
JWTAuth(handler)(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,326 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"authorization/models"
|
||||
"authorization/redisclient"
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redismock/v9"
|
||||
)
|
||||
|
||||
func TestDefaultRateLimitConfig(t *testing.T) {
|
||||
config := DefaultRateLimitConfig()
|
||||
|
||||
if config.RequestsPerMinute != 100 {
|
||||
t.Errorf("RequestsPerMinute = %v, want 100", config.RequestsPerMinute)
|
||||
}
|
||||
|
||||
if config.BurstSize != 20 {
|
||||
t.Errorf("BurstSize = %v, want 20", config.BurstSize)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetClientIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
xForwardedFor string
|
||||
xRealIP string
|
||||
remoteAddr string
|
||||
expectedIP string
|
||||
}{
|
||||
{
|
||||
name: "uses X-Forwarded-For when present",
|
||||
xForwardedFor: "192.168.1.1",
|
||||
xRealIP: "192.168.1.2",
|
||||
remoteAddr: "192.168.1.3:1234",
|
||||
expectedIP: "192.168.1.1",
|
||||
},
|
||||
{
|
||||
name: "uses X-Real-IP when X-Forwarded-For absent",
|
||||
xForwardedFor: "",
|
||||
xRealIP: "192.168.1.2",
|
||||
remoteAddr: "192.168.1.3:1234",
|
||||
expectedIP: "192.168.1.2",
|
||||
},
|
||||
{
|
||||
name: "uses RemoteAddr when both headers absent",
|
||||
xForwardedFor: "",
|
||||
xRealIP: "",
|
||||
remoteAddr: "192.168.1.3:1234",
|
||||
expectedIP: "192.168.1.3:1234",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.RemoteAddr = tt.remoteAddr
|
||||
if tt.xForwardedFor != "" {
|
||||
req.Header.Set("X-Forwarded-For", tt.xForwardedFor)
|
||||
}
|
||||
if tt.xRealIP != "" {
|
||||
req.Header.Set("X-Real-IP", tt.xRealIP)
|
||||
}
|
||||
|
||||
got := getClientIP(req)
|
||||
if got != tt.expectedIP {
|
||||
t.Errorf("getClientIP() = %v, want %v", got, tt.expectedIP)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckRateLimit_AllowedRequests(t *testing.T) {
|
||||
db, mock := redismock.NewClientMock()
|
||||
originalRedis := redisclient.RDB
|
||||
redisclient.RDB = db
|
||||
defer func() { redisclient.RDB = originalRedis }()
|
||||
|
||||
config := models.RateLimitConfig{
|
||||
RequestsPerMinute: 100,
|
||||
BurstSize: 20,
|
||||
}
|
||||
|
||||
identifier := "user:test123"
|
||||
key := "ratelimit:user:test123"
|
||||
|
||||
// Mock Redis INCR returning 10 (within limit)
|
||||
mock.ExpectIncr(key).SetVal(10)
|
||||
mock.ExpectExpire(key, time.Minute).SetVal(true)
|
||||
|
||||
allowed, err := checkRateLimit(identifier, config)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("checkRateLimit() error = %v", err)
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
t.Error("checkRateLimit() should allow request")
|
||||
}
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unfulfilled mock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckRateLimit_ExceedsLimit(t *testing.T) {
|
||||
db, mock := redismock.NewClientMock()
|
||||
originalRedis := redisclient.RDB
|
||||
redisclient.RDB = db
|
||||
defer func() { redisclient.RDB = originalRedis }()
|
||||
|
||||
config := models.RateLimitConfig{
|
||||
RequestsPerMinute: 100,
|
||||
BurstSize: 20,
|
||||
}
|
||||
|
||||
identifier := "user:test123"
|
||||
key := "ratelimit:user:test123"
|
||||
|
||||
// Mock Redis INCR returning 121 (exceeds limit of 120)
|
||||
mock.ExpectIncr(key).SetVal(121)
|
||||
mock.ExpectExpire(key, time.Minute).SetVal(true)
|
||||
|
||||
allowed, err := checkRateLimit(identifier, config)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("checkRateLimit() error = %v", err)
|
||||
}
|
||||
|
||||
if allowed {
|
||||
t.Error("checkRateLimit() should block request when limit exceeded")
|
||||
}
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unfulfilled mock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckRateLimit_RedisError(t *testing.T) {
|
||||
db, mock := redismock.NewClientMock()
|
||||
originalRedis := redisclient.RDB
|
||||
redisclient.RDB = db
|
||||
defer func() { redisclient.RDB = originalRedis }()
|
||||
|
||||
config := models.RateLimitConfig{
|
||||
RequestsPerMinute: 100,
|
||||
BurstSize: 20,
|
||||
}
|
||||
|
||||
identifier := "user:test123"
|
||||
key := "ratelimit:user:test123"
|
||||
|
||||
// Mock Redis error
|
||||
mock.ExpectIncr(key).SetErr(context.DeadlineExceeded)
|
||||
mock.ExpectExpire(key, time.Minute).SetVal(true)
|
||||
|
||||
allowed, err := checkRateLimit(identifier, config)
|
||||
|
||||
if err == nil {
|
||||
t.Error("checkRateLimit() should return error when Redis fails")
|
||||
}
|
||||
|
||||
if allowed {
|
||||
t.Error("checkRateLimit() should not allow when error occurs")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiterMiddleware_RedisNotAvailable(t *testing.T) {
|
||||
originalRedis := redisclient.RDB
|
||||
redisclient.RDB = nil
|
||||
defer func() { redisclient.RDB = originalRedis }()
|
||||
|
||||
config := DefaultRateLimitConfig()
|
||||
middleware := RateLimiterMiddleware(config)
|
||||
|
||||
handlerCalled := false
|
||||
handler := middleware(func(w http.ResponseWriter, r *http.Request) {
|
||||
handlerCalled = true
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler(w, req)
|
||||
|
||||
resp := w.Result()
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusServiceUnavailable {
|
||||
t.Errorf("status = %v, want %v", resp.StatusCode, http.StatusServiceUnavailable)
|
||||
}
|
||||
|
||||
if handlerCalled {
|
||||
t.Error("handler should not be called when Redis is not available")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiterMiddleware_AllowsRequest(t *testing.T) {
|
||||
db, mock := redismock.NewClientMock()
|
||||
originalRedis := redisclient.RDB
|
||||
redisclient.RDB = db
|
||||
defer func() { redisclient.RDB = originalRedis }()
|
||||
|
||||
config := models.RateLimitConfig{
|
||||
RequestsPerMinute: 100,
|
||||
BurstSize: 20,
|
||||
}
|
||||
|
||||
// Mock Redis response for allowed request
|
||||
mock.MatchExpectationsInOrder(false)
|
||||
mock.ExpectIncr("ratelimit:ip:192.168.1.1:1234").SetVal(5)
|
||||
mock.ExpectExpire("ratelimit:ip:192.168.1.1:1234", time.Minute).SetVal(true)
|
||||
|
||||
middleware := RateLimiterMiddleware(config)
|
||||
|
||||
handlerCalled := false
|
||||
handler := middleware(func(w http.ResponseWriter, r *http.Request) {
|
||||
handlerCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.RemoteAddr = "192.168.1.1:1234"
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler(w, req)
|
||||
|
||||
resp := w.Result()
|
||||
defer resp.Body.Close()
|
||||
|
||||
if !handlerCalled {
|
||||
t.Error("handler should be called when rate limit not exceeded")
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("status = %v, want %v", resp.StatusCode, http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiterMiddleware_BlocksRequest(t *testing.T) {
|
||||
db, mock := redismock.NewClientMock()
|
||||
originalRedis := redisclient.RDB
|
||||
redisclient.RDB = db
|
||||
defer func() { redisclient.RDB = originalRedis }()
|
||||
|
||||
config := models.RateLimitConfig{
|
||||
RequestsPerMinute: 100,
|
||||
BurstSize: 20,
|
||||
}
|
||||
|
||||
// Mock Redis response for blocked request
|
||||
mock.MatchExpectationsInOrder(false)
|
||||
mock.ExpectIncr("ratelimit:ip:192.168.1.1:1234").SetVal(121)
|
||||
mock.ExpectExpire("ratelimit:ip:192.168.1.1:1234", time.Minute).SetVal(true)
|
||||
|
||||
middleware := RateLimiterMiddleware(config)
|
||||
|
||||
handlerCalled := false
|
||||
handler := middleware(func(w http.ResponseWriter, r *http.Request) {
|
||||
handlerCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.RemoteAddr = "192.168.1.1:1234"
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler(w, req)
|
||||
|
||||
resp := w.Result()
|
||||
defer resp.Body.Close()
|
||||
|
||||
if handlerCalled {
|
||||
t.Error("handler should not be called when rate limit exceeded")
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusTooManyRequests {
|
||||
t.Errorf("status = %v, want %v", resp.StatusCode, http.StatusTooManyRequests)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiterMiddleware_FailsOpenOnError(t *testing.T) {
|
||||
db, mock := redismock.NewClientMock()
|
||||
originalRedis := redisclient.RDB
|
||||
redisclient.RDB = db
|
||||
defer func() { redisclient.RDB = originalRedis }()
|
||||
|
||||
config := models.RateLimitConfig{
|
||||
RequestsPerMinute: 100,
|
||||
BurstSize: 20,
|
||||
}
|
||||
|
||||
// Mock Redis error
|
||||
mock.MatchExpectationsInOrder(false)
|
||||
mock.ExpectIncr("ratelimit:ip:192.168.1.1:1234").SetErr(context.DeadlineExceeded)
|
||||
mock.ExpectExpire("ratelimit:ip:192.168.1.1:1234", time.Minute).SetVal(true)
|
||||
|
||||
middleware := RateLimiterMiddleware(config)
|
||||
|
||||
handlerCalled := false
|
||||
handler := middleware(func(w http.ResponseWriter, r *http.Request) {
|
||||
handlerCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.RemoteAddr = "192.168.1.1:1234"
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler(w, req)
|
||||
|
||||
resp := w.Result()
|
||||
defer resp.Body.Close()
|
||||
|
||||
if !handlerCalled {
|
||||
t.Error("handler should be called when Redis errors (fail open)")
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("status = %v, want %v", resp.StatusCode, http.StatusOK)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user