Files
coder/coderd/x/chatd/chatd_test.go
T
Michael Suchacz a554de372a fix: use per-chat plan file paths (#24268)
> This PR was authored by Mux on behalf of Mike.

Chats sharing one workspace (e.g. sibling subagents) all wrote to
`/home/coder/PLAN.md`, causing plan file collisions. This change derives
a unique plan path per chat from the workspace home directory and chat
ID.

## Changes

* `write_file`, `edit_files`, and `propose_plan` reject any `plan.md`
variant (case-insensitive) at the workspace home root, with a clear
error pointing to the chat-specific path.
* Root chats receive a `<plan-file-path>` block inlined in the main
system prompt with the concrete path.
* Prompt and tool descriptions no longer hardcode `/home/coder/PLAN.md`.
* Plan path handling is POSIX-only (forward-slash), relying on the
contract that workspace agent paths are normalized before reaching
chatd.
* Updated `ProposePlanTool.stories.tsx` to use per-chat path examples.
* Full test coverage for plan path detection, legacy-path rejection in
all three tools, inline prompt rendering, and fallback behavior.
2026-04-14 10:50:40 +02:00

5776 lines
187 KiB
Go

package chatd_test
import (
"context"
"database/sql"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/google/uuid"
mcpgo "github.com/mark3labs/mcp-go/mcp"
mcpserver "github.com/mark3labs/mcp-go/server"
"github.com/prometheus/client_golang/prometheus"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"golang.org/x/xerrors"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/agent/agentcontextconfig"
"github.com/coder/coder/v2/agent/agenttest"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbfake"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/database/dbtime"
dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub"
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/coderd/util/slice"
"github.com/coder/coder/v2/coderd/workspacestats"
"github.com/coder/coder/v2/coderd/x/chatd"
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
"github.com/coder/coder/v2/coderd/x/chatd/chattest"
"github.com/coder/coder/v2/coderd/x/chatd/chattool"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock"
"github.com/coder/coder/v2/provisioner/echo"
proto "github.com/coder/coder/v2/provisionersdk/proto"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
)
func TestInterruptChatBroadcastsStatusAcrossInstances(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
replicaA := newTestServer(t, db, ps, uuid.New())
replicaB := newTestServer(t, db, ps, uuid.New())
ctx := testutil.Context(t, testutil.WaitLong)
user, org, model := seedChatDependencies(ctx, t, db)
chat, err := replicaA.CreateChat(ctx, chatd.CreateOptions{
OrganizationID: org.ID,
OwnerID: user.ID,
Title: "interrupt-me",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
})
require.NoError(t, err)
runningWorker := uuid.New()
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
ID: chat.ID,
Status: database.ChatStatusRunning,
WorkerID: uuid.NullUUID{UUID: runningWorker, Valid: true},
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
})
require.NoError(t, err)
_, events, cancel, ok := replicaB.Subscribe(ctx, chat.ID, nil, 0)
require.True(t, ok)
t.Cleanup(cancel)
updated := replicaA.InterruptChat(ctx, chat)
require.Equal(t, database.ChatStatusWaiting, updated.Status)
require.False(t, updated.WorkerID.Valid)
require.Eventually(t, func() bool {
select {
case event := <-events:
if event.Type == codersdk.ChatStreamEventTypeStatus && event.Status != nil {
return event.Status.Status == codersdk.ChatStatusWaiting
}
t.Logf("skipping unexpected event: type=%s", event.Type)
return false
default:
return false
}
}, testutil.WaitMedium, testutil.IntervalFast)
}
func TestSubagentChatExcludesWorkspaceProvisioningTools(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
deploymentValues := coderdtest.DeploymentValues(t)
deploymentValues.Experiments = []string{string(codersdk.ExperimentAgents)}
client := coderdtest.New(t, &coderdtest.Options{
DeploymentValues: deploymentValues,
IncludeProvisionerDaemon: true,
})
user := coderdtest.CreateFirstUser(t, client)
expClient := codersdk.NewExperimentalClient(client)
agentToken := uuid.NewString()
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
Parse: echo.ParseComplete,
ProvisionPlan: echo.PlanComplete,
ProvisionApply: echo.ApplyComplete,
ProvisionGraph: echo.ProvisionGraphWithAgent(agentToken),
})
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
_ = agenttest.New(t, client.URL, agentToken)
// Track tools sent in LLM requests. The first call is for the
// root chat which spawns a subagent; the second call is for the
// subagent itself.
var toolsMu sync.Mutex
toolsByCall := make([][]string, 0, 2)
var callCount atomic.Int32
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
if !req.Stream {
return chattest.OpenAINonStreamingResponse("ok")
}
names := make([]string, 0, len(req.Tools))
for _, tool := range req.Tools {
names = append(names, tool.Function.Name)
}
toolsMu.Lock()
toolsByCall = append(toolsByCall, names)
toolsMu.Unlock()
if callCount.Add(1) == 1 {
// Root chat: model calls spawn_agent.
return chattest.OpenAIStreamingResponse(
chattest.OpenAIToolCallChunk("spawn_agent", `{"prompt":"do the thing","title":"sub"}`),
)
}
// Subsequent calls (including the subagent): just reply.
// Include literal \u0000 in the response text, which is
// what a real LLM writes when explaining binary output.
// json.Marshal encodes the backslash as \\, producing
// \\u0000 in the JSON bytes. The sanitizer must not
// corrupt this into invalid JSON.
return chattest.OpenAIStreamingResponse(
chattest.OpenAITextChunks("The file contains \\u0000 null bytes.")...,
)
})
_, err := expClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
Provider: "openai-compat",
APIKey: "test-api-key",
BaseURL: openAIURL,
})
require.NoError(t, err)
contextLimit := int64(4096)
isDefault := true
_, err = expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: "openai-compat",
Model: "gpt-4o-mini",
ContextLimit: &contextLimit,
IsDefault: &isDefault,
})
require.NoError(t, err)
// Create a root chat whose first model call will spawn a subagent.
chat, err := expClient.CreateChat(ctx, codersdk.CreateChatRequest{
OrganizationID: user.OrganizationID,
Content: []codersdk.ChatInputPart{
{
Type: codersdk.ChatInputPartTypeText,
Text: "Spawn a subagent to do the thing.",
},
},
})
require.NoError(t, err)
// Wait for the root chat AND the subagent to finish.
// The root chat finishes first, then the chatd server
// picks up and runs the child (subagent) chat.
require.Eventually(t, func() bool {
got, getErr := expClient.GetChat(ctx, chat.ID)
if getErr != nil {
return false
}
if got.Status != codersdk.ChatStatusWaiting && got.Status != codersdk.ChatStatusError {
return false
}
// Also ensure the subagent LLM call has been made.
toolsMu.Lock()
n := len(toolsByCall)
toolsMu.Unlock()
// Expect at least 3 calls: root-1 (spawn_agent), child-1, root-2.
return n >= 3
}, testutil.WaitLong, testutil.IntervalFast)
// There should be at least two streamed calls: one for the root
// chat and one for the subagent child chat.
toolsMu.Lock()
recorded := append([][]string(nil), toolsByCall...)
toolsMu.Unlock()
require.GreaterOrEqual(t, len(recorded), 2,
"expected at least 2 streamed LLM calls (root + subagent)")
workspaceTools := []string{"propose_plan", "list_templates", "read_template", "create_workspace"}
subagentTools := []string{"spawn_agent", "wait_agent", "message_agent", "close_agent"}
// Identify root and subagent calls. Root chat calls include
// spawn_agent; the subagent call does not. Because the root chat
// makes multiple LLM calls (before and after spawn_agent), we
// find exactly one call that lacks spawn_agent — that's the
// subagent.
var rootCalls, childCalls [][]string
for _, tools := range recorded {
hasSpawnAgent := slice.Contains(tools, "spawn_agent")
if hasSpawnAgent {
rootCalls = append(rootCalls, tools)
} else {
childCalls = append(childCalls, tools)
}
}
require.NotEmpty(t, rootCalls, "expected at least one root chat LLM call")
require.NotEmpty(t, childCalls, "expected at least one subagent LLM call")
// Root chat calls must include workspace and subagent tools.
for _, tool := range workspaceTools {
require.Contains(t, rootCalls[0], tool,
"root chat should have workspace tool %q", tool)
}
for _, tool := range subagentTools {
require.Contains(t, rootCalls[0], tool,
"root chat should have subagent tool %q", tool)
}
// Subagent calls must NOT include workspace or subagent tools.
for _, tool := range workspaceTools {
require.NotContains(t, childCalls[0], tool,
"subagent chat should NOT have workspace tool %q", tool)
}
for _, tool := range subagentTools {
require.NotContains(t, childCalls[0], tool,
"subagent chat should NOT have subagent tool %q", tool)
}
}
func TestInterruptChatClearsWorkerInDatabase(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
replica := newTestServer(t, db, ps, uuid.New())
ctx := testutil.Context(t, testutil.WaitLong)
user, org, model := seedChatDependencies(ctx, t, db)
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
OrganizationID: org.ID,
OwnerID: user.ID,
Title: "db-transition",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
})
require.NoError(t, err)
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
ID: chat.ID,
Status: database.ChatStatusRunning,
WorkerID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
})
require.NoError(t, err)
updated := replica.InterruptChat(ctx, chat)
require.Equal(t, database.ChatStatusWaiting, updated.Status)
require.False(t, updated.WorkerID.Valid)
fromDB, err := db.GetChatByID(ctx, chat.ID)
require.NoError(t, err)
require.Equal(t, database.ChatStatusWaiting, fromDB.Status)
require.False(t, fromDB.WorkerID.Valid)
}
func TestArchiveChatMovesPendingChatToWaiting(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
replica := newTestServer(t, db, ps, uuid.New())
ctx := testutil.Context(t, testutil.WaitLong)
user, org, model := seedChatDependencies(ctx, t, db)
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
OwnerID: user.ID,
OrganizationID: org.ID,
Title: "archive-pending",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
})
require.NoError(t, err)
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
ID: chat.ID,
Status: database.ChatStatusPending,
WorkerID: uuid.NullUUID{},
StartedAt: sql.NullTime{},
HeartbeatAt: sql.NullTime{},
LastError: sql.NullString{},
})
require.NoError(t, err)
err = replica.ArchiveChat(ctx, chat)
require.NoError(t, err)
fromDB, err := db.GetChatByID(ctx, chat.ID)
require.NoError(t, err)
require.Equal(t, database.ChatStatusWaiting, fromDB.Status)
require.False(t, fromDB.WorkerID.Valid)
require.False(t, fromDB.StartedAt.Valid)
require.False(t, fromDB.HeartbeatAt.Valid)
require.True(t, fromDB.Archived)
require.Zero(t, fromDB.PinOrder)
}
func TestArchiveChatInterruptsActiveProcessing(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitLong)
streamStarted := make(chan struct{})
streamCanceled := make(chan struct{})
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
if !req.Stream {
return chattest.OpenAINonStreamingResponse("title")
}
chunks := make(chan chattest.OpenAIChunk, 1)
go func() {
defer close(chunks)
chunks <- chattest.OpenAITextChunks("partial")[0]
select {
case <-streamStarted:
default:
close(streamStarted)
}
<-req.Context().Done()
select {
case <-streamCanceled:
default:
close(streamCanceled)
}
}()
return chattest.OpenAIResponse{StreamingChunks: chunks}
})
server := newActiveTestServer(t, db, ps)
user, org, model := seedChatDependencies(ctx, t, db)
setOpenAIProviderBaseURL(ctx, t, db, openAIURL)
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
OwnerID: user.ID,
OrganizationID: org.ID,
Title: "archive-interrupt",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
})
require.NoError(t, err)
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
fromDB, dbErr := db.GetChatByID(ctx, chat.ID)
if dbErr != nil {
return false
}
return fromDB.Status == database.ChatStatusRunning && fromDB.WorkerID.Valid
}, testutil.IntervalFast)
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
select {
case <-streamStarted:
return true
default:
return false
}
}, testutil.IntervalFast)
_, events, cancel, ok := server.Subscribe(ctx, chat.ID, nil, 0)
require.True(t, ok)
defer cancel()
queuedResult, err := server.SendMessage(ctx, chatd.SendMessageOptions{
ChatID: chat.ID,
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("queued")},
BusyBehavior: chatd.SendMessageBusyBehaviorQueue,
})
require.NoError(t, err)
require.True(t, queuedResult.Queued)
require.NotNil(t, queuedResult.QueuedMessage)
err = server.ArchiveChat(ctx, chat)
require.NoError(t, err)
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
select {
case <-streamCanceled:
return true
default:
return false
}
}, testutil.IntervalFast)
gotWaitingStatus := false
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
for {
select {
case ev := <-events:
if ev.Type == codersdk.ChatStreamEventTypeStatus &&
ev.Status != nil &&
ev.Status.Status == codersdk.ChatStatusWaiting {
gotWaitingStatus = true
return true
}
default:
return gotWaitingStatus
}
}
}, testutil.IntervalFast)
require.True(t, gotWaitingStatus, "expected a waiting status event after archive")
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
fromDB, dbErr := db.GetChatByID(ctx, chat.ID)
if dbErr != nil {
return false
}
return fromDB.Archived &&
fromDB.Status == database.ChatStatusWaiting &&
!fromDB.WorkerID.Valid &&
!fromDB.StartedAt.Valid &&
!fromDB.HeartbeatAt.Valid
}, testutil.IntervalFast)
queuedMessages, err := db.GetChatQueuedMessages(ctx, chat.ID)
require.NoError(t, err)
require.Len(t, queuedMessages, 1)
require.Equal(t, queuedResult.QueuedMessage.ID, queuedMessages[0].ID)
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: chat.ID,
AfterID: 0,
})
require.NoError(t, err)
userMessages := 0
for _, msg := range messages {
if msg.Role == database.ChatMessageRoleUser {
userMessages++
}
}
require.Equal(t, 1, userMessages, "expected queued message to stay queued after archive")
}
func TestUpdateChatHeartbeatsRequiresOwnership(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
replica := newTestServer(t, db, ps, uuid.New())
ctx := testutil.Context(t, testutil.WaitLong)
user, org, model := seedChatDependencies(ctx, t, db)
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
OrganizationID: org.ID,
OwnerID: user.ID,
Title: "heartbeat-ownership",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
})
require.NoError(t, err)
workerID := uuid.New()
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
ID: chat.ID,
Status: database.ChatStatusRunning,
WorkerID: uuid.NullUUID{UUID: workerID, Valid: true},
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
})
require.NoError(t, err)
// Wrong worker_id should return no IDs.
ids, err := db.UpdateChatHeartbeats(ctx, database.UpdateChatHeartbeatsParams{
IDs: []uuid.UUID{chat.ID},
WorkerID: uuid.New(),
Now: time.Now(),
})
require.NoError(t, err)
require.Empty(t, ids)
// Correct worker_id should return the chat's ID.
ids, err = db.UpdateChatHeartbeats(ctx, database.UpdateChatHeartbeatsParams{
IDs: []uuid.UUID{chat.ID},
WorkerID: workerID,
Now: time.Now(),
})
require.NoError(t, err)
require.Len(t, ids, 1)
require.Equal(t, chat.ID, ids[0])
}
func TestSendMessageQueueBehaviorQueuesWhenBusy(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
replica := newTestServer(t, db, ps, uuid.New())
ctx := testutil.Context(t, testutil.WaitLong)
user, org, model := seedChatDependencies(ctx, t, db)
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
OrganizationID: org.ID,
OwnerID: user.ID,
Title: "queue-when-busy",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
})
require.NoError(t, err)
workerID := uuid.New()
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
ID: chat.ID,
Status: database.ChatStatusRunning,
WorkerID: uuid.NullUUID{UUID: workerID, Valid: true},
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
})
require.NoError(t, err)
result, err := replica.SendMessage(ctx, chatd.SendMessageOptions{
ChatID: chat.ID,
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("queued")},
BusyBehavior: chatd.SendMessageBusyBehaviorQueue,
})
require.NoError(t, err)
require.True(t, result.Queued)
require.NotNil(t, result.QueuedMessage)
require.Equal(t, database.ChatStatusRunning, result.Chat.Status)
require.Equal(t, workerID, result.Chat.WorkerID.UUID)
require.True(t, result.Chat.WorkerID.Valid)
queued, err := db.GetChatQueuedMessages(ctx, chat.ID)
require.NoError(t, err)
require.Len(t, queued, 1)
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: chat.ID,
AfterID: 0,
})
require.NoError(t, err)
require.Len(t, messages, 1)
}
func TestSendMessageQueuesWhenWaitingWithQueuedBacklog(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
replica := newTestServer(t, db, ps, uuid.New())
ctx := testutil.Context(t, testutil.WaitLong)
user, org, model := seedChatDependencies(ctx, t, db)
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
OrganizationID: org.ID,
OwnerID: user.ID,
Title: "queue-when-waiting-with-backlog",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
})
require.NoError(t, err)
queuedContent, err := json.Marshal([]codersdk.ChatMessagePart{
codersdk.ChatMessageText("older queued"),
})
require.NoError(t, err)
_, err = db.InsertChatQueuedMessage(ctx, database.InsertChatQueuedMessageParams{
ChatID: chat.ID,
Content: queuedContent,
})
require.NoError(t, err)
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
ID: chat.ID,
Status: database.ChatStatusWaiting,
WorkerID: uuid.NullUUID{},
StartedAt: sql.NullTime{},
HeartbeatAt: sql.NullTime{},
LastError: sql.NullString{},
})
require.NoError(t, err)
result, err := replica.SendMessage(ctx, chatd.SendMessageOptions{
ChatID: chat.ID,
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("newer queued")},
})
require.NoError(t, err)
require.True(t, result.Queued)
require.NotNil(t, result.QueuedMessage)
require.Equal(t, database.ChatStatusWaiting, result.Chat.Status)
queued, err := db.GetChatQueuedMessages(ctx, chat.ID)
require.NoError(t, err)
require.Len(t, queued, 2)
olderSDK := db2sdk.ChatQueuedMessage(queued[0])
require.Len(t, olderSDK.Content, 1)
require.Equal(t, "older queued", olderSDK.Content[0].Text)
newerSDK := db2sdk.ChatQueuedMessage(queued[1])
require.Len(t, newerSDK.Content, 1)
require.Equal(t, "newer queued", newerSDK.Content[0].Text)
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: chat.ID,
AfterID: 0,
})
require.NoError(t, err)
require.Len(t, messages, 1)
}
func TestSendMessageInterruptBehaviorQueuesAndInterruptsWhenBusy(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
replica := newTestServer(t, db, ps, uuid.New())
ctx := testutil.Context(t, testutil.WaitLong)
user, org, model := seedChatDependencies(ctx, t, db)
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
OrganizationID: org.ID,
OwnerID: user.ID,
Title: "interrupt-when-busy",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
})
require.NoError(t, err)
// CreateChat calls signalWake which triggers processOnce in
// the background. Wait for that processing to finish so it
// doesn't race with the manual status update below.
waitForChatProcessed(ctx, t, db, chat.ID, replica)
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
ID: chat.ID,
Status: database.ChatStatusRunning,
WorkerID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
})
require.NoError(t, err)
result, err := replica.SendMessage(ctx, chatd.SendMessageOptions{
ChatID: chat.ID,
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("interrupt")},
BusyBehavior: chatd.SendMessageBusyBehaviorInterrupt,
})
require.NoError(t, err)
// The message should be queued, not inserted directly.
require.True(t, result.Queued)
require.NotNil(t, result.QueuedMessage)
// The chat should transition to waiting (interrupt signal),
// not pending.
require.Equal(t, database.ChatStatusWaiting, result.Chat.Status)
fromDB, err := db.GetChatByID(ctx, chat.ID)
require.NoError(t, err)
require.Equal(t, database.ChatStatusWaiting, fromDB.Status)
// The message should be in the queue, not in chat_messages.
queued, err := db.GetChatQueuedMessages(ctx, chat.ID)
require.NoError(t, err)
require.Len(t, queued, 1)
// Only the initial user message should be in chat_messages.
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: chat.ID,
AfterID: 0,
})
require.NoError(t, err)
require.Len(t, messages, 1)
}
func TestEditMessageUpdatesAndTruncatesAndClearsQueue(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
replica := newTestServer(t, db, ps, uuid.New())
ctx := testutil.Context(t, testutil.WaitLong)
user, org, model := seedChatDependencies(ctx, t, db)
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
OrganizationID: org.ID,
OwnerID: user.ID,
Title: "edit-message",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("original")},
})
require.NoError(t, err)
initialMessages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: chat.ID,
AfterID: 0,
})
require.NoError(t, err)
require.Len(t, initialMessages, 1)
editedMessageID := initialMessages[0].ID
_, err = replica.SendMessage(ctx, chatd.SendMessageOptions{
ChatID: chat.ID,
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("follow-up")},
BusyBehavior: chatd.SendMessageBusyBehaviorInterrupt,
})
require.NoError(t, err)
_, err = replica.SendMessage(ctx, chatd.SendMessageOptions{
ChatID: chat.ID,
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("another")},
BusyBehavior: chatd.SendMessageBusyBehaviorInterrupt,
})
require.NoError(t, err)
queuedContent, err := json.Marshal([]codersdk.ChatMessagePart{
codersdk.ChatMessageText("queued"),
})
require.NoError(t, err)
_, err = db.InsertChatQueuedMessage(ctx, database.InsertChatQueuedMessageParams{
ChatID: chat.ID,
Content: queuedContent,
})
require.NoError(t, err)
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
ID: chat.ID,
Status: database.ChatStatusRunning,
WorkerID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
})
require.NoError(t, err)
editResult, err := replica.EditMessage(ctx, chatd.EditMessageOptions{
ChatID: chat.ID,
EditedMessageID: editedMessageID,
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("edited")},
})
require.NoError(t, err)
// The edited message is soft-deleted and a new message is inserted,
// so the returned message ID will differ from the original.
require.NotEqual(t, editedMessageID, editResult.Message.ID)
require.Equal(t, database.ChatStatusPending, editResult.Chat.Status)
require.False(t, editResult.Chat.WorkerID.Valid)
editedSDK := db2sdk.ChatMessage(editResult.Message)
require.Len(t, editedSDK.Content, 1)
require.Equal(t, "edited", editedSDK.Content[0].Text)
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: chat.ID,
AfterID: 0,
})
require.NoError(t, err)
require.Len(t, messages, 1)
require.Equal(t, editResult.Message.ID, messages[0].ID)
onlyMessage := db2sdk.ChatMessage(messages[0])
require.Len(t, onlyMessage.Content, 1)
require.Equal(t, "edited", onlyMessage.Content[0].Text)
queued, err := db.GetChatQueuedMessages(ctx, chat.ID)
require.NoError(t, err)
require.Len(t, queued, 0)
// The wake channel may trigger immediate processing after EditMessage,
// transitioning the chat from pending to running then error before we
// read the DB. Wait for any in-flight processing to settle.
// Note: WaitUntilIdleForTest must be called from the test goroutine
// (not inside require.Eventually) to avoid a WaitGroup Add/Wait race.
chatd.WaitUntilIdleForTest(replica)
var chatFromDB database.Chat
require.Eventually(t, func() bool {
c, e := db.GetChatByID(ctx, chat.ID)
if e != nil {
return false
}
chatFromDB = c
return chatFromDB.Status != database.ChatStatusRunning
}, testutil.WaitShort, testutil.IntervalFast)
require.False(t, chatFromDB.WorkerID.Valid)
}
func TestCreateChatInsertsWorkspaceAwarenessMessage(t *testing.T) {
t.Parallel()
t.Run("WithWorkspace", func(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
server := newTestServer(t, db, ps, uuid.New())
ctx := testutil.Context(t, testutil.WaitLong)
user, org, model := seedChatDependencies(ctx, t, db)
tv := dbgen.TemplateVersion(t, db, database.TemplateVersion{
OrganizationID: org.ID,
CreatedBy: user.ID,
})
tpl := dbgen.Template(t, db, database.Template{
CreatedBy: user.ID,
OrganizationID: org.ID,
ActiveVersionID: tv.ID,
})
workspace := dbgen.Workspace(t, db, database.WorkspaceTable{
OwnerID: user.ID,
OrganizationID: org.ID,
TemplateID: tpl.ID,
})
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
OrganizationID: org.ID,
OwnerID: user.ID,
WorkspaceID: uuid.NullUUID{UUID: workspace.ID, Valid: true},
Title: "test-with-workspace",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
})
require.NoError(t, err)
messages, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID)
require.NoError(t, err)
var workspaceMsg *database.ChatMessage
for _, msg := range messages {
if msg.Role == database.ChatMessageRoleSystem {
content := string(msg.Content.RawMessage)
if strings.Contains(content, "attached to a workspace") {
workspaceMsg = &msg
break
}
}
}
require.NotNil(t, workspaceMsg, "workspace awareness system message should exist")
require.Equal(t, database.ChatMessageRoleSystem, workspaceMsg.Role)
require.Equal(t, database.ChatMessageVisibilityModel, workspaceMsg.Visibility)
})
t.Run("WithoutWorkspace", func(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
server := newTestServer(t, db, ps, uuid.New())
ctx := testutil.Context(t, testutil.WaitLong)
user, org, model := seedChatDependencies(ctx, t, db)
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
OrganizationID: org.ID,
OwnerID: user.ID,
Title: "test-without-workspace",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
})
require.NoError(t, err)
messages, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID)
require.NoError(t, err)
var workspaceMsg *database.ChatMessage
for _, msg := range messages {
if msg.Role == database.ChatMessageRoleSystem {
content := string(msg.Content.RawMessage)
if strings.Contains(content, "no workspace associated") {
workspaceMsg = &msg
break
}
}
}
require.NotNil(t, workspaceMsg, "workspace awareness system message should exist")
require.Equal(t, database.ChatMessageRoleSystem, workspaceMsg.Role)
require.Equal(t, database.ChatMessageVisibilityModel, workspaceMsg.Visibility)
})
}
func TestCreateChatRejectsWhenUsageLimitReached(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
replica := newTestServer(t, db, ps, uuid.New())
ctx := testutil.Context(t, testutil.WaitLong)
user, org, model := seedChatDependencies(ctx, t, db)
_, err := db.UpsertChatUsageLimitConfig(ctx, database.UpsertChatUsageLimitConfigParams{
Enabled: true,
DefaultLimitMicros: 100,
Period: string(codersdk.ChatUsageLimitPeriodDay),
})
require.NoError(t, err)
existingChat, err := db.InsertChat(ctx, database.InsertChatParams{
OrganizationID: org.ID,
Status: database.ChatStatusWaiting,
OwnerID: user.ID,
Title: "existing-limit-chat",
LastModelConfigID: model.ID,
})
require.NoError(t, err)
assistantContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
codersdk.ChatMessageText("assistant"),
})
require.NoError(t, err)
_, err = db.InsertChatMessages(ctx, database.InsertChatMessagesParams{
ChatID: existingChat.ID,
CreatedBy: []uuid.UUID{uuid.Nil},
ModelConfigID: []uuid.UUID{model.ID},
Role: []database.ChatMessageRole{database.ChatMessageRoleAssistant},
ContentVersion: []int16{chatprompt.CurrentContentVersion},
Content: []string{string(assistantContent.RawMessage)},
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
InputTokens: []int64{0},
OutputTokens: []int64{0},
TotalTokens: []int64{0},
ReasoningTokens: []int64{0},
CacheCreationTokens: []int64{0},
CacheReadTokens: []int64{0},
ContextLimit: []int64{0},
Compressed: []bool{false},
TotalCostMicros: []int64{100},
RuntimeMs: []int64{0},
})
require.NoError(t, err)
beforeChats, err := db.GetChats(ctx, database.GetChatsParams{
OwnerID: user.ID,
AfterID: uuid.Nil,
OffsetOpt: 0,
LimitOpt: 100,
})
require.NoError(t, err)
require.Len(t, beforeChats, 1)
_, err = replica.CreateChat(ctx, chatd.CreateOptions{
OrganizationID: org.ID,
OwnerID: user.ID,
Title: "over-limit",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
})
require.Error(t, err)
var limitErr *chatd.UsageLimitExceededError
require.ErrorAs(t, err, &limitErr)
require.Equal(t, int64(100), limitErr.LimitMicros)
require.Equal(t, int64(100), limitErr.ConsumedMicros)
afterChats, err := db.GetChats(ctx, database.GetChatsParams{
OwnerID: user.ID,
AfterID: uuid.Nil,
OffsetOpt: 0,
LimitOpt: 100,
})
require.NoError(t, err)
require.Len(t, afterChats, len(beforeChats))
}
func TestPromoteQueuedAllowsAlreadyQueuedMessageWhenUsageLimitReached(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
replica := newTestServer(t, db, ps, uuid.New())
ctx := testutil.Context(t, testutil.WaitLong)
user, org, model := seedChatDependencies(ctx, t, db)
_, err := db.UpsertChatUsageLimitConfig(ctx, database.UpsertChatUsageLimitConfigParams{
Enabled: true,
DefaultLimitMicros: 100,
Period: string(codersdk.ChatUsageLimitPeriodDay),
})
require.NoError(t, err)
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
OrganizationID: org.ID,
OwnerID: user.ID,
Title: "queued-limit-reached",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
})
require.NoError(t, err)
// CreateChat calls signalWake which triggers processOnce in
// the background. Wait for that processing to finish so it
// doesn't race with the manual status update below.
waitForChatProcessed(ctx, t, db, chat.ID, replica)
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
ID: chat.ID,
Status: database.ChatStatusRunning,
WorkerID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
})
require.NoError(t, err)
queuedResult, err := replica.SendMessage(ctx, chatd.SendMessageOptions{
ChatID: chat.ID,
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("queued")},
BusyBehavior: chatd.SendMessageBusyBehaviorQueue,
})
require.NoError(t, err)
require.True(t, queuedResult.Queued)
require.NotNil(t, queuedResult.QueuedMessage)
assistantContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
codersdk.ChatMessageText("assistant"),
})
require.NoError(t, err)
_, err = db.InsertChatMessages(ctx, database.InsertChatMessagesParams{
ChatID: chat.ID,
CreatedBy: []uuid.UUID{uuid.Nil},
ModelConfigID: []uuid.UUID{model.ID},
Role: []database.ChatMessageRole{database.ChatMessageRoleAssistant},
ContentVersion: []int16{chatprompt.CurrentContentVersion},
Content: []string{string(assistantContent.RawMessage)},
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
InputTokens: []int64{0},
OutputTokens: []int64{0},
TotalTokens: []int64{0},
ReasoningTokens: []int64{0},
CacheCreationTokens: []int64{0},
CacheReadTokens: []int64{0},
ContextLimit: []int64{0},
Compressed: []bool{false},
TotalCostMicros: []int64{100},
RuntimeMs: []int64{0},
})
require.NoError(t, err)
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
ID: chat.ID,
Status: database.ChatStatusWaiting,
WorkerID: uuid.NullUUID{},
StartedAt: sql.NullTime{},
HeartbeatAt: sql.NullTime{},
LastError: sql.NullString{},
})
require.NoError(t, err)
result, err := replica.PromoteQueued(ctx, chatd.PromoteQueuedOptions{
ChatID: chat.ID,
QueuedMessageID: queuedResult.QueuedMessage.ID,
CreatedBy: user.ID,
})
require.NoError(t, err)
require.Equal(t, database.ChatMessageRoleUser, result.PromotedMessage.Role)
queued, err := db.GetChatQueuedMessages(ctx, chat.ID)
require.NoError(t, err)
require.Empty(t, queued)
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: chat.ID,
AfterID: 0,
})
require.NoError(t, err)
require.Len(t, messages, 3)
require.Equal(t, database.ChatMessageRoleUser, messages[2].Role)
}
func TestInterruptAutoPromotionIgnoresLaterUsageLimitIncrease(t *testing.T) {
t.Parallel()
const acquireInterval = 10 * time.Millisecond
db, ps := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitLong)
_, err := db.UpsertChatUsageLimitConfig(ctx, database.UpsertChatUsageLimitConfigParams{
Enabled: true,
DefaultLimitMicros: 100,
Period: string(codersdk.ChatUsageLimitPeriodDay),
})
require.NoError(t, err)
clock := quartz.NewMock(t)
acquireTrap := clock.Trap().NewTicker("chatd", "acquire")
defer acquireTrap.Close()
assertPendingWithoutQueuedMessages := func(chatID uuid.UUID) {
t.Helper()
queued, dbErr := db.GetChatQueuedMessages(ctx, chatID)
require.NoError(t, dbErr)
require.Empty(t, queued)
fromDB, dbErr := db.GetChatByID(ctx, chatID)
require.NoError(t, dbErr)
require.Equal(t, database.ChatStatusPending, fromDB.Status)
require.False(t, fromDB.WorkerID.Valid)
}
streamStarted := make(chan struct{})
interrupted := make(chan struct{})
secondRequestStarted := make(chan struct{})
thirdRequestStarted := make(chan struct{})
allowFinish := make(chan struct{})
var requestCount atomic.Int32
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
if !req.Stream {
return chattest.OpenAINonStreamingResponse("title")
}
switch requestCount.Add(1) {
case 1:
chunks := make(chan chattest.OpenAIChunk, 1)
go func() {
defer close(chunks)
chunks <- chattest.OpenAITextChunks("partial")[0]
select {
case <-streamStarted:
default:
close(streamStarted)
}
<-req.Context().Done()
select {
case <-interrupted:
default:
close(interrupted)
}
<-allowFinish
}()
return chattest.OpenAIResponse{StreamingChunks: chunks}
case 2:
close(secondRequestStarted)
case 3:
close(thirdRequestStarted)
}
return chattest.OpenAIStreamingResponse(
chattest.OpenAITextChunks("done")...,
)
})
server := newActiveTestServer(t, db, ps, func(cfg *chatd.Config) {
cfg.Clock = clock
cfg.PendingChatAcquireInterval = acquireInterval
cfg.InFlightChatStaleAfter = testutil.WaitSuperLong
})
acquireTrap.MustWait(ctx).MustRelease(ctx)
user, org, model := seedChatDependencies(ctx, t, db)
setOpenAIProviderBaseURL(ctx, t, db, openAIURL)
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
OrganizationID: org.ID,
OwnerID: user.ID,
Title: "interrupt-autopromote-limit",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
})
require.NoError(t, err)
clock.Advance(acquireInterval).MustWait(ctx)
testutil.TryReceive(ctx, t, streamStarted)
queuedResult, err := server.SendMessage(ctx, chatd.SendMessageOptions{
ChatID: chat.ID,
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("queued")},
BusyBehavior: chatd.SendMessageBusyBehaviorInterrupt,
})
require.NoError(t, err)
require.True(t, queuedResult.Queued)
require.NotNil(t, queuedResult.QueuedMessage)
testutil.TryReceive(ctx, t, interrupted)
close(allowFinish)
chatd.WaitUntilIdleForTest(server)
assertPendingWithoutQueuedMessages(chat.ID)
// Keep the acquire loop frozen here so "queued" stays pending.
// That makes the later send queue because the chat is still busy,
// rather than because the scheduler happened to be slow.
laterQueuedResult, err := server.SendMessage(ctx, chatd.SendMessageOptions{
ChatID: chat.ID,
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("later queued")},
})
require.NoError(t, err)
require.True(t, laterQueuedResult.Queued)
require.NotNil(t, laterQueuedResult.QueuedMessage)
spendChat, err := db.InsertChat(ctx, database.InsertChatParams{
OrganizationID: org.ID,
Status: database.ChatStatusWaiting,
OwnerID: user.ID,
WorkspaceID: uuid.NullUUID{},
ParentChatID: uuid.NullUUID{},
RootChatID: uuid.NullUUID{},
LastModelConfigID: model.ID,
Title: "other-spend",
Mode: database.NullChatMode{},
})
require.NoError(t, err)
assistantContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
codersdk.ChatMessageText("spent elsewhere"),
})
require.NoError(t, err)
_, err = db.InsertChatMessages(ctx, database.InsertChatMessagesParams{
ChatID: spendChat.ID,
CreatedBy: []uuid.UUID{uuid.Nil},
ModelConfigID: []uuid.UUID{model.ID},
Role: []database.ChatMessageRole{database.ChatMessageRoleAssistant},
ContentVersion: []int16{chatprompt.CurrentContentVersion},
Content: []string{string(assistantContent.RawMessage)},
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
InputTokens: []int64{0},
OutputTokens: []int64{0},
TotalTokens: []int64{0},
ReasoningTokens: []int64{0},
CacheCreationTokens: []int64{0},
CacheReadTokens: []int64{0},
ContextLimit: []int64{0},
Compressed: []bool{false},
TotalCostMicros: []int64{100},
RuntimeMs: []int64{0},
})
require.NoError(t, err)
clock.Advance(acquireInterval).MustWait(ctx)
testutil.TryReceive(ctx, t, secondRequestStarted)
chatd.WaitUntilIdleForTest(server)
assertPendingWithoutQueuedMessages(chat.ID)
clock.Advance(acquireInterval).MustWait(ctx)
testutil.TryReceive(ctx, t, thirdRequestStarted)
chatd.WaitUntilIdleForTest(server)
queued, err := db.GetChatQueuedMessages(ctx, chat.ID)
require.NoError(t, err)
require.Empty(t, queued)
fromDB, err := db.GetChatByID(ctx, chat.ID)
require.NoError(t, err)
require.Equal(t, database.ChatStatusWaiting, fromDB.Status)
require.False(t, fromDB.WorkerID.Valid)
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: chat.ID,
AfterID: 0,
})
require.NoError(t, err)
userTexts := make([]string, 0, 3)
for _, message := range messages {
if message.Role != database.ChatMessageRoleUser {
continue
}
sdkMessage := db2sdk.ChatMessage(message)
if len(sdkMessage.Content) != 1 {
continue
}
userTexts = append(userTexts, sdkMessage.Content[0].Text)
}
require.Equal(t, []string{"hello", "queued", "later queued"}, userTexts)
}
func TestEditMessageRejectsWhenUsageLimitReached(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
replica := newTestServer(t, db, ps, uuid.New())
ctx := testutil.Context(t, testutil.WaitLong)
user, org, model := seedChatDependencies(ctx, t, db)
_, err := db.UpsertChatUsageLimitConfig(ctx, database.UpsertChatUsageLimitConfigParams{
Enabled: true,
DefaultLimitMicros: 100,
Period: string(codersdk.ChatUsageLimitPeriodDay),
})
require.NoError(t, err)
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
OrganizationID: org.ID,
OwnerID: user.ID,
Title: "edit-limit-reached",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("original")},
})
require.NoError(t, err)
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: chat.ID,
AfterID: 0,
})
require.NoError(t, err)
require.Len(t, messages, 1)
editedMessageID := messages[0].ID
assistantContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
codersdk.ChatMessageText("assistant"),
})
require.NoError(t, err)
_, err = db.InsertChatMessages(ctx, database.InsertChatMessagesParams{
ChatID: chat.ID,
CreatedBy: []uuid.UUID{uuid.Nil},
ModelConfigID: []uuid.UUID{model.ID},
Role: []database.ChatMessageRole{database.ChatMessageRoleAssistant},
ContentVersion: []int16{chatprompt.CurrentContentVersion},
Content: []string{string(assistantContent.RawMessage)},
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
InputTokens: []int64{0},
OutputTokens: []int64{0},
TotalTokens: []int64{0},
ReasoningTokens: []int64{0},
CacheCreationTokens: []int64{0},
CacheReadTokens: []int64{0},
ContextLimit: []int64{0},
Compressed: []bool{false},
TotalCostMicros: []int64{100},
RuntimeMs: []int64{0},
})
require.NoError(t, err)
_, err = replica.EditMessage(ctx, chatd.EditMessageOptions{
ChatID: chat.ID,
EditedMessageID: editedMessageID,
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("edited")},
})
require.Error(t, err)
var limitErr *chatd.UsageLimitExceededError
require.ErrorAs(t, err, &limitErr)
require.Equal(t, int64(100), limitErr.LimitMicros)
require.Equal(t, int64(100), limitErr.ConsumedMicros)
messages, err = db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: chat.ID,
AfterID: 0,
})
require.NoError(t, err)
require.Len(t, messages, 2)
originalMessage := db2sdk.ChatMessage(messages[0])
require.Len(t, originalMessage.Content, 1)
require.Equal(t, "original", originalMessage.Content[0].Text)
}
func TestEditMessageRejectsMissingMessage(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
replica := newTestServer(t, db, ps, uuid.New())
ctx := testutil.Context(t, testutil.WaitLong)
user, org, model := seedChatDependencies(ctx, t, db)
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
OrganizationID: org.ID,
OwnerID: user.ID,
Title: "missing-edited-message",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
})
require.NoError(t, err)
_, err = replica.EditMessage(ctx, chatd.EditMessageOptions{
ChatID: chat.ID,
EditedMessageID: 999999,
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("edited")},
})
require.Error(t, err)
require.True(t, errors.Is(err, chatd.ErrEditedMessageNotFound))
}
func TestEditMessageRejectsNonUserMessage(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
replica := newTestServer(t, db, ps, uuid.New())
ctx := testutil.Context(t, testutil.WaitLong)
user, org, model := seedChatDependencies(ctx, t, db)
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
OrganizationID: org.ID,
OwnerID: user.ID,
Title: "non-user-edited-message",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
})
require.NoError(t, err)
assistantContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
codersdk.ChatMessageText("assistant"),
})
require.NoError(t, err)
assistantMessages, err := db.InsertChatMessages(ctx, database.InsertChatMessagesParams{
ChatID: chat.ID,
CreatedBy: []uuid.UUID{uuid.Nil},
ModelConfigID: []uuid.UUID{model.ID},
Role: []database.ChatMessageRole{database.ChatMessageRoleAssistant},
ContentVersion: []int16{chatprompt.CurrentContentVersion},
Content: []string{string(assistantContent.RawMessage)},
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
InputTokens: []int64{0},
OutputTokens: []int64{0},
TotalTokens: []int64{0},
ReasoningTokens: []int64{0},
CacheCreationTokens: []int64{0},
CacheReadTokens: []int64{0},
ContextLimit: []int64{0},
Compressed: []bool{false},
TotalCostMicros: []int64{0},
RuntimeMs: []int64{0},
})
require.NoError(t, err)
assistantMessage := assistantMessages[0]
_, err = replica.EditMessage(ctx, chatd.EditMessageOptions{
ChatID: chat.ID,
EditedMessageID: assistantMessage.ID,
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("edited")},
})
require.Error(t, err)
require.True(t, errors.Is(err, chatd.ErrEditedMessageNotUser))
}
func TestRecoverStaleChatsPeriodically(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitLong)
user, org, model := seedChatDependencies(ctx, t, db)
// Use a very short stale threshold so the periodic recovery
// kicks in quickly during the test.
staleAfter := 500 * time.Millisecond
// Create a chat and simulate a dead worker by setting the chat
// to running with a heartbeat in the past.
deadWorkerID := uuid.New()
chat, err := db.InsertChat(ctx, database.InsertChatParams{
OrganizationID: org.ID,
Status: database.ChatStatusWaiting,
OwnerID: user.ID,
Title: "stale-recovery-periodic",
LastModelConfigID: model.ID,
})
require.NoError(t, err)
_, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
ID: chat.ID,
Status: database.ChatStatusRunning,
WorkerID: uuid.NullUUID{UUID: deadWorkerID, Valid: true},
StartedAt: sql.NullTime{Time: time.Now().Add(-time.Hour), Valid: true},
HeartbeatAt: sql.NullTime{Time: time.Now().Add(-time.Hour), Valid: true},
})
require.NoError(t, err)
// Start a new replica. Its startup recovery will reset the
// chat (since the heartbeat is old), but the key point is that
// the periodic loop also recovers newly-stale chats.
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
server := chatd.New(chatd.Config{
Logger: logger,
Database: db,
ReplicaID: uuid.New(),
Pubsub: ps,
PendingChatAcquireInterval: testutil.WaitLong,
InFlightChatStaleAfter: staleAfter,
})
t.Cleanup(func() {
require.NoError(t, server.Close())
})
// The startup recovery should have already reset our stale
// chat.
require.Eventually(t, func() bool {
fromDB, err := db.GetChatByID(ctx, chat.ID)
if err != nil {
return false
}
return fromDB.Status == database.ChatStatusPending
}, testutil.WaitMedium, testutil.IntervalFast)
// Now simulate a second stale chat appearing AFTER startup.
// This tests the periodic recovery, not just the startup one.
deadWorkerID2 := uuid.New()
chat2, err := db.InsertChat(ctx, database.InsertChatParams{
OrganizationID: org.ID,
Status: database.ChatStatusWaiting,
OwnerID: user.ID,
Title: "stale-recovery-periodic-2",
LastModelConfigID: model.ID,
})
require.NoError(t, err)
_, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
ID: chat2.ID,
Status: database.ChatStatusRunning,
WorkerID: uuid.NullUUID{UUID: deadWorkerID2, Valid: true},
StartedAt: sql.NullTime{Time: time.Now().Add(-time.Hour), Valid: true},
HeartbeatAt: sql.NullTime{Time: time.Now().Add(-time.Hour), Valid: true},
})
require.NoError(t, err)
// The periodic stale recovery loop (running at staleAfter/5 =
// 100ms intervals) should pick this up without a restart.
require.Eventually(t, func() bool {
fromDB, err := db.GetChatByID(ctx, chat2.ID)
if err != nil {
return false
}
return fromDB.Status == database.ChatStatusPending
}, testutil.WaitMedium, testutil.IntervalFast)
}
func TestRecoverStaleRequiresActionChat(t *testing.T) {
t.Parallel()
db, ps, rawDB := dbtestutil.NewDBWithSQLDB(t)
ctx := testutil.Context(t, testutil.WaitLong)
user, org, model := seedChatDependencies(ctx, t, db)
// Use a very short stale threshold so the periodic recovery
// kicks in quickly during the test.
staleAfter := 500 * time.Millisecond
// Create a chat and set it to requires_action to simulate a
// client that disappeared while the chat was waiting for
// dynamic tool results.
chat, err := db.InsertChat(ctx, database.InsertChatParams{
OrganizationID: org.ID,
Status: database.ChatStatusWaiting,
OwnerID: user.ID,
Title: "stale-requires-action",
LastModelConfigID: model.ID,
})
require.NoError(t, err)
_, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
ID: chat.ID,
Status: database.ChatStatusRequiresAction,
})
require.NoError(t, err)
// Backdate updated_at so the chat appears stale to the
// recovery loop without needing time.Sleep.
_, err = rawDB.ExecContext(ctx,
"UPDATE chats SET updated_at = $1 WHERE id = $2",
time.Now().Add(-time.Hour), chat.ID)
require.NoError(t, err)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
server := chatd.New(chatd.Config{
Logger: logger,
Database: db,
ReplicaID: uuid.New(),
Pubsub: ps,
PendingChatAcquireInterval: testutil.WaitLong,
InFlightChatStaleAfter: staleAfter,
})
t.Cleanup(func() {
require.NoError(t, server.Close())
})
// The stale recovery should transition the requires_action
// chat to error with the timeout message.
var chatResult database.Chat
require.Eventually(t, func() bool {
chatResult, err = db.GetChatByID(ctx, chat.ID)
if err != nil {
return false
}
return chatResult.Status == database.ChatStatusError
}, testutil.WaitMedium, testutil.IntervalFast)
require.Contains(t, chatResult.LastError.String, "Dynamic tool execution timed out")
require.False(t, chatResult.WorkerID.Valid)
}
func TestNewReplicaRecoversStaleChatFromDeadReplica(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitLong)
user, org, model := seedChatDependencies(ctx, t, db)
// Simulate a chat left running by a dead replica with a stale
// heartbeat (well beyond the stale threshold).
deadReplicaID := uuid.New()
chat, err := db.InsertChat(ctx, database.InsertChatParams{
OrganizationID: org.ID,
Status: database.ChatStatusWaiting,
OwnerID: user.ID,
Title: "orphaned-chat",
LastModelConfigID: model.ID,
})
require.NoError(t, err)
// Set the heartbeat far in the past so it's definitely stale.
_, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
ID: chat.ID,
Status: database.ChatStatusRunning,
WorkerID: uuid.NullUUID{UUID: deadReplicaID, Valid: true},
StartedAt: sql.NullTime{Time: time.Now().Add(-time.Hour), Valid: true},
HeartbeatAt: sql.NullTime{Time: time.Now().Add(-time.Hour), Valid: true},
})
require.NoError(t, err)
// Start a new replica — it should recover the stale chat on
// startup.
newReplica := newTestServer(t, db, ps, uuid.New())
_ = newReplica
require.Eventually(t, func() bool {
fromDB, err := db.GetChatByID(ctx, chat.ID)
if err != nil {
return false
}
return fromDB.Status == database.ChatStatusPending &&
!fromDB.WorkerID.Valid
}, testutil.WaitMedium, testutil.IntervalFast)
}
func TestWaitingChatsAreNotRecoveredAsStale(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitLong)
user, org, model := seedChatDependencies(ctx, t, db)
// Create a chat in waiting status — this should NOT be touched
// by stale recovery.
chat, err := db.InsertChat(ctx, database.InsertChatParams{
OrganizationID: org.ID,
Status: database.ChatStatusWaiting,
OwnerID: user.ID,
Title: "waiting-chat",
LastModelConfigID: model.ID,
})
require.NoError(t, err)
// Start a replica with a short stale threshold.
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
server := chatd.New(chatd.Config{
Logger: logger,
Database: db,
ReplicaID: uuid.New(),
Pubsub: ps,
PendingChatAcquireInterval: testutil.WaitLong,
InFlightChatStaleAfter: 500 * time.Millisecond,
})
t.Cleanup(func() {
require.NoError(t, server.Close())
})
// Wait long enough for multiple periodic recovery cycles to
// run (staleAfter/5 = 100ms intervals).
require.Never(t, func() bool {
fromDB, err := db.GetChatByID(ctx, chat.ID)
if err != nil {
return false
}
return fromDB.Status != database.ChatStatusWaiting
}, time.Second, testutil.IntervalFast,
"waiting chat should not be modified by stale recovery")
}
func TestUpdateChatStatusPersistsLastError(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
_ = newTestServer(t, db, ps, uuid.New())
ctx := testutil.Context(t, testutil.WaitLong)
user, org, model := seedChatDependencies(ctx, t, db)
chat, err := db.InsertChat(ctx, database.InsertChatParams{
OrganizationID: org.ID,
Status: database.ChatStatusWaiting,
OwnerID: user.ID,
Title: "error-persisted",
LastModelConfigID: model.ID,
})
require.NoError(t, err)
// Simulate a chat that failed with an error.
errorMessage := "stream response: status 500: internal server error"
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
ID: chat.ID,
Status: database.ChatStatusError,
WorkerID: uuid.NullUUID{},
StartedAt: sql.NullTime{},
HeartbeatAt: sql.NullTime{},
LastError: sql.NullString{String: errorMessage, Valid: true},
})
require.NoError(t, err)
require.Equal(t, database.ChatStatusError, chat.Status)
require.Equal(t, sql.NullString{String: errorMessage, Valid: true}, chat.LastError)
// Verify the error is persisted when re-read from the database.
fromDB, err := db.GetChatByID(ctx, chat.ID)
require.NoError(t, err)
require.Equal(t, database.ChatStatusError, fromDB.Status)
require.Equal(t, sql.NullString{String: errorMessage, Valid: true}, fromDB.LastError)
// Verify the error is cleared when the chat transitions to a
// non-error status (e.g. pending after a retry).
chat, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
ID: chat.ID,
Status: database.ChatStatusPending,
WorkerID: uuid.NullUUID{},
StartedAt: sql.NullTime{},
HeartbeatAt: sql.NullTime{},
LastError: sql.NullString{},
})
require.NoError(t, err)
require.Equal(t, database.ChatStatusPending, chat.Status)
require.False(t, chat.LastError.Valid)
fromDB, err = db.GetChatByID(ctx, chat.ID)
require.NoError(t, err)
require.False(t, fromDB.LastError.Valid)
}
func TestSubscribeSnapshotIncludesStatusEvent(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
replica := newTestServer(t, db, ps, uuid.New())
ctx := testutil.Context(t, testutil.WaitLong)
user, org, model := seedChatDependencies(ctx, t, db)
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
OrganizationID: org.ID,
OwnerID: user.ID,
Title: "status-snapshot",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
})
require.NoError(t, err)
snapshot, _, cancel, ok := replica.Subscribe(ctx, chat.ID, nil, 0)
require.True(t, ok)
t.Cleanup(cancel)
// The first event in the snapshot must be a status event.
// The exact status depends on timing: CreateChat sets
// pending, but the wake signal may trigger processing
// before Subscribe is called.
require.NotEmpty(t, snapshot)
require.Equal(t, codersdk.ChatStreamEventTypeStatus, snapshot[0].Type)
require.NotNil(t, snapshot[0].Status)
}
func TestPersistToolResultWithBinaryData(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitLong)
const binaryOutputBase64 = "SEVBREVSAAAAc29tZSBkYXRhAABtb3JlIGRhdGEARU5E"
binaryOutput, err := io.ReadAll(base64.NewDecoder(
base64.StdEncoding,
strings.NewReader(binaryOutputBase64),
))
require.NoError(t, err)
var streamedCallCount atomic.Int32
var streamedCallsMu sync.Mutex
streamedCalls := make([][]chattest.OpenAIMessage, 0, 2)
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
if !req.Stream {
return chattest.OpenAINonStreamingResponse("Binary tool result test")
}
streamedCallsMu.Lock()
streamedCalls = append(streamedCalls, append([]chattest.OpenAIMessage(nil), req.Messages...))
streamedCallsMu.Unlock()
if streamedCallCount.Add(1) == 1 {
return chattest.OpenAIStreamingResponse(
chattest.OpenAIToolCallChunk(
"execute",
`{"command":"cat /home/coder/binary_file.bin"}`,
),
)
}
// Include literal \u0000 in the response text, which is
// what a real LLM writes when explaining binary output.
// json.Marshal encodes the backslash as \\, producing
// \\u0000 in the JSON bytes. The sanitizer must not
// corrupt this into invalid JSON.
return chattest.OpenAIStreamingResponse(
chattest.OpenAITextChunks("The file contains \\u0000 null bytes.")...,
)
})
// Use "openai-compat" provider so the chatd framework uses the
// /chat/completions endpoint, where the mock server supports
// streaming tool calls. The default "openai" provider routes to
// /responses which only handles text deltas in the mock.
user, org, model := seedChatDependenciesWithProvider(ctx, t, db, "openai-compat", openAIURL)
ws, dbAgent := seedWorkspaceWithAgent(t, db, user.ID)
ctrl := gomock.NewController(t)
mockConn := agentconnmock.NewMockAgentConn(ctrl)
mockConn.EXPECT().
SetExtraHeaders(gomock.Any()).
AnyTimes()
mockConn.EXPECT().
ContextConfig(gomock.Any()).
Return(workspacesdk.ContextConfigResponse{}, xerrors.New("not supported")).
AnyTimes()
mockConn.EXPECT().
ListMCPTools(gomock.Any()).
Return(workspacesdk.ListMCPToolsResponse{}, nil).
AnyTimes()
mockConn.EXPECT().
LS(gomock.Any(), gomock.Any(), gomock.Any()).
Return(workspacesdk.LSResponse{}, nil).
AnyTimes()
mockConn.EXPECT().
ReadFile(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(io.NopCloser(strings.NewReader("")), "", nil).
AnyTimes()
mockConn.EXPECT().
StartProcess(gomock.Any(), gomock.Any()).
DoAndReturn(func(_ context.Context, req workspacesdk.StartProcessRequest) (workspacesdk.StartProcessResponse, error) {
require.Equal(t, "cat /home/coder/binary_file.bin", req.Command)
return workspacesdk.StartProcessResponse{ID: "proc-binary", Started: true}, nil
}).
Times(1)
mockConn.EXPECT().
ProcessOutput(gomock.Any(), "proc-binary", gomock.Any()).
Return(workspacesdk.ProcessOutputResponse{
Output: string(binaryOutput),
Running: false,
ExitCode: ptrRef(0),
}, nil).
AnyTimes()
server := newActiveTestServer(t, db, ps, func(cfg *chatd.Config) {
cfg.AgentConn = func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) {
require.Equal(t, dbAgent.ID, agentID)
return mockConn, func() {}, nil
}
})
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
OrganizationID: org.ID,
OwnerID: user.ID,
Title: "binary-tool-result",
ModelConfigID: model.ID,
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
InitialUserContent: []codersdk.ChatMessagePart{
codersdk.ChatMessageText("Read /home/coder/binary_file.bin."),
},
})
require.NoError(t, err)
var chatResult database.Chat
require.Eventually(t, func() bool {
got, getErr := db.GetChatByID(ctx, chat.ID)
if getErr != nil {
return false
}
chatResult = got
return got.Status == database.ChatStatusWaiting || got.Status == database.ChatStatusError
}, testutil.WaitLong, testutil.IntervalFast)
if chatResult.Status == database.ChatStatusError {
require.FailNowf(t, "chat run failed", "last_error=%q", chatResult.LastError.String)
}
var toolMessage *database.ChatMessage
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
messages, dbErr := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: chat.ID,
AfterID: 0,
})
if dbErr != nil {
return false
}
for i := range messages {
if messages[i].Role == database.ChatMessageRoleTool {
toolMessage = &messages[i]
return true
}
}
return false
}, testutil.IntervalFast)
require.NotNil(t, toolMessage)
parts, err := chatprompt.ParseContent(*toolMessage)
require.NoError(t, err)
require.Len(t, parts, 1)
require.Equal(t, codersdk.ChatMessagePartTypeToolResult, parts[0].Type)
require.Equal(t, "execute", parts[0].ToolName)
var result chattool.ExecuteResult
require.NoError(t, json.Unmarshal(parts[0].Result, &result))
require.True(t, result.Success)
require.Equal(t, string(binaryOutput), result.Output)
require.Equal(t, 0, result.ExitCode)
require.GreaterOrEqual(t, streamedCallCount.Load(), int32(2))
streamedCallsMu.Lock()
recordedStreamCalls := append([][]chattest.OpenAIMessage(nil), streamedCalls...)
streamedCallsMu.Unlock()
require.GreaterOrEqual(t, len(recordedStreamCalls), 2)
var foundToolResultInSecondCall bool
for _, message := range recordedStreamCalls[1] {
if message.Role != "tool" {
continue
}
if !json.Valid([]byte(message.Content)) {
continue
}
var result chattool.ExecuteResult
if err := json.Unmarshal([]byte(message.Content), &result); err != nil {
continue
}
if result.Output == string(binaryOutput) {
foundToolResultInSecondCall = true
break
}
}
require.True(t, foundToolResultInSecondCall, "expected second streamed model call to include execute tool output")
}
func TestDynamicToolCallPausesAndResumes(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitLong)
// Track streaming calls to the mock LLM.
var streamedCallCount atomic.Int32
var streamedCallsMu sync.Mutex
streamedCalls := make([]chattest.OpenAIRequest, 0, 2)
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
// Non-streaming requests are title generation — return a
// simple title.
if !req.Stream {
return chattest.OpenAINonStreamingResponse("Dynamic tool test")
}
// Capture the full request for later assertions.
streamedCallsMu.Lock()
streamedCalls = append(streamedCalls, chattest.OpenAIRequest{
Messages: append([]chattest.OpenAIMessage(nil), req.Messages...),
Tools: append([]chattest.OpenAITool(nil), req.Tools...),
Stream: req.Stream,
})
streamedCallsMu.Unlock()
if streamedCallCount.Add(1) == 1 {
// First call: the LLM invokes our dynamic tool.
return chattest.OpenAIStreamingResponse(
chattest.OpenAIToolCallChunk(
"my_dynamic_tool",
`{"input":"hello world"}`,
),
)
}
// Second call: the LLM returns a normal text response.
return chattest.OpenAIStreamingResponse(
chattest.OpenAITextChunks("Dynamic tool result received.")...,
)
})
user, org, model := seedChatDependenciesWithProvider(ctx, t, db, "openai-compat", openAIURL)
// Dynamic tools do not need a workspace connection, but the
// chatd server always builds workspace tools. Use an active
// server without an agent connection — the built-in tools
// are never invoked because the only tool call targets our
// dynamic tool.
server := newActiveTestServer(t, db, ps)
// Create a chat with a dynamic tool.
dynamicToolsJSON, err := json.Marshal([]mcpgo.Tool{{
Name: "my_dynamic_tool",
Description: "A test dynamic tool.",
InputSchema: mcpgo.ToolInputSchema{
Type: "object",
Properties: map[string]any{
"input": map[string]any{"type": "string"},
},
Required: []string{"input"},
},
}})
require.NoError(t, err)
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
OrganizationID: org.ID,
OwnerID: user.ID,
Title: "dynamic-tool-pause-resume",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{
codersdk.ChatMessageText("Please call the dynamic tool."),
},
DynamicTools: dynamicToolsJSON,
})
require.NoError(t, err)
// 1. Wait for the chat to reach requires_action status.
var chatResult database.Chat
require.Eventually(t, func() bool {
got, getErr := db.GetChatByID(ctx, chat.ID)
if getErr != nil {
return false
}
chatResult = got
return got.Status == database.ChatStatusRequiresAction ||
got.Status == database.ChatStatusError
}, testutil.WaitLong, testutil.IntervalFast)
require.Equal(t, database.ChatStatusRequiresAction, chatResult.Status,
"expected requires_action, got %s (last_error=%q)",
chatResult.Status, chatResult.LastError.String)
// 2. Read the assistant message to find the tool-call ID.
var toolCallID string
var toolCallFound bool
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
messages, dbErr := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: chat.ID,
AfterID: 0,
})
if dbErr != nil {
return false
}
for _, msg := range messages {
if msg.Role != database.ChatMessageRoleAssistant {
continue
}
parts, parseErr := chatprompt.ParseContent(msg)
if parseErr != nil {
continue
}
for _, part := range parts {
if part.Type == codersdk.ChatMessagePartTypeToolCall && part.ToolName == "my_dynamic_tool" {
toolCallID = part.ToolCallID
toolCallFound = true
return true
}
}
}
return false
}, testutil.IntervalFast)
require.True(t, toolCallFound, "expected to find tool call for my_dynamic_tool")
require.NotEmpty(t, toolCallID)
// 3. Submit tool results via SubmitToolResults.
toolResultOutput := json.RawMessage(`{"result":"dynamic tool output"}`)
err = server.SubmitToolResults(ctx, chatd.SubmitToolResultsOptions{
ChatID: chat.ID,
UserID: user.ID,
ModelConfigID: chatResult.LastModelConfigID,
Results: []codersdk.ToolResult{{
ToolCallID: toolCallID,
Output: toolResultOutput,
}},
DynamicTools: dynamicToolsJSON,
})
require.NoError(t, err)
// 4. Wait for the chat to reach a terminal status.
require.Eventually(t, func() bool {
got, getErr := db.GetChatByID(ctx, chat.ID)
if getErr != nil {
return false
}
chatResult = got
return got.Status == database.ChatStatusWaiting || got.Status == database.ChatStatusError
}, testutil.WaitLong, testutil.IntervalFast)
// 5. Verify the chat completed successfully.
if chatResult.Status == database.ChatStatusError {
require.FailNowf(t, "chat run failed", "last_error=%q", chatResult.LastError.String)
}
// 6. Verify the mock received exactly 2 streaming calls.
require.Equal(t, int32(2), streamedCallCount.Load(),
"expected exactly 2 streaming calls to the LLM")
streamedCallsMu.Lock()
recordedCalls := append([]chattest.OpenAIRequest(nil), streamedCalls...)
streamedCallsMu.Unlock()
require.Len(t, recordedCalls, 2)
// 7. Verify the dynamic tool appeared in the first call's tool list.
var foundDynamicTool bool
for _, tool := range recordedCalls[0].Tools {
if tool.Function.Name == "my_dynamic_tool" {
foundDynamicTool = true
break
}
}
require.True(t, foundDynamicTool,
"expected 'my_dynamic_tool' in the first LLM call's tool list")
// 8. Verify the second call's messages contain the tool result.
var foundToolResultInSecondCall bool
for _, message := range recordedCalls[1].Messages {
if message.Role != "tool" {
continue
}
if strings.Contains(message.Content, "dynamic tool output") {
foundToolResultInSecondCall = true
break
}
}
require.True(t, foundToolResultInSecondCall,
"expected second LLM call to include the submitted dynamic tool result")
}
func TestDynamicToolCallMixedWithBuiltIn(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitLong)
// Track streaming calls to the mock LLM.
var streamedCallCount atomic.Int32
var streamedCallsMu sync.Mutex
streamedCalls := make([]chattest.OpenAIRequest, 0, 2)
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
if !req.Stream {
return chattest.OpenAINonStreamingResponse("Mixed tool test")
}
streamedCallsMu.Lock()
streamedCalls = append(streamedCalls, chattest.OpenAIRequest{
Messages: append([]chattest.OpenAIMessage(nil), req.Messages...),
Tools: append([]chattest.OpenAITool(nil), req.Tools...),
Stream: req.Stream,
})
streamedCallsMu.Unlock()
if streamedCallCount.Add(1) == 1 {
// First call: return TWO tool calls in one
// response — a built-in tool (read_file) and a
// dynamic tool (my_dynamic_tool).
builtinChunk := chattest.OpenAIToolCallChunk(
"read_file",
`{"path":"/tmp/test.txt"}`,
)
dynamicChunk := chattest.OpenAIToolCallChunk(
"my_dynamic_tool",
`{"input":"hello world"}`,
)
// Merge both tool calls into one chunk with
// separate indices so the LLM appears to have
// requested both tools simultaneously.
mergedChunk := builtinChunk
dynCall := dynamicChunk.Choices[0].ToolCalls[0]
dynCall.Index = 1
mergedChunk.Choices[0].ToolCalls = append(
mergedChunk.Choices[0].ToolCalls,
dynCall,
)
return chattest.OpenAIStreamingResponse(mergedChunk)
}
// Second call (after tool results): normal text
// response.
return chattest.OpenAIStreamingResponse(
chattest.OpenAITextChunks("All done.")...,
)
})
user, org, model := seedChatDependenciesWithProvider(ctx, t, db, "openai-compat", openAIURL)
server := newActiveTestServer(t, db, ps)
// Create a chat with a dynamic tool.
dynamicToolsJSON, err := json.Marshal([]mcpgo.Tool{{
Name: "my_dynamic_tool",
Description: "A test dynamic tool.",
InputSchema: mcpgo.ToolInputSchema{
Type: "object",
Properties: map[string]any{
"input": map[string]any{"type": "string"},
},
Required: []string{"input"},
},
}})
require.NoError(t, err)
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
OrganizationID: org.ID,
OwnerID: user.ID,
Title: "mixed-builtin-dynamic",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{
codersdk.ChatMessageText("Call both tools."),
},
DynamicTools: dynamicToolsJSON,
})
require.NoError(t, err)
// 1. Wait for the chat to reach requires_action status.
var chatResult database.Chat
require.Eventually(t, func() bool {
got, getErr := db.GetChatByID(ctx, chat.ID)
if getErr != nil {
return false
}
chatResult = got
return got.Status == database.ChatStatusRequiresAction ||
got.Status == database.ChatStatusError
}, testutil.WaitLong, testutil.IntervalFast)
require.Equal(t, database.ChatStatusRequiresAction, chatResult.Status,
"expected requires_action, got %s (last_error=%q)",
chatResult.Status, chatResult.LastError.String)
// 2. Verify the built-in tool (read_file) was already
// executed by checking that a tool result message
// exists for it in the database.
var builtinToolResultFound bool
var toolCallID string
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
messages, dbErr := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: chat.ID,
AfterID: 0,
})
if dbErr != nil {
return false
}
for _, msg := range messages {
parts, parseErr := chatprompt.ParseContent(msg)
if parseErr != nil {
continue
}
for _, part := range parts {
// Check for the built-in tool result.
if part.Type == codersdk.ChatMessagePartTypeToolResult && part.ToolName == "read_file" {
builtinToolResultFound = true
}
// Find the dynamic tool call ID.
if part.Type == codersdk.ChatMessagePartTypeToolCall && part.ToolName == "my_dynamic_tool" {
toolCallID = part.ToolCallID
}
}
}
return builtinToolResultFound && toolCallID != ""
}, testutil.IntervalFast)
require.True(t, builtinToolResultFound,
"expected read_file tool result in the DB before dynamic tool resolution")
require.NotEmpty(t, toolCallID)
// 3. Submit dynamic tool results.
err = server.SubmitToolResults(ctx, chatd.SubmitToolResultsOptions{
ChatID: chat.ID,
UserID: user.ID,
ModelConfigID: chatResult.LastModelConfigID,
Results: []codersdk.ToolResult{{
ToolCallID: toolCallID,
Output: json.RawMessage(`{"result":"dynamic output"}`),
}},
DynamicTools: dynamicToolsJSON,
})
require.NoError(t, err)
// 4. Wait for the chat to complete.
require.Eventually(t, func() bool {
got, getErr := db.GetChatByID(ctx, chat.ID)
if getErr != nil {
return false
}
chatResult = got
return got.Status == database.ChatStatusWaiting || got.Status == database.ChatStatusError
}, testutil.WaitLong, testutil.IntervalFast)
if chatResult.Status == database.ChatStatusError {
require.FailNowf(t, "chat run failed", "last_error=%q", chatResult.LastError.String)
}
// 5. Verify the LLM received exactly 2 streaming calls.
require.Equal(t, int32(2), streamedCallCount.Load(),
"expected exactly 2 streaming calls to the LLM")
}
func TestSubmitToolResultsConcurrency(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitLong)
// The mock LLM returns a dynamic tool call on the first streaming
// request, then a plain text reply on the second.
var streamedCallCount atomic.Int32
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
if !req.Stream {
return chattest.OpenAINonStreamingResponse("Concurrency test")
}
if streamedCallCount.Add(1) == 1 {
return chattest.OpenAIStreamingResponse(
chattest.OpenAIToolCallChunk(
"my_dynamic_tool",
`{"input":"hello"}`,
),
)
}
return chattest.OpenAIStreamingResponse(
chattest.OpenAITextChunks("Done.")...,
)
})
user, org, model := seedChatDependenciesWithProvider(ctx, t, db, "openai-compat", openAIURL)
server := newActiveTestServer(t, db, ps)
// Create a chat with a dynamic tool.
dynamicToolsJSON, err := json.Marshal([]mcpgo.Tool{{
Name: "my_dynamic_tool",
Description: "A test dynamic tool.",
InputSchema: mcpgo.ToolInputSchema{
Type: "object",
Properties: map[string]any{
"input": map[string]any{"type": "string"},
},
Required: []string{"input"},
},
}})
require.NoError(t, err)
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
OrganizationID: org.ID,
OwnerID: user.ID,
Title: "concurrency-tool-results",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{
codersdk.ChatMessageText("Please call the dynamic tool."),
},
DynamicTools: dynamicToolsJSON,
})
require.NoError(t, err)
// Wait for the chat to reach requires_action status.
var chatResult database.Chat
require.Eventually(t, func() bool {
got, getErr := db.GetChatByID(ctx, chat.ID)
if getErr != nil {
return false
}
chatResult = got
return got.Status == database.ChatStatusRequiresAction ||
got.Status == database.ChatStatusError
}, testutil.WaitLong, testutil.IntervalFast)
require.Equal(t, database.ChatStatusRequiresAction, chatResult.Status,
"expected requires_action, got %s (last_error=%q)",
chatResult.Status, chatResult.LastError.String)
// Find the tool call ID from the assistant message.
var toolCallID string
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
messages, dbErr := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: chat.ID,
AfterID: 0,
})
if dbErr != nil {
return false
}
for _, msg := range messages {
if msg.Role != database.ChatMessageRoleAssistant {
continue
}
parts, parseErr := chatprompt.ParseContent(msg)
if parseErr != nil {
continue
}
for _, part := range parts {
if part.Type == codersdk.ChatMessagePartTypeToolCall && part.ToolName == "my_dynamic_tool" {
toolCallID = part.ToolCallID
return true
}
}
}
return false
}, testutil.IntervalFast)
require.NotEmpty(t, toolCallID)
// Spawn N goroutines that all try to submit tool results at the
// same time. Exactly one should succeed; the rest must get a
// ToolResultStatusConflictError.
const numGoroutines = 10
var (
wg sync.WaitGroup
ready = make(chan struct{})
successes atomic.Int32
conflicts atomic.Int32
unexpectedErrors = make(chan error, numGoroutines)
)
for range numGoroutines {
wg.Go(func() {
// Wait for all goroutines to be ready.
<-ready
submitErr := server.SubmitToolResults(ctx, chatd.SubmitToolResultsOptions{
ChatID: chat.ID,
UserID: user.ID,
ModelConfigID: chatResult.LastModelConfigID,
Results: []codersdk.ToolResult{{
ToolCallID: toolCallID,
Output: json.RawMessage(`{"result":"concurrent output"}`),
}},
DynamicTools: dynamicToolsJSON,
})
if submitErr == nil {
successes.Add(1)
return
}
var conflict *chatd.ToolResultStatusConflictError
if errors.As(submitErr, &conflict) {
conflicts.Add(1)
return
}
// Collect unexpected errors for assertion
// outside the goroutine (require.NoError
// calls t.FailNow which is illegal here).
unexpectedErrors <- submitErr
})
}
// Release all goroutines at once.
close(ready)
wg.Wait()
close(unexpectedErrors)
for ue := range unexpectedErrors {
require.NoError(t, ue, "unexpected error from SubmitToolResults")
}
require.Equal(t, int32(1), successes.Load(),
"expected exactly 1 goroutine to succeed")
require.Equal(t, int32(numGoroutines-1), conflicts.Load(),
"expected %d conflict errors", numGoroutines-1)
}
func ptrRef[T any](v T) *T {
return &v
}
func TestSubscribeNoPubsubNoDuplicateMessageParts(t *testing.T) {
t.Parallel()
// Use nil pubsub to force the no-pubsub path.
db, _ := dbtestutil.NewDB(t)
replica := newTestServer(t, db, nil, uuid.New())
ctx := testutil.Context(t, testutil.WaitLong)
user, org, model := seedChatDependencies(ctx, t, db)
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
OrganizationID: org.ID,
OwnerID: user.ID,
Title: "no-dup-parts",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
})
require.NoError(t, err)
// Wait for any wake-triggered processing to settle before
// subscribing, so the snapshot captures the final state.
// The wake signal may trigger processOnce which will fail
// (no LLM configured) and set the chat to error status.
// Poll until the chat reaches a terminal state (not pending
// and not running), then wait for the goroutine to finish.
waitForChatProcessed(ctx, t, db, chat.ID, replica)
snapshot, events, cancel, ok := replica.Subscribe(ctx, chat.ID, nil, 0)
require.True(t, ok)
t.Cleanup(cancel)
// Snapshot should have events (at minimum: status + message).
require.NotEmpty(t, snapshot)
// The events channel should NOT immediately produce any
// events — the snapshot already contained everything. Before
// the fix, localSnapshot was replayed into the channel,
// causing duplicates.
require.Never(t, func() bool {
select {
case <-events:
return true
default:
return false
}
}, 200*time.Millisecond, testutil.IntervalFast,
"expected no duplicate events after snapshot")
}
func TestSubscribeAfterMessageID(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
replica := newTestServer(t, db, ps, uuid.New())
ctx := testutil.Context(t, testutil.WaitLong)
user, org, model := seedChatDependencies(ctx, t, db)
// Create a chat — this inserts one initial "user" message.
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
OrganizationID: org.ID,
OwnerID: user.ID,
Title: "after-id-test",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("first")},
})
require.NoError(t, err)
// Insert two more messages so we have three total visible
// messages (the initial user message plus these two).
secondContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
codersdk.ChatMessageText("second"),
})
require.NoError(t, err)
msg2Results, err := db.InsertChatMessages(ctx, database.InsertChatMessagesParams{
ChatID: chat.ID,
CreatedBy: []uuid.UUID{uuid.Nil},
ModelConfigID: []uuid.UUID{model.ID},
Role: []database.ChatMessageRole{database.ChatMessageRoleAssistant},
ContentVersion: []int16{chatprompt.CurrentContentVersion},
Content: []string{string(secondContent.RawMessage)},
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
InputTokens: []int64{0},
OutputTokens: []int64{0},
TotalTokens: []int64{0},
ReasoningTokens: []int64{0},
CacheCreationTokens: []int64{0},
CacheReadTokens: []int64{0},
ContextLimit: []int64{0},
Compressed: []bool{false},
TotalCostMicros: []int64{0},
RuntimeMs: []int64{0},
})
require.NoError(t, err)
msg2 := msg2Results[0]
thirdContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
codersdk.ChatMessageText("third"),
})
require.NoError(t, err)
_, err = db.InsertChatMessages(ctx, database.InsertChatMessagesParams{
ChatID: chat.ID,
CreatedBy: []uuid.UUID{uuid.Nil},
ModelConfigID: []uuid.UUID{model.ID},
Role: []database.ChatMessageRole{database.ChatMessageRoleUser},
ContentVersion: []int16{chatprompt.CurrentContentVersion},
Content: []string{string(thirdContent.RawMessage)},
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
InputTokens: []int64{0},
OutputTokens: []int64{0},
TotalTokens: []int64{0},
ReasoningTokens: []int64{0},
CacheCreationTokens: []int64{0},
CacheReadTokens: []int64{0},
ContextLimit: []int64{0},
Compressed: []bool{false},
TotalCostMicros: []int64{0},
RuntimeMs: []int64{0},
})
require.NoError(t, err)
// Control: Subscribe with afterMessageID=0 returns ALL messages.
allSnapshot, _, cancelAll, ok := replica.Subscribe(ctx, chat.ID, nil, 0)
require.True(t, ok)
cancelAll()
allMessages := filterMessageEvents(allSnapshot)
require.Len(t, allMessages, 3, "afterMessageID=0 should return all three messages")
// Subscribe with afterMessageID set to the second message's ID.
// Only the third message (inserted after msg2) should appear.
partialSnapshot, _, cancelPartial, ok := replica.Subscribe(ctx, chat.ID, nil, msg2.ID)
require.True(t, ok)
cancelPartial()
partialMessages := filterMessageEvents(partialSnapshot)
require.Len(t, partialMessages, 1, "afterMessageID=msg2.ID should return only messages after msg2")
require.Equal(t, codersdk.ChatMessageRoleUser, partialMessages[0].Message.Role)
}
// filterMessageEvents returns only the Message-type events from a
// snapshot slice, which is useful for ignoring status / queue events.
func filterMessageEvents(events []codersdk.ChatStreamEvent) []codersdk.ChatStreamEvent {
return slice.Filter(events, func(e codersdk.ChatStreamEvent) bool {
return e.Type == codersdk.ChatStreamEventTypeMessage
})
}
func TestCreateWorkspaceTool_EndToEnd(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
deploymentValues := coderdtest.DeploymentValues(t)
deploymentValues.Experiments = []string{string(codersdk.ExperimentAgents)}
client := coderdtest.New(t, &coderdtest.Options{
DeploymentValues: deploymentValues,
IncludeProvisionerDaemon: true,
})
user := coderdtest.CreateFirstUser(t, client)
expClient := codersdk.NewExperimentalClient(client)
agentToken := uuid.NewString()
// Add a startup script so the agent spends time in the
// "starting" lifecycle state. This lets us verify that
// create_workspace waits for scripts to finish.
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
Parse: echo.ParseComplete,
ProvisionPlan: echo.PlanComplete,
ProvisionApply: echo.ApplyComplete,
ProvisionGraph: echo.ProvisionGraphWithAgent(agentToken, func(g *proto.GraphComplete) {
g.Resources[0].Agents[0].Scripts = []*proto.Script{{
DisplayName: "setup",
Script: "sleep 5",
RunOnStart: true,
}}
}),
})
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
// Start the test workspace agent so create_workspace can wait for
// the agent to become reachable before returning.
_ = agenttest.New(t, client.URL, agentToken)
workspaceName := "chat-ws-" + strings.ReplaceAll(uuid.NewString(), "-", "")[:8]
createWorkspaceArgs := fmt.Sprintf(
`{"template_id":%q,"name":%q}`,
template.ID.String(),
workspaceName,
)
var streamedCallCount atomic.Int32
var streamedCallsMu sync.Mutex
streamedCalls := make([][]chattest.OpenAIMessage, 0, 2)
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
if !req.Stream {
return chattest.OpenAINonStreamingResponse("Create workspace test")
}
streamedCallsMu.Lock()
streamedCalls = append(streamedCalls, append([]chattest.OpenAIMessage(nil), req.Messages...))
streamedCallsMu.Unlock()
if streamedCallCount.Add(1) == 1 {
return chattest.OpenAIStreamingResponse(
chattest.OpenAIToolCallChunk("create_workspace", createWorkspaceArgs),
)
}
return chattest.OpenAIStreamingResponse(
chattest.OpenAITextChunks("Workspace created and ready.")...,
)
})
_, err := expClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
Provider: "openai-compat",
APIKey: "test-api-key",
BaseURL: openAIURL,
})
require.NoError(t, err)
contextLimit := int64(4096)
isDefault := true
_, err = expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: "openai-compat",
Model: "gpt-4o-mini",
ContextLimit: &contextLimit,
IsDefault: &isDefault,
})
require.NoError(t, err)
chat, err := expClient.CreateChat(ctx, codersdk.CreateChatRequest{
OrganizationID: user.OrganizationID,
Content: []codersdk.ChatInputPart{
{
Type: codersdk.ChatInputPartTypeText,
Text: "Create a workspace from the template and continue.",
},
},
})
require.NoError(t, err)
var chatResult codersdk.Chat
require.Eventually(t, func() bool {
got, getErr := expClient.GetChat(ctx, chat.ID)
if getErr != nil {
return false
}
chatResult = got
return got.Status == codersdk.ChatStatusWaiting || got.Status == codersdk.ChatStatusError
}, testutil.WaitLong, testutil.IntervalFast)
if chatResult.Status == codersdk.ChatStatusError {
lastError := ""
if chatResult.LastError != nil {
lastError = *chatResult.LastError
}
require.FailNowf(t, "chat run failed", "last_error=%q", lastError)
}
require.NotNil(t, chatResult.WorkspaceID)
workspaceID := *chatResult.WorkspaceID
workspace, err := client.Workspace(ctx, workspaceID)
require.NoError(t, err)
require.Equal(t, workspaceName, workspace.Name)
chatMsgs, err := expClient.GetChatMessages(ctx, chat.ID, nil)
require.NoError(t, err)
var foundCreateWorkspaceResult bool
for _, message := range chatMsgs.Messages {
if message.Role != codersdk.ChatMessageRoleTool {
continue
}
for _, part := range message.Content {
if part.Type != codersdk.ChatMessagePartTypeToolResult || part.ToolName != "create_workspace" {
continue
}
var result map[string]any
require.NoError(t, json.Unmarshal(part.Result, &result))
created, ok := result["created"].(bool)
require.True(t, ok)
require.True(t, created)
foundCreateWorkspaceResult = true
}
}
require.True(t, foundCreateWorkspaceResult, "expected create_workspace tool result message")
// Verify that the tool waited for startup scripts to
// complete. The agent should be in "ready" state by the
// time create_workspace returns its result.
workspace, err = client.Workspace(ctx, workspaceID)
require.NoError(t, err)
var agentLifecycle codersdk.WorkspaceAgentLifecycle
for _, res := range workspace.LatestBuild.Resources {
for _, agt := range res.Agents {
agentLifecycle = agt.LifecycleState
}
}
require.Equal(t, codersdk.WorkspaceAgentLifecycleReady, agentLifecycle,
"agent should be ready after create_workspace returns; startup scripts were not awaited")
require.GreaterOrEqual(t, streamedCallCount.Load(), int32(2))
streamedCallsMu.Lock()
recordedStreamCalls := append([][]chattest.OpenAIMessage(nil), streamedCalls...)
streamedCallsMu.Unlock()
require.GreaterOrEqual(t, len(recordedStreamCalls), 2)
var foundToolResultInSecondCall bool
for _, message := range recordedStreamCalls[1] {
if message.Role != "tool" {
continue
}
if !json.Valid([]byte(message.Content)) {
continue
}
var result map[string]any
if err := json.Unmarshal([]byte(message.Content), &result); err != nil {
continue
}
created, ok := result["created"].(bool)
if ok && created {
foundToolResultInSecondCall = true
break
}
}
require.True(t, foundToolResultInSecondCall, "expected second streamed model call to include create_workspace tool output")
}
func TestStartWorkspaceTool_EndToEnd(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitSuperLong)
deploymentValues := coderdtest.DeploymentValues(t)
deploymentValues.Experiments = []string{string(codersdk.ExperimentAgents)}
client := coderdtest.New(t, &coderdtest.Options{
DeploymentValues: deploymentValues,
IncludeProvisionerDaemon: true,
})
user := coderdtest.CreateFirstUser(t, client)
expClient := codersdk.NewExperimentalClient(client)
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
Parse: echo.ParseComplete,
ProvisionPlan: echo.PlanComplete,
ProvisionApply: echo.ApplyComplete,
})
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
// Create a workspace, then stop it so start_workspace has
// something to start. We intentionally skip starting a test
// agent — the echo provisioner creates new agent rows for each
// build, so an agent started for build 1 cannot serve build 3.
// The tool handles the no-agent case gracefully.
workspace := coderdtest.CreateWorkspace(t, client, template.ID)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
workspace = coderdtest.MustTransitionWorkspace(
t, client, workspace.ID,
codersdk.WorkspaceTransitionStart, codersdk.WorkspaceTransitionStop,
)
var streamedCallCount atomic.Int32
var streamedCallsMu sync.Mutex
streamedCalls := make([][]chattest.OpenAIMessage, 0, 2)
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
if !req.Stream {
return chattest.OpenAINonStreamingResponse("Start workspace test")
}
streamedCallsMu.Lock()
streamedCalls = append(streamedCalls, append([]chattest.OpenAIMessage(nil), req.Messages...))
streamedCallsMu.Unlock()
if streamedCallCount.Add(1) == 1 {
return chattest.OpenAIStreamingResponse(
chattest.OpenAIToolCallChunk("start_workspace", "{}"),
)
}
return chattest.OpenAIStreamingResponse(
chattest.OpenAITextChunks("Workspace started and ready.")...,
)
})
_, err := expClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
Provider: "openai-compat",
APIKey: "test-api-key",
BaseURL: openAIURL,
})
require.NoError(t, err)
contextLimit := int64(4096)
isDefault := true
_, err = expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: "openai-compat",
Model: "gpt-4o-mini",
ContextLimit: &contextLimit,
IsDefault: &isDefault,
})
require.NoError(t, err)
// Create a chat with the stopped workspace pre-associated.
chat, err := expClient.CreateChat(ctx, codersdk.CreateChatRequest{
OrganizationID: user.OrganizationID,
Content: []codersdk.ChatInputPart{
{
Type: codersdk.ChatInputPartTypeText,
Text: "Start the workspace.",
},
},
WorkspaceID: &workspace.ID,
})
require.NoError(t, err)
var chatResult codersdk.Chat
require.Eventually(t, func() bool {
got, getErr := expClient.GetChat(ctx, chat.ID)
if getErr != nil {
return false
}
chatResult = got
return got.Status == codersdk.ChatStatusWaiting || got.Status == codersdk.ChatStatusError
}, testutil.WaitSuperLong, testutil.IntervalFast)
if chatResult.Status == codersdk.ChatStatusError {
lastError := ""
if chatResult.LastError != nil {
lastError = *chatResult.LastError
}
require.FailNowf(t, "chat run failed", "last_error=%q", lastError)
}
// Verify the workspace was started.
require.NotNil(t, chatResult.WorkspaceID)
updatedWorkspace, err := client.Workspace(ctx, workspace.ID)
require.NoError(t, err)
require.Equal(t, codersdk.WorkspaceTransitionStart, updatedWorkspace.LatestBuild.Transition)
chatMsgs, err := expClient.GetChatMessages(ctx, chat.ID, nil)
require.NoError(t, err)
// Verify start_workspace tool result exists in the chat messages.
var foundStartWorkspaceResult bool
for _, message := range chatMsgs.Messages {
if message.Role != codersdk.ChatMessageRoleTool {
continue
}
for _, part := range message.Content {
if part.Type != codersdk.ChatMessagePartTypeToolResult || part.ToolName != "start_workspace" {
continue
}
var result map[string]any
require.NoError(t, json.Unmarshal(part.Result, &result))
started, ok := result["started"].(bool)
require.True(t, ok)
require.True(t, started)
foundStartWorkspaceResult = true
}
}
require.True(t, foundStartWorkspaceResult, "expected start_workspace tool result message")
// Verify the LLM received the tool result in its second call.
require.GreaterOrEqual(t, streamedCallCount.Load(), int32(2))
streamedCallsMu.Lock()
recordedStreamCalls := append([][]chattest.OpenAIMessage(nil), streamedCalls...)
streamedCallsMu.Unlock()
require.GreaterOrEqual(t, len(recordedStreamCalls), 2)
var foundToolResultInSecondCall bool
for _, message := range recordedStreamCalls[1] {
if message.Role != "tool" {
continue
}
if !json.Valid([]byte(message.Content)) {
continue
}
var result map[string]any
if err := json.Unmarshal([]byte(message.Content), &result); err != nil {
continue
}
started, ok := result["started"].(bool)
if ok && started {
foundToolResultInSecondCall = true
break
}
}
require.True(t, foundToolResultInSecondCall, "expected second streamed model call to include start_workspace tool output")
}
func TestStoppedWorkspaceWithPersistedAgentBindingDoesNotBlockChat(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitLong)
var streamedCallCount atomic.Int32
var streamedCallsMu sync.Mutex
streamedCalls := make([][]chattest.OpenAIMessage, 0, 2)
toolsByCall := make([][]string, 0, 2)
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
if !req.Stream {
return chattest.OpenAINonStreamingResponse("Stopped workspace regression")
}
names := make([]string, 0, len(req.Tools))
for _, tool := range req.Tools {
names = append(names, tool.Function.Name)
}
streamedCallsMu.Lock()
streamedCalls = append(streamedCalls, append([]chattest.OpenAIMessage(nil), req.Messages...))
toolsByCall = append(toolsByCall, names)
streamedCallsMu.Unlock()
if streamedCallCount.Add(1) == 1 {
return chattest.OpenAIStreamingResponse(
chattest.OpenAIToolCallChunk("execute", `{"command":"echo hi"}`),
)
}
return chattest.OpenAIStreamingResponse(
chattest.OpenAITextChunks("The workspace is unavailable. Start it before retrying workspace tools.")...,
)
})
user, org, model := seedChatDependenciesWithProvider(ctx, t, db, "openai-compat", openAIURL)
ws, dbAgent := seedWorkspaceWithAgent(t, db, user.ID)
inactive := newTestServer(t, db, ps, uuid.New())
chat, err := inactive.CreateChat(ctx, chatd.CreateOptions{
OrganizationID: org.ID,
OwnerID: user.ID,
Title: "stopped-workspace-regression",
ModelConfigID: model.ID,
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
InitialUserContent: []codersdk.ChatMessagePart{
codersdk.ChatMessageText("Run echo hi in the workspace."),
},
})
require.NoError(t, err)
// Close the inactive server so its wake-triggered processing
// stops and releases the chat. Then reset to pending so the
// active server (created below) can acquire it cleanly.
require.NoError(t, inactive.Close())
_, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
ID: chat.ID,
Status: database.ChatStatusPending,
WorkerID: uuid.NullUUID{},
StartedAt: sql.NullTime{},
HeartbeatAt: sql.NullTime{},
LastError: sql.NullString{},
})
require.NoError(t, err)
build, err := db.GetLatestWorkspaceBuildByWorkspaceID(ctx, ws.ID)
require.NoError(t, err)
chat, err = db.UpdateChatBuildAgentBinding(ctx, database.UpdateChatBuildAgentBindingParams{
ID: chat.ID,
BuildID: uuid.NullUUID{UUID: build.ID, Valid: true},
AgentID: uuid.NullUUID{UUID: dbAgent.ID, Valid: true},
})
require.NoError(t, err)
dbfake.WorkspaceBuild(t, db, ws).Seed(database.WorkspaceBuild{
Transition: database.WorkspaceTransitionStop,
BuildNumber: 2,
}).Do()
var dialCalls atomic.Int32
_ = newActiveTestServer(t, db, ps, func(cfg *chatd.Config) {
cfg.AgentConn = func(ctx context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) {
dialCalls.Add(1)
require.Equal(t, dbAgent.ID, agentID)
<-ctx.Done()
return nil, nil, ctx.Err()
}
})
var chatResult database.Chat
require.Eventually(t, func() bool {
got, getErr := db.GetChatByID(ctx, chat.ID)
if getErr != nil {
return false
}
chatResult = got
return got.Status == database.ChatStatusWaiting || got.Status == database.ChatStatusError
}, testutil.WaitLong, testutil.IntervalFast)
if chatResult.Status == database.ChatStatusError {
require.FailNowf(t, "chat failed", "last_error=%q", chatResult.LastError.String)
}
require.EqualValues(t, 1, dialCalls.Load())
require.GreaterOrEqual(t, streamedCallCount.Load(), int32(2))
streamedCallsMu.Lock()
recordedCalls := append([][]chattest.OpenAIMessage(nil), streamedCalls...)
recordedTools := append([][]string(nil), toolsByCall...)
streamedCallsMu.Unlock()
require.GreaterOrEqual(t, len(recordedCalls), 2)
require.NotEmpty(t, recordedTools)
require.Contains(t, recordedTools[0], "execute")
require.Contains(t, recordedTools[0], "start_workspace")
var foundUnavailableToolResult bool
for _, message := range recordedCalls[1] {
if message.Role != "tool" {
continue
}
if strings.Contains(message.Content, "workspace has no running agent") {
foundUnavailableToolResult = true
break
}
if !json.Valid([]byte(message.Content)) {
continue
}
var toolResult map[string]any
if err := json.Unmarshal([]byte(message.Content), &toolResult); err != nil {
continue
}
errMsg, _ := toolResult["error"].(string)
outputMsg, _ := toolResult["output"].(string)
if strings.Contains(errMsg, "workspace has no running agent") ||
strings.Contains(outputMsg, "workspace has no running agent") {
foundUnavailableToolResult = true
break
}
}
require.True(t, foundUnavailableToolResult,
"expected the second streamed model call to include the unavailable workspace tool result")
var toolMessage *database.ChatMessage
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
messages, dbErr := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: chat.ID,
AfterID: 0,
})
if dbErr != nil {
return false
}
for i := range messages {
if messages[i].Role == database.ChatMessageRoleTool {
toolMessage = &messages[i]
return true
}
}
return false
}, testutil.IntervalFast)
require.NotNil(t, toolMessage)
parts, err := chatprompt.ParseContent(*toolMessage)
require.NoError(t, err)
require.Len(t, parts, 1)
require.Equal(t, codersdk.ChatMessagePartTypeToolResult, parts[0].Type)
require.Equal(t, "execute", parts[0].ToolName)
require.True(t, parts[0].IsError)
require.Contains(t, string(parts[0].Result), "workspace has no running agent")
}
func TestHeartbeatBumpsWorkspaceUsage(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitLong)
user, org, model := seedChatDependencies(ctx, t, db)
setOpenAIProviderBaseURL(ctx, t, db, chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
if !req.Stream {
return chattest.OpenAINonStreamingResponse("ok")
}
// Block until the request context is canceled so the chat
// stays in a processing state long enough for heartbeats
// to fire.
chunks := make(chan chattest.OpenAIChunk)
go func() {
defer close(chunks)
<-req.Context().Done()
}()
return chattest.OpenAIResponse{StreamingChunks: chunks}
}))
// Create a workspace with a full build chain so we can verify
// both last_used_at (dormancy) and deadline (autostop) bumps.
tv := dbgen.TemplateVersion(t, db, database.TemplateVersion{
OrganizationID: org.ID,
CreatedBy: user.ID,
})
tmpl := dbgen.Template(t, db, database.Template{
OrganizationID: org.ID,
ActiveVersionID: tv.ID,
CreatedBy: user.ID,
})
require.NoError(t, db.UpdateTemplateScheduleByID(ctx, database.UpdateTemplateScheduleByIDParams{
ID: tmpl.ID,
UpdatedAt: dbtime.Now(),
AllowUserAutostop: true,
ActivityBump: int64(time.Hour),
}))
ws := dbgen.Workspace(t, db, database.WorkspaceTable{
OwnerID: user.ID,
OrganizationID: org.ID,
TemplateID: tmpl.ID,
Ttl: sql.NullInt64{Valid: true, Int64: int64(8 * time.Hour)},
})
pj := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
OrganizationID: org.ID,
CompletedAt: sql.NullTime{
Valid: true,
Time: dbtime.Now().Add(-30 * time.Minute),
},
})
// Build deadline is 30 minutes in the past — close enough to
// be bumped by the default 1-hour activity bump.
build := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
WorkspaceID: ws.ID,
TemplateVersionID: tv.ID,
JobID: pj.ID,
Transition: database.WorkspaceTransitionStart,
Deadline: dbtime.Now().Add(-30 * time.Minute),
})
originalDeadline := build.Deadline
// Set up a short heartbeat interval and a UsageTracker that
// flushes frequently so last_used_at gets updated in the DB.
flushTick := make(chan time.Time)
flushDone := make(chan int, 1)
tracker := workspacestats.NewTracker(db,
workspacestats.TrackerWithTickFlush(flushTick, flushDone),
workspacestats.TrackerWithLogger(slogtest.Make(t, nil)),
)
t.Cleanup(func() { tracker.Close() })
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
// Wrap the database with dbauthz so the chatd server's
// AsChatd context is enforced on every query, matching
// production behavior.
authzDB := dbauthz.New(db, rbac.NewStrictCachingAuthorizer(prometheus.NewRegistry()), slogtest.Make(t, nil), coderdtest.AccessControlStorePointer())
server := chatd.New(chatd.Config{
Logger: logger,
Database: authzDB,
ReplicaID: uuid.New(),
Pubsub: ps,
PendingChatAcquireInterval: 10 * time.Millisecond,
InFlightChatStaleAfter: testutil.WaitLong,
ChatHeartbeatInterval: 100 * time.Millisecond,
UsageTracker: tracker,
})
t.Cleanup(func() {
require.NoError(t, server.Close())
})
// Create a chat WITHOUT a workspace, the normal starting state.
// In production, CreateChat is called from the HTTP handler with
// the authenticated user's context. Here we use AsChatd since
// the chatd server processes everything under that role.
chatCtx := dbauthz.AsChatd(ctx)
chat, err := server.CreateChat(chatCtx, chatd.CreateOptions{
OrganizationID: org.ID,
OwnerID: user.ID,
Title: "usage-tracking-test",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
})
require.NoError(t, err)
// Wait for the chat to start processing and at least one
// heartbeat to fire.
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
fromDB, listErr := db.GetChatByID(ctx, chat.ID)
if listErr != nil {
return false
}
return fromDB.Status == database.ChatStatusRunning &&
fromDB.HeartbeatAt.Valid &&
fromDB.HeartbeatAt.Time.After(fromDB.CreatedAt)
}, testutil.IntervalFast,
"chat should be running with at least one heartbeat")
// Flush the tracker and verify nothing was tracked yet
// (no workspace linked).
testutil.RequireSend(ctx, t, flushTick, time.Now())
count := testutil.RequireReceive(ctx, t, flushDone)
require.Equal(t, 0, count,
"expected no workspaces to be flushed before association")
// Link the workspace to the chat in the DB, simulating what
// the create_workspace tool does mid-conversation.
_, err = db.UpdateChatWorkspaceBinding(ctx, database.UpdateChatWorkspaceBindingParams{
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
ID: chat.ID,
})
require.NoError(t, err)
// The heartbeat re-reads the workspace association from the DB
// on each tick. Wait for the tracker to pick it up.
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
select {
case flushTick <- time.Now():
case <-ctx.Done():
return false
}
select {
case c := <-flushDone:
return c > 0
case <-ctx.Done():
return false
}
}, testutil.IntervalMedium,
"expected usage tracker to flush the late-associated workspace")
// Verify the workspace's last_used_at was actually updated.
updatedWs, err := db.GetWorkspaceByID(ctx, ws.ID)
require.NoError(t, err)
require.True(t, updatedWs.LastUsedAt.After(ws.LastUsedAt),
"workspace last_used_at should have been bumped")
// Verify the workspace build deadline was also extended.
// The SQL only writes when 5% of the deadline has elapsed —
// most calls perform a read-only CTE lookup. Wider ±2
// minute tolerance than activitybump_test.go because the bump
// happens asynchronously via the heartbeat goroutine.
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
updatedBuild, buildErr := db.GetLatestWorkspaceBuildByWorkspaceID(ctx, ws.ID)
if buildErr != nil || !updatedBuild.Deadline.After(originalDeadline) {
return false
}
now := dbtime.Now()
return updatedBuild.Deadline.After(now.Add(time.Hour-2*time.Minute)) &&
updatedBuild.Deadline.Before(now.Add(time.Hour+2*time.Minute))
}, testutil.IntervalFast,
"workspace build deadline should have been bumped to ~now+1h")
}
func TestHeartbeatNoWorkspaceNoBump(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitLong)
user, org, model := seedChatDependencies(ctx, t, db)
setOpenAIProviderBaseURL(ctx, t, db, chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
if !req.Stream {
return chattest.OpenAINonStreamingResponse("ok")
}
chunks := make(chan chattest.OpenAIChunk)
go func() {
defer close(chunks)
<-req.Context().Done()
}()
return chattest.OpenAIResponse{StreamingChunks: chunks}
}))
// Set up UsageTracker with manual tick/flush.
usageTickCh := make(chan time.Time)
flushCh := make(chan int, 1)
tracker := workspacestats.NewTracker(db,
workspacestats.TrackerWithTickFlush(usageTickCh, flushCh),
workspacestats.TrackerWithLogger(slogtest.Make(t, nil)),
)
t.Cleanup(func() { tracker.Close() })
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
server := chatd.New(chatd.Config{
Logger: logger,
Database: db,
ReplicaID: uuid.New(),
Pubsub: ps,
PendingChatAcquireInterval: 10 * time.Millisecond,
InFlightChatStaleAfter: testutil.WaitLong,
ChatHeartbeatInterval: 100 * time.Millisecond,
})
t.Cleanup(func() {
require.NoError(t, server.Close())
})
// Create a chat WITHOUT linking a workspace.
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
OrganizationID: org.ID,
OwnerID: user.ID,
Title: "no-workspace-test",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
})
require.NoError(t, err)
// Wait for the chat to be acquired and at least one heartbeat
// to fire.
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
fromDB, listErr := db.GetChatByID(ctx, chat.ID)
if listErr != nil {
return false
}
return fromDB.Status == database.ChatStatusRunning &&
fromDB.HeartbeatAt.Valid &&
fromDB.HeartbeatAt.Time.After(fromDB.CreatedAt)
}, testutil.IntervalFast,
"chat should be running with at least one heartbeat")
// Flush the tracker. Since no workspace was linked, count
// should be 0.
testutil.RequireSend(ctx, t, usageTickCh, time.Now())
count := testutil.RequireReceive(ctx, t, flushCh)
require.Equal(t, 0, count, "expected no workspaces to be flushed when chat has no workspace")
}
// waitForChatProcessed waits for a wake-triggered processOnce to
// fully complete for the given chat. It polls until the chat leaves
// both pending and running states (meaning processChat has finished
// its cleanup and updated the DB), then calls WaitUntilIdleForTest.
//
// Waiting for a terminal state (not just "not pending") avoids a
// WaitGroup Add/Wait race: AcquireChats changes the DB status to
// running before processOnce calls inflight.Add(1). If we only
// waited for status != pending, we could call Wait() while Add(1)
// hasn't happened yet.
func waitForChatProcessed(
ctx context.Context,
t *testing.T,
db database.Store,
chatID uuid.UUID,
server *chatd.Server,
) {
t.Helper()
require.Eventually(t, func() bool {
c, err := db.GetChatByID(ctx, chatID)
if err != nil {
return false
}
// Wait until the chat reaches a terminal state — neither
// pending (waiting to be acquired) nor running (being
// processed). This guarantees that inflight.Add(1) has
// already been called by processOnce.
return c.Status != database.ChatStatusPending &&
c.Status != database.ChatStatusRunning
}, testutil.WaitShort, testutil.IntervalFast)
chatd.WaitUntilIdleForTest(server)
}
func newTestServer(
t *testing.T,
db database.Store,
ps dbpubsub.Pubsub,
replicaID uuid.UUID,
) *chatd.Server {
t.Helper()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
server := chatd.New(chatd.Config{
Logger: logger,
Database: db,
ReplicaID: replicaID,
Pubsub: ps,
PendingChatAcquireInterval: testutil.WaitLong,
})
t.Cleanup(func() {
require.NoError(t, server.Close())
})
return server
}
// newActiveTestServer creates a chatd server that actively polls for
// and processes pending chats. Use this instead of newTestServer when
// the test needs the chat loop to actually run. Optional config
// overrides are applied after the defaults.
func newActiveTestServer(
t *testing.T,
db database.Store,
ps dbpubsub.Pubsub,
overrides ...func(*chatd.Config),
) *chatd.Server {
t.Helper()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
cfg := chatd.Config{
Logger: logger,
Database: db,
ReplicaID: uuid.New(),
Pubsub: ps,
PendingChatAcquireInterval: 10 * time.Millisecond,
InFlightChatStaleAfter: testutil.WaitSuperLong,
}
for _, o := range overrides {
o(&cfg)
}
server := chatd.New(cfg)
t.Cleanup(func() {
require.NoError(t, server.Close())
})
return server
}
func seedChatDependencies(
ctx context.Context,
t *testing.T,
db database.Store,
) (database.User, database.Organization, database.ChatModelConfig) {
t.Helper()
return seedChatDependenciesWithProvider(ctx, t, db, "openai", "")
}
// seedChatDependenciesWithProvider creates a user, organization,
// chat provider, and model config for the given provider type and
// base URL.
func seedChatDependenciesWithProvider(
ctx context.Context,
t *testing.T,
db database.Store,
provider string,
baseURL string,
) (database.User, database.Organization, database.ChatModelConfig) {
t.Helper()
user := dbgen.User(t, db, database.User{})
org := dbgen.Organization(t, db, database.Organization{})
dbgen.OrganizationMember(t, db, database.OrganizationMember{
UserID: user.ID,
OrganizationID: org.ID,
})
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
Provider: provider,
DisplayName: provider,
APIKey: "test-key",
BaseUrl: baseURL,
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
Enabled: true,
CentralApiKeyEnabled: true,
})
require.NoError(t, err)
model, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
Provider: provider,
Model: "gpt-4o-mini",
DisplayName: "Test Model",
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
Enabled: true,
IsDefault: true,
ContextLimit: 128000,
CompressionThreshold: 70,
Options: json.RawMessage(`{}`),
})
require.NoError(t, err)
return user, org, model
}
func seedChatDependenciesWithProviderPolicy(
ctx context.Context,
t *testing.T,
db database.Store,
provider string,
baseURL string,
apiKey string,
centralAPIKeyEnabled bool,
allowUserAPIKey bool,
allowCentralAPIKeyFallback bool,
) (database.User, database.Organization, database.ChatProvider, database.ChatModelConfig) {
t.Helper()
user := dbgen.User(t, db, database.User{})
org := dbgen.Organization(t, db, database.Organization{})
dbgen.OrganizationMember(t, db, database.OrganizationMember{
UserID: user.ID,
OrganizationID: org.ID,
})
providerConfig, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
Provider: provider,
DisplayName: provider,
APIKey: apiKey,
BaseUrl: baseURL,
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
Enabled: true,
CentralApiKeyEnabled: centralAPIKeyEnabled,
AllowUserApiKey: allowUserAPIKey,
AllowCentralApiKeyFallback: allowCentralAPIKeyFallback,
})
require.NoError(t, err)
model, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
Provider: provider,
Model: "gpt-4o-mini",
DisplayName: "Test Model",
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
Enabled: true,
IsDefault: true,
ContextLimit: 128000,
CompressionThreshold: 70,
Options: json.RawMessage(`{}`),
})
require.NoError(t, err)
return user, org, providerConfig, model
}
func waitForTerminalChatStatusEvent(
ctx context.Context,
t *testing.T,
events <-chan codersdk.ChatStreamEvent,
) codersdk.ChatStatus {
t.Helper()
var terminalStatus codersdk.ChatStatus
testutil.Eventually(ctx, t, func(context.Context) bool {
for {
select {
case event, ok := <-events:
if !ok {
return false
}
if event.Type != codersdk.ChatStreamEventTypeStatus || event.Status == nil {
continue
}
if event.Status.Status == codersdk.ChatStatusWaiting || event.Status.Status == codersdk.ChatStatusError {
terminalStatus = event.Status.Status
return true
}
default:
return false
}
}
}, testutil.IntervalFast)
return terminalStatus
}
func waitForTerminalChat(
ctx context.Context,
t *testing.T,
db database.Store,
chatID uuid.UUID,
) database.Chat {
t.Helper()
var chatResult database.Chat
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
got, err := db.GetChatByID(ctx, chatID)
if err != nil {
return false
}
chatResult = got
return got.Status == database.ChatStatusWaiting || got.Status == database.ChatStatusError
}, testutil.IntervalFast)
return chatResult
}
// seedWorkspaceWithAgent creates a full workspace chain with a connected
// agent. This is the common setup needed by tests that exercise tool
// execution against a workspace.
func seedWorkspaceWithAgent(
t *testing.T,
db database.Store,
userID uuid.UUID,
) (database.WorkspaceTable, database.WorkspaceAgent) {
t.Helper()
org := dbgen.Organization(t, db, database.Organization{})
tv := dbgen.TemplateVersion(t, db, database.TemplateVersion{
OrganizationID: org.ID,
CreatedBy: userID,
})
tpl := dbgen.Template(t, db, database.Template{
CreatedBy: userID,
OrganizationID: org.ID,
ActiveVersionID: tv.ID,
})
ws := dbgen.Workspace(t, db, database.WorkspaceTable{
TemplateID: tpl.ID,
OwnerID: userID,
OrganizationID: org.ID,
})
pj := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
InitiatorID: userID,
OrganizationID: org.ID,
})
_ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
TemplateVersionID: tv.ID,
WorkspaceID: ws.ID,
JobID: pj.ID,
})
res := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
Transition: database.WorkspaceTransitionStart,
JobID: pj.ID,
})
agent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
ResourceID: res.ID,
})
return ws, agent
}
func setOpenAIProviderBaseURL(
ctx context.Context,
t *testing.T,
db database.Store,
baseURL string,
) {
t.Helper()
provider, err := db.GetChatProviderByProvider(ctx, "openai")
require.NoError(t, err)
_, err = db.UpdateChatProvider(ctx, database.UpdateChatProviderParams{
ID: provider.ID,
DisplayName: provider.DisplayName,
APIKey: provider.APIKey,
BaseUrl: baseURL,
ApiKeyKeyID: provider.ApiKeyKeyID,
Enabled: provider.Enabled,
CentralApiKeyEnabled: provider.CentralApiKeyEnabled,
AllowUserApiKey: provider.AllowUserApiKey,
AllowCentralApiKeyFallback: provider.AllowCentralApiKeyFallback,
})
require.NoError(t, err)
}
func TestInterruptChatDoesNotSendWebPushNotification(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitLong)
// Set up a mock OpenAI that blocks until the request context is
// canceled (i.e. until the chat is interrupted).
streamStarted := make(chan struct{})
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
if !req.Stream {
return chattest.OpenAINonStreamingResponse("title")
}
chunks := make(chan chattest.OpenAIChunk, 1)
go func() {
defer close(chunks)
chunks <- chattest.OpenAITextChunks("partial")[0]
select {
case <-streamStarted:
default:
close(streamStarted)
}
// Block until the chat context is canceled by the interrupt.
<-req.Context().Done()
}()
return chattest.OpenAIResponse{StreamingChunks: chunks}
})
// Mock webpush dispatcher that records calls.
mockPush := &mockWebpushDispatcher{}
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
server := chatd.New(chatd.Config{
Logger: logger,
Database: db,
ReplicaID: uuid.New(),
Pubsub: ps,
PendingChatAcquireInterval: 10 * time.Millisecond,
InFlightChatStaleAfter: testutil.WaitSuperLong,
WebpushDispatcher: mockPush,
})
t.Cleanup(func() {
require.NoError(t, server.Close())
})
user, org, model := seedChatDependencies(ctx, t, db)
setOpenAIProviderBaseURL(ctx, t, db, openAIURL)
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
OrganizationID: org.ID,
OwnerID: user.ID,
Title: "interrupt-no-push",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
})
require.NoError(t, err)
// Wait for the chat to be picked up and start streaming.
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
fromDB, dbErr := db.GetChatByID(ctx, chat.ID)
if dbErr != nil {
return false
}
return fromDB.Status == database.ChatStatusRunning && fromDB.WorkerID.Valid
}, testutil.IntervalFast)
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
select {
case <-streamStarted:
return true
default:
return false
}
}, testutil.IntervalFast)
// Interrupt the chat.
updated := server.InterruptChat(ctx, chat)
require.Equal(t, database.ChatStatusWaiting, updated.Status)
// Wait for the chat to finish processing and return to waiting.
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
fromDB, dbErr := db.GetChatByID(ctx, chat.ID)
if dbErr != nil {
return false
}
return fromDB.Status == database.ChatStatusWaiting && !fromDB.WorkerID.Valid
}, testutil.IntervalFast)
// Verify no web push notification was dispatched.
require.Equal(t, int32(0), mockPush.dispatchCount.Load(),
"expected no web push dispatch for an interrupted chat")
}
// mockWebpushDispatcher implements webpush.Dispatcher and records Dispatch calls.
type mockWebpushDispatcher struct {
dispatchCount atomic.Int32
mu sync.Mutex
lastMessage codersdk.WebpushMessage
lastUserID uuid.UUID
}
func (m *mockWebpushDispatcher) Dispatch(_ context.Context, userID uuid.UUID, msg codersdk.WebpushMessage) error {
m.dispatchCount.Add(1)
m.mu.Lock()
m.lastMessage = msg
m.lastUserID = userID
m.mu.Unlock()
return nil
}
func (m *mockWebpushDispatcher) getLastMessage() codersdk.WebpushMessage {
m.mu.Lock()
defer m.mu.Unlock()
return m.lastMessage
}
func (*mockWebpushDispatcher) Test(_ context.Context, _ codersdk.WebpushSubscription) error {
return nil
}
func (*mockWebpushDispatcher) PublicKey() string {
return "test-vapid-public-key"
}
func TestSuccessfulChatSendsWebPushWithNavigationData(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitLong)
// Set up a mock OpenAI that returns a simple successful response.
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
if !req.Stream {
return chattest.OpenAINonStreamingResponse("title")
}
return chattest.OpenAIStreamingResponse(
chattest.OpenAITextChunks("done")...,
)
})
// Mock webpush dispatcher that captures the dispatched message.
mockPush := &mockWebpushDispatcher{}
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
server := chatd.New(chatd.Config{
Logger: logger,
Database: db,
ReplicaID: uuid.New(),
Pubsub: ps,
PendingChatAcquireInterval: 10 * time.Millisecond,
InFlightChatStaleAfter: testutil.WaitSuperLong,
WebpushDispatcher: mockPush,
})
t.Cleanup(func() {
require.NoError(t, server.Close())
})
user, org, model := seedChatDependencies(ctx, t, db)
setOpenAIProviderBaseURL(ctx, t, db, openAIURL)
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
OrganizationID: org.ID,
OwnerID: user.ID,
Title: "push-nav-test",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
})
require.NoError(t, err)
// Wait for the chat to complete and return to waiting status.
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
fromDB, dbErr := db.GetChatByID(ctx, chat.ID)
if dbErr != nil {
return false
}
return fromDB.Status == database.ChatStatusWaiting && !fromDB.WorkerID.Valid && mockPush.dispatchCount.Load() == 1
}, testutil.IntervalFast)
// Verify a web push notification was dispatched exactly once.
require.Equal(t, int32(1), mockPush.dispatchCount.Load(),
"expected exactly one web push dispatch for a completed chat")
// Verify the notification was sent to the correct user.
mockPush.mu.Lock()
capturedMsg := mockPush.lastMessage
capturedUserID := mockPush.lastUserID
mockPush.mu.Unlock()
require.Equal(t, user.ID, capturedUserID,
"web push should be dispatched to the chat owner")
// Verify the Data field contains the correct navigation URL.
expectedURL := fmt.Sprintf("/agents/%s", chat.ID)
require.Equal(t, expectedURL, capturedMsg.Data["url"],
"web push Data should contain the chat navigation URL")
}
func TestCloseDuringShutdownContextCanceledShouldRetryOnNewReplica(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitLong)
var requestCount atomic.Int32
streamStarted := make(chan struct{})
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
// Ignore non-streaming requests (e.g. title generation) so
// they don't interfere with the request counter used to
// coordinate the streaming chat flow.
if !req.Stream {
return chattest.OpenAINonStreamingResponse("shutdown-retry")
}
if requestCount.Add(1) == 1 {
chunks := make(chan chattest.OpenAIChunk, 1)
go func() {
defer close(chunks)
chunks <- chattest.OpenAITextChunks("partial")[0]
select {
case <-streamStarted:
default:
close(streamStarted)
}
<-req.Context().Done()
}()
return chattest.OpenAIResponse{StreamingChunks: chunks}
}
return chattest.OpenAIStreamingResponse(chattest.OpenAITextChunks("retry", " complete")...)
})
loggerA := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
serverA := chatd.New(chatd.Config{
Logger: loggerA,
Database: db,
ReplicaID: uuid.New(),
Pubsub: ps,
PendingChatAcquireInterval: 10 * time.Millisecond,
InFlightChatStaleAfter: testutil.WaitLong,
})
t.Cleanup(func() {
require.NoError(t, serverA.Close())
})
user, org, model := seedChatDependencies(ctx, t, db)
setOpenAIProviderBaseURL(ctx, t, db, openAIURL)
chat, err := serverA.CreateChat(ctx, chatd.CreateOptions{
OrganizationID: org.ID,
OwnerID: user.ID,
Title: "shutdown-retry",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
})
require.NoError(t, err)
require.Eventually(t, func() bool {
fromDB, dbErr := db.GetChatByID(ctx, chat.ID)
if dbErr != nil {
return false
}
return fromDB.Status == database.ChatStatusRunning && fromDB.WorkerID.Valid
}, testutil.WaitMedium, testutil.IntervalFast)
require.Eventually(t, func() bool {
select {
case <-streamStarted:
return true
default:
return false
}
}, testutil.WaitMedium, testutil.IntervalFast)
require.NoError(t, serverA.Close())
require.Eventually(t, func() bool {
fromDB, dbErr := db.GetChatByID(ctx, chat.ID)
if dbErr != nil {
return false
}
return fromDB.Status == database.ChatStatusPending &&
!fromDB.WorkerID.Valid &&
!fromDB.LastError.Valid
}, testutil.WaitMedium, testutil.IntervalFast)
loggerB := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
serverB := chatd.New(chatd.Config{
Logger: loggerB,
Database: db,
ReplicaID: uuid.New(),
Pubsub: ps,
PendingChatAcquireInterval: 10 * time.Millisecond,
InFlightChatStaleAfter: testutil.WaitLong,
})
t.Cleanup(func() {
require.NoError(t, serverB.Close())
})
require.Eventually(t, func() bool {
return requestCount.Load() >= 2
}, testutil.WaitMedium, testutil.IntervalFast)
require.Eventually(t, func() bool {
fromDB, dbErr := db.GetChatByID(ctx, chat.ID)
if dbErr != nil {
return false
}
return fromDB.Status == database.ChatStatusWaiting &&
!fromDB.WorkerID.Valid &&
!fromDB.LastError.Valid
}, testutil.WaitMedium, testutil.IntervalFast)
}
func TestSuccessfulChatSendsWebPushWithSummary(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitLong)
const assistantText = "I have completed the task successfully and all tests are passing now."
const summaryText = "Completed task and verified all tests pass."
var nonStreamingRequests atomic.Int32
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
if !req.Stream {
nonStreamingRequests.Add(1)
return chattest.OpenAINonStreamingResponse(summaryText)
}
return chattest.OpenAIStreamingResponse(
chattest.OpenAITextChunks(assistantText)...,
)
})
mockPush := &mockWebpushDispatcher{}
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
server := chatd.New(chatd.Config{
Logger: logger,
Database: db,
ReplicaID: uuid.New(),
Pubsub: ps,
PendingChatAcquireInterval: 10 * time.Millisecond,
InFlightChatStaleAfter: testutil.WaitSuperLong,
WebpushDispatcher: mockPush,
})
t.Cleanup(func() {
require.NoError(t, server.Close())
})
user, org, model := seedChatDependencies(ctx, t, db)
setOpenAIProviderBaseURL(ctx, t, db, openAIURL)
_, err := server.CreateChat(ctx, chatd.CreateOptions{
OrganizationID: org.ID,
OwnerID: user.ID,
Title: "summary-push-test",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("do the thing")},
})
require.NoError(t, err)
// The push notification is dispatched asynchronously after the
// chat finishes, so we poll for it rather than checking
// immediately after the status transitions to waiting.
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
return mockPush.dispatchCount.Load() >= 1
}, testutil.IntervalFast)
msg := mockPush.getLastMessage()
require.Equal(t, summaryText, msg.Body,
"push body should be the LLM-generated summary")
require.NotEqual(t, "Agent has finished running.", msg.Body,
"push body should not use the default fallback text")
require.Equal(t, int32(1), nonStreamingRequests.Load(),
"expected exactly one non-streaming request for push summary generation")
}
func TestSuccessfulChatSendsWebPushFallbackWithoutSummaryForEmptyAssistantText(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitLong)
var nonStreamingRequests atomic.Int32
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
if !req.Stream {
nonStreamingRequests.Add(1)
return chattest.OpenAINonStreamingResponse("unexpected summary request")
}
return chattest.OpenAIStreamingResponse(
chattest.OpenAITextChunks(" ")...,
)
})
mockPush := &mockWebpushDispatcher{}
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
server := chatd.New(chatd.Config{
Logger: logger,
Database: db,
ReplicaID: uuid.New(),
Pubsub: ps,
PendingChatAcquireInterval: 10 * time.Millisecond,
InFlightChatStaleAfter: testutil.WaitSuperLong,
WebpushDispatcher: mockPush,
})
t.Cleanup(func() {
require.NoError(t, server.Close())
})
user, org, model := seedChatDependencies(ctx, t, db)
setOpenAIProviderBaseURL(ctx, t, db, openAIURL)
_, err := server.CreateChat(ctx, chatd.CreateOptions{
OrganizationID: org.ID,
OwnerID: user.ID,
Title: "empty-summary-push-test",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("do the thing")},
})
require.NoError(t, err)
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
return mockPush.dispatchCount.Load() >= 1
}, testutil.IntervalFast)
msg := mockPush.getLastMessage()
require.Equal(t, "Agent has finished running.", msg.Body,
"push body should fall back when the final assistant text is empty")
require.Equal(t, int32(0), nonStreamingRequests.Load(),
"push summary should not be requested when final assistant text has no usable text")
}
func TestComputerUseSubagentToolsAndModel(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitLong)
// Track tools and model from the Anthropic LLM calls (the
// computer use child chat). We use a raw HTTP handler because
// the chattest AnthropicRequest struct does not capture tools.
type anthropicCall struct {
Model string
Tools []string
}
var anthropicMu sync.Mutex
var anthropicCalls []anthropicCall
anthropicSrv := httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
var req struct {
Model string `json:"model"`
Stream bool `json:"stream"`
Tools []struct {
Name string `json:"name"`
} `json:"tools"`
}
if err := json.Unmarshal(body, &req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
names := make([]string, len(req.Tools))
for i, tool := range req.Tools {
names[i] = tool.Name
}
anthropicMu.Lock()
anthropicCalls = append(anthropicCalls, anthropicCall{
Model: req.Model,
Tools: names,
})
anthropicMu.Unlock()
if !req.Stream {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{
"id": "msg-test",
"type": "message",
"role": "assistant",
"model": chattool.ComputerUseModelName,
"content": []map[string]any{{"type": "text", "text": "Done."}},
"stop_reason": "end_turn",
"usage": map[string]any{"input_tokens": 10, "output_tokens": 5},
})
return
}
// Stream a minimal Anthropic SSE response.
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
flusher, _ := w.(http.Flusher)
chunks := []map[string]any{
{
"type": "message_start",
"message": map[string]any{
"id": "msg-test",
"type": "message",
"role": "assistant",
"model": chattool.ComputerUseModelName,
},
},
{
"type": "content_block_start",
"index": 0,
"content_block": map[string]any{
"type": "text",
"text": "",
},
},
{
"type": "content_block_delta",
"index": 0,
"delta": map[string]any{
"type": "text_delta",
"text": "Done.",
},
},
{"type": "content_block_stop", "index": 0},
{
"type": "message_delta",
"delta": map[string]any{"stop_reason": "end_turn"},
"usage": map[string]any{"output_tokens": 5},
},
{"type": "message_stop"},
}
for _, chunk := range chunks {
chunkBytes, _ := json.Marshal(chunk)
eventType, _ := chunk["type"].(string)
_, _ = fmt.Fprintf(w, "event: %s\ndata: %s\n\n",
eventType, chunkBytes)
flusher.Flush()
}
},
))
t.Cleanup(anthropicSrv.Close)
// OpenAI mock for the root chat. The first streaming call
// triggers spawn_computer_use_agent; subsequent calls reply
// with text.
var openAICallCount atomic.Int32
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
if !req.Stream {
return chattest.OpenAINonStreamingResponse("title")
}
if openAICallCount.Add(1) == 1 {
return chattest.OpenAIStreamingResponse(
chattest.OpenAIToolCallChunk(
"spawn_computer_use_agent",
`{"prompt":"do the desktop thing","title":"cu-sub"}`,
),
)
}
// Include literal \u0000 in the response text, which is
// what a real LLM writes when explaining binary output.
// json.Marshal encodes the backslash as \\, producing
// \\u0000 in the JSON bytes. The sanitizer must not
// corrupt this into invalid JSON.
return chattest.OpenAIStreamingResponse(
chattest.OpenAITextChunks("The file contains \\u0000 null bytes.")...,
)
})
// Seed the DB: user, openai-compat provider, model config.
user, org, model := seedChatDependenciesWithProvider(ctx, t, db, "openai-compat", openAIURL)
// Add an Anthropic provider pointing to our mock server.
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
Provider: "anthropic",
DisplayName: "Anthropic",
APIKey: "test-anthropic-key",
BaseUrl: anthropicSrv.URL,
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
Enabled: true,
CentralApiKeyEnabled: true,
})
require.NoError(t, err)
err = db.UpsertChatDesktopEnabled(ctx, true)
require.NoError(t, err)
// Build workspace + agent records so getWorkspaceConn can
// resolve the agent for the computer use child.
ws, dbAgent := seedWorkspaceWithAgent(t, db, user.ID)
// Mock agent connection that returns valid display dimensions
// for the initial screenshot check in the computer use path.
ctrl := gomock.NewController(t)
mockConn := agentconnmock.NewMockAgentConn(ctrl)
mockConn.EXPECT().
ListMCPTools(gomock.Any()).
Return(workspacesdk.ListMCPToolsResponse{}, nil).
AnyTimes()
mockConn.EXPECT().
ExecuteDesktopAction(gomock.Any(), gomock.Any()).
Return(workspacesdk.DesktopActionResponse{
ScreenshotWidth: 1920,
ScreenshotHeight: 1080,
ScreenshotData: "iVBOR",
}, nil).
AnyTimes()
mockConn.EXPECT().
SetExtraHeaders(gomock.Any()).
AnyTimes()
mockConn.EXPECT().
ContextConfig(gomock.Any()).
Return(workspacesdk.ContextConfigResponse{}, xerrors.New("not supported")).
AnyTimes()
mockConn.EXPECT().
LS(gomock.Any(), gomock.Any(), gomock.Any()).
Return(workspacesdk.LSResponse{}, xerrors.New("not found")).
AnyTimes()
server := newActiveTestServer(t, db, ps, func(cfg *chatd.Config) {
cfg.AgentConn = func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) {
require.Equal(t, dbAgent.ID, agentID)
return mockConn, func() {}, nil
}
})
// Create a root chat with a workspace so the child inherits it.
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
OrganizationID: org.ID,
OwnerID: user.ID,
Title: "computer-use-detection",
ModelConfigID: model.ID,
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
InitialUserContent: []codersdk.ChatMessagePart{
codersdk.ChatMessageText("Use the desktop to check the UI"),
},
})
require.NoError(t, err)
// Wait for the root chat AND the computer use child to finish.
// The root chat spawns the child, then the chatd server picks
// up and runs the child (which hits the Anthropic mock).
require.Eventually(t, func() bool {
got, getErr := db.GetChatByID(ctx, chat.ID)
if getErr != nil {
return false
}
if got.Status != database.ChatStatusWaiting &&
got.Status != database.ChatStatusError {
return false
}
// Ensure the Anthropic mock received at least one call.
anthropicMu.Lock()
n := len(anthropicCalls)
anthropicMu.Unlock()
return n >= 1
}, testutil.WaitLong, testutil.IntervalFast)
anthropicMu.Lock()
calls := append([]anthropicCall(nil), anthropicCalls...)
anthropicMu.Unlock()
require.NotEmpty(t, calls,
"expected at least one Anthropic LLM call")
childModel := calls[0].Model
childTools := calls[0].Tools
// 1. Verify the model is the computer use model.
require.Equal(t, chattool.ComputerUseModelName, childModel,
"computer use subagent should use %s",
chattool.ComputerUseModelName)
// 2. Verify the computer tool is present.
require.Contains(t, childTools, "computer",
"computer use subagent should have the computer tool")
// 3. Verify standard workspace tools are present (the same
// set a regular subagent gets).
standardTools := []string{
"read_file", "write_file", "edit_files", "execute",
"process_output", "process_list", "process_signal",
}
for _, tool := range standardTools {
require.Contains(t, childTools, tool,
"computer use subagent should have standard tool %q",
tool)
}
// 4. Verify workspace provisioning tools are NOT present.
workspaceProvisioningTools := []string{
"list_templates", "read_template",
"create_workspace", "start_workspace",
}
for _, tool := range workspaceProvisioningTools {
require.NotContains(t, childTools, tool,
"computer use subagent should NOT have workspace "+
"provisioning tool %q", tool)
}
// 5. Verify subagent tools are NOT present.
subagentTools := []string{
"spawn_agent", "spawn_computer_use_agent",
"wait_agent", "message_agent", "close_agent",
}
for _, tool := range subagentTools {
require.NotContains(t, childTools, tool,
"computer use subagent should NOT have subagent "+
"tool %q", tool)
}
// 6. Verify the child chat has Mode = computer_use in
// the DB.
allChats, err := db.GetChats(ctx, database.GetChatsParams{
OwnerID: user.ID,
})
require.NoError(t, err)
var children []database.Chat
for _, c := range allChats {
if c.Chat.ParentChatID.Valid && c.Chat.ParentChatID.UUID == chat.ID {
children = append(children, c.Chat)
}
}
require.Len(t, children, 1)
require.True(t, children[0].Mode.Valid)
require.Equal(t, database.ChatModeComputerUse,
children[0].Mode.ChatMode)
}
func TestInterruptChatPersistsPartialResponse(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitLong)
// Set up a mock OpenAI that streams a partial response and then
// blocks until the request context is canceled (simulating an
// interrupt mid-stream).
chunksDelivered := make(chan struct{})
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
if !req.Stream {
return chattest.OpenAINonStreamingResponse("title")
}
chunks := make(chan chattest.OpenAIChunk, 1)
go func() {
defer close(chunks)
// Send two partial text chunks so there is meaningful
// content to persist.
for _, c := range chattest.OpenAITextChunks("hello world") {
chunks <- c
}
// Signal that chunks have been written to the HTTP response.
select {
case <-chunksDelivered:
default:
close(chunksDelivered)
}
// Block until interrupt cancels the context.
<-req.Context().Done()
}()
return chattest.OpenAIResponse{StreamingChunks: chunks}
})
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
server := chatd.New(chatd.Config{
Logger: logger,
Database: db,
ReplicaID: uuid.New(),
Pubsub: ps,
PendingChatAcquireInterval: 10 * time.Millisecond,
InFlightChatStaleAfter: testutil.WaitSuperLong,
})
t.Cleanup(func() {
require.NoError(t, server.Close())
})
user, org, model := seedChatDependencies(ctx, t, db)
setOpenAIProviderBaseURL(ctx, t, db, openAIURL)
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
OrganizationID: org.ID,
OwnerID: user.ID,
Title: "interrupt-persist-test",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
})
require.NoError(t, err)
// Subscribe to the chat's event stream so we can observe
// message_part events — proof the chatloop has actually
// processed the streamed chunks.
_, events, subCancel, ok := server.Subscribe(ctx, chat.ID, nil, 0)
require.True(t, ok)
defer subCancel()
// Wait for the mock to finish sending chunks.
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
select {
case <-chunksDelivered:
return true
default:
return false
}
}, testutil.IntervalFast)
// Drain the event channel until we see a message_part event,
// which means the chatloop has consumed and published the chunk.
gotMessagePart := false
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
for {
select {
case ev := <-events:
if ev.Type == codersdk.ChatStreamEventTypeMessagePart {
gotMessagePart = true
return true
}
default:
return gotMessagePart
}
}
}, testutil.IntervalFast)
require.True(t, gotMessagePart, "should have received at least one message_part event")
// Now interrupt the chat — the chatloop has processed content.
updated := server.InterruptChat(ctx, chat)
require.Equal(t, database.ChatStatusWaiting, updated.Status)
// Wait for the partial assistant message to be persisted.
// After the interrupt, the chatloop runs persistInterruptedStep
// which inserts the message and publishes a "message" event.
// We poll the DB directly for the assistant message rather than
// relying on the chat status (which transitions to "waiting"
// before the persist completes).
var assistantMsg *database.ChatMessage
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
msgs, dbErr := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: chat.ID,
AfterID: 0,
})
if dbErr != nil {
return false
}
for i := range msgs {
if msgs[i].Role == database.ChatMessageRoleAssistant {
assistantMsg = &msgs[i]
return true
}
}
return false
}, testutil.IntervalFast)
require.NotNilf(t, assistantMsg, "expected a persisted assistant message after interrupt")
// Parse the content and verify it contains the partial text.
parts, err := chatprompt.ParseContent(*assistantMsg)
require.NoError(t, err)
var foundText string
for _, part := range parts {
if part.Type == codersdk.ChatMessagePartTypeText {
foundText += part.Text
}
}
require.Contains(t, foundText, "hello world",
"partial assistant response should contain the streamed text")
}
func TestProcessChat_UserProviderKey_Success(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitLong)
const userAPIKey = "user-test-key"
var authHeadersMu sync.Mutex
authHeaders := make([]string, 0, 1)
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
authHeadersMu.Lock()
authHeaders = append(authHeaders, req.Header.Get("Authorization"))
authHeadersMu.Unlock()
if !req.Stream {
return chattest.OpenAINonStreamingResponse("user provider key success")
}
return chattest.OpenAIStreamingResponse(
chattest.OpenAITextChunks("hello from the saved user key")...,
)
})
user, org, provider, model := seedChatDependenciesWithProviderPolicy(
ctx,
t,
db,
"openai-compat",
openAIURL,
"",
false,
true,
false,
)
_, err := db.UpsertUserChatProviderKey(ctx, database.UpsertUserChatProviderKeyParams{
UserID: user.ID,
ChatProviderID: provider.ID,
APIKey: userAPIKey,
})
require.NoError(t, err)
creator := newTestServer(t, db, ps, uuid.New())
chat, err := creator.CreateChat(ctx, chatd.CreateOptions{
OrganizationID: org.ID,
OwnerID: user.ID,
Title: "user-provider-key-success",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{
codersdk.ChatMessageText("say hello"),
},
})
require.NoError(t, err)
_, events, cancel, ok := creator.Subscribe(ctx, chat.ID, nil, 0)
require.True(t, ok)
t.Cleanup(cancel)
_ = newActiveTestServer(t, db, ps)
terminalStatus := waitForTerminalChatStatusEvent(ctx, t, events)
require.Equal(t, codersdk.ChatStatusWaiting, terminalStatus)
chatResult := waitForTerminalChat(ctx, t, db, chat.ID)
require.Equal(t, database.ChatStatusWaiting, chatResult.Status)
require.False(t, chatResult.LastError.Valid)
authHeadersMu.Lock()
recordedAuthHeaders := append([]string(nil), authHeaders...)
authHeadersMu.Unlock()
require.Contains(t, recordedAuthHeaders, "Bearer "+userAPIKey)
}
func TestProcessChat_UserProviderKey_MissingKeyError(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitLong)
var llmCalls atomic.Int32
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
llmCalls.Add(1)
if !req.Stream {
return chattest.OpenAINonStreamingResponse("unexpected non-streaming request")
}
return chattest.OpenAIStreamingResponse(
chattest.OpenAITextChunks("unexpected streaming request")...,
)
})
user, org, _, model := seedChatDependenciesWithProviderPolicy(
ctx,
t,
db,
"openai-compat",
openAIURL,
"",
false,
true,
false,
)
creator := newTestServer(t, db, ps, uuid.New())
chat, err := creator.CreateChat(ctx, chatd.CreateOptions{
OrganizationID: org.ID,
OwnerID: user.ID,
Title: "user-provider-key-missing",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{
codersdk.ChatMessageText("say hello"),
},
})
require.NoError(t, err)
_, events, cancel, ok := creator.Subscribe(ctx, chat.ID, nil, 0)
require.True(t, ok)
t.Cleanup(cancel)
_ = newActiveTestServer(t, db, ps)
terminalStatus := waitForTerminalChatStatusEvent(ctx, t, events)
require.Equal(t, codersdk.ChatStatusError, terminalStatus)
chatResult := waitForTerminalChat(ctx, t, db, chat.ID)
require.Equal(t, database.ChatStatusError, chatResult.Status)
require.True(t, chatResult.LastError.Valid, "LastError should be set")
require.NotEmpty(t, chatResult.LastError.String)
require.NotContains(t, chatResult.LastError.String, "panicked")
require.NotEqual(t, database.ChatStatusRunning, chatResult.Status)
require.Zero(t, llmCalls.Load(), "missing user key should fail before any LLM request")
}
func TestProcessChatPanicRecovery(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
// Wrap the database so we can trigger a panic on the main
// goroutine of processChat. The chatloop's executeTools has
// its own recover, so panicking inside a tool goroutine won't
// reach the processChat-level recovery. Instead, we panic
// during PersistStep's InTx call, which runs synchronously on
// the processChat goroutine.
panicWrapper := &panicOnInTxDB{Store: db}
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
if !req.Stream {
return chattest.OpenAINonStreamingResponse("Panic recovery test")
}
return chattest.OpenAIStreamingResponse(
chattest.OpenAITextChunks("hello")...,
)
})
ctx := testutil.Context(t, testutil.WaitLong)
user, org, model := seedChatDependenciesWithProvider(ctx, t, db, "openai-compat", openAIURL)
// Pass the panic wrapper to the server, but use the real
// database for seeding so those operations don't panic.
server := newActiveTestServer(t, panicWrapper, ps)
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
OrganizationID: org.ID,
OwnerID: user.ID,
Title: "panic-recovery",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{
codersdk.ChatMessageText("hello"),
},
})
require.NoError(t, err)
// Enable the panic now that CreateChat's InTx has completed.
// The next InTx call is PersistStep inside the chatloop,
// running synchronously on the processChat goroutine.
panicWrapper.enablePanic()
// Wait for the panic to be recovered and the chat to
// transition to error status.
var chatResult database.Chat
require.Eventually(t, func() bool {
got, getErr := db.GetChatByID(ctx, chat.ID)
if getErr != nil {
return false
}
chatResult = got
return got.Status == database.ChatStatusError
}, testutil.WaitLong, testutil.IntervalFast)
require.True(t, chatResult.LastError.Valid, "LastError should be set")
require.Contains(t, chatResult.LastError.String, "chat processing panicked")
require.Contains(t, chatResult.LastError.String, "intentional test panic")
}
// panicOnInTxDB wraps a database.Store and panics on the first InTx
// call after enablePanic is called. Subsequent calls pass through
// so the processChat cleanup defer can update the chat status.
type panicOnInTxDB struct {
database.Store
active atomic.Bool
panicked atomic.Bool
}
func (d *panicOnInTxDB) enablePanic() { d.active.Store(true) }
func (d *panicOnInTxDB) InTx(f func(database.Store) error, opts *database.TxOptions) error {
if d.active.Load() && !d.panicked.Load() {
d.panicked.Store(true)
panic("intentional test panic")
}
return d.Store.InTx(f, opts)
}
// TestMCPServerToolInvocation verifies that when a chat has
// mcp_server_ids set, the chat loop connects to those MCP servers,
// discovers their tools, and the LLM can invoke them.
//
// NOTE: This test uses a raw database.Store (no dbauthz wrapper).
// The chatd RBAC authorization of GetMCPServerConfigsByIDs (which
// requires ActionRead on ResourceDeploymentConfig) is covered by
// the chatd role definition tests, not here.
func TestMCPServerToolInvocation(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitLong)
// Start a real MCP server that exposes an "echo" tool.
mcpSrv := mcpserver.NewMCPServer("test-mcp", "1.0.0")
mcpSrv.AddTools(mcpserver.ServerTool{
Tool: mcpgo.NewTool("echo",
mcpgo.WithDescription("Echoes the input"),
mcpgo.WithString("input",
mcpgo.Description("The input string"),
mcpgo.Required(),
),
),
Handler: func(_ context.Context, req mcpgo.CallToolRequest) (*mcpgo.CallToolResult, error) {
input, _ := req.GetArguments()["input"].(string)
return mcpgo.NewToolResultText("echo: " + input), nil
},
})
mcpHTTP := mcpserver.NewStreamableHTTPServer(mcpSrv)
mcpTS := httptest.NewServer(mcpHTTP)
t.Cleanup(mcpTS.Close)
// Track which tool names are sent to the LLM and capture
// whether the MCP tool result appears in the second call.
var (
callCount atomic.Int32
llmToolNames []string
llmToolsMu sync.Mutex
foundMCPResult atomic.Bool
)
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
if !req.Stream {
return chattest.OpenAINonStreamingResponse("title")
}
// Record tool names from the first streamed call.
if callCount.Add(1) == 1 {
names := make([]string, 0, len(req.Tools))
for _, tool := range req.Tools {
names = append(names, tool.Function.Name)
}
llmToolsMu.Lock()
llmToolNames = names
llmToolsMu.Unlock()
// Ask the LLM to call the MCP echo tool.
return chattest.OpenAIStreamingResponse(
chattest.OpenAIToolCallChunk(
"test-mcp__echo",
`{"input":"hello from LLM"}`,
),
)
}
// Second call: verify the tool result was fed back.
for _, msg := range req.Messages {
if msg.Role == "tool" && strings.Contains(msg.Content, "echo: hello from LLM") {
foundMCPResult.Store(true)
}
}
return chattest.OpenAIStreamingResponse(
chattest.OpenAITextChunks("Got it!")...,
)
})
user, org, model := seedChatDependenciesWithProvider(ctx, t, db, "openai-compat", openAIURL)
// Seed the MCP server config in the database. This must
// happen after seedChatDependencies so user.ID exists for
// the foreign key.
mcpConfig, err := db.InsertMCPServerConfig(ctx, database.InsertMCPServerConfigParams{
DisplayName: "Test MCP",
Slug: "test-mcp",
Url: mcpTS.URL,
Transport: "streamable_http",
AuthType: "none",
Availability: "default_off",
Enabled: true,
ToolAllowList: []string{},
ToolDenyList: []string{},
CreatedBy: user.ID,
UpdatedBy: user.ID,
})
require.NoError(t, err)
ws, dbAgent := seedWorkspaceWithAgent(t, db, user.ID)
ctrl := gomock.NewController(t)
mockConn := agentconnmock.NewMockAgentConn(ctrl)
mockConn.EXPECT().SetExtraHeaders(gomock.Any()).AnyTimes()
mockConn.EXPECT().ContextConfig(gomock.Any()).
Return(workspacesdk.ContextConfigResponse{}, xerrors.New("not supported")).AnyTimes()
mockConn.EXPECT().ListMCPTools(gomock.Any()).
Return(workspacesdk.ListMCPToolsResponse{}, nil).AnyTimes()
mockConn.EXPECT().LS(gomock.Any(), gomock.Any(), gomock.Any()).
Return(workspacesdk.LSResponse{}, nil).AnyTimes()
mockConn.EXPECT().ReadFile(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(io.NopCloser(strings.NewReader("")), "", nil).AnyTimes()
server := newActiveTestServer(t, db, ps, func(cfg *chatd.Config) {
cfg.AgentConn = func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) {
require.Equal(t, dbAgent.ID, agentID)
return mockConn, func() {}, nil
}
})
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
OrganizationID: org.ID,
OwnerID: user.ID,
Title: "mcp-tool-test",
ModelConfigID: model.ID,
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
MCPServerIDs: []uuid.UUID{mcpConfig.ID},
InitialUserContent: []codersdk.ChatMessagePart{
codersdk.ChatMessageText("Echo something via MCP."),
},
})
require.NoError(t, err)
// Verify MCPServerIDs were persisted on the chat record.
dbChat, getErr := db.GetChatByID(ctx, chat.ID)
require.NoError(t, getErr)
require.Equal(t, []uuid.UUID{mcpConfig.ID}, dbChat.MCPServerIDs)
// Wait for the chat to finish processing.
var chatResult database.Chat
require.Eventually(t, func() bool {
got, getErr := db.GetChatByID(ctx, chat.ID)
if getErr != nil {
return false
}
chatResult = got
return got.Status == database.ChatStatusWaiting || got.Status == database.ChatStatusError
}, testutil.WaitLong, testutil.IntervalFast)
if chatResult.Status == database.ChatStatusError {
require.FailNowf(t, "chat failed", "last_error=%q", chatResult.LastError.String)
}
// The MCP tool (test-mcp__echo) should appear in the tool
// list sent to the LLM.
llmToolsMu.Lock()
recordedNames := append([]string(nil), llmToolNames...)
llmToolsMu.Unlock()
require.Contains(t, recordedNames, "test-mcp__echo",
"MCP tool should be in the tool list sent to the LLM")
// The tool result from the MCP server ("echo: hello from
// LLM") should have been fed back to the LLM as a tool
// message in the second call.
require.True(t, foundMCPResult.Load(),
"MCP tool result should appear in the second LLM call")
// Verify the tool result was persisted in the database.
var foundToolMessage bool
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
messages, dbErr := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: chat.ID,
AfterID: 0,
})
if dbErr != nil {
return false
}
for _, msg := range messages {
if msg.Role != database.ChatMessageRoleTool {
continue
}
parts, parseErr := chatprompt.ParseContent(msg)
if parseErr != nil || len(parts) == 0 {
continue
}
for _, part := range parts {
if part.Type == codersdk.ChatMessagePartTypeToolResult &&
part.ToolName == "test-mcp__echo" &&
strings.Contains(string(part.Result), "echo: hello from LLM") {
foundToolMessage = true
return true
}
}
}
return false
}, testutil.IntervalFast)
require.True(t, foundToolMessage,
"MCP tool result should be persisted as a tool message in the database")
}
// TestMCPServerOAuth2TokenRefresh verifies that when a chat uses an
// MCP server with OAuth2 auth and the stored access token is expired,
// chatd refreshes the token using the stored refresh_token before
// connecting. The refreshed token is persisted to the database and
// the MCP tool call succeeds.
func TestMCPServerOAuth2TokenRefresh(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitLong)
// The "fresh" token that the mock OAuth2 server returns after
// a successful refresh_token grant.
freshAccessToken := "fresh-access-token-" + uuid.New().String()
// Mock OAuth2 token endpoint that exchanges a refresh token
// for a new access token.
var refreshCalled atomic.Int32
tokenSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
refreshCalled.Add(1)
if r.Method != http.MethodPost {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
grantType := r.FormValue("grant_type")
if grantType != "refresh_token" {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte(`{"error":"unsupported_grant_type"}`))
return
}
w.Header().Set("Content-Type", "application/json")
_, _ = fmt.Fprintf(w, `{"access_token":%q,"token_type":"Bearer","expires_in":3600,"refresh_token":"rotated-refresh-token"}`, freshAccessToken)
}))
t.Cleanup(tokenSrv.Close)
// Start a real MCP server with an auth middleware that only
// accepts the fresh access token. An expired token (or any
// other value) gets a 401.
mcpSrv := mcpserver.NewMCPServer("authed-mcp", "1.0.0")
mcpSrv.AddTools(mcpserver.ServerTool{
Tool: mcpgo.NewTool("echo",
mcpgo.WithDescription("Echoes the input"),
mcpgo.WithString("input",
mcpgo.Description("The input string"),
mcpgo.Required(),
),
),
Handler: func(_ context.Context, req mcpgo.CallToolRequest) (*mcpgo.CallToolResult, error) {
input, _ := req.GetArguments()["input"].(string)
return mcpgo.NewToolResultText("echo: " + input), nil
},
})
mcpHTTP := mcpserver.NewStreamableHTTPServer(mcpSrv)
// Wrap with auth check.
authMux := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
auth := r.Header.Get("Authorization")
if auth != "Bearer "+freshAccessToken {
w.WriteHeader(http.StatusUnauthorized)
_, _ = w.Write([]byte(`{"error":"invalid_token","error_description":"The access token is invalid or expired"}`))
return
}
mcpHTTP.ServeHTTP(w, r)
})
mcpTS := httptest.NewServer(authMux)
t.Cleanup(mcpTS.Close)
// Track LLM interactions.
var (
callCount atomic.Int32
llmToolNames []string
llmToolsMu sync.Mutex
foundMCPResult atomic.Bool
)
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
if !req.Stream {
return chattest.OpenAINonStreamingResponse("title")
}
if callCount.Add(1) == 1 {
names := make([]string, 0, len(req.Tools))
for _, tool := range req.Tools {
names = append(names, tool.Function.Name)
}
llmToolsMu.Lock()
llmToolNames = names
llmToolsMu.Unlock()
// Ask the LLM to call the MCP echo tool.
return chattest.OpenAIStreamingResponse(
chattest.OpenAIToolCallChunk(
"authed-mcp__echo",
`{"input":"hello via refreshed token"}`,
),
)
}
// Second call: verify the tool result was fed back.
for _, msg := range req.Messages {
if msg.Role == "tool" && strings.Contains(msg.Content, "echo: hello via refreshed token") {
foundMCPResult.Store(true)
}
}
return chattest.OpenAIStreamingResponse(
chattest.OpenAITextChunks("Done!")...,
)
})
user, org, model := seedChatDependenciesWithProvider(ctx, t, db, "openai-compat", openAIURL)
// Seed the MCP server config with OAuth2 auth pointing to our
// mock token endpoint.
mcpConfig, err := db.InsertMCPServerConfig(ctx, database.InsertMCPServerConfigParams{
DisplayName: "Authed MCP",
Slug: "authed-mcp",
Url: mcpTS.URL,
Transport: "streamable_http",
AuthType: "oauth2",
OAuth2ClientID: "test-client-id",
OAuth2TokenURL: tokenSrv.URL,
Availability: "default_off",
Enabled: true,
ToolAllowList: []string{},
ToolDenyList: []string{},
CreatedBy: user.ID,
UpdatedBy: user.ID,
})
require.NoError(t, err)
// Seed an expired OAuth2 token with a valid refresh_token.
_, err = db.UpsertMCPServerUserToken(ctx, database.UpsertMCPServerUserTokenParams{
MCPServerConfigID: mcpConfig.ID,
UserID: user.ID,
AccessToken: "old-expired-access-token",
RefreshToken: "old-refresh-token",
TokenType: "Bearer",
Expiry: sql.NullTime{Time: time.Now().Add(-1 * time.Hour), Valid: true},
})
require.NoError(t, err)
ws, dbAgent := seedWorkspaceWithAgent(t, db, user.ID)
ctrl := gomock.NewController(t)
mockConn := agentconnmock.NewMockAgentConn(ctrl)
mockConn.EXPECT().SetExtraHeaders(gomock.Any()).AnyTimes()
mockConn.EXPECT().ContextConfig(gomock.Any()).
Return(workspacesdk.ContextConfigResponse{}, xerrors.New("not supported")).AnyTimes()
mockConn.EXPECT().ListMCPTools(gomock.Any()).
Return(workspacesdk.ListMCPToolsResponse{}, nil).AnyTimes()
mockConn.EXPECT().LS(gomock.Any(), gomock.Any(), gomock.Any()).
Return(workspacesdk.LSResponse{}, nil).AnyTimes()
mockConn.EXPECT().ReadFile(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(io.NopCloser(strings.NewReader("")), "", nil).AnyTimes()
server := newActiveTestServer(t, db, ps, func(cfg *chatd.Config) {
cfg.AgentConn = func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) {
require.Equal(t, dbAgent.ID, agentID)
return mockConn, func() {}, nil
}
})
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
OrganizationID: org.ID,
OwnerID: user.ID,
Title: "oauth2-refresh-test",
ModelConfigID: model.ID,
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
MCPServerIDs: []uuid.UUID{mcpConfig.ID},
InitialUserContent: []codersdk.ChatMessagePart{
codersdk.ChatMessageText("Echo something via the authed MCP."),
},
})
require.NoError(t, err)
// Wait for the chat to finish processing.
var chatResult database.Chat
require.Eventually(t, func() bool {
got, getErr := db.GetChatByID(ctx, chat.ID)
if getErr != nil {
return false
}
chatResult = got
return got.Status == database.ChatStatusWaiting || got.Status == database.ChatStatusError
}, testutil.WaitLong, testutil.IntervalFast)
if chatResult.Status == database.ChatStatusError {
require.FailNowf(t, "chat failed", "last_error=%q", chatResult.LastError.String)
}
// The token should have been refreshed.
require.Greater(t, refreshCalled.Load(), int32(0),
"OAuth2 token endpoint should have been called to refresh the expired token")
// The MCP tool should appear in the tool list.
llmToolsMu.Lock()
recordedNames := append([]string(nil), llmToolNames...)
llmToolsMu.Unlock()
require.Contains(t, recordedNames, "authed-mcp__echo",
"MCP tool should be in the tool list sent to the LLM")
// The tool result should have been fed back to the LLM.
require.True(t, foundMCPResult.Load(),
"MCP tool result should appear in the second LLM call")
// Verify the refreshed token was persisted to the database.
dbToken, err := db.GetMCPServerUserToken(ctx, database.GetMCPServerUserTokenParams{
MCPServerConfigID: mcpConfig.ID,
UserID: user.ID,
})
require.NoError(t, err)
require.Equal(t, freshAccessToken, dbToken.AccessToken,
"refreshed access token should be persisted in the database")
require.Equal(t, "rotated-refresh-token", dbToken.RefreshToken,
"rotated refresh token should be persisted in the database")
}
// TestMCPServerOAuth2TokenRefreshFailureGraceful verifies that when
// the OAuth2 token endpoint is down, the chat still proceeds without
// the MCP server's tools. The expired token is preserved unchanged.
func TestMCPServerOAuth2TokenRefreshFailureGraceful(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitLong)
// Token endpoint that always returns an error.
tokenSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadGateway)
_, _ = w.Write([]byte(`{"error":"server_error","error_description":"token endpoint unavailable"}`))
}))
t.Cleanup(tokenSrv.Close)
// The LLM just replies with text — no tool calls.
var callCount atomic.Int32
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
if !req.Stream {
return chattest.OpenAINonStreamingResponse("title")
}
callCount.Add(1)
return chattest.OpenAIStreamingResponse(
chattest.OpenAITextChunks("I responded without MCP tools.")...,
)
})
user, org, model := seedChatDependenciesWithProvider(ctx, t, db, "openai-compat", openAIURL)
mcpConfig, err := db.InsertMCPServerConfig(ctx, database.InsertMCPServerConfigParams{
DisplayName: "Broken MCP",
Slug: "broken-mcp",
Url: "http://127.0.0.1:0/does-not-exist",
Transport: "streamable_http",
AuthType: "oauth2",
OAuth2ClientID: "test-client-id",
OAuth2TokenURL: tokenSrv.URL,
Availability: "default_off",
Enabled: true,
ToolAllowList: []string{},
ToolDenyList: []string{},
CreatedBy: user.ID,
UpdatedBy: user.ID,
})
require.NoError(t, err)
_, err = db.UpsertMCPServerUserToken(ctx, database.UpsertMCPServerUserTokenParams{
MCPServerConfigID: mcpConfig.ID,
UserID: user.ID,
AccessToken: "old-expired-token",
RefreshToken: "old-refresh-token",
TokenType: "Bearer",
Expiry: sql.NullTime{Time: time.Now().Add(-1 * time.Hour), Valid: true},
})
require.NoError(t, err)
server := newActiveTestServer(t, db, ps)
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
OrganizationID: org.ID,
OwnerID: user.ID,
Title: "graceful-degradation-test",
ModelConfigID: model.ID,
MCPServerIDs: []uuid.UUID{mcpConfig.ID},
InitialUserContent: []codersdk.ChatMessagePart{
codersdk.ChatMessageText("Hello, just reply."),
},
})
require.NoError(t, err)
// Chat should finish successfully despite the failed refresh.
var chatResult database.Chat
require.Eventually(t, func() bool {
got, getErr := db.GetChatByID(ctx, chat.ID)
if getErr != nil {
return false
}
chatResult = got
return got.Status == database.ChatStatusWaiting || got.Status == database.ChatStatusError
}, testutil.WaitLong, testutil.IntervalFast)
if chatResult.Status == database.ChatStatusError {
require.FailNowf(t, "chat should not fail", "last_error=%q", chatResult.LastError.String)
}
// The LLM should have been called at least once.
require.Greater(t, callCount.Load(), int32(0),
"LLM should be called even when MCP token refresh fails")
// The original token should be unchanged in the database.
dbToken, err := db.GetMCPServerUserToken(ctx, database.GetMCPServerUserTokenParams{
MCPServerConfigID: mcpConfig.ID,
UserID: user.ID,
})
require.NoError(t, err)
require.Equal(t, "old-expired-token", dbToken.AccessToken,
"original token should be preserved when refresh fails")
}
func TestChatTemplateAllowlistEnforcement(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
db, ps := dbtestutil.NewDB(t)
// Declare templates before the handler so the closure can
// reference their IDs when building tool-call arguments.
var tplAllowed, tplBlocked database.Template
// Set up a mock OpenAI server that chains tool calls:
// 1. list_templates
// 2. read_template (blocked template — should fail)
// 3. read_template (allowed template — should succeed)
// 4. create_workspace (blocked template — should fail)
// 5. text response
var callCount atomic.Int32
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
if !req.Stream {
return chattest.OpenAINonStreamingResponse("title")
}
switch callCount.Add(1) {
case 1:
return chattest.OpenAIStreamingResponse(
chattest.OpenAIToolCallChunk("list_templates", `{}`),
)
case 2:
return chattest.OpenAIStreamingResponse(
chattest.OpenAIToolCallChunk("read_template",
fmt.Sprintf(`{"template_id":%q}`, tplBlocked.ID.String())),
)
case 3:
return chattest.OpenAIStreamingResponse(
chattest.OpenAIToolCallChunk("read_template",
fmt.Sprintf(`{"template_id":%q}`, tplAllowed.ID.String())),
)
case 4:
return chattest.OpenAIStreamingResponse(
chattest.OpenAIToolCallChunk("create_workspace",
fmt.Sprintf(`{"template_id":%q}`, tplBlocked.ID.String())),
)
default:
return chattest.OpenAIStreamingResponse(
chattest.OpenAITextChunks("Done testing.")...,
)
}
})
user, org, model := seedChatDependenciesWithProvider(ctx, t, db, "openai-compat", openAIURL)
// Create two templates the user can see.
tplAllowed = dbgen.Template(t, db, database.Template{
OrganizationID: org.ID,
CreatedBy: user.ID,
Name: "allowed-template",
})
tplBlocked = dbgen.Template(t, db, database.Template{
OrganizationID: org.ID,
CreatedBy: user.ID,
Name: "blocked-template",
})
// Set the allowlist to only tplAllowed.
allowlistJSON, err := json.Marshal([]string{tplAllowed.ID.String()})
require.NoError(t, err)
err = db.UpsertChatTemplateAllowlist(dbauthz.AsSystemRestricted(ctx), string(allowlistJSON))
require.NoError(t, err)
server := newActiveTestServer(t, db, ps, func(cfg *chatd.Config) {
// Provide a CreateWorkspace function so the tool reaches
// the allowlist check instead of bailing with "not
// configured". If the allowlist is enforced correctly
// this function will never be called.
cfg.CreateWorkspace = func(
_ context.Context,
_ uuid.UUID,
_ codersdk.CreateWorkspaceRequest,
) (codersdk.Workspace, error) {
t.Error("CreateWorkspace should not be called for a blocked template")
return codersdk.Workspace{}, xerrors.New("unexpected call")
}
})
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
OrganizationID: org.ID,
OwnerID: user.ID,
Title: "allowlist-test",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{
codersdk.ChatMessageText("Test allowlist enforcement"),
},
})
require.NoError(t, err)
// Wait for the chat to finish processing.
var chatResult database.Chat
require.Eventually(t, func() bool {
got, getErr := db.GetChatByID(ctx, chat.ID)
if getErr != nil {
return false
}
chatResult = got
return got.Status == database.ChatStatusWaiting || got.Status == database.ChatStatusError
}, testutil.WaitLong, testutil.IntervalFast)
if chatResult.Status == database.ChatStatusError {
require.FailNowf(t, "chat run failed", "last_error=%q", chatResult.LastError.String)
}
// Collect all tool results keyed by tool name. Each tool may
// have been called more than once, so we store a slice.
var toolResults map[string][]string
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
toolResults = map[string][]string{}
messages, dbErr := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: chat.ID,
AfterID: 0,
})
if dbErr != nil {
return false
}
for _, msg := range messages {
if msg.Role != database.ChatMessageRoleTool {
continue
}
parts, parseErr := chatprompt.ParseContent(msg)
if parseErr != nil {
continue
}
for _, part := range parts {
if part.Type == codersdk.ChatMessagePartTypeToolResult {
toolResults[part.ToolName] = append(
toolResults[part.ToolName], string(part.Result))
}
}
}
// We expect results from all four tool calls.
return len(toolResults["list_templates"]) >= 1 &&
len(toolResults["read_template"]) >= 2 &&
len(toolResults["create_workspace"]) >= 1
}, testutil.IntervalFast)
// list_templates: only the allowed template should appear.
require.Contains(t, toolResults["list_templates"][0], tplAllowed.ID.String(),
"allowed template should appear in list_templates result")
require.NotContains(t, toolResults["list_templates"][0], tplBlocked.ID.String(),
"blocked template should NOT appear in list_templates result")
// read_template: blocked ID → error, allowed ID → success.
require.Contains(t, toolResults["read_template"][0], "not found",
"read_template for blocked template should return not-found error")
require.Contains(t, toolResults["read_template"][1], tplAllowed.ID.String(),
"read_template for allowed template should return template details")
// create_workspace: blocked ID → rejected.
require.Contains(t, toolResults["create_workspace"][0], "not available",
"create_workspace for blocked template should be rejected")
}
// TestSignalWakeImmediateAcquisition verifies that CreateChat triggers
// immediate processing via signalWake without waiting for the polling
// ticker to fire. The ticker interval is set to an hour so it never
// fires during the test — any processing must come from the wake
// channel.
func TestSignalWakeImmediateAcquisition(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitLong)
processed := make(chan struct{})
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
if !req.Stream {
return chattest.OpenAINonStreamingResponse("title")
}
// Signal that the LLM was reached — this proves the chat
// was acquired and processing started.
select {
case <-processed:
default:
close(processed)
}
return chattest.OpenAIStreamingResponse(
chattest.OpenAITextChunks("hello from the model")...,
)
})
// Use a 1-hour acquire interval so the ticker never fires.
server := newActiveTestServer(t, db, ps, func(cfg *chatd.Config) {
cfg.PendingChatAcquireInterval = time.Hour
cfg.InFlightChatStaleAfter = testutil.WaitSuperLong
})
user, org, model := seedChatDependencies(ctx, t, db)
setOpenAIProviderBaseURL(ctx, t, db, openAIURL)
// CreateChat sets status=pending and calls signalWake().
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
OrganizationID: org.ID,
OwnerID: user.ID,
Title: "wake-test",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
})
require.NoError(t, err)
// The chat should be processed immediately — the LLM handler
// closes the `processed` channel when it receives a streaming
// request. Without signalWake this would hang forever because
// the 1-hour ticker never fires.
testutil.TryReceive(ctx, t, processed)
chatd.WaitUntilIdleForTest(server)
// Verify the chat was fully processed.
fromDB, err := db.GetChatByID(ctx, chat.ID)
require.NoError(t, err)
require.Equal(t, database.ChatStatusWaiting, fromDB.Status,
"chat should be in waiting status after processing completes")
}
// TestSignalWakeSendMessage verifies that SendMessage on an idle chat
// triggers immediate processing via signalWake.
func TestSignalWakeSendMessage(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitSuperLong)
firstProcessed := make(chan struct{})
var requestCount atomic.Int32
secondProcessed := make(chan struct{})
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
if !req.Stream {
return chattest.OpenAINonStreamingResponse("title")
}
switch requestCount.Add(1) {
case 1:
select {
case <-firstProcessed:
default:
close(firstProcessed)
}
case 2:
close(secondProcessed)
}
return chattest.OpenAIStreamingResponse(
chattest.OpenAITextChunks("response")...,
)
})
server := newActiveTestServer(t, db, ps, func(cfg *chatd.Config) {
cfg.PendingChatAcquireInterval = time.Hour
cfg.InFlightChatStaleAfter = testutil.WaitSuperLong
})
user, org, model := seedChatDependencies(ctx, t, db)
setOpenAIProviderBaseURL(ctx, t, db, openAIURL)
// CreateChat triggers wake -> processes first turn.
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
OrganizationID: org.ID,
OwnerID: user.ID,
Title: "wake-send-test",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("first")},
})
require.NoError(t, err)
// Wait for the first turn to actually reach the LLM, then
// wait for the processing goroutine to finish so the chat
// transitions to "waiting" status.
testutil.TryReceive(ctx, t, firstProcessed)
chatd.WaitUntilIdleForTest(server)
// Now send a follow-up message — this should also be
// processed immediately via signalWake.
_, err = server.SendMessage(ctx, chatd.SendMessageOptions{
ChatID: chat.ID,
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("second")},
})
require.NoError(t, err)
testutil.TryReceive(ctx, t, secondProcessed)
chatd.WaitUntilIdleForTest(server)
// Both turns processed — verify second request reached the LLM.
require.GreaterOrEqual(t, requestCount.Load(), int32(2),
"LLM should have received at least 2 streaming requests")
}
// TestAgentContextFilesAndSkillsLoadedIntoChat verifies the full
// end-to-end path: the workspace agent reads instruction files and
// discovers skills from the filesystem, chatd fetches them via a
// real tailnet agent connection, and both the <workspace-context>
// block and <available-skills> index appear in the LLM prompt.
//
// This test is NOT parallel because it sets process-wide environment
// variables via t.Setenv to configure the agent's context config.
func TestAgentContextFilesAndSkillsLoadedIntoChat(t *testing.T) {
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
instructionsDir := filepath.Join(fakeHome, ".coder")
skillsDir := filepath.Join(fakeHome, ".coder", "skills")
require.NoError(t, os.MkdirAll(instructionsDir, 0o755))
require.NoError(t, os.MkdirAll(skillsDir, 0o755))
t.Setenv(agentcontextconfig.EnvInstructionsDirs, instructionsDir)
t.Setenv(agentcontextconfig.EnvInstructionsFile, "AGENTS.md")
t.Setenv(agentcontextconfig.EnvSkillsDirs, skillsDir)
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "SKILL.md")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, filepath.Join(fakeHome, "nonexistent-mcp.json"))
require.NoError(t, os.WriteFile(
filepath.Join(instructionsDir, "AGENTS.md"),
[]byte("# Project Rules\nAlways write tests."),
0o600,
))
skillDir := filepath.Join(skillsDir, "my-cool-skill")
require.NoError(t, os.MkdirAll(skillDir, 0o755))
require.NoError(t, os.WriteFile(
filepath.Join(skillDir, "SKILL.md"),
[]byte("---\nname: my-cool-skill\ndescription: A test skill\n---\nDo the cool thing.\n"),
0o600,
))
ctx := testutil.Context(t, testutil.WaitSuperLong)
deploymentValues := coderdtest.DeploymentValues(t)
deploymentValues.Experiments = []string{string(codersdk.ExperimentAgents)}
client := coderdtest.New(t, &coderdtest.Options{
DeploymentValues: deploymentValues,
IncludeProvisionerDaemon: true,
ChatdInstructionLookupTimeout: testutil.WaitLong,
})
user := coderdtest.CreateFirstUser(t, client)
expClient := codersdk.NewExperimentalClient(client)
agentToken := uuid.NewString()
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
Parse: echo.ParseComplete,
ProvisionPlan: echo.PlanComplete,
ProvisionApply: echo.ApplyComplete,
ProvisionGraph: echo.ProvisionGraphWithAgent(agentToken),
})
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
workspace := coderdtest.CreateWorkspace(t, client, template.ID)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
_ = agenttest.New(t, client.URL, agentToken)
coderdtest.NewWorkspaceAgentWaiter(t, client, workspace.ID).Wait()
// Capture LLM requests so we can inspect the system prompt.
var streamedCallsMu sync.Mutex
streamedCalls := make([][]chattest.OpenAIMessage, 0, 2)
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
if !req.Stream {
return chattest.OpenAINonStreamingResponse("context test")
}
streamedCallsMu.Lock()
streamedCalls = append(streamedCalls, append([]chattest.OpenAIMessage(nil), req.Messages...))
streamedCallsMu.Unlock()
return chattest.OpenAIStreamingResponse(
chattest.OpenAITextChunks("Got it.")...,
)
})
_, err := expClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
Provider: "openai-compat",
APIKey: "test-api-key",
BaseURL: openAIURL,
})
require.NoError(t, err)
contextLimit := int64(4096)
isDefault := true
_, err = expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: "openai-compat",
Model: "gpt-4o-mini",
ContextLimit: &contextLimit,
IsDefault: &isDefault,
})
require.NoError(t, err)
workspaceID := workspace.ID
chat, err := expClient.CreateChat(ctx, codersdk.CreateChatRequest{
OrganizationID: user.OrganizationID,
WorkspaceID: &workspaceID,
Content: []codersdk.ChatInputPart{
{
Type: codersdk.ChatInputPartTypeText,
Text: "Hello, what are the project rules?",
},
},
})
require.NoError(t, err)
require.Eventually(t, func() bool {
got, getErr := expClient.GetChat(ctx, chat.ID)
if getErr != nil {
return false
}
return got.Status == codersdk.ChatStatusWaiting || got.Status == codersdk.ChatStatusError
}, testutil.WaitSuperLong, testutil.IntervalFast)
streamedCallsMu.Lock()
recordedCalls := append([][]chattest.OpenAIMessage(nil), streamedCalls...)
streamedCallsMu.Unlock()
require.NotEmpty(t, recordedCalls, "LLM should have received at least one streaming request")
var allSystemContent string
for _, msg := range recordedCalls[0] {
if msg.Role == "system" {
allSystemContent += msg.Content + "\n"
}
}
require.Contains(t, allSystemContent, "<workspace-context>",
"system prompt should contain workspace-context block")
require.Contains(t, allSystemContent, "Always write tests.",
"system prompt should contain AGENTS.md content")
require.Contains(t, allSystemContent, "AGENTS.md",
"system prompt should reference the source file")
planBlockCount := 0
standalonePlanBlockCount := 0
for _, msg := range recordedCalls[0] {
if msg.Role != "system" {
continue
}
planBlockCount += strings.Count(
msg.Content,
"<plan-file-path>\nYour plan file path for this chat is:",
)
trimmed := strings.TrimSpace(msg.Content)
if strings.HasPrefix(trimmed, "<plan-file-path>") &&
strings.HasSuffix(trimmed, "</plan-file-path>") {
standalonePlanBlockCount++
}
}
require.Contains(t, allSystemContent, "<available-skills>",
"system prompt should contain available-skills block")
require.Contains(t, allSystemContent, "my-cool-skill",
"system prompt should list the discovered skill")
require.Contains(t, allSystemContent, "A test skill",
"system prompt should include the skill description")
require.Contains(t, allSystemContent, "<plan-file-path>",
"system prompt should contain the plan-file-path block")
require.Contains(t, allSystemContent, "PLAN-"+chat.ID.String()+".md",
"system prompt should use the chat-specific plan path")
require.Contains(t, allSystemContent,
"Do not use "+strings.TrimRight(fakeHome, "/")+"/PLAN.md.",
"system prompt should warn against the home-root plan path")
require.Equal(t, 1, planBlockCount,
"system prompt should contain a single plan-file-path block")
require.Zero(t, standalonePlanBlockCount,
"plan-file-path block should be part of the main system prompt, not a standalone message")
}