219 lines
5.7 KiB
Go
219 lines
5.7 KiB
Go
package middleware
|
|
|
|
import (
|
|
"authorization/helper"
|
|
"authorization/models"
|
|
"context"
|
|
"fmt"
|
|
"net/http"
|
|
"os"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/golang-jwt/jwt/v5"
|
|
)
|
|
|
|
// contextKey is a custom type for context keys to avoid collisions
|
|
type contextKey string
|
|
|
|
const (
|
|
claimsKey contextKey = "claims"
|
|
userIDKey contextKey = "user_id"
|
|
usernameKey contextKey = "username"
|
|
roleKey contextKey = "role"
|
|
)
|
|
|
|
// Token cache entry
|
|
type cacheEntry struct {
|
|
claims *models.Claims
|
|
expiresAt time.Time
|
|
}
|
|
|
|
var (
|
|
// Token cache for high-frequency requests
|
|
tokenCache = make(map[string]*cacheEntry)
|
|
tokenCacheMutex sync.RWMutex
|
|
|
|
// Cache JWT secret to avoid repeated os.Getenv calls
|
|
jwtSecretOnce sync.Once
|
|
jwtSecretCached []byte
|
|
jwtSecretError error
|
|
|
|
// Pre-allocate error messages to avoid repeated allocations
|
|
errMissingAuth = "missing authorization header"
|
|
errInvalidAuthFormat = "invalid authorization header format"
|
|
errInvalidToken = "Invalid token"
|
|
errExpiredToken = "Invalid or expired token"
|
|
errInvalidClaims = "Invalid token claims"
|
|
)
|
|
|
|
// Initialize JWT secret once
|
|
func getJWTSecret() ([]byte, error) {
|
|
jwtSecretOnce.Do(func() {
|
|
secret := os.Getenv("JWT_KEY")
|
|
if secret == "" {
|
|
jwtSecretError = fmt.Errorf("JWT_KEY not set in environment")
|
|
return
|
|
}
|
|
jwtSecretCached = []byte(secret)
|
|
})
|
|
return jwtSecretCached, jwtSecretError
|
|
}
|
|
|
|
// Clean expired cache entries periodically
|
|
func init() {
|
|
go func() {
|
|
ticker := time.NewTicker(5 * time.Minute)
|
|
defer ticker.Stop()
|
|
for range ticker.C {
|
|
cleanExpiredTokens()
|
|
}
|
|
}()
|
|
}
|
|
|
|
func cleanExpiredTokens() {
|
|
tokenCacheMutex.Lock()
|
|
defer tokenCacheMutex.Unlock()
|
|
|
|
now := time.Now()
|
|
for token, entry := range tokenCache {
|
|
if now.After(entry.expiresAt) {
|
|
delete(tokenCache, token)
|
|
}
|
|
}
|
|
}
|
|
|
|
// JWTAuth is a middleware that validates JWT tokens with caching for high-frequency requests
|
|
func JWTAuth(next http.HandlerFunc) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
// Get the Authorization header
|
|
authHeader := r.Header.Get("Authorization")
|
|
if authHeader == "" {
|
|
helper.RespondWithError(w, http.StatusUnauthorized, "Unauthorized")
|
|
return
|
|
}
|
|
|
|
// Fast path: check if header has Bearer prefix without allocation
|
|
if len(authHeader) < 8 || authHeader[:7] != "Bearer " {
|
|
helper.RespondWithError(w, http.StatusUnauthorized, errInvalidAuthFormat)
|
|
return
|
|
}
|
|
|
|
tokenString := authHeader[7:] // Skip "Bearer " without strings.Split allocation
|
|
|
|
// Check cache first (read lock)
|
|
tokenCacheMutex.RLock()
|
|
if cached, exists := tokenCache[tokenString]; exists {
|
|
if time.Now().Before(cached.expiresAt) {
|
|
claims := cached.claims
|
|
tokenCacheMutex.RUnlock()
|
|
|
|
// Add claims to context and proceed
|
|
ctx := buildContext(r.Context(), claims)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
return
|
|
}
|
|
// Token expired in cache, remove it
|
|
tokenCacheMutex.RUnlock()
|
|
tokenCacheMutex.Lock()
|
|
delete(tokenCache, tokenString)
|
|
tokenCacheMutex.Unlock()
|
|
} else {
|
|
tokenCacheMutex.RUnlock()
|
|
}
|
|
|
|
// Parse and validate the token
|
|
token, err := jwt.ParseWithClaims(tokenString, &models.Claims{}, func(token *jwt.Token) (interface{}, error) {
|
|
// Validate the signing method
|
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
|
}
|
|
|
|
// Get cached JWT secret
|
|
return getJWTSecret()
|
|
})
|
|
|
|
if err != nil {
|
|
helper.RespondWithError(w, http.StatusUnauthorized, errExpiredToken)
|
|
return
|
|
}
|
|
|
|
// Check if token is valid
|
|
if !token.Valid {
|
|
helper.RespondWithError(w, http.StatusUnauthorized, errInvalidToken)
|
|
return
|
|
}
|
|
|
|
// Extract claims
|
|
claims, ok := token.Claims.(*models.Claims)
|
|
if !ok {
|
|
helper.RespondWithError(w, http.StatusUnauthorized, errInvalidClaims)
|
|
return
|
|
}
|
|
|
|
// Cache the validated token
|
|
expiresAt := time.Now().Add(5 * time.Minute) // Cache for 5 minutes
|
|
if claims.ExpiresAt != nil && claims.ExpiresAt.Time.Before(expiresAt) {
|
|
expiresAt = claims.ExpiresAt.Time
|
|
}
|
|
|
|
tokenCacheMutex.Lock()
|
|
// Limit cache size to prevent memory issues
|
|
if len(tokenCache) > 10000000 {
|
|
// Remove oldest 10% when cache is full
|
|
count := 0
|
|
for k := range tokenCache {
|
|
delete(tokenCache, k)
|
|
count++
|
|
if count >= 1000000 {
|
|
break
|
|
}
|
|
}
|
|
}
|
|
tokenCache[tokenString] = &cacheEntry{
|
|
claims: claims,
|
|
expiresAt: expiresAt,
|
|
}
|
|
tokenCacheMutex.Unlock()
|
|
|
|
// Add claims to request context
|
|
ctx := buildContext(r.Context(), claims)
|
|
|
|
// Call the next handler with updated context
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
}
|
|
}
|
|
|
|
// buildContext efficiently builds context with claims (reduces allocations)
|
|
func buildContext(parent context.Context, claims *models.Claims) context.Context {
|
|
ctx := context.WithValue(parent, claimsKey, claims)
|
|
ctx = context.WithValue(ctx, userIDKey, claims.UserID)
|
|
ctx = context.WithValue(ctx, usernameKey, claims.Username)
|
|
ctx = context.WithValue(ctx, roleKey, claims.Role)
|
|
return ctx
|
|
}
|
|
|
|
// GetClaims retrieves the JWT claims from the request context
|
|
func GetClaims(r *http.Request) (*models.Claims, bool) {
|
|
claims, ok := r.Context().Value(claimsKey).(*models.Claims)
|
|
return claims, ok
|
|
}
|
|
|
|
// GetUserID retrieves the user ID from the request context
|
|
func GetUserID(r *http.Request) (string, bool) {
|
|
userID, ok := r.Context().Value(userIDKey).(string)
|
|
return userID, ok
|
|
}
|
|
|
|
// GetUsername retrieves the username from the request context
|
|
func GetUsername(r *http.Request) (string, bool) {
|
|
username, ok := r.Context().Value(usernameKey).(string)
|
|
return username, ok
|
|
}
|
|
|
|
// GetRole retrieves the role from the request context
|
|
func GetRole(r *http.Request) (string, bool) {
|
|
role, ok := r.Context().Value(roleKey).(string)
|
|
return role, ok
|
|
}
|