Files
coder/aibridge/circuitbreaker/circuitbreaker_test.go
T
Paweł Banaszewski e00e85765b chore: move aibridge library code into coder repo (#24190)
This PR merges code from `coder/aibridge` repository into `coder/coder`.
It was split into 4 PRs for easier review but stacked PRs will need to
be merged into this PR so all checks pass.

* https://github.com/coder/coder/pull/24190 -> raw code copy (this PR,
before merging PRs on top of it, it was just 1 commit:
https://github.com/coder/coder/commit/70d33f33200c7e77df910957595715f81f9bec24)
* https://github.com/coder/coder/pull/24570 -> update imports in
`coder/coder` to use copied code
* https://github.com/coder/coder/pull/24586 -> linter fixes and CI
integration (also added README.md)
* https://github.com/coder/coder/pull/24571 -> added exclude to
scripts/check_emdash.sh check

Original PR message (before PR squash):
Moves coder/aibridge code into coder/coder repository.

Omitted files:

- `go.mod`, `go.sum`, `.gitignore`, `.github/workflows/ci.yml,`
`Makefile`, `LICENSE`, `README.md` (modified README.md is added later)
- `.github`, `example`, `buildinfo,` `scripts` directories

Simple verification script (will list omitted files)

```
tmp=$(mktemp -d)
echo "$tmp"
git clone --depth=1 https://github.com/coder/aibridge "$tmp/aibridge"
git clone --depth=1 --branch pb/aibridge-code-move https://github.com/coder/coder "$tmp/coder"
diff -rq --exclude=.git "$tmp/aibridge" "$tmp/coder/aibridge"
# rm -rf "$tmp"
```
2026-04-22 17:01:01 +02:00

224 lines
6.5 KiB
Go

package circuitbreaker_test
import (
"errors"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"time"
"github.com/sony/gobreaker/v2"
"github.com/stretchr/testify/assert"
"github.com/coder/coder/v2/aibridge/circuitbreaker"
"github.com/coder/coder/v2/aibridge/config"
)
func TestExecute_PerModelIsolation(t *testing.T) {
t.Parallel()
sonnetCalls := atomic.Int32{}
haikuCalls := atomic.Int32{}
cbs := circuitbreaker.NewProviderCircuitBreakers("test", &config.CircuitBreaker{
FailureThreshold: 1,
Interval: time.Minute,
Timeout: time.Minute,
MaxRequests: 1,
}, func(endpoint, model string, from, to gobreaker.State) {}, nil)
endpoint := "/v1/messages"
sonnetModel := "claude-sonnet-4-20250514"
haikuModel := "claude-3-5-haiku-20241022"
// Trip circuit on sonnet model (returns 429)
w := httptest.NewRecorder()
err := cbs.Execute(endpoint, sonnetModel, w, func(rw http.ResponseWriter) error {
sonnetCalls.Add(1)
rw.WriteHeader(http.StatusTooManyRequests)
return nil
})
assert.NoError(t, err)
assert.Equal(t, int32(1), sonnetCalls.Load())
// Second sonnet request should be blocked by circuit breaker
w = httptest.NewRecorder()
err = cbs.Execute(endpoint, sonnetModel, w, func(rw http.ResponseWriter) error {
sonnetCalls.Add(1)
rw.WriteHeader(http.StatusOK)
return nil
})
assert.True(t, errors.Is(err, circuitbreaker.ErrCircuitOpen))
assert.Equal(t, int32(1), sonnetCalls.Load()) // No new call
assert.Equal(t, http.StatusServiceUnavailable, w.Code)
// Haiku model on same endpoint should still work (independent circuit)
w = httptest.NewRecorder()
err = cbs.Execute(endpoint, haikuModel, w, func(rw http.ResponseWriter) error {
haikuCalls.Add(1)
rw.WriteHeader(http.StatusOK)
return nil
})
assert.NoError(t, err)
assert.Equal(t, int32(1), haikuCalls.Load())
}
func TestExecute_PerEndpointIsolation(t *testing.T) {
t.Parallel()
messagesCalls := atomic.Int32{}
completionsCalls := atomic.Int32{}
cbs := circuitbreaker.NewProviderCircuitBreakers("test", &config.CircuitBreaker{
FailureThreshold: 1,
Interval: time.Minute,
Timeout: time.Minute,
MaxRequests: 1,
}, func(endpoint, model string, from, to gobreaker.State) {}, nil)
model := "test-model"
// Trip circuit on /v1/messages endpoint (returns 429)
w := httptest.NewRecorder()
err := cbs.Execute("/v1/messages", model, w, func(rw http.ResponseWriter) error {
messagesCalls.Add(1)
rw.WriteHeader(http.StatusTooManyRequests)
return nil
})
assert.NoError(t, err)
assert.Equal(t, int32(1), messagesCalls.Load())
// Second /v1/messages request should be blocked
w = httptest.NewRecorder()
err = cbs.Execute("/v1/messages", model, w, func(rw http.ResponseWriter) error {
messagesCalls.Add(1)
rw.WriteHeader(http.StatusOK)
return nil
})
assert.True(t, errors.Is(err, circuitbreaker.ErrCircuitOpen))
assert.Equal(t, int32(1), messagesCalls.Load()) // No new call
assert.Equal(t, http.StatusServiceUnavailable, w.Code)
// /v1/chat/completions on same model should still work (different endpoint)
w = httptest.NewRecorder()
err = cbs.Execute("/v1/chat/completions", model, w, func(rw http.ResponseWriter) error {
completionsCalls.Add(1)
rw.WriteHeader(http.StatusOK)
return nil
})
assert.NoError(t, err)
assert.Equal(t, int32(1), completionsCalls.Load())
}
func TestExecute_CustomIsFailure(t *testing.T) {
t.Parallel()
var calls atomic.Int32
// Custom IsFailure that treats 502 as failure
cbs := circuitbreaker.NewProviderCircuitBreakers("test", &config.CircuitBreaker{
FailureThreshold: 1,
Interval: time.Minute,
Timeout: time.Minute,
MaxRequests: 1,
IsFailure: func(statusCode int) bool {
return statusCode == http.StatusBadGateway
},
}, func(endpoint, model string, from, to gobreaker.State) {}, nil)
// First request returns 502, trips circuit
w := httptest.NewRecorder()
err := cbs.Execute("/v1/messages", "test-model", w, func(rw http.ResponseWriter) error {
calls.Add(1)
rw.WriteHeader(http.StatusBadGateway)
return nil
})
assert.NoError(t, err)
assert.Equal(t, int32(1), calls.Load())
// Second request should be blocked
w = httptest.NewRecorder()
err = cbs.Execute("/v1/messages", "test-model", w, func(rw http.ResponseWriter) error {
calls.Add(1)
rw.WriteHeader(http.StatusOK)
return nil
})
assert.True(t, errors.Is(err, circuitbreaker.ErrCircuitOpen))
assert.Equal(t, int32(1), calls.Load()) // No new call
assert.Equal(t, http.StatusServiceUnavailable, w.Code)
}
func TestExecute_OnStateChange(t *testing.T) {
t.Parallel()
var stateChanges []struct {
endpoint string
model string
from gobreaker.State
to gobreaker.State
}
cbs := circuitbreaker.NewProviderCircuitBreakers("test", &config.CircuitBreaker{
FailureThreshold: 1,
Interval: time.Minute,
Timeout: time.Minute,
MaxRequests: 1,
}, func(endpoint, model string, from, to gobreaker.State) {
stateChanges = append(stateChanges, struct {
endpoint string
model string
from gobreaker.State
to gobreaker.State
}{endpoint, model, from, to})
}, nil)
endpoint := "/v1/messages"
model := "claude-sonnet-4-20250514"
// Trip circuit
w := httptest.NewRecorder()
err := cbs.Execute(endpoint, model, w, func(rw http.ResponseWriter) error {
rw.WriteHeader(http.StatusTooManyRequests)
return nil
})
assert.NoError(t, err)
// Verify state change callback was called with correct parameters
assert.Len(t, stateChanges, 1)
assert.Equal(t, endpoint, stateChanges[0].endpoint)
assert.Equal(t, model, stateChanges[0].model)
assert.Equal(t, gobreaker.StateClosed, stateChanges[0].from)
assert.Equal(t, gobreaker.StateOpen, stateChanges[0].to)
}
func TestDefaultIsFailure(t *testing.T) {
t.Parallel()
tests := []struct {
statusCode int
isFailure bool
}{
{http.StatusOK, false},
{http.StatusBadRequest, false},
{http.StatusUnauthorized, false},
{http.StatusTooManyRequests, true}, // 429
{http.StatusInternalServerError, false},
{http.StatusBadGateway, false},
{http.StatusServiceUnavailable, true}, // 503
{http.StatusGatewayTimeout, true}, // 504
}
for _, tt := range tests {
assert.Equal(t, tt.isFailure, circuitbreaker.DefaultIsFailure(tt.statusCode), "status code %d", tt.statusCode)
}
}
func TestStateToGaugeValue(t *testing.T) {
t.Parallel()
assert.Equal(t, float64(0), circuitbreaker.StateToGaugeValue(gobreaker.StateClosed))
assert.Equal(t, float64(0.5), circuitbreaker.StateToGaugeValue(gobreaker.StateHalfOpen))
assert.Equal(t, float64(1), circuitbreaker.StateToGaugeValue(gobreaker.StateOpen))
}