Files
Authorization/middleware/jwt.go
T

291 lines
7.8 KiB
Go

package middleware
import (
"authorization/helper"
"authorization/models"
"authorization/redisclient"
"context"
"crypto/rsa"
"crypto/x509"
"encoding/json"
"encoding/pem"
"fmt"
"log"
"net/http"
"os"
"sync"
"time"
"github.com/golang-jwt/jwt/v5"
)
const (
claimsKey models.ContextKey = "claims"
userIDKey models.ContextKey = "users_id"
roleIDKey models.ContextKey = "role_id"
)
var (
// Cache RSA public key to avoid repeated file reads
rsaPublicKeyOnce sync.Once
rsaPublicKeyCached *rsa.PublicKey
rsaPublicKeyError error
// Pre-allocate error messages to avoid repeated allocations
errExpiredToken = "Invalid or expired token" // #nosec G101
// Redis key prefix for token cache
redisTokenPrefix = "jwt:v2:token:" // #nosec G101
)
func getRSAPublicKey() (*rsa.PublicKey, error) {
rsaPublicKeyOnce.Do(func() {
log.Print("Loading RSA public key from PEM certificate file")
// Read PEM file - use path relative to executable or try both common paths
pemData, err := os.ReadFile("rsa/ServerCertificate.pem")
if err != nil {
// Try alternate path when running tests from subdirectory
pemData, err = os.ReadFile("../rsa/ServerCertificate.pem")
if err != nil {
rsaPublicKeyError = fmt.Errorf("failed to read PEM file: %w", err)
log.Printf("Error reading PEM file: %v", rsaPublicKeyError)
return
}
}
log.Print("PEM file successfully read")
// Parse PEM blocks to find the certificate
var certBlock *pem.Block
for {
block, rest := pem.Decode(pemData)
if block == nil {
break
}
if block.Type == "CERTIFICATE" {
certBlock = block
log.Print("Certificate block found in PEM file")
break
}
pemData = rest
}
if certBlock == nil {
rsaPublicKeyError = fmt.Errorf("no certificate block found in PEM file")
log.Printf("Error: %v", rsaPublicKeyError)
return
}
// Parse certificate
cert, err := x509.ParseCertificate(certBlock.Bytes)
if err != nil {
rsaPublicKeyError = fmt.Errorf("failed to parse certificate: %w", err)
log.Printf("Error parsing certificate: %v", rsaPublicKeyError)
return
}
log.Print("Certificate successfully parsed")
// Extract RSA public key
publicKey, ok := cert.PublicKey.(*rsa.PublicKey)
if !ok {
rsaPublicKeyError = fmt.Errorf("certificate does not contain RSA public key")
log.Printf("Error: %v", rsaPublicKeyError)
return
}
rsaPublicKeyCached = publicKey
log.Printf("RSA public key successfully loaded and cached (key size: %d bits)", publicKey.N.BitLen())
})
return rsaPublicKeyCached, rsaPublicKeyError
}
// extractBearerToken extracts token from Authorization header
func extractBearerToken(authHeader string) (string, bool) {
if authHeader == "" || len(authHeader) < 8 || authHeader[:7] != "Bearer " {
return "", false
}
return authHeader[7:], true
}
// checkTokenCache retrieves token from Redis cache if valid
func checkTokenCache(tokenString string) (*models.Claims, bool) {
if redisclient.RDB == nil {
return nil, false
}
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
key := redisTokenPrefix + tokenString
val, err := redisclient.RDB.Get(ctx, key).Result()
if err != nil {
return nil, false
}
var claims models.Claims
if err := json.Unmarshal([]byte(val), &claims); err != nil {
return nil, false
}
return &claims, true
}
// removeExpiredCacheEntry removes a single expired token from cache
// func removeExpiredCacheEntry(tokenString string) {
// tokenCacheMutex.Lock()
// defer tokenCacheMutex.Unlock()
// delete(tokenCache, tokenString)
// }
// parseAndValidateToken parses JWT token and validates it using RSA
func parseAndValidateToken(tokenString string) (*models.Claims, error) {
log.Print("Starting JWT token parsing and verification with RSA")
token, err := jwt.ParseWithClaims(tokenString, &models.Claims{}, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
log.Printf("Token verification failed: unexpected signing method: %v", token.Header["alg"])
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
log.Print("Token signing method verified: RSA")
return getRSAPublicKey()
})
if err != nil {
log.Printf("Token parsing failed: %v", err)
return nil, err
}
if !token.Valid {
log.Print("Token validation failed: token is invalid")
return nil, fmt.Errorf("invalid token")
}
log.Print("Token successfully parsed and validated")
claims, ok := token.Claims.(*models.Claims)
if !ok {
log.Print("Token validation failed: invalid claims structure")
return nil, fmt.Errorf("invalid claims")
}
log.Printf("Token verified successfully for user: (UserID: %s)", claims.UsersID)
return claims, nil
}
// cacheToken stores validated token in Redis cache
func cacheToken(tokenString string, claims *models.Claims) {
if redisclient.RDB == nil {
return
}
// Calculate TTL
ttl := 5 * time.Minute
if claims.ExpiresAt != nil {
timeUntilExpiry := time.Until(claims.ExpiresAt.Time)
if timeUntilExpiry > 0 && timeUntilExpiry < ttl {
ttl = timeUntilExpiry
}
}
// Serialize claims to JSON
claimsJSON, err := json.Marshal(claims)
if err != nil {
return
}
// Store in Redis with TTL
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
key := redisTokenPrefix + tokenString
redisclient.RDB.Set(ctx, key, claimsJSON, ttl)
}
// JWTAuth is a middleware that validates JWT tokens with caching for high-frequency requests
func JWTAuth(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Extract token from header
tokenString, ok := extractBearerToken(r.Header.Get("Authorization"))
if !ok {
helper.RespondWithError(w, http.StatusUnauthorized, "Unauthorized")
return
}
// Check cache first
if claims, found := checkTokenCache(tokenString); found {
ctx := buildContext(r.Context(), claims)
next.ServeHTTP(w, r.WithContext(ctx))
return
}
log.Print("1")
// Parse and validate token
claims, err := parseAndValidateToken(tokenString)
if err != nil {
helper.RespondWithError(w, http.StatusUnauthorized, errExpiredToken)
return
}
log.Print("2")
// Cache the validated token
cacheToken(tokenString, claims)
// Add claims to context and proceed
ctx := buildContext(r.Context(), claims)
next.ServeHTTP(w, r.WithContext(ctx))
}
}
// buildContext efficiently builds context with claims (reduces allocations)
func buildContext(parent context.Context, claims *models.Claims) context.Context {
ctx := context.WithValue(parent, claimsKey, claims)
ctx = context.WithValue(ctx, userIDKey, claims.UsersID)
roles := make([]int, 0, len(claims.RoleID)+len(claims.AdditionalRoleID))
unique := make(map[int]struct{})
for _, role := range claims.RoleID {
if _, exists := unique[role]; !exists {
unique[role] = struct{}{}
roles = append(roles, role)
}
}
for _, role := range claims.AdditionalRoleID {
if _, exists := unique[role]; !exists {
unique[role] = struct{}{}
roles = append(roles, role)
}
}
for _, project := range claims.Projects {
for _, role := range project.RoleID {
if _, exists := unique[role]; !exists {
unique[role] = struct{}{}
roles = append(roles, role)
}
}
}
ctx = context.WithValue(ctx, roleIDKey, roles)
return ctx
}
// GetClaims retrieves the JWT claims from the request context
func GetClaims(r *http.Request) (*models.Claims, bool) {
claims, ok := r.Context().Value(claimsKey).(*models.Claims)
return claims, ok
}
// GetUserID retrieves the user ID from the request context
func GetUserID(r *http.Request) (string, bool) {
userID, ok := r.Context().Value(userIDKey).(string)
return userID, ok
}
// GetRole retrieves the roles from the request context
func GetRole(r *http.Request) ([]int, bool) {
role, ok := r.Context().Value(roleIDKey).([]int)
return role, ok
}