Files
2026-02-18 10:33:42 +08:00

133 lines
2.9 KiB
Go

package middleware
import (
"authentication/helper"
"crypto/rand"
"encoding/base64"
"log"
"net/http"
"net/url"
"strings"
"sync"
"time"
)
const (
csrfTokenHeader = "X-CSRF-Token" // #nosec G101
csrfCookieName = "csrf_token"
csrfTokenLength = 32
csrfTokenTTL = 24 * time.Hour
)
type csrfTokenStore struct {
tokens map[string]time.Time
mu sync.RWMutex
}
var tokenStore = &csrfTokenStore{
tokens: make(map[string]time.Time),
}
var cleanupOnce sync.Once
func generateCSRFToken() (string, error) {
b := make([]byte, csrfTokenLength)
_, err := rand.Read(b)
if err != nil {
return "", err
}
return base64.URLEncoding.EncodeToString(b), nil
}
func cleanupExpiredTokens() {
tokenStore.mu.Lock()
defer tokenStore.mu.Unlock()
now := time.Now()
for token, expiry := range tokenStore.tokens {
if now.After(expiry) {
delete(tokenStore.tokens, token)
}
}
}
func validateToken(token string) bool {
tokenStore.mu.RLock()
defer tokenStore.mu.RUnlock()
expiry, exists := tokenStore.tokens[token]
if !exists {
return false
}
return time.Now().Before(expiry)
}
func storeToken(token string) {
tokenStore.mu.Lock()
defer tokenStore.mu.Unlock()
tokenStore.tokens[token] = time.Now().Add(csrfTokenTTL)
}
func CSRFMiddleware(next http.Handler) http.Handler {
cleanupOnce.Do(func() {
go func() {
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
for range ticker.C {
cleanupExpiredTokens()
}
}()
})
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Print("Request headers: ", r.Header)
if r.Method == http.MethodGet || r.Method == http.MethodHead || r.Method == http.MethodOptions {
// For GET requests, generate and set a new CSRF token
token, err := generateCSRFToken()
if err != nil {
helper.RespondWithError(w, http.StatusInternalServerError, InternalServerError)
return
}
log.Print("Generated CSRF token: ", token)
storeToken(token)
expires := time.Now().Add(csrfTokenTTL)
http.SetCookie(w, &http.Cookie{
Name: csrfCookieName,
Value: token,
Path: "/",
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteStrictMode,
MaxAge: int(csrfTokenTTL.Seconds()),
Expires: expires,
})
w.Header().Set(csrfTokenHeader, token)
log.Print("Set CSRF token cookie and header")
next.ServeHTTP(&FlusherPreservingResponseWriter{ResponseWriter: w}, r)
return
}
tokenFromHeader := r.Header.Get(csrfTokenHeader)
if tokenFromHeader == "" {
helper.RespondWithError(w, http.StatusForbidden, "CSRF token missing from header")
return
}
if strings.Contains(tokenFromHeader, "%") {
if decoded, err := url.QueryUnescape(tokenFromHeader); err == nil {
tokenFromHeader = decoded
}
}
if !validateToken(tokenFromHeader) {
helper.RespondWithError(w, http.StatusForbidden, "Invalid or expired CSRF token")
return
}
next.ServeHTTP(&FlusherPreservingResponseWriter{ResponseWriter: w}, r)
})
}