package middleware import ( "authorization/models" "context" "net/http" "net/http/httptest" "os" "sync" "testing" "time" "github.com/golang-jwt/jwt/v5" ) func TestGetJWTSecret(t *testing.T) { // Save original and restore after test originalSecret := os.Getenv("JWT_KEY") defer func() { if originalSecret != "" { os.Setenv("JWT_KEY", originalSecret) } else { os.Unsetenv("JWT_KEY") } }() t.Run("JWT_KEY not set", func(t *testing.T) { // Reset state for testing oldCached := jwtSecretCached oldError := jwtSecretError defer func() { jwtSecretCached = oldCached jwtSecretError = oldError }() os.Unsetenv("JWT_KEY") jwtSecretOnce = sync.Once{} jwtSecretError = nil jwtSecretCached = nil _, err := getJWTSecret() if err == nil { t.Error("Expected error when JWT_KEY is not set") } }) t.Run("JWT_KEY set", func(t *testing.T) { // Reset state for testing oldCached := jwtSecretCached oldError := jwtSecretError defer func() { jwtSecretCached = oldCached jwtSecretError = oldError }() os.Setenv("JWT_KEY", "test-secret-key") jwtSecretOnce = sync.Once{} jwtSecretError = nil jwtSecretCached = nil secret, err := getJWTSecret() if err != nil { t.Errorf("Expected no error, got %v", err) } if string(secret) != "test-secret-key" { t.Errorf("Expected 'test-secret-key', got '%s'", string(secret)) } }) } func TestExtractBearerToken(t *testing.T) { tests := []struct { name string authHeader string wantToken string wantOK bool }{ { name: "Valid Bearer token", authHeader: "Bearer token123", wantToken: "token123", wantOK: true, }, { name: "Empty header", authHeader: "", wantToken: "", wantOK: false, }, { name: "Too short", authHeader: "Bearer", wantToken: "", wantOK: false, }, { name: "Wrong prefix", authHeader: "Basic token123", wantToken: "", wantOK: false, }, { name: "Missing space", authHeader: "Bearertoken123", wantToken: "", wantOK: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { token, ok := extractBearerToken(tt.authHeader) if token != tt.wantToken { t.Errorf("Expected token '%s', got '%s'", tt.wantToken, token) } if ok != tt.wantOK { t.Errorf("Expected ok %v, got %v", tt.wantOK, ok) } }) } } func TestParseAndValidateToken(t *testing.T) { // Setup os.Setenv("JWT_KEY", "test-secret-key") jwtSecretOnce = sync.Once{} jwtSecretError = nil defer os.Unsetenv("JWT_KEY") t.Run("Valid token", func(t *testing.T) { // Create a valid token claims := &models.Claims{ UserID: "user123", Username: "testuser", Role: "admin", RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), }, } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) tokenString, err := token.SignedString([]byte("test-secret-key")) if err != nil { t.Fatalf("Failed to create token: %v", err) } parsedClaims, err := parseAndValidateToken(tokenString) if err != nil { t.Errorf("Expected no error, got %v", err) } if parsedClaims.UserID != "user123" { t.Errorf("Expected UserID 'user123', got '%s'", parsedClaims.UserID) } }) t.Run("Invalid token", func(t *testing.T) { _, err := parseAndValidateToken("invalid.token.string") if err == nil { t.Error("Expected error for invalid token") } }) t.Run("Expired token", func(t *testing.T) { claims := &models.Claims{ UserID: "user123", Username: "testuser", Role: "admin", RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(time.Now().Add(-1 * time.Hour)), }, } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) tokenString, _ := token.SignedString([]byte("test-secret-key")) _, err := parseAndValidateToken(tokenString) if err == nil { t.Error("Expected error for expired token") } }) } func TestBuildContext(t *testing.T) { claims := &models.Claims{ UserID: "user123", Username: "testuser", Role: "admin", } parent := context.Background() ctx := buildContext(parent, claims) // Check claims if val, ok := ctx.Value(claimsKey).(*models.Claims); !ok || val.UserID != "user123" { t.Error("Claims not properly set in context") } // Check userID if val, ok := ctx.Value(userIDKey).(string); !ok || val != "user123" { t.Error("UserID not properly set in context") } // Check username if val, ok := ctx.Value(usernameKey).(string); !ok || val != "testuser" { t.Error("Username not properly set in context") } // Check role if val, ok := ctx.Value(roleKey).(string); !ok || val != "admin" { t.Error("Role not properly set in context") } } func TestGetClaims(t *testing.T) { claims := &models.Claims{ UserID: "user123", Username: "testuser", Role: "admin", } req := httptest.NewRequest("GET", "/", nil) ctx := context.WithValue(req.Context(), claimsKey, claims) req = req.WithContext(ctx) retrievedClaims, ok := GetClaims(req) if !ok { t.Error("Expected claims to be found") } if retrievedClaims.UserID != "user123" { t.Errorf("Expected UserID 'user123', got '%s'", retrievedClaims.UserID) } } func TestGetUserID(t *testing.T) { req := httptest.NewRequest("GET", "/", nil) ctx := context.WithValue(req.Context(), userIDKey, "user123") req = req.WithContext(ctx) userID, ok := GetUserID(req) if !ok { t.Error("Expected userID to be found") } if userID != "user123" { t.Errorf("Expected 'user123', got '%s'", userID) } } func TestGetUsername(t *testing.T) { req := httptest.NewRequest("GET", "/", nil) ctx := context.WithValue(req.Context(), usernameKey, "testuser") req = req.WithContext(ctx) username, ok := GetUsername(req) if !ok { t.Error("Expected username to be found") } if username != "testuser" { t.Errorf("Expected 'testuser', got '%s'", username) } } func TestGetRole(t *testing.T) { req := httptest.NewRequest("GET", "/", nil) ctx := context.WithValue(req.Context(), roleKey, "admin") req = req.WithContext(ctx) role, ok := GetRole(req) if !ok { t.Error("Expected role to be found") } if role != "admin" { t.Errorf("Expected 'admin', got '%s'", role) } } func TestJWTAuth_NoAuthHeader(t *testing.T) { handler := func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) } req := httptest.NewRequest("GET", "/", nil) w := httptest.NewRecorder() JWTAuth(handler)(w, req) if w.Code != http.StatusUnauthorized { t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code) } } func TestJWTAuth_InvalidToken(t *testing.T) { os.Setenv("JWT_KEY", "test-secret-key") jwtSecretOnce = sync.Once{} jwtSecretError = nil defer os.Unsetenv("JWT_KEY") handler := func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) } req := httptest.NewRequest("GET", "/", nil) req.Header.Set("Authorization", "Bearer invalid.token.here") w := httptest.NewRecorder() JWTAuth(handler)(w, req) if w.Code != http.StatusUnauthorized { t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code) } } func TestJWTAuth_ValidToken(t *testing.T) { os.Setenv("JWT_KEY", "test-secret-key") jwtSecretOnce = sync.Once{} jwtSecretError = nil defer os.Unsetenv("JWT_KEY") claims := &models.Claims{ UserID: "user123", Username: "testuser", Role: "admin", RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), }, } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) tokenString, _ := token.SignedString([]byte("test-secret-key")) handler := func(w http.ResponseWriter, r *http.Request) { // Verify claims are in context if retrievedClaims, ok := GetClaims(r); !ok || retrievedClaims.UserID != "user123" { t.Error("Claims not found or incorrect in context") } w.WriteHeader(http.StatusOK) } req := httptest.NewRequest("GET", "/", nil) req.Header.Set("Authorization", "Bearer "+tokenString) w := httptest.NewRecorder() JWTAuth(handler)(w, req) if w.Code != http.StatusOK { t.Errorf("Expected status %d, got %d", http.StatusOK, w.Code) } } // Additional comprehensive test cases func TestExtractBearerToken_EdgeCases(t *testing.T) { testCases := []struct { name string header string wantToken string wantOk bool }{ {"Multiple spaces", "Bearer token123", " token123", true}, // Extracts everything after "Bearer " {"No space after Bearer", "Bearertoken123", "", false}, {"Lowercase bearer", "bearer token123", "", false}, {"Mixed case", "BeArEr token123", "", false}, {"Extra whitespace", " Bearer token123", "", false}, // Must start with "Bearer " {"Token with spaces", "Bearer token with spaces", "token with spaces", true}, {"Very long token", "Bearer " + string(make([]byte, 5000)), string(make([]byte, 5000)), true}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { gotToken, gotOk := extractBearerToken(tc.header) if gotToken != tc.wantToken || gotOk != tc.wantOk { t.Errorf("extractBearerToken(%q) = (%q, %v), want (%q, %v)", tc.header, gotToken, gotOk, tc.wantToken, tc.wantOk) } }) } } func TestParseAndValidateToken_MalformedTokens(t *testing.T) { os.Setenv("JWT_KEY", "test-secret") defer os.Unsetenv("JWT_KEY") jwtSecretOnce = sync.Once{} jwtSecretError = nil jwtSecretCached = nil testCases := []struct { name string token string }{ {"Empty string", ""}, {"Random string", "not.a.jwt.token"}, {"Only dots", "..."}, {"Two parts only", "header.payload"}, {"Four parts", "part1.part2.part3.part4"}, {"Invalid base64", "!@#$.!@#$.!@#$"}, {"Spaces in token", "part1 .part2 .part3"}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { _, err := parseAndValidateToken(tc.token) if err == nil { t.Errorf("Expected error for malformed token %q", tc.name) } }) } } func TestBuildContext_WithDifferentRoles(t *testing.T) { roles := []string{"admin", "user", "guest", "superadmin", "", "role-with-dash"} for _, role := range roles { t.Run("Role: "+role, func(t *testing.T) { claims := &models.Claims{ UserID: "user123", Username: "testuser", Role: role, } req := httptest.NewRequest("GET", "/", nil) newReq := buildContext(req.Context(), claims) reqWithCtx := req.WithContext(newReq) retrievedClaims, ok := GetClaims(reqWithCtx) if !ok { t.Error("Claims not found in context") } if retrievedClaims.Role != role { t.Errorf("Role = %q, want %q", retrievedClaims.Role, role) } }) } } func TestGetClaims_WithoutClaims(t *testing.T) { req := httptest.NewRequest("GET", "/", nil) claims, ok := GetClaims(req) if ok { t.Error("Expected ok=false when claims not in context") } if claims != nil { t.Error("Expected nil claims when not in context") } } func TestGetClaims_WithWrongType(t *testing.T) { req := httptest.NewRequest("GET", "/", nil) ctx := context.WithValue(req.Context(), models.ContextKey("claims"), "wrong type") req = req.WithContext(ctx) claims, ok := GetClaims(req) if ok { t.Error("Expected ok=false when claims are wrong type") } if claims != nil { t.Error("Expected nil claims when wrong type in context") } } func TestGetUserID_WithNoClaims(t *testing.T) { req := httptest.NewRequest("GET", "/", nil) userID, ok := GetUserID(req) if ok { t.Error("Expected ok=false when no claims") } if userID != "" { t.Errorf("Expected empty string, got %q", userID) } } func TestGetUsername_WithNoClaims(t *testing.T) { req := httptest.NewRequest("GET", "/", nil) username, ok := GetUsername(req) if ok { t.Error("Expected ok=false when no claims") } if username != "" { t.Errorf("Expected empty string, got %q", username) } } func TestGetRole_WithNoClaims(t *testing.T) { req := httptest.NewRequest("GET", "/", nil) role, ok := GetRole(req) if ok { t.Error("Expected ok=false when no claims") } if role != "" { t.Errorf("Expected empty string, got %q", role) } } func TestJWTAuth_MissingBearerPrefix(t *testing.T) { os.Setenv("JWT_KEY", "test-secret") defer os.Unsetenv("JWT_KEY") jwtSecretOnce = sync.Once{} jwtSecretError = nil jwtSecretCached = nil handler := func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) } req := httptest.NewRequest("GET", "/", nil) req.Header.Set("Authorization", "InvalidToken") w := httptest.NewRecorder() JWTAuth(handler)(w, req) if w.Code != http.StatusUnauthorized { t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code) } } func TestJWTAuth_ExpiredToken(t *testing.T) { os.Setenv("JWT_KEY", "test-secret") defer os.Unsetenv("JWT_KEY") jwtSecretOnce = sync.Once{} jwtSecretError = nil jwtSecretCached = nil // Create token that's already expired claims := &models.Claims{ UserID: "user123", Username: "testuser", Role: "admin", RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(time.Now().Add(-1 * time.Hour)), }, } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) tokenString, _ := token.SignedString([]byte("test-secret")) handler := func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) } req := httptest.NewRequest("GET", "/", nil) req.Header.Set("Authorization", "Bearer "+tokenString) w := httptest.NewRecorder() JWTAuth(handler)(w, req) if w.Code != http.StatusUnauthorized { t.Errorf("Expected status %d for expired token, got %d", http.StatusUnauthorized, w.Code) } } func TestJWTAuth_TokenWithMissingClaims(t *testing.T) { os.Setenv("JWT_KEY", "test-secret") defer os.Unsetenv("JWT_KEY") jwtSecretOnce = sync.Once{} jwtSecretError = nil jwtSecretCached = nil testCases := []struct { name string claims *models.Claims }{ { "Missing UserID", &models.Claims{ Username: "testuser", Role: "admin", RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), }, }, }, { "Missing Username", &models.Claims{ UserID: "user123", Role: "admin", RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), }, }, }, { "Missing Role", &models.Claims{ UserID: "user123", Username: "testuser", RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), }, }, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { token := jwt.NewWithClaims(jwt.SigningMethodHS256, tc.claims) tokenString, _ := token.SignedString([]byte("test-secret")) handler := func(w http.ResponseWriter, r *http.Request) { claims, ok := GetClaims(r) if !ok { t.Error("Claims should still be in context even if some fields are empty") } // Verify the missing field is empty _ = claims w.WriteHeader(http.StatusOK) } req := httptest.NewRequest("GET", "/", nil) req.Header.Set("Authorization", "Bearer "+tokenString) w := httptest.NewRecorder() JWTAuth(handler)(w, req) // Token is valid, just missing some claim fields if w.Code != http.StatusOK { t.Errorf("Expected status %d, got %d", http.StatusOK, w.Code) } }) } } func TestJWTAuth_ConcurrentRequests(t *testing.T) { os.Setenv("JWT_KEY", "test-secret") defer os.Unsetenv("JWT_KEY") jwtSecretOnce = sync.Once{} jwtSecretError = nil jwtSecretCached = nil claims := &models.Claims{ UserID: "user123", Username: "testuser", Role: "admin", RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), }, } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) tokenString, _ := token.SignedString([]byte("test-secret")) handler := func(w http.ResponseWriter, r *http.Request) { if _, ok := GetClaims(r); !ok { t.Error("Claims not found in concurrent request") } w.WriteHeader(http.StatusOK) } concurrency := 50 done := make(chan bool, concurrency) for i := 0; i < concurrency; i++ { go func() { req := httptest.NewRequest("GET", "/", nil) req.Header.Set("Authorization", "Bearer "+tokenString) w := httptest.NewRecorder() JWTAuth(handler)(w, req) if w.Code != http.StatusOK { t.Errorf("Expected status %d in concurrent request, got %d", http.StatusOK, w.Code) } done <- true }() } for i := 0; i < concurrency; i++ { <-done } } func TestJWTAuth_TokenSignedWithWrongKey(t *testing.T) { os.Setenv("JWT_KEY", "correct-secret") defer os.Unsetenv("JWT_KEY") jwtSecretOnce = sync.Once{} jwtSecretError = nil jwtSecretCached = nil // Create token with wrong key claims := &models.Claims{ UserID: "user123", Username: "testuser", Role: "admin", RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), }, } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) tokenString, _ := token.SignedString([]byte("wrong-secret")) handler := func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) } req := httptest.NewRequest("GET", "/", nil) req.Header.Set("Authorization", "Bearer "+tokenString) w := httptest.NewRecorder() JWTAuth(handler)(w, req) if w.Code != http.StatusUnauthorized { t.Errorf("Expected status %d for wrong signature, got %d", http.StatusUnauthorized, w.Code) } }