From 0ebe8e57adcda321e67d19c11b981bb882bee645 Mon Sep 17 00:00:00 2001 From: Sas Swart Date: Thu, 15 Jan 2026 17:05:46 +0200 Subject: [PATCH] chore: add scaletesting tools for aibridge (#21279) This pull request adds scaletesting tools for aibridge. See https://www.notion.so/Scale-tests-2c5d579be5928088b565d15dd8bdea41?source=copy_link for information and instructions. closes: https://github.com/coder/internal/issues/1156 closes: https://github.com/coder/internal/issues/1155 closes: https://github.com/coder/internal/issues/1158 --- cli/exp_scaletest.go | 2 + cli/exp_scaletest_bridge.go | 278 ++++++++++++++++++ cli/exp_scaletest_llmmock.go | 118 ++++++++ go.mod | 2 +- scaletest/bridge/config.go | 150 ++++++++++ scaletest/bridge/metrics.go | 72 +++++ scaletest/bridge/provider.go | 134 +++++++++ scaletest/bridge/run.go | 391 +++++++++++++++++++++++++ scaletest/bridge/strategy.go | 117 ++++++++ scaletest/llmmock/server.go | 545 +++++++++++++++++++++++++++++++++++ 10 files changed, 1808 insertions(+), 1 deletion(-) create mode 100644 cli/exp_scaletest_bridge.go create mode 100644 cli/exp_scaletest_llmmock.go create mode 100644 scaletest/bridge/config.go create mode 100644 scaletest/bridge/metrics.go create mode 100644 scaletest/bridge/provider.go create mode 100644 scaletest/bridge/run.go create mode 100644 scaletest/bridge/strategy.go create mode 100644 scaletest/llmmock/server.go diff --git a/cli/exp_scaletest.go b/cli/exp_scaletest.go index 5b4f16322f..02bd80763a 100644 --- a/cli/exp_scaletest.go +++ b/cli/exp_scaletest.go @@ -68,6 +68,8 @@ func (r *RootCmd) scaletestCmd() *serpent.Command { r.scaletestTaskStatus(), r.scaletestSMTP(), r.scaletestPrebuilds(), + r.scaletestBridge(), + r.scaletestLLMMock(), }, } diff --git a/cli/exp_scaletest_bridge.go b/cli/exp_scaletest_bridge.go new file mode 100644 index 0000000000..f7dda90471 --- /dev/null +++ b/cli/exp_scaletest_bridge.go @@ -0,0 +1,278 @@ +//go:build !slim + +package cli + +import ( + "fmt" + "net/http" + "os/signal" + "strconv" + "text/tabwriter" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promhttp" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/scaletest/bridge" + "github.com/coder/coder/v2/scaletest/createusers" + "github.com/coder/coder/v2/scaletest/harness" + "github.com/coder/serpent" +) + +func (r *RootCmd) scaletestBridge() *serpent.Command { + var ( + concurrentUsers int64 + noCleanup bool + mode string + upstreamURL string + provider string + requestsPerUser int64 + useStreamingAPI bool + requestPayloadSize int64 + numMessages int64 + httpTimeout time.Duration + + timeoutStrategy = &timeoutFlags{} + cleanupStrategy = newScaletestCleanupStrategy() + output = &scaletestOutputFlags{} + prometheusFlags = &scaletestPrometheusFlags{} + ) + + cmd := &serpent.Command{ + Use: "bridge", + Short: "Generate load on the AI Bridge service.", + Long: `Generate load for AI Bridge testing. Supports two modes: 'bridge' mode routes requests through the Coder AI Bridge, 'direct' mode makes requests directly to an upstream URL (useful for baseline comparisons). + +Examples: + # Test OpenAI API through bridge + coder scaletest bridge --mode bridge --provider openai --concurrent-users 10 --request-count 5 --num-messages 10 + + # Test Anthropic API through bridge + coder scaletest bridge --mode bridge --provider anthropic --concurrent-users 10 --request-count 5 --num-messages 10 + + # Test directly against mock server + coder scaletest bridge --mode direct --provider openai --upstream-url http://localhost:8080/v1/chat/completions +`, + Handler: func(inv *serpent.Invocation) error { + ctx := inv.Context() + client, err := r.InitClient(inv) + if err != nil { + return err + } + client.HTTPClient = &http.Client{ + Transport: &codersdk.HeaderTransport{ + Transport: http.DefaultTransport, + Header: map[string][]string{ + codersdk.BypassRatelimitHeader: {"true"}, + }, + }, + } + reg := prometheus.NewRegistry() + metrics := bridge.NewMetrics(reg) + + logger := inv.Logger + prometheusSrvClose := ServeHandler(ctx, logger, promhttp.HandlerFor(reg, promhttp.HandlerOpts{}), prometheusFlags.Address, "prometheus") + defer prometheusSrvClose() + + defer func() { + _, _ = fmt.Fprintf(inv.Stderr, "Waiting %s for prometheus metrics to be scraped\n", prometheusFlags.Wait) + <-time.After(prometheusFlags.Wait) + }() + + notifyCtx, stop := signal.NotifyContext(ctx, StopSignals...) + defer stop() + ctx = notifyCtx + + var userConfig createusers.Config + if bridge.RequestMode(mode) == bridge.RequestModeBridge { + me, err := requireAdmin(ctx, client) + if err != nil { + return err + } + if len(me.OrganizationIDs) == 0 { + return xerrors.Errorf("admin user must have at least one organization") + } + userConfig = createusers.Config{ + OrganizationID: me.OrganizationIDs[0], + } + _, _ = fmt.Fprintln(inv.Stderr, "Bridge mode: creating users and making requests through AI Bridge...") + } else { + _, _ = fmt.Fprintf(inv.Stderr, "Direct mode: making requests directly to %s\n", upstreamURL) + } + + outputs, err := output.parse() + if err != nil { + return xerrors.Errorf("parse output flags: %w", err) + } + + config := bridge.Config{ + Mode: bridge.RequestMode(mode), + Metrics: metrics, + Provider: provider, + RequestCount: int(requestsPerUser), + Stream: useStreamingAPI, + RequestPayloadSize: int(requestPayloadSize), + NumMessages: int(numMessages), + HTTPTimeout: httpTimeout, + UpstreamURL: upstreamURL, + User: userConfig, + } + if err := config.Validate(); err != nil { + return xerrors.Errorf("validate config: %w", err) + } + if err := config.PrepareRequestBody(); err != nil { + return xerrors.Errorf("prepare request body: %w", err) + } + + th := harness.NewTestHarness(timeoutStrategy.wrapStrategy(harness.ConcurrentExecutionStrategy{}), cleanupStrategy.toStrategy()) + + for i := range concurrentUsers { + id := strconv.Itoa(int(i)) + name := fmt.Sprintf("bridge-%s", id) + var runner harness.Runnable = bridge.NewRunner(client, config) + th.AddRun(name, id, runner) + } + + _, _ = fmt.Fprintln(inv.Stderr, "Bridge scaletest configuration:") + tw := tabwriter.NewWriter(inv.Stderr, 0, 0, 2, ' ', 0) + for _, opt := range inv.Command.Options { + if opt.Hidden || opt.ValueSource == serpent.ValueSourceNone { + continue + } + _, _ = fmt.Fprintf(tw, " %s:\t%s", opt.Name, opt.Value.String()) + if opt.ValueSource != serpent.ValueSourceDefault { + _, _ = fmt.Fprintf(tw, "\t(from %s)", opt.ValueSource) + } + _, _ = fmt.Fprintln(tw) + } + _ = tw.Flush() + + _, _ = fmt.Fprintln(inv.Stderr, "\nRunning bridge scaletest...") + testCtx, testCancel := timeoutStrategy.toContext(ctx) + defer testCancel() + err = th.Run(testCtx) + if err != nil { + return xerrors.Errorf("run test harness (harness failure, not a test failure): %w", err) + } + + // If the command was interrupted, skip stats. + if notifyCtx.Err() != nil { + return notifyCtx.Err() + } + + res := th.Results() + + for _, o := range outputs { + err = o.write(res, inv.Stdout) + if err != nil { + return xerrors.Errorf("write output %q to %q: %w", o.format, o.path, err) + } + } + + if !noCleanup { + _, _ = fmt.Fprintln(inv.Stderr, "\nCleaning up...") + cleanupCtx, cleanupCancel := cleanupStrategy.toContext(ctx) + defer cleanupCancel() + err = th.Cleanup(cleanupCtx) + if err != nil { + return xerrors.Errorf("cleanup tests: %w", err) + } + } + + if res.TotalFail > 0 { + return xerrors.New("load test failed, see above for more details") + } + + return nil + }, + } + + cmd.Options = serpent.OptionSet{ + { + Flag: "concurrent-users", + FlagShorthand: "c", + Env: "CODER_SCALETEST_BRIDGE_CONCURRENT_USERS", + Description: "Required: Number of concurrent users.", + Value: serpent.Validate(serpent.Int64Of(&concurrentUsers), func(value *serpent.Int64) error { + if value == nil || value.Value() <= 0 { + return xerrors.Errorf("--concurrent-users must be greater than 0") + } + return nil + }), + Required: true, + }, + { + Flag: "mode", + Env: "CODER_SCALETEST_BRIDGE_MODE", + Default: "direct", + Description: "Request mode: 'bridge' (create users and use AI Bridge) or 'direct' (make requests directly to upstream-url).", + Value: serpent.EnumOf(&mode, string(bridge.RequestModeBridge), string(bridge.RequestModeDirect)), + }, + { + Flag: "upstream-url", + Env: "CODER_SCALETEST_BRIDGE_UPSTREAM_URL", + Description: "URL to make requests to directly (required in direct mode, e.g., http://localhost:8080/v1/chat/completions).", + Value: serpent.StringOf(&upstreamURL), + }, + { + Flag: "provider", + Env: "CODER_SCALETEST_BRIDGE_PROVIDER", + Default: "openai", + Description: "API provider to use.", + Value: serpent.EnumOf(&provider, "openai", "anthropic"), + }, + { + Flag: "request-count", + Env: "CODER_SCALETEST_BRIDGE_REQUEST_COUNT", + Default: "1", + Description: "Number of sequential requests to make per runner.", + Value: serpent.Validate(serpent.Int64Of(&requestsPerUser), func(value *serpent.Int64) error { + if value == nil || value.Value() <= 0 { + return xerrors.Errorf("--request-count must be greater than 0") + } + return nil + }), + }, + { + Flag: "stream", + Env: "CODER_SCALETEST_BRIDGE_STREAM", + Description: "Enable streaming requests.", + Value: serpent.BoolOf(&useStreamingAPI), + }, + { + Flag: "request-payload-size", + Env: "CODER_SCALETEST_BRIDGE_REQUEST_PAYLOAD_SIZE", + Default: "1024", + Description: "Size in bytes of the request payload (user message content). If 0, uses default message content.", + Value: serpent.Int64Of(&requestPayloadSize), + }, + { + Flag: "num-messages", + Env: "CODER_SCALETEST_BRIDGE_NUM_MESSAGES", + Default: "1", + Description: "Number of messages to include in the conversation.", + Value: serpent.Int64Of(&numMessages), + }, + { + Flag: "no-cleanup", + Env: "CODER_SCALETEST_NO_CLEANUP", + Description: "Do not clean up resources after the test completes.", + Value: serpent.BoolOf(&noCleanup), + }, + { + Flag: "http-timeout", + Env: "CODER_SCALETEST_BRIDGE_HTTP_TIMEOUT", + Default: "30s", + Description: "Timeout for individual HTTP requests to the upstream provider.", + Value: serpent.DurationOf(&httpTimeout), + }, + } + + timeoutStrategy.attach(&cmd.Options) + cleanupStrategy.attach(&cmd.Options) + output.attach(&cmd.Options) + prometheusFlags.attach(&cmd.Options) + return cmd +} diff --git a/cli/exp_scaletest_llmmock.go b/cli/exp_scaletest_llmmock.go new file mode 100644 index 0000000000..2cb6312407 --- /dev/null +++ b/cli/exp_scaletest_llmmock.go @@ -0,0 +1,118 @@ +//go:build !slim + +package cli + +import ( + "fmt" + "os/signal" + "time" + + "golang.org/x/xerrors" + + "cdr.dev/slog/v3" + "cdr.dev/slog/v3/sloggers/sloghuman" + "github.com/coder/coder/v2/scaletest/llmmock" + "github.com/coder/serpent" +) + +func (*RootCmd) scaletestLLMMock() *serpent.Command { + var ( + address string + artificialLatency time.Duration + responsePayloadSize int64 + + pprofEnable bool + pprofAddress string + + traceEnable bool + ) + cmd := &serpent.Command{ + Use: "llm-mock", + Short: "Start a mock LLM API server for testing", + Long: `Start a mock LLM API server that simulates OpenAI and Anthropic APIs`, + Handler: func(inv *serpent.Invocation) error { + ctx, stop := signal.NotifyContext(inv.Context(), StopSignals...) + defer stop() + + logger := slog.Make(sloghuman.Sink(inv.Stderr)).Leveled(slog.LevelInfo) + + if pprofEnable { + closePprof := ServeHandler(ctx, logger, nil, pprofAddress, "pprof") + defer closePprof() + logger.Info(ctx, "pprof server started", slog.F("address", pprofAddress)) + } + + config := llmmock.Config{ + Address: address, + Logger: logger, + ArtificialLatency: artificialLatency, + ResponsePayloadSize: int(responsePayloadSize), + PprofEnable: pprofEnable, + PprofAddress: pprofAddress, + TraceEnable: traceEnable, + } + srv := new(llmmock.Server) + + if err := srv.Start(ctx, config); err != nil { + return xerrors.Errorf("start mock LLM server: %w", err) + } + defer func() { + _ = srv.Stop() + }() + + _, _ = fmt.Fprintf(inv.Stdout, "Mock LLM API server started on %s\n", srv.APIAddress()) + _, _ = fmt.Fprintf(inv.Stdout, " OpenAI endpoint: %s/v1/chat/completions\n", srv.APIAddress()) + _, _ = fmt.Fprintf(inv.Stdout, " Anthropic endpoint: %s/v1/messages\n", srv.APIAddress()) + + <-ctx.Done() + return nil + }, + } + + cmd.Options = []serpent.Option{ + { + Flag: "address", + Env: "CODER_SCALETEST_LLM_MOCK_ADDRESS", + Default: "localhost", + Description: "Address to bind the mock LLM API server. Can include a port (e.g., 'localhost:8080' or ':8080'). Uses a random port if no port is specified.", + Value: serpent.StringOf(&address), + }, + { + Flag: "artificial-latency", + Env: "CODER_SCALETEST_LLM_MOCK_ARTIFICIAL_LATENCY", + Default: "0s", + Description: "Artificial latency to add to each response (e.g., 100ms, 1s). Simulates slow upstream processing.", + Value: serpent.DurationOf(&artificialLatency), + }, + { + Flag: "response-payload-size", + Env: "CODER_SCALETEST_LLM_MOCK_RESPONSE_PAYLOAD_SIZE", + Default: "0", + Description: "Size in bytes of the response payload. If 0, uses default context-aware responses.", + Value: serpent.Int64Of(&responsePayloadSize), + }, + { + Flag: "pprof-enable", + Env: "CODER_SCALETEST_LLM_MOCK_PPROF_ENABLE", + Default: "false", + Description: "Serve pprof metrics on the address defined by pprof-address.", + Value: serpent.BoolOf(&pprofEnable), + }, + { + Flag: "pprof-address", + Env: "CODER_SCALETEST_LLM_MOCK_PPROF_ADDRESS", + Default: "127.0.0.1:6060", + Description: "The bind address to serve pprof.", + Value: serpent.StringOf(&pprofAddress), + }, + { + Flag: "trace-enable", + Env: "CODER_SCALETEST_LLM_MOCK_TRACE_ENABLE", + Default: "false", + Description: "Whether application tracing data is collected. It exports to a backend configured by environment variables. See: https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/protocol/exporter.md.", + Value: serpent.BoolOf(&traceEnable), + }, + } + + return cmd +} diff --git a/go.mod b/go.mod index ec6e41ec3e..1e41616f2c 100644 --- a/go.mod +++ b/go.mod @@ -437,7 +437,7 @@ require ( go.opentelemetry.io/collector/pdata/pprofile v0.121.0 // indirect go.opentelemetry.io/collector/semconv v0.123.0 // indirect go.opentelemetry.io/contrib v1.19.0 // indirect - go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.62.0 // indirect + go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.62.0 go.opentelemetry.io/otel/metric v1.38.0 // indirect go.opentelemetry.io/proto/otlp v1.7.0 // indirect go.uber.org/multierr v1.11.0 // indirect diff --git a/scaletest/bridge/config.go b/scaletest/bridge/config.go new file mode 100644 index 0000000000..92ff9b5739 --- /dev/null +++ b/scaletest/bridge/config.go @@ -0,0 +1,150 @@ +package bridge + +import ( + "encoding/json" + "time" + + "github.com/google/uuid" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/scaletest/createusers" +) + +type RequestMode string + +const ( + RequestModeBridge RequestMode = "bridge" + RequestModeDirect RequestMode = "direct" +) + +type Config struct { + // Mode determines how requests are made. + // "bridge": Create users in Coder and use their session tokens to make requests through AI Bridge. + // "direct": Make requests directly to UpstreamURL without user creation. + Mode RequestMode `json:"mode"` + + // User is the configuration for the user to create. + // Required in bridge mode. + User createusers.Config `json:"user"` + + // UpstreamURL is the URL to make requests to directly. + // Only used in direct mode. + UpstreamURL string `json:"upstream_url"` + + // Provider is the API provider to use: "openai" or "anthropic". + Provider string `json:"provider"` + + // RequestCount is the number of requests to make per runner. + RequestCount int `json:"request_count"` + + // Stream indicates whether to use streaming requests. + Stream bool `json:"stream"` + + // RequestPayloadSize is the size in bytes of the request payload (user message content). + // If 0, uses default message content. + RequestPayloadSize int `json:"request_payload_size"` + + // NumMessages is the number of messages to include in the conversation. + // Messages alternate between user and assistant roles, always ending with user. + // Must be greater than 0. + NumMessages int `json:"num_messages"` + + // HTTPTimeout is the timeout for individual HTTP requests to the upstream + // provider. This is separate from the job timeout which controls the overall + // test execution. + HTTPTimeout time.Duration `json:"http_timeout"` + + Metrics *Metrics `json:"-"` + + // RequestBody is the pre-serialized JSON request body. This is generated + // once by PrepareRequestBody and shared across all runners and requests. + RequestBody []byte `json:"-"` +} + +func (c Config) Validate() error { + if c.Metrics == nil { + return xerrors.New("metrics must be set") + } + + // Validate mode + if c.Mode != RequestModeBridge && c.Mode != RequestModeDirect { + return xerrors.New("mode must be either 'bridge' or 'direct'") + } + + if c.RequestCount <= 0 { + return xerrors.New("request_count must be greater than 0") + } + + // Validate provider + if c.Provider != "openai" && c.Provider != "anthropic" { + return xerrors.New("provider must be either 'openai' or 'anthropic'") + } + + if c.Mode == RequestModeDirect { + // In direct mode, UpstreamURL must be set. + if c.UpstreamURL == "" { + return xerrors.New("upstream_url must be set in direct mode") + } + return nil + } + + // In bridge mode, User config is required. + if c.User.OrganizationID == uuid.Nil { + return xerrors.New("user organization_id must be set in bridge mode") + } + + if err := c.User.Validate(); err != nil { + return xerrors.Errorf("user config: %w", err) + } + + if c.NumMessages <= 0 { + return xerrors.New("num_messages must be greater than 0") + } + + return nil +} + +func (c Config) NewStrategy(client *codersdk.Client) requestModeStrategy { + if c.Mode == RequestModeDirect { + return newDirectStrategy(directStrategyConfig{ + UpstreamURL: c.UpstreamURL, + }) + } + + return newBridgeStrategy(bridgeStrategyConfig{ + Client: client, + Provider: c.Provider, + Metrics: c.Metrics, + User: c.User, + }) +} + +// PrepareRequestBody generates the conversation and serializes the full request +// body once. This should be called before creating Runners so that all runners +// share the same pre-generated payload. +func (c *Config) PrepareRequestBody() error { + provider := NewProviderStrategy(c.Provider) + model := provider.DefaultModel() + + var formattedMessages []any + if c.RequestPayloadSize > 0 { + formattedMessages = generateConversation(provider, c.RequestPayloadSize, c.NumMessages) + } else { + messages := []message{{ + Role: "user", + Content: "Hello from the bridge load generator.", + }} + formattedMessages = provider.formatMessages(messages) + } + + reqBody := provider.buildRequestBody(model, formattedMessages, c.Stream) + + bodyBytes, err := json.Marshal(reqBody) + if err != nil { + return xerrors.Errorf("marshal request body: %w", err) + } + + c.RequestBody = bodyBytes + return nil +} diff --git a/scaletest/bridge/metrics.go b/scaletest/bridge/metrics.go new file mode 100644 index 0000000000..25a35f3e52 --- /dev/null +++ b/scaletest/bridge/metrics.go @@ -0,0 +1,72 @@ +package bridge + +import ( + "github.com/prometheus/client_golang/prometheus" +) + +type Metrics struct { + bridgeErrors *prometheus.CounterVec + bridgeRequests *prometheus.CounterVec + bridgeDuration prometheus.Histogram + bridgeTokensTotal *prometheus.CounterVec +} + +func NewMetrics(reg prometheus.Registerer) *Metrics { + if reg == nil { + reg = prometheus.DefaultRegisterer + } + + errors := prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: "coderd", + Subsystem: "scaletest", + Name: "bridge_errors_total", + Help: "Total number of bridge errors", + }, []string{"action"}) + + requests := prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: "coderd", + Subsystem: "scaletest", + Name: "bridge_requests_total", + Help: "Total number of bridge requests", + }, []string{"status"}) + + duration := prometheus.NewHistogram(prometheus.HistogramOpts{ + Namespace: "coderd", + Subsystem: "scaletest", + Name: "bridge_request_duration_seconds", + Help: "Duration of bridge requests in seconds", + Buckets: prometheus.DefBuckets, + }) + + tokens := prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: "coderd", + Subsystem: "scaletest", + Name: "bridge_response_tokens_total", + Help: "Total number of tokens in bridge responses", + }, []string{"type"}) + + reg.MustRegister(errors, requests, duration, tokens) + + return &Metrics{ + bridgeErrors: errors, + bridgeRequests: requests, + bridgeDuration: duration, + bridgeTokensTotal: tokens, + } +} + +func (m *Metrics) AddError(action string) { + m.bridgeErrors.WithLabelValues(action).Inc() +} + +func (m *Metrics) AddRequest(status string) { + m.bridgeRequests.WithLabelValues(status).Inc() +} + +func (m *Metrics) ObserveDuration(duration float64) { + m.bridgeDuration.Observe(duration) +} + +func (m *Metrics) AddTokens(tokenType string, count int64) { + m.bridgeTokensTotal.WithLabelValues(tokenType).Add(float64(count)) +} diff --git a/scaletest/bridge/provider.go b/scaletest/bridge/provider.go new file mode 100644 index 0000000000..c4f827b7c7 --- /dev/null +++ b/scaletest/bridge/provider.go @@ -0,0 +1,134 @@ +package bridge + +import ( + "encoding/json" + "strings" +) + +// ProviderStrategy handles provider-specific message formatting for LLM APIs. +type ProviderStrategy interface { + DefaultModel() string + formatMessages(messages []message) []any + buildRequestBody(model string, messages []any, stream bool) map[string]any +} + +type message struct { + Role string + Content string +} + +func NewProviderStrategy(provider string) ProviderStrategy { + switch provider { + case "anthropic": + return &anthropicProvider{} + default: + return &openAIProvider{} + } +} + +type openAIProvider struct{} + +func (*openAIProvider) DefaultModel() string { + return "gpt-4" +} + +func (*openAIProvider) formatMessages(messages []message) []any { + formatted := make([]any, 0, len(messages)) + for _, msg := range messages { + formatted = append(formatted, map[string]string{ + "role": msg.Role, + "content": msg.Content, + }) + } + return formatted +} + +func (*openAIProvider) buildRequestBody(model string, messages []any, stream bool) map[string]any { + return map[string]any{ + "model": model, + "messages": messages, + "stream": stream, + } +} + +type anthropicProvider struct{} + +func (*anthropicProvider) DefaultModel() string { + return "claude-3-opus-20240229" +} + +func (*anthropicProvider) formatMessages(messages []message) []any { + formatted := make([]any, 0, len(messages)) + for _, msg := range messages { + formatted = append(formatted, map[string]any{ + "role": msg.Role, + "content": []map[string]string{ + { + "type": "text", + "text": msg.Content, + }, + }, + }) + } + return formatted +} + +func (*anthropicProvider) buildRequestBody(model string, messages []any, stream bool) map[string]any { + return map[string]any{ + "model": model, + "messages": messages, + "max_tokens": 1024, + "stream": stream, + } +} + +// generateConversation creates a conversation with alternating user/assistant +// messages. The content is filled with repeated 'x' characters to reach +// approximately the target size. The last message is always from "user" as +// required by LLM APIs. +func generateConversation(provider ProviderStrategy, targetSize int, numMessages int) []any { + if targetSize <= 0 { + return nil + } + if numMessages < 1 { + numMessages = 1 + } + + roles := []string{"user", "assistant"} + messages := make([]message, numMessages) + for i := range messages { + messages[i].Role = roles[i%2] + } + // Ensure last message is from user (required for LLM APIs). + if messages[len(messages)-1].Role != "user" { + messages[len(messages)-1].Role = "user" + } + + overhead := measureJSONSize(provider.formatMessages(messages)) + + bytesPerMessage := targetSize - overhead + if bytesPerMessage < 0 { + bytesPerMessage = 0 + } + + perMessage := bytesPerMessage / len(messages) + remainder := bytesPerMessage % len(messages) + + for i := range messages { + size := perMessage + if i == len(messages)-1 { + size += remainder + } + messages[i].Content = strings.Repeat("x", size) + } + + return provider.formatMessages(messages) +} + +func measureJSONSize(v any) int { + data, err := json.Marshal(v) + if err != nil { + return 0 + } + return len(data) +} diff --git a/scaletest/bridge/run.go b/scaletest/bridge/run.go new file mode 100644 index 0000000000..d9c8d47270 --- /dev/null +++ b/scaletest/bridge/run.go @@ -0,0 +1,391 @@ +package bridge + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "time" + + "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" + "go.opentelemetry.io/otel/attribute" + semconv "go.opentelemetry.io/otel/semconv/v1.14.0" + "go.opentelemetry.io/otel/semconv/v1.14.0/httpconv" + "go.opentelemetry.io/otel/trace" + "golang.org/x/xerrors" + + "cdr.dev/slog/v3" + "cdr.dev/slog/v3/sloggers/sloghuman" + "github.com/coder/coder/v2/coderd/tracing" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/scaletest/harness" + "github.com/coder/coder/v2/scaletest/loadtestutil" + "github.com/coder/quartz" +) + +type ( + tracingContextKey struct{} + tracingContext struct { + provider string + model string + stream bool + requestNum int + mode RequestMode + } +) + +type tracingTransport struct { + cfg Config + underlying http.RoundTripper +} + +func newTracingTransport(cfg Config, underlying http.RoundTripper) *tracingTransport { + if underlying == nil { + underlying = http.DefaultTransport + } + return &tracingTransport{ + cfg: cfg, + underlying: otelhttp.NewTransport(underlying), + } +} + +func (t *tracingTransport) RoundTrip(req *http.Request) (*http.Response, error) { + aibridgeCtx, hasAIBridgeCtx := req.Context().Value(tracingContextKey{}).(tracingContext) + + resp, err := t.underlying.RoundTrip(req) + + if hasAIBridgeCtx { + ctx := req.Context() + if resp != nil && resp.Request != nil { + ctx = resp.Request.Context() + } + span := trace.SpanFromContext(ctx) + if span.IsRecording() { + span.SetAttributes( + attribute.String("aibridge.provider", aibridgeCtx.provider), + attribute.String("aibridge.model", aibridgeCtx.model), + attribute.Bool("aibridge.stream", aibridgeCtx.stream), + attribute.Int("aibridge.request_num", aibridgeCtx.requestNum), + attribute.String("aibridge.mode", string(aibridgeCtx.mode)), + ) + } + } + + return resp, err +} + +type Runner struct { + client *codersdk.Client + cfg Config + strategy requestModeStrategy + providerStrategy ProviderStrategy + + clock quartz.Clock + httpClient *http.Client + + requestCount int64 + successCount int64 + failureCount int64 + totalDuration time.Duration + totalTokens int64 +} + +func NewRunner(client *codersdk.Client, cfg Config) *Runner { + httpTimeout := cfg.HTTPTimeout + if httpTimeout <= 0 { + httpTimeout = 30 * time.Second + } + return &Runner{ + client: client, + cfg: cfg, + strategy: cfg.NewStrategy(client), + providerStrategy: NewProviderStrategy(cfg.Provider), + clock: quartz.NewReal(), + httpClient: &http.Client{ + Timeout: httpTimeout, + Transport: newTracingTransport(cfg, http.DefaultTransport), + }, + } +} + +func (r *Runner) WithClock(clock quartz.Clock) *Runner { + r.clock = clock + return r +} + +var ( + _ harness.Runnable = &Runner{} + _ harness.Cleanable = &Runner{} + _ harness.Collectable = &Runner{} +) + +func (r *Runner) Run(ctx context.Context, id string, logs io.Writer) error { + ctx, span := tracing.StartSpan(ctx) + defer span.End() + + logs = loadtestutil.NewSyncWriter(logs) + logger := slog.Make(sloghuman.Sink(logs)).Leveled(slog.LevelDebug) + + requestURL, token, err := r.strategy.Setup(ctx, id, logs) + if err != nil { + return xerrors.Errorf("strategy setup: %w", err) + } + + requestCount := r.cfg.RequestCount + if requestCount <= 0 { + requestCount = 1 + } + + model := r.providerStrategy.DefaultModel() + + logger.Info(ctx, "bridge runner is ready", + slog.F("request_count", requestCount), + slog.F("model", model), + slog.F("stream", r.cfg.Stream), + ) + + for i := 0; i < requestCount; i++ { + if err := r.makeRequest(ctx, logger, requestURL, token, model, i); err != nil { + logger.Warn(ctx, "bridge request failed", + slog.F("request_num", i+1), + slog.F("error_type", "request_failed"), + slog.Error(err), + ) + r.cfg.Metrics.AddError("request") + r.cfg.Metrics.AddRequest("failure") + r.failureCount++ + + // Continue making requests even if one fails + continue + } + r.successCount++ + r.cfg.Metrics.AddRequest("success") + r.requestCount++ + } + + logger.Info(ctx, "bridge runner completed", + slog.F("total_requests", r.requestCount), + slog.F("success", r.successCount), + slog.F("failure", r.failureCount), + ) + + // Fail the run if any request failed + if r.failureCount > 0 { + return xerrors.Errorf("bridge runner failed: %d out of %d requests failed", r.failureCount, requestCount) + } + + return nil +} + +func (r *Runner) makeRequest(ctx context.Context, logger slog.Logger, url, token, model string, requestNum int) error { + start := r.clock.Now() + + ctx = context.WithValue(ctx, tracingContextKey{}, tracingContext{ + provider: r.cfg.Provider, + model: model, + stream: r.cfg.Stream, + requestNum: requestNum + 1, + mode: r.cfg.Mode, + }) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(r.cfg.RequestBody)) + if err != nil { + return xerrors.Errorf("create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + if token != "" { + req.Header.Set("Authorization", "Bearer "+token) + } + + logger.Debug(ctx, "making bridge request", + slog.F("url", url), + slog.F("request_num", requestNum+1), + slog.F("model", model), + ) + + resp, err := r.httpClient.Do(req) + if err != nil { + span := trace.SpanFromContext(req.Context()) + if span.IsRecording() { + span.RecordError(err) + } + logger.Warn(ctx, "request failed during execution", + slog.F("request_num", requestNum+1), + slog.Error(err), + ) + return xerrors.Errorf("execute request: %w", err) + } + defer resp.Body.Close() + + span := trace.SpanFromContext(req.Context()) + if span.IsRecording() { + span.SetAttributes(semconv.HTTPStatusCodeKey.Int(resp.StatusCode)) + span.SetStatus(httpconv.ClientStatus(resp.StatusCode)) + } + + duration := r.clock.Since(start) + r.totalDuration += duration + r.cfg.Metrics.ObserveDuration(duration.Seconds()) + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + err := xerrors.Errorf("request failed with status %d: %s", resp.StatusCode, string(body)) + span.RecordError(err) + return err + } + + if r.cfg.Stream { + err := r.handleStreamingResponse(ctx, logger, resp) + if err != nil { + span.RecordError(err) + return err + } + return nil + } + + return r.handleNonStreamingResponse(ctx, logger, resp) +} + +func (r *Runner) handleNonStreamingResponse(ctx context.Context, logger slog.Logger, resp *http.Response) error { + if r.cfg.Provider == "anthropic" { + return r.handleAnthropicResponse(ctx, logger, resp) + } + return r.handleOpenAIResponse(ctx, logger, resp) +} + +func (r *Runner) handleOpenAIResponse(ctx context.Context, logger slog.Logger, resp *http.Response) error { + var response struct { + ID string `json:"id"` + Model string `json:"model"` + Choices []struct { + Message struct { + Content string `json:"content"` + } `json:"message"` + } `json:"choices"` + Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + } `json:"usage"` + } + + if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { + return xerrors.Errorf("decode response: %w", err) + } + + if len(response.Choices) > 0 { + assistantContent := response.Choices[0].Message.Content + logger.Debug(ctx, "received response", + slog.F("response_id", response.ID), + slog.F("content_length", len(assistantContent)), + ) + } + + if response.Usage.TotalTokens > 0 { + r.totalTokens += int64(response.Usage.TotalTokens) + r.cfg.Metrics.AddTokens("input", int64(response.Usage.PromptTokens)) + r.cfg.Metrics.AddTokens("output", int64(response.Usage.CompletionTokens)) + } + + return nil +} + +func (r *Runner) handleAnthropicResponse(ctx context.Context, logger slog.Logger, resp *http.Response) error { + var response struct { + ID string `json:"id"` + Model string `json:"model"` + Content []struct { + Type string `json:"type"` + Text string `json:"text"` + } `json:"content"` + Usage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + } `json:"usage"` + } + + if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { + return xerrors.Errorf("decode response: %w", err) + } + + var assistantContent string + if len(response.Content) > 0 { + assistantContent = response.Content[0].Text + logger.Debug(ctx, "received response", + slog.F("response_id", response.ID), + slog.F("content_length", len(assistantContent)), + ) + } + + totalTokens := response.Usage.InputTokens + response.Usage.OutputTokens + if totalTokens > 0 { + r.totalTokens += int64(totalTokens) + r.cfg.Metrics.AddTokens("input", int64(response.Usage.InputTokens)) + r.cfg.Metrics.AddTokens("output", int64(response.Usage.OutputTokens)) + } + + return nil +} + +func (*Runner) handleStreamingResponse(ctx context.Context, logger slog.Logger, resp *http.Response) error { + buf := make([]byte, 4096) + totalRead := 0 + for { + // Check for context cancellation before each read + if ctx.Err() != nil { + logger.Warn(ctx, "streaming response canceled", + slog.F("bytes_read", totalRead), + slog.Error(ctx.Err()), + ) + return xerrors.Errorf("stream canceled: %w", ctx.Err()) + } + + n, err := resp.Body.Read(buf) + if n > 0 { + totalRead += n + } + if err == io.EOF { + break + } + if err != nil { + // Check if error is due to context cancellation + if xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) { + logger.Warn(ctx, "streaming response read canceled", + slog.F("bytes_read", totalRead), + slog.Error(err), + ) + return xerrors.Errorf("stream read canceled: %w", err) + } + logger.Warn(ctx, "streaming response read error", + slog.F("bytes_read", totalRead), + slog.Error(err), + ) + return xerrors.Errorf("read stream: %w", err) + } + } + + logger.Debug(ctx, "received streaming response", slog.F("bytes_read", totalRead)) + return nil +} + +func (r *Runner) Cleanup(ctx context.Context, id string, logs io.Writer) error { + return r.strategy.Cleanup(ctx, id, logs) +} + +func (r *Runner) GetMetrics() map[string]any { + avgDuration := time.Duration(0) + if r.requestCount > 0 { + avgDuration = r.totalDuration / time.Duration(r.requestCount) + } + + return map[string]any{ + "request_count": r.requestCount, + "success_count": r.successCount, + "failure_count": r.failureCount, + "total_duration": r.totalDuration.String(), + "avg_duration": avgDuration.String(), + "total_tokens": r.totalTokens, + } +} diff --git a/scaletest/bridge/strategy.go b/scaletest/bridge/strategy.go new file mode 100644 index 0000000000..8b99c2bc2d --- /dev/null +++ b/scaletest/bridge/strategy.go @@ -0,0 +1,117 @@ +package bridge + +import ( + "context" + "fmt" + "io" + + "golang.org/x/xerrors" + + "cdr.dev/slog/v3" + "cdr.dev/slog/v3/sloggers/sloghuman" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/scaletest/createusers" +) + +type requestModeStrategy interface { + Setup(ctx context.Context, id string, logs io.Writer) (url string, token string, err error) + Cleanup(ctx context.Context, id string, logs io.Writer) error +} + +// bridgeStrategy creates users via Coder and routes requests through AI Bridge. +type bridgeStrategy struct { + client *codersdk.Client + provider string + metrics *Metrics + + userConfig createusers.Config + createUserRunner *createusers.Runner +} + +type bridgeStrategyConfig struct { + Client *codersdk.Client + Provider string + Metrics *Metrics + User createusers.Config +} + +func newBridgeStrategy(cfg bridgeStrategyConfig) *bridgeStrategy { + return &bridgeStrategy{ + client: cfg.Client, + provider: cfg.Provider, + metrics: cfg.Metrics, + userConfig: cfg.User, + } +} + +func (s *bridgeStrategy) Setup(ctx context.Context, id string, logs io.Writer) (requestURL string, token string, err error) { + logger := slog.Make(sloghuman.Sink(logs)).Leveled(slog.LevelDebug) + + s.client.SetLogger(logger) + s.client.SetLogBodies(true) + + s.createUserRunner = createusers.NewRunner(s.client, s.userConfig) + newUserAndToken, err := s.createUserRunner.RunReturningUser(ctx, id, logs) + if err != nil { + s.metrics.AddError("create_user") + return "", "", xerrors.Errorf("create user: %w", err) + } + newUser := newUserAndToken.User + token = newUserAndToken.SessionToken + + logger.Info(ctx, "runner user created", + slog.F("username", newUser.Username), + slog.F("user_id", newUser.ID.String()), + ) + + if s.provider == "anthropic" { + requestURL = fmt.Sprintf("%s/api/v2/aibridge/anthropic/v1/messages", s.client.URL) + } else { + requestURL = fmt.Sprintf("%s/api/v2/aibridge/openai/v1/chat/completions", s.client.URL) + } + logger.Info(ctx, "bridge runner in bridge mode", + slog.F("url", requestURL), + slog.F("provider", s.provider), + ) + + return requestURL, token, nil +} + +func (s *bridgeStrategy) Cleanup(ctx context.Context, id string, logs io.Writer) error { + if s.createUserRunner == nil { + return nil + } + + _, _ = fmt.Fprintln(logs, "Cleaning up user...") + if err := s.createUserRunner.Cleanup(ctx, id, logs); err != nil { + return xerrors.Errorf("cleanup user: %w", err) + } + return nil +} + +// directStrategy makes requests directly to an upstream URL. +type directStrategy struct { + upstreamURL string +} + +type directStrategyConfig struct { + UpstreamURL string +} + +func newDirectStrategy(cfg directStrategyConfig) *directStrategy { + return &directStrategy{ + upstreamURL: cfg.UpstreamURL, + } +} + +func (s *directStrategy) Setup(ctx context.Context, _ string, logs io.Writer) (requestURL string, _ string, err error) { + logger := slog.Make(sloghuman.Sink(logs)).Leveled(slog.LevelDebug) + + logger.Info(ctx, "bridge runner in direct mode", slog.F("url", s.upstreamURL)) + return s.upstreamURL, "", err +} + +func (*directStrategy) Cleanup(_ context.Context, _ string, _ io.Writer) error { + // Direct mode has no resources to clean up. + return nil +} diff --git a/scaletest/llmmock/server.go b/scaletest/llmmock/server.go new file mode 100644 index 0000000000..18fca711aa --- /dev/null +++ b/scaletest/llmmock/server.go @@ -0,0 +1,545 @@ +package llmmock + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net" + "net/http" + "strings" + "time" + + "github.com/google/uuid" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/propagation" + semconv "go.opentelemetry.io/otel/semconv/v1.14.0" + "go.opentelemetry.io/otel/semconv/v1.14.0/httpconv" + "go.opentelemetry.io/otel/semconv/v1.14.0/netconv" + "go.opentelemetry.io/otel/trace" + "golang.org/x/xerrors" + + "cdr.dev/slog/v3" + "github.com/coder/coder/v2/coderd/pproflabel" + "github.com/coder/coder/v2/coderd/tracing" +) + +// Server wraps the LLM mock server and provides an HTTP API to retrieve requests. +type Server struct { + httpServer *http.Server + httpListener net.Listener + logger slog.Logger + + address string + artificialLatency time.Duration + responsePayloadSize int + + tracerProvider trace.TracerProvider + closeTracing func(context.Context) error +} + +type Config struct { + Address string + Logger slog.Logger + ArtificialLatency time.Duration + ResponsePayloadSize int + + PprofEnable bool + PprofAddress string + + TraceEnable bool +} + +type llmRequest struct { + Model string `json:"model"` + Stream bool `json:"stream,omitempty"` +} + +type openAIMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type openAIResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []struct { + Index int `json:"index"` + Message openAIMessage `json:"message"` + FinishReason string `json:"finish_reason"` + } `json:"choices"` + Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + } `json:"usage"` +} + +type anthropicResponse struct { + ID string `json:"id"` + Type string `json:"type"` + Role string `json:"role"` + Content []struct { + Type string `json:"type"` + Text string `json:"text"` + } `json:"content"` + Model string `json:"model"` + StopReason string `json:"stop_reason"` + StopSequence *string `json:"stop_sequence"` + Usage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + } `json:"usage"` +} + +func (s *Server) Start(ctx context.Context, cfg Config) error { + s.address = cfg.Address + s.logger = cfg.Logger + s.artificialLatency = cfg.ArtificialLatency + s.responsePayloadSize = cfg.ResponsePayloadSize + + if cfg.TraceEnable { + otel.SetTextMapPropagator( + propagation.NewCompositeTextMapPropagator( + propagation.TraceContext{}, + propagation.Baggage{}, + ), + ) + + tracerProvider, closeTracing, err := tracing.TracerProvider(ctx, "llm-mock", tracing.TracerOpts{ + Default: cfg.TraceEnable, + }) + if err != nil { + s.logger.Warn(ctx, "failed to initialize tracing", slog.Error(err)) + } else { + s.tracerProvider = tracerProvider + s.closeTracing = closeTracing + } + } + + if err := s.startAPIServer(ctx); err != nil { + return xerrors.Errorf("start API server: %w", err) + } + + return nil +} + +func (s *Server) Stop() error { + if s.httpServer != nil { + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := s.httpServer.Shutdown(shutdownCtx); err != nil { + return xerrors.Errorf("shutdown HTTP server: %w", err) + } + } + if s.closeTracing != nil { + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := s.closeTracing(shutdownCtx); err != nil { + s.logger.Warn(shutdownCtx, "failed to close tracing", slog.Error(err)) + } + } + return nil +} + +func (s *Server) APIAddress() string { + return fmt.Sprintf("http://%s", s.httpListener.Addr().String()) +} + +func (s *Server) startAPIServer(ctx context.Context) error { + mux := http.NewServeMux() + + mux.HandleFunc("POST /v1/chat/completions", s.handleOpenAI) + mux.HandleFunc("POST /v1/messages", s.handleAnthropic) + + var handler http.Handler = mux + if s.tracerProvider != nil { + handler = s.tracingMiddleware(handler) + } + + s.httpServer = &http.Server{ + Handler: handler, + ReadHeaderTimeout: 10 * time.Second, + } + + listener, err := net.Listen("tcp", s.address) + if err != nil { + return xerrors.Errorf("listen on %s: %w", s.address, err) + } + s.httpListener = listener + + pproflabel.Go(ctx, pproflabel.Service("llm-mock"), func(ctx context.Context) { + if err := s.httpServer.Serve(listener); err != nil && !errors.Is(err, http.ErrServerClosed) { + s.logger.Error(ctx, "http API server error", slog.Error(err)) + } + }) + + return nil +} + +func (s *Server) handleOpenAI(w http.ResponseWriter, r *http.Request) { + pproflabel.Do(r.Context(), pproflabel.Service("llm-mock"), func(ctx context.Context) { + s.handleOpenAIWithLabels(w, r.WithContext(ctx)) + }) +} + +func (s *Server) handleOpenAIWithLabels(w http.ResponseWriter, r *http.Request) { + s.logger.Debug(r.Context(), "handling OpenAI request") + defer s.logger.Debug(r.Context(), "handled OpenAI request") + + ctx := r.Context() + requestID := uuid.New() + now := time.Now() + + var req llmRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + s.logger.Error(ctx, "failed to parse OpenAI request", slog.Error(err)) + http.Error(w, "invalid request body", http.StatusBadRequest) + return + } + + if s.artificialLatency > 0 { + time.Sleep(s.artificialLatency) + } + + var resp openAIResponse + resp.ID = fmt.Sprintf("chatcmpl-%s", requestID.String()[:8]) + resp.Object = "chat.completion" + resp.Created = now.Unix() + resp.Model = req.Model + + var responseContent string + if s.responsePayloadSize > 0 { + pattern := "x" + repeated := strings.Repeat(pattern, s.responsePayloadSize) + responseContent = repeated[:s.responsePayloadSize] + } else { + responseContent = "This is a mock response from OpenAI." + } + + resp.Choices = []struct { + Index int `json:"index"` + Message openAIMessage `json:"message"` + FinishReason string `json:"finish_reason"` + }{ + { + Index: 0, + Message: openAIMessage{ + Role: "assistant", + Content: responseContent, + }, + FinishReason: "stop", + }, + } + + resp.Usage.PromptTokens = 10 + resp.Usage.CompletionTokens = 5 + resp.Usage.TotalTokens = 15 + + responseBody, _ := json.Marshal(resp) + + if req.Stream { + s.sendOpenAIStream(ctx, w, resp) + } else { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if _, err := w.Write(responseBody); err != nil { + s.logger.Error(ctx, "failed to write OpenAI response", + slog.F("request_id", requestID), + slog.Error(err), + slog.F("error_type", "write_error"), + slog.F("likely_cause", "network_error"), + ) + } + } +} + +func (s *Server) handleAnthropic(w http.ResponseWriter, r *http.Request) { + pproflabel.Do(r.Context(), pproflabel.Service("llm-mock"), func(ctx context.Context) { + s.handleAnthropicWithLabels(w, r.WithContext(ctx)) + }) +} + +func (s *Server) handleAnthropicWithLabels(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + requestID := uuid.New() + + var req llmRequest + + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + s.logger.Error(ctx, "failed to parse LLM request", slog.Error(err)) + http.Error(w, "invalid request body", http.StatusBadRequest) + return + } + + if s.artificialLatency > 0 { + time.Sleep(s.artificialLatency) + } + + var resp anthropicResponse + resp.ID = fmt.Sprintf("msg_%s", requestID.String()[:8]) + resp.Type = "message" + resp.Role = "assistant" + + var responseText string + if s.responsePayloadSize > 0 { + pattern := "x" + repeated := strings.Repeat(pattern, s.responsePayloadSize) + responseText = repeated[:s.responsePayloadSize] + } else { + responseText = "This is a mock response from Anthropic." + } + + resp.Content = []struct { + Type string `json:"type"` + Text string `json:"text"` + }{ + { + Type: "text", + Text: responseText, + }, + } + resp.Model = req.Model + resp.StopReason = "end_turn" + resp.Usage.InputTokens = 10 + resp.Usage.OutputTokens = 5 + + responseBody, _ := json.Marshal(resp) + + if req.Stream { + s.sendAnthropicStream(ctx, w, resp) + } else { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("anthropic-version", "2023-06-01") + w.WriteHeader(http.StatusOK) + if _, err := w.Write(responseBody); err != nil { + s.logger.Error(ctx, "failed to write Anthropic response", + slog.F("request_id", requestID), + slog.Error(err), + slog.F("error_type", "write_error"), + slog.F("likely_cause", "network_error"), + ) + } + } +} + +func (s *Server) sendOpenAIStream(ctx context.Context, w http.ResponseWriter, resp openAIResponse) { + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.WriteHeader(http.StatusOK) + + flusher, ok := w.(http.Flusher) + if !ok { + s.logger.Error(ctx, "responseWriter does not support flushing", + slog.F("response_id", resp.ID), + ) + return + } + + writeChunk := func(data string) bool { + if _, err := fmt.Fprintf(w, "%s", data); err != nil { + s.logger.Error(ctx, "failed to write OpenAI stream chunk", + slog.F("response_id", resp.ID), + slog.Error(err), + slog.F("error_type", "write_error"), + slog.F("likely_cause", "network_error"), + ) + return false + } + flusher.Flush() + return true + } + + // Send initial chunk + chunk := map[string]interface{}{ + "id": resp.ID, + "object": "chat.completion.chunk", + "created": resp.Created, + "model": resp.Model, + "choices": []map[string]interface{}{ + { + "index": 0, + "delta": map[string]interface{}{ + "role": "assistant", + "content": resp.Choices[0].Message.Content, + }, + "finish_reason": nil, + }, + }, + } + chunkBytes, _ := json.Marshal(chunk) + if !writeChunk(fmt.Sprintf("data: %s\n\n", chunkBytes)) { + return + } + + // Send final chunk + finalChunk := map[string]interface{}{ + "id": resp.ID, + "object": "chat.completion.chunk", + "created": resp.Created, + "model": resp.Model, + "choices": []map[string]interface{}{ + { + "index": 0, + "delta": map[string]interface{}{}, + "finish_reason": resp.Choices[0].FinishReason, + }, + }, + } + finalChunkBytes, _ := json.Marshal(finalChunk) + if !writeChunk(fmt.Sprintf("data: %s\n\n", finalChunkBytes)) { + return + } + writeChunk("data: [DONE]\n\n") +} + +func (s *Server) sendAnthropicStream(ctx context.Context, w http.ResponseWriter, resp anthropicResponse) { + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("anthropic-version", "2023-06-01") + w.WriteHeader(http.StatusOK) + + flusher, ok := w.(http.Flusher) + if !ok { + s.logger.Error(ctx, "responseWriter does not support flushing", + slog.F("response_id", resp.ID), + ) + return + } + + writeChunk := func(data string) bool { + if _, err := fmt.Fprintf(w, "%s", data); err != nil { + s.logger.Error(ctx, "failed to write Anthropic stream chunk", + slog.F("response_id", resp.ID), + slog.Error(err), + slog.F("error_type", "write_error"), + slog.F("likely_cause", "network_error"), + ) + return false + } + flusher.Flush() + return true + } + + startEvent := map[string]interface{}{ + "type": "message_start", + "message": map[string]interface{}{ + "id": resp.ID, + "type": resp.Type, + "role": resp.Role, + "model": resp.Model, + }, + } + startBytes, _ := json.Marshal(startEvent) + if !writeChunk(fmt.Sprintf("data: %s\n\n", startBytes)) { + return + } + + // Send content_block_start event + contentStartEvent := map[string]interface{}{ + "type": "content_block_start", + "index": 0, + "content_block": map[string]interface{}{ + "type": "text", + "text": resp.Content[0].Text, + }, + } + contentStartBytes, _ := json.Marshal(contentStartEvent) + if !writeChunk(fmt.Sprintf("data: %s\n\n", contentStartBytes)) { + return + } + + // Send content_block_delta event + deltaEvent := map[string]interface{}{ + "type": "content_block_delta", + "index": 0, + "delta": map[string]interface{}{ + "type": "text_delta", + "text": resp.Content[0].Text, + }, + } + deltaBytes, _ := json.Marshal(deltaEvent) + if !writeChunk(fmt.Sprintf("data: %s\n\n", deltaBytes)) { + return + } + + // Send content_block_stop event + contentStopEvent := map[string]interface{}{ + "type": "content_block_stop", + "index": 0, + } + contentStopBytes, _ := json.Marshal(contentStopEvent) + if !writeChunk(fmt.Sprintf("data: %s\n\n", contentStopBytes)) { + return + } + + // Send message_delta event + deltaMsgEvent := map[string]interface{}{ + "type": "message_delta", + "delta": map[string]interface{}{ + "stop_reason": resp.StopReason, + "stop_sequence": resp.StopSequence, + }, + "usage": resp.Usage, + } + deltaMsgBytes, _ := json.Marshal(deltaMsgEvent) + if !writeChunk(fmt.Sprintf("data: %s\n\n", deltaMsgBytes)) { + return + } + + // Send message_stop event + stopEvent := map[string]interface{}{ + "type": "message_stop", + } + stopBytes, _ := json.Marshal(stopEvent) + writeChunk(fmt.Sprintf("data: %s\n\n", stopBytes)) +} + +func (s *Server) tracingMiddleware(next http.Handler) http.Handler { + tracer := s.tracerProvider.Tracer("llm-mock") + + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + // Wrap response writer with StatusWriter for tracing + sw := &tracing.StatusWriter{ResponseWriter: rw} + + // Extract trace context from headers + propagator := otel.GetTextMapPropagator() + hc := propagation.HeaderCarrier(r.Header) + ctx := propagator.Extract(r.Context(), hc) + + // Start span with initial name (will be updated after handler) + ctx, span := tracer.Start(ctx, fmt.Sprintf("%s %s", r.Method, r.RequestURI)) + defer span.End() + r = r.WithContext(ctx) + + // Inject trace context into response headers + if span.SpanContext().HasTraceID() && span.SpanContext().HasSpanID() { + rw.Header().Set("X-Trace-ID", span.SpanContext().TraceID().String()) + rw.Header().Set("X-Span-ID", span.SpanContext().SpanID().String()) + + hc := propagation.HeaderCarrier(rw.Header()) + propagator.Inject(ctx, hc) + } + + // Execute the handler + next.ServeHTTP(sw, r) + + // Update span with final route and response information + route := r.URL.Path + span.SetName(fmt.Sprintf("%s %s", r.Method, route)) + span.SetAttributes(netconv.Transport("tcp")) + span.SetAttributes(httpconv.ServerRequest("llm-mock", r)...) + span.SetAttributes(semconv.HTTPRouteKey.String(route)) + + status := sw.Status + if status == 0 { + status = http.StatusOK + } + span.SetAttributes(semconv.HTTPStatusCodeKey.Int(status)) + span.SetStatus(httpconv.ServerStatus(status)) + }) +}