added csrf

This commit is contained in:
2026-01-21 13:41:15 +08:00
parent 7caf9b069d
commit 33c59d1c6d
2 changed files with 128 additions and 2 deletions
+123
View File
@@ -0,0 +1,123 @@
package middleware
import (
"authentication/helper"
"crypto/rand"
"encoding/base64"
"log"
"net/http"
"sync"
"time"
)
const (
csrfTokenHeader = "X-CSRF-Token"
csrfCookieName = "csrf_token"
csrfTokenLength = 32
csrfTokenTTL = 24 * time.Hour
)
type csrfTokenStore struct {
tokens map[string]time.Time
mu sync.RWMutex
}
var tokenStore = &csrfTokenStore{
tokens: make(map[string]time.Time),
}
var cleanupOnce sync.Once
func generateCSRFToken() (string, error) {
b := make([]byte, csrfTokenLength)
_, err := rand.Read(b)
if err != nil {
return "", err
}
return base64.URLEncoding.EncodeToString(b), nil
}
func cleanupExpiredTokens() {
tokenStore.mu.Lock()
defer tokenStore.mu.Unlock()
now := time.Now()
for token, expiry := range tokenStore.tokens {
if now.After(expiry) {
delete(tokenStore.tokens, token)
}
}
}
func validateToken(token string) bool {
tokenStore.mu.RLock()
defer tokenStore.mu.RUnlock()
expiry, exists := tokenStore.tokens[token]
if !exists {
return false
}
return time.Now().Before(expiry)
}
func storeToken(token string) {
tokenStore.mu.Lock()
defer tokenStore.mu.Unlock()
tokenStore.tokens[token] = time.Now().Add(csrfTokenTTL)
}
func CSRFMiddleware(next http.Handler) http.Handler {
cleanupOnce.Do(func() {
go func() {
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
for range ticker.C {
cleanupExpiredTokens()
}
}()
})
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodGet || r.Method == http.MethodHead || r.Method == http.MethodOptions {
// For GET requests, generate and set a new CSRF token
token, err := generateCSRFToken()
if err != nil {
helper.RespondWithError(w, http.StatusInternalServerError, InternalServerError)
return
}
log.Print("Generated CSRF token: ", token)
storeToken(token)
expires := time.Now().Add(csrfTokenTTL)
http.SetCookie(w, &http.Cookie{
Name: csrfCookieName,
Value: token,
Path: "/",
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteStrictMode,
MaxAge: int(csrfTokenTTL.Seconds()),
Expires: expires,
})
w.Header().Set(csrfTokenHeader, token)
log.Print("Set CSRF token cookie and header")
next.ServeHTTP(&FlusherPreservingResponseWriter{ResponseWriter: w}, r)
return
}
tokenFromHeader := r.Header.Get(csrfTokenHeader)
if tokenFromHeader == "" {
helper.RespondWithError(w, http.StatusForbidden, "CSRF token missing from header")
return
}
if !validateToken(tokenFromHeader) {
helper.RespondWithError(w, http.StatusForbidden, "Invalid or expired CSRF token")
return
}
next.ServeHTTP(&FlusherPreservingResponseWriter{ResponseWriter: w}, r)
})
}
+5 -2
View File
@@ -2,6 +2,7 @@ package routes
import (
"authentication/handlers"
"authentication/middleware"
"database/sql"
"github.com/gorilla/mux"
@@ -16,9 +17,11 @@ func SetupRoutes(router *mux.Router, db *sql.DB) {
authRoutes := router.PathPrefix("/v1/auth").Subrouter()
authRoutes.HandleFunc("/login", handlers.GoogleLogin).Methods("GET")
authRoutes.HandleFunc("/callback", handlers.GoogleCallback).Methods("GET")
authRoutes.HandleFunc("/refresh_token", handlers.HandleTokenRefresh).Methods("GET", "POST", "OPTIONS")
authRoutes.HandleFunc("/logout", handlers.LogoutHandler).Methods("GET")
authRoutes.HandleFunc("/forgot-password", handlers.ForgotPassword).Methods("GET")
csrfProtected := authRoutes.NewRoute().Subrouter()
csrfProtected.Use(middleware.CSRFMiddleware)
csrfProtected.HandleFunc("/refresh_token", handlers.HandleTokenRefresh).Methods("POST", "OPTIONS")
csrfProtected.HandleFunc("/logout", handlers.LogoutHandler).Methods("POST")
// authRoutes.HandleFunc("/microsoft/login", handlers.MicrosoftLogin).Methods("GET")
// authRoutes.HandleFunc("/microsoft/callback", handlers.MicrosotCallback).Methods("GET")