Files
Authentication/middleware/rate_limiter.go
T
2025-11-25 15:12:31 +08:00

187 lines
5.8 KiB
Go

package middleware
import (
"database/sql"
"fmt"
"log"
"net"
"net/http"
"os"
"regexp"
"time"
"authentication/db"
"authentication/helper"
"authentication/redisclient"
)
func normalizeEndpoint(path string) string {
uuidRegex := regexp.MustCompile(`/([a-zA-Z0-9_-]{11})(/|$)`)
path = uuidRegex.ReplaceAllString(path, "/{id}$2")
queryParamRegex := regexp.MustCompile(`\?.*`)
return queryParamRegex.ReplaceAllString(path, "")
}
func RateLimiterMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
rateLimitHeaderValue := os.Getenv("RATE_LIMIT_HEADER")
if rateLimitHeaderValue == "" {
rateLimitHeaderValue = "F04C"
}
if r.Header.Get("X-RateLimit-Bypass") == rateLimitHeaderValue {
// Bypass header is set to the correct value, skip rate limiting
next.ServeHTTP(&FlusherPreservingResponseWriter{ResponseWriter: w}, r)
return
}
// If the header is not set or has an invalid value, proceed with rate limiting logic
log.Print("No valid rate limit bypass header, proceeding with rate limiting logic")
// Get user identifier (email or IP)
userIdentifier := ""
email, err := helper.ExtractEmailFromToken(r.Header.Get("Authorization"))
if err != nil {
email, err = helper.ExtractEmailFromToken(r.URL.Query().Get("access_token"))
if err != nil {
helper.LogInfo(fmt.Sprintf("Could not extract email from token: %v, using IP-based rate limiting", err))
}
}
if email != "" {
userIdentifier = email
} else {
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
helper.RespondWithError(w, http.StatusInternalServerError, InternalServerError)
return
}
userIdentifier = ip
}
if r.URL == nil || r.URL.Path == "" {
helper.RespondWithError(w, http.StatusInternalServerError, "Invalid request URL")
return
}
endpoint := normalizeEndpoint(r.URL.Path)
var limitCount, timeWindow int
err = db.DB.QueryRow("SELECT limit_count, time_window FROM rate_limiter WHERE identifier = ?", endpoint).Scan(&limitCount, &timeWindow)
if err != nil {
if err == sql.ErrNoRows {
limitCount = 300
timeWindow = 60
_, insertErr := db.DB.Exec("INSERT INTO rate_limiter (identifier, limit_count, time_window) VALUES (?, ?, ?)", endpoint, limitCount, timeWindow)
if insertErr != nil {
helper.RespondWithError(w, http.StatusInternalServerError, InternalServerError)
return
}
} else {
helper.RespondWithError(w, http.StatusInternalServerError, InternalServerError)
return
}
}
redisCountKey := "ratelimit_count:" + userIdentifier + ":" + endpoint
if redisclient.RDB == nil {
helper.RespondWithError(w, http.StatusInternalServerError, "Redis client not initialized")
return
}
count, err := redisclient.RDB.Incr(r.Context(), redisCountKey).Result()
if err != nil {
helper.RespondWithError(w, http.StatusInternalServerError, InternalServerError)
return
}
if count == 1 {
_ = redisclient.RDB.Expire(r.Context(), redisCountKey, time.Duration(timeWindow)*time.Second).Err()
}
if int(count) > limitCount {
println("Rate limit exceeded: user=" + userIdentifier + " endpoint=" + endpoint + " count=" +
fmt.Sprintf("%d", count) + " limit=" + fmt.Sprintf("%d", limitCount))
helper.RespondWithError(w, http.StatusTooManyRequests, "Rate limit exceeded")
return
}
next.ServeHTTP(w, r)
})
}
func PublicRateLimiterMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("X-RateLimit-Bypass") == "F04C" {
next.ServeHTTP(&FlusherPreservingResponseWriter{ResponseWriter: w}, r)
return
}
log.Print("No valid rate limit bypass header, proceeding with rate limiting logic")
// Use IP address as the user identifier for public endpoints
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
helper.RespondWithError(w, http.StatusInternalServerError, InternalServerError)
return
}
userIdentifier := ip
if r.URL == nil || r.URL.Path == "" {
helper.RespondWithError(w, http.StatusInternalServerError, "Invalid request URL")
return
}
endpoint := normalizeEndpoint(r.URL.Path)
var limitCount, timeWindow int
err = db.DB.QueryRow("SELECT limit_count, time_window FROM rate_limiter WHERE identifier = ?", endpoint).Scan(&limitCount, &timeWindow)
if err != nil {
if err == sql.ErrNoRows {
limitCount = 36000
timeWindow = 60
_, insertErr := db.DB.Exec("INSERT INTO rate_limiter (identifier, limit_count, time_window) VALUES (?, ?, ?)", endpoint, limitCount, timeWindow)
if insertErr != nil {
helper.RespondWithError(w, http.StatusInternalServerError, InternalServerError)
return
}
} else {
helper.RespondWithError(w, http.StatusInternalServerError, InternalServerError)
return
}
}
redisCountKey := "ratelimit_count:" + userIdentifier + ":" + endpoint
if redisclient.RDB == nil {
helper.RespondWithError(w, http.StatusInternalServerError, "Redis client not initialized")
return
}
count, err := redisclient.RDB.Incr(r.Context(), redisCountKey).Result()
if err != nil {
helper.RespondWithError(w, http.StatusInternalServerError, InternalServerError)
return
}
if count == 1 {
err := redisclient.RDB.Expire(r.Context(), redisCountKey, time.Duration(timeWindow)*time.Second).Err()
if err != nil {
helper.RespondWithError(w, http.StatusInternalServerError, InternalServerError)
return
}
}
// Log the key and value saved
log.Printf("Redis key: %s, value: %d", redisCountKey, count)
if int(count) > limitCount {
println("Rate limit exceeded: user=" + userIdentifier + " endpoint=" + endpoint + " count=" +
fmt.Sprintf("%d", count) + " limit=" + fmt.Sprintf("%d", limitCount))
helper.RespondWithError(w, http.StatusTooManyRequests, "Rate limit exceeded")
return
}
next.ServeHTTP(w, r)
})
}