fix all issues
This commit is contained in:
+97
-80
@@ -31,7 +31,6 @@ var (
|
||||
jwtSecretError error
|
||||
|
||||
// Pre-allocate error messages to avoid repeated allocations
|
||||
errMissingAuth = "missing authorization header"
|
||||
errInvalidAuthFormat = "invalid authorization header format"
|
||||
errInvalidToken = "Invalid token"
|
||||
errExpiredToken = "Invalid or expired token"
|
||||
@@ -74,103 +73,121 @@ func cleanExpiredTokens() {
|
||||
}
|
||||
}
|
||||
|
||||
// extractBearerToken extracts token from Authorization header
|
||||
func extractBearerToken(authHeader string) (string, bool) {
|
||||
if authHeader == "" || len(authHeader) < 8 || authHeader[:7] != "Bearer " {
|
||||
return "", false
|
||||
}
|
||||
return authHeader[7:], true
|
||||
}
|
||||
|
||||
// checkTokenCache retrieves token from cache if valid
|
||||
func checkTokenCache(tokenString string) (*models.Claims, bool) {
|
||||
tokenCacheMutex.RLock()
|
||||
defer tokenCacheMutex.RUnlock()
|
||||
|
||||
cached, exists := tokenCache[tokenString]
|
||||
if !exists {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if time.Now().Before(cached.ExpiresAt) {
|
||||
return cached.Claims, true
|
||||
}
|
||||
|
||||
// Token expired, will be cleaned up later
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// removeExpiredCacheEntry removes a single expired token from cache
|
||||
func removeExpiredCacheEntry(tokenString string) {
|
||||
tokenCacheMutex.Lock()
|
||||
defer tokenCacheMutex.Unlock()
|
||||
delete(tokenCache, tokenString)
|
||||
}
|
||||
|
||||
// parseAndValidateToken parses JWT token and validates it
|
||||
func parseAndValidateToken(tokenString string) (*models.Claims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &models.Claims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
return getJWTSecret()
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !token.Valid {
|
||||
return nil, fmt.Errorf("invalid token")
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(*models.Claims)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid claims")
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// cacheToken stores validated token in cache
|
||||
func cacheToken(tokenString string, claims *models.Claims) {
|
||||
expiresAt := time.Now().Add(5 * time.Minute)
|
||||
if claims.ExpiresAt != nil && claims.ExpiresAt.Time.Before(expiresAt) {
|
||||
expiresAt = claims.ExpiresAt.Time
|
||||
}
|
||||
|
||||
tokenCacheMutex.Lock()
|
||||
defer tokenCacheMutex.Unlock()
|
||||
|
||||
// Limit cache size
|
||||
if len(tokenCache) > 10000000 {
|
||||
count := 0
|
||||
for k := range tokenCache {
|
||||
delete(tokenCache, k)
|
||||
count++
|
||||
if count >= 1000000 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tokenCache[tokenString] = &models.CacheEntry{
|
||||
Claims: claims,
|
||||
ExpiresAt: expiresAt,
|
||||
}
|
||||
}
|
||||
|
||||
// JWTAuth is a middleware that validates JWT tokens with caching for high-frequency requests
|
||||
func JWTAuth(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// Get the Authorization header
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
// Extract token from header
|
||||
tokenString, ok := extractBearerToken(r.Header.Get("Authorization"))
|
||||
if !ok {
|
||||
helper.RespondWithError(w, http.StatusUnauthorized, "Unauthorized")
|
||||
return
|
||||
}
|
||||
|
||||
// Fast path: check if header has Bearer prefix without allocation
|
||||
if len(authHeader) < 8 || authHeader[:7] != "Bearer " {
|
||||
helper.RespondWithError(w, http.StatusUnauthorized, errInvalidAuthFormat)
|
||||
// Check cache first
|
||||
if claims, found := checkTokenCache(tokenString); found {
|
||||
ctx := buildContext(r.Context(), claims)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
|
||||
tokenString := authHeader[7:] // Skip "Bearer " without strings.Split allocation
|
||||
|
||||
// Check cache first (read lock)
|
||||
tokenCacheMutex.RLock()
|
||||
if cached, exists := tokenCache[tokenString]; exists {
|
||||
if time.Now().Before(cached.ExpiresAt) {
|
||||
claims := cached.Claims
|
||||
tokenCacheMutex.RUnlock()
|
||||
|
||||
// Add claims to context and proceed
|
||||
ctx := buildContext(r.Context(), claims)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
// Token expired in cache, remove it
|
||||
tokenCacheMutex.RUnlock()
|
||||
tokenCacheMutex.Lock()
|
||||
delete(tokenCache, tokenString)
|
||||
tokenCacheMutex.Unlock()
|
||||
} else {
|
||||
tokenCacheMutex.RUnlock()
|
||||
}
|
||||
|
||||
// Parse and validate the token
|
||||
token, err := jwt.ParseWithClaims(tokenString, &models.Claims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
// Validate the signing method
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
|
||||
// Get cached JWT secret
|
||||
return getJWTSecret()
|
||||
})
|
||||
|
||||
// Parse and validate token
|
||||
claims, err := parseAndValidateToken(tokenString)
|
||||
if err != nil {
|
||||
helper.RespondWithError(w, http.StatusUnauthorized, errExpiredToken)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if token is valid
|
||||
if !token.Valid {
|
||||
helper.RespondWithError(w, http.StatusUnauthorized, errInvalidToken)
|
||||
return
|
||||
}
|
||||
|
||||
// Extract claims
|
||||
claims, ok := token.Claims.(*models.Claims)
|
||||
if !ok {
|
||||
helper.RespondWithError(w, http.StatusUnauthorized, errInvalidClaims)
|
||||
return
|
||||
}
|
||||
|
||||
// Cache the validated token
|
||||
expiresAt := time.Now().Add(5 * time.Minute) // Cache for 5 minutes
|
||||
if claims.ExpiresAt != nil && claims.ExpiresAt.Time.Before(expiresAt) {
|
||||
expiresAt = claims.ExpiresAt.Time
|
||||
}
|
||||
cacheToken(tokenString, claims)
|
||||
|
||||
tokenCacheMutex.Lock()
|
||||
// Limit cache size to prevent memory issues
|
||||
if len(tokenCache) > 10000000 {
|
||||
// Remove oldest 10% when cache is full
|
||||
count := 0
|
||||
for k := range tokenCache {
|
||||
delete(tokenCache, k)
|
||||
count++
|
||||
if count >= 1000000 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
tokenCache[tokenString] = &models.CacheEntry{
|
||||
Claims: claims,
|
||||
ExpiresAt: expiresAt,
|
||||
}
|
||||
tokenCacheMutex.Unlock()
|
||||
|
||||
// Add claims to request context
|
||||
// Add claims to context and proceed
|
||||
ctx := buildContext(r.Context(), claims)
|
||||
|
||||
// Call the next handler with updated context
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user