Files
Authorization/helper/circuit_breaker_test.go
T
2025-12-16 10:13:24 +08:00

460 lines
9.5 KiB
Go

package helper
import (
"errors"
"sync"
"testing"
"time"
)
func TestNewCircuitBreaker(t *testing.T) {
tests := []struct {
name string
cbName string
maxFailures int
timeout time.Duration
wantState CircuitState
}{
{
name: "creates circuit breaker with correct defaults",
cbName: "test-service",
maxFailures: 5,
timeout: 2 * time.Second,
wantState: StateClosed,
},
{
name: "creates circuit breaker with different parameters",
cbName: "db-service",
maxFailures: 3,
timeout: 1 * time.Second,
wantState: StateClosed,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cb := NewCircuitBreaker(tt.cbName, tt.maxFailures, tt.timeout)
if cb.name != tt.cbName {
t.Errorf("name = %v, want %v", cb.name, tt.cbName)
}
if cb.maxFailures != tt.maxFailures {
t.Errorf("maxFailures = %v, want %v", cb.maxFailures, tt.maxFailures)
}
if cb.timeout != tt.timeout {
t.Errorf("timeout = %v, want %v", cb.timeout, tt.timeout)
}
if cb.state != tt.wantState {
t.Errorf("state = %v, want %v", cb.state, tt.wantState)
}
if cb.resetTimeout != 30*time.Second {
t.Errorf("resetTimeout = %v, want %v", cb.resetTimeout, 30*time.Second)
}
if cb.failures != 0 {
t.Errorf("failures = %v, want 0", cb.failures)
}
})
}
}
func TestCircuitBreaker_Call_Success(t *testing.T) {
cb := NewCircuitBreaker("test", 3, 1*time.Second)
successFn := func() error {
return nil
}
err := Call(cb, successFn)
if err != nil {
t.Errorf("Call() error = %v, want nil", err)
}
if GetState(cb) != StateClosed {
t.Errorf("state = %v, want %v", GetState(cb), StateClosed)
}
}
func TestCircuitBreaker_Call_FailuresOpenCircuit(t *testing.T) {
cb := NewCircuitBreaker("test", 3, 1*time.Second)
failFn := func() error {
return errors.New("service error")
}
// First 2 failures - circuit should stay closed
for i := 0; i < 2; i++ {
err := Call(cb, failFn)
if err == nil {
t.Errorf("Call() iteration %d: expected error, got nil", i)
}
if GetState(cb) != StateClosed {
t.Errorf("iteration %d: state = %v, want %v", i, GetState(cb), StateClosed)
}
}
// 3rd failure - circuit should open
err := Call(cb, failFn)
if err == nil {
t.Error("Call() expected error, got nil")
}
if GetState(cb) != StateOpen {
t.Errorf("state = %v, want %v", GetState(cb), StateOpen)
}
// Next call should immediately return circuit breaker error
err = Call(cb, failFn)
if !IsCircuitBreakerError(err) {
t.Errorf("expected CircuitBreakerError, got %v", err)
}
}
func TestCircuitBreaker_Call_OpenToHalfOpen(t *testing.T) {
cb := NewCircuitBreaker("test", 2, 1*time.Second)
cb.resetTimeout = 100 * time.Millisecond // Shorter reset for testing
failFn := func() error {
return errors.New("service error")
}
// Open the circuit
Call(cb, failFn)
Call(cb, failFn)
if GetState(cb) != StateOpen {
t.Fatalf("state = %v, want %v", GetState(cb), StateOpen)
}
// Wait for reset timeout
time.Sleep(150 * time.Millisecond)
// Next call should transition to HalfOpen
successFn := func() error {
return nil
}
err := Call(cb, successFn)
if err != nil {
t.Errorf("Call() error = %v, want nil", err)
}
// Should now be closed
if GetState(cb) != StateClosed {
t.Errorf("state = %v, want %v", GetState(cb), StateClosed)
}
}
func TestCircuitBreaker_Call_HalfOpenFailReturnsToOpen(t *testing.T) {
cb := NewCircuitBreaker("test", 2, 1*time.Second)
cb.resetTimeout = 100 * time.Millisecond
failFn := func() error {
return errors.New("service error")
}
// Open the circuit
Call(cb, failFn)
Call(cb, failFn)
if GetState(cb) != StateOpen {
t.Fatalf("state = %v, want %v", GetState(cb), StateOpen)
}
// Wait for reset timeout to transition to HalfOpen
time.Sleep(150 * time.Millisecond)
// Fail in HalfOpen state - should return to Open
err := Call(cb, failFn)
if err == nil {
t.Error("Call() expected error, got nil")
}
if GetState(cb) != StateOpen {
t.Errorf("state = %v, want %v", GetState(cb), StateOpen)
}
}
func TestCircuitBreaker_Call_GradualFailureReduction(t *testing.T) {
cb := NewCircuitBreaker("test", 5, 1*time.Second)
failFn := func() error {
return errors.New("service error")
}
successFn := func() error {
return nil
}
// Add 3 failures
for i := 0; i < 3; i++ {
Call(cb, failFn)
}
cb.mutex.RLock()
failures := cb.failures
cb.mutex.RUnlock()
if failures != 3 {
t.Fatalf("failures = %v, want 3", failures)
}
// One success should reduce failure count
Call(cb, successFn)
cb.mutex.RLock()
failures = cb.failures
cb.mutex.RUnlock()
if failures != 2 {
t.Errorf("failures after success = %v, want 2", failures)
}
}
func TestCircuitBreaker_GetState(t *testing.T) {
tests := []struct {
name string
setupFunc func(*CircuitBreaker)
wantState CircuitState
}{
{
name: "returns closed state",
setupFunc: func(cb *CircuitBreaker) {
cb.state = StateClosed
},
wantState: StateClosed,
},
{
name: "returns open state",
setupFunc: func(cb *CircuitBreaker) {
cb.state = StateOpen
},
wantState: StateOpen,
},
{
name: "returns half-open state",
setupFunc: func(cb *CircuitBreaker) {
cb.state = StateHalfOpen
},
wantState: StateHalfOpen,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cb := NewCircuitBreaker("test", 3, 1*time.Second)
tt.setupFunc(cb)
got := GetState(cb)
if got != tt.wantState {
t.Errorf("GetState() = %v, want %v", got, tt.wantState)
}
})
}
}
func TestCircuitBreaker_Reset(t *testing.T) {
cb := NewCircuitBreaker("test", 2, 1*time.Second)
// Open the circuit
failFn := func() error {
return errors.New("error")
}
Call(cb, failFn)
Call(cb, failFn)
if GetState(cb) != StateOpen {
t.Fatalf("state = %v, want %v", GetState(cb), StateOpen)
}
// Reset the circuit breaker
Reset(cb)
if GetState(cb) != StateClosed {
t.Errorf("state after Reset() = %v, want %v", GetState(cb), StateClosed)
}
cb.mutex.RLock()
failures := cb.failures
cb.mutex.RUnlock()
if failures != 0 {
t.Errorf("failures after Reset() = %v, want 0", failures)
}
}
func TestCircuitBreakerError_Error(t *testing.T) {
tests := []struct {
name string
err *CircuitBreakerError
wantError string
}{
{
name: "formats error message correctly",
err: &CircuitBreakerError{
Name: "database",
State: "open",
},
wantError: "circuit breaker 'database' is open",
},
{
name: "formats error message with different state",
err: &CircuitBreakerError{
Name: "redis",
State: "half-open",
},
wantError: "circuit breaker 'redis' is half-open",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.err.Error()
if got != tt.wantError {
t.Errorf("Error() = %v, want %v", got, tt.wantError)
}
})
}
}
func TestIsCircuitBreakerError(t *testing.T) {
tests := []struct {
name string
err error
want bool
}{
{
name: "returns true for CircuitBreakerError",
err: &CircuitBreakerError{
Name: "test",
State: "open",
},
want: true,
},
{
name: "returns false for regular error",
err: errors.New("regular error"),
want: false,
},
{
name: "returns false for nil error",
err: nil,
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := IsCircuitBreakerError(tt.err)
if got != tt.want {
t.Errorf("IsCircuitBreakerError() = %v, want %v", got, tt.want)
}
})
}
}
func TestCircuitBreaker_Concurrency(t *testing.T) {
cb := NewCircuitBreaker("test", 10, 1*time.Second)
var wg sync.WaitGroup
successCount := 0
errorCount := 0
var countMutex sync.Mutex
// Run 100 concurrent operations
for i := 0; i < 100; i++ {
wg.Add(1)
go func(index int) {
defer wg.Done()
fn := func() error {
// Alternate between success and failure
if index%2 == 0 {
return nil
}
return errors.New("error")
}
err := Call(cb, fn)
countMutex.Lock()
if err == nil {
successCount++
} else {
errorCount++
}
countMutex.Unlock()
}(i)
}
wg.Wait()
// Verify that all operations completed
if successCount+errorCount != 100 {
t.Errorf("total operations = %v, want 100", successCount+errorCount)
}
// Verify circuit breaker is in a valid state
state := GetState(cb)
if state != StateClosed && state != StateOpen && state != StateHalfOpen {
t.Errorf("invalid state = %v", state)
}
}
func TestCircuitBreaker_OpenCircuitRejectsImmediately(t *testing.T) {
cb := NewCircuitBreaker("test", 1, 1*time.Second)
// Open the circuit
failFn := func() error {
return errors.New("error")
}
Call(cb, failFn)
if GetState(cb) != StateOpen {
t.Fatalf("state = %v, want %v", GetState(cb), StateOpen)
}
// Try calling with a function that should not execute
executed := false
testFn := func() error {
executed = true
return nil
}
err := Call(cb, testFn)
if !IsCircuitBreakerError(err) {
t.Errorf("expected CircuitBreakerError, got %v", err)
}
if executed {
t.Error("function should not have executed when circuit is open")
}
}
func BenchmarkCircuitBreaker_Call_Success(b *testing.B) {
cb := NewCircuitBreaker("test", 5, 1*time.Second)
fn := func() error {
return nil
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
Call(cb, fn)
}
}
func BenchmarkCircuitBreaker_Call_Open(b *testing.B) {
cb := NewCircuitBreaker("test", 1, 1*time.Second)
// Open the circuit
Call(cb, func() error {
return errors.New("error")
})
fn := func() error {
return nil
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
Call(cb, fn)
}
}