diff --git a/db/db.go b/db/db.go index b872bb9..dfaa722 100644 --- a/db/db.go +++ b/db/db.go @@ -1,7 +1,6 @@ package db import ( - "authorization/helper" "database/sql" "fmt" "log" @@ -14,9 +13,6 @@ import ( // DB is the global database connection pool var DB *sql.DB -// DBCircuitBreaker protects database operations -var DBCircuitBreaker *helper.CircuitBreaker - func InitDB() (*sql.DB, error) { // Get connection details from environment variables (loaded in main) connStr := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?parseTime=true", @@ -33,19 +29,14 @@ func InitDB() (*sql.DB, error) { if err != nil { return nil, fmt.Errorf("error opening database: %v", err) } - // Initialize circuit breaker - DBCircuitBreaker = helper.NewCircuitBreaker("database", 5, 2*time.Second) - // Set connection pool parameters optimized for horizontal scaling // Lower per-replica to allow more replicas without exhausting DB connections DB.SetMaxOpenConns(25) // Maximum number of open connections to the database DB.SetMaxIdleConns(10) // Maximum number of connections in the idle connection pool DB.SetConnMaxLifetime(5 * time.Minute) // Maximum amount of time a connection may be reused - // Check if the database connection is working with circuit breaker - err = DBCircuitBreaker.Call(func() error { - return DB.Ping() - }) + // Check if the database connection is working + err = DB.Ping() if err != nil { log.Printf("Database connection lost: %v. Reconnecting...", err) DB, err = InitDB() diff --git a/helper/circuit_breaker.go b/helper/circuit_breaker.go index 994a129..14aa6da 100644 --- a/helper/circuit_breaker.go +++ b/helper/circuit_breaker.go @@ -38,7 +38,7 @@ func NewCircuitBreaker(name string, maxFailures int, timeout time.Duration) *Cir } // Call executes the given function with circuit breaker protection -func (cb *CircuitBreaker) Call(fn func() error) error { +func Call(cb *CircuitBreaker, fn func() error) error { cb.mutex.Lock() // Check if circuit should transition from Open to HalfOpen @@ -94,14 +94,14 @@ func (cb *CircuitBreaker) Call(fn func() error) error { } // GetState returns the current state of the circuit breaker -func (cb *CircuitBreaker) GetState() CircuitState { +func GetState(cb *CircuitBreaker) CircuitState { cb.mutex.RLock() defer cb.mutex.RUnlock() return cb.state } // Reset manually resets the circuit breaker -func (cb *CircuitBreaker) Reset() { +func Reset(cb *CircuitBreaker) { cb.mutex.Lock() defer cb.mutex.Unlock() cb.state = StateClosed diff --git a/helper/circuit_breaker_test.go b/helper/circuit_breaker_test.go new file mode 100644 index 0000000..a28443a --- /dev/null +++ b/helper/circuit_breaker_test.go @@ -0,0 +1,459 @@ +package helper + +import ( + "errors" + "sync" + "testing" + "time" +) + +func TestNewCircuitBreaker(t *testing.T) { + tests := []struct { + name string + cbName string + maxFailures int + timeout time.Duration + wantState CircuitState + }{ + { + name: "creates circuit breaker with correct defaults", + cbName: "test-service", + maxFailures: 5, + timeout: 2 * time.Second, + wantState: StateClosed, + }, + { + name: "creates circuit breaker with different parameters", + cbName: "db-service", + maxFailures: 3, + timeout: 1 * time.Second, + wantState: StateClosed, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cb := NewCircuitBreaker(tt.cbName, tt.maxFailures, tt.timeout) + + if cb.name != tt.cbName { + t.Errorf("name = %v, want %v", cb.name, tt.cbName) + } + if cb.maxFailures != tt.maxFailures { + t.Errorf("maxFailures = %v, want %v", cb.maxFailures, tt.maxFailures) + } + if cb.timeout != tt.timeout { + t.Errorf("timeout = %v, want %v", cb.timeout, tt.timeout) + } + if cb.state != tt.wantState { + t.Errorf("state = %v, want %v", cb.state, tt.wantState) + } + if cb.resetTimeout != 30*time.Second { + t.Errorf("resetTimeout = %v, want %v", cb.resetTimeout, 30*time.Second) + } + if cb.failures != 0 { + t.Errorf("failures = %v, want 0", cb.failures) + } + }) + } +} + +func TestCircuitBreaker_Call_Success(t *testing.T) { + cb := NewCircuitBreaker("test", 3, 1*time.Second) + + successFn := func() error { + return nil + } + + err := Call(cb, successFn) + if err != nil { + t.Errorf("Call() error = %v, want nil", err) + } + + if GetState(cb) != StateClosed { + t.Errorf("state = %v, want %v", GetState(cb), StateClosed) + } +} + +func TestCircuitBreaker_Call_FailuresOpenCircuit(t *testing.T) { + cb := NewCircuitBreaker("test", 3, 1*time.Second) + + failFn := func() error { + return errors.New("service error") + } + + // First 2 failures - circuit should stay closed + for i := 0; i < 2; i++ { + err := Call(cb, failFn) + if err == nil { + t.Errorf("Call() iteration %d: expected error, got nil", i) + } + if GetState(cb) != StateClosed { + t.Errorf("iteration %d: state = %v, want %v", i, GetState(cb), StateClosed) + } + } + + // 3rd failure - circuit should open + err := Call(cb, failFn) + if err == nil { + t.Error("Call() expected error, got nil") + } + if GetState(cb) != StateOpen { + t.Errorf("state = %v, want %v", GetState(cb), StateOpen) + } + + // Next call should immediately return circuit breaker error + err = Call(cb, failFn) + if !IsCircuitBreakerError(err) { + t.Errorf("expected CircuitBreakerError, got %v", err) + } +} + +func TestCircuitBreaker_Call_OpenToHalfOpen(t *testing.T) { + cb := NewCircuitBreaker("test", 2, 1*time.Second) + cb.resetTimeout = 100 * time.Millisecond // Shorter reset for testing + + failFn := func() error { + return errors.New("service error") + } + + // Open the circuit + Call(cb, failFn) + Call(cb, failFn) + + if GetState(cb) != StateOpen { + t.Fatalf("state = %v, want %v", GetState(cb), StateOpen) + } + + // Wait for reset timeout + time.Sleep(150 * time.Millisecond) + + // Next call should transition to HalfOpen + successFn := func() error { + return nil + } + + err := Call(cb, successFn) + if err != nil { + t.Errorf("Call() error = %v, want nil", err) + } + + // Should now be closed + if GetState(cb) != StateClosed { + t.Errorf("state = %v, want %v", GetState(cb), StateClosed) + } +} + +func TestCircuitBreaker_Call_HalfOpenFailReturnsToOpen(t *testing.T) { + cb := NewCircuitBreaker("test", 2, 1*time.Second) + cb.resetTimeout = 100 * time.Millisecond + + failFn := func() error { + return errors.New("service error") + } + + // Open the circuit + Call(cb, failFn) + Call(cb, failFn) + + if GetState(cb) != StateOpen { + t.Fatalf("state = %v, want %v", GetState(cb), StateOpen) + } + + // Wait for reset timeout to transition to HalfOpen + time.Sleep(150 * time.Millisecond) + + // Fail in HalfOpen state - should return to Open + err := Call(cb, failFn) + if err == nil { + t.Error("Call() expected error, got nil") + } + + if GetState(cb) != StateOpen { + t.Errorf("state = %v, want %v", GetState(cb), StateOpen) + } +} + +func TestCircuitBreaker_Call_GradualFailureReduction(t *testing.T) { + cb := NewCircuitBreaker("test", 5, 1*time.Second) + + failFn := func() error { + return errors.New("service error") + } + successFn := func() error { + return nil + } + + // Add 3 failures + for i := 0; i < 3; i++ { + Call(cb, failFn) + } + + cb.mutex.RLock() + failures := cb.failures + cb.mutex.RUnlock() + + if failures != 3 { + t.Fatalf("failures = %v, want 3", failures) + } + + // One success should reduce failure count + Call(cb, successFn) + + cb.mutex.RLock() + failures = cb.failures + cb.mutex.RUnlock() + + if failures != 2 { + t.Errorf("failures after success = %v, want 2", failures) + } +} + +func TestCircuitBreaker_GetState(t *testing.T) { + tests := []struct { + name string + setupFunc func(*CircuitBreaker) + wantState CircuitState + }{ + { + name: "returns closed state", + setupFunc: func(cb *CircuitBreaker) { + cb.state = StateClosed + }, + wantState: StateClosed, + }, + { + name: "returns open state", + setupFunc: func(cb *CircuitBreaker) { + cb.state = StateOpen + }, + wantState: StateOpen, + }, + { + name: "returns half-open state", + setupFunc: func(cb *CircuitBreaker) { + cb.state = StateHalfOpen + }, + wantState: StateHalfOpen, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cb := NewCircuitBreaker("test", 3, 1*time.Second) + tt.setupFunc(cb) + + got := GetState(cb) + if got != tt.wantState { + t.Errorf("GetState() = %v, want %v", got, tt.wantState) + } + }) + } +} + +func TestCircuitBreaker_Reset(t *testing.T) { + cb := NewCircuitBreaker("test", 2, 1*time.Second) + + // Open the circuit + failFn := func() error { + return errors.New("error") + } + Call(cb, failFn) + Call(cb, failFn) + + if GetState(cb) != StateOpen { + t.Fatalf("state = %v, want %v", GetState(cb), StateOpen) + } + + // Reset the circuit breaker + Reset(cb) + + if GetState(cb) != StateClosed { + t.Errorf("state after Reset() = %v, want %v", GetState(cb), StateClosed) + } + + cb.mutex.RLock() + failures := cb.failures + cb.mutex.RUnlock() + + if failures != 0 { + t.Errorf("failures after Reset() = %v, want 0", failures) + } +} + +func TestCircuitBreakerError_Error(t *testing.T) { + tests := []struct { + name string + err *CircuitBreakerError + wantError string + }{ + { + name: "formats error message correctly", + err: &CircuitBreakerError{ + Name: "database", + State: "open", + }, + wantError: "circuit breaker 'database' is open", + }, + { + name: "formats error message with different state", + err: &CircuitBreakerError{ + Name: "redis", + State: "half-open", + }, + wantError: "circuit breaker 'redis' is half-open", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.err.Error() + if got != tt.wantError { + t.Errorf("Error() = %v, want %v", got, tt.wantError) + } + }) + } +} + +func TestIsCircuitBreakerError(t *testing.T) { + tests := []struct { + name string + err error + want bool + }{ + { + name: "returns true for CircuitBreakerError", + err: &CircuitBreakerError{ + Name: "test", + State: "open", + }, + want: true, + }, + { + name: "returns false for regular error", + err: errors.New("regular error"), + want: false, + }, + { + name: "returns false for nil error", + err: nil, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := IsCircuitBreakerError(tt.err) + if got != tt.want { + t.Errorf("IsCircuitBreakerError() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestCircuitBreaker_Concurrency(t *testing.T) { + cb := NewCircuitBreaker("test", 10, 1*time.Second) + + var wg sync.WaitGroup + successCount := 0 + errorCount := 0 + var countMutex sync.Mutex + + // Run 100 concurrent operations + for i := 0; i < 100; i++ { + wg.Add(1) + go func(index int) { + defer wg.Done() + + fn := func() error { + // Alternate between success and failure + if index%2 == 0 { + return nil + } + return errors.New("error") + } + + err := Call(cb, fn) + countMutex.Lock() + if err == nil { + successCount++ + } else { + errorCount++ + } + countMutex.Unlock() + }(i) + } + + wg.Wait() + + // Verify that all operations completed + if successCount+errorCount != 100 { + t.Errorf("total operations = %v, want 100", successCount+errorCount) + } + + // Verify circuit breaker is in a valid state + state := GetState(cb) + if state != StateClosed && state != StateOpen && state != StateHalfOpen { + t.Errorf("invalid state = %v", state) + } +} + +func TestCircuitBreaker_OpenCircuitRejectsImmediately(t *testing.T) { + cb := NewCircuitBreaker("test", 1, 1*time.Second) + + // Open the circuit + failFn := func() error { + return errors.New("error") + } + Call(cb, failFn) + + if GetState(cb) != StateOpen { + t.Fatalf("state = %v, want %v", GetState(cb), StateOpen) + } + + // Try calling with a function that should not execute + executed := false + testFn := func() error { + executed = true + return nil + } + + err := Call(cb, testFn) + + if !IsCircuitBreakerError(err) { + t.Errorf("expected CircuitBreakerError, got %v", err) + } + + if executed { + t.Error("function should not have executed when circuit is open") + } +} + +func BenchmarkCircuitBreaker_Call_Success(b *testing.B) { + cb := NewCircuitBreaker("test", 5, 1*time.Second) + fn := func() error { + return nil + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + Call(cb, fn) + } +} + +func BenchmarkCircuitBreaker_Call_Open(b *testing.B) { + cb := NewCircuitBreaker("test", 1, 1*time.Second) + + // Open the circuit + Call(cb, func() error { + return errors.New("error") + }) + + fn := func() error { + return nil + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + Call(cb, fn) + } +} diff --git a/redisclient/redis.go b/redisclient/redis.go index 9257751..22b9c22 100644 --- a/redisclient/redis.go +++ b/redisclient/redis.go @@ -1,20 +1,15 @@ package redisclient import ( - "authorization/helper" "context" "fmt" "os" - "time" "github.com/redis/go-redis/v9" ) var RDB *redis.Client -// RedisCircuitBreaker protects Redis operations -var RedisCircuitBreaker *helper.CircuitBreaker - func Init() { redisHost := os.Getenv("REDIS_HOST") if redisHost == "" { @@ -42,15 +37,9 @@ func Init() { RDB = redis.NewClient(opts) - // Initialize circuit breaker - RedisCircuitBreaker = helper.NewCircuitBreaker("redis", 5, 2*time.Second) - - // Test connection with authentication using circuit breaker + // Test connection with authentication ctx := context.Background() - err := RedisCircuitBreaker.Call(func() error { - _, err := RDB.Ping(ctx).Result() - return err - }) + _, err := RDB.Ping(ctx).Result() if err != nil { panic(fmt.Sprintf("Could not connect to Redis: %v", err)) }