package middleware import ( "database/sql" "fmt" "log" "net" "net/http" "os" "regexp" "time" "authentication/db" "authentication/helper" "authentication/redisclient" ) func normalizeEndpoint(path string) string { uuidRegex := regexp.MustCompile(`/([a-zA-Z0-9_-]{11})(/|$)`) path = uuidRegex.ReplaceAllString(path, "/{id}$2") queryParamRegex := regexp.MustCompile(`\?.*`) return queryParamRegex.ReplaceAllString(path, "") } func RateLimiterMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { rateLimitHeaderValue := os.Getenv("RATE_LIMIT_HEADER") if rateLimitHeaderValue == "" { rateLimitHeaderValue = "F04C" } if r.Header.Get("X-RateLimit-Bypass") == rateLimitHeaderValue { // Bypass header is set to the correct value, skip rate limiting next.ServeHTTP(&FlusherPreservingResponseWriter{ResponseWriter: w}, r) return } // If the header is not set or has an invalid value, proceed with rate limiting logic log.Print("No valid rate limit bypass header, proceeding with rate limiting logic") // Get user identifier (email or IP) userIdentifier := "" email, err := helper.ExtractEmailFromToken(r.Header.Get("Authorization")) if err != nil { email, err = helper.ExtractEmailFromToken(r.URL.Query().Get("access_token")) if err != nil { helper.LogInfo(fmt.Sprintf("Could not extract email from token: %v, using IP-based rate limiting", err)) } } if email != "" { userIdentifier = email } else { ip, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { helper.RespondWithError(w, http.StatusInternalServerError, InternalServerError) return } userIdentifier = ip } if r.URL == nil || r.URL.Path == "" { helper.RespondWithError(w, http.StatusInternalServerError, "Invalid request URL") return } endpoint := normalizeEndpoint(r.URL.Path) var limitCount, timeWindow int err = db.DB.QueryRow("SELECT limit_count, time_window FROM rate_limiter WHERE identifier = ?", endpoint).Scan(&limitCount, &timeWindow) if err != nil { if err == sql.ErrNoRows { limitCount = 300 timeWindow = 60 _, insertErr := db.DB.Exec("INSERT INTO rate_limiter (identifier, limit_count, time_window) VALUES (?, ?, ?)", endpoint, limitCount, timeWindow) if insertErr != nil { helper.RespondWithError(w, http.StatusInternalServerError, InternalServerError) return } } else { helper.RespondWithError(w, http.StatusInternalServerError, InternalServerError) return } } redisCountKey := "ratelimit_count:" + userIdentifier + ":" + endpoint if redisclient.RDB == nil { helper.RespondWithError(w, http.StatusInternalServerError, "Redis client not initialized") return } count, err := redisclient.RDB.Incr(r.Context(), redisCountKey).Result() if err != nil { helper.RespondWithError(w, http.StatusInternalServerError, InternalServerError) return } if count == 1 { _ = redisclient.RDB.Expire(r.Context(), redisCountKey, time.Duration(timeWindow)*time.Second).Err() } if int(count) > limitCount { println("Rate limit exceeded: user=" + userIdentifier + " endpoint=" + endpoint + " count=" + fmt.Sprintf("%d", count) + " limit=" + fmt.Sprintf("%d", limitCount)) helper.RespondWithError(w, http.StatusTooManyRequests, "Rate limit exceeded") return } next.ServeHTTP(w, r) }) } func PublicRateLimiterMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Header.Get("X-RateLimit-Bypass") == "F04C" { next.ServeHTTP(&FlusherPreservingResponseWriter{ResponseWriter: w}, r) return } log.Print("No valid rate limit bypass header, proceeding with rate limiting logic") // Use IP address as the user identifier for public endpoints ip, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { helper.RespondWithError(w, http.StatusInternalServerError, InternalServerError) return } userIdentifier := ip if r.URL == nil || r.URL.Path == "" { helper.RespondWithError(w, http.StatusInternalServerError, "Invalid request URL") return } endpoint := normalizeEndpoint(r.URL.Path) var limitCount, timeWindow int err = db.DB.QueryRow("SELECT limit_count, time_window FROM rate_limiter WHERE identifier = ?", endpoint).Scan(&limitCount, &timeWindow) if err != nil { if err == sql.ErrNoRows { limitCount = 36000 timeWindow = 60 _, insertErr := db.DB.Exec("INSERT INTO rate_limiter (identifier, limit_count, time_window) VALUES (?, ?, ?)", endpoint, limitCount, timeWindow) if insertErr != nil { helper.RespondWithError(w, http.StatusInternalServerError, InternalServerError) return } } else { helper.RespondWithError(w, http.StatusInternalServerError, InternalServerError) return } } redisCountKey := "ratelimit_count:" + userIdentifier + ":" + endpoint if redisclient.RDB == nil { helper.RespondWithError(w, http.StatusInternalServerError, "Redis client not initialized") return } count, err := redisclient.RDB.Incr(r.Context(), redisCountKey).Result() if err != nil { helper.RespondWithError(w, http.StatusInternalServerError, InternalServerError) return } if count == 1 { err := redisclient.RDB.Expire(r.Context(), redisCountKey, time.Duration(timeWindow)*time.Second).Err() if err != nil { helper.RespondWithError(w, http.StatusInternalServerError, InternalServerError) return } } // Log the key and value saved log.Printf("Redis key: %s, value: %d", redisCountKey, count) if int(count) > limitCount { println("Rate limit exceeded: user=" + userIdentifier + " endpoint=" + endpoint + " count=" + fmt.Sprintf("%d", count) + " limit=" + fmt.Sprintf("%d", limitCount)) helper.RespondWithError(w, http.StatusTooManyRequests, "Rate limit exceeded") return } next.ServeHTTP(w, r) }) }