init commit
This commit is contained in:
@@ -0,0 +1,9 @@
|
||||
package middleware
|
||||
|
||||
const (
|
||||
InvalidTokenClaims = "Invalid token claims" // #nosec G101
|
||||
InvalidOrExpiredToken = "Invalid or expired token" // #nosec G101
|
||||
redisKeyJWTSessionID = "jwt_session_id:%s"
|
||||
errorFormat = "%s?error=%s"
|
||||
InternalServerError = "Internal server error"
|
||||
)
|
||||
@@ -0,0 +1,9 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"authentication/models"
|
||||
)
|
||||
|
||||
// FlusherPreservingResponseWriter is an alias for models.FlusherPreservingResponseWriter
|
||||
// Kept for backward compatibility
|
||||
type FlusherPreservingResponseWriter = models.FlusherPreservingResponseWriter
|
||||
@@ -0,0 +1,43 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
)
|
||||
|
||||
func SetHeaders(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodOptions {
|
||||
// Only set Content-Type if not SSE
|
||||
if w.Header().Get("Content-Type") != "text/event-stream" {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
}
|
||||
}
|
||||
|
||||
w.Header().Set("X-DNS-Prefetch-Control", "off")
|
||||
w.Header().Set("X-Frame-Options", "DENY")
|
||||
w.Header().Set("X-XSS-Protection", "1; mode=block")
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
w.Header().Set("Content-Security-Policy", "default-src 'self'")
|
||||
w.Header().Set("Referrer-Policy", "no-referrer")
|
||||
w.Header().Set("X-Powered-By", "Zig")
|
||||
|
||||
GoEnv := os.Getenv("GO_ENV")
|
||||
|
||||
if GoEnv == "" {
|
||||
log.Fatal("GO_ENV is not set in SetHeaders middleware. Please set the GO_ENV environment variable.")
|
||||
}
|
||||
|
||||
if GoEnv != "development" {
|
||||
w.Header().Set("Strict-Transport-Security", "max-age=63072000; includeSubDomains; preload")
|
||||
}
|
||||
|
||||
if r.Method == http.MethodOptions {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(&FlusherPreservingResponseWriter{ResponseWriter: w}, r)
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,284 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSetHeaders(t *testing.T) {
|
||||
os.Setenv("GO_ENV", "production")
|
||||
defer os.Unsetenv("GO_ENV")
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
middleware := SetHeaders(handler)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
middleware.ServeHTTP(recorder, req)
|
||||
|
||||
// Check security headers
|
||||
headers := map[string]string{
|
||||
"X-DNS-Prefetch-Control": "off",
|
||||
"X-Frame-Options": "DENY",
|
||||
"X-XSS-Protection": "1; mode=block",
|
||||
"X-Content-Type-Options": "nosniff",
|
||||
"Content-Security-Policy": "default-src 'self'",
|
||||
"Referrer-Policy": "no-referrer",
|
||||
"X-Powered-By": "Zig",
|
||||
"Strict-Transport-Security": "max-age=63072000; includeSubDomains; preload",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
for header, expected := range headers {
|
||||
actual := recorder.Header().Get(header)
|
||||
if actual != expected {
|
||||
t.Errorf("Expected header %s to be '%s', got '%s'", header, expected, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetHeadersDevelopment(t *testing.T) {
|
||||
os.Setenv("GO_ENV", "development")
|
||||
defer os.Unsetenv("GO_ENV")
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
middleware := SetHeaders(handler)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
middleware.ServeHTTP(recorder, req)
|
||||
|
||||
// HSTS should not be set in development
|
||||
hsts := recorder.Header().Get("Strict-Transport-Security")
|
||||
if hsts != "" {
|
||||
t.Errorf("Expected no HSTS header in development, got '%s'", hsts)
|
||||
}
|
||||
|
||||
// Other security headers should still be present
|
||||
if recorder.Header().Get("X-Frame-Options") != "DENY" {
|
||||
t.Error("Expected X-Frame-Options header in development")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetHeadersSSE(t *testing.T) {
|
||||
os.Setenv("GO_ENV", "production")
|
||||
defer os.Unsetenv("GO_ENV")
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
middleware := SetHeaders(handler)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/stream", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
// Pre-set SSE content type
|
||||
recorder.Header().Set("Content-Type", "text/event-stream")
|
||||
|
||||
middleware.ServeHTTP(recorder, req)
|
||||
|
||||
// Content-Type should remain text/event-stream
|
||||
contentType := recorder.Header().Get("Content-Type")
|
||||
if contentType != "text/event-stream" {
|
||||
t.Errorf("Expected Content-Type 'text/event-stream', got '%s'", contentType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetHeadersOptions(t *testing.T) {
|
||||
os.Setenv("GO_ENV", "production")
|
||||
defer os.Unsetenv("GO_ENV")
|
||||
|
||||
handlerCalled := false
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
handlerCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
middleware := SetHeaders(handler)
|
||||
|
||||
req := httptest.NewRequest(http.MethodOptions, "/test", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
middleware.ServeHTTP(recorder, req)
|
||||
|
||||
// OPTIONS should return 200 without calling next handler
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200 for OPTIONS, got %d", recorder.Code)
|
||||
}
|
||||
|
||||
if handlerCalled {
|
||||
t.Error("Expected next handler NOT to be called for OPTIONS request")
|
||||
}
|
||||
|
||||
// Security headers should still be set
|
||||
if recorder.Header().Get("X-Frame-Options") != "DENY" {
|
||||
t.Error("Expected security headers to be set for OPTIONS")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetHeadersAllMethods(t *testing.T) {
|
||||
os.Setenv("GO_ENV", "production")
|
||||
defer os.Unsetenv("GO_ENV")
|
||||
|
||||
methods := []string{
|
||||
http.MethodGet,
|
||||
http.MethodPost,
|
||||
http.MethodPut,
|
||||
http.MethodDelete,
|
||||
http.MethodPatch,
|
||||
}
|
||||
|
||||
for _, method := range methods {
|
||||
t.Run(method, func(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
middleware := SetHeaders(handler)
|
||||
|
||||
req := httptest.NewRequest(method, "/test", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
middleware.ServeHTTP(recorder, req)
|
||||
|
||||
// All methods should have security headers
|
||||
if recorder.Header().Get("X-Frame-Options") != "DENY" {
|
||||
t.Errorf("Expected X-Frame-Options for %s", method)
|
||||
}
|
||||
|
||||
if recorder.Header().Get("Content-Type") != "application/json" {
|
||||
t.Errorf("Expected Content-Type application/json for %s", method)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetHeadersEnvironments(t *testing.T) {
|
||||
environments := []string{"development", "production", "canary", "debug"}
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
for _, env := range environments {
|
||||
t.Run(env, func(t *testing.T) {
|
||||
os.Setenv("GO_ENV", env)
|
||||
defer os.Unsetenv("GO_ENV")
|
||||
|
||||
middleware := SetHeaders(handler)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
middleware.ServeHTTP(recorder, req)
|
||||
|
||||
// HSTS should only be set in non-development environments
|
||||
hsts := recorder.Header().Get("Strict-Transport-Security")
|
||||
if env == "development" {
|
||||
if hsts != "" {
|
||||
t.Errorf("HSTS should not be set in development, got '%s'", hsts)
|
||||
}
|
||||
} else {
|
||||
if hsts == "" {
|
||||
t.Errorf("HSTS should be set in %s environment", env)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetHeadersPoweredBy(t *testing.T) {
|
||||
os.Setenv("GO_ENV", "production")
|
||||
defer os.Unsetenv("GO_ENV")
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
middleware := SetHeaders(handler)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
middleware.ServeHTTP(recorder, req)
|
||||
|
||||
poweredBy := recorder.Header().Get("X-Powered-By")
|
||||
if poweredBy != "Zig" {
|
||||
t.Errorf("Expected X-Powered-By 'Zig', got '%s'", poweredBy)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetHeadersCSP(t *testing.T) {
|
||||
os.Setenv("GO_ENV", "production")
|
||||
defer os.Unsetenv("GO_ENV")
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
middleware := SetHeaders(handler)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
middleware.ServeHTTP(recorder, req)
|
||||
|
||||
csp := recorder.Header().Get("Content-Security-Policy")
|
||||
if csp != "default-src 'self'" {
|
||||
t.Errorf("Expected CSP 'default-src 'self'', got '%s'", csp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetHeadersReferrerPolicy(t *testing.T) {
|
||||
os.Setenv("GO_ENV", "production")
|
||||
defer os.Unsetenv("GO_ENV")
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
middleware := SetHeaders(handler)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
middleware.ServeHTTP(recorder, req)
|
||||
|
||||
referrer := recorder.Header().Get("Referrer-Policy")
|
||||
if referrer != "no-referrer" {
|
||||
t.Errorf("Expected Referrer-Policy 'no-referrer', got '%s'", referrer)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetHeadersXSSProtection(t *testing.T) {
|
||||
os.Setenv("GO_ENV", "production")
|
||||
defer os.Unsetenv("GO_ENV")
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
middleware := SetHeaders(handler)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
middleware.ServeHTTP(recorder, req)
|
||||
|
||||
xss := recorder.Header().Get("X-XSS-Protection")
|
||||
if xss != "1; mode=block" {
|
||||
t.Errorf("Expected X-XSS-Protection '1; mode=block', got '%s'", xss)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,247 @@
|
||||
//lint:file-ignore SA1029 Ignore all golangci-lint warnings in this file
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"authentication/db"
|
||||
"authentication/helper"
|
||||
"authentication/models"
|
||||
"authentication/redisclient"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/joho/godotenv"
|
||||
)
|
||||
|
||||
var (
|
||||
Blacklist = make(map[string]struct{})
|
||||
Mu sync.Mutex
|
||||
)
|
||||
|
||||
func init() {
|
||||
err := godotenv.Load()
|
||||
if err != nil {
|
||||
helper.LogWarn("Warning: Could not load .env file, using system environment variables.")
|
||||
}
|
||||
}
|
||||
|
||||
func JWTMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
DashboardBaseURL := os.Getenv("DASHBOARD_URL")
|
||||
tokenString := ""
|
||||
if isValidAuthHeader(authHeader) {
|
||||
tokenString = strings.TrimPrefix(authHeader, "Bearer ")
|
||||
} else {
|
||||
path := r.URL.Path
|
||||
if strings.Contains(path, "/sse") {
|
||||
tokenString = r.URL.Query().Get("access_token")
|
||||
if tokenString == "" {
|
||||
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("Missing access_token in query params")), http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("Invalid authorization header")), http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if isTokenBlacklisted(tokenString) {
|
||||
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("Token is blacklisted")), http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
|
||||
secretKey := os.Getenv("JWT_SECRET_KEY")
|
||||
if secretKey == "" {
|
||||
helper.RespondWithError(w, http.StatusInternalServerError, "Secret key not set")
|
||||
return
|
||||
}
|
||||
|
||||
token, err := parseToken(tokenString, secretKey)
|
||||
if err != nil {
|
||||
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape(InvalidOrExpiredToken)), http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok || !token.Valid {
|
||||
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape(InvalidTokenClaims)), http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
|
||||
// Check JWT token expiration
|
||||
|
||||
if exp, ok := claims["exp"].(float64); ok {
|
||||
if exp == 0 {
|
||||
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("Token has no expiration")), http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
// Check if token is expired
|
||||
if time.Now().Unix() > int64(exp) {
|
||||
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("Token has expired")), http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("Token missing expiration claim")), http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
|
||||
email, ok := claims["email"].(string)
|
||||
if !ok {
|
||||
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape(InvalidTokenClaims)), http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
sessionID, ok := claims["session_id"].(string)
|
||||
if !ok {
|
||||
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("Invalid session ID in token")), http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
|
||||
if isSessionBlacklisted(sessionID) {
|
||||
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("Session has been revoked")), http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
|
||||
session, err := validateSessionFromDB(sessionID)
|
||||
if err != nil {
|
||||
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("Invalid or revoked session")), http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
|
||||
userAgent := r.Header.Get("User-Agent")
|
||||
ipAddress := getClientIP(r)
|
||||
if session.UserAgent != userAgent {
|
||||
helper.LogError(nil, fmt.Sprintf("Session security mismatch for session %s", sessionID))
|
||||
}
|
||||
|
||||
if session.IPAddress != ipAddress {
|
||||
helper.LogError(nil, fmt.Sprintf("Session IP address mismatch for session %s: expected %s, got %s", sessionID, session.IPAddress, ipAddress))
|
||||
}
|
||||
|
||||
userID, err := getUserIDByEmail(email)
|
||||
if err != nil {
|
||||
if err != sql.ErrNoRows {
|
||||
helper.RespondWithError(w, http.StatusInternalServerError, "Failed to get user ID")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), "userID", userID)
|
||||
ctx = context.WithValue(ctx, "sessionID", sessionID)
|
||||
ctx = context.WithValue(ctx, "email", email)
|
||||
next.ServeHTTP(&models.FlusherPreservingResponseWriter{ResponseWriter: w}, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
func isValidAuthHeader(authHeader string) bool {
|
||||
return authHeader != "" && strings.HasPrefix(authHeader, "Bearer ")
|
||||
}
|
||||
|
||||
func isTokenBlacklisted(tokenString string) bool {
|
||||
Mu.Lock()
|
||||
defer Mu.Unlock()
|
||||
_, found := Blacklist[tokenString]
|
||||
return found
|
||||
}
|
||||
|
||||
// isSessionBlacklisted checks if a session is in the Redis blacklist
|
||||
func isSessionBlacklisted(sessionID string) bool {
|
||||
ctx := context.Background()
|
||||
blacklistKey := fmt.Sprintf("session_blacklist:%s", sessionID)
|
||||
|
||||
exists, err := redisclient.RDB.Exists(ctx, blacklistKey).Result()
|
||||
return err == nil && exists > 0
|
||||
}
|
||||
|
||||
func parseToken(tokenString, secretKey string) (*jwt.Token, error) {
|
||||
return jwt.ParseWithClaims(tokenString, jwt.MapClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
return []byte(secretKey), nil
|
||||
})
|
||||
}
|
||||
|
||||
func getUserIDByEmail(email string) (string, error) {
|
||||
var userID string
|
||||
err := db.DB.QueryRow("SELECT id FROM users WHERE email_address = ?", email).Scan(&userID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return userID, nil
|
||||
}
|
||||
|
||||
func validateSessionFromDB(sessionID string) (*models.JWTSession, error) {
|
||||
ctx := context.Background()
|
||||
sessionKey := fmt.Sprintf(redisKeyJWTSessionID, sessionID)
|
||||
|
||||
// Try to get session from Redis cache first
|
||||
var session models.JWTSession
|
||||
err := helper.GetJSON(ctx, sessionKey, &session)
|
||||
if err != nil {
|
||||
// Session not in cache, fetch from database
|
||||
err = db.DB.QueryRow(`
|
||||
SELECT id, user_id, refresh_token_hash, user_agent, ip_address, created_at, updated_at, expires_at, is_revoked
|
||||
FROM jwt_sessions
|
||||
WHERE id = ? AND is_revoked = false
|
||||
`, sessionID).Scan(
|
||||
&session.ID,
|
||||
&session.UserID,
|
||||
&session.RefreshTokenHash,
|
||||
&session.UserAgent,
|
||||
&session.IPAddress,
|
||||
&session.CreatedAt,
|
||||
&session.UpdatedAt,
|
||||
&session.ExpiresAt,
|
||||
&session.IsRevoked,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("session not found or revoked: %w", err)
|
||||
}
|
||||
|
||||
// Cache the session in Redis (TTL based on session expiry)
|
||||
sessionTTL := int(time.Until(session.ExpiresAt).Seconds())
|
||||
if sessionTTL > 0 {
|
||||
if err := helper.SetJSON(ctx, sessionKey, session, &sessionTTL); err != nil {
|
||||
helper.LogWarn(fmt.Sprintf("Failed to cache session in Redis: %v", err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if session.ExpiresAt.Before(time.Now()) {
|
||||
// Auto-revoke expired session and clear cache
|
||||
_, _ = db.DB.Exec("UPDATE jwt_sessions SET is_revoked = true WHERE id = ?", sessionID)
|
||||
redisclient.RDB.Del(ctx, sessionKey)
|
||||
return nil, fmt.Errorf("session has expired")
|
||||
}
|
||||
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
func getClientIP(r *http.Request) string {
|
||||
forwarded := r.Header.Get("X-Forwarded-For")
|
||||
if forwarded != "" {
|
||||
parts := strings.Split(forwarded, ",")
|
||||
return strings.TrimSpace(parts[0])
|
||||
}
|
||||
|
||||
realIP := r.Header.Get("X-Real-IP")
|
||||
if realIP != "" {
|
||||
return realIP
|
||||
}
|
||||
|
||||
ip := r.RemoteAddr
|
||||
if idx := strings.LastIndex(ip, ":"); idx != -1 {
|
||||
ip = ip[:idx]
|
||||
}
|
||||
return ip
|
||||
}
|
||||
@@ -0,0 +1,186 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"regexp"
|
||||
"time"
|
||||
|
||||
"authentication/db"
|
||||
"authentication/helper"
|
||||
"authentication/redisclient"
|
||||
)
|
||||
|
||||
func normalizeEndpoint(path string) string {
|
||||
uuidRegex := regexp.MustCompile(`/([a-zA-Z0-9_-]{11})(/|$)`)
|
||||
|
||||
path = uuidRegex.ReplaceAllString(path, "/{id}$2")
|
||||
|
||||
queryParamRegex := regexp.MustCompile(`\?.*`)
|
||||
return queryParamRegex.ReplaceAllString(path, "")
|
||||
}
|
||||
|
||||
func RateLimiterMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
rateLimitHeaderValue := os.Getenv("RATE_LIMIT_HEADER")
|
||||
|
||||
if rateLimitHeaderValue == "" {
|
||||
rateLimitHeaderValue = "F04C"
|
||||
}
|
||||
if r.Header.Get("X-RateLimit-Bypass") == rateLimitHeaderValue {
|
||||
// Bypass header is set to the correct value, skip rate limiting
|
||||
next.ServeHTTP(&FlusherPreservingResponseWriter{ResponseWriter: w}, r)
|
||||
return
|
||||
}
|
||||
|
||||
// If the header is not set or has an invalid value, proceed with rate limiting logic
|
||||
log.Print("No valid rate limit bypass header, proceeding with rate limiting logic")
|
||||
|
||||
// Get user identifier (email or IP)
|
||||
userIdentifier := ""
|
||||
email, err := helper.ExtractEmailFromToken(r.Header.Get("Authorization"))
|
||||
if err != nil {
|
||||
email, err = helper.ExtractEmailFromToken(r.URL.Query().Get("access_token"))
|
||||
if err != nil {
|
||||
helper.LogInfo(fmt.Sprintf("Could not extract email from token: %v, using IP-based rate limiting", err))
|
||||
}
|
||||
}
|
||||
|
||||
if email != "" {
|
||||
userIdentifier = email
|
||||
} else {
|
||||
ip, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
helper.RespondWithError(w, http.StatusInternalServerError, InternalServerError)
|
||||
return
|
||||
}
|
||||
userIdentifier = ip
|
||||
}
|
||||
|
||||
if r.URL == nil || r.URL.Path == "" {
|
||||
helper.RespondWithError(w, http.StatusInternalServerError, "Invalid request URL")
|
||||
return
|
||||
}
|
||||
endpoint := normalizeEndpoint(r.URL.Path)
|
||||
|
||||
var limitCount, timeWindow int
|
||||
err = db.DB.QueryRow("SELECT limit_count, time_window FROM rate_limiter WHERE identifier = ?", endpoint).Scan(&limitCount, &timeWindow)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
limitCount = 300
|
||||
timeWindow = 60
|
||||
_, insertErr := db.DB.Exec("INSERT INTO rate_limiter (identifier, limit_count, time_window) VALUES (?, ?, ?)", endpoint, limitCount, timeWindow)
|
||||
if insertErr != nil {
|
||||
helper.RespondWithError(w, http.StatusInternalServerError, InternalServerError)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
helper.RespondWithError(w, http.StatusInternalServerError, InternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
redisCountKey := "ratelimit_count:" + userIdentifier + ":" + endpoint
|
||||
|
||||
if redisclient.RDB == nil {
|
||||
helper.RespondWithError(w, http.StatusInternalServerError, "Redis client not initialized")
|
||||
return
|
||||
}
|
||||
|
||||
count, err := redisclient.RDB.Incr(r.Context(), redisCountKey).Result()
|
||||
if err != nil {
|
||||
helper.RespondWithError(w, http.StatusInternalServerError, InternalServerError)
|
||||
return
|
||||
}
|
||||
if count == 1 {
|
||||
_ = redisclient.RDB.Expire(r.Context(), redisCountKey, time.Duration(timeWindow)*time.Second).Err()
|
||||
}
|
||||
if int(count) > limitCount {
|
||||
println("Rate limit exceeded: user=" + userIdentifier + " endpoint=" + endpoint + " count=" +
|
||||
fmt.Sprintf("%d", count) + " limit=" + fmt.Sprintf("%d", limitCount))
|
||||
helper.RespondWithError(w, http.StatusTooManyRequests, "Rate limit exceeded")
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func PublicRateLimiterMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Header.Get("X-RateLimit-Bypass") == "F04C" {
|
||||
next.ServeHTTP(&FlusherPreservingResponseWriter{ResponseWriter: w}, r)
|
||||
return
|
||||
}
|
||||
|
||||
log.Print("No valid rate limit bypass header, proceeding with rate limiting logic")
|
||||
|
||||
// Use IP address as the user identifier for public endpoints
|
||||
ip, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
helper.RespondWithError(w, http.StatusInternalServerError, InternalServerError)
|
||||
return
|
||||
}
|
||||
userIdentifier := ip
|
||||
|
||||
if r.URL == nil || r.URL.Path == "" {
|
||||
helper.RespondWithError(w, http.StatusInternalServerError, "Invalid request URL")
|
||||
return
|
||||
}
|
||||
endpoint := normalizeEndpoint(r.URL.Path)
|
||||
|
||||
var limitCount, timeWindow int
|
||||
err = db.DB.QueryRow("SELECT limit_count, time_window FROM rate_limiter WHERE identifier = ?", endpoint).Scan(&limitCount, &timeWindow)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
limitCount = 36000
|
||||
timeWindow = 60
|
||||
_, insertErr := db.DB.Exec("INSERT INTO rate_limiter (identifier, limit_count, time_window) VALUES (?, ?, ?)", endpoint, limitCount, timeWindow)
|
||||
if insertErr != nil {
|
||||
helper.RespondWithError(w, http.StatusInternalServerError, InternalServerError)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
helper.RespondWithError(w, http.StatusInternalServerError, InternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
redisCountKey := "ratelimit_count:" + userIdentifier + ":" + endpoint
|
||||
|
||||
if redisclient.RDB == nil {
|
||||
helper.RespondWithError(w, http.StatusInternalServerError, "Redis client not initialized")
|
||||
return
|
||||
}
|
||||
|
||||
count, err := redisclient.RDB.Incr(r.Context(), redisCountKey).Result()
|
||||
if err != nil {
|
||||
helper.RespondWithError(w, http.StatusInternalServerError, InternalServerError)
|
||||
return
|
||||
}
|
||||
if count == 1 {
|
||||
err := redisclient.RDB.Expire(r.Context(), redisCountKey, time.Duration(timeWindow)*time.Second).Err()
|
||||
if err != nil {
|
||||
helper.RespondWithError(w, http.StatusInternalServerError, InternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Log the key and value saved
|
||||
log.Printf("Redis key: %s, value: %d", redisCountKey, count)
|
||||
|
||||
if int(count) > limitCount {
|
||||
println("Rate limit exceeded: user=" + userIdentifier + " endpoint=" + endpoint + " count=" +
|
||||
fmt.Sprintf("%d", count) + " limit=" + fmt.Sprintf("%d", limitCount))
|
||||
helper.RespondWithError(w, http.StatusTooManyRequests, "Rate limit exceeded")
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,214 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNormalizeEndpoint(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Simple path",
|
||||
input: "/api/users",
|
||||
expected: "/api/users",
|
||||
},
|
||||
{
|
||||
name: "Path with UUID",
|
||||
input: "/api/users/abcdef12345",
|
||||
expected: "/api/users/{id}",
|
||||
},
|
||||
{
|
||||
name: "Path with UUID and trailing slash",
|
||||
input: "/api/users/abcdef12345/",
|
||||
expected: "/api/users/{id}/",
|
||||
},
|
||||
{
|
||||
name: "Path with UUID in middle",
|
||||
input: "/api/users/abcdef12345/profile",
|
||||
expected: "/api/users/{id}/profile",
|
||||
},
|
||||
{
|
||||
name: "Path with query params",
|
||||
input: "/api/users?page=1&limit=10",
|
||||
expected: "/api/users",
|
||||
},
|
||||
{
|
||||
name: "Path with UUID and query params",
|
||||
input: "/api/users/abcdef12345?detail=full",
|
||||
expected: "/api/users/abcdef12345", // Query params removed first, then UUID not matched
|
||||
},
|
||||
{
|
||||
name: "Multiple UUIDs",
|
||||
input: "/api/users/abc12345678/posts/def87654321",
|
||||
expected: "/api/users/{id}/posts/{id}",
|
||||
},
|
||||
{
|
||||
name: "Root path",
|
||||
input: "/",
|
||||
expected: "/",
|
||||
},
|
||||
{
|
||||
name: "Empty path",
|
||||
input: "",
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := normalizeEndpoint(tc.input)
|
||||
if result != tc.expected {
|
||||
t.Errorf("Expected '%s', got '%s'", tc.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeEndpointUUIDFormats(t *testing.T) {
|
||||
uuidFormats := []string{
|
||||
"abcdef12345",
|
||||
"ABCDEF12345",
|
||||
"abc_def1234",
|
||||
"abc-def1234",
|
||||
"mixedCase12",
|
||||
}
|
||||
|
||||
for _, uuid := range uuidFormats {
|
||||
t.Run(uuid, func(t *testing.T) {
|
||||
input := "/api/users/" + uuid
|
||||
result := normalizeEndpoint(input)
|
||||
expected := "/api/users/{id}"
|
||||
|
||||
if result != expected {
|
||||
t.Errorf("Expected '%s', got '%s' for UUID format '%s'", expected, result, uuid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeEndpointComplexQueryStrings(t *testing.T) {
|
||||
testCases := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"/api/users?a=1&b=2&c=3", "/api/users"},
|
||||
{"/api/users?filter=active&sort=name&order=asc", "/api/users"},
|
||||
{"/api/users?search=john+doe", "/api/users"},
|
||||
{"/api/users?tags[]=tag1&tags[]=tag2", "/api/users"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
result := normalizeEndpoint(tc.input)
|
||||
if result != tc.expected {
|
||||
t.Errorf("Input '%s': expected '%s', got '%s'", tc.input, tc.expected, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeEndpointEdgeCases(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Just query string",
|
||||
input: "?param=value",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Double slashes",
|
||||
input: "/api//users",
|
||||
expected: "/api//users",
|
||||
},
|
||||
{
|
||||
name: "Trailing query without params",
|
||||
input: "/api/users?",
|
||||
expected: "/api/users",
|
||||
},
|
||||
{
|
||||
name: "UUID at end without slash",
|
||||
input: "/users/abc12345678",
|
||||
expected: "/users/{id}",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := normalizeEndpoint(tc.input)
|
||||
if result != tc.expected {
|
||||
t.Errorf("Expected '%s', got '%s'", tc.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeEndpointPreservesNonUUID(t *testing.T) {
|
||||
testCases := []string{
|
||||
"/api/users/all",
|
||||
"/api/users/active",
|
||||
"/api/sessions/current",
|
||||
"/health",
|
||||
"/metrics",
|
||||
}
|
||||
|
||||
for _, input := range testCases {
|
||||
t.Run(input, func(t *testing.T) {
|
||||
result := normalizeEndpoint(input)
|
||||
if result != input {
|
||||
t.Errorf("Non-UUID path should be preserved. Input: '%s', got '%s'", input, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeEndpointUUIDLength(t *testing.T) {
|
||||
// UUIDs must be exactly 11 characters
|
||||
testCases := []struct {
|
||||
name string
|
||||
uuid string
|
||||
shouldNormalize bool
|
||||
}{
|
||||
{"10 chars", "abcdefghij", false},
|
||||
{"11 chars", "abcdefghijk", true},
|
||||
{"12 chars", "abcdefghijkl", false},
|
||||
{"5 chars", "abcde", false},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
input := "/api/users/" + tc.uuid
|
||||
result := normalizeEndpoint(input)
|
||||
|
||||
if tc.shouldNormalize {
|
||||
expected := "/api/users/{id}"
|
||||
if result != expected {
|
||||
t.Errorf("Expected '%s', got '%s'", expected, result)
|
||||
}
|
||||
} else {
|
||||
if result != input {
|
||||
t.Errorf("Should not normalize, expected '%s', got '%s'", input, result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkNormalizeEndpoint(b *testing.B) {
|
||||
testPaths := []string{
|
||||
"/api/users",
|
||||
"/api/users/abc12345678",
|
||||
"/api/users/abc12345678/profile",
|
||||
"/api/users?page=1&limit=10",
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, path := range testPaths {
|
||||
normalizeEndpoint(path)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestMain runs before all tests and sets up the test environment
|
||||
func TestMain(m *testing.M) {
|
||||
// Set GO_ENV for all tests to prevent init() failures
|
||||
os.Setenv("GO_ENV", "development")
|
||||
|
||||
// Run tests
|
||||
code := m.Run()
|
||||
|
||||
// Cleanup
|
||||
os.Unsetenv("GO_ENV")
|
||||
|
||||
os.Exit(code)
|
||||
}
|
||||
Reference in New Issue
Block a user