mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
chore: add scaletesting tools for aibridge (#21279)
This pull request adds scaletesting tools for aibridge. See https://www.notion.so/Scale-tests-2c5d579be5928088b565d15dd8bdea41?source=copy_link for information and instructions. closes: https://github.com/coder/internal/issues/1156 closes: https://github.com/coder/internal/issues/1155 closes: https://github.com/coder/internal/issues/1158
This commit is contained in:
@@ -68,6 +68,8 @@ func (r *RootCmd) scaletestCmd() *serpent.Command {
|
||||
r.scaletestTaskStatus(),
|
||||
r.scaletestSMTP(),
|
||||
r.scaletestPrebuilds(),
|
||||
r.scaletestBridge(),
|
||||
r.scaletestLLMMock(),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,278 @@
|
||||
//go:build !slim
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os/signal"
|
||||
"strconv"
|
||||
"text/tabwriter"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/scaletest/bridge"
|
||||
"github.com/coder/coder/v2/scaletest/createusers"
|
||||
"github.com/coder/coder/v2/scaletest/harness"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
func (r *RootCmd) scaletestBridge() *serpent.Command {
|
||||
var (
|
||||
concurrentUsers int64
|
||||
noCleanup bool
|
||||
mode string
|
||||
upstreamURL string
|
||||
provider string
|
||||
requestsPerUser int64
|
||||
useStreamingAPI bool
|
||||
requestPayloadSize int64
|
||||
numMessages int64
|
||||
httpTimeout time.Duration
|
||||
|
||||
timeoutStrategy = &timeoutFlags{}
|
||||
cleanupStrategy = newScaletestCleanupStrategy()
|
||||
output = &scaletestOutputFlags{}
|
||||
prometheusFlags = &scaletestPrometheusFlags{}
|
||||
)
|
||||
|
||||
cmd := &serpent.Command{
|
||||
Use: "bridge",
|
||||
Short: "Generate load on the AI Bridge service.",
|
||||
Long: `Generate load for AI Bridge testing. Supports two modes: 'bridge' mode routes requests through the Coder AI Bridge, 'direct' mode makes requests directly to an upstream URL (useful for baseline comparisons).
|
||||
|
||||
Examples:
|
||||
# Test OpenAI API through bridge
|
||||
coder scaletest bridge --mode bridge --provider openai --concurrent-users 10 --request-count 5 --num-messages 10
|
||||
|
||||
# Test Anthropic API through bridge
|
||||
coder scaletest bridge --mode bridge --provider anthropic --concurrent-users 10 --request-count 5 --num-messages 10
|
||||
|
||||
# Test directly against mock server
|
||||
coder scaletest bridge --mode direct --provider openai --upstream-url http://localhost:8080/v1/chat/completions
|
||||
`,
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
ctx := inv.Context()
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
client.HTTPClient = &http.Client{
|
||||
Transport: &codersdk.HeaderTransport{
|
||||
Transport: http.DefaultTransport,
|
||||
Header: map[string][]string{
|
||||
codersdk.BypassRatelimitHeader: {"true"},
|
||||
},
|
||||
},
|
||||
}
|
||||
reg := prometheus.NewRegistry()
|
||||
metrics := bridge.NewMetrics(reg)
|
||||
|
||||
logger := inv.Logger
|
||||
prometheusSrvClose := ServeHandler(ctx, logger, promhttp.HandlerFor(reg, promhttp.HandlerOpts{}), prometheusFlags.Address, "prometheus")
|
||||
defer prometheusSrvClose()
|
||||
|
||||
defer func() {
|
||||
_, _ = fmt.Fprintf(inv.Stderr, "Waiting %s for prometheus metrics to be scraped\n", prometheusFlags.Wait)
|
||||
<-time.After(prometheusFlags.Wait)
|
||||
}()
|
||||
|
||||
notifyCtx, stop := signal.NotifyContext(ctx, StopSignals...)
|
||||
defer stop()
|
||||
ctx = notifyCtx
|
||||
|
||||
var userConfig createusers.Config
|
||||
if bridge.RequestMode(mode) == bridge.RequestModeBridge {
|
||||
me, err := requireAdmin(ctx, client)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(me.OrganizationIDs) == 0 {
|
||||
return xerrors.Errorf("admin user must have at least one organization")
|
||||
}
|
||||
userConfig = createusers.Config{
|
||||
OrganizationID: me.OrganizationIDs[0],
|
||||
}
|
||||
_, _ = fmt.Fprintln(inv.Stderr, "Bridge mode: creating users and making requests through AI Bridge...")
|
||||
} else {
|
||||
_, _ = fmt.Fprintf(inv.Stderr, "Direct mode: making requests directly to %s\n", upstreamURL)
|
||||
}
|
||||
|
||||
outputs, err := output.parse()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("parse output flags: %w", err)
|
||||
}
|
||||
|
||||
config := bridge.Config{
|
||||
Mode: bridge.RequestMode(mode),
|
||||
Metrics: metrics,
|
||||
Provider: provider,
|
||||
RequestCount: int(requestsPerUser),
|
||||
Stream: useStreamingAPI,
|
||||
RequestPayloadSize: int(requestPayloadSize),
|
||||
NumMessages: int(numMessages),
|
||||
HTTPTimeout: httpTimeout,
|
||||
UpstreamURL: upstreamURL,
|
||||
User: userConfig,
|
||||
}
|
||||
if err := config.Validate(); err != nil {
|
||||
return xerrors.Errorf("validate config: %w", err)
|
||||
}
|
||||
if err := config.PrepareRequestBody(); err != nil {
|
||||
return xerrors.Errorf("prepare request body: %w", err)
|
||||
}
|
||||
|
||||
th := harness.NewTestHarness(timeoutStrategy.wrapStrategy(harness.ConcurrentExecutionStrategy{}), cleanupStrategy.toStrategy())
|
||||
|
||||
for i := range concurrentUsers {
|
||||
id := strconv.Itoa(int(i))
|
||||
name := fmt.Sprintf("bridge-%s", id)
|
||||
var runner harness.Runnable = bridge.NewRunner(client, config)
|
||||
th.AddRun(name, id, runner)
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintln(inv.Stderr, "Bridge scaletest configuration:")
|
||||
tw := tabwriter.NewWriter(inv.Stderr, 0, 0, 2, ' ', 0)
|
||||
for _, opt := range inv.Command.Options {
|
||||
if opt.Hidden || opt.ValueSource == serpent.ValueSourceNone {
|
||||
continue
|
||||
}
|
||||
_, _ = fmt.Fprintf(tw, " %s:\t%s", opt.Name, opt.Value.String())
|
||||
if opt.ValueSource != serpent.ValueSourceDefault {
|
||||
_, _ = fmt.Fprintf(tw, "\t(from %s)", opt.ValueSource)
|
||||
}
|
||||
_, _ = fmt.Fprintln(tw)
|
||||
}
|
||||
_ = tw.Flush()
|
||||
|
||||
_, _ = fmt.Fprintln(inv.Stderr, "\nRunning bridge scaletest...")
|
||||
testCtx, testCancel := timeoutStrategy.toContext(ctx)
|
||||
defer testCancel()
|
||||
err = th.Run(testCtx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("run test harness (harness failure, not a test failure): %w", err)
|
||||
}
|
||||
|
||||
// If the command was interrupted, skip stats.
|
||||
if notifyCtx.Err() != nil {
|
||||
return notifyCtx.Err()
|
||||
}
|
||||
|
||||
res := th.Results()
|
||||
|
||||
for _, o := range outputs {
|
||||
err = o.write(res, inv.Stdout)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("write output %q to %q: %w", o.format, o.path, err)
|
||||
}
|
||||
}
|
||||
|
||||
if !noCleanup {
|
||||
_, _ = fmt.Fprintln(inv.Stderr, "\nCleaning up...")
|
||||
cleanupCtx, cleanupCancel := cleanupStrategy.toContext(ctx)
|
||||
defer cleanupCancel()
|
||||
err = th.Cleanup(cleanupCtx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("cleanup tests: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if res.TotalFail > 0 {
|
||||
return xerrors.New("load test failed, see above for more details")
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Options = serpent.OptionSet{
|
||||
{
|
||||
Flag: "concurrent-users",
|
||||
FlagShorthand: "c",
|
||||
Env: "CODER_SCALETEST_BRIDGE_CONCURRENT_USERS",
|
||||
Description: "Required: Number of concurrent users.",
|
||||
Value: serpent.Validate(serpent.Int64Of(&concurrentUsers), func(value *serpent.Int64) error {
|
||||
if value == nil || value.Value() <= 0 {
|
||||
return xerrors.Errorf("--concurrent-users must be greater than 0")
|
||||
}
|
||||
return nil
|
||||
}),
|
||||
Required: true,
|
||||
},
|
||||
{
|
||||
Flag: "mode",
|
||||
Env: "CODER_SCALETEST_BRIDGE_MODE",
|
||||
Default: "direct",
|
||||
Description: "Request mode: 'bridge' (create users and use AI Bridge) or 'direct' (make requests directly to upstream-url).",
|
||||
Value: serpent.EnumOf(&mode, string(bridge.RequestModeBridge), string(bridge.RequestModeDirect)),
|
||||
},
|
||||
{
|
||||
Flag: "upstream-url",
|
||||
Env: "CODER_SCALETEST_BRIDGE_UPSTREAM_URL",
|
||||
Description: "URL to make requests to directly (required in direct mode, e.g., http://localhost:8080/v1/chat/completions).",
|
||||
Value: serpent.StringOf(&upstreamURL),
|
||||
},
|
||||
{
|
||||
Flag: "provider",
|
||||
Env: "CODER_SCALETEST_BRIDGE_PROVIDER",
|
||||
Default: "openai",
|
||||
Description: "API provider to use.",
|
||||
Value: serpent.EnumOf(&provider, "openai", "anthropic"),
|
||||
},
|
||||
{
|
||||
Flag: "request-count",
|
||||
Env: "CODER_SCALETEST_BRIDGE_REQUEST_COUNT",
|
||||
Default: "1",
|
||||
Description: "Number of sequential requests to make per runner.",
|
||||
Value: serpent.Validate(serpent.Int64Of(&requestsPerUser), func(value *serpent.Int64) error {
|
||||
if value == nil || value.Value() <= 0 {
|
||||
return xerrors.Errorf("--request-count must be greater than 0")
|
||||
}
|
||||
return nil
|
||||
}),
|
||||
},
|
||||
{
|
||||
Flag: "stream",
|
||||
Env: "CODER_SCALETEST_BRIDGE_STREAM",
|
||||
Description: "Enable streaming requests.",
|
||||
Value: serpent.BoolOf(&useStreamingAPI),
|
||||
},
|
||||
{
|
||||
Flag: "request-payload-size",
|
||||
Env: "CODER_SCALETEST_BRIDGE_REQUEST_PAYLOAD_SIZE",
|
||||
Default: "1024",
|
||||
Description: "Size in bytes of the request payload (user message content). If 0, uses default message content.",
|
||||
Value: serpent.Int64Of(&requestPayloadSize),
|
||||
},
|
||||
{
|
||||
Flag: "num-messages",
|
||||
Env: "CODER_SCALETEST_BRIDGE_NUM_MESSAGES",
|
||||
Default: "1",
|
||||
Description: "Number of messages to include in the conversation.",
|
||||
Value: serpent.Int64Of(&numMessages),
|
||||
},
|
||||
{
|
||||
Flag: "no-cleanup",
|
||||
Env: "CODER_SCALETEST_NO_CLEANUP",
|
||||
Description: "Do not clean up resources after the test completes.",
|
||||
Value: serpent.BoolOf(&noCleanup),
|
||||
},
|
||||
{
|
||||
Flag: "http-timeout",
|
||||
Env: "CODER_SCALETEST_BRIDGE_HTTP_TIMEOUT",
|
||||
Default: "30s",
|
||||
Description: "Timeout for individual HTTP requests to the upstream provider.",
|
||||
Value: serpent.DurationOf(&httpTimeout),
|
||||
},
|
||||
}
|
||||
|
||||
timeoutStrategy.attach(&cmd.Options)
|
||||
cleanupStrategy.attach(&cmd.Options)
|
||||
output.attach(&cmd.Options)
|
||||
prometheusFlags.attach(&cmd.Options)
|
||||
return cmd
|
||||
}
|
||||
@@ -0,0 +1,118 @@
|
||||
//go:build !slim
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/signal"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/sloghuman"
|
||||
"github.com/coder/coder/v2/scaletest/llmmock"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
func (*RootCmd) scaletestLLMMock() *serpent.Command {
|
||||
var (
|
||||
address string
|
||||
artificialLatency time.Duration
|
||||
responsePayloadSize int64
|
||||
|
||||
pprofEnable bool
|
||||
pprofAddress string
|
||||
|
||||
traceEnable bool
|
||||
)
|
||||
cmd := &serpent.Command{
|
||||
Use: "llm-mock",
|
||||
Short: "Start a mock LLM API server for testing",
|
||||
Long: `Start a mock LLM API server that simulates OpenAI and Anthropic APIs`,
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
ctx, stop := signal.NotifyContext(inv.Context(), StopSignals...)
|
||||
defer stop()
|
||||
|
||||
logger := slog.Make(sloghuman.Sink(inv.Stderr)).Leveled(slog.LevelInfo)
|
||||
|
||||
if pprofEnable {
|
||||
closePprof := ServeHandler(ctx, logger, nil, pprofAddress, "pprof")
|
||||
defer closePprof()
|
||||
logger.Info(ctx, "pprof server started", slog.F("address", pprofAddress))
|
||||
}
|
||||
|
||||
config := llmmock.Config{
|
||||
Address: address,
|
||||
Logger: logger,
|
||||
ArtificialLatency: artificialLatency,
|
||||
ResponsePayloadSize: int(responsePayloadSize),
|
||||
PprofEnable: pprofEnable,
|
||||
PprofAddress: pprofAddress,
|
||||
TraceEnable: traceEnable,
|
||||
}
|
||||
srv := new(llmmock.Server)
|
||||
|
||||
if err := srv.Start(ctx, config); err != nil {
|
||||
return xerrors.Errorf("start mock LLM server: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = srv.Stop()
|
||||
}()
|
||||
|
||||
_, _ = fmt.Fprintf(inv.Stdout, "Mock LLM API server started on %s\n", srv.APIAddress())
|
||||
_, _ = fmt.Fprintf(inv.Stdout, " OpenAI endpoint: %s/v1/chat/completions\n", srv.APIAddress())
|
||||
_, _ = fmt.Fprintf(inv.Stdout, " Anthropic endpoint: %s/v1/messages\n", srv.APIAddress())
|
||||
|
||||
<-ctx.Done()
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Options = []serpent.Option{
|
||||
{
|
||||
Flag: "address",
|
||||
Env: "CODER_SCALETEST_LLM_MOCK_ADDRESS",
|
||||
Default: "localhost",
|
||||
Description: "Address to bind the mock LLM API server. Can include a port (e.g., 'localhost:8080' or ':8080'). Uses a random port if no port is specified.",
|
||||
Value: serpent.StringOf(&address),
|
||||
},
|
||||
{
|
||||
Flag: "artificial-latency",
|
||||
Env: "CODER_SCALETEST_LLM_MOCK_ARTIFICIAL_LATENCY",
|
||||
Default: "0s",
|
||||
Description: "Artificial latency to add to each response (e.g., 100ms, 1s). Simulates slow upstream processing.",
|
||||
Value: serpent.DurationOf(&artificialLatency),
|
||||
},
|
||||
{
|
||||
Flag: "response-payload-size",
|
||||
Env: "CODER_SCALETEST_LLM_MOCK_RESPONSE_PAYLOAD_SIZE",
|
||||
Default: "0",
|
||||
Description: "Size in bytes of the response payload. If 0, uses default context-aware responses.",
|
||||
Value: serpent.Int64Of(&responsePayloadSize),
|
||||
},
|
||||
{
|
||||
Flag: "pprof-enable",
|
||||
Env: "CODER_SCALETEST_LLM_MOCK_PPROF_ENABLE",
|
||||
Default: "false",
|
||||
Description: "Serve pprof metrics on the address defined by pprof-address.",
|
||||
Value: serpent.BoolOf(&pprofEnable),
|
||||
},
|
||||
{
|
||||
Flag: "pprof-address",
|
||||
Env: "CODER_SCALETEST_LLM_MOCK_PPROF_ADDRESS",
|
||||
Default: "127.0.0.1:6060",
|
||||
Description: "The bind address to serve pprof.",
|
||||
Value: serpent.StringOf(&pprofAddress),
|
||||
},
|
||||
{
|
||||
Flag: "trace-enable",
|
||||
Env: "CODER_SCALETEST_LLM_MOCK_TRACE_ENABLE",
|
||||
Default: "false",
|
||||
Description: "Whether application tracing data is collected. It exports to a backend configured by environment variables. See: https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/protocol/exporter.md.",
|
||||
Value: serpent.BoolOf(&traceEnable),
|
||||
},
|
||||
}
|
||||
|
||||
return cmd
|
||||
}
|
||||
@@ -437,7 +437,7 @@ require (
|
||||
go.opentelemetry.io/collector/pdata/pprofile v0.121.0 // indirect
|
||||
go.opentelemetry.io/collector/semconv v0.123.0 // indirect
|
||||
go.opentelemetry.io/contrib v1.19.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.62.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.62.0
|
||||
go.opentelemetry.io/otel/metric v1.38.0 // indirect
|
||||
go.opentelemetry.io/proto/otlp v1.7.0 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
|
||||
@@ -0,0 +1,150 @@
|
||||
package bridge
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/scaletest/createusers"
|
||||
)
|
||||
|
||||
type RequestMode string
|
||||
|
||||
const (
|
||||
RequestModeBridge RequestMode = "bridge"
|
||||
RequestModeDirect RequestMode = "direct"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
// Mode determines how requests are made.
|
||||
// "bridge": Create users in Coder and use their session tokens to make requests through AI Bridge.
|
||||
// "direct": Make requests directly to UpstreamURL without user creation.
|
||||
Mode RequestMode `json:"mode"`
|
||||
|
||||
// User is the configuration for the user to create.
|
||||
// Required in bridge mode.
|
||||
User createusers.Config `json:"user"`
|
||||
|
||||
// UpstreamURL is the URL to make requests to directly.
|
||||
// Only used in direct mode.
|
||||
UpstreamURL string `json:"upstream_url"`
|
||||
|
||||
// Provider is the API provider to use: "openai" or "anthropic".
|
||||
Provider string `json:"provider"`
|
||||
|
||||
// RequestCount is the number of requests to make per runner.
|
||||
RequestCount int `json:"request_count"`
|
||||
|
||||
// Stream indicates whether to use streaming requests.
|
||||
Stream bool `json:"stream"`
|
||||
|
||||
// RequestPayloadSize is the size in bytes of the request payload (user message content).
|
||||
// If 0, uses default message content.
|
||||
RequestPayloadSize int `json:"request_payload_size"`
|
||||
|
||||
// NumMessages is the number of messages to include in the conversation.
|
||||
// Messages alternate between user and assistant roles, always ending with user.
|
||||
// Must be greater than 0.
|
||||
NumMessages int `json:"num_messages"`
|
||||
|
||||
// HTTPTimeout is the timeout for individual HTTP requests to the upstream
|
||||
// provider. This is separate from the job timeout which controls the overall
|
||||
// test execution.
|
||||
HTTPTimeout time.Duration `json:"http_timeout"`
|
||||
|
||||
Metrics *Metrics `json:"-"`
|
||||
|
||||
// RequestBody is the pre-serialized JSON request body. This is generated
|
||||
// once by PrepareRequestBody and shared across all runners and requests.
|
||||
RequestBody []byte `json:"-"`
|
||||
}
|
||||
|
||||
func (c Config) Validate() error {
|
||||
if c.Metrics == nil {
|
||||
return xerrors.New("metrics must be set")
|
||||
}
|
||||
|
||||
// Validate mode
|
||||
if c.Mode != RequestModeBridge && c.Mode != RequestModeDirect {
|
||||
return xerrors.New("mode must be either 'bridge' or 'direct'")
|
||||
}
|
||||
|
||||
if c.RequestCount <= 0 {
|
||||
return xerrors.New("request_count must be greater than 0")
|
||||
}
|
||||
|
||||
// Validate provider
|
||||
if c.Provider != "openai" && c.Provider != "anthropic" {
|
||||
return xerrors.New("provider must be either 'openai' or 'anthropic'")
|
||||
}
|
||||
|
||||
if c.Mode == RequestModeDirect {
|
||||
// In direct mode, UpstreamURL must be set.
|
||||
if c.UpstreamURL == "" {
|
||||
return xerrors.New("upstream_url must be set in direct mode")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// In bridge mode, User config is required.
|
||||
if c.User.OrganizationID == uuid.Nil {
|
||||
return xerrors.New("user organization_id must be set in bridge mode")
|
||||
}
|
||||
|
||||
if err := c.User.Validate(); err != nil {
|
||||
return xerrors.Errorf("user config: %w", err)
|
||||
}
|
||||
|
||||
if c.NumMessages <= 0 {
|
||||
return xerrors.New("num_messages must be greater than 0")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c Config) NewStrategy(client *codersdk.Client) requestModeStrategy {
|
||||
if c.Mode == RequestModeDirect {
|
||||
return newDirectStrategy(directStrategyConfig{
|
||||
UpstreamURL: c.UpstreamURL,
|
||||
})
|
||||
}
|
||||
|
||||
return newBridgeStrategy(bridgeStrategyConfig{
|
||||
Client: client,
|
||||
Provider: c.Provider,
|
||||
Metrics: c.Metrics,
|
||||
User: c.User,
|
||||
})
|
||||
}
|
||||
|
||||
// PrepareRequestBody generates the conversation and serializes the full request
|
||||
// body once. This should be called before creating Runners so that all runners
|
||||
// share the same pre-generated payload.
|
||||
func (c *Config) PrepareRequestBody() error {
|
||||
provider := NewProviderStrategy(c.Provider)
|
||||
model := provider.DefaultModel()
|
||||
|
||||
var formattedMessages []any
|
||||
if c.RequestPayloadSize > 0 {
|
||||
formattedMessages = generateConversation(provider, c.RequestPayloadSize, c.NumMessages)
|
||||
} else {
|
||||
messages := []message{{
|
||||
Role: "user",
|
||||
Content: "Hello from the bridge load generator.",
|
||||
}}
|
||||
formattedMessages = provider.formatMessages(messages)
|
||||
}
|
||||
|
||||
reqBody := provider.buildRequestBody(model, formattedMessages, c.Stream)
|
||||
|
||||
bodyBytes, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("marshal request body: %w", err)
|
||||
}
|
||||
|
||||
c.RequestBody = bodyBytes
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,72 @@
|
||||
package bridge
|
||||
|
||||
import (
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
)
|
||||
|
||||
type Metrics struct {
|
||||
bridgeErrors *prometheus.CounterVec
|
||||
bridgeRequests *prometheus.CounterVec
|
||||
bridgeDuration prometheus.Histogram
|
||||
bridgeTokensTotal *prometheus.CounterVec
|
||||
}
|
||||
|
||||
func NewMetrics(reg prometheus.Registerer) *Metrics {
|
||||
if reg == nil {
|
||||
reg = prometheus.DefaultRegisterer
|
||||
}
|
||||
|
||||
errors := prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: "coderd",
|
||||
Subsystem: "scaletest",
|
||||
Name: "bridge_errors_total",
|
||||
Help: "Total number of bridge errors",
|
||||
}, []string{"action"})
|
||||
|
||||
requests := prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: "coderd",
|
||||
Subsystem: "scaletest",
|
||||
Name: "bridge_requests_total",
|
||||
Help: "Total number of bridge requests",
|
||||
}, []string{"status"})
|
||||
|
||||
duration := prometheus.NewHistogram(prometheus.HistogramOpts{
|
||||
Namespace: "coderd",
|
||||
Subsystem: "scaletest",
|
||||
Name: "bridge_request_duration_seconds",
|
||||
Help: "Duration of bridge requests in seconds",
|
||||
Buckets: prometheus.DefBuckets,
|
||||
})
|
||||
|
||||
tokens := prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: "coderd",
|
||||
Subsystem: "scaletest",
|
||||
Name: "bridge_response_tokens_total",
|
||||
Help: "Total number of tokens in bridge responses",
|
||||
}, []string{"type"})
|
||||
|
||||
reg.MustRegister(errors, requests, duration, tokens)
|
||||
|
||||
return &Metrics{
|
||||
bridgeErrors: errors,
|
||||
bridgeRequests: requests,
|
||||
bridgeDuration: duration,
|
||||
bridgeTokensTotal: tokens,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Metrics) AddError(action string) {
|
||||
m.bridgeErrors.WithLabelValues(action).Inc()
|
||||
}
|
||||
|
||||
func (m *Metrics) AddRequest(status string) {
|
||||
m.bridgeRequests.WithLabelValues(status).Inc()
|
||||
}
|
||||
|
||||
func (m *Metrics) ObserveDuration(duration float64) {
|
||||
m.bridgeDuration.Observe(duration)
|
||||
}
|
||||
|
||||
func (m *Metrics) AddTokens(tokenType string, count int64) {
|
||||
m.bridgeTokensTotal.WithLabelValues(tokenType).Add(float64(count))
|
||||
}
|
||||
@@ -0,0 +1,134 @@
|
||||
package bridge
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ProviderStrategy handles provider-specific message formatting for LLM APIs.
|
||||
type ProviderStrategy interface {
|
||||
DefaultModel() string
|
||||
formatMessages(messages []message) []any
|
||||
buildRequestBody(model string, messages []any, stream bool) map[string]any
|
||||
}
|
||||
|
||||
type message struct {
|
||||
Role string
|
||||
Content string
|
||||
}
|
||||
|
||||
func NewProviderStrategy(provider string) ProviderStrategy {
|
||||
switch provider {
|
||||
case "anthropic":
|
||||
return &anthropicProvider{}
|
||||
default:
|
||||
return &openAIProvider{}
|
||||
}
|
||||
}
|
||||
|
||||
type openAIProvider struct{}
|
||||
|
||||
func (*openAIProvider) DefaultModel() string {
|
||||
return "gpt-4"
|
||||
}
|
||||
|
||||
func (*openAIProvider) formatMessages(messages []message) []any {
|
||||
formatted := make([]any, 0, len(messages))
|
||||
for _, msg := range messages {
|
||||
formatted = append(formatted, map[string]string{
|
||||
"role": msg.Role,
|
||||
"content": msg.Content,
|
||||
})
|
||||
}
|
||||
return formatted
|
||||
}
|
||||
|
||||
func (*openAIProvider) buildRequestBody(model string, messages []any, stream bool) map[string]any {
|
||||
return map[string]any{
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"stream": stream,
|
||||
}
|
||||
}
|
||||
|
||||
type anthropicProvider struct{}
|
||||
|
||||
func (*anthropicProvider) DefaultModel() string {
|
||||
return "claude-3-opus-20240229"
|
||||
}
|
||||
|
||||
func (*anthropicProvider) formatMessages(messages []message) []any {
|
||||
formatted := make([]any, 0, len(messages))
|
||||
for _, msg := range messages {
|
||||
formatted = append(formatted, map[string]any{
|
||||
"role": msg.Role,
|
||||
"content": []map[string]string{
|
||||
{
|
||||
"type": "text",
|
||||
"text": msg.Content,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
return formatted
|
||||
}
|
||||
|
||||
func (*anthropicProvider) buildRequestBody(model string, messages []any, stream bool) map[string]any {
|
||||
return map[string]any{
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"max_tokens": 1024,
|
||||
"stream": stream,
|
||||
}
|
||||
}
|
||||
|
||||
// generateConversation creates a conversation with alternating user/assistant
|
||||
// messages. The content is filled with repeated 'x' characters to reach
|
||||
// approximately the target size. The last message is always from "user" as
|
||||
// required by LLM APIs.
|
||||
func generateConversation(provider ProviderStrategy, targetSize int, numMessages int) []any {
|
||||
if targetSize <= 0 {
|
||||
return nil
|
||||
}
|
||||
if numMessages < 1 {
|
||||
numMessages = 1
|
||||
}
|
||||
|
||||
roles := []string{"user", "assistant"}
|
||||
messages := make([]message, numMessages)
|
||||
for i := range messages {
|
||||
messages[i].Role = roles[i%2]
|
||||
}
|
||||
// Ensure last message is from user (required for LLM APIs).
|
||||
if messages[len(messages)-1].Role != "user" {
|
||||
messages[len(messages)-1].Role = "user"
|
||||
}
|
||||
|
||||
overhead := measureJSONSize(provider.formatMessages(messages))
|
||||
|
||||
bytesPerMessage := targetSize - overhead
|
||||
if bytesPerMessage < 0 {
|
||||
bytesPerMessage = 0
|
||||
}
|
||||
|
||||
perMessage := bytesPerMessage / len(messages)
|
||||
remainder := bytesPerMessage % len(messages)
|
||||
|
||||
for i := range messages {
|
||||
size := perMessage
|
||||
if i == len(messages)-1 {
|
||||
size += remainder
|
||||
}
|
||||
messages[i].Content = strings.Repeat("x", size)
|
||||
}
|
||||
|
||||
return provider.formatMessages(messages)
|
||||
}
|
||||
|
||||
func measureJSONSize(v any) int {
|
||||
data, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return len(data)
|
||||
}
|
||||
@@ -0,0 +1,391 @@
|
||||
package bridge
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
semconv "go.opentelemetry.io/otel/semconv/v1.14.0"
|
||||
"go.opentelemetry.io/otel/semconv/v1.14.0/httpconv"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/sloghuman"
|
||||
"github.com/coder/coder/v2/coderd/tracing"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/scaletest/harness"
|
||||
"github.com/coder/coder/v2/scaletest/loadtestutil"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
type (
|
||||
tracingContextKey struct{}
|
||||
tracingContext struct {
|
||||
provider string
|
||||
model string
|
||||
stream bool
|
||||
requestNum int
|
||||
mode RequestMode
|
||||
}
|
||||
)
|
||||
|
||||
type tracingTransport struct {
|
||||
cfg Config
|
||||
underlying http.RoundTripper
|
||||
}
|
||||
|
||||
func newTracingTransport(cfg Config, underlying http.RoundTripper) *tracingTransport {
|
||||
if underlying == nil {
|
||||
underlying = http.DefaultTransport
|
||||
}
|
||||
return &tracingTransport{
|
||||
cfg: cfg,
|
||||
underlying: otelhttp.NewTransport(underlying),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *tracingTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
aibridgeCtx, hasAIBridgeCtx := req.Context().Value(tracingContextKey{}).(tracingContext)
|
||||
|
||||
resp, err := t.underlying.RoundTrip(req)
|
||||
|
||||
if hasAIBridgeCtx {
|
||||
ctx := req.Context()
|
||||
if resp != nil && resp.Request != nil {
|
||||
ctx = resp.Request.Context()
|
||||
}
|
||||
span := trace.SpanFromContext(ctx)
|
||||
if span.IsRecording() {
|
||||
span.SetAttributes(
|
||||
attribute.String("aibridge.provider", aibridgeCtx.provider),
|
||||
attribute.String("aibridge.model", aibridgeCtx.model),
|
||||
attribute.Bool("aibridge.stream", aibridgeCtx.stream),
|
||||
attribute.Int("aibridge.request_num", aibridgeCtx.requestNum),
|
||||
attribute.String("aibridge.mode", string(aibridgeCtx.mode)),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return resp, err
|
||||
}
|
||||
|
||||
type Runner struct {
|
||||
client *codersdk.Client
|
||||
cfg Config
|
||||
strategy requestModeStrategy
|
||||
providerStrategy ProviderStrategy
|
||||
|
||||
clock quartz.Clock
|
||||
httpClient *http.Client
|
||||
|
||||
requestCount int64
|
||||
successCount int64
|
||||
failureCount int64
|
||||
totalDuration time.Duration
|
||||
totalTokens int64
|
||||
}
|
||||
|
||||
func NewRunner(client *codersdk.Client, cfg Config) *Runner {
|
||||
httpTimeout := cfg.HTTPTimeout
|
||||
if httpTimeout <= 0 {
|
||||
httpTimeout = 30 * time.Second
|
||||
}
|
||||
return &Runner{
|
||||
client: client,
|
||||
cfg: cfg,
|
||||
strategy: cfg.NewStrategy(client),
|
||||
providerStrategy: NewProviderStrategy(cfg.Provider),
|
||||
clock: quartz.NewReal(),
|
||||
httpClient: &http.Client{
|
||||
Timeout: httpTimeout,
|
||||
Transport: newTracingTransport(cfg, http.DefaultTransport),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Runner) WithClock(clock quartz.Clock) *Runner {
|
||||
r.clock = clock
|
||||
return r
|
||||
}
|
||||
|
||||
var (
|
||||
_ harness.Runnable = &Runner{}
|
||||
_ harness.Cleanable = &Runner{}
|
||||
_ harness.Collectable = &Runner{}
|
||||
)
|
||||
|
||||
func (r *Runner) Run(ctx context.Context, id string, logs io.Writer) error {
|
||||
ctx, span := tracing.StartSpan(ctx)
|
||||
defer span.End()
|
||||
|
||||
logs = loadtestutil.NewSyncWriter(logs)
|
||||
logger := slog.Make(sloghuman.Sink(logs)).Leveled(slog.LevelDebug)
|
||||
|
||||
requestURL, token, err := r.strategy.Setup(ctx, id, logs)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("strategy setup: %w", err)
|
||||
}
|
||||
|
||||
requestCount := r.cfg.RequestCount
|
||||
if requestCount <= 0 {
|
||||
requestCount = 1
|
||||
}
|
||||
|
||||
model := r.providerStrategy.DefaultModel()
|
||||
|
||||
logger.Info(ctx, "bridge runner is ready",
|
||||
slog.F("request_count", requestCount),
|
||||
slog.F("model", model),
|
||||
slog.F("stream", r.cfg.Stream),
|
||||
)
|
||||
|
||||
for i := 0; i < requestCount; i++ {
|
||||
if err := r.makeRequest(ctx, logger, requestURL, token, model, i); err != nil {
|
||||
logger.Warn(ctx, "bridge request failed",
|
||||
slog.F("request_num", i+1),
|
||||
slog.F("error_type", "request_failed"),
|
||||
slog.Error(err),
|
||||
)
|
||||
r.cfg.Metrics.AddError("request")
|
||||
r.cfg.Metrics.AddRequest("failure")
|
||||
r.failureCount++
|
||||
|
||||
// Continue making requests even if one fails
|
||||
continue
|
||||
}
|
||||
r.successCount++
|
||||
r.cfg.Metrics.AddRequest("success")
|
||||
r.requestCount++
|
||||
}
|
||||
|
||||
logger.Info(ctx, "bridge runner completed",
|
||||
slog.F("total_requests", r.requestCount),
|
||||
slog.F("success", r.successCount),
|
||||
slog.F("failure", r.failureCount),
|
||||
)
|
||||
|
||||
// Fail the run if any request failed
|
||||
if r.failureCount > 0 {
|
||||
return xerrors.Errorf("bridge runner failed: %d out of %d requests failed", r.failureCount, requestCount)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Runner) makeRequest(ctx context.Context, logger slog.Logger, url, token, model string, requestNum int) error {
|
||||
start := r.clock.Now()
|
||||
|
||||
ctx = context.WithValue(ctx, tracingContextKey{}, tracingContext{
|
||||
provider: r.cfg.Provider,
|
||||
model: model,
|
||||
stream: r.cfg.Stream,
|
||||
requestNum: requestNum + 1,
|
||||
mode: r.cfg.Mode,
|
||||
})
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(r.cfg.RequestBody))
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
}
|
||||
|
||||
logger.Debug(ctx, "making bridge request",
|
||||
slog.F("url", url),
|
||||
slog.F("request_num", requestNum+1),
|
||||
slog.F("model", model),
|
||||
)
|
||||
|
||||
resp, err := r.httpClient.Do(req)
|
||||
if err != nil {
|
||||
span := trace.SpanFromContext(req.Context())
|
||||
if span.IsRecording() {
|
||||
span.RecordError(err)
|
||||
}
|
||||
logger.Warn(ctx, "request failed during execution",
|
||||
slog.F("request_num", requestNum+1),
|
||||
slog.Error(err),
|
||||
)
|
||||
return xerrors.Errorf("execute request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
span := trace.SpanFromContext(req.Context())
|
||||
if span.IsRecording() {
|
||||
span.SetAttributes(semconv.HTTPStatusCodeKey.Int(resp.StatusCode))
|
||||
span.SetStatus(httpconv.ClientStatus(resp.StatusCode))
|
||||
}
|
||||
|
||||
duration := r.clock.Since(start)
|
||||
r.totalDuration += duration
|
||||
r.cfg.Metrics.ObserveDuration(duration.Seconds())
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
err := xerrors.Errorf("request failed with status %d: %s", resp.StatusCode, string(body))
|
||||
span.RecordError(err)
|
||||
return err
|
||||
}
|
||||
|
||||
if r.cfg.Stream {
|
||||
err := r.handleStreamingResponse(ctx, logger, resp)
|
||||
if err != nil {
|
||||
span.RecordError(err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return r.handleNonStreamingResponse(ctx, logger, resp)
|
||||
}
|
||||
|
||||
func (r *Runner) handleNonStreamingResponse(ctx context.Context, logger slog.Logger, resp *http.Response) error {
|
||||
if r.cfg.Provider == "anthropic" {
|
||||
return r.handleAnthropicResponse(ctx, logger, resp)
|
||||
}
|
||||
return r.handleOpenAIResponse(ctx, logger, resp)
|
||||
}
|
||||
|
||||
func (r *Runner) handleOpenAIResponse(ctx context.Context, logger slog.Logger, resp *http.Response) error {
|
||||
var response struct {
|
||||
ID string `json:"id"`
|
||||
Model string `json:"model"`
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
} `json:"choices"`
|
||||
Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
|
||||
return xerrors.Errorf("decode response: %w", err)
|
||||
}
|
||||
|
||||
if len(response.Choices) > 0 {
|
||||
assistantContent := response.Choices[0].Message.Content
|
||||
logger.Debug(ctx, "received response",
|
||||
slog.F("response_id", response.ID),
|
||||
slog.F("content_length", len(assistantContent)),
|
||||
)
|
||||
}
|
||||
|
||||
if response.Usage.TotalTokens > 0 {
|
||||
r.totalTokens += int64(response.Usage.TotalTokens)
|
||||
r.cfg.Metrics.AddTokens("input", int64(response.Usage.PromptTokens))
|
||||
r.cfg.Metrics.AddTokens("output", int64(response.Usage.CompletionTokens))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Runner) handleAnthropicResponse(ctx context.Context, logger slog.Logger, resp *http.Response) error {
|
||||
var response struct {
|
||||
ID string `json:"id"`
|
||||
Model string `json:"model"`
|
||||
Content []struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
} `json:"content"`
|
||||
Usage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
|
||||
return xerrors.Errorf("decode response: %w", err)
|
||||
}
|
||||
|
||||
var assistantContent string
|
||||
if len(response.Content) > 0 {
|
||||
assistantContent = response.Content[0].Text
|
||||
logger.Debug(ctx, "received response",
|
||||
slog.F("response_id", response.ID),
|
||||
slog.F("content_length", len(assistantContent)),
|
||||
)
|
||||
}
|
||||
|
||||
totalTokens := response.Usage.InputTokens + response.Usage.OutputTokens
|
||||
if totalTokens > 0 {
|
||||
r.totalTokens += int64(totalTokens)
|
||||
r.cfg.Metrics.AddTokens("input", int64(response.Usage.InputTokens))
|
||||
r.cfg.Metrics.AddTokens("output", int64(response.Usage.OutputTokens))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*Runner) handleStreamingResponse(ctx context.Context, logger slog.Logger, resp *http.Response) error {
|
||||
buf := make([]byte, 4096)
|
||||
totalRead := 0
|
||||
for {
|
||||
// Check for context cancellation before each read
|
||||
if ctx.Err() != nil {
|
||||
logger.Warn(ctx, "streaming response canceled",
|
||||
slog.F("bytes_read", totalRead),
|
||||
slog.Error(ctx.Err()),
|
||||
)
|
||||
return xerrors.Errorf("stream canceled: %w", ctx.Err())
|
||||
}
|
||||
|
||||
n, err := resp.Body.Read(buf)
|
||||
if n > 0 {
|
||||
totalRead += n
|
||||
}
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
// Check if error is due to context cancellation
|
||||
if xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) {
|
||||
logger.Warn(ctx, "streaming response read canceled",
|
||||
slog.F("bytes_read", totalRead),
|
||||
slog.Error(err),
|
||||
)
|
||||
return xerrors.Errorf("stream read canceled: %w", err)
|
||||
}
|
||||
logger.Warn(ctx, "streaming response read error",
|
||||
slog.F("bytes_read", totalRead),
|
||||
slog.Error(err),
|
||||
)
|
||||
return xerrors.Errorf("read stream: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Debug(ctx, "received streaming response", slog.F("bytes_read", totalRead))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Runner) Cleanup(ctx context.Context, id string, logs io.Writer) error {
|
||||
return r.strategy.Cleanup(ctx, id, logs)
|
||||
}
|
||||
|
||||
func (r *Runner) GetMetrics() map[string]any {
|
||||
avgDuration := time.Duration(0)
|
||||
if r.requestCount > 0 {
|
||||
avgDuration = r.totalDuration / time.Duration(r.requestCount)
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
"request_count": r.requestCount,
|
||||
"success_count": r.successCount,
|
||||
"failure_count": r.failureCount,
|
||||
"total_duration": r.totalDuration.String(),
|
||||
"avg_duration": avgDuration.String(),
|
||||
"total_tokens": r.totalTokens,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,117 @@
|
||||
package bridge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/sloghuman"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/scaletest/createusers"
|
||||
)
|
||||
|
||||
type requestModeStrategy interface {
|
||||
Setup(ctx context.Context, id string, logs io.Writer) (url string, token string, err error)
|
||||
Cleanup(ctx context.Context, id string, logs io.Writer) error
|
||||
}
|
||||
|
||||
// bridgeStrategy creates users via Coder and routes requests through AI Bridge.
|
||||
type bridgeStrategy struct {
|
||||
client *codersdk.Client
|
||||
provider string
|
||||
metrics *Metrics
|
||||
|
||||
userConfig createusers.Config
|
||||
createUserRunner *createusers.Runner
|
||||
}
|
||||
|
||||
type bridgeStrategyConfig struct {
|
||||
Client *codersdk.Client
|
||||
Provider string
|
||||
Metrics *Metrics
|
||||
User createusers.Config
|
||||
}
|
||||
|
||||
func newBridgeStrategy(cfg bridgeStrategyConfig) *bridgeStrategy {
|
||||
return &bridgeStrategy{
|
||||
client: cfg.Client,
|
||||
provider: cfg.Provider,
|
||||
metrics: cfg.Metrics,
|
||||
userConfig: cfg.User,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *bridgeStrategy) Setup(ctx context.Context, id string, logs io.Writer) (requestURL string, token string, err error) {
|
||||
logger := slog.Make(sloghuman.Sink(logs)).Leveled(slog.LevelDebug)
|
||||
|
||||
s.client.SetLogger(logger)
|
||||
s.client.SetLogBodies(true)
|
||||
|
||||
s.createUserRunner = createusers.NewRunner(s.client, s.userConfig)
|
||||
newUserAndToken, err := s.createUserRunner.RunReturningUser(ctx, id, logs)
|
||||
if err != nil {
|
||||
s.metrics.AddError("create_user")
|
||||
return "", "", xerrors.Errorf("create user: %w", err)
|
||||
}
|
||||
newUser := newUserAndToken.User
|
||||
token = newUserAndToken.SessionToken
|
||||
|
||||
logger.Info(ctx, "runner user created",
|
||||
slog.F("username", newUser.Username),
|
||||
slog.F("user_id", newUser.ID.String()),
|
||||
)
|
||||
|
||||
if s.provider == "anthropic" {
|
||||
requestURL = fmt.Sprintf("%s/api/v2/aibridge/anthropic/v1/messages", s.client.URL)
|
||||
} else {
|
||||
requestURL = fmt.Sprintf("%s/api/v2/aibridge/openai/v1/chat/completions", s.client.URL)
|
||||
}
|
||||
logger.Info(ctx, "bridge runner in bridge mode",
|
||||
slog.F("url", requestURL),
|
||||
slog.F("provider", s.provider),
|
||||
)
|
||||
|
||||
return requestURL, token, nil
|
||||
}
|
||||
|
||||
func (s *bridgeStrategy) Cleanup(ctx context.Context, id string, logs io.Writer) error {
|
||||
if s.createUserRunner == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintln(logs, "Cleaning up user...")
|
||||
if err := s.createUserRunner.Cleanup(ctx, id, logs); err != nil {
|
||||
return xerrors.Errorf("cleanup user: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// directStrategy makes requests directly to an upstream URL.
|
||||
type directStrategy struct {
|
||||
upstreamURL string
|
||||
}
|
||||
|
||||
type directStrategyConfig struct {
|
||||
UpstreamURL string
|
||||
}
|
||||
|
||||
func newDirectStrategy(cfg directStrategyConfig) *directStrategy {
|
||||
return &directStrategy{
|
||||
upstreamURL: cfg.UpstreamURL,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *directStrategy) Setup(ctx context.Context, _ string, logs io.Writer) (requestURL string, _ string, err error) {
|
||||
logger := slog.Make(sloghuman.Sink(logs)).Leveled(slog.LevelDebug)
|
||||
|
||||
logger.Info(ctx, "bridge runner in direct mode", slog.F("url", s.upstreamURL))
|
||||
return s.upstreamURL, "", err
|
||||
}
|
||||
|
||||
func (*directStrategy) Cleanup(_ context.Context, _ string, _ io.Writer) error {
|
||||
// Direct mode has no resources to clean up.
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,545 @@
|
||||
package llmmock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.opentelemetry.io/otel/propagation"
|
||||
semconv "go.opentelemetry.io/otel/semconv/v1.14.0"
|
||||
"go.opentelemetry.io/otel/semconv/v1.14.0/httpconv"
|
||||
"go.opentelemetry.io/otel/semconv/v1.14.0/netconv"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/pproflabel"
|
||||
"github.com/coder/coder/v2/coderd/tracing"
|
||||
)
|
||||
|
||||
// Server wraps the LLM mock server and provides an HTTP API to retrieve requests.
|
||||
type Server struct {
|
||||
httpServer *http.Server
|
||||
httpListener net.Listener
|
||||
logger slog.Logger
|
||||
|
||||
address string
|
||||
artificialLatency time.Duration
|
||||
responsePayloadSize int
|
||||
|
||||
tracerProvider trace.TracerProvider
|
||||
closeTracing func(context.Context) error
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Address string
|
||||
Logger slog.Logger
|
||||
ArtificialLatency time.Duration
|
||||
ResponsePayloadSize int
|
||||
|
||||
PprofEnable bool
|
||||
PprofAddress string
|
||||
|
||||
TraceEnable bool
|
||||
}
|
||||
|
||||
type llmRequest struct {
|
||||
Model string `json:"model"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
}
|
||||
|
||||
type openAIMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type openAIResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []struct {
|
||||
Index int `json:"index"`
|
||||
Message openAIMessage `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
} `json:"choices"`
|
||||
Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
|
||||
type anthropicResponse struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Role string `json:"role"`
|
||||
Content []struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
} `json:"content"`
|
||||
Model string `json:"model"`
|
||||
StopReason string `json:"stop_reason"`
|
||||
StopSequence *string `json:"stop_sequence"`
|
||||
Usage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
|
||||
func (s *Server) Start(ctx context.Context, cfg Config) error {
|
||||
s.address = cfg.Address
|
||||
s.logger = cfg.Logger
|
||||
s.artificialLatency = cfg.ArtificialLatency
|
||||
s.responsePayloadSize = cfg.ResponsePayloadSize
|
||||
|
||||
if cfg.TraceEnable {
|
||||
otel.SetTextMapPropagator(
|
||||
propagation.NewCompositeTextMapPropagator(
|
||||
propagation.TraceContext{},
|
||||
propagation.Baggage{},
|
||||
),
|
||||
)
|
||||
|
||||
tracerProvider, closeTracing, err := tracing.TracerProvider(ctx, "llm-mock", tracing.TracerOpts{
|
||||
Default: cfg.TraceEnable,
|
||||
})
|
||||
if err != nil {
|
||||
s.logger.Warn(ctx, "failed to initialize tracing", slog.Error(err))
|
||||
} else {
|
||||
s.tracerProvider = tracerProvider
|
||||
s.closeTracing = closeTracing
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.startAPIServer(ctx); err != nil {
|
||||
return xerrors.Errorf("start API server: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) Stop() error {
|
||||
if s.httpServer != nil {
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := s.httpServer.Shutdown(shutdownCtx); err != nil {
|
||||
return xerrors.Errorf("shutdown HTTP server: %w", err)
|
||||
}
|
||||
}
|
||||
if s.closeTracing != nil {
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := s.closeTracing(shutdownCtx); err != nil {
|
||||
s.logger.Warn(shutdownCtx, "failed to close tracing", slog.Error(err))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) APIAddress() string {
|
||||
return fmt.Sprintf("http://%s", s.httpListener.Addr().String())
|
||||
}
|
||||
|
||||
func (s *Server) startAPIServer(ctx context.Context) error {
|
||||
mux := http.NewServeMux()
|
||||
|
||||
mux.HandleFunc("POST /v1/chat/completions", s.handleOpenAI)
|
||||
mux.HandleFunc("POST /v1/messages", s.handleAnthropic)
|
||||
|
||||
var handler http.Handler = mux
|
||||
if s.tracerProvider != nil {
|
||||
handler = s.tracingMiddleware(handler)
|
||||
}
|
||||
|
||||
s.httpServer = &http.Server{
|
||||
Handler: handler,
|
||||
ReadHeaderTimeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
listener, err := net.Listen("tcp", s.address)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("listen on %s: %w", s.address, err)
|
||||
}
|
||||
s.httpListener = listener
|
||||
|
||||
pproflabel.Go(ctx, pproflabel.Service("llm-mock"), func(ctx context.Context) {
|
||||
if err := s.httpServer.Serve(listener); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
s.logger.Error(ctx, "http API server error", slog.Error(err))
|
||||
}
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) handleOpenAI(w http.ResponseWriter, r *http.Request) {
|
||||
pproflabel.Do(r.Context(), pproflabel.Service("llm-mock"), func(ctx context.Context) {
|
||||
s.handleOpenAIWithLabels(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) handleOpenAIWithLabels(w http.ResponseWriter, r *http.Request) {
|
||||
s.logger.Debug(r.Context(), "handling OpenAI request")
|
||||
defer s.logger.Debug(r.Context(), "handled OpenAI request")
|
||||
|
||||
ctx := r.Context()
|
||||
requestID := uuid.New()
|
||||
now := time.Now()
|
||||
|
||||
var req llmRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
s.logger.Error(ctx, "failed to parse OpenAI request", slog.Error(err))
|
||||
http.Error(w, "invalid request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if s.artificialLatency > 0 {
|
||||
time.Sleep(s.artificialLatency)
|
||||
}
|
||||
|
||||
var resp openAIResponse
|
||||
resp.ID = fmt.Sprintf("chatcmpl-%s", requestID.String()[:8])
|
||||
resp.Object = "chat.completion"
|
||||
resp.Created = now.Unix()
|
||||
resp.Model = req.Model
|
||||
|
||||
var responseContent string
|
||||
if s.responsePayloadSize > 0 {
|
||||
pattern := "x"
|
||||
repeated := strings.Repeat(pattern, s.responsePayloadSize)
|
||||
responseContent = repeated[:s.responsePayloadSize]
|
||||
} else {
|
||||
responseContent = "This is a mock response from OpenAI."
|
||||
}
|
||||
|
||||
resp.Choices = []struct {
|
||||
Index int `json:"index"`
|
||||
Message openAIMessage `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}{
|
||||
{
|
||||
Index: 0,
|
||||
Message: openAIMessage{
|
||||
Role: "assistant",
|
||||
Content: responseContent,
|
||||
},
|
||||
FinishReason: "stop",
|
||||
},
|
||||
}
|
||||
|
||||
resp.Usage.PromptTokens = 10
|
||||
resp.Usage.CompletionTokens = 5
|
||||
resp.Usage.TotalTokens = 15
|
||||
|
||||
responseBody, _ := json.Marshal(resp)
|
||||
|
||||
if req.Stream {
|
||||
s.sendOpenAIStream(ctx, w, resp)
|
||||
} else {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if _, err := w.Write(responseBody); err != nil {
|
||||
s.logger.Error(ctx, "failed to write OpenAI response",
|
||||
slog.F("request_id", requestID),
|
||||
slog.Error(err),
|
||||
slog.F("error_type", "write_error"),
|
||||
slog.F("likely_cause", "network_error"),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleAnthropic(w http.ResponseWriter, r *http.Request) {
|
||||
pproflabel.Do(r.Context(), pproflabel.Service("llm-mock"), func(ctx context.Context) {
|
||||
s.handleAnthropicWithLabels(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) handleAnthropicWithLabels(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
requestID := uuid.New()
|
||||
|
||||
var req llmRequest
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
s.logger.Error(ctx, "failed to parse LLM request", slog.Error(err))
|
||||
http.Error(w, "invalid request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if s.artificialLatency > 0 {
|
||||
time.Sleep(s.artificialLatency)
|
||||
}
|
||||
|
||||
var resp anthropicResponse
|
||||
resp.ID = fmt.Sprintf("msg_%s", requestID.String()[:8])
|
||||
resp.Type = "message"
|
||||
resp.Role = "assistant"
|
||||
|
||||
var responseText string
|
||||
if s.responsePayloadSize > 0 {
|
||||
pattern := "x"
|
||||
repeated := strings.Repeat(pattern, s.responsePayloadSize)
|
||||
responseText = repeated[:s.responsePayloadSize]
|
||||
} else {
|
||||
responseText = "This is a mock response from Anthropic."
|
||||
}
|
||||
|
||||
resp.Content = []struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
}{
|
||||
{
|
||||
Type: "text",
|
||||
Text: responseText,
|
||||
},
|
||||
}
|
||||
resp.Model = req.Model
|
||||
resp.StopReason = "end_turn"
|
||||
resp.Usage.InputTokens = 10
|
||||
resp.Usage.OutputTokens = 5
|
||||
|
||||
responseBody, _ := json.Marshal(resp)
|
||||
|
||||
if req.Stream {
|
||||
s.sendAnthropicStream(ctx, w, resp)
|
||||
} else {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("anthropic-version", "2023-06-01")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if _, err := w.Write(responseBody); err != nil {
|
||||
s.logger.Error(ctx, "failed to write Anthropic response",
|
||||
slog.F("request_id", requestID),
|
||||
slog.Error(err),
|
||||
slog.F("error_type", "write_error"),
|
||||
slog.F("likely_cause", "network_error"),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) sendOpenAIStream(ctx context.Context, w http.ResponseWriter, resp openAIResponse) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
s.logger.Error(ctx, "responseWriter does not support flushing",
|
||||
slog.F("response_id", resp.ID),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
writeChunk := func(data string) bool {
|
||||
if _, err := fmt.Fprintf(w, "%s", data); err != nil {
|
||||
s.logger.Error(ctx, "failed to write OpenAI stream chunk",
|
||||
slog.F("response_id", resp.ID),
|
||||
slog.Error(err),
|
||||
slog.F("error_type", "write_error"),
|
||||
slog.F("likely_cause", "network_error"),
|
||||
)
|
||||
return false
|
||||
}
|
||||
flusher.Flush()
|
||||
return true
|
||||
}
|
||||
|
||||
// Send initial chunk
|
||||
chunk := map[string]interface{}{
|
||||
"id": resp.ID,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": resp.Created,
|
||||
"model": resp.Model,
|
||||
"choices": []map[string]interface{}{
|
||||
{
|
||||
"index": 0,
|
||||
"delta": map[string]interface{}{
|
||||
"role": "assistant",
|
||||
"content": resp.Choices[0].Message.Content,
|
||||
},
|
||||
"finish_reason": nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
chunkBytes, _ := json.Marshal(chunk)
|
||||
if !writeChunk(fmt.Sprintf("data: %s\n\n", chunkBytes)) {
|
||||
return
|
||||
}
|
||||
|
||||
// Send final chunk
|
||||
finalChunk := map[string]interface{}{
|
||||
"id": resp.ID,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": resp.Created,
|
||||
"model": resp.Model,
|
||||
"choices": []map[string]interface{}{
|
||||
{
|
||||
"index": 0,
|
||||
"delta": map[string]interface{}{},
|
||||
"finish_reason": resp.Choices[0].FinishReason,
|
||||
},
|
||||
},
|
||||
}
|
||||
finalChunkBytes, _ := json.Marshal(finalChunk)
|
||||
if !writeChunk(fmt.Sprintf("data: %s\n\n", finalChunkBytes)) {
|
||||
return
|
||||
}
|
||||
writeChunk("data: [DONE]\n\n")
|
||||
}
|
||||
|
||||
func (s *Server) sendAnthropicStream(ctx context.Context, w http.ResponseWriter, resp anthropicResponse) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("anthropic-version", "2023-06-01")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
s.logger.Error(ctx, "responseWriter does not support flushing",
|
||||
slog.F("response_id", resp.ID),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
writeChunk := func(data string) bool {
|
||||
if _, err := fmt.Fprintf(w, "%s", data); err != nil {
|
||||
s.logger.Error(ctx, "failed to write Anthropic stream chunk",
|
||||
slog.F("response_id", resp.ID),
|
||||
slog.Error(err),
|
||||
slog.F("error_type", "write_error"),
|
||||
slog.F("likely_cause", "network_error"),
|
||||
)
|
||||
return false
|
||||
}
|
||||
flusher.Flush()
|
||||
return true
|
||||
}
|
||||
|
||||
startEvent := map[string]interface{}{
|
||||
"type": "message_start",
|
||||
"message": map[string]interface{}{
|
||||
"id": resp.ID,
|
||||
"type": resp.Type,
|
||||
"role": resp.Role,
|
||||
"model": resp.Model,
|
||||
},
|
||||
}
|
||||
startBytes, _ := json.Marshal(startEvent)
|
||||
if !writeChunk(fmt.Sprintf("data: %s\n\n", startBytes)) {
|
||||
return
|
||||
}
|
||||
|
||||
// Send content_block_start event
|
||||
contentStartEvent := map[string]interface{}{
|
||||
"type": "content_block_start",
|
||||
"index": 0,
|
||||
"content_block": map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": resp.Content[0].Text,
|
||||
},
|
||||
}
|
||||
contentStartBytes, _ := json.Marshal(contentStartEvent)
|
||||
if !writeChunk(fmt.Sprintf("data: %s\n\n", contentStartBytes)) {
|
||||
return
|
||||
}
|
||||
|
||||
// Send content_block_delta event
|
||||
deltaEvent := map[string]interface{}{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": map[string]interface{}{
|
||||
"type": "text_delta",
|
||||
"text": resp.Content[0].Text,
|
||||
},
|
||||
}
|
||||
deltaBytes, _ := json.Marshal(deltaEvent)
|
||||
if !writeChunk(fmt.Sprintf("data: %s\n\n", deltaBytes)) {
|
||||
return
|
||||
}
|
||||
|
||||
// Send content_block_stop event
|
||||
contentStopEvent := map[string]interface{}{
|
||||
"type": "content_block_stop",
|
||||
"index": 0,
|
||||
}
|
||||
contentStopBytes, _ := json.Marshal(contentStopEvent)
|
||||
if !writeChunk(fmt.Sprintf("data: %s\n\n", contentStopBytes)) {
|
||||
return
|
||||
}
|
||||
|
||||
// Send message_delta event
|
||||
deltaMsgEvent := map[string]interface{}{
|
||||
"type": "message_delta",
|
||||
"delta": map[string]interface{}{
|
||||
"stop_reason": resp.StopReason,
|
||||
"stop_sequence": resp.StopSequence,
|
||||
},
|
||||
"usage": resp.Usage,
|
||||
}
|
||||
deltaMsgBytes, _ := json.Marshal(deltaMsgEvent)
|
||||
if !writeChunk(fmt.Sprintf("data: %s\n\n", deltaMsgBytes)) {
|
||||
return
|
||||
}
|
||||
|
||||
// Send message_stop event
|
||||
stopEvent := map[string]interface{}{
|
||||
"type": "message_stop",
|
||||
}
|
||||
stopBytes, _ := json.Marshal(stopEvent)
|
||||
writeChunk(fmt.Sprintf("data: %s\n\n", stopBytes))
|
||||
}
|
||||
|
||||
func (s *Server) tracingMiddleware(next http.Handler) http.Handler {
|
||||
tracer := s.tracerProvider.Tracer("llm-mock")
|
||||
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
// Wrap response writer with StatusWriter for tracing
|
||||
sw := &tracing.StatusWriter{ResponseWriter: rw}
|
||||
|
||||
// Extract trace context from headers
|
||||
propagator := otel.GetTextMapPropagator()
|
||||
hc := propagation.HeaderCarrier(r.Header)
|
||||
ctx := propagator.Extract(r.Context(), hc)
|
||||
|
||||
// Start span with initial name (will be updated after handler)
|
||||
ctx, span := tracer.Start(ctx, fmt.Sprintf("%s %s", r.Method, r.RequestURI))
|
||||
defer span.End()
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
// Inject trace context into response headers
|
||||
if span.SpanContext().HasTraceID() && span.SpanContext().HasSpanID() {
|
||||
rw.Header().Set("X-Trace-ID", span.SpanContext().TraceID().String())
|
||||
rw.Header().Set("X-Span-ID", span.SpanContext().SpanID().String())
|
||||
|
||||
hc := propagation.HeaderCarrier(rw.Header())
|
||||
propagator.Inject(ctx, hc)
|
||||
}
|
||||
|
||||
// Execute the handler
|
||||
next.ServeHTTP(sw, r)
|
||||
|
||||
// Update span with final route and response information
|
||||
route := r.URL.Path
|
||||
span.SetName(fmt.Sprintf("%s %s", r.Method, route))
|
||||
span.SetAttributes(netconv.Transport("tcp"))
|
||||
span.SetAttributes(httpconv.ServerRequest("llm-mock", r)...)
|
||||
span.SetAttributes(semconv.HTTPRouteKey.String(route))
|
||||
|
||||
status := sw.Status
|
||||
if status == 0 {
|
||||
status = http.StatusOK
|
||||
}
|
||||
span.SetAttributes(semconv.HTTPStatusCodeKey.Int(status))
|
||||
span.SetStatus(httpconv.ServerStatus(status))
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user