mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
ec1e861152
The subscriber advanced a single delivery cursor on each notify and trusted it for both lookups. Concurrent publishMessage calls and PG NOTIFY commit ordering let cache appends and notifies arrive out of ID order, after which a late notify would scan above its own message and drop it. The DB fallback was also skipped whenever the cache delivered anything, hiding cross-replica messages that only the DB held. The cursor becomes a high-water mark, not the lookup key. Notifies trigger a rescan over the gap they describe and dedupe per subscription, and the DB pass runs every time so cross-replica messages can't get eaten by a local cache hit. Closes coder/internal#1525 Closes CODAGT-357
6520 lines
199 KiB
Go
6520 lines
199 KiB
Go
package chatd
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"charm.land/fantasy"
|
|
"github.com/google/uuid"
|
|
"github.com/sqlc-dev/pqtype"
|
|
"github.com/stretchr/testify/require"
|
|
"go.uber.org/mock/gomock"
|
|
"golang.org/x/xerrors"
|
|
|
|
"cdr.dev/slog/v3"
|
|
"cdr.dev/slog/v3/sloggers/slogtest"
|
|
"github.com/coder/coder/v2/coderd/database"
|
|
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
|
"github.com/coder/coder/v2/coderd/database/dbmock"
|
|
dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub"
|
|
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
|
|
"github.com/coder/coder/v2/coderd/rbac"
|
|
"github.com/coder/coder/v2/coderd/x/chatd/chaterror"
|
|
"github.com/coder/coder/v2/coderd/x/chatd/chatloop"
|
|
openaicomputeruse "github.com/coder/coder/v2/coderd/x/chatd/chatopenai/computeruse"
|
|
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
|
|
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
|
|
"github.com/coder/coder/v2/coderd/x/chatd/chattest"
|
|
"github.com/coder/coder/v2/coderd/x/chatd/chattool"
|
|
skillspkg "github.com/coder/coder/v2/coderd/x/skills"
|
|
"github.com/coder/coder/v2/codersdk"
|
|
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
|
"github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock"
|
|
"github.com/coder/coder/v2/testutil"
|
|
"github.com/coder/quartz"
|
|
)
|
|
|
|
type testAgentTool struct {
|
|
info fantasy.ToolInfo
|
|
providerOptions fantasy.ProviderOptions
|
|
}
|
|
|
|
func newTestAgentTool(name string) fantasy.AgentTool {
|
|
return &testAgentTool{info: fantasy.ToolInfo{Name: name}}
|
|
}
|
|
|
|
func (t *testAgentTool) Info() fantasy.ToolInfo {
|
|
return t.info
|
|
}
|
|
|
|
func (t *testAgentTool) Run(context.Context, fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
|
_ = t
|
|
return fantasy.ToolResponse{}, nil
|
|
}
|
|
|
|
func (t *testAgentTool) ProviderOptions() fantasy.ProviderOptions {
|
|
return t.providerOptions
|
|
}
|
|
|
|
func (t *testAgentTool) SetProviderOptions(opts fantasy.ProviderOptions) {
|
|
t.providerOptions = opts
|
|
}
|
|
|
|
type testMCPAgentTool struct {
|
|
*testAgentTool
|
|
configID uuid.UUID
|
|
}
|
|
|
|
func newTestMCPAgentTool(name string, configID uuid.UUID) fantasy.AgentTool {
|
|
return &testMCPAgentTool{
|
|
testAgentTool: &testAgentTool{info: fantasy.ToolInfo{Name: name}},
|
|
configID: configID,
|
|
}
|
|
}
|
|
|
|
func (t *testMCPAgentTool) MCPServerConfigID() uuid.UUID {
|
|
return t.configID
|
|
}
|
|
|
|
func TestComputerUseProviderAndModelFromConfig(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
rawProvider string
|
|
wantProvider string
|
|
wantErr string
|
|
}{
|
|
{
|
|
name: "DefaultAnthropic",
|
|
rawProvider: "",
|
|
wantProvider: chattool.ComputerUseProviderAnthropic,
|
|
},
|
|
{
|
|
name: "OpenAI",
|
|
rawProvider: " openai ",
|
|
wantProvider: chattool.ComputerUseProviderOpenAI,
|
|
},
|
|
{
|
|
name: "Unknown",
|
|
rawProvider: "bogus",
|
|
wantErr: `unknown computer-use provider "bogus" configured in agents_computer_use_provider`,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
server := &Server{db: db}
|
|
|
|
db.EXPECT().GetChatComputerUseProvider(gomock.Any()).DoAndReturn(
|
|
func(ctx context.Context) (string, error) {
|
|
_, ok := dbauthz.ActorFromContext(ctx)
|
|
require.True(t, ok, "config reads must have an actor")
|
|
return tt.rawProvider, nil
|
|
},
|
|
)
|
|
|
|
provider, modelProvider, modelName, err := server.computerUseProviderAndModelFromConfig(context.Background())
|
|
if tt.wantErr != "" {
|
|
require.Error(t, err)
|
|
require.Contains(t, err.Error(), tt.wantErr)
|
|
return
|
|
}
|
|
require.NoError(t, err)
|
|
require.Equal(t, tt.wantProvider, provider)
|
|
|
|
wantModelProvider, wantModelName, ok := chattool.DefaultComputerUseModel(tt.wantProvider)
|
|
require.True(t, ok)
|
|
require.Equal(t, wantModelProvider, modelProvider)
|
|
require.Equal(t, wantModelName, modelName)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestResolveComputerUseModel_OpenAIMissingCredentials(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := &Server{}
|
|
provider := chattool.ComputerUseProviderOpenAI
|
|
modelProvider, modelName, ok := chattool.DefaultComputerUseModel(provider)
|
|
require.True(t, ok)
|
|
|
|
model, debugEnabled, resolvedProvider, resolvedModel, err := server.resolveComputerUseModel(
|
|
context.Background(),
|
|
database.Chat{ID: uuid.New(), OwnerID: uuid.New()},
|
|
chatprovider.ProviderAPIKeys{},
|
|
provider,
|
|
modelProvider,
|
|
modelName,
|
|
)
|
|
require.Error(t, err)
|
|
require.Nil(t, model)
|
|
require.False(t, debugEnabled)
|
|
require.Empty(t, resolvedProvider)
|
|
require.Empty(t, resolvedModel)
|
|
require.Contains(t, err.Error(), `provider "openai" model "gpt-5.5"`)
|
|
require.Contains(t, err.Error(), "OPENAI_API_KEY is not set")
|
|
require.NotContains(t, err.Error(), "ANTHROPIC_API_KEY")
|
|
}
|
|
|
|
func TestAppendComputerUseProviderTool(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
providerTools, err := appendComputerUseProviderTool(
|
|
nil,
|
|
computerUseProviderToolOptions{
|
|
provider: chattool.ComputerUseProviderOpenAI,
|
|
isComputerUse: true,
|
|
logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
|
|
},
|
|
)
|
|
require.NoError(t, err)
|
|
require.Len(t, providerTools, 1)
|
|
require.True(t, openaicomputeruse.IsTool(providerTools[0].Definition))
|
|
require.Equal(t, "computer", providerTools[0].Definition.GetName())
|
|
require.Equal(t, "computer", providerTools[0].Runner.Info().Name)
|
|
require.NotNil(t, providerTools[0].ResultProviderMetadata)
|
|
|
|
metadata := providerTools[0].ResultProviderMetadata(
|
|
fantasy.NewImageResponse([]byte("png"), "image/png"),
|
|
)
|
|
require.NotNil(t, metadata)
|
|
}
|
|
|
|
func TestAppendComputerUseProviderTool_Gates(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
baseTools := []chatloop.ProviderTool{{
|
|
Definition: fantasy.ProviderDefinedTool{
|
|
ID: "web_search",
|
|
Name: "web_search",
|
|
},
|
|
}}
|
|
|
|
tests := []struct {
|
|
name string
|
|
isPlanModeTurn bool
|
|
isComputerUse bool
|
|
}{
|
|
{name: "PlanMode", isPlanModeTurn: true, isComputerUse: true},
|
|
// Non-computer-use includes regular, master, general, and explore chats.
|
|
// Mode cannot be both ChatModeComputerUse and another chat mode.
|
|
{name: "NonComputerUseModes"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
providerTools, err := appendComputerUseProviderTool(
|
|
baseTools,
|
|
computerUseProviderToolOptions{
|
|
provider: chattool.ComputerUseProviderOpenAI,
|
|
isPlanModeTurn: tt.isPlanModeTurn,
|
|
isComputerUse: tt.isComputerUse,
|
|
},
|
|
)
|
|
require.NoError(t, err)
|
|
require.Len(t, providerTools, 1)
|
|
require.Equal(t, "web_search", providerTools[0].Definition.GetName())
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAppendComputerUseProviderTool_AnthropicHasNoResultMetadata(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
providerTools, err := appendComputerUseProviderTool(
|
|
nil,
|
|
computerUseProviderToolOptions{
|
|
provider: chattool.ComputerUseProviderAnthropic,
|
|
isComputerUse: true,
|
|
logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
|
|
},
|
|
)
|
|
require.NoError(t, err)
|
|
require.Len(t, providerTools, 1)
|
|
require.Equal(t, "computer", providerTools[0].Definition.GetName())
|
|
require.Nil(t, providerTools[0].ResultProviderMetadata)
|
|
}
|
|
|
|
func TestFilterExternalMCPConfigsForTurn(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
approvedConfig := database.MCPServerConfig{ID: uuid.New(), AllowInPlanMode: true}
|
|
blockedConfig := database.MCPServerConfig{ID: uuid.New(), AllowInPlanMode: false}
|
|
configs := []database.MCPServerConfig{approvedConfig, blockedConfig}
|
|
planMode := database.NullChatPlanMode{
|
|
ChatPlanMode: database.ChatPlanModePlan,
|
|
Valid: true,
|
|
}
|
|
|
|
t.Run("NonPlanModePassesThroughAllConfigs", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
filtered, approvedIDs := filterExternalMCPConfigsForTurn(
|
|
configs,
|
|
database.NullChatPlanMode{},
|
|
uuid.NullUUID{},
|
|
)
|
|
|
|
require.Equal(t, configs, filtered)
|
|
require.Nil(t, approvedIDs)
|
|
})
|
|
|
|
t.Run("PlanModeSubagentsReturnNoConfigs", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
filtered, approvedIDs := filterExternalMCPConfigsForTurn(
|
|
configs,
|
|
planMode,
|
|
uuid.NullUUID{UUID: uuid.New(), Valid: true},
|
|
)
|
|
|
|
require.Nil(t, filtered)
|
|
require.NotNil(t, approvedIDs)
|
|
require.Empty(t, approvedIDs)
|
|
})
|
|
|
|
t.Run("PlanModeRootFiltersToApprovedConfigs", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
filtered, approvedIDs := filterExternalMCPConfigsForTurn(
|
|
configs,
|
|
planMode,
|
|
uuid.NullUUID{},
|
|
)
|
|
|
|
require.Equal(t, []database.MCPServerConfig{approvedConfig}, filtered)
|
|
require.Equal(t, map[uuid.UUID]struct{}{approvedConfig.ID: {}}, approvedIDs)
|
|
})
|
|
}
|
|
|
|
func TestChatWorkspaceRecoveryErrorsDifferentiateSignalStrength(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// Disconnected recovery is gated by a DB-confirmed duration
|
|
// threshold, so the message can give direct stop/start guidance
|
|
// without asking the user.
|
|
disconnected := errChatAgentDisconnected.Error()
|
|
require.Contains(t, disconnected, "90 seconds")
|
|
require.Contains(t, disconnected, "stop_workspace")
|
|
require.Contains(t, disconnected, "start_workspace")
|
|
require.NotContains(t, disconnected, "ask_user_question")
|
|
|
|
// Dial timeout alone is a weak signal. The model should not
|
|
// escalate to lifecycle tools without DB-confirmed disconnect.
|
|
dialTimeout := errChatDialTimeout.Error()
|
|
require.NotContains(t, dialTimeout, "ask_user_question")
|
|
require.NotContains(t, dialTimeout, "stop_workspace")
|
|
require.NotContains(t, dialTimeout, "start_workspace")
|
|
}
|
|
|
|
func TestActiveToolNamesForTurn(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
makeTools := func(names ...string) []fantasy.AgentTool {
|
|
tools := make([]fantasy.AgentTool, 0, len(names))
|
|
for _, name := range names {
|
|
tools = append(tools, newTestAgentTool(name))
|
|
}
|
|
return tools
|
|
}
|
|
|
|
planMode := database.NullChatPlanMode{
|
|
ChatPlanMode: database.ChatPlanModePlan,
|
|
Valid: true,
|
|
}
|
|
|
|
t.Run("NormalModeReturnsAllRegisteredTools", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
got := activeToolNamesForTurn(makeTools(
|
|
"read_file",
|
|
"propose_plan",
|
|
"custom_tool",
|
|
"execute",
|
|
), database.NullChatPlanMode{}, uuid.NullUUID{}, nil)
|
|
|
|
require.Equal(t, []string{
|
|
"read_file",
|
|
"propose_plan",
|
|
"custom_tool",
|
|
"execute",
|
|
}, got)
|
|
})
|
|
|
|
t.Run("PlanModeIncludesOnlyAllowlistedBuiltIns", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
got := activeToolNamesForTurn(makeTools(
|
|
"read_file",
|
|
"write_file",
|
|
"edit_files",
|
|
"execute",
|
|
"process_output",
|
|
"process_list",
|
|
"process_signal",
|
|
"list_templates",
|
|
"read_template",
|
|
"create_workspace",
|
|
"start_workspace",
|
|
"stop_workspace",
|
|
"propose_plan",
|
|
"spawn_agent",
|
|
"wait_agent",
|
|
"message_agent",
|
|
"close_agent",
|
|
"read_skill",
|
|
"read_skill_file",
|
|
"ask_user_question",
|
|
), planMode, uuid.NullUUID{}, nil)
|
|
|
|
require.Equal(t, []string{
|
|
"read_file",
|
|
"write_file",
|
|
"edit_files",
|
|
"execute",
|
|
"process_output",
|
|
"list_templates",
|
|
"read_template",
|
|
"create_workspace",
|
|
"start_workspace",
|
|
"stop_workspace",
|
|
"propose_plan",
|
|
"spawn_agent",
|
|
"wait_agent",
|
|
"read_skill",
|
|
"read_skill_file",
|
|
"ask_user_question",
|
|
}, got)
|
|
})
|
|
|
|
t.Run("PlanModeChildChatsAllowExplorationOnly", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
got := activeToolNamesForTurn(makeTools(
|
|
"read_file",
|
|
"write_file",
|
|
"edit_files",
|
|
"execute",
|
|
"process_output",
|
|
"list_templates",
|
|
"read_template",
|
|
"create_workspace",
|
|
"start_workspace",
|
|
"stop_workspace",
|
|
"propose_plan",
|
|
"spawn_agent",
|
|
"wait_agent",
|
|
"read_skill",
|
|
"read_skill_file",
|
|
"ask_user_question",
|
|
), planMode, uuid.NullUUID{UUID: uuid.New(), Valid: true}, nil)
|
|
|
|
require.Equal(t, []string{
|
|
"read_file",
|
|
"execute",
|
|
"process_output",
|
|
"read_skill",
|
|
"read_skill_file",
|
|
}, got)
|
|
require.NotContains(t, got, "write_file")
|
|
require.NotContains(t, got, "edit_files")
|
|
require.NotContains(t, got, "ask_user_question")
|
|
require.NotContains(t, got, "propose_plan")
|
|
require.NotContains(t, got, "start_workspace")
|
|
require.NotContains(t, got, "stop_workspace")
|
|
require.NotContains(t, got, "spawn_explore_agent")
|
|
})
|
|
|
|
t.Run("PlanModeStillExcludesDangerousTools", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
got := activeToolNamesForTurn(makeTools(
|
|
"execute",
|
|
"process_output",
|
|
"message_agent",
|
|
"spawn_computer_use_agent",
|
|
"propose_plan",
|
|
), planMode, uuid.NullUUID{}, nil)
|
|
|
|
require.Equal(t, []string{"execute", "process_output", "propose_plan"}, got)
|
|
require.NotContains(t, got, "message_agent")
|
|
require.NotContains(t, got, "spawn_computer_use_agent")
|
|
})
|
|
|
|
t.Run("PlanModeExcludesUnknownTools", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
got := activeToolNamesForTurn(makeTools(
|
|
"read_file",
|
|
"custom_tool",
|
|
"another_custom_tool",
|
|
"propose_plan",
|
|
), planMode, uuid.NullUUID{}, nil)
|
|
|
|
require.Equal(t, []string{
|
|
"read_file",
|
|
"propose_plan",
|
|
}, got)
|
|
require.NotContains(t, got, "custom_tool")
|
|
require.NotContains(t, got, "another_custom_tool")
|
|
})
|
|
|
|
t.Run("PlanModeIncludesOnlyApprovedExternalMCPTools", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
approvedConfigID := uuid.New()
|
|
blockedConfigID := uuid.New()
|
|
got := activeToolNamesForTurn([]fantasy.AgentTool{
|
|
newTestAgentTool("read_file"),
|
|
newTestMCPAgentTool("approved-mcp__echo", approvedConfigID),
|
|
newTestMCPAgentTool("blocked-mcp__echo", blockedConfigID),
|
|
newTestAgentTool("workspace-mcp__echo"),
|
|
}, planMode, uuid.NullUUID{}, map[uuid.UUID]struct{}{
|
|
approvedConfigID: {},
|
|
})
|
|
|
|
require.Equal(t, []string{
|
|
"read_file",
|
|
"approved-mcp__echo",
|
|
}, got)
|
|
require.NotContains(t, got, "blocked-mcp__echo")
|
|
require.NotContains(t, got, "workspace-mcp__echo")
|
|
})
|
|
}
|
|
|
|
func TestAllowedExploreToolNames(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
externalConfigID := uuid.New()
|
|
got := allowedExploreToolNames([]fantasy.AgentTool{
|
|
newTestAgentTool("read_file"),
|
|
newTestAgentTool("write_file"),
|
|
newTestMCPAgentTool("external-mcp__echo", externalConfigID),
|
|
newTestAgentTool("workspace-mcp__echo"),
|
|
newTestAgentTool("start_workspace"),
|
|
newTestAgentTool("stop_workspace"),
|
|
newTestAgentTool("execute"),
|
|
newTestAgentTool("process_output"),
|
|
newTestAgentTool("process_list"),
|
|
newTestAgentTool("process_signal"),
|
|
newTestAgentTool("spawn_agent"),
|
|
newTestAgentTool("wait_agent"),
|
|
newTestAgentTool("read_skill"),
|
|
newTestAgentTool("read_skill_file"),
|
|
newTestAgentTool("ask_user_question"),
|
|
})
|
|
|
|
require.Equal(t, []string{
|
|
"read_file",
|
|
"external-mcp__echo",
|
|
"execute",
|
|
"process_output",
|
|
"read_skill",
|
|
"read_skill_file",
|
|
}, got)
|
|
require.NotContains(t, got, "workspace-mcp__echo")
|
|
require.NotContains(t, got, "start_workspace")
|
|
require.NotContains(t, got, "stop_workspace")
|
|
require.NotContains(t, got, "ask_user_question")
|
|
}
|
|
|
|
func TestAllowedBehaviorToolNames(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
makeTools := func(names ...string) []fantasy.AgentTool {
|
|
tools := make([]fantasy.AgentTool, 0, len(names))
|
|
for _, name := range names {
|
|
tools = append(tools, newTestAgentTool(name))
|
|
}
|
|
return tools
|
|
}
|
|
|
|
allTools := makeTools("read_file", "custom_tool", "spawn_agent")
|
|
exploreMode := database.NullChatMode{
|
|
ChatMode: database.ChatModeExplore,
|
|
Valid: true,
|
|
}
|
|
|
|
t.Run("DefaultModeReturnsAllTools", func(t *testing.T) {
|
|
t.Parallel()
|
|
require.Equal(t, []string{"read_file", "custom_tool", "spawn_agent"}, allowedBehaviorToolNames(
|
|
allTools,
|
|
database.NullChatMode{},
|
|
))
|
|
})
|
|
|
|
t.Run("ExploreModeUsesExploreAllowlist", func(t *testing.T) {
|
|
t.Parallel()
|
|
require.Equal(t, []string{"read_file"}, allowedBehaviorToolNames(
|
|
allTools,
|
|
exploreMode,
|
|
))
|
|
})
|
|
}
|
|
|
|
func TestStopAfterPlanTools(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
planMode := database.NullChatPlanMode{
|
|
ChatPlanMode: database.ChatPlanModePlan,
|
|
Valid: true,
|
|
}
|
|
|
|
t.Run("NormalModeReturnsNil", func(t *testing.T) {
|
|
t.Parallel()
|
|
require.Nil(t, stopAfterPlanTools(database.NullChatPlanMode{}, uuid.NullUUID{}))
|
|
})
|
|
|
|
t.Run("RootPlanModeIncludesClarificationTool", func(t *testing.T) {
|
|
t.Parallel()
|
|
require.Equal(t, map[string]struct{}{
|
|
"propose_plan": {},
|
|
"ask_user_question": {},
|
|
}, stopAfterPlanTools(planMode, uuid.NullUUID{}))
|
|
})
|
|
|
|
t.Run("ChildPlanModeSkipsClarificationTool", func(t *testing.T) {
|
|
t.Parallel()
|
|
require.Equal(t, map[string]struct{}{
|
|
"propose_plan": {},
|
|
}, stopAfterPlanTools(planMode, uuid.NullUUID{UUID: uuid.New(), Valid: true}))
|
|
})
|
|
}
|
|
|
|
func TestStopAfterBehaviorTools(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
planMode := database.NullChatPlanMode{
|
|
ChatPlanMode: database.ChatPlanModePlan,
|
|
Valid: true,
|
|
}
|
|
exploreMode := database.NullChatMode{
|
|
ChatMode: database.ChatModeExplore,
|
|
Valid: true,
|
|
}
|
|
|
|
t.Run("DefaultModeReturnsNil", func(t *testing.T) {
|
|
t.Parallel()
|
|
require.Nil(t, stopAfterBehaviorTools(
|
|
database.NullChatPlanMode{},
|
|
database.NullChatMode{},
|
|
uuid.NullUUID{},
|
|
))
|
|
})
|
|
|
|
t.Run("PlanModeDelegatesToPlanTools", func(t *testing.T) {
|
|
t.Parallel()
|
|
require.Equal(t, stopAfterPlanTools(planMode, uuid.NullUUID{}), stopAfterBehaviorTools(
|
|
planMode,
|
|
database.NullChatMode{},
|
|
uuid.NullUUID{},
|
|
))
|
|
})
|
|
|
|
t.Run("ExploreModeReturnsNil", func(t *testing.T) {
|
|
t.Parallel()
|
|
require.Nil(t, stopAfterBehaviorTools(planMode, exploreMode, uuid.NullUUID{}))
|
|
})
|
|
}
|
|
|
|
// TestWaitForActiveChatStop and TestWaitForActiveChatStop_WaitsForReplacementRun
|
|
// were removed along with the process-local activeChats mechanism.
|
|
// Debug cleanup is now best-effort; stale finalization handles orphaned rows.
|
|
|
|
// TestArchiveChatWaitsForActiveChatStop and
|
|
// TestArchiveChatWaitsForEveryInterruptedChat were removed along with
|
|
// the process-local activeChats mechanism. Archive cleanup is now
|
|
// best-effort; stale finalization handles any orphaned rows.
|
|
|
|
func TestRenameChatTitle(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
setupRealWorkerLock := func(
|
|
db *dbmock.MockStore,
|
|
chatID uuid.UUID,
|
|
lockedChat database.Chat,
|
|
) {
|
|
lockTx := dbmock.NewMockStore(gomock.NewController(t))
|
|
unlockTx := dbmock.NewMockStore(gomock.NewController(t))
|
|
gomock.InOrder(
|
|
db.EXPECT().InTx(gomock.Any(), database.DefaultTXOptions().WithID("chat_title_regenerate_lock")).DoAndReturn(
|
|
func(fn func(database.Store) error, _ *database.TxOptions) error {
|
|
return fn(lockTx)
|
|
},
|
|
),
|
|
db.EXPECT().InTx(gomock.Any(), database.DefaultTXOptions().WithID("chat_title_regenerate_unlock")).DoAndReturn(
|
|
func(fn func(database.Store) error, _ *database.TxOptions) error {
|
|
return fn(unlockTx)
|
|
},
|
|
),
|
|
)
|
|
lockTx.EXPECT().GetChatByIDForUpdate(gomock.Any(), chatID).Return(lockedChat, nil)
|
|
unlockTx.EXPECT().GetChatByIDForUpdate(gomock.Any(), chatID).Return(lockedChat, nil)
|
|
}
|
|
|
|
t.Run("WritesAndReturnsWroteTrue", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
|
|
chatID := uuid.New()
|
|
workerID := uuid.New()
|
|
stored := database.Chat{
|
|
ID: chatID,
|
|
Status: database.ChatStatusRunning,
|
|
WorkerID: uuid.NullUUID{UUID: workerID, Valid: true},
|
|
Title: "original",
|
|
}
|
|
updated := stored
|
|
updated.Title = "renamed"
|
|
|
|
server := &Server{db: db, logger: logger}
|
|
|
|
setupRealWorkerLock(db, chatID, stored)
|
|
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(stored, nil)
|
|
db.EXPECT().UpdateChatTitleByID(gomock.Any(), database.UpdateChatTitleByIDParams{
|
|
ID: chatID,
|
|
Title: "renamed",
|
|
}).Return(updated, nil)
|
|
|
|
got, wrote, err := server.RenameChatTitle(ctx, stored, "renamed")
|
|
require.NoError(t, err)
|
|
require.True(t, wrote, "fresh rename must report wrote=true")
|
|
require.Equal(t, updated, got)
|
|
})
|
|
|
|
t.Run("SkipsWriteWhenAlreadyAtNewTitle", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
|
|
chatID := uuid.New()
|
|
workerID := uuid.New()
|
|
stale := database.Chat{
|
|
ID: chatID,
|
|
Status: database.ChatStatusRunning,
|
|
WorkerID: uuid.NullUUID{UUID: workerID, Valid: true},
|
|
Title: "pre-race",
|
|
}
|
|
landed := stale
|
|
landed.Title = "landed-concurrently"
|
|
|
|
server := &Server{db: db, logger: logger}
|
|
|
|
setupRealWorkerLock(db, chatID, landed)
|
|
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(landed, nil)
|
|
|
|
got, wrote, err := server.RenameChatTitle(ctx, stale, "landed-concurrently")
|
|
require.NoError(t, err)
|
|
require.False(t, wrote,
|
|
"must report wrote=false when the stored row already matches newTitle so the handler suppresses a redundant title_change event")
|
|
require.Equal(t, landed, got)
|
|
})
|
|
}
|
|
|
|
func TestRegenerateChatTitle_PersistsAndBroadcasts(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
lockTx := dbmock.NewMockStore(ctrl)
|
|
usageTx := dbmock.NewMockStore(ctrl)
|
|
unlockTx := dbmock.NewMockStore(ctrl)
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
pubsub := dbpubsub.NewInMemory()
|
|
clock := quartz.NewReal()
|
|
|
|
ownerID := uuid.New()
|
|
chatID := uuid.New()
|
|
modelConfigID := uuid.New()
|
|
workerID := uuid.New()
|
|
userPrompt := "review pull request 23633 and fix review threads"
|
|
wantTitle := "Review PR 23633"
|
|
|
|
chat := database.Chat{
|
|
ID: chatID,
|
|
OwnerID: ownerID,
|
|
LastModelConfigID: modelConfigID,
|
|
Status: database.ChatStatusRunning,
|
|
WorkerID: uuid.NullUUID{UUID: workerID, Valid: true},
|
|
Title: fallbackChatTitle(userPrompt),
|
|
}
|
|
modelConfig := database.ChatModelConfig{
|
|
ID: modelConfigID,
|
|
Provider: "openai",
|
|
Model: "gpt-4o-mini",
|
|
ContextLimit: 8192,
|
|
}
|
|
updatedChat := chat
|
|
updatedChat.Title = wantTitle
|
|
|
|
messageEvents := make(chan struct {
|
|
payload codersdk.ChatWatchEvent
|
|
err error
|
|
}, 1)
|
|
cancelSub, err := pubsub.SubscribeWithErr(
|
|
coderdpubsub.ChatWatchEventChannel(ownerID),
|
|
coderdpubsub.HandleChatWatchEvent(func(_ context.Context, payload codersdk.ChatWatchEvent, err error) {
|
|
messageEvents <- struct {
|
|
payload codersdk.ChatWatchEvent
|
|
err error
|
|
}{payload: payload, err: err}
|
|
}),
|
|
)
|
|
require.NoError(t, err)
|
|
defer cancelSub()
|
|
|
|
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
require.Equal(t, "gpt-4o-mini", req.Model)
|
|
return chattest.OpenAINonStreamingResponse("{\"title\":\"" + wantTitle + "\"}")
|
|
})
|
|
|
|
server := &Server{
|
|
db: db,
|
|
logger: logger,
|
|
pubsub: pubsub,
|
|
configCache: newChatConfigCache(context.Background(), db, clock),
|
|
}
|
|
|
|
db.EXPECT().GetChatModelConfigByID(gomock.Any(), modelConfigID).Return(modelConfig, nil)
|
|
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{
|
|
Provider: "openai",
|
|
CentralApiKeyEnabled: true,
|
|
APIKey: "test-key",
|
|
BaseUrl: serverURL,
|
|
}}, nil)
|
|
db.EXPECT().GetChatUsageLimitConfig(gomock.Any()).Return(database.ChatUsageLimitConfig{}, sql.ErrNoRows)
|
|
db.EXPECT().GetChatMessagesByChatIDAscPaginated(
|
|
gomock.Any(),
|
|
database.GetChatMessagesByChatIDAscPaginatedParams{
|
|
ChatID: chatID,
|
|
AfterID: 0,
|
|
LimitVal: manualTitleMessageWindowLimit,
|
|
},
|
|
).Return([]database.ChatMessage{
|
|
mustChatMessage(
|
|
t,
|
|
database.ChatMessageRoleUser,
|
|
database.ChatMessageVisibilityBoth,
|
|
codersdk.ChatMessageText(userPrompt),
|
|
),
|
|
mustChatMessage(
|
|
t,
|
|
database.ChatMessageRoleAssistant,
|
|
database.ChatMessageVisibilityBoth,
|
|
codersdk.ChatMessageText("checking the diff now"),
|
|
),
|
|
}, nil)
|
|
db.EXPECT().GetChatMessagesByChatIDDescPaginated(
|
|
gomock.Any(),
|
|
database.GetChatMessagesByChatIDDescPaginatedParams{
|
|
ChatID: chatID,
|
|
BeforeID: 0,
|
|
LimitVal: manualTitleMessageWindowLimit,
|
|
},
|
|
).Return(nil, nil)
|
|
db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return("", nil)
|
|
db.EXPECT().GetEnabledChatModelConfigs(gomock.Any()).Return(nil, nil)
|
|
|
|
gomock.InOrder(
|
|
db.EXPECT().InTx(gomock.Any(), database.DefaultTXOptions().WithID("chat_title_regenerate_lock")).DoAndReturn(
|
|
func(fn func(database.Store) error, opts *database.TxOptions) error {
|
|
require.Equal(t, "chat_title_regenerate_lock", opts.TxIdentifier)
|
|
return fn(lockTx)
|
|
},
|
|
),
|
|
db.EXPECT().InTx(gomock.Any(), nil).DoAndReturn(
|
|
func(fn func(database.Store) error, opts *database.TxOptions) error {
|
|
require.Nil(t, opts)
|
|
return fn(usageTx)
|
|
},
|
|
),
|
|
db.EXPECT().InTx(gomock.Any(), database.DefaultTXOptions().WithID("chat_title_regenerate_unlock")).DoAndReturn(
|
|
func(fn func(database.Store) error, opts *database.TxOptions) error {
|
|
require.Equal(t, "chat_title_regenerate_unlock", opts.TxIdentifier)
|
|
return fn(unlockTx)
|
|
},
|
|
),
|
|
)
|
|
|
|
lockTx.EXPECT().GetChatByIDForUpdate(gomock.Any(), chatID).Return(chat, nil)
|
|
|
|
usageTx.EXPECT().GetChatByIDForUpdate(gomock.Any(), chatID).Return(chat, nil)
|
|
usageTx.EXPECT().InsertChatMessages(gomock.Any(), gomock.AssignableToTypeOf(database.InsertChatMessagesParams{})).DoAndReturn(
|
|
func(_ context.Context, arg database.InsertChatMessagesParams) ([]database.ChatMessage, error) {
|
|
require.Equal(t, []uuid.UUID{ownerID}, arg.CreatedBy)
|
|
require.Equal(t, []uuid.UUID{modelConfigID}, arg.ModelConfigID)
|
|
require.Equal(t, []string{"[]"}, arg.Content)
|
|
return []database.ChatMessage{{ID: 91}}, nil
|
|
},
|
|
)
|
|
usageTx.EXPECT().SoftDeleteChatMessageByID(gomock.Any(), int64(91)).Return(nil)
|
|
usageTx.EXPECT().UpdateChatByID(gomock.Any(), database.UpdateChatByIDParams{
|
|
ID: chatID,
|
|
Title: wantTitle,
|
|
}).Return(updatedChat, nil)
|
|
|
|
unlockTx.EXPECT().GetChatByIDForUpdate(gomock.Any(), chatID).Return(updatedChat, nil)
|
|
|
|
gotChat, err := server.RegenerateChatTitle(ctx, chat)
|
|
require.NoError(t, err)
|
|
require.Equal(t, updatedChat, gotChat)
|
|
|
|
select {
|
|
case event := <-messageEvents:
|
|
require.NoError(t, event.err)
|
|
require.Equal(t, codersdk.ChatWatchEventKindTitleChange, event.payload.Kind)
|
|
require.Equal(t, chatID, event.payload.Chat.ID)
|
|
require.Equal(t, wantTitle, event.payload.Chat.Title)
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timed out waiting for title change pubsub event")
|
|
}
|
|
}
|
|
|
|
func TestRegenerateChatTitle_PersistsAndBroadcasts_IdleChatReleasesManualLock(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
lockTx := dbmock.NewMockStore(ctrl)
|
|
usageTx := dbmock.NewMockStore(ctrl)
|
|
unlockTx := dbmock.NewMockStore(ctrl)
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
pubsub := dbpubsub.NewInMemory()
|
|
clock := quartz.NewReal()
|
|
|
|
ownerID := uuid.New()
|
|
chatID := uuid.New()
|
|
modelConfigID := uuid.New()
|
|
userPrompt := "review pull request 23633 and fix review threads"
|
|
wantTitle := "Review PR 23633"
|
|
|
|
chat := database.Chat{
|
|
ID: chatID,
|
|
OwnerID: ownerID,
|
|
LastModelConfigID: modelConfigID,
|
|
Status: database.ChatStatusCompleted,
|
|
Title: fallbackChatTitle(userPrompt),
|
|
}
|
|
lockedChat := chat
|
|
lockedChat.WorkerID = uuid.NullUUID{UUID: manualTitleLockWorkerID, Valid: true}
|
|
lockedChat.StartedAt = sql.NullTime{Time: time.Now(), Valid: true}
|
|
modelConfig := database.ChatModelConfig{
|
|
ID: modelConfigID,
|
|
Provider: "openai",
|
|
Model: "gpt-4o-mini",
|
|
ContextLimit: 8192,
|
|
}
|
|
updatedChat := lockedChat
|
|
updatedChat.Title = wantTitle
|
|
unlockedChat := updatedChat
|
|
unlockedChat.WorkerID = uuid.NullUUID{}
|
|
unlockedChat.StartedAt = sql.NullTime{}
|
|
|
|
messageEvents := make(chan struct {
|
|
payload codersdk.ChatWatchEvent
|
|
err error
|
|
}, 1)
|
|
cancelSub, err := pubsub.SubscribeWithErr(
|
|
coderdpubsub.ChatWatchEventChannel(ownerID),
|
|
coderdpubsub.HandleChatWatchEvent(func(_ context.Context, payload codersdk.ChatWatchEvent, err error) {
|
|
messageEvents <- struct {
|
|
payload codersdk.ChatWatchEvent
|
|
err error
|
|
}{payload: payload, err: err}
|
|
}),
|
|
)
|
|
require.NoError(t, err)
|
|
defer cancelSub()
|
|
|
|
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
|
require.Equal(t, "gpt-4o-mini", req.Model)
|
|
return chattest.OpenAINonStreamingResponse("{\"title\":\"" + wantTitle + "\"}")
|
|
})
|
|
|
|
server := &Server{
|
|
db: db,
|
|
logger: logger,
|
|
pubsub: pubsub,
|
|
configCache: newChatConfigCache(context.Background(), db, clock),
|
|
}
|
|
|
|
db.EXPECT().GetChatModelConfigByID(gomock.Any(), modelConfigID).Return(modelConfig, nil)
|
|
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{
|
|
Provider: "openai",
|
|
CentralApiKeyEnabled: true,
|
|
APIKey: "test-key",
|
|
BaseUrl: serverURL,
|
|
}}, nil)
|
|
db.EXPECT().GetChatUsageLimitConfig(gomock.Any()).Return(database.ChatUsageLimitConfig{}, sql.ErrNoRows)
|
|
db.EXPECT().GetChatMessagesByChatIDAscPaginated(
|
|
gomock.Any(),
|
|
database.GetChatMessagesByChatIDAscPaginatedParams{
|
|
ChatID: chatID,
|
|
AfterID: 0,
|
|
LimitVal: manualTitleMessageWindowLimit,
|
|
},
|
|
).Return([]database.ChatMessage{
|
|
mustChatMessage(
|
|
t,
|
|
database.ChatMessageRoleUser,
|
|
database.ChatMessageVisibilityBoth,
|
|
codersdk.ChatMessageText(userPrompt),
|
|
),
|
|
mustChatMessage(
|
|
t,
|
|
database.ChatMessageRoleAssistant,
|
|
database.ChatMessageVisibilityBoth,
|
|
codersdk.ChatMessageText("checking the diff now"),
|
|
),
|
|
}, nil)
|
|
db.EXPECT().GetChatMessagesByChatIDDescPaginated(
|
|
gomock.Any(),
|
|
database.GetChatMessagesByChatIDDescPaginatedParams{
|
|
ChatID: chatID,
|
|
BeforeID: 0,
|
|
LimitVal: manualTitleMessageWindowLimit,
|
|
},
|
|
).Return(nil, nil)
|
|
db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return("", nil)
|
|
db.EXPECT().GetEnabledChatModelConfigs(gomock.Any()).Return(nil, nil)
|
|
|
|
gomock.InOrder(
|
|
db.EXPECT().InTx(gomock.Any(), database.DefaultTXOptions().WithID("chat_title_regenerate_lock")).DoAndReturn(
|
|
func(fn func(database.Store) error, opts *database.TxOptions) error {
|
|
require.Equal(t, "chat_title_regenerate_lock", opts.TxIdentifier)
|
|
return fn(lockTx)
|
|
},
|
|
),
|
|
db.EXPECT().InTx(gomock.Any(), nil).DoAndReturn(
|
|
func(fn func(database.Store) error, opts *database.TxOptions) error {
|
|
require.Nil(t, opts)
|
|
return fn(usageTx)
|
|
},
|
|
),
|
|
db.EXPECT().InTx(gomock.Any(), database.DefaultTXOptions().WithID("chat_title_regenerate_unlock")).DoAndReturn(
|
|
func(fn func(database.Store) error, opts *database.TxOptions) error {
|
|
require.Equal(t, "chat_title_regenerate_unlock", opts.TxIdentifier)
|
|
return fn(unlockTx)
|
|
},
|
|
),
|
|
)
|
|
|
|
lockTx.EXPECT().GetChatByIDForUpdate(gomock.Any(), chatID).Return(chat, nil)
|
|
lockTx.EXPECT().UpdateChatStatusPreserveUpdatedAt(
|
|
gomock.Any(),
|
|
gomock.AssignableToTypeOf(database.UpdateChatStatusPreserveUpdatedAtParams{}),
|
|
).DoAndReturn(func(_ context.Context, arg database.UpdateChatStatusPreserveUpdatedAtParams) (database.Chat, error) {
|
|
require.Equal(t, chat.ID, arg.ID)
|
|
require.Equal(t, chat.Status, arg.Status)
|
|
require.Equal(t, uuid.NullUUID{UUID: manualTitleLockWorkerID, Valid: true}, arg.WorkerID)
|
|
require.True(t, arg.StartedAt.Valid)
|
|
require.WithinDuration(t, time.Now(), arg.StartedAt.Time, time.Second)
|
|
require.False(t, arg.HeartbeatAt.Valid)
|
|
require.Equal(t, chat.LastError, arg.LastError)
|
|
require.Equal(t, chat.UpdatedAt, arg.UpdatedAt)
|
|
return lockedChat, nil
|
|
})
|
|
|
|
usageTx.EXPECT().GetChatByIDForUpdate(gomock.Any(), chatID).Return(lockedChat, nil)
|
|
usageTx.EXPECT().InsertChatMessages(gomock.Any(), gomock.AssignableToTypeOf(database.InsertChatMessagesParams{})).DoAndReturn(
|
|
func(_ context.Context, arg database.InsertChatMessagesParams) ([]database.ChatMessage, error) {
|
|
require.Equal(t, []uuid.UUID{ownerID}, arg.CreatedBy)
|
|
require.Equal(t, []uuid.UUID{modelConfigID}, arg.ModelConfigID)
|
|
require.Equal(t, []string{"[]"}, arg.Content)
|
|
return []database.ChatMessage{{ID: 91}}, nil
|
|
},
|
|
)
|
|
usageTx.EXPECT().SoftDeleteChatMessageByID(gomock.Any(), int64(91)).Return(nil)
|
|
usageTx.EXPECT().UpdateChatByID(gomock.Any(), database.UpdateChatByIDParams{
|
|
ID: chatID,
|
|
Title: wantTitle,
|
|
}).Return(updatedChat, nil)
|
|
|
|
unlockTx.EXPECT().GetChatByIDForUpdate(gomock.Any(), chatID).Return(updatedChat, nil)
|
|
unlockTx.EXPECT().UpdateChatStatusPreserveUpdatedAt(
|
|
gomock.Any(),
|
|
database.UpdateChatStatusPreserveUpdatedAtParams{
|
|
ID: updatedChat.ID,
|
|
Status: updatedChat.Status,
|
|
WorkerID: uuid.NullUUID{},
|
|
StartedAt: sql.NullTime{},
|
|
HeartbeatAt: sql.NullTime{},
|
|
LastError: updatedChat.LastError,
|
|
UpdatedAt: updatedChat.UpdatedAt,
|
|
},
|
|
).Return(unlockedChat, nil)
|
|
|
|
gotChat, err := server.RegenerateChatTitle(ctx, chat)
|
|
require.NoError(t, err)
|
|
require.Equal(t, updatedChat, gotChat)
|
|
|
|
select {
|
|
case event := <-messageEvents:
|
|
require.NoError(t, event.err)
|
|
require.Equal(t, codersdk.ChatWatchEventKindTitleChange, event.payload.Kind)
|
|
require.Equal(t, chatID, event.payload.Chat.ID)
|
|
require.Equal(t, wantTitle, event.payload.Chat.Title)
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timed out waiting for title change pubsub event")
|
|
}
|
|
}
|
|
|
|
func TestResolveUserProviderAPIKeys_StripsDisabledFallbackKeys(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
ownerID := uuid.New()
|
|
|
|
server := &Server{
|
|
db: db,
|
|
configCache: newChatConfigCache(
|
|
context.Background(),
|
|
db,
|
|
quartz.NewReal(),
|
|
),
|
|
providerAPIKeys: chatprovider.ProviderAPIKeys{
|
|
OpenAI: "openai-deployment-key",
|
|
Anthropic: "anthropic-deployment-key",
|
|
ByProvider: map[string]string{
|
|
"openai": "openai-deployment-key",
|
|
"anthropic": "anthropic-deployment-key",
|
|
},
|
|
BaseURLByProvider: map[string]string{
|
|
"openai": "https://openai.example.com",
|
|
"anthropic": "https://anthropic.example.com",
|
|
},
|
|
},
|
|
}
|
|
|
|
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{
|
|
Provider: "anthropic",
|
|
CentralApiKeyEnabled: true,
|
|
AllowCentralApiKeyFallback: true,
|
|
}}, nil)
|
|
|
|
keys, err := server.resolveUserProviderAPIKeys(ctx, ownerID)
|
|
require.NoError(t, err)
|
|
require.Empty(t, keys.OpenAI)
|
|
require.Empty(t, keys.APIKey("openai"))
|
|
require.Empty(t, keys.BaseURL("openai"))
|
|
require.Equal(t, "anthropic-deployment-key", keys.Anthropic)
|
|
require.Equal(t, "anthropic-deployment-key", keys.APIKey("anthropic"))
|
|
require.Equal(t, "https://anthropic.example.com", keys.BaseURL("anthropic"))
|
|
require.Equal(t, map[string]string{"anthropic": "anthropic-deployment-key"}, keys.ByProvider)
|
|
require.Equal(t, map[string]string{"anthropic": "https://anthropic.example.com"}, keys.BaseURLByProvider)
|
|
}
|
|
|
|
func TestResolveUserProviderAPIKeys_SkipsUserKeyLookupWhenNoProviderAllowsUserKeys(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
ownerID := uuid.New()
|
|
|
|
server := &Server{
|
|
db: db,
|
|
configCache: newChatConfigCache(
|
|
context.Background(),
|
|
db,
|
|
quartz.NewReal(),
|
|
),
|
|
providerAPIKeys: chatprovider.ProviderAPIKeys{
|
|
OpenAI: "openai-deployment-key",
|
|
ByProvider: map[string]string{
|
|
"openai": "openai-deployment-key",
|
|
},
|
|
},
|
|
}
|
|
|
|
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{
|
|
Provider: "openai",
|
|
CentralApiKeyEnabled: true,
|
|
}}, nil)
|
|
|
|
keys, err := server.resolveUserProviderAPIKeys(ctx, ownerID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, "openai-deployment-key", keys.OpenAI)
|
|
require.Equal(t, "openai-deployment-key", keys.APIKey("openai"))
|
|
}
|
|
|
|
func TestRefreshChatWorkspaceSnapshot_NoReloadWhenWorkspacePresent(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
workspaceID := uuid.New()
|
|
chat := database.Chat{
|
|
ID: uuid.New(),
|
|
WorkspaceID: uuid.NullUUID{
|
|
UUID: workspaceID,
|
|
Valid: true,
|
|
},
|
|
}
|
|
|
|
calls := 0
|
|
refreshed, err := refreshChatWorkspaceSnapshot(
|
|
context.Background(),
|
|
chat,
|
|
func(context.Context, uuid.UUID) (database.Chat, error) {
|
|
calls++
|
|
return database.Chat{}, nil
|
|
},
|
|
)
|
|
require.NoError(t, err)
|
|
require.Equal(t, chat, refreshed)
|
|
require.Equal(t, 0, calls)
|
|
}
|
|
|
|
func TestRefreshChatWorkspaceSnapshot_ReloadsWhenWorkspaceMissing(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
chatID := uuid.New()
|
|
workspaceID := uuid.New()
|
|
chat := database.Chat{ID: chatID}
|
|
reloaded := database.Chat{
|
|
ID: chatID,
|
|
WorkspaceID: uuid.NullUUID{
|
|
UUID: workspaceID,
|
|
Valid: true,
|
|
},
|
|
}
|
|
|
|
calls := 0
|
|
refreshed, err := refreshChatWorkspaceSnapshot(
|
|
context.Background(),
|
|
chat,
|
|
func(_ context.Context, id uuid.UUID) (database.Chat, error) {
|
|
calls++
|
|
require.Equal(t, chatID, id)
|
|
return reloaded, nil
|
|
},
|
|
)
|
|
require.NoError(t, err)
|
|
require.Equal(t, reloaded, refreshed)
|
|
require.Equal(t, 1, calls)
|
|
}
|
|
|
|
func TestRefreshChatWorkspaceSnapshot_ReturnsReloadError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
chat := database.Chat{ID: uuid.New()}
|
|
loadErr := xerrors.New("boom")
|
|
|
|
refreshed, err := refreshChatWorkspaceSnapshot(
|
|
context.Background(),
|
|
chat,
|
|
func(context.Context, uuid.UUID) (database.Chat, error) {
|
|
return database.Chat{}, loadErr
|
|
},
|
|
)
|
|
require.Error(t, err)
|
|
require.ErrorContains(t, err, "reload chat workspace state")
|
|
require.ErrorContains(t, err, loadErr.Error())
|
|
require.Equal(t, chat, refreshed)
|
|
}
|
|
|
|
func TestPersistInstructionFilesIncludesAgentMetadata(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := context.Background()
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
workspaceID := uuid.New()
|
|
agentID := uuid.New()
|
|
chat := database.Chat{
|
|
ID: uuid.New(),
|
|
WorkspaceID: uuid.NullUUID{
|
|
UUID: workspaceID,
|
|
Valid: true,
|
|
},
|
|
AgentID: uuid.NullUUID{
|
|
UUID: agentID,
|
|
Valid: true,
|
|
},
|
|
}
|
|
workspaceAgent := database.WorkspaceAgent{
|
|
ID: agentID,
|
|
OperatingSystem: "linux",
|
|
Directory: "/home/coder/project",
|
|
ExpandedDirectory: "/home/coder/project",
|
|
}
|
|
|
|
db.EXPECT().GetWorkspaceAgentByID(
|
|
gomock.Any(),
|
|
agentID,
|
|
).Return(workspaceAgent, nil).Times(1)
|
|
db.EXPECT().InsertChatMessages(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
|
|
db.EXPECT().UpdateChatLastInjectedContext(gomock.Any(),
|
|
gomock.Cond(func(x any) bool {
|
|
arg, ok := x.(database.UpdateChatLastInjectedContextParams)
|
|
if !ok || arg.ID != chat.ID {
|
|
return false
|
|
}
|
|
if !arg.LastInjectedContext.Valid {
|
|
return false
|
|
}
|
|
var parts []codersdk.ChatMessagePart
|
|
if err := json.Unmarshal(arg.LastInjectedContext.RawMessage, &parts); err != nil {
|
|
return false
|
|
}
|
|
// Expect at least one context-file part for the
|
|
// working-directory AGENTS.md, with internal fields
|
|
// stripped (no content, OS, or directory).
|
|
for _, p := range parts {
|
|
if p.Type == codersdk.ChatMessagePartTypeContextFile && p.ContextFilePath != "" {
|
|
return p.ContextFileContent == "" &&
|
|
p.ContextFileOS == "" &&
|
|
p.ContextFileDirectory == ""
|
|
}
|
|
}
|
|
return false
|
|
}),
|
|
).Return(database.Chat{}, nil).Times(1)
|
|
|
|
conn := agentconnmock.NewMockAgentConn(ctrl)
|
|
conn.EXPECT().SetExtraHeaders(gomock.Any()).Times(1)
|
|
conn.EXPECT().ContextConfig(gomock.Any()).Return(workspacesdk.ContextConfigResponse{
|
|
Parts: []codersdk.ChatMessagePart{{
|
|
Type: codersdk.ChatMessagePartTypeContextFile,
|
|
ContextFilePath: "/home/coder/project/AGENTS.md",
|
|
ContextFileContent: "# Project instructions",
|
|
}},
|
|
}, nil).AnyTimes()
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
server := &Server{
|
|
db: db,
|
|
logger: logger,
|
|
clock: quartz.NewReal(),
|
|
instructionLookupTimeout: 5 * time.Second,
|
|
agentInactiveDisconnectTimeout: 30 * time.Second,
|
|
dialTimeout: 30 * time.Second,
|
|
agentConnFn: func(context.Context, uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
|
return conn, func() {}, nil
|
|
},
|
|
}
|
|
|
|
chatStateMu := &sync.Mutex{}
|
|
currentChat := chat
|
|
workspaceCtx := turnWorkspaceContext{
|
|
server: server,
|
|
chatStateMu: chatStateMu,
|
|
currentChat: ¤tChat,
|
|
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
|
|
}
|
|
t.Cleanup(workspaceCtx.close)
|
|
|
|
instruction, _, err := server.persistInstructionFiles(
|
|
ctx,
|
|
chat,
|
|
uuid.New(),
|
|
workspaceCtx.getWorkspaceAgent,
|
|
workspaceCtx.getWorkspaceConn,
|
|
)
|
|
require.NoError(t, err)
|
|
require.Contains(t, instruction, "Operating System: linux")
|
|
require.Contains(t, instruction, "Working Directory: /home/coder/project")
|
|
}
|
|
|
|
func TestPersistInstructionFilesSkipsSentinelWhenWorkspaceUnavailable(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := context.Background()
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
chat := database.Chat{
|
|
ID: uuid.New(),
|
|
WorkspaceID: uuid.NullUUID{
|
|
UUID: uuid.New(),
|
|
Valid: true,
|
|
},
|
|
}
|
|
server := &Server{
|
|
db: db,
|
|
logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
|
|
}
|
|
|
|
instruction, _, err := server.persistInstructionFiles(
|
|
ctx,
|
|
chat,
|
|
uuid.New(),
|
|
func(context.Context) (database.WorkspaceAgent, error) {
|
|
return database.WorkspaceAgent{
|
|
ID: uuid.New(),
|
|
Directory: "/home/coder/project",
|
|
}, nil
|
|
},
|
|
func(context.Context) (workspacesdk.AgentConn, error) {
|
|
return nil, errChatHasNoWorkspaceAgent
|
|
},
|
|
)
|
|
require.NoError(t, err)
|
|
require.Empty(t, instruction)
|
|
}
|
|
|
|
func TestPersistInstructionFilesSentinelWithSkills(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := context.Background()
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
workspaceID := uuid.New()
|
|
agentID := uuid.New()
|
|
chat := database.Chat{
|
|
ID: uuid.New(),
|
|
WorkspaceID: uuid.NullUUID{
|
|
UUID: workspaceID,
|
|
Valid: true,
|
|
},
|
|
AgentID: uuid.NullUUID{
|
|
UUID: agentID,
|
|
Valid: true,
|
|
},
|
|
}
|
|
workspaceAgent := database.WorkspaceAgent{
|
|
ID: agentID,
|
|
OperatingSystem: "linux",
|
|
Directory: "/home/coder/project",
|
|
ExpandedDirectory: "/home/coder/project",
|
|
}
|
|
|
|
db.EXPECT().GetWorkspaceAgentByID(
|
|
gomock.Any(),
|
|
agentID,
|
|
).Return(workspaceAgent, nil).Times(1)
|
|
db.EXPECT().InsertChatMessages(gomock.Any(),
|
|
gomock.Cond(func(x any) bool {
|
|
arg, ok := x.(database.InsertChatMessagesParams)
|
|
if !ok || arg.ChatID != chat.ID || len(arg.Content) != 1 {
|
|
return false
|
|
}
|
|
var parts []codersdk.ChatMessagePart
|
|
if err := json.Unmarshal([]byte(arg.Content[0]), &parts); err != nil {
|
|
return false
|
|
}
|
|
foundMarker := false
|
|
foundSkill := false
|
|
for _, p := range parts {
|
|
switch p.Type {
|
|
case codersdk.ChatMessagePartTypeContextFile:
|
|
if p.ContextFileAgentID == (uuid.NullUUID{UUID: agentID, Valid: true}) && p.ContextFileContent == "" {
|
|
foundMarker = true
|
|
}
|
|
case codersdk.ChatMessagePartTypeSkill:
|
|
if p.SkillName == "my-skill" && p.ContextFileAgentID == (uuid.NullUUID{UUID: agentID, Valid: true}) {
|
|
foundSkill = true
|
|
}
|
|
}
|
|
}
|
|
return foundMarker && foundSkill
|
|
}),
|
|
).Return(nil, nil).Times(1)
|
|
db.EXPECT().UpdateChatLastInjectedContext(gomock.Any(),
|
|
gomock.Cond(func(x any) bool {
|
|
arg, ok := x.(database.UpdateChatLastInjectedContextParams)
|
|
if !ok || arg.ID != chat.ID {
|
|
return false
|
|
}
|
|
if !arg.LastInjectedContext.Valid {
|
|
return false
|
|
}
|
|
var parts []codersdk.ChatMessagePart
|
|
if err := json.Unmarshal(arg.LastInjectedContext.RawMessage, &parts); err != nil {
|
|
return false
|
|
}
|
|
// The sentinel path should persist only skill parts
|
|
// with ContextFileAgentID set.
|
|
for _, p := range parts {
|
|
if p.Type == codersdk.ChatMessagePartTypeSkill &&
|
|
p.SkillName == "my-skill" &&
|
|
p.ContextFileAgentID == (uuid.NullUUID{UUID: agentID, Valid: true}) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}),
|
|
).Return(database.Chat{}, nil).Times(1)
|
|
|
|
conn := agentconnmock.NewMockAgentConn(ctrl)
|
|
conn.EXPECT().SetExtraHeaders(gomock.Any()).Times(1)
|
|
conn.EXPECT().ContextConfig(gomock.Any()).Return(workspacesdk.ContextConfigResponse{
|
|
// Agent returns pre-read content: no instruction files
|
|
// found but one skill discovered.
|
|
Parts: []codersdk.ChatMessagePart{{
|
|
Type: codersdk.ChatMessagePartTypeSkill,
|
|
SkillName: "my-skill",
|
|
SkillDescription: "A test skill",
|
|
SkillDir: "/home/coder/project/.agents/skills/my-skill",
|
|
}},
|
|
}, nil).AnyTimes()
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
server := &Server{
|
|
db: db,
|
|
logger: logger,
|
|
clock: quartz.NewReal(),
|
|
instructionLookupTimeout: 5 * time.Second,
|
|
agentInactiveDisconnectTimeout: 30 * time.Second,
|
|
dialTimeout: 30 * time.Second,
|
|
agentConnFn: func(context.Context, uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
|
return conn, func() {}, nil
|
|
},
|
|
}
|
|
|
|
chatStateMu := &sync.Mutex{}
|
|
currentChat := chat
|
|
workspaceCtx := turnWorkspaceContext{
|
|
server: server,
|
|
chatStateMu: chatStateMu,
|
|
currentChat: ¤tChat,
|
|
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
|
|
}
|
|
t.Cleanup(workspaceCtx.close)
|
|
|
|
instruction, skills, err := server.persistInstructionFiles(
|
|
ctx,
|
|
chat,
|
|
uuid.New(),
|
|
workspaceCtx.getWorkspaceAgent,
|
|
workspaceCtx.getWorkspaceConn,
|
|
)
|
|
require.NoError(t, err)
|
|
// Sentinel path returns empty instruction string.
|
|
require.Empty(t, instruction)
|
|
// Skills are still discovered and returned.
|
|
require.Len(t, skills, 1)
|
|
require.Equal(t, "my-skill", skills[0].Name)
|
|
}
|
|
|
|
func TestPersistInstructionFilesSentinelNoSkillsClearsColumn(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := context.Background()
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
workspaceID := uuid.New()
|
|
agentID := uuid.New()
|
|
chat := database.Chat{
|
|
ID: uuid.New(),
|
|
WorkspaceID: uuid.NullUUID{
|
|
UUID: workspaceID,
|
|
Valid: true,
|
|
},
|
|
AgentID: uuid.NullUUID{
|
|
UUID: agentID,
|
|
Valid: true,
|
|
},
|
|
}
|
|
workspaceAgent := database.WorkspaceAgent{
|
|
ID: agentID,
|
|
OperatingSystem: "linux",
|
|
Directory: "/home/coder/project",
|
|
ExpandedDirectory: "/home/coder/project",
|
|
}
|
|
|
|
db.EXPECT().GetWorkspaceAgentByID(
|
|
gomock.Any(),
|
|
agentID,
|
|
).Return(workspaceAgent, nil).Times(1)
|
|
db.EXPECT().InsertChatMessages(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
|
|
db.EXPECT().UpdateChatLastInjectedContext(gomock.Any(),
|
|
gomock.Cond(func(x any) bool {
|
|
arg, ok := x.(database.UpdateChatLastInjectedContextParams)
|
|
if !ok || arg.ID != chat.ID {
|
|
return false
|
|
}
|
|
// No skills discovered, so the column should be
|
|
// cleared to NULL.
|
|
return !arg.LastInjectedContext.Valid
|
|
}),
|
|
).Return(database.Chat{}, nil).Times(1)
|
|
|
|
conn := agentconnmock.NewMockAgentConn(ctrl)
|
|
conn.EXPECT().SetExtraHeaders(gomock.Any()).Times(1)
|
|
conn.EXPECT().ContextConfig(gomock.Any()).Return(workspacesdk.ContextConfigResponse{
|
|
// Agent returns pre-read content: no files, no skills.
|
|
Parts: []codersdk.ChatMessagePart{},
|
|
}, nil).AnyTimes()
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
server := &Server{
|
|
db: db,
|
|
logger: logger,
|
|
clock: quartz.NewReal(),
|
|
instructionLookupTimeout: 5 * time.Second,
|
|
agentInactiveDisconnectTimeout: 30 * time.Second,
|
|
dialTimeout: 30 * time.Second,
|
|
agentConnFn: func(context.Context, uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
|
return conn, func() {}, nil
|
|
},
|
|
}
|
|
|
|
chatStateMu := &sync.Mutex{}
|
|
currentChat := chat
|
|
workspaceCtx := turnWorkspaceContext{
|
|
server: server,
|
|
chatStateMu: chatStateMu,
|
|
currentChat: ¤tChat,
|
|
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
|
|
}
|
|
t.Cleanup(workspaceCtx.close)
|
|
|
|
instruction, skills, err := server.persistInstructionFiles(
|
|
ctx,
|
|
chat,
|
|
uuid.New(),
|
|
workspaceCtx.getWorkspaceAgent,
|
|
workspaceCtx.getWorkspaceConn,
|
|
)
|
|
require.NoError(t, err)
|
|
// Sentinel path: empty instruction, no skills.
|
|
require.Empty(t, instruction)
|
|
require.Empty(t, skills)
|
|
}
|
|
|
|
func TestTurnWorkspaceContext_BindingFirstPath(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := context.Background()
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
workspaceID := uuid.New()
|
|
agentID := uuid.New()
|
|
chat := database.Chat{
|
|
ID: uuid.New(),
|
|
WorkspaceID: uuid.NullUUID{
|
|
UUID: workspaceID,
|
|
Valid: true,
|
|
},
|
|
AgentID: uuid.NullUUID{
|
|
UUID: agentID,
|
|
Valid: true,
|
|
},
|
|
}
|
|
workspaceAgent := database.WorkspaceAgent{ID: agentID}
|
|
|
|
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).Return(workspaceAgent, nil).Times(1)
|
|
|
|
chatStateMu := &sync.Mutex{}
|
|
currentChat := chat
|
|
workspaceCtx := turnWorkspaceContext{
|
|
server: &Server{db: db},
|
|
chatStateMu: chatStateMu,
|
|
currentChat: ¤tChat,
|
|
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
|
|
}
|
|
t.Cleanup(workspaceCtx.close)
|
|
|
|
chatSnapshot, agent, err := workspaceCtx.ensureWorkspaceAgent(ctx)
|
|
require.NoError(t, err)
|
|
require.Equal(t, chat, chatSnapshot)
|
|
require.Equal(t, workspaceAgent, agent)
|
|
|
|
gotAgent, err := workspaceCtx.getWorkspaceAgent(ctx)
|
|
require.NoError(t, err)
|
|
require.Equal(t, workspaceAgent, gotAgent)
|
|
require.Equal(t, chat, currentChat)
|
|
}
|
|
|
|
func TestTurnWorkspaceContext_NullBindingLazyBind(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := context.Background()
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
workspaceID := uuid.New()
|
|
buildID := uuid.New()
|
|
agentID := uuid.New()
|
|
chat := database.Chat{
|
|
ID: uuid.New(),
|
|
WorkspaceID: uuid.NullUUID{
|
|
UUID: workspaceID,
|
|
Valid: true,
|
|
},
|
|
}
|
|
workspaceAgent := database.WorkspaceAgent{ID: agentID}
|
|
updatedChat := chat
|
|
updatedChat.BuildID = uuid.NullUUID{UUID: buildID, Valid: true}
|
|
updatedChat.AgentID = uuid.NullUUID{UUID: agentID, Valid: true}
|
|
|
|
gomock.InOrder(
|
|
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).Return([]database.WorkspaceAgent{workspaceAgent}, nil),
|
|
db.EXPECT().GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), workspaceID).Return(database.WorkspaceBuild{ID: buildID}, nil),
|
|
db.EXPECT().UpdateChatBuildAgentBinding(gomock.Any(), database.UpdateChatBuildAgentBindingParams{
|
|
BuildID: uuid.NullUUID{UUID: buildID, Valid: true},
|
|
AgentID: uuid.NullUUID{UUID: agentID, Valid: true},
|
|
ID: chat.ID,
|
|
}).Return(updatedChat, nil),
|
|
)
|
|
|
|
chatStateMu := &sync.Mutex{}
|
|
currentChat := chat
|
|
workspaceCtx := turnWorkspaceContext{
|
|
server: &Server{db: db},
|
|
chatStateMu: chatStateMu,
|
|
currentChat: ¤tChat,
|
|
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
|
|
}
|
|
t.Cleanup(workspaceCtx.close)
|
|
|
|
chatSnapshot, agent, err := workspaceCtx.ensureWorkspaceAgent(ctx)
|
|
require.NoError(t, err)
|
|
require.Equal(t, updatedChat, chatSnapshot)
|
|
require.Equal(t, workspaceAgent, agent)
|
|
require.Equal(t, updatedChat, currentChat)
|
|
|
|
gotAgent, err := workspaceCtx.getWorkspaceAgent(ctx)
|
|
require.NoError(t, err)
|
|
require.Equal(t, workspaceAgent, gotAgent)
|
|
}
|
|
|
|
func TestTurnWorkspaceContext_StaleBindingRepair(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := context.Background()
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
workspaceID := uuid.New()
|
|
staleAgentID := uuid.New()
|
|
buildID := uuid.New()
|
|
currentAgentID := uuid.New()
|
|
chat := database.Chat{
|
|
ID: uuid.New(),
|
|
WorkspaceID: uuid.NullUUID{
|
|
UUID: workspaceID,
|
|
Valid: true,
|
|
},
|
|
AgentID: uuid.NullUUID{
|
|
UUID: staleAgentID,
|
|
Valid: true,
|
|
},
|
|
}
|
|
currentAgent := database.WorkspaceAgent{ID: currentAgentID}
|
|
updatedChat := chat
|
|
updatedChat.BuildID = uuid.NullUUID{UUID: buildID, Valid: true}
|
|
updatedChat.AgentID = uuid.NullUUID{UUID: currentAgentID, Valid: true}
|
|
|
|
gomock.InOrder(
|
|
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), staleAgentID).Return(database.WorkspaceAgent{}, xerrors.New("missing agent")),
|
|
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).Return([]database.WorkspaceAgent{currentAgent}, nil),
|
|
db.EXPECT().GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), workspaceID).Return(database.WorkspaceBuild{ID: buildID}, nil),
|
|
db.EXPECT().UpdateChatBuildAgentBinding(gomock.Any(), database.UpdateChatBuildAgentBindingParams{
|
|
BuildID: uuid.NullUUID{UUID: buildID, Valid: true},
|
|
AgentID: uuid.NullUUID{UUID: currentAgentID, Valid: true},
|
|
ID: chat.ID,
|
|
}).Return(updatedChat, nil),
|
|
)
|
|
|
|
chatStateMu := &sync.Mutex{}
|
|
currentChat := chat
|
|
workspaceCtx := turnWorkspaceContext{
|
|
server: &Server{db: db},
|
|
chatStateMu: chatStateMu,
|
|
currentChat: ¤tChat,
|
|
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
|
|
}
|
|
t.Cleanup(workspaceCtx.close)
|
|
|
|
chatSnapshot, agent, err := workspaceCtx.ensureWorkspaceAgent(ctx)
|
|
require.NoError(t, err)
|
|
require.Equal(t, updatedChat, chatSnapshot)
|
|
require.Equal(t, currentAgent, agent)
|
|
require.Equal(t, updatedChat, currentChat)
|
|
}
|
|
|
|
func TestTurnWorkspaceContextGetWorkspaceConnLazyValidationSwitchesWorkspaceAgent(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := context.Background()
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
workspaceID := uuid.New()
|
|
staleAgentID := uuid.New()
|
|
currentAgentID := uuid.New()
|
|
buildID := uuid.New()
|
|
chat := database.Chat{
|
|
ID: uuid.New(),
|
|
WorkspaceID: uuid.NullUUID{
|
|
UUID: workspaceID,
|
|
Valid: true,
|
|
},
|
|
AgentID: uuid.NullUUID{
|
|
UUID: staleAgentID,
|
|
Valid: true,
|
|
},
|
|
}
|
|
staleAgent := database.WorkspaceAgent{ID: staleAgentID}
|
|
currentAgent := database.WorkspaceAgent{ID: currentAgentID}
|
|
updatedChat := chat
|
|
updatedChat.BuildID = uuid.NullUUID{UUID: buildID, Valid: true}
|
|
updatedChat.AgentID = uuid.NullUUID{UUID: currentAgentID, Valid: true}
|
|
|
|
gomock.InOrder(
|
|
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), staleAgentID).Return(staleAgent, nil),
|
|
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).Return([]database.WorkspaceAgent{currentAgent}, nil),
|
|
db.EXPECT().GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), workspaceID).Return(database.WorkspaceBuild{ID: buildID}, nil),
|
|
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), currentAgentID).Return(currentAgent, nil),
|
|
db.EXPECT().UpdateChatBuildAgentBinding(gomock.Any(), database.UpdateChatBuildAgentBindingParams{
|
|
BuildID: uuid.NullUUID{UUID: buildID, Valid: true},
|
|
AgentID: uuid.NullUUID{UUID: currentAgentID, Valid: true},
|
|
ID: chat.ID,
|
|
}).Return(updatedChat, nil),
|
|
)
|
|
|
|
conn := agentconnmock.NewMockAgentConn(ctrl)
|
|
conn.EXPECT().SetExtraHeaders(gomock.Any()).Times(1)
|
|
|
|
var dialed []uuid.UUID
|
|
server := &Server{
|
|
db: db,
|
|
clock: quartz.NewReal(),
|
|
agentInactiveDisconnectTimeout: 30 * time.Second,
|
|
dialTimeout: 30 * time.Second,
|
|
}
|
|
server.agentConnFn = func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
|
dialed = append(dialed, agentID)
|
|
if agentID == staleAgentID {
|
|
return nil, nil, xerrors.New("dial failed")
|
|
}
|
|
return conn, func() {}, nil
|
|
}
|
|
|
|
chatStateMu := &sync.Mutex{}
|
|
currentChat := chat
|
|
workspaceCtx := turnWorkspaceContext{
|
|
server: server,
|
|
chatStateMu: chatStateMu,
|
|
currentChat: ¤tChat,
|
|
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
|
|
}
|
|
t.Cleanup(workspaceCtx.close)
|
|
|
|
gotConn, err := workspaceCtx.getWorkspaceConn(ctx)
|
|
require.NoError(t, err)
|
|
require.Same(t, conn, gotConn)
|
|
require.Equal(t, []uuid.UUID{staleAgentID, currentAgentID}, dialed)
|
|
require.Equal(t, updatedChat, currentChat)
|
|
|
|
gotAgent, err := workspaceCtx.getWorkspaceAgent(ctx)
|
|
require.NoError(t, err)
|
|
require.Equal(t, currentAgent, gotAgent)
|
|
}
|
|
|
|
func TestTurnWorkspaceContextGetWorkspaceConnFastFailsWithoutCurrentAgent(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := context.Background()
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
workspaceID := uuid.New()
|
|
staleAgentID := uuid.New()
|
|
resourceID := uuid.New()
|
|
chat := database.Chat{
|
|
ID: uuid.New(),
|
|
WorkspaceID: uuid.NullUUID{
|
|
UUID: workspaceID,
|
|
Valid: true,
|
|
},
|
|
AgentID: uuid.NullUUID{
|
|
UUID: staleAgentID,
|
|
Valid: true,
|
|
},
|
|
}
|
|
|
|
staleAgent := database.WorkspaceAgent{ID: staleAgentID, ResourceID: resourceID}
|
|
|
|
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), staleAgentID).
|
|
Return(staleAgent, nil).
|
|
Times(1)
|
|
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).
|
|
Return([]database.WorkspaceAgent{}, nil).
|
|
Times(1)
|
|
db.EXPECT().GetWorkspaceResourceByID(gomock.Any(), resourceID).
|
|
Return(database.WorkspaceResource{
|
|
ID: resourceID,
|
|
Type: chattool.ExternalAgentResourceType,
|
|
}, nil).
|
|
AnyTimes()
|
|
|
|
server := &Server{
|
|
db: db,
|
|
clock: quartz.NewReal(),
|
|
agentInactiveDisconnectTimeout: 30 * time.Second,
|
|
dialTimeout: 30 * time.Second,
|
|
}
|
|
server.agentConnFn = func(context.Context, uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
|
return nil, nil, xerrors.New("dial failed")
|
|
}
|
|
|
|
chatStateMu := &sync.Mutex{}
|
|
currentChat := chat
|
|
workspaceCtx := turnWorkspaceContext{
|
|
server: server,
|
|
chatStateMu: chatStateMu,
|
|
currentChat: ¤tChat,
|
|
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
|
|
}
|
|
defer workspaceCtx.close()
|
|
|
|
gotConn, err := workspaceCtx.getWorkspaceConn(ctx)
|
|
require.Nil(t, gotConn)
|
|
require.ErrorIs(t, err, errChatHasNoWorkspaceAgent)
|
|
require.NotErrorIs(t, err, errChatExternalAgentUnavailable)
|
|
|
|
workspaceCtx.mu.Lock()
|
|
defer workspaceCtx.mu.Unlock()
|
|
require.Equal(t, database.WorkspaceAgent{}, workspaceCtx.agent)
|
|
require.False(t, workspaceCtx.agentLoaded)
|
|
require.Nil(t, workspaceCtx.conn)
|
|
require.Nil(t, workspaceCtx.releaseConn)
|
|
require.Equal(t, uuid.NullUUID{}, workspaceCtx.cachedWorkspaceID)
|
|
}
|
|
|
|
func TestTurnWorkspaceContext_SelectWorkspaceClearsCachedState(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
currentChat := database.Chat{
|
|
ID: uuid.New(),
|
|
WorkspaceID: uuid.NullUUID{
|
|
UUID: uuid.New(),
|
|
Valid: true,
|
|
},
|
|
}
|
|
updatedChat := database.Chat{
|
|
ID: currentChat.ID,
|
|
WorkspaceID: uuid.NullUUID{
|
|
UUID: uuid.New(),
|
|
Valid: true,
|
|
},
|
|
}
|
|
cachedConn := agentconnmock.NewMockAgentConn(ctrl)
|
|
releaseCalls := 0
|
|
|
|
workspaceCtx := turnWorkspaceContext{
|
|
chatStateMu: &sync.Mutex{},
|
|
currentChat: ¤tChat,
|
|
}
|
|
workspaceCtx.agent = database.WorkspaceAgent{ID: uuid.New()}
|
|
workspaceCtx.agentLoaded = true
|
|
workspaceCtx.conn = cachedConn
|
|
workspaceCtx.cachedWorkspaceID = currentChat.WorkspaceID
|
|
workspaceCtx.releaseConn = func() {
|
|
releaseCalls++
|
|
}
|
|
|
|
workspaceCtx.selectWorkspace(updatedChat)
|
|
|
|
require.Equal(t, updatedChat, currentChat)
|
|
require.Equal(t, 1, releaseCalls)
|
|
|
|
workspaceCtx.mu.Lock()
|
|
defer workspaceCtx.mu.Unlock()
|
|
require.Equal(t, database.WorkspaceAgent{}, workspaceCtx.agent)
|
|
require.False(t, workspaceCtx.agentLoaded)
|
|
require.Nil(t, workspaceCtx.conn)
|
|
require.Nil(t, workspaceCtx.releaseConn)
|
|
require.Equal(t, uuid.NullUUID{}, workspaceCtx.cachedWorkspaceID)
|
|
}
|
|
|
|
func TestTurnWorkspaceContext_EnsureWorkspaceAgentIgnoresCachedAgentForDifferentWorkspace(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := context.Background()
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
workspaceOneID := uuid.New()
|
|
workspaceTwoID := uuid.New()
|
|
buildID := uuid.New()
|
|
cachedAgent := database.WorkspaceAgent{ID: uuid.New()}
|
|
resolvedAgent := database.WorkspaceAgent{ID: uuid.New()}
|
|
chat := database.Chat{
|
|
ID: uuid.New(),
|
|
WorkspaceID: uuid.NullUUID{
|
|
UUID: workspaceTwoID,
|
|
Valid: true,
|
|
},
|
|
}
|
|
updatedChat := chat
|
|
updatedChat.BuildID = uuid.NullUUID{UUID: buildID, Valid: true}
|
|
updatedChat.AgentID = uuid.NullUUID{UUID: resolvedAgent.ID, Valid: true}
|
|
|
|
gomock.InOrder(
|
|
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceTwoID).Return([]database.WorkspaceAgent{resolvedAgent}, nil),
|
|
db.EXPECT().GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), workspaceTwoID).Return(database.WorkspaceBuild{ID: buildID}, nil),
|
|
db.EXPECT().UpdateChatBuildAgentBinding(gomock.Any(), database.UpdateChatBuildAgentBindingParams{
|
|
ID: chat.ID,
|
|
BuildID: uuid.NullUUID{UUID: buildID, Valid: true},
|
|
AgentID: uuid.NullUUID{UUID: resolvedAgent.ID, Valid: true},
|
|
}).Return(updatedChat, nil),
|
|
)
|
|
|
|
chatStateMu := &sync.Mutex{}
|
|
currentChat := chat
|
|
workspaceCtx := turnWorkspaceContext{
|
|
server: &Server{db: db},
|
|
chatStateMu: chatStateMu,
|
|
currentChat: ¤tChat,
|
|
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
|
|
}
|
|
workspaceCtx.agent = cachedAgent
|
|
workspaceCtx.agentLoaded = true
|
|
workspaceCtx.cachedWorkspaceID = uuid.NullUUID{UUID: workspaceOneID, Valid: true}
|
|
defer workspaceCtx.close()
|
|
|
|
chatSnapshot, agent, err := workspaceCtx.ensureWorkspaceAgent(ctx)
|
|
require.NoError(t, err)
|
|
require.Equal(t, updatedChat, chatSnapshot)
|
|
require.Equal(t, resolvedAgent, agent)
|
|
require.Equal(t, updatedChat, currentChat)
|
|
}
|
|
|
|
func TestSubscribeDedupesLocallyDeliveredMessageOnNotifyCatchup(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancelCtx := context.WithCancel(context.Background())
|
|
defer cancelCtx()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
chatID := uuid.New()
|
|
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
|
|
initialMessage := database.ChatMessage{
|
|
ID: 1,
|
|
ChatID: chatID,
|
|
Role: database.ChatMessageRoleUser,
|
|
}
|
|
localMessage := database.ChatMessage{
|
|
ID: 2,
|
|
ChatID: chatID,
|
|
Role: database.ChatMessageRoleAssistant,
|
|
}
|
|
gomock.InOrder(
|
|
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
|
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
|
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
|
ChatID: chatID,
|
|
AfterID: 0,
|
|
}).Return([]database.ChatMessage{initialMessage}, nil),
|
|
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
|
|
// DB catchup runs unconditionally on every notify; the delivered
|
|
// set dedupes against locally-delivered messages.
|
|
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
|
ChatID: chatID,
|
|
AfterID: 1,
|
|
}).Return(nil, nil),
|
|
)
|
|
|
|
server := newSubscribeTestServer(t, db)
|
|
_, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
|
|
require.True(t, ok)
|
|
defer cancel()
|
|
|
|
server.publishMessage(chatID, localMessage)
|
|
|
|
event := requireStreamMessageEvent(t, events)
|
|
require.Equal(t, int64(2), event.Message.ID)
|
|
requireNoStreamEvent(t, events, 200*time.Millisecond)
|
|
}
|
|
|
|
func TestSubscribeUsesDurableCacheWhenLocalMessageWasNotDelivered(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancelCtx := context.WithCancel(context.Background())
|
|
defer cancelCtx()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
chatID := uuid.New()
|
|
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
|
|
initialMessage := database.ChatMessage{
|
|
ID: 1,
|
|
ChatID: chatID,
|
|
Role: database.ChatMessageRoleUser,
|
|
}
|
|
cachedMessage := codersdk.ChatMessage{
|
|
ID: 2,
|
|
ChatID: chatID,
|
|
Role: codersdk.ChatMessageRoleAssistant,
|
|
}
|
|
gomock.InOrder(
|
|
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
|
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
|
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
|
ChatID: chatID,
|
|
AfterID: 0,
|
|
}).Return([]database.ChatMessage{initialMessage}, nil),
|
|
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
|
|
// DB catchup runs unconditionally; cached id=2 is deduped via
|
|
// the delivered set so this query returning nil is sufficient.
|
|
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
|
ChatID: chatID,
|
|
AfterID: 1,
|
|
}).Return(nil, nil),
|
|
)
|
|
|
|
server := newSubscribeTestServer(t, db)
|
|
server.cacheDurableMessage(chatID, codersdk.ChatStreamEvent{
|
|
Type: codersdk.ChatStreamEventTypeMessage,
|
|
ChatID: chatID,
|
|
Message: &cachedMessage,
|
|
})
|
|
|
|
_, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
|
|
require.True(t, ok)
|
|
defer cancel()
|
|
|
|
server.publishChatStreamNotify(chatID, coderdpubsub.ChatStreamNotifyMessage{
|
|
AfterMessageID: 1,
|
|
})
|
|
|
|
event := requireStreamMessageEvent(t, events)
|
|
require.Equal(t, int64(2), event.Message.ID)
|
|
requireNoStreamEvent(t, events, 200*time.Millisecond)
|
|
}
|
|
|
|
func TestSubscribeQueriesDatabaseWhenDurableCacheMisses(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancelCtx := context.WithCancel(context.Background())
|
|
defer cancelCtx()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
chatID := uuid.New()
|
|
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
|
|
initialMessage := database.ChatMessage{
|
|
ID: 1,
|
|
ChatID: chatID,
|
|
Role: database.ChatMessageRoleUser,
|
|
}
|
|
catchupMessage := database.ChatMessage{
|
|
ID: 2,
|
|
ChatID: chatID,
|
|
Role: database.ChatMessageRoleAssistant,
|
|
}
|
|
gomock.InOrder(
|
|
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
|
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
|
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
|
ChatID: chatID,
|
|
AfterID: 0,
|
|
}).Return([]database.ChatMessage{initialMessage}, nil),
|
|
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
|
|
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
|
ChatID: chatID,
|
|
AfterID: 1,
|
|
}).Return([]database.ChatMessage{catchupMessage}, nil),
|
|
)
|
|
|
|
server := newSubscribeTestServer(t, db)
|
|
_, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
|
|
require.True(t, ok)
|
|
defer cancel()
|
|
|
|
server.publishChatStreamNotify(chatID, coderdpubsub.ChatStreamNotifyMessage{
|
|
AfterMessageID: 1,
|
|
})
|
|
|
|
event := requireStreamMessageEvent(t, events)
|
|
require.Equal(t, int64(2), event.Message.ID)
|
|
requireNoStreamEvent(t, events, 200*time.Millisecond)
|
|
}
|
|
|
|
func TestSubscribeFullRefreshStillUsesDatabaseCatchup(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancelCtx := context.WithCancel(context.Background())
|
|
defer cancelCtx()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
chatID := uuid.New()
|
|
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
|
|
initialMessage := database.ChatMessage{
|
|
ID: 1,
|
|
ChatID: chatID,
|
|
Role: database.ChatMessageRoleUser,
|
|
}
|
|
editedMessage := database.ChatMessage{
|
|
ID: 1,
|
|
ChatID: chatID,
|
|
Role: database.ChatMessageRoleUser,
|
|
}
|
|
gomock.InOrder(
|
|
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
|
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
|
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
|
ChatID: chatID,
|
|
AfterID: 0,
|
|
}).Return([]database.ChatMessage{initialMessage}, nil),
|
|
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
|
|
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
|
ChatID: chatID,
|
|
AfterID: 0,
|
|
}).Return([]database.ChatMessage{editedMessage}, nil),
|
|
)
|
|
|
|
server := newSubscribeTestServer(t, db)
|
|
_, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
|
|
require.True(t, ok)
|
|
defer cancel()
|
|
|
|
server.publishEditedMessage(chatID, editedMessage)
|
|
|
|
event := requireStreamMessageEvent(t, events)
|
|
require.Equal(t, int64(1), event.Message.ID)
|
|
requireNoStreamEvent(t, events, 200*time.Millisecond)
|
|
}
|
|
|
|
func TestSubscribeDeliversRetryEventViaPubsubOnce(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancelCtx := context.WithCancel(context.Background())
|
|
defer cancelCtx()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
chatID := uuid.New()
|
|
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
|
|
gomock.InOrder(
|
|
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
|
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
|
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
|
ChatID: chatID,
|
|
AfterID: 0,
|
|
}).Return(nil, nil),
|
|
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
|
|
)
|
|
|
|
server := newSubscribeTestServer(t, db)
|
|
_, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
|
|
require.True(t, ok)
|
|
defer cancel()
|
|
|
|
expected := newTestRetryPayload()
|
|
|
|
server.publishRetry(chatID, expected)
|
|
|
|
event := requireStreamRetryEvent(t, events)
|
|
require.Equal(t, expected, event.Retry)
|
|
requireNoStreamEvent(t, events, 200*time.Millisecond)
|
|
}
|
|
|
|
func TestSubscribeReplaysCurrentRetryPhaseInSnapshot(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancelCtx := context.WithCancel(context.Background())
|
|
defer cancelCtx()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
chatID := uuid.New()
|
|
chat := database.Chat{ID: chatID, Status: database.ChatStatusRunning}
|
|
|
|
gomock.InOrder(
|
|
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
|
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
|
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
|
ChatID: chatID,
|
|
AfterID: 0,
|
|
}).Return(nil, nil),
|
|
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
|
|
)
|
|
|
|
server := newBufferedSubscribeTestServer(t, db, chatID)
|
|
|
|
expected := newTestRetryPayload()
|
|
server.publishRetry(chatID, expected)
|
|
|
|
snapshot, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
|
|
require.True(t, ok)
|
|
defer cancel()
|
|
|
|
require.Len(t, snapshot, 2)
|
|
require.Equal(t, codersdk.ChatStreamEventTypeStatus, snapshot[0].Type)
|
|
require.Equal(t, codersdk.ChatStreamEventTypeRetry, snapshot[1].Type)
|
|
event := requireSnapshotRetryEvent(t, snapshot)
|
|
require.Equal(t, expected, event.Retry)
|
|
requireNoStreamEvent(t, events, 200*time.Millisecond)
|
|
}
|
|
|
|
func TestSubscribeCapturesRetryPhaseAtSubscriptionBoundary(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancelCtx := context.WithCancel(context.Background())
|
|
defer cancelCtx()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
chatID := uuid.New()
|
|
chat := database.Chat{ID: chatID, Status: database.ChatStatusRunning}
|
|
expected := newTestRetryPayload()
|
|
|
|
server := newSubscribeTestServer(t, db)
|
|
|
|
gomock.InOrder(
|
|
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
|
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
|
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
|
ChatID: chatID,
|
|
AfterID: 0,
|
|
}).DoAndReturn(func(context.Context, database.GetChatMessagesByChatIDParams) ([]database.ChatMessage, error) {
|
|
server.publishRetry(chatID, expected)
|
|
return nil, nil
|
|
}),
|
|
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
|
|
)
|
|
|
|
snapshot, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
|
|
require.True(t, ok)
|
|
defer cancel()
|
|
|
|
requireNoSnapshotRetryEvent(t, snapshot)
|
|
event := requireStreamRetryEvent(t, events)
|
|
require.Equal(t, expected, event.Retry)
|
|
requireNoStreamEvent(t, events, 200*time.Millisecond)
|
|
}
|
|
|
|
func TestSubscribeDoesNotReplayRetryAfterStreamResumes(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancelCtx := context.WithCancel(context.Background())
|
|
defer cancelCtx()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
chatID := uuid.New()
|
|
chat := database.Chat{ID: chatID, Status: database.ChatStatusRunning}
|
|
|
|
gomock.InOrder(
|
|
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
|
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
|
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
|
ChatID: chatID,
|
|
AfterID: 0,
|
|
}).Return(nil, nil),
|
|
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
|
|
)
|
|
|
|
server := newBufferedSubscribeTestServer(t, db, chatID)
|
|
|
|
server.publishRetry(chatID, newTestRetryPayload())
|
|
server.publishMessagePart(chatID, codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageText("retry recovered"))
|
|
|
|
snapshot, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
|
|
require.True(t, ok)
|
|
defer cancel()
|
|
|
|
requireNoSnapshotRetryEvent(t, snapshot)
|
|
requireSnapshotMessagePartEvent(t, snapshot)
|
|
requireNoStreamEvent(t, events, 200*time.Millisecond)
|
|
}
|
|
|
|
func TestSubscribeDoesNotReplayFailedAttemptPartsAfterRetry(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancelCtx := context.WithCancel(context.Background())
|
|
defer cancelCtx()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
chatID := uuid.New()
|
|
chat := database.Chat{ID: chatID, Status: database.ChatStatusRunning}
|
|
|
|
gomock.InOrder(
|
|
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
|
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
|
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
|
ChatID: chatID,
|
|
AfterID: 0,
|
|
}).Return(nil, nil),
|
|
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
|
|
)
|
|
|
|
server := newBufferedSubscribeTestServer(t, db, chatID)
|
|
|
|
server.publishMessagePart(chatID, codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageText("failed partial"))
|
|
server.clearProvisionalStreamParts(chatID)
|
|
server.publishRetry(chatID, newTestRetryPayload())
|
|
server.publishMessagePart(chatID, codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageText("retry recovered"))
|
|
|
|
snapshot, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
|
|
require.True(t, ok)
|
|
defer cancel()
|
|
|
|
requireNoSnapshotRetryEvent(t, snapshot)
|
|
partEvent := requireSnapshotMessagePartEvent(t, snapshot)
|
|
require.Equal(t, "retry recovered", partEvent.MessagePart.Part.Text)
|
|
for _, event := range snapshot {
|
|
if event.Type != codersdk.ChatStreamEventTypeMessagePart {
|
|
continue
|
|
}
|
|
require.NotEqual(t, "failed partial", event.MessagePart.Part.Text)
|
|
}
|
|
requireNoStreamEvent(t, events, 200*time.Millisecond)
|
|
}
|
|
|
|
func TestSubscribeDoesNotReplayRetryAfterTerminalError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancelCtx := context.WithCancel(context.Background())
|
|
defer cancelCtx()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
chatID := uuid.New()
|
|
chat := database.Chat{ID: chatID, Status: database.ChatStatusRunning}
|
|
|
|
gomock.InOrder(
|
|
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
|
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
|
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
|
ChatID: chatID,
|
|
AfterID: 0,
|
|
}).Return(nil, nil),
|
|
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
|
|
)
|
|
|
|
server := newBufferedSubscribeTestServer(t, db, chatID)
|
|
|
|
server.publishRetry(chatID, newTestRetryPayload())
|
|
server.publishError(chatID, chaterror.ClassifiedError{
|
|
Message: "OpenAI is rate limiting requests.",
|
|
Kind: codersdk.ChatErrorKindRateLimit,
|
|
Provider: "openai",
|
|
Retryable: true,
|
|
StatusCode: 429,
|
|
})
|
|
|
|
snapshot, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
|
|
require.True(t, ok)
|
|
defer cancel()
|
|
|
|
requireNoSnapshotRetryEvent(t, snapshot)
|
|
requireNoStreamEvent(t, events, 200*time.Millisecond)
|
|
}
|
|
|
|
func TestSubscribeDoesNotReplayRetryAfterTerminalStatus(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancelCtx := context.WithCancel(context.Background())
|
|
defer cancelCtx()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
chatID := uuid.New()
|
|
chat := database.Chat{ID: chatID, Status: database.ChatStatusCompleted}
|
|
|
|
gomock.InOrder(
|
|
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
|
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
|
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
|
ChatID: chatID,
|
|
AfterID: 0,
|
|
}).Return(nil, nil),
|
|
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
|
|
)
|
|
|
|
server := newBufferedSubscribeTestServer(t, db, chatID)
|
|
|
|
server.publishRetry(chatID, newTestRetryPayload())
|
|
server.publishStatus(chatID, database.ChatStatusCompleted, uuid.NullUUID{})
|
|
|
|
snapshot, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
|
|
require.True(t, ok)
|
|
defer cancel()
|
|
|
|
requireNoSnapshotRetryEvent(t, snapshot)
|
|
requireNoStreamEvent(t, events, 200*time.Millisecond)
|
|
}
|
|
|
|
func TestSubscribePrefersStructuredErrorPayloadViaPubsub(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancelCtx := context.WithCancel(context.Background())
|
|
defer cancelCtx()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
chatID := uuid.New()
|
|
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
|
|
gomock.InOrder(
|
|
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
|
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
|
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
|
ChatID: chatID,
|
|
AfterID: 0,
|
|
}).Return(nil, nil),
|
|
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
|
|
)
|
|
|
|
server := newSubscribeTestServer(t, db)
|
|
_, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
|
|
require.True(t, ok)
|
|
defer cancel()
|
|
|
|
classified := chaterror.ClassifiedError{
|
|
Message: "OpenAI is rate limiting requests.",
|
|
Kind: codersdk.ChatErrorKindRateLimit,
|
|
Provider: "openai",
|
|
Retryable: true,
|
|
StatusCode: 429,
|
|
}
|
|
server.publishError(chatID, classified)
|
|
|
|
event := requireStreamErrorEvent(t, events)
|
|
require.Equal(t, chaterror.TerminalErrorPayload(classified), event.Error)
|
|
requireNoStreamEvent(t, events, 200*time.Millisecond)
|
|
}
|
|
|
|
func TestSubscribeFallsBackToLegacyErrorStringViaPubsub(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancelCtx := context.WithCancel(context.Background())
|
|
defer cancelCtx()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
chatID := uuid.New()
|
|
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
|
|
gomock.InOrder(
|
|
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
|
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
|
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
|
ChatID: chatID,
|
|
AfterID: 0,
|
|
}).Return(nil, nil),
|
|
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
|
|
)
|
|
|
|
server := newSubscribeTestServer(t, db)
|
|
_, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
|
|
require.True(t, ok)
|
|
defer cancel()
|
|
|
|
server.publishChatStreamNotify(chatID, coderdpubsub.ChatStreamNotifyMessage{
|
|
Error: "legacy error only",
|
|
})
|
|
|
|
event := requireStreamErrorEvent(t, events)
|
|
require.Equal(t, &codersdk.ChatError{Message: "legacy error only"}, event.Error)
|
|
requireNoStreamEvent(t, events, 200*time.Millisecond)
|
|
}
|
|
|
|
func newTestRetryPayload() *codersdk.ChatStreamRetry {
|
|
payload := chaterror.StreamRetryPayload(1, 1500*time.Millisecond, chaterror.ClassifiedError{
|
|
Message: "OpenAI is rate limiting requests.",
|
|
Kind: codersdk.ChatErrorKindRateLimit,
|
|
Provider: "openai",
|
|
Retryable: true,
|
|
StatusCode: 429,
|
|
})
|
|
if payload == nil {
|
|
panic("expected retry payload")
|
|
}
|
|
payload.RetryingAt = time.Unix(1_700_000_000, 0).UTC()
|
|
return payload
|
|
}
|
|
|
|
func TestSubscribeAuthorizedFallsBackToStaleRowWhenRefreshFails(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
server := newSubscribeTestServer(t, db)
|
|
|
|
chatID := uuid.New()
|
|
staleChat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
|
|
|
|
state := server.getOrCreateStreamState(chatID)
|
|
state.mu.Lock()
|
|
state.buffer = []bufferedStreamPart{{
|
|
event: codersdk.ChatStreamEvent{
|
|
Type: codersdk.ChatStreamEventTypeMessagePart,
|
|
ChatID: chatID,
|
|
MessagePart: &codersdk.ChatStreamMessagePart{
|
|
Role: "assistant",
|
|
Part: codersdk.ChatMessageText("thinking"),
|
|
},
|
|
},
|
|
}}
|
|
state.mu.Unlock()
|
|
|
|
gomock.InOrder(
|
|
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(database.Chat{}, xerrors.New("refresh failed")),
|
|
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
|
ChatID: chatID,
|
|
AfterID: 0,
|
|
}).Return(nil, nil),
|
|
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
|
|
)
|
|
|
|
initialSnapshot, events, cancel, ok := server.SubscribeAuthorized(ctx, staleChat, nil, 0)
|
|
require.True(t, ok)
|
|
defer cancel()
|
|
|
|
require.Len(t, initialSnapshot, 2)
|
|
require.Equal(t, codersdk.ChatStreamEventTypeStatus, initialSnapshot[0].Type)
|
|
require.NotNil(t, initialSnapshot[0].Status)
|
|
require.Equal(t, codersdk.ChatStatusPending, initialSnapshot[0].Status.Status)
|
|
require.Equal(t, codersdk.ChatStreamEventTypeMessagePart, initialSnapshot[1].Type)
|
|
require.NotNil(t, initialSnapshot[1].MessagePart)
|
|
require.Equal(t, "thinking", initialSnapshot[1].MessagePart.Part.Text)
|
|
requireNoStreamEvent(t, events, 200*time.Millisecond)
|
|
}
|
|
|
|
func TestSubscribeRejectsUnauthorizedCallerBeforeSharedFetches(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
server := newSubscribeTestServer(t, db)
|
|
|
|
chatID := uuid.New()
|
|
db.EXPECT().GetChatByID(gomock.Any(), chatID).
|
|
Return(database.Chat{}, dbauthz.NotAuthorizedError{Err: xerrors.New("not authorized")})
|
|
|
|
snapshot, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
|
|
require.False(t, ok)
|
|
require.Nil(t, snapshot)
|
|
require.Nil(t, events)
|
|
require.Nil(t, cancel)
|
|
|
|
_, exists := server.chatStreams.Load(chatID)
|
|
require.False(t, exists)
|
|
}
|
|
|
|
func TestSubscribeSurfacesTransientLookupFailureAsInitialError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
server := newSubscribeTestServer(t, db)
|
|
|
|
chatID := uuid.New()
|
|
db.EXPECT().GetChatByID(gomock.Any(), chatID).
|
|
Return(database.Chat{}, xerrors.New("transient lookup failure"))
|
|
|
|
snapshot, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
|
|
require.True(t, ok)
|
|
require.NotNil(t, cancel)
|
|
require.Len(t, snapshot, 1)
|
|
require.Equal(t, codersdk.ChatStreamEventTypeError, snapshot[0].Type)
|
|
require.Equal(t, chatID, snapshot[0].ChatID)
|
|
require.Equal(t, "failed to load initial snapshot", snapshot[0].Error.Message)
|
|
|
|
_, open := <-events
|
|
require.False(t, open)
|
|
|
|
_, exists := server.chatStreams.Load(chatID)
|
|
require.False(t, exists)
|
|
}
|
|
|
|
func newSubscribeTestServer(t *testing.T, db database.Store) *Server {
|
|
t.Helper()
|
|
|
|
return &Server{
|
|
db: db,
|
|
logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
|
|
pubsub: dbpubsub.NewInMemory(),
|
|
}
|
|
}
|
|
|
|
func newBufferedSubscribeTestServer(t *testing.T, db database.Store, chatID uuid.UUID) *Server {
|
|
t.Helper()
|
|
|
|
server := newSubscribeTestServer(t, db)
|
|
state := server.getOrCreateStreamState(chatID)
|
|
state.mu.Lock()
|
|
state.buffering = true
|
|
state.mu.Unlock()
|
|
return server
|
|
}
|
|
|
|
func requireStreamMessageEvent(t *testing.T, events <-chan codersdk.ChatStreamEvent) codersdk.ChatStreamEvent {
|
|
t.Helper()
|
|
|
|
select {
|
|
case event, ok := <-events:
|
|
require.True(t, ok, "chat stream closed before delivering an event")
|
|
require.Equal(t, codersdk.ChatStreamEventTypeMessage, event.Type)
|
|
require.NotNil(t, event.Message)
|
|
return event
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timed out waiting for chat stream message event")
|
|
return codersdk.ChatStreamEvent{}
|
|
}
|
|
}
|
|
|
|
func requireStreamRetryEvent(t *testing.T, events <-chan codersdk.ChatStreamEvent) codersdk.ChatStreamEvent {
|
|
t.Helper()
|
|
|
|
select {
|
|
case event, ok := <-events:
|
|
require.True(t, ok, "chat stream closed before delivering an event")
|
|
require.Equal(t, codersdk.ChatStreamEventTypeRetry, event.Type)
|
|
require.NotNil(t, event.Retry)
|
|
return event
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timed out waiting for chat stream retry event")
|
|
return codersdk.ChatStreamEvent{}
|
|
}
|
|
}
|
|
|
|
func requireSnapshotRetryEvent(t *testing.T, snapshot []codersdk.ChatStreamEvent) codersdk.ChatStreamEvent {
|
|
t.Helper()
|
|
|
|
var retryEvents []codersdk.ChatStreamEvent
|
|
for _, event := range snapshot {
|
|
if event.Type == codersdk.ChatStreamEventTypeRetry {
|
|
retryEvents = append(retryEvents, event)
|
|
}
|
|
}
|
|
|
|
require.Len(t, retryEvents, 1, "expected exactly one retry event in snapshot")
|
|
require.NotNil(t, retryEvents[0].Retry)
|
|
return retryEvents[0]
|
|
}
|
|
|
|
func requireNoSnapshotRetryEvent(t *testing.T, snapshot []codersdk.ChatStreamEvent) {
|
|
t.Helper()
|
|
|
|
for _, event := range snapshot {
|
|
require.NotEqual(t, codersdk.ChatStreamEventTypeRetry, event.Type,
|
|
"unexpected retry event in snapshot: %+v", event)
|
|
}
|
|
}
|
|
|
|
func requireSnapshotMessagePartEvent(t *testing.T, snapshot []codersdk.ChatStreamEvent) codersdk.ChatStreamEvent {
|
|
t.Helper()
|
|
|
|
for _, event := range snapshot {
|
|
if event.Type == codersdk.ChatStreamEventTypeMessagePart {
|
|
require.NotNil(t, event.MessagePart)
|
|
return event
|
|
}
|
|
}
|
|
|
|
t.Fatal("expected message_part event in snapshot")
|
|
return codersdk.ChatStreamEvent{}
|
|
}
|
|
|
|
func requireStreamErrorEvent(t *testing.T, events <-chan codersdk.ChatStreamEvent) codersdk.ChatStreamEvent {
|
|
t.Helper()
|
|
|
|
select {
|
|
case event, ok := <-events:
|
|
require.True(t, ok, "chat stream closed before delivering an event")
|
|
require.Equal(t, codersdk.ChatStreamEventTypeError, event.Type)
|
|
require.NotNil(t, event.Error)
|
|
return event
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timed out waiting for chat stream error event")
|
|
return codersdk.ChatStreamEvent{}
|
|
}
|
|
}
|
|
|
|
func requireNoStreamEvent(t *testing.T, events <-chan codersdk.ChatStreamEvent, wait time.Duration) {
|
|
t.Helper()
|
|
|
|
select {
|
|
case event, ok := <-events:
|
|
if !ok {
|
|
t.Fatal("chat stream closed unexpectedly")
|
|
}
|
|
t.Fatalf("unexpected chat stream event: %+v", event)
|
|
case <-time.After(wait):
|
|
}
|
|
}
|
|
|
|
// TestPublishToStream_DropWarnRateLimiting walks through a
|
|
// realistic lifecycle: buffer fills up, subscriber channel fills
|
|
// up, counters get reset between steps. It verifies that WARN
|
|
// logs are rate-limited to at most once per streamDropWarnInterval
|
|
// and that counter resets re-enable an immediate WARN.
|
|
func TestPublishToStream_DropWarnRateLimiting(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
sink := testutil.NewFakeSink(t)
|
|
mClock := quartz.NewMock(t)
|
|
|
|
server := &Server{
|
|
logger: sink.Logger(),
|
|
clock: mClock,
|
|
}
|
|
|
|
chatID := uuid.New()
|
|
subCh := make(chan codersdk.ChatStreamEvent, 1)
|
|
subCh <- codersdk.ChatStreamEvent{} // pre-fill so sends always drop
|
|
|
|
// Set up state that mirrors a running chat: buffer at capacity,
|
|
// buffering enabled, one saturated subscriber.
|
|
state := &chatStreamState{
|
|
buffering: true,
|
|
buffer: make([]bufferedStreamPart, maxStreamBufferSize),
|
|
subscribers: map[uuid.UUID]chan codersdk.ChatStreamEvent{
|
|
uuid.New(): subCh,
|
|
},
|
|
}
|
|
server.chatStreams.Store(chatID, state)
|
|
|
|
bufferMsg := "chat stream buffer full, dropping oldest event"
|
|
subMsg := "dropping chat stream event"
|
|
|
|
filter := func(level slog.Level, msg string) func(slog.SinkEntry) bool {
|
|
return func(e slog.SinkEntry) bool {
|
|
return e.Level == level && e.Message == msg
|
|
}
|
|
}
|
|
|
|
// --- Phase 1: buffer-full rate limiting ---
|
|
// message_part events hit both the buffer-full and subscriber-full
|
|
// paths. The first publish triggers a WARN for each; the rest
|
|
// within the window are DEBUG.
|
|
partEvent := codersdk.ChatStreamEvent{
|
|
Type: codersdk.ChatStreamEventTypeMessagePart,
|
|
MessagePart: &codersdk.ChatStreamMessagePart{},
|
|
}
|
|
for i := 0; i < 50; i++ {
|
|
server.publishToStream(chatID, partEvent)
|
|
}
|
|
|
|
require.Len(t, sink.Entries(filter(slog.LevelWarn, bufferMsg)), 1)
|
|
require.Empty(t, sink.Entries(filter(slog.LevelDebug, bufferMsg)))
|
|
requireFieldValue(t, sink.Entries(filter(slog.LevelWarn, bufferMsg))[0], "dropped_count", int64(1))
|
|
|
|
// Subscriber also saw 50 drops (one per publish).
|
|
require.Len(t, sink.Entries(filter(slog.LevelWarn, subMsg)), 1)
|
|
require.Empty(t, sink.Entries(filter(slog.LevelDebug, subMsg)))
|
|
requireFieldValue(t, sink.Entries(filter(slog.LevelWarn, subMsg))[0], "dropped_count", int64(1))
|
|
|
|
// --- Phase 2: clock advance triggers second WARN with count ---
|
|
mClock.Advance(streamDropWarnInterval + time.Second)
|
|
server.publishToStream(chatID, partEvent)
|
|
|
|
bufWarn := sink.Entries(filter(slog.LevelWarn, bufferMsg))
|
|
require.Len(t, bufWarn, 2)
|
|
requireFieldValue(t, bufWarn[1], "dropped_count", int64(50))
|
|
|
|
subWarn := sink.Entries(filter(slog.LevelWarn, subMsg))
|
|
require.Len(t, subWarn, 2)
|
|
requireFieldValue(t, subWarn[1], "dropped_count", int64(50))
|
|
|
|
// --- Phase 3: counter reset (simulates step persist) ---
|
|
state.mu.Lock()
|
|
state.buffer = make([]bufferedStreamPart, maxStreamBufferSize)
|
|
state.resetDropCounters()
|
|
state.mu.Unlock()
|
|
|
|
// The very next drop should WARN immediately — the reset zeroed
|
|
// lastWarnAt so the interval check passes.
|
|
server.publishToStream(chatID, partEvent)
|
|
|
|
bufWarn = sink.Entries(filter(slog.LevelWarn, bufferMsg))
|
|
require.Len(t, bufWarn, 3, "expected WARN immediately after counter reset")
|
|
requireFieldValue(t, bufWarn[2], "dropped_count", int64(1))
|
|
|
|
subWarn = sink.Entries(filter(slog.LevelWarn, subMsg))
|
|
require.Len(t, subWarn, 3, "expected subscriber WARN immediately after counter reset")
|
|
requireFieldValue(t, subWarn[2], "dropped_count", int64(1))
|
|
}
|
|
|
|
func TestResolveUserCompactionThreshold(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
userID := uuid.New()
|
|
modelConfigID := uuid.New()
|
|
expectedKey := codersdk.CompactionThresholdKey(modelConfigID)
|
|
|
|
tests := []struct {
|
|
name string
|
|
dbReturn string
|
|
dbErr error
|
|
wantVal int32
|
|
wantOK bool
|
|
wantWarnLog bool
|
|
}{
|
|
{
|
|
name: "NoRowsReturnsDefault",
|
|
dbErr: sql.ErrNoRows,
|
|
wantOK: false,
|
|
},
|
|
{
|
|
name: "ValidOverride",
|
|
dbReturn: "75",
|
|
wantVal: 75,
|
|
wantOK: true,
|
|
},
|
|
{
|
|
name: "OutOfRangeValue",
|
|
dbReturn: "101",
|
|
wantOK: false,
|
|
},
|
|
{
|
|
name: "NonIntegerValue",
|
|
dbReturn: "abc",
|
|
wantOK: false,
|
|
},
|
|
{
|
|
name: "UnexpectedDBError",
|
|
dbErr: xerrors.New("connection refused"),
|
|
wantOK: false,
|
|
wantWarnLog: true,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
mockDB := dbmock.NewMockStore(ctrl)
|
|
sink := testutil.NewFakeSink(t)
|
|
|
|
srv := &Server{
|
|
db: mockDB,
|
|
logger: sink.Logger(),
|
|
}
|
|
|
|
mockDB.EXPECT().GetUserChatCompactionThreshold(gomock.Any(), database.GetUserChatCompactionThresholdParams{
|
|
UserID: userID,
|
|
Key: expectedKey,
|
|
}).Return(tc.dbReturn, tc.dbErr)
|
|
|
|
val, ok := srv.resolveUserCompactionThreshold(context.Background(), userID, modelConfigID)
|
|
require.Equal(t, tc.wantVal, val)
|
|
require.Equal(t, tc.wantOK, ok)
|
|
|
|
warns := sink.Entries(func(e slog.SinkEntry) bool {
|
|
return e.Level == slog.LevelWarn
|
|
})
|
|
if tc.wantWarnLog {
|
|
require.NotEmpty(t, warns, "expected a warning log entry")
|
|
return
|
|
}
|
|
require.Empty(t, warns, "unexpected warning log entry")
|
|
})
|
|
}
|
|
}
|
|
|
|
// requireFieldValue asserts that a SinkEntry contains a field with
|
|
// the given name and value.
|
|
func requireFieldValue(t *testing.T, entry slog.SinkEntry, name string, expected interface{}) {
|
|
t.Helper()
|
|
for _, f := range entry.Fields {
|
|
if f.Name == name {
|
|
require.Equal(t, expected, f.Value, "field %q value mismatch", name)
|
|
return
|
|
}
|
|
}
|
|
t.Fatalf("field %q not found in log entry", name)
|
|
}
|
|
|
|
func TestSkillsFromParts(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("Empty", func(t *testing.T) {
|
|
t.Parallel()
|
|
got := skillsFromParts(nil)
|
|
require.Empty(t, got)
|
|
})
|
|
|
|
t.Run("NoSkillParts", func(t *testing.T) {
|
|
t.Parallel()
|
|
msgs := []database.ChatMessage{
|
|
chattest.ChatMessageWithParts([]codersdk.ChatMessagePart{
|
|
{Type: codersdk.ChatMessagePartTypeText, Text: "hello"},
|
|
}),
|
|
}
|
|
got := skillsFromParts(msgs)
|
|
require.Empty(t, got)
|
|
})
|
|
|
|
t.Run("SingleSkill", func(t *testing.T) {
|
|
t.Parallel()
|
|
msgs := []database.ChatMessage{
|
|
chattest.ChatMessageWithParts([]codersdk.ChatMessagePart{
|
|
{
|
|
Type: codersdk.ChatMessagePartTypeSkill,
|
|
SkillName: "deep-review",
|
|
SkillDescription: "Multi-reviewer code review",
|
|
SkillDir: "/home/coder/.agents/skills/deep-review",
|
|
},
|
|
}),
|
|
}
|
|
got := skillsFromParts(msgs)
|
|
require.Len(t, got, 1)
|
|
require.Equal(t, "deep-review", got[0].Name)
|
|
require.Equal(t, "Multi-reviewer code review", got[0].Description)
|
|
require.Equal(t, "/home/coder/.agents/skills/deep-review", got[0].Dir)
|
|
})
|
|
|
|
t.Run("MultipleSkillsAcrossMessages", func(t *testing.T) {
|
|
t.Parallel()
|
|
msgs := []database.ChatMessage{
|
|
chattest.ChatMessageWithParts([]codersdk.ChatMessagePart{
|
|
{
|
|
Type: codersdk.ChatMessagePartTypeSkill,
|
|
SkillName: "pull-requests",
|
|
SkillDir: "/home/coder/.agents/skills/pull-requests",
|
|
},
|
|
}),
|
|
chattest.ChatMessageWithParts([]codersdk.ChatMessagePart{
|
|
{
|
|
Type: codersdk.ChatMessagePartTypeSkill,
|
|
SkillName: "deep-review",
|
|
SkillDir: "/home/coder/.agents/skills/deep-review",
|
|
},
|
|
}),
|
|
}
|
|
got := skillsFromParts(msgs)
|
|
require.Len(t, got, 2)
|
|
require.Equal(t, "pull-requests", got[0].Name)
|
|
require.Equal(t, "deep-review", got[1].Name)
|
|
})
|
|
|
|
t.Run("MixedPartTypes", func(t *testing.T) {
|
|
t.Parallel()
|
|
msgs := []database.ChatMessage{
|
|
chattest.ChatMessageWithParts([]codersdk.ChatMessagePart{
|
|
{
|
|
Type: codersdk.ChatMessagePartTypeContextFile,
|
|
ContextFilePath: "/home/coder/.coder/AGENTS.md",
|
|
},
|
|
{
|
|
Type: codersdk.ChatMessagePartTypeSkill,
|
|
SkillName: "refine-plan",
|
|
SkillDir: "/home/coder/.agents/skills/refine-plan",
|
|
},
|
|
}),
|
|
// A text-only message should be skipped entirely.
|
|
chattest.ChatMessageWithParts([]codersdk.ChatMessagePart{
|
|
{Type: codersdk.ChatMessagePartTypeText, Text: "user turn"},
|
|
}),
|
|
}
|
|
got := skillsFromParts(msgs)
|
|
require.Len(t, got, 1)
|
|
require.Equal(t, "refine-plan", got[0].Name)
|
|
require.Equal(t, "/home/coder/.agents/skills/refine-plan", got[0].Dir)
|
|
})
|
|
|
|
t.Run("OptionalDescriptionOmitted", func(t *testing.T) {
|
|
t.Parallel()
|
|
msgs := []database.ChatMessage{
|
|
chattest.ChatMessageWithParts([]codersdk.ChatMessagePart{
|
|
{
|
|
Type: codersdk.ChatMessagePartTypeSkill,
|
|
SkillName: "refine-plan",
|
|
SkillDir: "/home/coder/.agents/skills/refine-plan",
|
|
},
|
|
}),
|
|
}
|
|
got := skillsFromParts(msgs)
|
|
require.Len(t, got, 1)
|
|
require.Equal(t, "refine-plan", got[0].Name)
|
|
require.Empty(t, got[0].Description)
|
|
})
|
|
|
|
t.Run("InvalidJSON", func(t *testing.T) {
|
|
t.Parallel()
|
|
msgs := []database.ChatMessage{
|
|
{
|
|
Content: pqtype.NullRawMessage{
|
|
RawMessage: []byte(`not valid json with "skill" in it`),
|
|
Valid: true,
|
|
},
|
|
},
|
|
}
|
|
got := skillsFromParts(msgs)
|
|
require.Empty(t, got)
|
|
})
|
|
|
|
t.Run("RoundTrip", func(t *testing.T) {
|
|
// Simulate persist -> reconstruct cycle: marshal skill
|
|
// parts the same way persistInstructionFiles does, then
|
|
// verify skillsFromParts recovers the metadata.
|
|
t.Parallel()
|
|
want := []chattool.SkillMeta{
|
|
{Name: "deep-review", Description: "Multi-reviewer review", Dir: "/skills/deep-review"},
|
|
{Name: "pull-requests", Description: "", Dir: "/skills/pull-requests"},
|
|
}
|
|
agentID := uuid.New()
|
|
var parts []codersdk.ChatMessagePart
|
|
for _, s := range want {
|
|
parts = append(parts, codersdk.ChatMessagePart{
|
|
Type: codersdk.ChatMessagePartTypeSkill,
|
|
SkillName: s.Name,
|
|
SkillDescription: s.Description,
|
|
SkillDir: s.Dir,
|
|
ContextFileAgentID: uuid.NullUUID{UUID: agentID, Valid: true},
|
|
})
|
|
}
|
|
msgs := []database.ChatMessage{chattest.ChatMessageWithParts(parts)}
|
|
got := skillsFromParts(msgs)
|
|
require.Len(t, got, len(want))
|
|
for i, w := range want {
|
|
require.Equal(t, w.Name, got[i].Name)
|
|
require.Equal(t, w.Description, got[i].Description)
|
|
require.Equal(t, w.Dir, got[i].Dir)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestPersonalSkillsInSystemPrompt(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
prompt := buildSystemPrompt(
|
|
nil,
|
|
"",
|
|
"",
|
|
mergeTurnSkills(
|
|
[]skillspkg.Skill{{
|
|
Name: "personal-review",
|
|
Description: "Personal review process",
|
|
Source: skillspkg.SourcePersonal,
|
|
}},
|
|
nil,
|
|
),
|
|
"",
|
|
systemPromptBehaviorContext{},
|
|
)
|
|
|
|
text := systemPromptText(t, prompt)
|
|
require.Contains(t, text, "<available-skills>")
|
|
require.Contains(t, text, "- personal-review: Personal review process")
|
|
require.NotContains(t, text, `"skill"`)
|
|
}
|
|
|
|
func TestPersonalAndWorkspaceSkillCollisionInSystemPrompt(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
resolved := mergeTurnSkills(
|
|
[]skillspkg.Skill{{
|
|
Name: "deploy",
|
|
Description: "Personal deployment process",
|
|
Source: skillspkg.SourcePersonal,
|
|
}},
|
|
[]chattool.SkillMeta{{
|
|
Name: "deploy",
|
|
Description: "Workspace deployment process",
|
|
Dir: "/skills/deploy",
|
|
}},
|
|
)
|
|
prompt := buildSystemPrompt(
|
|
nil,
|
|
"",
|
|
"",
|
|
resolved,
|
|
"",
|
|
systemPromptBehaviorContext{},
|
|
)
|
|
|
|
text := systemPromptText(t, prompt)
|
|
require.Contains(t, text, "<available-skills>")
|
|
require.Contains(t, text, "- personal/deploy: Personal deployment process")
|
|
require.Contains(t, text, "- workspace/deploy: Workspace deployment process")
|
|
require.NotContains(t, text, "\n- deploy: ")
|
|
require.NotContains(t, text, "\n- deploy\n")
|
|
|
|
personal, err := skillspkg.Lookup(resolved, "personal/deploy")
|
|
require.NoError(t, err)
|
|
require.Equal(t, "deploy", personal.Name)
|
|
require.Equal(t, skillspkg.SourcePersonal, personal.Source)
|
|
|
|
workspace, err := skillspkg.Lookup(resolved, "workspace/deploy")
|
|
require.NoError(t, err)
|
|
require.Equal(t, "deploy", workspace.Name)
|
|
require.Equal(t, skillspkg.SourceWorkspace, workspace.Source)
|
|
|
|
_, err = skillspkg.Lookup(resolved, "deploy")
|
|
require.ErrorIs(t, err, skillspkg.ErrSkillAmbiguous)
|
|
require.ErrorContains(t, err, "personal/deploy")
|
|
require.ErrorContains(t, err, "workspace/deploy")
|
|
}
|
|
|
|
func TestSkillIndexRefreshReplacesStaleAliases(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
initialResolved := mergeTurnSkills(
|
|
[]skillspkg.Skill{{
|
|
Name: "deploy",
|
|
Description: "Personal deployment process",
|
|
Source: skillspkg.SourcePersonal,
|
|
}},
|
|
nil,
|
|
)
|
|
prompt := buildSystemPrompt(
|
|
[]fantasy.Message{{
|
|
Role: fantasy.MessageRoleUser,
|
|
Content: []fantasy.MessagePart{
|
|
fantasy.TextPart{Text: "Create a workspace."},
|
|
},
|
|
}},
|
|
"",
|
|
"",
|
|
initialResolved,
|
|
"",
|
|
systemPromptBehaviorContext{},
|
|
)
|
|
|
|
mergedIndex := chattool.FormatResolvedSkillIndex(mergeTurnSkills(
|
|
[]skillspkg.Skill{{
|
|
Name: "deploy",
|
|
Description: "Personal deployment process",
|
|
Source: skillspkg.SourcePersonal,
|
|
}},
|
|
[]chattool.SkillMeta{{
|
|
Name: "deploy",
|
|
Description: "Workspace deployment process",
|
|
Dir: "/skills/deploy",
|
|
}},
|
|
))
|
|
prompt = removeSkillIndexMessages(prompt)
|
|
prompt = chatprompt.InsertSystem(prompt, mergedIndex)
|
|
|
|
text := systemPromptText(t, prompt)
|
|
require.Equal(t, 1, strings.Count(text, "<available-skills>"))
|
|
require.NotContains(t, text, "\n- deploy: Personal deployment process")
|
|
require.Contains(t, text, "- personal/deploy: Personal deployment process")
|
|
require.Contains(t, text, "- workspace/deploy: Workspace deployment process")
|
|
}
|
|
|
|
func requireUserSkillContextActor(ctx context.Context, t *testing.T, userID uuid.UUID) {
|
|
t.Helper()
|
|
actor, ok := dbauthz.ActorFromContext(ctx)
|
|
require.True(t, ok)
|
|
require.Equal(t, rbac.SubjectTypeUser, actor.Type)
|
|
require.Equal(t, userID.String(), actor.ID)
|
|
require.Equal(t, rbac.RoleIdentifiers{rbac.RoleMember()}, actor.Roles)
|
|
}
|
|
|
|
func TestFetchPersonalSkillMetadata(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("Success", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
|
server := &Server{db: db}
|
|
userID := uuid.New()
|
|
|
|
db.EXPECT().ListUserSkillMetadataByUserID(gomock.Any(), userID).DoAndReturn(
|
|
func(ctx context.Context, gotUserID uuid.UUID) ([]database.ListUserSkillMetadataByUserIDRow, error) {
|
|
requireUserSkillContextActor(ctx, t, userID)
|
|
require.Equal(t, userID, gotUserID)
|
|
return []database.ListUserSkillMetadataByUserIDRow{{
|
|
UserID: userID,
|
|
Name: "personal-review",
|
|
Description: "Personal review process",
|
|
}}, nil
|
|
},
|
|
)
|
|
|
|
got := server.fetchPersonalSkillMetadata(context.Background(), userID, logger)
|
|
require.Equal(t, []skillspkg.Skill{{
|
|
Name: "personal-review",
|
|
Description: "Personal review process",
|
|
Source: skillspkg.SourcePersonal,
|
|
}}, got)
|
|
})
|
|
|
|
t.Run("ListFailure", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
sink := testutil.NewFakeSink(t)
|
|
logger := sink.Logger().Leveled(slog.LevelDebug)
|
|
server := &Server{db: db}
|
|
userID := uuid.New()
|
|
|
|
db.EXPECT().ListUserSkillMetadataByUserID(gomock.Any(), userID).Return(nil, xerrors.New("boom"))
|
|
|
|
got := server.fetchPersonalSkillMetadata(context.Background(), userID, logger)
|
|
require.Empty(t, got)
|
|
warns := sink.Entries(func(e slog.SinkEntry) bool {
|
|
return e.Level == slog.LevelWarn && strings.Contains(e.Message, "personal skill metadata")
|
|
})
|
|
require.NotEmpty(t, warns)
|
|
})
|
|
}
|
|
|
|
func TestLoadPersonalSkillBody(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("ParsesCurrentContent", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
server := &Server{db: db}
|
|
userID := uuid.New()
|
|
params := database.GetUserSkillByUserIDAndNameParams{
|
|
UserID: userID,
|
|
Name: "personal-review",
|
|
}
|
|
|
|
db.EXPECT().GetUserSkillByUserIDAndName(gomock.Any(), params).DoAndReturn(
|
|
func(ctx context.Context, gotParams database.GetUserSkillByUserIDAndNameParams) (database.UserSkill, error) {
|
|
requireUserSkillContextActor(ctx, t, userID)
|
|
require.Equal(t, params, gotParams)
|
|
return database.UserSkill{
|
|
UserID: userID,
|
|
Name: "personal-review",
|
|
Content: "---\nname: personal-review\ndescription: Personal review process\n---\n\nUpdated instructions.\n",
|
|
}, nil
|
|
},
|
|
)
|
|
|
|
got, err := server.loadPersonalSkillBody(context.Background(), userID, "personal-review")
|
|
require.NoError(t, err)
|
|
require.Equal(t, "personal-review", got.Name)
|
|
require.Equal(t, "Personal review process", got.Description)
|
|
require.Equal(t, skillspkg.SourcePersonal, got.Source)
|
|
require.Contains(t, got.Body, "Updated instructions.")
|
|
})
|
|
|
|
t.Run("DeletedSkill", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
server := &Server{db: db}
|
|
userID := uuid.New()
|
|
params := database.GetUserSkillByUserIDAndNameParams{
|
|
UserID: userID,
|
|
Name: "missing-skill",
|
|
}
|
|
|
|
db.EXPECT().GetUserSkillByUserIDAndName(gomock.Any(), params).DoAndReturn(
|
|
func(ctx context.Context, gotParams database.GetUserSkillByUserIDAndNameParams) (database.UserSkill, error) {
|
|
requireUserSkillContextActor(ctx, t, userID)
|
|
require.Equal(t, params, gotParams)
|
|
return database.UserSkill{}, sql.ErrNoRows
|
|
},
|
|
)
|
|
|
|
_, err := server.loadPersonalSkillBody(context.Background(), userID, "missing-skill")
|
|
require.ErrorIs(t, err, skillspkg.ErrSkillNotFound)
|
|
})
|
|
|
|
t.Run("DatabaseError", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
sink := testutil.NewFakeSink(t)
|
|
server := &Server{db: db, logger: sink.Logger()}
|
|
userID := uuid.New()
|
|
params := database.GetUserSkillByUserIDAndNameParams{
|
|
UserID: userID,
|
|
Name: "error-skill",
|
|
}
|
|
dbErr := xerrors.New("database unavailable")
|
|
|
|
db.EXPECT().GetUserSkillByUserIDAndName(gomock.Any(), params).DoAndReturn(
|
|
func(ctx context.Context, gotParams database.GetUserSkillByUserIDAndNameParams) (database.UserSkill, error) {
|
|
requireUserSkillContextActor(ctx, t, userID)
|
|
require.Equal(t, params, gotParams)
|
|
return database.UserSkill{}, dbErr
|
|
},
|
|
)
|
|
|
|
_, err := server.loadPersonalSkillBody(context.Background(), userID, "error-skill")
|
|
|
|
require.ErrorContains(t, err, "load personal skill body")
|
|
require.ErrorIs(t, err, dbErr)
|
|
entries := sink.Entries(func(e slog.SinkEntry) bool {
|
|
return e.Level == slog.LevelError && e.Message == "load personal skill body failed"
|
|
})
|
|
require.Len(t, entries, 1)
|
|
requireFieldValue(t, entries[0], "error", dbErr)
|
|
})
|
|
|
|
t.Run("ParseError", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
sink := testutil.NewFakeSink(t)
|
|
server := &Server{db: db, logger: sink.Logger()}
|
|
userID := uuid.New()
|
|
params := database.GetUserSkillByUserIDAndNameParams{
|
|
UserID: userID,
|
|
Name: "broken-skill",
|
|
}
|
|
|
|
db.EXPECT().GetUserSkillByUserIDAndName(gomock.Any(), params).DoAndReturn(
|
|
func(ctx context.Context, gotParams database.GetUserSkillByUserIDAndNameParams) (database.UserSkill, error) {
|
|
requireUserSkillContextActor(ctx, t, userID)
|
|
require.Equal(t, params, gotParams)
|
|
return database.UserSkill{
|
|
UserID: userID,
|
|
Name: "broken-skill",
|
|
Content: "---\nname: broken-skill\ndescription: Broken\n---\n\n \n",
|
|
}, nil
|
|
},
|
|
)
|
|
|
|
_, err := server.loadPersonalSkillBody(context.Background(), userID, "broken-skill")
|
|
|
|
require.ErrorContains(t, err, "parse personal skill body")
|
|
require.ErrorIs(t, err, skillspkg.ErrSkillBodyRequired)
|
|
entries := sink.Entries(func(e slog.SinkEntry) bool {
|
|
return e.Level == slog.LevelError && e.Message == "parse personal skill body failed"
|
|
})
|
|
require.Len(t, entries, 1)
|
|
requireFieldValue(t, entries[0], "user_id", userID)
|
|
requireFieldValue(t, entries[0], "name", "broken-skill")
|
|
})
|
|
}
|
|
|
|
func systemPromptText(t *testing.T, prompt []fantasy.Message) string {
|
|
t.Helper()
|
|
|
|
var b strings.Builder
|
|
for _, msg := range prompt {
|
|
if msg.Role != fantasy.MessageRoleSystem {
|
|
continue
|
|
}
|
|
for _, part := range msg.Content {
|
|
textPart, ok := fantasy.AsMessagePart[fantasy.TextPart](part)
|
|
if ok {
|
|
_, _ = b.WriteString(textPart.Text)
|
|
_, _ = b.WriteString("\n")
|
|
}
|
|
}
|
|
}
|
|
return b.String()
|
|
}
|
|
|
|
func TestContextFileAgentID(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("EmptyMessages", func(t *testing.T) {
|
|
t.Parallel()
|
|
id, ok := contextFileAgentID(nil)
|
|
require.Equal(t, uuid.Nil, id)
|
|
require.False(t, ok)
|
|
})
|
|
|
|
t.Run("NoContextFileParts", func(t *testing.T) {
|
|
t.Parallel()
|
|
msgs := []database.ChatMessage{
|
|
chattest.ChatMessageWithParts([]codersdk.ChatMessagePart{
|
|
{Type: codersdk.ChatMessagePartTypeText, Text: "hello"},
|
|
}),
|
|
}
|
|
id, ok := contextFileAgentID(msgs)
|
|
require.Equal(t, uuid.Nil, id)
|
|
require.False(t, ok)
|
|
})
|
|
|
|
t.Run("SingleContextFile", func(t *testing.T) {
|
|
t.Parallel()
|
|
agentID := uuid.New()
|
|
msgs := []database.ChatMessage{
|
|
chattest.ChatMessageWithParts([]codersdk.ChatMessagePart{
|
|
{
|
|
Type: codersdk.ChatMessagePartTypeContextFile,
|
|
ContextFilePath: "/some/path",
|
|
ContextFileAgentID: uuid.NullUUID{UUID: agentID, Valid: true},
|
|
},
|
|
}),
|
|
}
|
|
id, ok := contextFileAgentID(msgs)
|
|
require.Equal(t, agentID, id)
|
|
require.True(t, ok)
|
|
})
|
|
|
|
t.Run("MultipleContextFiles", func(t *testing.T) {
|
|
t.Parallel()
|
|
agentID1 := uuid.New()
|
|
agentID2 := uuid.New()
|
|
msgs := []database.ChatMessage{
|
|
chattest.ChatMessageWithParts([]codersdk.ChatMessagePart{
|
|
{
|
|
Type: codersdk.ChatMessagePartTypeContextFile,
|
|
ContextFilePath: "/first/path",
|
|
ContextFileAgentID: uuid.NullUUID{UUID: agentID1, Valid: true},
|
|
},
|
|
}),
|
|
chattest.ChatMessageWithParts([]codersdk.ChatMessagePart{
|
|
{
|
|
Type: codersdk.ChatMessagePartTypeContextFile,
|
|
ContextFilePath: "/second/path",
|
|
ContextFileAgentID: uuid.NullUUID{UUID: agentID2, Valid: true},
|
|
},
|
|
}),
|
|
}
|
|
id, ok := contextFileAgentID(msgs)
|
|
require.Equal(t, agentID2, id)
|
|
require.True(t, ok)
|
|
})
|
|
|
|
t.Run("IgnoresSkillOnlySentinel", func(t *testing.T) {
|
|
t.Parallel()
|
|
instructionAgentID := uuid.New()
|
|
sentinelAgentID := uuid.New()
|
|
msgs := []database.ChatMessage{
|
|
chattest.ChatMessageWithParts([]codersdk.ChatMessagePart{{
|
|
Type: codersdk.ChatMessagePartTypeContextFile,
|
|
ContextFilePath: "/workspace/AGENTS.md",
|
|
ContextFileAgentID: uuid.NullUUID{UUID: instructionAgentID, Valid: true},
|
|
}}),
|
|
chattest.ChatMessageWithParts([]codersdk.ChatMessagePart{{
|
|
Type: codersdk.ChatMessagePartTypeContextFile,
|
|
ContextFilePath: AgentChatContextSentinelPath,
|
|
ContextFileAgentID: uuid.NullUUID{
|
|
UUID: sentinelAgentID,
|
|
Valid: true,
|
|
},
|
|
}}),
|
|
}
|
|
id, ok := contextFileAgentID(msgs)
|
|
require.Equal(t, instructionAgentID, id)
|
|
require.True(t, ok)
|
|
})
|
|
|
|
t.Run("SentinelWithoutAgentID", func(t *testing.T) {
|
|
t.Parallel()
|
|
msgs := []database.ChatMessage{
|
|
chattest.ChatMessageWithParts([]codersdk.ChatMessagePart{
|
|
{
|
|
Type: codersdk.ChatMessagePartTypeContextFile,
|
|
ContextFileAgentID: uuid.NullUUID{Valid: false},
|
|
},
|
|
}),
|
|
}
|
|
id, ok := contextFileAgentID(msgs)
|
|
require.Equal(t, uuid.Nil, id)
|
|
require.False(t, ok)
|
|
})
|
|
}
|
|
|
|
func TestHasPersistedInstructionFiles(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("IgnoresAgentChatContextSentinel", func(t *testing.T) {
|
|
t.Parallel()
|
|
agentID := uuid.New()
|
|
msgs := []database.ChatMessage{
|
|
chattest.ChatMessageWithParts([]codersdk.ChatMessagePart{{
|
|
Type: codersdk.ChatMessagePartTypeContextFile,
|
|
ContextFilePath: AgentChatContextSentinelPath,
|
|
ContextFileAgentID: uuid.NullUUID{
|
|
UUID: agentID,
|
|
Valid: true,
|
|
},
|
|
}}),
|
|
}
|
|
require.False(t, hasPersistedInstructionFiles(msgs))
|
|
})
|
|
|
|
t.Run("AcceptsPersistedInstructionFile", func(t *testing.T) {
|
|
t.Parallel()
|
|
agentID := uuid.New()
|
|
msgs := []database.ChatMessage{
|
|
chattest.ChatMessageWithParts([]codersdk.ChatMessagePart{{
|
|
Type: codersdk.ChatMessagePartTypeContextFile,
|
|
ContextFilePath: "/workspace/AGENTS.md",
|
|
ContextFileContent: "repo instructions",
|
|
ContextFileAgentID: uuid.NullUUID{UUID: agentID, Valid: true},
|
|
}}),
|
|
}
|
|
require.True(t, hasPersistedInstructionFiles(msgs))
|
|
})
|
|
}
|
|
|
|
func TestInstructionFromContextFilesUsesLatestContextAgent(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
oldAgentID := uuid.New()
|
|
newAgentID := uuid.New()
|
|
msgs := []database.ChatMessage{
|
|
chattest.ChatMessageWithParts([]codersdk.ChatMessagePart{{
|
|
Type: codersdk.ChatMessagePartTypeContextFile,
|
|
ContextFilePath: "/old/AGENTS.md",
|
|
ContextFileContent: "old instructions",
|
|
ContextFileOS: "darwin",
|
|
ContextFileDirectory: "/old",
|
|
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
|
|
}}),
|
|
chattest.ChatMessageWithParts([]codersdk.ChatMessagePart{{
|
|
Type: codersdk.ChatMessagePartTypeContextFile,
|
|
ContextFilePath: "/new/AGENTS.md",
|
|
ContextFileContent: "new instructions",
|
|
ContextFileOS: "linux",
|
|
ContextFileDirectory: "/new",
|
|
ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true},
|
|
}}),
|
|
}
|
|
|
|
got := instructionFromContextFiles(msgs)
|
|
require.Contains(t, got, "new instructions")
|
|
require.Contains(t, got, "Operating System: linux")
|
|
require.Contains(t, got, "Working Directory: /new")
|
|
require.NotContains(t, got, "old instructions")
|
|
require.NotContains(t, got, "Operating System: darwin")
|
|
}
|
|
|
|
func TestInstructionFromContextFilesKeepsLegacyUnstampedParts(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
oldAgentID := uuid.New()
|
|
newAgentID := uuid.New()
|
|
msgs := []database.ChatMessage{
|
|
chattest.ChatMessageWithParts([]codersdk.ChatMessagePart{{
|
|
Type: codersdk.ChatMessagePartTypeContextFile,
|
|
ContextFilePath: "/legacy/AGENTS.md",
|
|
ContextFileContent: "legacy instructions",
|
|
}}),
|
|
chattest.ChatMessageWithParts([]codersdk.ChatMessagePart{{
|
|
Type: codersdk.ChatMessagePartTypeContextFile,
|
|
ContextFilePath: "/old/AGENTS.md",
|
|
ContextFileContent: "old instructions",
|
|
ContextFileOS: "darwin",
|
|
ContextFileDirectory: "/old",
|
|
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
|
|
}}),
|
|
chattest.ChatMessageWithParts([]codersdk.ChatMessagePart{{
|
|
Type: codersdk.ChatMessagePartTypeContextFile,
|
|
ContextFilePath: "/new/AGENTS.md",
|
|
ContextFileContent: "new instructions",
|
|
ContextFileOS: "linux",
|
|
ContextFileDirectory: "/new",
|
|
ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true},
|
|
}}),
|
|
}
|
|
|
|
got := instructionFromContextFiles(msgs)
|
|
require.Contains(t, got, "legacy instructions")
|
|
require.Contains(t, got, "new instructions")
|
|
require.Contains(t, got, "Operating System: linux")
|
|
require.Contains(t, got, "Working Directory: /new")
|
|
require.NotContains(t, got, "old instructions")
|
|
require.NotContains(t, got, "Operating System: darwin")
|
|
}
|
|
|
|
func TestSkillsFromPartsKeepsLegacyUnstampedParts(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
oldAgentID := uuid.New()
|
|
newAgentID := uuid.New()
|
|
msgs := []database.ChatMessage{
|
|
chattest.ChatMessageWithParts([]codersdk.ChatMessagePart{{
|
|
Type: codersdk.ChatMessagePartTypeSkill,
|
|
SkillName: "repo-helper-legacy",
|
|
SkillDir: "/skills/repo-helper-legacy",
|
|
}}),
|
|
chattest.ChatMessageWithParts([]codersdk.ChatMessagePart{
|
|
{
|
|
Type: codersdk.ChatMessagePartTypeContextFile,
|
|
ContextFilePath: "/old/AGENTS.md",
|
|
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
|
|
},
|
|
{
|
|
Type: codersdk.ChatMessagePartTypeSkill,
|
|
SkillName: "repo-helper-old",
|
|
SkillDir: "/skills/repo-helper-old",
|
|
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
|
|
},
|
|
}),
|
|
chattest.ChatMessageWithParts([]codersdk.ChatMessagePart{
|
|
{
|
|
Type: codersdk.ChatMessagePartTypeContextFile,
|
|
ContextFilePath: AgentChatContextSentinelPath,
|
|
ContextFileAgentID: uuid.NullUUID{
|
|
UUID: newAgentID,
|
|
Valid: true,
|
|
},
|
|
},
|
|
{
|
|
Type: codersdk.ChatMessagePartTypeSkill,
|
|
SkillName: "repo-helper-new",
|
|
SkillDir: "/skills/repo-helper-new",
|
|
ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true},
|
|
},
|
|
}),
|
|
}
|
|
|
|
got := skillsFromParts(msgs)
|
|
require.Equal(t, []chattool.SkillMeta{
|
|
{Name: "repo-helper-legacy", Dir: "/skills/repo-helper-legacy"},
|
|
{Name: "repo-helper-new", Dir: "/skills/repo-helper-new"},
|
|
}, got)
|
|
}
|
|
|
|
func TestSkillsFromPartsUsesLatestContextAgent(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
oldAgentID := uuid.New()
|
|
newAgentID := uuid.New()
|
|
msgs := []database.ChatMessage{
|
|
chattest.ChatMessageWithParts([]codersdk.ChatMessagePart{
|
|
{
|
|
Type: codersdk.ChatMessagePartTypeContextFile,
|
|
ContextFilePath: "/old/AGENTS.md",
|
|
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
|
|
},
|
|
{
|
|
Type: codersdk.ChatMessagePartTypeSkill,
|
|
SkillName: "repo-helper-old",
|
|
SkillDir: "/skills/repo-helper-old",
|
|
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
|
|
},
|
|
}),
|
|
chattest.ChatMessageWithParts([]codersdk.ChatMessagePart{
|
|
{
|
|
Type: codersdk.ChatMessagePartTypeContextFile,
|
|
ContextFilePath: AgentChatContextSentinelPath,
|
|
ContextFileAgentID: uuid.NullUUID{
|
|
UUID: newAgentID,
|
|
Valid: true,
|
|
},
|
|
},
|
|
{
|
|
Type: codersdk.ChatMessagePartTypeSkill,
|
|
SkillName: "repo-helper-new",
|
|
SkillDir: "/skills/repo-helper-new",
|
|
ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true},
|
|
},
|
|
}),
|
|
}
|
|
|
|
got := skillsFromParts(msgs)
|
|
require.Equal(t, []chattool.SkillMeta{{
|
|
Name: "repo-helper-new",
|
|
Dir: "/skills/repo-helper-new",
|
|
}}, got)
|
|
}
|
|
|
|
func TestMergeSkillMetas(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
persisted := []chattool.SkillMeta{{
|
|
Name: "repo-helper",
|
|
Description: "Persisted skill",
|
|
Dir: "/skills/repo-helper-old",
|
|
}}
|
|
discovered := []chattool.SkillMeta{
|
|
{
|
|
Name: "repo-helper",
|
|
Description: "Discovered replacement",
|
|
Dir: "/skills/repo-helper-new",
|
|
MetaFile: "SKILL.md",
|
|
},
|
|
{
|
|
Name: "deep-review",
|
|
Description: "Discovered skill",
|
|
Dir: "/skills/deep-review",
|
|
},
|
|
}
|
|
|
|
got := mergeSkillMetas(persisted, discovered)
|
|
require.Equal(t, []chattool.SkillMeta{
|
|
discovered[0],
|
|
discovered[1],
|
|
}, got)
|
|
}
|
|
|
|
func TestSelectSkillMetasForInstructionRefresh(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
persisted := []chattool.SkillMeta{{Name: "persisted", Dir: "/skills/persisted"}}
|
|
discovered := []chattool.SkillMeta{{Name: "discovered", Dir: "/skills/discovered"}}
|
|
currentAgentID := uuid.New()
|
|
otherAgentID := uuid.New()
|
|
|
|
t.Run("MergesCurrentAgentSkills", func(t *testing.T) {
|
|
t.Parallel()
|
|
got := selectSkillMetasForInstructionRefresh(
|
|
persisted,
|
|
discovered,
|
|
uuid.NullUUID{UUID: currentAgentID, Valid: true},
|
|
uuid.NullUUID{UUID: currentAgentID, Valid: true},
|
|
)
|
|
require.Equal(t, []chattool.SkillMeta{discovered[0], persisted[0]}, got)
|
|
})
|
|
|
|
t.Run("DropsStalePersistedSkillsWhenAgentChanged", func(t *testing.T) {
|
|
t.Parallel()
|
|
got := selectSkillMetasForInstructionRefresh(
|
|
persisted,
|
|
discovered,
|
|
uuid.NullUUID{UUID: currentAgentID, Valid: true},
|
|
uuid.NullUUID{UUID: otherAgentID, Valid: true},
|
|
)
|
|
require.Equal(t, discovered, got)
|
|
})
|
|
|
|
t.Run("PreservesPersistedSkillsWhenAgentLookupFails", func(t *testing.T) {
|
|
t.Parallel()
|
|
got := selectSkillMetasForInstructionRefresh(
|
|
persisted,
|
|
nil,
|
|
uuid.NullUUID{},
|
|
uuid.NullUUID{UUID: otherAgentID, Valid: true},
|
|
)
|
|
require.Equal(t, persisted, got)
|
|
})
|
|
}
|
|
|
|
// TestProcessChat_IgnoresStaleControlNotification verifies that
|
|
// processChat is not interrupted by a "pending" notification
|
|
// published before processing begins. This is the race that caused
|
|
// TestOpenAIReasoningWithWebSearchRoundTripStoreFalse to flake:
|
|
// SendMessage publishes "pending" via PostgreSQL NOTIFY, and due
|
|
// to async delivery the notification can arrive at the control
|
|
// subscriber after it registers but before the processor publishes
|
|
// "running".
|
|
func TestProcessChat_IgnoresStaleControlNotification(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
ps := dbpubsub.NewInMemory()
|
|
clock := quartz.NewMock(t)
|
|
|
|
chatID := uuid.New()
|
|
workerID := uuid.New()
|
|
|
|
server := &Server{
|
|
db: db,
|
|
logger: logger,
|
|
pubsub: ps,
|
|
clock: clock,
|
|
workerID: workerID,
|
|
chatHeartbeatInterval: time.Minute,
|
|
metrics: chatloop.NopMetrics(),
|
|
configCache: newChatConfigCache(ctx, db, clock),
|
|
heartbeatRegistry: make(map[uuid.UUID]*heartbeatEntry),
|
|
}
|
|
|
|
// Publish a stale "pending" notification on the control channel
|
|
// BEFORE processChat subscribes. In production this is the
|
|
// notification from SendMessage that triggered the processing.
|
|
staleNotify, err := json.Marshal(coderdpubsub.ChatStreamNotifyMessage{
|
|
Status: string(database.ChatStatusPending),
|
|
})
|
|
require.NoError(t, err)
|
|
err = ps.Publish(coderdpubsub.ChatStreamNotifyChannel(chatID), staleNotify)
|
|
require.NoError(t, err)
|
|
|
|
// Track which status processChat writes during cleanup.
|
|
var finalStatus database.ChatStatus
|
|
|
|
// The deferred cleanup in processChat runs a transaction.
|
|
db.EXPECT().InTx(gomock.Any(), gomock.Any()).DoAndReturn(
|
|
func(fn func(database.Store) error, _ *database.TxOptions) error {
|
|
return fn(db)
|
|
},
|
|
)
|
|
db.EXPECT().GetChatByIDForUpdate(gomock.Any(), chatID).Return(
|
|
database.Chat{ID: chatID, Status: database.ChatStatusRunning, WorkerID: uuid.NullUUID{UUID: workerID, Valid: true}}, nil,
|
|
)
|
|
db.EXPECT().UpdateChatStatus(gomock.Any(), gomock.Any()).DoAndReturn(
|
|
func(_ context.Context, params database.UpdateChatStatusParams) (database.Chat, error) {
|
|
finalStatus = params.Status
|
|
return database.Chat{
|
|
ID: chatID,
|
|
Status: params.Status,
|
|
LastTurnSummary: sql.NullString{String: "previous summary", Valid: true},
|
|
}, nil
|
|
},
|
|
)
|
|
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(
|
|
database.Chat{ID: chatID, Status: database.ChatStatusError},
|
|
nil,
|
|
)
|
|
|
|
db.EXPECT().UpdateChatLastTurnSummary(gomock.Any(), gomock.Any()).Return(int64(1), nil)
|
|
|
|
// resolveChatModel fails immediately — that's fine, we only
|
|
// need processChat to get past initialization without being
|
|
// interrupted by the stale notification.
|
|
db.EXPECT().GetChatModelConfigByID(gomock.Any(), gomock.Any()).Return(
|
|
database.ChatModelConfig{}, xerrors.New("no model configured"),
|
|
).AnyTimes()
|
|
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return(nil, nil).AnyTimes()
|
|
db.EXPECT().GetEnabledChatModelConfigs(gomock.Any()).Return(nil, nil).AnyTimes()
|
|
db.EXPECT().GetChatUsageLimitConfig(gomock.Any()).Return(
|
|
database.ChatUsageLimitConfig{}, sql.ErrNoRows,
|
|
).AnyTimes()
|
|
db.EXPECT().GetChatMessagesForPromptByChatID(gomock.Any(), chatID).Return(nil, nil).AnyTimes()
|
|
|
|
chat := database.Chat{ID: chatID, LastModelConfigID: uuid.New()}
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer close(done)
|
|
server.processChat(ctx, chat)
|
|
}()
|
|
|
|
// Wait for processChat to finish entirely. It re-reads chat state and
|
|
// runs more cleanup after UpdateChatStatus, so signaling completion from
|
|
// the status update itself races test teardown.
|
|
testutil.TryReceive(ctx, t, done)
|
|
|
|
WaitUntilIdleForTest(server)
|
|
|
|
// If the stale notification interrupted us, status would be
|
|
// "waiting" (the ErrInterrupted path). Since the gate blocked
|
|
// it, processChat reached runChat, which failed on model
|
|
// resolution → status is "error".
|
|
require.Equal(t, database.ChatStatusError, finalStatus,
|
|
"processChat should have reached runChat (error), not been interrupted (waiting)")
|
|
}
|
|
|
|
func TestShouldPublishFinishedChatState(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
chatID := uuid.New()
|
|
workerID := uuid.New()
|
|
|
|
server := &Server{db: db}
|
|
updatedChat := database.Chat{
|
|
ID: chatID,
|
|
Status: database.ChatStatusWaiting,
|
|
WorkerID: uuid.NullUUID{},
|
|
}
|
|
|
|
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(database.Chat{
|
|
ID: chatID,
|
|
Status: database.ChatStatusWaiting,
|
|
WorkerID: uuid.NullUUID{},
|
|
}, nil)
|
|
|
|
require.True(t, server.shouldPublishFinishedChatState(ctx, logger, updatedChat))
|
|
|
|
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(database.Chat{
|
|
ID: chatID,
|
|
Status: database.ChatStatusRunning,
|
|
WorkerID: uuid.NullUUID{UUID: workerID, Valid: true},
|
|
}, nil)
|
|
|
|
require.False(t, server.shouldPublishFinishedChatState(ctx, logger, updatedChat))
|
|
}
|
|
|
|
// TestShouldPublishFinishedChatState_DBErrorPublishes pins the
|
|
// deliberate fail-open behavior when the re-read query errors: we
|
|
// surface the finished state anyway so watchers don't get stuck
|
|
// waiting for a status update that never arrives. The error path is
|
|
// easy to regress into a fail-closed default otherwise.
|
|
func TestShouldPublishFinishedChatState_DBErrorPublishes(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
chatID := uuid.New()
|
|
|
|
server := &Server{db: db}
|
|
updatedChat := database.Chat{
|
|
ID: chatID,
|
|
Status: database.ChatStatusWaiting,
|
|
WorkerID: uuid.NullUUID{},
|
|
}
|
|
|
|
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(
|
|
database.Chat{}, xerrors.New("boom"),
|
|
)
|
|
|
|
require.True(t, server.shouldPublishFinishedChatState(ctx, logger, updatedChat),
|
|
"fail-open: a re-read error must not swallow the status change")
|
|
}
|
|
|
|
// TestHeartbeatTick_StolenChatIsInterrupted verifies that when the
|
|
// batch heartbeat UPDATE does not return a registered chat's ID
|
|
// (because another replica stole it or it was completed), the
|
|
// heartbeat tick cancels that chat's context with ErrInterrupted
|
|
// while leaving surviving chats untouched.
|
|
func TestHeartbeatTick_StolenChatIsInterrupted(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
clock := quartz.NewMock(t)
|
|
|
|
workerID := uuid.New()
|
|
|
|
server := &Server{
|
|
db: db,
|
|
logger: logger,
|
|
clock: clock,
|
|
workerID: workerID,
|
|
chatHeartbeatInterval: time.Minute,
|
|
metrics: chatloop.NopMetrics(),
|
|
heartbeatRegistry: make(map[uuid.UUID]*heartbeatEntry),
|
|
}
|
|
|
|
// Create three chats with independent cancel functions.
|
|
chat1 := uuid.New()
|
|
chat2 := uuid.New()
|
|
chat3 := uuid.New()
|
|
|
|
_, cancel1 := context.WithCancelCause(ctx)
|
|
_, cancel2 := context.WithCancelCause(ctx)
|
|
ctx3, cancel3 := context.WithCancelCause(ctx)
|
|
|
|
server.registerHeartbeat(&heartbeatEntry{
|
|
cancelWithCause: cancel1,
|
|
chatID: chat1,
|
|
logger: logger,
|
|
})
|
|
server.registerHeartbeat(&heartbeatEntry{
|
|
cancelWithCause: cancel2,
|
|
chatID: chat2,
|
|
logger: logger,
|
|
})
|
|
server.registerHeartbeat(&heartbeatEntry{
|
|
cancelWithCause: cancel3,
|
|
chatID: chat3,
|
|
logger: logger,
|
|
})
|
|
|
|
// The batch UPDATE returns only chat1 and chat2 —
|
|
// chat3 was "stolen" by another replica.
|
|
db.EXPECT().UpdateChatHeartbeats(gomock.Any(), gomock.Any()).DoAndReturn(
|
|
func(_ context.Context, params database.UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
|
|
require.Equal(t, workerID, params.WorkerID)
|
|
require.Len(t, params.IDs, 3)
|
|
// Return only chat1 and chat2 as surviving.
|
|
return []uuid.UUID{chat1, chat2}, nil
|
|
},
|
|
)
|
|
|
|
server.heartbeatTick(ctx)
|
|
|
|
// chat3's context should be canceled with ErrInterrupted.
|
|
require.ErrorIs(t, context.Cause(ctx3), chatloop.ErrInterrupted,
|
|
"stolen chat should be interrupted")
|
|
|
|
// chat3 should have been removed from the registry by
|
|
// unregister (in production this happens via defer in
|
|
// processChat). The heartbeat tick itself does not
|
|
// unregister — it only cancels. Verify the entry is
|
|
// still present (processChat's defer would clean it up).
|
|
server.heartbeatMu.Lock()
|
|
_, chat1Exists := server.heartbeatRegistry[chat1]
|
|
_, chat2Exists := server.heartbeatRegistry[chat2]
|
|
_, chat3Exists := server.heartbeatRegistry[chat3]
|
|
server.heartbeatMu.Unlock()
|
|
|
|
require.True(t, chat1Exists, "surviving chat1 should remain registered")
|
|
require.True(t, chat2Exists, "surviving chat2 should remain registered")
|
|
require.True(t, chat3Exists,
|
|
"stolen chat3 should still be in registry (processChat defer removes it)")
|
|
}
|
|
|
|
// TestHeartbeatTick_DBErrorDoesNotInterruptChats verifies that a
|
|
// transient database failure causes the tick to log and return
|
|
// without canceling any registered chats.
|
|
func TestHeartbeatTick_DBErrorDoesNotInterruptChats(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
clock := quartz.NewMock(t)
|
|
|
|
server := &Server{
|
|
db: db,
|
|
logger: logger,
|
|
clock: clock,
|
|
workerID: uuid.New(),
|
|
chatHeartbeatInterval: time.Minute,
|
|
metrics: chatloop.NopMetrics(),
|
|
heartbeatRegistry: make(map[uuid.UUID]*heartbeatEntry),
|
|
}
|
|
|
|
chatID := uuid.New()
|
|
chatCtx, cancel := context.WithCancelCause(ctx)
|
|
|
|
server.registerHeartbeat(&heartbeatEntry{
|
|
cancelWithCause: cancel,
|
|
chatID: chatID,
|
|
logger: logger,
|
|
})
|
|
|
|
// Simulate a transient DB error.
|
|
db.EXPECT().UpdateChatHeartbeats(gomock.Any(), gomock.Any()).Return(
|
|
nil, xerrors.New("connection reset"),
|
|
)
|
|
|
|
server.heartbeatTick(ctx)
|
|
|
|
// Chat should NOT be interrupted — the tick logged and
|
|
// returned early.
|
|
require.NoError(t, chatCtx.Err(),
|
|
"chat context should not be canceled on transient DB error")
|
|
}
|
|
|
|
// TestSubscribeCancelDuringGrace_ReapedBySweep verifies that a
|
|
// subscriber detach inside bufferRetainGracePeriod (the OSS trigger
|
|
// for the retained-buffer leak) leaves the state mapped, and the
|
|
// next sweep past the grace window reaps it.
|
|
func TestSubscribeCancelDuringGrace_ReapedBySweep(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
logger := slogtest.Make(t, nil)
|
|
mClock := quartz.NewMock(t)
|
|
|
|
server := &Server{
|
|
logger: logger,
|
|
clock: mClock,
|
|
}
|
|
|
|
chatID := uuid.New()
|
|
start := mClock.Now()
|
|
|
|
// Just-finished chat: processing done, buffer retained for
|
|
// late-connecting relay subscribers.
|
|
state := &chatStreamState{
|
|
buffering: false,
|
|
bufferRetainedAt: start,
|
|
subscribers: map[uuid.UUID]chan codersdk.ChatStreamEvent{},
|
|
buffer: []bufferedStreamPart{{
|
|
event: codersdk.ChatStreamEvent{
|
|
Type: codersdk.ChatStreamEventTypeMessagePart,
|
|
MessagePart: &codersdk.ChatStreamMessagePart{
|
|
Role: codersdk.ChatMessageRoleAssistant,
|
|
},
|
|
},
|
|
}},
|
|
}
|
|
server.chatStreams.Store(chatID, state)
|
|
|
|
// Real subscribeToStream cancel path: the WS subscriber detach
|
|
// that leaks in prod.
|
|
snapshot, currentRetry, events, cancelSub := server.subscribeToStream(chatID)
|
|
require.Len(t, snapshot, 1)
|
|
require.Nil(t, currentRetry)
|
|
require.NotNil(t, events)
|
|
|
|
mClock.Advance(bufferRetainGracePeriod / 2)
|
|
cancelSub()
|
|
|
|
_, ok := server.chatStreams.Load(chatID)
|
|
require.True(t, ok,
|
|
"entry should remain during grace window after subscriber detach")
|
|
|
|
mClock.Advance(bufferRetainGracePeriod)
|
|
server.sweepIdleStreams()
|
|
|
|
_, ok = server.chatStreams.Load(chatID)
|
|
require.False(t, ok,
|
|
"entry should be reaped after grace period expires and sweep runs")
|
|
}
|
|
|
|
// TestSweepIdleStreams_ReapsStaleRetainedBuffer: grace expired, no
|
|
// subscribers, not buffering -> reaped.
|
|
func TestSweepIdleStreams_ReapsStaleRetainedBuffer(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
mClock := quartz.NewMock(t)
|
|
server := &Server{
|
|
logger: slogtest.Make(t, nil),
|
|
clock: mClock,
|
|
}
|
|
|
|
chatID := uuid.New()
|
|
state := &chatStreamState{
|
|
buffering: false,
|
|
bufferRetainedAt: mClock.Now(),
|
|
subscribers: map[uuid.UUID]chan codersdk.ChatStreamEvent{},
|
|
buffer: []bufferedStreamPart{{
|
|
event: codersdk.ChatStreamEvent{
|
|
Type: codersdk.ChatStreamEventTypeMessagePart,
|
|
MessagePart: &codersdk.ChatStreamMessagePart{},
|
|
},
|
|
}},
|
|
}
|
|
server.chatStreams.Store(chatID, state)
|
|
|
|
mClock.Advance(bufferRetainGracePeriod + time.Second)
|
|
server.sweepIdleStreams()
|
|
|
|
_, ok := server.chatStreams.Load(chatID)
|
|
require.False(t, ok, "stale retained state should be reaped")
|
|
}
|
|
|
|
// TestSweepIdleStreams_DoesNotReapActiveBuffering: buffering=true
|
|
// blocks reap even long after any grace would have expired.
|
|
func TestSweepIdleStreams_DoesNotReapActiveBuffering(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
mClock := quartz.NewMock(t)
|
|
server := &Server{
|
|
logger: slogtest.Make(t, nil),
|
|
clock: mClock,
|
|
}
|
|
|
|
chatID := uuid.New()
|
|
state := &chatStreamState{
|
|
buffering: true,
|
|
subscribers: map[uuid.UUID]chan codersdk.ChatStreamEvent{},
|
|
buffer: []bufferedStreamPart{{
|
|
event: codersdk.ChatStreamEvent{
|
|
Type: codersdk.ChatStreamEventTypeMessagePart,
|
|
MessagePart: &codersdk.ChatStreamMessagePart{},
|
|
},
|
|
}},
|
|
}
|
|
server.chatStreams.Store(chatID, state)
|
|
|
|
mClock.Advance(time.Hour)
|
|
server.sweepIdleStreams()
|
|
|
|
_, ok := server.chatStreams.Load(chatID)
|
|
require.True(t, ok, "actively-buffering state must not be reaped")
|
|
}
|
|
|
|
// TestSweepIdleStreams_DoesNotReapWithSubscribers: attached
|
|
// subscribers block reap even when grace has expired.
|
|
func TestSweepIdleStreams_DoesNotReapWithSubscribers(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
mClock := quartz.NewMock(t)
|
|
server := &Server{
|
|
logger: slogtest.Make(t, nil),
|
|
clock: mClock,
|
|
}
|
|
|
|
chatID := uuid.New()
|
|
state := &chatStreamState{
|
|
buffering: false,
|
|
bufferRetainedAt: mClock.Now(),
|
|
subscribers: map[uuid.UUID]chan codersdk.ChatStreamEvent{
|
|
uuid.New(): make(chan codersdk.ChatStreamEvent, 1),
|
|
},
|
|
buffer: []bufferedStreamPart{{
|
|
event: codersdk.ChatStreamEvent{
|
|
Type: codersdk.ChatStreamEventTypeMessagePart,
|
|
MessagePart: &codersdk.ChatStreamMessagePart{},
|
|
},
|
|
}},
|
|
}
|
|
server.chatStreams.Store(chatID, state)
|
|
|
|
mClock.Advance(bufferRetainGracePeriod + time.Second)
|
|
server.sweepIdleStreams()
|
|
|
|
_, ok := server.chatStreams.Load(chatID)
|
|
require.True(t, ok, "state with subscribers must not be reaped")
|
|
}
|
|
|
|
// TestSweepIdleStreams_DefersDuringGracePeriod: sweep inside grace
|
|
// is a no-op; the next sweep past grace reaps.
|
|
func TestSweepIdleStreams_DefersDuringGracePeriod(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
mClock := quartz.NewMock(t)
|
|
server := &Server{
|
|
logger: slogtest.Make(t, nil),
|
|
clock: mClock,
|
|
}
|
|
|
|
chatID := uuid.New()
|
|
start := mClock.Now()
|
|
state := &chatStreamState{
|
|
buffering: false,
|
|
bufferRetainedAt: start,
|
|
subscribers: map[uuid.UUID]chan codersdk.ChatStreamEvent{},
|
|
buffer: []bufferedStreamPart{{
|
|
event: codersdk.ChatStreamEvent{
|
|
Type: codersdk.ChatStreamEventTypeMessagePart,
|
|
MessagePart: &codersdk.ChatStreamMessagePart{},
|
|
},
|
|
}},
|
|
}
|
|
server.chatStreams.Store(chatID, state)
|
|
|
|
mClock.Advance(bufferRetainGracePeriod / 2)
|
|
server.sweepIdleStreams()
|
|
|
|
_, ok := server.chatStreams.Load(chatID)
|
|
require.True(t, ok, "sweep inside grace window must not reap")
|
|
|
|
mClock.Advance(bufferRetainGracePeriod)
|
|
server.sweepIdleStreams()
|
|
|
|
_, ok = server.chatStreams.Load(chatID)
|
|
require.False(t, ok, "sweep after grace window must reap")
|
|
}
|
|
|
|
// TestPublishToStream_DropZeroesBackingSlot verifies that evicting
|
|
// the oldest buffered event at capacity zeroes the dropped slot so
|
|
// its *ChatStreamMessagePart becomes GC-eligible immediately.
|
|
func TestPublishToStream_DropZeroesBackingSlot(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
mClock := quartz.NewMock(t)
|
|
server := &Server{
|
|
logger: slogtest.Make(t, nil),
|
|
clock: mClock,
|
|
}
|
|
|
|
chatID := uuid.New()
|
|
|
|
// Over-allocate by one so the post-drop append fits in place and
|
|
// exercises the backing-array reuse this test is checking.
|
|
buf := make([]bufferedStreamPart, maxStreamBufferSize, maxStreamBufferSize+1)
|
|
for i := range buf {
|
|
buf[i] = bufferedStreamPart{
|
|
event: codersdk.ChatStreamEvent{
|
|
Type: codersdk.ChatStreamEventTypeMessagePart,
|
|
MessagePart: &codersdk.ChatStreamMessagePart{},
|
|
},
|
|
}
|
|
}
|
|
// Sentinel in slot 0 distinguishes "slot was zeroed" from "slot
|
|
// was overwritten by a later append".
|
|
sentinel := &codersdk.ChatStreamMessagePart{
|
|
Role: codersdk.ChatMessageRoleAssistant,
|
|
}
|
|
buf[0] = bufferedStreamPart{
|
|
event: codersdk.ChatStreamEvent{
|
|
Type: codersdk.ChatStreamEventTypeMessagePart,
|
|
MessagePart: sentinel,
|
|
},
|
|
}
|
|
// Alias over the full backing array so we can still observe slot
|
|
// 0 after publishToStream reslices state.buffer forward.
|
|
origBacking := buf[:cap(buf)]
|
|
|
|
state := &chatStreamState{
|
|
buffering: true,
|
|
buffer: buf,
|
|
subscribers: map[uuid.UUID]chan codersdk.ChatStreamEvent{},
|
|
}
|
|
server.chatStreams.Store(chatID, state)
|
|
|
|
newPart := &codersdk.ChatStreamMessagePart{
|
|
Role: codersdk.ChatMessageRoleAssistant,
|
|
}
|
|
server.publishToStream(chatID, codersdk.ChatStreamEvent{
|
|
Type: codersdk.ChatStreamEventTypeMessagePart,
|
|
MessagePart: newPart,
|
|
})
|
|
|
|
require.Equal(t, bufferedStreamPart{}, origBacking[0],
|
|
"dropped slot must be zero-valued so its *ChatStreamMessagePart "+
|
|
"is eligible for GC; got %+v", origBacking[0])
|
|
|
|
// Sanity-check the in-place append path the fix targets: if Go's
|
|
// growth policy ever makes this append reallocate, this fails
|
|
// loudly so the test author revisits the setup.
|
|
require.Same(t, newPart, origBacking[len(origBacking)-1].event.MessagePart,
|
|
"append must have landed in the original backing array; the "+
|
|
"zero-out invariant only matters when cap > len")
|
|
}
|
|
|
|
// TestCleanupStreamIfIdle_StalePointerDoesNotDeleteFreshEntry covers
|
|
// the race where a caller holds a pointer to a no-longer-mapped
|
|
// state (e.g. a janitor Range callback racing a fresh
|
|
// getOrCreateStreamState) and would otherwise evict the fresh entry.
|
|
// With CompareAndDelete in cleanupStreamIfIdle the stale delete is
|
|
// a no-op.
|
|
func TestCleanupStreamIfIdle_StalePointerDoesNotDeleteFreshEntry(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
mClock := quartz.NewMock(t)
|
|
server := &Server{
|
|
logger: slogtest.Make(t, nil),
|
|
clock: mClock,
|
|
}
|
|
|
|
chatID := uuid.New()
|
|
|
|
// Stale pointer: reapable (not buffering, no subscribers, grace
|
|
// expired) but no longer the map's live entry.
|
|
stale := &chatStreamState{
|
|
buffering: false,
|
|
bufferRetainedAt: mClock.Now(),
|
|
subscribers: map[uuid.UUID]chan codersdk.ChatStreamEvent{},
|
|
}
|
|
|
|
// Fresh entry: the state getOrCreateStreamState would install
|
|
// after a racing processChat run. Actively buffering, so not
|
|
// reapable. Only this state is in the map.
|
|
fresh := &chatStreamState{
|
|
buffering: true,
|
|
subscribers: map[uuid.UUID]chan codersdk.ChatStreamEvent{},
|
|
}
|
|
server.chatStreams.Store(chatID, fresh)
|
|
|
|
mClock.Advance(bufferRetainGracePeriod + time.Second)
|
|
|
|
// Stale caller mirrors the janitor Range callback after the map
|
|
// entry has already been replaced.
|
|
stale.mu.Lock()
|
|
server.cleanupStreamIfIdle(chatID, stale)
|
|
stale.mu.Unlock()
|
|
|
|
got, ok := server.chatStreams.Load(chatID)
|
|
require.True(t, ok,
|
|
"fresh entry must remain mapped when cleanup is called with a stale pointer")
|
|
require.Same(t, fresh, got,
|
|
"cleanup must not replace the fresh entry with the stale one")
|
|
}
|
|
|
|
// TestSafeSweepIdleStreams_RecoversFromPanic verifies that an
|
|
// unexpected panic inside sweepIdleStreams is recovered rather than
|
|
// killing the janitor goroutine. Without this guard, a panic would
|
|
// silently reintroduce the very leak the janitor exists to prevent.
|
|
func TestSafeSweepIdleStreams_RecoversFromPanic(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := &Server{
|
|
logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
|
|
clock: quartz.NewMock(t),
|
|
}
|
|
|
|
chatID := uuid.New()
|
|
// A nil *chatStreamState passes the type assertion in sweepIdleStreams
|
|
// but panics on state.mu.Lock with a nil-pointer deref. Any future
|
|
// panic source in the sweep would trigger the same recovery path.
|
|
var nilState *chatStreamState
|
|
server.chatStreams.Store(chatID, nilState)
|
|
|
|
require.NotPanics(t, func() {
|
|
server.safeSweepIdleStreams(context.Background())
|
|
}, "safeSweepIdleStreams must recover panics so the janitor loop keeps running")
|
|
}
|
|
|
|
func TestGetWorkspaceConn_StaleAgentRecovery(t *testing.T) {
|
|
// Regression test: when a workspace is rebuilt, the chat's stored
|
|
// agent ID points to a disconnected agent from the old build. The
|
|
// cache-miss path must let dialWithLazyValidation discover the new
|
|
// agent instead of rejecting the old one immediately.
|
|
t.Parallel()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
workspaceID := uuid.New()
|
|
oldAgentID := uuid.New()
|
|
newAgentID := uuid.New()
|
|
buildID := uuid.New()
|
|
|
|
// Old agent: disconnected (from previous build).
|
|
oldAgent := database.WorkspaceAgent{
|
|
ID: oldAgentID,
|
|
FirstConnectedAt: sql.NullTime{
|
|
Time: time.Now().Add(-10 * time.Minute),
|
|
Valid: true,
|
|
},
|
|
LastConnectedAt: sql.NullTime{
|
|
Time: time.Now().Add(-10 * time.Minute),
|
|
Valid: true,
|
|
},
|
|
DisconnectedAt: sql.NullTime{
|
|
Time: time.Now().Add(-9 * time.Minute),
|
|
Valid: true,
|
|
},
|
|
}
|
|
|
|
// New agent: connected (from latest build).
|
|
newAgent := database.WorkspaceAgent{
|
|
ID: newAgentID,
|
|
Name: "main",
|
|
FirstConnectedAt: sql.NullTime{
|
|
Time: time.Now().Add(-1 * time.Minute),
|
|
Valid: true,
|
|
},
|
|
LastConnectedAt: sql.NullTime{
|
|
Time: time.Now(),
|
|
Valid: true,
|
|
},
|
|
}
|
|
|
|
chat := database.Chat{
|
|
ID: uuid.New(),
|
|
WorkspaceID: uuid.NullUUID{
|
|
UUID: workspaceID,
|
|
Valid: true,
|
|
},
|
|
AgentID: uuid.NullUUID{
|
|
UUID: oldAgentID,
|
|
Valid: true,
|
|
},
|
|
}
|
|
|
|
// ensureWorkspaceAgent fetches the stale agent.
|
|
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), oldAgentID).
|
|
Return(oldAgent, nil).Times(1)
|
|
// Lazy validation discovers the new agent.
|
|
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).
|
|
Return([]database.WorkspaceAgent{newAgent}, nil).Times(1)
|
|
// Post-switch: persist the new binding.
|
|
db.EXPECT().GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), workspaceID).
|
|
Return(database.WorkspaceBuild{ID: buildID}, nil).Times(1)
|
|
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), newAgentID).
|
|
Return(newAgent, nil).Times(1)
|
|
|
|
updatedChat := chat
|
|
updatedChat.AgentID = uuid.NullUUID{UUID: newAgentID, Valid: true}
|
|
updatedChat.BuildID = uuid.NullUUID{UUID: buildID, Valid: true}
|
|
db.EXPECT().UpdateChatBuildAgentBinding(gomock.Any(), database.UpdateChatBuildAgentBindingParams{
|
|
ID: chat.ID,
|
|
BuildID: uuid.NullUUID{UUID: buildID, Valid: true},
|
|
AgentID: uuid.NullUUID{UUID: newAgentID, Valid: true},
|
|
}).Return(updatedChat, nil).Times(1)
|
|
|
|
newConn := agentconnmock.NewMockAgentConn(ctrl)
|
|
newConn.EXPECT().SetExtraHeaders(gomock.Any()).Times(1)
|
|
|
|
server := &Server{
|
|
db: db,
|
|
logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
|
|
clock: quartz.NewReal(),
|
|
agentInactiveDisconnectTimeout: 30 * time.Second,
|
|
dialTimeout: defaultDialTimeout,
|
|
}
|
|
server.agentConnFn = func(_ context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
|
switch id {
|
|
case oldAgentID:
|
|
return nil, nil, xerrors.New("agent is not connected")
|
|
case newAgentID:
|
|
return newConn, func() {}, nil
|
|
default:
|
|
return nil, nil, xerrors.Errorf("unexpected agent ID: %s", id)
|
|
}
|
|
}
|
|
|
|
chatStateMu := &sync.Mutex{}
|
|
currentChat := chat
|
|
workspaceCtx := turnWorkspaceContext{
|
|
server: server,
|
|
chatStateMu: chatStateMu,
|
|
currentChat: ¤tChat,
|
|
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) {
|
|
return database.Chat{}, nil
|
|
},
|
|
}
|
|
defer workspaceCtx.close()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
gotConn, err := workspaceCtx.getWorkspaceConn(ctx)
|
|
require.NoError(t, err, "getWorkspaceConn should recover stale agent binding")
|
|
require.Same(t, newConn, gotConn, "should return the connection to the new agent")
|
|
|
|
// Verify the cache was updated to the new agent so subsequent
|
|
// cache-hit calls use the correct agent ID.
|
|
workspaceCtx.mu.Lock()
|
|
defer workspaceCtx.mu.Unlock()
|
|
require.Equal(t, newAgentID, workspaceCtx.agent.ID, "cached agent should be the new agent")
|
|
require.True(t, workspaceCtx.agentLoaded)
|
|
require.Same(t, newConn, workspaceCtx.conn, "connection should be cached for subsequent calls")
|
|
}
|
|
|
|
func TestGetWorkspaceConn_SameBuildAgentCrash(t *testing.T) {
|
|
// When an agent crashes on the same build (disconnected, but still
|
|
// in the latest build), dialWithLazyValidation dials, fails fast,
|
|
// validation finds the same agent, and the retry also fails. The
|
|
// wrapped dial error propagates (not errChatAgentDisconnected).
|
|
t.Parallel()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
workspaceID := uuid.New()
|
|
agentID := uuid.New()
|
|
|
|
// Agent: disconnected (crashed on current build).
|
|
agent := database.WorkspaceAgent{
|
|
ID: agentID,
|
|
Name: "main",
|
|
FirstConnectedAt: sql.NullTime{
|
|
Time: time.Now().Add(-10 * time.Minute),
|
|
Valid: true,
|
|
},
|
|
LastConnectedAt: sql.NullTime{
|
|
Time: time.Now().Add(-10 * time.Minute),
|
|
Valid: true,
|
|
},
|
|
DisconnectedAt: sql.NullTime{
|
|
Time: time.Now().Add(-9 * time.Minute),
|
|
Valid: true,
|
|
},
|
|
}
|
|
|
|
chat := database.Chat{
|
|
ID: uuid.New(),
|
|
WorkspaceID: uuid.NullUUID{
|
|
UUID: workspaceID,
|
|
Valid: true,
|
|
},
|
|
AgentID: uuid.NullUUID{
|
|
UUID: agentID,
|
|
Valid: true,
|
|
},
|
|
}
|
|
|
|
// ensureWorkspaceAgent fetches the (crashed) agent.
|
|
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).
|
|
Return(agent, nil).Times(1)
|
|
// Validation finds the same agent in the latest build.
|
|
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).
|
|
Return([]database.WorkspaceAgent{agent}, nil).Times(1)
|
|
|
|
dialErr := xerrors.New("agent is not connected")
|
|
server := &Server{
|
|
db: db,
|
|
logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
|
|
clock: quartz.NewReal(),
|
|
agentInactiveDisconnectTimeout: 30 * time.Second,
|
|
dialTimeout: defaultDialTimeout,
|
|
}
|
|
server.agentConnFn = func(_ context.Context, _ uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
|
return nil, nil, dialErr
|
|
}
|
|
|
|
chatStateMu := &sync.Mutex{}
|
|
currentChat := chat
|
|
workspaceCtx := turnWorkspaceContext{
|
|
server: server,
|
|
chatStateMu: chatStateMu,
|
|
currentChat: ¤tChat,
|
|
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) {
|
|
return database.Chat{}, nil
|
|
},
|
|
}
|
|
defer workspaceCtx.close()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
gotConn, err := workspaceCtx.getWorkspaceConn(ctx)
|
|
require.Nil(t, gotConn)
|
|
require.Error(t, err)
|
|
// The error should be a wrapped dial error, not the
|
|
// agent-disconnected sentinel.
|
|
require.NotErrorIs(t, err, errChatAgentDisconnected)
|
|
require.ErrorIs(t, err, dialErr)
|
|
|
|
// Cache should not have a connection, but the agent should
|
|
// still be loaded (ensureWorkspaceAgent cached it).
|
|
workspaceCtx.mu.Lock()
|
|
defer workspaceCtx.mu.Unlock()
|
|
require.True(t, workspaceCtx.agentLoaded)
|
|
require.Nil(t, workspaceCtx.conn)
|
|
}
|
|
|
|
func TestGetWorkspaceConn_StatusCheck(t *testing.T) {
|
|
// The cache-hit status check re-fetches the agent row for a fresh
|
|
// heartbeat timestamp. Healthy, timed-out, and DB-error paths return
|
|
// the cached connection. Disconnected agents are covered separately
|
|
// because they now trigger a fresh dial before recovery.
|
|
t.Parallel()
|
|
|
|
type testCase struct {
|
|
name string
|
|
buildAgent func(now time.Time) database.WorkspaceAgent
|
|
dbError bool
|
|
}
|
|
|
|
tests := []testCase{
|
|
{
|
|
// Agent never connected and the connection timeout
|
|
// has elapsed. This should not trigger lifecycle
|
|
// recovery because the agent did not connect and
|
|
// then disconnect.
|
|
name: "TimedOutAgentCacheHit",
|
|
buildAgent: func(now time.Time) database.WorkspaceAgent {
|
|
return database.WorkspaceAgent{
|
|
CreatedAt: now.Add(-10 * time.Minute),
|
|
ConnectionTimeoutSeconds: 60,
|
|
}
|
|
},
|
|
},
|
|
{
|
|
name: "CacheHitHealthyAgent",
|
|
buildAgent: func(now time.Time) database.WorkspaceAgent {
|
|
return database.WorkspaceAgent{
|
|
FirstConnectedAt: sql.NullTime{
|
|
Time: now.Add(-5 * time.Minute),
|
|
Valid: true,
|
|
},
|
|
LastConnectedAt: sql.NullTime{
|
|
Time: now,
|
|
Valid: true,
|
|
},
|
|
}
|
|
},
|
|
},
|
|
{
|
|
// When GetWorkspaceAgentByID returns an error on
|
|
// cache hit, the cached connection should be returned.
|
|
name: "CacheHitDBError",
|
|
buildAgent: func(now time.Time) database.WorkspaceAgent {
|
|
return database.WorkspaceAgent{
|
|
FirstConnectedAt: sql.NullTime{
|
|
Time: now.Add(-5 * time.Minute),
|
|
Valid: true,
|
|
},
|
|
LastConnectedAt: sql.NullTime{
|
|
Time: now,
|
|
Valid: true,
|
|
},
|
|
}
|
|
},
|
|
dbError: true,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
workspaceID := uuid.New()
|
|
agentID := uuid.New()
|
|
chat := database.Chat{
|
|
ID: uuid.New(),
|
|
WorkspaceID: uuid.NullUUID{
|
|
UUID: workspaceID,
|
|
Valid: true,
|
|
},
|
|
AgentID: uuid.NullUUID{
|
|
UUID: agentID,
|
|
Valid: true,
|
|
},
|
|
}
|
|
|
|
// Stamp the agent with the generated ID. Use the
|
|
// subtest's mock clock so the agent's timestamps are
|
|
// anchored to the same `now` the server uses. Using
|
|
// time.Now() at slice-literal construction time
|
|
// produced a Windows-CI flake because a slow scheduler
|
|
// could insert more than agentInactiveDisconnectTimeout
|
|
// of wall-clock delay between the literal and the
|
|
// subtest body.
|
|
clock := quartz.NewMock(t)
|
|
now := clock.Now()
|
|
agent := tc.buildAgent(now)
|
|
agent.ID = agentID
|
|
|
|
// Set up the DB mock for GetWorkspaceAgentByID.
|
|
if tc.dbError {
|
|
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).
|
|
Return(database.WorkspaceAgent{}, xerrors.New("connection reset")).
|
|
Times(1)
|
|
} else {
|
|
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).
|
|
Return(agent, nil).
|
|
Times(1)
|
|
}
|
|
|
|
var releaseCalled bool
|
|
|
|
server := &Server{
|
|
db: db,
|
|
logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
|
|
clock: clock,
|
|
agentInactiveDisconnectTimeout: 30 * time.Second,
|
|
dialTimeout: defaultDialTimeout,
|
|
}
|
|
server.agentConnFn = func(context.Context, uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
|
return nil, nil, xerrors.New("should not be called")
|
|
}
|
|
|
|
chatStateMu := &sync.Mutex{}
|
|
currentChat := chat
|
|
cachedConn := agentconnmock.NewMockAgentConn(ctrl)
|
|
workspaceCtx := turnWorkspaceContext{
|
|
server: server,
|
|
chatStateMu: chatStateMu,
|
|
currentChat: ¤tChat,
|
|
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) {
|
|
return database.Chat{}, nil
|
|
},
|
|
agent: agent,
|
|
agentLoaded: true,
|
|
conn: cachedConn,
|
|
releaseConn: func() { releaseCalled = true },
|
|
cachedWorkspaceID: chat.WorkspaceID,
|
|
}
|
|
defer workspaceCtx.close()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
gotConn, err := workspaceCtx.getWorkspaceConn(ctx)
|
|
require.NoError(t, err)
|
|
require.Same(t, cachedConn, gotConn)
|
|
require.False(t, releaseCalled, "release called")
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGetWorkspaceConn_DialTimeoutDisconnectedRecoveryThreshold(t *testing.T) {
|
|
// The recovery sentinel requires a failed dial and a fresh
|
|
// disconnected status check past the recovery threshold. A
|
|
// disconnected DB row alone is not enough to trigger stop/start
|
|
// recovery.
|
|
t.Parallel()
|
|
|
|
testCases := []struct {
|
|
name string
|
|
disconnectedFor time.Duration
|
|
wantErr error
|
|
wantRecovery bool
|
|
}{
|
|
{
|
|
name: "RecentDisconnectReturnsDialTimeout",
|
|
disconnectedFor: agentDisconnectedRecoveryThreshold / 2,
|
|
wantErr: errChatDialTimeout,
|
|
wantRecovery: false,
|
|
},
|
|
{
|
|
name: "PastThresholdEscalates",
|
|
disconnectedFor: agentDisconnectedRecoveryThreshold,
|
|
wantErr: errChatAgentDisconnected,
|
|
wantRecovery: true,
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
workspaceID := uuid.New()
|
|
agentID := uuid.New()
|
|
chat := database.Chat{
|
|
ID: uuid.New(),
|
|
WorkspaceID: uuid.NullUUID{
|
|
UUID: workspaceID,
|
|
Valid: true,
|
|
},
|
|
AgentID: uuid.NullUUID{
|
|
UUID: agentID,
|
|
Valid: true,
|
|
},
|
|
}
|
|
|
|
clock := quartz.NewMock(t)
|
|
now := clock.Now()
|
|
disconnectedAgent := database.WorkspaceAgent{
|
|
ID: agentID,
|
|
FirstConnectedAt: sql.NullTime{
|
|
Time: now.Add(-10 * time.Minute),
|
|
Valid: true,
|
|
},
|
|
LastConnectedAt: sql.NullTime{
|
|
Time: now.Add(-10 * time.Minute),
|
|
Valid: true,
|
|
},
|
|
DisconnectedAt: sql.NullTime{
|
|
Time: now.Add(-tc.disconnectedFor),
|
|
Valid: true,
|
|
},
|
|
}
|
|
|
|
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).
|
|
Return(disconnectedAgent, nil).
|
|
Times(2)
|
|
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).
|
|
Return([]database.WorkspaceAgent{disconnectedAgent}, nil).
|
|
Times(1)
|
|
|
|
server := &Server{
|
|
db: db,
|
|
logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
|
|
clock: clock,
|
|
agentInactiveDisconnectTimeout: 30 * time.Second,
|
|
dialTimeout: 10 * time.Millisecond,
|
|
}
|
|
server.agentConnFn = func(ctx context.Context, _ uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
|
<-ctx.Done()
|
|
return nil, nil, ctx.Err()
|
|
}
|
|
|
|
chatStateMu := &sync.Mutex{}
|
|
currentChat := chat
|
|
workspaceCtx := turnWorkspaceContext{
|
|
server: server,
|
|
chatStateMu: chatStateMu,
|
|
currentChat: ¤tChat,
|
|
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
|
|
}
|
|
defer workspaceCtx.close()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
gotConn, err := workspaceCtx.getWorkspaceConn(ctx)
|
|
require.Nil(t, gotConn)
|
|
require.ErrorIs(t, err, tc.wantErr)
|
|
if tc.wantRecovery {
|
|
require.ErrorIs(t, err, errChatAgentDisconnected)
|
|
} else {
|
|
require.NotErrorIs(t, err, errChatAgentDisconnected)
|
|
}
|
|
|
|
workspaceCtx.mu.Lock()
|
|
defer workspaceCtx.mu.Unlock()
|
|
require.False(t, workspaceCtx.agentLoaded)
|
|
require.Nil(t, workspaceCtx.conn)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGetWorkspaceConn_DisconnectedStatusDialSuccessDoesNotEscalate(t *testing.T) {
|
|
// A stale disconnected row must not prompt stop/start if the
|
|
// agent can still be dialed successfully.
|
|
t.Parallel()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
workspaceID := uuid.New()
|
|
agentID := uuid.New()
|
|
chat := database.Chat{
|
|
ID: uuid.New(),
|
|
WorkspaceID: uuid.NullUUID{
|
|
UUID: workspaceID,
|
|
Valid: true,
|
|
},
|
|
AgentID: uuid.NullUUID{
|
|
UUID: agentID,
|
|
Valid: true,
|
|
},
|
|
}
|
|
|
|
disconnectedAgent := database.WorkspaceAgent{
|
|
ID: agentID,
|
|
FirstConnectedAt: sql.NullTime{
|
|
Time: time.Now().Add(-10 * time.Minute),
|
|
Valid: true,
|
|
},
|
|
LastConnectedAt: sql.NullTime{
|
|
Time: time.Now().Add(-10 * time.Minute),
|
|
Valid: true,
|
|
},
|
|
}
|
|
|
|
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).
|
|
Return(disconnectedAgent, nil).
|
|
Times(1)
|
|
|
|
server := &Server{
|
|
db: db,
|
|
logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
|
|
clock: quartz.NewReal(),
|
|
agentInactiveDisconnectTimeout: 30 * time.Second,
|
|
dialTimeout: 10 * time.Millisecond,
|
|
}
|
|
conn := agentconnmock.NewMockAgentConn(ctrl)
|
|
conn.EXPECT().SetExtraHeaders(gomock.Any()).Times(1)
|
|
var dialCalled bool
|
|
server.agentConnFn = func(context.Context, uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
|
dialCalled = true
|
|
return conn, nil, nil
|
|
}
|
|
|
|
chatStateMu := &sync.Mutex{}
|
|
currentChat := chat
|
|
workspaceCtx := turnWorkspaceContext{
|
|
server: server,
|
|
chatStateMu: chatStateMu,
|
|
currentChat: ¤tChat,
|
|
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
|
|
}
|
|
defer workspaceCtx.close()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
gotConn, err := workspaceCtx.getWorkspaceConn(ctx)
|
|
require.NoError(t, err)
|
|
require.Same(t, conn, gotConn)
|
|
require.True(t, dialCalled, "dial called")
|
|
}
|
|
|
|
func TestGetWorkspaceConn_CacheHitDisconnectedRetriesDialBeforeEscalating(t *testing.T) {
|
|
// A disconnected cached connection is discarded first. Recovery is
|
|
// only surfaced if the replacement dial also times out.
|
|
t.Parallel()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
workspaceID := uuid.New()
|
|
agentID := uuid.New()
|
|
chat := database.Chat{
|
|
ID: uuid.New(),
|
|
WorkspaceID: uuid.NullUUID{
|
|
UUID: workspaceID,
|
|
Valid: true,
|
|
},
|
|
AgentID: uuid.NullUUID{
|
|
UUID: agentID,
|
|
Valid: true,
|
|
},
|
|
}
|
|
disconnectedAgent := database.WorkspaceAgent{
|
|
ID: agentID,
|
|
FirstConnectedAt: sql.NullTime{
|
|
Time: time.Now().Add(-10 * time.Minute),
|
|
Valid: true,
|
|
},
|
|
LastConnectedAt: sql.NullTime{
|
|
Time: time.Now().Add(-10 * time.Minute),
|
|
Valid: true,
|
|
},
|
|
}
|
|
|
|
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).
|
|
Return(disconnectedAgent, nil).
|
|
Times(2)
|
|
|
|
server := &Server{
|
|
db: db,
|
|
logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
|
|
clock: quartz.NewReal(),
|
|
agentInactiveDisconnectTimeout: 30 * time.Second,
|
|
dialTimeout: 10 * time.Millisecond,
|
|
}
|
|
newConn := agentconnmock.NewMockAgentConn(ctrl)
|
|
newConn.EXPECT().SetExtraHeaders(gomock.Any()).Times(1)
|
|
var dialCalled bool
|
|
server.agentConnFn = func(context.Context, uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
|
dialCalled = true
|
|
return newConn, nil, nil
|
|
}
|
|
|
|
var releaseCalled bool
|
|
chatStateMu := &sync.Mutex{}
|
|
currentChat := chat
|
|
oldConn := agentconnmock.NewMockAgentConn(ctrl)
|
|
workspaceCtx := turnWorkspaceContext{
|
|
server: server,
|
|
chatStateMu: chatStateMu,
|
|
currentChat: ¤tChat,
|
|
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
|
|
agent: disconnectedAgent,
|
|
agentLoaded: true,
|
|
conn: oldConn,
|
|
releaseConn: func() { releaseCalled = true },
|
|
cachedWorkspaceID: chat.WorkspaceID,
|
|
}
|
|
defer workspaceCtx.close()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
gotConn, err := workspaceCtx.getWorkspaceConn(ctx)
|
|
require.NoError(t, err)
|
|
require.Same(t, newConn, gotConn)
|
|
require.True(t, releaseCalled, "release called")
|
|
require.True(t, dialCalled, "dial called")
|
|
}
|
|
|
|
func TestGetWorkspaceConn_DialTimeout(t *testing.T) {
|
|
// When dialWithLazyValidation blocks beyond the dial
|
|
// timeout, getWorkspaceConn should return
|
|
// errChatDialTimeout instead of hanging indefinitely.
|
|
t.Parallel()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
workspaceID := uuid.New()
|
|
agentID := uuid.New()
|
|
chat := database.Chat{
|
|
ID: uuid.New(),
|
|
WorkspaceID: uuid.NullUUID{
|
|
UUID: workspaceID,
|
|
Valid: true,
|
|
},
|
|
AgentID: uuid.NullUUID{
|
|
UUID: agentID,
|
|
Valid: true,
|
|
},
|
|
}
|
|
|
|
// Agent appears connected so the status check passes.
|
|
connectedAgent := database.WorkspaceAgent{
|
|
ID: agentID,
|
|
FirstConnectedAt: sql.NullTime{
|
|
Time: time.Now().Add(-1 * time.Minute),
|
|
Valid: true,
|
|
},
|
|
LastConnectedAt: sql.NullTime{
|
|
Time: time.Now(),
|
|
Valid: true,
|
|
},
|
|
}
|
|
|
|
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).
|
|
Return(connectedAgent, nil).
|
|
Times(2)
|
|
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).
|
|
Return([]database.WorkspaceAgent{connectedAgent}, nil).
|
|
Times(1)
|
|
|
|
server := &Server{
|
|
db: db,
|
|
clock: quartz.NewReal(),
|
|
agentInactiveDisconnectTimeout: 30 * time.Second,
|
|
dialTimeout: 10 * time.Millisecond,
|
|
}
|
|
// Dial blocks forever (simulates unreachable agent).
|
|
server.agentConnFn = func(ctx context.Context, _ uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
|
<-ctx.Done()
|
|
return nil, nil, ctx.Err()
|
|
}
|
|
|
|
chatStateMu := &sync.Mutex{}
|
|
currentChat := chat
|
|
workspaceCtx := turnWorkspaceContext{
|
|
server: server,
|
|
chatStateMu: chatStateMu,
|
|
currentChat: ¤tChat,
|
|
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
|
|
}
|
|
defer workspaceCtx.close()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
gotConn, err := workspaceCtx.getWorkspaceConn(ctx)
|
|
require.Nil(t, gotConn)
|
|
require.ErrorIs(t, err, errChatDialTimeout)
|
|
}
|
|
|
|
func TestGetWorkspaceConn_DialTimeoutStatusTimeoutDoesNotEscalate(t *testing.T) {
|
|
// Agents that never connected are startup failures, not
|
|
// disconnected recovery cases. A dial timeout should stay a
|
|
// retry/escalation error rather than stop/start guidance.
|
|
t.Parallel()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
workspaceID := uuid.New()
|
|
agentID := uuid.New()
|
|
chat := database.Chat{
|
|
ID: uuid.New(),
|
|
WorkspaceID: uuid.NullUUID{
|
|
UUID: workspaceID,
|
|
Valid: true,
|
|
},
|
|
AgentID: uuid.NullUUID{
|
|
UUID: agentID,
|
|
Valid: true,
|
|
},
|
|
}
|
|
|
|
timedOutAgent := database.WorkspaceAgent{
|
|
ID: agentID,
|
|
CreatedAt: time.Now().Add(-10 * time.Minute),
|
|
ConnectionTimeoutSeconds: 60,
|
|
}
|
|
|
|
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).
|
|
Return(timedOutAgent, nil).
|
|
Times(2)
|
|
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).
|
|
Return([]database.WorkspaceAgent{timedOutAgent}, nil).
|
|
Times(1)
|
|
|
|
server := &Server{
|
|
db: db,
|
|
clock: quartz.NewReal(),
|
|
agentInactiveDisconnectTimeout: 30 * time.Second,
|
|
dialTimeout: 10 * time.Millisecond,
|
|
}
|
|
server.agentConnFn = func(ctx context.Context, _ uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
|
<-ctx.Done()
|
|
return nil, nil, ctx.Err()
|
|
}
|
|
|
|
chatStateMu := &sync.Mutex{}
|
|
currentChat := chat
|
|
workspaceCtx := turnWorkspaceContext{
|
|
server: server,
|
|
chatStateMu: chatStateMu,
|
|
currentChat: ¤tChat,
|
|
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
|
|
}
|
|
defer workspaceCtx.close()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
gotConn, err := workspaceCtx.getWorkspaceConn(ctx)
|
|
require.Nil(t, gotConn)
|
|
require.ErrorIs(t, err, errChatDialTimeout)
|
|
require.NotErrorIs(t, err, errChatAgentDisconnected)
|
|
}
|
|
|
|
func TestGetWorkspaceConn_DialTimeoutParentCanceled(t *testing.T) {
|
|
// When the parent context is canceled, the parent's error
|
|
// must propagate unchanged (not wrapped as a dial timeout).
|
|
// This is critical because the chatloop checks
|
|
// context.Cause(ctx) for ErrInterrupted.
|
|
t.Parallel()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
workspaceID := uuid.New()
|
|
agentID := uuid.New()
|
|
chat := database.Chat{
|
|
ID: uuid.New(),
|
|
WorkspaceID: uuid.NullUUID{
|
|
UUID: workspaceID,
|
|
Valid: true,
|
|
},
|
|
AgentID: uuid.NullUUID{
|
|
UUID: agentID,
|
|
Valid: true,
|
|
},
|
|
}
|
|
|
|
connectedAgent := database.WorkspaceAgent{
|
|
ID: agentID,
|
|
FirstConnectedAt: sql.NullTime{
|
|
Time: time.Now().Add(-1 * time.Minute),
|
|
Valid: true,
|
|
},
|
|
LastConnectedAt: sql.NullTime{
|
|
Time: time.Now(),
|
|
Valid: true,
|
|
},
|
|
}
|
|
|
|
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).
|
|
Return(connectedAgent, nil).
|
|
Times(1)
|
|
|
|
parentErr := xerrors.New("parent canceled")
|
|
ctx, cancel := context.WithCancelCause(testutil.Context(t, testutil.WaitShort))
|
|
|
|
server := &Server{
|
|
db: db,
|
|
clock: quartz.NewReal(),
|
|
agentInactiveDisconnectTimeout: 30 * time.Second,
|
|
// Use a very long dial timeout so the parent cancel fires
|
|
// first.
|
|
dialTimeout: 10 * time.Minute,
|
|
}
|
|
// Signal when the dial goroutine has started so we can
|
|
// cancel the parent at the right time without time.Sleep.
|
|
dialStarted := make(chan struct{})
|
|
server.agentConnFn = func(ctx context.Context, _ uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
|
close(dialStarted)
|
|
<-ctx.Done()
|
|
return nil, nil, ctx.Err()
|
|
}
|
|
|
|
chatStateMu := &sync.Mutex{}
|
|
currentChat := chat
|
|
workspaceCtx := turnWorkspaceContext{
|
|
server: server,
|
|
chatStateMu: chatStateMu,
|
|
currentChat: ¤tChat,
|
|
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
|
|
}
|
|
defer workspaceCtx.close()
|
|
|
|
// Cancel the parent after the dial starts.
|
|
go func() {
|
|
<-dialStarted
|
|
cancel(parentErr)
|
|
}()
|
|
|
|
gotConn, err := workspaceCtx.getWorkspaceConn(ctx)
|
|
require.Nil(t, gotConn)
|
|
// The error must NOT be errChatDialTimeout.
|
|
require.NotErrorIs(t, err, errChatDialTimeout)
|
|
// The parent context's error should propagate.
|
|
require.Error(t, err)
|
|
require.ErrorIs(t, err, context.Canceled)
|
|
}
|
|
|
|
func TestGetWorkspaceConn_PreflightExternalAgentTimedOut(t *testing.T) {
|
|
// External agent never connected and the connection window has
|
|
// elapsed (Timeout). Preflight must short-circuit before any
|
|
// dial attempt and return the external-agent error.
|
|
t.Parallel()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
workspaceID := uuid.New()
|
|
agentID := uuid.New()
|
|
resourceID := uuid.New()
|
|
agent := database.WorkspaceAgent{
|
|
ID: agentID,
|
|
Name: "main",
|
|
ResourceID: resourceID,
|
|
CreatedAt: time.Now().Add(-10 * time.Minute),
|
|
ConnectionTimeoutSeconds: 60,
|
|
}
|
|
chat := database.Chat{
|
|
ID: uuid.New(),
|
|
WorkspaceID: uuid.NullUUID{
|
|
UUID: workspaceID,
|
|
Valid: true,
|
|
},
|
|
AgentID: uuid.NullUUID{
|
|
UUID: agentID,
|
|
Valid: true,
|
|
},
|
|
}
|
|
|
|
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).
|
|
Return(agent, nil).
|
|
Times(1)
|
|
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).
|
|
Return([]database.WorkspaceAgent{agent}, nil).
|
|
Times(1)
|
|
db.EXPECT().GetWorkspaceResourceByID(gomock.Any(), resourceID).
|
|
Return(database.WorkspaceResource{
|
|
ID: resourceID,
|
|
Type: chattool.ExternalAgentResourceType,
|
|
}, nil).
|
|
Times(1)
|
|
|
|
server := &Server{
|
|
db: db,
|
|
logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
|
|
clock: quartz.NewReal(),
|
|
agentInactiveDisconnectTimeout: 30 * time.Second,
|
|
dialTimeout: defaultDialTimeout,
|
|
}
|
|
server.agentConnFn = func(context.Context, uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
|
t.Fatal("unexpected agent dial for external agent preflight")
|
|
return nil, nil, xerrors.New("unexpected agent dial")
|
|
}
|
|
|
|
chatStateMu := &sync.Mutex{}
|
|
currentChat := chat
|
|
workspaceCtx := turnWorkspaceContext{
|
|
server: server,
|
|
chatStateMu: chatStateMu,
|
|
currentChat: ¤tChat,
|
|
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
|
|
}
|
|
defer workspaceCtx.close()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
gotConn, err := workspaceCtx.getWorkspaceConn(ctx)
|
|
require.Nil(t, gotConn)
|
|
require.ErrorIs(t, err, errChatExternalAgentUnavailable)
|
|
require.Equal(t, chattool.ExternalAgentUnavailableMessage(agent), err.Error())
|
|
}
|
|
|
|
func TestGetWorkspaceConn_PreflightExternalAgentConnectingDials(t *testing.T) {
|
|
// External agent in the Connecting state (never connected yet,
|
|
// still inside ConnectionTimeoutSeconds) must fall through to the
|
|
// dial so the user can succeed in the same turn if they just
|
|
// started the agent on their host.
|
|
t.Parallel()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
workspaceID := uuid.New()
|
|
agentID := uuid.New()
|
|
resourceID := uuid.New()
|
|
agent := database.WorkspaceAgent{
|
|
ID: agentID,
|
|
Name: "main",
|
|
ResourceID: resourceID,
|
|
CreatedAt: time.Now().Add(-1 * time.Second),
|
|
ConnectionTimeoutSeconds: 600,
|
|
}
|
|
chat := database.Chat{
|
|
ID: uuid.New(),
|
|
WorkspaceID: uuid.NullUUID{
|
|
UUID: workspaceID,
|
|
Valid: true,
|
|
},
|
|
AgentID: uuid.NullUUID{
|
|
UUID: agentID,
|
|
Valid: true,
|
|
},
|
|
}
|
|
|
|
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).
|
|
Return(agent, nil).
|
|
Times(1)
|
|
|
|
conn := agentconnmock.NewMockAgentConn(ctrl)
|
|
conn.EXPECT().SetExtraHeaders(gomock.Any()).Times(1)
|
|
|
|
dialed := false
|
|
server := &Server{
|
|
db: db,
|
|
logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
|
|
clock: quartz.NewReal(),
|
|
agentInactiveDisconnectTimeout: 30 * time.Second,
|
|
dialTimeout: defaultDialTimeout,
|
|
}
|
|
server.agentConnFn = func(_ context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
|
dialed = true
|
|
require.Equal(t, agentID, id)
|
|
return conn, func() {}, nil
|
|
}
|
|
|
|
chatStateMu := &sync.Mutex{}
|
|
currentChat := chat
|
|
workspaceCtx := turnWorkspaceContext{
|
|
server: server,
|
|
chatStateMu: chatStateMu,
|
|
currentChat: ¤tChat,
|
|
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
|
|
}
|
|
defer workspaceCtx.close()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
gotConn, err := workspaceCtx.getWorkspaceConn(ctx)
|
|
require.NoError(t, err)
|
|
require.Same(t, conn, gotConn)
|
|
require.True(t, dialed, "preflight must let Connecting external agents reach the dial")
|
|
}
|
|
|
|
func TestGetWorkspaceConn_DialErrorNotMisclassifiedAsTimeout(t *testing.T) {
|
|
// Regression test: a non-timeout dial error (e.g. auth
|
|
// failure) with the parent context still alive must NOT be
|
|
// converted to errChatDialTimeout or masked as external-agent
|
|
// unavailability.
|
|
t.Parallel()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
workspaceID := uuid.New()
|
|
agentID := uuid.New()
|
|
resourceID := uuid.New()
|
|
chat := database.Chat{
|
|
ID: uuid.New(),
|
|
WorkspaceID: uuid.NullUUID{
|
|
UUID: workspaceID,
|
|
Valid: true,
|
|
},
|
|
AgentID: uuid.NullUUID{
|
|
UUID: agentID,
|
|
Valid: true,
|
|
},
|
|
}
|
|
|
|
connectedAgent := database.WorkspaceAgent{
|
|
ID: agentID,
|
|
ResourceID: resourceID,
|
|
FirstConnectedAt: sql.NullTime{
|
|
Time: time.Now().Add(-1 * time.Minute),
|
|
Valid: true,
|
|
},
|
|
LastConnectedAt: sql.NullTime{
|
|
Time: time.Now(),
|
|
Valid: true,
|
|
},
|
|
}
|
|
|
|
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).
|
|
Return(connectedAgent, nil).
|
|
Times(1)
|
|
// When the initial dial fails immediately, dialWithLazyValidation
|
|
// calls resolveFastFailure which validates the binding. Mock the
|
|
// validation to return the same agent, triggering a synchronous
|
|
// redial that also returns the error.
|
|
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).
|
|
Return([]database.WorkspaceAgent{connectedAgent}, nil).
|
|
AnyTimes()
|
|
db.EXPECT().GetWorkspaceResourceByID(gomock.Any(), resourceID).
|
|
Return(database.WorkspaceResource{
|
|
ID: resourceID,
|
|
Type: chattool.ExternalAgentResourceType,
|
|
}, nil).
|
|
AnyTimes()
|
|
|
|
dialErr := xerrors.New("authentication failed")
|
|
server := &Server{
|
|
db: db,
|
|
clock: quartz.NewReal(),
|
|
agentInactiveDisconnectTimeout: 30 * time.Second,
|
|
// Generous timeout so the dial error fires well before
|
|
// the timeout.
|
|
dialTimeout: defaultDialTimeout,
|
|
}
|
|
server.agentConnFn = func(context.Context, uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
|
// Return an error immediately (not a timeout).
|
|
return nil, nil, dialErr
|
|
}
|
|
|
|
chatStateMu := &sync.Mutex{}
|
|
currentChat := chat
|
|
workspaceCtx := turnWorkspaceContext{
|
|
server: server,
|
|
chatStateMu: chatStateMu,
|
|
currentChat: ¤tChat,
|
|
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
|
|
}
|
|
defer workspaceCtx.close()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
gotConn, err := workspaceCtx.getWorkspaceConn(ctx)
|
|
require.Nil(t, gotConn)
|
|
// Must NOT be misclassified as a dial timeout or external-agent outage.
|
|
require.NotErrorIs(t, err, errChatDialTimeout)
|
|
require.NotErrorIs(t, err, errChatExternalAgentUnavailable)
|
|
// The original dial error should propagate.
|
|
require.ErrorIs(t, err, dialErr)
|
|
require.ErrorContains(t, err, "authentication failed")
|
|
}
|
|
|
|
// TestAutoPromote_InsertFailureRollsBackTransaction verifies that when
|
|
// tryAutoPromoteQueuedMessage pops a queued message but the subsequent
|
|
// insert fails, the error propagates to the InTx callback, causing the
|
|
// transaction to roll back and preserving the queued message.
|
|
func TestAutoPromote_InsertFailureRollsBackTransaction(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
tx := dbmock.NewMockStore(ctrl)
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
ps := dbpubsub.NewInMemory()
|
|
clock := quartz.NewReal()
|
|
|
|
chatID := uuid.New()
|
|
workerID := uuid.New()
|
|
ownerID := uuid.New()
|
|
modelConfigID := uuid.New()
|
|
|
|
waitingChat := database.Chat{
|
|
ID: chatID,
|
|
OwnerID: ownerID,
|
|
LastModelConfigID: modelConfigID,
|
|
Status: database.ChatStatusWaiting,
|
|
WorkerID: uuid.NullUUID{UUID: workerID, Valid: true},
|
|
}
|
|
queuedMsg := database.ChatQueuedMessage{
|
|
ID: 1,
|
|
ChatID: chatID,
|
|
Content: []byte(`[{"type":"text","text":"queued"}]`),
|
|
}
|
|
insertErr := xerrors.New("insert failed")
|
|
|
|
server := &Server{
|
|
db: db,
|
|
logger: logger,
|
|
pubsub: ps,
|
|
configCache: newChatConfigCache(ctx, db, clock),
|
|
}
|
|
|
|
// The caller runs tryAutoPromoteQueuedMessage inside InTx.
|
|
// Wire the mock to execute the callback against the TX mock.
|
|
var txErr error
|
|
db.EXPECT().InTx(gomock.Any(), gomock.Any()).DoAndReturn(
|
|
func(fn func(database.Store) error, _ *database.TxOptions) error {
|
|
txErr = fn(tx)
|
|
return txErr
|
|
},
|
|
)
|
|
|
|
// Inside the TX: lock chat, get queued messages, resolve model
|
|
// config, pop queued message, insert fails.
|
|
tx.EXPECT().GetChatByIDForUpdate(gomock.Any(), chatID).Return(waitingChat, nil)
|
|
tx.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return([]database.ChatQueuedMessage{queuedMsg}, nil)
|
|
tx.EXPECT().GetChatModelConfigByID(gomock.Any(), modelConfigID).Return(database.ChatModelConfig{ID: modelConfigID}, nil)
|
|
tx.EXPECT().PopNextQueuedMessage(gomock.Any(), chatID).Return(queuedMsg, nil)
|
|
tx.EXPECT().InsertChatMessages(gomock.Any(), gomock.Any()).Return(nil, insertErr)
|
|
|
|
// Invoke tryAutoPromoteQueuedMessage through the same InTx
|
|
// pattern the processChat defer uses. The test directly calls
|
|
// the production path to verify error propagation.
|
|
_ = db.InTx(func(txStore database.Store) error {
|
|
latestChat, err := txStore.GetChatByIDForUpdate(ctx, chatID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
_, _, _, promoteErr := server.tryAutoPromoteQueuedMessage(ctx, txStore, latestChat)
|
|
if promoteErr != nil {
|
|
return promoteErr
|
|
}
|
|
|
|
// This code path should not be reached when the insert
|
|
// fails, because promoteErr should be non-nil.
|
|
return nil
|
|
}, nil)
|
|
|
|
// The InTx callback must return a non-nil error so the
|
|
// transaction rolls back, preserving the queued message.
|
|
require.Error(t, txErr, "InTx callback should return error when insert fails")
|
|
}
|
|
|
|
// TestAutoPromote_WakesRunLoopAfterPromotion verifies that after the
|
|
func TestAutoPromote_InsertFailureSkipsStatusUpdate(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
tx := dbmock.NewMockStore(ctrl)
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
ps := dbpubsub.NewInMemory()
|
|
clock := quartz.NewReal()
|
|
|
|
chatID := uuid.New()
|
|
workerID := uuid.New()
|
|
ownerID := uuid.New()
|
|
modelConfigID := uuid.New()
|
|
|
|
waitingChat := database.Chat{
|
|
ID: chatID,
|
|
OwnerID: ownerID,
|
|
LastModelConfigID: modelConfigID,
|
|
Status: database.ChatStatusWaiting,
|
|
WorkerID: uuid.NullUUID{UUID: workerID, Valid: true},
|
|
}
|
|
queuedMsg := database.ChatQueuedMessage{
|
|
ID: 1,
|
|
ChatID: chatID,
|
|
Content: []byte(`[{"type":"text","text":"queued"}]`),
|
|
}
|
|
|
|
wakeCh := make(chan struct{}, 1)
|
|
server := &Server{
|
|
db: db,
|
|
logger: logger,
|
|
pubsub: ps,
|
|
clock: clock,
|
|
workerID: workerID,
|
|
wakeCh: wakeCh,
|
|
chatHeartbeatInterval: time.Minute,
|
|
metrics: chatloop.NopMetrics(),
|
|
configCache: newChatConfigCache(ctx, db, clock),
|
|
heartbeatRegistry: make(map[uuid.UUID]*heartbeatEntry),
|
|
}
|
|
|
|
// Hold model resolution until the interrupt has canceled the chat
|
|
// context. Returning ErrInterrupted keeps processChat on the
|
|
// interrupted path regardless of whether the cache singleflight sees
|
|
// the caller cancellation or the DB fetch result first.
|
|
modelBlocked := make(chan struct{})
|
|
modelRelease := make(chan struct{})
|
|
var modelBlockedOnce sync.Once
|
|
db.EXPECT().GetChatModelConfigByID(gomock.Any(), gomock.Any()).DoAndReturn(
|
|
func(_ context.Context, _ uuid.UUID) (database.ChatModelConfig, error) {
|
|
modelBlockedOnce.Do(func() { close(modelBlocked) })
|
|
<-modelRelease
|
|
return database.ChatModelConfig{}, chatloop.ErrInterrupted
|
|
},
|
|
).AnyTimes()
|
|
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return(nil, nil).AnyTimes()
|
|
db.EXPECT().GetEnabledChatModelConfigs(gomock.Any()).Return(nil, nil).AnyTimes()
|
|
db.EXPECT().GetChatUsageLimitConfig(gomock.Any()).Return(
|
|
database.ChatUsageLimitConfig{}, sql.ErrNoRows,
|
|
).AnyTimes()
|
|
db.EXPECT().GetChatMessagesForPromptByChatID(gomock.Any(), chatID).Return(nil, nil).AnyTimes()
|
|
|
|
// The deferred cleanup transaction: InsertChatMessages fails,
|
|
// so UpdateChatStatus must NOT be called.
|
|
db.EXPECT().InTx(gomock.Any(), gomock.Any()).DoAndReturn(
|
|
func(fn func(database.Store) error, _ *database.TxOptions) error {
|
|
return fn(tx)
|
|
},
|
|
)
|
|
tx.EXPECT().GetChatByIDForUpdate(gomock.Any(), chatID).Return(waitingChat, nil)
|
|
tx.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return([]database.ChatQueuedMessage{queuedMsg}, nil)
|
|
tx.EXPECT().GetChatModelConfigByID(gomock.Any(), modelConfigID).Return(database.ChatModelConfig{ID: modelConfigID}, nil)
|
|
tx.EXPECT().PopNextQueuedMessage(gomock.Any(), chatID).Return(queuedMsg, nil)
|
|
tx.EXPECT().InsertChatMessages(gomock.Any(), gomock.Any()).Return(
|
|
nil, xerrors.New("insert failed"),
|
|
)
|
|
tx.EXPECT().UpdateChatStatus(gomock.Any(), gomock.Any()).Times(0)
|
|
|
|
// Subscribe BEFORE launching the goroutine.
|
|
runningCh := make(chan struct{}, 1)
|
|
unsubRunning, err := ps.SubscribeWithErr(
|
|
coderdpubsub.ChatStreamNotifyChannel(chatID),
|
|
func(_ context.Context, msg []byte, err error) {
|
|
if err != nil {
|
|
return
|
|
}
|
|
var notify coderdpubsub.ChatStreamNotifyMessage
|
|
if json.Unmarshal(msg, ¬ify) != nil {
|
|
return
|
|
}
|
|
if notify.Status == string(database.ChatStatusRunning) {
|
|
select {
|
|
case runningCh <- struct{}{}:
|
|
default:
|
|
}
|
|
}
|
|
},
|
|
)
|
|
require.NoError(t, err)
|
|
defer unsubRunning()
|
|
|
|
chat := database.Chat{ID: chatID, OwnerID: ownerID, LastModelConfigID: modelConfigID}
|
|
processDone := make(chan struct{})
|
|
go func() {
|
|
defer close(processDone)
|
|
server.processChat(ctx, chat)
|
|
}()
|
|
|
|
select {
|
|
case <-runningCh:
|
|
case <-ctx.Done():
|
|
t.Fatal("timed out waiting for running status")
|
|
}
|
|
|
|
select {
|
|
case <-modelBlocked:
|
|
case <-ctx.Done():
|
|
t.Fatal("timed out waiting for model resolution")
|
|
}
|
|
|
|
// Publish an interrupt so processChat exits runChat.
|
|
interruptMsg, err := json.Marshal(coderdpubsub.ChatStreamNotifyMessage{
|
|
Status: string(database.ChatStatusWaiting),
|
|
})
|
|
require.NoError(t, err)
|
|
err = ps.Publish(coderdpubsub.ChatStreamNotifyChannel(chatID), interruptMsg)
|
|
require.NoError(t, err)
|
|
close(modelRelease)
|
|
|
|
select {
|
|
case <-processDone:
|
|
case <-ctx.Done():
|
|
t.Fatal("processChat did not complete")
|
|
}
|
|
|
|
// The wake channel should NOT have a signal because the
|
|
// transaction failed before reaching UpdateChatStatus.
|
|
select {
|
|
case <-wakeCh:
|
|
t.Fatal("wake channel should not have a signal after insert failure")
|
|
default:
|
|
// No signal, as expected.
|
|
}
|
|
}
|
|
|
|
// makeInProgressPart is a small constructor for buffered message_part
|
|
// fixtures used by snapshotBufferLocked / subscribeToStream tests. It
|
|
// builds an in-progress part (committedMessageID == 0) with a
|
|
// recognizable text body so failing assertions can identify which
|
|
// part survived the filter.
|
|
func makeInProgressPart(text string) bufferedStreamPart {
|
|
return bufferedStreamPart{
|
|
event: codersdk.ChatStreamEvent{
|
|
Type: codersdk.ChatStreamEventTypeMessagePart,
|
|
MessagePart: &codersdk.ChatStreamMessagePart{
|
|
Role: codersdk.ChatMessageRoleAssistant,
|
|
Part: codersdk.ChatMessageText(text),
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
// makeCommittedPart builds a part already claimed by the given
|
|
// durable assistant message ID.
|
|
func makeCommittedPart(committedID int64, text string) bufferedStreamPart {
|
|
p := makeInProgressPart(text)
|
|
p.committedMessageID = committedID
|
|
return p
|
|
}
|
|
|
|
func partText(event codersdk.ChatStreamEvent) string {
|
|
if event.MessagePart == nil {
|
|
return ""
|
|
}
|
|
return event.MessagePart.Part.Text
|
|
}
|
|
|
|
// TestSnapshotBufferLocked_DropsCommittedParts asserts the core
|
|
// dedup contract: parts that were claimed by a durable assistant
|
|
// message (committedMessageID != 0) are dropped from the snapshot
|
|
// because the subscriber will receive that durable message through
|
|
// the REST snapshot, the initial DB query, or pubsub.
|
|
func TestSnapshotBufferLocked_DropsCommittedParts(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
buffer := []bufferedStreamPart{
|
|
makeCommittedPart(100, "turnA-1"),
|
|
makeCommittedPart(100, "turnA-2"),
|
|
makeCommittedPart(200, "turnB-1"),
|
|
makeInProgressPart("in-progress-1"),
|
|
makeInProgressPart("in-progress-2"),
|
|
}
|
|
|
|
snapshot := snapshotBufferLocked(buffer)
|
|
|
|
require.Len(t, snapshot, 2,
|
|
"only in-progress (committedMessageID == 0) parts should be kept")
|
|
require.Equal(t, "in-progress-1", partText(snapshot[0]))
|
|
require.Equal(t, "in-progress-2", partText(snapshot[1]))
|
|
}
|
|
|
|
// TestSnapshotBufferLocked_AllInProgressReturnsAll covers the
|
|
// fresh-load convention: when no assistant message has committed
|
|
// yet, every buffered part is in-progress and must be delivered.
|
|
func TestSnapshotBufferLocked_AllInProgressReturnsAll(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
buffer := []bufferedStreamPart{
|
|
makeInProgressPart("a"),
|
|
makeInProgressPart("b"),
|
|
makeInProgressPart("c"),
|
|
}
|
|
|
|
snapshot := snapshotBufferLocked(buffer)
|
|
|
|
require.Len(t, snapshot, 3,
|
|
"all in-progress parts must be delivered to the subscriber")
|
|
require.Equal(t, "a", partText(snapshot[0]))
|
|
require.Equal(t, "b", partText(snapshot[1]))
|
|
require.Equal(t, "c", partText(snapshot[2]))
|
|
}
|
|
|
|
// TestSnapshotBufferLocked_EmptyBufferReturnsNil documents that
|
|
// snapshotBufferLocked returns nil (not an empty slice) for an
|
|
// empty buffer, matching the prior append-from-nil behavior.
|
|
func TestSnapshotBufferLocked_EmptyBufferReturnsNil(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
require.Nil(t, snapshotBufferLocked(nil))
|
|
require.Nil(t, snapshotBufferLocked([]bufferedStreamPart{}))
|
|
}
|
|
|
|
// TestSnapshotBufferLocked_AllCommittedReturnsEmpty covers the
|
|
// natural resting point after an assistant turn commits and before
|
|
// the next turn starts streaming: every buffered part has been
|
|
// claimed and must be filtered out. The snapshot must be empty so
|
|
// reconnecting subscribers do not re-render content that is already
|
|
// available as a durable message.
|
|
func TestSnapshotBufferLocked_AllCommittedReturnsEmpty(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
buffer := []bufferedStreamPart{
|
|
makeCommittedPart(100, "a"),
|
|
makeCommittedPart(100, "b"),
|
|
makeCommittedPart(200, "c"),
|
|
}
|
|
|
|
require.Empty(t, snapshotBufferLocked(buffer))
|
|
}
|
|
|
|
// TestPublishToStream_AppendsAsInProgress verifies that parts
|
|
// buffered while the chat is streaming are tagged as in-progress
|
|
// (committedMessageID == 0) until publishMessage claims them via a
|
|
// committed assistant message.
|
|
func TestPublishToStream_AppendsAsInProgress(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
mClock := quartz.NewMock(t)
|
|
server := &Server{
|
|
logger: slogtest.Make(t, nil),
|
|
clock: mClock,
|
|
}
|
|
|
|
chatID := uuid.New()
|
|
state := &chatStreamState{
|
|
buffering: true,
|
|
subscribers: map[uuid.UUID]chan codersdk.ChatStreamEvent{},
|
|
}
|
|
server.chatStreams.Store(chatID, state)
|
|
|
|
server.publishToStream(chatID, codersdk.ChatStreamEvent{
|
|
Type: codersdk.ChatStreamEventTypeMessagePart,
|
|
MessagePart: &codersdk.ChatStreamMessagePart{
|
|
Role: codersdk.ChatMessageRoleAssistant,
|
|
Part: codersdk.ChatMessageText("hello"),
|
|
},
|
|
})
|
|
|
|
state.mu.Lock()
|
|
defer state.mu.Unlock()
|
|
require.Len(t, state.buffer, 1)
|
|
require.Equal(t, int64(0), state.buffer[0].committedMessageID,
|
|
"newly buffered parts must be in-progress until publishMessage claims them")
|
|
require.Equal(t, "hello", partText(state.buffer[0].event))
|
|
}
|
|
|
|
// TestClaimCommittedParts covers the per-role behavior of
|
|
// claimCommittedParts:
|
|
// - assistant messages claim every in-progress part with the
|
|
// committed message ID.
|
|
// - tool / user messages do not claim parts.
|
|
// - parts already claimed by an earlier assistant message are not
|
|
// re-claimed.
|
|
// - a chat with no live state is a no-op (does not panic).
|
|
func TestClaimCommittedParts(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("AssistantClaimsAllInProgressParts", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := &Server{
|
|
logger: slogtest.Make(t, nil),
|
|
clock: quartz.NewMock(t),
|
|
}
|
|
chatID := uuid.New()
|
|
state := server.getOrCreateStreamState(chatID)
|
|
state.mu.Lock()
|
|
state.buffer = []bufferedStreamPart{
|
|
makeCommittedPart(100, "old-1"),
|
|
makeInProgressPart("new-1"),
|
|
makeInProgressPart("new-2"),
|
|
}
|
|
state.mu.Unlock()
|
|
|
|
server.claimCommittedParts(chatID, database.ChatMessage{
|
|
ID: 200,
|
|
Role: database.ChatMessageRoleAssistant,
|
|
})
|
|
|
|
state.mu.Lock()
|
|
defer state.mu.Unlock()
|
|
require.Equal(t, int64(100), state.buffer[0].committedMessageID,
|
|
"already-claimed parts must keep their original message ID")
|
|
require.Equal(t, int64(200), state.buffer[1].committedMessageID,
|
|
"in-progress parts must be claimed by the new message ID")
|
|
require.Equal(t, int64(200), state.buffer[2].committedMessageID,
|
|
"in-progress parts must be claimed by the new message ID")
|
|
})
|
|
|
|
t.Run("ToolMessageIsNoOp", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := &Server{
|
|
logger: slogtest.Make(t, nil),
|
|
clock: quartz.NewMock(t),
|
|
}
|
|
chatID := uuid.New()
|
|
state := server.getOrCreateStreamState(chatID)
|
|
state.mu.Lock()
|
|
state.buffer = []bufferedStreamPart{
|
|
makeInProgressPart("a"),
|
|
makeInProgressPart("b"),
|
|
}
|
|
state.mu.Unlock()
|
|
|
|
server.claimCommittedParts(chatID, database.ChatMessage{
|
|
ID: 300,
|
|
Role: database.ChatMessageRoleTool,
|
|
})
|
|
|
|
state.mu.Lock()
|
|
defer state.mu.Unlock()
|
|
require.Equal(t, int64(0), state.buffer[0].committedMessageID,
|
|
"tool messages must not claim buffered parts")
|
|
require.Equal(t, int64(0), state.buffer[1].committedMessageID,
|
|
"tool messages must not claim buffered parts")
|
|
})
|
|
|
|
t.Run("UserMessageIsNoOp", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := &Server{
|
|
logger: slogtest.Make(t, nil),
|
|
clock: quartz.NewMock(t),
|
|
}
|
|
chatID := uuid.New()
|
|
state := server.getOrCreateStreamState(chatID)
|
|
state.mu.Lock()
|
|
state.buffer = []bufferedStreamPart{
|
|
makeInProgressPart("a"),
|
|
}
|
|
state.mu.Unlock()
|
|
|
|
server.claimCommittedParts(chatID, database.ChatMessage{
|
|
ID: 400,
|
|
Role: database.ChatMessageRoleUser,
|
|
})
|
|
|
|
state.mu.Lock()
|
|
defer state.mu.Unlock()
|
|
require.Equal(t, int64(0), state.buffer[0].committedMessageID,
|
|
"user messages must not claim buffered parts")
|
|
})
|
|
|
|
t.Run("NoLiveStateIsNoOp", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := &Server{
|
|
logger: slogtest.Make(t, nil),
|
|
clock: quartz.NewMock(t),
|
|
}
|
|
chatID := uuid.New()
|
|
|
|
// No state stored: claimCommittedParts must not panic and
|
|
// must not allocate a new state for an unknown chat.
|
|
require.NotPanics(t, func() {
|
|
server.claimCommittedParts(chatID, database.ChatMessage{
|
|
ID: 500,
|
|
Role: database.ChatMessageRoleAssistant,
|
|
})
|
|
})
|
|
_, ok := server.chatStreams.Load(chatID)
|
|
require.False(t, ok,
|
|
"claimCommittedParts must not create stream state for a chat that has none")
|
|
})
|
|
}
|
|
|
|
// TestSubscribeToStream_FiltersBufferedParts_Integration wires
|
|
// publishToStream, claimCommittedParts (via publishMessage), and
|
|
// subscribeToStream together to confirm the end-to-end contract: a
|
|
// reconnecting subscriber only receives parts that belong to the
|
|
// current in-progress turn, not parts that were already committed
|
|
// to durable assistant messages.
|
|
func TestSubscribeToStream_FiltersBufferedParts_Integration(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
mClock := quartz.NewMock(t)
|
|
server := &Server{
|
|
logger: slogtest.Make(t, nil),
|
|
clock: mClock,
|
|
}
|
|
chatID := uuid.New()
|
|
|
|
// Simulate the lifecycle:
|
|
// 1. Stream parts of turn A (still in-progress, no commit yet).
|
|
// 2. Commit turn A; its parts are claimed by message 100.
|
|
// 3. Stream parts of turn B (in-progress).
|
|
// 4. Commit turn B; its parts are claimed by message 200.
|
|
// 5. Stream parts of turn C (in-progress, never committed).
|
|
state := server.getOrCreateStreamState(chatID)
|
|
state.mu.Lock()
|
|
state.buffering = true
|
|
state.mu.Unlock()
|
|
|
|
publishPart := func(text string) {
|
|
server.publishToStream(chatID, codersdk.ChatStreamEvent{
|
|
Type: codersdk.ChatStreamEventTypeMessagePart,
|
|
MessagePart: &codersdk.ChatStreamMessagePart{
|
|
Role: codersdk.ChatMessageRoleAssistant,
|
|
Part: codersdk.ChatMessageText(text),
|
|
},
|
|
})
|
|
}
|
|
|
|
publishPart("A-1")
|
|
publishPart("A-2")
|
|
server.claimCommittedParts(chatID, database.ChatMessage{
|
|
ID: 100,
|
|
Role: database.ChatMessageRoleAssistant,
|
|
})
|
|
publishPart("B-1")
|
|
publishPart("B-2")
|
|
server.claimCommittedParts(chatID, database.ChatMessage{
|
|
ID: 200,
|
|
Role: database.ChatMessageRoleAssistant,
|
|
})
|
|
publishPart("C-1")
|
|
|
|
// Reconnecting subscriber: only the currently in-progress turn
|
|
// (turn C) survives the filter, no matter what cursor the
|
|
// client passes through SubscribeAuthorized (the filter no
|
|
// longer depends on the cursor).
|
|
snapshot, _, _, cancel := server.subscribeToStream(chatID)
|
|
defer cancel()
|
|
|
|
texts := make([]string, 0, len(snapshot))
|
|
for _, ev := range snapshot {
|
|
texts = append(texts, partText(ev))
|
|
}
|
|
require.Equal(t, []string{"C-1"}, texts,
|
|
"only in-progress (un-claimed) buffered parts must survive the filter")
|
|
}
|
|
|
|
// TestPrimeWorkspaceMCPCache_SuccessOnFirstAttempt verifies the
|
|
// onChatUpdated cache primer path: when create_workspace /
|
|
// start_workspace finish waitForAgentReady and the agent's MCP
|
|
// server is already advertising tools, a single ListMCPTools call
|
|
// populates the cache so the next PrepareTools step is a cache hit
|
|
// and does not need to dial.
|
|
func TestPrimeWorkspaceMCPCache_SuccessOnFirstAttempt(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
workspaceID := uuid.New()
|
|
agentID := uuid.New()
|
|
chat := database.Chat{
|
|
ID: uuid.New(),
|
|
WorkspaceID: uuid.NullUUID{
|
|
UUID: workspaceID,
|
|
Valid: true,
|
|
},
|
|
AgentID: uuid.NullUUID{
|
|
UUID: agentID,
|
|
Valid: true,
|
|
},
|
|
}
|
|
now := time.Now()
|
|
workspaceAgent := database.WorkspaceAgent{
|
|
ID: agentID,
|
|
FirstConnectedAt: sql.NullTime{
|
|
Time: now.Add(-time.Minute),
|
|
Valid: true,
|
|
},
|
|
LastConnectedAt: sql.NullTime{
|
|
Time: now,
|
|
Valid: true,
|
|
},
|
|
}
|
|
|
|
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).
|
|
Return(workspaceAgent, nil).AnyTimes()
|
|
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).
|
|
Return([]database.WorkspaceAgent{workspaceAgent}, nil).AnyTimes()
|
|
|
|
toolName := "workspace-mcp__echo"
|
|
conn := agentconnmock.NewMockAgentConn(ctrl)
|
|
conn.EXPECT().SetExtraHeaders(gomock.Any()).AnyTimes()
|
|
conn.EXPECT().ListMCPTools(gomock.Any()).Return(workspacesdk.ListMCPToolsResponse{
|
|
Tools: []workspacesdk.MCPToolInfo{{
|
|
ServerName: "workspace-mcp",
|
|
Name: toolName,
|
|
Schema: map[string]any{},
|
|
}},
|
|
}, nil).Times(1)
|
|
|
|
server := &Server{
|
|
db: db,
|
|
logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
|
|
clock: quartz.NewMock(t),
|
|
agentInactiveDisconnectTimeout: 30 * time.Second,
|
|
dialTimeout: time.Second,
|
|
agentConnFn: func(context.Context, uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
|
return conn, func() {}, nil
|
|
},
|
|
}
|
|
|
|
chatStateMu := &sync.Mutex{}
|
|
currentChat := chat
|
|
workspaceCtx := turnWorkspaceContext{
|
|
server: server,
|
|
chatStateMu: chatStateMu,
|
|
currentChat: ¤tChat,
|
|
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return chat, nil },
|
|
}
|
|
t.Cleanup(workspaceCtx.close)
|
|
|
|
server.primeWorkspaceMCPCache(ctx, server.logger, chat.ID, &workspaceCtx)
|
|
|
|
cached, ok := server.workspaceMCPToolsCache.Load(chat.ID)
|
|
require.True(t, ok, "primer must populate the cache on success")
|
|
entry, ok := cached.(*cachedWorkspaceMCPTools)
|
|
require.True(t, ok)
|
|
require.Equal(t, agentID, entry.agentID)
|
|
require.Len(t, entry.tools, 1)
|
|
require.Equal(t, toolName, entry.tools[0].Name)
|
|
}
|
|
|
|
// TestPrimeWorkspaceMCPCache_RetriesUntilToolsAppear simulates the
|
|
// race between agent reachability and the agent's MCP Connect: the
|
|
// first ListMCPTools call returns an empty list (no error), the
|
|
// second returns the workspace tools. The primer must retry after
|
|
// workspaceMCPPrimeRetryInterval and write the cache on the second
|
|
// attempt.
|
|
func TestPrimeWorkspaceMCPCache_RetriesUntilToolsAppear(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
workspaceID := uuid.New()
|
|
agentID := uuid.New()
|
|
chat := database.Chat{
|
|
ID: uuid.New(),
|
|
WorkspaceID: uuid.NullUUID{
|
|
UUID: workspaceID,
|
|
Valid: true,
|
|
},
|
|
AgentID: uuid.NullUUID{
|
|
UUID: agentID,
|
|
Valid: true,
|
|
},
|
|
}
|
|
now := time.Now()
|
|
workspaceAgent := database.WorkspaceAgent{
|
|
ID: agentID,
|
|
FirstConnectedAt: sql.NullTime{
|
|
Time: now.Add(-time.Minute),
|
|
Valid: true,
|
|
},
|
|
LastConnectedAt: sql.NullTime{
|
|
Time: now,
|
|
Valid: true,
|
|
},
|
|
}
|
|
|
|
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).
|
|
Return(workspaceAgent, nil).AnyTimes()
|
|
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).
|
|
Return([]database.WorkspaceAgent{workspaceAgent}, nil).AnyTimes()
|
|
|
|
toolName := "workspace-mcp__echo"
|
|
var listCalls atomic.Int32
|
|
emptyOnce := make(chan struct{}, 1)
|
|
emptyOnce <- struct{}{}
|
|
conn := agentconnmock.NewMockAgentConn(ctrl)
|
|
conn.EXPECT().SetExtraHeaders(gomock.Any()).AnyTimes()
|
|
conn.EXPECT().ListMCPTools(gomock.Any()).DoAndReturn(
|
|
func(context.Context) (workspacesdk.ListMCPToolsResponse, error) {
|
|
listCalls.Add(1)
|
|
select {
|
|
case <-emptyOnce:
|
|
return workspacesdk.ListMCPToolsResponse{}, nil
|
|
default:
|
|
return workspacesdk.ListMCPToolsResponse{
|
|
Tools: []workspacesdk.MCPToolInfo{{
|
|
ServerName: "workspace-mcp",
|
|
Name: toolName,
|
|
Schema: map[string]any{},
|
|
}},
|
|
}, nil
|
|
}
|
|
},
|
|
).AnyTimes()
|
|
|
|
mockClock := quartz.NewMock(t)
|
|
timerTrap := mockClock.Trap().NewTimer("chatd", "workspace-mcp-prime")
|
|
t.Cleanup(timerTrap.Close)
|
|
|
|
server := &Server{
|
|
db: db,
|
|
logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
|
|
clock: mockClock,
|
|
agentInactiveDisconnectTimeout: 30 * time.Second,
|
|
dialTimeout: time.Second,
|
|
agentConnFn: func(context.Context, uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
|
return conn, func() {}, nil
|
|
},
|
|
}
|
|
|
|
chatStateMu := &sync.Mutex{}
|
|
currentChat := chat
|
|
workspaceCtx := turnWorkspaceContext{
|
|
server: server,
|
|
chatStateMu: chatStateMu,
|
|
currentChat: ¤tChat,
|
|
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return chat, nil },
|
|
}
|
|
t.Cleanup(workspaceCtx.close)
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer close(done)
|
|
server.primeWorkspaceMCPCache(ctx, server.logger, chat.ID, &workspaceCtx)
|
|
}()
|
|
|
|
// First attempt returns empty. The primer arms a timer; release
|
|
// it and advance the clock so the second attempt fires.
|
|
call := timerTrap.MustWait(ctx)
|
|
call.MustRelease(ctx)
|
|
mockClock.Advance(workspaceMCPPrimeRetryInterval).MustWait(ctx)
|
|
|
|
select {
|
|
case <-done:
|
|
case <-ctx.Done():
|
|
t.Fatal("primer did not finish after second attempt")
|
|
}
|
|
|
|
require.GreaterOrEqual(t, listCalls.Load(), int32(2),
|
|
"primer must retry after empty result")
|
|
cached, ok := server.workspaceMCPToolsCache.Load(chat.ID)
|
|
require.True(t, ok, "primer must populate the cache on retry success")
|
|
entry, ok := cached.(*cachedWorkspaceMCPTools)
|
|
require.True(t, ok)
|
|
require.Equal(t, agentID, entry.agentID)
|
|
require.Len(t, entry.tools, 1)
|
|
require.Equal(t, toolName, entry.tools[0].Name)
|
|
}
|
|
|
|
// TestPrimeWorkspaceMCPCache_GivesUpAfterDeadline verifies the
|
|
// bounded-wait guarantee: when ListMCPTools always returns an empty
|
|
// list (e.g. the agent's MCP server never advertises tools), the
|
|
// primer stops trying at workspaceMCPPrimeMaxWait and does not cache
|
|
// the empty result. PrepareTools is then free to retry on the next
|
|
// chat step.
|
|
func TestPrimeWorkspaceMCPCache_GivesUpAfterDeadline(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
workspaceID := uuid.New()
|
|
agentID := uuid.New()
|
|
chat := database.Chat{
|
|
ID: uuid.New(),
|
|
WorkspaceID: uuid.NullUUID{
|
|
UUID: workspaceID,
|
|
Valid: true,
|
|
},
|
|
AgentID: uuid.NullUUID{
|
|
UUID: agentID,
|
|
Valid: true,
|
|
},
|
|
}
|
|
now := time.Now()
|
|
workspaceAgent := database.WorkspaceAgent{
|
|
ID: agentID,
|
|
FirstConnectedAt: sql.NullTime{
|
|
Time: now.Add(-time.Minute),
|
|
Valid: true,
|
|
},
|
|
LastConnectedAt: sql.NullTime{
|
|
Time: now,
|
|
Valid: true,
|
|
},
|
|
}
|
|
|
|
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).
|
|
Return(workspaceAgent, nil).AnyTimes()
|
|
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).
|
|
Return([]database.WorkspaceAgent{workspaceAgent}, nil).AnyTimes()
|
|
|
|
var listCalls atomic.Int32
|
|
conn := agentconnmock.NewMockAgentConn(ctrl)
|
|
conn.EXPECT().SetExtraHeaders(gomock.Any()).AnyTimes()
|
|
conn.EXPECT().ListMCPTools(gomock.Any()).DoAndReturn(
|
|
func(context.Context) (workspacesdk.ListMCPToolsResponse, error) {
|
|
listCalls.Add(1)
|
|
return workspacesdk.ListMCPToolsResponse{}, nil
|
|
},
|
|
).AnyTimes()
|
|
|
|
mockClock := quartz.NewMock(t)
|
|
timerTrap := mockClock.Trap().NewTimer("chatd", "workspace-mcp-prime")
|
|
t.Cleanup(timerTrap.Close)
|
|
|
|
server := &Server{
|
|
db: db,
|
|
logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
|
|
clock: mockClock,
|
|
agentInactiveDisconnectTimeout: 30 * time.Second,
|
|
dialTimeout: time.Second,
|
|
agentConnFn: func(context.Context, uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
|
return conn, func() {}, nil
|
|
},
|
|
}
|
|
|
|
chatStateMu := &sync.Mutex{}
|
|
currentChat := chat
|
|
workspaceCtx := turnWorkspaceContext{
|
|
server: server,
|
|
chatStateMu: chatStateMu,
|
|
currentChat: ¤tChat,
|
|
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return chat, nil },
|
|
}
|
|
t.Cleanup(workspaceCtx.close)
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer close(done)
|
|
server.primeWorkspaceMCPCache(ctx, server.logger, chat.ID, &workspaceCtx)
|
|
}()
|
|
|
|
// Drive the retry loop forward until the primer gives up. Each
|
|
// iteration: release the trapped NewTimer call, then advance the
|
|
// clock past the retry interval. The primer exits when
|
|
// p.clock.Now() is no longer before deadline. The loop bounds
|
|
// itself on maxIterations and uses a done-aware wait context so
|
|
// the test fails cleanly instead of hanging when the primer
|
|
// shuts down between iterations.
|
|
maxIterations := int(workspaceMCPPrimeMaxWait/workspaceMCPPrimeRetryInterval) + 2
|
|
Loop:
|
|
for i := 0; i < maxIterations; i++ {
|
|
waitCtx, cancel := context.WithCancel(ctx)
|
|
go func() {
|
|
select {
|
|
case <-done:
|
|
cancel()
|
|
case <-waitCtx.Done():
|
|
}
|
|
}()
|
|
call, err := timerTrap.Wait(waitCtx)
|
|
cancel()
|
|
if err != nil {
|
|
break Loop
|
|
}
|
|
call.MustRelease(ctx)
|
|
mockClock.Advance(workspaceMCPPrimeRetryInterval).MustWait(ctx)
|
|
}
|
|
|
|
// expectedAttempts is the floor on how many times the primer
|
|
// should call discoverWorkspaceMCPTools before the deadline
|
|
// expires. The primer makes one attempt before sleeping, then
|
|
// one per workspaceMCPPrimeRetryInterval until the deadline.
|
|
// We assert a high-water mark (rather than exact equality) so
|
|
// the test is robust to off-by-one boundaries while still
|
|
// catching deadline miscomputations: a primer that exits after a
|
|
// handful of attempts would suggest the deadline was set with a
|
|
// shorter window than workspaceMCPPrimeMaxWait.
|
|
expectedAttempts := int32(workspaceMCPPrimeMaxWait/workspaceMCPPrimeRetryInterval) / 2
|
|
require.GreaterOrEqual(t, listCalls.Load(), expectedAttempts,
|
|
"primer must retry enough times to consume the full budget")
|
|
_, ok := server.workspaceMCPToolsCache.Load(chat.ID)
|
|
require.False(t, ok,
|
|
"primer must not cache an empty result; PrepareTools needs to retry on the next step")
|
|
}
|
|
|
|
// TestPrimeWorkspaceMCPCache_ExitsOnContextCancel verifies the
|
|
// primer's context.Done() branch: the retry loop must exit promptly
|
|
// when the chat ctx is canceled (runChat cancels its primerCtx
|
|
// before workspaceCtx.close runs to prevent a primer from re-dialing
|
|
// the freed conn).
|
|
func TestPrimeWorkspaceMCPCache_ExitsOnContextCancel(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
|
|
workspaceID := uuid.New()
|
|
agentID := uuid.New()
|
|
chat := database.Chat{
|
|
ID: uuid.New(),
|
|
WorkspaceID: uuid.NullUUID{
|
|
UUID: workspaceID,
|
|
Valid: true,
|
|
},
|
|
AgentID: uuid.NullUUID{
|
|
UUID: agentID,
|
|
Valid: true,
|
|
},
|
|
}
|
|
now := time.Now()
|
|
workspaceAgent := database.WorkspaceAgent{
|
|
ID: agentID,
|
|
FirstConnectedAt: sql.NullTime{
|
|
Time: now.Add(-time.Minute),
|
|
Valid: true,
|
|
},
|
|
LastConnectedAt: sql.NullTime{
|
|
Time: now,
|
|
Valid: true,
|
|
},
|
|
}
|
|
|
|
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).
|
|
Return(workspaceAgent, nil).AnyTimes()
|
|
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).
|
|
Return([]database.WorkspaceAgent{workspaceAgent}, nil).AnyTimes()
|
|
|
|
conn := agentconnmock.NewMockAgentConn(ctrl)
|
|
conn.EXPECT().SetExtraHeaders(gomock.Any()).AnyTimes()
|
|
conn.EXPECT().ListMCPTools(gomock.Any()).
|
|
Return(workspacesdk.ListMCPToolsResponse{}, nil).AnyTimes()
|
|
|
|
mockClock := quartz.NewMock(t)
|
|
timerTrap := mockClock.Trap().NewTimer("chatd", "workspace-mcp-prime")
|
|
t.Cleanup(timerTrap.Close)
|
|
|
|
server := &Server{
|
|
db: db,
|
|
logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
|
|
clock: mockClock,
|
|
agentInactiveDisconnectTimeout: 30 * time.Second,
|
|
dialTimeout: time.Second,
|
|
agentConnFn: func(context.Context, uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
|
return conn, func() {}, nil
|
|
},
|
|
}
|
|
|
|
chatStateMu := &sync.Mutex{}
|
|
currentChat := chat
|
|
workspaceCtx := turnWorkspaceContext{
|
|
server: server,
|
|
chatStateMu: chatStateMu,
|
|
currentChat: ¤tChat,
|
|
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return chat, nil },
|
|
}
|
|
t.Cleanup(workspaceCtx.close)
|
|
|
|
primerCtx, primerCancel := context.WithCancel(ctx)
|
|
t.Cleanup(primerCancel)
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer close(done)
|
|
server.primeWorkspaceMCPCache(primerCtx, server.logger, chat.ID, &workspaceCtx)
|
|
}()
|
|
|
|
// Let the primer arm at least one retry timer so we know it is
|
|
// blocked in the select. Canceling before this would race with
|
|
// the loop entering the retry path.
|
|
call := timerTrap.MustWait(ctx)
|
|
call.MustRelease(ctx)
|
|
|
|
primerCancel()
|
|
|
|
select {
|
|
case <-done:
|
|
case <-ctx.Done():
|
|
t.Fatal("primer did not exit after context cancellation")
|
|
}
|
|
|
|
_, ok := server.workspaceMCPToolsCache.Load(chat.ID)
|
|
require.False(t, ok, "primer must not cache anything when canceled")
|
|
}
|