//lint:file-ignore SA1029 Ignore all golangci-lint warnings in this file package middleware import ( "context" "database/sql" "encoding/pem" "fmt" "net/http" "net/url" "os" "strings" "sync" "time" "authentication/db" "authentication/helper" "authentication/models" "authentication/redisclient" "github.com/golang-jwt/jwt/v5" ) var ( Blacklist = make(map[string]struct{}) Mu sync.Mutex ) func JWTMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { authHeader := r.Header.Get("Authorization") DashboardBaseURL := os.Getenv("DASHBOARD_URL") tokenString := "" if isValidAuthHeader(authHeader) { tokenString = strings.TrimPrefix(authHeader, "Bearer ") } else { path := r.URL.Path if strings.Contains(path, "/sse") { tokenString = r.URL.Query().Get("access_token") if tokenString == "" { http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("Missing access_token in query params")), http.StatusSeeOther) return } } else { http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("Invalid authorization header")), http.StatusSeeOther) return } } if isTokenBlacklisted(tokenString) { http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("Token is blacklisted")), http.StatusSeeOther) return } secretKey := os.Getenv("JWT_SECRET_KEY") if secretKey == "" { helper.RespondWithError(w, http.StatusInternalServerError, "Secret key not set") return } token, err := parseToken(tokenString, secretKey) if err != nil { http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape(InvalidOrExpiredToken)), http.StatusSeeOther) return } claims, ok := token.Claims.(jwt.MapClaims) if !ok || !token.Valid { http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape(InvalidTokenClaims)), http.StatusSeeOther) return } // Check JWT token expiration if exp, ok := claims["exp"].(float64); ok { if exp == 0 { http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("Token has no expiration")), http.StatusSeeOther) return } // Check if token is expired if time.Now().Unix() > int64(exp) { http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("Token has expired")), http.StatusSeeOther) return } } else { http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("Token missing expiration claim")), http.StatusSeeOther) return } email, ok := claims["email"].(string) if !ok { http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape(InvalidTokenClaims)), http.StatusSeeOther) return } sessionID, ok := claims["session_id"].(string) if !ok { http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("Invalid session ID in token")), http.StatusSeeOther) return } if isSessionBlacklisted(sessionID) { http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("Session has been revoked")), http.StatusSeeOther) return } session, err := validateSessionFromDB(sessionID) if err != nil { http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("Invalid or revoked session")), http.StatusSeeOther) return } userAgent := r.Header.Get("User-Agent") ipAddress := getClientIP(r) if session.UserAgent != userAgent { helper.LogError(nil, fmt.Sprintf("Session security mismatch for session %s", sessionID)) } if session.IPAddress != ipAddress { helper.LogError(nil, fmt.Sprintf("Session IP address mismatch for session %s: expected %s, got %s", sessionID, session.IPAddress, ipAddress)) } userID, err := getUserIDByEmail(email) if err != nil { if err != sql.ErrNoRows { helper.RespondWithError(w, http.StatusInternalServerError, "Failed to get user ID") return } } ctx := context.WithValue(r.Context(), "userID", userID) ctx = context.WithValue(ctx, "sessionID", sessionID) ctx = context.WithValue(ctx, "email", email) next.ServeHTTP(&models.FlusherPreservingResponseWriter{ResponseWriter: w}, r.WithContext(ctx)) }) } func isValidAuthHeader(authHeader string) bool { return authHeader != "" && strings.HasPrefix(authHeader, "Bearer ") } func isTokenBlacklisted(tokenString string) bool { Mu.Lock() defer Mu.Unlock() _, found := Blacklist[tokenString] return found } // isSessionBlacklisted checks if a session is in the Redis blacklist func isSessionBlacklisted(sessionID string) bool { ctx := context.Background() blacklistKey := fmt.Sprintf("session_blacklist:%s", sessionID) exists, err := redisclient.RDB.Exists(ctx, blacklistKey).Result() return err == nil && exists > 0 } func parseToken(tokenString, secretKey string) (*jwt.Token, error) { return jwt.ParseWithClaims(tokenString, jwt.MapClaims{}, func(token *jwt.Token) (interface{}, error) { if token.Method != jwt.SigningMethodRS256 { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } publicKeyPEM := os.Getenv("JWT_PUBLIC_KEY") if publicKeyPEM == "" { return nil, fmt.Errorf("JWT public key not set") } block, _ := pem.Decode([]byte(publicKeyPEM)) if block == nil { return nil, fmt.Errorf("failed to decode PEM block") } pubKey, err := jwt.ParseRSAPublicKeyFromPEM([]byte(publicKeyPEM)) if err != nil { return nil, fmt.Errorf("failed to parse RSA public key") } return pubKey, nil }) } func getUserIDByEmail(email string) (string, error) { var userID string err := db.DB.QueryRow("SELECT id FROM users WHERE email_address = ?", email).Scan(&userID) if err != nil { return "", err } return userID, nil } func validateSessionFromDB(sessionID string) (*models.JWTSession, error) { ctx := context.Background() sessionKey := fmt.Sprintf(redisKeyJWTSessionID, sessionID) // Try to get session from Redis cache first var session models.JWTSession err := helper.GetJSON(ctx, sessionKey, &session) if err != nil { // Session not in cache, fetch from database err = db.DB.QueryRow(` SELECT id, user_id, refresh_token_hash, user_agent, ip_address, created_at, updated_at, expires_at, is_revoked FROM jwt_sessions WHERE id = ? AND is_revoked = false `, sessionID).Scan( &session.ID, &session.UsersID, &session.RefreshTokenHash, &session.UserAgent, &session.IPAddress, &session.CreatedAt, &session.UpdatedAt, &session.ExpiresAt, &session.IsRevoked, ) if err != nil { return nil, fmt.Errorf("session not found or revoked: %w", err) } // Cache the session in Redis (TTL based on session expiry) sessionTTL := int(time.Until(session.ExpiresAt).Seconds()) if sessionTTL > 0 { if err := helper.SetJSON(ctx, sessionKey, session, &sessionTTL); err != nil { helper.LogWarn(fmt.Sprintf("Failed to cache session in Redis: %v", err)) } } } if session.ExpiresAt.Before(time.Now()) { // Auto-revoke expired session and clear cache _, _ = db.DB.Exec("UPDATE jwt_sessions SET is_revoked = true WHERE id = ?", sessionID) redisclient.RDB.Del(ctx, sessionKey) return nil, fmt.Errorf("session has expired") } return &session, nil } func getClientIP(r *http.Request) string { forwarded := r.Header.Get("X-Forwarded-For") if forwarded != "" { parts := strings.Split(forwarded, ",") return strings.TrimSpace(parts[0]) } realIP := r.Header.Get("X-Real-IP") if realIP != "" { return realIP } ip := r.RemoteAddr if idx := strings.LastIndex(ip, ":"); idx != -1 { ip = ip[:idx] } return ip }