mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
e00e85765b
This PR merges code from `coder/aibridge` repository into `coder/coder`. It was split into 4 PRs for easier review but stacked PRs will need to be merged into this PR so all checks pass. * https://github.com/coder/coder/pull/24190 -> raw code copy (this PR, before merging PRs on top of it, it was just 1 commit: https://github.com/coder/coder/commit/70d33f33200c7e77df910957595715f81f9bec24) * https://github.com/coder/coder/pull/24570 -> update imports in `coder/coder` to use copied code * https://github.com/coder/coder/pull/24586 -> linter fixes and CI integration (also added README.md) * https://github.com/coder/coder/pull/24571 -> added exclude to scripts/check_emdash.sh check Original PR message (before PR squash): Moves coder/aibridge code into coder/coder repository. Omitted files: - `go.mod`, `go.sum`, `.gitignore`, `.github/workflows/ci.yml,` `Makefile`, `LICENSE`, `README.md` (modified README.md is added later) - `.github`, `example`, `buildinfo,` `scripts` directories Simple verification script (will list omitted files) ``` tmp=$(mktemp -d) echo "$tmp" git clone --depth=1 https://github.com/coder/aibridge "$tmp/aibridge" git clone --depth=1 --branch pb/aibridge-code-move https://github.com/coder/coder "$tmp/coder" diff -rq --exclude=.git "$tmp/aibridge" "$tmp/coder/aibridge" # rm -rf "$tmp" ```
218 lines
6.7 KiB
Go
218 lines
6.7 KiB
Go
package circuitbreaker
|
|
|
|
import (
|
|
"bufio"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/sony/gobreaker/v2"
|
|
"golang.org/x/xerrors"
|
|
|
|
"github.com/coder/coder/v2/aibridge/config"
|
|
"github.com/coder/coder/v2/aibridge/metrics"
|
|
)
|
|
|
|
// ErrCircuitOpen is returned by Execute when the circuit breaker is open
|
|
// and the request was rejected without calling the handler.
|
|
var ErrCircuitOpen = xerrors.New("circuit breaker is open")
|
|
|
|
// DefaultIsFailure returns true for standard HTTP status codes that typically
|
|
// indicate upstream overload.
|
|
func DefaultIsFailure(statusCode int) bool {
|
|
switch statusCode {
|
|
case http.StatusTooManyRequests, // 429
|
|
http.StatusServiceUnavailable, // 503
|
|
http.StatusGatewayTimeout: // 504
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
// ProviderCircuitBreakers manages per-endpoint/model circuit breakers for a single provider.
|
|
type ProviderCircuitBreakers struct {
|
|
provider string
|
|
config config.CircuitBreaker
|
|
breakers sync.Map // "endpoint:model" -> *gobreaker.CircuitBreaker[struct{}]
|
|
onChange func(endpoint, model string, from, to gobreaker.State)
|
|
metrics *metrics.Metrics
|
|
}
|
|
|
|
// NewProviderCircuitBreakers creates circuit breakers for a single provider.
|
|
// Returns nil if cfg is nil (no circuit breaker protection).
|
|
// onChange is called when circuit state changes.
|
|
// metrics is used to record circuit breaker reject counts (can be nil).
|
|
func NewProviderCircuitBreakers(provider string, cfg *config.CircuitBreaker, onChange func(endpoint, model string, from, to gobreaker.State), m *metrics.Metrics) *ProviderCircuitBreakers {
|
|
if cfg == nil {
|
|
return nil
|
|
}
|
|
return &ProviderCircuitBreakers{
|
|
provider: provider,
|
|
config: *cfg,
|
|
onChange: onChange,
|
|
metrics: m,
|
|
}
|
|
}
|
|
|
|
// isFailure checks if the status code should count as a failure.
|
|
// Falls back to DefaultIsFailure if no custom function is configured.
|
|
func (p *ProviderCircuitBreakers) isFailure(statusCode int) bool {
|
|
if p.config.IsFailure != nil {
|
|
return p.config.IsFailure(statusCode)
|
|
}
|
|
return DefaultIsFailure(statusCode)
|
|
}
|
|
|
|
// openErrBody returns the error response body when the circuit is open.
|
|
func (p *ProviderCircuitBreakers) openErrBody() []byte {
|
|
if p.config.OpenErrorResponse != nil {
|
|
return p.config.OpenErrorResponse()
|
|
}
|
|
return []byte(`{"error":"circuit breaker is open"}`)
|
|
}
|
|
|
|
// Get returns the circuit breaker for an endpoint/model tuple, creating it if needed.
|
|
func (p *ProviderCircuitBreakers) Get(endpoint, model string) *gobreaker.CircuitBreaker[struct{}] {
|
|
key := endpoint + ":" + model
|
|
if v, ok := p.breakers.Load(key); ok {
|
|
return v.(*gobreaker.CircuitBreaker[struct{}]) //nolint:forcetypeassert // sync.Map always stores this type
|
|
}
|
|
|
|
settings := gobreaker.Settings{
|
|
Name: p.provider + ":" + key,
|
|
MaxRequests: p.config.MaxRequests,
|
|
Interval: p.config.Interval,
|
|
Timeout: p.config.Timeout,
|
|
ReadyToTrip: func(counts gobreaker.Counts) bool {
|
|
return counts.ConsecutiveFailures >= p.config.FailureThreshold
|
|
},
|
|
OnStateChange: func(_ string, from, to gobreaker.State) {
|
|
if p.onChange != nil {
|
|
p.onChange(endpoint, model, from, to)
|
|
}
|
|
},
|
|
}
|
|
|
|
cb := gobreaker.NewCircuitBreaker[struct{}](settings)
|
|
actual, _ := p.breakers.LoadOrStore(key, cb)
|
|
return actual.(*gobreaker.CircuitBreaker[struct{}]) //nolint:forcetypeassert // sync.Map always stores this type
|
|
}
|
|
|
|
// statusCapturingWriter wraps http.ResponseWriter to capture the status code.
|
|
// It implements http.Flusher to support streaming and http.Hijacker to
|
|
// satisfy the FullResponseWriter lint rule.
|
|
type statusCapturingWriter struct {
|
|
http.ResponseWriter
|
|
statusCode int
|
|
headerWritten bool
|
|
}
|
|
|
|
func (w *statusCapturingWriter) WriteHeader(code int) {
|
|
if !w.headerWritten {
|
|
w.statusCode = code
|
|
w.headerWritten = true
|
|
}
|
|
w.ResponseWriter.WriteHeader(code)
|
|
}
|
|
|
|
func (w *statusCapturingWriter) Write(b []byte) (int, error) {
|
|
if !w.headerWritten {
|
|
w.statusCode = http.StatusOK
|
|
w.headerWritten = true
|
|
}
|
|
return w.ResponseWriter.Write(b)
|
|
}
|
|
|
|
func (w *statusCapturingWriter) Flush() {
|
|
if f, ok := w.ResponseWriter.(http.Flusher); ok {
|
|
f.Flush()
|
|
}
|
|
}
|
|
|
|
func (w *statusCapturingWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
|
h, ok := w.ResponseWriter.(http.Hijacker)
|
|
if !ok {
|
|
return nil, nil, xerrors.New("upstream ResponseWriter does not support hijacking")
|
|
}
|
|
return h.Hijack()
|
|
}
|
|
|
|
// Unwrap returns the underlying ResponseWriter for interface checks.
|
|
func (w *statusCapturingWriter) Unwrap() http.ResponseWriter {
|
|
return w.ResponseWriter
|
|
}
|
|
|
|
// Execute runs the given handler function within circuit breaker protection.
|
|
// If the circuit is open, the request is rejected with a 503 response, metrics are recorded,
|
|
// and ErrCircuitOpen is returned.
|
|
// Otherwise, it returns the handler's error (or nil on success).
|
|
// The handler receives a wrapped ResponseWriter that captures the status code.
|
|
// If the receiver is nil (no circuit breaker configured), the handler is called directly.
|
|
func (p *ProviderCircuitBreakers) Execute(endpoint, model string, w http.ResponseWriter, handler func(http.ResponseWriter) error) error {
|
|
if p == nil {
|
|
return handler(w)
|
|
}
|
|
|
|
cb := p.Get(endpoint, model)
|
|
|
|
// Wrap response writer to capture status code
|
|
sw := &statusCapturingWriter{ResponseWriter: w, statusCode: http.StatusOK}
|
|
|
|
var handlerErr error
|
|
_, err := cb.Execute(func() (struct{}, error) {
|
|
handlerErr = handler(sw)
|
|
if p.isFailure(sw.statusCode) {
|
|
return struct{}{}, xerrors.Errorf("upstream error: %d", sw.statusCode)
|
|
}
|
|
return struct{}{}, nil
|
|
})
|
|
|
|
if errors.Is(err, gobreaker.ErrOpenState) || errors.Is(err, gobreaker.ErrTooManyRequests) {
|
|
if p.metrics != nil {
|
|
p.metrics.CircuitBreakerRejects.WithLabelValues(p.provider, endpoint, model).Inc()
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.Header().Set("Retry-After", fmt.Sprintf("%d", int64(p.config.Timeout.Seconds())))
|
|
w.WriteHeader(http.StatusServiceUnavailable)
|
|
_, _ = w.Write(p.openErrBody())
|
|
return ErrCircuitOpen
|
|
}
|
|
|
|
return handlerErr
|
|
}
|
|
|
|
// Timeout returns the configured timeout duration for this circuit breaker.
|
|
func (p *ProviderCircuitBreakers) Timeout() time.Duration {
|
|
return p.config.Timeout
|
|
}
|
|
|
|
// Provider returns the provider name for this circuit breaker.
|
|
func (p *ProviderCircuitBreakers) Provider() string {
|
|
return p.provider
|
|
}
|
|
|
|
// OpenErrorResponse returns the error response body when the circuit is open.
|
|
// This is exposed for handlers to use when responding to rejected requests.
|
|
func (p *ProviderCircuitBreakers) OpenErrorResponse() []byte {
|
|
return p.openErrBody()
|
|
}
|
|
|
|
// StateToGaugeValue converts gobreaker.State to a gauge value.
|
|
// closed=0, half-open=0.5, open=1
|
|
func StateToGaugeValue(s gobreaker.State) float64 {
|
|
switch s {
|
|
case gobreaker.StateClosed:
|
|
return 0
|
|
case gobreaker.StateHalfOpen:
|
|
return 0.5
|
|
case gobreaker.StateOpen:
|
|
return 1
|
|
default:
|
|
return 0
|
|
}
|
|
}
|