diff --git a/cli/exp_scaletest.go b/cli/exp_scaletest.go index 06af372e15..a4d5b14d65 100644 --- a/cli/exp_scaletest.go +++ b/cli/exp_scaletest.go @@ -70,6 +70,7 @@ func (r *RootCmd) scaletestCmd() *serpent.Command { r.scaletestSMTP(), r.scaletestPrebuilds(), r.scaletestBridge(), + r.scaletestChat(), r.scaletestLLMMock(), }, } @@ -404,13 +405,13 @@ func (f *workspaceTargetFlags) attach(opts *serpent.OptionSet) { Flag: "template", FlagShorthand: "t", Env: "CODER_SCALETEST_TEMPLATE", - Description: "Name or ID of the template. Traffic generation will be limited to workspaces created from this template.", + Description: "Name or ID of the template. Only workspaces created from this template are targeted.", Value: serpent.StringOf(&f.template), }, serpent.Option{ Flag: "target-workspaces", Env: "CODER_SCALETEST_TARGET_WORKSPACES", - Description: "Target a specific range of workspaces in the format [START]:[END] (exclusive). Example: 0:10 will target the 10 first alphabetically sorted workspaces (0-9).", + Description: "Target a specific range of matching workspaces in the format [START]:[END] (exclusive). Example: 0:10 targets the first 10 matching workspaces returned by the workspace query.", Value: serpent.StringOf(&f.targetWorkspaces), }, serpent.Option{ diff --git a/cli/exp_scaletest_chat.go b/cli/exp_scaletest_chat.go new file mode 100644 index 0000000000..bbde5f67ab --- /dev/null +++ b/cli/exp_scaletest_chat.go @@ -0,0 +1,254 @@ +//go:build !slim + +package cli + +import ( + "fmt" + "sync" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promhttp" + "golang.org/x/xerrors" + + "cdr.dev/slog/v3" + "cdr.dev/slog/v3/sloggers/sloghuman" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/scaletest/chat" + "github.com/coder/coder/v2/scaletest/harness" + "github.com/coder/coder/v2/scaletest/loadtestutil" + "github.com/coder/serpent" +) + +func (r *RootCmd) scaletestChat() *serpent.Command { + var ( + chatsPerWorkspace int64 + prompt string + turns int64 + turnStartDelay time.Duration + llmMockURL string + targetFlags = &workspaceTargetFlags{} + tracingFlags = &scaletestTracingFlags{} + prometheusFlags = &scaletestPrometheusFlags{} + timeoutStrategy = &timeoutFlags{} + cleanupStrategy = newScaletestCleanupStrategy() + output = &scaletestOutputFlags{} + ) + + cmd := &serpent.Command{ + Use: "chat", + Short: "Generate Coder Agents load.", + Handler: func(inv *serpent.Invocation) error { + baseCtx := inv.Context() + ctx, stop := inv.SignalNotifyContext(baseCtx, StopSignals...) + defer stop() + + outputs, err := output.parse() + if err != nil { + return xerrors.Errorf("could not parse --output flags: %w", err) + } + switch { + case turns < 1: + return xerrors.Errorf("--turns must be at least 1") + case chatsPerWorkspace < 1: + return xerrors.Errorf("--chats-per-workspace must be at least 1") + } + + client, err := r.InitClient(inv) + if err != nil { + return err + } + me, err := RequireAdmin(ctx, client) + if err != nil { + return err + } + client.HTTPClient.Transport = &codersdk.HeaderTransport{ + Transport: client.HTTPClient.Transport, + Header: BypassHeader, + } + + workspaces, err := targetFlags.getTargetedWorkspaces(ctx, client, me.OrganizationIDs, inv.Stdout) + if err != nil { + return err + } + + logger := slog.Make(sloghuman.Sink(inv.Stderr)).Leveled(slog.LevelDebug) + modelConfigID, err := chat.EnsureScaletestModelConfig(ctx, codersdk.NewExperimentalClient(client), logger, llmMockURL) + if err != nil { + return err + } + + // Start metrics and tracing before creating runners. + reg := prometheus.NewRegistry() + metrics := chat.NewMetrics(reg) + + prometheusSrvClose := ServeHandler(baseCtx, logger, promhttp.HandlerFor(reg, promhttp.HandlerOpts{}), prometheusFlags.Address, "prometheus") + + tracerProvider, closeTracing, tracingEnabled, err := tracingFlags.provider(baseCtx) + if err != nil { + prometheusSrvClose() + return xerrors.Errorf("create tracer provider: %w", err) + } + defer func() { + if tracingEnabled { + _, _ = fmt.Fprintln(inv.Stderr, "Uploading traces...") + } + if err := closeTracing(baseCtx); err != nil { + _, _ = fmt.Fprintf(inv.Stderr, "Error uploading traces: %+v\n", err) + } + _, _ = fmt.Fprintf(inv.Stderr, "Waiting %s for prometheus metrics to be scraped\n", prometheusFlags.Wait) + <-time.After(prometheusFlags.Wait) + prometheusSrvClose() + }() + + tracer := tracerProvider.Tracer(scaletestTracerName) + + var turnStartReadyWaitGroup *sync.WaitGroup + var startTurnsChan chan struct{} + if turnStartDelay > 0 && turns > 1 { + turnStartReadyWaitGroup = &sync.WaitGroup{} + startTurnsChan = make(chan struct{}) + } + + chatHarness := harness.NewTestHarness( + timeoutStrategy.wrapStrategy(harness.ConcurrentExecutionStrategy{}), + cleanupStrategy.toStrategy(), + ) + for workspaceIndex, targetWorkspace := range workspaces { + for chatIndex := int64(0); chatIndex < chatsPerWorkspace; chatIndex++ { + if turnStartReadyWaitGroup != nil { + turnStartReadyWaitGroup.Add(1) + } + + cfg := chat.Config{ + OrganizationID: targetWorkspace.OrganizationID, + WorkspaceID: targetWorkspace.ID, + Prompt: prompt, + ModelConfigID: modelConfigID, + Turns: int(turns), + TurnStartDelay: turnStartDelay, + TurnStartReadyWaitGroup: turnStartReadyWaitGroup, + StartTurnsChan: startTurnsChan, + Metrics: metrics, + } + if err := cfg.Validate(); err != nil { + return xerrors.Errorf("validate config for workspace %d chat %d: %w", workspaceIndex, chatIndex, err) + } + + runnerClient, err := loadtestutil.DupClientCopyingHeaders(client, BypassHeader) + if err != nil { + return xerrors.Errorf("duplicate client for workspace %d chat %d: %w", workspaceIndex, chatIndex, err) + } + var runner harness.Runnable = chat.NewRunner(runnerClient, cfg) + if tracingEnabled { + runner = &runnableTraceWrapper{ + tracer: tracer, + runner: runner, + spanName: fmt.Sprintf("chat/workspace-%d-chat-%d", workspaceIndex, chatIndex), + } + } + chatHarness.AddRun("chat", fmt.Sprintf("workspace-%d-chat-%d", workspaceIndex, chatIndex), runner) + } + } + + // Run the chat harness in the background so the CLI can release the + // follow-up turns after every runner finishes its initial turn. + totalChats := int64(len(workspaces)) * chatsPerWorkspace + _, _ = fmt.Fprintf(inv.Stderr, "Starting chat scale test with %d chats across %d workspaces...\n", totalChats, len(workspaces)) + testCtx, testCancel := timeoutStrategy.toContext(ctx) + defer testCancel() + testDone := make(chan error, 1) + go func() { + testDone <- chatHarness.Run(testCtx) + }() + + if turnStartReadyWaitGroup != nil { + initialTurnsDone := make(chan struct{}) + go func() { + turnStartReadyWaitGroup.Wait() + close(initialTurnsDone) + }() + + select { + case <-testCtx.Done(): + return testCtx.Err() + case <-initialTurnsDone: + } + + _, _ = fmt.Fprintf(inv.Stderr, "All %d initial turns completed, waiting %s before starting the follow-up turns...\n", totalChats, turnStartDelay) + select { + case <-testCtx.Done(): + return testCtx.Err() + case <-time.After(turnStartDelay): + } + + close(startTurnsChan) + } + + if err := <-testDone; err != nil { + return xerrors.Errorf("run harness: %w", err) + } + + results := chatHarness.Results() + for _, o := range outputs { + if err := o.write(results, inv.Stdout); err != nil { + return xerrors.Errorf("write output %q to %q: %w", o.format, o.path, err) + } + } + + _, _ = fmt.Fprintln(inv.Stderr, "\nCleaning up (archiving chats)...") + cleanupCtx, cleanupCancel := cleanupStrategy.toContext(ctx) + defer cleanupCancel() + if err := chatHarness.Cleanup(cleanupCtx); err != nil { + return xerrors.Errorf("cleanup chats: %w", err) + } + + if results.TotalFail > 0 { + return xerrors.Errorf("scale test failed: %d/%d runs failed", results.TotalFail, results.TotalRuns) + } + + _, _ = fmt.Fprintf(inv.Stderr, "Scale test passed: %d/%d runs succeeded\n", results.TotalPass, results.TotalRuns) + return nil + }, + } + + cmd.Options = serpent.OptionSet{ + { + Flag: "chats-per-workspace", + Description: "Number of chats to run against each targeted workspace. Required and must be greater than 0.", + Value: serpent.Int64Of(&chatsPerWorkspace), + Required: true, + }, + { + Flag: "prompt", + Description: "Text prompt to send on every turn in each chat.", + Default: "Reply with one short sentence.", + Value: serpent.StringOf(&prompt), + }, + { + Flag: "turns", + Description: "Number of user to assistant exchanges per chat conversation.", + Default: "10", + Value: serpent.Int64Of(&turns), + }, + { + Flag: "turn-start-delay", + Description: "Delay between every chat completing its initial turn and starting the follow-up turns. Use this to separate initial-turn load from follow-up-turn load.", + Default: "0s", + Value: serpent.DurationOf(&turnStartDelay), + }, + { + Flag: "llm-mock-url", + Description: "URL of the mock LLM server (e.g. http://127.0.0.1:8080/v1). Creates or updates the Scaletest LLM Mock openai-compat provider and model config to point at this URL.", + Value: serpent.StringOf(&llmMockURL), + Required: true, + }, + } + targetFlags.attach(&cmd.Options) + output.attach(&cmd.Options) + tracingFlags.attach(&cmd.Options) + prometheusFlags.attach(&cmd.Options) + timeoutStrategy.attach(&cmd.Options) + cleanupStrategy.attach(&cmd.Options) + return cmd +} diff --git a/scaletest/chat/client.go b/scaletest/chat/client.go new file mode 100644 index 0000000000..552bbd87e1 --- /dev/null +++ b/scaletest/chat/client.go @@ -0,0 +1,54 @@ +package chat + +import ( + "context" + "io" + + "github.com/google/uuid" + + "cdr.dev/slog/v3" + "github.com/coder/coder/v2/codersdk" +) + +type chatClient interface { + SetLogger(logger slog.Logger) + SetLogBodies(logBodies bool) + CreateChat(ctx context.Context, req codersdk.CreateChatRequest) (codersdk.Chat, error) + StreamChat(ctx context.Context, chatID uuid.UUID, opts *codersdk.StreamChatOptions) (<-chan codersdk.ChatStreamEvent, io.Closer, error) + CreateChatMessage(ctx context.Context, chatID uuid.UUID, req codersdk.CreateChatMessageRequest) (codersdk.CreateChatMessageResponse, error) + UpdateChat(ctx context.Context, chatID uuid.UUID, req codersdk.UpdateChatRequest) error +} + +type sdkChatClient struct { + client *codersdk.ExperimentalClient +} + +func newChatClient(client *codersdk.Client) chatClient { + return &sdkChatClient{client: codersdk.NewExperimentalClient(client)} +} + +func (c *sdkChatClient) SetLogger(logger slog.Logger) { + c.client.SetLogger(logger) +} + +func (c *sdkChatClient) SetLogBodies(logBodies bool) { + c.client.SetLogBodies(logBodies) +} + +func (c *sdkChatClient) CreateChat(ctx context.Context, req codersdk.CreateChatRequest) (codersdk.Chat, error) { + return c.client.CreateChat(ctx, req) +} + +func (c *sdkChatClient) StreamChat(ctx context.Context, chatID uuid.UUID, opts *codersdk.StreamChatOptions) (<-chan codersdk.ChatStreamEvent, io.Closer, error) { + return c.client.StreamChat(ctx, chatID, opts) +} + +func (c *sdkChatClient) CreateChatMessage(ctx context.Context, chatID uuid.UUID, req codersdk.CreateChatMessageRequest) (codersdk.CreateChatMessageResponse, error) { + return c.client.CreateChatMessage(ctx, chatID, req) +} + +func (c *sdkChatClient) UpdateChat(ctx context.Context, chatID uuid.UUID, req codersdk.UpdateChatRequest) error { + return c.client.UpdateChat(ctx, chatID, req) +} + +var _ chatClient = (*sdkChatClient)(nil) diff --git a/scaletest/chat/config.go b/scaletest/chat/config.go new file mode 100644 index 0000000000..5b6b36baa2 --- /dev/null +++ b/scaletest/chat/config.go @@ -0,0 +1,78 @@ +package chat + +import ( + "sync" + "time" + + "github.com/google/uuid" + "golang.org/x/xerrors" +) + +// Config describes a single chat runner within a scaletest invocation. +type Config struct { + // OrganizationID is the organization that owns the target workspace. + OrganizationID uuid.UUID `json:"organization_id"` + + // WorkspaceID is the pre-existing workspace to use for this chat run. + WorkspaceID uuid.UUID `json:"workspace_id"` + + // Prompt is the text content sent on every turn. + Prompt string `json:"prompt"` + + // ModelConfigID is the scaletest mock LLM model config. + ModelConfigID uuid.UUID `json:"model_config_id"` + + // Turns is the total number of user to assistant exchanges per chat. + // Must be at least 1. + Turns int `json:"turns"` + + // TurnStartDelay is the shared delay between every runner completing + // its initial turn and the release of the follow-up turns. Set + // to 0 to send all turns without an inter-phase pause. + TurnStartDelay time.Duration `json:"turn_start_delay"` + + // TurnStartReadyWaitGroup coordinates the gap between the initial turn + // finishing and the follow-up turns. Each runner signals exactly + // once after its first turn reaches a terminal status, or when it + // knows it will never reach that point. + TurnStartReadyWaitGroup *sync.WaitGroup `json:"-"` + + // StartTurnsChan blocks follow-up turns until the CLI layer releases them. + StartTurnsChan chan struct{} `json:"-"` + + Metrics *Metrics `json:"-"` +} + +func (c Config) Validate() error { + if c.OrganizationID == uuid.Nil { + return xerrors.Errorf("validate organization_id: must not be empty") + } + if c.WorkspaceID == uuid.Nil { + return xerrors.Errorf("validate workspace_id: must not be empty") + } + if c.Prompt == "" { + return xerrors.Errorf("validate prompt: must not be empty") + } + if c.ModelConfigID == uuid.Nil { + return xerrors.Errorf("validate model_config_id: must not be empty") + } + if c.Turns < 1 { + return xerrors.Errorf("validate turns: must be at least 1") + } + if c.TurnStartDelay < 0 { + return xerrors.Errorf("validate turn_start_delay: must not be negative") + } + if c.TurnStartDelay > 0 && c.Turns > 1 { + if c.TurnStartReadyWaitGroup == nil { + return xerrors.Errorf("validate turn_start_ready_wait_group: must not be nil when turn start delay is enabled for more than one turn") + } + if c.StartTurnsChan == nil { + return xerrors.Errorf("validate start_turns_chan: must not be nil when turn start delay is enabled for more than one turn") + } + } + if c.Metrics == nil { + return xerrors.Errorf("validate metrics: must not be nil") + } + + return nil +} diff --git a/scaletest/chat/metrics.go b/scaletest/chat/metrics.go new file mode 100644 index 0000000000..829931cd81 --- /dev/null +++ b/scaletest/chat/metrics.go @@ -0,0 +1,137 @@ +package chat + +import "github.com/prometheus/client_golang/prometheus" + +const ( + metricLabelPhase = "phase" + metricLabelStatus = "status" + metricLabelStage = "stage" + + phaseInitial = "initial" + phaseFollowUp = "follow_up" + + failureStageCreateChat = "create_chat" + failureStageCreateMessage = "create_message" + failureStageStreamOpen = "stream_open" + failureStageStreamEndedEarly = "stream_ended_early" + failureStageStatusError = "status_error" +) + +var ( + chatRequestLatencyBuckets = prometheus.ExponentialBucketsRange(0.05, 120, 18) + chatProcessingLatencyBuckets = prometheus.ExponentialBucketsRange(0.1, 300, 18) +) + +// Metrics holds the Prometheus metrics emitted by the chat scaletest. +type Metrics struct { + ChatCreateLatencySeconds prometheus.Histogram + ChatMessageLatencySeconds *prometheus.HistogramVec + ChatConversationDurationSeconds prometheus.Histogram + ChatTimeToRunningSeconds *prometheus.HistogramVec + ChatTimeToFirstOutputSeconds *prometheus.HistogramVec + ChatTimeToTerminalStatusSeconds *prometheus.HistogramVec + ChatStageFailuresTotal *prometheus.CounterVec + ChatTerminalStatusTotal *prometheus.CounterVec + ChatTurnsCompletedTotal prometheus.Counter + ChatRetryEventsTotal prometheus.Counter + ActiveChatStreams prometheus.Gauge +} + +func NewMetrics(reg prometheus.Registerer) *Metrics { + if reg == nil { + reg = prometheus.DefaultRegisterer + } + + phaseLabelNames := []string{metricLabelPhase} + terminalStatusLabelNames := []string{metricLabelStatus} + failureStageLabelNames := []string{metricLabelStage} + + m := &Metrics{ + ChatCreateLatencySeconds: prometheus.NewHistogram(prometheus.HistogramOpts{ + Namespace: "coderd", + Subsystem: "scaletest", + Name: "chat_create_latency_seconds", + Help: "Time in seconds to create a chat and enqueue the initial turn.", + Buckets: chatRequestLatencyBuckets, + }), + ChatMessageLatencySeconds: prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: "coderd", + Subsystem: "scaletest", + Name: "chat_message_latency_seconds", + Help: "Time in seconds to add a follow-up message to an existing chat.", + Buckets: chatRequestLatencyBuckets, + }, phaseLabelNames), + ChatConversationDurationSeconds: prometheus.NewHistogram(prometheus.HistogramOpts{ + Namespace: "coderd", + Subsystem: "scaletest", + Name: "chat_conversation_duration_seconds", + Help: "Time in seconds from chat creation start until the conversation finishes or errors.", + Buckets: chatProcessingLatencyBuckets, + }), + ChatTimeToRunningSeconds: prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: "coderd", + Subsystem: "scaletest", + Name: "chat_time_to_running_seconds", + Help: "Time in seconds from the start of a chat turn until the chat enters running status.", + Buckets: chatProcessingLatencyBuckets, + }, phaseLabelNames), + ChatTimeToFirstOutputSeconds: prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: "coderd", + Subsystem: "scaletest", + Name: "chat_time_to_first_output_seconds", + Help: "Time in seconds from the start of a chat turn until the first output is received.", + Buckets: chatProcessingLatencyBuckets, + }, phaseLabelNames), + ChatTimeToTerminalStatusSeconds: prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: "coderd", + Subsystem: "scaletest", + Name: "chat_time_to_terminal_status_seconds", + Help: "Time in seconds from the start of a chat turn until a terminal status is received.", + Buckets: chatProcessingLatencyBuckets, + }, phaseLabelNames), + ChatStageFailuresTotal: prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: "coderd", + Subsystem: "scaletest", + Name: "chat_stage_failures_total", + Help: "Total number of terminal stage-specific chat runner failures.", + }, failureStageLabelNames), + ChatTerminalStatusTotal: prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: "coderd", + Subsystem: "scaletest", + Name: "chat_terminal_status_total", + Help: "Total number of terminal chat statuses observed.", + }, terminalStatusLabelNames), + ChatTurnsCompletedTotal: prometheus.NewCounter(prometheus.CounterOpts{ + Namespace: "coderd", + Subsystem: "scaletest", + Name: "chat_turns_completed_total", + Help: "Total number of chat turns completed successfully.", + }), + ChatRetryEventsTotal: prometheus.NewCounter(prometheus.CounterOpts{ + Namespace: "coderd", + Subsystem: "scaletest", + Name: "chat_retry_events_total", + Help: "Total number of chat retry events observed.", + }), + ActiveChatStreams: prometheus.NewGauge(prometheus.GaugeOpts{ + Namespace: "coderd", + Subsystem: "scaletest", + Name: "active_chat_streams", + Help: "Current number of active chat streams.", + }), + } + + reg.MustRegister(m.ChatCreateLatencySeconds) + reg.MustRegister(m.ChatMessageLatencySeconds) + reg.MustRegister(m.ChatConversationDurationSeconds) + reg.MustRegister(m.ChatTimeToRunningSeconds) + reg.MustRegister(m.ChatTimeToFirstOutputSeconds) + reg.MustRegister(m.ChatTimeToTerminalStatusSeconds) + reg.MustRegister(m.ChatStageFailuresTotal) + reg.MustRegister(m.ChatTerminalStatusTotal) + reg.MustRegister(m.ChatTurnsCompletedTotal) + reg.MustRegister(m.ChatRetryEventsTotal) + reg.MustRegister(m.ActiveChatStreams) + + return m +} diff --git a/scaletest/chat/provider.go b/scaletest/chat/provider.go new file mode 100644 index 0000000000..ba946d7db2 --- /dev/null +++ b/scaletest/chat/provider.go @@ -0,0 +1,148 @@ +package chat + +import ( + "context" + "net/http" + + "github.com/google/uuid" + "golang.org/x/xerrors" + + "cdr.dev/slog/v3" + "github.com/coder/coder/v2/codersdk" +) + +const ( + scaletestProviderType = "openai-compat" + scaletestProviderDisplayName = "Scaletest LLM Mock" + scaletestModelName = "scaletest-model" + scaletestModelDisplayName = "Scaletest Model" +) + +type scaletestProviderAction string + +const ( + scaletestProviderActionCreated scaletestProviderAction = "created" + scaletestProviderActionUpdated scaletestProviderAction = "updated" + scaletestProviderActionReused scaletestProviderAction = "reused" +) + +// EnsureScaletestModelConfig bootstraps the shared chat provider and model +// config used by chat scaletests. +func EnsureScaletestModelConfig(ctx context.Context, client *codersdk.ExperimentalClient, logger slog.Logger, llmMockURL string) (uuid.UUID, error) { + logger.Info(ctx, "bootstrapping mock LLM provider", slog.F("llm_mock_url", llmMockURL)) + + provider, providerAction, err := ensureScaletestProvider(ctx, client, llmMockURL) + if err != nil { + return uuid.Nil, err + } + + switch providerAction { + case scaletestProviderActionCreated: + logger.Info(ctx, "created mock LLM provider", + slog.F("provider_type", scaletestProviderType), + slog.F("llm_mock_url", llmMockURL), + ) + case scaletestProviderActionUpdated: + logger.Info(ctx, "updated mock LLM provider", + slog.F("provider_type", scaletestProviderType), + slog.F("provider_id", provider.ID), + slog.F("llm_mock_url", llmMockURL), + ) + case scaletestProviderActionReused: + logger.Info(ctx, "reusing mock LLM provider", + slog.F("provider_type", scaletestProviderType), + slog.F("provider_id", provider.ID), + ) + } + + modelConfigs, err := client.ListChatModelConfigs(ctx) + if err != nil { + return uuid.Nil, xerrors.Errorf("list chat model configs: %w", err) + } + + for i := range modelConfigs { + if modelConfigs[i].Provider != provider.Provider || modelConfigs[i].Model != scaletestModelName { + continue + } + if !modelConfigs[i].Enabled { + return uuid.Nil, xerrors.Errorf("existing scaletest chat model config %s is disabled; re-enable or delete it before running scaletests", modelConfigs[i].ID) + } + modelConfigID := modelConfigs[i].ID + logger.Info(ctx, "reusing scaletest model config", slog.F("model_config_id", modelConfigID)) + return modelConfigID, nil + } + + enabled := true + isDefault := false + contextLimit := int64(4096) + created, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ + Provider: provider.Provider, + Model: scaletestModelName, + DisplayName: scaletestModelDisplayName, + Enabled: &enabled, + IsDefault: &isDefault, + ContextLimit: &contextLimit, + }) + if err != nil { + return uuid.Nil, xerrors.Errorf("create scaletest chat model config: %w", err) + } + logger.Info(ctx, "created scaletest model config", slog.F("model_config_id", created.ID)) + return created.ID, nil +} + +func ensureScaletestProvider(ctx context.Context, client *codersdk.ExperimentalClient, llmMockURL string) (codersdk.ChatProviderConfig, scaletestProviderAction, error) { + enabled := true + mockProviderToken := uuid.NewString() + created, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ + Provider: scaletestProviderType, + DisplayName: scaletestProviderDisplayName, + APIKey: mockProviderToken, + BaseURL: llmMockURL, + Enabled: &enabled, + }) + if err == nil { + return created, scaletestProviderActionCreated, nil + } + + var sdkErr *codersdk.Error + if !xerrors.As(err, &sdkErr) || sdkErr.StatusCode() != http.StatusConflict { + return codersdk.ChatProviderConfig{}, "", xerrors.Errorf("create scaletest chat provider: %w", err) + } + + providers, err := client.ListChatProviders(ctx) + if err != nil { + return codersdk.ChatProviderConfig{}, "", xerrors.Errorf("list chat providers: %w", err) + } + + var existing *codersdk.ChatProviderConfig + for i := range providers { + if providers[i].Provider == scaletestProviderType { + existing = &providers[i] + break + } + } + if existing == nil { + return codersdk.ChatProviderConfig{}, "", xerrors.Errorf("find existing %s provider after conflict: not found", scaletestProviderType) + } + if existing.DisplayName != scaletestProviderDisplayName { + return codersdk.ChatProviderConfig{}, "", xerrors.Errorf("refusing to overwrite existing %s provider %s with display name %q", scaletestProviderType, existing.ID, existing.DisplayName) + } + + if !existing.Enabled { + return codersdk.ChatProviderConfig{}, "", xerrors.Errorf("existing scaletest chat provider %s is disabled; re-enable or delete it before running scaletests", existing.ID) + } + if existing.BaseURL == llmMockURL { + return *existing, scaletestProviderActionReused, nil + } + + updated, err := client.UpdateChatProvider(ctx, existing.ID, codersdk.UpdateChatProviderConfigRequest{ + DisplayName: scaletestProviderDisplayName, + APIKey: &mockProviderToken, + BaseURL: &llmMockURL, + Enabled: &enabled, + }) + if err != nil { + return codersdk.ChatProviderConfig{}, "", xerrors.Errorf("update scaletest chat provider: %w", err) + } + return updated, scaletestProviderActionUpdated, nil +} diff --git a/scaletest/chat/run.go b/scaletest/chat/run.go new file mode 100644 index 0000000000..b2e591fab6 --- /dev/null +++ b/scaletest/chat/run.go @@ -0,0 +1,413 @@ +package chat + +import ( + "context" + "io" + "sync" + "time" + + "github.com/google/uuid" + "go.opentelemetry.io/otel/attribute" + "golang.org/x/xerrors" + + "cdr.dev/slog/v3" + "cdr.dev/slog/v3/sloggers/sloghuman" + "github.com/coder/coder/v2/coderd/tracing" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/scaletest/harness" + "github.com/coder/coder/v2/scaletest/loadtestutil" +) + +// Runner executes a single chat conversation as part of a scaletest run. +type Runner struct { + client chatClient + cfg Config + + chatID uuid.UUID + result runnerResult + + conversationStart time.Time + turnStartTime time.Time + currentPhase string + lastStreamError string + lastStatus codersdk.ChatStatus + sawTurnRunning bool + sawTurnFirstOutput bool + markTurnStartReady func() +} + +type runnerResult struct { + finalStatus string + failureStage string + totalDuration time.Duration + sawFirstOutput bool + retryCount int + eventCount int + turnsCompleted int +} + +var ( + _ harness.Runnable = &Runner{} + _ harness.Cleanable = &Runner{} + _ harness.Collectable = &Runner{} +) + +func NewRunner(client *codersdk.Client, cfg Config) *Runner { + return &Runner{ + client: newChatClient(client), + cfg: cfg, + } +} + +func (r *Runner) Run(ctx context.Context, id string, logs io.Writer) error { + ctx, span := tracing.StartSpan(ctx) + defer span.End() + + logs = loadtestutil.NewSyncWriter(logs) + logger := slog.Make(sloghuman.Sink(logs)).Leveled(slog.LevelDebug).Named(id) + r.client.SetLogger(logger) + r.client.SetLogBodies(true) + + span.SetAttributes( + attribute.String("chat.runner_id", id), + attribute.String("chat.workspace_id", r.cfg.WorkspaceID.String()), + attribute.Int("chat.turns_requested", r.cfg.Turns), + attribute.Int64("chat.turn_start_delay_ms", r.cfg.TurnStartDelay.Milliseconds()), + ) + span.SetAttributes(attribute.String("chat.model_config_id", r.cfg.ModelConfigID.String())) + + markTurnStartReady := func() {} + if r.cfg.TurnStartReadyWaitGroup != nil { + markTurnStartReady = sync.OnceFunc(r.cfg.TurnStartReadyWaitGroup.Done) + } + r.markTurnStartReady = markTurnStartReady + defer r.markTurnStartReady() + + defer func() { + if !r.conversationStart.IsZero() { + r.result.totalDuration = time.Since(r.conversationStart) + r.cfg.Metrics.ChatConversationDurationSeconds.Observe(r.result.totalDuration.Seconds()) + } + span.SetAttributes( + attribute.String("chat.final_status", r.result.finalStatus), + attribute.String("chat.failure_stage", r.result.failureStage), + attribute.Int("chat.retry_count", r.result.retryCount), + attribute.Int("chat.turns_completed", r.result.turnsCompleted), + attribute.Bool("chat.saw_first_output", r.result.sawFirstOutput), + ) + if r.result.totalDuration > 0 { + span.SetAttributes(attribute.Float64("chat.total_duration_seconds", r.result.totalDuration.Seconds())) + } + }() + + workspaceID := r.cfg.WorkspaceID + modelConfigID := r.cfg.ModelConfigID + logger = logger.With(slog.F("workspace_id", workspaceID)) + logger.Info(ctx, "starting chat runner") + + r.resetConversation(time.Now(), markTurnStartReady) + + createStartedAt := time.Now() + chat, err := r.client.CreateChat(ctx, codersdk.CreateChatRequest{ + OrganizationID: r.cfg.OrganizationID, + WorkspaceID: &workspaceID, + ModelConfigID: &modelConfigID, + Content: []codersdk.ChatInputPart{{ + Type: codersdk.ChatInputPartTypeText, + Text: r.cfg.Prompt, + }}, + }) + if err != nil { + r.result.failureStage = failureStageCreateChat + r.cfg.Metrics.ChatStageFailuresTotal.WithLabelValues(r.result.failureStage).Inc() + return xerrors.Errorf("create chat: %w", err) + } + r.cfg.Metrics.ChatCreateLatencySeconds.Observe(time.Since(createStartedAt).Seconds()) + + r.chatID = chat.ID + span.SetAttributes(attribute.String("chat.chat_id", chat.ID.String())) + logger = logger.With(slog.F("chat_id", chat.ID)) + logger.Info(ctx, "created chat session", slog.F("duration", time.Since(createStartedAt))) + + // CreateChat already queues the first prompt for processing on the + // server, so the initial turn is in flight as soon as CreateChat + // returns. Open the stream immediately and let the conversation loop + // drive the gate at the natural phase boundary (after the first turn + // reaches a terminal Waiting status), rather than fencing here on a + // turn that has already started running. + events, closer, err := r.client.StreamChat(ctx, chat.ID, nil) + if err != nil { + r.result.failureStage = failureStageStreamOpen + r.cfg.Metrics.ChatStageFailuresTotal.WithLabelValues(r.result.failureStage).Inc() + return xerrors.Errorf("stream chat: %w", err) + } + + r.cfg.Metrics.ActiveChatStreams.Inc() + defer func() { + r.cfg.Metrics.ActiveChatStreams.Dec() + _ = closer.Close() + }() + + logger.Info(ctx, "streaming chat events") + + return r.runConversation(ctx, chat.ID, logger, events) +} + +func (r *Runner) resetConversation(conversationStart time.Time, markTurnStartReady func()) { + if markTurnStartReady == nil { + markTurnStartReady = func() {} + } + + r.result = runnerResult{} + r.conversationStart = conversationStart + r.turnStartTime = conversationStart + r.currentPhase = phaseInitial + r.lastStreamError = "" + r.lastStatus = "" + r.sawTurnRunning = false + r.sawTurnFirstOutput = false + r.markTurnStartReady = markTurnStartReady +} + +func (r *Runner) runConversation(ctx context.Context, chatID uuid.UUID, logger slog.Logger, events <-chan codersdk.ChatStreamEvent) error { + r.chatID = chatID + + for event := range events { + r.result.eventCount++ + + switch event.Type { + case codersdk.ChatStreamEventTypeStatus: + if event.Status == nil { + continue + } + done, err := r.handleStatusEvent(ctx, chatID, logger, event.Status.Status) + if err != nil { + return err + } + if done { + return nil + } + case codersdk.ChatStreamEventTypeMessagePart: + r.handleMessagePartEvent(ctx, logger) + case codersdk.ChatStreamEventTypeMessage: + // StreamChat replays persisted rows as message events, not + // message_part deltas, when a turn finished server-side before + // the stream attached. Route assistant rows through the same + // first-output path; skip user rows so persisted prompts do not + // count as model output. + if event.Message == nil || event.Message.Role != codersdk.ChatMessageRoleAssistant { + continue + } + r.handleMessagePartEvent(ctx, logger) + case codersdk.ChatStreamEventTypeRetry: + r.handleRetryEvent(ctx, logger, event.Retry) + case codersdk.ChatStreamEventTypeError: + r.handleErrorEvent(ctx, logger, event.Error) + } + } + + if ctx.Err() != nil { + return ctx.Err() + } + + r.result.failureStage = failureStageStreamEndedEarly + r.cfg.Metrics.ChatStageFailuresTotal.WithLabelValues(r.result.failureStage).Inc() + if r.lastStreamError != "" { + return xerrors.Errorf("chat %s stream ended before completing %d of %d turns: %s", chatID, r.result.turnsCompleted, r.cfg.Turns, r.lastStreamError) + } + return xerrors.Errorf("chat %s stream ended before completing %d of %d turns", chatID, r.result.turnsCompleted, r.cfg.Turns) +} + +func (r *Runner) handleStatusEvent(ctx context.Context, chatID uuid.UUID, logger slog.Logger, status codersdk.ChatStatus) (bool, error) { + if status == r.lastStatus { + return false, nil + } + if status == codersdk.ChatStatusWaiting && + !r.sawTurnFirstOutput && + (r.sawTurnRunning || r.result.turnsCompleted > 0) { + return false, nil + } + r.lastStatus = status + + switch status { + case codersdk.ChatStatusRunning: + r.sawTurnRunning = true + r.cfg.Metrics.ChatTimeToRunningSeconds.WithLabelValues(r.currentPhase).Observe(time.Since(r.turnStartTime).Seconds()) + logger.Info(ctx, "chat reached running status", + slog.F("phase", r.currentPhase), + ) + return false, nil + case codersdk.ChatStatusWaiting: + r.result.turnsCompleted++ + turnDuration := time.Since(r.turnStartTime) + r.cfg.Metrics.ChatTimeToTerminalStatusSeconds.WithLabelValues(r.currentPhase).Observe(turnDuration.Seconds()) + r.cfg.Metrics.ChatTerminalStatusTotal.WithLabelValues(string(codersdk.ChatStatusWaiting)).Inc() + r.cfg.Metrics.ChatTurnsCompletedTotal.Inc() + logger.Info(ctx, "chat completed turn", + slog.F("turn", r.result.turnsCompleted), + slog.F("turns", r.cfg.Turns), + slog.F("duration", turnDuration), + ) + if r.result.turnsCompleted >= r.cfg.Turns { + r.result.finalStatus = string(codersdk.ChatStatusWaiting) + conversationDuration := time.Since(r.conversationStart) + logger.Info(ctx, "chat reached terminal status", + slog.F("status", codersdk.ChatStatusWaiting), + slog.F("duration", conversationDuration), + slog.F("turns_completed", r.result.turnsCompleted), + ) + return true, nil + } + + // After the very first turn completes, mark this runner ready + // for the CLI-coordinated turn-start gate. The inter-phase + // delay measures the gap between every chat actually finishing its + // initial turn and the start of the follow-up turns, not the gap + // between CreateChat returning and the next turn. + if r.result.turnsCompleted == 1 { + r.markTurnStartReady() + if r.cfg.StartTurnsChan != nil { + logger.Info(ctx, "chat waiting for turn start release", + slog.F("turn_start_delay", r.cfg.TurnStartDelay), + ) + select { + case <-ctx.Done(): + return false, ctx.Err() + case <-r.cfg.StartTurnsChan: + } + } + } + + nextTurn := r.result.turnsCompleted + 1 + r.currentPhase = phaseFollowUp + r.turnStartTime = time.Now() + r.lastStreamError = "" + r.lastStatus = "" + r.sawTurnRunning = false + r.sawTurnFirstOutput = false + if err := r.sendNextTurn(ctx, chatID, logger, nextTurn, r.currentPhase); err != nil { + r.result.failureStage = failureStageCreateMessage + r.cfg.Metrics.ChatStageFailuresTotal.WithLabelValues(r.result.failureStage).Inc() + return false, err + } + return false, nil + case codersdk.ChatStatusError: + r.result.finalStatus = string(codersdk.ChatStatusError) + r.result.failureStage = failureStageStatusError + turnDuration := time.Since(r.turnStartTime) + r.cfg.Metrics.ChatTimeToTerminalStatusSeconds.WithLabelValues(r.currentPhase).Observe(turnDuration.Seconds()) + r.cfg.Metrics.ChatTerminalStatusTotal.WithLabelValues(string(codersdk.ChatStatusError)).Inc() + r.cfg.Metrics.ChatStageFailuresTotal.WithLabelValues(r.result.failureStage).Inc() + + errMessage := r.lastStreamError + if errMessage == "" { + errMessage = "chat reached error status" + } + logger.Error(ctx, "chat reached terminal status", + slog.F("status", codersdk.ChatStatusError), + slog.F("turns_completed", r.result.turnsCompleted), + slog.F("turns", r.cfg.Turns), + slog.F("error", errMessage), + ) + return false, xerrors.Errorf("chat %s reached error status: %s", chatID, errMessage) + default: + return false, nil + } +} + +func (r *Runner) sendNextTurn(ctx context.Context, chatID uuid.UUID, logger slog.Logger, nextTurn int, phase string) error { + messageStartedAt := time.Now() + modelConfigID := r.cfg.ModelConfigID + _, err := r.client.CreateChatMessage(ctx, chatID, codersdk.CreateChatMessageRequest{ + Content: []codersdk.ChatInputPart{{ + Type: codersdk.ChatInputPartTypeText, + Text: r.cfg.Prompt, + }}, + ModelConfigID: &modelConfigID, + }) + if err != nil { + return xerrors.Errorf("create chat message for turn %d: %w", nextTurn, err) + } + + r.cfg.Metrics.ChatMessageLatencySeconds.WithLabelValues(phase).Observe(time.Since(messageStartedAt).Seconds()) + logger.Info(ctx, "chat sent message", + slog.F("turn", nextTurn), + slog.F("turns", r.cfg.Turns), + ) + return nil +} + +func (r *Runner) handleMessagePartEvent(ctx context.Context, logger slog.Logger) { + if r.sawTurnFirstOutput { + return + } + r.sawTurnFirstOutput = true + r.result.sawFirstOutput = true + firstOutputDuration := time.Since(r.turnStartTime) + r.cfg.Metrics.ChatTimeToFirstOutputSeconds.WithLabelValues(r.currentPhase).Observe(firstOutputDuration.Seconds()) + logger.Info(ctx, "chat received first output", + slog.F("phase", r.currentPhase), + slog.F("duration", firstOutputDuration), + ) +} + +func (r *Runner) handleRetryEvent(ctx context.Context, logger slog.Logger, retry *codersdk.ChatStreamRetry) { + r.result.retryCount++ + r.cfg.Metrics.ChatRetryEventsTotal.Inc() + if retry != nil { + logger.Warn(ctx, "chat retry event", + slog.F("attempt", retry.Attempt), + slog.F("delay_ms", retry.DelayMs), + slog.F("error", retry.Error), + ) + return + } + logger.Warn(ctx, "chat retry event") +} + +func (r *Runner) handleErrorEvent(ctx context.Context, logger slog.Logger, eventErr *codersdk.ChatError) { + if eventErr != nil && eventErr.Message != "" { + r.lastStreamError = eventErr.Message + logger.Warn(ctx, "chat stream error", + slog.F("error", r.lastStreamError), + ) + return + } + logger.Warn(ctx, "chat stream error event") +} + +func (r *Runner) Cleanup(ctx context.Context, id string, logs io.Writer) error { + if r.chatID == uuid.Nil { + return nil + } + + logs = loadtestutil.NewSyncWriter(logs) + logger := slog.Make(sloghuman.Sink(logs)).Leveled(slog.LevelDebug).Named(id).With(slog.F("chat_id", r.chatID)) + r.client.SetLogger(logger) + r.client.SetLogBodies(true) + + archived := true + logger.Info(ctx, "archiving chat session") + if err := r.client.UpdateChat(ctx, r.chatID, codersdk.UpdateChatRequest{Archived: &archived}); err != nil { + logger.Error(ctx, "failed to archive chat", slog.Error(err)) + return xerrors.Errorf("archive chat: %w", err) + } + logger.Info(ctx, "archived chat session") + return nil +} + +func (r *Runner) GetMetrics() map[string]any { + return map[string]any{ + "workspace_id": r.cfg.WorkspaceID.String(), + "turn_start_delay_ms": r.cfg.TurnStartDelay.Milliseconds(), + "chat_id": r.chatID.String(), + "final_status": r.result.finalStatus, + "failure_stage": r.result.failureStage, + "total_duration_seconds": r.result.totalDuration.Seconds(), + "saw_first_output": r.result.sawFirstOutput, + "retry_count": r.result.retryCount, + "event_count": r.result.eventCount, + "turns_requested": r.cfg.Turns, + "turns_completed": r.result.turnsCompleted, + } +} diff --git a/scaletest/chat/run_internal_test.go b/scaletest/chat/run_internal_test.go new file mode 100644 index 0000000000..2d93737fae --- /dev/null +++ b/scaletest/chat/run_internal_test.go @@ -0,0 +1,391 @@ +package chat + +import ( + "bytes" + "context" + "io" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/google/uuid" + "github.com/prometheus/client_golang/prometheus" + "github.com/stretchr/testify/require" + "golang.org/x/xerrors" + + "cdr.dev/slog/v3" + "cdr.dev/slog/v3/sloggers/sloghuman" + "github.com/coder/coder/v2/codersdk" +) + +func TestRunnerRunConversation(t *testing.T) { + t.Parallel() + + chatID := uuid.MustParse("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa") + noopMarkTurnStartReady := func() {} + + t.Run("OneTurnHappyPath", func(t *testing.T) { + t.Parallel() + + runner := newTestRunner(t, newRunConfig(t)) + events := make(chan codersdk.ChatStreamEvent, 3) + events <- statusEvent(chatID, codersdk.ChatStatusRunning) + events <- messagePartEvent(chatID) + events <- statusEvent(chatID, codersdk.ChatStatusWaiting) + close(events) + + err := runTestConversation(t, runner, chatID, events, noopMarkTurnStartReady) + require.NoError(t, err) + result := runner.result + require.Equal(t, string(codersdk.ChatStatusWaiting), result.finalStatus) + require.Empty(t, result.failureStage) + require.True(t, result.sawFirstOutput) + require.Equal(t, 1, result.turnsCompleted) + require.Equal(t, 3, result.eventCount) + }) + + t.Run("DuplicateWaitingDoesNotAdvanceTurn", func(t *testing.T) { + t.Parallel() + + cfg := newRunConfig(t) + cfg.Turns = 2 + + events := make(chan codersdk.ChatStreamEvent, 7) + events <- statusEvent(chatID, codersdk.ChatStatusRunning) + events <- messagePartEvent(chatID) + events <- statusEvent(chatID, codersdk.ChatStatusWaiting) + events <- statusEvent(chatID, codersdk.ChatStatusWaiting) + + var sendCount atomic.Int64 + runner := newTestRunnerWithChatMessage(t, cfg, chatID, func() { + sendCount.Add(1) + events <- statusEvent(chatID, codersdk.ChatStatusRunning) + events <- messagePartEvent(chatID) + events <- statusEvent(chatID, codersdk.ChatStatusWaiting) + close(events) + }) + + err := runTestConversation(t, runner, chatID, events, noopMarkTurnStartReady) + require.NoError(t, err) + result := runner.result + require.Equal(t, int64(1), sendCount.Load()) + require.Equal(t, 2, result.turnsCompleted) + require.Equal(t, 7, result.eventCount) + require.Equal(t, string(codersdk.ChatStatusWaiting), result.finalStatus) + }) + + t.Run("StaleWaitingAfterNextTurnRunningDoesNotAdvanceTurn", func(t *testing.T) { + t.Parallel() + + cfg := newRunConfig(t) + cfg.Turns = 2 + + events := make(chan codersdk.ChatStreamEvent, 7) + events <- statusEvent(chatID, codersdk.ChatStatusRunning) + events <- messagePartEvent(chatID) + events <- statusEvent(chatID, codersdk.ChatStatusWaiting) + + var sendCount atomic.Int64 + runner := newTestRunnerWithChatMessage(t, cfg, chatID, func() { + sendCount.Add(1) + events <- statusEvent(chatID, codersdk.ChatStatusRunning) + events <- statusEvent(chatID, codersdk.ChatStatusWaiting) + events <- messagePartEvent(chatID) + events <- statusEvent(chatID, codersdk.ChatStatusWaiting) + close(events) + }) + + err := runTestConversation(t, runner, chatID, events, noopMarkTurnStartReady) + require.NoError(t, err) + result := runner.result + require.Equal(t, int64(1), sendCount.Load()) + require.Equal(t, 2, result.turnsCompleted) + require.Equal(t, 7, result.eventCount) + require.Equal(t, string(codersdk.ChatStatusWaiting), result.finalStatus) + }) + + t.Run("FirstTurnGatesFollowUpStorm", func(t *testing.T) { + t.Parallel() + + // Reproduces the contract that the turn-start gate is checked + // after the first turn finishes, not before it begins. The runner + // must mark itself ready, wait for the release channel, and only + // then send turn 2. + cfg := newRunConfig(t) + cfg.Turns = 2 + readyWG := &sync.WaitGroup{} + readyWG.Add(1) + releaseChan := make(chan struct{}) + cfg.TurnStartReadyWaitGroup = readyWG + cfg.StartTurnsChan = releaseChan + + events := make(chan codersdk.ChatStreamEvent, 4) + events <- statusEvent(chatID, codersdk.ChatStatusRunning) + events <- messagePartEvent(chatID) + events <- statusEvent(chatID, codersdk.ChatStatusWaiting) + + ready := make(chan struct{}) + go func() { + readyWG.Wait() + close(ready) + }() + + errCh := make(chan error, 1) + var sendCount atomic.Int64 + runner := newTestRunnerWithChatMessage(t, cfg, chatID, func() { + sendCount.Add(1) + events <- statusEvent(chatID, codersdk.ChatStatusRunning) + events <- messagePartEvent(chatID) + events <- statusEvent(chatID, codersdk.ChatStatusWaiting) + close(events) + }) + + runner.resetConversation(time.Now(), sync.OnceFunc(readyWG.Done)) + + go func() { + runErr := runner.runConversation(context.Background(), chatID, testLogger(), events) + errCh <- runErr + }() + + select { + case <-ready: + case <-time.After(2 * time.Second): + t.Fatal("runner did not mark turn-start gate ready after first turn") + } + + require.Equal(t, int64(0), sendCount.Load(), "next turn was sent before turn-start release") + + close(releaseChan) + + select { + case err := <-errCh: + require.NoError(t, err) + case <-time.After(2 * time.Second): + t.Fatal("runner did not finish after turn-start release") + } + require.Equal(t, int64(1), sendCount.Load()) + }) + + t.Run("FirstOutputFromAssistantMessageEvent", func(t *testing.T) { + t.Parallel() + + // Snapshot race: when a turn finishes before stream attach, + // StreamChat replays rows as message events, never as + // message_part deltas; the assistant row must record first output. + runner := newTestRunner(t, newRunConfig(t)) + events := make(chan codersdk.ChatStreamEvent, 3) + events <- messageEvent(chatID, codersdk.ChatMessageRoleUser) + events <- messageEvent(chatID, codersdk.ChatMessageRoleAssistant) + events <- statusEvent(chatID, codersdk.ChatStatusWaiting) + close(events) + + err := runTestConversation(t, runner, chatID, events, noopMarkTurnStartReady) + require.NoError(t, err) + result := runner.result + require.True(t, result.sawFirstOutput, "first output not recorded from assistant message event") + require.Equal(t, 1, result.turnsCompleted) + require.Equal(t, string(codersdk.ChatStatusWaiting), result.finalStatus) + }) + + t.Run("ImmediateWaitingCountsNextTurn", func(t *testing.T) { + t.Parallel() + + cfg := newRunConfig(t) + cfg.Turns = 2 + + events := make(chan codersdk.ChatStreamEvent, 3) + events <- statusEvent(chatID, codersdk.ChatStatusWaiting) + + var sendCount atomic.Int64 + runner := newTestRunnerWithChatMessage(t, cfg, chatID, func() { + sendCount.Add(1) + events <- statusEvent(chatID, codersdk.ChatStatusRunning) + events <- messagePartEvent(chatID) + events <- statusEvent(chatID, codersdk.ChatStatusWaiting) + close(events) + }) + + err := runTestConversation(t, runner, chatID, events, noopMarkTurnStartReady) + require.NoError(t, err) + result := runner.result + require.Equal(t, int64(1), sendCount.Load()) + require.Equal(t, 2, result.turnsCompleted) + require.Equal(t, string(codersdk.ChatStatusWaiting), result.finalStatus) + }) +} + +func runTestConversation(t *testing.T, runner *Runner, chatID uuid.UUID, events <-chan codersdk.ChatStreamEvent, markTurnStartReady func()) error { + t.Helper() + runner.resetConversation(time.Now(), markTurnStartReady) + return runner.runConversation(context.Background(), chatID, testLogger(), events) +} + +func TestRunnerCleanup(t *testing.T) { + t.Parallel() + + chatID := uuid.MustParse("22222222-2222-2222-2222-222222222222") + + t.Run("ArchivesChat", func(t *testing.T) { + t.Parallel() + + runner, archived := newTestRunnerWithChatArchive(t, chatID, nil) + + logs := bytes.NewBuffer(nil) + err := runner.Cleanup(context.Background(), "runner-1", logs) + require.NoError(t, err) + require.True(t, archived()) + require.Contains(t, logs.String(), "archived chat") + }) + + t.Run("ArchiveErrorIsReturned", func(t *testing.T) { + t.Parallel() + + runner, archived := newTestRunnerWithChatArchive(t, chatID, xerrors.New("boom")) + + err := runner.Cleanup(context.Background(), "runner-1", bytes.NewBuffer(nil)) + require.Error(t, err) + require.ErrorContains(t, err, "archive chat") + require.True(t, archived()) + }) +} + +func testLogger() slog.Logger { + return slog.Make(sloghuman.Sink(io.Discard)).Leveled(slog.LevelDebug) +} + +func newRunConfig(t *testing.T) Config { + t.Helper() + reg := prometheus.NewRegistry() + return Config{ + OrganizationID: uuid.MustParse("22222222-2222-2222-2222-222222222222"), + WorkspaceID: uuid.MustParse("11111111-1111-1111-1111-111111111111"), + ModelConfigID: uuid.MustParse("33333333-3333-3333-3333-333333333333"), + Prompt: "Reply with one short sentence.", + Turns: 1, + Metrics: NewMetrics(reg), + } +} + +type fakeChatClient struct { + createChatFunc func(context.Context, codersdk.CreateChatRequest) (codersdk.Chat, error) + streamChatFunc func(context.Context, uuid.UUID, *codersdk.StreamChatOptions) (<-chan codersdk.ChatStreamEvent, io.Closer, error) + createChatMessageFunc func(context.Context, uuid.UUID, codersdk.CreateChatMessageRequest) (codersdk.CreateChatMessageResponse, error) + updateChatFunc func(context.Context, uuid.UUID, codersdk.UpdateChatRequest) error +} + +func newFakeChatClient(t *testing.T) *fakeChatClient { + t.Helper() + return &fakeChatClient{} +} + +func (*fakeChatClient) SetLogger(logger slog.Logger) {} + +func (*fakeChatClient) SetLogBodies(logBodies bool) {} + +func (f *fakeChatClient) CreateChat(ctx context.Context, req codersdk.CreateChatRequest) (codersdk.Chat, error) { + if f.createChatFunc == nil { + return codersdk.Chat{}, xerrors.New("unexpected CreateChat call") + } + return f.createChatFunc(ctx, req) +} + +func (f *fakeChatClient) StreamChat(ctx context.Context, chatID uuid.UUID, opts *codersdk.StreamChatOptions) (<-chan codersdk.ChatStreamEvent, io.Closer, error) { + if f.streamChatFunc == nil { + return nil, nil, xerrors.New("unexpected StreamChat call") + } + return f.streamChatFunc(ctx, chatID, opts) +} + +func (f *fakeChatClient) CreateChatMessage(ctx context.Context, chatID uuid.UUID, req codersdk.CreateChatMessageRequest) (codersdk.CreateChatMessageResponse, error) { + if f.createChatMessageFunc == nil { + return codersdk.CreateChatMessageResponse{}, xerrors.New("unexpected CreateChatMessage call") + } + return f.createChatMessageFunc(ctx, chatID, req) +} + +func (f *fakeChatClient) UpdateChat(ctx context.Context, chatID uuid.UUID, req codersdk.UpdateChatRequest) error { + if f.updateChatFunc == nil { + return xerrors.New("unexpected UpdateChat call") + } + return f.updateChatFunc(ctx, chatID, req) +} + +var _ chatClient = (*fakeChatClient)(nil) + +func newTestRunner(t *testing.T, cfg Config) *Runner { + t.Helper() + return &Runner{client: newFakeChatClient(t), cfg: cfg} +} + +func newTestRunnerWithChatArchive(t *testing.T, chatID uuid.UUID, updateErr error) (*Runner, func() bool) { + t.Helper() + + var archived atomic.Bool + client := newFakeChatClient(t) + client.updateChatFunc = func(ctx context.Context, gotChatID uuid.UUID, req codersdk.UpdateChatRequest) error { + if gotChatID != chatID { + return xerrors.Errorf("unexpected chat archive ID: %s", gotChatID) + } + if req.Archived == nil || !*req.Archived { + return xerrors.Errorf("unexpected archived value: %v", req.Archived) + } + archived.Store(true) + return updateErr + } + runner := &Runner{client: client, cfg: Config{}, chatID: chatID} + return runner, archived.Load +} + +func newTestRunnerWithChatMessage(t *testing.T, cfg Config, chatID uuid.UUID, onMessage func()) *Runner { + t.Helper() + + client := newFakeChatClient(t) + client.createChatMessageFunc = func(ctx context.Context, gotChatID uuid.UUID, req codersdk.CreateChatMessageRequest) (codersdk.CreateChatMessageResponse, error) { + if gotChatID != chatID { + return codersdk.CreateChatMessageResponse{}, xerrors.Errorf("unexpected chat message ID: %s", gotChatID) + } + if err := validatePromptParts(req.Content, cfg.Prompt); err != nil { + return codersdk.CreateChatMessageResponse{}, err + } + if req.ModelConfigID == nil || *req.ModelConfigID != cfg.ModelConfigID { + return codersdk.CreateChatMessageResponse{}, xerrors.Errorf("unexpected chat message model config ID: %v", req.ModelConfigID) + } + + if onMessage != nil { + onMessage() + } + return codersdk.CreateChatMessageResponse{Queued: true}, nil + } + return &Runner{client: client, cfg: cfg} +} + +func validatePromptParts(parts []codersdk.ChatInputPart, prompt string) error { + if len(parts) != 1 || parts[0].Type != codersdk.ChatInputPartTypeText || parts[0].Text != prompt { + return xerrors.Errorf("unexpected chat message content: %#v", parts) + } + return nil +} + +func statusEvent(chatID uuid.UUID, status codersdk.ChatStatus) codersdk.ChatStreamEvent { + return codersdk.ChatStreamEvent{ + Type: codersdk.ChatStreamEventTypeStatus, + ChatID: chatID, + Status: &codersdk.ChatStreamStatus{Status: status}, + } +} + +func messagePartEvent(chatID uuid.UUID) codersdk.ChatStreamEvent { + return codersdk.ChatStreamEvent{ + Type: codersdk.ChatStreamEventTypeMessagePart, + ChatID: chatID, + } +} + +func messageEvent(chatID uuid.UUID, role codersdk.ChatMessageRole) codersdk.ChatStreamEvent { + return codersdk.ChatStreamEvent{ + Type: codersdk.ChatStreamEventTypeMessage, + ChatID: chatID, + Message: &codersdk.ChatMessage{Role: role}, + } +}