133 lines
2.9 KiB
Go
133 lines
2.9 KiB
Go
package middleware
|
|
|
|
import (
|
|
"authentication/helper"
|
|
"crypto/rand"
|
|
"encoding/base64"
|
|
"log"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
const (
|
|
csrfTokenHeader = "X-CSRF-Token" // #nosec G101
|
|
csrfCookieName = "csrf_token"
|
|
csrfTokenLength = 32
|
|
csrfTokenTTL = 24 * time.Hour
|
|
)
|
|
|
|
type csrfTokenStore struct {
|
|
tokens map[string]time.Time
|
|
mu sync.RWMutex
|
|
}
|
|
|
|
var tokenStore = &csrfTokenStore{
|
|
tokens: make(map[string]time.Time),
|
|
}
|
|
|
|
var cleanupOnce sync.Once
|
|
|
|
func generateCSRFToken() (string, error) {
|
|
b := make([]byte, csrfTokenLength)
|
|
_, err := rand.Read(b)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return base64.URLEncoding.EncodeToString(b), nil
|
|
}
|
|
|
|
func cleanupExpiredTokens() {
|
|
tokenStore.mu.Lock()
|
|
defer tokenStore.mu.Unlock()
|
|
|
|
now := time.Now()
|
|
for token, expiry := range tokenStore.tokens {
|
|
if now.After(expiry) {
|
|
delete(tokenStore.tokens, token)
|
|
}
|
|
}
|
|
}
|
|
|
|
func validateToken(token string) bool {
|
|
tokenStore.mu.RLock()
|
|
defer tokenStore.mu.RUnlock()
|
|
|
|
expiry, exists := tokenStore.tokens[token]
|
|
if !exists {
|
|
return false
|
|
}
|
|
|
|
return time.Now().Before(expiry)
|
|
}
|
|
|
|
func storeToken(token string) {
|
|
tokenStore.mu.Lock()
|
|
defer tokenStore.mu.Unlock()
|
|
|
|
tokenStore.tokens[token] = time.Now().Add(csrfTokenTTL)
|
|
}
|
|
|
|
func CSRFMiddleware(next http.Handler) http.Handler {
|
|
cleanupOnce.Do(func() {
|
|
go func() {
|
|
ticker := time.NewTicker(1 * time.Hour)
|
|
defer ticker.Stop()
|
|
for range ticker.C {
|
|
cleanupExpiredTokens()
|
|
}
|
|
}()
|
|
})
|
|
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
log.Print("Request headers: ", r.Header)
|
|
if r.Method == http.MethodGet || r.Method == http.MethodHead || r.Method == http.MethodOptions {
|
|
// For GET requests, generate and set a new CSRF token
|
|
token, err := generateCSRFToken()
|
|
if err != nil {
|
|
helper.RespondWithError(w, http.StatusInternalServerError, InternalServerError)
|
|
return
|
|
}
|
|
log.Print("Generated CSRF token: ", token)
|
|
storeToken(token)
|
|
|
|
expires := time.Now().Add(csrfTokenTTL)
|
|
http.SetCookie(w, &http.Cookie{
|
|
Name: csrfCookieName,
|
|
Value: token,
|
|
Path: "/",
|
|
HttpOnly: true,
|
|
Secure: true,
|
|
SameSite: http.SameSiteStrictMode,
|
|
MaxAge: int(csrfTokenTTL.Seconds()),
|
|
Expires: expires,
|
|
})
|
|
w.Header().Set(csrfTokenHeader, token)
|
|
log.Print("Set CSRF token cookie and header")
|
|
next.ServeHTTP(&FlusherPreservingResponseWriter{ResponseWriter: w}, r)
|
|
return
|
|
}
|
|
|
|
tokenFromHeader := r.Header.Get(csrfTokenHeader)
|
|
if tokenFromHeader == "" {
|
|
helper.RespondWithError(w, http.StatusForbidden, "CSRF token missing from header")
|
|
return
|
|
}
|
|
|
|
if strings.Contains(tokenFromHeader, "%") {
|
|
if decoded, err := url.QueryUnescape(tokenFromHeader); err == nil {
|
|
tokenFromHeader = decoded
|
|
}
|
|
}
|
|
|
|
if !validateToken(tokenFromHeader) {
|
|
helper.RespondWithError(w, http.StatusForbidden, "Invalid or expired CSRF token")
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(&FlusherPreservingResponseWriter{ResponseWriter: w}, r)
|
|
})
|
|
}
|