package middleware import ( "authorization/models" "context" "crypto/rand" "crypto/rsa" "crypto/x509" "crypto/x509/pkix" "encoding/pem" "math/big" "net/http" "net/http/httptest" "os" "sync" "testing" "time" "github.com/golang-jwt/jwt/v5" ) func TestMain(m *testing.M) { os.Setenv("GO_ENV", "development") code := m.Run() os.Unsetenv("GO_ENV") os.Exit(code) } // Test helper to generate RSA key pair and certificate func generateTestRSACertificate(t *testing.T) (privateKey *rsa.PrivateKey, certPEM []byte) { t.Helper() privateKey, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { t.Fatalf("Failed to generate RSA key: %v", err) } template := x509.Certificate{ SerialNumber: big.NewInt(1), Subject: pkix.Name{ Organization: []string{"Test Org"}, }, NotBefore: time.Now(), NotAfter: time.Now().Add(365 * 24 * time.Hour), KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, BasicConstraintsValid: true, } certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey) if err != nil { t.Fatalf("Failed to create certificate: %v", err) } certPEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) return privateKey, certPEM } func TestGetRSAPublicKey(t *testing.T) { t.Run("Valid PEM file", func(t *testing.T) { // Reset state for testing rsaPublicKeyOnce = sync.Once{} rsaPublicKeyError = nil rsaPublicKeyCached = nil key, err := getRSAPublicKey() if err != nil { t.Errorf("Expected no error, got %v", err) } if key == nil { t.Error("Expected RSA public key, got nil") } }) } func TestExtractBearerToken(t *testing.T) { tests := []struct { name string authHeader string wantToken string wantOK bool }{ { name: "Valid Bearer token", authHeader: "Bearer token123", wantToken: "token123", wantOK: true, }, { name: "Empty header", authHeader: "", wantToken: "", wantOK: false, }, { name: "Too short", authHeader: "Bearer", wantToken: "", wantOK: false, }, { name: "Wrong prefix", authHeader: "Basic token123", wantToken: "", wantOK: false, }, { name: "Missing space", authHeader: "Bearertoken123", wantToken: "", wantOK: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { token, ok := extractBearerToken(tt.authHeader) if token != tt.wantToken { t.Errorf("Expected token '%s', got '%s'", tt.wantToken, token) } if ok != tt.wantOK { t.Errorf("Expected ok %v, got %v", tt.wantOK, ok) } }) } } func TestParseAndValidateToken(t *testing.T) { // Reset state rsaPublicKeyOnce = sync.Once{} rsaPublicKeyError = nil rsaPublicKeyCached = nil t.Run("Valid token", func(t *testing.T) { // Generate test RSA key and certificate _, certPEM := generateTestRSACertificate(t) // Create temporary test certificate file tmpFile, err := os.CreateTemp("", "test-cert-*.pem") if err != nil { t.Fatalf("Failed to create temp file: %v", err) } defer os.Remove(tmpFile.Name()) if _, err := tmpFile.Write(certPEM); err != nil { t.Fatalf("Failed to write cert: %v", err) } tmpFile.Close() // Note: This test would need the actual PEM file to be present // For now, we'll skip the full validation t.Skip("Requires actual RSA certificate file") }) t.Run("Invalid token", func(t *testing.T) { // Reset state rsaPublicKeyOnce = sync.Once{} rsaPublicKeyError = nil rsaPublicKeyCached = nil _, err := parseAndValidateToken("invalid.token.string") if err == nil { t.Error("Expected error for invalid token") } }) } func TestBuildContext(t *testing.T) { claims := &models.Claims{ UsersID: "user123", RoleID: models.RoleIDs{3}, } parent := context.Background() ctx := buildContext(parent, claims) // Check claims if val, ok := ctx.Value(claimsKey).(*models.Claims); !ok || val.UsersID != "user123" { t.Error("Claims not properly set in context") } // Check userID if val, ok := ctx.Value(userIDKey).(string); !ok || val != "user123" { t.Error("UserID not properly set in context") } // Check role if val, ok := ctx.Value(roleIDKey).([]int); !ok || len(val) == 0 || val[0] != 3 { t.Error("Role not properly set in context") } } func TestBuildContextIncludesAdditionalRoles(t *testing.T) { claims := &models.Claims{ UsersID: "user123", RoleID: models.RoleIDs{30}, AdditionalRoleID: models.RoleIDs{4, 5, 30}, } ctx := buildContext(context.Background(), claims) val, ok := ctx.Value(roleIDKey).([]int) if !ok { t.Fatal("Role not properly set in context") } if len(val) != 3 { t.Fatalf("expected 3 unique roles, got %d (%v)", len(val), val) } if val[0] != 30 || val[1] != 4 || val[2] != 5 { t.Fatalf("unexpected roles in context: %v", val) } } func TestBuildContextIncludesProjectRoles(t *testing.T) { claims := &models.Claims{ UsersID: "user123", RoleID: models.RoleIDs{30}, AdditionalRoleID: models.RoleIDs{4}, Projects: []models.ProjectClaim{ {ProjectID: 10, RoleID: models.RoleIDs{44, 52}}, {ProjectID: 11, RoleID: models.RoleIDs{30, 52, 61}}, }, } ctx := buildContext(context.Background(), claims) val, ok := ctx.Value(roleIDKey).([]int) if !ok { t.Fatal("Role not properly set in context") } if len(val) != 5 { t.Fatalf("expected 5 unique roles, got %d (%v)", len(val), val) } if val[0] != 30 || val[1] != 4 || val[2] != 44 || val[3] != 52 || val[4] != 61 { t.Fatalf("unexpected roles in context: %v", val) } } func TestGetClaims(t *testing.T) { claims := &models.Claims{ UsersID: "user123", RoleID: models.RoleIDs{3}, } req := httptest.NewRequest("GET", "/", nil) ctx := context.WithValue(req.Context(), claimsKey, claims) req = req.WithContext(ctx) retrievedClaims, ok := GetClaims(req) if !ok { t.Error("Expected claims to be found") } if retrievedClaims.UsersID != "user123" { t.Errorf("Expected UserID 'user123', got '%s'", retrievedClaims.UsersID) } } func TestGetUserID(t *testing.T) { req := httptest.NewRequest("GET", "/", nil) ctx := context.WithValue(req.Context(), userIDKey, "user123") req = req.WithContext(ctx) userID, ok := GetUserID(req) if !ok { t.Error("Expected userID to be found") } if userID != "user123" { t.Errorf("Expected 'user123', got '%s'", userID) } } func TestGetRole(t *testing.T) { req := httptest.NewRequest("GET", "/", nil) ctx := context.WithValue(req.Context(), roleIDKey, []int{3}) req = req.WithContext(ctx) role, ok := GetRole(req) if !ok { t.Error("Expected role to be found") } if len(role) == 0 { t.Errorf("Expected at least one role, got '%v'", role) } else if role[0] != 3 { t.Errorf("Expected first role to be 3, got '%v'", role[0]) } } func TestJWTAuthNoAuthHeader(t *testing.T) { handler := func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) } req := httptest.NewRequest("GET", "/", nil) w := httptest.NewRecorder() JWTAuth(handler)(w, req) if w.Code != http.StatusUnauthorized { t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code) } } func TestJWTAuthInvalidToken(t *testing.T) { // Reset RSA state rsaPublicKeyOnce = sync.Once{} rsaPublicKeyError = nil rsaPublicKeyCached = nil handler := func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) } req := httptest.NewRequest("GET", "/", nil) req.Header.Set("Authorization", "Bearer invalid.token.here") w := httptest.NewRecorder() JWTAuth(handler)(w, req) if w.Code != http.StatusUnauthorized { t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code) } } func TestJWTAuthValidToken(t *testing.T) { t.Skip("Requires RSA certificate setup - integration test") } // Additional comprehensive test cases func TestExtractBearerTokenEdgeCases(t *testing.T) { testCases := []struct { name string header string wantToken string wantOk bool }{ {"Multiple spaces", "Bearer token123", " token123", true}, // Extracts everything after "Bearer " {"No space after Bearer", "Bearertoken123", "", false}, {"Lowercase bearer", "bearer token123", "", false}, {"Mixed case", "BeArEr token123", "", false}, {"Extra whitespace", " Bearer token123", "", false}, // Must start with "Bearer " {"Token with spaces", "Bearer token with spaces", "token with spaces", true}, {"Very long token", "Bearer " + string(make([]byte, 5000)), string(make([]byte, 5000)), true}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { gotToken, gotOk := extractBearerToken(tc.header) if gotToken != tc.wantToken || gotOk != tc.wantOk { t.Errorf("extractBearerToken(%q) = (%q, %v), want (%q, %v)", tc.header, gotToken, gotOk, tc.wantToken, tc.wantOk) } }) } } func TestParseAndValidateTokenMalformedTokens(t *testing.T) { // Reset RSA state rsaPublicKeyOnce = sync.Once{} rsaPublicKeyError = nil rsaPublicKeyCached = nil testCases := []struct { name string token string }{ {"Empty string", ""}, {"Random string", "not.a.jwt.token"}, {"Only dots", "..."}, {"Two parts only", "header.payload"}, {"Four parts", "part1.part2.part3.part4"}, {"Invalid base64", "!@#$.!@#$.!@#$"}, {"Spaces in token", "part1 .part2 .part3"}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { _, err := parseAndValidateToken(tc.token) if err == nil { t.Errorf("Expected error for malformed token %q", tc.name) } }) } } func TestBuildContextWithDifferentRoles(t *testing.T) { roles := []int{3, 4, 5, 6, 7, 8} for _, role := range roles { t.Run("Role: "+string(rune(role)), func(t *testing.T) { claims := &models.Claims{ UsersID: "user123", RoleID: models.RoleIDs{role}, } req := httptest.NewRequest("GET", "/", nil) newReq := buildContext(req.Context(), claims) reqWithCtx := req.WithContext(newReq) retrievedClaims, ok := GetClaims(reqWithCtx) if !ok { t.Error("Claims not found in context") } if len(retrievedClaims.RoleID) == 0 || retrievedClaims.RoleID[0] != role { t.Errorf("Role = %v, want %v", retrievedClaims.RoleID, role) } }) } } func TestGetClaimsWithoutClaims(t *testing.T) { req := httptest.NewRequest("GET", "/", nil) claims, ok := GetClaims(req) if ok { t.Error("Expected ok=false when claims not in context") } if claims != nil { t.Error("Expected nil claims when not in context") } } func TestGetClaimsWithWrongType(t *testing.T) { req := httptest.NewRequest("GET", "/", nil) ctx := context.WithValue(req.Context(), models.ContextKey("claims"), "wrong type") req = req.WithContext(ctx) claims, ok := GetClaims(req) if ok { t.Error("Expected ok=false when claims are wrong type") } if claims != nil { t.Error("Expected nil claims when wrong type in context") } } func TestGetUserIDWithNoClaims(t *testing.T) { req := httptest.NewRequest("GET", "/", nil) userID, ok := GetUserID(req) if ok { t.Error("Expected ok=false when no claims") } if userID != "" { t.Errorf("Expected empty string, got %q", userID) } } func TestGetRoleWithNoClaims(t *testing.T) { req := httptest.NewRequest("GET", "/", nil) role, ok := GetRole(req) if ok { t.Error("Expected ok=false when no claims") } if role != nil && len(role) != 0 { t.Errorf("Expected no roles, got %v", role) } } func TestJWTAuthMissingBearerPrefix(t *testing.T) { // Reset RSA state rsaPublicKeyOnce = sync.Once{} rsaPublicKeyError = nil rsaPublicKeyCached = nil handler := func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) } req := httptest.NewRequest("GET", "/", nil) req.Header.Set("Authorization", "InvalidToken") w := httptest.NewRecorder() JWTAuth(handler)(w, req) if w.Code != http.StatusUnauthorized { t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code) } } func TestJWTAuthExpiredToken(t *testing.T) { t.Skip("Requires RSA certificate setup - integration test") } func TestJWTAuthTokenWithMissingClaims(t *testing.T) { t.Skip("Requires RSA certificate setup - integration test") testCases := []struct { name string claims *models.Claims }{ { "Missing UserID", &models.Claims{ RoleID: models.RoleIDs{3}, RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), }, }, }, { "Missing Role", &models.Claims{ UsersID: "user123", RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), }, }, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { token := jwt.NewWithClaims(jwt.SigningMethodHS256, tc.claims) tokenString, _ := token.SignedString([]byte("test-secret")) handler := func(w http.ResponseWriter, r *http.Request) { claims, ok := GetClaims(r) if !ok { t.Error("Claims should still be in context even if some fields are empty") } // Verify the missing field is empty _ = claims w.WriteHeader(http.StatusOK) } req := httptest.NewRequest("GET", "/", nil) req.Header.Set("Authorization", "Bearer "+tokenString) w := httptest.NewRecorder() JWTAuth(handler)(w, req) // Token is valid, just missing some claim fields if w.Code != http.StatusOK { t.Errorf("Expected status %d, got %d", http.StatusOK, w.Code) } }) } } func TestJWTAuthConcurrentRequests(t *testing.T) { t.Skip("Requires RSA certificate setup - integration test") claims := &models.Claims{ UsersID: "user123", RoleID: models.RoleIDs{3}, RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), }, } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) tokenString, _ := token.SignedString([]byte("test-secret")) handler := func(w http.ResponseWriter, r *http.Request) { if _, ok := GetClaims(r); !ok { t.Error("Claims not found in concurrent request") } w.WriteHeader(http.StatusOK) } concurrency := 50 done := make(chan bool, concurrency) for i := 0; i < concurrency; i++ { go func() { req := httptest.NewRequest("GET", "/", nil) req.Header.Set("Authorization", "Bearer "+tokenString) w := httptest.NewRecorder() JWTAuth(handler)(w, req) if w.Code != http.StatusOK { t.Errorf("Expected status %d in concurrent request, got %d", http.StatusOK, w.Code) } done <- true }() } for i := 0; i < concurrency; i++ { <-done } } func TestJWTAuthTokenSignedWithWrongKey(t *testing.T) { t.Skip("Requires RSA certificate setup - integration test") // Create token with wrong key claims := &models.Claims{ UsersID: "user123", RoleID: models.RoleIDs{3}, RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), }, } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) tokenString, _ := token.SignedString([]byte("wrong-secret")) handler := func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) } req := httptest.NewRequest("GET", "/", nil) req.Header.Set("Authorization", "Bearer "+tokenString) w := httptest.NewRecorder() JWTAuth(handler)(w, req) if w.Code != http.StatusUnauthorized { t.Errorf("Expected status %d for wrong signature, got %d", http.StatusUnauthorized, w.Code) } }