From 33c59d1c6d4b92f8994bfb791ba01eedc6071df2 Mon Sep 17 00:00:00 2001 From: F04C Date: Wed, 21 Jan 2026 13:41:15 +0800 Subject: [PATCH] added csrf --- middleware/csrf.go | 123 +++++++++++++++++++++++++++++++++++++++++++++ routes/routes.go | 7 ++- 2 files changed, 128 insertions(+), 2 deletions(-) create mode 100644 middleware/csrf.go diff --git a/middleware/csrf.go b/middleware/csrf.go new file mode 100644 index 0000000..6ac6bc9 --- /dev/null +++ b/middleware/csrf.go @@ -0,0 +1,123 @@ +package middleware + +import ( + "authentication/helper" + "crypto/rand" + "encoding/base64" + "log" + "net/http" + "sync" + "time" +) + +const ( + csrfTokenHeader = "X-CSRF-Token" + csrfCookieName = "csrf_token" + csrfTokenLength = 32 + csrfTokenTTL = 24 * time.Hour +) + +type csrfTokenStore struct { + tokens map[string]time.Time + mu sync.RWMutex +} + +var tokenStore = &csrfTokenStore{ + tokens: make(map[string]time.Time), +} + +var cleanupOnce sync.Once + +func generateCSRFToken() (string, error) { + b := make([]byte, csrfTokenLength) + _, err := rand.Read(b) + if err != nil { + return "", err + } + return base64.URLEncoding.EncodeToString(b), nil +} + +func cleanupExpiredTokens() { + tokenStore.mu.Lock() + defer tokenStore.mu.Unlock() + + now := time.Now() + for token, expiry := range tokenStore.tokens { + if now.After(expiry) { + delete(tokenStore.tokens, token) + } + } +} + +func validateToken(token string) bool { + tokenStore.mu.RLock() + defer tokenStore.mu.RUnlock() + + expiry, exists := tokenStore.tokens[token] + if !exists { + return false + } + + return time.Now().Before(expiry) +} + +func storeToken(token string) { + tokenStore.mu.Lock() + defer tokenStore.mu.Unlock() + + tokenStore.tokens[token] = time.Now().Add(csrfTokenTTL) +} + +func CSRFMiddleware(next http.Handler) http.Handler { + cleanupOnce.Do(func() { + go func() { + ticker := time.NewTicker(1 * time.Hour) + defer ticker.Stop() + for range ticker.C { + cleanupExpiredTokens() + } + }() + }) + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet || r.Method == http.MethodHead || r.Method == http.MethodOptions { + // For GET requests, generate and set a new CSRF token + token, err := generateCSRFToken() + if err != nil { + helper.RespondWithError(w, http.StatusInternalServerError, InternalServerError) + return + } + log.Print("Generated CSRF token: ", token) + storeToken(token) + + expires := time.Now().Add(csrfTokenTTL) + http.SetCookie(w, &http.Cookie{ + Name: csrfCookieName, + Value: token, + Path: "/", + HttpOnly: true, + Secure: true, + SameSite: http.SameSiteStrictMode, + MaxAge: int(csrfTokenTTL.Seconds()), + Expires: expires, + }) + w.Header().Set(csrfTokenHeader, token) + log.Print("Set CSRF token cookie and header") + next.ServeHTTP(&FlusherPreservingResponseWriter{ResponseWriter: w}, r) + return + } + + tokenFromHeader := r.Header.Get(csrfTokenHeader) + if tokenFromHeader == "" { + helper.RespondWithError(w, http.StatusForbidden, "CSRF token missing from header") + return + } + + if !validateToken(tokenFromHeader) { + helper.RespondWithError(w, http.StatusForbidden, "Invalid or expired CSRF token") + return + } + + next.ServeHTTP(&FlusherPreservingResponseWriter{ResponseWriter: w}, r) + }) +} diff --git a/routes/routes.go b/routes/routes.go index 631858c..eff8efc 100644 --- a/routes/routes.go +++ b/routes/routes.go @@ -2,6 +2,7 @@ package routes import ( "authentication/handlers" + "authentication/middleware" "database/sql" "github.com/gorilla/mux" @@ -16,9 +17,11 @@ func SetupRoutes(router *mux.Router, db *sql.DB) { authRoutes := router.PathPrefix("/v1/auth").Subrouter() authRoutes.HandleFunc("/login", handlers.GoogleLogin).Methods("GET") authRoutes.HandleFunc("/callback", handlers.GoogleCallback).Methods("GET") - authRoutes.HandleFunc("/refresh_token", handlers.HandleTokenRefresh).Methods("GET", "POST", "OPTIONS") - authRoutes.HandleFunc("/logout", handlers.LogoutHandler).Methods("GET") authRoutes.HandleFunc("/forgot-password", handlers.ForgotPassword).Methods("GET") + csrfProtected := authRoutes.NewRoute().Subrouter() + csrfProtected.Use(middleware.CSRFMiddleware) + csrfProtected.HandleFunc("/refresh_token", handlers.HandleTokenRefresh).Methods("POST", "OPTIONS") + csrfProtected.HandleFunc("/logout", handlers.LogoutHandler).Methods("POST") // authRoutes.HandleFunc("/microsoft/login", handlers.MicrosoftLogin).Methods("GET") // authRoutes.HandleFunc("/microsoft/callback", handlers.MicrosotCallback).Methods("GET")