Files
2025-11-25 15:12:31 +08:00

285 lines
7.4 KiB
Go

package middleware
import (
"net/http"
"net/http/httptest"
"os"
"testing"
)
func TestSetHeaders(t *testing.T) {
os.Setenv("GO_ENV", "production")
defer os.Unsetenv("GO_ENV")
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
middleware := SetHeaders(handler)
req := httptest.NewRequest(http.MethodGet, "/test", nil)
recorder := httptest.NewRecorder()
middleware.ServeHTTP(recorder, req)
// Check security headers
headers := map[string]string{
"X-DNS-Prefetch-Control": "off",
"X-Frame-Options": "DENY",
"X-XSS-Protection": "1; mode=block",
"X-Content-Type-Options": "nosniff",
"Content-Security-Policy": "default-src 'self'",
"Referrer-Policy": "no-referrer",
"X-Powered-By": "Zig",
"Strict-Transport-Security": "max-age=63072000; includeSubDomains; preload",
"Content-Type": "application/json",
}
for header, expected := range headers {
actual := recorder.Header().Get(header)
if actual != expected {
t.Errorf("Expected header %s to be '%s', got '%s'", header, expected, actual)
}
}
}
func TestSetHeadersDevelopment(t *testing.T) {
os.Setenv("GO_ENV", "development")
defer os.Unsetenv("GO_ENV")
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
middleware := SetHeaders(handler)
req := httptest.NewRequest(http.MethodGet, "/test", nil)
recorder := httptest.NewRecorder()
middleware.ServeHTTP(recorder, req)
// HSTS should not be set in development
hsts := recorder.Header().Get("Strict-Transport-Security")
if hsts != "" {
t.Errorf("Expected no HSTS header in development, got '%s'", hsts)
}
// Other security headers should still be present
if recorder.Header().Get("X-Frame-Options") != "DENY" {
t.Error("Expected X-Frame-Options header in development")
}
}
func TestSetHeadersSSE(t *testing.T) {
os.Setenv("GO_ENV", "production")
defer os.Unsetenv("GO_ENV")
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
})
middleware := SetHeaders(handler)
req := httptest.NewRequest(http.MethodGet, "/stream", nil)
recorder := httptest.NewRecorder()
// Pre-set SSE content type
recorder.Header().Set("Content-Type", "text/event-stream")
middleware.ServeHTTP(recorder, req)
// Content-Type should remain text/event-stream
contentType := recorder.Header().Get("Content-Type")
if contentType != "text/event-stream" {
t.Errorf("Expected Content-Type 'text/event-stream', got '%s'", contentType)
}
}
func TestSetHeadersOptions(t *testing.T) {
os.Setenv("GO_ENV", "production")
defer os.Unsetenv("GO_ENV")
handlerCalled := false
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handlerCalled = true
w.WriteHeader(http.StatusOK)
})
middleware := SetHeaders(handler)
req := httptest.NewRequest(http.MethodOptions, "/test", nil)
recorder := httptest.NewRecorder()
middleware.ServeHTTP(recorder, req)
// OPTIONS should return 200 without calling next handler
if recorder.Code != http.StatusOK {
t.Errorf("Expected status 200 for OPTIONS, got %d", recorder.Code)
}
if handlerCalled {
t.Error("Expected next handler NOT to be called for OPTIONS request")
}
// Security headers should still be set
if recorder.Header().Get("X-Frame-Options") != "DENY" {
t.Error("Expected security headers to be set for OPTIONS")
}
}
func TestSetHeadersAllMethods(t *testing.T) {
os.Setenv("GO_ENV", "production")
defer os.Unsetenv("GO_ENV")
methods := []string{
http.MethodGet,
http.MethodPost,
http.MethodPut,
http.MethodDelete,
http.MethodPatch,
}
for _, method := range methods {
t.Run(method, func(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
middleware := SetHeaders(handler)
req := httptest.NewRequest(method, "/test", nil)
recorder := httptest.NewRecorder()
middleware.ServeHTTP(recorder, req)
// All methods should have security headers
if recorder.Header().Get("X-Frame-Options") != "DENY" {
t.Errorf("Expected X-Frame-Options for %s", method)
}
if recorder.Header().Get("Content-Type") != "application/json" {
t.Errorf("Expected Content-Type application/json for %s", method)
}
})
}
}
func TestSetHeadersEnvironments(t *testing.T) {
environments := []string{"development", "production", "canary", "debug"}
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
for _, env := range environments {
t.Run(env, func(t *testing.T) {
os.Setenv("GO_ENV", env)
defer os.Unsetenv("GO_ENV")
middleware := SetHeaders(handler)
req := httptest.NewRequest(http.MethodGet, "/test", nil)
recorder := httptest.NewRecorder()
middleware.ServeHTTP(recorder, req)
// HSTS should only be set in non-development environments
hsts := recorder.Header().Get("Strict-Transport-Security")
if env == "development" {
if hsts != "" {
t.Errorf("HSTS should not be set in development, got '%s'", hsts)
}
} else {
if hsts == "" {
t.Errorf("HSTS should be set in %s environment", env)
}
}
})
}
}
func TestSetHeadersPoweredBy(t *testing.T) {
os.Setenv("GO_ENV", "production")
defer os.Unsetenv("GO_ENV")
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
middleware := SetHeaders(handler)
req := httptest.NewRequest(http.MethodGet, "/test", nil)
recorder := httptest.NewRecorder()
middleware.ServeHTTP(recorder, req)
poweredBy := recorder.Header().Get("X-Powered-By")
if poweredBy != "Zig" {
t.Errorf("Expected X-Powered-By 'Zig', got '%s'", poweredBy)
}
}
func TestSetHeadersCSP(t *testing.T) {
os.Setenv("GO_ENV", "production")
defer os.Unsetenv("GO_ENV")
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
middleware := SetHeaders(handler)
req := httptest.NewRequest(http.MethodGet, "/test", nil)
recorder := httptest.NewRecorder()
middleware.ServeHTTP(recorder, req)
csp := recorder.Header().Get("Content-Security-Policy")
if csp != "default-src 'self'" {
t.Errorf("Expected CSP 'default-src 'self'', got '%s'", csp)
}
}
func TestSetHeadersReferrerPolicy(t *testing.T) {
os.Setenv("GO_ENV", "production")
defer os.Unsetenv("GO_ENV")
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
middleware := SetHeaders(handler)
req := httptest.NewRequest(http.MethodGet, "/test", nil)
recorder := httptest.NewRecorder()
middleware.ServeHTTP(recorder, req)
referrer := recorder.Header().Get("Referrer-Policy")
if referrer != "no-referrer" {
t.Errorf("Expected Referrer-Policy 'no-referrer', got '%s'", referrer)
}
}
func TestSetHeadersXSSProtection(t *testing.T) {
os.Setenv("GO_ENV", "production")
defer os.Unsetenv("GO_ENV")
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
middleware := SetHeaders(handler)
req := httptest.NewRequest(http.MethodGet, "/test", nil)
recorder := httptest.NewRecorder()
middleware.ServeHTTP(recorder, req)
xss := recorder.Header().Get("X-XSS-Protection")
if xss != "1; mode=block" {
t.Errorf("Expected X-XSS-Protection '1; mode=block', got '%s'", xss)
}
}