Files
Authentication/middleware/jwt.go
T
2025-11-26 11:31:09 +08:00

240 lines
7.1 KiB
Go

//lint:file-ignore SA1029 Ignore all golangci-lint warnings in this file
package middleware
import (
"context"
"database/sql"
"fmt"
"net/http"
"net/url"
"os"
"strings"
"sync"
"time"
"authentication/db"
"authentication/helper"
"authentication/models"
"authentication/redisclient"
"github.com/golang-jwt/jwt/v5"
)
var (
Blacklist = make(map[string]struct{})
Mu sync.Mutex
)
func JWTMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authHeader := r.Header.Get("Authorization")
DashboardBaseURL := os.Getenv("DASHBOARD_URL")
tokenString := ""
if isValidAuthHeader(authHeader) {
tokenString = strings.TrimPrefix(authHeader, "Bearer ")
} else {
path := r.URL.Path
if strings.Contains(path, "/sse") {
tokenString = r.URL.Query().Get("access_token")
if tokenString == "" {
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("Missing access_token in query params")), http.StatusSeeOther)
return
}
} else {
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("Invalid authorization header")), http.StatusSeeOther)
return
}
}
if isTokenBlacklisted(tokenString) {
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("Token is blacklisted")), http.StatusSeeOther)
return
}
secretKey := os.Getenv("JWT_SECRET_KEY")
if secretKey == "" {
helper.RespondWithError(w, http.StatusInternalServerError, "Secret key not set")
return
}
token, err := parseToken(tokenString, secretKey)
if err != nil {
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape(InvalidOrExpiredToken)), http.StatusSeeOther)
return
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok || !token.Valid {
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape(InvalidTokenClaims)), http.StatusSeeOther)
return
}
// Check JWT token expiration
if exp, ok := claims["exp"].(float64); ok {
if exp == 0 {
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("Token has no expiration")), http.StatusSeeOther)
return
}
// Check if token is expired
if time.Now().Unix() > int64(exp) {
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("Token has expired")), http.StatusSeeOther)
return
}
} else {
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("Token missing expiration claim")), http.StatusSeeOther)
return
}
email, ok := claims["email"].(string)
if !ok {
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape(InvalidTokenClaims)), http.StatusSeeOther)
return
}
sessionID, ok := claims["session_id"].(string)
if !ok {
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("Invalid session ID in token")), http.StatusSeeOther)
return
}
if isSessionBlacklisted(sessionID) {
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("Session has been revoked")), http.StatusSeeOther)
return
}
session, err := validateSessionFromDB(sessionID)
if err != nil {
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("Invalid or revoked session")), http.StatusSeeOther)
return
}
userAgent := r.Header.Get("User-Agent")
ipAddress := getClientIP(r)
if session.UserAgent != userAgent {
helper.LogError(nil, fmt.Sprintf("Session security mismatch for session %s", sessionID))
}
if session.IPAddress != ipAddress {
helper.LogError(nil, fmt.Sprintf("Session IP address mismatch for session %s: expected %s, got %s", sessionID, session.IPAddress, ipAddress))
}
userID, err := getUserIDByEmail(email)
if err != nil {
if err != sql.ErrNoRows {
helper.RespondWithError(w, http.StatusInternalServerError, "Failed to get user ID")
return
}
}
ctx := context.WithValue(r.Context(), "userID", userID)
ctx = context.WithValue(ctx, "sessionID", sessionID)
ctx = context.WithValue(ctx, "email", email)
next.ServeHTTP(&models.FlusherPreservingResponseWriter{ResponseWriter: w}, r.WithContext(ctx))
})
}
func isValidAuthHeader(authHeader string) bool {
return authHeader != "" && strings.HasPrefix(authHeader, "Bearer ")
}
func isTokenBlacklisted(tokenString string) bool {
Mu.Lock()
defer Mu.Unlock()
_, found := Blacklist[tokenString]
return found
}
// isSessionBlacklisted checks if a session is in the Redis blacklist
func isSessionBlacklisted(sessionID string) bool {
ctx := context.Background()
blacklistKey := fmt.Sprintf("session_blacklist:%s", sessionID)
exists, err := redisclient.RDB.Exists(ctx, blacklistKey).Result()
return err == nil && exists > 0
}
func parseToken(tokenString, secretKey string) (*jwt.Token, error) {
return jwt.ParseWithClaims(tokenString, jwt.MapClaims{}, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return []byte(secretKey), nil
})
}
func getUserIDByEmail(email string) (string, error) {
var userID string
err := db.DB.QueryRow("SELECT id FROM users WHERE email_address = ?", email).Scan(&userID)
if err != nil {
return "", err
}
return userID, nil
}
func validateSessionFromDB(sessionID string) (*models.JWTSession, error) {
ctx := context.Background()
sessionKey := fmt.Sprintf(redisKeyJWTSessionID, sessionID)
// Try to get session from Redis cache first
var session models.JWTSession
err := helper.GetJSON(ctx, sessionKey, &session)
if err != nil {
// Session not in cache, fetch from database
err = db.DB.QueryRow(`
SELECT id, user_id, refresh_token_hash, user_agent, ip_address, created_at, updated_at, expires_at, is_revoked
FROM jwt_sessions
WHERE id = ? AND is_revoked = false
`, sessionID).Scan(
&session.ID,
&session.UserID,
&session.RefreshTokenHash,
&session.UserAgent,
&session.IPAddress,
&session.CreatedAt,
&session.UpdatedAt,
&session.ExpiresAt,
&session.IsRevoked,
)
if err != nil {
return nil, fmt.Errorf("session not found or revoked: %w", err)
}
// Cache the session in Redis (TTL based on session expiry)
sessionTTL := int(time.Until(session.ExpiresAt).Seconds())
if sessionTTL > 0 {
if err := helper.SetJSON(ctx, sessionKey, session, &sessionTTL); err != nil {
helper.LogWarn(fmt.Sprintf("Failed to cache session in Redis: %v", err))
}
}
}
if session.ExpiresAt.Before(time.Now()) {
// Auto-revoke expired session and clear cache
_, _ = db.DB.Exec("UPDATE jwt_sessions SET is_revoked = true WHERE id = ?", sessionID)
redisclient.RDB.Del(ctx, sessionKey)
return nil, fmt.Errorf("session has expired")
}
return &session, nil
}
func getClientIP(r *http.Request) string {
forwarded := r.Header.Get("X-Forwarded-For")
if forwarded != "" {
parts := strings.Split(forwarded, ",")
return strings.TrimSpace(parts[0])
}
realIP := r.Header.Get("X-Real-IP")
if realIP != "" {
return realIP
}
ip := r.RemoteAddr
if idx := strings.LastIndex(ip, ":"); idx != -1 {
ip = ip[:idx]
}
return ip
}