chore: add scaletesting tools for aibridge (#21279)

This pull request adds scaletesting tools for aibridge.

See
https://www.notion.so/Scale-tests-2c5d579be5928088b565d15dd8bdea41?source=copy_link
for information and instructions.

closes: https://github.com/coder/internal/issues/1156
closes: https://github.com/coder/internal/issues/1155
closes: https://github.com/coder/internal/issues/1158
This commit is contained in:
Sas Swart
2026-01-15 17:05:46 +02:00
committed by GitHub
parent 3894edbcc3
commit 0ebe8e57ad
10 changed files with 1808 additions and 1 deletions
+2
View File
@@ -68,6 +68,8 @@ func (r *RootCmd) scaletestCmd() *serpent.Command {
r.scaletestTaskStatus(),
r.scaletestSMTP(),
r.scaletestPrebuilds(),
r.scaletestBridge(),
r.scaletestLLMMock(),
},
}
+278
View File
@@ -0,0 +1,278 @@
//go:build !slim
package cli
import (
"fmt"
"net/http"
"os/signal"
"strconv"
"text/tabwriter"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/scaletest/bridge"
"github.com/coder/coder/v2/scaletest/createusers"
"github.com/coder/coder/v2/scaletest/harness"
"github.com/coder/serpent"
)
func (r *RootCmd) scaletestBridge() *serpent.Command {
var (
concurrentUsers int64
noCleanup bool
mode string
upstreamURL string
provider string
requestsPerUser int64
useStreamingAPI bool
requestPayloadSize int64
numMessages int64
httpTimeout time.Duration
timeoutStrategy = &timeoutFlags{}
cleanupStrategy = newScaletestCleanupStrategy()
output = &scaletestOutputFlags{}
prometheusFlags = &scaletestPrometheusFlags{}
)
cmd := &serpent.Command{
Use: "bridge",
Short: "Generate load on the AI Bridge service.",
Long: `Generate load for AI Bridge testing. Supports two modes: 'bridge' mode routes requests through the Coder AI Bridge, 'direct' mode makes requests directly to an upstream URL (useful for baseline comparisons).
Examples:
# Test OpenAI API through bridge
coder scaletest bridge --mode bridge --provider openai --concurrent-users 10 --request-count 5 --num-messages 10
# Test Anthropic API through bridge
coder scaletest bridge --mode bridge --provider anthropic --concurrent-users 10 --request-count 5 --num-messages 10
# Test directly against mock server
coder scaletest bridge --mode direct --provider openai --upstream-url http://localhost:8080/v1/chat/completions
`,
Handler: func(inv *serpent.Invocation) error {
ctx := inv.Context()
client, err := r.InitClient(inv)
if err != nil {
return err
}
client.HTTPClient = &http.Client{
Transport: &codersdk.HeaderTransport{
Transport: http.DefaultTransport,
Header: map[string][]string{
codersdk.BypassRatelimitHeader: {"true"},
},
},
}
reg := prometheus.NewRegistry()
metrics := bridge.NewMetrics(reg)
logger := inv.Logger
prometheusSrvClose := ServeHandler(ctx, logger, promhttp.HandlerFor(reg, promhttp.HandlerOpts{}), prometheusFlags.Address, "prometheus")
defer prometheusSrvClose()
defer func() {
_, _ = fmt.Fprintf(inv.Stderr, "Waiting %s for prometheus metrics to be scraped\n", prometheusFlags.Wait)
<-time.After(prometheusFlags.Wait)
}()
notifyCtx, stop := signal.NotifyContext(ctx, StopSignals...)
defer stop()
ctx = notifyCtx
var userConfig createusers.Config
if bridge.RequestMode(mode) == bridge.RequestModeBridge {
me, err := requireAdmin(ctx, client)
if err != nil {
return err
}
if len(me.OrganizationIDs) == 0 {
return xerrors.Errorf("admin user must have at least one organization")
}
userConfig = createusers.Config{
OrganizationID: me.OrganizationIDs[0],
}
_, _ = fmt.Fprintln(inv.Stderr, "Bridge mode: creating users and making requests through AI Bridge...")
} else {
_, _ = fmt.Fprintf(inv.Stderr, "Direct mode: making requests directly to %s\n", upstreamURL)
}
outputs, err := output.parse()
if err != nil {
return xerrors.Errorf("parse output flags: %w", err)
}
config := bridge.Config{
Mode: bridge.RequestMode(mode),
Metrics: metrics,
Provider: provider,
RequestCount: int(requestsPerUser),
Stream: useStreamingAPI,
RequestPayloadSize: int(requestPayloadSize),
NumMessages: int(numMessages),
HTTPTimeout: httpTimeout,
UpstreamURL: upstreamURL,
User: userConfig,
}
if err := config.Validate(); err != nil {
return xerrors.Errorf("validate config: %w", err)
}
if err := config.PrepareRequestBody(); err != nil {
return xerrors.Errorf("prepare request body: %w", err)
}
th := harness.NewTestHarness(timeoutStrategy.wrapStrategy(harness.ConcurrentExecutionStrategy{}), cleanupStrategy.toStrategy())
for i := range concurrentUsers {
id := strconv.Itoa(int(i))
name := fmt.Sprintf("bridge-%s", id)
var runner harness.Runnable = bridge.NewRunner(client, config)
th.AddRun(name, id, runner)
}
_, _ = fmt.Fprintln(inv.Stderr, "Bridge scaletest configuration:")
tw := tabwriter.NewWriter(inv.Stderr, 0, 0, 2, ' ', 0)
for _, opt := range inv.Command.Options {
if opt.Hidden || opt.ValueSource == serpent.ValueSourceNone {
continue
}
_, _ = fmt.Fprintf(tw, " %s:\t%s", opt.Name, opt.Value.String())
if opt.ValueSource != serpent.ValueSourceDefault {
_, _ = fmt.Fprintf(tw, "\t(from %s)", opt.ValueSource)
}
_, _ = fmt.Fprintln(tw)
}
_ = tw.Flush()
_, _ = fmt.Fprintln(inv.Stderr, "\nRunning bridge scaletest...")
testCtx, testCancel := timeoutStrategy.toContext(ctx)
defer testCancel()
err = th.Run(testCtx)
if err != nil {
return xerrors.Errorf("run test harness (harness failure, not a test failure): %w", err)
}
// If the command was interrupted, skip stats.
if notifyCtx.Err() != nil {
return notifyCtx.Err()
}
res := th.Results()
for _, o := range outputs {
err = o.write(res, inv.Stdout)
if err != nil {
return xerrors.Errorf("write output %q to %q: %w", o.format, o.path, err)
}
}
if !noCleanup {
_, _ = fmt.Fprintln(inv.Stderr, "\nCleaning up...")
cleanupCtx, cleanupCancel := cleanupStrategy.toContext(ctx)
defer cleanupCancel()
err = th.Cleanup(cleanupCtx)
if err != nil {
return xerrors.Errorf("cleanup tests: %w", err)
}
}
if res.TotalFail > 0 {
return xerrors.New("load test failed, see above for more details")
}
return nil
},
}
cmd.Options = serpent.OptionSet{
{
Flag: "concurrent-users",
FlagShorthand: "c",
Env: "CODER_SCALETEST_BRIDGE_CONCURRENT_USERS",
Description: "Required: Number of concurrent users.",
Value: serpent.Validate(serpent.Int64Of(&concurrentUsers), func(value *serpent.Int64) error {
if value == nil || value.Value() <= 0 {
return xerrors.Errorf("--concurrent-users must be greater than 0")
}
return nil
}),
Required: true,
},
{
Flag: "mode",
Env: "CODER_SCALETEST_BRIDGE_MODE",
Default: "direct",
Description: "Request mode: 'bridge' (create users and use AI Bridge) or 'direct' (make requests directly to upstream-url).",
Value: serpent.EnumOf(&mode, string(bridge.RequestModeBridge), string(bridge.RequestModeDirect)),
},
{
Flag: "upstream-url",
Env: "CODER_SCALETEST_BRIDGE_UPSTREAM_URL",
Description: "URL to make requests to directly (required in direct mode, e.g., http://localhost:8080/v1/chat/completions).",
Value: serpent.StringOf(&upstreamURL),
},
{
Flag: "provider",
Env: "CODER_SCALETEST_BRIDGE_PROVIDER",
Default: "openai",
Description: "API provider to use.",
Value: serpent.EnumOf(&provider, "openai", "anthropic"),
},
{
Flag: "request-count",
Env: "CODER_SCALETEST_BRIDGE_REQUEST_COUNT",
Default: "1",
Description: "Number of sequential requests to make per runner.",
Value: serpent.Validate(serpent.Int64Of(&requestsPerUser), func(value *serpent.Int64) error {
if value == nil || value.Value() <= 0 {
return xerrors.Errorf("--request-count must be greater than 0")
}
return nil
}),
},
{
Flag: "stream",
Env: "CODER_SCALETEST_BRIDGE_STREAM",
Description: "Enable streaming requests.",
Value: serpent.BoolOf(&useStreamingAPI),
},
{
Flag: "request-payload-size",
Env: "CODER_SCALETEST_BRIDGE_REQUEST_PAYLOAD_SIZE",
Default: "1024",
Description: "Size in bytes of the request payload (user message content). If 0, uses default message content.",
Value: serpent.Int64Of(&requestPayloadSize),
},
{
Flag: "num-messages",
Env: "CODER_SCALETEST_BRIDGE_NUM_MESSAGES",
Default: "1",
Description: "Number of messages to include in the conversation.",
Value: serpent.Int64Of(&numMessages),
},
{
Flag: "no-cleanup",
Env: "CODER_SCALETEST_NO_CLEANUP",
Description: "Do not clean up resources after the test completes.",
Value: serpent.BoolOf(&noCleanup),
},
{
Flag: "http-timeout",
Env: "CODER_SCALETEST_BRIDGE_HTTP_TIMEOUT",
Default: "30s",
Description: "Timeout for individual HTTP requests to the upstream provider.",
Value: serpent.DurationOf(&httpTimeout),
},
}
timeoutStrategy.attach(&cmd.Options)
cleanupStrategy.attach(&cmd.Options)
output.attach(&cmd.Options)
prometheusFlags.attach(&cmd.Options)
return cmd
}
+118
View File
@@ -0,0 +1,118 @@
//go:build !slim
package cli
import (
"fmt"
"os/signal"
"time"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/sloghuman"
"github.com/coder/coder/v2/scaletest/llmmock"
"github.com/coder/serpent"
)
func (*RootCmd) scaletestLLMMock() *serpent.Command {
var (
address string
artificialLatency time.Duration
responsePayloadSize int64
pprofEnable bool
pprofAddress string
traceEnable bool
)
cmd := &serpent.Command{
Use: "llm-mock",
Short: "Start a mock LLM API server for testing",
Long: `Start a mock LLM API server that simulates OpenAI and Anthropic APIs`,
Handler: func(inv *serpent.Invocation) error {
ctx, stop := signal.NotifyContext(inv.Context(), StopSignals...)
defer stop()
logger := slog.Make(sloghuman.Sink(inv.Stderr)).Leveled(slog.LevelInfo)
if pprofEnable {
closePprof := ServeHandler(ctx, logger, nil, pprofAddress, "pprof")
defer closePprof()
logger.Info(ctx, "pprof server started", slog.F("address", pprofAddress))
}
config := llmmock.Config{
Address: address,
Logger: logger,
ArtificialLatency: artificialLatency,
ResponsePayloadSize: int(responsePayloadSize),
PprofEnable: pprofEnable,
PprofAddress: pprofAddress,
TraceEnable: traceEnable,
}
srv := new(llmmock.Server)
if err := srv.Start(ctx, config); err != nil {
return xerrors.Errorf("start mock LLM server: %w", err)
}
defer func() {
_ = srv.Stop()
}()
_, _ = fmt.Fprintf(inv.Stdout, "Mock LLM API server started on %s\n", srv.APIAddress())
_, _ = fmt.Fprintf(inv.Stdout, " OpenAI endpoint: %s/v1/chat/completions\n", srv.APIAddress())
_, _ = fmt.Fprintf(inv.Stdout, " Anthropic endpoint: %s/v1/messages\n", srv.APIAddress())
<-ctx.Done()
return nil
},
}
cmd.Options = []serpent.Option{
{
Flag: "address",
Env: "CODER_SCALETEST_LLM_MOCK_ADDRESS",
Default: "localhost",
Description: "Address to bind the mock LLM API server. Can include a port (e.g., 'localhost:8080' or ':8080'). Uses a random port if no port is specified.",
Value: serpent.StringOf(&address),
},
{
Flag: "artificial-latency",
Env: "CODER_SCALETEST_LLM_MOCK_ARTIFICIAL_LATENCY",
Default: "0s",
Description: "Artificial latency to add to each response (e.g., 100ms, 1s). Simulates slow upstream processing.",
Value: serpent.DurationOf(&artificialLatency),
},
{
Flag: "response-payload-size",
Env: "CODER_SCALETEST_LLM_MOCK_RESPONSE_PAYLOAD_SIZE",
Default: "0",
Description: "Size in bytes of the response payload. If 0, uses default context-aware responses.",
Value: serpent.Int64Of(&responsePayloadSize),
},
{
Flag: "pprof-enable",
Env: "CODER_SCALETEST_LLM_MOCK_PPROF_ENABLE",
Default: "false",
Description: "Serve pprof metrics on the address defined by pprof-address.",
Value: serpent.BoolOf(&pprofEnable),
},
{
Flag: "pprof-address",
Env: "CODER_SCALETEST_LLM_MOCK_PPROF_ADDRESS",
Default: "127.0.0.1:6060",
Description: "The bind address to serve pprof.",
Value: serpent.StringOf(&pprofAddress),
},
{
Flag: "trace-enable",
Env: "CODER_SCALETEST_LLM_MOCK_TRACE_ENABLE",
Default: "false",
Description: "Whether application tracing data is collected. It exports to a backend configured by environment variables. See: https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/protocol/exporter.md.",
Value: serpent.BoolOf(&traceEnable),
},
}
return cmd
}
+1 -1
View File
@@ -437,7 +437,7 @@ require (
go.opentelemetry.io/collector/pdata/pprofile v0.121.0 // indirect
go.opentelemetry.io/collector/semconv v0.123.0 // indirect
go.opentelemetry.io/contrib v1.19.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.62.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.62.0
go.opentelemetry.io/otel/metric v1.38.0 // indirect
go.opentelemetry.io/proto/otlp v1.7.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
+150
View File
@@ -0,0 +1,150 @@
package bridge
import (
"encoding/json"
"time"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/scaletest/createusers"
)
type RequestMode string
const (
RequestModeBridge RequestMode = "bridge"
RequestModeDirect RequestMode = "direct"
)
type Config struct {
// Mode determines how requests are made.
// "bridge": Create users in Coder and use their session tokens to make requests through AI Bridge.
// "direct": Make requests directly to UpstreamURL without user creation.
Mode RequestMode `json:"mode"`
// User is the configuration for the user to create.
// Required in bridge mode.
User createusers.Config `json:"user"`
// UpstreamURL is the URL to make requests to directly.
// Only used in direct mode.
UpstreamURL string `json:"upstream_url"`
// Provider is the API provider to use: "openai" or "anthropic".
Provider string `json:"provider"`
// RequestCount is the number of requests to make per runner.
RequestCount int `json:"request_count"`
// Stream indicates whether to use streaming requests.
Stream bool `json:"stream"`
// RequestPayloadSize is the size in bytes of the request payload (user message content).
// If 0, uses default message content.
RequestPayloadSize int `json:"request_payload_size"`
// NumMessages is the number of messages to include in the conversation.
// Messages alternate between user and assistant roles, always ending with user.
// Must be greater than 0.
NumMessages int `json:"num_messages"`
// HTTPTimeout is the timeout for individual HTTP requests to the upstream
// provider. This is separate from the job timeout which controls the overall
// test execution.
HTTPTimeout time.Duration `json:"http_timeout"`
Metrics *Metrics `json:"-"`
// RequestBody is the pre-serialized JSON request body. This is generated
// once by PrepareRequestBody and shared across all runners and requests.
RequestBody []byte `json:"-"`
}
func (c Config) Validate() error {
if c.Metrics == nil {
return xerrors.New("metrics must be set")
}
// Validate mode
if c.Mode != RequestModeBridge && c.Mode != RequestModeDirect {
return xerrors.New("mode must be either 'bridge' or 'direct'")
}
if c.RequestCount <= 0 {
return xerrors.New("request_count must be greater than 0")
}
// Validate provider
if c.Provider != "openai" && c.Provider != "anthropic" {
return xerrors.New("provider must be either 'openai' or 'anthropic'")
}
if c.Mode == RequestModeDirect {
// In direct mode, UpstreamURL must be set.
if c.UpstreamURL == "" {
return xerrors.New("upstream_url must be set in direct mode")
}
return nil
}
// In bridge mode, User config is required.
if c.User.OrganizationID == uuid.Nil {
return xerrors.New("user organization_id must be set in bridge mode")
}
if err := c.User.Validate(); err != nil {
return xerrors.Errorf("user config: %w", err)
}
if c.NumMessages <= 0 {
return xerrors.New("num_messages must be greater than 0")
}
return nil
}
func (c Config) NewStrategy(client *codersdk.Client) requestModeStrategy {
if c.Mode == RequestModeDirect {
return newDirectStrategy(directStrategyConfig{
UpstreamURL: c.UpstreamURL,
})
}
return newBridgeStrategy(bridgeStrategyConfig{
Client: client,
Provider: c.Provider,
Metrics: c.Metrics,
User: c.User,
})
}
// PrepareRequestBody generates the conversation and serializes the full request
// body once. This should be called before creating Runners so that all runners
// share the same pre-generated payload.
func (c *Config) PrepareRequestBody() error {
provider := NewProviderStrategy(c.Provider)
model := provider.DefaultModel()
var formattedMessages []any
if c.RequestPayloadSize > 0 {
formattedMessages = generateConversation(provider, c.RequestPayloadSize, c.NumMessages)
} else {
messages := []message{{
Role: "user",
Content: "Hello from the bridge load generator.",
}}
formattedMessages = provider.formatMessages(messages)
}
reqBody := provider.buildRequestBody(model, formattedMessages, c.Stream)
bodyBytes, err := json.Marshal(reqBody)
if err != nil {
return xerrors.Errorf("marshal request body: %w", err)
}
c.RequestBody = bodyBytes
return nil
}
+72
View File
@@ -0,0 +1,72 @@
package bridge
import (
"github.com/prometheus/client_golang/prometheus"
)
type Metrics struct {
bridgeErrors *prometheus.CounterVec
bridgeRequests *prometheus.CounterVec
bridgeDuration prometheus.Histogram
bridgeTokensTotal *prometheus.CounterVec
}
func NewMetrics(reg prometheus.Registerer) *Metrics {
if reg == nil {
reg = prometheus.DefaultRegisterer
}
errors := prometheus.NewCounterVec(prometheus.CounterOpts{
Namespace: "coderd",
Subsystem: "scaletest",
Name: "bridge_errors_total",
Help: "Total number of bridge errors",
}, []string{"action"})
requests := prometheus.NewCounterVec(prometheus.CounterOpts{
Namespace: "coderd",
Subsystem: "scaletest",
Name: "bridge_requests_total",
Help: "Total number of bridge requests",
}, []string{"status"})
duration := prometheus.NewHistogram(prometheus.HistogramOpts{
Namespace: "coderd",
Subsystem: "scaletest",
Name: "bridge_request_duration_seconds",
Help: "Duration of bridge requests in seconds",
Buckets: prometheus.DefBuckets,
})
tokens := prometheus.NewCounterVec(prometheus.CounterOpts{
Namespace: "coderd",
Subsystem: "scaletest",
Name: "bridge_response_tokens_total",
Help: "Total number of tokens in bridge responses",
}, []string{"type"})
reg.MustRegister(errors, requests, duration, tokens)
return &Metrics{
bridgeErrors: errors,
bridgeRequests: requests,
bridgeDuration: duration,
bridgeTokensTotal: tokens,
}
}
func (m *Metrics) AddError(action string) {
m.bridgeErrors.WithLabelValues(action).Inc()
}
func (m *Metrics) AddRequest(status string) {
m.bridgeRequests.WithLabelValues(status).Inc()
}
func (m *Metrics) ObserveDuration(duration float64) {
m.bridgeDuration.Observe(duration)
}
func (m *Metrics) AddTokens(tokenType string, count int64) {
m.bridgeTokensTotal.WithLabelValues(tokenType).Add(float64(count))
}
+134
View File
@@ -0,0 +1,134 @@
package bridge
import (
"encoding/json"
"strings"
)
// ProviderStrategy handles provider-specific message formatting for LLM APIs.
type ProviderStrategy interface {
DefaultModel() string
formatMessages(messages []message) []any
buildRequestBody(model string, messages []any, stream bool) map[string]any
}
type message struct {
Role string
Content string
}
func NewProviderStrategy(provider string) ProviderStrategy {
switch provider {
case "anthropic":
return &anthropicProvider{}
default:
return &openAIProvider{}
}
}
type openAIProvider struct{}
func (*openAIProvider) DefaultModel() string {
return "gpt-4"
}
func (*openAIProvider) formatMessages(messages []message) []any {
formatted := make([]any, 0, len(messages))
for _, msg := range messages {
formatted = append(formatted, map[string]string{
"role": msg.Role,
"content": msg.Content,
})
}
return formatted
}
func (*openAIProvider) buildRequestBody(model string, messages []any, stream bool) map[string]any {
return map[string]any{
"model": model,
"messages": messages,
"stream": stream,
}
}
type anthropicProvider struct{}
func (*anthropicProvider) DefaultModel() string {
return "claude-3-opus-20240229"
}
func (*anthropicProvider) formatMessages(messages []message) []any {
formatted := make([]any, 0, len(messages))
for _, msg := range messages {
formatted = append(formatted, map[string]any{
"role": msg.Role,
"content": []map[string]string{
{
"type": "text",
"text": msg.Content,
},
},
})
}
return formatted
}
func (*anthropicProvider) buildRequestBody(model string, messages []any, stream bool) map[string]any {
return map[string]any{
"model": model,
"messages": messages,
"max_tokens": 1024,
"stream": stream,
}
}
// generateConversation creates a conversation with alternating user/assistant
// messages. The content is filled with repeated 'x' characters to reach
// approximately the target size. The last message is always from "user" as
// required by LLM APIs.
func generateConversation(provider ProviderStrategy, targetSize int, numMessages int) []any {
if targetSize <= 0 {
return nil
}
if numMessages < 1 {
numMessages = 1
}
roles := []string{"user", "assistant"}
messages := make([]message, numMessages)
for i := range messages {
messages[i].Role = roles[i%2]
}
// Ensure last message is from user (required for LLM APIs).
if messages[len(messages)-1].Role != "user" {
messages[len(messages)-1].Role = "user"
}
overhead := measureJSONSize(provider.formatMessages(messages))
bytesPerMessage := targetSize - overhead
if bytesPerMessage < 0 {
bytesPerMessage = 0
}
perMessage := bytesPerMessage / len(messages)
remainder := bytesPerMessage % len(messages)
for i := range messages {
size := perMessage
if i == len(messages)-1 {
size += remainder
}
messages[i].Content = strings.Repeat("x", size)
}
return provider.formatMessages(messages)
}
func measureJSONSize(v any) int {
data, err := json.Marshal(v)
if err != nil {
return 0
}
return len(data)
}
+391
View File
@@ -0,0 +1,391 @@
package bridge
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"time"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
"go.opentelemetry.io/otel/attribute"
semconv "go.opentelemetry.io/otel/semconv/v1.14.0"
"go.opentelemetry.io/otel/semconv/v1.14.0/httpconv"
"go.opentelemetry.io/otel/trace"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/sloghuman"
"github.com/coder/coder/v2/coderd/tracing"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/scaletest/harness"
"github.com/coder/coder/v2/scaletest/loadtestutil"
"github.com/coder/quartz"
)
type (
tracingContextKey struct{}
tracingContext struct {
provider string
model string
stream bool
requestNum int
mode RequestMode
}
)
type tracingTransport struct {
cfg Config
underlying http.RoundTripper
}
func newTracingTransport(cfg Config, underlying http.RoundTripper) *tracingTransport {
if underlying == nil {
underlying = http.DefaultTransport
}
return &tracingTransport{
cfg: cfg,
underlying: otelhttp.NewTransport(underlying),
}
}
func (t *tracingTransport) RoundTrip(req *http.Request) (*http.Response, error) {
aibridgeCtx, hasAIBridgeCtx := req.Context().Value(tracingContextKey{}).(tracingContext)
resp, err := t.underlying.RoundTrip(req)
if hasAIBridgeCtx {
ctx := req.Context()
if resp != nil && resp.Request != nil {
ctx = resp.Request.Context()
}
span := trace.SpanFromContext(ctx)
if span.IsRecording() {
span.SetAttributes(
attribute.String("aibridge.provider", aibridgeCtx.provider),
attribute.String("aibridge.model", aibridgeCtx.model),
attribute.Bool("aibridge.stream", aibridgeCtx.stream),
attribute.Int("aibridge.request_num", aibridgeCtx.requestNum),
attribute.String("aibridge.mode", string(aibridgeCtx.mode)),
)
}
}
return resp, err
}
type Runner struct {
client *codersdk.Client
cfg Config
strategy requestModeStrategy
providerStrategy ProviderStrategy
clock quartz.Clock
httpClient *http.Client
requestCount int64
successCount int64
failureCount int64
totalDuration time.Duration
totalTokens int64
}
func NewRunner(client *codersdk.Client, cfg Config) *Runner {
httpTimeout := cfg.HTTPTimeout
if httpTimeout <= 0 {
httpTimeout = 30 * time.Second
}
return &Runner{
client: client,
cfg: cfg,
strategy: cfg.NewStrategy(client),
providerStrategy: NewProviderStrategy(cfg.Provider),
clock: quartz.NewReal(),
httpClient: &http.Client{
Timeout: httpTimeout,
Transport: newTracingTransport(cfg, http.DefaultTransport),
},
}
}
func (r *Runner) WithClock(clock quartz.Clock) *Runner {
r.clock = clock
return r
}
var (
_ harness.Runnable = &Runner{}
_ harness.Cleanable = &Runner{}
_ harness.Collectable = &Runner{}
)
func (r *Runner) Run(ctx context.Context, id string, logs io.Writer) error {
ctx, span := tracing.StartSpan(ctx)
defer span.End()
logs = loadtestutil.NewSyncWriter(logs)
logger := slog.Make(sloghuman.Sink(logs)).Leveled(slog.LevelDebug)
requestURL, token, err := r.strategy.Setup(ctx, id, logs)
if err != nil {
return xerrors.Errorf("strategy setup: %w", err)
}
requestCount := r.cfg.RequestCount
if requestCount <= 0 {
requestCount = 1
}
model := r.providerStrategy.DefaultModel()
logger.Info(ctx, "bridge runner is ready",
slog.F("request_count", requestCount),
slog.F("model", model),
slog.F("stream", r.cfg.Stream),
)
for i := 0; i < requestCount; i++ {
if err := r.makeRequest(ctx, logger, requestURL, token, model, i); err != nil {
logger.Warn(ctx, "bridge request failed",
slog.F("request_num", i+1),
slog.F("error_type", "request_failed"),
slog.Error(err),
)
r.cfg.Metrics.AddError("request")
r.cfg.Metrics.AddRequest("failure")
r.failureCount++
// Continue making requests even if one fails
continue
}
r.successCount++
r.cfg.Metrics.AddRequest("success")
r.requestCount++
}
logger.Info(ctx, "bridge runner completed",
slog.F("total_requests", r.requestCount),
slog.F("success", r.successCount),
slog.F("failure", r.failureCount),
)
// Fail the run if any request failed
if r.failureCount > 0 {
return xerrors.Errorf("bridge runner failed: %d out of %d requests failed", r.failureCount, requestCount)
}
return nil
}
func (r *Runner) makeRequest(ctx context.Context, logger slog.Logger, url, token, model string, requestNum int) error {
start := r.clock.Now()
ctx = context.WithValue(ctx, tracingContextKey{}, tracingContext{
provider: r.cfg.Provider,
model: model,
stream: r.cfg.Stream,
requestNum: requestNum + 1,
mode: r.cfg.Mode,
})
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(r.cfg.RequestBody))
if err != nil {
return xerrors.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
if token != "" {
req.Header.Set("Authorization", "Bearer "+token)
}
logger.Debug(ctx, "making bridge request",
slog.F("url", url),
slog.F("request_num", requestNum+1),
slog.F("model", model),
)
resp, err := r.httpClient.Do(req)
if err != nil {
span := trace.SpanFromContext(req.Context())
if span.IsRecording() {
span.RecordError(err)
}
logger.Warn(ctx, "request failed during execution",
slog.F("request_num", requestNum+1),
slog.Error(err),
)
return xerrors.Errorf("execute request: %w", err)
}
defer resp.Body.Close()
span := trace.SpanFromContext(req.Context())
if span.IsRecording() {
span.SetAttributes(semconv.HTTPStatusCodeKey.Int(resp.StatusCode))
span.SetStatus(httpconv.ClientStatus(resp.StatusCode))
}
duration := r.clock.Since(start)
r.totalDuration += duration
r.cfg.Metrics.ObserveDuration(duration.Seconds())
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
err := xerrors.Errorf("request failed with status %d: %s", resp.StatusCode, string(body))
span.RecordError(err)
return err
}
if r.cfg.Stream {
err := r.handleStreamingResponse(ctx, logger, resp)
if err != nil {
span.RecordError(err)
return err
}
return nil
}
return r.handleNonStreamingResponse(ctx, logger, resp)
}
func (r *Runner) handleNonStreamingResponse(ctx context.Context, logger slog.Logger, resp *http.Response) error {
if r.cfg.Provider == "anthropic" {
return r.handleAnthropicResponse(ctx, logger, resp)
}
return r.handleOpenAIResponse(ctx, logger, resp)
}
func (r *Runner) handleOpenAIResponse(ctx context.Context, logger slog.Logger, resp *http.Response) error {
var response struct {
ID string `json:"id"`
Model string `json:"model"`
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
}
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
return xerrors.Errorf("decode response: %w", err)
}
if len(response.Choices) > 0 {
assistantContent := response.Choices[0].Message.Content
logger.Debug(ctx, "received response",
slog.F("response_id", response.ID),
slog.F("content_length", len(assistantContent)),
)
}
if response.Usage.TotalTokens > 0 {
r.totalTokens += int64(response.Usage.TotalTokens)
r.cfg.Metrics.AddTokens("input", int64(response.Usage.PromptTokens))
r.cfg.Metrics.AddTokens("output", int64(response.Usage.CompletionTokens))
}
return nil
}
func (r *Runner) handleAnthropicResponse(ctx context.Context, logger slog.Logger, resp *http.Response) error {
var response struct {
ID string `json:"id"`
Model string `json:"model"`
Content []struct {
Type string `json:"type"`
Text string `json:"text"`
} `json:"content"`
Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
} `json:"usage"`
}
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
return xerrors.Errorf("decode response: %w", err)
}
var assistantContent string
if len(response.Content) > 0 {
assistantContent = response.Content[0].Text
logger.Debug(ctx, "received response",
slog.F("response_id", response.ID),
slog.F("content_length", len(assistantContent)),
)
}
totalTokens := response.Usage.InputTokens + response.Usage.OutputTokens
if totalTokens > 0 {
r.totalTokens += int64(totalTokens)
r.cfg.Metrics.AddTokens("input", int64(response.Usage.InputTokens))
r.cfg.Metrics.AddTokens("output", int64(response.Usage.OutputTokens))
}
return nil
}
func (*Runner) handleStreamingResponse(ctx context.Context, logger slog.Logger, resp *http.Response) error {
buf := make([]byte, 4096)
totalRead := 0
for {
// Check for context cancellation before each read
if ctx.Err() != nil {
logger.Warn(ctx, "streaming response canceled",
slog.F("bytes_read", totalRead),
slog.Error(ctx.Err()),
)
return xerrors.Errorf("stream canceled: %w", ctx.Err())
}
n, err := resp.Body.Read(buf)
if n > 0 {
totalRead += n
}
if err == io.EOF {
break
}
if err != nil {
// Check if error is due to context cancellation
if xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) {
logger.Warn(ctx, "streaming response read canceled",
slog.F("bytes_read", totalRead),
slog.Error(err),
)
return xerrors.Errorf("stream read canceled: %w", err)
}
logger.Warn(ctx, "streaming response read error",
slog.F("bytes_read", totalRead),
slog.Error(err),
)
return xerrors.Errorf("read stream: %w", err)
}
}
logger.Debug(ctx, "received streaming response", slog.F("bytes_read", totalRead))
return nil
}
func (r *Runner) Cleanup(ctx context.Context, id string, logs io.Writer) error {
return r.strategy.Cleanup(ctx, id, logs)
}
func (r *Runner) GetMetrics() map[string]any {
avgDuration := time.Duration(0)
if r.requestCount > 0 {
avgDuration = r.totalDuration / time.Duration(r.requestCount)
}
return map[string]any{
"request_count": r.requestCount,
"success_count": r.successCount,
"failure_count": r.failureCount,
"total_duration": r.totalDuration.String(),
"avg_duration": avgDuration.String(),
"total_tokens": r.totalTokens,
}
}
+117
View File
@@ -0,0 +1,117 @@
package bridge
import (
"context"
"fmt"
"io"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/sloghuman"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/scaletest/createusers"
)
type requestModeStrategy interface {
Setup(ctx context.Context, id string, logs io.Writer) (url string, token string, err error)
Cleanup(ctx context.Context, id string, logs io.Writer) error
}
// bridgeStrategy creates users via Coder and routes requests through AI Bridge.
type bridgeStrategy struct {
client *codersdk.Client
provider string
metrics *Metrics
userConfig createusers.Config
createUserRunner *createusers.Runner
}
type bridgeStrategyConfig struct {
Client *codersdk.Client
Provider string
Metrics *Metrics
User createusers.Config
}
func newBridgeStrategy(cfg bridgeStrategyConfig) *bridgeStrategy {
return &bridgeStrategy{
client: cfg.Client,
provider: cfg.Provider,
metrics: cfg.Metrics,
userConfig: cfg.User,
}
}
func (s *bridgeStrategy) Setup(ctx context.Context, id string, logs io.Writer) (requestURL string, token string, err error) {
logger := slog.Make(sloghuman.Sink(logs)).Leveled(slog.LevelDebug)
s.client.SetLogger(logger)
s.client.SetLogBodies(true)
s.createUserRunner = createusers.NewRunner(s.client, s.userConfig)
newUserAndToken, err := s.createUserRunner.RunReturningUser(ctx, id, logs)
if err != nil {
s.metrics.AddError("create_user")
return "", "", xerrors.Errorf("create user: %w", err)
}
newUser := newUserAndToken.User
token = newUserAndToken.SessionToken
logger.Info(ctx, "runner user created",
slog.F("username", newUser.Username),
slog.F("user_id", newUser.ID.String()),
)
if s.provider == "anthropic" {
requestURL = fmt.Sprintf("%s/api/v2/aibridge/anthropic/v1/messages", s.client.URL)
} else {
requestURL = fmt.Sprintf("%s/api/v2/aibridge/openai/v1/chat/completions", s.client.URL)
}
logger.Info(ctx, "bridge runner in bridge mode",
slog.F("url", requestURL),
slog.F("provider", s.provider),
)
return requestURL, token, nil
}
func (s *bridgeStrategy) Cleanup(ctx context.Context, id string, logs io.Writer) error {
if s.createUserRunner == nil {
return nil
}
_, _ = fmt.Fprintln(logs, "Cleaning up user...")
if err := s.createUserRunner.Cleanup(ctx, id, logs); err != nil {
return xerrors.Errorf("cleanup user: %w", err)
}
return nil
}
// directStrategy makes requests directly to an upstream URL.
type directStrategy struct {
upstreamURL string
}
type directStrategyConfig struct {
UpstreamURL string
}
func newDirectStrategy(cfg directStrategyConfig) *directStrategy {
return &directStrategy{
upstreamURL: cfg.UpstreamURL,
}
}
func (s *directStrategy) Setup(ctx context.Context, _ string, logs io.Writer) (requestURL string, _ string, err error) {
logger := slog.Make(sloghuman.Sink(logs)).Leveled(slog.LevelDebug)
logger.Info(ctx, "bridge runner in direct mode", slog.F("url", s.upstreamURL))
return s.upstreamURL, "", err
}
func (*directStrategy) Cleanup(_ context.Context, _ string, _ io.Writer) error {
// Direct mode has no resources to clean up.
return nil
}
+545
View File
@@ -0,0 +1,545 @@
package llmmock
import (
"context"
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
"strings"
"time"
"github.com/google/uuid"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/propagation"
semconv "go.opentelemetry.io/otel/semconv/v1.14.0"
"go.opentelemetry.io/otel/semconv/v1.14.0/httpconv"
"go.opentelemetry.io/otel/semconv/v1.14.0/netconv"
"go.opentelemetry.io/otel/trace"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/pproflabel"
"github.com/coder/coder/v2/coderd/tracing"
)
// Server wraps the LLM mock server and provides an HTTP API to retrieve requests.
type Server struct {
httpServer *http.Server
httpListener net.Listener
logger slog.Logger
address string
artificialLatency time.Duration
responsePayloadSize int
tracerProvider trace.TracerProvider
closeTracing func(context.Context) error
}
type Config struct {
Address string
Logger slog.Logger
ArtificialLatency time.Duration
ResponsePayloadSize int
PprofEnable bool
PprofAddress string
TraceEnable bool
}
type llmRequest struct {
Model string `json:"model"`
Stream bool `json:"stream,omitempty"`
}
type openAIMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type openAIResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []struct {
Index int `json:"index"`
Message openAIMessage `json:"message"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
}
type anthropicResponse struct {
ID string `json:"id"`
Type string `json:"type"`
Role string `json:"role"`
Content []struct {
Type string `json:"type"`
Text string `json:"text"`
} `json:"content"`
Model string `json:"model"`
StopReason string `json:"stop_reason"`
StopSequence *string `json:"stop_sequence"`
Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
} `json:"usage"`
}
func (s *Server) Start(ctx context.Context, cfg Config) error {
s.address = cfg.Address
s.logger = cfg.Logger
s.artificialLatency = cfg.ArtificialLatency
s.responsePayloadSize = cfg.ResponsePayloadSize
if cfg.TraceEnable {
otel.SetTextMapPropagator(
propagation.NewCompositeTextMapPropagator(
propagation.TraceContext{},
propagation.Baggage{},
),
)
tracerProvider, closeTracing, err := tracing.TracerProvider(ctx, "llm-mock", tracing.TracerOpts{
Default: cfg.TraceEnable,
})
if err != nil {
s.logger.Warn(ctx, "failed to initialize tracing", slog.Error(err))
} else {
s.tracerProvider = tracerProvider
s.closeTracing = closeTracing
}
}
if err := s.startAPIServer(ctx); err != nil {
return xerrors.Errorf("start API server: %w", err)
}
return nil
}
func (s *Server) Stop() error {
if s.httpServer != nil {
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.httpServer.Shutdown(shutdownCtx); err != nil {
return xerrors.Errorf("shutdown HTTP server: %w", err)
}
}
if s.closeTracing != nil {
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.closeTracing(shutdownCtx); err != nil {
s.logger.Warn(shutdownCtx, "failed to close tracing", slog.Error(err))
}
}
return nil
}
func (s *Server) APIAddress() string {
return fmt.Sprintf("http://%s", s.httpListener.Addr().String())
}
func (s *Server) startAPIServer(ctx context.Context) error {
mux := http.NewServeMux()
mux.HandleFunc("POST /v1/chat/completions", s.handleOpenAI)
mux.HandleFunc("POST /v1/messages", s.handleAnthropic)
var handler http.Handler = mux
if s.tracerProvider != nil {
handler = s.tracingMiddleware(handler)
}
s.httpServer = &http.Server{
Handler: handler,
ReadHeaderTimeout: 10 * time.Second,
}
listener, err := net.Listen("tcp", s.address)
if err != nil {
return xerrors.Errorf("listen on %s: %w", s.address, err)
}
s.httpListener = listener
pproflabel.Go(ctx, pproflabel.Service("llm-mock"), func(ctx context.Context) {
if err := s.httpServer.Serve(listener); err != nil && !errors.Is(err, http.ErrServerClosed) {
s.logger.Error(ctx, "http API server error", slog.Error(err))
}
})
return nil
}
func (s *Server) handleOpenAI(w http.ResponseWriter, r *http.Request) {
pproflabel.Do(r.Context(), pproflabel.Service("llm-mock"), func(ctx context.Context) {
s.handleOpenAIWithLabels(w, r.WithContext(ctx))
})
}
func (s *Server) handleOpenAIWithLabels(w http.ResponseWriter, r *http.Request) {
s.logger.Debug(r.Context(), "handling OpenAI request")
defer s.logger.Debug(r.Context(), "handled OpenAI request")
ctx := r.Context()
requestID := uuid.New()
now := time.Now()
var req llmRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.logger.Error(ctx, "failed to parse OpenAI request", slog.Error(err))
http.Error(w, "invalid request body", http.StatusBadRequest)
return
}
if s.artificialLatency > 0 {
time.Sleep(s.artificialLatency)
}
var resp openAIResponse
resp.ID = fmt.Sprintf("chatcmpl-%s", requestID.String()[:8])
resp.Object = "chat.completion"
resp.Created = now.Unix()
resp.Model = req.Model
var responseContent string
if s.responsePayloadSize > 0 {
pattern := "x"
repeated := strings.Repeat(pattern, s.responsePayloadSize)
responseContent = repeated[:s.responsePayloadSize]
} else {
responseContent = "This is a mock response from OpenAI."
}
resp.Choices = []struct {
Index int `json:"index"`
Message openAIMessage `json:"message"`
FinishReason string `json:"finish_reason"`
}{
{
Index: 0,
Message: openAIMessage{
Role: "assistant",
Content: responseContent,
},
FinishReason: "stop",
},
}
resp.Usage.PromptTokens = 10
resp.Usage.CompletionTokens = 5
resp.Usage.TotalTokens = 15
responseBody, _ := json.Marshal(resp)
if req.Stream {
s.sendOpenAIStream(ctx, w, resp)
} else {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
if _, err := w.Write(responseBody); err != nil {
s.logger.Error(ctx, "failed to write OpenAI response",
slog.F("request_id", requestID),
slog.Error(err),
slog.F("error_type", "write_error"),
slog.F("likely_cause", "network_error"),
)
}
}
}
func (s *Server) handleAnthropic(w http.ResponseWriter, r *http.Request) {
pproflabel.Do(r.Context(), pproflabel.Service("llm-mock"), func(ctx context.Context) {
s.handleAnthropicWithLabels(w, r.WithContext(ctx))
})
}
func (s *Server) handleAnthropicWithLabels(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
requestID := uuid.New()
var req llmRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.logger.Error(ctx, "failed to parse LLM request", slog.Error(err))
http.Error(w, "invalid request body", http.StatusBadRequest)
return
}
if s.artificialLatency > 0 {
time.Sleep(s.artificialLatency)
}
var resp anthropicResponse
resp.ID = fmt.Sprintf("msg_%s", requestID.String()[:8])
resp.Type = "message"
resp.Role = "assistant"
var responseText string
if s.responsePayloadSize > 0 {
pattern := "x"
repeated := strings.Repeat(pattern, s.responsePayloadSize)
responseText = repeated[:s.responsePayloadSize]
} else {
responseText = "This is a mock response from Anthropic."
}
resp.Content = []struct {
Type string `json:"type"`
Text string `json:"text"`
}{
{
Type: "text",
Text: responseText,
},
}
resp.Model = req.Model
resp.StopReason = "end_turn"
resp.Usage.InputTokens = 10
resp.Usage.OutputTokens = 5
responseBody, _ := json.Marshal(resp)
if req.Stream {
s.sendAnthropicStream(ctx, w, resp)
} else {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("anthropic-version", "2023-06-01")
w.WriteHeader(http.StatusOK)
if _, err := w.Write(responseBody); err != nil {
s.logger.Error(ctx, "failed to write Anthropic response",
slog.F("request_id", requestID),
slog.Error(err),
slog.F("error_type", "write_error"),
slog.F("likely_cause", "network_error"),
)
}
}
}
func (s *Server) sendOpenAIStream(ctx context.Context, w http.ResponseWriter, resp openAIResponse) {
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher)
if !ok {
s.logger.Error(ctx, "responseWriter does not support flushing",
slog.F("response_id", resp.ID),
)
return
}
writeChunk := func(data string) bool {
if _, err := fmt.Fprintf(w, "%s", data); err != nil {
s.logger.Error(ctx, "failed to write OpenAI stream chunk",
slog.F("response_id", resp.ID),
slog.Error(err),
slog.F("error_type", "write_error"),
slog.F("likely_cause", "network_error"),
)
return false
}
flusher.Flush()
return true
}
// Send initial chunk
chunk := map[string]interface{}{
"id": resp.ID,
"object": "chat.completion.chunk",
"created": resp.Created,
"model": resp.Model,
"choices": []map[string]interface{}{
{
"index": 0,
"delta": map[string]interface{}{
"role": "assistant",
"content": resp.Choices[0].Message.Content,
},
"finish_reason": nil,
},
},
}
chunkBytes, _ := json.Marshal(chunk)
if !writeChunk(fmt.Sprintf("data: %s\n\n", chunkBytes)) {
return
}
// Send final chunk
finalChunk := map[string]interface{}{
"id": resp.ID,
"object": "chat.completion.chunk",
"created": resp.Created,
"model": resp.Model,
"choices": []map[string]interface{}{
{
"index": 0,
"delta": map[string]interface{}{},
"finish_reason": resp.Choices[0].FinishReason,
},
},
}
finalChunkBytes, _ := json.Marshal(finalChunk)
if !writeChunk(fmt.Sprintf("data: %s\n\n", finalChunkBytes)) {
return
}
writeChunk("data: [DONE]\n\n")
}
func (s *Server) sendAnthropicStream(ctx context.Context, w http.ResponseWriter, resp anthropicResponse) {
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("anthropic-version", "2023-06-01")
w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher)
if !ok {
s.logger.Error(ctx, "responseWriter does not support flushing",
slog.F("response_id", resp.ID),
)
return
}
writeChunk := func(data string) bool {
if _, err := fmt.Fprintf(w, "%s", data); err != nil {
s.logger.Error(ctx, "failed to write Anthropic stream chunk",
slog.F("response_id", resp.ID),
slog.Error(err),
slog.F("error_type", "write_error"),
slog.F("likely_cause", "network_error"),
)
return false
}
flusher.Flush()
return true
}
startEvent := map[string]interface{}{
"type": "message_start",
"message": map[string]interface{}{
"id": resp.ID,
"type": resp.Type,
"role": resp.Role,
"model": resp.Model,
},
}
startBytes, _ := json.Marshal(startEvent)
if !writeChunk(fmt.Sprintf("data: %s\n\n", startBytes)) {
return
}
// Send content_block_start event
contentStartEvent := map[string]interface{}{
"type": "content_block_start",
"index": 0,
"content_block": map[string]interface{}{
"type": "text",
"text": resp.Content[0].Text,
},
}
contentStartBytes, _ := json.Marshal(contentStartEvent)
if !writeChunk(fmt.Sprintf("data: %s\n\n", contentStartBytes)) {
return
}
// Send content_block_delta event
deltaEvent := map[string]interface{}{
"type": "content_block_delta",
"index": 0,
"delta": map[string]interface{}{
"type": "text_delta",
"text": resp.Content[0].Text,
},
}
deltaBytes, _ := json.Marshal(deltaEvent)
if !writeChunk(fmt.Sprintf("data: %s\n\n", deltaBytes)) {
return
}
// Send content_block_stop event
contentStopEvent := map[string]interface{}{
"type": "content_block_stop",
"index": 0,
}
contentStopBytes, _ := json.Marshal(contentStopEvent)
if !writeChunk(fmt.Sprintf("data: %s\n\n", contentStopBytes)) {
return
}
// Send message_delta event
deltaMsgEvent := map[string]interface{}{
"type": "message_delta",
"delta": map[string]interface{}{
"stop_reason": resp.StopReason,
"stop_sequence": resp.StopSequence,
},
"usage": resp.Usage,
}
deltaMsgBytes, _ := json.Marshal(deltaMsgEvent)
if !writeChunk(fmt.Sprintf("data: %s\n\n", deltaMsgBytes)) {
return
}
// Send message_stop event
stopEvent := map[string]interface{}{
"type": "message_stop",
}
stopBytes, _ := json.Marshal(stopEvent)
writeChunk(fmt.Sprintf("data: %s\n\n", stopBytes))
}
func (s *Server) tracingMiddleware(next http.Handler) http.Handler {
tracer := s.tracerProvider.Tracer("llm-mock")
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
// Wrap response writer with StatusWriter for tracing
sw := &tracing.StatusWriter{ResponseWriter: rw}
// Extract trace context from headers
propagator := otel.GetTextMapPropagator()
hc := propagation.HeaderCarrier(r.Header)
ctx := propagator.Extract(r.Context(), hc)
// Start span with initial name (will be updated after handler)
ctx, span := tracer.Start(ctx, fmt.Sprintf("%s %s", r.Method, r.RequestURI))
defer span.End()
r = r.WithContext(ctx)
// Inject trace context into response headers
if span.SpanContext().HasTraceID() && span.SpanContext().HasSpanID() {
rw.Header().Set("X-Trace-ID", span.SpanContext().TraceID().String())
rw.Header().Set("X-Span-ID", span.SpanContext().SpanID().String())
hc := propagation.HeaderCarrier(rw.Header())
propagator.Inject(ctx, hc)
}
// Execute the handler
next.ServeHTTP(sw, r)
// Update span with final route and response information
route := r.URL.Path
span.SetName(fmt.Sprintf("%s %s", r.Method, route))
span.SetAttributes(netconv.Transport("tcp"))
span.SetAttributes(httpconv.ServerRequest("llm-mock", r)...)
span.SetAttributes(semconv.HTTPRouteKey.String(route))
status := sw.Status
if status == 0 {
status = http.StatusOK
}
span.SetAttributes(semconv.HTTPStatusCodeKey.Int(status))
span.SetStatus(httpconv.ServerStatus(status))
})
}