init commit

This commit is contained in:
2025-11-25 15:12:31 +08:00
commit 052c7e0cca
63 changed files with 8828 additions and 0 deletions
+9
View File
@@ -0,0 +1,9 @@
package middleware
const (
InvalidTokenClaims = "Invalid token claims" // #nosec G101
InvalidOrExpiredToken = "Invalid or expired token" // #nosec G101
redisKeyJWTSessionID = "jwt_session_id:%s"
errorFormat = "%s?error=%s"
InternalServerError = "Internal server error"
)
+9
View File
@@ -0,0 +1,9 @@
package middleware
import (
"authentication/models"
)
// FlusherPreservingResponseWriter is an alias for models.FlusherPreservingResponseWriter
// Kept for backward compatibility
type FlusherPreservingResponseWriter = models.FlusherPreservingResponseWriter
+43
View File
@@ -0,0 +1,43 @@
package middleware
import (
"log"
"net/http"
"os"
)
func SetHeaders(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodOptions {
// Only set Content-Type if not SSE
if w.Header().Get("Content-Type") != "text/event-stream" {
w.Header().Set("Content-Type", "application/json")
}
}
w.Header().Set("X-DNS-Prefetch-Control", "off")
w.Header().Set("X-Frame-Options", "DENY")
w.Header().Set("X-XSS-Protection", "1; mode=block")
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("Content-Security-Policy", "default-src 'self'")
w.Header().Set("Referrer-Policy", "no-referrer")
w.Header().Set("X-Powered-By", "Zig")
GoEnv := os.Getenv("GO_ENV")
if GoEnv == "" {
log.Fatal("GO_ENV is not set in SetHeaders middleware. Please set the GO_ENV environment variable.")
}
if GoEnv != "development" {
w.Header().Set("Strict-Transport-Security", "max-age=63072000; includeSubDomains; preload")
}
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusOK)
return
}
next.ServeHTTP(&FlusherPreservingResponseWriter{ResponseWriter: w}, r)
})
}
+284
View File
@@ -0,0 +1,284 @@
package middleware
import (
"net/http"
"net/http/httptest"
"os"
"testing"
)
func TestSetHeaders(t *testing.T) {
os.Setenv("GO_ENV", "production")
defer os.Unsetenv("GO_ENV")
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
middleware := SetHeaders(handler)
req := httptest.NewRequest(http.MethodGet, "/test", nil)
recorder := httptest.NewRecorder()
middleware.ServeHTTP(recorder, req)
// Check security headers
headers := map[string]string{
"X-DNS-Prefetch-Control": "off",
"X-Frame-Options": "DENY",
"X-XSS-Protection": "1; mode=block",
"X-Content-Type-Options": "nosniff",
"Content-Security-Policy": "default-src 'self'",
"Referrer-Policy": "no-referrer",
"X-Powered-By": "Zig",
"Strict-Transport-Security": "max-age=63072000; includeSubDomains; preload",
"Content-Type": "application/json",
}
for header, expected := range headers {
actual := recorder.Header().Get(header)
if actual != expected {
t.Errorf("Expected header %s to be '%s', got '%s'", header, expected, actual)
}
}
}
func TestSetHeadersDevelopment(t *testing.T) {
os.Setenv("GO_ENV", "development")
defer os.Unsetenv("GO_ENV")
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
middleware := SetHeaders(handler)
req := httptest.NewRequest(http.MethodGet, "/test", nil)
recorder := httptest.NewRecorder()
middleware.ServeHTTP(recorder, req)
// HSTS should not be set in development
hsts := recorder.Header().Get("Strict-Transport-Security")
if hsts != "" {
t.Errorf("Expected no HSTS header in development, got '%s'", hsts)
}
// Other security headers should still be present
if recorder.Header().Get("X-Frame-Options") != "DENY" {
t.Error("Expected X-Frame-Options header in development")
}
}
func TestSetHeadersSSE(t *testing.T) {
os.Setenv("GO_ENV", "production")
defer os.Unsetenv("GO_ENV")
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
})
middleware := SetHeaders(handler)
req := httptest.NewRequest(http.MethodGet, "/stream", nil)
recorder := httptest.NewRecorder()
// Pre-set SSE content type
recorder.Header().Set("Content-Type", "text/event-stream")
middleware.ServeHTTP(recorder, req)
// Content-Type should remain text/event-stream
contentType := recorder.Header().Get("Content-Type")
if contentType != "text/event-stream" {
t.Errorf("Expected Content-Type 'text/event-stream', got '%s'", contentType)
}
}
func TestSetHeadersOptions(t *testing.T) {
os.Setenv("GO_ENV", "production")
defer os.Unsetenv("GO_ENV")
handlerCalled := false
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handlerCalled = true
w.WriteHeader(http.StatusOK)
})
middleware := SetHeaders(handler)
req := httptest.NewRequest(http.MethodOptions, "/test", nil)
recorder := httptest.NewRecorder()
middleware.ServeHTTP(recorder, req)
// OPTIONS should return 200 without calling next handler
if recorder.Code != http.StatusOK {
t.Errorf("Expected status 200 for OPTIONS, got %d", recorder.Code)
}
if handlerCalled {
t.Error("Expected next handler NOT to be called for OPTIONS request")
}
// Security headers should still be set
if recorder.Header().Get("X-Frame-Options") != "DENY" {
t.Error("Expected security headers to be set for OPTIONS")
}
}
func TestSetHeadersAllMethods(t *testing.T) {
os.Setenv("GO_ENV", "production")
defer os.Unsetenv("GO_ENV")
methods := []string{
http.MethodGet,
http.MethodPost,
http.MethodPut,
http.MethodDelete,
http.MethodPatch,
}
for _, method := range methods {
t.Run(method, func(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
middleware := SetHeaders(handler)
req := httptest.NewRequest(method, "/test", nil)
recorder := httptest.NewRecorder()
middleware.ServeHTTP(recorder, req)
// All methods should have security headers
if recorder.Header().Get("X-Frame-Options") != "DENY" {
t.Errorf("Expected X-Frame-Options for %s", method)
}
if recorder.Header().Get("Content-Type") != "application/json" {
t.Errorf("Expected Content-Type application/json for %s", method)
}
})
}
}
func TestSetHeadersEnvironments(t *testing.T) {
environments := []string{"development", "production", "canary", "debug"}
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
for _, env := range environments {
t.Run(env, func(t *testing.T) {
os.Setenv("GO_ENV", env)
defer os.Unsetenv("GO_ENV")
middleware := SetHeaders(handler)
req := httptest.NewRequest(http.MethodGet, "/test", nil)
recorder := httptest.NewRecorder()
middleware.ServeHTTP(recorder, req)
// HSTS should only be set in non-development environments
hsts := recorder.Header().Get("Strict-Transport-Security")
if env == "development" {
if hsts != "" {
t.Errorf("HSTS should not be set in development, got '%s'", hsts)
}
} else {
if hsts == "" {
t.Errorf("HSTS should be set in %s environment", env)
}
}
})
}
}
func TestSetHeadersPoweredBy(t *testing.T) {
os.Setenv("GO_ENV", "production")
defer os.Unsetenv("GO_ENV")
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
middleware := SetHeaders(handler)
req := httptest.NewRequest(http.MethodGet, "/test", nil)
recorder := httptest.NewRecorder()
middleware.ServeHTTP(recorder, req)
poweredBy := recorder.Header().Get("X-Powered-By")
if poweredBy != "Zig" {
t.Errorf("Expected X-Powered-By 'Zig', got '%s'", poweredBy)
}
}
func TestSetHeadersCSP(t *testing.T) {
os.Setenv("GO_ENV", "production")
defer os.Unsetenv("GO_ENV")
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
middleware := SetHeaders(handler)
req := httptest.NewRequest(http.MethodGet, "/test", nil)
recorder := httptest.NewRecorder()
middleware.ServeHTTP(recorder, req)
csp := recorder.Header().Get("Content-Security-Policy")
if csp != "default-src 'self'" {
t.Errorf("Expected CSP 'default-src 'self'', got '%s'", csp)
}
}
func TestSetHeadersReferrerPolicy(t *testing.T) {
os.Setenv("GO_ENV", "production")
defer os.Unsetenv("GO_ENV")
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
middleware := SetHeaders(handler)
req := httptest.NewRequest(http.MethodGet, "/test", nil)
recorder := httptest.NewRecorder()
middleware.ServeHTTP(recorder, req)
referrer := recorder.Header().Get("Referrer-Policy")
if referrer != "no-referrer" {
t.Errorf("Expected Referrer-Policy 'no-referrer', got '%s'", referrer)
}
}
func TestSetHeadersXSSProtection(t *testing.T) {
os.Setenv("GO_ENV", "production")
defer os.Unsetenv("GO_ENV")
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
middleware := SetHeaders(handler)
req := httptest.NewRequest(http.MethodGet, "/test", nil)
recorder := httptest.NewRecorder()
middleware.ServeHTTP(recorder, req)
xss := recorder.Header().Get("X-XSS-Protection")
if xss != "1; mode=block" {
t.Errorf("Expected X-XSS-Protection '1; mode=block', got '%s'", xss)
}
}
+247
View File
@@ -0,0 +1,247 @@
//lint:file-ignore SA1029 Ignore all golangci-lint warnings in this file
package middleware
import (
"context"
"database/sql"
"fmt"
"net/http"
"net/url"
"os"
"strings"
"sync"
"time"
"authentication/db"
"authentication/helper"
"authentication/models"
"authentication/redisclient"
"github.com/golang-jwt/jwt/v5"
"github.com/joho/godotenv"
)
var (
Blacklist = make(map[string]struct{})
Mu sync.Mutex
)
func init() {
err := godotenv.Load()
if err != nil {
helper.LogWarn("Warning: Could not load .env file, using system environment variables.")
}
}
func JWTMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authHeader := r.Header.Get("Authorization")
DashboardBaseURL := os.Getenv("DASHBOARD_URL")
tokenString := ""
if isValidAuthHeader(authHeader) {
tokenString = strings.TrimPrefix(authHeader, "Bearer ")
} else {
path := r.URL.Path
if strings.Contains(path, "/sse") {
tokenString = r.URL.Query().Get("access_token")
if tokenString == "" {
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("Missing access_token in query params")), http.StatusSeeOther)
return
}
} else {
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("Invalid authorization header")), http.StatusSeeOther)
return
}
}
if isTokenBlacklisted(tokenString) {
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("Token is blacklisted")), http.StatusSeeOther)
return
}
secretKey := os.Getenv("JWT_SECRET_KEY")
if secretKey == "" {
helper.RespondWithError(w, http.StatusInternalServerError, "Secret key not set")
return
}
token, err := parseToken(tokenString, secretKey)
if err != nil {
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape(InvalidOrExpiredToken)), http.StatusSeeOther)
return
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok || !token.Valid {
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape(InvalidTokenClaims)), http.StatusSeeOther)
return
}
// Check JWT token expiration
if exp, ok := claims["exp"].(float64); ok {
if exp == 0 {
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("Token has no expiration")), http.StatusSeeOther)
return
}
// Check if token is expired
if time.Now().Unix() > int64(exp) {
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("Token has expired")), http.StatusSeeOther)
return
}
} else {
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("Token missing expiration claim")), http.StatusSeeOther)
return
}
email, ok := claims["email"].(string)
if !ok {
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape(InvalidTokenClaims)), http.StatusSeeOther)
return
}
sessionID, ok := claims["session_id"].(string)
if !ok {
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("Invalid session ID in token")), http.StatusSeeOther)
return
}
if isSessionBlacklisted(sessionID) {
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("Session has been revoked")), http.StatusSeeOther)
return
}
session, err := validateSessionFromDB(sessionID)
if err != nil {
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("Invalid or revoked session")), http.StatusSeeOther)
return
}
userAgent := r.Header.Get("User-Agent")
ipAddress := getClientIP(r)
if session.UserAgent != userAgent {
helper.LogError(nil, fmt.Sprintf("Session security mismatch for session %s", sessionID))
}
if session.IPAddress != ipAddress {
helper.LogError(nil, fmt.Sprintf("Session IP address mismatch for session %s: expected %s, got %s", sessionID, session.IPAddress, ipAddress))
}
userID, err := getUserIDByEmail(email)
if err != nil {
if err != sql.ErrNoRows {
helper.RespondWithError(w, http.StatusInternalServerError, "Failed to get user ID")
return
}
}
ctx := context.WithValue(r.Context(), "userID", userID)
ctx = context.WithValue(ctx, "sessionID", sessionID)
ctx = context.WithValue(ctx, "email", email)
next.ServeHTTP(&models.FlusherPreservingResponseWriter{ResponseWriter: w}, r.WithContext(ctx))
})
}
func isValidAuthHeader(authHeader string) bool {
return authHeader != "" && strings.HasPrefix(authHeader, "Bearer ")
}
func isTokenBlacklisted(tokenString string) bool {
Mu.Lock()
defer Mu.Unlock()
_, found := Blacklist[tokenString]
return found
}
// isSessionBlacklisted checks if a session is in the Redis blacklist
func isSessionBlacklisted(sessionID string) bool {
ctx := context.Background()
blacklistKey := fmt.Sprintf("session_blacklist:%s", sessionID)
exists, err := redisclient.RDB.Exists(ctx, blacklistKey).Result()
return err == nil && exists > 0
}
func parseToken(tokenString, secretKey string) (*jwt.Token, error) {
return jwt.ParseWithClaims(tokenString, jwt.MapClaims{}, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return []byte(secretKey), nil
})
}
func getUserIDByEmail(email string) (string, error) {
var userID string
err := db.DB.QueryRow("SELECT id FROM users WHERE email_address = ?", email).Scan(&userID)
if err != nil {
return "", err
}
return userID, nil
}
func validateSessionFromDB(sessionID string) (*models.JWTSession, error) {
ctx := context.Background()
sessionKey := fmt.Sprintf(redisKeyJWTSessionID, sessionID)
// Try to get session from Redis cache first
var session models.JWTSession
err := helper.GetJSON(ctx, sessionKey, &session)
if err != nil {
// Session not in cache, fetch from database
err = db.DB.QueryRow(`
SELECT id, user_id, refresh_token_hash, user_agent, ip_address, created_at, updated_at, expires_at, is_revoked
FROM jwt_sessions
WHERE id = ? AND is_revoked = false
`, sessionID).Scan(
&session.ID,
&session.UserID,
&session.RefreshTokenHash,
&session.UserAgent,
&session.IPAddress,
&session.CreatedAt,
&session.UpdatedAt,
&session.ExpiresAt,
&session.IsRevoked,
)
if err != nil {
return nil, fmt.Errorf("session not found or revoked: %w", err)
}
// Cache the session in Redis (TTL based on session expiry)
sessionTTL := int(time.Until(session.ExpiresAt).Seconds())
if sessionTTL > 0 {
if err := helper.SetJSON(ctx, sessionKey, session, &sessionTTL); err != nil {
helper.LogWarn(fmt.Sprintf("Failed to cache session in Redis: %v", err))
}
}
}
if session.ExpiresAt.Before(time.Now()) {
// Auto-revoke expired session and clear cache
_, _ = db.DB.Exec("UPDATE jwt_sessions SET is_revoked = true WHERE id = ?", sessionID)
redisclient.RDB.Del(ctx, sessionKey)
return nil, fmt.Errorf("session has expired")
}
return &session, nil
}
func getClientIP(r *http.Request) string {
forwarded := r.Header.Get("X-Forwarded-For")
if forwarded != "" {
parts := strings.Split(forwarded, ",")
return strings.TrimSpace(parts[0])
}
realIP := r.Header.Get("X-Real-IP")
if realIP != "" {
return realIP
}
ip := r.RemoteAddr
if idx := strings.LastIndex(ip, ":"); idx != -1 {
ip = ip[:idx]
}
return ip
}
+186
View File
@@ -0,0 +1,186 @@
package middleware
import (
"database/sql"
"fmt"
"log"
"net"
"net/http"
"os"
"regexp"
"time"
"authentication/db"
"authentication/helper"
"authentication/redisclient"
)
func normalizeEndpoint(path string) string {
uuidRegex := regexp.MustCompile(`/([a-zA-Z0-9_-]{11})(/|$)`)
path = uuidRegex.ReplaceAllString(path, "/{id}$2")
queryParamRegex := regexp.MustCompile(`\?.*`)
return queryParamRegex.ReplaceAllString(path, "")
}
func RateLimiterMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
rateLimitHeaderValue := os.Getenv("RATE_LIMIT_HEADER")
if rateLimitHeaderValue == "" {
rateLimitHeaderValue = "F04C"
}
if r.Header.Get("X-RateLimit-Bypass") == rateLimitHeaderValue {
// Bypass header is set to the correct value, skip rate limiting
next.ServeHTTP(&FlusherPreservingResponseWriter{ResponseWriter: w}, r)
return
}
// If the header is not set or has an invalid value, proceed with rate limiting logic
log.Print("No valid rate limit bypass header, proceeding with rate limiting logic")
// Get user identifier (email or IP)
userIdentifier := ""
email, err := helper.ExtractEmailFromToken(r.Header.Get("Authorization"))
if err != nil {
email, err = helper.ExtractEmailFromToken(r.URL.Query().Get("access_token"))
if err != nil {
helper.LogInfo(fmt.Sprintf("Could not extract email from token: %v, using IP-based rate limiting", err))
}
}
if email != "" {
userIdentifier = email
} else {
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
helper.RespondWithError(w, http.StatusInternalServerError, InternalServerError)
return
}
userIdentifier = ip
}
if r.URL == nil || r.URL.Path == "" {
helper.RespondWithError(w, http.StatusInternalServerError, "Invalid request URL")
return
}
endpoint := normalizeEndpoint(r.URL.Path)
var limitCount, timeWindow int
err = db.DB.QueryRow("SELECT limit_count, time_window FROM rate_limiter WHERE identifier = ?", endpoint).Scan(&limitCount, &timeWindow)
if err != nil {
if err == sql.ErrNoRows {
limitCount = 300
timeWindow = 60
_, insertErr := db.DB.Exec("INSERT INTO rate_limiter (identifier, limit_count, time_window) VALUES (?, ?, ?)", endpoint, limitCount, timeWindow)
if insertErr != nil {
helper.RespondWithError(w, http.StatusInternalServerError, InternalServerError)
return
}
} else {
helper.RespondWithError(w, http.StatusInternalServerError, InternalServerError)
return
}
}
redisCountKey := "ratelimit_count:" + userIdentifier + ":" + endpoint
if redisclient.RDB == nil {
helper.RespondWithError(w, http.StatusInternalServerError, "Redis client not initialized")
return
}
count, err := redisclient.RDB.Incr(r.Context(), redisCountKey).Result()
if err != nil {
helper.RespondWithError(w, http.StatusInternalServerError, InternalServerError)
return
}
if count == 1 {
_ = redisclient.RDB.Expire(r.Context(), redisCountKey, time.Duration(timeWindow)*time.Second).Err()
}
if int(count) > limitCount {
println("Rate limit exceeded: user=" + userIdentifier + " endpoint=" + endpoint + " count=" +
fmt.Sprintf("%d", count) + " limit=" + fmt.Sprintf("%d", limitCount))
helper.RespondWithError(w, http.StatusTooManyRequests, "Rate limit exceeded")
return
}
next.ServeHTTP(w, r)
})
}
func PublicRateLimiterMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("X-RateLimit-Bypass") == "F04C" {
next.ServeHTTP(&FlusherPreservingResponseWriter{ResponseWriter: w}, r)
return
}
log.Print("No valid rate limit bypass header, proceeding with rate limiting logic")
// Use IP address as the user identifier for public endpoints
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
helper.RespondWithError(w, http.StatusInternalServerError, InternalServerError)
return
}
userIdentifier := ip
if r.URL == nil || r.URL.Path == "" {
helper.RespondWithError(w, http.StatusInternalServerError, "Invalid request URL")
return
}
endpoint := normalizeEndpoint(r.URL.Path)
var limitCount, timeWindow int
err = db.DB.QueryRow("SELECT limit_count, time_window FROM rate_limiter WHERE identifier = ?", endpoint).Scan(&limitCount, &timeWindow)
if err != nil {
if err == sql.ErrNoRows {
limitCount = 36000
timeWindow = 60
_, insertErr := db.DB.Exec("INSERT INTO rate_limiter (identifier, limit_count, time_window) VALUES (?, ?, ?)", endpoint, limitCount, timeWindow)
if insertErr != nil {
helper.RespondWithError(w, http.StatusInternalServerError, InternalServerError)
return
}
} else {
helper.RespondWithError(w, http.StatusInternalServerError, InternalServerError)
return
}
}
redisCountKey := "ratelimit_count:" + userIdentifier + ":" + endpoint
if redisclient.RDB == nil {
helper.RespondWithError(w, http.StatusInternalServerError, "Redis client not initialized")
return
}
count, err := redisclient.RDB.Incr(r.Context(), redisCountKey).Result()
if err != nil {
helper.RespondWithError(w, http.StatusInternalServerError, InternalServerError)
return
}
if count == 1 {
err := redisclient.RDB.Expire(r.Context(), redisCountKey, time.Duration(timeWindow)*time.Second).Err()
if err != nil {
helper.RespondWithError(w, http.StatusInternalServerError, InternalServerError)
return
}
}
// Log the key and value saved
log.Printf("Redis key: %s, value: %d", redisCountKey, count)
if int(count) > limitCount {
println("Rate limit exceeded: user=" + userIdentifier + " endpoint=" + endpoint + " count=" +
fmt.Sprintf("%d", count) + " limit=" + fmt.Sprintf("%d", limitCount))
helper.RespondWithError(w, http.StatusTooManyRequests, "Rate limit exceeded")
return
}
next.ServeHTTP(w, r)
})
}
+214
View File
@@ -0,0 +1,214 @@
package middleware
import (
"testing"
)
func TestNormalizeEndpoint(t *testing.T) {
testCases := []struct {
name string
input string
expected string
}{
{
name: "Simple path",
input: "/api/users",
expected: "/api/users",
},
{
name: "Path with UUID",
input: "/api/users/abcdef12345",
expected: "/api/users/{id}",
},
{
name: "Path with UUID and trailing slash",
input: "/api/users/abcdef12345/",
expected: "/api/users/{id}/",
},
{
name: "Path with UUID in middle",
input: "/api/users/abcdef12345/profile",
expected: "/api/users/{id}/profile",
},
{
name: "Path with query params",
input: "/api/users?page=1&limit=10",
expected: "/api/users",
},
{
name: "Path with UUID and query params",
input: "/api/users/abcdef12345?detail=full",
expected: "/api/users/abcdef12345", // Query params removed first, then UUID not matched
},
{
name: "Multiple UUIDs",
input: "/api/users/abc12345678/posts/def87654321",
expected: "/api/users/{id}/posts/{id}",
},
{
name: "Root path",
input: "/",
expected: "/",
},
{
name: "Empty path",
input: "",
expected: "",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := normalizeEndpoint(tc.input)
if result != tc.expected {
t.Errorf("Expected '%s', got '%s'", tc.expected, result)
}
})
}
}
func TestNormalizeEndpointUUIDFormats(t *testing.T) {
uuidFormats := []string{
"abcdef12345",
"ABCDEF12345",
"abc_def1234",
"abc-def1234",
"mixedCase12",
}
for _, uuid := range uuidFormats {
t.Run(uuid, func(t *testing.T) {
input := "/api/users/" + uuid
result := normalizeEndpoint(input)
expected := "/api/users/{id}"
if result != expected {
t.Errorf("Expected '%s', got '%s' for UUID format '%s'", expected, result, uuid)
}
})
}
}
func TestNormalizeEndpointComplexQueryStrings(t *testing.T) {
testCases := []struct {
input string
expected string
}{
{"/api/users?a=1&b=2&c=3", "/api/users"},
{"/api/users?filter=active&sort=name&order=asc", "/api/users"},
{"/api/users?search=john+doe", "/api/users"},
{"/api/users?tags[]=tag1&tags[]=tag2", "/api/users"},
}
for _, tc := range testCases {
result := normalizeEndpoint(tc.input)
if result != tc.expected {
t.Errorf("Input '%s': expected '%s', got '%s'", tc.input, tc.expected, result)
}
}
}
func TestNormalizeEndpointEdgeCases(t *testing.T) {
testCases := []struct {
name string
input string
expected string
}{
{
name: "Just query string",
input: "?param=value",
expected: "",
},
{
name: "Double slashes",
input: "/api//users",
expected: "/api//users",
},
{
name: "Trailing query without params",
input: "/api/users?",
expected: "/api/users",
},
{
name: "UUID at end without slash",
input: "/users/abc12345678",
expected: "/users/{id}",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := normalizeEndpoint(tc.input)
if result != tc.expected {
t.Errorf("Expected '%s', got '%s'", tc.expected, result)
}
})
}
}
func TestNormalizeEndpointPreservesNonUUID(t *testing.T) {
testCases := []string{
"/api/users/all",
"/api/users/active",
"/api/sessions/current",
"/health",
"/metrics",
}
for _, input := range testCases {
t.Run(input, func(t *testing.T) {
result := normalizeEndpoint(input)
if result != input {
t.Errorf("Non-UUID path should be preserved. Input: '%s', got '%s'", input, result)
}
})
}
}
func TestNormalizeEndpointUUIDLength(t *testing.T) {
// UUIDs must be exactly 11 characters
testCases := []struct {
name string
uuid string
shouldNormalize bool
}{
{"10 chars", "abcdefghij", false},
{"11 chars", "abcdefghijk", true},
{"12 chars", "abcdefghijkl", false},
{"5 chars", "abcde", false},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
input := "/api/users/" + tc.uuid
result := normalizeEndpoint(input)
if tc.shouldNormalize {
expected := "/api/users/{id}"
if result != expected {
t.Errorf("Expected '%s', got '%s'", expected, result)
}
} else {
if result != input {
t.Errorf("Should not normalize, expected '%s', got '%s'", input, result)
}
}
})
}
}
func BenchmarkNormalizeEndpoint(b *testing.B) {
testPaths := []string{
"/api/users",
"/api/users/abc12345678",
"/api/users/abc12345678/profile",
"/api/users?page=1&limit=10",
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
for _, path := range testPaths {
normalizeEndpoint(path)
}
}
}
+20
View File
@@ -0,0 +1,20 @@
package middleware
import (
"os"
"testing"
)
// TestMain runs before all tests and sets up the test environment
func TestMain(m *testing.M) {
// Set GO_ENV for all tests to prevent init() failures
os.Setenv("GO_ENV", "development")
// Run tests
code := m.Run()
// Cleanup
os.Unsetenv("GO_ENV")
os.Exit(code)
}