240 lines
7.1 KiB
Go
240 lines
7.1 KiB
Go
//lint:file-ignore SA1029 Ignore all golangci-lint warnings in this file
|
|
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"authentication/db"
|
|
"authentication/helper"
|
|
"authentication/models"
|
|
"authentication/redisclient"
|
|
|
|
"github.com/golang-jwt/jwt/v5"
|
|
)
|
|
|
|
var (
|
|
Blacklist = make(map[string]struct{})
|
|
Mu sync.Mutex
|
|
)
|
|
|
|
func JWTMiddleware(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
authHeader := r.Header.Get("Authorization")
|
|
DashboardBaseURL := os.Getenv("DASHBOARD_URL")
|
|
tokenString := ""
|
|
if isValidAuthHeader(authHeader) {
|
|
tokenString = strings.TrimPrefix(authHeader, "Bearer ")
|
|
} else {
|
|
path := r.URL.Path
|
|
if strings.Contains(path, "/sse") {
|
|
tokenString = r.URL.Query().Get("access_token")
|
|
if tokenString == "" {
|
|
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("Missing access_token in query params")), http.StatusSeeOther)
|
|
return
|
|
}
|
|
} else {
|
|
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("Invalid authorization header")), http.StatusSeeOther)
|
|
return
|
|
}
|
|
}
|
|
|
|
if isTokenBlacklisted(tokenString) {
|
|
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("Token is blacklisted")), http.StatusSeeOther)
|
|
return
|
|
}
|
|
|
|
secretKey := os.Getenv("JWT_SECRET_KEY")
|
|
if secretKey == "" {
|
|
helper.RespondWithError(w, http.StatusInternalServerError, "Secret key not set")
|
|
return
|
|
}
|
|
|
|
token, err := parseToken(tokenString, secretKey)
|
|
if err != nil {
|
|
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape(InvalidOrExpiredToken)), http.StatusSeeOther)
|
|
return
|
|
}
|
|
|
|
claims, ok := token.Claims.(jwt.MapClaims)
|
|
if !ok || !token.Valid {
|
|
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape(InvalidTokenClaims)), http.StatusSeeOther)
|
|
return
|
|
}
|
|
|
|
// Check JWT token expiration
|
|
|
|
if exp, ok := claims["exp"].(float64); ok {
|
|
if exp == 0 {
|
|
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("Token has no expiration")), http.StatusSeeOther)
|
|
return
|
|
}
|
|
// Check if token is expired
|
|
if time.Now().Unix() > int64(exp) {
|
|
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("Token has expired")), http.StatusSeeOther)
|
|
return
|
|
}
|
|
} else {
|
|
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("Token missing expiration claim")), http.StatusSeeOther)
|
|
return
|
|
}
|
|
|
|
email, ok := claims["email"].(string)
|
|
if !ok {
|
|
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape(InvalidTokenClaims)), http.StatusSeeOther)
|
|
return
|
|
}
|
|
sessionID, ok := claims["session_id"].(string)
|
|
if !ok {
|
|
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("Invalid session ID in token")), http.StatusSeeOther)
|
|
return
|
|
}
|
|
|
|
if isSessionBlacklisted(sessionID) {
|
|
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("Session has been revoked")), http.StatusSeeOther)
|
|
return
|
|
}
|
|
|
|
session, err := validateSessionFromDB(sessionID)
|
|
if err != nil {
|
|
http.Redirect(w, r, fmt.Sprintf(errorFormat, DashboardBaseURL, url.QueryEscape("Invalid or revoked session")), http.StatusSeeOther)
|
|
return
|
|
}
|
|
|
|
userAgent := r.Header.Get("User-Agent")
|
|
ipAddress := getClientIP(r)
|
|
if session.UserAgent != userAgent {
|
|
helper.LogError(nil, fmt.Sprintf("Session security mismatch for session %s", sessionID))
|
|
}
|
|
|
|
if session.IPAddress != ipAddress {
|
|
helper.LogError(nil, fmt.Sprintf("Session IP address mismatch for session %s: expected %s, got %s", sessionID, session.IPAddress, ipAddress))
|
|
}
|
|
|
|
userID, err := getUserIDByEmail(email)
|
|
if err != nil {
|
|
if err != sql.ErrNoRows {
|
|
helper.RespondWithError(w, http.StatusInternalServerError, "Failed to get user ID")
|
|
return
|
|
}
|
|
}
|
|
|
|
ctx := context.WithValue(r.Context(), "userID", userID)
|
|
ctx = context.WithValue(ctx, "sessionID", sessionID)
|
|
ctx = context.WithValue(ctx, "email", email)
|
|
next.ServeHTTP(&models.FlusherPreservingResponseWriter{ResponseWriter: w}, r.WithContext(ctx))
|
|
})
|
|
}
|
|
|
|
func isValidAuthHeader(authHeader string) bool {
|
|
return authHeader != "" && strings.HasPrefix(authHeader, "Bearer ")
|
|
}
|
|
|
|
func isTokenBlacklisted(tokenString string) bool {
|
|
Mu.Lock()
|
|
defer Mu.Unlock()
|
|
_, found := Blacklist[tokenString]
|
|
return found
|
|
}
|
|
|
|
// isSessionBlacklisted checks if a session is in the Redis blacklist
|
|
func isSessionBlacklisted(sessionID string) bool {
|
|
ctx := context.Background()
|
|
blacklistKey := fmt.Sprintf("session_blacklist:%s", sessionID)
|
|
|
|
exists, err := redisclient.RDB.Exists(ctx, blacklistKey).Result()
|
|
return err == nil && exists > 0
|
|
}
|
|
|
|
func parseToken(tokenString, secretKey string) (*jwt.Token, error) {
|
|
return jwt.ParseWithClaims(tokenString, jwt.MapClaims{}, func(token *jwt.Token) (interface{}, error) {
|
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
|
}
|
|
return []byte(secretKey), nil
|
|
})
|
|
}
|
|
|
|
func getUserIDByEmail(email string) (string, error) {
|
|
var userID string
|
|
err := db.DB.QueryRow("SELECT id FROM users WHERE email_address = ?", email).Scan(&userID)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return userID, nil
|
|
}
|
|
|
|
func validateSessionFromDB(sessionID string) (*models.JWTSession, error) {
|
|
ctx := context.Background()
|
|
sessionKey := fmt.Sprintf(redisKeyJWTSessionID, sessionID)
|
|
|
|
// Try to get session from Redis cache first
|
|
var session models.JWTSession
|
|
err := helper.GetJSON(ctx, sessionKey, &session)
|
|
if err != nil {
|
|
// Session not in cache, fetch from database
|
|
err = db.DB.QueryRow(`
|
|
SELECT id, user_id, refresh_token_hash, user_agent, ip_address, created_at, updated_at, expires_at, is_revoked
|
|
FROM jwt_sessions
|
|
WHERE id = ? AND is_revoked = false
|
|
`, sessionID).Scan(
|
|
&session.ID,
|
|
&session.UserID,
|
|
&session.RefreshTokenHash,
|
|
&session.UserAgent,
|
|
&session.IPAddress,
|
|
&session.CreatedAt,
|
|
&session.UpdatedAt,
|
|
&session.ExpiresAt,
|
|
&session.IsRevoked,
|
|
)
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("session not found or revoked: %w", err)
|
|
}
|
|
|
|
// Cache the session in Redis (TTL based on session expiry)
|
|
sessionTTL := int(time.Until(session.ExpiresAt).Seconds())
|
|
if sessionTTL > 0 {
|
|
if err := helper.SetJSON(ctx, sessionKey, session, &sessionTTL); err != nil {
|
|
helper.LogWarn(fmt.Sprintf("Failed to cache session in Redis: %v", err))
|
|
}
|
|
}
|
|
}
|
|
|
|
if session.ExpiresAt.Before(time.Now()) {
|
|
// Auto-revoke expired session and clear cache
|
|
_, _ = db.DB.Exec("UPDATE jwt_sessions SET is_revoked = true WHERE id = ?", sessionID)
|
|
redisclient.RDB.Del(ctx, sessionKey)
|
|
return nil, fmt.Errorf("session has expired")
|
|
}
|
|
|
|
return &session, nil
|
|
}
|
|
|
|
func getClientIP(r *http.Request) string {
|
|
forwarded := r.Header.Get("X-Forwarded-For")
|
|
if forwarded != "" {
|
|
parts := strings.Split(forwarded, ",")
|
|
return strings.TrimSpace(parts[0])
|
|
}
|
|
|
|
realIP := r.Header.Get("X-Real-IP")
|
|
if realIP != "" {
|
|
return realIP
|
|
}
|
|
|
|
ip := r.RemoteAddr
|
|
if idx := strings.LastIndex(ip, ":"); idx != -1 {
|
|
ip = ip[:idx]
|
|
}
|
|
return ip
|
|
}
|