This commit is contained in:
2025-12-04 10:55:25 +08:00
commit 60992c1e44
19 changed files with 1058 additions and 0 deletions
+6
View File
@@ -0,0 +1,6 @@
package middleware
const (
Authorization = "Authorization"
Unauthorized = "Unauthorized"
)
+218
View File
@@ -0,0 +1,218 @@
package middleware
import (
"authorization/helper"
"authorization/models"
"context"
"fmt"
"net/http"
"os"
"sync"
"time"
"github.com/golang-jwt/jwt/v5"
)
// contextKey is a custom type for context keys to avoid collisions
type contextKey string
const (
claimsKey contextKey = "claims"
userIDKey contextKey = "user_id"
usernameKey contextKey = "username"
roleKey contextKey = "role"
)
// Token cache entry
type cacheEntry struct {
claims *models.Claims
expiresAt time.Time
}
var (
// Token cache for high-frequency requests
tokenCache = make(map[string]*cacheEntry)
tokenCacheMutex sync.RWMutex
// Cache JWT secret to avoid repeated os.Getenv calls
jwtSecretOnce sync.Once
jwtSecretCached []byte
jwtSecretError error
// Pre-allocate error messages to avoid repeated allocations
errMissingAuth = "missing authorization header"
errInvalidAuthFormat = "invalid authorization header format"
errInvalidToken = "Invalid token"
errExpiredToken = "Invalid or expired token"
errInvalidClaims = "Invalid token claims"
)
// Initialize JWT secret once
func getJWTSecret() ([]byte, error) {
jwtSecretOnce.Do(func() {
secret := os.Getenv("JWT_KEY")
if secret == "" {
jwtSecretError = fmt.Errorf("JWT_KEY not set in environment")
return
}
jwtSecretCached = []byte(secret)
})
return jwtSecretCached, jwtSecretError
}
// Clean expired cache entries periodically
func init() {
go func() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for range ticker.C {
cleanExpiredTokens()
}
}()
}
func cleanExpiredTokens() {
tokenCacheMutex.Lock()
defer tokenCacheMutex.Unlock()
now := time.Now()
for token, entry := range tokenCache {
if now.After(entry.expiresAt) {
delete(tokenCache, token)
}
}
}
// JWTAuth is a middleware that validates JWT tokens with caching for high-frequency requests
func JWTAuth(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Get the Authorization header
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
helper.RespondWithError(w, http.StatusUnauthorized, "Unauthorized")
return
}
// Fast path: check if header has Bearer prefix without allocation
if len(authHeader) < 8 || authHeader[:7] != "Bearer " {
helper.RespondWithError(w, http.StatusUnauthorized, errInvalidAuthFormat)
return
}
tokenString := authHeader[7:] // Skip "Bearer " without strings.Split allocation
// Check cache first (read lock)
tokenCacheMutex.RLock()
if cached, exists := tokenCache[tokenString]; exists {
if time.Now().Before(cached.expiresAt) {
claims := cached.claims
tokenCacheMutex.RUnlock()
// Add claims to context and proceed
ctx := buildContext(r.Context(), claims)
next.ServeHTTP(w, r.WithContext(ctx))
return
}
// Token expired in cache, remove it
tokenCacheMutex.RUnlock()
tokenCacheMutex.Lock()
delete(tokenCache, tokenString)
tokenCacheMutex.Unlock()
} else {
tokenCacheMutex.RUnlock()
}
// Parse and validate the token
token, err := jwt.ParseWithClaims(tokenString, &models.Claims{}, func(token *jwt.Token) (interface{}, error) {
// Validate the signing method
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
// Get cached JWT secret
return getJWTSecret()
})
if err != nil {
helper.RespondWithError(w, http.StatusUnauthorized, errExpiredToken)
return
}
// Check if token is valid
if !token.Valid {
helper.RespondWithError(w, http.StatusUnauthorized, errInvalidToken)
return
}
// Extract claims
claims, ok := token.Claims.(*models.Claims)
if !ok {
helper.RespondWithError(w, http.StatusUnauthorized, errInvalidClaims)
return
}
// Cache the validated token
expiresAt := time.Now().Add(5 * time.Minute) // Cache for 5 minutes
if claims.ExpiresAt != nil && claims.ExpiresAt.Time.Before(expiresAt) {
expiresAt = claims.ExpiresAt.Time
}
tokenCacheMutex.Lock()
// Limit cache size to prevent memory issues
if len(tokenCache) > 10000000 {
// Remove oldest 10% when cache is full
count := 0
for k := range tokenCache {
delete(tokenCache, k)
count++
if count >= 1000000 {
break
}
}
}
tokenCache[tokenString] = &cacheEntry{
claims: claims,
expiresAt: expiresAt,
}
tokenCacheMutex.Unlock()
// Add claims to request context
ctx := buildContext(r.Context(), claims)
// Call the next handler with updated context
next.ServeHTTP(w, r.WithContext(ctx))
}
}
// buildContext efficiently builds context with claims (reduces allocations)
func buildContext(parent context.Context, claims *models.Claims) context.Context {
ctx := context.WithValue(parent, claimsKey, claims)
ctx = context.WithValue(ctx, userIDKey, claims.UserID)
ctx = context.WithValue(ctx, usernameKey, claims.Username)
ctx = context.WithValue(ctx, roleKey, claims.Role)
return ctx
}
// GetClaims retrieves the JWT claims from the request context
func GetClaims(r *http.Request) (*models.Claims, bool) {
claims, ok := r.Context().Value(claimsKey).(*models.Claims)
return claims, ok
}
// GetUserID retrieves the user ID from the request context
func GetUserID(r *http.Request) (string, bool) {
userID, ok := r.Context().Value(userIDKey).(string)
return userID, ok
}
// GetUsername retrieves the username from the request context
func GetUsername(r *http.Request) (string, bool) {
username, ok := r.Context().Value(usernameKey).(string)
return username, ok
}
// GetRole retrieves the role from the request context
func GetRole(r *http.Request) (string, bool) {
role, ok := r.Context().Value(roleKey).(string)
return role, ok
}