285 lines
7.4 KiB
Go
285 lines
7.4 KiB
Go
package middleware
|
|
|
|
import (
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"testing"
|
|
)
|
|
|
|
func TestSetHeaders(t *testing.T) {
|
|
os.Setenv("GO_ENV", "production")
|
|
defer os.Unsetenv("GO_ENV")
|
|
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
middleware := SetHeaders(handler)
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
|
recorder := httptest.NewRecorder()
|
|
|
|
middleware.ServeHTTP(recorder, req)
|
|
|
|
// Check security headers
|
|
headers := map[string]string{
|
|
"X-DNS-Prefetch-Control": "off",
|
|
"X-Frame-Options": "DENY",
|
|
"X-XSS-Protection": "1; mode=block",
|
|
"X-Content-Type-Options": "nosniff",
|
|
"Content-Security-Policy": "default-src 'self'",
|
|
"Referrer-Policy": "no-referrer",
|
|
"X-Powered-By": "Zig",
|
|
"Strict-Transport-Security": "max-age=63072000; includeSubDomains; preload",
|
|
"Content-Type": "application/json",
|
|
}
|
|
|
|
for header, expected := range headers {
|
|
actual := recorder.Header().Get(header)
|
|
if actual != expected {
|
|
t.Errorf("Expected header %s to be '%s', got '%s'", header, expected, actual)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestSetHeadersDevelopment(t *testing.T) {
|
|
os.Setenv("GO_ENV", "development")
|
|
defer os.Unsetenv("GO_ENV")
|
|
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
middleware := SetHeaders(handler)
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
|
recorder := httptest.NewRecorder()
|
|
|
|
middleware.ServeHTTP(recorder, req)
|
|
|
|
// HSTS should not be set in development
|
|
hsts := recorder.Header().Get("Strict-Transport-Security")
|
|
if hsts != "" {
|
|
t.Errorf("Expected no HSTS header in development, got '%s'", hsts)
|
|
}
|
|
|
|
// Other security headers should still be present
|
|
if recorder.Header().Get("X-Frame-Options") != "DENY" {
|
|
t.Error("Expected X-Frame-Options header in development")
|
|
}
|
|
}
|
|
|
|
func TestSetHeadersSSE(t *testing.T) {
|
|
os.Setenv("GO_ENV", "production")
|
|
defer os.Unsetenv("GO_ENV")
|
|
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
middleware := SetHeaders(handler)
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/stream", nil)
|
|
recorder := httptest.NewRecorder()
|
|
|
|
// Pre-set SSE content type
|
|
recorder.Header().Set("Content-Type", "text/event-stream")
|
|
|
|
middleware.ServeHTTP(recorder, req)
|
|
|
|
// Content-Type should remain text/event-stream
|
|
contentType := recorder.Header().Get("Content-Type")
|
|
if contentType != "text/event-stream" {
|
|
t.Errorf("Expected Content-Type 'text/event-stream', got '%s'", contentType)
|
|
}
|
|
}
|
|
|
|
func TestSetHeadersOptions(t *testing.T) {
|
|
os.Setenv("GO_ENV", "production")
|
|
defer os.Unsetenv("GO_ENV")
|
|
|
|
handlerCalled := false
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
handlerCalled = true
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
middleware := SetHeaders(handler)
|
|
|
|
req := httptest.NewRequest(http.MethodOptions, "/test", nil)
|
|
recorder := httptest.NewRecorder()
|
|
|
|
middleware.ServeHTTP(recorder, req)
|
|
|
|
// OPTIONS should return 200 without calling next handler
|
|
if recorder.Code != http.StatusOK {
|
|
t.Errorf("Expected status 200 for OPTIONS, got %d", recorder.Code)
|
|
}
|
|
|
|
if handlerCalled {
|
|
t.Error("Expected next handler NOT to be called for OPTIONS request")
|
|
}
|
|
|
|
// Security headers should still be set
|
|
if recorder.Header().Get("X-Frame-Options") != "DENY" {
|
|
t.Error("Expected security headers to be set for OPTIONS")
|
|
}
|
|
}
|
|
|
|
func TestSetHeadersAllMethods(t *testing.T) {
|
|
os.Setenv("GO_ENV", "production")
|
|
defer os.Unsetenv("GO_ENV")
|
|
|
|
methods := []string{
|
|
http.MethodGet,
|
|
http.MethodPost,
|
|
http.MethodPut,
|
|
http.MethodDelete,
|
|
http.MethodPatch,
|
|
}
|
|
|
|
for _, method := range methods {
|
|
t.Run(method, func(t *testing.T) {
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
middleware := SetHeaders(handler)
|
|
|
|
req := httptest.NewRequest(method, "/test", nil)
|
|
recorder := httptest.NewRecorder()
|
|
|
|
middleware.ServeHTTP(recorder, req)
|
|
|
|
// All methods should have security headers
|
|
if recorder.Header().Get("X-Frame-Options") != "DENY" {
|
|
t.Errorf("Expected X-Frame-Options for %s", method)
|
|
}
|
|
|
|
if recorder.Header().Get("Content-Type") != "application/json" {
|
|
t.Errorf("Expected Content-Type application/json for %s", method)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestSetHeadersEnvironments(t *testing.T) {
|
|
environments := []string{"development", "production", "canary", "debug"}
|
|
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
for _, env := range environments {
|
|
t.Run(env, func(t *testing.T) {
|
|
os.Setenv("GO_ENV", env)
|
|
defer os.Unsetenv("GO_ENV")
|
|
|
|
middleware := SetHeaders(handler)
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
|
recorder := httptest.NewRecorder()
|
|
|
|
middleware.ServeHTTP(recorder, req)
|
|
|
|
// HSTS should only be set in non-development environments
|
|
hsts := recorder.Header().Get("Strict-Transport-Security")
|
|
if env == "development" {
|
|
if hsts != "" {
|
|
t.Errorf("HSTS should not be set in development, got '%s'", hsts)
|
|
}
|
|
} else {
|
|
if hsts == "" {
|
|
t.Errorf("HSTS should be set in %s environment", env)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestSetHeadersPoweredBy(t *testing.T) {
|
|
os.Setenv("GO_ENV", "production")
|
|
defer os.Unsetenv("GO_ENV")
|
|
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
middleware := SetHeaders(handler)
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
|
recorder := httptest.NewRecorder()
|
|
|
|
middleware.ServeHTTP(recorder, req)
|
|
|
|
poweredBy := recorder.Header().Get("X-Powered-By")
|
|
if poweredBy != "Zig" {
|
|
t.Errorf("Expected X-Powered-By 'Zig', got '%s'", poweredBy)
|
|
}
|
|
}
|
|
|
|
func TestSetHeadersCSP(t *testing.T) {
|
|
os.Setenv("GO_ENV", "production")
|
|
defer os.Unsetenv("GO_ENV")
|
|
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
middleware := SetHeaders(handler)
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
|
recorder := httptest.NewRecorder()
|
|
|
|
middleware.ServeHTTP(recorder, req)
|
|
|
|
csp := recorder.Header().Get("Content-Security-Policy")
|
|
if csp != "default-src 'self'" {
|
|
t.Errorf("Expected CSP 'default-src 'self'', got '%s'", csp)
|
|
}
|
|
}
|
|
|
|
func TestSetHeadersReferrerPolicy(t *testing.T) {
|
|
os.Setenv("GO_ENV", "production")
|
|
defer os.Unsetenv("GO_ENV")
|
|
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
middleware := SetHeaders(handler)
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
|
recorder := httptest.NewRecorder()
|
|
|
|
middleware.ServeHTTP(recorder, req)
|
|
|
|
referrer := recorder.Header().Get("Referrer-Policy")
|
|
if referrer != "no-referrer" {
|
|
t.Errorf("Expected Referrer-Policy 'no-referrer', got '%s'", referrer)
|
|
}
|
|
}
|
|
|
|
func TestSetHeadersXSSProtection(t *testing.T) {
|
|
os.Setenv("GO_ENV", "production")
|
|
defer os.Unsetenv("GO_ENV")
|
|
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
middleware := SetHeaders(handler)
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
|
recorder := httptest.NewRecorder()
|
|
|
|
middleware.ServeHTTP(recorder, req)
|
|
|
|
xss := recorder.Header().Get("X-XSS-Protection")
|
|
if xss != "1; mode=block" {
|
|
t.Errorf("Expected X-XSS-Protection '1; mode=block', got '%s'", xss)
|
|
}
|
|
}
|