mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
c49170b6b3
Surface the io.ReadAll error in the error message when an HTTP request fails with a non-200 status, instead of silently discarding it.
455 lines
12 KiB
Go
455 lines
12 KiB
Go
package bridge
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
|
|
"go.opentelemetry.io/otel/attribute"
|
|
semconv "go.opentelemetry.io/otel/semconv/v1.14.0"
|
|
"go.opentelemetry.io/otel/semconv/v1.14.0/httpconv"
|
|
"go.opentelemetry.io/otel/trace"
|
|
"golang.org/x/xerrors"
|
|
|
|
"cdr.dev/slog/v3"
|
|
"cdr.dev/slog/v3/sloggers/sloghuman"
|
|
"github.com/coder/coder/v2/coderd/tracing"
|
|
"github.com/coder/coder/v2/codersdk"
|
|
"github.com/coder/coder/v2/scaletest/harness"
|
|
"github.com/coder/coder/v2/scaletest/loadtestutil"
|
|
"github.com/coder/quartz"
|
|
)
|
|
|
|
type (
|
|
tracingContextKey struct{}
|
|
tracingContext struct {
|
|
provider string
|
|
model string
|
|
stream bool
|
|
requestNum int
|
|
mode RequestMode
|
|
}
|
|
)
|
|
|
|
type tracingTransport struct {
|
|
cfg Config
|
|
underlying http.RoundTripper
|
|
}
|
|
|
|
func newTracingTransport(cfg Config, underlying http.RoundTripper) *tracingTransport {
|
|
if underlying == nil {
|
|
underlying = http.DefaultTransport
|
|
}
|
|
return &tracingTransport{
|
|
cfg: cfg,
|
|
underlying: otelhttp.NewTransport(underlying),
|
|
}
|
|
}
|
|
|
|
func (t *tracingTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
aibridgeCtx, hasAIBridgeCtx := req.Context().Value(tracingContextKey{}).(tracingContext)
|
|
|
|
resp, err := t.underlying.RoundTrip(req)
|
|
|
|
if hasAIBridgeCtx {
|
|
ctx := req.Context()
|
|
if resp != nil && resp.Request != nil {
|
|
ctx = resp.Request.Context()
|
|
}
|
|
span := trace.SpanFromContext(ctx)
|
|
if span.IsRecording() {
|
|
span.SetAttributes(
|
|
attribute.String("aibridge.provider", aibridgeCtx.provider),
|
|
attribute.String("aibridge.model", aibridgeCtx.model),
|
|
attribute.Bool("aibridge.stream", aibridgeCtx.stream),
|
|
attribute.Int("aibridge.request_num", aibridgeCtx.requestNum),
|
|
attribute.String("aibridge.mode", string(aibridgeCtx.mode)),
|
|
)
|
|
}
|
|
}
|
|
|
|
return resp, err
|
|
}
|
|
|
|
type Runner struct {
|
|
client *codersdk.Client
|
|
cfg Config
|
|
strategy requestModeStrategy
|
|
providerStrategy ProviderStrategy
|
|
|
|
clock quartz.Clock
|
|
httpClient *http.Client
|
|
|
|
requestCount int64
|
|
successCount int64
|
|
failureCount int64
|
|
totalDuration time.Duration
|
|
totalTokens int64
|
|
}
|
|
|
|
func NewRunner(client *codersdk.Client, cfg Config) *Runner {
|
|
return &Runner{
|
|
client: client,
|
|
cfg: cfg,
|
|
strategy: cfg.NewStrategy(client),
|
|
providerStrategy: NewProviderStrategy(cfg.Provider),
|
|
clock: quartz.NewReal(),
|
|
httpClient: &http.Client{
|
|
Transport: newTracingTransport(cfg, http.DefaultTransport),
|
|
},
|
|
}
|
|
}
|
|
|
|
func (r *Runner) WithClock(clock quartz.Clock) *Runner {
|
|
r.clock = clock
|
|
return r
|
|
}
|
|
|
|
var (
|
|
_ harness.Runnable = &Runner{}
|
|
_ harness.Cleanable = &Runner{}
|
|
_ harness.Collectable = &Runner{}
|
|
)
|
|
|
|
func (r *Runner) Run(ctx context.Context, id string, logs io.Writer) error {
|
|
ctx, span := tracing.StartSpan(ctx)
|
|
defer span.End()
|
|
|
|
logs = loadtestutil.NewSyncWriter(logs)
|
|
logger := slog.Make(sloghuman.Sink(logs)).Leveled(slog.LevelDebug)
|
|
|
|
requestURL, token, err := r.strategy.Setup(ctx, id, logs)
|
|
if err != nil {
|
|
return xerrors.Errorf("strategy setup: %w", err)
|
|
}
|
|
|
|
requestCount := r.cfg.RequestCount
|
|
if requestCount <= 0 {
|
|
requestCount = 1
|
|
}
|
|
|
|
model := r.providerStrategy.DefaultModel()
|
|
|
|
logger.Info(ctx, "bridge runner is ready",
|
|
slog.F("request_count", requestCount),
|
|
slog.F("model", model),
|
|
slog.F("stream", r.cfg.Stream),
|
|
)
|
|
|
|
for i := 0; i < requestCount; i++ {
|
|
if err := r.makeRequest(ctx, logger, requestURL, token, model, i); err != nil {
|
|
logger.Warn(ctx, "bridge request failed",
|
|
slog.F("request_num", i+1),
|
|
slog.F("error_type", "request_failed"),
|
|
slog.Error(err),
|
|
)
|
|
r.cfg.Metrics.AddError("request")
|
|
r.cfg.Metrics.AddRequest("failure")
|
|
r.failureCount++
|
|
|
|
// Continue making requests even if one fails
|
|
continue
|
|
}
|
|
r.successCount++
|
|
r.cfg.Metrics.AddRequest("success")
|
|
r.requestCount++
|
|
}
|
|
|
|
logger.Info(ctx, "bridge runner completed",
|
|
slog.F("total_requests", r.requestCount),
|
|
slog.F("success", r.successCount),
|
|
slog.F("failure", r.failureCount),
|
|
)
|
|
|
|
// Fail the run if any request failed
|
|
if r.failureCount > 0 {
|
|
return xerrors.Errorf("bridge runner failed: %d out of %d requests failed", r.failureCount, requestCount)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *Runner) makeRequest(ctx context.Context, logger slog.Logger, url, token, model string, requestNum int) error {
|
|
start := r.clock.Now()
|
|
|
|
ctx = context.WithValue(ctx, tracingContextKey{}, tracingContext{
|
|
provider: r.cfg.Provider,
|
|
model: model,
|
|
stream: r.cfg.Stream,
|
|
requestNum: requestNum + 1,
|
|
mode: r.cfg.Mode,
|
|
})
|
|
|
|
// Set timeout per request
|
|
ctx, cancel := context.WithTimeout(ctx, r.cfg.HTTPTimeout)
|
|
defer cancel()
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(r.cfg.RequestBody))
|
|
if err != nil {
|
|
return xerrors.Errorf("create request: %w", err)
|
|
}
|
|
|
|
req.Header.Set("Content-Type", "application/json")
|
|
if token != "" {
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
}
|
|
|
|
logger.Debug(ctx, "making bridge request",
|
|
slog.F("url", url),
|
|
slog.F("request_num", requestNum+1),
|
|
slog.F("model", model),
|
|
)
|
|
|
|
resp, err := r.httpClient.Do(req)
|
|
if err != nil {
|
|
span := trace.SpanFromContext(req.Context())
|
|
if span.IsRecording() {
|
|
span.RecordError(err)
|
|
}
|
|
logger.Warn(ctx, "request failed during execution",
|
|
slog.F("request_num", requestNum+1),
|
|
slog.Error(err),
|
|
)
|
|
return xerrors.Errorf("execute request: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
span := trace.SpanFromContext(req.Context())
|
|
if span.IsRecording() {
|
|
span.SetAttributes(semconv.HTTPStatusCodeKey.Int(resp.StatusCode))
|
|
span.SetStatus(httpconv.ClientStatus(resp.StatusCode))
|
|
}
|
|
|
|
duration := r.clock.Since(start)
|
|
r.totalDuration += duration
|
|
r.cfg.Metrics.ObserveDuration(duration.Seconds())
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, readErr := io.ReadAll(resp.Body)
|
|
if readErr != nil {
|
|
body = []byte(fmt.Sprintf("<failed to read response body: %s>", readErr))
|
|
}
|
|
err := xerrors.Errorf("request failed with status %d: %s", resp.StatusCode, string(body))
|
|
span.RecordError(err)
|
|
return err
|
|
}
|
|
|
|
if r.cfg.Stream {
|
|
err := r.handleStreamingResponse(ctx, logger, resp)
|
|
if err != nil {
|
|
span.RecordError(err)
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
return r.handleNonStreamingResponse(ctx, logger, resp)
|
|
}
|
|
|
|
func (r *Runner) handleNonStreamingResponse(ctx context.Context, logger slog.Logger, resp *http.Response) error {
|
|
switch r.cfg.Provider {
|
|
case "messages":
|
|
return r.handleMessagesResponse(ctx, logger, resp)
|
|
case "responses":
|
|
return r.handleResponsesResponse(ctx, logger, resp)
|
|
case "completions":
|
|
return r.handleCompletionsResponse(ctx, logger, resp)
|
|
default:
|
|
return xerrors.Errorf("unsupported provider: %s", r.cfg.Provider)
|
|
}
|
|
}
|
|
|
|
func (r *Runner) handleCompletionsResponse(ctx context.Context, logger slog.Logger, resp *http.Response) error {
|
|
var response struct {
|
|
ID string `json:"id"`
|
|
Model string `json:"model"`
|
|
Choices []struct {
|
|
Message struct {
|
|
Content string `json:"content"`
|
|
} `json:"message"`
|
|
} `json:"choices"`
|
|
Usage struct {
|
|
PromptTokens int `json:"prompt_tokens"`
|
|
CompletionTokens int `json:"completion_tokens"`
|
|
TotalTokens int `json:"total_tokens"`
|
|
} `json:"usage"`
|
|
}
|
|
|
|
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
|
|
return xerrors.Errorf("decode response: %w", err)
|
|
}
|
|
|
|
if len(response.Choices) > 0 {
|
|
assistantContent := response.Choices[0].Message.Content
|
|
logger.Debug(ctx, "received response",
|
|
slog.F("response_id", response.ID),
|
|
slog.F("content_length", len(assistantContent)),
|
|
)
|
|
}
|
|
|
|
if response.Usage.TotalTokens > 0 {
|
|
r.totalTokens += int64(response.Usage.TotalTokens)
|
|
r.cfg.Metrics.AddTokens("input", int64(response.Usage.PromptTokens))
|
|
r.cfg.Metrics.AddTokens("output", int64(response.Usage.CompletionTokens))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *Runner) handleResponsesResponse(ctx context.Context, logger slog.Logger, resp *http.Response) error {
|
|
var response struct {
|
|
ID string `json:"id"`
|
|
Model string `json:"model"`
|
|
Output []struct {
|
|
Type string `json:"type"`
|
|
Role string `json:"role"`
|
|
Content []struct {
|
|
Type string `json:"type"`
|
|
Text string `json:"text"`
|
|
} `json:"content"`
|
|
} `json:"output"`
|
|
Usage struct {
|
|
InputTokens int `json:"input_tokens"`
|
|
OutputTokens int `json:"output_tokens"`
|
|
TotalTokens int `json:"total_tokens"`
|
|
} `json:"usage"`
|
|
}
|
|
|
|
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
|
|
return xerrors.Errorf("decode response: %w", err)
|
|
}
|
|
|
|
var assistantContent string
|
|
var contentBuilder strings.Builder
|
|
for _, item := range response.Output {
|
|
if item.Role != "assistant" {
|
|
continue
|
|
}
|
|
for _, content := range item.Content {
|
|
if content.Type != "output_text" {
|
|
continue
|
|
}
|
|
_, _ = contentBuilder.WriteString(content.Text)
|
|
}
|
|
}
|
|
assistantContent = contentBuilder.String()
|
|
if assistantContent != "" {
|
|
logger.Debug(ctx, "received response",
|
|
slog.F("response_id", response.ID),
|
|
slog.F("content_length", len(assistantContent)),
|
|
)
|
|
}
|
|
|
|
if response.Usage.TotalTokens > 0 {
|
|
r.totalTokens += int64(response.Usage.TotalTokens)
|
|
r.cfg.Metrics.AddTokens("input", int64(response.Usage.InputTokens))
|
|
r.cfg.Metrics.AddTokens("output", int64(response.Usage.OutputTokens))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *Runner) handleMessagesResponse(ctx context.Context, logger slog.Logger, resp *http.Response) error {
|
|
var response struct {
|
|
ID string `json:"id"`
|
|
Model string `json:"model"`
|
|
Content []struct {
|
|
Type string `json:"type"`
|
|
Text string `json:"text"`
|
|
} `json:"content"`
|
|
Usage struct {
|
|
InputTokens int `json:"input_tokens"`
|
|
OutputTokens int `json:"output_tokens"`
|
|
} `json:"usage"`
|
|
}
|
|
|
|
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
|
|
return xerrors.Errorf("decode response: %w", err)
|
|
}
|
|
|
|
var assistantContent string
|
|
if len(response.Content) > 0 {
|
|
assistantContent = response.Content[0].Text
|
|
logger.Debug(ctx, "received response",
|
|
slog.F("response_id", response.ID),
|
|
slog.F("content_length", len(assistantContent)),
|
|
)
|
|
}
|
|
|
|
totalTokens := response.Usage.InputTokens + response.Usage.OutputTokens
|
|
if totalTokens > 0 {
|
|
r.totalTokens += int64(totalTokens)
|
|
r.cfg.Metrics.AddTokens("input", int64(response.Usage.InputTokens))
|
|
r.cfg.Metrics.AddTokens("output", int64(response.Usage.OutputTokens))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (*Runner) handleStreamingResponse(ctx context.Context, logger slog.Logger, resp *http.Response) error {
|
|
buf := make([]byte, 4096)
|
|
totalRead := 0
|
|
for {
|
|
// Check for context cancellation before each read
|
|
if ctx.Err() != nil {
|
|
logger.Warn(ctx, "streaming response canceled",
|
|
slog.F("bytes_read", totalRead),
|
|
slog.Error(ctx.Err()),
|
|
)
|
|
return xerrors.Errorf("stream canceled: %w", ctx.Err())
|
|
}
|
|
|
|
n, err := resp.Body.Read(buf)
|
|
if n > 0 {
|
|
totalRead += n
|
|
}
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
if err != nil {
|
|
// Check if error is due to context cancellation
|
|
if xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) {
|
|
logger.Warn(ctx, "streaming response read canceled",
|
|
slog.F("bytes_read", totalRead),
|
|
slog.Error(err),
|
|
)
|
|
return xerrors.Errorf("stream read canceled: %w", err)
|
|
}
|
|
logger.Warn(ctx, "streaming response read error",
|
|
slog.F("bytes_read", totalRead),
|
|
slog.Error(err),
|
|
)
|
|
return xerrors.Errorf("read stream: %w", err)
|
|
}
|
|
}
|
|
|
|
logger.Debug(ctx, "received streaming response", slog.F("bytes_read", totalRead))
|
|
return nil
|
|
}
|
|
|
|
func (r *Runner) Cleanup(ctx context.Context, id string, logs io.Writer) error {
|
|
return r.strategy.Cleanup(ctx, id, logs)
|
|
}
|
|
|
|
func (r *Runner) GetMetrics() map[string]any {
|
|
avgDuration := time.Duration(0)
|
|
if r.requestCount > 0 {
|
|
avgDuration = r.totalDuration / time.Duration(r.requestCount)
|
|
}
|
|
|
|
return map[string]any{
|
|
"request_count": r.requestCount,
|
|
"success_count": r.successCount,
|
|
"failure_count": r.failureCount,
|
|
"total_duration": r.totalDuration.String(),
|
|
"avg_duration": avgDuration.String(),
|
|
"total_tokens": r.totalTokens,
|
|
}
|
|
}
|