187 lines
5.8 KiB
Go
187 lines
5.8 KiB
Go
package middleware
|
|
|
|
import (
|
|
"database/sql"
|
|
"fmt"
|
|
"log"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"regexp"
|
|
"time"
|
|
|
|
"authentication/db"
|
|
"authentication/helper"
|
|
"authentication/redisclient"
|
|
)
|
|
|
|
func normalizeEndpoint(path string) string {
|
|
uuidRegex := regexp.MustCompile(`/([a-zA-Z0-9_-]{11})(/|$)`)
|
|
|
|
path = uuidRegex.ReplaceAllString(path, "/{id}$2")
|
|
|
|
queryParamRegex := regexp.MustCompile(`\?.*`)
|
|
return queryParamRegex.ReplaceAllString(path, "")
|
|
}
|
|
|
|
func RateLimiterMiddleware(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
|
|
rateLimitHeaderValue := os.Getenv("RATE_LIMIT_HEADER")
|
|
|
|
if rateLimitHeaderValue == "" {
|
|
rateLimitHeaderValue = "F04C"
|
|
}
|
|
if r.Header.Get("X-RateLimit-Bypass") == rateLimitHeaderValue {
|
|
// Bypass header is set to the correct value, skip rate limiting
|
|
next.ServeHTTP(&FlusherPreservingResponseWriter{ResponseWriter: w}, r)
|
|
return
|
|
}
|
|
|
|
// If the header is not set or has an invalid value, proceed with rate limiting logic
|
|
log.Print("No valid rate limit bypass header, proceeding with rate limiting logic")
|
|
|
|
// Get user identifier (email or IP)
|
|
userIdentifier := ""
|
|
email, err := helper.ExtractEmailFromToken(r.Header.Get("Authorization"))
|
|
if err != nil {
|
|
email, err = helper.ExtractEmailFromToken(r.URL.Query().Get("access_token"))
|
|
if err != nil {
|
|
helper.LogInfo(fmt.Sprintf("Could not extract email from token: %v, using IP-based rate limiting", err))
|
|
}
|
|
}
|
|
|
|
if email != "" {
|
|
userIdentifier = email
|
|
} else {
|
|
ip, _, err := net.SplitHostPort(r.RemoteAddr)
|
|
if err != nil {
|
|
helper.RespondWithError(w, http.StatusInternalServerError, InternalServerError)
|
|
return
|
|
}
|
|
userIdentifier = ip
|
|
}
|
|
|
|
if r.URL == nil || r.URL.Path == "" {
|
|
helper.RespondWithError(w, http.StatusInternalServerError, "Invalid request URL")
|
|
return
|
|
}
|
|
endpoint := normalizeEndpoint(r.URL.Path)
|
|
|
|
var limitCount, timeWindow int
|
|
err = db.DB.QueryRow("SELECT limit_count, time_window FROM rate_limiter WHERE identifier = ?", endpoint).Scan(&limitCount, &timeWindow)
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
limitCount = 300
|
|
timeWindow = 60
|
|
_, insertErr := db.DB.Exec("INSERT INTO rate_limiter (identifier, limit_count, time_window) VALUES (?, ?, ?)", endpoint, limitCount, timeWindow)
|
|
if insertErr != nil {
|
|
helper.RespondWithError(w, http.StatusInternalServerError, InternalServerError)
|
|
return
|
|
}
|
|
} else {
|
|
helper.RespondWithError(w, http.StatusInternalServerError, InternalServerError)
|
|
return
|
|
}
|
|
}
|
|
|
|
redisCountKey := "ratelimit_count:" + userIdentifier + ":" + endpoint
|
|
|
|
if redisclient.RDB == nil {
|
|
helper.RespondWithError(w, http.StatusInternalServerError, "Redis client not initialized")
|
|
return
|
|
}
|
|
|
|
count, err := redisclient.RDB.Incr(r.Context(), redisCountKey).Result()
|
|
if err != nil {
|
|
helper.RespondWithError(w, http.StatusInternalServerError, InternalServerError)
|
|
return
|
|
}
|
|
if count == 1 {
|
|
_ = redisclient.RDB.Expire(r.Context(), redisCountKey, time.Duration(timeWindow)*time.Second).Err()
|
|
}
|
|
if int(count) > limitCount {
|
|
println("Rate limit exceeded: user=" + userIdentifier + " endpoint=" + endpoint + " count=" +
|
|
fmt.Sprintf("%d", count) + " limit=" + fmt.Sprintf("%d", limitCount))
|
|
helper.RespondWithError(w, http.StatusTooManyRequests, "Rate limit exceeded")
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
func PublicRateLimiterMiddleware(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Header.Get("X-RateLimit-Bypass") == "F04C" {
|
|
next.ServeHTTP(&FlusherPreservingResponseWriter{ResponseWriter: w}, r)
|
|
return
|
|
}
|
|
|
|
log.Print("No valid rate limit bypass header, proceeding with rate limiting logic")
|
|
|
|
// Use IP address as the user identifier for public endpoints
|
|
ip, _, err := net.SplitHostPort(r.RemoteAddr)
|
|
if err != nil {
|
|
helper.RespondWithError(w, http.StatusInternalServerError, InternalServerError)
|
|
return
|
|
}
|
|
userIdentifier := ip
|
|
|
|
if r.URL == nil || r.URL.Path == "" {
|
|
helper.RespondWithError(w, http.StatusInternalServerError, "Invalid request URL")
|
|
return
|
|
}
|
|
endpoint := normalizeEndpoint(r.URL.Path)
|
|
|
|
var limitCount, timeWindow int
|
|
err = db.DB.QueryRow("SELECT limit_count, time_window FROM rate_limiter WHERE identifier = ?", endpoint).Scan(&limitCount, &timeWindow)
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
limitCount = 36000
|
|
timeWindow = 60
|
|
_, insertErr := db.DB.Exec("INSERT INTO rate_limiter (identifier, limit_count, time_window) VALUES (?, ?, ?)", endpoint, limitCount, timeWindow)
|
|
if insertErr != nil {
|
|
helper.RespondWithError(w, http.StatusInternalServerError, InternalServerError)
|
|
return
|
|
}
|
|
} else {
|
|
helper.RespondWithError(w, http.StatusInternalServerError, InternalServerError)
|
|
return
|
|
}
|
|
}
|
|
|
|
redisCountKey := "ratelimit_count:" + userIdentifier + ":" + endpoint
|
|
|
|
if redisclient.RDB == nil {
|
|
helper.RespondWithError(w, http.StatusInternalServerError, "Redis client not initialized")
|
|
return
|
|
}
|
|
|
|
count, err := redisclient.RDB.Incr(r.Context(), redisCountKey).Result()
|
|
if err != nil {
|
|
helper.RespondWithError(w, http.StatusInternalServerError, InternalServerError)
|
|
return
|
|
}
|
|
if count == 1 {
|
|
err := redisclient.RDB.Expire(r.Context(), redisCountKey, time.Duration(timeWindow)*time.Second).Err()
|
|
if err != nil {
|
|
helper.RespondWithError(w, http.StatusInternalServerError, InternalServerError)
|
|
return
|
|
}
|
|
}
|
|
|
|
// Log the key and value saved
|
|
log.Printf("Redis key: %s, value: %d", redisCountKey, count)
|
|
|
|
if int(count) > limitCount {
|
|
println("Rate limit exceeded: user=" + userIdentifier + " endpoint=" + endpoint + " count=" +
|
|
fmt.Sprintf("%d", count) + " limit=" + fmt.Sprintf("%d", limitCount))
|
|
helper.RespondWithError(w, http.StatusTooManyRequests, "Rate limit exceeded")
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|