feat(scaletest): add chat scaletest command (#25553)

Adds `coder exp scaletest chat`, a harness for creating Coder Agents
chat load.
Start the mock LLM separately, prepare the scaletest workspaces you want
to target, then run the chat scaletest against the existing
`scaletest-*` fleet selected by the shared workspace targeting flags:

```sh
coder exp scaletest llm-mock --address 127.0.0.1:18080

coder exp scaletest chat --llm-mock-url http://127.0.0.1:18080/v1 --chats-per-workspace 10 --turns 1
coder exp scaletest chat --llm-mock-url http://127.0.0.1:18080/v1 --template docker --target-workspaces 0:10 --chats-per-workspace 1 --turns 10 --turn-start-delay 30s
```

This is the same pattern used by the `workspace-traffic` load generator.

Keeping the fake LLM as a separate process is intentional so it can be
scaled independently from the Coder deployment, which will likely be
necessary as we scale up and up.

This PR is the starting point: it provides the command, mock
provider/model bootstrap, existing workspace selection, chat streaming,
follow-up turns, metrics, and cleanup. Follow-up PRs will add multi-step
turns via tool calls. I'm still a bit iffy on the mechanism I have for
that. It'll likely involve having the runner send some magic strings
that the mock will recognise.


Relates to CODAGT-307
Relates to GRU-48
Relates to https://github.com/coder/scaletest/issues/124

Generated by Mux, but reviewed by a human
This commit is contained in:
Ethan
2026-05-26 14:19:36 +10:00
committed by GitHub
parent fe13bb2a20
commit 4f1043a50a
8 changed files with 1478 additions and 2 deletions
+3 -2
View File
@@ -70,6 +70,7 @@ func (r *RootCmd) scaletestCmd() *serpent.Command {
r.scaletestSMTP(),
r.scaletestPrebuilds(),
r.scaletestBridge(),
r.scaletestChat(),
r.scaletestLLMMock(),
},
}
@@ -404,13 +405,13 @@ func (f *workspaceTargetFlags) attach(opts *serpent.OptionSet) {
Flag: "template",
FlagShorthand: "t",
Env: "CODER_SCALETEST_TEMPLATE",
Description: "Name or ID of the template. Traffic generation will be limited to workspaces created from this template.",
Description: "Name or ID of the template. Only workspaces created from this template are targeted.",
Value: serpent.StringOf(&f.template),
},
serpent.Option{
Flag: "target-workspaces",
Env: "CODER_SCALETEST_TARGET_WORKSPACES",
Description: "Target a specific range of workspaces in the format [START]:[END] (exclusive). Example: 0:10 will target the 10 first alphabetically sorted workspaces (0-9).",
Description: "Target a specific range of matching workspaces in the format [START]:[END] (exclusive). Example: 0:10 targets the first 10 matching workspaces returned by the workspace query.",
Value: serpent.StringOf(&f.targetWorkspaces),
},
serpent.Option{
+254
View File
@@ -0,0 +1,254 @@
//go:build !slim
package cli
import (
"fmt"
"sync"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/sloghuman"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/scaletest/chat"
"github.com/coder/coder/v2/scaletest/harness"
"github.com/coder/coder/v2/scaletest/loadtestutil"
"github.com/coder/serpent"
)
func (r *RootCmd) scaletestChat() *serpent.Command {
var (
chatsPerWorkspace int64
prompt string
turns int64
turnStartDelay time.Duration
llmMockURL string
targetFlags = &workspaceTargetFlags{}
tracingFlags = &scaletestTracingFlags{}
prometheusFlags = &scaletestPrometheusFlags{}
timeoutStrategy = &timeoutFlags{}
cleanupStrategy = newScaletestCleanupStrategy()
output = &scaletestOutputFlags{}
)
cmd := &serpent.Command{
Use: "chat",
Short: "Generate Coder Agents load.",
Handler: func(inv *serpent.Invocation) error {
baseCtx := inv.Context()
ctx, stop := inv.SignalNotifyContext(baseCtx, StopSignals...)
defer stop()
outputs, err := output.parse()
if err != nil {
return xerrors.Errorf("could not parse --output flags: %w", err)
}
switch {
case turns < 1:
return xerrors.Errorf("--turns must be at least 1")
case chatsPerWorkspace < 1:
return xerrors.Errorf("--chats-per-workspace must be at least 1")
}
client, err := r.InitClient(inv)
if err != nil {
return err
}
me, err := RequireAdmin(ctx, client)
if err != nil {
return err
}
client.HTTPClient.Transport = &codersdk.HeaderTransport{
Transport: client.HTTPClient.Transport,
Header: BypassHeader,
}
workspaces, err := targetFlags.getTargetedWorkspaces(ctx, client, me.OrganizationIDs, inv.Stdout)
if err != nil {
return err
}
logger := slog.Make(sloghuman.Sink(inv.Stderr)).Leveled(slog.LevelDebug)
modelConfigID, err := chat.EnsureScaletestModelConfig(ctx, codersdk.NewExperimentalClient(client), logger, llmMockURL)
if err != nil {
return err
}
// Start metrics and tracing before creating runners.
reg := prometheus.NewRegistry()
metrics := chat.NewMetrics(reg)
prometheusSrvClose := ServeHandler(baseCtx, logger, promhttp.HandlerFor(reg, promhttp.HandlerOpts{}), prometheusFlags.Address, "prometheus")
tracerProvider, closeTracing, tracingEnabled, err := tracingFlags.provider(baseCtx)
if err != nil {
prometheusSrvClose()
return xerrors.Errorf("create tracer provider: %w", err)
}
defer func() {
if tracingEnabled {
_, _ = fmt.Fprintln(inv.Stderr, "Uploading traces...")
}
if err := closeTracing(baseCtx); err != nil {
_, _ = fmt.Fprintf(inv.Stderr, "Error uploading traces: %+v\n", err)
}
_, _ = fmt.Fprintf(inv.Stderr, "Waiting %s for prometheus metrics to be scraped\n", prometheusFlags.Wait)
<-time.After(prometheusFlags.Wait)
prometheusSrvClose()
}()
tracer := tracerProvider.Tracer(scaletestTracerName)
var turnStartReadyWaitGroup *sync.WaitGroup
var startTurnsChan chan struct{}
if turnStartDelay > 0 && turns > 1 {
turnStartReadyWaitGroup = &sync.WaitGroup{}
startTurnsChan = make(chan struct{})
}
chatHarness := harness.NewTestHarness(
timeoutStrategy.wrapStrategy(harness.ConcurrentExecutionStrategy{}),
cleanupStrategy.toStrategy(),
)
for workspaceIndex, targetWorkspace := range workspaces {
for chatIndex := int64(0); chatIndex < chatsPerWorkspace; chatIndex++ {
if turnStartReadyWaitGroup != nil {
turnStartReadyWaitGroup.Add(1)
}
cfg := chat.Config{
OrganizationID: targetWorkspace.OrganizationID,
WorkspaceID: targetWorkspace.ID,
Prompt: prompt,
ModelConfigID: modelConfigID,
Turns: int(turns),
TurnStartDelay: turnStartDelay,
TurnStartReadyWaitGroup: turnStartReadyWaitGroup,
StartTurnsChan: startTurnsChan,
Metrics: metrics,
}
if err := cfg.Validate(); err != nil {
return xerrors.Errorf("validate config for workspace %d chat %d: %w", workspaceIndex, chatIndex, err)
}
runnerClient, err := loadtestutil.DupClientCopyingHeaders(client, BypassHeader)
if err != nil {
return xerrors.Errorf("duplicate client for workspace %d chat %d: %w", workspaceIndex, chatIndex, err)
}
var runner harness.Runnable = chat.NewRunner(runnerClient, cfg)
if tracingEnabled {
runner = &runnableTraceWrapper{
tracer: tracer,
runner: runner,
spanName: fmt.Sprintf("chat/workspace-%d-chat-%d", workspaceIndex, chatIndex),
}
}
chatHarness.AddRun("chat", fmt.Sprintf("workspace-%d-chat-%d", workspaceIndex, chatIndex), runner)
}
}
// Run the chat harness in the background so the CLI can release the
// follow-up turns after every runner finishes its initial turn.
totalChats := int64(len(workspaces)) * chatsPerWorkspace
_, _ = fmt.Fprintf(inv.Stderr, "Starting chat scale test with %d chats across %d workspaces...\n", totalChats, len(workspaces))
testCtx, testCancel := timeoutStrategy.toContext(ctx)
defer testCancel()
testDone := make(chan error, 1)
go func() {
testDone <- chatHarness.Run(testCtx)
}()
if turnStartReadyWaitGroup != nil {
initialTurnsDone := make(chan struct{})
go func() {
turnStartReadyWaitGroup.Wait()
close(initialTurnsDone)
}()
select {
case <-testCtx.Done():
return testCtx.Err()
case <-initialTurnsDone:
}
_, _ = fmt.Fprintf(inv.Stderr, "All %d initial turns completed, waiting %s before starting the follow-up turns...\n", totalChats, turnStartDelay)
select {
case <-testCtx.Done():
return testCtx.Err()
case <-time.After(turnStartDelay):
}
close(startTurnsChan)
}
if err := <-testDone; err != nil {
return xerrors.Errorf("run harness: %w", err)
}
results := chatHarness.Results()
for _, o := range outputs {
if err := o.write(results, inv.Stdout); err != nil {
return xerrors.Errorf("write output %q to %q: %w", o.format, o.path, err)
}
}
_, _ = fmt.Fprintln(inv.Stderr, "\nCleaning up (archiving chats)...")
cleanupCtx, cleanupCancel := cleanupStrategy.toContext(ctx)
defer cleanupCancel()
if err := chatHarness.Cleanup(cleanupCtx); err != nil {
return xerrors.Errorf("cleanup chats: %w", err)
}
if results.TotalFail > 0 {
return xerrors.Errorf("scale test failed: %d/%d runs failed", results.TotalFail, results.TotalRuns)
}
_, _ = fmt.Fprintf(inv.Stderr, "Scale test passed: %d/%d runs succeeded\n", results.TotalPass, results.TotalRuns)
return nil
},
}
cmd.Options = serpent.OptionSet{
{
Flag: "chats-per-workspace",
Description: "Number of chats to run against each targeted workspace. Required and must be greater than 0.",
Value: serpent.Int64Of(&chatsPerWorkspace),
Required: true,
},
{
Flag: "prompt",
Description: "Text prompt to send on every turn in each chat.",
Default: "Reply with one short sentence.",
Value: serpent.StringOf(&prompt),
},
{
Flag: "turns",
Description: "Number of user to assistant exchanges per chat conversation.",
Default: "10",
Value: serpent.Int64Of(&turns),
},
{
Flag: "turn-start-delay",
Description: "Delay between every chat completing its initial turn and starting the follow-up turns. Use this to separate initial-turn load from follow-up-turn load.",
Default: "0s",
Value: serpent.DurationOf(&turnStartDelay),
},
{
Flag: "llm-mock-url",
Description: "URL of the mock LLM server (e.g. http://127.0.0.1:8080/v1). Creates or updates the Scaletest LLM Mock openai-compat provider and model config to point at this URL.",
Value: serpent.StringOf(&llmMockURL),
Required: true,
},
}
targetFlags.attach(&cmd.Options)
output.attach(&cmd.Options)
tracingFlags.attach(&cmd.Options)
prometheusFlags.attach(&cmd.Options)
timeoutStrategy.attach(&cmd.Options)
cleanupStrategy.attach(&cmd.Options)
return cmd
}
+54
View File
@@ -0,0 +1,54 @@
package chat
import (
"context"
"io"
"github.com/google/uuid"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/codersdk"
)
type chatClient interface {
SetLogger(logger slog.Logger)
SetLogBodies(logBodies bool)
CreateChat(ctx context.Context, req codersdk.CreateChatRequest) (codersdk.Chat, error)
StreamChat(ctx context.Context, chatID uuid.UUID, opts *codersdk.StreamChatOptions) (<-chan codersdk.ChatStreamEvent, io.Closer, error)
CreateChatMessage(ctx context.Context, chatID uuid.UUID, req codersdk.CreateChatMessageRequest) (codersdk.CreateChatMessageResponse, error)
UpdateChat(ctx context.Context, chatID uuid.UUID, req codersdk.UpdateChatRequest) error
}
type sdkChatClient struct {
client *codersdk.ExperimentalClient
}
func newChatClient(client *codersdk.Client) chatClient {
return &sdkChatClient{client: codersdk.NewExperimentalClient(client)}
}
func (c *sdkChatClient) SetLogger(logger slog.Logger) {
c.client.SetLogger(logger)
}
func (c *sdkChatClient) SetLogBodies(logBodies bool) {
c.client.SetLogBodies(logBodies)
}
func (c *sdkChatClient) CreateChat(ctx context.Context, req codersdk.CreateChatRequest) (codersdk.Chat, error) {
return c.client.CreateChat(ctx, req)
}
func (c *sdkChatClient) StreamChat(ctx context.Context, chatID uuid.UUID, opts *codersdk.StreamChatOptions) (<-chan codersdk.ChatStreamEvent, io.Closer, error) {
return c.client.StreamChat(ctx, chatID, opts)
}
func (c *sdkChatClient) CreateChatMessage(ctx context.Context, chatID uuid.UUID, req codersdk.CreateChatMessageRequest) (codersdk.CreateChatMessageResponse, error) {
return c.client.CreateChatMessage(ctx, chatID, req)
}
func (c *sdkChatClient) UpdateChat(ctx context.Context, chatID uuid.UUID, req codersdk.UpdateChatRequest) error {
return c.client.UpdateChat(ctx, chatID, req)
}
var _ chatClient = (*sdkChatClient)(nil)
+78
View File
@@ -0,0 +1,78 @@
package chat
import (
"sync"
"time"
"github.com/google/uuid"
"golang.org/x/xerrors"
)
// Config describes a single chat runner within a scaletest invocation.
type Config struct {
// OrganizationID is the organization that owns the target workspace.
OrganizationID uuid.UUID `json:"organization_id"`
// WorkspaceID is the pre-existing workspace to use for this chat run.
WorkspaceID uuid.UUID `json:"workspace_id"`
// Prompt is the text content sent on every turn.
Prompt string `json:"prompt"`
// ModelConfigID is the scaletest mock LLM model config.
ModelConfigID uuid.UUID `json:"model_config_id"`
// Turns is the total number of user to assistant exchanges per chat.
// Must be at least 1.
Turns int `json:"turns"`
// TurnStartDelay is the shared delay between every runner completing
// its initial turn and the release of the follow-up turns. Set
// to 0 to send all turns without an inter-phase pause.
TurnStartDelay time.Duration `json:"turn_start_delay"`
// TurnStartReadyWaitGroup coordinates the gap between the initial turn
// finishing and the follow-up turns. Each runner signals exactly
// once after its first turn reaches a terminal status, or when it
// knows it will never reach that point.
TurnStartReadyWaitGroup *sync.WaitGroup `json:"-"`
// StartTurnsChan blocks follow-up turns until the CLI layer releases them.
StartTurnsChan chan struct{} `json:"-"`
Metrics *Metrics `json:"-"`
}
func (c Config) Validate() error {
if c.OrganizationID == uuid.Nil {
return xerrors.Errorf("validate organization_id: must not be empty")
}
if c.WorkspaceID == uuid.Nil {
return xerrors.Errorf("validate workspace_id: must not be empty")
}
if c.Prompt == "" {
return xerrors.Errorf("validate prompt: must not be empty")
}
if c.ModelConfigID == uuid.Nil {
return xerrors.Errorf("validate model_config_id: must not be empty")
}
if c.Turns < 1 {
return xerrors.Errorf("validate turns: must be at least 1")
}
if c.TurnStartDelay < 0 {
return xerrors.Errorf("validate turn_start_delay: must not be negative")
}
if c.TurnStartDelay > 0 && c.Turns > 1 {
if c.TurnStartReadyWaitGroup == nil {
return xerrors.Errorf("validate turn_start_ready_wait_group: must not be nil when turn start delay is enabled for more than one turn")
}
if c.StartTurnsChan == nil {
return xerrors.Errorf("validate start_turns_chan: must not be nil when turn start delay is enabled for more than one turn")
}
}
if c.Metrics == nil {
return xerrors.Errorf("validate metrics: must not be nil")
}
return nil
}
+137
View File
@@ -0,0 +1,137 @@
package chat
import "github.com/prometheus/client_golang/prometheus"
const (
metricLabelPhase = "phase"
metricLabelStatus = "status"
metricLabelStage = "stage"
phaseInitial = "initial"
phaseFollowUp = "follow_up"
failureStageCreateChat = "create_chat"
failureStageCreateMessage = "create_message"
failureStageStreamOpen = "stream_open"
failureStageStreamEndedEarly = "stream_ended_early"
failureStageStatusError = "status_error"
)
var (
chatRequestLatencyBuckets = prometheus.ExponentialBucketsRange(0.05, 120, 18)
chatProcessingLatencyBuckets = prometheus.ExponentialBucketsRange(0.1, 300, 18)
)
// Metrics holds the Prometheus metrics emitted by the chat scaletest.
type Metrics struct {
ChatCreateLatencySeconds prometheus.Histogram
ChatMessageLatencySeconds *prometheus.HistogramVec
ChatConversationDurationSeconds prometheus.Histogram
ChatTimeToRunningSeconds *prometheus.HistogramVec
ChatTimeToFirstOutputSeconds *prometheus.HistogramVec
ChatTimeToTerminalStatusSeconds *prometheus.HistogramVec
ChatStageFailuresTotal *prometheus.CounterVec
ChatTerminalStatusTotal *prometheus.CounterVec
ChatTurnsCompletedTotal prometheus.Counter
ChatRetryEventsTotal prometheus.Counter
ActiveChatStreams prometheus.Gauge
}
func NewMetrics(reg prometheus.Registerer) *Metrics {
if reg == nil {
reg = prometheus.DefaultRegisterer
}
phaseLabelNames := []string{metricLabelPhase}
terminalStatusLabelNames := []string{metricLabelStatus}
failureStageLabelNames := []string{metricLabelStage}
m := &Metrics{
ChatCreateLatencySeconds: prometheus.NewHistogram(prometheus.HistogramOpts{
Namespace: "coderd",
Subsystem: "scaletest",
Name: "chat_create_latency_seconds",
Help: "Time in seconds to create a chat and enqueue the initial turn.",
Buckets: chatRequestLatencyBuckets,
}),
ChatMessageLatencySeconds: prometheus.NewHistogramVec(prometheus.HistogramOpts{
Namespace: "coderd",
Subsystem: "scaletest",
Name: "chat_message_latency_seconds",
Help: "Time in seconds to add a follow-up message to an existing chat.",
Buckets: chatRequestLatencyBuckets,
}, phaseLabelNames),
ChatConversationDurationSeconds: prometheus.NewHistogram(prometheus.HistogramOpts{
Namespace: "coderd",
Subsystem: "scaletest",
Name: "chat_conversation_duration_seconds",
Help: "Time in seconds from chat creation start until the conversation finishes or errors.",
Buckets: chatProcessingLatencyBuckets,
}),
ChatTimeToRunningSeconds: prometheus.NewHistogramVec(prometheus.HistogramOpts{
Namespace: "coderd",
Subsystem: "scaletest",
Name: "chat_time_to_running_seconds",
Help: "Time in seconds from the start of a chat turn until the chat enters running status.",
Buckets: chatProcessingLatencyBuckets,
}, phaseLabelNames),
ChatTimeToFirstOutputSeconds: prometheus.NewHistogramVec(prometheus.HistogramOpts{
Namespace: "coderd",
Subsystem: "scaletest",
Name: "chat_time_to_first_output_seconds",
Help: "Time in seconds from the start of a chat turn until the first output is received.",
Buckets: chatProcessingLatencyBuckets,
}, phaseLabelNames),
ChatTimeToTerminalStatusSeconds: prometheus.NewHistogramVec(prometheus.HistogramOpts{
Namespace: "coderd",
Subsystem: "scaletest",
Name: "chat_time_to_terminal_status_seconds",
Help: "Time in seconds from the start of a chat turn until a terminal status is received.",
Buckets: chatProcessingLatencyBuckets,
}, phaseLabelNames),
ChatStageFailuresTotal: prometheus.NewCounterVec(prometheus.CounterOpts{
Namespace: "coderd",
Subsystem: "scaletest",
Name: "chat_stage_failures_total",
Help: "Total number of terminal stage-specific chat runner failures.",
}, failureStageLabelNames),
ChatTerminalStatusTotal: prometheus.NewCounterVec(prometheus.CounterOpts{
Namespace: "coderd",
Subsystem: "scaletest",
Name: "chat_terminal_status_total",
Help: "Total number of terminal chat statuses observed.",
}, terminalStatusLabelNames),
ChatTurnsCompletedTotal: prometheus.NewCounter(prometheus.CounterOpts{
Namespace: "coderd",
Subsystem: "scaletest",
Name: "chat_turns_completed_total",
Help: "Total number of chat turns completed successfully.",
}),
ChatRetryEventsTotal: prometheus.NewCounter(prometheus.CounterOpts{
Namespace: "coderd",
Subsystem: "scaletest",
Name: "chat_retry_events_total",
Help: "Total number of chat retry events observed.",
}),
ActiveChatStreams: prometheus.NewGauge(prometheus.GaugeOpts{
Namespace: "coderd",
Subsystem: "scaletest",
Name: "active_chat_streams",
Help: "Current number of active chat streams.",
}),
}
reg.MustRegister(m.ChatCreateLatencySeconds)
reg.MustRegister(m.ChatMessageLatencySeconds)
reg.MustRegister(m.ChatConversationDurationSeconds)
reg.MustRegister(m.ChatTimeToRunningSeconds)
reg.MustRegister(m.ChatTimeToFirstOutputSeconds)
reg.MustRegister(m.ChatTimeToTerminalStatusSeconds)
reg.MustRegister(m.ChatStageFailuresTotal)
reg.MustRegister(m.ChatTerminalStatusTotal)
reg.MustRegister(m.ChatTurnsCompletedTotal)
reg.MustRegister(m.ChatRetryEventsTotal)
reg.MustRegister(m.ActiveChatStreams)
return m
}
+148
View File
@@ -0,0 +1,148 @@
package chat
import (
"context"
"net/http"
"github.com/google/uuid"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/codersdk"
)
const (
scaletestProviderType = "openai-compat"
scaletestProviderDisplayName = "Scaletest LLM Mock"
scaletestModelName = "scaletest-model"
scaletestModelDisplayName = "Scaletest Model"
)
type scaletestProviderAction string
const (
scaletestProviderActionCreated scaletestProviderAction = "created"
scaletestProviderActionUpdated scaletestProviderAction = "updated"
scaletestProviderActionReused scaletestProviderAction = "reused"
)
// EnsureScaletestModelConfig bootstraps the shared chat provider and model
// config used by chat scaletests.
func EnsureScaletestModelConfig(ctx context.Context, client *codersdk.ExperimentalClient, logger slog.Logger, llmMockURL string) (uuid.UUID, error) {
logger.Info(ctx, "bootstrapping mock LLM provider", slog.F("llm_mock_url", llmMockURL))
provider, providerAction, err := ensureScaletestProvider(ctx, client, llmMockURL)
if err != nil {
return uuid.Nil, err
}
switch providerAction {
case scaletestProviderActionCreated:
logger.Info(ctx, "created mock LLM provider",
slog.F("provider_type", scaletestProviderType),
slog.F("llm_mock_url", llmMockURL),
)
case scaletestProviderActionUpdated:
logger.Info(ctx, "updated mock LLM provider",
slog.F("provider_type", scaletestProviderType),
slog.F("provider_id", provider.ID),
slog.F("llm_mock_url", llmMockURL),
)
case scaletestProviderActionReused:
logger.Info(ctx, "reusing mock LLM provider",
slog.F("provider_type", scaletestProviderType),
slog.F("provider_id", provider.ID),
)
}
modelConfigs, err := client.ListChatModelConfigs(ctx)
if err != nil {
return uuid.Nil, xerrors.Errorf("list chat model configs: %w", err)
}
for i := range modelConfigs {
if modelConfigs[i].Provider != provider.Provider || modelConfigs[i].Model != scaletestModelName {
continue
}
if !modelConfigs[i].Enabled {
return uuid.Nil, xerrors.Errorf("existing scaletest chat model config %s is disabled; re-enable or delete it before running scaletests", modelConfigs[i].ID)
}
modelConfigID := modelConfigs[i].ID
logger.Info(ctx, "reusing scaletest model config", slog.F("model_config_id", modelConfigID))
return modelConfigID, nil
}
enabled := true
isDefault := false
contextLimit := int64(4096)
created, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: provider.Provider,
Model: scaletestModelName,
DisplayName: scaletestModelDisplayName,
Enabled: &enabled,
IsDefault: &isDefault,
ContextLimit: &contextLimit,
})
if err != nil {
return uuid.Nil, xerrors.Errorf("create scaletest chat model config: %w", err)
}
logger.Info(ctx, "created scaletest model config", slog.F("model_config_id", created.ID))
return created.ID, nil
}
func ensureScaletestProvider(ctx context.Context, client *codersdk.ExperimentalClient, llmMockURL string) (codersdk.ChatProviderConfig, scaletestProviderAction, error) {
enabled := true
mockProviderToken := uuid.NewString()
created, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
Provider: scaletestProviderType,
DisplayName: scaletestProviderDisplayName,
APIKey: mockProviderToken,
BaseURL: llmMockURL,
Enabled: &enabled,
})
if err == nil {
return created, scaletestProviderActionCreated, nil
}
var sdkErr *codersdk.Error
if !xerrors.As(err, &sdkErr) || sdkErr.StatusCode() != http.StatusConflict {
return codersdk.ChatProviderConfig{}, "", xerrors.Errorf("create scaletest chat provider: %w", err)
}
providers, err := client.ListChatProviders(ctx)
if err != nil {
return codersdk.ChatProviderConfig{}, "", xerrors.Errorf("list chat providers: %w", err)
}
var existing *codersdk.ChatProviderConfig
for i := range providers {
if providers[i].Provider == scaletestProviderType {
existing = &providers[i]
break
}
}
if existing == nil {
return codersdk.ChatProviderConfig{}, "", xerrors.Errorf("find existing %s provider after conflict: not found", scaletestProviderType)
}
if existing.DisplayName != scaletestProviderDisplayName {
return codersdk.ChatProviderConfig{}, "", xerrors.Errorf("refusing to overwrite existing %s provider %s with display name %q", scaletestProviderType, existing.ID, existing.DisplayName)
}
if !existing.Enabled {
return codersdk.ChatProviderConfig{}, "", xerrors.Errorf("existing scaletest chat provider %s is disabled; re-enable or delete it before running scaletests", existing.ID)
}
if existing.BaseURL == llmMockURL {
return *existing, scaletestProviderActionReused, nil
}
updated, err := client.UpdateChatProvider(ctx, existing.ID, codersdk.UpdateChatProviderConfigRequest{
DisplayName: scaletestProviderDisplayName,
APIKey: &mockProviderToken,
BaseURL: &llmMockURL,
Enabled: &enabled,
})
if err != nil {
return codersdk.ChatProviderConfig{}, "", xerrors.Errorf("update scaletest chat provider: %w", err)
}
return updated, scaletestProviderActionUpdated, nil
}
+413
View File
@@ -0,0 +1,413 @@
package chat
import (
"context"
"io"
"sync"
"time"
"github.com/google/uuid"
"go.opentelemetry.io/otel/attribute"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/sloghuman"
"github.com/coder/coder/v2/coderd/tracing"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/scaletest/harness"
"github.com/coder/coder/v2/scaletest/loadtestutil"
)
// Runner executes a single chat conversation as part of a scaletest run.
type Runner struct {
client chatClient
cfg Config
chatID uuid.UUID
result runnerResult
conversationStart time.Time
turnStartTime time.Time
currentPhase string
lastStreamError string
lastStatus codersdk.ChatStatus
sawTurnRunning bool
sawTurnFirstOutput bool
markTurnStartReady func()
}
type runnerResult struct {
finalStatus string
failureStage string
totalDuration time.Duration
sawFirstOutput bool
retryCount int
eventCount int
turnsCompleted int
}
var (
_ harness.Runnable = &Runner{}
_ harness.Cleanable = &Runner{}
_ harness.Collectable = &Runner{}
)
func NewRunner(client *codersdk.Client, cfg Config) *Runner {
return &Runner{
client: newChatClient(client),
cfg: cfg,
}
}
func (r *Runner) Run(ctx context.Context, id string, logs io.Writer) error {
ctx, span := tracing.StartSpan(ctx)
defer span.End()
logs = loadtestutil.NewSyncWriter(logs)
logger := slog.Make(sloghuman.Sink(logs)).Leveled(slog.LevelDebug).Named(id)
r.client.SetLogger(logger)
r.client.SetLogBodies(true)
span.SetAttributes(
attribute.String("chat.runner_id", id),
attribute.String("chat.workspace_id", r.cfg.WorkspaceID.String()),
attribute.Int("chat.turns_requested", r.cfg.Turns),
attribute.Int64("chat.turn_start_delay_ms", r.cfg.TurnStartDelay.Milliseconds()),
)
span.SetAttributes(attribute.String("chat.model_config_id", r.cfg.ModelConfigID.String()))
markTurnStartReady := func() {}
if r.cfg.TurnStartReadyWaitGroup != nil {
markTurnStartReady = sync.OnceFunc(r.cfg.TurnStartReadyWaitGroup.Done)
}
r.markTurnStartReady = markTurnStartReady
defer r.markTurnStartReady()
defer func() {
if !r.conversationStart.IsZero() {
r.result.totalDuration = time.Since(r.conversationStart)
r.cfg.Metrics.ChatConversationDurationSeconds.Observe(r.result.totalDuration.Seconds())
}
span.SetAttributes(
attribute.String("chat.final_status", r.result.finalStatus),
attribute.String("chat.failure_stage", r.result.failureStage),
attribute.Int("chat.retry_count", r.result.retryCount),
attribute.Int("chat.turns_completed", r.result.turnsCompleted),
attribute.Bool("chat.saw_first_output", r.result.sawFirstOutput),
)
if r.result.totalDuration > 0 {
span.SetAttributes(attribute.Float64("chat.total_duration_seconds", r.result.totalDuration.Seconds()))
}
}()
workspaceID := r.cfg.WorkspaceID
modelConfigID := r.cfg.ModelConfigID
logger = logger.With(slog.F("workspace_id", workspaceID))
logger.Info(ctx, "starting chat runner")
r.resetConversation(time.Now(), markTurnStartReady)
createStartedAt := time.Now()
chat, err := r.client.CreateChat(ctx, codersdk.CreateChatRequest{
OrganizationID: r.cfg.OrganizationID,
WorkspaceID: &workspaceID,
ModelConfigID: &modelConfigID,
Content: []codersdk.ChatInputPart{{
Type: codersdk.ChatInputPartTypeText,
Text: r.cfg.Prompt,
}},
})
if err != nil {
r.result.failureStage = failureStageCreateChat
r.cfg.Metrics.ChatStageFailuresTotal.WithLabelValues(r.result.failureStage).Inc()
return xerrors.Errorf("create chat: %w", err)
}
r.cfg.Metrics.ChatCreateLatencySeconds.Observe(time.Since(createStartedAt).Seconds())
r.chatID = chat.ID
span.SetAttributes(attribute.String("chat.chat_id", chat.ID.String()))
logger = logger.With(slog.F("chat_id", chat.ID))
logger.Info(ctx, "created chat session", slog.F("duration", time.Since(createStartedAt)))
// CreateChat already queues the first prompt for processing on the
// server, so the initial turn is in flight as soon as CreateChat
// returns. Open the stream immediately and let the conversation loop
// drive the gate at the natural phase boundary (after the first turn
// reaches a terminal Waiting status), rather than fencing here on a
// turn that has already started running.
events, closer, err := r.client.StreamChat(ctx, chat.ID, nil)
if err != nil {
r.result.failureStage = failureStageStreamOpen
r.cfg.Metrics.ChatStageFailuresTotal.WithLabelValues(r.result.failureStage).Inc()
return xerrors.Errorf("stream chat: %w", err)
}
r.cfg.Metrics.ActiveChatStreams.Inc()
defer func() {
r.cfg.Metrics.ActiveChatStreams.Dec()
_ = closer.Close()
}()
logger.Info(ctx, "streaming chat events")
return r.runConversation(ctx, chat.ID, logger, events)
}
func (r *Runner) resetConversation(conversationStart time.Time, markTurnStartReady func()) {
if markTurnStartReady == nil {
markTurnStartReady = func() {}
}
r.result = runnerResult{}
r.conversationStart = conversationStart
r.turnStartTime = conversationStart
r.currentPhase = phaseInitial
r.lastStreamError = ""
r.lastStatus = ""
r.sawTurnRunning = false
r.sawTurnFirstOutput = false
r.markTurnStartReady = markTurnStartReady
}
func (r *Runner) runConversation(ctx context.Context, chatID uuid.UUID, logger slog.Logger, events <-chan codersdk.ChatStreamEvent) error {
r.chatID = chatID
for event := range events {
r.result.eventCount++
switch event.Type {
case codersdk.ChatStreamEventTypeStatus:
if event.Status == nil {
continue
}
done, err := r.handleStatusEvent(ctx, chatID, logger, event.Status.Status)
if err != nil {
return err
}
if done {
return nil
}
case codersdk.ChatStreamEventTypeMessagePart:
r.handleMessagePartEvent(ctx, logger)
case codersdk.ChatStreamEventTypeMessage:
// StreamChat replays persisted rows as message events, not
// message_part deltas, when a turn finished server-side before
// the stream attached. Route assistant rows through the same
// first-output path; skip user rows so persisted prompts do not
// count as model output.
if event.Message == nil || event.Message.Role != codersdk.ChatMessageRoleAssistant {
continue
}
r.handleMessagePartEvent(ctx, logger)
case codersdk.ChatStreamEventTypeRetry:
r.handleRetryEvent(ctx, logger, event.Retry)
case codersdk.ChatStreamEventTypeError:
r.handleErrorEvent(ctx, logger, event.Error)
}
}
if ctx.Err() != nil {
return ctx.Err()
}
r.result.failureStage = failureStageStreamEndedEarly
r.cfg.Metrics.ChatStageFailuresTotal.WithLabelValues(r.result.failureStage).Inc()
if r.lastStreamError != "" {
return xerrors.Errorf("chat %s stream ended before completing %d of %d turns: %s", chatID, r.result.turnsCompleted, r.cfg.Turns, r.lastStreamError)
}
return xerrors.Errorf("chat %s stream ended before completing %d of %d turns", chatID, r.result.turnsCompleted, r.cfg.Turns)
}
func (r *Runner) handleStatusEvent(ctx context.Context, chatID uuid.UUID, logger slog.Logger, status codersdk.ChatStatus) (bool, error) {
if status == r.lastStatus {
return false, nil
}
if status == codersdk.ChatStatusWaiting &&
!r.sawTurnFirstOutput &&
(r.sawTurnRunning || r.result.turnsCompleted > 0) {
return false, nil
}
r.lastStatus = status
switch status {
case codersdk.ChatStatusRunning:
r.sawTurnRunning = true
r.cfg.Metrics.ChatTimeToRunningSeconds.WithLabelValues(r.currentPhase).Observe(time.Since(r.turnStartTime).Seconds())
logger.Info(ctx, "chat reached running status",
slog.F("phase", r.currentPhase),
)
return false, nil
case codersdk.ChatStatusWaiting:
r.result.turnsCompleted++
turnDuration := time.Since(r.turnStartTime)
r.cfg.Metrics.ChatTimeToTerminalStatusSeconds.WithLabelValues(r.currentPhase).Observe(turnDuration.Seconds())
r.cfg.Metrics.ChatTerminalStatusTotal.WithLabelValues(string(codersdk.ChatStatusWaiting)).Inc()
r.cfg.Metrics.ChatTurnsCompletedTotal.Inc()
logger.Info(ctx, "chat completed turn",
slog.F("turn", r.result.turnsCompleted),
slog.F("turns", r.cfg.Turns),
slog.F("duration", turnDuration),
)
if r.result.turnsCompleted >= r.cfg.Turns {
r.result.finalStatus = string(codersdk.ChatStatusWaiting)
conversationDuration := time.Since(r.conversationStart)
logger.Info(ctx, "chat reached terminal status",
slog.F("status", codersdk.ChatStatusWaiting),
slog.F("duration", conversationDuration),
slog.F("turns_completed", r.result.turnsCompleted),
)
return true, nil
}
// After the very first turn completes, mark this runner ready
// for the CLI-coordinated turn-start gate. The inter-phase
// delay measures the gap between every chat actually finishing its
// initial turn and the start of the follow-up turns, not the gap
// between CreateChat returning and the next turn.
if r.result.turnsCompleted == 1 {
r.markTurnStartReady()
if r.cfg.StartTurnsChan != nil {
logger.Info(ctx, "chat waiting for turn start release",
slog.F("turn_start_delay", r.cfg.TurnStartDelay),
)
select {
case <-ctx.Done():
return false, ctx.Err()
case <-r.cfg.StartTurnsChan:
}
}
}
nextTurn := r.result.turnsCompleted + 1
r.currentPhase = phaseFollowUp
r.turnStartTime = time.Now()
r.lastStreamError = ""
r.lastStatus = ""
r.sawTurnRunning = false
r.sawTurnFirstOutput = false
if err := r.sendNextTurn(ctx, chatID, logger, nextTurn, r.currentPhase); err != nil {
r.result.failureStage = failureStageCreateMessage
r.cfg.Metrics.ChatStageFailuresTotal.WithLabelValues(r.result.failureStage).Inc()
return false, err
}
return false, nil
case codersdk.ChatStatusError:
r.result.finalStatus = string(codersdk.ChatStatusError)
r.result.failureStage = failureStageStatusError
turnDuration := time.Since(r.turnStartTime)
r.cfg.Metrics.ChatTimeToTerminalStatusSeconds.WithLabelValues(r.currentPhase).Observe(turnDuration.Seconds())
r.cfg.Metrics.ChatTerminalStatusTotal.WithLabelValues(string(codersdk.ChatStatusError)).Inc()
r.cfg.Metrics.ChatStageFailuresTotal.WithLabelValues(r.result.failureStage).Inc()
errMessage := r.lastStreamError
if errMessage == "" {
errMessage = "chat reached error status"
}
logger.Error(ctx, "chat reached terminal status",
slog.F("status", codersdk.ChatStatusError),
slog.F("turns_completed", r.result.turnsCompleted),
slog.F("turns", r.cfg.Turns),
slog.F("error", errMessage),
)
return false, xerrors.Errorf("chat %s reached error status: %s", chatID, errMessage)
default:
return false, nil
}
}
func (r *Runner) sendNextTurn(ctx context.Context, chatID uuid.UUID, logger slog.Logger, nextTurn int, phase string) error {
messageStartedAt := time.Now()
modelConfigID := r.cfg.ModelConfigID
_, err := r.client.CreateChatMessage(ctx, chatID, codersdk.CreateChatMessageRequest{
Content: []codersdk.ChatInputPart{{
Type: codersdk.ChatInputPartTypeText,
Text: r.cfg.Prompt,
}},
ModelConfigID: &modelConfigID,
})
if err != nil {
return xerrors.Errorf("create chat message for turn %d: %w", nextTurn, err)
}
r.cfg.Metrics.ChatMessageLatencySeconds.WithLabelValues(phase).Observe(time.Since(messageStartedAt).Seconds())
logger.Info(ctx, "chat sent message",
slog.F("turn", nextTurn),
slog.F("turns", r.cfg.Turns),
)
return nil
}
func (r *Runner) handleMessagePartEvent(ctx context.Context, logger slog.Logger) {
if r.sawTurnFirstOutput {
return
}
r.sawTurnFirstOutput = true
r.result.sawFirstOutput = true
firstOutputDuration := time.Since(r.turnStartTime)
r.cfg.Metrics.ChatTimeToFirstOutputSeconds.WithLabelValues(r.currentPhase).Observe(firstOutputDuration.Seconds())
logger.Info(ctx, "chat received first output",
slog.F("phase", r.currentPhase),
slog.F("duration", firstOutputDuration),
)
}
func (r *Runner) handleRetryEvent(ctx context.Context, logger slog.Logger, retry *codersdk.ChatStreamRetry) {
r.result.retryCount++
r.cfg.Metrics.ChatRetryEventsTotal.Inc()
if retry != nil {
logger.Warn(ctx, "chat retry event",
slog.F("attempt", retry.Attempt),
slog.F("delay_ms", retry.DelayMs),
slog.F("error", retry.Error),
)
return
}
logger.Warn(ctx, "chat retry event")
}
func (r *Runner) handleErrorEvent(ctx context.Context, logger slog.Logger, eventErr *codersdk.ChatError) {
if eventErr != nil && eventErr.Message != "" {
r.lastStreamError = eventErr.Message
logger.Warn(ctx, "chat stream error",
slog.F("error", r.lastStreamError),
)
return
}
logger.Warn(ctx, "chat stream error event")
}
func (r *Runner) Cleanup(ctx context.Context, id string, logs io.Writer) error {
if r.chatID == uuid.Nil {
return nil
}
logs = loadtestutil.NewSyncWriter(logs)
logger := slog.Make(sloghuman.Sink(logs)).Leveled(slog.LevelDebug).Named(id).With(slog.F("chat_id", r.chatID))
r.client.SetLogger(logger)
r.client.SetLogBodies(true)
archived := true
logger.Info(ctx, "archiving chat session")
if err := r.client.UpdateChat(ctx, r.chatID, codersdk.UpdateChatRequest{Archived: &archived}); err != nil {
logger.Error(ctx, "failed to archive chat", slog.Error(err))
return xerrors.Errorf("archive chat: %w", err)
}
logger.Info(ctx, "archived chat session")
return nil
}
func (r *Runner) GetMetrics() map[string]any {
return map[string]any{
"workspace_id": r.cfg.WorkspaceID.String(),
"turn_start_delay_ms": r.cfg.TurnStartDelay.Milliseconds(),
"chat_id": r.chatID.String(),
"final_status": r.result.finalStatus,
"failure_stage": r.result.failureStage,
"total_duration_seconds": r.result.totalDuration.Seconds(),
"saw_first_output": r.result.sawFirstOutput,
"retry_count": r.result.retryCount,
"event_count": r.result.eventCount,
"turns_requested": r.cfg.Turns,
"turns_completed": r.result.turnsCompleted,
}
}
+391
View File
@@ -0,0 +1,391 @@
package chat
import (
"bytes"
"context"
"io"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/google/uuid"
"github.com/prometheus/client_golang/prometheus"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/sloghuman"
"github.com/coder/coder/v2/codersdk"
)
func TestRunnerRunConversation(t *testing.T) {
t.Parallel()
chatID := uuid.MustParse("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa")
noopMarkTurnStartReady := func() {}
t.Run("OneTurnHappyPath", func(t *testing.T) {
t.Parallel()
runner := newTestRunner(t, newRunConfig(t))
events := make(chan codersdk.ChatStreamEvent, 3)
events <- statusEvent(chatID, codersdk.ChatStatusRunning)
events <- messagePartEvent(chatID)
events <- statusEvent(chatID, codersdk.ChatStatusWaiting)
close(events)
err := runTestConversation(t, runner, chatID, events, noopMarkTurnStartReady)
require.NoError(t, err)
result := runner.result
require.Equal(t, string(codersdk.ChatStatusWaiting), result.finalStatus)
require.Empty(t, result.failureStage)
require.True(t, result.sawFirstOutput)
require.Equal(t, 1, result.turnsCompleted)
require.Equal(t, 3, result.eventCount)
})
t.Run("DuplicateWaitingDoesNotAdvanceTurn", func(t *testing.T) {
t.Parallel()
cfg := newRunConfig(t)
cfg.Turns = 2
events := make(chan codersdk.ChatStreamEvent, 7)
events <- statusEvent(chatID, codersdk.ChatStatusRunning)
events <- messagePartEvent(chatID)
events <- statusEvent(chatID, codersdk.ChatStatusWaiting)
events <- statusEvent(chatID, codersdk.ChatStatusWaiting)
var sendCount atomic.Int64
runner := newTestRunnerWithChatMessage(t, cfg, chatID, func() {
sendCount.Add(1)
events <- statusEvent(chatID, codersdk.ChatStatusRunning)
events <- messagePartEvent(chatID)
events <- statusEvent(chatID, codersdk.ChatStatusWaiting)
close(events)
})
err := runTestConversation(t, runner, chatID, events, noopMarkTurnStartReady)
require.NoError(t, err)
result := runner.result
require.Equal(t, int64(1), sendCount.Load())
require.Equal(t, 2, result.turnsCompleted)
require.Equal(t, 7, result.eventCount)
require.Equal(t, string(codersdk.ChatStatusWaiting), result.finalStatus)
})
t.Run("StaleWaitingAfterNextTurnRunningDoesNotAdvanceTurn", func(t *testing.T) {
t.Parallel()
cfg := newRunConfig(t)
cfg.Turns = 2
events := make(chan codersdk.ChatStreamEvent, 7)
events <- statusEvent(chatID, codersdk.ChatStatusRunning)
events <- messagePartEvent(chatID)
events <- statusEvent(chatID, codersdk.ChatStatusWaiting)
var sendCount atomic.Int64
runner := newTestRunnerWithChatMessage(t, cfg, chatID, func() {
sendCount.Add(1)
events <- statusEvent(chatID, codersdk.ChatStatusRunning)
events <- statusEvent(chatID, codersdk.ChatStatusWaiting)
events <- messagePartEvent(chatID)
events <- statusEvent(chatID, codersdk.ChatStatusWaiting)
close(events)
})
err := runTestConversation(t, runner, chatID, events, noopMarkTurnStartReady)
require.NoError(t, err)
result := runner.result
require.Equal(t, int64(1), sendCount.Load())
require.Equal(t, 2, result.turnsCompleted)
require.Equal(t, 7, result.eventCount)
require.Equal(t, string(codersdk.ChatStatusWaiting), result.finalStatus)
})
t.Run("FirstTurnGatesFollowUpStorm", func(t *testing.T) {
t.Parallel()
// Reproduces the contract that the turn-start gate is checked
// after the first turn finishes, not before it begins. The runner
// must mark itself ready, wait for the release channel, and only
// then send turn 2.
cfg := newRunConfig(t)
cfg.Turns = 2
readyWG := &sync.WaitGroup{}
readyWG.Add(1)
releaseChan := make(chan struct{})
cfg.TurnStartReadyWaitGroup = readyWG
cfg.StartTurnsChan = releaseChan
events := make(chan codersdk.ChatStreamEvent, 4)
events <- statusEvent(chatID, codersdk.ChatStatusRunning)
events <- messagePartEvent(chatID)
events <- statusEvent(chatID, codersdk.ChatStatusWaiting)
ready := make(chan struct{})
go func() {
readyWG.Wait()
close(ready)
}()
errCh := make(chan error, 1)
var sendCount atomic.Int64
runner := newTestRunnerWithChatMessage(t, cfg, chatID, func() {
sendCount.Add(1)
events <- statusEvent(chatID, codersdk.ChatStatusRunning)
events <- messagePartEvent(chatID)
events <- statusEvent(chatID, codersdk.ChatStatusWaiting)
close(events)
})
runner.resetConversation(time.Now(), sync.OnceFunc(readyWG.Done))
go func() {
runErr := runner.runConversation(context.Background(), chatID, testLogger(), events)
errCh <- runErr
}()
select {
case <-ready:
case <-time.After(2 * time.Second):
t.Fatal("runner did not mark turn-start gate ready after first turn")
}
require.Equal(t, int64(0), sendCount.Load(), "next turn was sent before turn-start release")
close(releaseChan)
select {
case err := <-errCh:
require.NoError(t, err)
case <-time.After(2 * time.Second):
t.Fatal("runner did not finish after turn-start release")
}
require.Equal(t, int64(1), sendCount.Load())
})
t.Run("FirstOutputFromAssistantMessageEvent", func(t *testing.T) {
t.Parallel()
// Snapshot race: when a turn finishes before stream attach,
// StreamChat replays rows as message events, never as
// message_part deltas; the assistant row must record first output.
runner := newTestRunner(t, newRunConfig(t))
events := make(chan codersdk.ChatStreamEvent, 3)
events <- messageEvent(chatID, codersdk.ChatMessageRoleUser)
events <- messageEvent(chatID, codersdk.ChatMessageRoleAssistant)
events <- statusEvent(chatID, codersdk.ChatStatusWaiting)
close(events)
err := runTestConversation(t, runner, chatID, events, noopMarkTurnStartReady)
require.NoError(t, err)
result := runner.result
require.True(t, result.sawFirstOutput, "first output not recorded from assistant message event")
require.Equal(t, 1, result.turnsCompleted)
require.Equal(t, string(codersdk.ChatStatusWaiting), result.finalStatus)
})
t.Run("ImmediateWaitingCountsNextTurn", func(t *testing.T) {
t.Parallel()
cfg := newRunConfig(t)
cfg.Turns = 2
events := make(chan codersdk.ChatStreamEvent, 3)
events <- statusEvent(chatID, codersdk.ChatStatusWaiting)
var sendCount atomic.Int64
runner := newTestRunnerWithChatMessage(t, cfg, chatID, func() {
sendCount.Add(1)
events <- statusEvent(chatID, codersdk.ChatStatusRunning)
events <- messagePartEvent(chatID)
events <- statusEvent(chatID, codersdk.ChatStatusWaiting)
close(events)
})
err := runTestConversation(t, runner, chatID, events, noopMarkTurnStartReady)
require.NoError(t, err)
result := runner.result
require.Equal(t, int64(1), sendCount.Load())
require.Equal(t, 2, result.turnsCompleted)
require.Equal(t, string(codersdk.ChatStatusWaiting), result.finalStatus)
})
}
func runTestConversation(t *testing.T, runner *Runner, chatID uuid.UUID, events <-chan codersdk.ChatStreamEvent, markTurnStartReady func()) error {
t.Helper()
runner.resetConversation(time.Now(), markTurnStartReady)
return runner.runConversation(context.Background(), chatID, testLogger(), events)
}
func TestRunnerCleanup(t *testing.T) {
t.Parallel()
chatID := uuid.MustParse("22222222-2222-2222-2222-222222222222")
t.Run("ArchivesChat", func(t *testing.T) {
t.Parallel()
runner, archived := newTestRunnerWithChatArchive(t, chatID, nil)
logs := bytes.NewBuffer(nil)
err := runner.Cleanup(context.Background(), "runner-1", logs)
require.NoError(t, err)
require.True(t, archived())
require.Contains(t, logs.String(), "archived chat")
})
t.Run("ArchiveErrorIsReturned", func(t *testing.T) {
t.Parallel()
runner, archived := newTestRunnerWithChatArchive(t, chatID, xerrors.New("boom"))
err := runner.Cleanup(context.Background(), "runner-1", bytes.NewBuffer(nil))
require.Error(t, err)
require.ErrorContains(t, err, "archive chat")
require.True(t, archived())
})
}
func testLogger() slog.Logger {
return slog.Make(sloghuman.Sink(io.Discard)).Leveled(slog.LevelDebug)
}
func newRunConfig(t *testing.T) Config {
t.Helper()
reg := prometheus.NewRegistry()
return Config{
OrganizationID: uuid.MustParse("22222222-2222-2222-2222-222222222222"),
WorkspaceID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
ModelConfigID: uuid.MustParse("33333333-3333-3333-3333-333333333333"),
Prompt: "Reply with one short sentence.",
Turns: 1,
Metrics: NewMetrics(reg),
}
}
type fakeChatClient struct {
createChatFunc func(context.Context, codersdk.CreateChatRequest) (codersdk.Chat, error)
streamChatFunc func(context.Context, uuid.UUID, *codersdk.StreamChatOptions) (<-chan codersdk.ChatStreamEvent, io.Closer, error)
createChatMessageFunc func(context.Context, uuid.UUID, codersdk.CreateChatMessageRequest) (codersdk.CreateChatMessageResponse, error)
updateChatFunc func(context.Context, uuid.UUID, codersdk.UpdateChatRequest) error
}
func newFakeChatClient(t *testing.T) *fakeChatClient {
t.Helper()
return &fakeChatClient{}
}
func (*fakeChatClient) SetLogger(logger slog.Logger) {}
func (*fakeChatClient) SetLogBodies(logBodies bool) {}
func (f *fakeChatClient) CreateChat(ctx context.Context, req codersdk.CreateChatRequest) (codersdk.Chat, error) {
if f.createChatFunc == nil {
return codersdk.Chat{}, xerrors.New("unexpected CreateChat call")
}
return f.createChatFunc(ctx, req)
}
func (f *fakeChatClient) StreamChat(ctx context.Context, chatID uuid.UUID, opts *codersdk.StreamChatOptions) (<-chan codersdk.ChatStreamEvent, io.Closer, error) {
if f.streamChatFunc == nil {
return nil, nil, xerrors.New("unexpected StreamChat call")
}
return f.streamChatFunc(ctx, chatID, opts)
}
func (f *fakeChatClient) CreateChatMessage(ctx context.Context, chatID uuid.UUID, req codersdk.CreateChatMessageRequest) (codersdk.CreateChatMessageResponse, error) {
if f.createChatMessageFunc == nil {
return codersdk.CreateChatMessageResponse{}, xerrors.New("unexpected CreateChatMessage call")
}
return f.createChatMessageFunc(ctx, chatID, req)
}
func (f *fakeChatClient) UpdateChat(ctx context.Context, chatID uuid.UUID, req codersdk.UpdateChatRequest) error {
if f.updateChatFunc == nil {
return xerrors.New("unexpected UpdateChat call")
}
return f.updateChatFunc(ctx, chatID, req)
}
var _ chatClient = (*fakeChatClient)(nil)
func newTestRunner(t *testing.T, cfg Config) *Runner {
t.Helper()
return &Runner{client: newFakeChatClient(t), cfg: cfg}
}
func newTestRunnerWithChatArchive(t *testing.T, chatID uuid.UUID, updateErr error) (*Runner, func() bool) {
t.Helper()
var archived atomic.Bool
client := newFakeChatClient(t)
client.updateChatFunc = func(ctx context.Context, gotChatID uuid.UUID, req codersdk.UpdateChatRequest) error {
if gotChatID != chatID {
return xerrors.Errorf("unexpected chat archive ID: %s", gotChatID)
}
if req.Archived == nil || !*req.Archived {
return xerrors.Errorf("unexpected archived value: %v", req.Archived)
}
archived.Store(true)
return updateErr
}
runner := &Runner{client: client, cfg: Config{}, chatID: chatID}
return runner, archived.Load
}
func newTestRunnerWithChatMessage(t *testing.T, cfg Config, chatID uuid.UUID, onMessage func()) *Runner {
t.Helper()
client := newFakeChatClient(t)
client.createChatMessageFunc = func(ctx context.Context, gotChatID uuid.UUID, req codersdk.CreateChatMessageRequest) (codersdk.CreateChatMessageResponse, error) {
if gotChatID != chatID {
return codersdk.CreateChatMessageResponse{}, xerrors.Errorf("unexpected chat message ID: %s", gotChatID)
}
if err := validatePromptParts(req.Content, cfg.Prompt); err != nil {
return codersdk.CreateChatMessageResponse{}, err
}
if req.ModelConfigID == nil || *req.ModelConfigID != cfg.ModelConfigID {
return codersdk.CreateChatMessageResponse{}, xerrors.Errorf("unexpected chat message model config ID: %v", req.ModelConfigID)
}
if onMessage != nil {
onMessage()
}
return codersdk.CreateChatMessageResponse{Queued: true}, nil
}
return &Runner{client: client, cfg: cfg}
}
func validatePromptParts(parts []codersdk.ChatInputPart, prompt string) error {
if len(parts) != 1 || parts[0].Type != codersdk.ChatInputPartTypeText || parts[0].Text != prompt {
return xerrors.Errorf("unexpected chat message content: %#v", parts)
}
return nil
}
func statusEvent(chatID uuid.UUID, status codersdk.ChatStatus) codersdk.ChatStreamEvent {
return codersdk.ChatStreamEvent{
Type: codersdk.ChatStreamEventTypeStatus,
ChatID: chatID,
Status: &codersdk.ChatStreamStatus{Status: status},
}
}
func messagePartEvent(chatID uuid.UUID) codersdk.ChatStreamEvent {
return codersdk.ChatStreamEvent{
Type: codersdk.ChatStreamEventTypeMessagePart,
ChatID: chatID,
}
}
func messageEvent(chatID uuid.UUID, role codersdk.ChatMessageRole) codersdk.ChatStreamEvent {
return codersdk.ChatStreamEvent{
Type: codersdk.ChatStreamEventTypeMessage,
ChatID: chatID,
Message: &codersdk.ChatMessage{Role: role},
}
}