fix all issues

This commit is contained in:
2025-12-04 10:59:46 +08:00
parent e4946b7ad7
commit ca49e8e24b
4 changed files with 184 additions and 86 deletions
+97 -80
View File
@@ -31,7 +31,6 @@ var (
jwtSecretError error
// Pre-allocate error messages to avoid repeated allocations
errMissingAuth = "missing authorization header"
errInvalidAuthFormat = "invalid authorization header format"
errInvalidToken = "Invalid token"
errExpiredToken = "Invalid or expired token"
@@ -74,103 +73,121 @@ func cleanExpiredTokens() {
}
}
// extractBearerToken extracts token from Authorization header
func extractBearerToken(authHeader string) (string, bool) {
if authHeader == "" || len(authHeader) < 8 || authHeader[:7] != "Bearer " {
return "", false
}
return authHeader[7:], true
}
// checkTokenCache retrieves token from cache if valid
func checkTokenCache(tokenString string) (*models.Claims, bool) {
tokenCacheMutex.RLock()
defer tokenCacheMutex.RUnlock()
cached, exists := tokenCache[tokenString]
if !exists {
return nil, false
}
if time.Now().Before(cached.ExpiresAt) {
return cached.Claims, true
}
// Token expired, will be cleaned up later
return nil, false
}
// removeExpiredCacheEntry removes a single expired token from cache
func removeExpiredCacheEntry(tokenString string) {
tokenCacheMutex.Lock()
defer tokenCacheMutex.Unlock()
delete(tokenCache, tokenString)
}
// parseAndValidateToken parses JWT token and validates it
func parseAndValidateToken(tokenString string) (*models.Claims, error) {
token, err := jwt.ParseWithClaims(tokenString, &models.Claims{}, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return getJWTSecret()
})
if err != nil {
return nil, err
}
if !token.Valid {
return nil, fmt.Errorf("invalid token")
}
claims, ok := token.Claims.(*models.Claims)
if !ok {
return nil, fmt.Errorf("invalid claims")
}
return claims, nil
}
// cacheToken stores validated token in cache
func cacheToken(tokenString string, claims *models.Claims) {
expiresAt := time.Now().Add(5 * time.Minute)
if claims.ExpiresAt != nil && claims.ExpiresAt.Time.Before(expiresAt) {
expiresAt = claims.ExpiresAt.Time
}
tokenCacheMutex.Lock()
defer tokenCacheMutex.Unlock()
// Limit cache size
if len(tokenCache) > 10000000 {
count := 0
for k := range tokenCache {
delete(tokenCache, k)
count++
if count >= 1000000 {
break
}
}
}
tokenCache[tokenString] = &models.CacheEntry{
Claims: claims,
ExpiresAt: expiresAt,
}
}
// JWTAuth is a middleware that validates JWT tokens with caching for high-frequency requests
func JWTAuth(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Get the Authorization header
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
// Extract token from header
tokenString, ok := extractBearerToken(r.Header.Get("Authorization"))
if !ok {
helper.RespondWithError(w, http.StatusUnauthorized, "Unauthorized")
return
}
// Fast path: check if header has Bearer prefix without allocation
if len(authHeader) < 8 || authHeader[:7] != "Bearer " {
helper.RespondWithError(w, http.StatusUnauthorized, errInvalidAuthFormat)
// Check cache first
if claims, found := checkTokenCache(tokenString); found {
ctx := buildContext(r.Context(), claims)
next.ServeHTTP(w, r.WithContext(ctx))
return
}
tokenString := authHeader[7:] // Skip "Bearer " without strings.Split allocation
// Check cache first (read lock)
tokenCacheMutex.RLock()
if cached, exists := tokenCache[tokenString]; exists {
if time.Now().Before(cached.ExpiresAt) {
claims := cached.Claims
tokenCacheMutex.RUnlock()
// Add claims to context and proceed
ctx := buildContext(r.Context(), claims)
next.ServeHTTP(w, r.WithContext(ctx))
return
}
// Token expired in cache, remove it
tokenCacheMutex.RUnlock()
tokenCacheMutex.Lock()
delete(tokenCache, tokenString)
tokenCacheMutex.Unlock()
} else {
tokenCacheMutex.RUnlock()
}
// Parse and validate the token
token, err := jwt.ParseWithClaims(tokenString, &models.Claims{}, func(token *jwt.Token) (interface{}, error) {
// Validate the signing method
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
// Get cached JWT secret
return getJWTSecret()
})
// Parse and validate token
claims, err := parseAndValidateToken(tokenString)
if err != nil {
helper.RespondWithError(w, http.StatusUnauthorized, errExpiredToken)
return
}
// Check if token is valid
if !token.Valid {
helper.RespondWithError(w, http.StatusUnauthorized, errInvalidToken)
return
}
// Extract claims
claims, ok := token.Claims.(*models.Claims)
if !ok {
helper.RespondWithError(w, http.StatusUnauthorized, errInvalidClaims)
return
}
// Cache the validated token
expiresAt := time.Now().Add(5 * time.Minute) // Cache for 5 minutes
if claims.ExpiresAt != nil && claims.ExpiresAt.Time.Before(expiresAt) {
expiresAt = claims.ExpiresAt.Time
}
cacheToken(tokenString, claims)
tokenCacheMutex.Lock()
// Limit cache size to prevent memory issues
if len(tokenCache) > 10000000 {
// Remove oldest 10% when cache is full
count := 0
for k := range tokenCache {
delete(tokenCache, k)
count++
if count >= 1000000 {
break
}
}
}
tokenCache[tokenString] = &models.CacheEntry{
Claims: claims,
ExpiresAt: expiresAt,
}
tokenCacheMutex.Unlock()
// Add claims to request context
// Add claims to context and proceed
ctx := buildContext(r.Context(), claims)
// Call the next handler with updated context
next.ServeHTTP(w, r.WithContext(ctx))
}
}