Files
Authorization/middleware/jwt_test.go
T
2026-01-27 10:45:15 +08:00

566 lines
14 KiB
Go

package middleware
import (
"authorization/models"
"context"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"net/http"
"net/http/httptest"
"os"
"sync"
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
)
func TestMain(m *testing.M) {
os.Setenv("GO_ENV", "development")
code := m.Run()
os.Unsetenv("GO_ENV")
os.Exit(code)
}
// Test helper to generate RSA key pair and certificate
func generateTestRSACertificate(t *testing.T) (privateKey *rsa.PrivateKey, certPEM []byte) {
t.Helper()
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("Failed to generate RSA key: %v", err)
}
template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"Test Org"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(365 * 24 * time.Hour),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
if err != nil {
t.Fatalf("Failed to create certificate: %v", err)
}
certPEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
return privateKey, certPEM
}
func TestGetRSAPublicKey(t *testing.T) {
t.Run("Valid PEM file", func(t *testing.T) {
// Reset state for testing
rsaPublicKeyOnce = sync.Once{}
rsaPublicKeyError = nil
rsaPublicKeyCached = nil
key, err := getRSAPublicKey()
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if key == nil {
t.Error("Expected RSA public key, got nil")
}
})
}
func TestExtractBearerToken(t *testing.T) {
tests := []struct {
name string
authHeader string
wantToken string
wantOK bool
}{
{
name: "Valid Bearer token",
authHeader: "Bearer token123",
wantToken: "token123",
wantOK: true,
},
{
name: "Empty header",
authHeader: "",
wantToken: "",
wantOK: false,
},
{
name: "Too short",
authHeader: "Bearer",
wantToken: "",
wantOK: false,
},
{
name: "Wrong prefix",
authHeader: "Basic token123",
wantToken: "",
wantOK: false,
},
{
name: "Missing space",
authHeader: "Bearertoken123",
wantToken: "",
wantOK: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
token, ok := extractBearerToken(tt.authHeader)
if token != tt.wantToken {
t.Errorf("Expected token '%s', got '%s'", tt.wantToken, token)
}
if ok != tt.wantOK {
t.Errorf("Expected ok %v, got %v", tt.wantOK, ok)
}
})
}
}
func TestParseAndValidateToken(t *testing.T) {
// Reset state
rsaPublicKeyOnce = sync.Once{}
rsaPublicKeyError = nil
rsaPublicKeyCached = nil
t.Run("Valid token", func(t *testing.T) {
// Generate test RSA key and certificate
_, certPEM := generateTestRSACertificate(t)
// Create temporary test certificate file
tmpFile, err := os.CreateTemp("", "test-cert-*.pem")
if err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
defer os.Remove(tmpFile.Name())
if _, err := tmpFile.Write(certPEM); err != nil {
t.Fatalf("Failed to write cert: %v", err)
}
tmpFile.Close()
// Note: This test would need the actual PEM file to be present
// For now, we'll skip the full validation
t.Skip("Requires actual RSA certificate file")
})
t.Run("Invalid token", func(t *testing.T) {
// Reset state
rsaPublicKeyOnce = sync.Once{}
rsaPublicKeyError = nil
rsaPublicKeyCached = nil
_, err := parseAndValidateToken("invalid.token.string")
if err == nil {
t.Error("Expected error for invalid token")
}
})
}
func TestBuildContext(t *testing.T) {
claims := &models.Claims{
UserID: "user123",
RoleID: "admin",
}
parent := context.Background()
ctx := buildContext(parent, claims)
// Check claims
if val, ok := ctx.Value(claimsKey).(*models.Claims); !ok || val.UserID != "user123" {
t.Error("Claims not properly set in context")
}
// Check userID
if val, ok := ctx.Value(userIDKey).(string); !ok || val != "user123" {
t.Error("UserID not properly set in context")
}
// Check role
if val, ok := ctx.Value(roleIDKey).(string); !ok || val != "admin" {
t.Error("Role not properly set in context")
}
}
func TestGetClaims(t *testing.T) {
claims := &models.Claims{
UserID: "user123",
RoleID: "admin",
}
req := httptest.NewRequest("GET", "/", nil)
ctx := context.WithValue(req.Context(), claimsKey, claims)
req = req.WithContext(ctx)
retrievedClaims, ok := GetClaims(req)
if !ok {
t.Error("Expected claims to be found")
}
if retrievedClaims.UserID != "user123" {
t.Errorf("Expected UserID 'user123', got '%s'", retrievedClaims.UserID)
}
}
func TestGetUserID(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
ctx := context.WithValue(req.Context(), userIDKey, "user123")
req = req.WithContext(ctx)
userID, ok := GetUserID(req)
if !ok {
t.Error("Expected userID to be found")
}
if userID != "user123" {
t.Errorf("Expected 'user123', got '%s'", userID)
}
}
func TestGetRole(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
ctx := context.WithValue(req.Context(), roleIDKey, "admin")
req = req.WithContext(ctx)
role, ok := GetRole(req)
if !ok {
t.Error("Expected role to be found")
}
if role != "admin" {
t.Errorf("Expected 'admin', got '%s'", role)
}
}
func TestJWTAuthNoAuthHeader(t *testing.T) {
handler := func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}
req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
JWTAuth(handler)(w, req)
if w.Code != http.StatusUnauthorized {
t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code)
}
}
func TestJWTAuthInvalidToken(t *testing.T) {
// Reset RSA state
rsaPublicKeyOnce = sync.Once{}
rsaPublicKeyError = nil
rsaPublicKeyCached = nil
handler := func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}
req := httptest.NewRequest("GET", "/", nil)
req.Header.Set("Authorization", "Bearer invalid.token.here")
w := httptest.NewRecorder()
JWTAuth(handler)(w, req)
if w.Code != http.StatusUnauthorized {
t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code)
}
}
func TestJWTAuthValidToken(t *testing.T) {
t.Skip("Requires RSA certificate setup - integration test")
}
// Additional comprehensive test cases
func TestExtractBearerTokenEdgeCases(t *testing.T) {
testCases := []struct {
name string
header string
wantToken string
wantOk bool
}{
{"Multiple spaces", "Bearer token123", " token123", true}, // Extracts everything after "Bearer "
{"No space after Bearer", "Bearertoken123", "", false},
{"Lowercase bearer", "bearer token123", "", false},
{"Mixed case", "BeArEr token123", "", false},
{"Extra whitespace", " Bearer token123", "", false}, // Must start with "Bearer "
{"Token with spaces", "Bearer token with spaces", "token with spaces", true},
{"Very long token", "Bearer " + string(make([]byte, 5000)), string(make([]byte, 5000)), true},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
gotToken, gotOk := extractBearerToken(tc.header)
if gotToken != tc.wantToken || gotOk != tc.wantOk {
t.Errorf("extractBearerToken(%q) = (%q, %v), want (%q, %v)", tc.header, gotToken, gotOk, tc.wantToken, tc.wantOk)
}
})
}
}
func TestParseAndValidateTokenMalformedTokens(t *testing.T) {
// Reset RSA state
rsaPublicKeyOnce = sync.Once{}
rsaPublicKeyError = nil
rsaPublicKeyCached = nil
testCases := []struct {
name string
token string
}{
{"Empty string", ""},
{"Random string", "not.a.jwt.token"},
{"Only dots", "..."},
{"Two parts only", "header.payload"},
{"Four parts", "part1.part2.part3.part4"},
{"Invalid base64", "!@#$.!@#$.!@#$"},
{"Spaces in token", "part1 .part2 .part3"},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
_, err := parseAndValidateToken(tc.token)
if err == nil {
t.Errorf("Expected error for malformed token %q", tc.name)
}
})
}
}
func TestBuildContextWithDifferentRoles(t *testing.T) {
roles := []string{"admin", "user", "guest", "superadmin", "", "role-with-dash"}
for _, role := range roles {
t.Run("Role: "+role, func(t *testing.T) {
claims := &models.Claims{
UserID: "user123",
RoleID: role,
}
req := httptest.NewRequest("GET", "/", nil)
newReq := buildContext(req.Context(), claims)
reqWithCtx := req.WithContext(newReq)
retrievedClaims, ok := GetClaims(reqWithCtx)
if !ok {
t.Error("Claims not found in context")
}
if retrievedClaims.RoleID != role {
t.Errorf("Role = %q, want %q", retrievedClaims.RoleID, role)
}
})
}
}
func TestGetClaimsWithoutClaims(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
claims, ok := GetClaims(req)
if ok {
t.Error("Expected ok=false when claims not in context")
}
if claims != nil {
t.Error("Expected nil claims when not in context")
}
}
func TestGetClaimsWithWrongType(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
ctx := context.WithValue(req.Context(), models.ContextKey("claims"), "wrong type")
req = req.WithContext(ctx)
claims, ok := GetClaims(req)
if ok {
t.Error("Expected ok=false when claims are wrong type")
}
if claims != nil {
t.Error("Expected nil claims when wrong type in context")
}
}
func TestGetUserIDWithNoClaims(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
userID, ok := GetUserID(req)
if ok {
t.Error("Expected ok=false when no claims")
}
if userID != "" {
t.Errorf("Expected empty string, got %q", userID)
}
}
func TestGetRoleWithNoClaims(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
role, ok := GetRole(req)
if ok {
t.Error("Expected ok=false when no claims")
}
if role != "" {
t.Errorf("Expected empty string, got %q", role)
}
}
func TestJWTAuthMissingBearerPrefix(t *testing.T) {
// Reset RSA state
rsaPublicKeyOnce = sync.Once{}
rsaPublicKeyError = nil
rsaPublicKeyCached = nil
handler := func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}
req := httptest.NewRequest("GET", "/", nil)
req.Header.Set("Authorization", "InvalidToken")
w := httptest.NewRecorder()
JWTAuth(handler)(w, req)
if w.Code != http.StatusUnauthorized {
t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code)
}
}
func TestJWTAuthExpiredToken(t *testing.T) {
t.Skip("Requires RSA certificate setup - integration test")
}
func TestJWTAuthTokenWithMissingClaims(t *testing.T) {
t.Skip("Requires RSA certificate setup - integration test")
testCases := []struct {
name string
claims *models.Claims
}{
{
"Missing UserID",
&models.Claims{
RoleID: "admin",
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
},
},
},
{
"Missing Role",
&models.Claims{
UserID: "user123",
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
},
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, tc.claims)
tokenString, _ := token.SignedString([]byte("test-secret"))
handler := func(w http.ResponseWriter, r *http.Request) {
claims, ok := GetClaims(r)
if !ok {
t.Error("Claims should still be in context even if some fields are empty")
}
// Verify the missing field is empty
_ = claims
w.WriteHeader(http.StatusOK)
}
req := httptest.NewRequest("GET", "/", nil)
req.Header.Set("Authorization", "Bearer "+tokenString)
w := httptest.NewRecorder()
JWTAuth(handler)(w, req)
// Token is valid, just missing some claim fields
if w.Code != http.StatusOK {
t.Errorf("Expected status %d, got %d", http.StatusOK, w.Code)
}
})
}
}
func TestJWTAuthConcurrentRequests(t *testing.T) {
t.Skip("Requires RSA certificate setup - integration test")
claims := &models.Claims{
UserID: "user123",
RoleID: "admin",
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, _ := token.SignedString([]byte("test-secret"))
handler := func(w http.ResponseWriter, r *http.Request) {
if _, ok := GetClaims(r); !ok {
t.Error("Claims not found in concurrent request")
}
w.WriteHeader(http.StatusOK)
}
concurrency := 50
done := make(chan bool, concurrency)
for i := 0; i < concurrency; i++ {
go func() {
req := httptest.NewRequest("GET", "/", nil)
req.Header.Set("Authorization", "Bearer "+tokenString)
w := httptest.NewRecorder()
JWTAuth(handler)(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status %d in concurrent request, got %d", http.StatusOK, w.Code)
}
done <- true
}()
}
for i := 0; i < concurrency; i++ {
<-done
}
}
func TestJWTAuthTokenSignedWithWrongKey(t *testing.T) {
t.Skip("Requires RSA certificate setup - integration test")
// Create token with wrong key
claims := &models.Claims{
UserID: "user123",
RoleID: "admin",
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, _ := token.SignedString([]byte("wrong-secret"))
handler := func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}
req := httptest.NewRequest("GET", "/", nil)
req.Header.Set("Authorization", "Bearer "+tokenString)
w := httptest.NewRecorder()
JWTAuth(handler)(w, req)
if w.Code != http.StatusUnauthorized {
t.Errorf("Expected status %d for wrong signature, got %d", http.StatusUnauthorized, w.Code)
}
}