mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
4751416b29
**Breaking change for changelog:**
> `codersdk.Chat.last_error` now returns a structured `ChatError` object
(`{message, kind, provider, retryable, status_code, detail}`) instead of
a plain string. The chats API is experimental
(`/api/experimental/chats`), so this ships without a deprecation cycle;
consumers reading `chat.last_error` as a string must update to read
`chat.last_error.message`. SDK/generated TypeScript terminal error
payloads now use the single `ChatError` type; the live stream error
payload type is renamed from `ChatStreamError` to `ChatError`.
Persisted chat errors now carry the same provider-specific detail (kind,
provider, retryable, HTTP status, optional detail) as the live stream,
so refreshing a failed chat rehydrates with the full structured error
instead of a one-line headline.
Existing rows are migrated in place: legacy text errors are wrapped into
`{message, kind: "generic"}` so already-errored chats still render, and
rows with `last_error IS NULL` stay NULL. Internally, persisted fallback
decoding now reuses the existing `chaterror.KindGeneric` constant, with
no JSON value change.
Closes CODAGT-239
4744 lines
143 KiB
Go
4744 lines
143 KiB
Go
package chatd
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"charm.land/fantasy"
|
|
"github.com/google/uuid"
|
|
"github.com/sqlc-dev/pqtype"
|
|
"github.com/stretchr/testify/require"
|
|
"go.uber.org/mock/gomock"
|
|
"golang.org/x/xerrors"
|
|
|
|
"cdr.dev/slog/v3"
|
|
"cdr.dev/slog/v3/sloggers/slogtest"
|
|
"github.com/coder/coder/v2/coderd/database"
|
|
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
|
"github.com/coder/coder/v2/coderd/database/dbmock"
|
|
dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub"
|
|
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
|
|
"github.com/coder/coder/v2/coderd/x/chatd/chaterror"
|
|
"github.com/coder/coder/v2/coderd/x/chatd/chatloop"
|
|
openaicomputeruse "github.com/coder/coder/v2/coderd/x/chatd/chatopenai/computeruse"
|
|
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
|
|
"github.com/coder/coder/v2/coderd/x/chatd/chattest"
|
|
"github.com/coder/coder/v2/coderd/x/chatd/chattool"
|
|
"github.com/coder/coder/v2/codersdk"
|
|
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
|
"github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock"
|
|
"github.com/coder/coder/v2/testutil"
|
|
"github.com/coder/quartz"
|
|
)
|
|
|
|
type testAgentTool struct {
|
|
info fantasy.ToolInfo
|
|
providerOptions fantasy.ProviderOptions
|
|
}
|
|
|
|
func newTestAgentTool(name string) fantasy.AgentTool {
|
|
return &testAgentTool{info: fantasy.ToolInfo{Name: name}}
|
|
}
|
|
|
|
func (t *testAgentTool) Info() fantasy.ToolInfo {
|
|
return t.info
|
|
}
|
|
|
|
func (t *testAgentTool) Run(context.Context, fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
|
_ = t
|
|
return fantasy.ToolResponse{}, nil
|
|
}
|
|
|
|
func (t *testAgentTool) ProviderOptions() fantasy.ProviderOptions {
|
|
return t.providerOptions
|
|
}
|
|
|
|
func (t *testAgentTool) SetProviderOptions(opts fantasy.ProviderOptions) {
|
|
t.providerOptions = opts
|
|
}
|
|
|
|
type testMCPAgentTool struct {
|
|
*testAgentTool
|
|
configID uuid.UUID
|
|
}
|
|
|
|
func newTestMCPAgentTool(name string, configID uuid.UUID) fantasy.AgentTool {
|
|
return &testMCPAgentTool{
|
|
testAgentTool: &testAgentTool{info: fantasy.ToolInfo{Name: name}},
|
|
configID: configID,
|
|
}
|
|
}
|
|
|
|
func (t *testMCPAgentTool) MCPServerConfigID() uuid.UUID {
|
|
return t.configID
|
|
}
|
|
|
|
func TestComputerUseProviderAndModelFromConfig(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
rawProvider string
|
|
wantProvider string
|
|
wantErr string
|
|
}{
|
|
{
|
|
name: "DefaultAnthropic",
|
|
rawProvider: "",
|
|
wantProvider: chattool.ComputerUseProviderAnthropic,
|
|
},
|
|
{
|
|
name: "OpenAI",
|
|
rawProvider: " openai ",
|
|
wantProvider: chattool.ComputerUseProviderOpenAI,
|
|
},
|
|
{
|
|
name: "Unknown",
|
|
rawProvider: "bogus",
|
|
wantErr: `unknown computer-use provider "bogus" configured in agents_computer_use_provider`,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
server := &Server{db: db}
|
|
|
|
db.EXPECT().GetChatComputerUseProvider(gomock.Any()).DoAndReturn(
|
|
func(ctx context.Context) (string, error) {
|
|
_, ok := dbauthz.ActorFromContext(ctx)
|
|
require.True(t, ok, "config reads must have an actor")
|
|
return tt.rawProvider, nil
|
|
},
|
|
)
|
|
|
|
provider, modelProvider, modelName, err := server.computerUseProviderAndModelFromConfig(context.Background())
|
|
if tt.wantErr != "" {
|
|
require.Error(t, err)
|
|
require.Contains(t, err.Error(), tt.wantErr)
|
|
return
|
|
}
|
|
require.NoError(t, err)
|
|
require.Equal(t, tt.wantProvider, provider)
|
|
|
|
wantModelProvider, wantModelName, ok := chattool.DefaultComputerUseModel(tt.wantProvider)
|
|
require.True(t, ok)
|
|
require.Equal(t, wantModelProvider, modelProvider)
|
|
require.Equal(t, wantModelName, modelName)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestResolveComputerUseModel_OpenAIMissingCredentials(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := &Server{}
|
|
provider := chattool.ComputerUseProviderOpenAI
|
|
modelProvider, modelName, ok := chattool.DefaultComputerUseModel(provider)
|
|
require.True(t, ok)
|
|
|
|
model, debugEnabled, resolvedProvider, resolvedModel, err := server.resolveComputerUseModel(
|
|
context.Background(),
|
|
database.Chat{ID: uuid.New(), OwnerID: uuid.New()},
|
|
chatprovider.ProviderAPIKeys{},
|
|
provider,
|
|
modelProvider,
|
|
modelName,
|
|
)
|
|
require.Error(t, err)
|
|
require.Nil(t, model)
|
|
require.False(t, debugEnabled)
|
|
require.Empty(t, resolvedProvider)
|
|
require.Empty(t, resolvedModel)
|
|
require.Contains(t, err.Error(), `provider "openai" model "gpt-5.5"`)
|
|
require.Contains(t, err.Error(), "OPENAI_API_KEY is not set")
|
|
require.NotContains(t, err.Error(), "ANTHROPIC_API_KEY")
|
|
}
|
|
|
|
func TestAppendComputerUseProviderTool(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
providerTools, err := appendComputerUseProviderTool(
|
|
nil,
|
|
computerUseProviderToolOptions{
|
|
provider: chattool.ComputerUseProviderOpenAI,
|
|
isComputerUse: true,
|
|
logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
|
|
},
|
|
)
|
|
require.NoError(t, err)
|
|
require.Len(t, providerTools, 1)
|
|
require.True(t, openaicomputeruse.IsTool(providerTools[0].Definition))
|
|
require.Equal(t, "computer", providerTools[0].Definition.GetName())
|
|
require.Equal(t, "computer", providerTools[0].Runner.Info().Name)
|
|
require.NotNil(t, providerTools[0].ResultProviderMetadata)
|
|
|
|
metadata := providerTools[0].ResultProviderMetadata(
|
|
fantasy.NewImageResponse([]byte("png"), "image/png"),
|
|
)
|
|
require.NotNil(t, metadata)
|
|
}
|
|
|
|
func TestAppendComputerUseProviderTool_Gates(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
baseTools := []chatloop.ProviderTool{{
|
|
Definition: fantasy.ProviderDefinedTool{
|
|
ID: "web_search",
|
|
Name: "web_search",
|
|
},
|
|
}}
|
|
|
|
tests := []struct {
|
|
name string
|
|
isPlanModeTurn bool
|
|
isComputerUse bool
|
|
}{
|
|
{name: "PlanMode", isPlanModeTurn: true, isComputerUse: true},
|
|
// Non-computer-use includes regular, master, general, and explore chats.
|
|
// Mode cannot be both ChatModeComputerUse and another chat mode.
|
|
{name: "NonComputerUseModes"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
providerTools, err := appendComputerUseProviderTool(
|
|
baseTools,
|
|
computerUseProviderToolOptions{
|
|
provider: chattool.ComputerUseProviderOpenAI,
|
|
isPlanModeTurn: tt.isPlanModeTurn,
|
|
isComputerUse: tt.isComputerUse,
|
|
},
|
|
)
|
|
require.NoError(t, err)
|
|
require.Len(t, providerTools, 1)
|
|
require.Equal(t, "web_search", providerTools[0].Definition.GetName())
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAppendComputerUseProviderTool_AnthropicHasNoResultMetadata(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
providerTools, err := appendComputerUseProviderTool(
|
|
nil,
|
|
computerUseProviderToolOptions{
|
|
provider: chattool.ComputerUseProviderAnthropic,
|
|
isComputerUse: true,
|
|
logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
|
|
},
|
|
)
|
|
require.NoError(t, err)
|
|
require.Len(t, providerTools, 1)
|
|
require.Equal(t, "computer", providerTools[0].Definition.GetName())
|
|
require.Nil(t, providerTools[0].ResultProviderMetadata)
|
|
}
|
|
|
|
func TestFilterExternalMCPConfigsForTurn(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
approvedConfig := database.MCPServerConfig{ID: uuid.New(), AllowInPlanMode: true}
|
|
blockedConfig := database.MCPServerConfig{ID: uuid.New(), AllowInPlanMode: false}
|
|
configs := []database.MCPServerConfig{approvedConfig, blockedConfig}
|
|
planMode := database.NullChatPlanMode{
|
|
ChatPlanMode: database.ChatPlanModePlan,
|
|
Valid: true,
|
|
}
|
|
|
|
t.Run("NonPlanModePassesThroughAllConfigs", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
filtered, approvedIDs := filterExternalMCPConfigsForTurn(
|
|
configs,
|
|
database.NullChatPlanMode{},
|
|
uuid.NullUUID{},
|
|
)
|
|
|
|
require.Equal(t, configs, filtered)
|
|
require.Nil(t, approvedIDs)
|
|
})
|
|
|
|
t.Run("PlanModeSubagentsReturnNoConfigs", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
filtered, approvedIDs := filterExternalMCPConfigsForTurn(
|
|
configs,
|
|
planMode,
|
|
uuid.NullUUID{UUID: uuid.New(), Valid: true},
|
|
)
|
|
|
|
require.Nil(t, filtered)
|
|
require.NotNil(t, approvedIDs)
|
|
require.Empty(t, approvedIDs)
|
|
})
|
|
|
|
t.Run("PlanModeRootFiltersToApprovedConfigs", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
filtered, approvedIDs := filterExternalMCPConfigsForTurn(
|
|
configs,
|
|
planMode,
|
|
uuid.NullUUID{},
|
|
)
|
|
|
|
require.Equal(t, []database.MCPServerConfig{approvedConfig}, filtered)
|
|
require.Equal(t, map[uuid.UUID]struct{}{approvedConfig.ID: {}}, approvedIDs)
|
|
})
|
|
}
|
|
|
|
func TestActiveToolNamesForTurn(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
makeTools := func(names ...string) []fantasy.AgentTool {
|
|
tools := make([]fantasy.AgentTool, 0, len(names))
|
|
for _, name := range names {
|
|
tools = append(tools, newTestAgentTool(name))
|
|
}
|
|
return tools
|
|
}
|
|
|
|
planMode := database.NullChatPlanMode{
|
|
ChatPlanMode: database.ChatPlanModePlan,
|
|
Valid: true,
|
|
}
|
|
|
|
t.Run("NormalModeReturnsAllRegisteredTools", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
got := activeToolNamesForTurn(makeTools(
|
|
"read_file",
|
|
"propose_plan",
|
|
"custom_tool",
|
|
"execute",
|
|
), database.NullChatPlanMode{}, uuid.NullUUID{}, nil)
|
|
|
|
require.Equal(t, []string{
|
|
"read_file",
|
|
"propose_plan",
|
|
"custom_tool",
|
|
"execute",
|
|
}, got)
|
|
})
|
|
|
|
t.Run("PlanModeIncludesOnlyAllowlistedBuiltIns", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
got := activeToolNamesForTurn(makeTools(
|
|
"read_file",
|
|
"write_file",
|
|
"edit_files",
|
|
"execute",
|
|
"process_output",
|
|
"process_list",
|
|
"process_signal",
|
|
"list_templates",
|
|
"read_template",
|
|
"create_workspace",
|
|
"start_workspace",
|
|
"propose_plan",
|
|
"spawn_agent",
|
|
"wait_agent",
|
|
"message_agent",
|
|
"close_agent",
|
|
"read_skill",
|
|
"read_skill_file",
|
|
"ask_user_question",
|
|
), planMode, uuid.NullUUID{}, nil)
|
|
|
|
require.Equal(t, []string{
|
|
"read_file",
|
|
"write_file",
|
|
"edit_files",
|
|
"execute",
|
|
"process_output",
|
|
"list_templates",
|
|
"read_template",
|
|
"create_workspace",
|
|
"start_workspace",
|
|
"propose_plan",
|
|
"spawn_agent",
|
|
"wait_agent",
|
|
"read_skill",
|
|
"read_skill_file",
|
|
"ask_user_question",
|
|
}, got)
|
|
})
|
|
|
|
t.Run("PlanModeChildChatsAllowExplorationOnly", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
got := activeToolNamesForTurn(makeTools(
|
|
"read_file",
|
|
"write_file",
|
|
"edit_files",
|
|
"execute",
|
|
"process_output",
|
|
"list_templates",
|
|
"read_template",
|
|
"create_workspace",
|
|
"start_workspace",
|
|
"propose_plan",
|
|
"spawn_agent",
|
|
"wait_agent",
|
|
"read_skill",
|
|
"read_skill_file",
|
|
"ask_user_question",
|
|
), planMode, uuid.NullUUID{UUID: uuid.New(), Valid: true}, nil)
|
|
|
|
require.Equal(t, []string{
|
|
"read_file",
|
|
"execute",
|
|
"process_output",
|
|
"read_skill",
|
|
"read_skill_file",
|
|
}, got)
|
|
require.NotContains(t, got, "write_file")
|
|
require.NotContains(t, got, "edit_files")
|
|
require.NotContains(t, got, "ask_user_question")
|
|
require.NotContains(t, got, "propose_plan")
|
|
require.NotContains(t, got, "spawn_explore_agent")
|
|
})
|
|
|
|
t.Run("PlanModeStillExcludesDangerousTools", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
got := activeToolNamesForTurn(makeTools(
|
|
"execute",
|
|
"process_output",
|
|
"message_agent",
|
|
"spawn_computer_use_agent",
|
|
"propose_plan",
|
|
), planMode, uuid.NullUUID{}, nil)
|
|
|
|
require.Equal(t, []string{"execute", "process_output", "propose_plan"}, got)
|
|
require.NotContains(t, got, "message_agent")
|
|
require.NotContains(t, got, "spawn_computer_use_agent")
|
|
})
|
|
|
|
t.Run("PlanModeExcludesUnknownTools", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
got := activeToolNamesForTurn(makeTools(
|
|
"read_file",
|
|
"custom_tool",
|
|
"another_custom_tool",
|
|
"propose_plan",
|
|
), planMode, uuid.NullUUID{}, nil)
|
|
|
|
require.Equal(t, []string{
|
|
"read_file",
|
|
"propose_plan",
|
|
}, got)
|
|
require.NotContains(t, got, "custom_tool")
|
|
require.NotContains(t, got, "another_custom_tool")
|
|
})
|
|
|
|
t.Run("PlanModeIncludesOnlyApprovedExternalMCPTools", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
approvedConfigID := uuid.New()
|
|
blockedConfigID := uuid.New()
|
|
got := activeToolNamesForTurn([]fantasy.AgentTool{
|
|
newTestAgentTool("read_file"),
|
|
newTestMCPAgentTool("approved-mcp__echo", approvedConfigID),
|
|
newTestMCPAgentTool("blocked-mcp__echo", blockedConfigID),
|
|
newTestAgentTool("workspace-mcp__echo"),
|
|
}, planMode, uuid.NullUUID{}, map[uuid.UUID]struct{}{
|
|
approvedConfigID: {},
|
|
})
|
|
|
|
require.Equal(t, []string{
|
|
"read_file",
|
|
"approved-mcp__echo",
|
|
}, got)
|
|
require.NotContains(t, got, "blocked-mcp__echo")
|
|
require.NotContains(t, got, "workspace-mcp__echo")
|
|
})
|
|
}
|
|
|
|
func TestAllowedExploreToolNames(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
externalConfigID := uuid.New()
|
|
got := allowedExploreToolNames([]fantasy.AgentTool{
|
|
newTestAgentTool("read_file"),
|
|
newTestAgentTool("write_file"),
|
|
newTestMCPAgentTool("external-mcp__echo", externalConfigID),
|
|
newTestAgentTool("workspace-mcp__echo"),
|
|
newTestAgentTool("execute"),
|
|
newTestAgentTool("process_output"),
|
|
newTestAgentTool("process_list"),
|
|
newTestAgentTool("process_signal"),
|
|
newTestAgentTool("spawn_agent"),
|
|
newTestAgentTool("wait_agent"),
|
|
newTestAgentTool("read_skill"),
|
|
newTestAgentTool("read_skill_file"),
|
|
newTestAgentTool("ask_user_question"),
|
|
})
|
|
|
|
require.Equal(t, []string{
|
|
"read_file",
|
|
"external-mcp__echo",
|
|
"execute",
|
|
"process_output",
|
|
"read_skill",
|
|
"read_skill_file",
|
|
}, got)
|
|
require.NotContains(t, got, "workspace-mcp__echo")
|
|
}
|
|
|
|
func TestAllowedBehaviorToolNames(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
makeTools := func(names ...string) []fantasy.AgentTool {
|
|
tools := make([]fantasy.AgentTool, 0, len(names))
|
|
for _, name := range names {
|
|
tools = append(tools, newTestAgentTool(name))
|
|
}
|
|
return tools
|
|
}
|
|
|
|
allTools := makeTools("read_file", "custom_tool", "spawn_agent")
|
|
exploreMode := database.NullChatMode{
|
|
ChatMode: database.ChatModeExplore,
|
|
Valid: true,
|
|
}
|
|
|
|
t.Run("DefaultModeReturnsAllTools", func(t *testing.T) {
|
|
t.Parallel()
|
|
require.Equal(t, []string{"read_file", "custom_tool", "spawn_agent"}, allowedBehaviorToolNames(
|
|
allTools,
|
|
database.NullChatMode{},
|
|
))
|
|
})
|
|
|
|
t.Run("ExploreModeUsesExploreAllowlist", func(t *testing.T) {
|
|
t.Parallel()
|
|
require.Equal(t, []string{"read_file"}, allowedBehaviorToolNames(
|
|
allTools,
|
|
exploreMode,
|
|
))
|
|
})
|
|
}
|
|
|
|
func TestStopAfterPlanTools(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
planMode := database.NullChatPlanMode{
|
|
ChatPlanMode: database.ChatPlanModePlan,
|
|
Valid: true,
|
|
}
|
|
|
|
t.Run("NormalModeReturnsNil", func(t *testing.T) {
|
|
t.Parallel()
|
|
require.Nil(t, stopAfterPlanTools(database.NullChatPlanMode{}, uuid.NullUUID{}))
|
|
})
|
|
|
|
t.Run("RootPlanModeIncludesClarificationTool", func(t *testing.T) {
|
|
t.Parallel()
|
|
require.Equal(t, map[string]struct{}{
|
|
"propose_plan": {},
|
|
"ask_user_question": {},
|
|
}, stopAfterPlanTools(planMode, uuid.NullUUID{}))
|
|
})
|
|
|
|
t.Run("ChildPlanModeSkipsClarificationTool", func(t *testing.T) {
|
|
t.Parallel()
|
|
require.Equal(t, map[string]struct{}{
|
|
"propose_plan": {},
|
|
}, stopAfterPlanTools(planMode, uuid.NullUUID{UUID: uuid.New(), Valid: true}))
|
|
})
|
|
}
|
|
|
|
func TestStopAfterBehaviorTools(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
planMode := database.NullChatPlanMode{
|
|
ChatPlanMode: database.ChatPlanModePlan,
|
|
Valid: true,
|
|
}
|
|
exploreMode := database.NullChatMode{
|
|
ChatMode: database.ChatModeExplore,
|
|
Valid: true,
|
|
}
|
|
|
|
t.Run("DefaultModeReturnsNil", func(t *testing.T) {
|
|
t.Parallel()
|
|
require.Nil(t, stopAfterBehaviorTools(
|
|
database.NullChatPlanMode{},
|
|
database.NullChatMode{},
|
|
uuid.NullUUID{},
|
|
))
|
|
})
|
|
|
|
t.Run("RootPlanModeIncludesClarificationTool", func(t *testing.T) {
|
|
t.Parallel()
|
|
require.Equal(t, map[string]struct{}{
|
|
"propose_plan": {},
|
|
"ask_user_question": {},
|
|
}, stopAfterBehaviorTools(planMode, database.NullChatMode{}, uuid.NullUUID{}))
|
|
})
|
|
|
|
t.Run("ChildPlanModeSkipsClarificationTool", func(t *testing.T) {
|
|
t.Parallel()
|
|
require.Equal(t, map[string]struct{}{
|
|
"propose_plan": {},
|
|
}, stopAfterBehaviorTools(planMode, database.NullChatMode{}, uuid.NullUUID{UUID: uuid.New(), Valid: true}))
|
|
})
|
|
|
|
t.Run("ExploreModeReturnsNil", func(t *testing.T) {
|
|
t.Parallel()
|
|
require.Nil(t, stopAfterBehaviorTools(planMode, exploreMode, uuid.NullUUID{}))
|
|
})
|
|
}
|
|
|
|
// TestWaitForActiveChatStop and TestWaitForActiveChatStop_WaitsForReplacementRun
|
|
// were removed along with the process-local activeChats mechanism.
|
|
// Debug cleanup is now best-effort; stale finalization handles orphaned rows.
|
|
|
|
// TestArchiveChatWaitsForActiveChatStop and
|
|
// TestArchiveChatWaitsForEveryInterruptedChat were removed along with
|
|
// the process-local activeChats mechanism. Archive cleanup is now
|
|
// best-effort; stale finalization handles any orphaned rows.
|
|
|
|
func TestRenameChatTitle(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
setupRealWorkerLock := func(
|
|
db *dbmock.MockStore,
|
|
chatID uuid.UUID,
|
|
lockedChat database.Chat,
|
|
) {
|
|
lockTx := dbmock.NewMockStore(gomock.NewController(t))
|
|
unlockTx := dbmock.NewMockStore(gomock.NewController(t))
|
|
gomock.InOrder(
|
|
db.EXPECT().InTx(gomock.Any(), database.DefaultTXOptions().WithID("chat_title_regenerate_lock")).DoAndReturn(
|
|
func(fn func(database.Store) error, _ *database.TxOptions) error {
|
|
return fn(lockTx)
|
|
},
|
|
),
|
|
db.EXPECT().InTx(gomock.Any(), database.DefaultTXOptions().WithID("chat_title_regenerate_unlock")).DoAndReturn(
|
|
func(fn func(database.Store) error, _ *database.TxOptions) error {
|
|
return fn(unlockTx)
|
|
},
|
|
),
|
|
)
|
|
lockTx.EXPECT().GetChatByIDForUpdate(gomock.Any(), chatID).Return(lockedChat, nil)
|
|
unlockTx.EXPECT().GetChatByIDForUpdate(gomock.Any(), chatID).Return(lockedChat, nil)
|
|
}
|
|
|
|
t.Run("WritesAndReturnsWroteTrue", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
|
|
chatID := uuid.New()
|
|
workerID := uuid.New()
|
|
stored := database.Chat{
|
|
ID: chatID,
|
|
Status: database.ChatStatusRunning,
|
|
WorkerID: uuid.NullUUID{UUID: workerID, Valid: true},
|
|
Title: "original",
|
|
}
|
|
updated := stored
|
|
updated.Title = "renamed"
|
|
|
|
server := &Server{db: db, logger: logger}
|
|
|
|
setupRealWorkerLock(db, chatID, stored)
|
|
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(stored, nil)
|
|
db.EXPECT().UpdateChatTitleByID(gomock.Any(), database.UpdateChatTitleByIDParams{
|
|
ID: chatID,
|
|
Title: "renamed",
|
|
}).Return(updated, nil)
|
|
|
|
got, wrote, err := server.RenameChatTitle(ctx, stored, "renamed")
|
|
require.NoError(t, err)
|
|
require.True(t, wrote, "fresh rename must report wrote=true")
|
|
require.Equal(t, updated, got)
|
|
})
|
|
|
|
t.Run("SkipsWriteWhenAlreadyAtNewTitle", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
|
|
chatID := uuid.New()
|
|
workerID := uuid.New()
|
|
stale := database.Chat{
|
|
ID: chatID,
|
|
Status: database.ChatStatusRunning,
|
|
WorkerID: uuid.NullUUID{UUID: workerID, Valid: true},
|
|
Title: "pre-race",
|
|
}
|
|
landed := stale
|
|
landed.Title = "landed-concurrently"
|
|
|
|
server := &Server{db: db, logger: logger}
|
|
|
|
setupRealWorkerLock(db, chatID, landed)
|
|
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(landed, nil)
|
|
|
|
got, wrote, err := server.RenameChatTitle(ctx, stale, "landed-concurrently")
|
|
require.NoError(t, err)
|
|
require.False(t, wrote,
|
|
"must report wrote=false when the stored row already matches newTitle so the handler suppresses a redundant title_change event")
|
|
require.Equal(t, landed, got)
|
|
})
|
|
}
|
|
|
|
func TestRegenerateChatTitle_PersistsAndBroadcasts(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
lockTx := dbmock.NewMockStore(ctrl)
|
|
usageTx := dbmock.NewMockStore(ctrl)
|
|
unlockTx := dbmock.NewMockStore(ctrl)
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
pubsub := dbpubsub.NewInMemory()
|
|
clock := quartz.NewReal()
|
|
|
|
ownerID := uuid.New()
|
|
chatID := uuid.New()
|
|
modelConfigID := uuid.New()
|
|
workerID := uuid.New()
|
|
userPrompt := "review pull request 23633 and fix review threads"
|
|
wantTitle := "Review PR 23633"
|
|
|
|
chat := database.Chat{
|
|
ID: chatID,
|
|
OwnerID: ownerID,
|
|
LastModelConfigID: modelConfigID,
|
|
Status: database.ChatStatusRunning,
|
|
WorkerID: uuid.NullUUID{UUID: workerID, Valid: true},
|
|
Title: fallbackChatTitle(userPrompt),
|
|
}
|
|
modelConfig := database.ChatModelConfig{
|
|
ID: modelConfigID,
|
|
Provider: "openai",
|
|
Model: "gpt-4o-mini",
|
|
ContextLimit: 8192,
|
|
}
|
|
updatedChat := chat
|
|
updatedChat.Title = wantTitle
|
|
|
|
messageEvents := make(chan struct {
|
|
payload codersdk.ChatWatchEvent
|
|
err error
|
|
}, 1)
|
|
cancelSub, err := pubsub.SubscribeWithErr(
|
|
coderdpubsub.ChatWatchEventChannel(ownerID),
|
|
coderdpubsub.HandleChatWatchEvent(func(_ context.Context, payload codersdk.ChatWatchEvent, err error) {
|
|
messageEvents <- struct {
|
|
payload codersdk.ChatWatchEvent
|
|
err error
|
|
}{payload: payload, err: err}
|
|
}),
|
|
)
|
|
require.NoError(t, err)
|
|
defer cancelSub()
|
|
|
|
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
require.Equal(t, "gpt-4o-mini", req.Model)
|
|
return chattest.OpenAINonStreamingResponse("{\"title\":\"" + wantTitle + "\"}")
|
|
})
|
|
|
|
server := &Server{
|
|
db: db,
|
|
logger: logger,
|
|
pubsub: pubsub,
|
|
configCache: newChatConfigCache(context.Background(), db, clock),
|
|
}
|
|
|
|
db.EXPECT().GetChatModelConfigByID(gomock.Any(), modelConfigID).Return(modelConfig, nil)
|
|
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{
|
|
Provider: "openai",
|
|
CentralApiKeyEnabled: true,
|
|
APIKey: "test-key",
|
|
BaseUrl: serverURL,
|
|
}}, nil)
|
|
db.EXPECT().GetChatUsageLimitConfig(gomock.Any()).Return(database.ChatUsageLimitConfig{}, sql.ErrNoRows)
|
|
db.EXPECT().GetChatMessagesByChatIDAscPaginated(
|
|
gomock.Any(),
|
|
database.GetChatMessagesByChatIDAscPaginatedParams{
|
|
ChatID: chatID,
|
|
AfterID: 0,
|
|
LimitVal: manualTitleMessageWindowLimit,
|
|
},
|
|
).Return([]database.ChatMessage{
|
|
mustChatMessage(
|
|
t,
|
|
database.ChatMessageRoleUser,
|
|
database.ChatMessageVisibilityBoth,
|
|
codersdk.ChatMessageText(userPrompt),
|
|
),
|
|
mustChatMessage(
|
|
t,
|
|
database.ChatMessageRoleAssistant,
|
|
database.ChatMessageVisibilityBoth,
|
|
codersdk.ChatMessageText("checking the diff now"),
|
|
),
|
|
}, nil)
|
|
db.EXPECT().GetChatMessagesByChatIDDescPaginated(
|
|
gomock.Any(),
|
|
database.GetChatMessagesByChatIDDescPaginatedParams{
|
|
ChatID: chatID,
|
|
BeforeID: 0,
|
|
LimitVal: manualTitleMessageWindowLimit,
|
|
},
|
|
).Return(nil, nil)
|
|
db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return("", nil)
|
|
db.EXPECT().GetEnabledChatModelConfigs(gomock.Any()).Return(nil, nil)
|
|
|
|
gomock.InOrder(
|
|
db.EXPECT().InTx(gomock.Any(), database.DefaultTXOptions().WithID("chat_title_regenerate_lock")).DoAndReturn(
|
|
func(fn func(database.Store) error, opts *database.TxOptions) error {
|
|
require.Equal(t, "chat_title_regenerate_lock", opts.TxIdentifier)
|
|
return fn(lockTx)
|
|
},
|
|
),
|
|
db.EXPECT().InTx(gomock.Any(), nil).DoAndReturn(
|
|
func(fn func(database.Store) error, opts *database.TxOptions) error {
|
|
require.Nil(t, opts)
|
|
return fn(usageTx)
|
|
},
|
|
),
|
|
db.EXPECT().InTx(gomock.Any(), database.DefaultTXOptions().WithID("chat_title_regenerate_unlock")).DoAndReturn(
|
|
func(fn func(database.Store) error, opts *database.TxOptions) error {
|
|
require.Equal(t, "chat_title_regenerate_unlock", opts.TxIdentifier)
|
|
return fn(unlockTx)
|
|
},
|
|
),
|
|
)
|
|
|
|
lockTx.EXPECT().GetChatByIDForUpdate(gomock.Any(), chatID).Return(chat, nil)
|
|
|
|
usageTx.EXPECT().GetChatByIDForUpdate(gomock.Any(), chatID).Return(chat, nil)
|
|
usageTx.EXPECT().InsertChatMessages(gomock.Any(), gomock.AssignableToTypeOf(database.InsertChatMessagesParams{})).DoAndReturn(
|
|
func(_ context.Context, arg database.InsertChatMessagesParams) ([]database.ChatMessage, error) {
|
|
require.Equal(t, []uuid.UUID{ownerID}, arg.CreatedBy)
|
|
require.Equal(t, []uuid.UUID{modelConfigID}, arg.ModelConfigID)
|
|
require.Equal(t, []string{"[]"}, arg.Content)
|
|
return []database.ChatMessage{{ID: 91}}, nil
|
|
},
|
|
)
|
|
usageTx.EXPECT().SoftDeleteChatMessageByID(gomock.Any(), int64(91)).Return(nil)
|
|
usageTx.EXPECT().UpdateChatByID(gomock.Any(), database.UpdateChatByIDParams{
|
|
ID: chatID,
|
|
Title: wantTitle,
|
|
}).Return(updatedChat, nil)
|
|
|
|
unlockTx.EXPECT().GetChatByIDForUpdate(gomock.Any(), chatID).Return(updatedChat, nil)
|
|
|
|
gotChat, err := server.RegenerateChatTitle(ctx, chat)
|
|
require.NoError(t, err)
|
|
require.Equal(t, updatedChat, gotChat)
|
|
|
|
select {
|
|
case event := <-messageEvents:
|
|
require.NoError(t, event.err)
|
|
require.Equal(t, codersdk.ChatWatchEventKindTitleChange, event.payload.Kind)
|
|
require.Equal(t, chatID, event.payload.Chat.ID)
|
|
require.Equal(t, wantTitle, event.payload.Chat.Title)
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timed out waiting for title change pubsub event")
|
|
}
|
|
}
|
|
|
|
func TestRegenerateChatTitle_PersistsAndBroadcasts_IdleChatReleasesManualLock(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
lockTx := dbmock.NewMockStore(ctrl)
|
|
usageTx := dbmock.NewMockStore(ctrl)
|
|
unlockTx := dbmock.NewMockStore(ctrl)
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
pubsub := dbpubsub.NewInMemory()
|
|
clock := quartz.NewReal()
|
|
|
|
ownerID := uuid.New()
|
|
chatID := uuid.New()
|
|
modelConfigID := uuid.New()
|
|
userPrompt := "review pull request 23633 and fix review threads"
|
|
wantTitle := "Review PR 23633"
|
|
|
|
chat := database.Chat{
|
|
ID: chatID,
|
|
OwnerID: ownerID,
|
|
LastModelConfigID: modelConfigID,
|
|
Status: database.ChatStatusCompleted,
|
|
Title: fallbackChatTitle(userPrompt),
|
|
}
|
|
lockedChat := chat
|
|
lockedChat.WorkerID = uuid.NullUUID{UUID: manualTitleLockWorkerID, Valid: true}
|
|
lockedChat.StartedAt = sql.NullTime{Time: time.Now(), Valid: true}
|
|
modelConfig := database.ChatModelConfig{
|
|
ID: modelConfigID,
|
|
Provider: "openai",
|
|
Model: "gpt-4o-mini",
|
|
ContextLimit: 8192,
|
|
}
|
|
updatedChat := lockedChat
|
|
updatedChat.Title = wantTitle
|
|
unlockedChat := updatedChat
|
|
unlockedChat.WorkerID = uuid.NullUUID{}
|
|
unlockedChat.StartedAt = sql.NullTime{}
|
|
|
|
messageEvents := make(chan struct {
|
|
payload codersdk.ChatWatchEvent
|
|
err error
|
|
}, 1)
|
|
cancelSub, err := pubsub.SubscribeWithErr(
|
|
coderdpubsub.ChatWatchEventChannel(ownerID),
|
|
coderdpubsub.HandleChatWatchEvent(func(_ context.Context, payload codersdk.ChatWatchEvent, err error) {
|
|
messageEvents <- struct {
|
|
payload codersdk.ChatWatchEvent
|
|
err error
|
|
}{payload: payload, err: err}
|
|
}),
|
|
)
|
|
require.NoError(t, err)
|
|
defer cancelSub()
|
|
|
|
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
require.Equal(t, "gpt-4o-mini", req.Model)
|
|
return chattest.OpenAINonStreamingResponse("{\"title\":\"" + wantTitle + "\"}")
|
|
})
|
|
|
|
server := &Server{
|
|
db: db,
|
|
logger: logger,
|
|
pubsub: pubsub,
|
|
configCache: newChatConfigCache(context.Background(), db, clock),
|
|
}
|
|
|
|
db.EXPECT().GetChatModelConfigByID(gomock.Any(), modelConfigID).Return(modelConfig, nil)
|
|
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{
|
|
Provider: "openai",
|
|
CentralApiKeyEnabled: true,
|
|
APIKey: "test-key",
|
|
BaseUrl: serverURL,
|
|
}}, nil)
|
|
db.EXPECT().GetChatUsageLimitConfig(gomock.Any()).Return(database.ChatUsageLimitConfig{}, sql.ErrNoRows)
|
|
db.EXPECT().GetChatMessagesByChatIDAscPaginated(
|
|
gomock.Any(),
|
|
database.GetChatMessagesByChatIDAscPaginatedParams{
|
|
ChatID: chatID,
|
|
AfterID: 0,
|
|
LimitVal: manualTitleMessageWindowLimit,
|
|
},
|
|
).Return([]database.ChatMessage{
|
|
mustChatMessage(
|
|
t,
|
|
database.ChatMessageRoleUser,
|
|
database.ChatMessageVisibilityBoth,
|
|
codersdk.ChatMessageText(userPrompt),
|
|
),
|
|
mustChatMessage(
|
|
t,
|
|
database.ChatMessageRoleAssistant,
|
|
database.ChatMessageVisibilityBoth,
|
|
codersdk.ChatMessageText("checking the diff now"),
|
|
),
|
|
}, nil)
|
|
db.EXPECT().GetChatMessagesByChatIDDescPaginated(
|
|
gomock.Any(),
|
|
database.GetChatMessagesByChatIDDescPaginatedParams{
|
|
ChatID: chatID,
|
|
BeforeID: 0,
|
|
LimitVal: manualTitleMessageWindowLimit,
|
|
},
|
|
).Return(nil, nil)
|
|
db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return("", nil)
|
|
db.EXPECT().GetEnabledChatModelConfigs(gomock.Any()).Return(nil, nil)
|
|
|
|
gomock.InOrder(
|
|
db.EXPECT().InTx(gomock.Any(), database.DefaultTXOptions().WithID("chat_title_regenerate_lock")).DoAndReturn(
|
|
func(fn func(database.Store) error, opts *database.TxOptions) error {
|
|
require.Equal(t, "chat_title_regenerate_lock", opts.TxIdentifier)
|
|
return fn(lockTx)
|
|
},
|
|
),
|
|
db.EXPECT().InTx(gomock.Any(), nil).DoAndReturn(
|
|
func(fn func(database.Store) error, opts *database.TxOptions) error {
|
|
require.Nil(t, opts)
|
|
return fn(usageTx)
|
|
},
|
|
),
|
|
db.EXPECT().InTx(gomock.Any(), database.DefaultTXOptions().WithID("chat_title_regenerate_unlock")).DoAndReturn(
|
|
func(fn func(database.Store) error, opts *database.TxOptions) error {
|
|
require.Equal(t, "chat_title_regenerate_unlock", opts.TxIdentifier)
|
|
return fn(unlockTx)
|
|
},
|
|
),
|
|
)
|
|
|
|
lockTx.EXPECT().GetChatByIDForUpdate(gomock.Any(), chatID).Return(chat, nil)
|
|
lockTx.EXPECT().UpdateChatStatusPreserveUpdatedAt(
|
|
gomock.Any(),
|
|
gomock.AssignableToTypeOf(database.UpdateChatStatusPreserveUpdatedAtParams{}),
|
|
).DoAndReturn(func(_ context.Context, arg database.UpdateChatStatusPreserveUpdatedAtParams) (database.Chat, error) {
|
|
require.Equal(t, chat.ID, arg.ID)
|
|
require.Equal(t, chat.Status, arg.Status)
|
|
require.Equal(t, uuid.NullUUID{UUID: manualTitleLockWorkerID, Valid: true}, arg.WorkerID)
|
|
require.True(t, arg.StartedAt.Valid)
|
|
require.WithinDuration(t, time.Now(), arg.StartedAt.Time, time.Second)
|
|
require.False(t, arg.HeartbeatAt.Valid)
|
|
require.Equal(t, chat.LastError, arg.LastError)
|
|
require.Equal(t, chat.UpdatedAt, arg.UpdatedAt)
|
|
return lockedChat, nil
|
|
})
|
|
|
|
usageTx.EXPECT().GetChatByIDForUpdate(gomock.Any(), chatID).Return(lockedChat, nil)
|
|
usageTx.EXPECT().InsertChatMessages(gomock.Any(), gomock.AssignableToTypeOf(database.InsertChatMessagesParams{})).DoAndReturn(
|
|
func(_ context.Context, arg database.InsertChatMessagesParams) ([]database.ChatMessage, error) {
|
|
require.Equal(t, []uuid.UUID{ownerID}, arg.CreatedBy)
|
|
require.Equal(t, []uuid.UUID{modelConfigID}, arg.ModelConfigID)
|
|
require.Equal(t, []string{"[]"}, arg.Content)
|
|
return []database.ChatMessage{{ID: 91}}, nil
|
|
},
|
|
)
|
|
usageTx.EXPECT().SoftDeleteChatMessageByID(gomock.Any(), int64(91)).Return(nil)
|
|
usageTx.EXPECT().UpdateChatByID(gomock.Any(), database.UpdateChatByIDParams{
|
|
ID: chatID,
|
|
Title: wantTitle,
|
|
}).Return(updatedChat, nil)
|
|
|
|
unlockTx.EXPECT().GetChatByIDForUpdate(gomock.Any(), chatID).Return(updatedChat, nil)
|
|
unlockTx.EXPECT().UpdateChatStatusPreserveUpdatedAt(
|
|
gomock.Any(),
|
|
database.UpdateChatStatusPreserveUpdatedAtParams{
|
|
ID: updatedChat.ID,
|
|
Status: updatedChat.Status,
|
|
WorkerID: uuid.NullUUID{},
|
|
StartedAt: sql.NullTime{},
|
|
HeartbeatAt: sql.NullTime{},
|
|
LastError: updatedChat.LastError,
|
|
UpdatedAt: updatedChat.UpdatedAt,
|
|
},
|
|
).Return(unlockedChat, nil)
|
|
|
|
gotChat, err := server.RegenerateChatTitle(ctx, chat)
|
|
require.NoError(t, err)
|
|
require.Equal(t, updatedChat, gotChat)
|
|
|
|
select {
|
|
case event := <-messageEvents:
|
|
require.NoError(t, event.err)
|
|
require.Equal(t, codersdk.ChatWatchEventKindTitleChange, event.payload.Kind)
|
|
require.Equal(t, chatID, event.payload.Chat.ID)
|
|
require.Equal(t, wantTitle, event.payload.Chat.Title)
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timed out waiting for title change pubsub event")
|
|
}
|
|
}
|
|
|
|
func TestResolveUserProviderAPIKeys_StripsDisabledFallbackKeys(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
ownerID := uuid.New()
|
|
|
|
server := &Server{
|
|
db: db,
|
|
configCache: newChatConfigCache(
|
|
context.Background(),
|
|
db,
|
|
quartz.NewReal(),
|
|
),
|
|
providerAPIKeys: chatprovider.ProviderAPIKeys{
|
|
OpenAI: "openai-deployment-key",
|
|
Anthropic: "anthropic-deployment-key",
|
|
ByProvider: map[string]string{
|
|
"openai": "openai-deployment-key",
|
|
"anthropic": "anthropic-deployment-key",
|
|
},
|
|
BaseURLByProvider: map[string]string{
|
|
"openai": "https://openai.example.com",
|
|
"anthropic": "https://anthropic.example.com",
|
|
},
|
|
},
|
|
}
|
|
|
|
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{
|
|
Provider: "anthropic",
|
|
CentralApiKeyEnabled: true,
|
|
AllowCentralApiKeyFallback: true,
|
|
}}, nil)
|
|
|
|
keys, err := server.resolveUserProviderAPIKeys(ctx, ownerID)
|
|
require.NoError(t, err)
|
|
require.Empty(t, keys.OpenAI)
|
|
require.Empty(t, keys.APIKey("openai"))
|
|
require.Empty(t, keys.BaseURL("openai"))
|
|
require.Equal(t, "anthropic-deployment-key", keys.Anthropic)
|
|
require.Equal(t, "anthropic-deployment-key", keys.APIKey("anthropic"))
|
|
require.Equal(t, "https://anthropic.example.com", keys.BaseURL("anthropic"))
|
|
require.Equal(t, map[string]string{"anthropic": "anthropic-deployment-key"}, keys.ByProvider)
|
|
require.Equal(t, map[string]string{"anthropic": "https://anthropic.example.com"}, keys.BaseURLByProvider)
|
|
}
|
|
|
|
func TestResolveUserProviderAPIKeys_SkipsUserKeyLookupWhenNoProviderAllowsUserKeys(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
ownerID := uuid.New()
|
|
|
|
server := &Server{
|
|
db: db,
|
|
configCache: newChatConfigCache(
|
|
context.Background(),
|
|
db,
|
|
quartz.NewReal(),
|
|
),
|
|
providerAPIKeys: chatprovider.ProviderAPIKeys{
|
|
OpenAI: "openai-deployment-key",
|
|
ByProvider: map[string]string{
|
|
"openai": "openai-deployment-key",
|
|
},
|
|
},
|
|
}
|
|
|
|
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{
|
|
Provider: "openai",
|
|
CentralApiKeyEnabled: true,
|
|
}}, nil)
|
|
|
|
keys, err := server.resolveUserProviderAPIKeys(ctx, ownerID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, "openai-deployment-key", keys.OpenAI)
|
|
require.Equal(t, "openai-deployment-key", keys.APIKey("openai"))
|
|
}
|
|
|
|
func TestRefreshChatWorkspaceSnapshot_NoReloadWhenWorkspacePresent(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
workspaceID := uuid.New()
|
|
chat := database.Chat{
|
|
ID: uuid.New(),
|
|
WorkspaceID: uuid.NullUUID{
|
|
UUID: workspaceID,
|
|
Valid: true,
|
|
},
|
|
}
|
|
|
|
calls := 0
|
|
refreshed, err := refreshChatWorkspaceSnapshot(
|
|
context.Background(),
|
|
chat,
|
|
func(context.Context, uuid.UUID) (database.Chat, error) {
|
|
calls++
|
|
return database.Chat{}, nil
|
|
},
|
|
)
|
|
require.NoError(t, err)
|
|
require.Equal(t, chat, refreshed)
|
|
require.Equal(t, 0, calls)
|
|
}
|
|
|
|
func TestRefreshChatWorkspaceSnapshot_ReloadsWhenWorkspaceMissing(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
chatID := uuid.New()
|
|
workspaceID := uuid.New()
|
|
chat := database.Chat{ID: chatID}
|
|
reloaded := database.Chat{
|
|
ID: chatID,
|
|
WorkspaceID: uuid.NullUUID{
|
|
UUID: workspaceID,
|
|
Valid: true,
|
|
},
|
|
}
|
|
|
|
calls := 0
|
|
refreshed, err := refreshChatWorkspaceSnapshot(
|
|
context.Background(),
|
|
chat,
|
|
func(_ context.Context, id uuid.UUID) (database.Chat, error) {
|
|
calls++
|
|
require.Equal(t, chatID, id)
|
|
return reloaded, nil
|
|
},
|
|
)
|
|
require.NoError(t, err)
|
|
require.Equal(t, reloaded, refreshed)
|
|
require.Equal(t, 1, calls)
|
|
}
|
|
|
|
func TestRefreshChatWorkspaceSnapshot_ReturnsReloadError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
chat := database.Chat{ID: uuid.New()}
|
|
loadErr := xerrors.New("boom")
|
|
|
|
refreshed, err := refreshChatWorkspaceSnapshot(
|
|
context.Background(),
|
|
chat,
|
|
func(context.Context, uuid.UUID) (database.Chat, error) {
|
|
return database.Chat{}, loadErr
|
|
},
|
|
)
|
|
require.Error(t, err)
|
|
require.ErrorContains(t, err, "reload chat workspace state")
|
|
require.ErrorContains(t, err, loadErr.Error())
|
|
require.Equal(t, chat, refreshed)
|
|
}
|
|
|
|
func TestPersistInstructionFilesIncludesAgentMetadata(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := context.Background()
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
workspaceID := uuid.New()
|
|
agentID := uuid.New()
|
|
chat := database.Chat{
|
|
ID: uuid.New(),
|
|
WorkspaceID: uuid.NullUUID{
|
|
UUID: workspaceID,
|
|
Valid: true,
|
|
},
|
|
AgentID: uuid.NullUUID{
|
|
UUID: agentID,
|
|
Valid: true,
|
|
},
|
|
}
|
|
workspaceAgent := database.WorkspaceAgent{
|
|
ID: agentID,
|
|
OperatingSystem: "linux",
|
|
Directory: "/home/coder/project",
|
|
ExpandedDirectory: "/home/coder/project",
|
|
}
|
|
|
|
db.EXPECT().GetWorkspaceAgentByID(
|
|
gomock.Any(),
|
|
agentID,
|
|
).Return(workspaceAgent, nil).Times(1)
|
|
db.EXPECT().InsertChatMessages(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
|
|
db.EXPECT().UpdateChatLastInjectedContext(gomock.Any(),
|
|
gomock.Cond(func(x any) bool {
|
|
arg, ok := x.(database.UpdateChatLastInjectedContextParams)
|
|
if !ok || arg.ID != chat.ID {
|
|
return false
|
|
}
|
|
if !arg.LastInjectedContext.Valid {
|
|
return false
|
|
}
|
|
var parts []codersdk.ChatMessagePart
|
|
if err := json.Unmarshal(arg.LastInjectedContext.RawMessage, &parts); err != nil {
|
|
return false
|
|
}
|
|
// Expect at least one context-file part for the
|
|
// working-directory AGENTS.md, with internal fields
|
|
// stripped (no content, OS, or directory).
|
|
for _, p := range parts {
|
|
if p.Type == codersdk.ChatMessagePartTypeContextFile && p.ContextFilePath != "" {
|
|
return p.ContextFileContent == "" &&
|
|
p.ContextFileOS == "" &&
|
|
p.ContextFileDirectory == ""
|
|
}
|
|
}
|
|
return false
|
|
}),
|
|
).Return(database.Chat{}, nil).Times(1)
|
|
|
|
conn := agentconnmock.NewMockAgentConn(ctrl)
|
|
conn.EXPECT().SetExtraHeaders(gomock.Any()).Times(1)
|
|
conn.EXPECT().ContextConfig(gomock.Any()).Return(workspacesdk.ContextConfigResponse{
|
|
Parts: []codersdk.ChatMessagePart{{
|
|
Type: codersdk.ChatMessagePartTypeContextFile,
|
|
ContextFilePath: "/home/coder/project/AGENTS.md",
|
|
ContextFileContent: "# Project instructions",
|
|
}},
|
|
}, nil).AnyTimes()
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
server := &Server{
|
|
db: db,
|
|
logger: logger,
|
|
clock: quartz.NewReal(),
|
|
instructionLookupTimeout: 5 * time.Second,
|
|
agentInactiveDisconnectTimeout: 30 * time.Second,
|
|
dialTimeout: 30 * time.Second,
|
|
agentConnFn: func(context.Context, uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
|
return conn, func() {}, nil
|
|
},
|
|
}
|
|
|
|
chatStateMu := &sync.Mutex{}
|
|
currentChat := chat
|
|
workspaceCtx := turnWorkspaceContext{
|
|
server: server,
|
|
chatStateMu: chatStateMu,
|
|
currentChat: ¤tChat,
|
|
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
|
|
}
|
|
t.Cleanup(workspaceCtx.close)
|
|
|
|
instruction, _, err := server.persistInstructionFiles(
|
|
ctx,
|
|
chat,
|
|
uuid.New(),
|
|
workspaceCtx.getWorkspaceAgent,
|
|
workspaceCtx.getWorkspaceConn,
|
|
)
|
|
require.NoError(t, err)
|
|
require.Contains(t, instruction, "Operating System: linux")
|
|
require.Contains(t, instruction, "Working Directory: /home/coder/project")
|
|
}
|
|
|
|
func TestPersistInstructionFilesSkipsSentinelWhenWorkspaceUnavailable(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := context.Background()
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
chat := database.Chat{
|
|
ID: uuid.New(),
|
|
WorkspaceID: uuid.NullUUID{
|
|
UUID: uuid.New(),
|
|
Valid: true,
|
|
},
|
|
}
|
|
server := &Server{
|
|
db: db,
|
|
logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
|
|
}
|
|
|
|
instruction, _, err := server.persistInstructionFiles(
|
|
ctx,
|
|
chat,
|
|
uuid.New(),
|
|
func(context.Context) (database.WorkspaceAgent, error) {
|
|
return database.WorkspaceAgent{
|
|
ID: uuid.New(),
|
|
Directory: "/home/coder/project",
|
|
}, nil
|
|
},
|
|
func(context.Context) (workspacesdk.AgentConn, error) {
|
|
return nil, errChatHasNoWorkspaceAgent
|
|
},
|
|
)
|
|
require.NoError(t, err)
|
|
require.Empty(t, instruction)
|
|
}
|
|
|
|
func TestPersistInstructionFilesSentinelWithSkills(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := context.Background()
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
workspaceID := uuid.New()
|
|
agentID := uuid.New()
|
|
chat := database.Chat{
|
|
ID: uuid.New(),
|
|
WorkspaceID: uuid.NullUUID{
|
|
UUID: workspaceID,
|
|
Valid: true,
|
|
},
|
|
AgentID: uuid.NullUUID{
|
|
UUID: agentID,
|
|
Valid: true,
|
|
},
|
|
}
|
|
workspaceAgent := database.WorkspaceAgent{
|
|
ID: agentID,
|
|
OperatingSystem: "linux",
|
|
Directory: "/home/coder/project",
|
|
ExpandedDirectory: "/home/coder/project",
|
|
}
|
|
|
|
db.EXPECT().GetWorkspaceAgentByID(
|
|
gomock.Any(),
|
|
agentID,
|
|
).Return(workspaceAgent, nil).Times(1)
|
|
db.EXPECT().InsertChatMessages(gomock.Any(),
|
|
gomock.Cond(func(x any) bool {
|
|
arg, ok := x.(database.InsertChatMessagesParams)
|
|
if !ok || arg.ChatID != chat.ID || len(arg.Content) != 1 {
|
|
return false
|
|
}
|
|
var parts []codersdk.ChatMessagePart
|
|
if err := json.Unmarshal([]byte(arg.Content[0]), &parts); err != nil {
|
|
return false
|
|
}
|
|
foundMarker := false
|
|
foundSkill := false
|
|
for _, p := range parts {
|
|
switch p.Type {
|
|
case codersdk.ChatMessagePartTypeContextFile:
|
|
if p.ContextFileAgentID == (uuid.NullUUID{UUID: agentID, Valid: true}) && p.ContextFileContent == "" {
|
|
foundMarker = true
|
|
}
|
|
case codersdk.ChatMessagePartTypeSkill:
|
|
if p.SkillName == "my-skill" && p.ContextFileAgentID == (uuid.NullUUID{UUID: agentID, Valid: true}) {
|
|
foundSkill = true
|
|
}
|
|
}
|
|
}
|
|
return foundMarker && foundSkill
|
|
}),
|
|
).Return(nil, nil).Times(1)
|
|
db.EXPECT().UpdateChatLastInjectedContext(gomock.Any(),
|
|
gomock.Cond(func(x any) bool {
|
|
arg, ok := x.(database.UpdateChatLastInjectedContextParams)
|
|
if !ok || arg.ID != chat.ID {
|
|
return false
|
|
}
|
|
if !arg.LastInjectedContext.Valid {
|
|
return false
|
|
}
|
|
var parts []codersdk.ChatMessagePart
|
|
if err := json.Unmarshal(arg.LastInjectedContext.RawMessage, &parts); err != nil {
|
|
return false
|
|
}
|
|
// The sentinel path should persist only skill parts
|
|
// with ContextFileAgentID set.
|
|
for _, p := range parts {
|
|
if p.Type == codersdk.ChatMessagePartTypeSkill &&
|
|
p.SkillName == "my-skill" &&
|
|
p.ContextFileAgentID == (uuid.NullUUID{UUID: agentID, Valid: true}) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}),
|
|
).Return(database.Chat{}, nil).Times(1)
|
|
|
|
conn := agentconnmock.NewMockAgentConn(ctrl)
|
|
conn.EXPECT().SetExtraHeaders(gomock.Any()).Times(1)
|
|
conn.EXPECT().ContextConfig(gomock.Any()).Return(workspacesdk.ContextConfigResponse{
|
|
// Agent returns pre-read content: no instruction files
|
|
// found but one skill discovered.
|
|
Parts: []codersdk.ChatMessagePart{{
|
|
Type: codersdk.ChatMessagePartTypeSkill,
|
|
SkillName: "my-skill",
|
|
SkillDescription: "A test skill",
|
|
SkillDir: "/home/coder/project/.agents/skills/my-skill",
|
|
}},
|
|
}, nil).AnyTimes()
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
server := &Server{
|
|
db: db,
|
|
logger: logger,
|
|
clock: quartz.NewReal(),
|
|
instructionLookupTimeout: 5 * time.Second,
|
|
agentInactiveDisconnectTimeout: 30 * time.Second,
|
|
dialTimeout: 30 * time.Second,
|
|
agentConnFn: func(context.Context, uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
|
return conn, func() {}, nil
|
|
},
|
|
}
|
|
|
|
chatStateMu := &sync.Mutex{}
|
|
currentChat := chat
|
|
workspaceCtx := turnWorkspaceContext{
|
|
server: server,
|
|
chatStateMu: chatStateMu,
|
|
currentChat: ¤tChat,
|
|
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
|
|
}
|
|
t.Cleanup(workspaceCtx.close)
|
|
|
|
instruction, skills, err := server.persistInstructionFiles(
|
|
ctx,
|
|
chat,
|
|
uuid.New(),
|
|
workspaceCtx.getWorkspaceAgent,
|
|
workspaceCtx.getWorkspaceConn,
|
|
)
|
|
require.NoError(t, err)
|
|
// Sentinel path returns empty instruction string.
|
|
require.Empty(t, instruction)
|
|
// Skills are still discovered and returned.
|
|
require.Len(t, skills, 1)
|
|
require.Equal(t, "my-skill", skills[0].Name)
|
|
}
|
|
|
|
func TestPersistInstructionFilesSentinelNoSkillsClearsColumn(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := context.Background()
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
workspaceID := uuid.New()
|
|
agentID := uuid.New()
|
|
chat := database.Chat{
|
|
ID: uuid.New(),
|
|
WorkspaceID: uuid.NullUUID{
|
|
UUID: workspaceID,
|
|
Valid: true,
|
|
},
|
|
AgentID: uuid.NullUUID{
|
|
UUID: agentID,
|
|
Valid: true,
|
|
},
|
|
}
|
|
workspaceAgent := database.WorkspaceAgent{
|
|
ID: agentID,
|
|
OperatingSystem: "linux",
|
|
Directory: "/home/coder/project",
|
|
ExpandedDirectory: "/home/coder/project",
|
|
}
|
|
|
|
db.EXPECT().GetWorkspaceAgentByID(
|
|
gomock.Any(),
|
|
agentID,
|
|
).Return(workspaceAgent, nil).Times(1)
|
|
db.EXPECT().InsertChatMessages(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
|
|
db.EXPECT().UpdateChatLastInjectedContext(gomock.Any(),
|
|
gomock.Cond(func(x any) bool {
|
|
arg, ok := x.(database.UpdateChatLastInjectedContextParams)
|
|
if !ok || arg.ID != chat.ID {
|
|
return false
|
|
}
|
|
// No skills discovered, so the column should be
|
|
// cleared to NULL.
|
|
return !arg.LastInjectedContext.Valid
|
|
}),
|
|
).Return(database.Chat{}, nil).Times(1)
|
|
|
|
conn := agentconnmock.NewMockAgentConn(ctrl)
|
|
conn.EXPECT().SetExtraHeaders(gomock.Any()).Times(1)
|
|
conn.EXPECT().ContextConfig(gomock.Any()).Return(workspacesdk.ContextConfigResponse{
|
|
// Agent returns pre-read content: no files, no skills.
|
|
Parts: []codersdk.ChatMessagePart{},
|
|
}, nil).AnyTimes()
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
server := &Server{
|
|
db: db,
|
|
logger: logger,
|
|
clock: quartz.NewReal(),
|
|
instructionLookupTimeout: 5 * time.Second,
|
|
agentInactiveDisconnectTimeout: 30 * time.Second,
|
|
dialTimeout: 30 * time.Second,
|
|
agentConnFn: func(context.Context, uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
|
return conn, func() {}, nil
|
|
},
|
|
}
|
|
|
|
chatStateMu := &sync.Mutex{}
|
|
currentChat := chat
|
|
workspaceCtx := turnWorkspaceContext{
|
|
server: server,
|
|
chatStateMu: chatStateMu,
|
|
currentChat: ¤tChat,
|
|
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
|
|
}
|
|
t.Cleanup(workspaceCtx.close)
|
|
|
|
instruction, skills, err := server.persistInstructionFiles(
|
|
ctx,
|
|
chat,
|
|
uuid.New(),
|
|
workspaceCtx.getWorkspaceAgent,
|
|
workspaceCtx.getWorkspaceConn,
|
|
)
|
|
require.NoError(t, err)
|
|
// Sentinel path: empty instruction, no skills.
|
|
require.Empty(t, instruction)
|
|
require.Empty(t, skills)
|
|
}
|
|
|
|
func TestTurnWorkspaceContext_BindingFirstPath(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := context.Background()
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
workspaceID := uuid.New()
|
|
agentID := uuid.New()
|
|
chat := database.Chat{
|
|
ID: uuid.New(),
|
|
WorkspaceID: uuid.NullUUID{
|
|
UUID: workspaceID,
|
|
Valid: true,
|
|
},
|
|
AgentID: uuid.NullUUID{
|
|
UUID: agentID,
|
|
Valid: true,
|
|
},
|
|
}
|
|
workspaceAgent := database.WorkspaceAgent{ID: agentID}
|
|
|
|
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).Return(workspaceAgent, nil).Times(1)
|
|
|
|
chatStateMu := &sync.Mutex{}
|
|
currentChat := chat
|
|
workspaceCtx := turnWorkspaceContext{
|
|
server: &Server{db: db},
|
|
chatStateMu: chatStateMu,
|
|
currentChat: ¤tChat,
|
|
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
|
|
}
|
|
t.Cleanup(workspaceCtx.close)
|
|
|
|
chatSnapshot, agent, err := workspaceCtx.ensureWorkspaceAgent(ctx)
|
|
require.NoError(t, err)
|
|
require.Equal(t, chat, chatSnapshot)
|
|
require.Equal(t, workspaceAgent, agent)
|
|
|
|
gotAgent, err := workspaceCtx.getWorkspaceAgent(ctx)
|
|
require.NoError(t, err)
|
|
require.Equal(t, workspaceAgent, gotAgent)
|
|
require.Equal(t, chat, currentChat)
|
|
}
|
|
|
|
func TestTurnWorkspaceContext_NullBindingLazyBind(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := context.Background()
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
workspaceID := uuid.New()
|
|
buildID := uuid.New()
|
|
agentID := uuid.New()
|
|
chat := database.Chat{
|
|
ID: uuid.New(),
|
|
WorkspaceID: uuid.NullUUID{
|
|
UUID: workspaceID,
|
|
Valid: true,
|
|
},
|
|
}
|
|
workspaceAgent := database.WorkspaceAgent{ID: agentID}
|
|
updatedChat := chat
|
|
updatedChat.BuildID = uuid.NullUUID{UUID: buildID, Valid: true}
|
|
updatedChat.AgentID = uuid.NullUUID{UUID: agentID, Valid: true}
|
|
|
|
gomock.InOrder(
|
|
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).Return([]database.WorkspaceAgent{workspaceAgent}, nil),
|
|
db.EXPECT().GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), workspaceID).Return(database.WorkspaceBuild{ID: buildID}, nil),
|
|
db.EXPECT().UpdateChatBuildAgentBinding(gomock.Any(), database.UpdateChatBuildAgentBindingParams{
|
|
BuildID: uuid.NullUUID{UUID: buildID, Valid: true},
|
|
AgentID: uuid.NullUUID{UUID: agentID, Valid: true},
|
|
ID: chat.ID,
|
|
}).Return(updatedChat, nil),
|
|
)
|
|
|
|
chatStateMu := &sync.Mutex{}
|
|
currentChat := chat
|
|
workspaceCtx := turnWorkspaceContext{
|
|
server: &Server{db: db},
|
|
chatStateMu: chatStateMu,
|
|
currentChat: ¤tChat,
|
|
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
|
|
}
|
|
t.Cleanup(workspaceCtx.close)
|
|
|
|
chatSnapshot, agent, err := workspaceCtx.ensureWorkspaceAgent(ctx)
|
|
require.NoError(t, err)
|
|
require.Equal(t, updatedChat, chatSnapshot)
|
|
require.Equal(t, workspaceAgent, agent)
|
|
require.Equal(t, updatedChat, currentChat)
|
|
|
|
gotAgent, err := workspaceCtx.getWorkspaceAgent(ctx)
|
|
require.NoError(t, err)
|
|
require.Equal(t, workspaceAgent, gotAgent)
|
|
}
|
|
|
|
func TestTurnWorkspaceContext_StaleBindingRepair(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := context.Background()
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
workspaceID := uuid.New()
|
|
staleAgentID := uuid.New()
|
|
buildID := uuid.New()
|
|
currentAgentID := uuid.New()
|
|
chat := database.Chat{
|
|
ID: uuid.New(),
|
|
WorkspaceID: uuid.NullUUID{
|
|
UUID: workspaceID,
|
|
Valid: true,
|
|
},
|
|
AgentID: uuid.NullUUID{
|
|
UUID: staleAgentID,
|
|
Valid: true,
|
|
},
|
|
}
|
|
currentAgent := database.WorkspaceAgent{ID: currentAgentID}
|
|
updatedChat := chat
|
|
updatedChat.BuildID = uuid.NullUUID{UUID: buildID, Valid: true}
|
|
updatedChat.AgentID = uuid.NullUUID{UUID: currentAgentID, Valid: true}
|
|
|
|
gomock.InOrder(
|
|
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), staleAgentID).Return(database.WorkspaceAgent{}, xerrors.New("missing agent")),
|
|
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).Return([]database.WorkspaceAgent{currentAgent}, nil),
|
|
db.EXPECT().GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), workspaceID).Return(database.WorkspaceBuild{ID: buildID}, nil),
|
|
db.EXPECT().UpdateChatBuildAgentBinding(gomock.Any(), database.UpdateChatBuildAgentBindingParams{
|
|
BuildID: uuid.NullUUID{UUID: buildID, Valid: true},
|
|
AgentID: uuid.NullUUID{UUID: currentAgentID, Valid: true},
|
|
ID: chat.ID,
|
|
}).Return(updatedChat, nil),
|
|
)
|
|
|
|
chatStateMu := &sync.Mutex{}
|
|
currentChat := chat
|
|
workspaceCtx := turnWorkspaceContext{
|
|
server: &Server{db: db},
|
|
chatStateMu: chatStateMu,
|
|
currentChat: ¤tChat,
|
|
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
|
|
}
|
|
t.Cleanup(workspaceCtx.close)
|
|
|
|
chatSnapshot, agent, err := workspaceCtx.ensureWorkspaceAgent(ctx)
|
|
require.NoError(t, err)
|
|
require.Equal(t, updatedChat, chatSnapshot)
|
|
require.Equal(t, currentAgent, agent)
|
|
require.Equal(t, updatedChat, currentChat)
|
|
}
|
|
|
|
func TestTurnWorkspaceContextGetWorkspaceConnLazyValidationSwitchesWorkspaceAgent(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := context.Background()
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
workspaceID := uuid.New()
|
|
staleAgentID := uuid.New()
|
|
currentAgentID := uuid.New()
|
|
buildID := uuid.New()
|
|
chat := database.Chat{
|
|
ID: uuid.New(),
|
|
WorkspaceID: uuid.NullUUID{
|
|
UUID: workspaceID,
|
|
Valid: true,
|
|
},
|
|
AgentID: uuid.NullUUID{
|
|
UUID: staleAgentID,
|
|
Valid: true,
|
|
},
|
|
}
|
|
staleAgent := database.WorkspaceAgent{ID: staleAgentID}
|
|
currentAgent := database.WorkspaceAgent{ID: currentAgentID}
|
|
updatedChat := chat
|
|
updatedChat.BuildID = uuid.NullUUID{UUID: buildID, Valid: true}
|
|
updatedChat.AgentID = uuid.NullUUID{UUID: currentAgentID, Valid: true}
|
|
|
|
gomock.InOrder(
|
|
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), staleAgentID).Return(staleAgent, nil),
|
|
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).Return([]database.WorkspaceAgent{currentAgent}, nil),
|
|
db.EXPECT().GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), workspaceID).Return(database.WorkspaceBuild{ID: buildID}, nil),
|
|
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), currentAgentID).Return(currentAgent, nil),
|
|
db.EXPECT().UpdateChatBuildAgentBinding(gomock.Any(), database.UpdateChatBuildAgentBindingParams{
|
|
BuildID: uuid.NullUUID{UUID: buildID, Valid: true},
|
|
AgentID: uuid.NullUUID{UUID: currentAgentID, Valid: true},
|
|
ID: chat.ID,
|
|
}).Return(updatedChat, nil),
|
|
)
|
|
|
|
conn := agentconnmock.NewMockAgentConn(ctrl)
|
|
conn.EXPECT().SetExtraHeaders(gomock.Any()).Times(1)
|
|
|
|
var dialed []uuid.UUID
|
|
server := &Server{
|
|
db: db,
|
|
clock: quartz.NewReal(),
|
|
agentInactiveDisconnectTimeout: 30 * time.Second,
|
|
dialTimeout: 30 * time.Second,
|
|
}
|
|
server.agentConnFn = func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
|
dialed = append(dialed, agentID)
|
|
if agentID == staleAgentID {
|
|
return nil, nil, xerrors.New("dial failed")
|
|
}
|
|
return conn, func() {}, nil
|
|
}
|
|
|
|
chatStateMu := &sync.Mutex{}
|
|
currentChat := chat
|
|
workspaceCtx := turnWorkspaceContext{
|
|
server: server,
|
|
chatStateMu: chatStateMu,
|
|
currentChat: ¤tChat,
|
|
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
|
|
}
|
|
t.Cleanup(workspaceCtx.close)
|
|
|
|
gotConn, err := workspaceCtx.getWorkspaceConn(ctx)
|
|
require.NoError(t, err)
|
|
require.Same(t, conn, gotConn)
|
|
require.Equal(t, []uuid.UUID{staleAgentID, currentAgentID}, dialed)
|
|
require.Equal(t, updatedChat, currentChat)
|
|
|
|
gotAgent, err := workspaceCtx.getWorkspaceAgent(ctx)
|
|
require.NoError(t, err)
|
|
require.Equal(t, currentAgent, gotAgent)
|
|
}
|
|
|
|
func TestTurnWorkspaceContextGetWorkspaceConnFastFailsWithoutCurrentAgent(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := context.Background()
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
workspaceID := uuid.New()
|
|
staleAgentID := uuid.New()
|
|
chat := database.Chat{
|
|
ID: uuid.New(),
|
|
WorkspaceID: uuid.NullUUID{
|
|
UUID: workspaceID,
|
|
Valid: true,
|
|
},
|
|
AgentID: uuid.NullUUID{
|
|
UUID: staleAgentID,
|
|
Valid: true,
|
|
},
|
|
}
|
|
|
|
staleAgent := database.WorkspaceAgent{ID: staleAgentID}
|
|
|
|
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), staleAgentID).
|
|
Return(staleAgent, nil).
|
|
Times(1)
|
|
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).
|
|
Return([]database.WorkspaceAgent{}, nil).
|
|
Times(1)
|
|
|
|
server := &Server{
|
|
db: db,
|
|
clock: quartz.NewReal(),
|
|
agentInactiveDisconnectTimeout: 30 * time.Second,
|
|
dialTimeout: 30 * time.Second,
|
|
}
|
|
server.agentConnFn = func(context.Context, uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
|
return nil, nil, xerrors.New("dial failed")
|
|
}
|
|
|
|
chatStateMu := &sync.Mutex{}
|
|
currentChat := chat
|
|
workspaceCtx := turnWorkspaceContext{
|
|
server: server,
|
|
chatStateMu: chatStateMu,
|
|
currentChat: ¤tChat,
|
|
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
|
|
}
|
|
defer workspaceCtx.close()
|
|
|
|
gotConn, err := workspaceCtx.getWorkspaceConn(ctx)
|
|
require.Nil(t, gotConn)
|
|
require.ErrorIs(t, err, errChatHasNoWorkspaceAgent)
|
|
|
|
workspaceCtx.mu.Lock()
|
|
defer workspaceCtx.mu.Unlock()
|
|
require.Equal(t, database.WorkspaceAgent{}, workspaceCtx.agent)
|
|
require.False(t, workspaceCtx.agentLoaded)
|
|
require.Nil(t, workspaceCtx.conn)
|
|
require.Nil(t, workspaceCtx.releaseConn)
|
|
require.Equal(t, uuid.NullUUID{}, workspaceCtx.cachedWorkspaceID)
|
|
}
|
|
|
|
func TestTurnWorkspaceContext_SelectWorkspaceClearsCachedState(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
currentChat := database.Chat{
|
|
ID: uuid.New(),
|
|
WorkspaceID: uuid.NullUUID{
|
|
UUID: uuid.New(),
|
|
Valid: true,
|
|
},
|
|
}
|
|
updatedChat := database.Chat{
|
|
ID: currentChat.ID,
|
|
WorkspaceID: uuid.NullUUID{
|
|
UUID: uuid.New(),
|
|
Valid: true,
|
|
},
|
|
}
|
|
cachedConn := agentconnmock.NewMockAgentConn(ctrl)
|
|
releaseCalls := 0
|
|
|
|
workspaceCtx := turnWorkspaceContext{
|
|
chatStateMu: &sync.Mutex{},
|
|
currentChat: ¤tChat,
|
|
}
|
|
workspaceCtx.agent = database.WorkspaceAgent{ID: uuid.New()}
|
|
workspaceCtx.agentLoaded = true
|
|
workspaceCtx.conn = cachedConn
|
|
workspaceCtx.cachedWorkspaceID = currentChat.WorkspaceID
|
|
workspaceCtx.releaseConn = func() {
|
|
releaseCalls++
|
|
}
|
|
|
|
workspaceCtx.selectWorkspace(updatedChat)
|
|
|
|
require.Equal(t, updatedChat, currentChat)
|
|
require.Equal(t, 1, releaseCalls)
|
|
|
|
workspaceCtx.mu.Lock()
|
|
defer workspaceCtx.mu.Unlock()
|
|
require.Equal(t, database.WorkspaceAgent{}, workspaceCtx.agent)
|
|
require.False(t, workspaceCtx.agentLoaded)
|
|
require.Nil(t, workspaceCtx.conn)
|
|
require.Nil(t, workspaceCtx.releaseConn)
|
|
require.Equal(t, uuid.NullUUID{}, workspaceCtx.cachedWorkspaceID)
|
|
}
|
|
|
|
func TestTurnWorkspaceContext_EnsureWorkspaceAgentIgnoresCachedAgentForDifferentWorkspace(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := context.Background()
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
workspaceOneID := uuid.New()
|
|
workspaceTwoID := uuid.New()
|
|
buildID := uuid.New()
|
|
cachedAgent := database.WorkspaceAgent{ID: uuid.New()}
|
|
resolvedAgent := database.WorkspaceAgent{ID: uuid.New()}
|
|
chat := database.Chat{
|
|
ID: uuid.New(),
|
|
WorkspaceID: uuid.NullUUID{
|
|
UUID: workspaceTwoID,
|
|
Valid: true,
|
|
},
|
|
}
|
|
updatedChat := chat
|
|
updatedChat.BuildID = uuid.NullUUID{UUID: buildID, Valid: true}
|
|
updatedChat.AgentID = uuid.NullUUID{UUID: resolvedAgent.ID, Valid: true}
|
|
|
|
gomock.InOrder(
|
|
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceTwoID).Return([]database.WorkspaceAgent{resolvedAgent}, nil),
|
|
db.EXPECT().GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), workspaceTwoID).Return(database.WorkspaceBuild{ID: buildID}, nil),
|
|
db.EXPECT().UpdateChatBuildAgentBinding(gomock.Any(), database.UpdateChatBuildAgentBindingParams{
|
|
ID: chat.ID,
|
|
BuildID: uuid.NullUUID{UUID: buildID, Valid: true},
|
|
AgentID: uuid.NullUUID{UUID: resolvedAgent.ID, Valid: true},
|
|
}).Return(updatedChat, nil),
|
|
)
|
|
|
|
chatStateMu := &sync.Mutex{}
|
|
currentChat := chat
|
|
workspaceCtx := turnWorkspaceContext{
|
|
server: &Server{db: db},
|
|
chatStateMu: chatStateMu,
|
|
currentChat: ¤tChat,
|
|
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
|
|
}
|
|
workspaceCtx.agent = cachedAgent
|
|
workspaceCtx.agentLoaded = true
|
|
workspaceCtx.cachedWorkspaceID = uuid.NullUUID{UUID: workspaceOneID, Valid: true}
|
|
defer workspaceCtx.close()
|
|
|
|
chatSnapshot, agent, err := workspaceCtx.ensureWorkspaceAgent(ctx)
|
|
require.NoError(t, err)
|
|
require.Equal(t, updatedChat, chatSnapshot)
|
|
require.Equal(t, resolvedAgent, agent)
|
|
require.Equal(t, updatedChat, currentChat)
|
|
}
|
|
|
|
func TestSubscribeSkipsDatabaseCatchupForLocallyDeliveredMessage(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancelCtx := context.WithCancel(context.Background())
|
|
defer cancelCtx()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
chatID := uuid.New()
|
|
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
|
|
initialMessage := database.ChatMessage{
|
|
ID: 1,
|
|
ChatID: chatID,
|
|
Role: database.ChatMessageRoleUser,
|
|
}
|
|
localMessage := database.ChatMessage{
|
|
ID: 2,
|
|
ChatID: chatID,
|
|
Role: database.ChatMessageRoleAssistant,
|
|
}
|
|
|
|
gomock.InOrder(
|
|
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
|
ChatID: chatID,
|
|
AfterID: 0,
|
|
}).Return([]database.ChatMessage{initialMessage}, nil),
|
|
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
|
|
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
|
)
|
|
|
|
server := newSubscribeTestServer(t, db)
|
|
_, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
|
|
require.True(t, ok)
|
|
defer cancel()
|
|
|
|
server.publishMessage(chatID, localMessage)
|
|
|
|
event := requireStreamMessageEvent(t, events)
|
|
require.Equal(t, int64(2), event.Message.ID)
|
|
requireNoStreamEvent(t, events, 200*time.Millisecond)
|
|
}
|
|
|
|
func TestSubscribeUsesDurableCacheWhenLocalMessageWasNotDelivered(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancelCtx := context.WithCancel(context.Background())
|
|
defer cancelCtx()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
chatID := uuid.New()
|
|
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
|
|
initialMessage := database.ChatMessage{
|
|
ID: 1,
|
|
ChatID: chatID,
|
|
Role: database.ChatMessageRoleUser,
|
|
}
|
|
cachedMessage := codersdk.ChatMessage{
|
|
ID: 2,
|
|
ChatID: chatID,
|
|
Role: codersdk.ChatMessageRoleAssistant,
|
|
}
|
|
|
|
gomock.InOrder(
|
|
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
|
ChatID: chatID,
|
|
AfterID: 0,
|
|
}).Return([]database.ChatMessage{initialMessage}, nil),
|
|
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
|
|
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
|
)
|
|
|
|
server := newSubscribeTestServer(t, db)
|
|
server.cacheDurableMessage(chatID, codersdk.ChatStreamEvent{
|
|
Type: codersdk.ChatStreamEventTypeMessage,
|
|
ChatID: chatID,
|
|
Message: &cachedMessage,
|
|
})
|
|
|
|
_, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
|
|
require.True(t, ok)
|
|
defer cancel()
|
|
|
|
server.publishChatStreamNotify(chatID, coderdpubsub.ChatStreamNotifyMessage{
|
|
AfterMessageID: 1,
|
|
})
|
|
|
|
event := requireStreamMessageEvent(t, events)
|
|
require.Equal(t, int64(2), event.Message.ID)
|
|
requireNoStreamEvent(t, events, 200*time.Millisecond)
|
|
}
|
|
|
|
func TestSubscribeQueriesDatabaseWhenDurableCacheMisses(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancelCtx := context.WithCancel(context.Background())
|
|
defer cancelCtx()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
chatID := uuid.New()
|
|
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
|
|
initialMessage := database.ChatMessage{
|
|
ID: 1,
|
|
ChatID: chatID,
|
|
Role: database.ChatMessageRoleUser,
|
|
}
|
|
catchupMessage := database.ChatMessage{
|
|
ID: 2,
|
|
ChatID: chatID,
|
|
Role: database.ChatMessageRoleAssistant,
|
|
}
|
|
|
|
gomock.InOrder(
|
|
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
|
ChatID: chatID,
|
|
AfterID: 0,
|
|
}).Return([]database.ChatMessage{initialMessage}, nil),
|
|
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
|
|
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
|
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
|
ChatID: chatID,
|
|
AfterID: 1,
|
|
}).Return([]database.ChatMessage{catchupMessage}, nil),
|
|
)
|
|
|
|
server := newSubscribeTestServer(t, db)
|
|
_, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
|
|
require.True(t, ok)
|
|
defer cancel()
|
|
|
|
server.publishChatStreamNotify(chatID, coderdpubsub.ChatStreamNotifyMessage{
|
|
AfterMessageID: 1,
|
|
})
|
|
|
|
event := requireStreamMessageEvent(t, events)
|
|
require.Equal(t, int64(2), event.Message.ID)
|
|
requireNoStreamEvent(t, events, 200*time.Millisecond)
|
|
}
|
|
|
|
func TestSubscribeFullRefreshStillUsesDatabaseCatchup(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancelCtx := context.WithCancel(context.Background())
|
|
defer cancelCtx()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
chatID := uuid.New()
|
|
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
|
|
initialMessage := database.ChatMessage{
|
|
ID: 1,
|
|
ChatID: chatID,
|
|
Role: database.ChatMessageRoleUser,
|
|
}
|
|
editedMessage := database.ChatMessage{
|
|
ID: 1,
|
|
ChatID: chatID,
|
|
Role: database.ChatMessageRoleUser,
|
|
}
|
|
|
|
gomock.InOrder(
|
|
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
|
ChatID: chatID,
|
|
AfterID: 0,
|
|
}).Return([]database.ChatMessage{initialMessage}, nil),
|
|
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
|
|
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
|
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
|
ChatID: chatID,
|
|
AfterID: 0,
|
|
}).Return([]database.ChatMessage{editedMessage}, nil),
|
|
)
|
|
|
|
server := newSubscribeTestServer(t, db)
|
|
_, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
|
|
require.True(t, ok)
|
|
defer cancel()
|
|
|
|
server.publishEditedMessage(chatID, editedMessage)
|
|
|
|
event := requireStreamMessageEvent(t, events)
|
|
require.Equal(t, int64(1), event.Message.ID)
|
|
requireNoStreamEvent(t, events, 200*time.Millisecond)
|
|
}
|
|
|
|
func TestSubscribeDeliversRetryEventViaPubsubOnce(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancelCtx := context.WithCancel(context.Background())
|
|
defer cancelCtx()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
chatID := uuid.New()
|
|
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
|
|
|
|
gomock.InOrder(
|
|
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
|
ChatID: chatID,
|
|
AfterID: 0,
|
|
}).Return(nil, nil),
|
|
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
|
|
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
|
)
|
|
|
|
server := newSubscribeTestServer(t, db)
|
|
_, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
|
|
require.True(t, ok)
|
|
defer cancel()
|
|
|
|
expected := newTestRetryPayload()
|
|
|
|
server.publishRetry(chatID, expected)
|
|
|
|
event := requireStreamRetryEvent(t, events)
|
|
require.Equal(t, expected, event.Retry)
|
|
requireNoStreamEvent(t, events, 200*time.Millisecond)
|
|
}
|
|
|
|
func TestSubscribeReplaysCurrentRetryPhaseInSnapshot(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancelCtx := context.WithCancel(context.Background())
|
|
defer cancelCtx()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
chatID := uuid.New()
|
|
chat := database.Chat{ID: chatID, Status: database.ChatStatusRunning}
|
|
|
|
gomock.InOrder(
|
|
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
|
ChatID: chatID,
|
|
AfterID: 0,
|
|
}).Return(nil, nil),
|
|
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
|
|
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
|
)
|
|
|
|
server := newBufferedSubscribeTestServer(t, db, chatID)
|
|
|
|
expected := newTestRetryPayload()
|
|
server.publishRetry(chatID, expected)
|
|
|
|
snapshot, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
|
|
require.True(t, ok)
|
|
defer cancel()
|
|
|
|
require.Len(t, snapshot, 2)
|
|
require.Equal(t, codersdk.ChatStreamEventTypeStatus, snapshot[0].Type)
|
|
require.Equal(t, codersdk.ChatStreamEventTypeRetry, snapshot[1].Type)
|
|
event := requireSnapshotRetryEvent(t, snapshot)
|
|
require.Equal(t, expected, event.Retry)
|
|
requireNoStreamEvent(t, events, 200*time.Millisecond)
|
|
}
|
|
|
|
func TestSubscribeCapturesRetryPhaseAtSubscriptionBoundary(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancelCtx := context.WithCancel(context.Background())
|
|
defer cancelCtx()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
chatID := uuid.New()
|
|
chat := database.Chat{ID: chatID, Status: database.ChatStatusRunning}
|
|
expected := newTestRetryPayload()
|
|
|
|
server := newSubscribeTestServer(t, db)
|
|
|
|
gomock.InOrder(
|
|
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
|
ChatID: chatID,
|
|
AfterID: 0,
|
|
}).DoAndReturn(func(context.Context, database.GetChatMessagesByChatIDParams) ([]database.ChatMessage, error) {
|
|
server.publishRetry(chatID, expected)
|
|
return nil, nil
|
|
}),
|
|
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
|
|
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
|
)
|
|
|
|
snapshot, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
|
|
require.True(t, ok)
|
|
defer cancel()
|
|
|
|
requireNoSnapshotRetryEvent(t, snapshot)
|
|
event := requireStreamRetryEvent(t, events)
|
|
require.Equal(t, expected, event.Retry)
|
|
requireNoStreamEvent(t, events, 200*time.Millisecond)
|
|
}
|
|
|
|
func TestSubscribeDoesNotReplayRetryAfterStreamResumes(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancelCtx := context.WithCancel(context.Background())
|
|
defer cancelCtx()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
chatID := uuid.New()
|
|
chat := database.Chat{ID: chatID, Status: database.ChatStatusRunning}
|
|
|
|
gomock.InOrder(
|
|
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
|
ChatID: chatID,
|
|
AfterID: 0,
|
|
}).Return(nil, nil),
|
|
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
|
|
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
|
)
|
|
|
|
server := newBufferedSubscribeTestServer(t, db, chatID)
|
|
|
|
server.publishRetry(chatID, newTestRetryPayload())
|
|
server.publishMessagePart(chatID, codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageText("retry recovered"))
|
|
|
|
snapshot, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
|
|
require.True(t, ok)
|
|
defer cancel()
|
|
|
|
requireNoSnapshotRetryEvent(t, snapshot)
|
|
requireSnapshotMessagePartEvent(t, snapshot)
|
|
requireNoStreamEvent(t, events, 200*time.Millisecond)
|
|
}
|
|
|
|
func TestSubscribeDoesNotReplayRetryAfterTerminalError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancelCtx := context.WithCancel(context.Background())
|
|
defer cancelCtx()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
chatID := uuid.New()
|
|
chat := database.Chat{ID: chatID, Status: database.ChatStatusRunning}
|
|
|
|
gomock.InOrder(
|
|
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
|
ChatID: chatID,
|
|
AfterID: 0,
|
|
}).Return(nil, nil),
|
|
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
|
|
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
|
)
|
|
|
|
server := newBufferedSubscribeTestServer(t, db, chatID)
|
|
|
|
server.publishRetry(chatID, newTestRetryPayload())
|
|
server.publishError(chatID, chaterror.ClassifiedError{
|
|
Message: "OpenAI is rate limiting requests.",
|
|
Kind: chaterror.KindRateLimit,
|
|
Provider: "openai",
|
|
Retryable: true,
|
|
StatusCode: 429,
|
|
})
|
|
|
|
snapshot, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
|
|
require.True(t, ok)
|
|
defer cancel()
|
|
|
|
requireNoSnapshotRetryEvent(t, snapshot)
|
|
requireNoStreamEvent(t, events, 200*time.Millisecond)
|
|
}
|
|
|
|
func TestSubscribeDoesNotReplayRetryAfterTerminalStatus(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancelCtx := context.WithCancel(context.Background())
|
|
defer cancelCtx()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
chatID := uuid.New()
|
|
chat := database.Chat{ID: chatID, Status: database.ChatStatusCompleted}
|
|
|
|
gomock.InOrder(
|
|
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
|
ChatID: chatID,
|
|
AfterID: 0,
|
|
}).Return(nil, nil),
|
|
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
|
|
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
|
)
|
|
|
|
server := newBufferedSubscribeTestServer(t, db, chatID)
|
|
|
|
server.publishRetry(chatID, newTestRetryPayload())
|
|
server.publishStatus(chatID, database.ChatStatusCompleted, uuid.NullUUID{})
|
|
|
|
snapshot, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
|
|
require.True(t, ok)
|
|
defer cancel()
|
|
|
|
requireNoSnapshotRetryEvent(t, snapshot)
|
|
requireNoStreamEvent(t, events, 200*time.Millisecond)
|
|
}
|
|
|
|
func TestSubscribePrefersStructuredErrorPayloadViaPubsub(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancelCtx := context.WithCancel(context.Background())
|
|
defer cancelCtx()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
chatID := uuid.New()
|
|
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
|
|
|
|
gomock.InOrder(
|
|
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
|
ChatID: chatID,
|
|
AfterID: 0,
|
|
}).Return(nil, nil),
|
|
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
|
|
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
|
)
|
|
|
|
server := newSubscribeTestServer(t, db)
|
|
_, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
|
|
require.True(t, ok)
|
|
defer cancel()
|
|
|
|
classified := chaterror.ClassifiedError{
|
|
Message: "OpenAI is rate limiting requests.",
|
|
Kind: chaterror.KindRateLimit,
|
|
Provider: "openai",
|
|
Retryable: true,
|
|
StatusCode: 429,
|
|
}
|
|
server.publishError(chatID, classified)
|
|
|
|
event := requireStreamErrorEvent(t, events)
|
|
require.Equal(t, chaterror.TerminalErrorPayload(classified), event.Error)
|
|
requireNoStreamEvent(t, events, 200*time.Millisecond)
|
|
}
|
|
|
|
func TestSubscribeFallsBackToLegacyErrorStringViaPubsub(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancelCtx := context.WithCancel(context.Background())
|
|
defer cancelCtx()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
chatID := uuid.New()
|
|
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
|
|
|
|
gomock.InOrder(
|
|
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
|
ChatID: chatID,
|
|
AfterID: 0,
|
|
}).Return(nil, nil),
|
|
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
|
|
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
|
)
|
|
|
|
server := newSubscribeTestServer(t, db)
|
|
_, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
|
|
require.True(t, ok)
|
|
defer cancel()
|
|
|
|
server.publishChatStreamNotify(chatID, coderdpubsub.ChatStreamNotifyMessage{
|
|
Error: "legacy error only",
|
|
})
|
|
|
|
event := requireStreamErrorEvent(t, events)
|
|
require.Equal(t, &codersdk.ChatError{Message: "legacy error only"}, event.Error)
|
|
requireNoStreamEvent(t, events, 200*time.Millisecond)
|
|
}
|
|
|
|
func newTestRetryPayload() *codersdk.ChatStreamRetry {
|
|
payload := chaterror.StreamRetryPayload(1, 1500*time.Millisecond, chaterror.ClassifiedError{
|
|
Message: "OpenAI is rate limiting requests.",
|
|
Kind: chaterror.KindRateLimit,
|
|
Provider: "openai",
|
|
Retryable: true,
|
|
StatusCode: 429,
|
|
})
|
|
if payload == nil {
|
|
panic("expected retry payload")
|
|
}
|
|
payload.RetryingAt = time.Unix(1_700_000_000, 0).UTC()
|
|
return payload
|
|
}
|
|
|
|
func newSubscribeTestServer(t *testing.T, db database.Store) *Server {
|
|
t.Helper()
|
|
|
|
return &Server{
|
|
db: db,
|
|
logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
|
|
pubsub: dbpubsub.NewInMemory(),
|
|
}
|
|
}
|
|
|
|
func newBufferedSubscribeTestServer(t *testing.T, db database.Store, chatID uuid.UUID) *Server {
|
|
t.Helper()
|
|
|
|
server := newSubscribeTestServer(t, db)
|
|
state := server.getOrCreateStreamState(chatID)
|
|
state.mu.Lock()
|
|
state.buffering = true
|
|
state.mu.Unlock()
|
|
return server
|
|
}
|
|
|
|
func requireStreamMessageEvent(t *testing.T, events <-chan codersdk.ChatStreamEvent) codersdk.ChatStreamEvent {
|
|
t.Helper()
|
|
|
|
select {
|
|
case event, ok := <-events:
|
|
require.True(t, ok, "chat stream closed before delivering an event")
|
|
require.Equal(t, codersdk.ChatStreamEventTypeMessage, event.Type)
|
|
require.NotNil(t, event.Message)
|
|
return event
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timed out waiting for chat stream message event")
|
|
return codersdk.ChatStreamEvent{}
|
|
}
|
|
}
|
|
|
|
func requireStreamRetryEvent(t *testing.T, events <-chan codersdk.ChatStreamEvent) codersdk.ChatStreamEvent {
|
|
t.Helper()
|
|
|
|
select {
|
|
case event, ok := <-events:
|
|
require.True(t, ok, "chat stream closed before delivering an event")
|
|
require.Equal(t, codersdk.ChatStreamEventTypeRetry, event.Type)
|
|
require.NotNil(t, event.Retry)
|
|
return event
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timed out waiting for chat stream retry event")
|
|
return codersdk.ChatStreamEvent{}
|
|
}
|
|
}
|
|
|
|
func requireSnapshotRetryEvent(t *testing.T, snapshot []codersdk.ChatStreamEvent) codersdk.ChatStreamEvent {
|
|
t.Helper()
|
|
|
|
var retryEvents []codersdk.ChatStreamEvent
|
|
for _, event := range snapshot {
|
|
if event.Type == codersdk.ChatStreamEventTypeRetry {
|
|
retryEvents = append(retryEvents, event)
|
|
}
|
|
}
|
|
|
|
require.Len(t, retryEvents, 1, "expected exactly one retry event in snapshot")
|
|
require.NotNil(t, retryEvents[0].Retry)
|
|
return retryEvents[0]
|
|
}
|
|
|
|
func requireNoSnapshotRetryEvent(t *testing.T, snapshot []codersdk.ChatStreamEvent) {
|
|
t.Helper()
|
|
|
|
for _, event := range snapshot {
|
|
require.NotEqual(t, codersdk.ChatStreamEventTypeRetry, event.Type,
|
|
"unexpected retry event in snapshot: %+v", event)
|
|
}
|
|
}
|
|
|
|
func requireSnapshotMessagePartEvent(t *testing.T, snapshot []codersdk.ChatStreamEvent) codersdk.ChatStreamEvent {
|
|
t.Helper()
|
|
|
|
for _, event := range snapshot {
|
|
if event.Type == codersdk.ChatStreamEventTypeMessagePart {
|
|
require.NotNil(t, event.MessagePart)
|
|
return event
|
|
}
|
|
}
|
|
|
|
t.Fatal("expected message_part event in snapshot")
|
|
return codersdk.ChatStreamEvent{}
|
|
}
|
|
|
|
func requireStreamErrorEvent(t *testing.T, events <-chan codersdk.ChatStreamEvent) codersdk.ChatStreamEvent {
|
|
t.Helper()
|
|
|
|
select {
|
|
case event, ok := <-events:
|
|
require.True(t, ok, "chat stream closed before delivering an event")
|
|
require.Equal(t, codersdk.ChatStreamEventTypeError, event.Type)
|
|
require.NotNil(t, event.Error)
|
|
return event
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timed out waiting for chat stream error event")
|
|
return codersdk.ChatStreamEvent{}
|
|
}
|
|
}
|
|
|
|
func requireNoStreamEvent(t *testing.T, events <-chan codersdk.ChatStreamEvent, wait time.Duration) {
|
|
t.Helper()
|
|
|
|
select {
|
|
case event, ok := <-events:
|
|
if !ok {
|
|
t.Fatal("chat stream closed unexpectedly")
|
|
}
|
|
t.Fatalf("unexpected chat stream event: %+v", event)
|
|
case <-time.After(wait):
|
|
}
|
|
}
|
|
|
|
// TestPublishToStream_DropWarnRateLimiting walks through a
|
|
// realistic lifecycle: buffer fills up, subscriber channel fills
|
|
// up, counters get reset between steps. It verifies that WARN
|
|
// logs are rate-limited to at most once per streamDropWarnInterval
|
|
// and that counter resets re-enable an immediate WARN.
|
|
func TestPublishToStream_DropWarnRateLimiting(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
sink := testutil.NewFakeSink(t)
|
|
mClock := quartz.NewMock(t)
|
|
|
|
server := &Server{
|
|
logger: sink.Logger(),
|
|
clock: mClock,
|
|
}
|
|
|
|
chatID := uuid.New()
|
|
subCh := make(chan codersdk.ChatStreamEvent, 1)
|
|
subCh <- codersdk.ChatStreamEvent{} // pre-fill so sends always drop
|
|
|
|
// Set up state that mirrors a running chat: buffer at capacity,
|
|
// buffering enabled, one saturated subscriber.
|
|
state := &chatStreamState{
|
|
buffering: true,
|
|
buffer: make([]codersdk.ChatStreamEvent, maxStreamBufferSize),
|
|
subscribers: map[uuid.UUID]chan codersdk.ChatStreamEvent{
|
|
uuid.New(): subCh,
|
|
},
|
|
}
|
|
server.chatStreams.Store(chatID, state)
|
|
|
|
bufferMsg := "chat stream buffer full, dropping oldest event"
|
|
subMsg := "dropping chat stream event"
|
|
|
|
filter := func(level slog.Level, msg string) func(slog.SinkEntry) bool {
|
|
return func(e slog.SinkEntry) bool {
|
|
return e.Level == level && e.Message == msg
|
|
}
|
|
}
|
|
|
|
// --- Phase 1: buffer-full rate limiting ---
|
|
// message_part events hit both the buffer-full and subscriber-full
|
|
// paths. The first publish triggers a WARN for each; the rest
|
|
// within the window are DEBUG.
|
|
partEvent := codersdk.ChatStreamEvent{
|
|
Type: codersdk.ChatStreamEventTypeMessagePart,
|
|
MessagePart: &codersdk.ChatStreamMessagePart{},
|
|
}
|
|
for i := 0; i < 50; i++ {
|
|
server.publishToStream(chatID, partEvent)
|
|
}
|
|
|
|
require.Len(t, sink.Entries(filter(slog.LevelWarn, bufferMsg)), 1)
|
|
require.Empty(t, sink.Entries(filter(slog.LevelDebug, bufferMsg)))
|
|
requireFieldValue(t, sink.Entries(filter(slog.LevelWarn, bufferMsg))[0], "dropped_count", int64(1))
|
|
|
|
// Subscriber also saw 50 drops (one per publish).
|
|
require.Len(t, sink.Entries(filter(slog.LevelWarn, subMsg)), 1)
|
|
require.Empty(t, sink.Entries(filter(slog.LevelDebug, subMsg)))
|
|
requireFieldValue(t, sink.Entries(filter(slog.LevelWarn, subMsg))[0], "dropped_count", int64(1))
|
|
|
|
// --- Phase 2: clock advance triggers second WARN with count ---
|
|
mClock.Advance(streamDropWarnInterval + time.Second)
|
|
server.publishToStream(chatID, partEvent)
|
|
|
|
bufWarn := sink.Entries(filter(slog.LevelWarn, bufferMsg))
|
|
require.Len(t, bufWarn, 2)
|
|
requireFieldValue(t, bufWarn[1], "dropped_count", int64(50))
|
|
|
|
subWarn := sink.Entries(filter(slog.LevelWarn, subMsg))
|
|
require.Len(t, subWarn, 2)
|
|
requireFieldValue(t, subWarn[1], "dropped_count", int64(50))
|
|
|
|
// --- Phase 3: counter reset (simulates step persist) ---
|
|
state.mu.Lock()
|
|
state.buffer = make([]codersdk.ChatStreamEvent, maxStreamBufferSize)
|
|
state.resetDropCounters()
|
|
state.mu.Unlock()
|
|
|
|
// The very next drop should WARN immediately — the reset zeroed
|
|
// lastWarnAt so the interval check passes.
|
|
server.publishToStream(chatID, partEvent)
|
|
|
|
bufWarn = sink.Entries(filter(slog.LevelWarn, bufferMsg))
|
|
require.Len(t, bufWarn, 3, "expected WARN immediately after counter reset")
|
|
requireFieldValue(t, bufWarn[2], "dropped_count", int64(1))
|
|
|
|
subWarn = sink.Entries(filter(slog.LevelWarn, subMsg))
|
|
require.Len(t, subWarn, 3, "expected subscriber WARN immediately after counter reset")
|
|
requireFieldValue(t, subWarn[2], "dropped_count", int64(1))
|
|
}
|
|
|
|
func TestResolveUserCompactionThreshold(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
userID := uuid.New()
|
|
modelConfigID := uuid.New()
|
|
expectedKey := codersdk.CompactionThresholdKey(modelConfigID)
|
|
|
|
tests := []struct {
|
|
name string
|
|
dbReturn string
|
|
dbErr error
|
|
wantVal int32
|
|
wantOK bool
|
|
wantWarnLog bool
|
|
}{
|
|
{
|
|
name: "NoRowsReturnsDefault",
|
|
dbErr: sql.ErrNoRows,
|
|
wantOK: false,
|
|
},
|
|
{
|
|
name: "ValidOverride",
|
|
dbReturn: "75",
|
|
wantVal: 75,
|
|
wantOK: true,
|
|
},
|
|
{
|
|
name: "OutOfRangeValue",
|
|
dbReturn: "101",
|
|
wantOK: false,
|
|
},
|
|
{
|
|
name: "NonIntegerValue",
|
|
dbReturn: "abc",
|
|
wantOK: false,
|
|
},
|
|
{
|
|
name: "UnexpectedDBError",
|
|
dbErr: xerrors.New("connection refused"),
|
|
wantOK: false,
|
|
wantWarnLog: true,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
mockDB := dbmock.NewMockStore(ctrl)
|
|
sink := testutil.NewFakeSink(t)
|
|
|
|
srv := &Server{
|
|
db: mockDB,
|
|
logger: sink.Logger(),
|
|
}
|
|
|
|
mockDB.EXPECT().GetUserChatCompactionThreshold(gomock.Any(), database.GetUserChatCompactionThresholdParams{
|
|
UserID: userID,
|
|
Key: expectedKey,
|
|
}).Return(tc.dbReturn, tc.dbErr)
|
|
|
|
val, ok := srv.resolveUserCompactionThreshold(context.Background(), userID, modelConfigID)
|
|
require.Equal(t, tc.wantVal, val)
|
|
require.Equal(t, tc.wantOK, ok)
|
|
|
|
warns := sink.Entries(func(e slog.SinkEntry) bool {
|
|
return e.Level == slog.LevelWarn
|
|
})
|
|
if tc.wantWarnLog {
|
|
require.NotEmpty(t, warns, "expected a warning log entry")
|
|
return
|
|
}
|
|
require.Empty(t, warns, "unexpected warning log entry")
|
|
})
|
|
}
|
|
}
|
|
|
|
// requireFieldValue asserts that a SinkEntry contains a field with
|
|
// the given name and value.
|
|
func requireFieldValue(t *testing.T, entry slog.SinkEntry, name string, expected interface{}) {
|
|
t.Helper()
|
|
for _, f := range entry.Fields {
|
|
if f.Name == name {
|
|
require.Equal(t, expected, f.Value, "field %q value mismatch", name)
|
|
return
|
|
}
|
|
}
|
|
t.Fatalf("field %q not found in log entry", name)
|
|
}
|
|
|
|
func TestSkillsFromParts(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("Empty", func(t *testing.T) {
|
|
t.Parallel()
|
|
got := skillsFromParts(nil)
|
|
require.Empty(t, got)
|
|
})
|
|
|
|
t.Run("NoSkillParts", func(t *testing.T) {
|
|
t.Parallel()
|
|
msgs := []database.ChatMessage{
|
|
chattest.ChatMessageWithParts([]codersdk.ChatMessagePart{
|
|
{Type: codersdk.ChatMessagePartTypeText, Text: "hello"},
|
|
}),
|
|
}
|
|
got := skillsFromParts(msgs)
|
|
require.Empty(t, got)
|
|
})
|
|
|
|
t.Run("SingleSkill", func(t *testing.T) {
|
|
t.Parallel()
|
|
msgs := []database.ChatMessage{
|
|
chattest.ChatMessageWithParts([]codersdk.ChatMessagePart{
|
|
{
|
|
Type: codersdk.ChatMessagePartTypeSkill,
|
|
SkillName: "deep-review",
|
|
SkillDescription: "Multi-reviewer code review",
|
|
SkillDir: "/home/coder/.agents/skills/deep-review",
|
|
},
|
|
}),
|
|
}
|
|
got := skillsFromParts(msgs)
|
|
require.Len(t, got, 1)
|
|
require.Equal(t, "deep-review", got[0].Name)
|
|
require.Equal(t, "Multi-reviewer code review", got[0].Description)
|
|
require.Equal(t, "/home/coder/.agents/skills/deep-review", got[0].Dir)
|
|
})
|
|
|
|
t.Run("MultipleSkillsAcrossMessages", func(t *testing.T) {
|
|
t.Parallel()
|
|
msgs := []database.ChatMessage{
|
|
chattest.ChatMessageWithParts([]codersdk.ChatMessagePart{
|
|
{
|
|
Type: codersdk.ChatMessagePartTypeSkill,
|
|
SkillName: "pull-requests",
|
|
SkillDir: "/home/coder/.agents/skills/pull-requests",
|
|
},
|
|
}),
|
|
chattest.ChatMessageWithParts([]codersdk.ChatMessagePart{
|
|
{
|
|
Type: codersdk.ChatMessagePartTypeSkill,
|
|
SkillName: "deep-review",
|
|
SkillDir: "/home/coder/.agents/skills/deep-review",
|
|
},
|
|
}),
|
|
}
|
|
got := skillsFromParts(msgs)
|
|
require.Len(t, got, 2)
|
|
require.Equal(t, "pull-requests", got[0].Name)
|
|
require.Equal(t, "deep-review", got[1].Name)
|
|
})
|
|
|
|
t.Run("MixedPartTypes", func(t *testing.T) {
|
|
t.Parallel()
|
|
msgs := []database.ChatMessage{
|
|
chattest.ChatMessageWithParts([]codersdk.ChatMessagePart{
|
|
{
|
|
Type: codersdk.ChatMessagePartTypeContextFile,
|
|
ContextFilePath: "/home/coder/.coder/AGENTS.md",
|
|
},
|
|
{
|
|
Type: codersdk.ChatMessagePartTypeSkill,
|
|
SkillName: "refine-plan",
|
|
SkillDir: "/home/coder/.agents/skills/refine-plan",
|
|
},
|
|
}),
|
|
// A text-only message should be skipped entirely.
|
|
chattest.ChatMessageWithParts([]codersdk.ChatMessagePart{
|
|
{Type: codersdk.ChatMessagePartTypeText, Text: "user turn"},
|
|
}),
|
|
}
|
|
got := skillsFromParts(msgs)
|
|
require.Len(t, got, 1)
|
|
require.Equal(t, "refine-plan", got[0].Name)
|
|
require.Equal(t, "/home/coder/.agents/skills/refine-plan", got[0].Dir)
|
|
})
|
|
|
|
t.Run("OptionalDescriptionOmitted", func(t *testing.T) {
|
|
t.Parallel()
|
|
msgs := []database.ChatMessage{
|
|
chattest.ChatMessageWithParts([]codersdk.ChatMessagePart{
|
|
{
|
|
Type: codersdk.ChatMessagePartTypeSkill,
|
|
SkillName: "refine-plan",
|
|
SkillDir: "/home/coder/.agents/skills/refine-plan",
|
|
},
|
|
}),
|
|
}
|
|
got := skillsFromParts(msgs)
|
|
require.Len(t, got, 1)
|
|
require.Equal(t, "refine-plan", got[0].Name)
|
|
require.Empty(t, got[0].Description)
|
|
})
|
|
|
|
t.Run("InvalidJSON", func(t *testing.T) {
|
|
t.Parallel()
|
|
msgs := []database.ChatMessage{
|
|
{
|
|
Content: pqtype.NullRawMessage{
|
|
RawMessage: []byte(`not valid json with "skill" in it`),
|
|
Valid: true,
|
|
},
|
|
},
|
|
}
|
|
got := skillsFromParts(msgs)
|
|
require.Empty(t, got)
|
|
})
|
|
|
|
t.Run("RoundTrip", func(t *testing.T) {
|
|
// Simulate persist -> reconstruct cycle: marshal skill
|
|
// parts the same way persistInstructionFiles does, then
|
|
// verify skillsFromParts recovers the metadata.
|
|
t.Parallel()
|
|
want := []chattool.SkillMeta{
|
|
{Name: "deep-review", Description: "Multi-reviewer review", Dir: "/skills/deep-review"},
|
|
{Name: "pull-requests", Description: "", Dir: "/skills/pull-requests"},
|
|
}
|
|
agentID := uuid.New()
|
|
var parts []codersdk.ChatMessagePart
|
|
for _, s := range want {
|
|
parts = append(parts, codersdk.ChatMessagePart{
|
|
Type: codersdk.ChatMessagePartTypeSkill,
|
|
SkillName: s.Name,
|
|
SkillDescription: s.Description,
|
|
SkillDir: s.Dir,
|
|
ContextFileAgentID: uuid.NullUUID{UUID: agentID, Valid: true},
|
|
})
|
|
}
|
|
msgs := []database.ChatMessage{chattest.ChatMessageWithParts(parts)}
|
|
got := skillsFromParts(msgs)
|
|
require.Len(t, got, len(want))
|
|
for i, w := range want {
|
|
require.Equal(t, w.Name, got[i].Name)
|
|
require.Equal(t, w.Description, got[i].Description)
|
|
require.Equal(t, w.Dir, got[i].Dir)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestContextFileAgentID(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("EmptyMessages", func(t *testing.T) {
|
|
t.Parallel()
|
|
id, ok := contextFileAgentID(nil)
|
|
require.Equal(t, uuid.Nil, id)
|
|
require.False(t, ok)
|
|
})
|
|
|
|
t.Run("NoContextFileParts", func(t *testing.T) {
|
|
t.Parallel()
|
|
msgs := []database.ChatMessage{
|
|
chattest.ChatMessageWithParts([]codersdk.ChatMessagePart{
|
|
{Type: codersdk.ChatMessagePartTypeText, Text: "hello"},
|
|
}),
|
|
}
|
|
id, ok := contextFileAgentID(msgs)
|
|
require.Equal(t, uuid.Nil, id)
|
|
require.False(t, ok)
|
|
})
|
|
|
|
t.Run("SingleContextFile", func(t *testing.T) {
|
|
t.Parallel()
|
|
agentID := uuid.New()
|
|
msgs := []database.ChatMessage{
|
|
chattest.ChatMessageWithParts([]codersdk.ChatMessagePart{
|
|
{
|
|
Type: codersdk.ChatMessagePartTypeContextFile,
|
|
ContextFilePath: "/some/path",
|
|
ContextFileAgentID: uuid.NullUUID{UUID: agentID, Valid: true},
|
|
},
|
|
}),
|
|
}
|
|
id, ok := contextFileAgentID(msgs)
|
|
require.Equal(t, agentID, id)
|
|
require.True(t, ok)
|
|
})
|
|
|
|
t.Run("MultipleContextFiles", func(t *testing.T) {
|
|
t.Parallel()
|
|
agentID1 := uuid.New()
|
|
agentID2 := uuid.New()
|
|
msgs := []database.ChatMessage{
|
|
chattest.ChatMessageWithParts([]codersdk.ChatMessagePart{
|
|
{
|
|
Type: codersdk.ChatMessagePartTypeContextFile,
|
|
ContextFilePath: "/first/path",
|
|
ContextFileAgentID: uuid.NullUUID{UUID: agentID1, Valid: true},
|
|
},
|
|
}),
|
|
chattest.ChatMessageWithParts([]codersdk.ChatMessagePart{
|
|
{
|
|
Type: codersdk.ChatMessagePartTypeContextFile,
|
|
ContextFilePath: "/second/path",
|
|
ContextFileAgentID: uuid.NullUUID{UUID: agentID2, Valid: true},
|
|
},
|
|
}),
|
|
}
|
|
id, ok := contextFileAgentID(msgs)
|
|
require.Equal(t, agentID2, id)
|
|
require.True(t, ok)
|
|
})
|
|
|
|
t.Run("IgnoresSkillOnlySentinel", func(t *testing.T) {
|
|
t.Parallel()
|
|
instructionAgentID := uuid.New()
|
|
sentinelAgentID := uuid.New()
|
|
msgs := []database.ChatMessage{
|
|
chattest.ChatMessageWithParts([]codersdk.ChatMessagePart{{
|
|
Type: codersdk.ChatMessagePartTypeContextFile,
|
|
ContextFilePath: "/workspace/AGENTS.md",
|
|
ContextFileAgentID: uuid.NullUUID{UUID: instructionAgentID, Valid: true},
|
|
}}),
|
|
chattest.ChatMessageWithParts([]codersdk.ChatMessagePart{{
|
|
Type: codersdk.ChatMessagePartTypeContextFile,
|
|
ContextFilePath: AgentChatContextSentinelPath,
|
|
ContextFileAgentID: uuid.NullUUID{
|
|
UUID: sentinelAgentID,
|
|
Valid: true,
|
|
},
|
|
}}),
|
|
}
|
|
id, ok := contextFileAgentID(msgs)
|
|
require.Equal(t, instructionAgentID, id)
|
|
require.True(t, ok)
|
|
})
|
|
|
|
t.Run("SentinelWithoutAgentID", func(t *testing.T) {
|
|
t.Parallel()
|
|
msgs := []database.ChatMessage{
|
|
chattest.ChatMessageWithParts([]codersdk.ChatMessagePart{
|
|
{
|
|
Type: codersdk.ChatMessagePartTypeContextFile,
|
|
ContextFileAgentID: uuid.NullUUID{Valid: false},
|
|
},
|
|
}),
|
|
}
|
|
id, ok := contextFileAgentID(msgs)
|
|
require.Equal(t, uuid.Nil, id)
|
|
require.False(t, ok)
|
|
})
|
|
}
|
|
|
|
func TestHasPersistedInstructionFiles(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("IgnoresAgentChatContextSentinel", func(t *testing.T) {
|
|
t.Parallel()
|
|
agentID := uuid.New()
|
|
msgs := []database.ChatMessage{
|
|
chattest.ChatMessageWithParts([]codersdk.ChatMessagePart{{
|
|
Type: codersdk.ChatMessagePartTypeContextFile,
|
|
ContextFilePath: AgentChatContextSentinelPath,
|
|
ContextFileAgentID: uuid.NullUUID{
|
|
UUID: agentID,
|
|
Valid: true,
|
|
},
|
|
}}),
|
|
}
|
|
require.False(t, hasPersistedInstructionFiles(msgs))
|
|
})
|
|
|
|
t.Run("AcceptsPersistedInstructionFile", func(t *testing.T) {
|
|
t.Parallel()
|
|
agentID := uuid.New()
|
|
msgs := []database.ChatMessage{
|
|
chattest.ChatMessageWithParts([]codersdk.ChatMessagePart{{
|
|
Type: codersdk.ChatMessagePartTypeContextFile,
|
|
ContextFilePath: "/workspace/AGENTS.md",
|
|
ContextFileContent: "repo instructions",
|
|
ContextFileAgentID: uuid.NullUUID{UUID: agentID, Valid: true},
|
|
}}),
|
|
}
|
|
require.True(t, hasPersistedInstructionFiles(msgs))
|
|
})
|
|
}
|
|
|
|
func TestInstructionFromContextFilesUsesLatestContextAgent(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
oldAgentID := uuid.New()
|
|
newAgentID := uuid.New()
|
|
msgs := []database.ChatMessage{
|
|
chattest.ChatMessageWithParts([]codersdk.ChatMessagePart{{
|
|
Type: codersdk.ChatMessagePartTypeContextFile,
|
|
ContextFilePath: "/old/AGENTS.md",
|
|
ContextFileContent: "old instructions",
|
|
ContextFileOS: "darwin",
|
|
ContextFileDirectory: "/old",
|
|
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
|
|
}}),
|
|
chattest.ChatMessageWithParts([]codersdk.ChatMessagePart{{
|
|
Type: codersdk.ChatMessagePartTypeContextFile,
|
|
ContextFilePath: "/new/AGENTS.md",
|
|
ContextFileContent: "new instructions",
|
|
ContextFileOS: "linux",
|
|
ContextFileDirectory: "/new",
|
|
ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true},
|
|
}}),
|
|
}
|
|
|
|
got := instructionFromContextFiles(msgs)
|
|
require.Contains(t, got, "new instructions")
|
|
require.Contains(t, got, "Operating System: linux")
|
|
require.Contains(t, got, "Working Directory: /new")
|
|
require.NotContains(t, got, "old instructions")
|
|
require.NotContains(t, got, "Operating System: darwin")
|
|
}
|
|
|
|
func TestInstructionFromContextFilesKeepsLegacyUnstampedParts(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
oldAgentID := uuid.New()
|
|
newAgentID := uuid.New()
|
|
msgs := []database.ChatMessage{
|
|
chattest.ChatMessageWithParts([]codersdk.ChatMessagePart{{
|
|
Type: codersdk.ChatMessagePartTypeContextFile,
|
|
ContextFilePath: "/legacy/AGENTS.md",
|
|
ContextFileContent: "legacy instructions",
|
|
}}),
|
|
chattest.ChatMessageWithParts([]codersdk.ChatMessagePart{{
|
|
Type: codersdk.ChatMessagePartTypeContextFile,
|
|
ContextFilePath: "/old/AGENTS.md",
|
|
ContextFileContent: "old instructions",
|
|
ContextFileOS: "darwin",
|
|
ContextFileDirectory: "/old",
|
|
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
|
|
}}),
|
|
chattest.ChatMessageWithParts([]codersdk.ChatMessagePart{{
|
|
Type: codersdk.ChatMessagePartTypeContextFile,
|
|
ContextFilePath: "/new/AGENTS.md",
|
|
ContextFileContent: "new instructions",
|
|
ContextFileOS: "linux",
|
|
ContextFileDirectory: "/new",
|
|
ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true},
|
|
}}),
|
|
}
|
|
|
|
got := instructionFromContextFiles(msgs)
|
|
require.Contains(t, got, "legacy instructions")
|
|
require.Contains(t, got, "new instructions")
|
|
require.Contains(t, got, "Operating System: linux")
|
|
require.Contains(t, got, "Working Directory: /new")
|
|
require.NotContains(t, got, "old instructions")
|
|
require.NotContains(t, got, "Operating System: darwin")
|
|
}
|
|
|
|
func TestSkillsFromPartsKeepsLegacyUnstampedParts(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
oldAgentID := uuid.New()
|
|
newAgentID := uuid.New()
|
|
msgs := []database.ChatMessage{
|
|
chattest.ChatMessageWithParts([]codersdk.ChatMessagePart{{
|
|
Type: codersdk.ChatMessagePartTypeSkill,
|
|
SkillName: "repo-helper-legacy",
|
|
SkillDir: "/skills/repo-helper-legacy",
|
|
}}),
|
|
chattest.ChatMessageWithParts([]codersdk.ChatMessagePart{
|
|
{
|
|
Type: codersdk.ChatMessagePartTypeContextFile,
|
|
ContextFilePath: "/old/AGENTS.md",
|
|
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
|
|
},
|
|
{
|
|
Type: codersdk.ChatMessagePartTypeSkill,
|
|
SkillName: "repo-helper-old",
|
|
SkillDir: "/skills/repo-helper-old",
|
|
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
|
|
},
|
|
}),
|
|
chattest.ChatMessageWithParts([]codersdk.ChatMessagePart{
|
|
{
|
|
Type: codersdk.ChatMessagePartTypeContextFile,
|
|
ContextFilePath: AgentChatContextSentinelPath,
|
|
ContextFileAgentID: uuid.NullUUID{
|
|
UUID: newAgentID,
|
|
Valid: true,
|
|
},
|
|
},
|
|
{
|
|
Type: codersdk.ChatMessagePartTypeSkill,
|
|
SkillName: "repo-helper-new",
|
|
SkillDir: "/skills/repo-helper-new",
|
|
ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true},
|
|
},
|
|
}),
|
|
}
|
|
|
|
got := skillsFromParts(msgs)
|
|
require.Equal(t, []chattool.SkillMeta{
|
|
{Name: "repo-helper-legacy", Dir: "/skills/repo-helper-legacy"},
|
|
{Name: "repo-helper-new", Dir: "/skills/repo-helper-new"},
|
|
}, got)
|
|
}
|
|
|
|
func TestSkillsFromPartsUsesLatestContextAgent(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
oldAgentID := uuid.New()
|
|
newAgentID := uuid.New()
|
|
msgs := []database.ChatMessage{
|
|
chattest.ChatMessageWithParts([]codersdk.ChatMessagePart{
|
|
{
|
|
Type: codersdk.ChatMessagePartTypeContextFile,
|
|
ContextFilePath: "/old/AGENTS.md",
|
|
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
|
|
},
|
|
{
|
|
Type: codersdk.ChatMessagePartTypeSkill,
|
|
SkillName: "repo-helper-old",
|
|
SkillDir: "/skills/repo-helper-old",
|
|
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
|
|
},
|
|
}),
|
|
chattest.ChatMessageWithParts([]codersdk.ChatMessagePart{
|
|
{
|
|
Type: codersdk.ChatMessagePartTypeContextFile,
|
|
ContextFilePath: AgentChatContextSentinelPath,
|
|
ContextFileAgentID: uuid.NullUUID{
|
|
UUID: newAgentID,
|
|
Valid: true,
|
|
},
|
|
},
|
|
{
|
|
Type: codersdk.ChatMessagePartTypeSkill,
|
|
SkillName: "repo-helper-new",
|
|
SkillDir: "/skills/repo-helper-new",
|
|
ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true},
|
|
},
|
|
}),
|
|
}
|
|
|
|
got := skillsFromParts(msgs)
|
|
require.Equal(t, []chattool.SkillMeta{{
|
|
Name: "repo-helper-new",
|
|
Dir: "/skills/repo-helper-new",
|
|
}}, got)
|
|
}
|
|
|
|
func TestMergeSkillMetas(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
persisted := []chattool.SkillMeta{{
|
|
Name: "repo-helper",
|
|
Description: "Persisted skill",
|
|
Dir: "/skills/repo-helper-old",
|
|
}}
|
|
discovered := []chattool.SkillMeta{
|
|
{
|
|
Name: "repo-helper",
|
|
Description: "Discovered replacement",
|
|
Dir: "/skills/repo-helper-new",
|
|
MetaFile: "SKILL.md",
|
|
},
|
|
{
|
|
Name: "deep-review",
|
|
Description: "Discovered skill",
|
|
Dir: "/skills/deep-review",
|
|
},
|
|
}
|
|
|
|
got := mergeSkillMetas(persisted, discovered)
|
|
require.Equal(t, []chattool.SkillMeta{
|
|
discovered[0],
|
|
discovered[1],
|
|
}, got)
|
|
}
|
|
|
|
func TestSelectSkillMetasForInstructionRefresh(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
persisted := []chattool.SkillMeta{{Name: "persisted", Dir: "/skills/persisted"}}
|
|
discovered := []chattool.SkillMeta{{Name: "discovered", Dir: "/skills/discovered"}}
|
|
currentAgentID := uuid.New()
|
|
otherAgentID := uuid.New()
|
|
|
|
t.Run("MergesCurrentAgentSkills", func(t *testing.T) {
|
|
t.Parallel()
|
|
got := selectSkillMetasForInstructionRefresh(
|
|
persisted,
|
|
discovered,
|
|
uuid.NullUUID{UUID: currentAgentID, Valid: true},
|
|
uuid.NullUUID{UUID: currentAgentID, Valid: true},
|
|
)
|
|
require.Equal(t, []chattool.SkillMeta{discovered[0], persisted[0]}, got)
|
|
})
|
|
|
|
t.Run("DropsStalePersistedSkillsWhenAgentChanged", func(t *testing.T) {
|
|
t.Parallel()
|
|
got := selectSkillMetasForInstructionRefresh(
|
|
persisted,
|
|
discovered,
|
|
uuid.NullUUID{UUID: currentAgentID, Valid: true},
|
|
uuid.NullUUID{UUID: otherAgentID, Valid: true},
|
|
)
|
|
require.Equal(t, discovered, got)
|
|
})
|
|
|
|
t.Run("PreservesPersistedSkillsWhenAgentLookupFails", func(t *testing.T) {
|
|
t.Parallel()
|
|
got := selectSkillMetasForInstructionRefresh(
|
|
persisted,
|
|
nil,
|
|
uuid.NullUUID{},
|
|
uuid.NullUUID{UUID: otherAgentID, Valid: true},
|
|
)
|
|
require.Equal(t, persisted, got)
|
|
})
|
|
}
|
|
|
|
// TestProcessChat_IgnoresStaleControlNotification verifies that
|
|
// processChat is not interrupted by a "pending" notification
|
|
// published before processing begins. This is the race that caused
|
|
// TestOpenAIReasoningWithWebSearchRoundTripStoreFalse to flake:
|
|
// SendMessage publishes "pending" via PostgreSQL NOTIFY, and due
|
|
// to async delivery the notification can arrive at the control
|
|
// subscriber after it registers but before the processor publishes
|
|
// "running".
|
|
func TestProcessChat_IgnoresStaleControlNotification(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
ps := dbpubsub.NewInMemory()
|
|
clock := quartz.NewMock(t)
|
|
|
|
chatID := uuid.New()
|
|
workerID := uuid.New()
|
|
|
|
server := &Server{
|
|
db: db,
|
|
logger: logger,
|
|
pubsub: ps,
|
|
clock: clock,
|
|
workerID: workerID,
|
|
chatHeartbeatInterval: time.Minute,
|
|
metrics: chatloop.NopMetrics(),
|
|
configCache: newChatConfigCache(ctx, db, clock),
|
|
heartbeatRegistry: make(map[uuid.UUID]*heartbeatEntry),
|
|
}
|
|
|
|
// Publish a stale "pending" notification on the control channel
|
|
// BEFORE processChat subscribes. In production this is the
|
|
// notification from SendMessage that triggered the processing.
|
|
staleNotify, err := json.Marshal(coderdpubsub.ChatStreamNotifyMessage{
|
|
Status: string(database.ChatStatusPending),
|
|
})
|
|
require.NoError(t, err)
|
|
err = ps.Publish(coderdpubsub.ChatStreamNotifyChannel(chatID), staleNotify)
|
|
require.NoError(t, err)
|
|
|
|
// Track which status processChat writes during cleanup.
|
|
var finalStatus database.ChatStatus
|
|
|
|
// The deferred cleanup in processChat runs a transaction.
|
|
db.EXPECT().InTx(gomock.Any(), gomock.Any()).DoAndReturn(
|
|
func(fn func(database.Store) error, _ *database.TxOptions) error {
|
|
return fn(db)
|
|
},
|
|
)
|
|
db.EXPECT().GetChatByIDForUpdate(gomock.Any(), chatID).Return(
|
|
database.Chat{ID: chatID, Status: database.ChatStatusRunning, WorkerID: uuid.NullUUID{UUID: workerID, Valid: true}}, nil,
|
|
)
|
|
db.EXPECT().UpdateChatStatus(gomock.Any(), gomock.Any()).DoAndReturn(
|
|
func(_ context.Context, params database.UpdateChatStatusParams) (database.Chat, error) {
|
|
finalStatus = params.Status
|
|
return database.Chat{ID: chatID, Status: params.Status}, nil
|
|
},
|
|
)
|
|
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(
|
|
database.Chat{ID: chatID, Status: database.ChatStatusError},
|
|
nil,
|
|
)
|
|
|
|
// resolveChatModel fails immediately — that's fine, we only
|
|
// need processChat to get past initialization without being
|
|
// interrupted by the stale notification.
|
|
db.EXPECT().GetChatModelConfigByID(gomock.Any(), gomock.Any()).Return(
|
|
database.ChatModelConfig{}, xerrors.New("no model configured"),
|
|
).AnyTimes()
|
|
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return(nil, nil).AnyTimes()
|
|
db.EXPECT().GetEnabledChatModelConfigs(gomock.Any()).Return(nil, nil).AnyTimes()
|
|
db.EXPECT().GetChatUsageLimitConfig(gomock.Any()).Return(
|
|
database.ChatUsageLimitConfig{}, sql.ErrNoRows,
|
|
).AnyTimes()
|
|
db.EXPECT().GetChatMessagesForPromptByChatID(gomock.Any(), chatID).Return(nil, nil).AnyTimes()
|
|
|
|
chat := database.Chat{ID: chatID, LastModelConfigID: uuid.New()}
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer close(done)
|
|
server.processChat(ctx, chat)
|
|
}()
|
|
|
|
// Wait for processChat to finish entirely. It re-reads chat state and
|
|
// runs more cleanup after UpdateChatStatus, so signaling completion from
|
|
// the status update itself races test teardown.
|
|
testutil.TryReceive(ctx, t, done)
|
|
|
|
// If the stale notification interrupted us, status would be
|
|
// "waiting" (the ErrInterrupted path). Since the gate blocked
|
|
// it, processChat reached runChat, which failed on model
|
|
// resolution → status is "error".
|
|
require.Equal(t, database.ChatStatusError, finalStatus,
|
|
"processChat should have reached runChat (error), not been interrupted (waiting)")
|
|
}
|
|
|
|
func TestShouldPublishFinishedChatState(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
chatID := uuid.New()
|
|
workerID := uuid.New()
|
|
|
|
server := &Server{db: db}
|
|
updatedChat := database.Chat{
|
|
ID: chatID,
|
|
Status: database.ChatStatusWaiting,
|
|
WorkerID: uuid.NullUUID{},
|
|
}
|
|
|
|
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(database.Chat{
|
|
ID: chatID,
|
|
Status: database.ChatStatusWaiting,
|
|
WorkerID: uuid.NullUUID{},
|
|
}, nil)
|
|
|
|
require.True(t, server.shouldPublishFinishedChatState(ctx, logger, updatedChat))
|
|
|
|
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(database.Chat{
|
|
ID: chatID,
|
|
Status: database.ChatStatusRunning,
|
|
WorkerID: uuid.NullUUID{UUID: workerID, Valid: true},
|
|
}, nil)
|
|
|
|
require.False(t, server.shouldPublishFinishedChatState(ctx, logger, updatedChat))
|
|
}
|
|
|
|
// TestShouldPublishFinishedChatState_DBErrorPublishes pins the
|
|
// deliberate fail-open behavior when the re-read query errors: we
|
|
// surface the finished state anyway so watchers don't get stuck
|
|
// waiting for a status update that never arrives. The error path is
|
|
// easy to regress into a fail-closed default otherwise.
|
|
func TestShouldPublishFinishedChatState_DBErrorPublishes(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
chatID := uuid.New()
|
|
|
|
server := &Server{db: db}
|
|
updatedChat := database.Chat{
|
|
ID: chatID,
|
|
Status: database.ChatStatusWaiting,
|
|
WorkerID: uuid.NullUUID{},
|
|
}
|
|
|
|
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(
|
|
database.Chat{}, xerrors.New("boom"),
|
|
)
|
|
|
|
require.True(t, server.shouldPublishFinishedChatState(ctx, logger, updatedChat),
|
|
"fail-open: a re-read error must not swallow the status change")
|
|
}
|
|
|
|
// TestHeartbeatTick_StolenChatIsInterrupted verifies that when the
|
|
// batch heartbeat UPDATE does not return a registered chat's ID
|
|
// (because another replica stole it or it was completed), the
|
|
// heartbeat tick cancels that chat's context with ErrInterrupted
|
|
// while leaving surviving chats untouched.
|
|
func TestHeartbeatTick_StolenChatIsInterrupted(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
clock := quartz.NewMock(t)
|
|
|
|
workerID := uuid.New()
|
|
|
|
server := &Server{
|
|
db: db,
|
|
logger: logger,
|
|
clock: clock,
|
|
workerID: workerID,
|
|
chatHeartbeatInterval: time.Minute,
|
|
metrics: chatloop.NopMetrics(),
|
|
heartbeatRegistry: make(map[uuid.UUID]*heartbeatEntry),
|
|
}
|
|
|
|
// Create three chats with independent cancel functions.
|
|
chat1 := uuid.New()
|
|
chat2 := uuid.New()
|
|
chat3 := uuid.New()
|
|
|
|
_, cancel1 := context.WithCancelCause(ctx)
|
|
_, cancel2 := context.WithCancelCause(ctx)
|
|
ctx3, cancel3 := context.WithCancelCause(ctx)
|
|
|
|
server.registerHeartbeat(&heartbeatEntry{
|
|
cancelWithCause: cancel1,
|
|
chatID: chat1,
|
|
logger: logger,
|
|
})
|
|
server.registerHeartbeat(&heartbeatEntry{
|
|
cancelWithCause: cancel2,
|
|
chatID: chat2,
|
|
logger: logger,
|
|
})
|
|
server.registerHeartbeat(&heartbeatEntry{
|
|
cancelWithCause: cancel3,
|
|
chatID: chat3,
|
|
logger: logger,
|
|
})
|
|
|
|
// The batch UPDATE returns only chat1 and chat2 —
|
|
// chat3 was "stolen" by another replica.
|
|
db.EXPECT().UpdateChatHeartbeats(gomock.Any(), gomock.Any()).DoAndReturn(
|
|
func(_ context.Context, params database.UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
|
|
require.Equal(t, workerID, params.WorkerID)
|
|
require.Len(t, params.IDs, 3)
|
|
// Return only chat1 and chat2 as surviving.
|
|
return []uuid.UUID{chat1, chat2}, nil
|
|
},
|
|
)
|
|
|
|
server.heartbeatTick(ctx)
|
|
|
|
// chat3's context should be canceled with ErrInterrupted.
|
|
require.ErrorIs(t, context.Cause(ctx3), chatloop.ErrInterrupted,
|
|
"stolen chat should be interrupted")
|
|
|
|
// chat3 should have been removed from the registry by
|
|
// unregister (in production this happens via defer in
|
|
// processChat). The heartbeat tick itself does not
|
|
// unregister — it only cancels. Verify the entry is
|
|
// still present (processChat's defer would clean it up).
|
|
server.heartbeatMu.Lock()
|
|
_, chat1Exists := server.heartbeatRegistry[chat1]
|
|
_, chat2Exists := server.heartbeatRegistry[chat2]
|
|
_, chat3Exists := server.heartbeatRegistry[chat3]
|
|
server.heartbeatMu.Unlock()
|
|
|
|
require.True(t, chat1Exists, "surviving chat1 should remain registered")
|
|
require.True(t, chat2Exists, "surviving chat2 should remain registered")
|
|
require.True(t, chat3Exists,
|
|
"stolen chat3 should still be in registry (processChat defer removes it)")
|
|
}
|
|
|
|
// TestHeartbeatTick_DBErrorDoesNotInterruptChats verifies that a
|
|
// transient database failure causes the tick to log and return
|
|
// without canceling any registered chats.
|
|
func TestHeartbeatTick_DBErrorDoesNotInterruptChats(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
clock := quartz.NewMock(t)
|
|
|
|
server := &Server{
|
|
db: db,
|
|
logger: logger,
|
|
clock: clock,
|
|
workerID: uuid.New(),
|
|
chatHeartbeatInterval: time.Minute,
|
|
metrics: chatloop.NopMetrics(),
|
|
heartbeatRegistry: make(map[uuid.UUID]*heartbeatEntry),
|
|
}
|
|
|
|
chatID := uuid.New()
|
|
chatCtx, cancel := context.WithCancelCause(ctx)
|
|
|
|
server.registerHeartbeat(&heartbeatEntry{
|
|
cancelWithCause: cancel,
|
|
chatID: chatID,
|
|
logger: logger,
|
|
})
|
|
|
|
// Simulate a transient DB error.
|
|
db.EXPECT().UpdateChatHeartbeats(gomock.Any(), gomock.Any()).Return(
|
|
nil, xerrors.New("connection reset"),
|
|
)
|
|
|
|
server.heartbeatTick(ctx)
|
|
|
|
// Chat should NOT be interrupted — the tick logged and
|
|
// returned early.
|
|
require.NoError(t, chatCtx.Err(),
|
|
"chat context should not be canceled on transient DB error")
|
|
}
|
|
|
|
// TestSubscribeCancelDuringGrace_ReapedBySweep verifies that a
|
|
// subscriber detach inside bufferRetainGracePeriod (the OSS trigger
|
|
// for the retained-buffer leak) leaves the state mapped, and the
|
|
// next sweep past the grace window reaps it.
|
|
func TestSubscribeCancelDuringGrace_ReapedBySweep(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
logger := slogtest.Make(t, nil)
|
|
mClock := quartz.NewMock(t)
|
|
|
|
server := &Server{
|
|
logger: logger,
|
|
clock: mClock,
|
|
}
|
|
|
|
chatID := uuid.New()
|
|
start := mClock.Now()
|
|
|
|
// Just-finished chat: processing done, buffer retained for
|
|
// late-connecting relay subscribers.
|
|
state := &chatStreamState{
|
|
buffering: false,
|
|
bufferRetainedAt: start,
|
|
subscribers: map[uuid.UUID]chan codersdk.ChatStreamEvent{},
|
|
buffer: []codersdk.ChatStreamEvent{{
|
|
Type: codersdk.ChatStreamEventTypeMessagePart,
|
|
MessagePart: &codersdk.ChatStreamMessagePart{
|
|
Role: codersdk.ChatMessageRoleAssistant,
|
|
},
|
|
}},
|
|
}
|
|
server.chatStreams.Store(chatID, state)
|
|
|
|
// Real subscribeToStream cancel path: the WS subscriber detach
|
|
// that leaks in prod.
|
|
snapshot, currentRetry, events, cancelSub := server.subscribeToStream(chatID)
|
|
require.Len(t, snapshot, 1)
|
|
require.Nil(t, currentRetry)
|
|
require.NotNil(t, events)
|
|
|
|
mClock.Advance(bufferRetainGracePeriod / 2)
|
|
cancelSub()
|
|
|
|
_, ok := server.chatStreams.Load(chatID)
|
|
require.True(t, ok,
|
|
"entry should remain during grace window after subscriber detach")
|
|
|
|
mClock.Advance(bufferRetainGracePeriod)
|
|
server.sweepIdleStreams()
|
|
|
|
_, ok = server.chatStreams.Load(chatID)
|
|
require.False(t, ok,
|
|
"entry should be reaped after grace period expires and sweep runs")
|
|
}
|
|
|
|
// TestSweepIdleStreams_ReapsStaleRetainedBuffer: grace expired, no
|
|
// subscribers, not buffering -> reaped.
|
|
func TestSweepIdleStreams_ReapsStaleRetainedBuffer(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
mClock := quartz.NewMock(t)
|
|
server := &Server{
|
|
logger: slogtest.Make(t, nil),
|
|
clock: mClock,
|
|
}
|
|
|
|
chatID := uuid.New()
|
|
state := &chatStreamState{
|
|
buffering: false,
|
|
bufferRetainedAt: mClock.Now(),
|
|
subscribers: map[uuid.UUID]chan codersdk.ChatStreamEvent{},
|
|
buffer: []codersdk.ChatStreamEvent{{
|
|
Type: codersdk.ChatStreamEventTypeMessagePart,
|
|
MessagePart: &codersdk.ChatStreamMessagePart{},
|
|
}},
|
|
}
|
|
server.chatStreams.Store(chatID, state)
|
|
|
|
mClock.Advance(bufferRetainGracePeriod + time.Second)
|
|
server.sweepIdleStreams()
|
|
|
|
_, ok := server.chatStreams.Load(chatID)
|
|
require.False(t, ok, "stale retained state should be reaped")
|
|
}
|
|
|
|
// TestSweepIdleStreams_DoesNotReapActiveBuffering: buffering=true
|
|
// blocks reap even long after any grace would have expired.
|
|
func TestSweepIdleStreams_DoesNotReapActiveBuffering(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
mClock := quartz.NewMock(t)
|
|
server := &Server{
|
|
logger: slogtest.Make(t, nil),
|
|
clock: mClock,
|
|
}
|
|
|
|
chatID := uuid.New()
|
|
state := &chatStreamState{
|
|
buffering: true,
|
|
subscribers: map[uuid.UUID]chan codersdk.ChatStreamEvent{},
|
|
buffer: []codersdk.ChatStreamEvent{{
|
|
Type: codersdk.ChatStreamEventTypeMessagePart,
|
|
MessagePart: &codersdk.ChatStreamMessagePart{},
|
|
}},
|
|
}
|
|
server.chatStreams.Store(chatID, state)
|
|
|
|
mClock.Advance(time.Hour)
|
|
server.sweepIdleStreams()
|
|
|
|
_, ok := server.chatStreams.Load(chatID)
|
|
require.True(t, ok, "actively-buffering state must not be reaped")
|
|
}
|
|
|
|
// TestSweepIdleStreams_DoesNotReapWithSubscribers: attached
|
|
// subscribers block reap even when grace has expired.
|
|
func TestSweepIdleStreams_DoesNotReapWithSubscribers(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
mClock := quartz.NewMock(t)
|
|
server := &Server{
|
|
logger: slogtest.Make(t, nil),
|
|
clock: mClock,
|
|
}
|
|
|
|
chatID := uuid.New()
|
|
state := &chatStreamState{
|
|
buffering: false,
|
|
bufferRetainedAt: mClock.Now(),
|
|
subscribers: map[uuid.UUID]chan codersdk.ChatStreamEvent{
|
|
uuid.New(): make(chan codersdk.ChatStreamEvent, 1),
|
|
},
|
|
buffer: []codersdk.ChatStreamEvent{{
|
|
Type: codersdk.ChatStreamEventTypeMessagePart,
|
|
MessagePart: &codersdk.ChatStreamMessagePart{},
|
|
}},
|
|
}
|
|
server.chatStreams.Store(chatID, state)
|
|
|
|
mClock.Advance(bufferRetainGracePeriod + time.Second)
|
|
server.sweepIdleStreams()
|
|
|
|
_, ok := server.chatStreams.Load(chatID)
|
|
require.True(t, ok, "state with subscribers must not be reaped")
|
|
}
|
|
|
|
// TestSweepIdleStreams_DefersDuringGracePeriod: sweep inside grace
|
|
// is a no-op; the next sweep past grace reaps.
|
|
func TestSweepIdleStreams_DefersDuringGracePeriod(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
mClock := quartz.NewMock(t)
|
|
server := &Server{
|
|
logger: slogtest.Make(t, nil),
|
|
clock: mClock,
|
|
}
|
|
|
|
chatID := uuid.New()
|
|
start := mClock.Now()
|
|
state := &chatStreamState{
|
|
buffering: false,
|
|
bufferRetainedAt: start,
|
|
subscribers: map[uuid.UUID]chan codersdk.ChatStreamEvent{},
|
|
buffer: []codersdk.ChatStreamEvent{{
|
|
Type: codersdk.ChatStreamEventTypeMessagePart,
|
|
MessagePart: &codersdk.ChatStreamMessagePart{},
|
|
}},
|
|
}
|
|
server.chatStreams.Store(chatID, state)
|
|
|
|
mClock.Advance(bufferRetainGracePeriod / 2)
|
|
server.sweepIdleStreams()
|
|
|
|
_, ok := server.chatStreams.Load(chatID)
|
|
require.True(t, ok, "sweep inside grace window must not reap")
|
|
|
|
mClock.Advance(bufferRetainGracePeriod)
|
|
server.sweepIdleStreams()
|
|
|
|
_, ok = server.chatStreams.Load(chatID)
|
|
require.False(t, ok, "sweep after grace window must reap")
|
|
}
|
|
|
|
// TestPublishToStream_DropZeroesBackingSlot verifies that evicting
|
|
// the oldest buffered event at capacity zeroes the dropped slot so
|
|
// its *ChatStreamMessagePart becomes GC-eligible immediately.
|
|
func TestPublishToStream_DropZeroesBackingSlot(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
mClock := quartz.NewMock(t)
|
|
server := &Server{
|
|
logger: slogtest.Make(t, nil),
|
|
clock: mClock,
|
|
}
|
|
|
|
chatID := uuid.New()
|
|
|
|
// Over-allocate by one so the post-drop append fits in place and
|
|
// exercises the backing-array reuse this test is checking.
|
|
buf := make([]codersdk.ChatStreamEvent, maxStreamBufferSize, maxStreamBufferSize+1)
|
|
for i := range buf {
|
|
buf[i] = codersdk.ChatStreamEvent{
|
|
Type: codersdk.ChatStreamEventTypeMessagePart,
|
|
MessagePart: &codersdk.ChatStreamMessagePart{},
|
|
}
|
|
}
|
|
// Sentinel in slot 0 distinguishes "slot was zeroed" from "slot
|
|
// was overwritten by a later append".
|
|
sentinel := &codersdk.ChatStreamMessagePart{
|
|
Role: codersdk.ChatMessageRoleAssistant,
|
|
}
|
|
buf[0] = codersdk.ChatStreamEvent{
|
|
Type: codersdk.ChatStreamEventTypeMessagePart,
|
|
MessagePart: sentinel,
|
|
}
|
|
// Alias over the full backing array so we can still observe slot
|
|
// 0 after publishToStream reslices state.buffer forward.
|
|
origBacking := buf[:cap(buf)]
|
|
|
|
state := &chatStreamState{
|
|
buffering: true,
|
|
buffer: buf,
|
|
subscribers: map[uuid.UUID]chan codersdk.ChatStreamEvent{},
|
|
}
|
|
server.chatStreams.Store(chatID, state)
|
|
|
|
newPart := &codersdk.ChatStreamMessagePart{
|
|
Role: codersdk.ChatMessageRoleAssistant,
|
|
}
|
|
server.publishToStream(chatID, codersdk.ChatStreamEvent{
|
|
Type: codersdk.ChatStreamEventTypeMessagePart,
|
|
MessagePart: newPart,
|
|
})
|
|
|
|
require.Equal(t, codersdk.ChatStreamEvent{}, origBacking[0],
|
|
"dropped slot must be zero-valued so its *ChatStreamMessagePart "+
|
|
"is eligible for GC; got %+v", origBacking[0])
|
|
|
|
// Sanity-check the in-place append path the fix targets: if Go's
|
|
// growth policy ever makes this append reallocate, this fails
|
|
// loudly so the test author revisits the setup.
|
|
require.Same(t, newPart, origBacking[len(origBacking)-1].MessagePart,
|
|
"append must have landed in the original backing array; the "+
|
|
"zero-out invariant only matters when cap > len")
|
|
}
|
|
|
|
// TestCleanupStreamIfIdle_StalePointerDoesNotDeleteFreshEntry covers
|
|
// the race where a caller holds a pointer to a no-longer-mapped
|
|
// state (e.g. a janitor Range callback racing a fresh
|
|
// getOrCreateStreamState) and would otherwise evict the fresh entry.
|
|
// With CompareAndDelete in cleanupStreamIfIdle the stale delete is
|
|
// a no-op.
|
|
func TestCleanupStreamIfIdle_StalePointerDoesNotDeleteFreshEntry(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
mClock := quartz.NewMock(t)
|
|
server := &Server{
|
|
logger: slogtest.Make(t, nil),
|
|
clock: mClock,
|
|
}
|
|
|
|
chatID := uuid.New()
|
|
|
|
// Stale pointer: reapable (not buffering, no subscribers, grace
|
|
// expired) but no longer the map's live entry.
|
|
stale := &chatStreamState{
|
|
buffering: false,
|
|
bufferRetainedAt: mClock.Now(),
|
|
subscribers: map[uuid.UUID]chan codersdk.ChatStreamEvent{},
|
|
}
|
|
|
|
// Fresh entry: the state getOrCreateStreamState would install
|
|
// after a racing processChat run. Actively buffering, so not
|
|
// reapable. Only this state is in the map.
|
|
fresh := &chatStreamState{
|
|
buffering: true,
|
|
subscribers: map[uuid.UUID]chan codersdk.ChatStreamEvent{},
|
|
}
|
|
server.chatStreams.Store(chatID, fresh)
|
|
|
|
mClock.Advance(bufferRetainGracePeriod + time.Second)
|
|
|
|
// Stale caller mirrors the janitor Range callback after the map
|
|
// entry has already been replaced.
|
|
stale.mu.Lock()
|
|
server.cleanupStreamIfIdle(chatID, stale)
|
|
stale.mu.Unlock()
|
|
|
|
got, ok := server.chatStreams.Load(chatID)
|
|
require.True(t, ok,
|
|
"fresh entry must remain mapped when cleanup is called with a stale pointer")
|
|
require.Same(t, fresh, got,
|
|
"cleanup must not replace the fresh entry with the stale one")
|
|
}
|
|
|
|
// TestSafeSweepIdleStreams_RecoversFromPanic verifies that an
|
|
// unexpected panic inside sweepIdleStreams is recovered rather than
|
|
// killing the janitor goroutine. Without this guard, a panic would
|
|
// silently reintroduce the very leak the janitor exists to prevent.
|
|
func TestSafeSweepIdleStreams_RecoversFromPanic(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := &Server{
|
|
logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
|
|
clock: quartz.NewMock(t),
|
|
}
|
|
|
|
chatID := uuid.New()
|
|
// A nil *chatStreamState passes the type assertion in sweepIdleStreams
|
|
// but panics on state.mu.Lock with a nil-pointer deref. Any future
|
|
// panic source in the sweep would trigger the same recovery path.
|
|
var nilState *chatStreamState
|
|
server.chatStreams.Store(chatID, nilState)
|
|
|
|
require.NotPanics(t, func() {
|
|
server.safeSweepIdleStreams(context.Background())
|
|
}, "safeSweepIdleStreams must recover panics so the janitor loop keeps running")
|
|
}
|
|
|
|
func TestGetWorkspaceConn_StaleAgentRecovery(t *testing.T) {
|
|
// Regression test: when a workspace is rebuilt, the chat's stored
|
|
// agent ID points to a disconnected agent from the old build. The
|
|
// cache-miss path must let dialWithLazyValidation discover the new
|
|
// agent instead of rejecting the old one immediately.
|
|
t.Parallel()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
workspaceID := uuid.New()
|
|
oldAgentID := uuid.New()
|
|
newAgentID := uuid.New()
|
|
buildID := uuid.New()
|
|
|
|
// Old agent: disconnected (from previous build).
|
|
oldAgent := database.WorkspaceAgent{
|
|
ID: oldAgentID,
|
|
FirstConnectedAt: sql.NullTime{
|
|
Time: time.Now().Add(-10 * time.Minute),
|
|
Valid: true,
|
|
},
|
|
LastConnectedAt: sql.NullTime{
|
|
Time: time.Now().Add(-10 * time.Minute),
|
|
Valid: true,
|
|
},
|
|
DisconnectedAt: sql.NullTime{
|
|
Time: time.Now().Add(-9 * time.Minute),
|
|
Valid: true,
|
|
},
|
|
}
|
|
|
|
// New agent: connected (from latest build).
|
|
newAgent := database.WorkspaceAgent{
|
|
ID: newAgentID,
|
|
Name: "main",
|
|
FirstConnectedAt: sql.NullTime{
|
|
Time: time.Now().Add(-1 * time.Minute),
|
|
Valid: true,
|
|
},
|
|
LastConnectedAt: sql.NullTime{
|
|
Time: time.Now(),
|
|
Valid: true,
|
|
},
|
|
}
|
|
|
|
chat := database.Chat{
|
|
ID: uuid.New(),
|
|
WorkspaceID: uuid.NullUUID{
|
|
UUID: workspaceID,
|
|
Valid: true,
|
|
},
|
|
AgentID: uuid.NullUUID{
|
|
UUID: oldAgentID,
|
|
Valid: true,
|
|
},
|
|
}
|
|
|
|
// ensureWorkspaceAgent fetches the stale agent.
|
|
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), oldAgentID).
|
|
Return(oldAgent, nil).Times(1)
|
|
// Lazy validation discovers the new agent.
|
|
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).
|
|
Return([]database.WorkspaceAgent{newAgent}, nil).Times(1)
|
|
// Post-switch: persist the new binding.
|
|
db.EXPECT().GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), workspaceID).
|
|
Return(database.WorkspaceBuild{ID: buildID}, nil).Times(1)
|
|
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), newAgentID).
|
|
Return(newAgent, nil).Times(1)
|
|
|
|
updatedChat := chat
|
|
updatedChat.AgentID = uuid.NullUUID{UUID: newAgentID, Valid: true}
|
|
updatedChat.BuildID = uuid.NullUUID{UUID: buildID, Valid: true}
|
|
db.EXPECT().UpdateChatBuildAgentBinding(gomock.Any(), database.UpdateChatBuildAgentBindingParams{
|
|
ID: chat.ID,
|
|
BuildID: uuid.NullUUID{UUID: buildID, Valid: true},
|
|
AgentID: uuid.NullUUID{UUID: newAgentID, Valid: true},
|
|
}).Return(updatedChat, nil).Times(1)
|
|
|
|
newConn := agentconnmock.NewMockAgentConn(ctrl)
|
|
newConn.EXPECT().SetExtraHeaders(gomock.Any()).Times(1)
|
|
|
|
server := &Server{
|
|
db: db,
|
|
logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
|
|
clock: quartz.NewReal(),
|
|
agentInactiveDisconnectTimeout: 30 * time.Second,
|
|
dialTimeout: defaultDialTimeout,
|
|
}
|
|
server.agentConnFn = func(_ context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
|
switch id {
|
|
case oldAgentID:
|
|
return nil, nil, xerrors.New("agent is not connected")
|
|
case newAgentID:
|
|
return newConn, func() {}, nil
|
|
default:
|
|
return nil, nil, xerrors.Errorf("unexpected agent ID: %s", id)
|
|
}
|
|
}
|
|
|
|
chatStateMu := &sync.Mutex{}
|
|
currentChat := chat
|
|
workspaceCtx := turnWorkspaceContext{
|
|
server: server,
|
|
chatStateMu: chatStateMu,
|
|
currentChat: ¤tChat,
|
|
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) {
|
|
return database.Chat{}, nil
|
|
},
|
|
}
|
|
defer workspaceCtx.close()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
gotConn, err := workspaceCtx.getWorkspaceConn(ctx)
|
|
require.NoError(t, err, "getWorkspaceConn should recover stale agent binding")
|
|
require.Same(t, newConn, gotConn, "should return the connection to the new agent")
|
|
|
|
// Verify the cache was updated to the new agent so subsequent
|
|
// cache-hit calls use the correct agent ID.
|
|
workspaceCtx.mu.Lock()
|
|
defer workspaceCtx.mu.Unlock()
|
|
require.Equal(t, newAgentID, workspaceCtx.agent.ID, "cached agent should be the new agent")
|
|
require.True(t, workspaceCtx.agentLoaded)
|
|
require.Same(t, newConn, workspaceCtx.conn, "connection should be cached for subsequent calls")
|
|
}
|
|
|
|
func TestGetWorkspaceConn_SameBuildAgentCrash(t *testing.T) {
|
|
// When an agent crashes on the same build (disconnected, but still
|
|
// in the latest build), dialWithLazyValidation dials, fails fast,
|
|
// validation finds the same agent, and the retry also fails. The
|
|
// wrapped dial error propagates (not errChatAgentDisconnected).
|
|
t.Parallel()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
workspaceID := uuid.New()
|
|
agentID := uuid.New()
|
|
|
|
// Agent: disconnected (crashed on current build).
|
|
agent := database.WorkspaceAgent{
|
|
ID: agentID,
|
|
Name: "main",
|
|
FirstConnectedAt: sql.NullTime{
|
|
Time: time.Now().Add(-10 * time.Minute),
|
|
Valid: true,
|
|
},
|
|
LastConnectedAt: sql.NullTime{
|
|
Time: time.Now().Add(-10 * time.Minute),
|
|
Valid: true,
|
|
},
|
|
DisconnectedAt: sql.NullTime{
|
|
Time: time.Now().Add(-9 * time.Minute),
|
|
Valid: true,
|
|
},
|
|
}
|
|
|
|
chat := database.Chat{
|
|
ID: uuid.New(),
|
|
WorkspaceID: uuid.NullUUID{
|
|
UUID: workspaceID,
|
|
Valid: true,
|
|
},
|
|
AgentID: uuid.NullUUID{
|
|
UUID: agentID,
|
|
Valid: true,
|
|
},
|
|
}
|
|
|
|
// ensureWorkspaceAgent fetches the (crashed) agent.
|
|
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).
|
|
Return(agent, nil).Times(1)
|
|
// Validation finds the same agent in the latest build.
|
|
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).
|
|
Return([]database.WorkspaceAgent{agent}, nil).Times(1)
|
|
|
|
dialErr := xerrors.New("agent is not connected")
|
|
server := &Server{
|
|
db: db,
|
|
logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
|
|
clock: quartz.NewReal(),
|
|
agentInactiveDisconnectTimeout: 30 * time.Second,
|
|
dialTimeout: defaultDialTimeout,
|
|
}
|
|
server.agentConnFn = func(_ context.Context, _ uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
|
return nil, nil, dialErr
|
|
}
|
|
|
|
chatStateMu := &sync.Mutex{}
|
|
currentChat := chat
|
|
workspaceCtx := turnWorkspaceContext{
|
|
server: server,
|
|
chatStateMu: chatStateMu,
|
|
currentChat: ¤tChat,
|
|
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) {
|
|
return database.Chat{}, nil
|
|
},
|
|
}
|
|
defer workspaceCtx.close()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
gotConn, err := workspaceCtx.getWorkspaceConn(ctx)
|
|
require.Nil(t, gotConn)
|
|
require.Error(t, err)
|
|
// The error should be a wrapped dial error, not the
|
|
// agent-disconnected sentinel.
|
|
require.NotErrorIs(t, err, errChatAgentDisconnected)
|
|
require.ErrorIs(t, err, dialErr)
|
|
|
|
// Cache should not have a connection, but the agent should
|
|
// still be loaded (ensureWorkspaceAgent cached it).
|
|
workspaceCtx.mu.Lock()
|
|
defer workspaceCtx.mu.Unlock()
|
|
require.True(t, workspaceCtx.agentLoaded)
|
|
require.Nil(t, workspaceCtx.conn)
|
|
}
|
|
|
|
func TestGetWorkspaceConn_StatusCheck(t *testing.T) {
|
|
// The cache-hit status check re-fetches the agent row for a fresh
|
|
// heartbeat timestamp. These tests verify that path detects
|
|
// disconnected or timed-out agents and that healthy or DB-error
|
|
// paths return the cached connection.
|
|
t.Parallel()
|
|
|
|
type testCase struct {
|
|
name string
|
|
agent database.WorkspaceAgent
|
|
dbError bool
|
|
wantErr error
|
|
wantReleaseCalled bool
|
|
}
|
|
|
|
tests := []testCase{
|
|
{
|
|
name: "DisconnectedAgentCacheHit",
|
|
agent: database.WorkspaceAgent{
|
|
FirstConnectedAt: sql.NullTime{
|
|
Time: time.Now().Add(-10 * time.Minute),
|
|
Valid: true,
|
|
},
|
|
LastConnectedAt: sql.NullTime{
|
|
Time: time.Now().Add(-10 * time.Minute),
|
|
Valid: true,
|
|
},
|
|
},
|
|
wantErr: errChatAgentDisconnected,
|
|
wantReleaseCalled: true,
|
|
},
|
|
{
|
|
// Agent never connected and the connection timeout
|
|
// has elapsed. This is the cache-hit timeout branch
|
|
// of isAgentUnreachable.
|
|
name: "TimedOutAgentCacheHit",
|
|
agent: database.WorkspaceAgent{
|
|
CreatedAt: time.Now().Add(-10 * time.Minute),
|
|
ConnectionTimeoutSeconds: 60,
|
|
},
|
|
wantErr: errChatAgentDisconnected,
|
|
wantReleaseCalled: true,
|
|
},
|
|
{
|
|
name: "CacheHitHealthyAgent",
|
|
agent: database.WorkspaceAgent{
|
|
FirstConnectedAt: sql.NullTime{
|
|
Time: time.Now().Add(-5 * time.Minute),
|
|
Valid: true,
|
|
},
|
|
LastConnectedAt: sql.NullTime{
|
|
Time: time.Now(),
|
|
Valid: true,
|
|
},
|
|
},
|
|
},
|
|
{
|
|
// When GetWorkspaceAgentByID returns an error on
|
|
// cache hit, the cached connection should be returned.
|
|
name: "CacheHitDBError",
|
|
agent: database.WorkspaceAgent{
|
|
FirstConnectedAt: sql.NullTime{
|
|
Time: time.Now().Add(-5 * time.Minute),
|
|
Valid: true,
|
|
},
|
|
LastConnectedAt: sql.NullTime{
|
|
Time: time.Now(),
|
|
Valid: true,
|
|
},
|
|
},
|
|
dbError: true,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
workspaceID := uuid.New()
|
|
agentID := uuid.New()
|
|
chat := database.Chat{
|
|
ID: uuid.New(),
|
|
WorkspaceID: uuid.NullUUID{
|
|
UUID: workspaceID,
|
|
Valid: true,
|
|
},
|
|
AgentID: uuid.NullUUID{
|
|
UUID: agentID,
|
|
Valid: true,
|
|
},
|
|
}
|
|
|
|
// Stamp the agent with the generated ID.
|
|
agent := tc.agent
|
|
agent.ID = agentID
|
|
|
|
// Set up the DB mock for GetWorkspaceAgentByID.
|
|
if tc.dbError {
|
|
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).
|
|
Return(database.WorkspaceAgent{}, xerrors.New("connection reset")).
|
|
Times(1)
|
|
} else {
|
|
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).
|
|
Return(agent, nil).
|
|
Times(1)
|
|
}
|
|
|
|
var releaseCalled bool
|
|
|
|
server := &Server{
|
|
db: db,
|
|
logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
|
|
clock: quartz.NewReal(),
|
|
agentInactiveDisconnectTimeout: 30 * time.Second,
|
|
dialTimeout: defaultDialTimeout,
|
|
}
|
|
server.agentConnFn = func(context.Context, uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
|
return nil, nil, xerrors.New("should not be called")
|
|
}
|
|
|
|
chatStateMu := &sync.Mutex{}
|
|
currentChat := chat
|
|
cachedConn := agentconnmock.NewMockAgentConn(ctrl)
|
|
workspaceCtx := turnWorkspaceContext{
|
|
server: server,
|
|
chatStateMu: chatStateMu,
|
|
currentChat: ¤tChat,
|
|
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) {
|
|
return database.Chat{}, nil
|
|
},
|
|
agent: agent,
|
|
agentLoaded: true,
|
|
conn: cachedConn,
|
|
releaseConn: func() { releaseCalled = true },
|
|
cachedWorkspaceID: chat.WorkspaceID,
|
|
}
|
|
defer workspaceCtx.close()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
gotConn, err := workspaceCtx.getWorkspaceConn(ctx)
|
|
|
|
if tc.wantErr != nil {
|
|
require.Nil(t, gotConn)
|
|
require.ErrorIs(t, err, tc.wantErr)
|
|
} else {
|
|
require.NoError(t, err)
|
|
require.Same(t, cachedConn, gotConn)
|
|
}
|
|
|
|
require.Equal(t, tc.wantReleaseCalled, releaseCalled, "release called")
|
|
|
|
// For cache-hit disconnect, the cache should be cleared.
|
|
if tc.wantErr != nil {
|
|
workspaceCtx.mu.Lock()
|
|
defer workspaceCtx.mu.Unlock()
|
|
require.False(t, workspaceCtx.agentLoaded)
|
|
require.Nil(t, workspaceCtx.conn)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGetWorkspaceConn_DialTimeout(t *testing.T) {
|
|
// When dialWithLazyValidation blocks beyond the dial
|
|
// timeout, getWorkspaceConn should return
|
|
// errChatDialTimeout instead of hanging indefinitely.
|
|
t.Parallel()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
workspaceID := uuid.New()
|
|
agentID := uuid.New()
|
|
chat := database.Chat{
|
|
ID: uuid.New(),
|
|
WorkspaceID: uuid.NullUUID{
|
|
UUID: workspaceID,
|
|
Valid: true,
|
|
},
|
|
AgentID: uuid.NullUUID{
|
|
UUID: agentID,
|
|
Valid: true,
|
|
},
|
|
}
|
|
|
|
// Agent appears connected so the status check passes.
|
|
connectedAgent := database.WorkspaceAgent{
|
|
ID: agentID,
|
|
FirstConnectedAt: sql.NullTime{
|
|
Time: time.Now().Add(-1 * time.Minute),
|
|
Valid: true,
|
|
},
|
|
LastConnectedAt: sql.NullTime{
|
|
Time: time.Now(),
|
|
Valid: true,
|
|
},
|
|
}
|
|
|
|
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).
|
|
Return(connectedAgent, nil).
|
|
Times(1)
|
|
|
|
server := &Server{
|
|
db: db,
|
|
clock: quartz.NewReal(),
|
|
agentInactiveDisconnectTimeout: 30 * time.Second,
|
|
dialTimeout: 10 * time.Millisecond,
|
|
}
|
|
// Dial blocks forever (simulates unreachable agent).
|
|
server.agentConnFn = func(ctx context.Context, _ uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
|
<-ctx.Done()
|
|
return nil, nil, ctx.Err()
|
|
}
|
|
|
|
chatStateMu := &sync.Mutex{}
|
|
currentChat := chat
|
|
workspaceCtx := turnWorkspaceContext{
|
|
server: server,
|
|
chatStateMu: chatStateMu,
|
|
currentChat: ¤tChat,
|
|
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
|
|
}
|
|
defer workspaceCtx.close()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
gotConn, err := workspaceCtx.getWorkspaceConn(ctx)
|
|
require.Nil(t, gotConn)
|
|
require.ErrorIs(t, err, errChatDialTimeout)
|
|
}
|
|
|
|
func TestGetWorkspaceConn_DialTimeoutParentCanceled(t *testing.T) {
|
|
// When the parent context is canceled, the parent's error
|
|
// must propagate unchanged (not wrapped as a dial timeout).
|
|
// This is critical because the chatloop checks
|
|
// context.Cause(ctx) for ErrInterrupted.
|
|
t.Parallel()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
workspaceID := uuid.New()
|
|
agentID := uuid.New()
|
|
chat := database.Chat{
|
|
ID: uuid.New(),
|
|
WorkspaceID: uuid.NullUUID{
|
|
UUID: workspaceID,
|
|
Valid: true,
|
|
},
|
|
AgentID: uuid.NullUUID{
|
|
UUID: agentID,
|
|
Valid: true,
|
|
},
|
|
}
|
|
|
|
connectedAgent := database.WorkspaceAgent{
|
|
ID: agentID,
|
|
FirstConnectedAt: sql.NullTime{
|
|
Time: time.Now().Add(-1 * time.Minute),
|
|
Valid: true,
|
|
},
|
|
LastConnectedAt: sql.NullTime{
|
|
Time: time.Now(),
|
|
Valid: true,
|
|
},
|
|
}
|
|
|
|
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).
|
|
Return(connectedAgent, nil).
|
|
Times(1)
|
|
|
|
parentErr := xerrors.New("parent canceled")
|
|
ctx, cancel := context.WithCancelCause(testutil.Context(t, testutil.WaitShort))
|
|
|
|
server := &Server{
|
|
db: db,
|
|
clock: quartz.NewReal(),
|
|
agentInactiveDisconnectTimeout: 30 * time.Second,
|
|
// Use a very long dial timeout so the parent cancel fires
|
|
// first.
|
|
dialTimeout: 10 * time.Minute,
|
|
}
|
|
// Signal when the dial goroutine has started so we can
|
|
// cancel the parent at the right time without time.Sleep.
|
|
dialStarted := make(chan struct{})
|
|
server.agentConnFn = func(ctx context.Context, _ uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
|
close(dialStarted)
|
|
<-ctx.Done()
|
|
return nil, nil, ctx.Err()
|
|
}
|
|
|
|
chatStateMu := &sync.Mutex{}
|
|
currentChat := chat
|
|
workspaceCtx := turnWorkspaceContext{
|
|
server: server,
|
|
chatStateMu: chatStateMu,
|
|
currentChat: ¤tChat,
|
|
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
|
|
}
|
|
defer workspaceCtx.close()
|
|
|
|
// Cancel the parent after the dial starts.
|
|
go func() {
|
|
<-dialStarted
|
|
cancel(parentErr)
|
|
}()
|
|
|
|
gotConn, err := workspaceCtx.getWorkspaceConn(ctx)
|
|
require.Nil(t, gotConn)
|
|
// The error must NOT be errChatDialTimeout.
|
|
require.NotErrorIs(t, err, errChatDialTimeout)
|
|
// The parent context's error should propagate.
|
|
require.Error(t, err)
|
|
require.ErrorIs(t, err, context.Canceled)
|
|
}
|
|
|
|
func TestGetWorkspaceConn_DialErrorNotMisclassifiedAsTimeout(t *testing.T) {
|
|
// Regression test: a non-timeout dial error (e.g. auth
|
|
// failure) with the parent context still alive must NOT be
|
|
// converted to errChatDialTimeout. Before the fix,
|
|
// dialCancel() poisoned dialCtx.Err(), causing all errors
|
|
// to be misclassified.
|
|
t.Parallel()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
workspaceID := uuid.New()
|
|
agentID := uuid.New()
|
|
chat := database.Chat{
|
|
ID: uuid.New(),
|
|
WorkspaceID: uuid.NullUUID{
|
|
UUID: workspaceID,
|
|
Valid: true,
|
|
},
|
|
AgentID: uuid.NullUUID{
|
|
UUID: agentID,
|
|
Valid: true,
|
|
},
|
|
}
|
|
|
|
connectedAgent := database.WorkspaceAgent{
|
|
ID: agentID,
|
|
FirstConnectedAt: sql.NullTime{
|
|
Time: time.Now().Add(-1 * time.Minute),
|
|
Valid: true,
|
|
},
|
|
LastConnectedAt: sql.NullTime{
|
|
Time: time.Now(),
|
|
Valid: true,
|
|
},
|
|
}
|
|
|
|
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).
|
|
Return(connectedAgent, nil).
|
|
Times(1)
|
|
// When the initial dial fails immediately, dialWithLazyValidation
|
|
// calls resolveFastFailure which validates the binding. Mock the
|
|
// validation to return the same agent, triggering a synchronous
|
|
// redial that also returns the error.
|
|
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).
|
|
Return([]database.WorkspaceAgent{connectedAgent}, nil).
|
|
AnyTimes()
|
|
|
|
dialErr := xerrors.New("authentication failed")
|
|
server := &Server{
|
|
db: db,
|
|
clock: quartz.NewReal(),
|
|
agentInactiveDisconnectTimeout: 30 * time.Second,
|
|
// Generous timeout so the dial error fires well before
|
|
// the timeout.
|
|
dialTimeout: defaultDialTimeout,
|
|
}
|
|
server.agentConnFn = func(context.Context, uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
|
// Return an error immediately (not a timeout).
|
|
return nil, nil, dialErr
|
|
}
|
|
|
|
chatStateMu := &sync.Mutex{}
|
|
currentChat := chat
|
|
workspaceCtx := turnWorkspaceContext{
|
|
server: server,
|
|
chatStateMu: chatStateMu,
|
|
currentChat: ¤tChat,
|
|
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
|
|
}
|
|
defer workspaceCtx.close()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
gotConn, err := workspaceCtx.getWorkspaceConn(ctx)
|
|
require.Nil(t, gotConn)
|
|
// Must NOT be misclassified as a dial timeout.
|
|
require.NotErrorIs(t, err, errChatDialTimeout)
|
|
// The original dial error should propagate.
|
|
require.ErrorContains(t, err, "authentication failed")
|
|
}
|
|
|
|
// TestAutoPromote_InsertFailureRollsBackTransaction verifies that when
|
|
// tryAutoPromoteQueuedMessage pops a queued message but the subsequent
|
|
// insert fails, the error propagates to the InTx callback, causing the
|
|
// transaction to roll back and preserving the queued message.
|
|
func TestAutoPromote_InsertFailureRollsBackTransaction(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
tx := dbmock.NewMockStore(ctrl)
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
ps := dbpubsub.NewInMemory()
|
|
clock := quartz.NewReal()
|
|
|
|
chatID := uuid.New()
|
|
workerID := uuid.New()
|
|
ownerID := uuid.New()
|
|
modelConfigID := uuid.New()
|
|
|
|
waitingChat := database.Chat{
|
|
ID: chatID,
|
|
OwnerID: ownerID,
|
|
LastModelConfigID: modelConfigID,
|
|
Status: database.ChatStatusWaiting,
|
|
WorkerID: uuid.NullUUID{UUID: workerID, Valid: true},
|
|
}
|
|
queuedMsg := database.ChatQueuedMessage{
|
|
ID: 1,
|
|
ChatID: chatID,
|
|
Content: []byte(`[{"type":"text","text":"queued"}]`),
|
|
}
|
|
insertErr := xerrors.New("insert failed")
|
|
|
|
server := &Server{
|
|
db: db,
|
|
logger: logger,
|
|
pubsub: ps,
|
|
configCache: newChatConfigCache(ctx, db, clock),
|
|
}
|
|
|
|
// The caller runs tryAutoPromoteQueuedMessage inside InTx.
|
|
// Wire the mock to execute the callback against the TX mock.
|
|
var txErr error
|
|
db.EXPECT().InTx(gomock.Any(), gomock.Any()).DoAndReturn(
|
|
func(fn func(database.Store) error, _ *database.TxOptions) error {
|
|
txErr = fn(tx)
|
|
return txErr
|
|
},
|
|
)
|
|
|
|
// Inside the TX: lock chat, get queued messages, resolve model
|
|
// config, pop queued message, insert fails.
|
|
tx.EXPECT().GetChatByIDForUpdate(gomock.Any(), chatID).Return(waitingChat, nil)
|
|
tx.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return([]database.ChatQueuedMessage{queuedMsg}, nil)
|
|
tx.EXPECT().GetChatModelConfigByID(gomock.Any(), modelConfigID).Return(database.ChatModelConfig{ID: modelConfigID}, nil)
|
|
tx.EXPECT().PopNextQueuedMessage(gomock.Any(), chatID).Return(queuedMsg, nil)
|
|
tx.EXPECT().InsertChatMessages(gomock.Any(), gomock.Any()).Return(nil, insertErr)
|
|
|
|
// Invoke tryAutoPromoteQueuedMessage through the same InTx
|
|
// pattern the processChat defer uses. The test directly calls
|
|
// the production path to verify error propagation.
|
|
_ = db.InTx(func(txStore database.Store) error {
|
|
latestChat, err := txStore.GetChatByIDForUpdate(ctx, chatID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
_, _, _, promoteErr := server.tryAutoPromoteQueuedMessage(ctx, txStore, latestChat)
|
|
if promoteErr != nil {
|
|
return promoteErr
|
|
}
|
|
|
|
// This code path should not be reached when the insert
|
|
// fails, because promoteErr should be non-nil.
|
|
return nil
|
|
}, nil)
|
|
|
|
// The InTx callback must return a non-nil error so the
|
|
// transaction rolls back, preserving the queued message.
|
|
require.Error(t, txErr, "InTx callback should return error when insert fails")
|
|
}
|
|
|
|
// TestAutoPromote_WakesRunLoopAfterPromotion verifies that after the
|
|
func TestAutoPromote_InsertFailureSkipsStatusUpdate(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
tx := dbmock.NewMockStore(ctrl)
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
ps := dbpubsub.NewInMemory()
|
|
clock := quartz.NewReal()
|
|
|
|
chatID := uuid.New()
|
|
workerID := uuid.New()
|
|
ownerID := uuid.New()
|
|
modelConfigID := uuid.New()
|
|
|
|
waitingChat := database.Chat{
|
|
ID: chatID,
|
|
OwnerID: ownerID,
|
|
LastModelConfigID: modelConfigID,
|
|
Status: database.ChatStatusWaiting,
|
|
WorkerID: uuid.NullUUID{UUID: workerID, Valid: true},
|
|
}
|
|
queuedMsg := database.ChatQueuedMessage{
|
|
ID: 1,
|
|
ChatID: chatID,
|
|
Content: []byte(`[{"type":"text","text":"queued"}]`),
|
|
}
|
|
|
|
wakeCh := make(chan struct{}, 1)
|
|
server := &Server{
|
|
db: db,
|
|
logger: logger,
|
|
pubsub: ps,
|
|
clock: clock,
|
|
workerID: workerID,
|
|
wakeCh: wakeCh,
|
|
chatHeartbeatInterval: time.Minute,
|
|
metrics: chatloop.NopMetrics(),
|
|
configCache: newChatConfigCache(ctx, db, clock),
|
|
heartbeatRegistry: make(map[uuid.UUID]*heartbeatEntry),
|
|
}
|
|
|
|
// Hold model resolution until the interrupt has canceled the chat
|
|
// context. Returning ErrInterrupted keeps processChat on the
|
|
// interrupted path regardless of whether the cache singleflight sees
|
|
// the caller cancellation or the DB fetch result first.
|
|
modelBlocked := make(chan struct{})
|
|
modelRelease := make(chan struct{})
|
|
var modelBlockedOnce sync.Once
|
|
db.EXPECT().GetChatModelConfigByID(gomock.Any(), gomock.Any()).DoAndReturn(
|
|
func(_ context.Context, _ uuid.UUID) (database.ChatModelConfig, error) {
|
|
modelBlockedOnce.Do(func() { close(modelBlocked) })
|
|
<-modelRelease
|
|
return database.ChatModelConfig{}, chatloop.ErrInterrupted
|
|
},
|
|
).AnyTimes()
|
|
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return(nil, nil).AnyTimes()
|
|
db.EXPECT().GetEnabledChatModelConfigs(gomock.Any()).Return(nil, nil).AnyTimes()
|
|
db.EXPECT().GetChatUsageLimitConfig(gomock.Any()).Return(
|
|
database.ChatUsageLimitConfig{}, sql.ErrNoRows,
|
|
).AnyTimes()
|
|
db.EXPECT().GetChatMessagesForPromptByChatID(gomock.Any(), chatID).Return(nil, nil).AnyTimes()
|
|
|
|
// The deferred cleanup transaction: InsertChatMessages fails,
|
|
// so UpdateChatStatus must NOT be called.
|
|
db.EXPECT().InTx(gomock.Any(), gomock.Any()).DoAndReturn(
|
|
func(fn func(database.Store) error, _ *database.TxOptions) error {
|
|
return fn(tx)
|
|
},
|
|
)
|
|
tx.EXPECT().GetChatByIDForUpdate(gomock.Any(), chatID).Return(waitingChat, nil)
|
|
tx.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return([]database.ChatQueuedMessage{queuedMsg}, nil)
|
|
tx.EXPECT().GetChatModelConfigByID(gomock.Any(), modelConfigID).Return(database.ChatModelConfig{ID: modelConfigID}, nil)
|
|
tx.EXPECT().PopNextQueuedMessage(gomock.Any(), chatID).Return(queuedMsg, nil)
|
|
tx.EXPECT().InsertChatMessages(gomock.Any(), gomock.Any()).Return(
|
|
nil, xerrors.New("insert failed"),
|
|
)
|
|
tx.EXPECT().UpdateChatStatus(gomock.Any(), gomock.Any()).Times(0)
|
|
|
|
// Subscribe BEFORE launching the goroutine.
|
|
runningCh := make(chan struct{}, 1)
|
|
unsubRunning, err := ps.SubscribeWithErr(
|
|
coderdpubsub.ChatStreamNotifyChannel(chatID),
|
|
func(_ context.Context, msg []byte, err error) {
|
|
if err != nil {
|
|
return
|
|
}
|
|
var notify coderdpubsub.ChatStreamNotifyMessage
|
|
if json.Unmarshal(msg, ¬ify) != nil {
|
|
return
|
|
}
|
|
if notify.Status == string(database.ChatStatusRunning) {
|
|
select {
|
|
case runningCh <- struct{}{}:
|
|
default:
|
|
}
|
|
}
|
|
},
|
|
)
|
|
require.NoError(t, err)
|
|
defer unsubRunning()
|
|
|
|
chat := database.Chat{ID: chatID, OwnerID: ownerID, LastModelConfigID: modelConfigID}
|
|
processDone := make(chan struct{})
|
|
go func() {
|
|
defer close(processDone)
|
|
server.processChat(ctx, chat)
|
|
}()
|
|
|
|
select {
|
|
case <-runningCh:
|
|
case <-ctx.Done():
|
|
t.Fatal("timed out waiting for running status")
|
|
}
|
|
|
|
select {
|
|
case <-modelBlocked:
|
|
case <-ctx.Done():
|
|
t.Fatal("timed out waiting for model resolution")
|
|
}
|
|
|
|
// Publish an interrupt so processChat exits runChat.
|
|
interruptMsg, err := json.Marshal(coderdpubsub.ChatStreamNotifyMessage{
|
|
Status: string(database.ChatStatusWaiting),
|
|
})
|
|
require.NoError(t, err)
|
|
err = ps.Publish(coderdpubsub.ChatStreamNotifyChannel(chatID), interruptMsg)
|
|
require.NoError(t, err)
|
|
close(modelRelease)
|
|
|
|
select {
|
|
case <-processDone:
|
|
case <-ctx.Done():
|
|
t.Fatal("processChat did not complete")
|
|
}
|
|
|
|
// The wake channel should NOT have a signal because the
|
|
// transaction failed before reaching UpdateChatStatus.
|
|
select {
|
|
case <-wakeCh:
|
|
t.Fatal("wake channel should not have a signal after insert failure")
|
|
default:
|
|
// No signal, as expected.
|
|
}
|
|
}
|