diff --git a/breaker.go b/breaker.go index 7800159..669e68d 100644 --- a/breaker.go +++ b/breaker.go @@ -108,13 +108,26 @@ func (e *EWMABreaker) observe(halfOpen, failure bool) stateChange { value = 1.0 } - // Unconditionally setting via swap and maybe overwriting is faster in the initial case. - failureRate := fromStore(e.failureRate.Swap(toStore(value))) - if failureRate == math.SmallestNonzeroFloat64 { - failureRate = value - } else { - failureRate = (value * e.decay) + (failureRate * (1 - e.decay)) - e.failureRate.Store(toStore(failureRate)) + // Use CompareAndSwap loop to atomically update the EWMA to avoid race conditions + // where concurrent observations could read raw values instead of the EWMA. + var failureRate float64 + for { + oldBits := e.failureRate.Load() + oldRate := fromStore(oldBits) + + if oldRate == math.SmallestNonzeroFloat64 { + // First observation - initialize with the current value + failureRate = value + } else { + // Compute EWMA + failureRate = (value * e.decay) + (oldRate * (1 - e.decay)) + } + + // Try to swap in the new rate atomically + if e.failureRate.CompareAndSwap(oldBits, toStore(failureRate)) { + break + } + // If CAS failed, another goroutine updated it; retry } if failureRate > e.threshold { @@ -192,6 +205,9 @@ func (s *SlidingWindowBreaker) observe(halfOpen, failure bool) stateChange { // overwrite the last counts to some near zero value. if sinceStart > s.windowSize && firstCallInNewWindow { sinceStart = 0 + // Atomically move current window counts to last window (lines below). + // Note: after these swaps, other goroutines may increment the current counters with their observations, + // which is correct - those observations will belong to the new window that just started. lastFailureCount = s.lastFailureCount.Swap(s.currentFailureCount.Swap(0)) lastSuccessCount = s.lastSuccessCount.Swap(s.currentSuccessCount.Swap(0)) } else { diff --git a/go.mod b/go.mod index 6a58c4a..d1bd987 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,6 @@ require ( github.com/kr/pretty v0.3.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rogpeppe/go-internal v1.10.0 // indirect - google.golang.org/protobuf v1.34.2 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 17f860e..e4f1d94 100644 --- a/go.sum +++ b/go.sum @@ -24,8 +24,6 @@ github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE= golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= -google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= diff --git a/hoglet.go b/hoglet.go index fb17a0a..3fe5f22 100644 --- a/hoglet.go +++ b/hoglet.go @@ -138,16 +138,31 @@ func (c *Circuit[IN, OUT]) State() State { // stateForCall returns the state of the circuit meant for the next call. // It wraps [State] to keep the mutable part outside of the external API. func (c *Circuit[IN, OUT]) stateForCall() State { - state := c.State() + for { + oa := c.openedAt.Load() - if state == StateHalfOpen { - // We reset openedAt to block further calls to pass through when half-open. A success will cause the breaker to - // close. This is slightly racy: multiple goroutines may reach this point concurrently since we do not lock the - // breaker. - c.reopen() - } + if oa == 0 { + // closed + return StateClosed + } + + if c.halfOpenDelay == 0 || time.Since(time.UnixMicro(oa)) < c.halfOpenDelay { + // open + return StateOpen + } - return state + // half-open: try to atomically transition to reopened state + // Only one goroutine should succeed, limiting concurrent calls in half-open to ~1 + reopenedAt := time.Now().UnixMicro() + if c.openedAt.CompareAndSwap(oa, reopenedAt) { + // This goroutine won the race and can proceed with a call + return StateHalfOpen + } + + // Another goroutine already transitioned from half-open; re-check the new state + // (should typically return StateOpen since we just reopened, or StateClosed if the call succeeded) + // Loop to avoid stack overflow in high-contention scenarios + } } // open marks the circuit as open, if it not already. diff --git a/race_test.go b/race_test.go new file mode 100644 index 0000000..2920cc2 --- /dev/null +++ b/race_test.go @@ -0,0 +1,233 @@ +package hoglet + +import ( + "context" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestEWMABreaker_ConcurrentObservations tests that concurrent observations +// don't cause incorrect EWMA calculations due to race conditions. +func TestEWMABreaker_ConcurrentObservations(t *testing.T) { + breaker := NewEWMABreaker(10, 0.5) + + // Run many concurrent observations + const numGoroutines = 100 + const observationsPerGoroutine = 100 + + var wg sync.WaitGroup + wg.Add(numGoroutines) + + for i := 0; i < numGoroutines; i++ { + go func(id int) { + defer wg.Done() + // Alternate between success and failure + for j := 0; j < observationsPerGoroutine; j++ { + failure := (id+j)%2 == 0 + breaker.observe(false, failure) + } + }(i) + } + + wg.Wait() + + // With 50% failures and threshold of 0.5, the breaker should eventually stabilize + // The exact value depends on the EWMA calculation, but it should be close to 0.5 + // and not have corrupted values + finalRate := fromStore(breaker.failureRate.Load()) + + // The rate should be between 0 and 1 + assert.GreaterOrEqual(t, finalRate, 0.0, "failure rate should be >= 0") + assert.LessOrEqual(t, finalRate, 1.0, "failure rate should be <= 1") + + // With many observations at ~50% failure rate, it should converge near 0.5 + // Allow some variance due to the EWMA nature + assert.InDelta(t, 0.5, finalRate, 0.3, "failure rate should converge near 50%") +} + +// TestSlidingWindowBreaker_ConcurrentWindowSwap tests that concurrent calls +// during window swapping don't lose counts or produce incorrect results. +func TestSlidingWindowBreaker_ConcurrentWindowSwap(t *testing.T) { + windowSize := 100 * time.Millisecond + breaker := NewSlidingWindowBreaker(windowSize, 0.5) + + // Start with some initial failures in the first window + for i := 0; i < 50; i++ { + breaker.observe(false, true) + } + for i := 0; i < 50; i++ { + breaker.observe(false, false) + } + + // Sleep to ensure we're past the window + time.Sleep(windowSize + 10*time.Millisecond) + + // Now trigger concurrent observations that should cause a window swap + const numGoroutines = 50 + var wg sync.WaitGroup + wg.Add(numGoroutines) + + var successCount, failureCount atomic.Int64 + + for i := 0; i < numGoroutines; i++ { + go func(id int) { + defer wg.Done() + failure := id%2 == 0 + breaker.observe(false, failure) + if failure { + failureCount.Add(1) + } else { + successCount.Add(1) + } + }(i) + } + + wg.Wait() + + // Verify counts are consistent (no lost observations) + currentSuccess := breaker.currentSuccessCount.Load() + currentFailure := breaker.currentFailureCount.Load() + lastSuccess := breaker.lastSuccessCount.Load() + lastFailure := breaker.lastFailureCount.Load() + + totalInBreaker := currentSuccess + currentFailure + lastSuccess + lastFailure + totalObserved := successCount.Load() + failureCount.Load() + + // The breaker should have tracked all observations (some might be in old window) + // At minimum, current window should have the observations + assert.GreaterOrEqual(t, totalInBreaker, totalObserved-100, + "breaker should track most observations, current+last=%d, observed=%d", + totalInBreaker, totalObserved) +} + +// TestCircuit_HalfOpenConcurrency tests that the half-open state properly limits +// concurrent calls to ~1, not allowing many calls through simultaneously. +func TestCircuit_HalfOpenConcurrency(t *testing.T) { + var callsInProgress atomic.Int32 + var maxConcurrent atomic.Int32 + var callsCompleted atomic.Int32 + + slowFunc := func(ctx context.Context, in int) (int, error) { + current := callsInProgress.Add(1) + defer callsInProgress.Add(-1) + + // Update max concurrent + for { + max := maxConcurrent.Load() + if current <= max || maxConcurrent.CompareAndSwap(max, current) { + break + } + } + + // Slow down the call to give time for concurrent calls + time.Sleep(50 * time.Millisecond) + callsCompleted.Add(1) + return in, nil + } + + // Create a breaker that opens immediately on first failure + breaker := NewEWMABreaker(1, 0.01) + c, err := NewCircuit(slowFunc, breaker, WithHalfOpenDelay(100*time.Millisecond)) + require.NoError(t, err) + + // Make it fail to open the circuit + c.Call(context.Background(), -1) + c.open() + assert.Equal(t, StateOpen, c.State()) + + // Wait for half-open + time.Sleep(150 * time.Millisecond) + assert.Equal(t, StateHalfOpen, c.State()) + + // Try to make many concurrent calls in half-open state + const numConcurrent = 20 + var wg sync.WaitGroup + wg.Add(numConcurrent) + + for i := 0; i < numConcurrent; i++ { + go func(id int) { + defer wg.Done() + c.Call(context.Background(), id) + }(i) + } + + wg.Wait() + + maxConcurrentCalls := maxConcurrent.Load() + completedCalls := callsCompleted.Load() + + t.Logf("Max concurrent calls in half-open: %d", maxConcurrentCalls) + t.Logf("Completed calls: %d", completedCalls) + + // In half-open state, we should limit to ~1 call, definitely not all 20 + // The comment says "limited (~1)", so we allow a small number due to race conditions + // But definitely should not be close to numConcurrent + assert.LessOrEqual(t, maxConcurrentCalls, int32(5), + "half-open should limit concurrent calls to ~1, not %d", maxConcurrentCalls) +} + +// TestCircuit_ConcurrentStateChanges tests that concurrent calls don't cause +// incorrect state changes that would affect unrelated calls. +func TestCircuit_ConcurrentStateChanges(t *testing.T) { + var successCount, failureCount, circuitOpenCount atomic.Int32 + + testFunc := func(ctx context.Context, shouldFail bool) (bool, error) { + if shouldFail { + failureCount.Add(1) + return false, assert.AnError + } + successCount.Add(1) + return true, nil + } + + // Breaker that opens quickly (low threshold, small sample) + breaker := NewEWMABreaker(5, 0.3) + c, err := NewCircuit(testFunc, breaker, WithHalfOpenDelay(50*time.Millisecond)) + require.NoError(t, err) + + const numGoroutines = 100 + const callsPerGoroutine = 10 + + var wg sync.WaitGroup + wg.Add(numGoroutines) + + for i := 0; i < numGoroutines; i++ { + go func(id int) { + defer wg.Done() + for j := 0; j < callsPerGoroutine; j++ { + // Mix successful and failing calls + shouldFail := (id*callsPerGoroutine+j)%3 == 0 + _, err := c.Call(context.Background(), shouldFail) + if err == ErrCircuitOpen { + circuitOpenCount.Add(1) + } + } + }(i) + } + + wg.Wait() + + totalSuccesses := successCount.Load() + totalFailures := failureCount.Load() + totalCircuitOpen := circuitOpenCount.Load() + totalAttempts := int32(numGoroutines * callsPerGoroutine) + + t.Logf("Successes: %d, Failures: %d, Circuit Open: %d, Total: %d", + totalSuccesses, totalFailures, totalCircuitOpen, totalAttempts) + + // Verify accounting: all attempts should be accounted for + assert.Equal(t, totalAttempts, totalSuccesses+totalFailures+totalCircuitOpen, + "all calls should be accounted for") + + // With ~33% failures, the circuit should open at some point + assert.Greater(t, totalCircuitOpen, int32(0), "circuit should have opened") + + // But not all calls should be blocked (circuit should close again eventually) + assert.Less(t, totalCircuitOpen, totalAttempts, + "not all calls should be blocked - circuit should recover") +}