init
This commit is contained in:
@@ -0,0 +1,6 @@
|
||||
package middleware
|
||||
|
||||
const (
|
||||
Authorization = "Authorization"
|
||||
Unauthorized = "Unauthorized"
|
||||
)
|
||||
@@ -0,0 +1,218 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"authorization/helper"
|
||||
"authorization/models"
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
// contextKey is a custom type for context keys to avoid collisions
|
||||
type contextKey string
|
||||
|
||||
const (
|
||||
claimsKey contextKey = "claims"
|
||||
userIDKey contextKey = "user_id"
|
||||
usernameKey contextKey = "username"
|
||||
roleKey contextKey = "role"
|
||||
)
|
||||
|
||||
// Token cache entry
|
||||
type cacheEntry struct {
|
||||
claims *models.Claims
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
var (
|
||||
// Token cache for high-frequency requests
|
||||
tokenCache = make(map[string]*cacheEntry)
|
||||
tokenCacheMutex sync.RWMutex
|
||||
|
||||
// Cache JWT secret to avoid repeated os.Getenv calls
|
||||
jwtSecretOnce sync.Once
|
||||
jwtSecretCached []byte
|
||||
jwtSecretError error
|
||||
|
||||
// Pre-allocate error messages to avoid repeated allocations
|
||||
errMissingAuth = "missing authorization header"
|
||||
errInvalidAuthFormat = "invalid authorization header format"
|
||||
errInvalidToken = "Invalid token"
|
||||
errExpiredToken = "Invalid or expired token"
|
||||
errInvalidClaims = "Invalid token claims"
|
||||
)
|
||||
|
||||
// Initialize JWT secret once
|
||||
func getJWTSecret() ([]byte, error) {
|
||||
jwtSecretOnce.Do(func() {
|
||||
secret := os.Getenv("JWT_KEY")
|
||||
if secret == "" {
|
||||
jwtSecretError = fmt.Errorf("JWT_KEY not set in environment")
|
||||
return
|
||||
}
|
||||
jwtSecretCached = []byte(secret)
|
||||
})
|
||||
return jwtSecretCached, jwtSecretError
|
||||
}
|
||||
|
||||
// Clean expired cache entries periodically
|
||||
func init() {
|
||||
go func() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
cleanExpiredTokens()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func cleanExpiredTokens() {
|
||||
tokenCacheMutex.Lock()
|
||||
defer tokenCacheMutex.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
for token, entry := range tokenCache {
|
||||
if now.After(entry.expiresAt) {
|
||||
delete(tokenCache, token)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// JWTAuth is a middleware that validates JWT tokens with caching for high-frequency requests
|
||||
func JWTAuth(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// Get the Authorization header
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
helper.RespondWithError(w, http.StatusUnauthorized, "Unauthorized")
|
||||
return
|
||||
}
|
||||
|
||||
// Fast path: check if header has Bearer prefix without allocation
|
||||
if len(authHeader) < 8 || authHeader[:7] != "Bearer " {
|
||||
helper.RespondWithError(w, http.StatusUnauthorized, errInvalidAuthFormat)
|
||||
return
|
||||
}
|
||||
|
||||
tokenString := authHeader[7:] // Skip "Bearer " without strings.Split allocation
|
||||
|
||||
// Check cache first (read lock)
|
||||
tokenCacheMutex.RLock()
|
||||
if cached, exists := tokenCache[tokenString]; exists {
|
||||
if time.Now().Before(cached.expiresAt) {
|
||||
claims := cached.claims
|
||||
tokenCacheMutex.RUnlock()
|
||||
|
||||
// Add claims to context and proceed
|
||||
ctx := buildContext(r.Context(), claims)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
// Token expired in cache, remove it
|
||||
tokenCacheMutex.RUnlock()
|
||||
tokenCacheMutex.Lock()
|
||||
delete(tokenCache, tokenString)
|
||||
tokenCacheMutex.Unlock()
|
||||
} else {
|
||||
tokenCacheMutex.RUnlock()
|
||||
}
|
||||
|
||||
// Parse and validate the token
|
||||
token, err := jwt.ParseWithClaims(tokenString, &models.Claims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
// Validate the signing method
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
|
||||
// Get cached JWT secret
|
||||
return getJWTSecret()
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
helper.RespondWithError(w, http.StatusUnauthorized, errExpiredToken)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if token is valid
|
||||
if !token.Valid {
|
||||
helper.RespondWithError(w, http.StatusUnauthorized, errInvalidToken)
|
||||
return
|
||||
}
|
||||
|
||||
// Extract claims
|
||||
claims, ok := token.Claims.(*models.Claims)
|
||||
if !ok {
|
||||
helper.RespondWithError(w, http.StatusUnauthorized, errInvalidClaims)
|
||||
return
|
||||
}
|
||||
|
||||
// Cache the validated token
|
||||
expiresAt := time.Now().Add(5 * time.Minute) // Cache for 5 minutes
|
||||
if claims.ExpiresAt != nil && claims.ExpiresAt.Time.Before(expiresAt) {
|
||||
expiresAt = claims.ExpiresAt.Time
|
||||
}
|
||||
|
||||
tokenCacheMutex.Lock()
|
||||
// Limit cache size to prevent memory issues
|
||||
if len(tokenCache) > 10000000 {
|
||||
// Remove oldest 10% when cache is full
|
||||
count := 0
|
||||
for k := range tokenCache {
|
||||
delete(tokenCache, k)
|
||||
count++
|
||||
if count >= 1000000 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
tokenCache[tokenString] = &cacheEntry{
|
||||
claims: claims,
|
||||
expiresAt: expiresAt,
|
||||
}
|
||||
tokenCacheMutex.Unlock()
|
||||
|
||||
// Add claims to request context
|
||||
ctx := buildContext(r.Context(), claims)
|
||||
|
||||
// Call the next handler with updated context
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
// buildContext efficiently builds context with claims (reduces allocations)
|
||||
func buildContext(parent context.Context, claims *models.Claims) context.Context {
|
||||
ctx := context.WithValue(parent, claimsKey, claims)
|
||||
ctx = context.WithValue(ctx, userIDKey, claims.UserID)
|
||||
ctx = context.WithValue(ctx, usernameKey, claims.Username)
|
||||
ctx = context.WithValue(ctx, roleKey, claims.Role)
|
||||
return ctx
|
||||
}
|
||||
|
||||
// GetClaims retrieves the JWT claims from the request context
|
||||
func GetClaims(r *http.Request) (*models.Claims, bool) {
|
||||
claims, ok := r.Context().Value(claimsKey).(*models.Claims)
|
||||
return claims, ok
|
||||
}
|
||||
|
||||
// GetUserID retrieves the user ID from the request context
|
||||
func GetUserID(r *http.Request) (string, bool) {
|
||||
userID, ok := r.Context().Value(userIDKey).(string)
|
||||
return userID, ok
|
||||
}
|
||||
|
||||
// GetUsername retrieves the username from the request context
|
||||
func GetUsername(r *http.Request) (string, bool) {
|
||||
username, ok := r.Context().Value(usernameKey).(string)
|
||||
return username, ok
|
||||
}
|
||||
|
||||
// GetRole retrieves the role from the request context
|
||||
func GetRole(r *http.Request) (string, bool) {
|
||||
role, ok := r.Context().Value(roleKey).(string)
|
||||
return role, ok
|
||||
}
|
||||
Reference in New Issue
Block a user