Files
coder/coderd/x/chatd/chatd_internal_test.go
T
Mathias Fredriksson 78d9a220cf fix(coderd/x/chatd): detect disconnected agents in getWorkspaceConn (#24336)
Add agent status check and dial timeout to getWorkspaceConn to
prevent tool calls from hanging when a workspace agent disconnects.

Status check: call isAgentUnreachable on every getWorkspaceConn
call. On cache miss, check the freshly fetched agent row. On
cache hit, re-fetch the agent row by PK for a fresh heartbeat
timestamp. Disconnected and timed-out agents return a sentinel
immediately; connecting agents proceed to dial.

Dial timeout: wrap dialWithLazyValidation in a 30s
context.WithTimeoutCause (matching 8 other server-side AgentConn
callers). Parent context cancellation propagates unchanged so
the chatloop can detect ErrInterrupted.

Both sentinels tell the LLM the agent is unreachable and the
workspace may need restarting from the dashboard.

Closes CODAGT-149
2026-04-22 12:10:32 +00:00

4163 lines
124 KiB
Go

package chatd
import (
"context"
"database/sql"
"encoding/json"
"sync"
"testing"
"time"
"charm.land/fantasy"
"github.com/google/uuid"
"github.com/sqlc-dev/pqtype"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbmock"
dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub"
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
"github.com/coder/coder/v2/coderd/x/chatd/chaterror"
"github.com/coder/coder/v2/coderd/x/chatd/chatloop"
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
"github.com/coder/coder/v2/coderd/x/chatd/chattest"
"github.com/coder/coder/v2/coderd/x/chatd/chattool"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
)
type testAgentTool struct {
info fantasy.ToolInfo
providerOptions fantasy.ProviderOptions
}
func newTestAgentTool(name string) fantasy.AgentTool {
return &testAgentTool{info: fantasy.ToolInfo{Name: name}}
}
func (t *testAgentTool) Info() fantasy.ToolInfo {
return t.info
}
func (t *testAgentTool) Run(context.Context, fantasy.ToolCall) (fantasy.ToolResponse, error) {
_ = t
return fantasy.ToolResponse{}, nil
}
func (t *testAgentTool) ProviderOptions() fantasy.ProviderOptions {
return t.providerOptions
}
func (t *testAgentTool) SetProviderOptions(opts fantasy.ProviderOptions) {
t.providerOptions = opts
}
type testMCPAgentTool struct {
*testAgentTool
configID uuid.UUID
}
func newTestMCPAgentTool(name string, configID uuid.UUID) fantasy.AgentTool {
return &testMCPAgentTool{
testAgentTool: &testAgentTool{info: fantasy.ToolInfo{Name: name}},
configID: configID,
}
}
func (t *testMCPAgentTool) MCPServerConfigID() uuid.UUID {
return t.configID
}
func TestFilterExternalMCPConfigsForTurn(t *testing.T) {
t.Parallel()
approvedConfig := database.MCPServerConfig{ID: uuid.New(), AllowInPlanMode: true}
blockedConfig := database.MCPServerConfig{ID: uuid.New(), AllowInPlanMode: false}
configs := []database.MCPServerConfig{approvedConfig, blockedConfig}
planMode := database.NullChatPlanMode{
ChatPlanMode: database.ChatPlanModePlan,
Valid: true,
}
t.Run("NonPlanModePassesThroughAllConfigs", func(t *testing.T) {
t.Parallel()
filtered, approvedIDs := filterExternalMCPConfigsForTurn(
configs,
database.NullChatPlanMode{},
uuid.NullUUID{},
)
require.Equal(t, configs, filtered)
require.Nil(t, approvedIDs)
})
t.Run("PlanModeSubagentsReturnNoConfigs", func(t *testing.T) {
t.Parallel()
filtered, approvedIDs := filterExternalMCPConfigsForTurn(
configs,
planMode,
uuid.NullUUID{UUID: uuid.New(), Valid: true},
)
require.Nil(t, filtered)
require.NotNil(t, approvedIDs)
require.Empty(t, approvedIDs)
})
t.Run("PlanModeRootFiltersToApprovedConfigs", func(t *testing.T) {
t.Parallel()
filtered, approvedIDs := filterExternalMCPConfigsForTurn(
configs,
planMode,
uuid.NullUUID{},
)
require.Equal(t, []database.MCPServerConfig{approvedConfig}, filtered)
require.Equal(t, map[uuid.UUID]struct{}{approvedConfig.ID: {}}, approvedIDs)
})
}
func TestActiveToolNamesForTurn(t *testing.T) {
t.Parallel()
makeTools := func(names ...string) []fantasy.AgentTool {
tools := make([]fantasy.AgentTool, 0, len(names))
for _, name := range names {
tools = append(tools, newTestAgentTool(name))
}
return tools
}
planMode := database.NullChatPlanMode{
ChatPlanMode: database.ChatPlanModePlan,
Valid: true,
}
t.Run("NormalModeReturnsAllRegisteredTools", func(t *testing.T) {
t.Parallel()
got := activeToolNamesForTurn(makeTools(
"read_file",
"propose_plan",
"custom_tool",
"execute",
), database.NullChatPlanMode{}, uuid.NullUUID{}, nil)
require.Equal(t, []string{
"read_file",
"propose_plan",
"custom_tool",
"execute",
}, got)
})
t.Run("PlanModeIncludesOnlyAllowlistedBuiltIns", func(t *testing.T) {
t.Parallel()
got := activeToolNamesForTurn(makeTools(
"read_file",
"write_file",
"edit_files",
"execute",
"process_output",
"process_list",
"process_signal",
"list_templates",
"read_template",
"create_workspace",
"start_workspace",
"propose_plan",
"spawn_agent",
"wait_agent",
"message_agent",
"close_agent",
"read_skill",
"read_skill_file",
"ask_user_question",
), planMode, uuid.NullUUID{}, nil)
require.Equal(t, []string{
"read_file",
"write_file",
"edit_files",
"execute",
"process_output",
"list_templates",
"read_template",
"create_workspace",
"start_workspace",
"propose_plan",
"spawn_agent",
"wait_agent",
"read_skill",
"read_skill_file",
"ask_user_question",
}, got)
})
t.Run("PlanModeChildChatsAllowExplorationOnly", func(t *testing.T) {
t.Parallel()
got := activeToolNamesForTurn(makeTools(
"read_file",
"write_file",
"edit_files",
"execute",
"process_output",
"list_templates",
"read_template",
"create_workspace",
"start_workspace",
"propose_plan",
"spawn_agent",
"wait_agent",
"read_skill",
"read_skill_file",
"ask_user_question",
), planMode, uuid.NullUUID{UUID: uuid.New(), Valid: true}, nil)
require.Equal(t, []string{
"read_file",
"execute",
"process_output",
"read_skill",
"read_skill_file",
}, got)
require.NotContains(t, got, "write_file")
require.NotContains(t, got, "edit_files")
require.NotContains(t, got, "ask_user_question")
require.NotContains(t, got, "propose_plan")
require.NotContains(t, got, "spawn_explore_agent")
})
t.Run("PlanModeStillExcludesDangerousTools", func(t *testing.T) {
t.Parallel()
got := activeToolNamesForTurn(makeTools(
"execute",
"process_output",
"message_agent",
"spawn_computer_use_agent",
"propose_plan",
), planMode, uuid.NullUUID{}, nil)
require.Equal(t, []string{"execute", "process_output", "propose_plan"}, got)
require.NotContains(t, got, "message_agent")
require.NotContains(t, got, "spawn_computer_use_agent")
})
t.Run("PlanModeExcludesUnknownTools", func(t *testing.T) {
t.Parallel()
got := activeToolNamesForTurn(makeTools(
"read_file",
"custom_tool",
"another_custom_tool",
"propose_plan",
), planMode, uuid.NullUUID{}, nil)
require.Equal(t, []string{
"read_file",
"propose_plan",
}, got)
require.NotContains(t, got, "custom_tool")
require.NotContains(t, got, "another_custom_tool")
})
t.Run("PlanModeIncludesOnlyApprovedExternalMCPTools", func(t *testing.T) {
t.Parallel()
approvedConfigID := uuid.New()
blockedConfigID := uuid.New()
got := activeToolNamesForTurn([]fantasy.AgentTool{
newTestAgentTool("read_file"),
newTestMCPAgentTool("approved-mcp__echo", approvedConfigID),
newTestMCPAgentTool("blocked-mcp__echo", blockedConfigID),
newTestAgentTool("workspace-mcp__echo"),
}, planMode, uuid.NullUUID{}, map[uuid.UUID]struct{}{
approvedConfigID: {},
})
require.Equal(t, []string{
"read_file",
"approved-mcp__echo",
}, got)
require.NotContains(t, got, "blocked-mcp__echo")
require.NotContains(t, got, "workspace-mcp__echo")
})
}
func TestAllowedExploreToolNames(t *testing.T) {
t.Parallel()
makeTools := func(names ...string) []fantasy.AgentTool {
tools := make([]fantasy.AgentTool, 0, len(names))
for _, name := range names {
tools = append(tools, newTestAgentTool(name))
}
return tools
}
got := allowedExploreToolNames(makeTools(
"read_file",
"write_file",
"edit_files",
"execute",
"process_output",
"process_list",
"process_signal",
"spawn_agent",
"wait_agent",
"read_skill",
"read_skill_file",
"ask_user_question",
))
require.Equal(t, []string{
"read_file",
"execute",
"process_output",
"read_skill",
"read_skill_file",
}, got)
}
func TestAllowedBehaviorToolNames(t *testing.T) {
t.Parallel()
makeTools := func(names ...string) []fantasy.AgentTool {
tools := make([]fantasy.AgentTool, 0, len(names))
for _, name := range names {
tools = append(tools, newTestAgentTool(name))
}
return tools
}
allTools := makeTools("read_file", "custom_tool", "spawn_agent")
exploreMode := database.NullChatMode{
ChatMode: database.ChatModeExplore,
Valid: true,
}
t.Run("DefaultModeReturnsAllTools", func(t *testing.T) {
t.Parallel()
require.Equal(t, []string{"read_file", "custom_tool", "spawn_agent"}, allowedBehaviorToolNames(
allTools,
database.NullChatMode{},
))
})
t.Run("ExploreModeUsesExploreAllowlist", func(t *testing.T) {
t.Parallel()
require.Equal(t, []string{"read_file"}, allowedBehaviorToolNames(
allTools,
exploreMode,
))
})
}
func TestStopAfterPlanTools(t *testing.T) {
t.Parallel()
planMode := database.NullChatPlanMode{
ChatPlanMode: database.ChatPlanModePlan,
Valid: true,
}
t.Run("NormalModeReturnsNil", func(t *testing.T) {
t.Parallel()
require.Nil(t, stopAfterPlanTools(database.NullChatPlanMode{}, uuid.NullUUID{}))
})
t.Run("RootPlanModeIncludesClarificationTool", func(t *testing.T) {
t.Parallel()
require.Equal(t, map[string]struct{}{
"propose_plan": {},
"ask_user_question": {},
}, stopAfterPlanTools(planMode, uuid.NullUUID{}))
})
t.Run("ChildPlanModeSkipsClarificationTool", func(t *testing.T) {
t.Parallel()
require.Equal(t, map[string]struct{}{
"propose_plan": {},
}, stopAfterPlanTools(planMode, uuid.NullUUID{UUID: uuid.New(), Valid: true}))
})
}
func TestStopAfterBehaviorTools(t *testing.T) {
t.Parallel()
planMode := database.NullChatPlanMode{
ChatPlanMode: database.ChatPlanModePlan,
Valid: true,
}
exploreMode := database.NullChatMode{
ChatMode: database.ChatModeExplore,
Valid: true,
}
t.Run("DefaultModeReturnsNil", func(t *testing.T) {
t.Parallel()
require.Nil(t, stopAfterBehaviorTools(
database.NullChatPlanMode{},
database.NullChatMode{},
uuid.NullUUID{},
))
})
t.Run("RootPlanModeIncludesClarificationTool", func(t *testing.T) {
t.Parallel()
require.Equal(t, map[string]struct{}{
"propose_plan": {},
"ask_user_question": {},
}, stopAfterBehaviorTools(planMode, database.NullChatMode{}, uuid.NullUUID{}))
})
t.Run("ChildPlanModeSkipsClarificationTool", func(t *testing.T) {
t.Parallel()
require.Equal(t, map[string]struct{}{
"propose_plan": {},
}, stopAfterBehaviorTools(planMode, database.NullChatMode{}, uuid.NullUUID{UUID: uuid.New(), Valid: true}))
})
t.Run("ExploreModeReturnsNil", func(t *testing.T) {
t.Parallel()
require.Nil(t, stopAfterBehaviorTools(planMode, exploreMode, uuid.NullUUID{}))
})
}
// TestWaitForActiveChatStop and TestWaitForActiveChatStop_WaitsForReplacementRun
// were removed along with the process-local activeChats mechanism.
// Debug cleanup is now best-effort; stale finalization handles orphaned rows.
// TestArchiveChatWaitsForActiveChatStop and
// TestArchiveChatWaitsForEveryInterruptedChat were removed along with
// the process-local activeChats mechanism. Archive cleanup is now
// best-effort; stale finalization handles any orphaned rows.
func TestRenameChatTitle(t *testing.T) {
t.Parallel()
setupRealWorkerLock := func(
db *dbmock.MockStore,
chatID uuid.UUID,
lockedChat database.Chat,
) {
lockTx := dbmock.NewMockStore(gomock.NewController(t))
unlockTx := dbmock.NewMockStore(gomock.NewController(t))
gomock.InOrder(
db.EXPECT().InTx(gomock.Any(), database.DefaultTXOptions().WithID("chat_title_regenerate_lock")).DoAndReturn(
func(fn func(database.Store) error, _ *database.TxOptions) error {
return fn(lockTx)
},
),
db.EXPECT().InTx(gomock.Any(), database.DefaultTXOptions().WithID("chat_title_regenerate_unlock")).DoAndReturn(
func(fn func(database.Store) error, _ *database.TxOptions) error {
return fn(unlockTx)
},
),
)
lockTx.EXPECT().GetChatByIDForUpdate(gomock.Any(), chatID).Return(lockedChat, nil)
unlockTx.EXPECT().GetChatByIDForUpdate(gomock.Any(), chatID).Return(lockedChat, nil)
}
t.Run("WritesAndReturnsWroteTrue", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
chatID := uuid.New()
workerID := uuid.New()
stored := database.Chat{
ID: chatID,
Status: database.ChatStatusRunning,
WorkerID: uuid.NullUUID{UUID: workerID, Valid: true},
Title: "original",
}
updated := stored
updated.Title = "renamed"
server := &Server{db: db, logger: logger}
setupRealWorkerLock(db, chatID, stored)
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(stored, nil)
db.EXPECT().UpdateChatTitleByID(gomock.Any(), database.UpdateChatTitleByIDParams{
ID: chatID,
Title: "renamed",
}).Return(updated, nil)
got, wrote, err := server.RenameChatTitle(ctx, stored, "renamed")
require.NoError(t, err)
require.True(t, wrote, "fresh rename must report wrote=true")
require.Equal(t, updated, got)
})
t.Run("SkipsWriteWhenAlreadyAtNewTitle", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
chatID := uuid.New()
workerID := uuid.New()
stale := database.Chat{
ID: chatID,
Status: database.ChatStatusRunning,
WorkerID: uuid.NullUUID{UUID: workerID, Valid: true},
Title: "pre-race",
}
landed := stale
landed.Title = "landed-concurrently"
server := &Server{db: db, logger: logger}
setupRealWorkerLock(db, chatID, landed)
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(landed, nil)
got, wrote, err := server.RenameChatTitle(ctx, stale, "landed-concurrently")
require.NoError(t, err)
require.False(t, wrote,
"must report wrote=false when the stored row already matches newTitle so the handler suppresses a redundant title_change event")
require.Equal(t, landed, got)
})
}
func TestRegenerateChatTitle_PersistsAndBroadcasts(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
lockTx := dbmock.NewMockStore(ctrl)
usageTx := dbmock.NewMockStore(ctrl)
unlockTx := dbmock.NewMockStore(ctrl)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
pubsub := dbpubsub.NewInMemory()
clock := quartz.NewReal()
ownerID := uuid.New()
chatID := uuid.New()
modelConfigID := uuid.New()
workerID := uuid.New()
userPrompt := "review pull request 23633 and fix review threads"
wantTitle := "Review PR 23633"
chat := database.Chat{
ID: chatID,
OwnerID: ownerID,
LastModelConfigID: modelConfigID,
Status: database.ChatStatusRunning,
WorkerID: uuid.NullUUID{UUID: workerID, Valid: true},
Title: fallbackChatTitle(userPrompt),
}
modelConfig := database.ChatModelConfig{
ID: modelConfigID,
Provider: "openai",
Model: "gpt-4o-mini",
ContextLimit: 8192,
}
updatedChat := chat
updatedChat.Title = wantTitle
messageEvents := make(chan struct {
payload codersdk.ChatWatchEvent
err error
}, 1)
cancelSub, err := pubsub.SubscribeWithErr(
coderdpubsub.ChatWatchEventChannel(ownerID),
coderdpubsub.HandleChatWatchEvent(func(_ context.Context, payload codersdk.ChatWatchEvent, err error) {
messageEvents <- struct {
payload codersdk.ChatWatchEvent
err error
}{payload: payload, err: err}
}),
)
require.NoError(t, err)
defer cancelSub()
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
require.Equal(t, "gpt-4o-mini", req.Model)
return chattest.OpenAINonStreamingResponse("{\"title\":\"" + wantTitle + "\"}")
})
server := &Server{
db: db,
logger: logger,
pubsub: pubsub,
configCache: newChatConfigCache(context.Background(), db, clock),
}
db.EXPECT().GetChatModelConfigByID(gomock.Any(), modelConfigID).Return(modelConfig, nil)
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{
Provider: "openai",
CentralApiKeyEnabled: true,
APIKey: "test-key",
BaseUrl: serverURL,
}}, nil)
db.EXPECT().GetChatUsageLimitConfig(gomock.Any()).Return(database.ChatUsageLimitConfig{}, sql.ErrNoRows)
db.EXPECT().GetChatMessagesByChatIDAscPaginated(
gomock.Any(),
database.GetChatMessagesByChatIDAscPaginatedParams{
ChatID: chatID,
AfterID: 0,
LimitVal: manualTitleMessageWindowLimit,
},
).Return([]database.ChatMessage{
mustChatMessage(
t,
database.ChatMessageRoleUser,
database.ChatMessageVisibilityBoth,
codersdk.ChatMessageText(userPrompt),
),
mustChatMessage(
t,
database.ChatMessageRoleAssistant,
database.ChatMessageVisibilityBoth,
codersdk.ChatMessageText("checking the diff now"),
),
}, nil)
db.EXPECT().GetChatMessagesByChatIDDescPaginated(
gomock.Any(),
database.GetChatMessagesByChatIDDescPaginatedParams{
ChatID: chatID,
BeforeID: 0,
LimitVal: manualTitleMessageWindowLimit,
},
).Return(nil, nil)
db.EXPECT().GetEnabledChatModelConfigs(gomock.Any()).Return(nil, nil)
gomock.InOrder(
db.EXPECT().InTx(gomock.Any(), database.DefaultTXOptions().WithID("chat_title_regenerate_lock")).DoAndReturn(
func(fn func(database.Store) error, opts *database.TxOptions) error {
require.Equal(t, "chat_title_regenerate_lock", opts.TxIdentifier)
return fn(lockTx)
},
),
db.EXPECT().InTx(gomock.Any(), nil).DoAndReturn(
func(fn func(database.Store) error, opts *database.TxOptions) error {
require.Nil(t, opts)
return fn(usageTx)
},
),
db.EXPECT().InTx(gomock.Any(), database.DefaultTXOptions().WithID("chat_title_regenerate_unlock")).DoAndReturn(
func(fn func(database.Store) error, opts *database.TxOptions) error {
require.Equal(t, "chat_title_regenerate_unlock", opts.TxIdentifier)
return fn(unlockTx)
},
),
)
lockTx.EXPECT().GetChatByIDForUpdate(gomock.Any(), chatID).Return(chat, nil)
usageTx.EXPECT().GetChatByIDForUpdate(gomock.Any(), chatID).Return(chat, nil)
usageTx.EXPECT().InsertChatMessages(gomock.Any(), gomock.AssignableToTypeOf(database.InsertChatMessagesParams{})).DoAndReturn(
func(_ context.Context, arg database.InsertChatMessagesParams) ([]database.ChatMessage, error) {
require.Equal(t, []uuid.UUID{ownerID}, arg.CreatedBy)
require.Equal(t, []uuid.UUID{modelConfigID}, arg.ModelConfigID)
require.Equal(t, []string{"[]"}, arg.Content)
return []database.ChatMessage{{ID: 91}}, nil
},
)
usageTx.EXPECT().SoftDeleteChatMessageByID(gomock.Any(), int64(91)).Return(nil)
usageTx.EXPECT().UpdateChatByID(gomock.Any(), database.UpdateChatByIDParams{
ID: chatID,
Title: wantTitle,
}).Return(updatedChat, nil)
unlockTx.EXPECT().GetChatByIDForUpdate(gomock.Any(), chatID).Return(updatedChat, nil)
gotChat, err := server.RegenerateChatTitle(ctx, chat)
require.NoError(t, err)
require.Equal(t, updatedChat, gotChat)
select {
case event := <-messageEvents:
require.NoError(t, event.err)
require.Equal(t, codersdk.ChatWatchEventKindTitleChange, event.payload.Kind)
require.Equal(t, chatID, event.payload.Chat.ID)
require.Equal(t, wantTitle, event.payload.Chat.Title)
case <-time.After(time.Second):
t.Fatal("timed out waiting for title change pubsub event")
}
}
func TestRegenerateChatTitle_PersistsAndBroadcasts_IdleChatReleasesManualLock(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
lockTx := dbmock.NewMockStore(ctrl)
usageTx := dbmock.NewMockStore(ctrl)
unlockTx := dbmock.NewMockStore(ctrl)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
pubsub := dbpubsub.NewInMemory()
clock := quartz.NewReal()
ownerID := uuid.New()
chatID := uuid.New()
modelConfigID := uuid.New()
userPrompt := "review pull request 23633 and fix review threads"
wantTitle := "Review PR 23633"
chat := database.Chat{
ID: chatID,
OwnerID: ownerID,
LastModelConfigID: modelConfigID,
Status: database.ChatStatusCompleted,
Title: fallbackChatTitle(userPrompt),
}
lockedChat := chat
lockedChat.WorkerID = uuid.NullUUID{UUID: manualTitleLockWorkerID, Valid: true}
lockedChat.StartedAt = sql.NullTime{Time: time.Now(), Valid: true}
modelConfig := database.ChatModelConfig{
ID: modelConfigID,
Provider: "openai",
Model: "gpt-4o-mini",
ContextLimit: 8192,
}
updatedChat := lockedChat
updatedChat.Title = wantTitle
unlockedChat := updatedChat
unlockedChat.WorkerID = uuid.NullUUID{}
unlockedChat.StartedAt = sql.NullTime{}
messageEvents := make(chan struct {
payload codersdk.ChatWatchEvent
err error
}, 1)
cancelSub, err := pubsub.SubscribeWithErr(
coderdpubsub.ChatWatchEventChannel(ownerID),
coderdpubsub.HandleChatWatchEvent(func(_ context.Context, payload codersdk.ChatWatchEvent, err error) {
messageEvents <- struct {
payload codersdk.ChatWatchEvent
err error
}{payload: payload, err: err}
}),
)
require.NoError(t, err)
defer cancelSub()
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
require.Equal(t, "gpt-4o-mini", req.Model)
return chattest.OpenAINonStreamingResponse("{\"title\":\"" + wantTitle + "\"}")
})
server := &Server{
db: db,
logger: logger,
pubsub: pubsub,
configCache: newChatConfigCache(context.Background(), db, clock),
}
db.EXPECT().GetChatModelConfigByID(gomock.Any(), modelConfigID).Return(modelConfig, nil)
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{
Provider: "openai",
CentralApiKeyEnabled: true,
APIKey: "test-key",
BaseUrl: serverURL,
}}, nil)
db.EXPECT().GetChatUsageLimitConfig(gomock.Any()).Return(database.ChatUsageLimitConfig{}, sql.ErrNoRows)
db.EXPECT().GetChatMessagesByChatIDAscPaginated(
gomock.Any(),
database.GetChatMessagesByChatIDAscPaginatedParams{
ChatID: chatID,
AfterID: 0,
LimitVal: manualTitleMessageWindowLimit,
},
).Return([]database.ChatMessage{
mustChatMessage(
t,
database.ChatMessageRoleUser,
database.ChatMessageVisibilityBoth,
codersdk.ChatMessageText(userPrompt),
),
mustChatMessage(
t,
database.ChatMessageRoleAssistant,
database.ChatMessageVisibilityBoth,
codersdk.ChatMessageText("checking the diff now"),
),
}, nil)
db.EXPECT().GetChatMessagesByChatIDDescPaginated(
gomock.Any(),
database.GetChatMessagesByChatIDDescPaginatedParams{
ChatID: chatID,
BeforeID: 0,
LimitVal: manualTitleMessageWindowLimit,
},
).Return(nil, nil)
db.EXPECT().GetEnabledChatModelConfigs(gomock.Any()).Return(nil, nil)
gomock.InOrder(
db.EXPECT().InTx(gomock.Any(), database.DefaultTXOptions().WithID("chat_title_regenerate_lock")).DoAndReturn(
func(fn func(database.Store) error, opts *database.TxOptions) error {
require.Equal(t, "chat_title_regenerate_lock", opts.TxIdentifier)
return fn(lockTx)
},
),
db.EXPECT().InTx(gomock.Any(), nil).DoAndReturn(
func(fn func(database.Store) error, opts *database.TxOptions) error {
require.Nil(t, opts)
return fn(usageTx)
},
),
db.EXPECT().InTx(gomock.Any(), database.DefaultTXOptions().WithID("chat_title_regenerate_unlock")).DoAndReturn(
func(fn func(database.Store) error, opts *database.TxOptions) error {
require.Equal(t, "chat_title_regenerate_unlock", opts.TxIdentifier)
return fn(unlockTx)
},
),
)
lockTx.EXPECT().GetChatByIDForUpdate(gomock.Any(), chatID).Return(chat, nil)
lockTx.EXPECT().UpdateChatStatusPreserveUpdatedAt(
gomock.Any(),
gomock.AssignableToTypeOf(database.UpdateChatStatusPreserveUpdatedAtParams{}),
).DoAndReturn(func(_ context.Context, arg database.UpdateChatStatusPreserveUpdatedAtParams) (database.Chat, error) {
require.Equal(t, chat.ID, arg.ID)
require.Equal(t, chat.Status, arg.Status)
require.Equal(t, uuid.NullUUID{UUID: manualTitleLockWorkerID, Valid: true}, arg.WorkerID)
require.True(t, arg.StartedAt.Valid)
require.WithinDuration(t, time.Now(), arg.StartedAt.Time, time.Second)
require.False(t, arg.HeartbeatAt.Valid)
require.Equal(t, chat.LastError, arg.LastError)
require.Equal(t, chat.UpdatedAt, arg.UpdatedAt)
return lockedChat, nil
})
usageTx.EXPECT().GetChatByIDForUpdate(gomock.Any(), chatID).Return(lockedChat, nil)
usageTx.EXPECT().InsertChatMessages(gomock.Any(), gomock.AssignableToTypeOf(database.InsertChatMessagesParams{})).DoAndReturn(
func(_ context.Context, arg database.InsertChatMessagesParams) ([]database.ChatMessage, error) {
require.Equal(t, []uuid.UUID{ownerID}, arg.CreatedBy)
require.Equal(t, []uuid.UUID{modelConfigID}, arg.ModelConfigID)
require.Equal(t, []string{"[]"}, arg.Content)
return []database.ChatMessage{{ID: 91}}, nil
},
)
usageTx.EXPECT().SoftDeleteChatMessageByID(gomock.Any(), int64(91)).Return(nil)
usageTx.EXPECT().UpdateChatByID(gomock.Any(), database.UpdateChatByIDParams{
ID: chatID,
Title: wantTitle,
}).Return(updatedChat, nil)
unlockTx.EXPECT().GetChatByIDForUpdate(gomock.Any(), chatID).Return(updatedChat, nil)
unlockTx.EXPECT().UpdateChatStatusPreserveUpdatedAt(
gomock.Any(),
database.UpdateChatStatusPreserveUpdatedAtParams{
ID: updatedChat.ID,
Status: updatedChat.Status,
WorkerID: uuid.NullUUID{},
StartedAt: sql.NullTime{},
HeartbeatAt: sql.NullTime{},
LastError: updatedChat.LastError,
UpdatedAt: updatedChat.UpdatedAt,
},
).Return(unlockedChat, nil)
gotChat, err := server.RegenerateChatTitle(ctx, chat)
require.NoError(t, err)
require.Equal(t, updatedChat, gotChat)
select {
case event := <-messageEvents:
require.NoError(t, event.err)
require.Equal(t, codersdk.ChatWatchEventKindTitleChange, event.payload.Kind)
require.Equal(t, chatID, event.payload.Chat.ID)
require.Equal(t, wantTitle, event.payload.Chat.Title)
case <-time.After(time.Second):
t.Fatal("timed out waiting for title change pubsub event")
}
}
func TestResolveUserProviderAPIKeys_StripsDisabledFallbackKeys(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
ownerID := uuid.New()
server := &Server{
db: db,
configCache: newChatConfigCache(
context.Background(),
db,
quartz.NewReal(),
),
providerAPIKeys: chatprovider.ProviderAPIKeys{
OpenAI: "openai-deployment-key",
Anthropic: "anthropic-deployment-key",
ByProvider: map[string]string{
"openai": "openai-deployment-key",
"anthropic": "anthropic-deployment-key",
},
BaseURLByProvider: map[string]string{
"openai": "https://openai.example.com",
"anthropic": "https://anthropic.example.com",
},
},
}
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{
Provider: "anthropic",
CentralApiKeyEnabled: true,
AllowCentralApiKeyFallback: true,
}}, nil)
keys, err := server.resolveUserProviderAPIKeys(ctx, ownerID)
require.NoError(t, err)
require.Empty(t, keys.OpenAI)
require.Empty(t, keys.APIKey("openai"))
require.Empty(t, keys.BaseURL("openai"))
require.Equal(t, "anthropic-deployment-key", keys.Anthropic)
require.Equal(t, "anthropic-deployment-key", keys.APIKey("anthropic"))
require.Equal(t, "https://anthropic.example.com", keys.BaseURL("anthropic"))
require.Equal(t, map[string]string{"anthropic": "anthropic-deployment-key"}, keys.ByProvider)
require.Equal(t, map[string]string{"anthropic": "https://anthropic.example.com"}, keys.BaseURLByProvider)
}
func TestResolveUserProviderAPIKeys_SkipsUserKeyLookupWhenNoProviderAllowsUserKeys(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
ownerID := uuid.New()
server := &Server{
db: db,
configCache: newChatConfigCache(
context.Background(),
db,
quartz.NewReal(),
),
providerAPIKeys: chatprovider.ProviderAPIKeys{
OpenAI: "openai-deployment-key",
ByProvider: map[string]string{
"openai": "openai-deployment-key",
},
},
}
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{
Provider: "openai",
CentralApiKeyEnabled: true,
}}, nil)
keys, err := server.resolveUserProviderAPIKeys(ctx, ownerID)
require.NoError(t, err)
require.Equal(t, "openai-deployment-key", keys.OpenAI)
require.Equal(t, "openai-deployment-key", keys.APIKey("openai"))
}
func TestRefreshChatWorkspaceSnapshot_NoReloadWhenWorkspacePresent(t *testing.T) {
t.Parallel()
workspaceID := uuid.New()
chat := database.Chat{
ID: uuid.New(),
WorkspaceID: uuid.NullUUID{
UUID: workspaceID,
Valid: true,
},
}
calls := 0
refreshed, err := refreshChatWorkspaceSnapshot(
context.Background(),
chat,
func(context.Context, uuid.UUID) (database.Chat, error) {
calls++
return database.Chat{}, nil
},
)
require.NoError(t, err)
require.Equal(t, chat, refreshed)
require.Equal(t, 0, calls)
}
func TestRefreshChatWorkspaceSnapshot_ReloadsWhenWorkspaceMissing(t *testing.T) {
t.Parallel()
chatID := uuid.New()
workspaceID := uuid.New()
chat := database.Chat{ID: chatID}
reloaded := database.Chat{
ID: chatID,
WorkspaceID: uuid.NullUUID{
UUID: workspaceID,
Valid: true,
},
}
calls := 0
refreshed, err := refreshChatWorkspaceSnapshot(
context.Background(),
chat,
func(_ context.Context, id uuid.UUID) (database.Chat, error) {
calls++
require.Equal(t, chatID, id)
return reloaded, nil
},
)
require.NoError(t, err)
require.Equal(t, reloaded, refreshed)
require.Equal(t, 1, calls)
}
func TestRefreshChatWorkspaceSnapshot_ReturnsReloadError(t *testing.T) {
t.Parallel()
chat := database.Chat{ID: uuid.New()}
loadErr := xerrors.New("boom")
refreshed, err := refreshChatWorkspaceSnapshot(
context.Background(),
chat,
func(context.Context, uuid.UUID) (database.Chat, error) {
return database.Chat{}, loadErr
},
)
require.Error(t, err)
require.ErrorContains(t, err, "reload chat workspace state")
require.ErrorContains(t, err, loadErr.Error())
require.Equal(t, chat, refreshed)
}
func TestPersistInstructionFilesIncludesAgentMetadata(t *testing.T) {
t.Parallel()
ctx := context.Background()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
workspaceID := uuid.New()
agentID := uuid.New()
chat := database.Chat{
ID: uuid.New(),
WorkspaceID: uuid.NullUUID{
UUID: workspaceID,
Valid: true,
},
AgentID: uuid.NullUUID{
UUID: agentID,
Valid: true,
},
}
workspaceAgent := database.WorkspaceAgent{
ID: agentID,
OperatingSystem: "linux",
Directory: "/home/coder/project",
ExpandedDirectory: "/home/coder/project",
}
db.EXPECT().GetWorkspaceAgentByID(
gomock.Any(),
agentID,
).Return(workspaceAgent, nil).Times(1)
db.EXPECT().InsertChatMessages(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
db.EXPECT().UpdateChatLastInjectedContext(gomock.Any(),
gomock.Cond(func(x any) bool {
arg, ok := x.(database.UpdateChatLastInjectedContextParams)
if !ok || arg.ID != chat.ID {
return false
}
if !arg.LastInjectedContext.Valid {
return false
}
var parts []codersdk.ChatMessagePart
if err := json.Unmarshal(arg.LastInjectedContext.RawMessage, &parts); err != nil {
return false
}
// Expect at least one context-file part for the
// working-directory AGENTS.md, with internal fields
// stripped (no content, OS, or directory).
for _, p := range parts {
if p.Type == codersdk.ChatMessagePartTypeContextFile && p.ContextFilePath != "" {
return p.ContextFileContent == "" &&
p.ContextFileOS == "" &&
p.ContextFileDirectory == ""
}
}
return false
}),
).Return(database.Chat{}, nil).Times(1)
conn := agentconnmock.NewMockAgentConn(ctrl)
conn.EXPECT().SetExtraHeaders(gomock.Any()).Times(1)
conn.EXPECT().ContextConfig(gomock.Any()).Return(workspacesdk.ContextConfigResponse{
Parts: []codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/home/coder/project/AGENTS.md",
ContextFileContent: "# Project instructions",
}},
}, nil).AnyTimes()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
server := &Server{
db: db,
logger: logger,
clock: quartz.NewReal(),
instructionLookupTimeout: 5 * time.Second,
agentInactiveDisconnectTimeout: 30 * time.Second,
dialTimeout: 30 * time.Second,
agentConnFn: func(context.Context, uuid.UUID) (workspacesdk.AgentConn, func(), error) {
return conn, func() {}, nil
},
}
chatStateMu := &sync.Mutex{}
currentChat := chat
workspaceCtx := turnWorkspaceContext{
server: server,
chatStateMu: chatStateMu,
currentChat: &currentChat,
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
}
t.Cleanup(workspaceCtx.close)
instruction, _, err := server.persistInstructionFiles(
ctx,
chat,
uuid.New(),
workspaceCtx.getWorkspaceAgent,
workspaceCtx.getWorkspaceConn,
)
require.NoError(t, err)
require.Contains(t, instruction, "Operating System: linux")
require.Contains(t, instruction, "Working Directory: /home/coder/project")
}
func TestPersistInstructionFilesSkipsSentinelWhenWorkspaceUnavailable(t *testing.T) {
t.Parallel()
ctx := context.Background()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chat := database.Chat{
ID: uuid.New(),
WorkspaceID: uuid.NullUUID{
UUID: uuid.New(),
Valid: true,
},
}
server := &Server{
db: db,
logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
}
instruction, _, err := server.persistInstructionFiles(
ctx,
chat,
uuid.New(),
func(context.Context) (database.WorkspaceAgent, error) {
return database.WorkspaceAgent{
ID: uuid.New(),
Directory: "/home/coder/project",
}, nil
},
func(context.Context) (workspacesdk.AgentConn, error) {
return nil, errChatHasNoWorkspaceAgent
},
)
require.NoError(t, err)
require.Empty(t, instruction)
}
func TestPersistInstructionFilesSentinelWithSkills(t *testing.T) {
t.Parallel()
ctx := context.Background()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
workspaceID := uuid.New()
agentID := uuid.New()
chat := database.Chat{
ID: uuid.New(),
WorkspaceID: uuid.NullUUID{
UUID: workspaceID,
Valid: true,
},
AgentID: uuid.NullUUID{
UUID: agentID,
Valid: true,
},
}
workspaceAgent := database.WorkspaceAgent{
ID: agentID,
OperatingSystem: "linux",
Directory: "/home/coder/project",
ExpandedDirectory: "/home/coder/project",
}
db.EXPECT().GetWorkspaceAgentByID(
gomock.Any(),
agentID,
).Return(workspaceAgent, nil).Times(1)
db.EXPECT().InsertChatMessages(gomock.Any(),
gomock.Cond(func(x any) bool {
arg, ok := x.(database.InsertChatMessagesParams)
if !ok || arg.ChatID != chat.ID || len(arg.Content) != 1 {
return false
}
var parts []codersdk.ChatMessagePart
if err := json.Unmarshal([]byte(arg.Content[0]), &parts); err != nil {
return false
}
foundMarker := false
foundSkill := false
for _, p := range parts {
switch p.Type {
case codersdk.ChatMessagePartTypeContextFile:
if p.ContextFileAgentID == (uuid.NullUUID{UUID: agentID, Valid: true}) && p.ContextFileContent == "" {
foundMarker = true
}
case codersdk.ChatMessagePartTypeSkill:
if p.SkillName == "my-skill" && p.ContextFileAgentID == (uuid.NullUUID{UUID: agentID, Valid: true}) {
foundSkill = true
}
}
}
return foundMarker && foundSkill
}),
).Return(nil, nil).Times(1)
db.EXPECT().UpdateChatLastInjectedContext(gomock.Any(),
gomock.Cond(func(x any) bool {
arg, ok := x.(database.UpdateChatLastInjectedContextParams)
if !ok || arg.ID != chat.ID {
return false
}
if !arg.LastInjectedContext.Valid {
return false
}
var parts []codersdk.ChatMessagePart
if err := json.Unmarshal(arg.LastInjectedContext.RawMessage, &parts); err != nil {
return false
}
// The sentinel path should persist only skill parts
// with ContextFileAgentID set.
for _, p := range parts {
if p.Type == codersdk.ChatMessagePartTypeSkill &&
p.SkillName == "my-skill" &&
p.ContextFileAgentID == (uuid.NullUUID{UUID: agentID, Valid: true}) {
return true
}
}
return false
}),
).Return(database.Chat{}, nil).Times(1)
conn := agentconnmock.NewMockAgentConn(ctrl)
conn.EXPECT().SetExtraHeaders(gomock.Any()).Times(1)
conn.EXPECT().ContextConfig(gomock.Any()).Return(workspacesdk.ContextConfigResponse{
// Agent returns pre-read content: no instruction files
// found but one skill discovered.
Parts: []codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeSkill,
SkillName: "my-skill",
SkillDescription: "A test skill",
SkillDir: "/home/coder/project/.agents/skills/my-skill",
}},
}, nil).AnyTimes()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
server := &Server{
db: db,
logger: logger,
clock: quartz.NewReal(),
instructionLookupTimeout: 5 * time.Second,
agentInactiveDisconnectTimeout: 30 * time.Second,
dialTimeout: 30 * time.Second,
agentConnFn: func(context.Context, uuid.UUID) (workspacesdk.AgentConn, func(), error) {
return conn, func() {}, nil
},
}
chatStateMu := &sync.Mutex{}
currentChat := chat
workspaceCtx := turnWorkspaceContext{
server: server,
chatStateMu: chatStateMu,
currentChat: &currentChat,
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
}
t.Cleanup(workspaceCtx.close)
instruction, skills, err := server.persistInstructionFiles(
ctx,
chat,
uuid.New(),
workspaceCtx.getWorkspaceAgent,
workspaceCtx.getWorkspaceConn,
)
require.NoError(t, err)
// Sentinel path returns empty instruction string.
require.Empty(t, instruction)
// Skills are still discovered and returned.
require.Len(t, skills, 1)
require.Equal(t, "my-skill", skills[0].Name)
}
func TestPersistInstructionFilesSentinelNoSkillsClearsColumn(t *testing.T) {
t.Parallel()
ctx := context.Background()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
workspaceID := uuid.New()
agentID := uuid.New()
chat := database.Chat{
ID: uuid.New(),
WorkspaceID: uuid.NullUUID{
UUID: workspaceID,
Valid: true,
},
AgentID: uuid.NullUUID{
UUID: agentID,
Valid: true,
},
}
workspaceAgent := database.WorkspaceAgent{
ID: agentID,
OperatingSystem: "linux",
Directory: "/home/coder/project",
ExpandedDirectory: "/home/coder/project",
}
db.EXPECT().GetWorkspaceAgentByID(
gomock.Any(),
agentID,
).Return(workspaceAgent, nil).Times(1)
db.EXPECT().InsertChatMessages(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
db.EXPECT().UpdateChatLastInjectedContext(gomock.Any(),
gomock.Cond(func(x any) bool {
arg, ok := x.(database.UpdateChatLastInjectedContextParams)
if !ok || arg.ID != chat.ID {
return false
}
// No skills discovered, so the column should be
// cleared to NULL.
return !arg.LastInjectedContext.Valid
}),
).Return(database.Chat{}, nil).Times(1)
conn := agentconnmock.NewMockAgentConn(ctrl)
conn.EXPECT().SetExtraHeaders(gomock.Any()).Times(1)
conn.EXPECT().ContextConfig(gomock.Any()).Return(workspacesdk.ContextConfigResponse{
// Agent returns pre-read content: no files, no skills.
Parts: []codersdk.ChatMessagePart{},
}, nil).AnyTimes()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
server := &Server{
db: db,
logger: logger,
clock: quartz.NewReal(),
instructionLookupTimeout: 5 * time.Second,
agentInactiveDisconnectTimeout: 30 * time.Second,
dialTimeout: 30 * time.Second,
agentConnFn: func(context.Context, uuid.UUID) (workspacesdk.AgentConn, func(), error) {
return conn, func() {}, nil
},
}
chatStateMu := &sync.Mutex{}
currentChat := chat
workspaceCtx := turnWorkspaceContext{
server: server,
chatStateMu: chatStateMu,
currentChat: &currentChat,
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
}
t.Cleanup(workspaceCtx.close)
instruction, skills, err := server.persistInstructionFiles(
ctx,
chat,
uuid.New(),
workspaceCtx.getWorkspaceAgent,
workspaceCtx.getWorkspaceConn,
)
require.NoError(t, err)
// Sentinel path: empty instruction, no skills.
require.Empty(t, instruction)
require.Empty(t, skills)
}
func TestTurnWorkspaceContext_BindingFirstPath(t *testing.T) {
t.Parallel()
ctx := context.Background()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
workspaceID := uuid.New()
agentID := uuid.New()
chat := database.Chat{
ID: uuid.New(),
WorkspaceID: uuid.NullUUID{
UUID: workspaceID,
Valid: true,
},
AgentID: uuid.NullUUID{
UUID: agentID,
Valid: true,
},
}
workspaceAgent := database.WorkspaceAgent{ID: agentID}
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).Return(workspaceAgent, nil).Times(1)
chatStateMu := &sync.Mutex{}
currentChat := chat
workspaceCtx := turnWorkspaceContext{
server: &Server{db: db},
chatStateMu: chatStateMu,
currentChat: &currentChat,
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
}
t.Cleanup(workspaceCtx.close)
chatSnapshot, agent, err := workspaceCtx.ensureWorkspaceAgent(ctx)
require.NoError(t, err)
require.Equal(t, chat, chatSnapshot)
require.Equal(t, workspaceAgent, agent)
gotAgent, err := workspaceCtx.getWorkspaceAgent(ctx)
require.NoError(t, err)
require.Equal(t, workspaceAgent, gotAgent)
require.Equal(t, chat, currentChat)
}
func TestTurnWorkspaceContext_NullBindingLazyBind(t *testing.T) {
t.Parallel()
ctx := context.Background()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
workspaceID := uuid.New()
buildID := uuid.New()
agentID := uuid.New()
chat := database.Chat{
ID: uuid.New(),
WorkspaceID: uuid.NullUUID{
UUID: workspaceID,
Valid: true,
},
}
workspaceAgent := database.WorkspaceAgent{ID: agentID}
updatedChat := chat
updatedChat.BuildID = uuid.NullUUID{UUID: buildID, Valid: true}
updatedChat.AgentID = uuid.NullUUID{UUID: agentID, Valid: true}
gomock.InOrder(
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).Return([]database.WorkspaceAgent{workspaceAgent}, nil),
db.EXPECT().GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), workspaceID).Return(database.WorkspaceBuild{ID: buildID}, nil),
db.EXPECT().UpdateChatBuildAgentBinding(gomock.Any(), database.UpdateChatBuildAgentBindingParams{
BuildID: uuid.NullUUID{UUID: buildID, Valid: true},
AgentID: uuid.NullUUID{UUID: agentID, Valid: true},
ID: chat.ID,
}).Return(updatedChat, nil),
)
chatStateMu := &sync.Mutex{}
currentChat := chat
workspaceCtx := turnWorkspaceContext{
server: &Server{db: db},
chatStateMu: chatStateMu,
currentChat: &currentChat,
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
}
t.Cleanup(workspaceCtx.close)
chatSnapshot, agent, err := workspaceCtx.ensureWorkspaceAgent(ctx)
require.NoError(t, err)
require.Equal(t, updatedChat, chatSnapshot)
require.Equal(t, workspaceAgent, agent)
require.Equal(t, updatedChat, currentChat)
gotAgent, err := workspaceCtx.getWorkspaceAgent(ctx)
require.NoError(t, err)
require.Equal(t, workspaceAgent, gotAgent)
}
func TestTurnWorkspaceContext_StaleBindingRepair(t *testing.T) {
t.Parallel()
ctx := context.Background()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
workspaceID := uuid.New()
staleAgentID := uuid.New()
buildID := uuid.New()
currentAgentID := uuid.New()
chat := database.Chat{
ID: uuid.New(),
WorkspaceID: uuid.NullUUID{
UUID: workspaceID,
Valid: true,
},
AgentID: uuid.NullUUID{
UUID: staleAgentID,
Valid: true,
},
}
currentAgent := database.WorkspaceAgent{ID: currentAgentID}
updatedChat := chat
updatedChat.BuildID = uuid.NullUUID{UUID: buildID, Valid: true}
updatedChat.AgentID = uuid.NullUUID{UUID: currentAgentID, Valid: true}
gomock.InOrder(
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), staleAgentID).Return(database.WorkspaceAgent{}, xerrors.New("missing agent")),
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).Return([]database.WorkspaceAgent{currentAgent}, nil),
db.EXPECT().GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), workspaceID).Return(database.WorkspaceBuild{ID: buildID}, nil),
db.EXPECT().UpdateChatBuildAgentBinding(gomock.Any(), database.UpdateChatBuildAgentBindingParams{
BuildID: uuid.NullUUID{UUID: buildID, Valid: true},
AgentID: uuid.NullUUID{UUID: currentAgentID, Valid: true},
ID: chat.ID,
}).Return(updatedChat, nil),
)
chatStateMu := &sync.Mutex{}
currentChat := chat
workspaceCtx := turnWorkspaceContext{
server: &Server{db: db},
chatStateMu: chatStateMu,
currentChat: &currentChat,
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
}
t.Cleanup(workspaceCtx.close)
chatSnapshot, agent, err := workspaceCtx.ensureWorkspaceAgent(ctx)
require.NoError(t, err)
require.Equal(t, updatedChat, chatSnapshot)
require.Equal(t, currentAgent, agent)
require.Equal(t, updatedChat, currentChat)
}
func TestTurnWorkspaceContextGetWorkspaceConnLazyValidationSwitchesWorkspaceAgent(t *testing.T) {
t.Parallel()
ctx := context.Background()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
workspaceID := uuid.New()
staleAgentID := uuid.New()
currentAgentID := uuid.New()
buildID := uuid.New()
chat := database.Chat{
ID: uuid.New(),
WorkspaceID: uuid.NullUUID{
UUID: workspaceID,
Valid: true,
},
AgentID: uuid.NullUUID{
UUID: staleAgentID,
Valid: true,
},
}
staleAgent := database.WorkspaceAgent{ID: staleAgentID}
currentAgent := database.WorkspaceAgent{ID: currentAgentID}
updatedChat := chat
updatedChat.BuildID = uuid.NullUUID{UUID: buildID, Valid: true}
updatedChat.AgentID = uuid.NullUUID{UUID: currentAgentID, Valid: true}
gomock.InOrder(
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), staleAgentID).Return(staleAgent, nil),
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).Return([]database.WorkspaceAgent{currentAgent}, nil),
db.EXPECT().GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), workspaceID).Return(database.WorkspaceBuild{ID: buildID}, nil),
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), currentAgentID).Return(currentAgent, nil),
db.EXPECT().UpdateChatBuildAgentBinding(gomock.Any(), database.UpdateChatBuildAgentBindingParams{
BuildID: uuid.NullUUID{UUID: buildID, Valid: true},
AgentID: uuid.NullUUID{UUID: currentAgentID, Valid: true},
ID: chat.ID,
}).Return(updatedChat, nil),
)
conn := agentconnmock.NewMockAgentConn(ctrl)
conn.EXPECT().SetExtraHeaders(gomock.Any()).Times(1)
var dialed []uuid.UUID
server := &Server{
db: db,
clock: quartz.NewReal(),
agentInactiveDisconnectTimeout: 30 * time.Second,
dialTimeout: 30 * time.Second,
}
server.agentConnFn = func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) {
dialed = append(dialed, agentID)
if agentID == staleAgentID {
return nil, nil, xerrors.New("dial failed")
}
return conn, func() {}, nil
}
chatStateMu := &sync.Mutex{}
currentChat := chat
workspaceCtx := turnWorkspaceContext{
server: server,
chatStateMu: chatStateMu,
currentChat: &currentChat,
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
}
t.Cleanup(workspaceCtx.close)
gotConn, err := workspaceCtx.getWorkspaceConn(ctx)
require.NoError(t, err)
require.Same(t, conn, gotConn)
require.Equal(t, []uuid.UUID{staleAgentID, currentAgentID}, dialed)
require.Equal(t, updatedChat, currentChat)
gotAgent, err := workspaceCtx.getWorkspaceAgent(ctx)
require.NoError(t, err)
require.Equal(t, currentAgent, gotAgent)
}
func TestTurnWorkspaceContextGetWorkspaceConnFastFailsWithoutCurrentAgent(t *testing.T) {
t.Parallel()
ctx := context.Background()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
workspaceID := uuid.New()
staleAgentID := uuid.New()
chat := database.Chat{
ID: uuid.New(),
WorkspaceID: uuid.NullUUID{
UUID: workspaceID,
Valid: true,
},
AgentID: uuid.NullUUID{
UUID: staleAgentID,
Valid: true,
},
}
staleAgent := database.WorkspaceAgent{ID: staleAgentID}
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), staleAgentID).
Return(staleAgent, nil).
Times(1)
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).
Return([]database.WorkspaceAgent{}, nil).
Times(1)
server := &Server{
db: db,
clock: quartz.NewReal(),
agentInactiveDisconnectTimeout: 30 * time.Second,
dialTimeout: 30 * time.Second,
}
server.agentConnFn = func(context.Context, uuid.UUID) (workspacesdk.AgentConn, func(), error) {
return nil, nil, xerrors.New("dial failed")
}
chatStateMu := &sync.Mutex{}
currentChat := chat
workspaceCtx := turnWorkspaceContext{
server: server,
chatStateMu: chatStateMu,
currentChat: &currentChat,
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
}
defer workspaceCtx.close()
gotConn, err := workspaceCtx.getWorkspaceConn(ctx)
require.Nil(t, gotConn)
require.ErrorIs(t, err, errChatHasNoWorkspaceAgent)
workspaceCtx.mu.Lock()
defer workspaceCtx.mu.Unlock()
require.Equal(t, database.WorkspaceAgent{}, workspaceCtx.agent)
require.False(t, workspaceCtx.agentLoaded)
require.Nil(t, workspaceCtx.conn)
require.Nil(t, workspaceCtx.releaseConn)
require.Equal(t, uuid.NullUUID{}, workspaceCtx.cachedWorkspaceID)
}
func TestTurnWorkspaceContext_SelectWorkspaceClearsCachedState(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
currentChat := database.Chat{
ID: uuid.New(),
WorkspaceID: uuid.NullUUID{
UUID: uuid.New(),
Valid: true,
},
}
updatedChat := database.Chat{
ID: currentChat.ID,
WorkspaceID: uuid.NullUUID{
UUID: uuid.New(),
Valid: true,
},
}
cachedConn := agentconnmock.NewMockAgentConn(ctrl)
releaseCalls := 0
workspaceCtx := turnWorkspaceContext{
chatStateMu: &sync.Mutex{},
currentChat: &currentChat,
}
workspaceCtx.agent = database.WorkspaceAgent{ID: uuid.New()}
workspaceCtx.agentLoaded = true
workspaceCtx.conn = cachedConn
workspaceCtx.cachedWorkspaceID = currentChat.WorkspaceID
workspaceCtx.releaseConn = func() {
releaseCalls++
}
workspaceCtx.selectWorkspace(updatedChat)
require.Equal(t, updatedChat, currentChat)
require.Equal(t, 1, releaseCalls)
workspaceCtx.mu.Lock()
defer workspaceCtx.mu.Unlock()
require.Equal(t, database.WorkspaceAgent{}, workspaceCtx.agent)
require.False(t, workspaceCtx.agentLoaded)
require.Nil(t, workspaceCtx.conn)
require.Nil(t, workspaceCtx.releaseConn)
require.Equal(t, uuid.NullUUID{}, workspaceCtx.cachedWorkspaceID)
}
func TestTurnWorkspaceContext_EnsureWorkspaceAgentIgnoresCachedAgentForDifferentWorkspace(t *testing.T) {
t.Parallel()
ctx := context.Background()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
workspaceOneID := uuid.New()
workspaceTwoID := uuid.New()
buildID := uuid.New()
cachedAgent := database.WorkspaceAgent{ID: uuid.New()}
resolvedAgent := database.WorkspaceAgent{ID: uuid.New()}
chat := database.Chat{
ID: uuid.New(),
WorkspaceID: uuid.NullUUID{
UUID: workspaceTwoID,
Valid: true,
},
}
updatedChat := chat
updatedChat.BuildID = uuid.NullUUID{UUID: buildID, Valid: true}
updatedChat.AgentID = uuid.NullUUID{UUID: resolvedAgent.ID, Valid: true}
gomock.InOrder(
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceTwoID).Return([]database.WorkspaceAgent{resolvedAgent}, nil),
db.EXPECT().GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), workspaceTwoID).Return(database.WorkspaceBuild{ID: buildID}, nil),
db.EXPECT().UpdateChatBuildAgentBinding(gomock.Any(), database.UpdateChatBuildAgentBindingParams{
ID: chat.ID,
BuildID: uuid.NullUUID{UUID: buildID, Valid: true},
AgentID: uuid.NullUUID{UUID: resolvedAgent.ID, Valid: true},
}).Return(updatedChat, nil),
)
chatStateMu := &sync.Mutex{}
currentChat := chat
workspaceCtx := turnWorkspaceContext{
server: &Server{db: db},
chatStateMu: chatStateMu,
currentChat: &currentChat,
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
}
workspaceCtx.agent = cachedAgent
workspaceCtx.agentLoaded = true
workspaceCtx.cachedWorkspaceID = uuid.NullUUID{UUID: workspaceOneID, Valid: true}
defer workspaceCtx.close()
chatSnapshot, agent, err := workspaceCtx.ensureWorkspaceAgent(ctx)
require.NoError(t, err)
require.Equal(t, updatedChat, chatSnapshot)
require.Equal(t, resolvedAgent, agent)
require.Equal(t, updatedChat, currentChat)
}
func TestSubscribeSkipsDatabaseCatchupForLocallyDeliveredMessage(t *testing.T) {
t.Parallel()
ctx, cancelCtx := context.WithCancel(context.Background())
defer cancelCtx()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
initialMessage := database.ChatMessage{
ID: 1,
ChatID: chatID,
Role: database.ChatMessageRoleUser,
}
localMessage := database.ChatMessage{
ID: 2,
ChatID: chatID,
Role: database.ChatMessageRoleAssistant,
}
gomock.InOrder(
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
}).Return([]database.ChatMessage{initialMessage}, nil),
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
)
server := newSubscribeTestServer(t, db)
_, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
require.True(t, ok)
defer cancel()
server.publishMessage(chatID, localMessage)
event := requireStreamMessageEvent(t, events)
require.Equal(t, int64(2), event.Message.ID)
requireNoStreamEvent(t, events, 200*time.Millisecond)
}
func TestSubscribeUsesDurableCacheWhenLocalMessageWasNotDelivered(t *testing.T) {
t.Parallel()
ctx, cancelCtx := context.WithCancel(context.Background())
defer cancelCtx()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
initialMessage := database.ChatMessage{
ID: 1,
ChatID: chatID,
Role: database.ChatMessageRoleUser,
}
cachedMessage := codersdk.ChatMessage{
ID: 2,
ChatID: chatID,
Role: codersdk.ChatMessageRoleAssistant,
}
gomock.InOrder(
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
}).Return([]database.ChatMessage{initialMessage}, nil),
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
)
server := newSubscribeTestServer(t, db)
server.cacheDurableMessage(chatID, codersdk.ChatStreamEvent{
Type: codersdk.ChatStreamEventTypeMessage,
ChatID: chatID,
Message: &cachedMessage,
})
_, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
require.True(t, ok)
defer cancel()
server.publishChatStreamNotify(chatID, coderdpubsub.ChatStreamNotifyMessage{
AfterMessageID: 1,
})
event := requireStreamMessageEvent(t, events)
require.Equal(t, int64(2), event.Message.ID)
requireNoStreamEvent(t, events, 200*time.Millisecond)
}
func TestSubscribeQueriesDatabaseWhenDurableCacheMisses(t *testing.T) {
t.Parallel()
ctx, cancelCtx := context.WithCancel(context.Background())
defer cancelCtx()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
initialMessage := database.ChatMessage{
ID: 1,
ChatID: chatID,
Role: database.ChatMessageRoleUser,
}
catchupMessage := database.ChatMessage{
ID: 2,
ChatID: chatID,
Role: database.ChatMessageRoleAssistant,
}
gomock.InOrder(
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
}).Return([]database.ChatMessage{initialMessage}, nil),
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 1,
}).Return([]database.ChatMessage{catchupMessage}, nil),
)
server := newSubscribeTestServer(t, db)
_, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
require.True(t, ok)
defer cancel()
server.publishChatStreamNotify(chatID, coderdpubsub.ChatStreamNotifyMessage{
AfterMessageID: 1,
})
event := requireStreamMessageEvent(t, events)
require.Equal(t, int64(2), event.Message.ID)
requireNoStreamEvent(t, events, 200*time.Millisecond)
}
func TestSubscribeFullRefreshStillUsesDatabaseCatchup(t *testing.T) {
t.Parallel()
ctx, cancelCtx := context.WithCancel(context.Background())
defer cancelCtx()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
initialMessage := database.ChatMessage{
ID: 1,
ChatID: chatID,
Role: database.ChatMessageRoleUser,
}
editedMessage := database.ChatMessage{
ID: 1,
ChatID: chatID,
Role: database.ChatMessageRoleUser,
}
gomock.InOrder(
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
}).Return([]database.ChatMessage{initialMessage}, nil),
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
}).Return([]database.ChatMessage{editedMessage}, nil),
)
server := newSubscribeTestServer(t, db)
_, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
require.True(t, ok)
defer cancel()
server.publishEditedMessage(chatID, editedMessage)
event := requireStreamMessageEvent(t, events)
require.Equal(t, int64(1), event.Message.ID)
requireNoStreamEvent(t, events, 200*time.Millisecond)
}
func TestSubscribeDeliversRetryEventViaPubsubOnce(t *testing.T) {
t.Parallel()
ctx, cancelCtx := context.WithCancel(context.Background())
defer cancelCtx()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
gomock.InOrder(
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
}).Return(nil, nil),
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
)
server := newSubscribeTestServer(t, db)
_, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
require.True(t, ok)
defer cancel()
retryingAt := time.Unix(1_700_000_000, 0).UTC()
expected := &codersdk.ChatStreamRetry{
Attempt: 1,
DelayMs: (1500 * time.Millisecond).Milliseconds(),
Error: "OpenAI is rate limiting requests (HTTP 429).",
Kind: chaterror.KindRateLimit,
Provider: "openai",
StatusCode: 429,
RetryingAt: retryingAt,
}
server.publishRetry(chatID, expected)
event := requireStreamRetryEvent(t, events)
require.Equal(t, expected, event.Retry)
requireNoStreamEvent(t, events, 200*time.Millisecond)
}
func TestSubscribePrefersStructuredErrorPayloadViaPubsub(t *testing.T) {
t.Parallel()
ctx, cancelCtx := context.WithCancel(context.Background())
defer cancelCtx()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
gomock.InOrder(
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
}).Return(nil, nil),
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
)
server := newSubscribeTestServer(t, db)
_, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
require.True(t, ok)
defer cancel()
classified := chaterror.ClassifiedError{
Message: "OpenAI is rate limiting requests (HTTP 429).",
Kind: chaterror.KindRateLimit,
Provider: "openai",
Retryable: true,
StatusCode: 429,
}
server.publishError(chatID, classified)
event := requireStreamErrorEvent(t, events)
require.Equal(t, chaterror.StreamErrorPayload(classified), event.Error)
requireNoStreamEvent(t, events, 200*time.Millisecond)
}
func TestSubscribeFallsBackToLegacyErrorStringViaPubsub(t *testing.T) {
t.Parallel()
ctx, cancelCtx := context.WithCancel(context.Background())
defer cancelCtx()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
gomock.InOrder(
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
}).Return(nil, nil),
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
)
server := newSubscribeTestServer(t, db)
_, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
require.True(t, ok)
defer cancel()
server.publishChatStreamNotify(chatID, coderdpubsub.ChatStreamNotifyMessage{
Error: "legacy error only",
})
event := requireStreamErrorEvent(t, events)
require.Equal(t, &codersdk.ChatStreamError{Message: "legacy error only"}, event.Error)
requireNoStreamEvent(t, events, 200*time.Millisecond)
}
func newSubscribeTestServer(t *testing.T, db database.Store) *Server {
t.Helper()
return &Server{
db: db,
logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
pubsub: dbpubsub.NewInMemory(),
}
}
func requireStreamMessageEvent(t *testing.T, events <-chan codersdk.ChatStreamEvent) codersdk.ChatStreamEvent {
t.Helper()
select {
case event, ok := <-events:
require.True(t, ok, "chat stream closed before delivering an event")
require.Equal(t, codersdk.ChatStreamEventTypeMessage, event.Type)
require.NotNil(t, event.Message)
return event
case <-time.After(time.Second):
t.Fatal("timed out waiting for chat stream message event")
return codersdk.ChatStreamEvent{}
}
}
func requireStreamRetryEvent(t *testing.T, events <-chan codersdk.ChatStreamEvent) codersdk.ChatStreamEvent {
t.Helper()
select {
case event, ok := <-events:
require.True(t, ok, "chat stream closed before delivering an event")
require.Equal(t, codersdk.ChatStreamEventTypeRetry, event.Type)
require.NotNil(t, event.Retry)
return event
case <-time.After(time.Second):
t.Fatal("timed out waiting for chat stream retry event")
return codersdk.ChatStreamEvent{}
}
}
func requireStreamErrorEvent(t *testing.T, events <-chan codersdk.ChatStreamEvent) codersdk.ChatStreamEvent {
t.Helper()
select {
case event, ok := <-events:
require.True(t, ok, "chat stream closed before delivering an event")
require.Equal(t, codersdk.ChatStreamEventTypeError, event.Type)
require.NotNil(t, event.Error)
return event
case <-time.After(time.Second):
t.Fatal("timed out waiting for chat stream error event")
return codersdk.ChatStreamEvent{}
}
}
func requireNoStreamEvent(t *testing.T, events <-chan codersdk.ChatStreamEvent, wait time.Duration) {
t.Helper()
select {
case event, ok := <-events:
if !ok {
t.Fatal("chat stream closed unexpectedly")
}
t.Fatalf("unexpected chat stream event: %+v", event)
case <-time.After(wait):
}
}
// TestPublishToStream_DropWarnRateLimiting walks through a
// realistic lifecycle: buffer fills up, subscriber channel fills
// up, counters get reset between steps. It verifies that WARN
// logs are rate-limited to at most once per streamDropWarnInterval
// and that counter resets re-enable an immediate WARN.
func TestPublishToStream_DropWarnRateLimiting(t *testing.T) {
t.Parallel()
sink := testutil.NewFakeSink(t)
mClock := quartz.NewMock(t)
server := &Server{
logger: sink.Logger(),
clock: mClock,
}
chatID := uuid.New()
subCh := make(chan codersdk.ChatStreamEvent, 1)
subCh <- codersdk.ChatStreamEvent{} // pre-fill so sends always drop
// Set up state that mirrors a running chat: buffer at capacity,
// buffering enabled, one saturated subscriber.
state := &chatStreamState{
buffering: true,
buffer: make([]codersdk.ChatStreamEvent, maxStreamBufferSize),
subscribers: map[uuid.UUID]chan codersdk.ChatStreamEvent{
uuid.New(): subCh,
},
}
server.chatStreams.Store(chatID, state)
bufferMsg := "chat stream buffer full, dropping oldest event"
subMsg := "dropping chat stream event"
filter := func(level slog.Level, msg string) func(slog.SinkEntry) bool {
return func(e slog.SinkEntry) bool {
return e.Level == level && e.Message == msg
}
}
// --- Phase 1: buffer-full rate limiting ---
// message_part events hit both the buffer-full and subscriber-full
// paths. The first publish triggers a WARN for each; the rest
// within the window are DEBUG.
partEvent := codersdk.ChatStreamEvent{
Type: codersdk.ChatStreamEventTypeMessagePart,
MessagePart: &codersdk.ChatStreamMessagePart{},
}
for i := 0; i < 50; i++ {
server.publishToStream(chatID, partEvent)
}
require.Len(t, sink.Entries(filter(slog.LevelWarn, bufferMsg)), 1)
require.Empty(t, sink.Entries(filter(slog.LevelDebug, bufferMsg)))
requireFieldValue(t, sink.Entries(filter(slog.LevelWarn, bufferMsg))[0], "dropped_count", int64(1))
// Subscriber also saw 50 drops (one per publish).
require.Len(t, sink.Entries(filter(slog.LevelWarn, subMsg)), 1)
require.Empty(t, sink.Entries(filter(slog.LevelDebug, subMsg)))
requireFieldValue(t, sink.Entries(filter(slog.LevelWarn, subMsg))[0], "dropped_count", int64(1))
// --- Phase 2: clock advance triggers second WARN with count ---
mClock.Advance(streamDropWarnInterval + time.Second)
server.publishToStream(chatID, partEvent)
bufWarn := sink.Entries(filter(slog.LevelWarn, bufferMsg))
require.Len(t, bufWarn, 2)
requireFieldValue(t, bufWarn[1], "dropped_count", int64(50))
subWarn := sink.Entries(filter(slog.LevelWarn, subMsg))
require.Len(t, subWarn, 2)
requireFieldValue(t, subWarn[1], "dropped_count", int64(50))
// --- Phase 3: counter reset (simulates step persist) ---
state.mu.Lock()
state.buffer = make([]codersdk.ChatStreamEvent, maxStreamBufferSize)
state.resetDropCounters()
state.mu.Unlock()
// The very next drop should WARN immediately — the reset zeroed
// lastWarnAt so the interval check passes.
server.publishToStream(chatID, partEvent)
bufWarn = sink.Entries(filter(slog.LevelWarn, bufferMsg))
require.Len(t, bufWarn, 3, "expected WARN immediately after counter reset")
requireFieldValue(t, bufWarn[2], "dropped_count", int64(1))
subWarn = sink.Entries(filter(slog.LevelWarn, subMsg))
require.Len(t, subWarn, 3, "expected subscriber WARN immediately after counter reset")
requireFieldValue(t, subWarn[2], "dropped_count", int64(1))
}
func TestResolveUserCompactionThreshold(t *testing.T) {
t.Parallel()
userID := uuid.New()
modelConfigID := uuid.New()
expectedKey := codersdk.CompactionThresholdKey(modelConfigID)
tests := []struct {
name string
dbReturn string
dbErr error
wantVal int32
wantOK bool
wantWarnLog bool
}{
{
name: "NoRowsReturnsDefault",
dbErr: sql.ErrNoRows,
wantOK: false,
},
{
name: "ValidOverride",
dbReturn: "75",
wantVal: 75,
wantOK: true,
},
{
name: "OutOfRangeValue",
dbReturn: "101",
wantOK: false,
},
{
name: "NonIntegerValue",
dbReturn: "abc",
wantOK: false,
},
{
name: "UnexpectedDBError",
dbErr: xerrors.New("connection refused"),
wantOK: false,
wantWarnLog: true,
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
mockDB := dbmock.NewMockStore(ctrl)
sink := testutil.NewFakeSink(t)
srv := &Server{
db: mockDB,
logger: sink.Logger(),
}
mockDB.EXPECT().GetUserChatCompactionThreshold(gomock.Any(), database.GetUserChatCompactionThresholdParams{
UserID: userID,
Key: expectedKey,
}).Return(tc.dbReturn, tc.dbErr)
val, ok := srv.resolveUserCompactionThreshold(context.Background(), userID, modelConfigID)
require.Equal(t, tc.wantVal, val)
require.Equal(t, tc.wantOK, ok)
warns := sink.Entries(func(e slog.SinkEntry) bool {
return e.Level == slog.LevelWarn
})
if tc.wantWarnLog {
require.NotEmpty(t, warns, "expected a warning log entry")
return
}
require.Empty(t, warns, "unexpected warning log entry")
})
}
}
// requireFieldValue asserts that a SinkEntry contains a field with
// the given name and value.
func requireFieldValue(t *testing.T, entry slog.SinkEntry, name string, expected interface{}) {
t.Helper()
for _, f := range entry.Fields {
if f.Name == name {
require.Equal(t, expected, f.Value, "field %q value mismatch", name)
return
}
}
t.Fatalf("field %q not found in log entry", name)
}
func TestSkillsFromParts(t *testing.T) {
t.Parallel()
t.Run("Empty", func(t *testing.T) {
t.Parallel()
got := skillsFromParts(nil)
require.Empty(t, got)
})
t.Run("NoSkillParts", func(t *testing.T) {
t.Parallel()
msgs := []database.ChatMessage{
chatMessageWithParts([]codersdk.ChatMessagePart{
{Type: codersdk.ChatMessagePartTypeText, Text: "hello"},
}),
}
got := skillsFromParts(msgs)
require.Empty(t, got)
})
t.Run("SingleSkill", func(t *testing.T) {
t.Parallel()
msgs := []database.ChatMessage{
chatMessageWithParts([]codersdk.ChatMessagePart{
{
Type: codersdk.ChatMessagePartTypeSkill,
SkillName: "deep-review",
SkillDescription: "Multi-reviewer code review",
SkillDir: "/home/coder/.agents/skills/deep-review",
},
}),
}
got := skillsFromParts(msgs)
require.Len(t, got, 1)
require.Equal(t, "deep-review", got[0].Name)
require.Equal(t, "Multi-reviewer code review", got[0].Description)
require.Equal(t, "/home/coder/.agents/skills/deep-review", got[0].Dir)
})
t.Run("MultipleSkillsAcrossMessages", func(t *testing.T) {
t.Parallel()
msgs := []database.ChatMessage{
chatMessageWithParts([]codersdk.ChatMessagePart{
{
Type: codersdk.ChatMessagePartTypeSkill,
SkillName: "pull-requests",
SkillDir: "/home/coder/.agents/skills/pull-requests",
},
}),
chatMessageWithParts([]codersdk.ChatMessagePart{
{
Type: codersdk.ChatMessagePartTypeSkill,
SkillName: "deep-review",
SkillDir: "/home/coder/.agents/skills/deep-review",
},
}),
}
got := skillsFromParts(msgs)
require.Len(t, got, 2)
require.Equal(t, "pull-requests", got[0].Name)
require.Equal(t, "deep-review", got[1].Name)
})
t.Run("MixedPartTypes", func(t *testing.T) {
t.Parallel()
msgs := []database.ChatMessage{
chatMessageWithParts([]codersdk.ChatMessagePart{
{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/home/coder/.coder/AGENTS.md",
},
{
Type: codersdk.ChatMessagePartTypeSkill,
SkillName: "refine-plan",
SkillDir: "/home/coder/.agents/skills/refine-plan",
},
}),
// A text-only message should be skipped entirely.
chatMessageWithParts([]codersdk.ChatMessagePart{
{Type: codersdk.ChatMessagePartTypeText, Text: "user turn"},
}),
}
got := skillsFromParts(msgs)
require.Len(t, got, 1)
require.Equal(t, "refine-plan", got[0].Name)
require.Equal(t, "/home/coder/.agents/skills/refine-plan", got[0].Dir)
})
t.Run("OptionalDescriptionOmitted", func(t *testing.T) {
t.Parallel()
msgs := []database.ChatMessage{
chatMessageWithParts([]codersdk.ChatMessagePart{
{
Type: codersdk.ChatMessagePartTypeSkill,
SkillName: "refine-plan",
SkillDir: "/home/coder/.agents/skills/refine-plan",
},
}),
}
got := skillsFromParts(msgs)
require.Len(t, got, 1)
require.Equal(t, "refine-plan", got[0].Name)
require.Empty(t, got[0].Description)
})
t.Run("InvalidJSON", func(t *testing.T) {
t.Parallel()
msgs := []database.ChatMessage{
{
Content: pqtype.NullRawMessage{
RawMessage: []byte(`not valid json with "skill" in it`),
Valid: true,
},
},
}
got := skillsFromParts(msgs)
require.Empty(t, got)
})
t.Run("RoundTrip", func(t *testing.T) {
// Simulate persist -> reconstruct cycle: marshal skill
// parts the same way persistInstructionFiles does, then
// verify skillsFromParts recovers the metadata.
t.Parallel()
want := []chattool.SkillMeta{
{Name: "deep-review", Description: "Multi-reviewer review", Dir: "/skills/deep-review"},
{Name: "pull-requests", Description: "", Dir: "/skills/pull-requests"},
}
agentID := uuid.New()
var parts []codersdk.ChatMessagePart
for _, s := range want {
parts = append(parts, codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeSkill,
SkillName: s.Name,
SkillDescription: s.Description,
SkillDir: s.Dir,
ContextFileAgentID: uuid.NullUUID{UUID: agentID, Valid: true},
})
}
msgs := []database.ChatMessage{chatMessageWithParts(parts)}
got := skillsFromParts(msgs)
require.Len(t, got, len(want))
for i, w := range want {
require.Equal(t, w.Name, got[i].Name)
require.Equal(t, w.Description, got[i].Description)
require.Equal(t, w.Dir, got[i].Dir)
}
})
}
func TestContextFileAgentID(t *testing.T) {
t.Parallel()
t.Run("EmptyMessages", func(t *testing.T) {
t.Parallel()
id, ok := contextFileAgentID(nil)
require.Equal(t, uuid.Nil, id)
require.False(t, ok)
})
t.Run("NoContextFileParts", func(t *testing.T) {
t.Parallel()
msgs := []database.ChatMessage{
chatMessageWithParts([]codersdk.ChatMessagePart{
{Type: codersdk.ChatMessagePartTypeText, Text: "hello"},
}),
}
id, ok := contextFileAgentID(msgs)
require.Equal(t, uuid.Nil, id)
require.False(t, ok)
})
t.Run("SingleContextFile", func(t *testing.T) {
t.Parallel()
agentID := uuid.New()
msgs := []database.ChatMessage{
chatMessageWithParts([]codersdk.ChatMessagePart{
{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/some/path",
ContextFileAgentID: uuid.NullUUID{UUID: agentID, Valid: true},
},
}),
}
id, ok := contextFileAgentID(msgs)
require.Equal(t, agentID, id)
require.True(t, ok)
})
t.Run("MultipleContextFiles", func(t *testing.T) {
t.Parallel()
agentID1 := uuid.New()
agentID2 := uuid.New()
msgs := []database.ChatMessage{
chatMessageWithParts([]codersdk.ChatMessagePart{
{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/first/path",
ContextFileAgentID: uuid.NullUUID{UUID: agentID1, Valid: true},
},
}),
chatMessageWithParts([]codersdk.ChatMessagePart{
{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/second/path",
ContextFileAgentID: uuid.NullUUID{UUID: agentID2, Valid: true},
},
}),
}
id, ok := contextFileAgentID(msgs)
require.Equal(t, agentID2, id)
require.True(t, ok)
})
t.Run("IgnoresSkillOnlySentinel", func(t *testing.T) {
t.Parallel()
instructionAgentID := uuid.New()
sentinelAgentID := uuid.New()
msgs := []database.ChatMessage{
chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/workspace/AGENTS.md",
ContextFileAgentID: uuid.NullUUID{UUID: instructionAgentID, Valid: true},
}}),
chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: AgentChatContextSentinelPath,
ContextFileAgentID: uuid.NullUUID{
UUID: sentinelAgentID,
Valid: true,
},
}}),
}
id, ok := contextFileAgentID(msgs)
require.Equal(t, instructionAgentID, id)
require.True(t, ok)
})
t.Run("SentinelWithoutAgentID", func(t *testing.T) {
t.Parallel()
msgs := []database.ChatMessage{
chatMessageWithParts([]codersdk.ChatMessagePart{
{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFileAgentID: uuid.NullUUID{Valid: false},
},
}),
}
id, ok := contextFileAgentID(msgs)
require.Equal(t, uuid.Nil, id)
require.False(t, ok)
})
}
func TestHasPersistedInstructionFiles(t *testing.T) {
t.Parallel()
t.Run("IgnoresAgentChatContextSentinel", func(t *testing.T) {
t.Parallel()
agentID := uuid.New()
msgs := []database.ChatMessage{
chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: AgentChatContextSentinelPath,
ContextFileAgentID: uuid.NullUUID{
UUID: agentID,
Valid: true,
},
}}),
}
require.False(t, hasPersistedInstructionFiles(msgs))
})
t.Run("AcceptsPersistedInstructionFile", func(t *testing.T) {
t.Parallel()
agentID := uuid.New()
msgs := []database.ChatMessage{
chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/workspace/AGENTS.md",
ContextFileContent: "repo instructions",
ContextFileAgentID: uuid.NullUUID{UUID: agentID, Valid: true},
}}),
}
require.True(t, hasPersistedInstructionFiles(msgs))
})
}
func TestInstructionFromContextFilesUsesLatestContextAgent(t *testing.T) {
t.Parallel()
oldAgentID := uuid.New()
newAgentID := uuid.New()
msgs := []database.ChatMessage{
chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/old/AGENTS.md",
ContextFileContent: "old instructions",
ContextFileOS: "darwin",
ContextFileDirectory: "/old",
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
}}),
chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/new/AGENTS.md",
ContextFileContent: "new instructions",
ContextFileOS: "linux",
ContextFileDirectory: "/new",
ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true},
}}),
}
got := instructionFromContextFiles(msgs)
require.Contains(t, got, "new instructions")
require.Contains(t, got, "Operating System: linux")
require.Contains(t, got, "Working Directory: /new")
require.NotContains(t, got, "old instructions")
require.NotContains(t, got, "Operating System: darwin")
}
func TestInstructionFromContextFilesKeepsLegacyUnstampedParts(t *testing.T) {
t.Parallel()
oldAgentID := uuid.New()
newAgentID := uuid.New()
msgs := []database.ChatMessage{
chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/legacy/AGENTS.md",
ContextFileContent: "legacy instructions",
}}),
chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/old/AGENTS.md",
ContextFileContent: "old instructions",
ContextFileOS: "darwin",
ContextFileDirectory: "/old",
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
}}),
chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/new/AGENTS.md",
ContextFileContent: "new instructions",
ContextFileOS: "linux",
ContextFileDirectory: "/new",
ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true},
}}),
}
got := instructionFromContextFiles(msgs)
require.Contains(t, got, "legacy instructions")
require.Contains(t, got, "new instructions")
require.Contains(t, got, "Operating System: linux")
require.Contains(t, got, "Working Directory: /new")
require.NotContains(t, got, "old instructions")
require.NotContains(t, got, "Operating System: darwin")
}
func TestSkillsFromPartsKeepsLegacyUnstampedParts(t *testing.T) {
t.Parallel()
oldAgentID := uuid.New()
newAgentID := uuid.New()
msgs := []database.ChatMessage{
chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeSkill,
SkillName: "repo-helper-legacy",
SkillDir: "/skills/repo-helper-legacy",
}}),
chatMessageWithParts([]codersdk.ChatMessagePart{
{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/old/AGENTS.md",
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
},
{
Type: codersdk.ChatMessagePartTypeSkill,
SkillName: "repo-helper-old",
SkillDir: "/skills/repo-helper-old",
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
},
}),
chatMessageWithParts([]codersdk.ChatMessagePart{
{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: AgentChatContextSentinelPath,
ContextFileAgentID: uuid.NullUUID{
UUID: newAgentID,
Valid: true,
},
},
{
Type: codersdk.ChatMessagePartTypeSkill,
SkillName: "repo-helper-new",
SkillDir: "/skills/repo-helper-new",
ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true},
},
}),
}
got := skillsFromParts(msgs)
require.Equal(t, []chattool.SkillMeta{
{Name: "repo-helper-legacy", Dir: "/skills/repo-helper-legacy"},
{Name: "repo-helper-new", Dir: "/skills/repo-helper-new"},
}, got)
}
func TestSkillsFromPartsUsesLatestContextAgent(t *testing.T) {
t.Parallel()
oldAgentID := uuid.New()
newAgentID := uuid.New()
msgs := []database.ChatMessage{
chatMessageWithParts([]codersdk.ChatMessagePart{
{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/old/AGENTS.md",
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
},
{
Type: codersdk.ChatMessagePartTypeSkill,
SkillName: "repo-helper-old",
SkillDir: "/skills/repo-helper-old",
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
},
}),
chatMessageWithParts([]codersdk.ChatMessagePart{
{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: AgentChatContextSentinelPath,
ContextFileAgentID: uuid.NullUUID{
UUID: newAgentID,
Valid: true,
},
},
{
Type: codersdk.ChatMessagePartTypeSkill,
SkillName: "repo-helper-new",
SkillDir: "/skills/repo-helper-new",
ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true},
},
}),
}
got := skillsFromParts(msgs)
require.Equal(t, []chattool.SkillMeta{{
Name: "repo-helper-new",
Dir: "/skills/repo-helper-new",
}}, got)
}
func TestMergeSkillMetas(t *testing.T) {
t.Parallel()
persisted := []chattool.SkillMeta{{
Name: "repo-helper",
Description: "Persisted skill",
Dir: "/skills/repo-helper-old",
}}
discovered := []chattool.SkillMeta{
{
Name: "repo-helper",
Description: "Discovered replacement",
Dir: "/skills/repo-helper-new",
MetaFile: "SKILL.md",
},
{
Name: "deep-review",
Description: "Discovered skill",
Dir: "/skills/deep-review",
},
}
got := mergeSkillMetas(persisted, discovered)
require.Equal(t, []chattool.SkillMeta{
discovered[0],
discovered[1],
}, got)
}
func TestSelectSkillMetasForInstructionRefresh(t *testing.T) {
t.Parallel()
persisted := []chattool.SkillMeta{{Name: "persisted", Dir: "/skills/persisted"}}
discovered := []chattool.SkillMeta{{Name: "discovered", Dir: "/skills/discovered"}}
currentAgentID := uuid.New()
otherAgentID := uuid.New()
t.Run("MergesCurrentAgentSkills", func(t *testing.T) {
t.Parallel()
got := selectSkillMetasForInstructionRefresh(
persisted,
discovered,
uuid.NullUUID{UUID: currentAgentID, Valid: true},
uuid.NullUUID{UUID: currentAgentID, Valid: true},
)
require.Equal(t, []chattool.SkillMeta{discovered[0], persisted[0]}, got)
})
t.Run("DropsStalePersistedSkillsWhenAgentChanged", func(t *testing.T) {
t.Parallel()
got := selectSkillMetasForInstructionRefresh(
persisted,
discovered,
uuid.NullUUID{UUID: currentAgentID, Valid: true},
uuid.NullUUID{UUID: otherAgentID, Valid: true},
)
require.Equal(t, discovered, got)
})
t.Run("PreservesPersistedSkillsWhenAgentLookupFails", func(t *testing.T) {
t.Parallel()
got := selectSkillMetasForInstructionRefresh(
persisted,
nil,
uuid.NullUUID{},
uuid.NullUUID{UUID: otherAgentID, Valid: true},
)
require.Equal(t, persisted, got)
})
}
func TestResolveChainModeIgnoresSkillOnlySentinelMessages(t *testing.T) {
t.Parallel()
modelConfigID := uuid.New()
assistant := database.ChatMessage{
Role: database.ChatMessageRoleAssistant,
ProviderResponseID: sql.NullString{String: "resp-123", Valid: true},
ModelConfigID: uuid.NullUUID{UUID: modelConfigID, Valid: true},
}
skillOnly := chatMessageWithParts([]codersdk.ChatMessagePart{
{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: AgentChatContextSentinelPath,
ContextFileAgentID: uuid.NullUUID{
UUID: uuid.New(),
Valid: true,
},
},
{
Type: codersdk.ChatMessagePartTypeSkill,
SkillName: "repo-helper",
SkillDir: "/skills/repo-helper",
},
})
skillOnly.Role = database.ChatMessageRoleUser
user := chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeText,
Text: "latest user message",
}})
user.Role = database.ChatMessageRoleUser
got := resolveChainMode([]database.ChatMessage{assistant, skillOnly, user})
require.Equal(t, "resp-123", got.previousResponseID)
require.Equal(t, modelConfigID, got.modelConfigID)
require.Equal(t, 2, got.trailingUserCount)
require.Equal(t, 1, got.contributingTrailingUserCount)
}
func TestFilterPromptForChainModeKeepsContributingUsersAcrossSkippedSentinelTurns(t *testing.T) {
t.Parallel()
modelConfigID := uuid.New()
priorUser := chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeText,
Text: "prior user message",
}})
priorUser.Role = database.ChatMessageRoleUser
assistant := database.ChatMessage{
Role: database.ChatMessageRoleAssistant,
ProviderResponseID: sql.NullString{String: "resp-123", Valid: true},
ModelConfigID: uuid.NullUUID{UUID: modelConfigID, Valid: true},
}
firstTrailingUser := chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeText,
Text: "first trailing user",
}})
firstTrailingUser.Role = database.ChatMessageRoleUser
skillOnly := chatMessageWithParts([]codersdk.ChatMessagePart{
{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: AgentChatContextSentinelPath,
ContextFileAgentID: uuid.NullUUID{
UUID: uuid.New(),
Valid: true,
},
},
{
Type: codersdk.ChatMessagePartTypeSkill,
SkillName: "repo-helper",
SkillDir: "/skills/repo-helper",
},
})
skillOnly.Role = database.ChatMessageRoleUser
lastTrailingUser := chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeText,
Text: "last trailing user",
}})
lastTrailingUser.Role = database.ChatMessageRoleUser
chainInfo := resolveChainMode([]database.ChatMessage{
priorUser,
assistant,
firstTrailingUser,
skillOnly,
lastTrailingUser,
})
require.Equal(t, 3, chainInfo.trailingUserCount)
require.Equal(t, 2, chainInfo.contributingTrailingUserCount)
prompt := []fantasy.Message{
{
Role: fantasy.MessageRoleSystem,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "system instruction"},
},
},
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "prior user message"},
},
},
{
Role: fantasy.MessageRoleAssistant,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "assistant reply"},
},
},
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "first trailing user"},
},
},
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "last trailing user"},
},
},
}
got := filterPromptForChainMode(prompt, chainInfo)
require.Len(t, got, 3)
require.Equal(t, fantasy.MessageRoleSystem, got[0].Role)
require.Equal(t, fantasy.MessageRoleUser, got[1].Role)
require.Equal(t, fantasy.MessageRoleUser, got[2].Role)
firstPart, ok := fantasy.AsMessagePart[fantasy.TextPart](got[1].Content[0])
require.True(t, ok)
require.Equal(t, "first trailing user", firstPart.Text)
lastPart, ok := fantasy.AsMessagePart[fantasy.TextPart](got[2].Content[0])
require.True(t, ok)
require.Equal(t, "last trailing user", lastPart.Text)
}
func TestFilterPromptForChainModeUsesContributingTrailingUsers(t *testing.T) {
t.Parallel()
modelConfigID := uuid.New()
priorUser := chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeText,
Text: "prior user message",
}})
priorUser.Role = database.ChatMessageRoleUser
assistant := database.ChatMessage{
Role: database.ChatMessageRoleAssistant,
ProviderResponseID: sql.NullString{String: "resp-123", Valid: true},
ModelConfigID: uuid.NullUUID{UUID: modelConfigID, Valid: true},
}
skillOnly := chatMessageWithParts([]codersdk.ChatMessagePart{
{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: AgentChatContextSentinelPath,
ContextFileAgentID: uuid.NullUUID{
UUID: uuid.New(),
Valid: true,
},
},
{
Type: codersdk.ChatMessagePartTypeSkill,
SkillName: "repo-helper",
SkillDir: "/skills/repo-helper",
},
})
skillOnly.Role = database.ChatMessageRoleUser
latestUser := chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeText,
Text: "latest user message",
}})
latestUser.Role = database.ChatMessageRoleUser
chainInfo := resolveChainMode([]database.ChatMessage{
priorUser,
assistant,
skillOnly,
latestUser,
})
require.Equal(t, 2, chainInfo.trailingUserCount)
require.Equal(t, 1, chainInfo.contributingTrailingUserCount)
prompt := []fantasy.Message{
{
Role: fantasy.MessageRoleSystem,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "system instruction"},
},
},
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "prior user message"},
},
},
{
Role: fantasy.MessageRoleAssistant,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "assistant reply"},
},
},
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "latest user message"},
},
},
}
got := filterPromptForChainMode(prompt, chainInfo)
require.Len(t, got, 2)
require.Equal(t, fantasy.MessageRoleSystem, got[0].Role)
require.Equal(t, fantasy.MessageRoleUser, got[1].Role)
part, ok := fantasy.AsMessagePart[fantasy.TextPart](got[1].Content[0])
require.True(t, ok)
require.Equal(t, "latest user message", part.Text)
}
func chatMessageWithParts(parts []codersdk.ChatMessagePart) database.ChatMessage {
raw, _ := json.Marshal(parts)
return database.ChatMessage{
Content: pqtype.NullRawMessage{RawMessage: raw, Valid: true},
}
}
// TestProcessChat_IgnoresStaleControlNotification verifies that
// processChat is not interrupted by a "pending" notification
// published before processing begins. This is the race that caused
// TestOpenAIReasoningWithWebSearchRoundTripStoreFalse to flake:
// SendMessage publishes "pending" via PostgreSQL NOTIFY, and due
// to async delivery the notification can arrive at the control
// subscriber after it registers but before the processor publishes
// "running".
func TestProcessChat_IgnoresStaleControlNotification(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
ps := dbpubsub.NewInMemory()
clock := quartz.NewMock(t)
chatID := uuid.New()
workerID := uuid.New()
server := &Server{
db: db,
logger: logger,
pubsub: ps,
clock: clock,
workerID: workerID,
chatHeartbeatInterval: time.Minute,
configCache: newChatConfigCache(ctx, db, clock),
heartbeatRegistry: make(map[uuid.UUID]*heartbeatEntry),
metrics: chatloop.NopMetrics(),
}
// Publish a stale "pending" notification on the control channel
// BEFORE processChat subscribes. In production this is the
// notification from SendMessage that triggered the processing.
staleNotify, err := json.Marshal(coderdpubsub.ChatStreamNotifyMessage{
Status: string(database.ChatStatusPending),
})
require.NoError(t, err)
err = ps.Publish(coderdpubsub.ChatStreamNotifyChannel(chatID), staleNotify)
require.NoError(t, err)
// Track which status processChat writes during cleanup.
var finalStatus database.ChatStatus
// The deferred cleanup in processChat runs a transaction.
db.EXPECT().InTx(gomock.Any(), gomock.Any()).DoAndReturn(
func(fn func(database.Store) error, _ *database.TxOptions) error {
return fn(db)
},
)
db.EXPECT().GetChatByIDForUpdate(gomock.Any(), chatID).Return(
database.Chat{ID: chatID, Status: database.ChatStatusRunning, WorkerID: uuid.NullUUID{UUID: workerID, Valid: true}}, nil,
)
db.EXPECT().UpdateChatStatus(gomock.Any(), gomock.Any()).DoAndReturn(
func(_ context.Context, params database.UpdateChatStatusParams) (database.Chat, error) {
finalStatus = params.Status
return database.Chat{ID: chatID, Status: params.Status}, nil
},
)
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(
database.Chat{ID: chatID, Status: database.ChatStatusError},
nil,
)
// resolveChatModel fails immediately — that's fine, we only
// need processChat to get past initialization without being
// interrupted by the stale notification.
db.EXPECT().GetChatModelConfigByID(gomock.Any(), gomock.Any()).Return(
database.ChatModelConfig{}, xerrors.New("no model configured"),
).AnyTimes()
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return(nil, nil).AnyTimes()
db.EXPECT().GetEnabledChatModelConfigs(gomock.Any()).Return(nil, nil).AnyTimes()
db.EXPECT().GetChatUsageLimitConfig(gomock.Any()).Return(
database.ChatUsageLimitConfig{}, sql.ErrNoRows,
).AnyTimes()
db.EXPECT().GetChatMessagesForPromptByChatID(gomock.Any(), chatID).Return(nil, nil).AnyTimes()
chat := database.Chat{ID: chatID, LastModelConfigID: uuid.New()}
done := make(chan struct{})
go func() {
defer close(done)
server.processChat(ctx, chat)
}()
// Wait for processChat to finish entirely. It re-reads chat state and
// runs more cleanup after UpdateChatStatus, so signaling completion from
// the status update itself races test teardown.
testutil.TryReceive(ctx, t, done)
// If the stale notification interrupted us, status would be
// "waiting" (the ErrInterrupted path). Since the gate blocked
// it, processChat reached runChat, which failed on model
// resolution → status is "error".
require.Equal(t, database.ChatStatusError, finalStatus,
"processChat should have reached runChat (error), not been interrupted (waiting)")
}
func TestShouldPublishFinishedChatState(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
workerID := uuid.New()
server := &Server{db: db}
updatedChat := database.Chat{
ID: chatID,
Status: database.ChatStatusWaiting,
WorkerID: uuid.NullUUID{},
}
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(database.Chat{
ID: chatID,
Status: database.ChatStatusWaiting,
WorkerID: uuid.NullUUID{},
}, nil)
require.True(t, server.shouldPublishFinishedChatState(ctx, logger, updatedChat))
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(database.Chat{
ID: chatID,
Status: database.ChatStatusRunning,
WorkerID: uuid.NullUUID{UUID: workerID, Valid: true},
}, nil)
require.False(t, server.shouldPublishFinishedChatState(ctx, logger, updatedChat))
}
// TestShouldPublishFinishedChatState_DBErrorPublishes pins the
// deliberate fail-open behavior when the re-read query errors: we
// surface the finished state anyway so watchers don't get stuck
// waiting for a status update that never arrives. The error path is
// easy to regress into a fail-closed default otherwise.
func TestShouldPublishFinishedChatState_DBErrorPublishes(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
server := &Server{db: db}
updatedChat := database.Chat{
ID: chatID,
Status: database.ChatStatusWaiting,
WorkerID: uuid.NullUUID{},
}
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(
database.Chat{}, xerrors.New("boom"),
)
require.True(t, server.shouldPublishFinishedChatState(ctx, logger, updatedChat),
"fail-open: a re-read error must not swallow the status change")
}
// TestHeartbeatTick_StolenChatIsInterrupted verifies that when the
// batch heartbeat UPDATE does not return a registered chat's ID
// (because another replica stole it or it was completed), the
// heartbeat tick cancels that chat's context with ErrInterrupted
// while leaving surviving chats untouched.
func TestHeartbeatTick_StolenChatIsInterrupted(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
clock := quartz.NewMock(t)
workerID := uuid.New()
server := &Server{
db: db,
logger: logger,
clock: clock,
workerID: workerID,
chatHeartbeatInterval: time.Minute,
heartbeatRegistry: make(map[uuid.UUID]*heartbeatEntry),
}
// Create three chats with independent cancel functions.
chat1 := uuid.New()
chat2 := uuid.New()
chat3 := uuid.New()
_, cancel1 := context.WithCancelCause(ctx)
_, cancel2 := context.WithCancelCause(ctx)
ctx3, cancel3 := context.WithCancelCause(ctx)
server.registerHeartbeat(&heartbeatEntry{
cancelWithCause: cancel1,
chatID: chat1,
logger: logger,
})
server.registerHeartbeat(&heartbeatEntry{
cancelWithCause: cancel2,
chatID: chat2,
logger: logger,
})
server.registerHeartbeat(&heartbeatEntry{
cancelWithCause: cancel3,
chatID: chat3,
logger: logger,
})
// The batch UPDATE returns only chat1 and chat2 —
// chat3 was "stolen" by another replica.
db.EXPECT().UpdateChatHeartbeats(gomock.Any(), gomock.Any()).DoAndReturn(
func(_ context.Context, params database.UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
require.Equal(t, workerID, params.WorkerID)
require.Len(t, params.IDs, 3)
// Return only chat1 and chat2 as surviving.
return []uuid.UUID{chat1, chat2}, nil
},
)
server.heartbeatTick(ctx)
// chat3's context should be canceled with ErrInterrupted.
require.ErrorIs(t, context.Cause(ctx3), chatloop.ErrInterrupted,
"stolen chat should be interrupted")
// chat3 should have been removed from the registry by
// unregister (in production this happens via defer in
// processChat). The heartbeat tick itself does not
// unregister — it only cancels. Verify the entry is
// still present (processChat's defer would clean it up).
server.heartbeatMu.Lock()
_, chat1Exists := server.heartbeatRegistry[chat1]
_, chat2Exists := server.heartbeatRegistry[chat2]
_, chat3Exists := server.heartbeatRegistry[chat3]
server.heartbeatMu.Unlock()
require.True(t, chat1Exists, "surviving chat1 should remain registered")
require.True(t, chat2Exists, "surviving chat2 should remain registered")
require.True(t, chat3Exists,
"stolen chat3 should still be in registry (processChat defer removes it)")
}
// TestHeartbeatTick_DBErrorDoesNotInterruptChats verifies that a
// transient database failure causes the tick to log and return
// without canceling any registered chats.
func TestHeartbeatTick_DBErrorDoesNotInterruptChats(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
clock := quartz.NewMock(t)
server := &Server{
db: db,
logger: logger,
clock: clock,
workerID: uuid.New(),
chatHeartbeatInterval: time.Minute,
heartbeatRegistry: make(map[uuid.UUID]*heartbeatEntry),
}
chatID := uuid.New()
chatCtx, cancel := context.WithCancelCause(ctx)
server.registerHeartbeat(&heartbeatEntry{
cancelWithCause: cancel,
chatID: chatID,
logger: logger,
})
// Simulate a transient DB error.
db.EXPECT().UpdateChatHeartbeats(gomock.Any(), gomock.Any()).Return(
nil, xerrors.New("connection reset"),
)
server.heartbeatTick(ctx)
// Chat should NOT be interrupted — the tick logged and
// returned early.
require.NoError(t, chatCtx.Err(),
"chat context should not be canceled on transient DB error")
}
// TestSubscribeCancelDuringGrace_ReapedBySweep verifies that a
// subscriber detach inside bufferRetainGracePeriod (the OSS trigger
// for the retained-buffer leak) leaves the state mapped, and the
// next sweep past the grace window reaps it.
func TestSubscribeCancelDuringGrace_ReapedBySweep(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
mClock := quartz.NewMock(t)
server := &Server{
logger: logger,
clock: mClock,
}
chatID := uuid.New()
start := mClock.Now()
// Just-finished chat: processing done, buffer retained for
// late-connecting relay subscribers.
state := &chatStreamState{
buffering: false,
bufferRetainedAt: start,
subscribers: map[uuid.UUID]chan codersdk.ChatStreamEvent{},
buffer: []codersdk.ChatStreamEvent{{
Type: codersdk.ChatStreamEventTypeMessagePart,
MessagePart: &codersdk.ChatStreamMessagePart{
Role: codersdk.ChatMessageRoleAssistant,
},
}},
}
server.chatStreams.Store(chatID, state)
// Real subscribeToStream cancel path: the WS subscriber detach
// that leaks in prod.
_, _, cancelSub := server.subscribeToStream(chatID)
mClock.Advance(bufferRetainGracePeriod / 2)
cancelSub()
_, ok := server.chatStreams.Load(chatID)
require.True(t, ok,
"entry should remain during grace window after subscriber detach")
mClock.Advance(bufferRetainGracePeriod)
server.sweepIdleStreams()
_, ok = server.chatStreams.Load(chatID)
require.False(t, ok,
"entry should be reaped after grace period expires and sweep runs")
}
// TestSweepIdleStreams_ReapsStaleRetainedBuffer: grace expired, no
// subscribers, not buffering -> reaped.
func TestSweepIdleStreams_ReapsStaleRetainedBuffer(t *testing.T) {
t.Parallel()
mClock := quartz.NewMock(t)
server := &Server{
logger: slogtest.Make(t, nil),
clock: mClock,
}
chatID := uuid.New()
state := &chatStreamState{
buffering: false,
bufferRetainedAt: mClock.Now(),
subscribers: map[uuid.UUID]chan codersdk.ChatStreamEvent{},
buffer: []codersdk.ChatStreamEvent{{
Type: codersdk.ChatStreamEventTypeMessagePart,
MessagePart: &codersdk.ChatStreamMessagePart{},
}},
}
server.chatStreams.Store(chatID, state)
mClock.Advance(bufferRetainGracePeriod + time.Second)
server.sweepIdleStreams()
_, ok := server.chatStreams.Load(chatID)
require.False(t, ok, "stale retained state should be reaped")
}
// TestSweepIdleStreams_DoesNotReapActiveBuffering: buffering=true
// blocks reap even long after any grace would have expired.
func TestSweepIdleStreams_DoesNotReapActiveBuffering(t *testing.T) {
t.Parallel()
mClock := quartz.NewMock(t)
server := &Server{
logger: slogtest.Make(t, nil),
clock: mClock,
}
chatID := uuid.New()
state := &chatStreamState{
buffering: true,
subscribers: map[uuid.UUID]chan codersdk.ChatStreamEvent{},
buffer: []codersdk.ChatStreamEvent{{
Type: codersdk.ChatStreamEventTypeMessagePart,
MessagePart: &codersdk.ChatStreamMessagePart{},
}},
}
server.chatStreams.Store(chatID, state)
mClock.Advance(time.Hour)
server.sweepIdleStreams()
_, ok := server.chatStreams.Load(chatID)
require.True(t, ok, "actively-buffering state must not be reaped")
}
// TestSweepIdleStreams_DoesNotReapWithSubscribers: attached
// subscribers block reap even when grace has expired.
func TestSweepIdleStreams_DoesNotReapWithSubscribers(t *testing.T) {
t.Parallel()
mClock := quartz.NewMock(t)
server := &Server{
logger: slogtest.Make(t, nil),
clock: mClock,
}
chatID := uuid.New()
state := &chatStreamState{
buffering: false,
bufferRetainedAt: mClock.Now(),
subscribers: map[uuid.UUID]chan codersdk.ChatStreamEvent{
uuid.New(): make(chan codersdk.ChatStreamEvent, 1),
},
buffer: []codersdk.ChatStreamEvent{{
Type: codersdk.ChatStreamEventTypeMessagePart,
MessagePart: &codersdk.ChatStreamMessagePart{},
}},
}
server.chatStreams.Store(chatID, state)
mClock.Advance(bufferRetainGracePeriod + time.Second)
server.sweepIdleStreams()
_, ok := server.chatStreams.Load(chatID)
require.True(t, ok, "state with subscribers must not be reaped")
}
// TestSweepIdleStreams_DefersDuringGracePeriod: sweep inside grace
// is a no-op; the next sweep past grace reaps.
func TestSweepIdleStreams_DefersDuringGracePeriod(t *testing.T) {
t.Parallel()
mClock := quartz.NewMock(t)
server := &Server{
logger: slogtest.Make(t, nil),
clock: mClock,
}
chatID := uuid.New()
start := mClock.Now()
state := &chatStreamState{
buffering: false,
bufferRetainedAt: start,
subscribers: map[uuid.UUID]chan codersdk.ChatStreamEvent{},
buffer: []codersdk.ChatStreamEvent{{
Type: codersdk.ChatStreamEventTypeMessagePart,
MessagePart: &codersdk.ChatStreamMessagePart{},
}},
}
server.chatStreams.Store(chatID, state)
mClock.Advance(bufferRetainGracePeriod / 2)
server.sweepIdleStreams()
_, ok := server.chatStreams.Load(chatID)
require.True(t, ok, "sweep inside grace window must not reap")
mClock.Advance(bufferRetainGracePeriod)
server.sweepIdleStreams()
_, ok = server.chatStreams.Load(chatID)
require.False(t, ok, "sweep after grace window must reap")
}
// TestPublishToStream_DropZeroesBackingSlot verifies that evicting
// the oldest buffered event at capacity zeroes the dropped slot so
// its *ChatStreamMessagePart becomes GC-eligible immediately.
func TestPublishToStream_DropZeroesBackingSlot(t *testing.T) {
t.Parallel()
mClock := quartz.NewMock(t)
server := &Server{
logger: slogtest.Make(t, nil),
clock: mClock,
}
chatID := uuid.New()
// Over-allocate by one so the post-drop append fits in place and
// exercises the backing-array reuse this test is checking.
buf := make([]codersdk.ChatStreamEvent, maxStreamBufferSize, maxStreamBufferSize+1)
for i := range buf {
buf[i] = codersdk.ChatStreamEvent{
Type: codersdk.ChatStreamEventTypeMessagePart,
MessagePart: &codersdk.ChatStreamMessagePart{},
}
}
// Sentinel in slot 0 distinguishes "slot was zeroed" from "slot
// was overwritten by a later append".
sentinel := &codersdk.ChatStreamMessagePart{
Role: codersdk.ChatMessageRoleAssistant,
}
buf[0] = codersdk.ChatStreamEvent{
Type: codersdk.ChatStreamEventTypeMessagePart,
MessagePart: sentinel,
}
// Alias over the full backing array so we can still observe slot
// 0 after publishToStream reslices state.buffer forward.
origBacking := buf[:cap(buf)]
state := &chatStreamState{
buffering: true,
buffer: buf,
subscribers: map[uuid.UUID]chan codersdk.ChatStreamEvent{},
}
server.chatStreams.Store(chatID, state)
newPart := &codersdk.ChatStreamMessagePart{
Role: codersdk.ChatMessageRoleAssistant,
}
server.publishToStream(chatID, codersdk.ChatStreamEvent{
Type: codersdk.ChatStreamEventTypeMessagePart,
MessagePart: newPart,
})
require.Equal(t, codersdk.ChatStreamEvent{}, origBacking[0],
"dropped slot must be zero-valued so its *ChatStreamMessagePart "+
"is eligible for GC; got %+v", origBacking[0])
// Sanity-check the in-place append path the fix targets: if Go's
// growth policy ever makes this append reallocate, this fails
// loudly so the test author revisits the setup.
require.Same(t, newPart, origBacking[len(origBacking)-1].MessagePart,
"append must have landed in the original backing array; the "+
"zero-out invariant only matters when cap > len")
}
// TestCleanupStreamIfIdle_StalePointerDoesNotDeleteFreshEntry covers
// the race where a caller holds a pointer to a no-longer-mapped
// state (e.g. a janitor Range callback racing a fresh
// getOrCreateStreamState) and would otherwise evict the fresh entry.
// With CompareAndDelete in cleanupStreamIfIdle the stale delete is
// a no-op.
func TestCleanupStreamIfIdle_StalePointerDoesNotDeleteFreshEntry(t *testing.T) {
t.Parallel()
mClock := quartz.NewMock(t)
server := &Server{
logger: slogtest.Make(t, nil),
clock: mClock,
}
chatID := uuid.New()
// Stale pointer: reapable (not buffering, no subscribers, grace
// expired) but no longer the map's live entry.
stale := &chatStreamState{
buffering: false,
bufferRetainedAt: mClock.Now(),
subscribers: map[uuid.UUID]chan codersdk.ChatStreamEvent{},
}
// Fresh entry: the state getOrCreateStreamState would install
// after a racing processChat run. Actively buffering, so not
// reapable. Only this state is in the map.
fresh := &chatStreamState{
buffering: true,
subscribers: map[uuid.UUID]chan codersdk.ChatStreamEvent{},
}
server.chatStreams.Store(chatID, fresh)
mClock.Advance(bufferRetainGracePeriod + time.Second)
// Stale caller mirrors the janitor Range callback after the map
// entry has already been replaced.
stale.mu.Lock()
server.cleanupStreamIfIdle(chatID, stale)
stale.mu.Unlock()
got, ok := server.chatStreams.Load(chatID)
require.True(t, ok,
"fresh entry must remain mapped when cleanup is called with a stale pointer")
require.Same(t, fresh, got,
"cleanup must not replace the fresh entry with the stale one")
}
// TestSafeSweepIdleStreams_RecoversFromPanic verifies that an
// unexpected panic inside sweepIdleStreams is recovered rather than
// killing the janitor goroutine. Without this guard, a panic would
// silently reintroduce the very leak the janitor exists to prevent.
func TestSafeSweepIdleStreams_RecoversFromPanic(t *testing.T) {
t.Parallel()
server := &Server{
logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
clock: quartz.NewMock(t),
}
chatID := uuid.New()
// A nil *chatStreamState passes the type assertion in sweepIdleStreams
// but panics on state.mu.Lock with a nil-pointer deref. Any future
// panic source in the sweep would trigger the same recovery path.
var nilState *chatStreamState
server.chatStreams.Store(chatID, nilState)
require.NotPanics(t, func() {
server.safeSweepIdleStreams(context.Background())
}, "safeSweepIdleStreams must recover panics so the janitor loop keeps running")
}
func TestGetWorkspaceConn_StatusCheck(t *testing.T) {
t.Parallel()
type testCase struct {
name string
agent database.WorkspaceAgent
cacheHit bool
dbError bool
wantErr error
wantDialCalled bool
wantReleaseCalled bool
}
tests := []testCase{
{
name: "DisconnectedAgentCacheMiss",
agent: database.WorkspaceAgent{
FirstConnectedAt: sql.NullTime{
Time: time.Now().Add(-10 * time.Minute),
Valid: true,
},
LastConnectedAt: sql.NullTime{
Time: time.Now().Add(-10 * time.Minute),
Valid: true,
},
},
wantErr: errChatAgentDisconnected,
},
{
name: "DisconnectedAgentCacheHit",
agent: database.WorkspaceAgent{
FirstConnectedAt: sql.NullTime{
Time: time.Now().Add(-10 * time.Minute),
Valid: true,
},
LastConnectedAt: sql.NullTime{
Time: time.Now().Add(-10 * time.Minute),
Valid: true,
},
},
cacheHit: true,
wantErr: errChatAgentDisconnected,
wantReleaseCalled: true,
},
{
name: "TimedOutAgentCacheMiss",
agent: database.WorkspaceAgent{
CreatedAt: time.Now().Add(-10 * time.Minute),
ConnectionTimeoutSeconds: 60,
},
wantErr: errChatAgentDisconnected,
},
{
// A "connecting" agent (never connected, normal after
// fresh build) must NOT be blocked by the status check.
name: "ConnectingAgentProceeds",
agent: database.WorkspaceAgent{},
wantDialCalled: true,
},
{
name: "CacheHitHealthyAgent",
agent: database.WorkspaceAgent{
FirstConnectedAt: sql.NullTime{
Time: time.Now().Add(-5 * time.Minute),
Valid: true,
},
LastConnectedAt: sql.NullTime{
Time: time.Now(),
Valid: true,
},
},
cacheHit: true,
},
{
// When GetWorkspaceAgentByID returns an error on
// cache hit, the cached connection should be returned.
name: "CacheHitDBError",
agent: database.WorkspaceAgent{
FirstConnectedAt: sql.NullTime{
Time: time.Now().Add(-5 * time.Minute),
Valid: true,
},
LastConnectedAt: sql.NullTime{
Time: time.Now(),
Valid: true,
},
},
cacheHit: true,
dbError: true,
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
workspaceID := uuid.New()
agentID := uuid.New()
chat := database.Chat{
ID: uuid.New(),
WorkspaceID: uuid.NullUUID{
UUID: workspaceID,
Valid: true,
},
AgentID: uuid.NullUUID{
UUID: agentID,
Valid: true,
},
}
// Stamp the agent with the generated ID.
agent := tc.agent
agent.ID = agentID
// Set up the DB mock for GetWorkspaceAgentByID.
if tc.dbError {
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).
Return(database.WorkspaceAgent{}, xerrors.New("connection reset")).
Times(1)
} else {
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).
Return(agent, nil).
Times(1)
}
var dialCalled bool
var releaseCalled bool
// For ConnectingAgentProceeds the dial returns a real
// mock conn; for all others it should not be reached.
var dialConn *agentconnmock.MockAgentConn
if tc.wantDialCalled {
dialConn = agentconnmock.NewMockAgentConn(ctrl)
dialConn.EXPECT().SetExtraHeaders(gomock.Any()).Times(1)
}
server := &Server{
db: db,
logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
clock: quartz.NewReal(),
agentInactiveDisconnectTimeout: 30 * time.Second,
dialTimeout: defaultDialTimeout,
}
server.agentConnFn = func(_ context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) {
dialCalled = true
if dialConn != nil {
return dialConn, func() {}, nil
}
return nil, nil, xerrors.New("should not be called")
}
chatStateMu := &sync.Mutex{}
currentChat := chat
workspaceCtx := turnWorkspaceContext{
server: server,
chatStateMu: chatStateMu,
currentChat: &currentChat,
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) {
return database.Chat{}, nil
},
}
defer workspaceCtx.close()
// For cache-hit tests, pre-populate the cached
// connection state.
var cachedConn *agentconnmock.MockAgentConn
if tc.cacheHit {
cachedConn = agentconnmock.NewMockAgentConn(ctrl)
workspaceCtx.agent = agent
workspaceCtx.agentLoaded = true
workspaceCtx.conn = cachedConn
workspaceCtx.releaseConn = func() { releaseCalled = true }
workspaceCtx.cachedWorkspaceID = chat.WorkspaceID
}
ctx := testutil.Context(t, testutil.WaitShort)
gotConn, err := workspaceCtx.getWorkspaceConn(ctx)
switch {
case tc.wantErr != nil:
require.Nil(t, gotConn)
require.ErrorIs(t, err, tc.wantErr)
case tc.cacheHit:
require.NoError(t, err)
require.Same(t, cachedConn, gotConn)
default:
// Cache-miss success (ConnectingAgentProceeds).
require.NoError(t, err)
require.Same(t, dialConn, gotConn)
}
require.Equal(t, tc.wantDialCalled, dialCalled, "dial called")
require.Equal(t, tc.wantReleaseCalled, releaseCalled, "release called")
// For error cases on cache-miss, the cache should be
// cleared.
if tc.wantErr != nil && !tc.cacheHit {
workspaceCtx.mu.Lock()
defer workspaceCtx.mu.Unlock()
require.False(t, workspaceCtx.agentLoaded)
require.Nil(t, workspaceCtx.conn)
}
// For cache-hit disconnect, the cache should also be
// cleared.
if tc.wantErr != nil && tc.cacheHit {
workspaceCtx.mu.Lock()
defer workspaceCtx.mu.Unlock()
require.False(t, workspaceCtx.agentLoaded)
require.Nil(t, workspaceCtx.conn)
}
})
}
}
func TestGetWorkspaceConn_DialTimeout(t *testing.T) {
// When dialWithLazyValidation blocks beyond the dial
// timeout, getWorkspaceConn should return
// errChatDialTimeout instead of hanging indefinitely.
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
workspaceID := uuid.New()
agentID := uuid.New()
chat := database.Chat{
ID: uuid.New(),
WorkspaceID: uuid.NullUUID{
UUID: workspaceID,
Valid: true,
},
AgentID: uuid.NullUUID{
UUID: agentID,
Valid: true,
},
}
// Agent appears connected so the status check passes.
connectedAgent := database.WorkspaceAgent{
ID: agentID,
FirstConnectedAt: sql.NullTime{
Time: time.Now().Add(-1 * time.Minute),
Valid: true,
},
LastConnectedAt: sql.NullTime{
Time: time.Now(),
Valid: true,
},
}
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).
Return(connectedAgent, nil).
Times(1)
server := &Server{
db: db,
clock: quartz.NewReal(),
agentInactiveDisconnectTimeout: 30 * time.Second,
dialTimeout: 10 * time.Millisecond,
}
// Dial blocks forever (simulates unreachable agent).
server.agentConnFn = func(ctx context.Context, _ uuid.UUID) (workspacesdk.AgentConn, func(), error) {
<-ctx.Done()
return nil, nil, ctx.Err()
}
chatStateMu := &sync.Mutex{}
currentChat := chat
workspaceCtx := turnWorkspaceContext{
server: server,
chatStateMu: chatStateMu,
currentChat: &currentChat,
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
}
defer workspaceCtx.close()
ctx := testutil.Context(t, testutil.WaitShort)
gotConn, err := workspaceCtx.getWorkspaceConn(ctx)
require.Nil(t, gotConn)
require.ErrorIs(t, err, errChatDialTimeout)
}
func TestGetWorkspaceConn_DialTimeoutParentCanceled(t *testing.T) {
// When the parent context is canceled, the parent's error
// must propagate unchanged (not wrapped as a dial timeout).
// This is critical because the chatloop checks
// context.Cause(ctx) for ErrInterrupted.
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
workspaceID := uuid.New()
agentID := uuid.New()
chat := database.Chat{
ID: uuid.New(),
WorkspaceID: uuid.NullUUID{
UUID: workspaceID,
Valid: true,
},
AgentID: uuid.NullUUID{
UUID: agentID,
Valid: true,
},
}
connectedAgent := database.WorkspaceAgent{
ID: agentID,
FirstConnectedAt: sql.NullTime{
Time: time.Now().Add(-1 * time.Minute),
Valid: true,
},
LastConnectedAt: sql.NullTime{
Time: time.Now(),
Valid: true,
},
}
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).
Return(connectedAgent, nil).
Times(1)
parentErr := xerrors.New("parent canceled")
ctx, cancel := context.WithCancelCause(testutil.Context(t, testutil.WaitShort))
server := &Server{
db: db,
clock: quartz.NewReal(),
agentInactiveDisconnectTimeout: 30 * time.Second,
// Use a very long dial timeout so the parent cancel fires
// first.
dialTimeout: 10 * time.Minute,
}
// Signal when the dial goroutine has started so we can
// cancel the parent at the right time without time.Sleep.
dialStarted := make(chan struct{})
server.agentConnFn = func(ctx context.Context, _ uuid.UUID) (workspacesdk.AgentConn, func(), error) {
close(dialStarted)
<-ctx.Done()
return nil, nil, ctx.Err()
}
chatStateMu := &sync.Mutex{}
currentChat := chat
workspaceCtx := turnWorkspaceContext{
server: server,
chatStateMu: chatStateMu,
currentChat: &currentChat,
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
}
defer workspaceCtx.close()
// Cancel the parent after the dial starts.
go func() {
<-dialStarted
cancel(parentErr)
}()
gotConn, err := workspaceCtx.getWorkspaceConn(ctx)
require.Nil(t, gotConn)
// The error must NOT be errChatDialTimeout.
require.NotErrorIs(t, err, errChatDialTimeout)
// The parent context's error should propagate.
require.Error(t, err)
require.ErrorIs(t, err, context.Canceled)
}
func TestGetWorkspaceConn_DialErrorNotMisclassifiedAsTimeout(t *testing.T) {
// Regression test: a non-timeout dial error (e.g. auth
// failure) with the parent context still alive must NOT be
// converted to errChatDialTimeout. Before the fix,
// dialCancel() poisoned dialCtx.Err(), causing all errors
// to be misclassified.
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
workspaceID := uuid.New()
agentID := uuid.New()
chat := database.Chat{
ID: uuid.New(),
WorkspaceID: uuid.NullUUID{
UUID: workspaceID,
Valid: true,
},
AgentID: uuid.NullUUID{
UUID: agentID,
Valid: true,
},
}
connectedAgent := database.WorkspaceAgent{
ID: agentID,
FirstConnectedAt: sql.NullTime{
Time: time.Now().Add(-1 * time.Minute),
Valid: true,
},
LastConnectedAt: sql.NullTime{
Time: time.Now(),
Valid: true,
},
}
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).
Return(connectedAgent, nil).
Times(1)
// When the initial dial fails immediately, dialWithLazyValidation
// calls resolveFastFailure which validates the binding. Mock the
// validation to return the same agent, triggering a synchronous
// redial that also returns the error.
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).
Return([]database.WorkspaceAgent{connectedAgent}, nil).
AnyTimes()
dialErr := xerrors.New("authentication failed")
server := &Server{
db: db,
clock: quartz.NewReal(),
agentInactiveDisconnectTimeout: 30 * time.Second,
// Generous timeout so the dial error fires well before
// the timeout.
dialTimeout: defaultDialTimeout,
}
server.agentConnFn = func(context.Context, uuid.UUID) (workspacesdk.AgentConn, func(), error) {
// Return an error immediately (not a timeout).
return nil, nil, dialErr
}
chatStateMu := &sync.Mutex{}
currentChat := chat
workspaceCtx := turnWorkspaceContext{
server: server,
chatStateMu: chatStateMu,
currentChat: &currentChat,
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
}
defer workspaceCtx.close()
ctx := testutil.Context(t, testutil.WaitShort)
gotConn, err := workspaceCtx.getWorkspaceConn(ctx)
require.Nil(t, gotConn)
// Must NOT be misclassified as a dial timeout.
require.NotErrorIs(t, err, errChatDialTimeout)
// The original dial error should propagate.
require.ErrorContains(t, err, "authentication failed")
}