mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
e00e85765b
This PR merges code from `coder/aibridge` repository into `coder/coder`. It was split into 4 PRs for easier review but stacked PRs will need to be merged into this PR so all checks pass. * https://github.com/coder/coder/pull/24190 -> raw code copy (this PR, before merging PRs on top of it, it was just 1 commit: https://github.com/coder/coder/commit/70d33f33200c7e77df910957595715f81f9bec24) * https://github.com/coder/coder/pull/24570 -> update imports in `coder/coder` to use copied code * https://github.com/coder/coder/pull/24586 -> linter fixes and CI integration (also added README.md) * https://github.com/coder/coder/pull/24571 -> added exclude to scripts/check_emdash.sh check Original PR message (before PR squash): Moves coder/aibridge code into coder/coder repository. Omitted files: - `go.mod`, `go.sum`, `.gitignore`, `.github/workflows/ci.yml,` `Makefile`, `LICENSE`, `README.md` (modified README.md is added later) - `.github`, `example`, `buildinfo,` `scripts` directories Simple verification script (will list omitted files) ``` tmp=$(mktemp -d) echo "$tmp" git clone --depth=1 https://github.com/coder/aibridge "$tmp/aibridge" git clone --depth=1 --branch pb/aibridge-code-move https://github.com/coder/coder "$tmp/coder" diff -rq --exclude=.git "$tmp/aibridge" "$tmp/coder/aibridge" # rm -rf "$tmp" ```
224 lines
6.5 KiB
Go
224 lines
6.5 KiB
Go
package circuitbreaker_test
|
|
|
|
import (
|
|
"errors"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/sony/gobreaker/v2"
|
|
"github.com/stretchr/testify/assert"
|
|
|
|
"github.com/coder/coder/v2/aibridge/circuitbreaker"
|
|
"github.com/coder/coder/v2/aibridge/config"
|
|
)
|
|
|
|
func TestExecute_PerModelIsolation(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
sonnetCalls := atomic.Int32{}
|
|
haikuCalls := atomic.Int32{}
|
|
|
|
cbs := circuitbreaker.NewProviderCircuitBreakers("test", &config.CircuitBreaker{
|
|
FailureThreshold: 1,
|
|
Interval: time.Minute,
|
|
Timeout: time.Minute,
|
|
MaxRequests: 1,
|
|
}, func(endpoint, model string, from, to gobreaker.State) {}, nil)
|
|
|
|
endpoint := "/v1/messages"
|
|
sonnetModel := "claude-sonnet-4-20250514"
|
|
haikuModel := "claude-3-5-haiku-20241022"
|
|
|
|
// Trip circuit on sonnet model (returns 429)
|
|
w := httptest.NewRecorder()
|
|
err := cbs.Execute(endpoint, sonnetModel, w, func(rw http.ResponseWriter) error {
|
|
sonnetCalls.Add(1)
|
|
rw.WriteHeader(http.StatusTooManyRequests)
|
|
return nil
|
|
})
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, int32(1), sonnetCalls.Load())
|
|
|
|
// Second sonnet request should be blocked by circuit breaker
|
|
w = httptest.NewRecorder()
|
|
err = cbs.Execute(endpoint, sonnetModel, w, func(rw http.ResponseWriter) error {
|
|
sonnetCalls.Add(1)
|
|
rw.WriteHeader(http.StatusOK)
|
|
return nil
|
|
})
|
|
assert.True(t, errors.Is(err, circuitbreaker.ErrCircuitOpen))
|
|
assert.Equal(t, int32(1), sonnetCalls.Load()) // No new call
|
|
assert.Equal(t, http.StatusServiceUnavailable, w.Code)
|
|
|
|
// Haiku model on same endpoint should still work (independent circuit)
|
|
w = httptest.NewRecorder()
|
|
err = cbs.Execute(endpoint, haikuModel, w, func(rw http.ResponseWriter) error {
|
|
haikuCalls.Add(1)
|
|
rw.WriteHeader(http.StatusOK)
|
|
return nil
|
|
})
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, int32(1), haikuCalls.Load())
|
|
}
|
|
|
|
func TestExecute_PerEndpointIsolation(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
messagesCalls := atomic.Int32{}
|
|
completionsCalls := atomic.Int32{}
|
|
|
|
cbs := circuitbreaker.NewProviderCircuitBreakers("test", &config.CircuitBreaker{
|
|
FailureThreshold: 1,
|
|
Interval: time.Minute,
|
|
Timeout: time.Minute,
|
|
MaxRequests: 1,
|
|
}, func(endpoint, model string, from, to gobreaker.State) {}, nil)
|
|
|
|
model := "test-model"
|
|
|
|
// Trip circuit on /v1/messages endpoint (returns 429)
|
|
w := httptest.NewRecorder()
|
|
err := cbs.Execute("/v1/messages", model, w, func(rw http.ResponseWriter) error {
|
|
messagesCalls.Add(1)
|
|
rw.WriteHeader(http.StatusTooManyRequests)
|
|
return nil
|
|
})
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, int32(1), messagesCalls.Load())
|
|
|
|
// Second /v1/messages request should be blocked
|
|
w = httptest.NewRecorder()
|
|
err = cbs.Execute("/v1/messages", model, w, func(rw http.ResponseWriter) error {
|
|
messagesCalls.Add(1)
|
|
rw.WriteHeader(http.StatusOK)
|
|
return nil
|
|
})
|
|
assert.True(t, errors.Is(err, circuitbreaker.ErrCircuitOpen))
|
|
assert.Equal(t, int32(1), messagesCalls.Load()) // No new call
|
|
assert.Equal(t, http.StatusServiceUnavailable, w.Code)
|
|
|
|
// /v1/chat/completions on same model should still work (different endpoint)
|
|
w = httptest.NewRecorder()
|
|
err = cbs.Execute("/v1/chat/completions", model, w, func(rw http.ResponseWriter) error {
|
|
completionsCalls.Add(1)
|
|
rw.WriteHeader(http.StatusOK)
|
|
return nil
|
|
})
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, int32(1), completionsCalls.Load())
|
|
}
|
|
|
|
func TestExecute_CustomIsFailure(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var calls atomic.Int32
|
|
|
|
// Custom IsFailure that treats 502 as failure
|
|
cbs := circuitbreaker.NewProviderCircuitBreakers("test", &config.CircuitBreaker{
|
|
FailureThreshold: 1,
|
|
Interval: time.Minute,
|
|
Timeout: time.Minute,
|
|
MaxRequests: 1,
|
|
IsFailure: func(statusCode int) bool {
|
|
return statusCode == http.StatusBadGateway
|
|
},
|
|
}, func(endpoint, model string, from, to gobreaker.State) {}, nil)
|
|
|
|
// First request returns 502, trips circuit
|
|
w := httptest.NewRecorder()
|
|
err := cbs.Execute("/v1/messages", "test-model", w, func(rw http.ResponseWriter) error {
|
|
calls.Add(1)
|
|
rw.WriteHeader(http.StatusBadGateway)
|
|
return nil
|
|
})
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, int32(1), calls.Load())
|
|
|
|
// Second request should be blocked
|
|
w = httptest.NewRecorder()
|
|
err = cbs.Execute("/v1/messages", "test-model", w, func(rw http.ResponseWriter) error {
|
|
calls.Add(1)
|
|
rw.WriteHeader(http.StatusOK)
|
|
return nil
|
|
})
|
|
assert.True(t, errors.Is(err, circuitbreaker.ErrCircuitOpen))
|
|
assert.Equal(t, int32(1), calls.Load()) // No new call
|
|
assert.Equal(t, http.StatusServiceUnavailable, w.Code)
|
|
}
|
|
|
|
func TestExecute_OnStateChange(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var stateChanges []struct {
|
|
endpoint string
|
|
model string
|
|
from gobreaker.State
|
|
to gobreaker.State
|
|
}
|
|
|
|
cbs := circuitbreaker.NewProviderCircuitBreakers("test", &config.CircuitBreaker{
|
|
FailureThreshold: 1,
|
|
Interval: time.Minute,
|
|
Timeout: time.Minute,
|
|
MaxRequests: 1,
|
|
}, func(endpoint, model string, from, to gobreaker.State) {
|
|
stateChanges = append(stateChanges, struct {
|
|
endpoint string
|
|
model string
|
|
from gobreaker.State
|
|
to gobreaker.State
|
|
}{endpoint, model, from, to})
|
|
}, nil)
|
|
|
|
endpoint := "/v1/messages"
|
|
model := "claude-sonnet-4-20250514"
|
|
|
|
// Trip circuit
|
|
w := httptest.NewRecorder()
|
|
err := cbs.Execute(endpoint, model, w, func(rw http.ResponseWriter) error {
|
|
rw.WriteHeader(http.StatusTooManyRequests)
|
|
return nil
|
|
})
|
|
assert.NoError(t, err)
|
|
|
|
// Verify state change callback was called with correct parameters
|
|
assert.Len(t, stateChanges, 1)
|
|
assert.Equal(t, endpoint, stateChanges[0].endpoint)
|
|
assert.Equal(t, model, stateChanges[0].model)
|
|
assert.Equal(t, gobreaker.StateClosed, stateChanges[0].from)
|
|
assert.Equal(t, gobreaker.StateOpen, stateChanges[0].to)
|
|
}
|
|
|
|
func TestDefaultIsFailure(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
statusCode int
|
|
isFailure bool
|
|
}{
|
|
{http.StatusOK, false},
|
|
{http.StatusBadRequest, false},
|
|
{http.StatusUnauthorized, false},
|
|
{http.StatusTooManyRequests, true}, // 429
|
|
{http.StatusInternalServerError, false},
|
|
{http.StatusBadGateway, false},
|
|
{http.StatusServiceUnavailable, true}, // 503
|
|
{http.StatusGatewayTimeout, true}, // 504
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
assert.Equal(t, tt.isFailure, circuitbreaker.DefaultIsFailure(tt.statusCode), "status code %d", tt.statusCode)
|
|
}
|
|
}
|
|
|
|
func TestStateToGaugeValue(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
assert.Equal(t, float64(0), circuitbreaker.StateToGaugeValue(gobreaker.StateClosed))
|
|
assert.Equal(t, float64(0.5), circuitbreaker.StateToGaugeValue(gobreaker.StateHalfOpen))
|
|
assert.Equal(t, float64(1), circuitbreaker.StateToGaugeValue(gobreaker.StateOpen))
|
|
}
|