mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
feat(scaletest): add chat scaletest command (#25553)
Adds `coder exp scaletest chat`, a harness for creating Coder Agents chat load. Start the mock LLM separately, prepare the scaletest workspaces you want to target, then run the chat scaletest against the existing `scaletest-*` fleet selected by the shared workspace targeting flags: ```sh coder exp scaletest llm-mock --address 127.0.0.1:18080 coder exp scaletest chat --llm-mock-url http://127.0.0.1:18080/v1 --chats-per-workspace 10 --turns 1 coder exp scaletest chat --llm-mock-url http://127.0.0.1:18080/v1 --template docker --target-workspaces 0:10 --chats-per-workspace 1 --turns 10 --turn-start-delay 30s ``` This is the same pattern used by the `workspace-traffic` load generator. Keeping the fake LLM as a separate process is intentional so it can be scaled independently from the Coder deployment, which will likely be necessary as we scale up and up. This PR is the starting point: it provides the command, mock provider/model bootstrap, existing workspace selection, chat streaming, follow-up turns, metrics, and cleanup. Follow-up PRs will add multi-step turns via tool calls. I'm still a bit iffy on the mechanism I have for that. It'll likely involve having the runner send some magic strings that the mock will recognise. Relates to CODAGT-307 Relates to GRU-48 Relates to https://github.com/coder/scaletest/issues/124 Generated by Mux, but reviewed by a human
This commit is contained in:
@@ -70,6 +70,7 @@ func (r *RootCmd) scaletestCmd() *serpent.Command {
|
||||
r.scaletestSMTP(),
|
||||
r.scaletestPrebuilds(),
|
||||
r.scaletestBridge(),
|
||||
r.scaletestChat(),
|
||||
r.scaletestLLMMock(),
|
||||
},
|
||||
}
|
||||
@@ -404,13 +405,13 @@ func (f *workspaceTargetFlags) attach(opts *serpent.OptionSet) {
|
||||
Flag: "template",
|
||||
FlagShorthand: "t",
|
||||
Env: "CODER_SCALETEST_TEMPLATE",
|
||||
Description: "Name or ID of the template. Traffic generation will be limited to workspaces created from this template.",
|
||||
Description: "Name or ID of the template. Only workspaces created from this template are targeted.",
|
||||
Value: serpent.StringOf(&f.template),
|
||||
},
|
||||
serpent.Option{
|
||||
Flag: "target-workspaces",
|
||||
Env: "CODER_SCALETEST_TARGET_WORKSPACES",
|
||||
Description: "Target a specific range of workspaces in the format [START]:[END] (exclusive). Example: 0:10 will target the 10 first alphabetically sorted workspaces (0-9).",
|
||||
Description: "Target a specific range of matching workspaces in the format [START]:[END] (exclusive). Example: 0:10 targets the first 10 matching workspaces returned by the workspace query.",
|
||||
Value: serpent.StringOf(&f.targetWorkspaces),
|
||||
},
|
||||
serpent.Option{
|
||||
|
||||
@@ -0,0 +1,254 @@
|
||||
//go:build !slim
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/sloghuman"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/scaletest/chat"
|
||||
"github.com/coder/coder/v2/scaletest/harness"
|
||||
"github.com/coder/coder/v2/scaletest/loadtestutil"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
func (r *RootCmd) scaletestChat() *serpent.Command {
|
||||
var (
|
||||
chatsPerWorkspace int64
|
||||
prompt string
|
||||
turns int64
|
||||
turnStartDelay time.Duration
|
||||
llmMockURL string
|
||||
targetFlags = &workspaceTargetFlags{}
|
||||
tracingFlags = &scaletestTracingFlags{}
|
||||
prometheusFlags = &scaletestPrometheusFlags{}
|
||||
timeoutStrategy = &timeoutFlags{}
|
||||
cleanupStrategy = newScaletestCleanupStrategy()
|
||||
output = &scaletestOutputFlags{}
|
||||
)
|
||||
|
||||
cmd := &serpent.Command{
|
||||
Use: "chat",
|
||||
Short: "Generate Coder Agents load.",
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
baseCtx := inv.Context()
|
||||
ctx, stop := inv.SignalNotifyContext(baseCtx, StopSignals...)
|
||||
defer stop()
|
||||
|
||||
outputs, err := output.parse()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("could not parse --output flags: %w", err)
|
||||
}
|
||||
switch {
|
||||
case turns < 1:
|
||||
return xerrors.Errorf("--turns must be at least 1")
|
||||
case chatsPerWorkspace < 1:
|
||||
return xerrors.Errorf("--chats-per-workspace must be at least 1")
|
||||
}
|
||||
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
me, err := RequireAdmin(ctx, client)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
client.HTTPClient.Transport = &codersdk.HeaderTransport{
|
||||
Transport: client.HTTPClient.Transport,
|
||||
Header: BypassHeader,
|
||||
}
|
||||
|
||||
workspaces, err := targetFlags.getTargetedWorkspaces(ctx, client, me.OrganizationIDs, inv.Stdout)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
logger := slog.Make(sloghuman.Sink(inv.Stderr)).Leveled(slog.LevelDebug)
|
||||
modelConfigID, err := chat.EnsureScaletestModelConfig(ctx, codersdk.NewExperimentalClient(client), logger, llmMockURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Start metrics and tracing before creating runners.
|
||||
reg := prometheus.NewRegistry()
|
||||
metrics := chat.NewMetrics(reg)
|
||||
|
||||
prometheusSrvClose := ServeHandler(baseCtx, logger, promhttp.HandlerFor(reg, promhttp.HandlerOpts{}), prometheusFlags.Address, "prometheus")
|
||||
|
||||
tracerProvider, closeTracing, tracingEnabled, err := tracingFlags.provider(baseCtx)
|
||||
if err != nil {
|
||||
prometheusSrvClose()
|
||||
return xerrors.Errorf("create tracer provider: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if tracingEnabled {
|
||||
_, _ = fmt.Fprintln(inv.Stderr, "Uploading traces...")
|
||||
}
|
||||
if err := closeTracing(baseCtx); err != nil {
|
||||
_, _ = fmt.Fprintf(inv.Stderr, "Error uploading traces: %+v\n", err)
|
||||
}
|
||||
_, _ = fmt.Fprintf(inv.Stderr, "Waiting %s for prometheus metrics to be scraped\n", prometheusFlags.Wait)
|
||||
<-time.After(prometheusFlags.Wait)
|
||||
prometheusSrvClose()
|
||||
}()
|
||||
|
||||
tracer := tracerProvider.Tracer(scaletestTracerName)
|
||||
|
||||
var turnStartReadyWaitGroup *sync.WaitGroup
|
||||
var startTurnsChan chan struct{}
|
||||
if turnStartDelay > 0 && turns > 1 {
|
||||
turnStartReadyWaitGroup = &sync.WaitGroup{}
|
||||
startTurnsChan = make(chan struct{})
|
||||
}
|
||||
|
||||
chatHarness := harness.NewTestHarness(
|
||||
timeoutStrategy.wrapStrategy(harness.ConcurrentExecutionStrategy{}),
|
||||
cleanupStrategy.toStrategy(),
|
||||
)
|
||||
for workspaceIndex, targetWorkspace := range workspaces {
|
||||
for chatIndex := int64(0); chatIndex < chatsPerWorkspace; chatIndex++ {
|
||||
if turnStartReadyWaitGroup != nil {
|
||||
turnStartReadyWaitGroup.Add(1)
|
||||
}
|
||||
|
||||
cfg := chat.Config{
|
||||
OrganizationID: targetWorkspace.OrganizationID,
|
||||
WorkspaceID: targetWorkspace.ID,
|
||||
Prompt: prompt,
|
||||
ModelConfigID: modelConfigID,
|
||||
Turns: int(turns),
|
||||
TurnStartDelay: turnStartDelay,
|
||||
TurnStartReadyWaitGroup: turnStartReadyWaitGroup,
|
||||
StartTurnsChan: startTurnsChan,
|
||||
Metrics: metrics,
|
||||
}
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return xerrors.Errorf("validate config for workspace %d chat %d: %w", workspaceIndex, chatIndex, err)
|
||||
}
|
||||
|
||||
runnerClient, err := loadtestutil.DupClientCopyingHeaders(client, BypassHeader)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("duplicate client for workspace %d chat %d: %w", workspaceIndex, chatIndex, err)
|
||||
}
|
||||
var runner harness.Runnable = chat.NewRunner(runnerClient, cfg)
|
||||
if tracingEnabled {
|
||||
runner = &runnableTraceWrapper{
|
||||
tracer: tracer,
|
||||
runner: runner,
|
||||
spanName: fmt.Sprintf("chat/workspace-%d-chat-%d", workspaceIndex, chatIndex),
|
||||
}
|
||||
}
|
||||
chatHarness.AddRun("chat", fmt.Sprintf("workspace-%d-chat-%d", workspaceIndex, chatIndex), runner)
|
||||
}
|
||||
}
|
||||
|
||||
// Run the chat harness in the background so the CLI can release the
|
||||
// follow-up turns after every runner finishes its initial turn.
|
||||
totalChats := int64(len(workspaces)) * chatsPerWorkspace
|
||||
_, _ = fmt.Fprintf(inv.Stderr, "Starting chat scale test with %d chats across %d workspaces...\n", totalChats, len(workspaces))
|
||||
testCtx, testCancel := timeoutStrategy.toContext(ctx)
|
||||
defer testCancel()
|
||||
testDone := make(chan error, 1)
|
||||
go func() {
|
||||
testDone <- chatHarness.Run(testCtx)
|
||||
}()
|
||||
|
||||
if turnStartReadyWaitGroup != nil {
|
||||
initialTurnsDone := make(chan struct{})
|
||||
go func() {
|
||||
turnStartReadyWaitGroup.Wait()
|
||||
close(initialTurnsDone)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-testCtx.Done():
|
||||
return testCtx.Err()
|
||||
case <-initialTurnsDone:
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintf(inv.Stderr, "All %d initial turns completed, waiting %s before starting the follow-up turns...\n", totalChats, turnStartDelay)
|
||||
select {
|
||||
case <-testCtx.Done():
|
||||
return testCtx.Err()
|
||||
case <-time.After(turnStartDelay):
|
||||
}
|
||||
|
||||
close(startTurnsChan)
|
||||
}
|
||||
|
||||
if err := <-testDone; err != nil {
|
||||
return xerrors.Errorf("run harness: %w", err)
|
||||
}
|
||||
|
||||
results := chatHarness.Results()
|
||||
for _, o := range outputs {
|
||||
if err := o.write(results, inv.Stdout); err != nil {
|
||||
return xerrors.Errorf("write output %q to %q: %w", o.format, o.path, err)
|
||||
}
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintln(inv.Stderr, "\nCleaning up (archiving chats)...")
|
||||
cleanupCtx, cleanupCancel := cleanupStrategy.toContext(ctx)
|
||||
defer cleanupCancel()
|
||||
if err := chatHarness.Cleanup(cleanupCtx); err != nil {
|
||||
return xerrors.Errorf("cleanup chats: %w", err)
|
||||
}
|
||||
|
||||
if results.TotalFail > 0 {
|
||||
return xerrors.Errorf("scale test failed: %d/%d runs failed", results.TotalFail, results.TotalRuns)
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintf(inv.Stderr, "Scale test passed: %d/%d runs succeeded\n", results.TotalPass, results.TotalRuns)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Options = serpent.OptionSet{
|
||||
{
|
||||
Flag: "chats-per-workspace",
|
||||
Description: "Number of chats to run against each targeted workspace. Required and must be greater than 0.",
|
||||
Value: serpent.Int64Of(&chatsPerWorkspace),
|
||||
Required: true,
|
||||
},
|
||||
{
|
||||
Flag: "prompt",
|
||||
Description: "Text prompt to send on every turn in each chat.",
|
||||
Default: "Reply with one short sentence.",
|
||||
Value: serpent.StringOf(&prompt),
|
||||
},
|
||||
{
|
||||
Flag: "turns",
|
||||
Description: "Number of user to assistant exchanges per chat conversation.",
|
||||
Default: "10",
|
||||
Value: serpent.Int64Of(&turns),
|
||||
},
|
||||
{
|
||||
Flag: "turn-start-delay",
|
||||
Description: "Delay between every chat completing its initial turn and starting the follow-up turns. Use this to separate initial-turn load from follow-up-turn load.",
|
||||
Default: "0s",
|
||||
Value: serpent.DurationOf(&turnStartDelay),
|
||||
},
|
||||
{
|
||||
Flag: "llm-mock-url",
|
||||
Description: "URL of the mock LLM server (e.g. http://127.0.0.1:8080/v1). Creates or updates the Scaletest LLM Mock openai-compat provider and model config to point at this URL.",
|
||||
Value: serpent.StringOf(&llmMockURL),
|
||||
Required: true,
|
||||
},
|
||||
}
|
||||
targetFlags.attach(&cmd.Options)
|
||||
output.attach(&cmd.Options)
|
||||
tracingFlags.attach(&cmd.Options)
|
||||
prometheusFlags.attach(&cmd.Options)
|
||||
timeoutStrategy.attach(&cmd.Options)
|
||||
cleanupStrategy.attach(&cmd.Options)
|
||||
return cmd
|
||||
}
|
||||
@@ -0,0 +1,54 @@
|
||||
package chat
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
type chatClient interface {
|
||||
SetLogger(logger slog.Logger)
|
||||
SetLogBodies(logBodies bool)
|
||||
CreateChat(ctx context.Context, req codersdk.CreateChatRequest) (codersdk.Chat, error)
|
||||
StreamChat(ctx context.Context, chatID uuid.UUID, opts *codersdk.StreamChatOptions) (<-chan codersdk.ChatStreamEvent, io.Closer, error)
|
||||
CreateChatMessage(ctx context.Context, chatID uuid.UUID, req codersdk.CreateChatMessageRequest) (codersdk.CreateChatMessageResponse, error)
|
||||
UpdateChat(ctx context.Context, chatID uuid.UUID, req codersdk.UpdateChatRequest) error
|
||||
}
|
||||
|
||||
type sdkChatClient struct {
|
||||
client *codersdk.ExperimentalClient
|
||||
}
|
||||
|
||||
func newChatClient(client *codersdk.Client) chatClient {
|
||||
return &sdkChatClient{client: codersdk.NewExperimentalClient(client)}
|
||||
}
|
||||
|
||||
func (c *sdkChatClient) SetLogger(logger slog.Logger) {
|
||||
c.client.SetLogger(logger)
|
||||
}
|
||||
|
||||
func (c *sdkChatClient) SetLogBodies(logBodies bool) {
|
||||
c.client.SetLogBodies(logBodies)
|
||||
}
|
||||
|
||||
func (c *sdkChatClient) CreateChat(ctx context.Context, req codersdk.CreateChatRequest) (codersdk.Chat, error) {
|
||||
return c.client.CreateChat(ctx, req)
|
||||
}
|
||||
|
||||
func (c *sdkChatClient) StreamChat(ctx context.Context, chatID uuid.UUID, opts *codersdk.StreamChatOptions) (<-chan codersdk.ChatStreamEvent, io.Closer, error) {
|
||||
return c.client.StreamChat(ctx, chatID, opts)
|
||||
}
|
||||
|
||||
func (c *sdkChatClient) CreateChatMessage(ctx context.Context, chatID uuid.UUID, req codersdk.CreateChatMessageRequest) (codersdk.CreateChatMessageResponse, error) {
|
||||
return c.client.CreateChatMessage(ctx, chatID, req)
|
||||
}
|
||||
|
||||
func (c *sdkChatClient) UpdateChat(ctx context.Context, chatID uuid.UUID, req codersdk.UpdateChatRequest) error {
|
||||
return c.client.UpdateChat(ctx, chatID, req)
|
||||
}
|
||||
|
||||
var _ chatClient = (*sdkChatClient)(nil)
|
||||
@@ -0,0 +1,78 @@
|
||||
package chat
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
// Config describes a single chat runner within a scaletest invocation.
|
||||
type Config struct {
|
||||
// OrganizationID is the organization that owns the target workspace.
|
||||
OrganizationID uuid.UUID `json:"organization_id"`
|
||||
|
||||
// WorkspaceID is the pre-existing workspace to use for this chat run.
|
||||
WorkspaceID uuid.UUID `json:"workspace_id"`
|
||||
|
||||
// Prompt is the text content sent on every turn.
|
||||
Prompt string `json:"prompt"`
|
||||
|
||||
// ModelConfigID is the scaletest mock LLM model config.
|
||||
ModelConfigID uuid.UUID `json:"model_config_id"`
|
||||
|
||||
// Turns is the total number of user to assistant exchanges per chat.
|
||||
// Must be at least 1.
|
||||
Turns int `json:"turns"`
|
||||
|
||||
// TurnStartDelay is the shared delay between every runner completing
|
||||
// its initial turn and the release of the follow-up turns. Set
|
||||
// to 0 to send all turns without an inter-phase pause.
|
||||
TurnStartDelay time.Duration `json:"turn_start_delay"`
|
||||
|
||||
// TurnStartReadyWaitGroup coordinates the gap between the initial turn
|
||||
// finishing and the follow-up turns. Each runner signals exactly
|
||||
// once after its first turn reaches a terminal status, or when it
|
||||
// knows it will never reach that point.
|
||||
TurnStartReadyWaitGroup *sync.WaitGroup `json:"-"`
|
||||
|
||||
// StartTurnsChan blocks follow-up turns until the CLI layer releases them.
|
||||
StartTurnsChan chan struct{} `json:"-"`
|
||||
|
||||
Metrics *Metrics `json:"-"`
|
||||
}
|
||||
|
||||
func (c Config) Validate() error {
|
||||
if c.OrganizationID == uuid.Nil {
|
||||
return xerrors.Errorf("validate organization_id: must not be empty")
|
||||
}
|
||||
if c.WorkspaceID == uuid.Nil {
|
||||
return xerrors.Errorf("validate workspace_id: must not be empty")
|
||||
}
|
||||
if c.Prompt == "" {
|
||||
return xerrors.Errorf("validate prompt: must not be empty")
|
||||
}
|
||||
if c.ModelConfigID == uuid.Nil {
|
||||
return xerrors.Errorf("validate model_config_id: must not be empty")
|
||||
}
|
||||
if c.Turns < 1 {
|
||||
return xerrors.Errorf("validate turns: must be at least 1")
|
||||
}
|
||||
if c.TurnStartDelay < 0 {
|
||||
return xerrors.Errorf("validate turn_start_delay: must not be negative")
|
||||
}
|
||||
if c.TurnStartDelay > 0 && c.Turns > 1 {
|
||||
if c.TurnStartReadyWaitGroup == nil {
|
||||
return xerrors.Errorf("validate turn_start_ready_wait_group: must not be nil when turn start delay is enabled for more than one turn")
|
||||
}
|
||||
if c.StartTurnsChan == nil {
|
||||
return xerrors.Errorf("validate start_turns_chan: must not be nil when turn start delay is enabled for more than one turn")
|
||||
}
|
||||
}
|
||||
if c.Metrics == nil {
|
||||
return xerrors.Errorf("validate metrics: must not be nil")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,137 @@
|
||||
package chat
|
||||
|
||||
import "github.com/prometheus/client_golang/prometheus"
|
||||
|
||||
const (
|
||||
metricLabelPhase = "phase"
|
||||
metricLabelStatus = "status"
|
||||
metricLabelStage = "stage"
|
||||
|
||||
phaseInitial = "initial"
|
||||
phaseFollowUp = "follow_up"
|
||||
|
||||
failureStageCreateChat = "create_chat"
|
||||
failureStageCreateMessage = "create_message"
|
||||
failureStageStreamOpen = "stream_open"
|
||||
failureStageStreamEndedEarly = "stream_ended_early"
|
||||
failureStageStatusError = "status_error"
|
||||
)
|
||||
|
||||
var (
|
||||
chatRequestLatencyBuckets = prometheus.ExponentialBucketsRange(0.05, 120, 18)
|
||||
chatProcessingLatencyBuckets = prometheus.ExponentialBucketsRange(0.1, 300, 18)
|
||||
)
|
||||
|
||||
// Metrics holds the Prometheus metrics emitted by the chat scaletest.
|
||||
type Metrics struct {
|
||||
ChatCreateLatencySeconds prometheus.Histogram
|
||||
ChatMessageLatencySeconds *prometheus.HistogramVec
|
||||
ChatConversationDurationSeconds prometheus.Histogram
|
||||
ChatTimeToRunningSeconds *prometheus.HistogramVec
|
||||
ChatTimeToFirstOutputSeconds *prometheus.HistogramVec
|
||||
ChatTimeToTerminalStatusSeconds *prometheus.HistogramVec
|
||||
ChatStageFailuresTotal *prometheus.CounterVec
|
||||
ChatTerminalStatusTotal *prometheus.CounterVec
|
||||
ChatTurnsCompletedTotal prometheus.Counter
|
||||
ChatRetryEventsTotal prometheus.Counter
|
||||
ActiveChatStreams prometheus.Gauge
|
||||
}
|
||||
|
||||
func NewMetrics(reg prometheus.Registerer) *Metrics {
|
||||
if reg == nil {
|
||||
reg = prometheus.DefaultRegisterer
|
||||
}
|
||||
|
||||
phaseLabelNames := []string{metricLabelPhase}
|
||||
terminalStatusLabelNames := []string{metricLabelStatus}
|
||||
failureStageLabelNames := []string{metricLabelStage}
|
||||
|
||||
m := &Metrics{
|
||||
ChatCreateLatencySeconds: prometheus.NewHistogram(prometheus.HistogramOpts{
|
||||
Namespace: "coderd",
|
||||
Subsystem: "scaletest",
|
||||
Name: "chat_create_latency_seconds",
|
||||
Help: "Time in seconds to create a chat and enqueue the initial turn.",
|
||||
Buckets: chatRequestLatencyBuckets,
|
||||
}),
|
||||
ChatMessageLatencySeconds: prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Namespace: "coderd",
|
||||
Subsystem: "scaletest",
|
||||
Name: "chat_message_latency_seconds",
|
||||
Help: "Time in seconds to add a follow-up message to an existing chat.",
|
||||
Buckets: chatRequestLatencyBuckets,
|
||||
}, phaseLabelNames),
|
||||
ChatConversationDurationSeconds: prometheus.NewHistogram(prometheus.HistogramOpts{
|
||||
Namespace: "coderd",
|
||||
Subsystem: "scaletest",
|
||||
Name: "chat_conversation_duration_seconds",
|
||||
Help: "Time in seconds from chat creation start until the conversation finishes or errors.",
|
||||
Buckets: chatProcessingLatencyBuckets,
|
||||
}),
|
||||
ChatTimeToRunningSeconds: prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Namespace: "coderd",
|
||||
Subsystem: "scaletest",
|
||||
Name: "chat_time_to_running_seconds",
|
||||
Help: "Time in seconds from the start of a chat turn until the chat enters running status.",
|
||||
Buckets: chatProcessingLatencyBuckets,
|
||||
}, phaseLabelNames),
|
||||
ChatTimeToFirstOutputSeconds: prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Namespace: "coderd",
|
||||
Subsystem: "scaletest",
|
||||
Name: "chat_time_to_first_output_seconds",
|
||||
Help: "Time in seconds from the start of a chat turn until the first output is received.",
|
||||
Buckets: chatProcessingLatencyBuckets,
|
||||
}, phaseLabelNames),
|
||||
ChatTimeToTerminalStatusSeconds: prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Namespace: "coderd",
|
||||
Subsystem: "scaletest",
|
||||
Name: "chat_time_to_terminal_status_seconds",
|
||||
Help: "Time in seconds from the start of a chat turn until a terminal status is received.",
|
||||
Buckets: chatProcessingLatencyBuckets,
|
||||
}, phaseLabelNames),
|
||||
ChatStageFailuresTotal: prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: "coderd",
|
||||
Subsystem: "scaletest",
|
||||
Name: "chat_stage_failures_total",
|
||||
Help: "Total number of terminal stage-specific chat runner failures.",
|
||||
}, failureStageLabelNames),
|
||||
ChatTerminalStatusTotal: prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: "coderd",
|
||||
Subsystem: "scaletest",
|
||||
Name: "chat_terminal_status_total",
|
||||
Help: "Total number of terminal chat statuses observed.",
|
||||
}, terminalStatusLabelNames),
|
||||
ChatTurnsCompletedTotal: prometheus.NewCounter(prometheus.CounterOpts{
|
||||
Namespace: "coderd",
|
||||
Subsystem: "scaletest",
|
||||
Name: "chat_turns_completed_total",
|
||||
Help: "Total number of chat turns completed successfully.",
|
||||
}),
|
||||
ChatRetryEventsTotal: prometheus.NewCounter(prometheus.CounterOpts{
|
||||
Namespace: "coderd",
|
||||
Subsystem: "scaletest",
|
||||
Name: "chat_retry_events_total",
|
||||
Help: "Total number of chat retry events observed.",
|
||||
}),
|
||||
ActiveChatStreams: prometheus.NewGauge(prometheus.GaugeOpts{
|
||||
Namespace: "coderd",
|
||||
Subsystem: "scaletest",
|
||||
Name: "active_chat_streams",
|
||||
Help: "Current number of active chat streams.",
|
||||
}),
|
||||
}
|
||||
|
||||
reg.MustRegister(m.ChatCreateLatencySeconds)
|
||||
reg.MustRegister(m.ChatMessageLatencySeconds)
|
||||
reg.MustRegister(m.ChatConversationDurationSeconds)
|
||||
reg.MustRegister(m.ChatTimeToRunningSeconds)
|
||||
reg.MustRegister(m.ChatTimeToFirstOutputSeconds)
|
||||
reg.MustRegister(m.ChatTimeToTerminalStatusSeconds)
|
||||
reg.MustRegister(m.ChatStageFailuresTotal)
|
||||
reg.MustRegister(m.ChatTerminalStatusTotal)
|
||||
reg.MustRegister(m.ChatTurnsCompletedTotal)
|
||||
reg.MustRegister(m.ChatRetryEventsTotal)
|
||||
reg.MustRegister(m.ActiveChatStreams)
|
||||
|
||||
return m
|
||||
}
|
||||
@@ -0,0 +1,148 @@
|
||||
package chat
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
const (
|
||||
scaletestProviderType = "openai-compat"
|
||||
scaletestProviderDisplayName = "Scaletest LLM Mock"
|
||||
scaletestModelName = "scaletest-model"
|
||||
scaletestModelDisplayName = "Scaletest Model"
|
||||
)
|
||||
|
||||
type scaletestProviderAction string
|
||||
|
||||
const (
|
||||
scaletestProviderActionCreated scaletestProviderAction = "created"
|
||||
scaletestProviderActionUpdated scaletestProviderAction = "updated"
|
||||
scaletestProviderActionReused scaletestProviderAction = "reused"
|
||||
)
|
||||
|
||||
// EnsureScaletestModelConfig bootstraps the shared chat provider and model
|
||||
// config used by chat scaletests.
|
||||
func EnsureScaletestModelConfig(ctx context.Context, client *codersdk.ExperimentalClient, logger slog.Logger, llmMockURL string) (uuid.UUID, error) {
|
||||
logger.Info(ctx, "bootstrapping mock LLM provider", slog.F("llm_mock_url", llmMockURL))
|
||||
|
||||
provider, providerAction, err := ensureScaletestProvider(ctx, client, llmMockURL)
|
||||
if err != nil {
|
||||
return uuid.Nil, err
|
||||
}
|
||||
|
||||
switch providerAction {
|
||||
case scaletestProviderActionCreated:
|
||||
logger.Info(ctx, "created mock LLM provider",
|
||||
slog.F("provider_type", scaletestProviderType),
|
||||
slog.F("llm_mock_url", llmMockURL),
|
||||
)
|
||||
case scaletestProviderActionUpdated:
|
||||
logger.Info(ctx, "updated mock LLM provider",
|
||||
slog.F("provider_type", scaletestProviderType),
|
||||
slog.F("provider_id", provider.ID),
|
||||
slog.F("llm_mock_url", llmMockURL),
|
||||
)
|
||||
case scaletestProviderActionReused:
|
||||
logger.Info(ctx, "reusing mock LLM provider",
|
||||
slog.F("provider_type", scaletestProviderType),
|
||||
slog.F("provider_id", provider.ID),
|
||||
)
|
||||
}
|
||||
|
||||
modelConfigs, err := client.ListChatModelConfigs(ctx)
|
||||
if err != nil {
|
||||
return uuid.Nil, xerrors.Errorf("list chat model configs: %w", err)
|
||||
}
|
||||
|
||||
for i := range modelConfigs {
|
||||
if modelConfigs[i].Provider != provider.Provider || modelConfigs[i].Model != scaletestModelName {
|
||||
continue
|
||||
}
|
||||
if !modelConfigs[i].Enabled {
|
||||
return uuid.Nil, xerrors.Errorf("existing scaletest chat model config %s is disabled; re-enable or delete it before running scaletests", modelConfigs[i].ID)
|
||||
}
|
||||
modelConfigID := modelConfigs[i].ID
|
||||
logger.Info(ctx, "reusing scaletest model config", slog.F("model_config_id", modelConfigID))
|
||||
return modelConfigID, nil
|
||||
}
|
||||
|
||||
enabled := true
|
||||
isDefault := false
|
||||
contextLimit := int64(4096)
|
||||
created, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: provider.Provider,
|
||||
Model: scaletestModelName,
|
||||
DisplayName: scaletestModelDisplayName,
|
||||
Enabled: &enabled,
|
||||
IsDefault: &isDefault,
|
||||
ContextLimit: &contextLimit,
|
||||
})
|
||||
if err != nil {
|
||||
return uuid.Nil, xerrors.Errorf("create scaletest chat model config: %w", err)
|
||||
}
|
||||
logger.Info(ctx, "created scaletest model config", slog.F("model_config_id", created.ID))
|
||||
return created.ID, nil
|
||||
}
|
||||
|
||||
func ensureScaletestProvider(ctx context.Context, client *codersdk.ExperimentalClient, llmMockURL string) (codersdk.ChatProviderConfig, scaletestProviderAction, error) {
|
||||
enabled := true
|
||||
mockProviderToken := uuid.NewString()
|
||||
created, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
Provider: scaletestProviderType,
|
||||
DisplayName: scaletestProviderDisplayName,
|
||||
APIKey: mockProviderToken,
|
||||
BaseURL: llmMockURL,
|
||||
Enabled: &enabled,
|
||||
})
|
||||
if err == nil {
|
||||
return created, scaletestProviderActionCreated, nil
|
||||
}
|
||||
|
||||
var sdkErr *codersdk.Error
|
||||
if !xerrors.As(err, &sdkErr) || sdkErr.StatusCode() != http.StatusConflict {
|
||||
return codersdk.ChatProviderConfig{}, "", xerrors.Errorf("create scaletest chat provider: %w", err)
|
||||
}
|
||||
|
||||
providers, err := client.ListChatProviders(ctx)
|
||||
if err != nil {
|
||||
return codersdk.ChatProviderConfig{}, "", xerrors.Errorf("list chat providers: %w", err)
|
||||
}
|
||||
|
||||
var existing *codersdk.ChatProviderConfig
|
||||
for i := range providers {
|
||||
if providers[i].Provider == scaletestProviderType {
|
||||
existing = &providers[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
if existing == nil {
|
||||
return codersdk.ChatProviderConfig{}, "", xerrors.Errorf("find existing %s provider after conflict: not found", scaletestProviderType)
|
||||
}
|
||||
if existing.DisplayName != scaletestProviderDisplayName {
|
||||
return codersdk.ChatProviderConfig{}, "", xerrors.Errorf("refusing to overwrite existing %s provider %s with display name %q", scaletestProviderType, existing.ID, existing.DisplayName)
|
||||
}
|
||||
|
||||
if !existing.Enabled {
|
||||
return codersdk.ChatProviderConfig{}, "", xerrors.Errorf("existing scaletest chat provider %s is disabled; re-enable or delete it before running scaletests", existing.ID)
|
||||
}
|
||||
if existing.BaseURL == llmMockURL {
|
||||
return *existing, scaletestProviderActionReused, nil
|
||||
}
|
||||
|
||||
updated, err := client.UpdateChatProvider(ctx, existing.ID, codersdk.UpdateChatProviderConfigRequest{
|
||||
DisplayName: scaletestProviderDisplayName,
|
||||
APIKey: &mockProviderToken,
|
||||
BaseURL: &llmMockURL,
|
||||
Enabled: &enabled,
|
||||
})
|
||||
if err != nil {
|
||||
return codersdk.ChatProviderConfig{}, "", xerrors.Errorf("update scaletest chat provider: %w", err)
|
||||
}
|
||||
return updated, scaletestProviderActionUpdated, nil
|
||||
}
|
||||
@@ -0,0 +1,413 @@
|
||||
package chat
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/sloghuman"
|
||||
"github.com/coder/coder/v2/coderd/tracing"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/scaletest/harness"
|
||||
"github.com/coder/coder/v2/scaletest/loadtestutil"
|
||||
)
|
||||
|
||||
// Runner executes a single chat conversation as part of a scaletest run.
|
||||
type Runner struct {
|
||||
client chatClient
|
||||
cfg Config
|
||||
|
||||
chatID uuid.UUID
|
||||
result runnerResult
|
||||
|
||||
conversationStart time.Time
|
||||
turnStartTime time.Time
|
||||
currentPhase string
|
||||
lastStreamError string
|
||||
lastStatus codersdk.ChatStatus
|
||||
sawTurnRunning bool
|
||||
sawTurnFirstOutput bool
|
||||
markTurnStartReady func()
|
||||
}
|
||||
|
||||
type runnerResult struct {
|
||||
finalStatus string
|
||||
failureStage string
|
||||
totalDuration time.Duration
|
||||
sawFirstOutput bool
|
||||
retryCount int
|
||||
eventCount int
|
||||
turnsCompleted int
|
||||
}
|
||||
|
||||
var (
|
||||
_ harness.Runnable = &Runner{}
|
||||
_ harness.Cleanable = &Runner{}
|
||||
_ harness.Collectable = &Runner{}
|
||||
)
|
||||
|
||||
func NewRunner(client *codersdk.Client, cfg Config) *Runner {
|
||||
return &Runner{
|
||||
client: newChatClient(client),
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Runner) Run(ctx context.Context, id string, logs io.Writer) error {
|
||||
ctx, span := tracing.StartSpan(ctx)
|
||||
defer span.End()
|
||||
|
||||
logs = loadtestutil.NewSyncWriter(logs)
|
||||
logger := slog.Make(sloghuman.Sink(logs)).Leveled(slog.LevelDebug).Named(id)
|
||||
r.client.SetLogger(logger)
|
||||
r.client.SetLogBodies(true)
|
||||
|
||||
span.SetAttributes(
|
||||
attribute.String("chat.runner_id", id),
|
||||
attribute.String("chat.workspace_id", r.cfg.WorkspaceID.String()),
|
||||
attribute.Int("chat.turns_requested", r.cfg.Turns),
|
||||
attribute.Int64("chat.turn_start_delay_ms", r.cfg.TurnStartDelay.Milliseconds()),
|
||||
)
|
||||
span.SetAttributes(attribute.String("chat.model_config_id", r.cfg.ModelConfigID.String()))
|
||||
|
||||
markTurnStartReady := func() {}
|
||||
if r.cfg.TurnStartReadyWaitGroup != nil {
|
||||
markTurnStartReady = sync.OnceFunc(r.cfg.TurnStartReadyWaitGroup.Done)
|
||||
}
|
||||
r.markTurnStartReady = markTurnStartReady
|
||||
defer r.markTurnStartReady()
|
||||
|
||||
defer func() {
|
||||
if !r.conversationStart.IsZero() {
|
||||
r.result.totalDuration = time.Since(r.conversationStart)
|
||||
r.cfg.Metrics.ChatConversationDurationSeconds.Observe(r.result.totalDuration.Seconds())
|
||||
}
|
||||
span.SetAttributes(
|
||||
attribute.String("chat.final_status", r.result.finalStatus),
|
||||
attribute.String("chat.failure_stage", r.result.failureStage),
|
||||
attribute.Int("chat.retry_count", r.result.retryCount),
|
||||
attribute.Int("chat.turns_completed", r.result.turnsCompleted),
|
||||
attribute.Bool("chat.saw_first_output", r.result.sawFirstOutput),
|
||||
)
|
||||
if r.result.totalDuration > 0 {
|
||||
span.SetAttributes(attribute.Float64("chat.total_duration_seconds", r.result.totalDuration.Seconds()))
|
||||
}
|
||||
}()
|
||||
|
||||
workspaceID := r.cfg.WorkspaceID
|
||||
modelConfigID := r.cfg.ModelConfigID
|
||||
logger = logger.With(slog.F("workspace_id", workspaceID))
|
||||
logger.Info(ctx, "starting chat runner")
|
||||
|
||||
r.resetConversation(time.Now(), markTurnStartReady)
|
||||
|
||||
createStartedAt := time.Now()
|
||||
chat, err := r.client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
OrganizationID: r.cfg.OrganizationID,
|
||||
WorkspaceID: &workspaceID,
|
||||
ModelConfigID: &modelConfigID,
|
||||
Content: []codersdk.ChatInputPart{{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: r.cfg.Prompt,
|
||||
}},
|
||||
})
|
||||
if err != nil {
|
||||
r.result.failureStage = failureStageCreateChat
|
||||
r.cfg.Metrics.ChatStageFailuresTotal.WithLabelValues(r.result.failureStage).Inc()
|
||||
return xerrors.Errorf("create chat: %w", err)
|
||||
}
|
||||
r.cfg.Metrics.ChatCreateLatencySeconds.Observe(time.Since(createStartedAt).Seconds())
|
||||
|
||||
r.chatID = chat.ID
|
||||
span.SetAttributes(attribute.String("chat.chat_id", chat.ID.String()))
|
||||
logger = logger.With(slog.F("chat_id", chat.ID))
|
||||
logger.Info(ctx, "created chat session", slog.F("duration", time.Since(createStartedAt)))
|
||||
|
||||
// CreateChat already queues the first prompt for processing on the
|
||||
// server, so the initial turn is in flight as soon as CreateChat
|
||||
// returns. Open the stream immediately and let the conversation loop
|
||||
// drive the gate at the natural phase boundary (after the first turn
|
||||
// reaches a terminal Waiting status), rather than fencing here on a
|
||||
// turn that has already started running.
|
||||
events, closer, err := r.client.StreamChat(ctx, chat.ID, nil)
|
||||
if err != nil {
|
||||
r.result.failureStage = failureStageStreamOpen
|
||||
r.cfg.Metrics.ChatStageFailuresTotal.WithLabelValues(r.result.failureStage).Inc()
|
||||
return xerrors.Errorf("stream chat: %w", err)
|
||||
}
|
||||
|
||||
r.cfg.Metrics.ActiveChatStreams.Inc()
|
||||
defer func() {
|
||||
r.cfg.Metrics.ActiveChatStreams.Dec()
|
||||
_ = closer.Close()
|
||||
}()
|
||||
|
||||
logger.Info(ctx, "streaming chat events")
|
||||
|
||||
return r.runConversation(ctx, chat.ID, logger, events)
|
||||
}
|
||||
|
||||
func (r *Runner) resetConversation(conversationStart time.Time, markTurnStartReady func()) {
|
||||
if markTurnStartReady == nil {
|
||||
markTurnStartReady = func() {}
|
||||
}
|
||||
|
||||
r.result = runnerResult{}
|
||||
r.conversationStart = conversationStart
|
||||
r.turnStartTime = conversationStart
|
||||
r.currentPhase = phaseInitial
|
||||
r.lastStreamError = ""
|
||||
r.lastStatus = ""
|
||||
r.sawTurnRunning = false
|
||||
r.sawTurnFirstOutput = false
|
||||
r.markTurnStartReady = markTurnStartReady
|
||||
}
|
||||
|
||||
func (r *Runner) runConversation(ctx context.Context, chatID uuid.UUID, logger slog.Logger, events <-chan codersdk.ChatStreamEvent) error {
|
||||
r.chatID = chatID
|
||||
|
||||
for event := range events {
|
||||
r.result.eventCount++
|
||||
|
||||
switch event.Type {
|
||||
case codersdk.ChatStreamEventTypeStatus:
|
||||
if event.Status == nil {
|
||||
continue
|
||||
}
|
||||
done, err := r.handleStatusEvent(ctx, chatID, logger, event.Status.Status)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if done {
|
||||
return nil
|
||||
}
|
||||
case codersdk.ChatStreamEventTypeMessagePart:
|
||||
r.handleMessagePartEvent(ctx, logger)
|
||||
case codersdk.ChatStreamEventTypeMessage:
|
||||
// StreamChat replays persisted rows as message events, not
|
||||
// message_part deltas, when a turn finished server-side before
|
||||
// the stream attached. Route assistant rows through the same
|
||||
// first-output path; skip user rows so persisted prompts do not
|
||||
// count as model output.
|
||||
if event.Message == nil || event.Message.Role != codersdk.ChatMessageRoleAssistant {
|
||||
continue
|
||||
}
|
||||
r.handleMessagePartEvent(ctx, logger)
|
||||
case codersdk.ChatStreamEventTypeRetry:
|
||||
r.handleRetryEvent(ctx, logger, event.Retry)
|
||||
case codersdk.ChatStreamEventTypeError:
|
||||
r.handleErrorEvent(ctx, logger, event.Error)
|
||||
}
|
||||
}
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
r.result.failureStage = failureStageStreamEndedEarly
|
||||
r.cfg.Metrics.ChatStageFailuresTotal.WithLabelValues(r.result.failureStage).Inc()
|
||||
if r.lastStreamError != "" {
|
||||
return xerrors.Errorf("chat %s stream ended before completing %d of %d turns: %s", chatID, r.result.turnsCompleted, r.cfg.Turns, r.lastStreamError)
|
||||
}
|
||||
return xerrors.Errorf("chat %s stream ended before completing %d of %d turns", chatID, r.result.turnsCompleted, r.cfg.Turns)
|
||||
}
|
||||
|
||||
func (r *Runner) handleStatusEvent(ctx context.Context, chatID uuid.UUID, logger slog.Logger, status codersdk.ChatStatus) (bool, error) {
|
||||
if status == r.lastStatus {
|
||||
return false, nil
|
||||
}
|
||||
if status == codersdk.ChatStatusWaiting &&
|
||||
!r.sawTurnFirstOutput &&
|
||||
(r.sawTurnRunning || r.result.turnsCompleted > 0) {
|
||||
return false, nil
|
||||
}
|
||||
r.lastStatus = status
|
||||
|
||||
switch status {
|
||||
case codersdk.ChatStatusRunning:
|
||||
r.sawTurnRunning = true
|
||||
r.cfg.Metrics.ChatTimeToRunningSeconds.WithLabelValues(r.currentPhase).Observe(time.Since(r.turnStartTime).Seconds())
|
||||
logger.Info(ctx, "chat reached running status",
|
||||
slog.F("phase", r.currentPhase),
|
||||
)
|
||||
return false, nil
|
||||
case codersdk.ChatStatusWaiting:
|
||||
r.result.turnsCompleted++
|
||||
turnDuration := time.Since(r.turnStartTime)
|
||||
r.cfg.Metrics.ChatTimeToTerminalStatusSeconds.WithLabelValues(r.currentPhase).Observe(turnDuration.Seconds())
|
||||
r.cfg.Metrics.ChatTerminalStatusTotal.WithLabelValues(string(codersdk.ChatStatusWaiting)).Inc()
|
||||
r.cfg.Metrics.ChatTurnsCompletedTotal.Inc()
|
||||
logger.Info(ctx, "chat completed turn",
|
||||
slog.F("turn", r.result.turnsCompleted),
|
||||
slog.F("turns", r.cfg.Turns),
|
||||
slog.F("duration", turnDuration),
|
||||
)
|
||||
if r.result.turnsCompleted >= r.cfg.Turns {
|
||||
r.result.finalStatus = string(codersdk.ChatStatusWaiting)
|
||||
conversationDuration := time.Since(r.conversationStart)
|
||||
logger.Info(ctx, "chat reached terminal status",
|
||||
slog.F("status", codersdk.ChatStatusWaiting),
|
||||
slog.F("duration", conversationDuration),
|
||||
slog.F("turns_completed", r.result.turnsCompleted),
|
||||
)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// After the very first turn completes, mark this runner ready
|
||||
// for the CLI-coordinated turn-start gate. The inter-phase
|
||||
// delay measures the gap between every chat actually finishing its
|
||||
// initial turn and the start of the follow-up turns, not the gap
|
||||
// between CreateChat returning and the next turn.
|
||||
if r.result.turnsCompleted == 1 {
|
||||
r.markTurnStartReady()
|
||||
if r.cfg.StartTurnsChan != nil {
|
||||
logger.Info(ctx, "chat waiting for turn start release",
|
||||
slog.F("turn_start_delay", r.cfg.TurnStartDelay),
|
||||
)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false, ctx.Err()
|
||||
case <-r.cfg.StartTurnsChan:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
nextTurn := r.result.turnsCompleted + 1
|
||||
r.currentPhase = phaseFollowUp
|
||||
r.turnStartTime = time.Now()
|
||||
r.lastStreamError = ""
|
||||
r.lastStatus = ""
|
||||
r.sawTurnRunning = false
|
||||
r.sawTurnFirstOutput = false
|
||||
if err := r.sendNextTurn(ctx, chatID, logger, nextTurn, r.currentPhase); err != nil {
|
||||
r.result.failureStage = failureStageCreateMessage
|
||||
r.cfg.Metrics.ChatStageFailuresTotal.WithLabelValues(r.result.failureStage).Inc()
|
||||
return false, err
|
||||
}
|
||||
return false, nil
|
||||
case codersdk.ChatStatusError:
|
||||
r.result.finalStatus = string(codersdk.ChatStatusError)
|
||||
r.result.failureStage = failureStageStatusError
|
||||
turnDuration := time.Since(r.turnStartTime)
|
||||
r.cfg.Metrics.ChatTimeToTerminalStatusSeconds.WithLabelValues(r.currentPhase).Observe(turnDuration.Seconds())
|
||||
r.cfg.Metrics.ChatTerminalStatusTotal.WithLabelValues(string(codersdk.ChatStatusError)).Inc()
|
||||
r.cfg.Metrics.ChatStageFailuresTotal.WithLabelValues(r.result.failureStage).Inc()
|
||||
|
||||
errMessage := r.lastStreamError
|
||||
if errMessage == "" {
|
||||
errMessage = "chat reached error status"
|
||||
}
|
||||
logger.Error(ctx, "chat reached terminal status",
|
||||
slog.F("status", codersdk.ChatStatusError),
|
||||
slog.F("turns_completed", r.result.turnsCompleted),
|
||||
slog.F("turns", r.cfg.Turns),
|
||||
slog.F("error", errMessage),
|
||||
)
|
||||
return false, xerrors.Errorf("chat %s reached error status: %s", chatID, errMessage)
|
||||
default:
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Runner) sendNextTurn(ctx context.Context, chatID uuid.UUID, logger slog.Logger, nextTurn int, phase string) error {
|
||||
messageStartedAt := time.Now()
|
||||
modelConfigID := r.cfg.ModelConfigID
|
||||
_, err := r.client.CreateChatMessage(ctx, chatID, codersdk.CreateChatMessageRequest{
|
||||
Content: []codersdk.ChatInputPart{{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: r.cfg.Prompt,
|
||||
}},
|
||||
ModelConfigID: &modelConfigID,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create chat message for turn %d: %w", nextTurn, err)
|
||||
}
|
||||
|
||||
r.cfg.Metrics.ChatMessageLatencySeconds.WithLabelValues(phase).Observe(time.Since(messageStartedAt).Seconds())
|
||||
logger.Info(ctx, "chat sent message",
|
||||
slog.F("turn", nextTurn),
|
||||
slog.F("turns", r.cfg.Turns),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Runner) handleMessagePartEvent(ctx context.Context, logger slog.Logger) {
|
||||
if r.sawTurnFirstOutput {
|
||||
return
|
||||
}
|
||||
r.sawTurnFirstOutput = true
|
||||
r.result.sawFirstOutput = true
|
||||
firstOutputDuration := time.Since(r.turnStartTime)
|
||||
r.cfg.Metrics.ChatTimeToFirstOutputSeconds.WithLabelValues(r.currentPhase).Observe(firstOutputDuration.Seconds())
|
||||
logger.Info(ctx, "chat received first output",
|
||||
slog.F("phase", r.currentPhase),
|
||||
slog.F("duration", firstOutputDuration),
|
||||
)
|
||||
}
|
||||
|
||||
func (r *Runner) handleRetryEvent(ctx context.Context, logger slog.Logger, retry *codersdk.ChatStreamRetry) {
|
||||
r.result.retryCount++
|
||||
r.cfg.Metrics.ChatRetryEventsTotal.Inc()
|
||||
if retry != nil {
|
||||
logger.Warn(ctx, "chat retry event",
|
||||
slog.F("attempt", retry.Attempt),
|
||||
slog.F("delay_ms", retry.DelayMs),
|
||||
slog.F("error", retry.Error),
|
||||
)
|
||||
return
|
||||
}
|
||||
logger.Warn(ctx, "chat retry event")
|
||||
}
|
||||
|
||||
func (r *Runner) handleErrorEvent(ctx context.Context, logger slog.Logger, eventErr *codersdk.ChatError) {
|
||||
if eventErr != nil && eventErr.Message != "" {
|
||||
r.lastStreamError = eventErr.Message
|
||||
logger.Warn(ctx, "chat stream error",
|
||||
slog.F("error", r.lastStreamError),
|
||||
)
|
||||
return
|
||||
}
|
||||
logger.Warn(ctx, "chat stream error event")
|
||||
}
|
||||
|
||||
func (r *Runner) Cleanup(ctx context.Context, id string, logs io.Writer) error {
|
||||
if r.chatID == uuid.Nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
logs = loadtestutil.NewSyncWriter(logs)
|
||||
logger := slog.Make(sloghuman.Sink(logs)).Leveled(slog.LevelDebug).Named(id).With(slog.F("chat_id", r.chatID))
|
||||
r.client.SetLogger(logger)
|
||||
r.client.SetLogBodies(true)
|
||||
|
||||
archived := true
|
||||
logger.Info(ctx, "archiving chat session")
|
||||
if err := r.client.UpdateChat(ctx, r.chatID, codersdk.UpdateChatRequest{Archived: &archived}); err != nil {
|
||||
logger.Error(ctx, "failed to archive chat", slog.Error(err))
|
||||
return xerrors.Errorf("archive chat: %w", err)
|
||||
}
|
||||
logger.Info(ctx, "archived chat session")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Runner) GetMetrics() map[string]any {
|
||||
return map[string]any{
|
||||
"workspace_id": r.cfg.WorkspaceID.String(),
|
||||
"turn_start_delay_ms": r.cfg.TurnStartDelay.Milliseconds(),
|
||||
"chat_id": r.chatID.String(),
|
||||
"final_status": r.result.finalStatus,
|
||||
"failure_stage": r.result.failureStage,
|
||||
"total_duration_seconds": r.result.totalDuration.Seconds(),
|
||||
"saw_first_output": r.result.sawFirstOutput,
|
||||
"retry_count": r.result.retryCount,
|
||||
"event_count": r.result.eventCount,
|
||||
"turns_requested": r.cfg.Turns,
|
||||
"turns_completed": r.result.turnsCompleted,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,391 @@
|
||||
package chat
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/sloghuman"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
func TestRunnerRunConversation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
chatID := uuid.MustParse("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa")
|
||||
noopMarkTurnStartReady := func() {}
|
||||
|
||||
t.Run("OneTurnHappyPath", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
runner := newTestRunner(t, newRunConfig(t))
|
||||
events := make(chan codersdk.ChatStreamEvent, 3)
|
||||
events <- statusEvent(chatID, codersdk.ChatStatusRunning)
|
||||
events <- messagePartEvent(chatID)
|
||||
events <- statusEvent(chatID, codersdk.ChatStatusWaiting)
|
||||
close(events)
|
||||
|
||||
err := runTestConversation(t, runner, chatID, events, noopMarkTurnStartReady)
|
||||
require.NoError(t, err)
|
||||
result := runner.result
|
||||
require.Equal(t, string(codersdk.ChatStatusWaiting), result.finalStatus)
|
||||
require.Empty(t, result.failureStage)
|
||||
require.True(t, result.sawFirstOutput)
|
||||
require.Equal(t, 1, result.turnsCompleted)
|
||||
require.Equal(t, 3, result.eventCount)
|
||||
})
|
||||
|
||||
t.Run("DuplicateWaitingDoesNotAdvanceTurn", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cfg := newRunConfig(t)
|
||||
cfg.Turns = 2
|
||||
|
||||
events := make(chan codersdk.ChatStreamEvent, 7)
|
||||
events <- statusEvent(chatID, codersdk.ChatStatusRunning)
|
||||
events <- messagePartEvent(chatID)
|
||||
events <- statusEvent(chatID, codersdk.ChatStatusWaiting)
|
||||
events <- statusEvent(chatID, codersdk.ChatStatusWaiting)
|
||||
|
||||
var sendCount atomic.Int64
|
||||
runner := newTestRunnerWithChatMessage(t, cfg, chatID, func() {
|
||||
sendCount.Add(1)
|
||||
events <- statusEvent(chatID, codersdk.ChatStatusRunning)
|
||||
events <- messagePartEvent(chatID)
|
||||
events <- statusEvent(chatID, codersdk.ChatStatusWaiting)
|
||||
close(events)
|
||||
})
|
||||
|
||||
err := runTestConversation(t, runner, chatID, events, noopMarkTurnStartReady)
|
||||
require.NoError(t, err)
|
||||
result := runner.result
|
||||
require.Equal(t, int64(1), sendCount.Load())
|
||||
require.Equal(t, 2, result.turnsCompleted)
|
||||
require.Equal(t, 7, result.eventCount)
|
||||
require.Equal(t, string(codersdk.ChatStatusWaiting), result.finalStatus)
|
||||
})
|
||||
|
||||
t.Run("StaleWaitingAfterNextTurnRunningDoesNotAdvanceTurn", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cfg := newRunConfig(t)
|
||||
cfg.Turns = 2
|
||||
|
||||
events := make(chan codersdk.ChatStreamEvent, 7)
|
||||
events <- statusEvent(chatID, codersdk.ChatStatusRunning)
|
||||
events <- messagePartEvent(chatID)
|
||||
events <- statusEvent(chatID, codersdk.ChatStatusWaiting)
|
||||
|
||||
var sendCount atomic.Int64
|
||||
runner := newTestRunnerWithChatMessage(t, cfg, chatID, func() {
|
||||
sendCount.Add(1)
|
||||
events <- statusEvent(chatID, codersdk.ChatStatusRunning)
|
||||
events <- statusEvent(chatID, codersdk.ChatStatusWaiting)
|
||||
events <- messagePartEvent(chatID)
|
||||
events <- statusEvent(chatID, codersdk.ChatStatusWaiting)
|
||||
close(events)
|
||||
})
|
||||
|
||||
err := runTestConversation(t, runner, chatID, events, noopMarkTurnStartReady)
|
||||
require.NoError(t, err)
|
||||
result := runner.result
|
||||
require.Equal(t, int64(1), sendCount.Load())
|
||||
require.Equal(t, 2, result.turnsCompleted)
|
||||
require.Equal(t, 7, result.eventCount)
|
||||
require.Equal(t, string(codersdk.ChatStatusWaiting), result.finalStatus)
|
||||
})
|
||||
|
||||
t.Run("FirstTurnGatesFollowUpStorm", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Reproduces the contract that the turn-start gate is checked
|
||||
// after the first turn finishes, not before it begins. The runner
|
||||
// must mark itself ready, wait for the release channel, and only
|
||||
// then send turn 2.
|
||||
cfg := newRunConfig(t)
|
||||
cfg.Turns = 2
|
||||
readyWG := &sync.WaitGroup{}
|
||||
readyWG.Add(1)
|
||||
releaseChan := make(chan struct{})
|
||||
cfg.TurnStartReadyWaitGroup = readyWG
|
||||
cfg.StartTurnsChan = releaseChan
|
||||
|
||||
events := make(chan codersdk.ChatStreamEvent, 4)
|
||||
events <- statusEvent(chatID, codersdk.ChatStatusRunning)
|
||||
events <- messagePartEvent(chatID)
|
||||
events <- statusEvent(chatID, codersdk.ChatStatusWaiting)
|
||||
|
||||
ready := make(chan struct{})
|
||||
go func() {
|
||||
readyWG.Wait()
|
||||
close(ready)
|
||||
}()
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
var sendCount atomic.Int64
|
||||
runner := newTestRunnerWithChatMessage(t, cfg, chatID, func() {
|
||||
sendCount.Add(1)
|
||||
events <- statusEvent(chatID, codersdk.ChatStatusRunning)
|
||||
events <- messagePartEvent(chatID)
|
||||
events <- statusEvent(chatID, codersdk.ChatStatusWaiting)
|
||||
close(events)
|
||||
})
|
||||
|
||||
runner.resetConversation(time.Now(), sync.OnceFunc(readyWG.Done))
|
||||
|
||||
go func() {
|
||||
runErr := runner.runConversation(context.Background(), chatID, testLogger(), events)
|
||||
errCh <- runErr
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ready:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("runner did not mark turn-start gate ready after first turn")
|
||||
}
|
||||
|
||||
require.Equal(t, int64(0), sendCount.Load(), "next turn was sent before turn-start release")
|
||||
|
||||
close(releaseChan)
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
require.NoError(t, err)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("runner did not finish after turn-start release")
|
||||
}
|
||||
require.Equal(t, int64(1), sendCount.Load())
|
||||
})
|
||||
|
||||
t.Run("FirstOutputFromAssistantMessageEvent", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Snapshot race: when a turn finishes before stream attach,
|
||||
// StreamChat replays rows as message events, never as
|
||||
// message_part deltas; the assistant row must record first output.
|
||||
runner := newTestRunner(t, newRunConfig(t))
|
||||
events := make(chan codersdk.ChatStreamEvent, 3)
|
||||
events <- messageEvent(chatID, codersdk.ChatMessageRoleUser)
|
||||
events <- messageEvent(chatID, codersdk.ChatMessageRoleAssistant)
|
||||
events <- statusEvent(chatID, codersdk.ChatStatusWaiting)
|
||||
close(events)
|
||||
|
||||
err := runTestConversation(t, runner, chatID, events, noopMarkTurnStartReady)
|
||||
require.NoError(t, err)
|
||||
result := runner.result
|
||||
require.True(t, result.sawFirstOutput, "first output not recorded from assistant message event")
|
||||
require.Equal(t, 1, result.turnsCompleted)
|
||||
require.Equal(t, string(codersdk.ChatStatusWaiting), result.finalStatus)
|
||||
})
|
||||
|
||||
t.Run("ImmediateWaitingCountsNextTurn", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cfg := newRunConfig(t)
|
||||
cfg.Turns = 2
|
||||
|
||||
events := make(chan codersdk.ChatStreamEvent, 3)
|
||||
events <- statusEvent(chatID, codersdk.ChatStatusWaiting)
|
||||
|
||||
var sendCount atomic.Int64
|
||||
runner := newTestRunnerWithChatMessage(t, cfg, chatID, func() {
|
||||
sendCount.Add(1)
|
||||
events <- statusEvent(chatID, codersdk.ChatStatusRunning)
|
||||
events <- messagePartEvent(chatID)
|
||||
events <- statusEvent(chatID, codersdk.ChatStatusWaiting)
|
||||
close(events)
|
||||
})
|
||||
|
||||
err := runTestConversation(t, runner, chatID, events, noopMarkTurnStartReady)
|
||||
require.NoError(t, err)
|
||||
result := runner.result
|
||||
require.Equal(t, int64(1), sendCount.Load())
|
||||
require.Equal(t, 2, result.turnsCompleted)
|
||||
require.Equal(t, string(codersdk.ChatStatusWaiting), result.finalStatus)
|
||||
})
|
||||
}
|
||||
|
||||
func runTestConversation(t *testing.T, runner *Runner, chatID uuid.UUID, events <-chan codersdk.ChatStreamEvent, markTurnStartReady func()) error {
|
||||
t.Helper()
|
||||
runner.resetConversation(time.Now(), markTurnStartReady)
|
||||
return runner.runConversation(context.Background(), chatID, testLogger(), events)
|
||||
}
|
||||
|
||||
func TestRunnerCleanup(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
chatID := uuid.MustParse("22222222-2222-2222-2222-222222222222")
|
||||
|
||||
t.Run("ArchivesChat", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
runner, archived := newTestRunnerWithChatArchive(t, chatID, nil)
|
||||
|
||||
logs := bytes.NewBuffer(nil)
|
||||
err := runner.Cleanup(context.Background(), "runner-1", logs)
|
||||
require.NoError(t, err)
|
||||
require.True(t, archived())
|
||||
require.Contains(t, logs.String(), "archived chat")
|
||||
})
|
||||
|
||||
t.Run("ArchiveErrorIsReturned", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
runner, archived := newTestRunnerWithChatArchive(t, chatID, xerrors.New("boom"))
|
||||
|
||||
err := runner.Cleanup(context.Background(), "runner-1", bytes.NewBuffer(nil))
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "archive chat")
|
||||
require.True(t, archived())
|
||||
})
|
||||
}
|
||||
|
||||
func testLogger() slog.Logger {
|
||||
return slog.Make(sloghuman.Sink(io.Discard)).Leveled(slog.LevelDebug)
|
||||
}
|
||||
|
||||
func newRunConfig(t *testing.T) Config {
|
||||
t.Helper()
|
||||
reg := prometheus.NewRegistry()
|
||||
return Config{
|
||||
OrganizationID: uuid.MustParse("22222222-2222-2222-2222-222222222222"),
|
||||
WorkspaceID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
|
||||
ModelConfigID: uuid.MustParse("33333333-3333-3333-3333-333333333333"),
|
||||
Prompt: "Reply with one short sentence.",
|
||||
Turns: 1,
|
||||
Metrics: NewMetrics(reg),
|
||||
}
|
||||
}
|
||||
|
||||
type fakeChatClient struct {
|
||||
createChatFunc func(context.Context, codersdk.CreateChatRequest) (codersdk.Chat, error)
|
||||
streamChatFunc func(context.Context, uuid.UUID, *codersdk.StreamChatOptions) (<-chan codersdk.ChatStreamEvent, io.Closer, error)
|
||||
createChatMessageFunc func(context.Context, uuid.UUID, codersdk.CreateChatMessageRequest) (codersdk.CreateChatMessageResponse, error)
|
||||
updateChatFunc func(context.Context, uuid.UUID, codersdk.UpdateChatRequest) error
|
||||
}
|
||||
|
||||
func newFakeChatClient(t *testing.T) *fakeChatClient {
|
||||
t.Helper()
|
||||
return &fakeChatClient{}
|
||||
}
|
||||
|
||||
func (*fakeChatClient) SetLogger(logger slog.Logger) {}
|
||||
|
||||
func (*fakeChatClient) SetLogBodies(logBodies bool) {}
|
||||
|
||||
func (f *fakeChatClient) CreateChat(ctx context.Context, req codersdk.CreateChatRequest) (codersdk.Chat, error) {
|
||||
if f.createChatFunc == nil {
|
||||
return codersdk.Chat{}, xerrors.New("unexpected CreateChat call")
|
||||
}
|
||||
return f.createChatFunc(ctx, req)
|
||||
}
|
||||
|
||||
func (f *fakeChatClient) StreamChat(ctx context.Context, chatID uuid.UUID, opts *codersdk.StreamChatOptions) (<-chan codersdk.ChatStreamEvent, io.Closer, error) {
|
||||
if f.streamChatFunc == nil {
|
||||
return nil, nil, xerrors.New("unexpected StreamChat call")
|
||||
}
|
||||
return f.streamChatFunc(ctx, chatID, opts)
|
||||
}
|
||||
|
||||
func (f *fakeChatClient) CreateChatMessage(ctx context.Context, chatID uuid.UUID, req codersdk.CreateChatMessageRequest) (codersdk.CreateChatMessageResponse, error) {
|
||||
if f.createChatMessageFunc == nil {
|
||||
return codersdk.CreateChatMessageResponse{}, xerrors.New("unexpected CreateChatMessage call")
|
||||
}
|
||||
return f.createChatMessageFunc(ctx, chatID, req)
|
||||
}
|
||||
|
||||
func (f *fakeChatClient) UpdateChat(ctx context.Context, chatID uuid.UUID, req codersdk.UpdateChatRequest) error {
|
||||
if f.updateChatFunc == nil {
|
||||
return xerrors.New("unexpected UpdateChat call")
|
||||
}
|
||||
return f.updateChatFunc(ctx, chatID, req)
|
||||
}
|
||||
|
||||
var _ chatClient = (*fakeChatClient)(nil)
|
||||
|
||||
func newTestRunner(t *testing.T, cfg Config) *Runner {
|
||||
t.Helper()
|
||||
return &Runner{client: newFakeChatClient(t), cfg: cfg}
|
||||
}
|
||||
|
||||
func newTestRunnerWithChatArchive(t *testing.T, chatID uuid.UUID, updateErr error) (*Runner, func() bool) {
|
||||
t.Helper()
|
||||
|
||||
var archived atomic.Bool
|
||||
client := newFakeChatClient(t)
|
||||
client.updateChatFunc = func(ctx context.Context, gotChatID uuid.UUID, req codersdk.UpdateChatRequest) error {
|
||||
if gotChatID != chatID {
|
||||
return xerrors.Errorf("unexpected chat archive ID: %s", gotChatID)
|
||||
}
|
||||
if req.Archived == nil || !*req.Archived {
|
||||
return xerrors.Errorf("unexpected archived value: %v", req.Archived)
|
||||
}
|
||||
archived.Store(true)
|
||||
return updateErr
|
||||
}
|
||||
runner := &Runner{client: client, cfg: Config{}, chatID: chatID}
|
||||
return runner, archived.Load
|
||||
}
|
||||
|
||||
func newTestRunnerWithChatMessage(t *testing.T, cfg Config, chatID uuid.UUID, onMessage func()) *Runner {
|
||||
t.Helper()
|
||||
|
||||
client := newFakeChatClient(t)
|
||||
client.createChatMessageFunc = func(ctx context.Context, gotChatID uuid.UUID, req codersdk.CreateChatMessageRequest) (codersdk.CreateChatMessageResponse, error) {
|
||||
if gotChatID != chatID {
|
||||
return codersdk.CreateChatMessageResponse{}, xerrors.Errorf("unexpected chat message ID: %s", gotChatID)
|
||||
}
|
||||
if err := validatePromptParts(req.Content, cfg.Prompt); err != nil {
|
||||
return codersdk.CreateChatMessageResponse{}, err
|
||||
}
|
||||
if req.ModelConfigID == nil || *req.ModelConfigID != cfg.ModelConfigID {
|
||||
return codersdk.CreateChatMessageResponse{}, xerrors.Errorf("unexpected chat message model config ID: %v", req.ModelConfigID)
|
||||
}
|
||||
|
||||
if onMessage != nil {
|
||||
onMessage()
|
||||
}
|
||||
return codersdk.CreateChatMessageResponse{Queued: true}, nil
|
||||
}
|
||||
return &Runner{client: client, cfg: cfg}
|
||||
}
|
||||
|
||||
func validatePromptParts(parts []codersdk.ChatInputPart, prompt string) error {
|
||||
if len(parts) != 1 || parts[0].Type != codersdk.ChatInputPartTypeText || parts[0].Text != prompt {
|
||||
return xerrors.Errorf("unexpected chat message content: %#v", parts)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func statusEvent(chatID uuid.UUID, status codersdk.ChatStatus) codersdk.ChatStreamEvent {
|
||||
return codersdk.ChatStreamEvent{
|
||||
Type: codersdk.ChatStreamEventTypeStatus,
|
||||
ChatID: chatID,
|
||||
Status: &codersdk.ChatStreamStatus{Status: status},
|
||||
}
|
||||
}
|
||||
|
||||
func messagePartEvent(chatID uuid.UUID) codersdk.ChatStreamEvent {
|
||||
return codersdk.ChatStreamEvent{
|
||||
Type: codersdk.ChatStreamEventTypeMessagePart,
|
||||
ChatID: chatID,
|
||||
}
|
||||
}
|
||||
|
||||
func messageEvent(chatID uuid.UUID, role codersdk.ChatMessageRole) codersdk.ChatStreamEvent {
|
||||
return codersdk.ChatStreamEvent{
|
||||
Type: codersdk.ChatStreamEventTypeMessage,
|
||||
ChatID: chatID,
|
||||
Message: &codersdk.ChatMessage{Role: role},
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user