package middleware import ( "authorization/helper" "authorization/models" "authorization/redisclient" "context" "crypto/rsa" "crypto/x509" "encoding/json" "encoding/pem" "fmt" "log" "net/http" "os" "sync" "time" "github.com/golang-jwt/jwt/v5" ) const ( claimsKey models.ContextKey = "claims" userIDKey models.ContextKey = "users_id" roleIDKey models.ContextKey = "role_id" ) var ( // Cache RSA public key to avoid repeated file reads rsaPublicKeyOnce sync.Once rsaPublicKeyCached *rsa.PublicKey rsaPublicKeyError error // Pre-allocate error messages to avoid repeated allocations errExpiredToken = "Invalid or expired token" // #nosec G101 // Redis key prefix for token cache redisTokenPrefix = "jwt:v2:token:" ) func getRSAPublicKey() (*rsa.PublicKey, error) { rsaPublicKeyOnce.Do(func() { log.Print("Loading RSA public key from PEM certificate file") // Read PEM file - use path relative to executable or try both common paths pemData, err := os.ReadFile("rsa/ServerCertificate.pem") if err != nil { // Try alternate path when running tests from subdirectory pemData, err = os.ReadFile("../rsa/ServerCertificate.pem") if err != nil { rsaPublicKeyError = fmt.Errorf("failed to read PEM file: %w", err) log.Printf("Error reading PEM file: %v", rsaPublicKeyError) return } } log.Print("PEM file successfully read") // Parse PEM blocks to find the certificate var certBlock *pem.Block for { block, rest := pem.Decode(pemData) if block == nil { break } if block.Type == "CERTIFICATE" { certBlock = block log.Print("Certificate block found in PEM file") break } pemData = rest } if certBlock == nil { rsaPublicKeyError = fmt.Errorf("no certificate block found in PEM file") log.Printf("Error: %v", rsaPublicKeyError) return } // Parse certificate cert, err := x509.ParseCertificate(certBlock.Bytes) if err != nil { rsaPublicKeyError = fmt.Errorf("failed to parse certificate: %w", err) log.Printf("Error parsing certificate: %v", rsaPublicKeyError) return } log.Print("Certificate successfully parsed") // Extract RSA public key publicKey, ok := cert.PublicKey.(*rsa.PublicKey) if !ok { rsaPublicKeyError = fmt.Errorf("certificate does not contain RSA public key") log.Printf("Error: %v", rsaPublicKeyError) return } rsaPublicKeyCached = publicKey log.Printf("RSA public key successfully loaded and cached (key size: %d bits)", publicKey.N.BitLen()) }) return rsaPublicKeyCached, rsaPublicKeyError } // extractBearerToken extracts token from Authorization header func extractBearerToken(authHeader string) (string, bool) { if authHeader == "" || len(authHeader) < 8 || authHeader[:7] != "Bearer " { return "", false } return authHeader[7:], true } // checkTokenCache retrieves token from Redis cache if valid func checkTokenCache(tokenString string) (*models.Claims, bool) { if redisclient.RDB == nil { return nil, false } ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() key := redisTokenPrefix + tokenString val, err := redisclient.RDB.Get(ctx, key).Result() if err != nil { return nil, false } var claims models.Claims if err := json.Unmarshal([]byte(val), &claims); err != nil { return nil, false } return &claims, true } // removeExpiredCacheEntry removes a single expired token from cache // func removeExpiredCacheEntry(tokenString string) { // tokenCacheMutex.Lock() // defer tokenCacheMutex.Unlock() // delete(tokenCache, tokenString) // } // parseAndValidateToken parses JWT token and validates it using RSA func parseAndValidateToken(tokenString string) (*models.Claims, error) { log.Print("Starting JWT token parsing and verification with RSA") token, err := jwt.ParseWithClaims(tokenString, &models.Claims{}, func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { log.Printf("Token verification failed: unexpected signing method: %v", token.Header["alg"]) return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } log.Print("Token signing method verified: RSA") return getRSAPublicKey() }) if err != nil { log.Printf("Token parsing failed: %v", err) return nil, err } if !token.Valid { log.Print("Token validation failed: token is invalid") return nil, fmt.Errorf("invalid token") } log.Print("Token successfully parsed and validated") claims, ok := token.Claims.(*models.Claims) if !ok { log.Print("Token validation failed: invalid claims structure") return nil, fmt.Errorf("invalid claims") } log.Printf("Token verified successfully for user: (UserID: %s)", claims.UsersID) return claims, nil } // cacheToken stores validated token in Redis cache func cacheToken(tokenString string, claims *models.Claims) { if redisclient.RDB == nil { return } // Calculate TTL ttl := 5 * time.Minute if claims.ExpiresAt != nil { timeUntilExpiry := time.Until(claims.ExpiresAt.Time) if timeUntilExpiry > 0 && timeUntilExpiry < ttl { ttl = timeUntilExpiry } } // Serialize claims to JSON claimsJSON, err := json.Marshal(claims) if err != nil { return } // Store in Redis with TTL ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() key := redisTokenPrefix + tokenString redisclient.RDB.Set(ctx, key, claimsJSON, ttl) } // JWTAuth is a middleware that validates JWT tokens with caching for high-frequency requests func JWTAuth(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // Extract token from header tokenString, ok := extractBearerToken(r.Header.Get("Authorization")) if !ok { helper.RespondWithError(w, http.StatusUnauthorized, "Unauthorized") return } // Check cache first if claims, found := checkTokenCache(tokenString); found { ctx := buildContext(r.Context(), claims) next.ServeHTTP(w, r.WithContext(ctx)) return } log.Print("1") // Parse and validate token claims, err := parseAndValidateToken(tokenString) if err != nil { helper.RespondWithError(w, http.StatusUnauthorized, errExpiredToken) return } log.Print("2") // Cache the validated token cacheToken(tokenString, claims) // Add claims to context and proceed ctx := buildContext(r.Context(), claims) next.ServeHTTP(w, r.WithContext(ctx)) } } // buildContext efficiently builds context with claims (reduces allocations) func buildContext(parent context.Context, claims *models.Claims) context.Context { ctx := context.WithValue(parent, claimsKey, claims) ctx = context.WithValue(ctx, userIDKey, claims.UsersID) roles := make([]int, 0, len(claims.RoleID)+len(claims.AdditionalRoleID)) unique := make(map[int]struct{}) for _, role := range claims.RoleID { if _, exists := unique[role]; !exists { unique[role] = struct{}{} roles = append(roles, role) } } for _, role := range claims.AdditionalRoleID { if _, exists := unique[role]; !exists { unique[role] = struct{}{} roles = append(roles, role) } } for _, project := range claims.Projects { for _, role := range project.RoleID { if _, exists := unique[role]; !exists { unique[role] = struct{}{} roles = append(roles, role) } } } ctx = context.WithValue(ctx, roleIDKey, roles) return ctx } // GetClaims retrieves the JWT claims from the request context func GetClaims(r *http.Request) (*models.Claims, bool) { claims, ok := r.Context().Value(claimsKey).(*models.Claims) return claims, ok } // GetUserID retrieves the user ID from the request context func GetUserID(r *http.Request) (string, bool) { userID, ok := r.Context().Value(userIDKey).(string) return userID, ok } // GetRole retrieves the roles from the request context func GetRole(r *http.Request) ([]int, bool) { role, ok := r.Context().Value(roleIDKey).([]int) return role, ok }