Files
coder/coderd/x/chatd/chatd_internal_test.go
T
Michael Suchacz 9d0469fc4c feat: allow approved external MCP tools in root plan mode (#24509)
## Summary

Allow root plan-mode chats to use MCP tools from external servers that
an admin has explicitly approved for plan mode. Workspace MCP and
plan-mode subagents remain blocked.

## Problem

`chatd.go` excluded every MCP tool when `isPlanModeTurn` was true, so
planning had no access to tools like docs search, ticketing, etc.
Lifting that guard wholesale was unsafe: `mcp_server_configs` already
has centralized admin governance, but workspace-local MCP (discovered
from agent `.mcp.json`) does not, and subagents use a narrower trust
boundary.

## Fix

Add an admin-controlled per-server `allow_in_plan_mode` flag (default
`false`) and gate plan-mode MCP access on it.

### Backend / schema
- New migration `000472_mcp_server_allow_in_plan_mode.{up,down}.sql` and
matching fixture update.
- `mcpserverconfigs.sql` + generated code: persist and read the new
column.
- `codersdk/mcp.go`: thread the field through `MCPServerConfig`,
`Create*`, and `Update*` request types.
- `coderd/mcp.go`: validate, persist, and return the flag in
get/list/create/update handlers.

### chatd
- `coderd/x/chatd/chatd.go`: pre-filter selected external MCP configs by
`AllowInPlanMode` before calling `mcpclient.ConnectAll` on plan-mode
root turns. Workspace MCP discovery is skipped entirely on plan-mode
turns.
- Single helper decides whether a tool is available in plan mode, used
both at construction and for active-tool filtering (defense in depth).
Plan-mode subagents, dynamic tools, provider-native tools, computer-use,
and workspace MCP stay unchanged.
- `coderd/x/chatd/prompt.go`: update the root plan-mode overlay text to
match the new boundary.

### UI
- `MCPServerAdminPanel.tsx`: add an explicit toggle ("Allow all tools
from this MCP server in root plan mode") next to the existing governance
controls.
- Regenerated `site/src/api/typesGenerated.ts`.

### Docs
- `docs/ai-coder/agents/architecture.md`: replace the blanket "MCP is
unavailable in plan mode" note with the new root-only, external-only,
admin-approved policy. Explicitly call out that workspace MCP and
plan-mode subagents are still excluded.

### Tests
- Plan-mode visibility (approved vs non-approved external server).
- Plan-mode invocation of an approved external MCP tool.
- End-to-end plan-mode workflow that uses an approved MCP tool and then
reaches `propose_plan`.
- Regressions: workspace MCP still excluded in plan mode; plan-mode
subagents still on the restricted tool boundary; existing tool
allow/deny list filtering still applies.

## Policy precedence

`allow_in_plan_mode` is an **additional** requirement on top of existing
`enabled`, availability, chat-selected / forced server IDs, and tool
allow/deny lists. It approves **all tools on that server** for root plan
mode; a per-tool plan allowlist is deliberately deferred.

## Follow-ups (explicitly out of scope)

- Whether plan-mode subagents should inherit approved external MCP
tools.
- Workspace-local MCP safety model (agent-side `.mcp.json` schema vs. a
coderd-managed workspace MCP config).

## Validation

- `go vet ./coderd/x/chatd/...`
- `go test ./coderd/x/chatd -run 'TestPlan.*|TestMCP.*' -count=1`
- `go test ./coderd/x/chatd -count=1 -timeout 5m` (full chatd suite)
- `make fmt` (no diff)

> Mux opened this PR on Mike's behalf.
2026-04-21 12:26:12 +02:00

3696 lines
110 KiB
Go

package chatd
import (
"context"
"database/sql"
"encoding/json"
"sync"
"testing"
"time"
"charm.land/fantasy"
"github.com/google/uuid"
"github.com/sqlc-dev/pqtype"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbmock"
dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub"
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
"github.com/coder/coder/v2/coderd/x/chatd/chaterror"
"github.com/coder/coder/v2/coderd/x/chatd/chatloop"
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
"github.com/coder/coder/v2/coderd/x/chatd/chattest"
"github.com/coder/coder/v2/coderd/x/chatd/chattool"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
)
type testAgentTool struct {
info fantasy.ToolInfo
providerOptions fantasy.ProviderOptions
}
func newTestAgentTool(name string) fantasy.AgentTool {
return &testAgentTool{info: fantasy.ToolInfo{Name: name}}
}
func (t *testAgentTool) Info() fantasy.ToolInfo {
return t.info
}
func (t *testAgentTool) Run(context.Context, fantasy.ToolCall) (fantasy.ToolResponse, error) {
_ = t
return fantasy.ToolResponse{}, nil
}
func (t *testAgentTool) ProviderOptions() fantasy.ProviderOptions {
return t.providerOptions
}
func (t *testAgentTool) SetProviderOptions(opts fantasy.ProviderOptions) {
t.providerOptions = opts
}
type testMCPAgentTool struct {
*testAgentTool
configID uuid.UUID
}
func newTestMCPAgentTool(name string, configID uuid.UUID) fantasy.AgentTool {
return &testMCPAgentTool{
testAgentTool: &testAgentTool{info: fantasy.ToolInfo{Name: name}},
configID: configID,
}
}
func (t *testMCPAgentTool) MCPServerConfigID() uuid.UUID {
return t.configID
}
func TestFilterExternalMCPConfigsForTurn(t *testing.T) {
t.Parallel()
approvedConfig := database.MCPServerConfig{ID: uuid.New(), AllowInPlanMode: true}
blockedConfig := database.MCPServerConfig{ID: uuid.New(), AllowInPlanMode: false}
configs := []database.MCPServerConfig{approvedConfig, blockedConfig}
planMode := database.NullChatPlanMode{
ChatPlanMode: database.ChatPlanModePlan,
Valid: true,
}
t.Run("NonPlanModePassesThroughAllConfigs", func(t *testing.T) {
t.Parallel()
filtered, approvedIDs := filterExternalMCPConfigsForTurn(
configs,
database.NullChatPlanMode{},
uuid.NullUUID{},
)
require.Equal(t, configs, filtered)
require.Nil(t, approvedIDs)
})
t.Run("PlanModeSubagentsReturnNoConfigs", func(t *testing.T) {
t.Parallel()
filtered, approvedIDs := filterExternalMCPConfigsForTurn(
configs,
planMode,
uuid.NullUUID{UUID: uuid.New(), Valid: true},
)
require.Nil(t, filtered)
require.NotNil(t, approvedIDs)
require.Empty(t, approvedIDs)
})
t.Run("PlanModeRootFiltersToApprovedConfigs", func(t *testing.T) {
t.Parallel()
filtered, approvedIDs := filterExternalMCPConfigsForTurn(
configs,
planMode,
uuid.NullUUID{},
)
require.Equal(t, []database.MCPServerConfig{approvedConfig}, filtered)
require.Equal(t, map[uuid.UUID]struct{}{approvedConfig.ID: {}}, approvedIDs)
})
}
func TestActiveToolNamesForTurn(t *testing.T) {
t.Parallel()
makeTools := func(names ...string) []fantasy.AgentTool {
tools := make([]fantasy.AgentTool, 0, len(names))
for _, name := range names {
tools = append(tools, newTestAgentTool(name))
}
return tools
}
planMode := database.NullChatPlanMode{
ChatPlanMode: database.ChatPlanModePlan,
Valid: true,
}
t.Run("NormalModeReturnsAllRegisteredTools", func(t *testing.T) {
t.Parallel()
got := activeToolNamesForTurn(makeTools(
"read_file",
"propose_plan",
"custom_tool",
"execute",
), database.NullChatPlanMode{}, uuid.NullUUID{}, nil)
require.Equal(t, []string{
"read_file",
"propose_plan",
"custom_tool",
"execute",
}, got)
})
t.Run("PlanModeIncludesOnlyAllowlistedBuiltIns", func(t *testing.T) {
t.Parallel()
got := activeToolNamesForTurn(makeTools(
"read_file",
"write_file",
"edit_files",
"execute",
"process_output",
"process_list",
"process_signal",
"list_templates",
"read_template",
"create_workspace",
"start_workspace",
"propose_plan",
"spawn_agent",
"spawn_explore_agent",
"wait_agent",
"message_agent",
"close_agent",
"spawn_computer_use_agent",
"read_skill",
"read_skill_file",
"ask_user_question",
), planMode, uuid.NullUUID{}, nil)
require.Equal(t, []string{
"read_file",
"write_file",
"edit_files",
"execute",
"process_output",
"list_templates",
"read_template",
"create_workspace",
"start_workspace",
"propose_plan",
"spawn_agent",
"spawn_explore_agent",
"wait_agent",
"read_skill",
"read_skill_file",
"ask_user_question",
}, got)
})
t.Run("PlanModeChildChatsAllowExplorationOnly", func(t *testing.T) {
t.Parallel()
got := activeToolNamesForTurn(makeTools(
"read_file",
"write_file",
"edit_files",
"execute",
"process_output",
"list_templates",
"read_template",
"create_workspace",
"start_workspace",
"propose_plan",
"spawn_agent",
"spawn_explore_agent",
"wait_agent",
"read_skill",
"read_skill_file",
"ask_user_question",
), planMode, uuid.NullUUID{UUID: uuid.New(), Valid: true}, nil)
require.Equal(t, []string{
"read_file",
"execute",
"process_output",
"read_skill",
"read_skill_file",
}, got)
require.NotContains(t, got, "write_file")
require.NotContains(t, got, "edit_files")
require.NotContains(t, got, "ask_user_question")
require.NotContains(t, got, "propose_plan")
require.NotContains(t, got, "spawn_explore_agent")
})
t.Run("PlanModeStillExcludesDangerousTools", func(t *testing.T) {
t.Parallel()
got := activeToolNamesForTurn(makeTools(
"execute",
"process_output",
"message_agent",
"spawn_computer_use_agent",
"propose_plan",
), planMode, uuid.NullUUID{}, nil)
require.Equal(t, []string{"execute", "process_output", "propose_plan"}, got)
require.NotContains(t, got, "message_agent")
require.NotContains(t, got, "spawn_computer_use_agent")
})
t.Run("PlanModeExcludesUnknownTools", func(t *testing.T) {
t.Parallel()
got := activeToolNamesForTurn(makeTools(
"read_file",
"custom_tool",
"another_custom_tool",
"propose_plan",
), planMode, uuid.NullUUID{}, nil)
require.Equal(t, []string{
"read_file",
"propose_plan",
}, got)
require.NotContains(t, got, "custom_tool")
require.NotContains(t, got, "another_custom_tool")
})
t.Run("PlanModeIncludesOnlyApprovedExternalMCPTools", func(t *testing.T) {
t.Parallel()
approvedConfigID := uuid.New()
blockedConfigID := uuid.New()
got := activeToolNamesForTurn([]fantasy.AgentTool{
newTestAgentTool("read_file"),
newTestMCPAgentTool("approved-mcp__echo", approvedConfigID),
newTestMCPAgentTool("blocked-mcp__echo", blockedConfigID),
newTestAgentTool("workspace-mcp__echo"),
}, planMode, uuid.NullUUID{}, map[uuid.UUID]struct{}{
approvedConfigID: {},
})
require.Equal(t, []string{
"read_file",
"approved-mcp__echo",
}, got)
require.NotContains(t, got, "blocked-mcp__echo")
require.NotContains(t, got, "workspace-mcp__echo")
})
}
func TestAllowedExploreToolNames(t *testing.T) {
t.Parallel()
makeTools := func(names ...string) []fantasy.AgentTool {
tools := make([]fantasy.AgentTool, 0, len(names))
for _, name := range names {
tools = append(tools, newTestAgentTool(name))
}
return tools
}
got := allowedExploreToolNames(makeTools(
"read_file",
"write_file",
"edit_files",
"execute",
"process_output",
"process_list",
"process_signal",
"spawn_agent",
"spawn_explore_agent",
"wait_agent",
"read_skill",
"read_skill_file",
"ask_user_question",
))
require.Equal(t, []string{
"read_file",
"execute",
"process_output",
"read_skill",
"read_skill_file",
}, got)
}
func TestAllowedBehaviorToolNames(t *testing.T) {
t.Parallel()
makeTools := func(names ...string) []fantasy.AgentTool {
tools := make([]fantasy.AgentTool, 0, len(names))
for _, name := range names {
tools = append(tools, newTestAgentTool(name))
}
return tools
}
allTools := makeTools("read_file", "custom_tool", "spawn_explore_agent")
exploreMode := database.NullChatMode{
ChatMode: database.ChatModeExplore,
Valid: true,
}
t.Run("DefaultModeReturnsAllTools", func(t *testing.T) {
t.Parallel()
require.Equal(t, []string{"read_file", "custom_tool", "spawn_explore_agent"}, allowedBehaviorToolNames(
allTools,
database.NullChatMode{},
))
})
t.Run("ExploreModeUsesExploreAllowlist", func(t *testing.T) {
t.Parallel()
require.Equal(t, []string{"read_file"}, allowedBehaviorToolNames(
allTools,
exploreMode,
))
})
}
func TestStopAfterPlanTools(t *testing.T) {
t.Parallel()
planMode := database.NullChatPlanMode{
ChatPlanMode: database.ChatPlanModePlan,
Valid: true,
}
t.Run("NormalModeReturnsNil", func(t *testing.T) {
t.Parallel()
require.Nil(t, stopAfterPlanTools(database.NullChatPlanMode{}, uuid.NullUUID{}))
})
t.Run("RootPlanModeIncludesClarificationTool", func(t *testing.T) {
t.Parallel()
require.Equal(t, map[string]struct{}{
"propose_plan": {},
"ask_user_question": {},
}, stopAfterPlanTools(planMode, uuid.NullUUID{}))
})
t.Run("ChildPlanModeSkipsClarificationTool", func(t *testing.T) {
t.Parallel()
require.Equal(t, map[string]struct{}{
"propose_plan": {},
}, stopAfterPlanTools(planMode, uuid.NullUUID{UUID: uuid.New(), Valid: true}))
})
}
func TestStopAfterBehaviorTools(t *testing.T) {
t.Parallel()
planMode := database.NullChatPlanMode{
ChatPlanMode: database.ChatPlanModePlan,
Valid: true,
}
exploreMode := database.NullChatMode{
ChatMode: database.ChatModeExplore,
Valid: true,
}
t.Run("DefaultModeReturnsNil", func(t *testing.T) {
t.Parallel()
require.Nil(t, stopAfterBehaviorTools(
database.NullChatPlanMode{},
database.NullChatMode{},
uuid.NullUUID{},
))
})
t.Run("RootPlanModeIncludesClarificationTool", func(t *testing.T) {
t.Parallel()
require.Equal(t, map[string]struct{}{
"propose_plan": {},
"ask_user_question": {},
}, stopAfterBehaviorTools(planMode, database.NullChatMode{}, uuid.NullUUID{}))
})
t.Run("ChildPlanModeSkipsClarificationTool", func(t *testing.T) {
t.Parallel()
require.Equal(t, map[string]struct{}{
"propose_plan": {},
}, stopAfterBehaviorTools(planMode, database.NullChatMode{}, uuid.NullUUID{UUID: uuid.New(), Valid: true}))
})
t.Run("ExploreModeReturnsNil", func(t *testing.T) {
t.Parallel()
require.Nil(t, stopAfterBehaviorTools(planMode, exploreMode, uuid.NullUUID{}))
})
}
// TestWaitForActiveChatStop and TestWaitForActiveChatStop_WaitsForReplacementRun
// were removed along with the process-local activeChats mechanism.
// Debug cleanup is now best-effort; stale finalization handles orphaned rows.
// TestArchiveChatWaitsForActiveChatStop and
// TestArchiveChatWaitsForEveryInterruptedChat were removed along with
// the process-local activeChats mechanism. Archive cleanup is now
// best-effort; stale finalization handles any orphaned rows.
func TestRenameChatTitle(t *testing.T) {
t.Parallel()
setupRealWorkerLock := func(
db *dbmock.MockStore,
chatID uuid.UUID,
lockedChat database.Chat,
) {
lockTx := dbmock.NewMockStore(gomock.NewController(t))
unlockTx := dbmock.NewMockStore(gomock.NewController(t))
gomock.InOrder(
db.EXPECT().InTx(gomock.Any(), database.DefaultTXOptions().WithID("chat_title_regenerate_lock")).DoAndReturn(
func(fn func(database.Store) error, _ *database.TxOptions) error {
return fn(lockTx)
},
),
db.EXPECT().InTx(gomock.Any(), database.DefaultTXOptions().WithID("chat_title_regenerate_unlock")).DoAndReturn(
func(fn func(database.Store) error, _ *database.TxOptions) error {
return fn(unlockTx)
},
),
)
lockTx.EXPECT().GetChatByIDForUpdate(gomock.Any(), chatID).Return(lockedChat, nil)
unlockTx.EXPECT().GetChatByIDForUpdate(gomock.Any(), chatID).Return(lockedChat, nil)
}
t.Run("WritesAndReturnsWroteTrue", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
chatID := uuid.New()
workerID := uuid.New()
stored := database.Chat{
ID: chatID,
Status: database.ChatStatusRunning,
WorkerID: uuid.NullUUID{UUID: workerID, Valid: true},
Title: "original",
}
updated := stored
updated.Title = "renamed"
server := &Server{db: db, logger: logger}
setupRealWorkerLock(db, chatID, stored)
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(stored, nil)
db.EXPECT().UpdateChatTitleByID(gomock.Any(), database.UpdateChatTitleByIDParams{
ID: chatID,
Title: "renamed",
}).Return(updated, nil)
got, wrote, err := server.RenameChatTitle(ctx, stored, "renamed")
require.NoError(t, err)
require.True(t, wrote, "fresh rename must report wrote=true")
require.Equal(t, updated, got)
})
t.Run("SkipsWriteWhenAlreadyAtNewTitle", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
chatID := uuid.New()
workerID := uuid.New()
stale := database.Chat{
ID: chatID,
Status: database.ChatStatusRunning,
WorkerID: uuid.NullUUID{UUID: workerID, Valid: true},
Title: "pre-race",
}
landed := stale
landed.Title = "landed-concurrently"
server := &Server{db: db, logger: logger}
setupRealWorkerLock(db, chatID, landed)
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(landed, nil)
got, wrote, err := server.RenameChatTitle(ctx, stale, "landed-concurrently")
require.NoError(t, err)
require.False(t, wrote,
"must report wrote=false when the stored row already matches newTitle so the handler suppresses a redundant title_change event")
require.Equal(t, landed, got)
})
}
func TestRegenerateChatTitle_PersistsAndBroadcasts(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
lockTx := dbmock.NewMockStore(ctrl)
usageTx := dbmock.NewMockStore(ctrl)
unlockTx := dbmock.NewMockStore(ctrl)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
pubsub := dbpubsub.NewInMemory()
clock := quartz.NewReal()
ownerID := uuid.New()
chatID := uuid.New()
modelConfigID := uuid.New()
workerID := uuid.New()
userPrompt := "review pull request 23633 and fix review threads"
wantTitle := "Review PR 23633"
chat := database.Chat{
ID: chatID,
OwnerID: ownerID,
LastModelConfigID: modelConfigID,
Status: database.ChatStatusRunning,
WorkerID: uuid.NullUUID{UUID: workerID, Valid: true},
Title: fallbackChatTitle(userPrompt),
}
modelConfig := database.ChatModelConfig{
ID: modelConfigID,
Provider: "openai",
Model: "gpt-4o-mini",
ContextLimit: 8192,
}
updatedChat := chat
updatedChat.Title = wantTitle
messageEvents := make(chan struct {
payload codersdk.ChatWatchEvent
err error
}, 1)
cancelSub, err := pubsub.SubscribeWithErr(
coderdpubsub.ChatWatchEventChannel(ownerID),
coderdpubsub.HandleChatWatchEvent(func(_ context.Context, payload codersdk.ChatWatchEvent, err error) {
messageEvents <- struct {
payload codersdk.ChatWatchEvent
err error
}{payload: payload, err: err}
}),
)
require.NoError(t, err)
defer cancelSub()
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
require.Equal(t, "gpt-4o-mini", req.Model)
return chattest.OpenAINonStreamingResponse("{\"title\":\"" + wantTitle + "\"}")
})
server := &Server{
db: db,
logger: logger,
pubsub: pubsub,
configCache: newChatConfigCache(context.Background(), db, clock),
}
db.EXPECT().GetChatModelConfigByID(gomock.Any(), modelConfigID).Return(modelConfig, nil)
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{
Provider: "openai",
CentralApiKeyEnabled: true,
APIKey: "test-key",
BaseUrl: serverURL,
}}, nil)
db.EXPECT().GetChatUsageLimitConfig(gomock.Any()).Return(database.ChatUsageLimitConfig{}, sql.ErrNoRows)
db.EXPECT().GetChatMessagesByChatIDAscPaginated(
gomock.Any(),
database.GetChatMessagesByChatIDAscPaginatedParams{
ChatID: chatID,
AfterID: 0,
LimitVal: manualTitleMessageWindowLimit,
},
).Return([]database.ChatMessage{
mustChatMessage(
t,
database.ChatMessageRoleUser,
database.ChatMessageVisibilityBoth,
codersdk.ChatMessageText(userPrompt),
),
mustChatMessage(
t,
database.ChatMessageRoleAssistant,
database.ChatMessageVisibilityBoth,
codersdk.ChatMessageText("checking the diff now"),
),
}, nil)
db.EXPECT().GetChatMessagesByChatIDDescPaginated(
gomock.Any(),
database.GetChatMessagesByChatIDDescPaginatedParams{
ChatID: chatID,
BeforeID: 0,
LimitVal: manualTitleMessageWindowLimit,
},
).Return(nil, nil)
db.EXPECT().GetEnabledChatModelConfigs(gomock.Any()).Return(nil, nil)
gomock.InOrder(
db.EXPECT().InTx(gomock.Any(), database.DefaultTXOptions().WithID("chat_title_regenerate_lock")).DoAndReturn(
func(fn func(database.Store) error, opts *database.TxOptions) error {
require.Equal(t, "chat_title_regenerate_lock", opts.TxIdentifier)
return fn(lockTx)
},
),
db.EXPECT().InTx(gomock.Any(), nil).DoAndReturn(
func(fn func(database.Store) error, opts *database.TxOptions) error {
require.Nil(t, opts)
return fn(usageTx)
},
),
db.EXPECT().InTx(gomock.Any(), database.DefaultTXOptions().WithID("chat_title_regenerate_unlock")).DoAndReturn(
func(fn func(database.Store) error, opts *database.TxOptions) error {
require.Equal(t, "chat_title_regenerate_unlock", opts.TxIdentifier)
return fn(unlockTx)
},
),
)
lockTx.EXPECT().GetChatByIDForUpdate(gomock.Any(), chatID).Return(chat, nil)
usageTx.EXPECT().GetChatByIDForUpdate(gomock.Any(), chatID).Return(chat, nil)
usageTx.EXPECT().InsertChatMessages(gomock.Any(), gomock.AssignableToTypeOf(database.InsertChatMessagesParams{})).DoAndReturn(
func(_ context.Context, arg database.InsertChatMessagesParams) ([]database.ChatMessage, error) {
require.Equal(t, []uuid.UUID{ownerID}, arg.CreatedBy)
require.Equal(t, []uuid.UUID{modelConfigID}, arg.ModelConfigID)
require.Equal(t, []string{"[]"}, arg.Content)
return []database.ChatMessage{{ID: 91}}, nil
},
)
usageTx.EXPECT().SoftDeleteChatMessageByID(gomock.Any(), int64(91)).Return(nil)
usageTx.EXPECT().UpdateChatByID(gomock.Any(), database.UpdateChatByIDParams{
ID: chatID,
Title: wantTitle,
}).Return(updatedChat, nil)
unlockTx.EXPECT().GetChatByIDForUpdate(gomock.Any(), chatID).Return(updatedChat, nil)
gotChat, err := server.RegenerateChatTitle(ctx, chat)
require.NoError(t, err)
require.Equal(t, updatedChat, gotChat)
select {
case event := <-messageEvents:
require.NoError(t, event.err)
require.Equal(t, codersdk.ChatWatchEventKindTitleChange, event.payload.Kind)
require.Equal(t, chatID, event.payload.Chat.ID)
require.Equal(t, wantTitle, event.payload.Chat.Title)
case <-time.After(time.Second):
t.Fatal("timed out waiting for title change pubsub event")
}
}
func TestRegenerateChatTitle_PersistsAndBroadcasts_IdleChatReleasesManualLock(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
lockTx := dbmock.NewMockStore(ctrl)
usageTx := dbmock.NewMockStore(ctrl)
unlockTx := dbmock.NewMockStore(ctrl)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
pubsub := dbpubsub.NewInMemory()
clock := quartz.NewReal()
ownerID := uuid.New()
chatID := uuid.New()
modelConfigID := uuid.New()
userPrompt := "review pull request 23633 and fix review threads"
wantTitle := "Review PR 23633"
chat := database.Chat{
ID: chatID,
OwnerID: ownerID,
LastModelConfigID: modelConfigID,
Status: database.ChatStatusCompleted,
Title: fallbackChatTitle(userPrompt),
}
lockedChat := chat
lockedChat.WorkerID = uuid.NullUUID{UUID: manualTitleLockWorkerID, Valid: true}
lockedChat.StartedAt = sql.NullTime{Time: time.Now(), Valid: true}
modelConfig := database.ChatModelConfig{
ID: modelConfigID,
Provider: "openai",
Model: "gpt-4o-mini",
ContextLimit: 8192,
}
updatedChat := lockedChat
updatedChat.Title = wantTitle
unlockedChat := updatedChat
unlockedChat.WorkerID = uuid.NullUUID{}
unlockedChat.StartedAt = sql.NullTime{}
messageEvents := make(chan struct {
payload codersdk.ChatWatchEvent
err error
}, 1)
cancelSub, err := pubsub.SubscribeWithErr(
coderdpubsub.ChatWatchEventChannel(ownerID),
coderdpubsub.HandleChatWatchEvent(func(_ context.Context, payload codersdk.ChatWatchEvent, err error) {
messageEvents <- struct {
payload codersdk.ChatWatchEvent
err error
}{payload: payload, err: err}
}),
)
require.NoError(t, err)
defer cancelSub()
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
require.Equal(t, "gpt-4o-mini", req.Model)
return chattest.OpenAINonStreamingResponse("{\"title\":\"" + wantTitle + "\"}")
})
server := &Server{
db: db,
logger: logger,
pubsub: pubsub,
configCache: newChatConfigCache(context.Background(), db, clock),
}
db.EXPECT().GetChatModelConfigByID(gomock.Any(), modelConfigID).Return(modelConfig, nil)
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{
Provider: "openai",
CentralApiKeyEnabled: true,
APIKey: "test-key",
BaseUrl: serverURL,
}}, nil)
db.EXPECT().GetChatUsageLimitConfig(gomock.Any()).Return(database.ChatUsageLimitConfig{}, sql.ErrNoRows)
db.EXPECT().GetChatMessagesByChatIDAscPaginated(
gomock.Any(),
database.GetChatMessagesByChatIDAscPaginatedParams{
ChatID: chatID,
AfterID: 0,
LimitVal: manualTitleMessageWindowLimit,
},
).Return([]database.ChatMessage{
mustChatMessage(
t,
database.ChatMessageRoleUser,
database.ChatMessageVisibilityBoth,
codersdk.ChatMessageText(userPrompt),
),
mustChatMessage(
t,
database.ChatMessageRoleAssistant,
database.ChatMessageVisibilityBoth,
codersdk.ChatMessageText("checking the diff now"),
),
}, nil)
db.EXPECT().GetChatMessagesByChatIDDescPaginated(
gomock.Any(),
database.GetChatMessagesByChatIDDescPaginatedParams{
ChatID: chatID,
BeforeID: 0,
LimitVal: manualTitleMessageWindowLimit,
},
).Return(nil, nil)
db.EXPECT().GetEnabledChatModelConfigs(gomock.Any()).Return(nil, nil)
gomock.InOrder(
db.EXPECT().InTx(gomock.Any(), database.DefaultTXOptions().WithID("chat_title_regenerate_lock")).DoAndReturn(
func(fn func(database.Store) error, opts *database.TxOptions) error {
require.Equal(t, "chat_title_regenerate_lock", opts.TxIdentifier)
return fn(lockTx)
},
),
db.EXPECT().InTx(gomock.Any(), nil).DoAndReturn(
func(fn func(database.Store) error, opts *database.TxOptions) error {
require.Nil(t, opts)
return fn(usageTx)
},
),
db.EXPECT().InTx(gomock.Any(), database.DefaultTXOptions().WithID("chat_title_regenerate_unlock")).DoAndReturn(
func(fn func(database.Store) error, opts *database.TxOptions) error {
require.Equal(t, "chat_title_regenerate_unlock", opts.TxIdentifier)
return fn(unlockTx)
},
),
)
lockTx.EXPECT().GetChatByIDForUpdate(gomock.Any(), chatID).Return(chat, nil)
lockTx.EXPECT().UpdateChatStatusPreserveUpdatedAt(
gomock.Any(),
gomock.AssignableToTypeOf(database.UpdateChatStatusPreserveUpdatedAtParams{}),
).DoAndReturn(func(_ context.Context, arg database.UpdateChatStatusPreserveUpdatedAtParams) (database.Chat, error) {
require.Equal(t, chat.ID, arg.ID)
require.Equal(t, chat.Status, arg.Status)
require.Equal(t, uuid.NullUUID{UUID: manualTitleLockWorkerID, Valid: true}, arg.WorkerID)
require.True(t, arg.StartedAt.Valid)
require.WithinDuration(t, time.Now(), arg.StartedAt.Time, time.Second)
require.False(t, arg.HeartbeatAt.Valid)
require.Equal(t, chat.LastError, arg.LastError)
require.Equal(t, chat.UpdatedAt, arg.UpdatedAt)
return lockedChat, nil
})
usageTx.EXPECT().GetChatByIDForUpdate(gomock.Any(), chatID).Return(lockedChat, nil)
usageTx.EXPECT().InsertChatMessages(gomock.Any(), gomock.AssignableToTypeOf(database.InsertChatMessagesParams{})).DoAndReturn(
func(_ context.Context, arg database.InsertChatMessagesParams) ([]database.ChatMessage, error) {
require.Equal(t, []uuid.UUID{ownerID}, arg.CreatedBy)
require.Equal(t, []uuid.UUID{modelConfigID}, arg.ModelConfigID)
require.Equal(t, []string{"[]"}, arg.Content)
return []database.ChatMessage{{ID: 91}}, nil
},
)
usageTx.EXPECT().SoftDeleteChatMessageByID(gomock.Any(), int64(91)).Return(nil)
usageTx.EXPECT().UpdateChatByID(gomock.Any(), database.UpdateChatByIDParams{
ID: chatID,
Title: wantTitle,
}).Return(updatedChat, nil)
unlockTx.EXPECT().GetChatByIDForUpdate(gomock.Any(), chatID).Return(updatedChat, nil)
unlockTx.EXPECT().UpdateChatStatusPreserveUpdatedAt(
gomock.Any(),
database.UpdateChatStatusPreserveUpdatedAtParams{
ID: updatedChat.ID,
Status: updatedChat.Status,
WorkerID: uuid.NullUUID{},
StartedAt: sql.NullTime{},
HeartbeatAt: sql.NullTime{},
LastError: updatedChat.LastError,
UpdatedAt: updatedChat.UpdatedAt,
},
).Return(unlockedChat, nil)
gotChat, err := server.RegenerateChatTitle(ctx, chat)
require.NoError(t, err)
require.Equal(t, updatedChat, gotChat)
select {
case event := <-messageEvents:
require.NoError(t, event.err)
require.Equal(t, codersdk.ChatWatchEventKindTitleChange, event.payload.Kind)
require.Equal(t, chatID, event.payload.Chat.ID)
require.Equal(t, wantTitle, event.payload.Chat.Title)
case <-time.After(time.Second):
t.Fatal("timed out waiting for title change pubsub event")
}
}
func TestResolveUserProviderAPIKeys_StripsDisabledFallbackKeys(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
ownerID := uuid.New()
server := &Server{
db: db,
configCache: newChatConfigCache(
context.Background(),
db,
quartz.NewReal(),
),
providerAPIKeys: chatprovider.ProviderAPIKeys{
OpenAI: "openai-deployment-key",
Anthropic: "anthropic-deployment-key",
ByProvider: map[string]string{
"openai": "openai-deployment-key",
"anthropic": "anthropic-deployment-key",
},
BaseURLByProvider: map[string]string{
"openai": "https://openai.example.com",
"anthropic": "https://anthropic.example.com",
},
},
}
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{
Provider: "anthropic",
CentralApiKeyEnabled: true,
AllowCentralApiKeyFallback: true,
}}, nil)
keys, err := server.resolveUserProviderAPIKeys(ctx, ownerID)
require.NoError(t, err)
require.Empty(t, keys.OpenAI)
require.Empty(t, keys.APIKey("openai"))
require.Empty(t, keys.BaseURL("openai"))
require.Equal(t, "anthropic-deployment-key", keys.Anthropic)
require.Equal(t, "anthropic-deployment-key", keys.APIKey("anthropic"))
require.Equal(t, "https://anthropic.example.com", keys.BaseURL("anthropic"))
require.Equal(t, map[string]string{"anthropic": "anthropic-deployment-key"}, keys.ByProvider)
require.Equal(t, map[string]string{"anthropic": "https://anthropic.example.com"}, keys.BaseURLByProvider)
}
func TestResolveUserProviderAPIKeys_SkipsUserKeyLookupWhenNoProviderAllowsUserKeys(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
ownerID := uuid.New()
server := &Server{
db: db,
configCache: newChatConfigCache(
context.Background(),
db,
quartz.NewReal(),
),
providerAPIKeys: chatprovider.ProviderAPIKeys{
OpenAI: "openai-deployment-key",
ByProvider: map[string]string{
"openai": "openai-deployment-key",
},
},
}
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{
Provider: "openai",
CentralApiKeyEnabled: true,
}}, nil)
keys, err := server.resolveUserProviderAPIKeys(ctx, ownerID)
require.NoError(t, err)
require.Equal(t, "openai-deployment-key", keys.OpenAI)
require.Equal(t, "openai-deployment-key", keys.APIKey("openai"))
}
func TestRefreshChatWorkspaceSnapshot_NoReloadWhenWorkspacePresent(t *testing.T) {
t.Parallel()
workspaceID := uuid.New()
chat := database.Chat{
ID: uuid.New(),
WorkspaceID: uuid.NullUUID{
UUID: workspaceID,
Valid: true,
},
}
calls := 0
refreshed, err := refreshChatWorkspaceSnapshot(
context.Background(),
chat,
func(context.Context, uuid.UUID) (database.Chat, error) {
calls++
return database.Chat{}, nil
},
)
require.NoError(t, err)
require.Equal(t, chat, refreshed)
require.Equal(t, 0, calls)
}
func TestRefreshChatWorkspaceSnapshot_ReloadsWhenWorkspaceMissing(t *testing.T) {
t.Parallel()
chatID := uuid.New()
workspaceID := uuid.New()
chat := database.Chat{ID: chatID}
reloaded := database.Chat{
ID: chatID,
WorkspaceID: uuid.NullUUID{
UUID: workspaceID,
Valid: true,
},
}
calls := 0
refreshed, err := refreshChatWorkspaceSnapshot(
context.Background(),
chat,
func(_ context.Context, id uuid.UUID) (database.Chat, error) {
calls++
require.Equal(t, chatID, id)
return reloaded, nil
},
)
require.NoError(t, err)
require.Equal(t, reloaded, refreshed)
require.Equal(t, 1, calls)
}
func TestRefreshChatWorkspaceSnapshot_ReturnsReloadError(t *testing.T) {
t.Parallel()
chat := database.Chat{ID: uuid.New()}
loadErr := xerrors.New("boom")
refreshed, err := refreshChatWorkspaceSnapshot(
context.Background(),
chat,
func(context.Context, uuid.UUID) (database.Chat, error) {
return database.Chat{}, loadErr
},
)
require.Error(t, err)
require.ErrorContains(t, err, "reload chat workspace state")
require.ErrorContains(t, err, loadErr.Error())
require.Equal(t, chat, refreshed)
}
func TestPersistInstructionFilesIncludesAgentMetadata(t *testing.T) {
t.Parallel()
ctx := context.Background()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
workspaceID := uuid.New()
agentID := uuid.New()
chat := database.Chat{
ID: uuid.New(),
WorkspaceID: uuid.NullUUID{
UUID: workspaceID,
Valid: true,
},
AgentID: uuid.NullUUID{
UUID: agentID,
Valid: true,
},
}
workspaceAgent := database.WorkspaceAgent{
ID: agentID,
OperatingSystem: "linux",
Directory: "/home/coder/project",
ExpandedDirectory: "/home/coder/project",
}
db.EXPECT().GetWorkspaceAgentByID(
gomock.Any(),
agentID,
).Return(workspaceAgent, nil).Times(1)
db.EXPECT().InsertChatMessages(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
db.EXPECT().UpdateChatLastInjectedContext(gomock.Any(),
gomock.Cond(func(x any) bool {
arg, ok := x.(database.UpdateChatLastInjectedContextParams)
if !ok || arg.ID != chat.ID {
return false
}
if !arg.LastInjectedContext.Valid {
return false
}
var parts []codersdk.ChatMessagePart
if err := json.Unmarshal(arg.LastInjectedContext.RawMessage, &parts); err != nil {
return false
}
// Expect at least one context-file part for the
// working-directory AGENTS.md, with internal fields
// stripped (no content, OS, or directory).
for _, p := range parts {
if p.Type == codersdk.ChatMessagePartTypeContextFile && p.ContextFilePath != "" {
return p.ContextFileContent == "" &&
p.ContextFileOS == "" &&
p.ContextFileDirectory == ""
}
}
return false
}),
).Return(database.Chat{}, nil).Times(1)
conn := agentconnmock.NewMockAgentConn(ctrl)
conn.EXPECT().SetExtraHeaders(gomock.Any()).Times(1)
conn.EXPECT().ContextConfig(gomock.Any()).Return(workspacesdk.ContextConfigResponse{
Parts: []codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/home/coder/project/AGENTS.md",
ContextFileContent: "# Project instructions",
}},
}, nil).AnyTimes()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
server := &Server{
db: db,
logger: logger,
instructionLookupTimeout: 5 * time.Second,
agentConnFn: func(context.Context, uuid.UUID) (workspacesdk.AgentConn, func(), error) {
return conn, func() {}, nil
},
}
chatStateMu := &sync.Mutex{}
currentChat := chat
workspaceCtx := turnWorkspaceContext{
server: server,
chatStateMu: chatStateMu,
currentChat: &currentChat,
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
}
t.Cleanup(workspaceCtx.close)
instruction, _, err := server.persistInstructionFiles(
ctx,
chat,
uuid.New(),
workspaceCtx.getWorkspaceAgent,
workspaceCtx.getWorkspaceConn,
)
require.NoError(t, err)
require.Contains(t, instruction, "Operating System: linux")
require.Contains(t, instruction, "Working Directory: /home/coder/project")
}
func TestPersistInstructionFilesSkipsSentinelWhenWorkspaceUnavailable(t *testing.T) {
t.Parallel()
ctx := context.Background()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chat := database.Chat{
ID: uuid.New(),
WorkspaceID: uuid.NullUUID{
UUID: uuid.New(),
Valid: true,
},
}
server := &Server{
db: db,
logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
}
instruction, _, err := server.persistInstructionFiles(
ctx,
chat,
uuid.New(),
func(context.Context) (database.WorkspaceAgent, error) {
return database.WorkspaceAgent{
ID: uuid.New(),
Directory: "/home/coder/project",
}, nil
},
func(context.Context) (workspacesdk.AgentConn, error) {
return nil, errChatHasNoWorkspaceAgent
},
)
require.NoError(t, err)
require.Empty(t, instruction)
}
func TestPersistInstructionFilesSentinelWithSkills(t *testing.T) {
t.Parallel()
ctx := context.Background()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
workspaceID := uuid.New()
agentID := uuid.New()
chat := database.Chat{
ID: uuid.New(),
WorkspaceID: uuid.NullUUID{
UUID: workspaceID,
Valid: true,
},
AgentID: uuid.NullUUID{
UUID: agentID,
Valid: true,
},
}
workspaceAgent := database.WorkspaceAgent{
ID: agentID,
OperatingSystem: "linux",
Directory: "/home/coder/project",
ExpandedDirectory: "/home/coder/project",
}
db.EXPECT().GetWorkspaceAgentByID(
gomock.Any(),
agentID,
).Return(workspaceAgent, nil).Times(1)
db.EXPECT().InsertChatMessages(gomock.Any(),
gomock.Cond(func(x any) bool {
arg, ok := x.(database.InsertChatMessagesParams)
if !ok || arg.ChatID != chat.ID || len(arg.Content) != 1 {
return false
}
var parts []codersdk.ChatMessagePart
if err := json.Unmarshal([]byte(arg.Content[0]), &parts); err != nil {
return false
}
foundMarker := false
foundSkill := false
for _, p := range parts {
switch p.Type {
case codersdk.ChatMessagePartTypeContextFile:
if p.ContextFileAgentID == (uuid.NullUUID{UUID: agentID, Valid: true}) && p.ContextFileContent == "" {
foundMarker = true
}
case codersdk.ChatMessagePartTypeSkill:
if p.SkillName == "my-skill" && p.ContextFileAgentID == (uuid.NullUUID{UUID: agentID, Valid: true}) {
foundSkill = true
}
}
}
return foundMarker && foundSkill
}),
).Return(nil, nil).Times(1)
db.EXPECT().UpdateChatLastInjectedContext(gomock.Any(),
gomock.Cond(func(x any) bool {
arg, ok := x.(database.UpdateChatLastInjectedContextParams)
if !ok || arg.ID != chat.ID {
return false
}
if !arg.LastInjectedContext.Valid {
return false
}
var parts []codersdk.ChatMessagePart
if err := json.Unmarshal(arg.LastInjectedContext.RawMessage, &parts); err != nil {
return false
}
// The sentinel path should persist only skill parts
// with ContextFileAgentID set.
for _, p := range parts {
if p.Type == codersdk.ChatMessagePartTypeSkill &&
p.SkillName == "my-skill" &&
p.ContextFileAgentID == (uuid.NullUUID{UUID: agentID, Valid: true}) {
return true
}
}
return false
}),
).Return(database.Chat{}, nil).Times(1)
conn := agentconnmock.NewMockAgentConn(ctrl)
conn.EXPECT().SetExtraHeaders(gomock.Any()).Times(1)
conn.EXPECT().ContextConfig(gomock.Any()).Return(workspacesdk.ContextConfigResponse{
// Agent returns pre-read content: no instruction files
// found but one skill discovered.
Parts: []codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeSkill,
SkillName: "my-skill",
SkillDescription: "A test skill",
SkillDir: "/home/coder/project/.agents/skills/my-skill",
}},
}, nil).AnyTimes()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
server := &Server{
db: db,
logger: logger,
instructionLookupTimeout: 5 * time.Second,
agentConnFn: func(context.Context, uuid.UUID) (workspacesdk.AgentConn, func(), error) {
return conn, func() {}, nil
},
}
chatStateMu := &sync.Mutex{}
currentChat := chat
workspaceCtx := turnWorkspaceContext{
server: server,
chatStateMu: chatStateMu,
currentChat: &currentChat,
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
}
t.Cleanup(workspaceCtx.close)
instruction, skills, err := server.persistInstructionFiles(
ctx,
chat,
uuid.New(),
workspaceCtx.getWorkspaceAgent,
workspaceCtx.getWorkspaceConn,
)
require.NoError(t, err)
// Sentinel path returns empty instruction string.
require.Empty(t, instruction)
// Skills are still discovered and returned.
require.Len(t, skills, 1)
require.Equal(t, "my-skill", skills[0].Name)
}
func TestPersistInstructionFilesSentinelNoSkillsClearsColumn(t *testing.T) {
t.Parallel()
ctx := context.Background()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
workspaceID := uuid.New()
agentID := uuid.New()
chat := database.Chat{
ID: uuid.New(),
WorkspaceID: uuid.NullUUID{
UUID: workspaceID,
Valid: true,
},
AgentID: uuid.NullUUID{
UUID: agentID,
Valid: true,
},
}
workspaceAgent := database.WorkspaceAgent{
ID: agentID,
OperatingSystem: "linux",
Directory: "/home/coder/project",
ExpandedDirectory: "/home/coder/project",
}
db.EXPECT().GetWorkspaceAgentByID(
gomock.Any(),
agentID,
).Return(workspaceAgent, nil).Times(1)
db.EXPECT().InsertChatMessages(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
db.EXPECT().UpdateChatLastInjectedContext(gomock.Any(),
gomock.Cond(func(x any) bool {
arg, ok := x.(database.UpdateChatLastInjectedContextParams)
if !ok || arg.ID != chat.ID {
return false
}
// No skills discovered, so the column should be
// cleared to NULL.
return !arg.LastInjectedContext.Valid
}),
).Return(database.Chat{}, nil).Times(1)
conn := agentconnmock.NewMockAgentConn(ctrl)
conn.EXPECT().SetExtraHeaders(gomock.Any()).Times(1)
conn.EXPECT().ContextConfig(gomock.Any()).Return(workspacesdk.ContextConfigResponse{
// Agent returns pre-read content: no files, no skills.
Parts: []codersdk.ChatMessagePart{},
}, nil).AnyTimes()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
server := &Server{
db: db,
logger: logger,
instructionLookupTimeout: 5 * time.Second,
agentConnFn: func(context.Context, uuid.UUID) (workspacesdk.AgentConn, func(), error) {
return conn, func() {}, nil
},
}
chatStateMu := &sync.Mutex{}
currentChat := chat
workspaceCtx := turnWorkspaceContext{
server: server,
chatStateMu: chatStateMu,
currentChat: &currentChat,
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
}
t.Cleanup(workspaceCtx.close)
instruction, skills, err := server.persistInstructionFiles(
ctx,
chat,
uuid.New(),
workspaceCtx.getWorkspaceAgent,
workspaceCtx.getWorkspaceConn,
)
require.NoError(t, err)
// Sentinel path: empty instruction, no skills.
require.Empty(t, instruction)
require.Empty(t, skills)
}
func TestTurnWorkspaceContext_BindingFirstPath(t *testing.T) {
t.Parallel()
ctx := context.Background()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
workspaceID := uuid.New()
agentID := uuid.New()
chat := database.Chat{
ID: uuid.New(),
WorkspaceID: uuid.NullUUID{
UUID: workspaceID,
Valid: true,
},
AgentID: uuid.NullUUID{
UUID: agentID,
Valid: true,
},
}
workspaceAgent := database.WorkspaceAgent{ID: agentID}
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).Return(workspaceAgent, nil).Times(1)
chatStateMu := &sync.Mutex{}
currentChat := chat
workspaceCtx := turnWorkspaceContext{
server: &Server{db: db},
chatStateMu: chatStateMu,
currentChat: &currentChat,
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
}
t.Cleanup(workspaceCtx.close)
chatSnapshot, agent, err := workspaceCtx.ensureWorkspaceAgent(ctx)
require.NoError(t, err)
require.Equal(t, chat, chatSnapshot)
require.Equal(t, workspaceAgent, agent)
gotAgent, err := workspaceCtx.getWorkspaceAgent(ctx)
require.NoError(t, err)
require.Equal(t, workspaceAgent, gotAgent)
require.Equal(t, chat, currentChat)
}
func TestTurnWorkspaceContext_NullBindingLazyBind(t *testing.T) {
t.Parallel()
ctx := context.Background()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
workspaceID := uuid.New()
buildID := uuid.New()
agentID := uuid.New()
chat := database.Chat{
ID: uuid.New(),
WorkspaceID: uuid.NullUUID{
UUID: workspaceID,
Valid: true,
},
}
workspaceAgent := database.WorkspaceAgent{ID: agentID}
updatedChat := chat
updatedChat.BuildID = uuid.NullUUID{UUID: buildID, Valid: true}
updatedChat.AgentID = uuid.NullUUID{UUID: agentID, Valid: true}
gomock.InOrder(
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).Return([]database.WorkspaceAgent{workspaceAgent}, nil),
db.EXPECT().GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), workspaceID).Return(database.WorkspaceBuild{ID: buildID}, nil),
db.EXPECT().UpdateChatBuildAgentBinding(gomock.Any(), database.UpdateChatBuildAgentBindingParams{
BuildID: uuid.NullUUID{UUID: buildID, Valid: true},
AgentID: uuid.NullUUID{UUID: agentID, Valid: true},
ID: chat.ID,
}).Return(updatedChat, nil),
)
chatStateMu := &sync.Mutex{}
currentChat := chat
workspaceCtx := turnWorkspaceContext{
server: &Server{db: db},
chatStateMu: chatStateMu,
currentChat: &currentChat,
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
}
t.Cleanup(workspaceCtx.close)
chatSnapshot, agent, err := workspaceCtx.ensureWorkspaceAgent(ctx)
require.NoError(t, err)
require.Equal(t, updatedChat, chatSnapshot)
require.Equal(t, workspaceAgent, agent)
require.Equal(t, updatedChat, currentChat)
gotAgent, err := workspaceCtx.getWorkspaceAgent(ctx)
require.NoError(t, err)
require.Equal(t, workspaceAgent, gotAgent)
}
func TestTurnWorkspaceContext_StaleBindingRepair(t *testing.T) {
t.Parallel()
ctx := context.Background()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
workspaceID := uuid.New()
staleAgentID := uuid.New()
buildID := uuid.New()
currentAgentID := uuid.New()
chat := database.Chat{
ID: uuid.New(),
WorkspaceID: uuid.NullUUID{
UUID: workspaceID,
Valid: true,
},
AgentID: uuid.NullUUID{
UUID: staleAgentID,
Valid: true,
},
}
currentAgent := database.WorkspaceAgent{ID: currentAgentID}
updatedChat := chat
updatedChat.BuildID = uuid.NullUUID{UUID: buildID, Valid: true}
updatedChat.AgentID = uuid.NullUUID{UUID: currentAgentID, Valid: true}
gomock.InOrder(
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), staleAgentID).Return(database.WorkspaceAgent{}, xerrors.New("missing agent")),
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).Return([]database.WorkspaceAgent{currentAgent}, nil),
db.EXPECT().GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), workspaceID).Return(database.WorkspaceBuild{ID: buildID}, nil),
db.EXPECT().UpdateChatBuildAgentBinding(gomock.Any(), database.UpdateChatBuildAgentBindingParams{
BuildID: uuid.NullUUID{UUID: buildID, Valid: true},
AgentID: uuid.NullUUID{UUID: currentAgentID, Valid: true},
ID: chat.ID,
}).Return(updatedChat, nil),
)
chatStateMu := &sync.Mutex{}
currentChat := chat
workspaceCtx := turnWorkspaceContext{
server: &Server{db: db},
chatStateMu: chatStateMu,
currentChat: &currentChat,
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
}
t.Cleanup(workspaceCtx.close)
chatSnapshot, agent, err := workspaceCtx.ensureWorkspaceAgent(ctx)
require.NoError(t, err)
require.Equal(t, updatedChat, chatSnapshot)
require.Equal(t, currentAgent, agent)
require.Equal(t, updatedChat, currentChat)
}
func TestTurnWorkspaceContextGetWorkspaceConnLazyValidationSwitchesWorkspaceAgent(t *testing.T) {
t.Parallel()
ctx := context.Background()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
workspaceID := uuid.New()
staleAgentID := uuid.New()
currentAgentID := uuid.New()
buildID := uuid.New()
chat := database.Chat{
ID: uuid.New(),
WorkspaceID: uuid.NullUUID{
UUID: workspaceID,
Valid: true,
},
AgentID: uuid.NullUUID{
UUID: staleAgentID,
Valid: true,
},
}
staleAgent := database.WorkspaceAgent{ID: staleAgentID}
currentAgent := database.WorkspaceAgent{ID: currentAgentID}
updatedChat := chat
updatedChat.BuildID = uuid.NullUUID{UUID: buildID, Valid: true}
updatedChat.AgentID = uuid.NullUUID{UUID: currentAgentID, Valid: true}
gomock.InOrder(
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), staleAgentID).Return(staleAgent, nil),
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).Return([]database.WorkspaceAgent{currentAgent}, nil),
db.EXPECT().GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), workspaceID).Return(database.WorkspaceBuild{ID: buildID}, nil),
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), currentAgentID).Return(currentAgent, nil),
db.EXPECT().UpdateChatBuildAgentBinding(gomock.Any(), database.UpdateChatBuildAgentBindingParams{
BuildID: uuid.NullUUID{UUID: buildID, Valid: true},
AgentID: uuid.NullUUID{UUID: currentAgentID, Valid: true},
ID: chat.ID,
}).Return(updatedChat, nil),
)
conn := agentconnmock.NewMockAgentConn(ctrl)
conn.EXPECT().SetExtraHeaders(gomock.Any()).Times(1)
var dialed []uuid.UUID
server := &Server{db: db}
server.agentConnFn = func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) {
dialed = append(dialed, agentID)
if agentID == staleAgentID {
return nil, nil, xerrors.New("dial failed")
}
return conn, func() {}, nil
}
chatStateMu := &sync.Mutex{}
currentChat := chat
workspaceCtx := turnWorkspaceContext{
server: server,
chatStateMu: chatStateMu,
currentChat: &currentChat,
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
}
t.Cleanup(workspaceCtx.close)
gotConn, err := workspaceCtx.getWorkspaceConn(ctx)
require.NoError(t, err)
require.Same(t, conn, gotConn)
require.Equal(t, []uuid.UUID{staleAgentID, currentAgentID}, dialed)
require.Equal(t, updatedChat, currentChat)
gotAgent, err := workspaceCtx.getWorkspaceAgent(ctx)
require.NoError(t, err)
require.Equal(t, currentAgent, gotAgent)
}
func TestTurnWorkspaceContextGetWorkspaceConnFastFailsWithoutCurrentAgent(t *testing.T) {
t.Parallel()
ctx := context.Background()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
workspaceID := uuid.New()
staleAgentID := uuid.New()
chat := database.Chat{
ID: uuid.New(),
WorkspaceID: uuid.NullUUID{
UUID: workspaceID,
Valid: true,
},
AgentID: uuid.NullUUID{
UUID: staleAgentID,
Valid: true,
},
}
staleAgent := database.WorkspaceAgent{ID: staleAgentID}
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), staleAgentID).
Return(staleAgent, nil).
Times(1)
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).
Return([]database.WorkspaceAgent{}, nil).
Times(1)
server := &Server{db: db}
server.agentConnFn = func(context.Context, uuid.UUID) (workspacesdk.AgentConn, func(), error) {
return nil, nil, xerrors.New("dial failed")
}
chatStateMu := &sync.Mutex{}
currentChat := chat
workspaceCtx := turnWorkspaceContext{
server: server,
chatStateMu: chatStateMu,
currentChat: &currentChat,
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
}
defer workspaceCtx.close()
gotConn, err := workspaceCtx.getWorkspaceConn(ctx)
require.Nil(t, gotConn)
require.ErrorIs(t, err, errChatHasNoWorkspaceAgent)
workspaceCtx.mu.Lock()
defer workspaceCtx.mu.Unlock()
require.Equal(t, database.WorkspaceAgent{}, workspaceCtx.agent)
require.False(t, workspaceCtx.agentLoaded)
require.Nil(t, workspaceCtx.conn)
require.Nil(t, workspaceCtx.releaseConn)
require.Equal(t, uuid.NullUUID{}, workspaceCtx.cachedWorkspaceID)
}
func TestTurnWorkspaceContext_SelectWorkspaceClearsCachedState(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
currentChat := database.Chat{
ID: uuid.New(),
WorkspaceID: uuid.NullUUID{
UUID: uuid.New(),
Valid: true,
},
}
updatedChat := database.Chat{
ID: currentChat.ID,
WorkspaceID: uuid.NullUUID{
UUID: uuid.New(),
Valid: true,
},
}
cachedConn := agentconnmock.NewMockAgentConn(ctrl)
releaseCalls := 0
workspaceCtx := turnWorkspaceContext{
chatStateMu: &sync.Mutex{},
currentChat: &currentChat,
}
workspaceCtx.agent = database.WorkspaceAgent{ID: uuid.New()}
workspaceCtx.agentLoaded = true
workspaceCtx.conn = cachedConn
workspaceCtx.cachedWorkspaceID = currentChat.WorkspaceID
workspaceCtx.releaseConn = func() {
releaseCalls++
}
workspaceCtx.selectWorkspace(updatedChat)
require.Equal(t, updatedChat, currentChat)
require.Equal(t, 1, releaseCalls)
workspaceCtx.mu.Lock()
defer workspaceCtx.mu.Unlock()
require.Equal(t, database.WorkspaceAgent{}, workspaceCtx.agent)
require.False(t, workspaceCtx.agentLoaded)
require.Nil(t, workspaceCtx.conn)
require.Nil(t, workspaceCtx.releaseConn)
require.Equal(t, uuid.NullUUID{}, workspaceCtx.cachedWorkspaceID)
}
func TestTurnWorkspaceContext_EnsureWorkspaceAgentIgnoresCachedAgentForDifferentWorkspace(t *testing.T) {
t.Parallel()
ctx := context.Background()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
workspaceOneID := uuid.New()
workspaceTwoID := uuid.New()
buildID := uuid.New()
cachedAgent := database.WorkspaceAgent{ID: uuid.New()}
resolvedAgent := database.WorkspaceAgent{ID: uuid.New()}
chat := database.Chat{
ID: uuid.New(),
WorkspaceID: uuid.NullUUID{
UUID: workspaceTwoID,
Valid: true,
},
}
updatedChat := chat
updatedChat.BuildID = uuid.NullUUID{UUID: buildID, Valid: true}
updatedChat.AgentID = uuid.NullUUID{UUID: resolvedAgent.ID, Valid: true}
gomock.InOrder(
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceTwoID).Return([]database.WorkspaceAgent{resolvedAgent}, nil),
db.EXPECT().GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), workspaceTwoID).Return(database.WorkspaceBuild{ID: buildID}, nil),
db.EXPECT().UpdateChatBuildAgentBinding(gomock.Any(), database.UpdateChatBuildAgentBindingParams{
ID: chat.ID,
BuildID: uuid.NullUUID{UUID: buildID, Valid: true},
AgentID: uuid.NullUUID{UUID: resolvedAgent.ID, Valid: true},
}).Return(updatedChat, nil),
)
chatStateMu := &sync.Mutex{}
currentChat := chat
workspaceCtx := turnWorkspaceContext{
server: &Server{db: db},
chatStateMu: chatStateMu,
currentChat: &currentChat,
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
}
workspaceCtx.agent = cachedAgent
workspaceCtx.agentLoaded = true
workspaceCtx.cachedWorkspaceID = uuid.NullUUID{UUID: workspaceOneID, Valid: true}
defer workspaceCtx.close()
chatSnapshot, agent, err := workspaceCtx.ensureWorkspaceAgent(ctx)
require.NoError(t, err)
require.Equal(t, updatedChat, chatSnapshot)
require.Equal(t, resolvedAgent, agent)
require.Equal(t, updatedChat, currentChat)
}
func TestSubscribeSkipsDatabaseCatchupForLocallyDeliveredMessage(t *testing.T) {
t.Parallel()
ctx, cancelCtx := context.WithCancel(context.Background())
defer cancelCtx()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
initialMessage := database.ChatMessage{
ID: 1,
ChatID: chatID,
Role: database.ChatMessageRoleUser,
}
localMessage := database.ChatMessage{
ID: 2,
ChatID: chatID,
Role: database.ChatMessageRoleAssistant,
}
gomock.InOrder(
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
}).Return([]database.ChatMessage{initialMessage}, nil),
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
)
server := newSubscribeTestServer(t, db)
_, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
require.True(t, ok)
defer cancel()
server.publishMessage(chatID, localMessage)
event := requireStreamMessageEvent(t, events)
require.Equal(t, int64(2), event.Message.ID)
requireNoStreamEvent(t, events, 200*time.Millisecond)
}
func TestSubscribeUsesDurableCacheWhenLocalMessageWasNotDelivered(t *testing.T) {
t.Parallel()
ctx, cancelCtx := context.WithCancel(context.Background())
defer cancelCtx()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
initialMessage := database.ChatMessage{
ID: 1,
ChatID: chatID,
Role: database.ChatMessageRoleUser,
}
cachedMessage := codersdk.ChatMessage{
ID: 2,
ChatID: chatID,
Role: codersdk.ChatMessageRoleAssistant,
}
gomock.InOrder(
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
}).Return([]database.ChatMessage{initialMessage}, nil),
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
)
server := newSubscribeTestServer(t, db)
server.cacheDurableMessage(chatID, codersdk.ChatStreamEvent{
Type: codersdk.ChatStreamEventTypeMessage,
ChatID: chatID,
Message: &cachedMessage,
})
_, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
require.True(t, ok)
defer cancel()
server.publishChatStreamNotify(chatID, coderdpubsub.ChatStreamNotifyMessage{
AfterMessageID: 1,
})
event := requireStreamMessageEvent(t, events)
require.Equal(t, int64(2), event.Message.ID)
requireNoStreamEvent(t, events, 200*time.Millisecond)
}
func TestSubscribeQueriesDatabaseWhenDurableCacheMisses(t *testing.T) {
t.Parallel()
ctx, cancelCtx := context.WithCancel(context.Background())
defer cancelCtx()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
initialMessage := database.ChatMessage{
ID: 1,
ChatID: chatID,
Role: database.ChatMessageRoleUser,
}
catchupMessage := database.ChatMessage{
ID: 2,
ChatID: chatID,
Role: database.ChatMessageRoleAssistant,
}
gomock.InOrder(
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
}).Return([]database.ChatMessage{initialMessage}, nil),
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 1,
}).Return([]database.ChatMessage{catchupMessage}, nil),
)
server := newSubscribeTestServer(t, db)
_, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
require.True(t, ok)
defer cancel()
server.publishChatStreamNotify(chatID, coderdpubsub.ChatStreamNotifyMessage{
AfterMessageID: 1,
})
event := requireStreamMessageEvent(t, events)
require.Equal(t, int64(2), event.Message.ID)
requireNoStreamEvent(t, events, 200*time.Millisecond)
}
func TestSubscribeFullRefreshStillUsesDatabaseCatchup(t *testing.T) {
t.Parallel()
ctx, cancelCtx := context.WithCancel(context.Background())
defer cancelCtx()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
initialMessage := database.ChatMessage{
ID: 1,
ChatID: chatID,
Role: database.ChatMessageRoleUser,
}
editedMessage := database.ChatMessage{
ID: 1,
ChatID: chatID,
Role: database.ChatMessageRoleUser,
}
gomock.InOrder(
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
}).Return([]database.ChatMessage{initialMessage}, nil),
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
}).Return([]database.ChatMessage{editedMessage}, nil),
)
server := newSubscribeTestServer(t, db)
_, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
require.True(t, ok)
defer cancel()
server.publishEditedMessage(chatID, editedMessage)
event := requireStreamMessageEvent(t, events)
require.Equal(t, int64(1), event.Message.ID)
requireNoStreamEvent(t, events, 200*time.Millisecond)
}
func TestSubscribeDeliversRetryEventViaPubsubOnce(t *testing.T) {
t.Parallel()
ctx, cancelCtx := context.WithCancel(context.Background())
defer cancelCtx()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
gomock.InOrder(
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
}).Return(nil, nil),
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
)
server := newSubscribeTestServer(t, db)
_, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
require.True(t, ok)
defer cancel()
retryingAt := time.Unix(1_700_000_000, 0).UTC()
expected := &codersdk.ChatStreamRetry{
Attempt: 1,
DelayMs: (1500 * time.Millisecond).Milliseconds(),
Error: "OpenAI is rate limiting requests (HTTP 429).",
Kind: chaterror.KindRateLimit,
Provider: "openai",
StatusCode: 429,
RetryingAt: retryingAt,
}
server.publishRetry(chatID, expected)
event := requireStreamRetryEvent(t, events)
require.Equal(t, expected, event.Retry)
requireNoStreamEvent(t, events, 200*time.Millisecond)
}
func TestSubscribePrefersStructuredErrorPayloadViaPubsub(t *testing.T) {
t.Parallel()
ctx, cancelCtx := context.WithCancel(context.Background())
defer cancelCtx()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
gomock.InOrder(
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
}).Return(nil, nil),
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
)
server := newSubscribeTestServer(t, db)
_, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
require.True(t, ok)
defer cancel()
classified := chaterror.ClassifiedError{
Message: "OpenAI is rate limiting requests (HTTP 429).",
Kind: chaterror.KindRateLimit,
Provider: "openai",
Retryable: true,
StatusCode: 429,
}
server.publishError(chatID, classified)
event := requireStreamErrorEvent(t, events)
require.Equal(t, chaterror.StreamErrorPayload(classified), event.Error)
requireNoStreamEvent(t, events, 200*time.Millisecond)
}
func TestSubscribeFallsBackToLegacyErrorStringViaPubsub(t *testing.T) {
t.Parallel()
ctx, cancelCtx := context.WithCancel(context.Background())
defer cancelCtx()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
gomock.InOrder(
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
}).Return(nil, nil),
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
)
server := newSubscribeTestServer(t, db)
_, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
require.True(t, ok)
defer cancel()
server.publishChatStreamNotify(chatID, coderdpubsub.ChatStreamNotifyMessage{
Error: "legacy error only",
})
event := requireStreamErrorEvent(t, events)
require.Equal(t, &codersdk.ChatStreamError{Message: "legacy error only"}, event.Error)
requireNoStreamEvent(t, events, 200*time.Millisecond)
}
func newSubscribeTestServer(t *testing.T, db database.Store) *Server {
t.Helper()
return &Server{
db: db,
logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
pubsub: dbpubsub.NewInMemory(),
}
}
func requireStreamMessageEvent(t *testing.T, events <-chan codersdk.ChatStreamEvent) codersdk.ChatStreamEvent {
t.Helper()
select {
case event, ok := <-events:
require.True(t, ok, "chat stream closed before delivering an event")
require.Equal(t, codersdk.ChatStreamEventTypeMessage, event.Type)
require.NotNil(t, event.Message)
return event
case <-time.After(time.Second):
t.Fatal("timed out waiting for chat stream message event")
return codersdk.ChatStreamEvent{}
}
}
func requireStreamRetryEvent(t *testing.T, events <-chan codersdk.ChatStreamEvent) codersdk.ChatStreamEvent {
t.Helper()
select {
case event, ok := <-events:
require.True(t, ok, "chat stream closed before delivering an event")
require.Equal(t, codersdk.ChatStreamEventTypeRetry, event.Type)
require.NotNil(t, event.Retry)
return event
case <-time.After(time.Second):
t.Fatal("timed out waiting for chat stream retry event")
return codersdk.ChatStreamEvent{}
}
}
func requireStreamErrorEvent(t *testing.T, events <-chan codersdk.ChatStreamEvent) codersdk.ChatStreamEvent {
t.Helper()
select {
case event, ok := <-events:
require.True(t, ok, "chat stream closed before delivering an event")
require.Equal(t, codersdk.ChatStreamEventTypeError, event.Type)
require.NotNil(t, event.Error)
return event
case <-time.After(time.Second):
t.Fatal("timed out waiting for chat stream error event")
return codersdk.ChatStreamEvent{}
}
}
func requireNoStreamEvent(t *testing.T, events <-chan codersdk.ChatStreamEvent, wait time.Duration) {
t.Helper()
select {
case event, ok := <-events:
if !ok {
t.Fatal("chat stream closed unexpectedly")
}
t.Fatalf("unexpected chat stream event: %+v", event)
case <-time.After(wait):
}
}
// TestPublishToStream_DropWarnRateLimiting walks through a
// realistic lifecycle: buffer fills up, subscriber channel fills
// up, counters get reset between steps. It verifies that WARN
// logs are rate-limited to at most once per streamDropWarnInterval
// and that counter resets re-enable an immediate WARN.
func TestPublishToStream_DropWarnRateLimiting(t *testing.T) {
t.Parallel()
sink := testutil.NewFakeSink(t)
mClock := quartz.NewMock(t)
server := &Server{
logger: sink.Logger(),
clock: mClock,
}
chatID := uuid.New()
subCh := make(chan codersdk.ChatStreamEvent, 1)
subCh <- codersdk.ChatStreamEvent{} // pre-fill so sends always drop
// Set up state that mirrors a running chat: buffer at capacity,
// buffering enabled, one saturated subscriber.
state := &chatStreamState{
buffering: true,
buffer: make([]codersdk.ChatStreamEvent, maxStreamBufferSize),
subscribers: map[uuid.UUID]chan codersdk.ChatStreamEvent{
uuid.New(): subCh,
},
}
server.chatStreams.Store(chatID, state)
bufferMsg := "chat stream buffer full, dropping oldest event"
subMsg := "dropping chat stream event"
filter := func(level slog.Level, msg string) func(slog.SinkEntry) bool {
return func(e slog.SinkEntry) bool {
return e.Level == level && e.Message == msg
}
}
// --- Phase 1: buffer-full rate limiting ---
// message_part events hit both the buffer-full and subscriber-full
// paths. The first publish triggers a WARN for each; the rest
// within the window are DEBUG.
partEvent := codersdk.ChatStreamEvent{
Type: codersdk.ChatStreamEventTypeMessagePart,
MessagePart: &codersdk.ChatStreamMessagePart{},
}
for i := 0; i < 50; i++ {
server.publishToStream(chatID, partEvent)
}
require.Len(t, sink.Entries(filter(slog.LevelWarn, bufferMsg)), 1)
require.Empty(t, sink.Entries(filter(slog.LevelDebug, bufferMsg)))
requireFieldValue(t, sink.Entries(filter(slog.LevelWarn, bufferMsg))[0], "dropped_count", int64(1))
// Subscriber also saw 50 drops (one per publish).
require.Len(t, sink.Entries(filter(slog.LevelWarn, subMsg)), 1)
require.Empty(t, sink.Entries(filter(slog.LevelDebug, subMsg)))
requireFieldValue(t, sink.Entries(filter(slog.LevelWarn, subMsg))[0], "dropped_count", int64(1))
// --- Phase 2: clock advance triggers second WARN with count ---
mClock.Advance(streamDropWarnInterval + time.Second)
server.publishToStream(chatID, partEvent)
bufWarn := sink.Entries(filter(slog.LevelWarn, bufferMsg))
require.Len(t, bufWarn, 2)
requireFieldValue(t, bufWarn[1], "dropped_count", int64(50))
subWarn := sink.Entries(filter(slog.LevelWarn, subMsg))
require.Len(t, subWarn, 2)
requireFieldValue(t, subWarn[1], "dropped_count", int64(50))
// --- Phase 3: counter reset (simulates step persist) ---
state.mu.Lock()
state.buffer = make([]codersdk.ChatStreamEvent, maxStreamBufferSize)
state.resetDropCounters()
state.mu.Unlock()
// The very next drop should WARN immediately — the reset zeroed
// lastWarnAt so the interval check passes.
server.publishToStream(chatID, partEvent)
bufWarn = sink.Entries(filter(slog.LevelWarn, bufferMsg))
require.Len(t, bufWarn, 3, "expected WARN immediately after counter reset")
requireFieldValue(t, bufWarn[2], "dropped_count", int64(1))
subWarn = sink.Entries(filter(slog.LevelWarn, subMsg))
require.Len(t, subWarn, 3, "expected subscriber WARN immediately after counter reset")
requireFieldValue(t, subWarn[2], "dropped_count", int64(1))
}
func TestResolveUserCompactionThreshold(t *testing.T) {
t.Parallel()
userID := uuid.New()
modelConfigID := uuid.New()
expectedKey := codersdk.CompactionThresholdKey(modelConfigID)
tests := []struct {
name string
dbReturn string
dbErr error
wantVal int32
wantOK bool
wantWarnLog bool
}{
{
name: "NoRowsReturnsDefault",
dbErr: sql.ErrNoRows,
wantOK: false,
},
{
name: "ValidOverride",
dbReturn: "75",
wantVal: 75,
wantOK: true,
},
{
name: "OutOfRangeValue",
dbReturn: "101",
wantOK: false,
},
{
name: "NonIntegerValue",
dbReturn: "abc",
wantOK: false,
},
{
name: "UnexpectedDBError",
dbErr: xerrors.New("connection refused"),
wantOK: false,
wantWarnLog: true,
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
mockDB := dbmock.NewMockStore(ctrl)
sink := testutil.NewFakeSink(t)
srv := &Server{
db: mockDB,
logger: sink.Logger(),
}
mockDB.EXPECT().GetUserChatCompactionThreshold(gomock.Any(), database.GetUserChatCompactionThresholdParams{
UserID: userID,
Key: expectedKey,
}).Return(tc.dbReturn, tc.dbErr)
val, ok := srv.resolveUserCompactionThreshold(context.Background(), userID, modelConfigID)
require.Equal(t, tc.wantVal, val)
require.Equal(t, tc.wantOK, ok)
warns := sink.Entries(func(e slog.SinkEntry) bool {
return e.Level == slog.LevelWarn
})
if tc.wantWarnLog {
require.NotEmpty(t, warns, "expected a warning log entry")
return
}
require.Empty(t, warns, "unexpected warning log entry")
})
}
}
// requireFieldValue asserts that a SinkEntry contains a field with
// the given name and value.
func requireFieldValue(t *testing.T, entry slog.SinkEntry, name string, expected interface{}) {
t.Helper()
for _, f := range entry.Fields {
if f.Name == name {
require.Equal(t, expected, f.Value, "field %q value mismatch", name)
return
}
}
t.Fatalf("field %q not found in log entry", name)
}
func TestSkillsFromParts(t *testing.T) {
t.Parallel()
t.Run("Empty", func(t *testing.T) {
t.Parallel()
got := skillsFromParts(nil)
require.Empty(t, got)
})
t.Run("NoSkillParts", func(t *testing.T) {
t.Parallel()
msgs := []database.ChatMessage{
chatMessageWithParts([]codersdk.ChatMessagePart{
{Type: codersdk.ChatMessagePartTypeText, Text: "hello"},
}),
}
got := skillsFromParts(msgs)
require.Empty(t, got)
})
t.Run("SingleSkill", func(t *testing.T) {
t.Parallel()
msgs := []database.ChatMessage{
chatMessageWithParts([]codersdk.ChatMessagePart{
{
Type: codersdk.ChatMessagePartTypeSkill,
SkillName: "deep-review",
SkillDescription: "Multi-reviewer code review",
SkillDir: "/home/coder/.agents/skills/deep-review",
},
}),
}
got := skillsFromParts(msgs)
require.Len(t, got, 1)
require.Equal(t, "deep-review", got[0].Name)
require.Equal(t, "Multi-reviewer code review", got[0].Description)
require.Equal(t, "/home/coder/.agents/skills/deep-review", got[0].Dir)
})
t.Run("MultipleSkillsAcrossMessages", func(t *testing.T) {
t.Parallel()
msgs := []database.ChatMessage{
chatMessageWithParts([]codersdk.ChatMessagePart{
{
Type: codersdk.ChatMessagePartTypeSkill,
SkillName: "pull-requests",
SkillDir: "/home/coder/.agents/skills/pull-requests",
},
}),
chatMessageWithParts([]codersdk.ChatMessagePart{
{
Type: codersdk.ChatMessagePartTypeSkill,
SkillName: "deep-review",
SkillDir: "/home/coder/.agents/skills/deep-review",
},
}),
}
got := skillsFromParts(msgs)
require.Len(t, got, 2)
require.Equal(t, "pull-requests", got[0].Name)
require.Equal(t, "deep-review", got[1].Name)
})
t.Run("MixedPartTypes", func(t *testing.T) {
t.Parallel()
msgs := []database.ChatMessage{
chatMessageWithParts([]codersdk.ChatMessagePart{
{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/home/coder/.coder/AGENTS.md",
},
{
Type: codersdk.ChatMessagePartTypeSkill,
SkillName: "refine-plan",
SkillDir: "/home/coder/.agents/skills/refine-plan",
},
}),
// A text-only message should be skipped entirely.
chatMessageWithParts([]codersdk.ChatMessagePart{
{Type: codersdk.ChatMessagePartTypeText, Text: "user turn"},
}),
}
got := skillsFromParts(msgs)
require.Len(t, got, 1)
require.Equal(t, "refine-plan", got[0].Name)
require.Equal(t, "/home/coder/.agents/skills/refine-plan", got[0].Dir)
})
t.Run("OptionalDescriptionOmitted", func(t *testing.T) {
t.Parallel()
msgs := []database.ChatMessage{
chatMessageWithParts([]codersdk.ChatMessagePart{
{
Type: codersdk.ChatMessagePartTypeSkill,
SkillName: "refine-plan",
SkillDir: "/home/coder/.agents/skills/refine-plan",
},
}),
}
got := skillsFromParts(msgs)
require.Len(t, got, 1)
require.Equal(t, "refine-plan", got[0].Name)
require.Empty(t, got[0].Description)
})
t.Run("InvalidJSON", func(t *testing.T) {
t.Parallel()
msgs := []database.ChatMessage{
{
Content: pqtype.NullRawMessage{
RawMessage: []byte(`not valid json with "skill" in it`),
Valid: true,
},
},
}
got := skillsFromParts(msgs)
require.Empty(t, got)
})
t.Run("RoundTrip", func(t *testing.T) {
// Simulate persist -> reconstruct cycle: marshal skill
// parts the same way persistInstructionFiles does, then
// verify skillsFromParts recovers the metadata.
t.Parallel()
want := []chattool.SkillMeta{
{Name: "deep-review", Description: "Multi-reviewer review", Dir: "/skills/deep-review"},
{Name: "pull-requests", Description: "", Dir: "/skills/pull-requests"},
}
agentID := uuid.New()
var parts []codersdk.ChatMessagePart
for _, s := range want {
parts = append(parts, codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeSkill,
SkillName: s.Name,
SkillDescription: s.Description,
SkillDir: s.Dir,
ContextFileAgentID: uuid.NullUUID{UUID: agentID, Valid: true},
})
}
msgs := []database.ChatMessage{chatMessageWithParts(parts)}
got := skillsFromParts(msgs)
require.Len(t, got, len(want))
for i, w := range want {
require.Equal(t, w.Name, got[i].Name)
require.Equal(t, w.Description, got[i].Description)
require.Equal(t, w.Dir, got[i].Dir)
}
})
}
func TestContextFileAgentID(t *testing.T) {
t.Parallel()
t.Run("EmptyMessages", func(t *testing.T) {
t.Parallel()
id, ok := contextFileAgentID(nil)
require.Equal(t, uuid.Nil, id)
require.False(t, ok)
})
t.Run("NoContextFileParts", func(t *testing.T) {
t.Parallel()
msgs := []database.ChatMessage{
chatMessageWithParts([]codersdk.ChatMessagePart{
{Type: codersdk.ChatMessagePartTypeText, Text: "hello"},
}),
}
id, ok := contextFileAgentID(msgs)
require.Equal(t, uuid.Nil, id)
require.False(t, ok)
})
t.Run("SingleContextFile", func(t *testing.T) {
t.Parallel()
agentID := uuid.New()
msgs := []database.ChatMessage{
chatMessageWithParts([]codersdk.ChatMessagePart{
{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/some/path",
ContextFileAgentID: uuid.NullUUID{UUID: agentID, Valid: true},
},
}),
}
id, ok := contextFileAgentID(msgs)
require.Equal(t, agentID, id)
require.True(t, ok)
})
t.Run("MultipleContextFiles", func(t *testing.T) {
t.Parallel()
agentID1 := uuid.New()
agentID2 := uuid.New()
msgs := []database.ChatMessage{
chatMessageWithParts([]codersdk.ChatMessagePart{
{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/first/path",
ContextFileAgentID: uuid.NullUUID{UUID: agentID1, Valid: true},
},
}),
chatMessageWithParts([]codersdk.ChatMessagePart{
{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/second/path",
ContextFileAgentID: uuid.NullUUID{UUID: agentID2, Valid: true},
},
}),
}
id, ok := contextFileAgentID(msgs)
require.Equal(t, agentID2, id)
require.True(t, ok)
})
t.Run("IgnoresSkillOnlySentinel", func(t *testing.T) {
t.Parallel()
instructionAgentID := uuid.New()
sentinelAgentID := uuid.New()
msgs := []database.ChatMessage{
chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/workspace/AGENTS.md",
ContextFileAgentID: uuid.NullUUID{UUID: instructionAgentID, Valid: true},
}}),
chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: AgentChatContextSentinelPath,
ContextFileAgentID: uuid.NullUUID{
UUID: sentinelAgentID,
Valid: true,
},
}}),
}
id, ok := contextFileAgentID(msgs)
require.Equal(t, instructionAgentID, id)
require.True(t, ok)
})
t.Run("SentinelWithoutAgentID", func(t *testing.T) {
t.Parallel()
msgs := []database.ChatMessage{
chatMessageWithParts([]codersdk.ChatMessagePart{
{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFileAgentID: uuid.NullUUID{Valid: false},
},
}),
}
id, ok := contextFileAgentID(msgs)
require.Equal(t, uuid.Nil, id)
require.False(t, ok)
})
}
func TestHasPersistedInstructionFiles(t *testing.T) {
t.Parallel()
t.Run("IgnoresAgentChatContextSentinel", func(t *testing.T) {
t.Parallel()
agentID := uuid.New()
msgs := []database.ChatMessage{
chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: AgentChatContextSentinelPath,
ContextFileAgentID: uuid.NullUUID{
UUID: agentID,
Valid: true,
},
}}),
}
require.False(t, hasPersistedInstructionFiles(msgs))
})
t.Run("AcceptsPersistedInstructionFile", func(t *testing.T) {
t.Parallel()
agentID := uuid.New()
msgs := []database.ChatMessage{
chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/workspace/AGENTS.md",
ContextFileContent: "repo instructions",
ContextFileAgentID: uuid.NullUUID{UUID: agentID, Valid: true},
}}),
}
require.True(t, hasPersistedInstructionFiles(msgs))
})
}
func TestInstructionFromContextFilesUsesLatestContextAgent(t *testing.T) {
t.Parallel()
oldAgentID := uuid.New()
newAgentID := uuid.New()
msgs := []database.ChatMessage{
chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/old/AGENTS.md",
ContextFileContent: "old instructions",
ContextFileOS: "darwin",
ContextFileDirectory: "/old",
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
}}),
chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/new/AGENTS.md",
ContextFileContent: "new instructions",
ContextFileOS: "linux",
ContextFileDirectory: "/new",
ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true},
}}),
}
got := instructionFromContextFiles(msgs)
require.Contains(t, got, "new instructions")
require.Contains(t, got, "Operating System: linux")
require.Contains(t, got, "Working Directory: /new")
require.NotContains(t, got, "old instructions")
require.NotContains(t, got, "Operating System: darwin")
}
func TestInstructionFromContextFilesKeepsLegacyUnstampedParts(t *testing.T) {
t.Parallel()
oldAgentID := uuid.New()
newAgentID := uuid.New()
msgs := []database.ChatMessage{
chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/legacy/AGENTS.md",
ContextFileContent: "legacy instructions",
}}),
chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/old/AGENTS.md",
ContextFileContent: "old instructions",
ContextFileOS: "darwin",
ContextFileDirectory: "/old",
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
}}),
chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/new/AGENTS.md",
ContextFileContent: "new instructions",
ContextFileOS: "linux",
ContextFileDirectory: "/new",
ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true},
}}),
}
got := instructionFromContextFiles(msgs)
require.Contains(t, got, "legacy instructions")
require.Contains(t, got, "new instructions")
require.Contains(t, got, "Operating System: linux")
require.Contains(t, got, "Working Directory: /new")
require.NotContains(t, got, "old instructions")
require.NotContains(t, got, "Operating System: darwin")
}
func TestSkillsFromPartsKeepsLegacyUnstampedParts(t *testing.T) {
t.Parallel()
oldAgentID := uuid.New()
newAgentID := uuid.New()
msgs := []database.ChatMessage{
chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeSkill,
SkillName: "repo-helper-legacy",
SkillDir: "/skills/repo-helper-legacy",
}}),
chatMessageWithParts([]codersdk.ChatMessagePart{
{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/old/AGENTS.md",
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
},
{
Type: codersdk.ChatMessagePartTypeSkill,
SkillName: "repo-helper-old",
SkillDir: "/skills/repo-helper-old",
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
},
}),
chatMessageWithParts([]codersdk.ChatMessagePart{
{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: AgentChatContextSentinelPath,
ContextFileAgentID: uuid.NullUUID{
UUID: newAgentID,
Valid: true,
},
},
{
Type: codersdk.ChatMessagePartTypeSkill,
SkillName: "repo-helper-new",
SkillDir: "/skills/repo-helper-new",
ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true},
},
}),
}
got := skillsFromParts(msgs)
require.Equal(t, []chattool.SkillMeta{
{Name: "repo-helper-legacy", Dir: "/skills/repo-helper-legacy"},
{Name: "repo-helper-new", Dir: "/skills/repo-helper-new"},
}, got)
}
func TestSkillsFromPartsUsesLatestContextAgent(t *testing.T) {
t.Parallel()
oldAgentID := uuid.New()
newAgentID := uuid.New()
msgs := []database.ChatMessage{
chatMessageWithParts([]codersdk.ChatMessagePart{
{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/old/AGENTS.md",
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
},
{
Type: codersdk.ChatMessagePartTypeSkill,
SkillName: "repo-helper-old",
SkillDir: "/skills/repo-helper-old",
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
},
}),
chatMessageWithParts([]codersdk.ChatMessagePart{
{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: AgentChatContextSentinelPath,
ContextFileAgentID: uuid.NullUUID{
UUID: newAgentID,
Valid: true,
},
},
{
Type: codersdk.ChatMessagePartTypeSkill,
SkillName: "repo-helper-new",
SkillDir: "/skills/repo-helper-new",
ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true},
},
}),
}
got := skillsFromParts(msgs)
require.Equal(t, []chattool.SkillMeta{{
Name: "repo-helper-new",
Dir: "/skills/repo-helper-new",
}}, got)
}
func TestMergeSkillMetas(t *testing.T) {
t.Parallel()
persisted := []chattool.SkillMeta{{
Name: "repo-helper",
Description: "Persisted skill",
Dir: "/skills/repo-helper-old",
}}
discovered := []chattool.SkillMeta{
{
Name: "repo-helper",
Description: "Discovered replacement",
Dir: "/skills/repo-helper-new",
MetaFile: "SKILL.md",
},
{
Name: "deep-review",
Description: "Discovered skill",
Dir: "/skills/deep-review",
},
}
got := mergeSkillMetas(persisted, discovered)
require.Equal(t, []chattool.SkillMeta{
discovered[0],
discovered[1],
}, got)
}
func TestSelectSkillMetasForInstructionRefresh(t *testing.T) {
t.Parallel()
persisted := []chattool.SkillMeta{{Name: "persisted", Dir: "/skills/persisted"}}
discovered := []chattool.SkillMeta{{Name: "discovered", Dir: "/skills/discovered"}}
currentAgentID := uuid.New()
otherAgentID := uuid.New()
t.Run("MergesCurrentAgentSkills", func(t *testing.T) {
t.Parallel()
got := selectSkillMetasForInstructionRefresh(
persisted,
discovered,
uuid.NullUUID{UUID: currentAgentID, Valid: true},
uuid.NullUUID{UUID: currentAgentID, Valid: true},
)
require.Equal(t, []chattool.SkillMeta{discovered[0], persisted[0]}, got)
})
t.Run("DropsStalePersistedSkillsWhenAgentChanged", func(t *testing.T) {
t.Parallel()
got := selectSkillMetasForInstructionRefresh(
persisted,
discovered,
uuid.NullUUID{UUID: currentAgentID, Valid: true},
uuid.NullUUID{UUID: otherAgentID, Valid: true},
)
require.Equal(t, discovered, got)
})
t.Run("PreservesPersistedSkillsWhenAgentLookupFails", func(t *testing.T) {
t.Parallel()
got := selectSkillMetasForInstructionRefresh(
persisted,
nil,
uuid.NullUUID{},
uuid.NullUUID{UUID: otherAgentID, Valid: true},
)
require.Equal(t, persisted, got)
})
}
func TestResolveChainModeIgnoresSkillOnlySentinelMessages(t *testing.T) {
t.Parallel()
modelConfigID := uuid.New()
assistant := database.ChatMessage{
Role: database.ChatMessageRoleAssistant,
ProviderResponseID: sql.NullString{String: "resp-123", Valid: true},
ModelConfigID: uuid.NullUUID{UUID: modelConfigID, Valid: true},
}
skillOnly := chatMessageWithParts([]codersdk.ChatMessagePart{
{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: AgentChatContextSentinelPath,
ContextFileAgentID: uuid.NullUUID{
UUID: uuid.New(),
Valid: true,
},
},
{
Type: codersdk.ChatMessagePartTypeSkill,
SkillName: "repo-helper",
SkillDir: "/skills/repo-helper",
},
})
skillOnly.Role = database.ChatMessageRoleUser
user := chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeText,
Text: "latest user message",
}})
user.Role = database.ChatMessageRoleUser
got := resolveChainMode([]database.ChatMessage{assistant, skillOnly, user})
require.Equal(t, "resp-123", got.previousResponseID)
require.Equal(t, modelConfigID, got.modelConfigID)
require.Equal(t, 2, got.trailingUserCount)
require.Equal(t, 1, got.contributingTrailingUserCount)
}
func TestFilterPromptForChainModeKeepsContributingUsersAcrossSkippedSentinelTurns(t *testing.T) {
t.Parallel()
modelConfigID := uuid.New()
priorUser := chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeText,
Text: "prior user message",
}})
priorUser.Role = database.ChatMessageRoleUser
assistant := database.ChatMessage{
Role: database.ChatMessageRoleAssistant,
ProviderResponseID: sql.NullString{String: "resp-123", Valid: true},
ModelConfigID: uuid.NullUUID{UUID: modelConfigID, Valid: true},
}
firstTrailingUser := chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeText,
Text: "first trailing user",
}})
firstTrailingUser.Role = database.ChatMessageRoleUser
skillOnly := chatMessageWithParts([]codersdk.ChatMessagePart{
{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: AgentChatContextSentinelPath,
ContextFileAgentID: uuid.NullUUID{
UUID: uuid.New(),
Valid: true,
},
},
{
Type: codersdk.ChatMessagePartTypeSkill,
SkillName: "repo-helper",
SkillDir: "/skills/repo-helper",
},
})
skillOnly.Role = database.ChatMessageRoleUser
lastTrailingUser := chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeText,
Text: "last trailing user",
}})
lastTrailingUser.Role = database.ChatMessageRoleUser
chainInfo := resolveChainMode([]database.ChatMessage{
priorUser,
assistant,
firstTrailingUser,
skillOnly,
lastTrailingUser,
})
require.Equal(t, 3, chainInfo.trailingUserCount)
require.Equal(t, 2, chainInfo.contributingTrailingUserCount)
prompt := []fantasy.Message{
{
Role: fantasy.MessageRoleSystem,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "system instruction"},
},
},
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "prior user message"},
},
},
{
Role: fantasy.MessageRoleAssistant,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "assistant reply"},
},
},
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "first trailing user"},
},
},
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "last trailing user"},
},
},
}
got := filterPromptForChainMode(prompt, chainInfo)
require.Len(t, got, 3)
require.Equal(t, fantasy.MessageRoleSystem, got[0].Role)
require.Equal(t, fantasy.MessageRoleUser, got[1].Role)
require.Equal(t, fantasy.MessageRoleUser, got[2].Role)
firstPart, ok := fantasy.AsMessagePart[fantasy.TextPart](got[1].Content[0])
require.True(t, ok)
require.Equal(t, "first trailing user", firstPart.Text)
lastPart, ok := fantasy.AsMessagePart[fantasy.TextPart](got[2].Content[0])
require.True(t, ok)
require.Equal(t, "last trailing user", lastPart.Text)
}
func TestFilterPromptForChainModeUsesContributingTrailingUsers(t *testing.T) {
t.Parallel()
modelConfigID := uuid.New()
priorUser := chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeText,
Text: "prior user message",
}})
priorUser.Role = database.ChatMessageRoleUser
assistant := database.ChatMessage{
Role: database.ChatMessageRoleAssistant,
ProviderResponseID: sql.NullString{String: "resp-123", Valid: true},
ModelConfigID: uuid.NullUUID{UUID: modelConfigID, Valid: true},
}
skillOnly := chatMessageWithParts([]codersdk.ChatMessagePart{
{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: AgentChatContextSentinelPath,
ContextFileAgentID: uuid.NullUUID{
UUID: uuid.New(),
Valid: true,
},
},
{
Type: codersdk.ChatMessagePartTypeSkill,
SkillName: "repo-helper",
SkillDir: "/skills/repo-helper",
},
})
skillOnly.Role = database.ChatMessageRoleUser
latestUser := chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeText,
Text: "latest user message",
}})
latestUser.Role = database.ChatMessageRoleUser
chainInfo := resolveChainMode([]database.ChatMessage{
priorUser,
assistant,
skillOnly,
latestUser,
})
require.Equal(t, 2, chainInfo.trailingUserCount)
require.Equal(t, 1, chainInfo.contributingTrailingUserCount)
prompt := []fantasy.Message{
{
Role: fantasy.MessageRoleSystem,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "system instruction"},
},
},
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "prior user message"},
},
},
{
Role: fantasy.MessageRoleAssistant,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "assistant reply"},
},
},
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "latest user message"},
},
},
}
got := filterPromptForChainMode(prompt, chainInfo)
require.Len(t, got, 2)
require.Equal(t, fantasy.MessageRoleSystem, got[0].Role)
require.Equal(t, fantasy.MessageRoleUser, got[1].Role)
part, ok := fantasy.AsMessagePart[fantasy.TextPart](got[1].Content[0])
require.True(t, ok)
require.Equal(t, "latest user message", part.Text)
}
func chatMessageWithParts(parts []codersdk.ChatMessagePart) database.ChatMessage {
raw, _ := json.Marshal(parts)
return database.ChatMessage{
Content: pqtype.NullRawMessage{RawMessage: raw, Valid: true},
}
}
// TestProcessChat_IgnoresStaleControlNotification verifies that
// processChat is not interrupted by a "pending" notification
// published before processing begins. This is the race that caused
// TestOpenAIReasoningWithWebSearchRoundTripStoreFalse to flake:
// SendMessage publishes "pending" via PostgreSQL NOTIFY, and due
// to async delivery the notification can arrive at the control
// subscriber after it registers but before the processor publishes
// "running".
func TestProcessChat_IgnoresStaleControlNotification(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
ps := dbpubsub.NewInMemory()
clock := quartz.NewMock(t)
chatID := uuid.New()
workerID := uuid.New()
server := &Server{
db: db,
logger: logger,
pubsub: ps,
clock: clock,
workerID: workerID,
chatHeartbeatInterval: time.Minute,
configCache: newChatConfigCache(ctx, db, clock),
heartbeatRegistry: make(map[uuid.UUID]*heartbeatEntry),
metrics: chatloop.NopMetrics(),
}
// Publish a stale "pending" notification on the control channel
// BEFORE processChat subscribes. In production this is the
// notification from SendMessage that triggered the processing.
staleNotify, err := json.Marshal(coderdpubsub.ChatStreamNotifyMessage{
Status: string(database.ChatStatusPending),
})
require.NoError(t, err)
err = ps.Publish(coderdpubsub.ChatStreamNotifyChannel(chatID), staleNotify)
require.NoError(t, err)
// Track which status processChat writes during cleanup.
var finalStatus database.ChatStatus
cleanupDone := make(chan struct{})
// The deferred cleanup in processChat runs a transaction.
db.EXPECT().InTx(gomock.Any(), gomock.Any()).DoAndReturn(
func(fn func(database.Store) error, _ *database.TxOptions) error {
return fn(db)
},
)
db.EXPECT().GetChatByIDForUpdate(gomock.Any(), chatID).Return(
database.Chat{ID: chatID, Status: database.ChatStatusRunning, WorkerID: uuid.NullUUID{UUID: workerID, Valid: true}}, nil,
)
db.EXPECT().UpdateChatStatus(gomock.Any(), gomock.Any()).DoAndReturn(
func(_ context.Context, params database.UpdateChatStatusParams) (database.Chat, error) {
finalStatus = params.Status
close(cleanupDone)
return database.Chat{ID: chatID, Status: params.Status}, nil
},
)
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(
database.Chat{ID: chatID, Status: database.ChatStatusError},
nil,
)
// resolveChatModel fails immediately — that's fine, we only
// need processChat to get past initialization without being
// interrupted by the stale notification.
db.EXPECT().GetChatModelConfigByID(gomock.Any(), gomock.Any()).Return(
database.ChatModelConfig{}, xerrors.New("no model configured"),
).AnyTimes()
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return(nil, nil).AnyTimes()
db.EXPECT().GetEnabledChatModelConfigs(gomock.Any()).Return(nil, nil).AnyTimes()
db.EXPECT().GetChatUsageLimitConfig(gomock.Any()).Return(
database.ChatUsageLimitConfig{}, sql.ErrNoRows,
).AnyTimes()
db.EXPECT().GetChatMessagesForPromptByChatID(gomock.Any(), chatID).Return(nil, nil).AnyTimes()
chat := database.Chat{ID: chatID, LastModelConfigID: uuid.New()}
go server.processChat(ctx, chat)
select {
case <-cleanupDone:
case <-ctx.Done():
t.Fatal("processChat did not complete")
}
// If the stale notification interrupted us, status would be
// "waiting" (the ErrInterrupted path). Since the gate blocked
// it, processChat reached runChat, which failed on model
// resolution → status is "error".
require.Equal(t, database.ChatStatusError, finalStatus,
"processChat should have reached runChat (error), not been interrupted (waiting)")
}
func TestShouldPublishFinishedChatState(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
workerID := uuid.New()
server := &Server{db: db}
updatedChat := database.Chat{
ID: chatID,
Status: database.ChatStatusWaiting,
WorkerID: uuid.NullUUID{},
}
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(database.Chat{
ID: chatID,
Status: database.ChatStatusWaiting,
WorkerID: uuid.NullUUID{},
}, nil)
require.True(t, server.shouldPublishFinishedChatState(ctx, logger, updatedChat))
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(database.Chat{
ID: chatID,
Status: database.ChatStatusRunning,
WorkerID: uuid.NullUUID{UUID: workerID, Valid: true},
}, nil)
require.False(t, server.shouldPublishFinishedChatState(ctx, logger, updatedChat))
}
// TestShouldPublishFinishedChatState_DBErrorPublishes pins the
// deliberate fail-open behavior when the re-read query errors: we
// surface the finished state anyway so watchers don't get stuck
// waiting for a status update that never arrives. The error path is
// easy to regress into a fail-closed default otherwise.
func TestShouldPublishFinishedChatState_DBErrorPublishes(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
server := &Server{db: db}
updatedChat := database.Chat{
ID: chatID,
Status: database.ChatStatusWaiting,
WorkerID: uuid.NullUUID{},
}
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(
database.Chat{}, xerrors.New("boom"),
)
require.True(t, server.shouldPublishFinishedChatState(ctx, logger, updatedChat),
"fail-open: a re-read error must not swallow the status change")
}
// TestHeartbeatTick_StolenChatIsInterrupted verifies that when the
// batch heartbeat UPDATE does not return a registered chat's ID
// (because another replica stole it or it was completed), the
// heartbeat tick cancels that chat's context with ErrInterrupted
// while leaving surviving chats untouched.
func TestHeartbeatTick_StolenChatIsInterrupted(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
clock := quartz.NewMock(t)
workerID := uuid.New()
server := &Server{
db: db,
logger: logger,
clock: clock,
workerID: workerID,
chatHeartbeatInterval: time.Minute,
heartbeatRegistry: make(map[uuid.UUID]*heartbeatEntry),
}
// Create three chats with independent cancel functions.
chat1 := uuid.New()
chat2 := uuid.New()
chat3 := uuid.New()
_, cancel1 := context.WithCancelCause(ctx)
_, cancel2 := context.WithCancelCause(ctx)
ctx3, cancel3 := context.WithCancelCause(ctx)
server.registerHeartbeat(&heartbeatEntry{
cancelWithCause: cancel1,
chatID: chat1,
logger: logger,
})
server.registerHeartbeat(&heartbeatEntry{
cancelWithCause: cancel2,
chatID: chat2,
logger: logger,
})
server.registerHeartbeat(&heartbeatEntry{
cancelWithCause: cancel3,
chatID: chat3,
logger: logger,
})
// The batch UPDATE returns only chat1 and chat2 —
// chat3 was "stolen" by another replica.
db.EXPECT().UpdateChatHeartbeats(gomock.Any(), gomock.Any()).DoAndReturn(
func(_ context.Context, params database.UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
require.Equal(t, workerID, params.WorkerID)
require.Len(t, params.IDs, 3)
// Return only chat1 and chat2 as surviving.
return []uuid.UUID{chat1, chat2}, nil
},
)
server.heartbeatTick(ctx)
// chat3's context should be canceled with ErrInterrupted.
require.ErrorIs(t, context.Cause(ctx3), chatloop.ErrInterrupted,
"stolen chat should be interrupted")
// chat3 should have been removed from the registry by
// unregister (in production this happens via defer in
// processChat). The heartbeat tick itself does not
// unregister — it only cancels. Verify the entry is
// still present (processChat's defer would clean it up).
server.heartbeatMu.Lock()
_, chat1Exists := server.heartbeatRegistry[chat1]
_, chat2Exists := server.heartbeatRegistry[chat2]
_, chat3Exists := server.heartbeatRegistry[chat3]
server.heartbeatMu.Unlock()
require.True(t, chat1Exists, "surviving chat1 should remain registered")
require.True(t, chat2Exists, "surviving chat2 should remain registered")
require.True(t, chat3Exists,
"stolen chat3 should still be in registry (processChat defer removes it)")
}
// TestHeartbeatTick_DBErrorDoesNotInterruptChats verifies that a
// transient database failure causes the tick to log and return
// without canceling any registered chats.
func TestHeartbeatTick_DBErrorDoesNotInterruptChats(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
clock := quartz.NewMock(t)
server := &Server{
db: db,
logger: logger,
clock: clock,
workerID: uuid.New(),
chatHeartbeatInterval: time.Minute,
heartbeatRegistry: make(map[uuid.UUID]*heartbeatEntry),
}
chatID := uuid.New()
chatCtx, cancel := context.WithCancelCause(ctx)
server.registerHeartbeat(&heartbeatEntry{
cancelWithCause: cancel,
chatID: chatID,
logger: logger,
})
// Simulate a transient DB error.
db.EXPECT().UpdateChatHeartbeats(gomock.Any(), gomock.Any()).Return(
nil, xerrors.New("connection reset"),
)
server.heartbeatTick(ctx)
// Chat should NOT be interrupted — the tick logged and
// returned early.
require.NoError(t, chatCtx.Err(),
"chat context should not be canceled on transient DB error")
}
// TestSubscribeCancelDuringGrace_ReapedBySweep verifies that a
// subscriber detach inside bufferRetainGracePeriod (the OSS trigger
// for the retained-buffer leak) leaves the state mapped, and the
// next sweep past the grace window reaps it.
func TestSubscribeCancelDuringGrace_ReapedBySweep(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
mClock := quartz.NewMock(t)
server := &Server{
logger: logger,
clock: mClock,
}
chatID := uuid.New()
start := mClock.Now()
// Just-finished chat: processing done, buffer retained for
// late-connecting relay subscribers.
state := &chatStreamState{
buffering: false,
bufferRetainedAt: start,
subscribers: map[uuid.UUID]chan codersdk.ChatStreamEvent{},
buffer: []codersdk.ChatStreamEvent{{
Type: codersdk.ChatStreamEventTypeMessagePart,
MessagePart: &codersdk.ChatStreamMessagePart{
Role: codersdk.ChatMessageRoleAssistant,
},
}},
}
server.chatStreams.Store(chatID, state)
// Real subscribeToStream cancel path: the WS subscriber detach
// that leaks in prod.
_, _, cancelSub := server.subscribeToStream(chatID)
mClock.Advance(bufferRetainGracePeriod / 2)
cancelSub()
_, ok := server.chatStreams.Load(chatID)
require.True(t, ok,
"entry should remain during grace window after subscriber detach")
mClock.Advance(bufferRetainGracePeriod)
server.sweepIdleStreams()
_, ok = server.chatStreams.Load(chatID)
require.False(t, ok,
"entry should be reaped after grace period expires and sweep runs")
}
// TestSweepIdleStreams_ReapsStaleRetainedBuffer: grace expired, no
// subscribers, not buffering -> reaped.
func TestSweepIdleStreams_ReapsStaleRetainedBuffer(t *testing.T) {
t.Parallel()
mClock := quartz.NewMock(t)
server := &Server{
logger: slogtest.Make(t, nil),
clock: mClock,
}
chatID := uuid.New()
state := &chatStreamState{
buffering: false,
bufferRetainedAt: mClock.Now(),
subscribers: map[uuid.UUID]chan codersdk.ChatStreamEvent{},
buffer: []codersdk.ChatStreamEvent{{
Type: codersdk.ChatStreamEventTypeMessagePart,
MessagePart: &codersdk.ChatStreamMessagePart{},
}},
}
server.chatStreams.Store(chatID, state)
mClock.Advance(bufferRetainGracePeriod + time.Second)
server.sweepIdleStreams()
_, ok := server.chatStreams.Load(chatID)
require.False(t, ok, "stale retained state should be reaped")
}
// TestSweepIdleStreams_DoesNotReapActiveBuffering: buffering=true
// blocks reap even long after any grace would have expired.
func TestSweepIdleStreams_DoesNotReapActiveBuffering(t *testing.T) {
t.Parallel()
mClock := quartz.NewMock(t)
server := &Server{
logger: slogtest.Make(t, nil),
clock: mClock,
}
chatID := uuid.New()
state := &chatStreamState{
buffering: true,
subscribers: map[uuid.UUID]chan codersdk.ChatStreamEvent{},
buffer: []codersdk.ChatStreamEvent{{
Type: codersdk.ChatStreamEventTypeMessagePart,
MessagePart: &codersdk.ChatStreamMessagePart{},
}},
}
server.chatStreams.Store(chatID, state)
mClock.Advance(time.Hour)
server.sweepIdleStreams()
_, ok := server.chatStreams.Load(chatID)
require.True(t, ok, "actively-buffering state must not be reaped")
}
// TestSweepIdleStreams_DoesNotReapWithSubscribers: attached
// subscribers block reap even when grace has expired.
func TestSweepIdleStreams_DoesNotReapWithSubscribers(t *testing.T) {
t.Parallel()
mClock := quartz.NewMock(t)
server := &Server{
logger: slogtest.Make(t, nil),
clock: mClock,
}
chatID := uuid.New()
state := &chatStreamState{
buffering: false,
bufferRetainedAt: mClock.Now(),
subscribers: map[uuid.UUID]chan codersdk.ChatStreamEvent{
uuid.New(): make(chan codersdk.ChatStreamEvent, 1),
},
buffer: []codersdk.ChatStreamEvent{{
Type: codersdk.ChatStreamEventTypeMessagePart,
MessagePart: &codersdk.ChatStreamMessagePart{},
}},
}
server.chatStreams.Store(chatID, state)
mClock.Advance(bufferRetainGracePeriod + time.Second)
server.sweepIdleStreams()
_, ok := server.chatStreams.Load(chatID)
require.True(t, ok, "state with subscribers must not be reaped")
}
// TestSweepIdleStreams_DefersDuringGracePeriod: sweep inside grace
// is a no-op; the next sweep past grace reaps.
func TestSweepIdleStreams_DefersDuringGracePeriod(t *testing.T) {
t.Parallel()
mClock := quartz.NewMock(t)
server := &Server{
logger: slogtest.Make(t, nil),
clock: mClock,
}
chatID := uuid.New()
start := mClock.Now()
state := &chatStreamState{
buffering: false,
bufferRetainedAt: start,
subscribers: map[uuid.UUID]chan codersdk.ChatStreamEvent{},
buffer: []codersdk.ChatStreamEvent{{
Type: codersdk.ChatStreamEventTypeMessagePart,
MessagePart: &codersdk.ChatStreamMessagePart{},
}},
}
server.chatStreams.Store(chatID, state)
mClock.Advance(bufferRetainGracePeriod / 2)
server.sweepIdleStreams()
_, ok := server.chatStreams.Load(chatID)
require.True(t, ok, "sweep inside grace window must not reap")
mClock.Advance(bufferRetainGracePeriod)
server.sweepIdleStreams()
_, ok = server.chatStreams.Load(chatID)
require.False(t, ok, "sweep after grace window must reap")
}
// TestPublishToStream_DropZeroesBackingSlot verifies that evicting
// the oldest buffered event at capacity zeroes the dropped slot so
// its *ChatStreamMessagePart becomes GC-eligible immediately.
func TestPublishToStream_DropZeroesBackingSlot(t *testing.T) {
t.Parallel()
mClock := quartz.NewMock(t)
server := &Server{
logger: slogtest.Make(t, nil),
clock: mClock,
}
chatID := uuid.New()
// Over-allocate by one so the post-drop append fits in place and
// exercises the backing-array reuse this test is checking.
buf := make([]codersdk.ChatStreamEvent, maxStreamBufferSize, maxStreamBufferSize+1)
for i := range buf {
buf[i] = codersdk.ChatStreamEvent{
Type: codersdk.ChatStreamEventTypeMessagePart,
MessagePart: &codersdk.ChatStreamMessagePart{},
}
}
// Sentinel in slot 0 distinguishes "slot was zeroed" from "slot
// was overwritten by a later append".
sentinel := &codersdk.ChatStreamMessagePart{
Role: codersdk.ChatMessageRoleAssistant,
}
buf[0] = codersdk.ChatStreamEvent{
Type: codersdk.ChatStreamEventTypeMessagePart,
MessagePart: sentinel,
}
// Alias over the full backing array so we can still observe slot
// 0 after publishToStream reslices state.buffer forward.
origBacking := buf[:cap(buf)]
state := &chatStreamState{
buffering: true,
buffer: buf,
subscribers: map[uuid.UUID]chan codersdk.ChatStreamEvent{},
}
server.chatStreams.Store(chatID, state)
newPart := &codersdk.ChatStreamMessagePart{
Role: codersdk.ChatMessageRoleAssistant,
}
server.publishToStream(chatID, codersdk.ChatStreamEvent{
Type: codersdk.ChatStreamEventTypeMessagePart,
MessagePart: newPart,
})
require.Equal(t, codersdk.ChatStreamEvent{}, origBacking[0],
"dropped slot must be zero-valued so its *ChatStreamMessagePart "+
"is eligible for GC; got %+v", origBacking[0])
// Sanity-check the in-place append path the fix targets: if Go's
// growth policy ever makes this append reallocate, this fails
// loudly so the test author revisits the setup.
require.Same(t, newPart, origBacking[len(origBacking)-1].MessagePart,
"append must have landed in the original backing array; the "+
"zero-out invariant only matters when cap > len")
}
// TestCleanupStreamIfIdle_StalePointerDoesNotDeleteFreshEntry covers
// the race where a caller holds a pointer to a no-longer-mapped
// state (e.g. a janitor Range callback racing a fresh
// getOrCreateStreamState) and would otherwise evict the fresh entry.
// With CompareAndDelete in cleanupStreamIfIdle the stale delete is
// a no-op.
func TestCleanupStreamIfIdle_StalePointerDoesNotDeleteFreshEntry(t *testing.T) {
t.Parallel()
mClock := quartz.NewMock(t)
server := &Server{
logger: slogtest.Make(t, nil),
clock: mClock,
}
chatID := uuid.New()
// Stale pointer: reapable (not buffering, no subscribers, grace
// expired) but no longer the map's live entry.
stale := &chatStreamState{
buffering: false,
bufferRetainedAt: mClock.Now(),
subscribers: map[uuid.UUID]chan codersdk.ChatStreamEvent{},
}
// Fresh entry: the state getOrCreateStreamState would install
// after a racing processChat run. Actively buffering, so not
// reapable. Only this state is in the map.
fresh := &chatStreamState{
buffering: true,
subscribers: map[uuid.UUID]chan codersdk.ChatStreamEvent{},
}
server.chatStreams.Store(chatID, fresh)
mClock.Advance(bufferRetainGracePeriod + time.Second)
// Stale caller mirrors the janitor Range callback after the map
// entry has already been replaced.
stale.mu.Lock()
server.cleanupStreamIfIdle(chatID, stale)
stale.mu.Unlock()
got, ok := server.chatStreams.Load(chatID)
require.True(t, ok,
"fresh entry must remain mapped when cleanup is called with a stale pointer")
require.Same(t, fresh, got,
"cleanup must not replace the fresh entry with the stale one")
}
// TestSafeSweepIdleStreams_RecoversFromPanic verifies that an
// unexpected panic inside sweepIdleStreams is recovered rather than
// killing the janitor goroutine. Without this guard, a panic would
// silently reintroduce the very leak the janitor exists to prevent.
func TestSafeSweepIdleStreams_RecoversFromPanic(t *testing.T) {
t.Parallel()
server := &Server{
logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
clock: quartz.NewMock(t),
}
chatID := uuid.New()
// A nil *chatStreamState passes the type assertion in sweepIdleStreams
// but panics on state.mu.Lock with a nil-pointer deref. Any future
// panic source in the sweep would trigger the same recovery path.
var nilState *chatStreamState
server.chatStreams.Store(chatID, nilState)
require.NotPanics(t, func() {
server.safeSweepIdleStreams(context.Background())
}, "safeSweepIdleStreams must recover panics so the janitor loop keeps running")
}