mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
perf(coderd/x/chatd): persist workspace agent binding across chat turns (#23274)
## Summary This change removes the steady-state "resolve the latest workspace agent" query from chat execution. Instead of asking the database for the latest build's agent on every turn, a chat now persists the workspace/build/agent binding it actually uses and reuses that binding across subsequent turns. The common path becomes "load the bound agent by ID and dial it", with fallback paths to repair the binding when it is missing, stale, or intentionally changed. ## What changes - add `workspace_id`, `build_id`, and `agent_id` binding fields to `chats` - expose those fields through the chat API / SDK so the execution context is explicit - load the persisted binding first in chatd, instead of always resolving the latest build's agent - persist a refreshed binding when chatd has to re-resolve the workspace agent - keep child / subagent chats on the same bound workspace context by inheriting the parent binding - leave `build_id` / `agent_id` unset for flows like `create_workspace`, then bind them lazily on the next agent-backed turn ## Runtime behavior The binding is treated as an optimistic cache of the agent a chat should use: - if the bound agent still exists and dials successfully, we use it without a latest-build lookup - if the bound agent is missing or no longer reachable, chatd re-resolves against the latest build and persists the new binding - if a workspace mutation changes the chat's target workspace, the binding is updated as part of that mutation To avoid reintroducing a hot-path query, dialing uses lazy validation: - start dialing the cached agent immediately - only validate against the latest build if the dial is still pending after a short delay - if validation finds a different agent, cancel the stale dial, switch to the current agent, and persist the repaired binding ## Result The hot path stops issuing `GetWorkspaceAgentsInLatestBuildByWorkspaceID` for every user message, which is the source of the DB pressure this PR is addressing. At the same time, chats still converge to the correct workspace agent when the binding becomes stale due to rebuilds or explicit workspace changes.
This commit is contained in:
@@ -5619,6 +5619,18 @@ func (q *querier) UpdateAPIKeyByID(ctx context.Context, arg database.UpdateAPIKe
|
||||
return update(q.log, q.auth, fetch, q.db.UpdateAPIKeyByID)(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatBuildAgentBinding(ctx context.Context, arg database.UpdateChatBuildAgentBindingParams) (database.Chat, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ID)
|
||||
if err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
|
||||
return q.db.UpdateChatBuildAgentBinding(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatByID(ctx context.Context, arg database.UpdateChatByIDParams) (database.Chat, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ID)
|
||||
if err != nil {
|
||||
@@ -5706,7 +5718,7 @@ func (q *querier) UpdateChatStatus(ctx context.Context, arg database.UpdateChatS
|
||||
return q.db.UpdateChatStatus(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatWorkspace(ctx context.Context, arg database.UpdateChatWorkspaceParams) (database.Chat, error) {
|
||||
func (q *querier) UpdateChatWorkspaceBinding(ctx context.Context, arg database.UpdateChatWorkspaceBindingParams) (database.Chat, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ID)
|
||||
if err != nil {
|
||||
return database.Chat{}, err
|
||||
@@ -5715,15 +5727,7 @@ func (q *querier) UpdateChatWorkspace(ctx context.Context, arg database.UpdateCh
|
||||
return database.Chat{}, err
|
||||
}
|
||||
|
||||
// UpdateChatWorkspace is manually implemented for chat tables and may not be
|
||||
// present on every wrapped store interface yet.
|
||||
chatWorkspaceUpdater, ok := q.db.(interface {
|
||||
UpdateChatWorkspace(context.Context, database.UpdateChatWorkspaceParams) (database.Chat, error)
|
||||
})
|
||||
if !ok {
|
||||
return database.Chat{}, xerrors.New("update chat workspace is not implemented by wrapped store")
|
||||
}
|
||||
return chatWorkspaceUpdater.UpdateChatWorkspace(ctx, arg)
|
||||
return q.db.UpdateChatWorkspaceBinding(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateCryptoKeyDeletesAt(ctx context.Context, arg database.UpdateCryptoKeyDeletesAtParams) (database.CryptoKey, error) {
|
||||
|
||||
@@ -819,15 +819,29 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().UpdateChatStatus(gomock.Any(), arg).Return(chat, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(chat)
|
||||
}))
|
||||
s.Run("UpdateChatWorkspace", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
s.Run("UpdateChatBuildAgentBinding", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.UpdateChatWorkspaceParams{
|
||||
ID: chat.ID,
|
||||
WorkspaceID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
|
||||
arg := database.UpdateChatBuildAgentBindingParams{
|
||||
ID: chat.ID,
|
||||
BuildID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
|
||||
AgentID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
|
||||
}
|
||||
updatedChat := testutil.Fake(s.T(), faker, database.Chat{ID: chat.ID})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateChatWorkspace(gomock.Any(), arg).Return(updatedChat, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateChatBuildAgentBinding(gomock.Any(), arg).Return(updatedChat, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(updatedChat)
|
||||
}))
|
||||
s.Run("UpdateChatWorkspaceBinding", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.UpdateChatWorkspaceBindingParams{
|
||||
ID: chat.ID,
|
||||
WorkspaceID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
|
||||
BuildID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
|
||||
AgentID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
|
||||
}
|
||||
updatedChat := testutil.Fake(s.T(), faker, database.Chat{ID: chat.ID})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateChatWorkspaceBinding(gomock.Any(), arg).Return(updatedChat, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(updatedChat)
|
||||
}))
|
||||
s.Run("UnsetDefaultChatModelConfigs", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
|
||||
@@ -4000,6 +4000,14 @@ func (m queryMetricsStore) UpdateAPIKeyByID(ctx context.Context, arg database.Up
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatBuildAgentBinding(ctx context.Context, arg database.UpdateChatBuildAgentBindingParams) (database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateChatBuildAgentBinding(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateChatBuildAgentBinding").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatBuildAgentBinding").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatByID(ctx context.Context, arg database.UpdateChatByIDParams) (database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateChatByID(ctx, arg)
|
||||
@@ -4064,11 +4072,11 @@ func (m queryMetricsStore) UpdateChatStatus(ctx context.Context, arg database.Up
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatWorkspace(ctx context.Context, arg database.UpdateChatWorkspaceParams) (database.Chat, error) {
|
||||
func (m queryMetricsStore) UpdateChatWorkspaceBinding(ctx context.Context, arg database.UpdateChatWorkspaceBindingParams) (database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateChatWorkspace(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateChatWorkspace").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatWorkspace").Inc()
|
||||
r0, r1 := m.s.UpdateChatWorkspaceBinding(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateChatWorkspaceBinding").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatWorkspaceBinding").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
|
||||
@@ -7552,6 +7552,21 @@ func (mr *MockStoreMockRecorder) UpdateAPIKeyByID(ctx, arg any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAPIKeyByID", reflect.TypeOf((*MockStore)(nil).UpdateAPIKeyByID), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatBuildAgentBinding mocks base method.
|
||||
func (m *MockStore) UpdateChatBuildAgentBinding(ctx context.Context, arg database.UpdateChatBuildAgentBindingParams) (database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateChatBuildAgentBinding", ctx, arg)
|
||||
ret0, _ := ret[0].(database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateChatBuildAgentBinding indicates an expected call of UpdateChatBuildAgentBinding.
|
||||
func (mr *MockStoreMockRecorder) UpdateChatBuildAgentBinding(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatBuildAgentBinding", reflect.TypeOf((*MockStore)(nil).UpdateChatBuildAgentBinding), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatByID mocks base method.
|
||||
func (m *MockStore) UpdateChatByID(ctx context.Context, arg database.UpdateChatByIDParams) (database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -7672,19 +7687,19 @@ func (mr *MockStoreMockRecorder) UpdateChatStatus(ctx, arg any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatStatus", reflect.TypeOf((*MockStore)(nil).UpdateChatStatus), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatWorkspace mocks base method.
|
||||
func (m *MockStore) UpdateChatWorkspace(ctx context.Context, arg database.UpdateChatWorkspaceParams) (database.Chat, error) {
|
||||
// UpdateChatWorkspaceBinding mocks base method.
|
||||
func (m *MockStore) UpdateChatWorkspaceBinding(ctx context.Context, arg database.UpdateChatWorkspaceBindingParams) (database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateChatWorkspace", ctx, arg)
|
||||
ret := m.ctrl.Call(m, "UpdateChatWorkspaceBinding", ctx, arg)
|
||||
ret0, _ := ret[0].(database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateChatWorkspace indicates an expected call of UpdateChatWorkspace.
|
||||
func (mr *MockStoreMockRecorder) UpdateChatWorkspace(ctx, arg any) *gomock.Call {
|
||||
// UpdateChatWorkspaceBinding indicates an expected call of UpdateChatWorkspaceBinding.
|
||||
func (mr *MockStoreMockRecorder) UpdateChatWorkspaceBinding(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatWorkspace", reflect.TypeOf((*MockStore)(nil).UpdateChatWorkspace), ctx, arg)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatWorkspaceBinding", reflect.TypeOf((*MockStore)(nil).UpdateChatWorkspaceBinding), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateCryptoKeyDeletesAt mocks base method.
|
||||
|
||||
Generated
+9
-1
@@ -1399,7 +1399,9 @@ CREATE TABLE chats (
|
||||
last_error text,
|
||||
mode chat_mode,
|
||||
mcp_server_ids uuid[] DEFAULT '{}'::uuid[] NOT NULL,
|
||||
labels jsonb DEFAULT '{}'::jsonb NOT NULL
|
||||
labels jsonb DEFAULT '{}'::jsonb NOT NULL,
|
||||
build_id uuid,
|
||||
agent_id uuid
|
||||
);
|
||||
|
||||
CREATE TABLE connection_logs (
|
||||
@@ -4033,6 +4035,12 @@ ALTER TABLE ONLY chat_providers
|
||||
ALTER TABLE ONLY chat_queued_messages
|
||||
ADD CONSTRAINT chat_queued_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY chats
|
||||
ADD CONSTRAINT chats_agent_id_fkey FOREIGN KEY (agent_id) REFERENCES workspace_agents(id) ON DELETE SET NULL;
|
||||
|
||||
ALTER TABLE ONLY chats
|
||||
ADD CONSTRAINT chats_build_id_fkey FOREIGN KEY (build_id) REFERENCES workspace_builds(id) ON DELETE SET NULL;
|
||||
|
||||
ALTER TABLE ONLY chats
|
||||
ADD CONSTRAINT chats_last_model_config_id_fkey FOREIGN KEY (last_model_config_id) REFERENCES chat_model_configs(id);
|
||||
|
||||
|
||||
@@ -20,6 +20,8 @@ const (
|
||||
ForeignKeyChatProvidersAPIKeyKeyID ForeignKeyConstraint = "chat_providers_api_key_key_id_fkey" // ALTER TABLE ONLY chat_providers ADD CONSTRAINT chat_providers_api_key_key_id_fkey FOREIGN KEY (api_key_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
ForeignKeyChatProvidersCreatedBy ForeignKeyConstraint = "chat_providers_created_by_fkey" // ALTER TABLE ONLY chat_providers ADD CONSTRAINT chat_providers_created_by_fkey FOREIGN KEY (created_by) REFERENCES users(id);
|
||||
ForeignKeyChatQueuedMessagesChatID ForeignKeyConstraint = "chat_queued_messages_chat_id_fkey" // ALTER TABLE ONLY chat_queued_messages ADD CONSTRAINT chat_queued_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatsAgentID ForeignKeyConstraint = "chats_agent_id_fkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_agent_id_fkey FOREIGN KEY (agent_id) REFERENCES workspace_agents(id) ON DELETE SET NULL;
|
||||
ForeignKeyChatsBuildID ForeignKeyConstraint = "chats_build_id_fkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_build_id_fkey FOREIGN KEY (build_id) REFERENCES workspace_builds(id) ON DELETE SET NULL;
|
||||
ForeignKeyChatsLastModelConfigID ForeignKeyConstraint = "chats_last_model_config_id_fkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_last_model_config_id_fkey FOREIGN KEY (last_model_config_id) REFERENCES chat_model_configs(id);
|
||||
ForeignKeyChatsOwnerID ForeignKeyConstraint = "chats_owner_id_fkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_owner_id_fkey FOREIGN KEY (owner_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatsParentChatID ForeignKeyConstraint = "chats_parent_chat_id_fkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_parent_chat_id_fkey FOREIGN KEY (parent_chat_id) REFERENCES chats(id) ON DELETE SET NULL;
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
ALTER TABLE chats
|
||||
DROP COLUMN IF EXISTS build_id,
|
||||
DROP COLUMN IF EXISTS agent_id;
|
||||
@@ -0,0 +1,3 @@
|
||||
ALTER TABLE chats
|
||||
ADD COLUMN build_id UUID REFERENCES workspace_builds(id) ON DELETE SET NULL,
|
||||
ADD COLUMN agent_id UUID REFERENCES workspace_agents(id) ON DELETE SET NULL;
|
||||
@@ -791,6 +791,8 @@ func (q *sqlQuerier) GetAuthorizedChats(ctx context.Context, arg GetChatsParams,
|
||||
&i.Mode,
|
||||
pq.Array(&i.MCPServerIDs),
|
||||
&i.Labels,
|
||||
&i.BuildID,
|
||||
&i.AgentID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -4171,6 +4171,8 @@ type Chat struct {
|
||||
Mode NullChatMode `db:"mode" json:"mode"`
|
||||
MCPServerIDs []uuid.UUID `db:"mcp_server_ids" json:"mcp_server_ids"`
|
||||
Labels StringMap `db:"labels" json:"labels"`
|
||||
BuildID uuid.NullUUID `db:"build_id" json:"build_id"`
|
||||
AgentID uuid.NullUUID `db:"agent_id" json:"agent_id"`
|
||||
}
|
||||
|
||||
type ChatDiffStatus struct {
|
||||
|
||||
@@ -819,6 +819,7 @@ type sqlcQuerier interface {
|
||||
UnsetDefaultChatModelConfigs(ctx context.Context) error
|
||||
UpdateAIBridgeInterceptionEnded(ctx context.Context, arg UpdateAIBridgeInterceptionEndedParams) (AIBridgeInterception, error)
|
||||
UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error
|
||||
UpdateChatBuildAgentBinding(ctx context.Context, arg UpdateChatBuildAgentBindingParams) (Chat, error)
|
||||
UpdateChatByID(ctx context.Context, arg UpdateChatByIDParams) (Chat, error)
|
||||
// Bumps the heartbeat timestamp for a running chat so that other
|
||||
// replicas know the worker is still alive.
|
||||
@@ -829,7 +830,7 @@ type sqlcQuerier interface {
|
||||
UpdateChatModelConfig(ctx context.Context, arg UpdateChatModelConfigParams) (ChatModelConfig, error)
|
||||
UpdateChatProvider(ctx context.Context, arg UpdateChatProviderParams) (ChatProvider, error)
|
||||
UpdateChatStatus(ctx context.Context, arg UpdateChatStatusParams) (Chat, error)
|
||||
UpdateChatWorkspace(ctx context.Context, arg UpdateChatWorkspaceParams) (Chat, error)
|
||||
UpdateChatWorkspaceBinding(ctx context.Context, arg UpdateChatWorkspaceBindingParams) (Chat, error)
|
||||
UpdateCryptoKeyDeletesAt(ctx context.Context, arg UpdateCryptoKeyDeletesAtParams) (CryptoKey, error)
|
||||
UpdateCustomRole(ctx context.Context, arg UpdateCustomRoleParams) (CustomRole, error)
|
||||
UpdateExternalAuthLink(ctx context.Context, arg UpdateExternalAuthLinkParams) (ExternalAuthLink, error)
|
||||
|
||||
+104
-25
@@ -3823,7 +3823,7 @@ WHERE
|
||||
$3::int
|
||||
)
|
||||
RETURNING
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id
|
||||
`
|
||||
|
||||
type AcquireChatsParams struct {
|
||||
@@ -3862,6 +3862,8 @@ func (q *sqlQuerier) AcquireChats(ctx context.Context, arg AcquireChatsParams) (
|
||||
&i.Mode,
|
||||
pq.Array(&i.MCPServerIDs),
|
||||
&i.Labels,
|
||||
&i.BuildID,
|
||||
&i.AgentID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -4095,7 +4097,7 @@ func (q *sqlQuerier) DeleteChatUsageLimitUserOverride(ctx context.Context, userI
|
||||
|
||||
const getChatByID = `-- name: GetChatByID :one
|
||||
SELECT
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id
|
||||
FROM
|
||||
chats
|
||||
WHERE
|
||||
@@ -4124,12 +4126,14 @@ func (q *sqlQuerier) GetChatByID(ctx context.Context, id uuid.UUID) (Chat, error
|
||||
&i.Mode,
|
||||
pq.Array(&i.MCPServerIDs),
|
||||
&i.Labels,
|
||||
&i.BuildID,
|
||||
&i.AgentID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getChatByIDForUpdate = `-- name: GetChatByIDForUpdate :one
|
||||
SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels FROM chats WHERE id = $1::uuid FOR UPDATE
|
||||
SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id FROM chats WHERE id = $1::uuid FOR UPDATE
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (Chat, error) {
|
||||
@@ -4154,6 +4158,8 @@ func (q *sqlQuerier) GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (Ch
|
||||
&i.Mode,
|
||||
pq.Array(&i.MCPServerIDs),
|
||||
&i.Labels,
|
||||
&i.BuildID,
|
||||
&i.AgentID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@@ -4998,7 +5004,7 @@ func (q *sqlQuerier) GetChatUsageLimitUserOverride(ctx context.Context, userID u
|
||||
|
||||
const getChats = `-- name: GetChats :many
|
||||
SELECT
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id
|
||||
FROM
|
||||
chats
|
||||
WHERE
|
||||
@@ -5089,6 +5095,8 @@ func (q *sqlQuerier) GetChats(ctx context.Context, arg GetChatsParams) ([]Chat,
|
||||
&i.Mode,
|
||||
pq.Array(&i.MCPServerIDs),
|
||||
&i.Labels,
|
||||
&i.BuildID,
|
||||
&i.AgentID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -5154,7 +5162,7 @@ func (q *sqlQuerier) GetLastChatMessageByRole(ctx context.Context, arg GetLastCh
|
||||
|
||||
const getStaleChats = `-- name: GetStaleChats :many
|
||||
SELECT
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id
|
||||
FROM
|
||||
chats
|
||||
WHERE
|
||||
@@ -5192,6 +5200,8 @@ func (q *sqlQuerier) GetStaleChats(ctx context.Context, staleThreshold time.Time
|
||||
&i.Mode,
|
||||
pq.Array(&i.MCPServerIDs),
|
||||
&i.Labels,
|
||||
&i.BuildID,
|
||||
&i.AgentID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -5250,6 +5260,8 @@ const insertChat = `-- name: InsertChat :one
|
||||
INSERT INTO chats (
|
||||
owner_id,
|
||||
workspace_id,
|
||||
build_id,
|
||||
agent_id,
|
||||
parent_chat_id,
|
||||
root_chat_id,
|
||||
last_model_config_id,
|
||||
@@ -5263,18 +5275,22 @@ INSERT INTO chats (
|
||||
$3::uuid,
|
||||
$4::uuid,
|
||||
$5::uuid,
|
||||
$6::text,
|
||||
$7::chat_mode,
|
||||
COALESCE($8::uuid[], '{}'::uuid[]),
|
||||
COALESCE($9::jsonb, '{}'::jsonb)
|
||||
$6::uuid,
|
||||
$7::uuid,
|
||||
$8::text,
|
||||
$9::chat_mode,
|
||||
COALESCE($10::uuid[], '{}'::uuid[]),
|
||||
COALESCE($11::jsonb, '{}'::jsonb)
|
||||
)
|
||||
RETURNING
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id
|
||||
`
|
||||
|
||||
type InsertChatParams struct {
|
||||
OwnerID uuid.UUID `db:"owner_id" json:"owner_id"`
|
||||
WorkspaceID uuid.NullUUID `db:"workspace_id" json:"workspace_id"`
|
||||
BuildID uuid.NullUUID `db:"build_id" json:"build_id"`
|
||||
AgentID uuid.NullUUID `db:"agent_id" json:"agent_id"`
|
||||
ParentChatID uuid.NullUUID `db:"parent_chat_id" json:"parent_chat_id"`
|
||||
RootChatID uuid.NullUUID `db:"root_chat_id" json:"root_chat_id"`
|
||||
LastModelConfigID uuid.UUID `db:"last_model_config_id" json:"last_model_config_id"`
|
||||
@@ -5288,6 +5304,8 @@ func (q *sqlQuerier) InsertChat(ctx context.Context, arg InsertChatParams) (Chat
|
||||
row := q.db.QueryRowContext(ctx, insertChat,
|
||||
arg.OwnerID,
|
||||
arg.WorkspaceID,
|
||||
arg.BuildID,
|
||||
arg.AgentID,
|
||||
arg.ParentChatID,
|
||||
arg.RootChatID,
|
||||
arg.LastModelConfigID,
|
||||
@@ -5316,6 +5334,8 @@ func (q *sqlQuerier) InsertChat(ctx context.Context, arg InsertChatParams) (Chat
|
||||
&i.Mode,
|
||||
pq.Array(&i.MCPServerIDs),
|
||||
&i.Labels,
|
||||
&i.BuildID,
|
||||
&i.AgentID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@@ -5702,6 +5722,50 @@ func (q *sqlQuerier) UnarchiveChatByID(ctx context.Context, id uuid.UUID) error
|
||||
return err
|
||||
}
|
||||
|
||||
const updateChatBuildAgentBinding = `-- name: UpdateChatBuildAgentBinding :one
|
||||
UPDATE chats SET
|
||||
build_id = $1::uuid,
|
||||
agent_id = $2::uuid,
|
||||
updated_at = NOW()
|
||||
WHERE
|
||||
id = $3::uuid
|
||||
RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id
|
||||
`
|
||||
|
||||
type UpdateChatBuildAgentBindingParams struct {
|
||||
BuildID uuid.NullUUID `db:"build_id" json:"build_id"`
|
||||
AgentID uuid.NullUUID `db:"agent_id" json:"agent_id"`
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) UpdateChatBuildAgentBinding(ctx context.Context, arg UpdateChatBuildAgentBindingParams) (Chat, error) {
|
||||
row := q.db.QueryRowContext(ctx, updateChatBuildAgentBinding, arg.BuildID, arg.AgentID, arg.ID)
|
||||
var i Chat
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.OwnerID,
|
||||
&i.WorkspaceID,
|
||||
&i.Title,
|
||||
&i.Status,
|
||||
&i.WorkerID,
|
||||
&i.StartedAt,
|
||||
&i.HeartbeatAt,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.ParentChatID,
|
||||
&i.RootChatID,
|
||||
&i.LastModelConfigID,
|
||||
&i.Archived,
|
||||
&i.LastError,
|
||||
&i.Mode,
|
||||
pq.Array(&i.MCPServerIDs),
|
||||
&i.Labels,
|
||||
&i.BuildID,
|
||||
&i.AgentID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const updateChatByID = `-- name: UpdateChatByID :one
|
||||
UPDATE
|
||||
chats
|
||||
@@ -5711,7 +5775,7 @@ SET
|
||||
WHERE
|
||||
id = $2::uuid
|
||||
RETURNING
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id
|
||||
`
|
||||
|
||||
type UpdateChatByIDParams struct {
|
||||
@@ -5741,6 +5805,8 @@ func (q *sqlQuerier) UpdateChatByID(ctx context.Context, arg UpdateChatByIDParam
|
||||
&i.Mode,
|
||||
pq.Array(&i.MCPServerIDs),
|
||||
&i.Labels,
|
||||
&i.BuildID,
|
||||
&i.AgentID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@@ -5780,7 +5846,7 @@ SET
|
||||
WHERE
|
||||
id = $2::uuid
|
||||
RETURNING
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id
|
||||
`
|
||||
|
||||
type UpdateChatLabelsByIDParams struct {
|
||||
@@ -5810,6 +5876,8 @@ func (q *sqlQuerier) UpdateChatLabelsByID(ctx context.Context, arg UpdateChatLab
|
||||
&i.Mode,
|
||||
pq.Array(&i.MCPServerIDs),
|
||||
&i.Labels,
|
||||
&i.BuildID,
|
||||
&i.AgentID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@@ -5823,7 +5891,7 @@ SET
|
||||
WHERE
|
||||
id = $2::uuid
|
||||
RETURNING
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id
|
||||
`
|
||||
|
||||
type UpdateChatMCPServerIDsParams struct {
|
||||
@@ -5853,6 +5921,8 @@ func (q *sqlQuerier) UpdateChatMCPServerIDs(ctx context.Context, arg UpdateChatM
|
||||
&i.Mode,
|
||||
pq.Array(&i.MCPServerIDs),
|
||||
&i.Labels,
|
||||
&i.BuildID,
|
||||
&i.AgentID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@@ -5917,7 +5987,7 @@ SET
|
||||
WHERE
|
||||
id = $6::uuid
|
||||
RETURNING
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id
|
||||
`
|
||||
|
||||
type UpdateChatStatusParams struct {
|
||||
@@ -5958,29 +6028,36 @@ func (q *sqlQuerier) UpdateChatStatus(ctx context.Context, arg UpdateChatStatusP
|
||||
&i.Mode,
|
||||
pq.Array(&i.MCPServerIDs),
|
||||
&i.Labels,
|
||||
&i.BuildID,
|
||||
&i.AgentID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const updateChatWorkspace = `-- name: UpdateChatWorkspace :one
|
||||
UPDATE
|
||||
chats
|
||||
SET
|
||||
const updateChatWorkspaceBinding = `-- name: UpdateChatWorkspaceBinding :one
|
||||
UPDATE chats SET
|
||||
workspace_id = $1::uuid,
|
||||
build_id = $2::uuid,
|
||||
agent_id = $3::uuid,
|
||||
updated_at = NOW()
|
||||
WHERE
|
||||
id = $2::uuid
|
||||
RETURNING
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels
|
||||
WHERE id = $4::uuid
|
||||
RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id
|
||||
`
|
||||
|
||||
type UpdateChatWorkspaceParams struct {
|
||||
type UpdateChatWorkspaceBindingParams struct {
|
||||
WorkspaceID uuid.NullUUID `db:"workspace_id" json:"workspace_id"`
|
||||
BuildID uuid.NullUUID `db:"build_id" json:"build_id"`
|
||||
AgentID uuid.NullUUID `db:"agent_id" json:"agent_id"`
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) UpdateChatWorkspace(ctx context.Context, arg UpdateChatWorkspaceParams) (Chat, error) {
|
||||
row := q.db.QueryRowContext(ctx, updateChatWorkspace, arg.WorkspaceID, arg.ID)
|
||||
func (q *sqlQuerier) UpdateChatWorkspaceBinding(ctx context.Context, arg UpdateChatWorkspaceBindingParams) (Chat, error) {
|
||||
row := q.db.QueryRowContext(ctx, updateChatWorkspaceBinding,
|
||||
arg.WorkspaceID,
|
||||
arg.BuildID,
|
||||
arg.AgentID,
|
||||
arg.ID,
|
||||
)
|
||||
var i Chat
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
@@ -6001,6 +6078,8 @@ func (q *sqlQuerier) UpdateChatWorkspace(ctx context.Context, arg UpdateChatWork
|
||||
&i.Mode,
|
||||
pq.Array(&i.MCPServerIDs),
|
||||
&i.Labels,
|
||||
&i.BuildID,
|
||||
&i.AgentID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
@@ -180,6 +180,8 @@ LIMIT
|
||||
INSERT INTO chats (
|
||||
owner_id,
|
||||
workspace_id,
|
||||
build_id,
|
||||
agent_id,
|
||||
parent_chat_id,
|
||||
root_chat_id,
|
||||
last_model_config_id,
|
||||
@@ -190,6 +192,8 @@ INSERT INTO chats (
|
||||
) VALUES (
|
||||
@owner_id::uuid,
|
||||
sqlc.narg('workspace_id')::uuid,
|
||||
sqlc.narg('build_id')::uuid,
|
||||
sqlc.narg('agent_id')::uuid,
|
||||
sqlc.narg('parent_chat_id')::uuid,
|
||||
sqlc.narg('root_chat_id')::uuid,
|
||||
@last_model_config_id::uuid,
|
||||
@@ -305,16 +309,23 @@ WHERE
|
||||
RETURNING
|
||||
*;
|
||||
|
||||
-- name: UpdateChatWorkspace :one
|
||||
UPDATE
|
||||
chats
|
||||
SET
|
||||
-- name: UpdateChatWorkspaceBinding :one
|
||||
UPDATE chats SET
|
||||
workspace_id = sqlc.narg('workspace_id')::uuid,
|
||||
build_id = sqlc.narg('build_id')::uuid,
|
||||
agent_id = sqlc.narg('agent_id')::uuid,
|
||||
updated_at = NOW()
|
||||
WHERE id = @id::uuid
|
||||
RETURNING *;
|
||||
|
||||
-- name: UpdateChatBuildAgentBinding :one
|
||||
UPDATE chats SET
|
||||
build_id = sqlc.narg('build_id')::uuid,
|
||||
agent_id = sqlc.narg('agent_id')::uuid,
|
||||
updated_at = NOW()
|
||||
WHERE
|
||||
id = @id::uuid
|
||||
RETURNING
|
||||
*;
|
||||
RETURNING *;
|
||||
|
||||
-- name: UpdateChatMCPServerIDs :one
|
||||
UPDATE
|
||||
|
||||
@@ -3600,6 +3600,12 @@ func convertChat(c database.Chat, diffStatus *database.ChatDiffStatus) codersdk.
|
||||
if c.WorkspaceID.Valid {
|
||||
chat.WorkspaceID = &c.WorkspaceID.UUID
|
||||
}
|
||||
if c.BuildID.Valid {
|
||||
chat.BuildID = &c.BuildID.UUID
|
||||
}
|
||||
if c.AgentID.Valid {
|
||||
chat.AgentID = &c.AgentID.UUID
|
||||
}
|
||||
if diffStatus != nil {
|
||||
convertedDiffStatus := db2sdk.ChatDiffStatus(c.ID, diffStatus)
|
||||
chat.DiffStatus = &convertedDiffStatus
|
||||
|
||||
+293
-99
@@ -53,6 +53,8 @@ const (
|
||||
DefaultInFlightChatStaleAfter = 5 * time.Minute
|
||||
|
||||
homeInstructionLookupTimeout = 5 * time.Second
|
||||
instructionCacheTTL = 5 * time.Minute
|
||||
workspaceDialValidationDelay = 5 * time.Second
|
||||
// DefaultChatHeartbeatInterval is the default time between chat
|
||||
// heartbeat updates while a chat is being processed.
|
||||
DefaultChatHeartbeatInterval = 30 * time.Second
|
||||
@@ -158,18 +160,26 @@ type turnWorkspaceContext struct {
|
||||
currentChat *database.Chat
|
||||
loadChatSnapshot func(context.Context, uuid.UUID) (database.Chat, error)
|
||||
|
||||
mu sync.Mutex
|
||||
agent database.WorkspaceAgent
|
||||
agentLoaded bool
|
||||
conn workspacesdk.AgentConn
|
||||
releaseConn func()
|
||||
mu sync.Mutex
|
||||
agent database.WorkspaceAgent
|
||||
agentLoaded bool
|
||||
conn workspacesdk.AgentConn
|
||||
releaseConn func()
|
||||
cachedWorkspaceID uuid.NullUUID
|
||||
}
|
||||
|
||||
func (c *turnWorkspaceContext) close() {
|
||||
c.clearCachedWorkspaceState()
|
||||
}
|
||||
|
||||
func (c *turnWorkspaceContext) clearCachedWorkspaceState() {
|
||||
c.mu.Lock()
|
||||
releaseConn := c.releaseConn
|
||||
c.agent = database.WorkspaceAgent{}
|
||||
c.agentLoaded = false
|
||||
c.conn = nil
|
||||
c.releaseConn = nil
|
||||
c.cachedWorkspaceID = uuid.NullUUID{}
|
||||
c.mu.Unlock()
|
||||
|
||||
if releaseConn != nil {
|
||||
@@ -177,6 +187,68 @@ func (c *turnWorkspaceContext) close() {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *turnWorkspaceContext) setCurrentChat(chat database.Chat) {
|
||||
c.chatStateMu.Lock()
|
||||
*c.currentChat = chat
|
||||
c.chatStateMu.Unlock()
|
||||
}
|
||||
|
||||
func (c *turnWorkspaceContext) currentChatSnapshot() database.Chat {
|
||||
c.chatStateMu.Lock()
|
||||
chatSnapshot := *c.currentChat
|
||||
c.chatStateMu.Unlock()
|
||||
return chatSnapshot
|
||||
}
|
||||
|
||||
func (c *turnWorkspaceContext) selectWorkspace(chat database.Chat) {
|
||||
c.setCurrentChat(chat)
|
||||
c.clearCachedWorkspaceState()
|
||||
}
|
||||
|
||||
func (c *turnWorkspaceContext) currentWorkspaceMatches(expected uuid.NullUUID) (database.Chat, bool) {
|
||||
chatSnapshot := c.currentChatSnapshot()
|
||||
return chatSnapshot, nullUUIDEqual(chatSnapshot.WorkspaceID, expected)
|
||||
}
|
||||
|
||||
func nullUUIDEqual(left, right uuid.NullUUID) bool {
|
||||
if left.Valid != right.Valid {
|
||||
return false
|
||||
}
|
||||
if !left.Valid {
|
||||
return true
|
||||
}
|
||||
return left.UUID == right.UUID
|
||||
}
|
||||
|
||||
func (c *turnWorkspaceContext) persistBuildAgentBinding(
|
||||
ctx context.Context,
|
||||
chatSnapshot database.Chat,
|
||||
buildID uuid.UUID,
|
||||
agentID uuid.UUID,
|
||||
) (database.Chat, error) {
|
||||
updatedChat, err := c.server.db.UpdateChatBuildAgentBinding(
|
||||
ctx,
|
||||
database.UpdateChatBuildAgentBindingParams{
|
||||
ID: chatSnapshot.ID,
|
||||
BuildID: uuid.NullUUID{
|
||||
UUID: buildID,
|
||||
Valid: true,
|
||||
},
|
||||
AgentID: uuid.NullUUID{
|
||||
UUID: agentID,
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return chatSnapshot, xerrors.Errorf(
|
||||
"update chat build/agent binding: %w", err,
|
||||
)
|
||||
}
|
||||
c.setCurrentChat(updatedChat)
|
||||
return updatedChat, nil
|
||||
}
|
||||
|
||||
func (c *turnWorkspaceContext) getWorkspaceAgent(ctx context.Context) (database.WorkspaceAgent, error) {
|
||||
_, agent, err := c.ensureWorkspaceAgent(ctx)
|
||||
return agent, err
|
||||
@@ -189,128 +261,245 @@ func (c *turnWorkspaceContext) ensureWorkspaceAgent(
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.agentLoaded {
|
||||
c.chatStateMu.Lock()
|
||||
chatSnapshot := *c.currentChat
|
||||
c.chatStateMu.Unlock()
|
||||
return chatSnapshot, c.agent, nil
|
||||
chatSnapshot := c.currentChatSnapshot()
|
||||
if nullUUIDEqual(c.cachedWorkspaceID, chatSnapshot.WorkspaceID) {
|
||||
return chatSnapshot, c.agent, nil
|
||||
}
|
||||
c.agent = database.WorkspaceAgent{}
|
||||
c.agentLoaded = false
|
||||
}
|
||||
|
||||
return c.loadWorkspaceAgentLocked(ctx)
|
||||
}
|
||||
|
||||
func (c *turnWorkspaceContext) refreshWorkspaceAgent(
|
||||
ctx context.Context,
|
||||
) (database.Chat, database.WorkspaceAgent, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.agent = database.WorkspaceAgent{}
|
||||
c.agentLoaded = false
|
||||
return c.loadWorkspaceAgentLocked(ctx)
|
||||
}
|
||||
|
||||
func (c *turnWorkspaceContext) loadWorkspaceAgentLocked(
|
||||
ctx context.Context,
|
||||
) (database.Chat, database.WorkspaceAgent, error) {
|
||||
c.chatStateMu.Lock()
|
||||
chatSnapshot := *c.currentChat
|
||||
c.chatStateMu.Unlock()
|
||||
chatSnapshot := c.currentChatSnapshot()
|
||||
|
||||
if !chatSnapshot.WorkspaceID.Valid {
|
||||
refreshedChat, refreshErr := refreshChatWorkspaceSnapshot(
|
||||
for attempt := 0; attempt < 2; attempt++ {
|
||||
if !chatSnapshot.WorkspaceID.Valid {
|
||||
refreshedChat, refreshErr := refreshChatWorkspaceSnapshot(
|
||||
ctx,
|
||||
chatSnapshot,
|
||||
c.loadChatSnapshot,
|
||||
)
|
||||
if refreshErr != nil {
|
||||
return chatSnapshot, database.WorkspaceAgent{}, refreshErr
|
||||
}
|
||||
if refreshedChat.WorkspaceID.Valid {
|
||||
c.setCurrentChat(refreshedChat)
|
||||
chatSnapshot = refreshedChat
|
||||
}
|
||||
}
|
||||
|
||||
if !chatSnapshot.WorkspaceID.Valid {
|
||||
return chatSnapshot, database.WorkspaceAgent{}, xerrors.New("chat has no workspace")
|
||||
}
|
||||
|
||||
if chatSnapshot.AgentID.Valid {
|
||||
agent, err := c.server.db.GetWorkspaceAgentByID(ctx, chatSnapshot.AgentID.UUID)
|
||||
if err == nil {
|
||||
latestChat, workspaceMatches := c.currentWorkspaceMatches(chatSnapshot.WorkspaceID)
|
||||
if !workspaceMatches {
|
||||
chatSnapshot = latestChat
|
||||
continue
|
||||
}
|
||||
c.agent = agent
|
||||
c.agentLoaded = true
|
||||
c.cachedWorkspaceID = chatSnapshot.WorkspaceID
|
||||
return chatSnapshot, c.agent, nil
|
||||
}
|
||||
if !xerrors.Is(err, sql.ErrNoRows) {
|
||||
c.server.logger.Warn(ctx, "agent binding lookup failed, re-resolving",
|
||||
slog.F("agent_id", chatSnapshot.AgentID.UUID),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
agents, err := c.server.db.GetWorkspaceAgentsInLatestBuildByWorkspaceID(
|
||||
ctx,
|
||||
chatSnapshot.WorkspaceID.UUID,
|
||||
)
|
||||
if err != nil || len(agents) == 0 {
|
||||
return chatSnapshot, database.WorkspaceAgent{}, xerrors.New("chat has no workspace agent")
|
||||
}
|
||||
|
||||
build, err := c.server.db.GetLatestWorkspaceBuildByWorkspaceID(ctx, chatSnapshot.WorkspaceID.UUID)
|
||||
if err != nil {
|
||||
return chatSnapshot, database.WorkspaceAgent{}, xerrors.Errorf("get latest workspace build: %w", err)
|
||||
}
|
||||
|
||||
updatedChat, err := c.persistBuildAgentBinding(
|
||||
ctx,
|
||||
chatSnapshot,
|
||||
c.loadChatSnapshot,
|
||||
build.ID,
|
||||
agents[0].ID,
|
||||
)
|
||||
if refreshErr != nil {
|
||||
return chatSnapshot, database.WorkspaceAgent{}, refreshErr
|
||||
if err != nil {
|
||||
return chatSnapshot, database.WorkspaceAgent{}, err
|
||||
}
|
||||
if refreshedChat.WorkspaceID.Valid {
|
||||
c.chatStateMu.Lock()
|
||||
*c.currentChat = refreshedChat
|
||||
c.chatStateMu.Unlock()
|
||||
chatSnapshot = refreshedChat
|
||||
|
||||
chatSnapshot = updatedChat
|
||||
latestChat, workspaceMatches := c.currentWorkspaceMatches(chatSnapshot.WorkspaceID)
|
||||
if !workspaceMatches {
|
||||
chatSnapshot = latestChat
|
||||
continue
|
||||
}
|
||||
c.agent = agents[0]
|
||||
c.agentLoaded = true
|
||||
c.cachedWorkspaceID = chatSnapshot.WorkspaceID
|
||||
return chatSnapshot, c.agent, nil
|
||||
}
|
||||
|
||||
if !chatSnapshot.WorkspaceID.Valid {
|
||||
return chatSnapshot, database.WorkspaceAgent{}, xerrors.New("chat has no workspace")
|
||||
}
|
||||
|
||||
agents, err := c.server.db.GetWorkspaceAgentsInLatestBuildByWorkspaceID(
|
||||
ctx,
|
||||
chatSnapshot.WorkspaceID.UUID,
|
||||
return chatSnapshot, database.WorkspaceAgent{}, xerrors.New(
|
||||
"chat workspace changed while resolving agent",
|
||||
)
|
||||
if err != nil || len(agents) == 0 {
|
||||
return chatSnapshot, database.WorkspaceAgent{}, xerrors.New("chat has no workspace agent")
|
||||
}
|
||||
|
||||
// getWorkspaceConnLocked returns the cached connection when it still matches
|
||||
// the current workspace. When the workspace changed, it clears the stale
|
||||
// cached state and returns the release func for the caller to run after
|
||||
// unlocking.
|
||||
func (c *turnWorkspaceContext) getWorkspaceConnLocked() (workspacesdk.AgentConn, func()) {
|
||||
if c.conn == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
c.agent = agents[0]
|
||||
c.agentLoaded = true
|
||||
return chatSnapshot, c.agent, nil
|
||||
chatSnapshot := c.currentChatSnapshot()
|
||||
if nullUUIDEqual(c.cachedWorkspaceID, chatSnapshot.WorkspaceID) {
|
||||
return c.conn, nil
|
||||
}
|
||||
|
||||
agentRelease := c.releaseConn
|
||||
c.agent = database.WorkspaceAgent{}
|
||||
c.agentLoaded = false
|
||||
c.conn = nil
|
||||
c.releaseConn = nil
|
||||
c.cachedWorkspaceID = uuid.NullUUID{}
|
||||
return nil, agentRelease
|
||||
}
|
||||
|
||||
func (c *turnWorkspaceContext) getWorkspaceConn(ctx context.Context) (workspacesdk.AgentConn, error) {
|
||||
c.mu.Lock()
|
||||
if c.conn != nil {
|
||||
currentConn := c.conn
|
||||
c.mu.Unlock()
|
||||
return currentConn, nil
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
if c.server.agentConnFn == nil {
|
||||
return nil, xerrors.New("workspace agent connector is not configured")
|
||||
}
|
||||
|
||||
chatSnapshot, agent, err := c.ensureWorkspaceAgent(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
agentConn, agentRelease, err := c.server.agentConnFn(ctx, agent.ID)
|
||||
if err != nil {
|
||||
refreshedChat, refreshedAgent, refreshErr := c.refreshWorkspaceAgent(ctx)
|
||||
if refreshErr != nil {
|
||||
return nil, xerrors.Errorf("connect to workspace agent: %w", err)
|
||||
}
|
||||
|
||||
retryConn, retryRelease, retryErr := c.server.agentConnFn(ctx, refreshedAgent.ID)
|
||||
if retryErr != nil {
|
||||
return nil, xerrors.Errorf("connect to workspace agent after refresh: %w", retryErr)
|
||||
}
|
||||
|
||||
chatSnapshot = refreshedChat
|
||||
agentConn = retryConn
|
||||
agentRelease = retryRelease
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
if c.conn == nil {
|
||||
c.conn = agentConn
|
||||
c.releaseConn = agentRelease
|
||||
|
||||
var ancestorIDs []string
|
||||
if chatSnapshot.ParentChatID.Valid {
|
||||
ancestorIDs = append(ancestorIDs, chatSnapshot.ParentChatID.UUID.String())
|
||||
}
|
||||
ancestorJSON, marshalErr := json.Marshal(ancestorIDs)
|
||||
if marshalErr != nil {
|
||||
ancestorJSON = []byte("[]")
|
||||
}
|
||||
agentConn.SetExtraHeaders(http.Header{
|
||||
workspacesdk.CoderChatIDHeader: {chatSnapshot.ID.String()},
|
||||
workspacesdk.CoderAncestorChatIDsHeader: {string(ancestorJSON)},
|
||||
})
|
||||
|
||||
for attempt := 0; attempt < 2; attempt++ {
|
||||
c.mu.Lock()
|
||||
currentConn, staleRelease := c.getWorkspaceConnLocked()
|
||||
c.mu.Unlock()
|
||||
return agentConn, nil
|
||||
}
|
||||
currentConn := c.conn
|
||||
c.mu.Unlock()
|
||||
if currentConn != nil {
|
||||
return currentConn, nil
|
||||
}
|
||||
if staleRelease != nil {
|
||||
staleRelease()
|
||||
}
|
||||
|
||||
agentRelease()
|
||||
return currentConn, nil
|
||||
chatSnapshot, agent, err := c.ensureWorkspaceAgent(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dialResult, err := dialWithLazyValidation(
|
||||
ctx,
|
||||
agent.ID,
|
||||
chatSnapshot.WorkspaceID.UUID,
|
||||
DialFunc(c.server.agentConnFn),
|
||||
func(ctx context.Context, workspaceID uuid.UUID) (uuid.UUID, error) {
|
||||
agents, err := c.server.db.GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx, workspaceID)
|
||||
if err != nil || len(agents) == 0 {
|
||||
return uuid.Nil, xerrors.New("chat has no workspace agent")
|
||||
}
|
||||
return agents[0].ID, nil
|
||||
},
|
||||
workspaceDialValidationDelay,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
agentConn := dialResult.Conn
|
||||
agentRelease := dialResult.Release
|
||||
if dialResult.WasSwitched {
|
||||
build, err := c.server.db.GetLatestWorkspaceBuildByWorkspaceID(ctx, chatSnapshot.WorkspaceID.UUID)
|
||||
if err != nil {
|
||||
if agentRelease != nil {
|
||||
agentRelease()
|
||||
}
|
||||
return nil, xerrors.Errorf("get latest workspace build: %w", err)
|
||||
}
|
||||
|
||||
switchedAgent, err := c.server.db.GetWorkspaceAgentByID(ctx, dialResult.AgentID)
|
||||
if err != nil {
|
||||
if agentRelease != nil {
|
||||
agentRelease()
|
||||
}
|
||||
return nil, xerrors.Errorf("get workspace agent by id: %w", err)
|
||||
}
|
||||
|
||||
updatedChat, err := c.persistBuildAgentBinding(
|
||||
ctx,
|
||||
chatSnapshot,
|
||||
build.ID,
|
||||
switchedAgent.ID,
|
||||
)
|
||||
if err != nil {
|
||||
if agentRelease != nil {
|
||||
agentRelease()
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
chatSnapshot = updatedChat
|
||||
|
||||
c.mu.Lock()
|
||||
c.agent = switchedAgent
|
||||
c.agentLoaded = true
|
||||
c.cachedWorkspaceID = chatSnapshot.WorkspaceID
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
if _, workspaceMatches := c.currentWorkspaceMatches(chatSnapshot.WorkspaceID); !workspaceMatches {
|
||||
if agentRelease != nil {
|
||||
agentRelease()
|
||||
}
|
||||
c.clearCachedWorkspaceState()
|
||||
continue
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
if c.conn == nil {
|
||||
c.conn = agentConn
|
||||
c.releaseConn = agentRelease
|
||||
c.cachedWorkspaceID = chatSnapshot.WorkspaceID
|
||||
|
||||
var ancestorIDs []string
|
||||
if chatSnapshot.ParentChatID.Valid {
|
||||
ancestorIDs = append(ancestorIDs, chatSnapshot.ParentChatID.UUID.String())
|
||||
}
|
||||
ancestorJSON, marshalErr := json.Marshal(ancestorIDs)
|
||||
if marshalErr != nil {
|
||||
ancestorJSON = []byte("[]")
|
||||
}
|
||||
agentConn.SetExtraHeaders(http.Header{
|
||||
workspacesdk.CoderChatIDHeader: {chatSnapshot.ID.String()},
|
||||
workspacesdk.CoderAncestorChatIDsHeader: {string(ancestorJSON)},
|
||||
})
|
||||
|
||||
c.mu.Unlock()
|
||||
return agentConn, nil
|
||||
}
|
||||
currentConn = c.conn
|
||||
c.mu.Unlock()
|
||||
|
||||
if agentRelease != nil {
|
||||
agentRelease()
|
||||
}
|
||||
return currentConn, nil
|
||||
}
|
||||
|
||||
return nil, xerrors.New("chat workspace changed while connecting")
|
||||
}
|
||||
|
||||
// AgentConnFunc provides access to workspace agent connections.
|
||||
@@ -420,6 +609,8 @@ func (e *UsageLimitExceededError) Error() string {
|
||||
type CreateOptions struct {
|
||||
OwnerID uuid.UUID
|
||||
WorkspaceID uuid.NullUUID
|
||||
BuildID uuid.NullUUID
|
||||
AgentID uuid.NullUUID
|
||||
ParentChatID uuid.NullUUID
|
||||
RootChatID uuid.NullUUID
|
||||
Title string
|
||||
@@ -525,6 +716,8 @@ func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.C
|
||||
insertedChat, err := tx.InsertChat(ctx, database.InsertChatParams{
|
||||
OwnerID: opts.OwnerID,
|
||||
WorkspaceID: opts.WorkspaceID,
|
||||
BuildID: opts.BuildID,
|
||||
AgentID: opts.AgentID,
|
||||
ParentChatID: opts.ParentChatID,
|
||||
RootChatID: opts.RootChatID,
|
||||
LastModelConfigID: opts.ModelConfigID,
|
||||
@@ -3461,6 +3654,7 @@ func (p *Server) runChat(
|
||||
AgentConnFn: chattool.AgentConnFunc(p.agentConnFn),
|
||||
AgentInactiveDisconnectTimeout: p.agentInactiveDisconnectTimeout,
|
||||
WorkspaceMu: &workspaceMu,
|
||||
OnChatUpdated: workspaceCtx.selectWorkspace,
|
||||
Logger: p.logger,
|
||||
AllowedTemplateIDs: p.chatTemplateAllowlist,
|
||||
}),
|
||||
|
||||
@@ -112,24 +112,29 @@ func TestPersistInstructionFilesIncludesAgentMetadata(t *testing.T) {
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
|
||||
workspaceID := uuid.New()
|
||||
agentID := uuid.New()
|
||||
chat := database.Chat{
|
||||
ID: uuid.New(),
|
||||
WorkspaceID: uuid.NullUUID{
|
||||
UUID: workspaceID,
|
||||
Valid: true,
|
||||
},
|
||||
AgentID: uuid.NullUUID{
|
||||
UUID: agentID,
|
||||
Valid: true,
|
||||
},
|
||||
}
|
||||
workspaceAgent := database.WorkspaceAgent{
|
||||
ID: uuid.New(),
|
||||
ID: agentID,
|
||||
OperatingSystem: "linux",
|
||||
Directory: "/home/coder/project",
|
||||
ExpandedDirectory: "/home/coder/project",
|
||||
}
|
||||
|
||||
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(
|
||||
db.EXPECT().GetWorkspaceAgentByID(
|
||||
gomock.Any(),
|
||||
workspaceID,
|
||||
).Return([]database.WorkspaceAgent{workspaceAgent}, nil).Times(1)
|
||||
agentID,
|
||||
).Return(workspaceAgent, nil).Times(1)
|
||||
db.EXPECT().InsertChatMessages(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
|
||||
|
||||
conn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
@@ -180,7 +185,7 @@ func TestPersistInstructionFilesIncludesAgentMetadata(t *testing.T) {
|
||||
require.Contains(t, instruction, "Working Directory: /home/coder/project")
|
||||
}
|
||||
|
||||
func TestTurnWorkspaceContextGetWorkspaceConnRefreshesWorkspaceAgent(t *testing.T) {
|
||||
func TestTurnWorkspaceContext_BindingFirstPath(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
@@ -188,6 +193,53 @@ func TestTurnWorkspaceContextGetWorkspaceConnRefreshesWorkspaceAgent(t *testing.
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
|
||||
workspaceID := uuid.New()
|
||||
agentID := uuid.New()
|
||||
chat := database.Chat{
|
||||
ID: uuid.New(),
|
||||
WorkspaceID: uuid.NullUUID{
|
||||
UUID: workspaceID,
|
||||
Valid: true,
|
||||
},
|
||||
AgentID: uuid.NullUUID{
|
||||
UUID: agentID,
|
||||
Valid: true,
|
||||
},
|
||||
}
|
||||
workspaceAgent := database.WorkspaceAgent{ID: agentID}
|
||||
|
||||
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).Return(workspaceAgent, nil).Times(1)
|
||||
|
||||
chatStateMu := &sync.Mutex{}
|
||||
currentChat := chat
|
||||
workspaceCtx := turnWorkspaceContext{
|
||||
server: &Server{db: db},
|
||||
chatStateMu: chatStateMu,
|
||||
currentChat: ¤tChat,
|
||||
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
|
||||
}
|
||||
t.Cleanup(workspaceCtx.close)
|
||||
|
||||
chatSnapshot, agent, err := workspaceCtx.ensureWorkspaceAgent(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, chat, chatSnapshot)
|
||||
require.Equal(t, workspaceAgent, agent)
|
||||
|
||||
gotAgent, err := workspaceCtx.getWorkspaceAgent(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, workspaceAgent, gotAgent)
|
||||
require.Equal(t, chat, currentChat)
|
||||
}
|
||||
|
||||
func TestTurnWorkspaceContext_NullBindingLazyBind(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
|
||||
workspaceID := uuid.New()
|
||||
buildID := uuid.New()
|
||||
agentID := uuid.New()
|
||||
chat := database.Chat{
|
||||
ID: uuid.New(),
|
||||
WorkspaceID: uuid.NullUUID{
|
||||
@@ -195,18 +247,135 @@ func TestTurnWorkspaceContextGetWorkspaceConnRefreshesWorkspaceAgent(t *testing.
|
||||
Valid: true,
|
||||
},
|
||||
}
|
||||
initialAgent := database.WorkspaceAgent{ID: uuid.New()}
|
||||
refreshedAgent := database.WorkspaceAgent{ID: uuid.New()}
|
||||
workspaceAgent := database.WorkspaceAgent{ID: agentID}
|
||||
updatedChat := chat
|
||||
updatedChat.BuildID = uuid.NullUUID{UUID: buildID, Valid: true}
|
||||
updatedChat.AgentID = uuid.NullUUID{UUID: agentID, Valid: true}
|
||||
|
||||
gomock.InOrder(
|
||||
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(
|
||||
gomock.Any(),
|
||||
workspaceID,
|
||||
).Return([]database.WorkspaceAgent{initialAgent}, nil),
|
||||
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(
|
||||
gomock.Any(),
|
||||
workspaceID,
|
||||
).Return([]database.WorkspaceAgent{refreshedAgent}, nil),
|
||||
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).Return([]database.WorkspaceAgent{workspaceAgent}, nil),
|
||||
db.EXPECT().GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), workspaceID).Return(database.WorkspaceBuild{ID: buildID}, nil),
|
||||
db.EXPECT().UpdateChatBuildAgentBinding(gomock.Any(), database.UpdateChatBuildAgentBindingParams{
|
||||
BuildID: uuid.NullUUID{UUID: buildID, Valid: true},
|
||||
AgentID: uuid.NullUUID{UUID: agentID, Valid: true},
|
||||
ID: chat.ID,
|
||||
}).Return(updatedChat, nil),
|
||||
)
|
||||
|
||||
chatStateMu := &sync.Mutex{}
|
||||
currentChat := chat
|
||||
workspaceCtx := turnWorkspaceContext{
|
||||
server: &Server{db: db},
|
||||
chatStateMu: chatStateMu,
|
||||
currentChat: ¤tChat,
|
||||
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
|
||||
}
|
||||
t.Cleanup(workspaceCtx.close)
|
||||
|
||||
chatSnapshot, agent, err := workspaceCtx.ensureWorkspaceAgent(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, updatedChat, chatSnapshot)
|
||||
require.Equal(t, workspaceAgent, agent)
|
||||
require.Equal(t, updatedChat, currentChat)
|
||||
|
||||
gotAgent, err := workspaceCtx.getWorkspaceAgent(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, workspaceAgent, gotAgent)
|
||||
}
|
||||
|
||||
func TestTurnWorkspaceContext_StaleBindingRepair(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
|
||||
workspaceID := uuid.New()
|
||||
staleAgentID := uuid.New()
|
||||
buildID := uuid.New()
|
||||
currentAgentID := uuid.New()
|
||||
chat := database.Chat{
|
||||
ID: uuid.New(),
|
||||
WorkspaceID: uuid.NullUUID{
|
||||
UUID: workspaceID,
|
||||
Valid: true,
|
||||
},
|
||||
AgentID: uuid.NullUUID{
|
||||
UUID: staleAgentID,
|
||||
Valid: true,
|
||||
},
|
||||
}
|
||||
currentAgent := database.WorkspaceAgent{ID: currentAgentID}
|
||||
updatedChat := chat
|
||||
updatedChat.BuildID = uuid.NullUUID{UUID: buildID, Valid: true}
|
||||
updatedChat.AgentID = uuid.NullUUID{UUID: currentAgentID, Valid: true}
|
||||
|
||||
gomock.InOrder(
|
||||
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), staleAgentID).Return(database.WorkspaceAgent{}, xerrors.New("missing agent")),
|
||||
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).Return([]database.WorkspaceAgent{currentAgent}, nil),
|
||||
db.EXPECT().GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), workspaceID).Return(database.WorkspaceBuild{ID: buildID}, nil),
|
||||
db.EXPECT().UpdateChatBuildAgentBinding(gomock.Any(), database.UpdateChatBuildAgentBindingParams{
|
||||
BuildID: uuid.NullUUID{UUID: buildID, Valid: true},
|
||||
AgentID: uuid.NullUUID{UUID: currentAgentID, Valid: true},
|
||||
ID: chat.ID,
|
||||
}).Return(updatedChat, nil),
|
||||
)
|
||||
|
||||
chatStateMu := &sync.Mutex{}
|
||||
currentChat := chat
|
||||
workspaceCtx := turnWorkspaceContext{
|
||||
server: &Server{db: db},
|
||||
chatStateMu: chatStateMu,
|
||||
currentChat: ¤tChat,
|
||||
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
|
||||
}
|
||||
t.Cleanup(workspaceCtx.close)
|
||||
|
||||
chatSnapshot, agent, err := workspaceCtx.ensureWorkspaceAgent(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, updatedChat, chatSnapshot)
|
||||
require.Equal(t, currentAgent, agent)
|
||||
require.Equal(t, updatedChat, currentChat)
|
||||
}
|
||||
|
||||
func TestTurnWorkspaceContextGetWorkspaceConnLazyValidationSwitchesWorkspaceAgent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
|
||||
workspaceID := uuid.New()
|
||||
staleAgentID := uuid.New()
|
||||
currentAgentID := uuid.New()
|
||||
buildID := uuid.New()
|
||||
chat := database.Chat{
|
||||
ID: uuid.New(),
|
||||
WorkspaceID: uuid.NullUUID{
|
||||
UUID: workspaceID,
|
||||
Valid: true,
|
||||
},
|
||||
AgentID: uuid.NullUUID{
|
||||
UUID: staleAgentID,
|
||||
Valid: true,
|
||||
},
|
||||
}
|
||||
staleAgent := database.WorkspaceAgent{ID: staleAgentID}
|
||||
currentAgent := database.WorkspaceAgent{ID: currentAgentID}
|
||||
updatedChat := chat
|
||||
updatedChat.BuildID = uuid.NullUUID{UUID: buildID, Valid: true}
|
||||
updatedChat.AgentID = uuid.NullUUID{UUID: currentAgentID, Valid: true}
|
||||
|
||||
gomock.InOrder(
|
||||
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), staleAgentID).Return(staleAgent, nil),
|
||||
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).Return([]database.WorkspaceAgent{currentAgent}, nil),
|
||||
db.EXPECT().GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), workspaceID).Return(database.WorkspaceBuild{ID: buildID}, nil),
|
||||
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), currentAgentID).Return(currentAgent, nil),
|
||||
db.EXPECT().UpdateChatBuildAgentBinding(gomock.Any(), database.UpdateChatBuildAgentBindingParams{
|
||||
BuildID: uuid.NullUUID{UUID: buildID, Valid: true},
|
||||
AgentID: uuid.NullUUID{UUID: currentAgentID, Valid: true},
|
||||
ID: chat.ID,
|
||||
}).Return(updatedChat, nil),
|
||||
)
|
||||
|
||||
conn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
@@ -216,7 +385,7 @@ func TestTurnWorkspaceContextGetWorkspaceConnRefreshesWorkspaceAgent(t *testing.
|
||||
server := &Server{db: db}
|
||||
server.agentConnFn = func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
||||
dialed = append(dialed, agentID)
|
||||
if agentID == initialAgent.ID {
|
||||
if agentID == staleAgentID {
|
||||
return nil, nil, xerrors.New("dial failed")
|
||||
}
|
||||
return conn, func() {}, nil
|
||||
@@ -235,7 +404,112 @@ func TestTurnWorkspaceContextGetWorkspaceConnRefreshesWorkspaceAgent(t *testing.
|
||||
gotConn, err := workspaceCtx.getWorkspaceConn(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Same(t, conn, gotConn)
|
||||
require.Equal(t, []uuid.UUID{initialAgent.ID, refreshedAgent.ID}, dialed)
|
||||
require.Equal(t, []uuid.UUID{staleAgentID, currentAgentID}, dialed)
|
||||
require.Equal(t, updatedChat, currentChat)
|
||||
|
||||
gotAgent, err := workspaceCtx.getWorkspaceAgent(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, currentAgent, gotAgent)
|
||||
}
|
||||
|
||||
func TestTurnWorkspaceContext_SelectWorkspaceClearsCachedState(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
currentChat := database.Chat{
|
||||
ID: uuid.New(),
|
||||
WorkspaceID: uuid.NullUUID{
|
||||
UUID: uuid.New(),
|
||||
Valid: true,
|
||||
},
|
||||
}
|
||||
updatedChat := database.Chat{
|
||||
ID: currentChat.ID,
|
||||
WorkspaceID: uuid.NullUUID{
|
||||
UUID: uuid.New(),
|
||||
Valid: true,
|
||||
},
|
||||
}
|
||||
cachedConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
releaseCalls := 0
|
||||
|
||||
workspaceCtx := turnWorkspaceContext{
|
||||
chatStateMu: &sync.Mutex{},
|
||||
currentChat: ¤tChat,
|
||||
}
|
||||
workspaceCtx.agent = database.WorkspaceAgent{ID: uuid.New()}
|
||||
workspaceCtx.agentLoaded = true
|
||||
workspaceCtx.conn = cachedConn
|
||||
workspaceCtx.cachedWorkspaceID = currentChat.WorkspaceID
|
||||
workspaceCtx.releaseConn = func() {
|
||||
releaseCalls++
|
||||
}
|
||||
|
||||
workspaceCtx.selectWorkspace(updatedChat)
|
||||
|
||||
require.Equal(t, updatedChat, currentChat)
|
||||
require.Equal(t, 1, releaseCalls)
|
||||
|
||||
workspaceCtx.mu.Lock()
|
||||
defer workspaceCtx.mu.Unlock()
|
||||
require.Equal(t, database.WorkspaceAgent{}, workspaceCtx.agent)
|
||||
require.False(t, workspaceCtx.agentLoaded)
|
||||
require.Nil(t, workspaceCtx.conn)
|
||||
require.Nil(t, workspaceCtx.releaseConn)
|
||||
require.Equal(t, uuid.NullUUID{}, workspaceCtx.cachedWorkspaceID)
|
||||
}
|
||||
|
||||
func TestTurnWorkspaceContext_EnsureWorkspaceAgentIgnoresCachedAgentForDifferentWorkspace(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
|
||||
workspaceOneID := uuid.New()
|
||||
workspaceTwoID := uuid.New()
|
||||
buildID := uuid.New()
|
||||
cachedAgent := database.WorkspaceAgent{ID: uuid.New()}
|
||||
resolvedAgent := database.WorkspaceAgent{ID: uuid.New()}
|
||||
chat := database.Chat{
|
||||
ID: uuid.New(),
|
||||
WorkspaceID: uuid.NullUUID{
|
||||
UUID: workspaceTwoID,
|
||||
Valid: true,
|
||||
},
|
||||
}
|
||||
updatedChat := chat
|
||||
updatedChat.BuildID = uuid.NullUUID{UUID: buildID, Valid: true}
|
||||
updatedChat.AgentID = uuid.NullUUID{UUID: resolvedAgent.ID, Valid: true}
|
||||
|
||||
gomock.InOrder(
|
||||
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceTwoID).Return([]database.WorkspaceAgent{resolvedAgent}, nil),
|
||||
db.EXPECT().GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), workspaceTwoID).Return(database.WorkspaceBuild{ID: buildID}, nil),
|
||||
db.EXPECT().UpdateChatBuildAgentBinding(gomock.Any(), database.UpdateChatBuildAgentBindingParams{
|
||||
ID: chat.ID,
|
||||
BuildID: uuid.NullUUID{UUID: buildID, Valid: true},
|
||||
AgentID: uuid.NullUUID{UUID: resolvedAgent.ID, Valid: true},
|
||||
}).Return(updatedChat, nil),
|
||||
)
|
||||
|
||||
chatStateMu := &sync.Mutex{}
|
||||
currentChat := chat
|
||||
workspaceCtx := turnWorkspaceContext{
|
||||
server: &Server{db: db},
|
||||
chatStateMu: chatStateMu,
|
||||
currentChat: ¤tChat,
|
||||
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
|
||||
}
|
||||
workspaceCtx.agent = cachedAgent
|
||||
workspaceCtx.agentLoaded = true
|
||||
workspaceCtx.cachedWorkspaceID = uuid.NullUUID{UUID: workspaceOneID, Valid: true}
|
||||
defer workspaceCtx.close()
|
||||
|
||||
chatSnapshot, agent, err := workspaceCtx.ensureWorkspaceAgent(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, updatedChat, chatSnapshot)
|
||||
require.Equal(t, resolvedAgent, agent)
|
||||
require.Equal(t, updatedChat, currentChat)
|
||||
}
|
||||
|
||||
func TestSubscribeSkipsDatabaseCatchupForLocallyDeliveredMessage(t *testing.T) {
|
||||
|
||||
@@ -2280,7 +2280,7 @@ func TestHeartbeatBumpsWorkspaceUsage(t *testing.T) {
|
||||
|
||||
// Link the workspace to the chat in the DB, simulating what
|
||||
// the create_workspace tool does mid-conversation.
|
||||
_, err = db.UpdateChatWorkspace(ctx, database.UpdateChatWorkspaceParams{
|
||||
_, err = db.UpdateChatWorkspaceBinding(ctx, database.UpdateChatWorkspaceBindingParams{
|
||||
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
|
||||
ID: chat.ID,
|
||||
})
|
||||
|
||||
@@ -66,6 +66,7 @@ type CreateWorkspaceOptions struct {
|
||||
AgentConnFn AgentConnFunc
|
||||
AgentInactiveDisconnectTimeout time.Duration
|
||||
WorkspaceMu *sync.Mutex
|
||||
OnChatUpdated func(database.Chat)
|
||||
Logger slog.Logger
|
||||
AllowedTemplateIDs func() map[uuid.UUID]bool
|
||||
}
|
||||
@@ -211,20 +212,29 @@ func CreateWorkspace(options CreateWorkspaceOptions) fantasy.AgentTool {
|
||||
}
|
||||
}
|
||||
|
||||
// Persist workspace + agent association on the chat.
|
||||
// Persist the workspace binding on the chat.
|
||||
if options.DB != nil && options.ChatID != uuid.Nil {
|
||||
if _, err := options.DB.UpdateChatWorkspace(ctx, database.UpdateChatWorkspaceParams{
|
||||
updatedChat, err := options.DB.UpdateChatWorkspaceBinding(ctx, database.UpdateChatWorkspaceBindingParams{
|
||||
ID: options.ChatID,
|
||||
WorkspaceID: uuid.NullUUID{
|
||||
UUID: workspace.ID,
|
||||
Valid: true,
|
||||
},
|
||||
}); err != nil {
|
||||
// BuildID and AgentID are intentionally left null
|
||||
// here. The chatd runtime (loadWorkspaceAgentLocked)
|
||||
// will bind them on the next turn. Authoritative
|
||||
// tool-path binding is deferred to a follow-up PR.
|
||||
BuildID: uuid.NullUUID{},
|
||||
AgentID: uuid.NullUUID{},
|
||||
})
|
||||
if err != nil {
|
||||
options.Logger.Error(ctx, "failed to persist chat workspace association",
|
||||
slog.F("chat_id", options.ChatID),
|
||||
slog.F("workspace_id", workspace.ID),
|
||||
slog.Error(err),
|
||||
)
|
||||
} else if options.OnChatUpdated != nil {
|
||||
options.OnChatUpdated(updatedChat)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,170 @@
|
||||
package chatd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
)
|
||||
|
||||
// DialResult contains the outcome of dialWithLazyValidation.
|
||||
type DialResult struct {
|
||||
Conn workspacesdk.AgentConn
|
||||
Release func()
|
||||
AgentID uuid.UUID // The agent that was actually dialed.
|
||||
WasSwitched bool // True if validation discovered a different agent.
|
||||
}
|
||||
|
||||
// DialFunc dials an agent by ID and returns a connection.
|
||||
type DialFunc func(ctx context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error)
|
||||
|
||||
// ValidateFunc returns the current agent ID for a workspace.
|
||||
type ValidateFunc func(ctx context.Context, workspaceID uuid.UUID) (uuid.UUID, error)
|
||||
|
||||
type dialOut struct {
|
||||
conn workspacesdk.AgentConn
|
||||
release func()
|
||||
err error
|
||||
}
|
||||
|
||||
// dialWithLazyValidation dials an agent and only consults the database if the
|
||||
// original dial is slow or fails quickly. This keeps the common path free of
|
||||
// latest-build lookups while still repairing stale bindings.
|
||||
//
|
||||
// Outcomes:
|
||||
// - The dial succeeds before delay, so validation is skipped.
|
||||
// - The timer fires and validation confirms the same agent, so the original
|
||||
// dial continues.
|
||||
// - The timer fires and validation finds a different agent, so the stale
|
||||
// dial is canceled and the new agent is dialed instead.
|
||||
// - The dial fails before delay, so validation runs immediately and either
|
||||
// switches to a different agent or retries the current one once.
|
||||
func dialWithLazyValidation(
|
||||
ctx context.Context,
|
||||
agentID uuid.UUID,
|
||||
workspaceID uuid.UUID,
|
||||
dialFn DialFunc,
|
||||
validateFn ValidateFunc,
|
||||
delay time.Duration,
|
||||
) (DialResult, error) {
|
||||
wrapErr := func(err error) error {
|
||||
return xerrors.Errorf("dial with lazy validation: %w", err)
|
||||
}
|
||||
|
||||
dialCtx, dialCancel := context.WithCancel(ctx)
|
||||
results := make(chan dialOut, 1)
|
||||
go func() {
|
||||
conn, release, err := dialFn(dialCtx, agentID)
|
||||
results <- dialOut{conn: conn, release: release, err: err}
|
||||
}()
|
||||
|
||||
drained := false
|
||||
defer func() {
|
||||
dialCancel()
|
||||
if drained {
|
||||
return
|
||||
}
|
||||
// Drain without blocking the caller. dialFn may take time to honor
|
||||
// cancellation, but any late-arriving successful connection still needs to
|
||||
// be released.
|
||||
go func() {
|
||||
result := <-results
|
||||
if result.err == nil && result.release != nil {
|
||||
result.release()
|
||||
}
|
||||
}()
|
||||
}()
|
||||
|
||||
resultForAgent := func(dialedAgentID uuid.UUID, result dialOut, switched bool) DialResult {
|
||||
return DialResult{
|
||||
Conn: result.conn,
|
||||
Release: result.release,
|
||||
AgentID: dialedAgentID,
|
||||
WasSwitched: switched,
|
||||
}
|
||||
}
|
||||
dialAgent := func(targetAgentID uuid.UUID, switched bool) (DialResult, error) {
|
||||
conn, release, err := dialFn(ctx, targetAgentID)
|
||||
if err != nil {
|
||||
return DialResult{}, wrapErr(err)
|
||||
}
|
||||
return resultForAgent(targetAgentID, dialOut{conn: conn, release: release}, switched), nil
|
||||
}
|
||||
preferReadyOriginalDial := func() (DialResult, bool) {
|
||||
select {
|
||||
case result := <-results:
|
||||
drained = true
|
||||
if result.err != nil {
|
||||
return DialResult{}, false
|
||||
}
|
||||
return resultForAgent(agentID, result, false), true
|
||||
default:
|
||||
return DialResult{}, false
|
||||
}
|
||||
}
|
||||
waitForOriginalDial := func(waitCtx context.Context) (DialResult, error) {
|
||||
select {
|
||||
case result := <-results:
|
||||
drained = true
|
||||
if result.err != nil {
|
||||
return DialResult{}, wrapErr(result.err)
|
||||
}
|
||||
return resultForAgent(agentID, result, false), nil
|
||||
case <-waitCtx.Done():
|
||||
if ready, ok := preferReadyOriginalDial(); ok {
|
||||
return ready, nil
|
||||
}
|
||||
return DialResult{}, waitCtx.Err()
|
||||
}
|
||||
}
|
||||
validateBinding := func() (uuid.UUID, error) {
|
||||
validatedAgentID, err := validateFn(ctx, workspaceID)
|
||||
if err != nil {
|
||||
return uuid.Nil, wrapErr(err)
|
||||
}
|
||||
return validatedAgentID, nil
|
||||
}
|
||||
resolveFastFailure := func() (DialResult, error) {
|
||||
validatedAgentID, err := validateBinding()
|
||||
if err != nil {
|
||||
return DialResult{}, err
|
||||
}
|
||||
if validatedAgentID == agentID {
|
||||
return dialAgent(agentID, false)
|
||||
}
|
||||
return dialAgent(validatedAgentID, true)
|
||||
}
|
||||
|
||||
timer := time.NewTimer(delay)
|
||||
defer timer.Stop()
|
||||
|
||||
select {
|
||||
case result := <-results:
|
||||
drained = true
|
||||
if result.err == nil {
|
||||
return resultForAgent(agentID, result, false), nil
|
||||
}
|
||||
return resolveFastFailure()
|
||||
|
||||
case <-timer.C:
|
||||
validatedAgentID, validationErr := validateFn(ctx, workspaceID)
|
||||
if validationErr != nil || validatedAgentID == agentID {
|
||||
// Validation could not prove the binding was stale, so keep waiting on
|
||||
// the original dial.
|
||||
return waitForOriginalDial(ctx)
|
||||
}
|
||||
// The original dial is stale. Cancel it first, then let the deferred drain
|
||||
// release any late result while we dial the validated agent immediately.
|
||||
dialCancel()
|
||||
return dialAgent(validatedAgentID, true)
|
||||
|
||||
case <-ctx.Done():
|
||||
if ready, ok := preferReadyOriginalDial(); ok {
|
||||
return ready, nil
|
||||
}
|
||||
return DialResult{}, ctx.Err()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,563 @@
|
||||
package chatd //nolint:testpackage // Uses internal symbols.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestDialWithLazyValidation_FastDial(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
agentID := uuid.New()
|
||||
workspaceID := uuid.New()
|
||||
conn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
|
||||
var releaseCalls atomic.Int32
|
||||
var validateCalls atomic.Int32
|
||||
|
||||
result, err := dialWithLazyValidation(
|
||||
context.Background(),
|
||||
agentID,
|
||||
workspaceID,
|
||||
func(_ context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
||||
if id != agentID {
|
||||
return nil, nil, xerrors.Errorf("unexpected agent ID %q", id)
|
||||
}
|
||||
return conn, func() {
|
||||
releaseCalls.Add(1)
|
||||
}, nil
|
||||
},
|
||||
func(_ context.Context, id uuid.UUID) (uuid.UUID, error) {
|
||||
validateCalls.Add(1)
|
||||
return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id)
|
||||
},
|
||||
time.Minute,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Same(t, conn, result.Conn)
|
||||
require.Equal(t, agentID, result.AgentID)
|
||||
require.False(t, result.WasSwitched)
|
||||
require.EqualValues(t, 0, validateCalls.Load())
|
||||
|
||||
if result.Release != nil {
|
||||
result.Release()
|
||||
}
|
||||
require.EqualValues(t, 1, releaseCalls.Load())
|
||||
}
|
||||
|
||||
func TestDialWithLazyValidation_SlowDialSameAgent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
agentID := uuid.New()
|
||||
workspaceID := uuid.New()
|
||||
conn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
unblockDial := make(chan struct{})
|
||||
|
||||
var releaseCalls atomic.Int32
|
||||
var validateCalls atomic.Int32
|
||||
|
||||
result, err := dialWithLazyValidation(
|
||||
context.Background(),
|
||||
agentID,
|
||||
workspaceID,
|
||||
func(ctx context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
||||
if id != agentID {
|
||||
return nil, nil, xerrors.Errorf("unexpected agent ID %q", id)
|
||||
}
|
||||
select {
|
||||
case <-unblockDial:
|
||||
return conn, func() {
|
||||
releaseCalls.Add(1)
|
||||
}, nil
|
||||
case <-ctx.Done():
|
||||
return nil, nil, ctx.Err()
|
||||
}
|
||||
},
|
||||
func(_ context.Context, id uuid.UUID) (uuid.UUID, error) {
|
||||
if id != workspaceID {
|
||||
return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id)
|
||||
}
|
||||
validateCalls.Add(1)
|
||||
close(unblockDial)
|
||||
return agentID, nil
|
||||
},
|
||||
0,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Same(t, conn, result.Conn)
|
||||
require.Equal(t, agentID, result.AgentID)
|
||||
require.False(t, result.WasSwitched)
|
||||
require.EqualValues(t, 1, validateCalls.Load())
|
||||
|
||||
if result.Release != nil {
|
||||
result.Release()
|
||||
}
|
||||
require.EqualValues(t, 1, releaseCalls.Load())
|
||||
}
|
||||
|
||||
func TestDialWithLazyValidation_SlowDialStaleAgent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("LateSuccessReleasesStaleConn", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
staleAgentID := uuid.New()
|
||||
currentAgentID := uuid.New()
|
||||
workspaceID := uuid.New()
|
||||
staleConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
currentConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
|
||||
var dialCalls atomic.Int32
|
||||
var validateCalls atomic.Int32
|
||||
var staleReleaseCalls atomic.Int32
|
||||
var currentReleaseCalls atomic.Int32
|
||||
|
||||
result, err := dialWithLazyValidation(
|
||||
context.Background(),
|
||||
staleAgentID,
|
||||
workspaceID,
|
||||
func(ctx context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
||||
dialCalls.Add(1)
|
||||
switch id {
|
||||
case staleAgentID:
|
||||
<-ctx.Done()
|
||||
return staleConn, func() {
|
||||
staleReleaseCalls.Add(1)
|
||||
}, nil
|
||||
case currentAgentID:
|
||||
return currentConn, func() {
|
||||
currentReleaseCalls.Add(1)
|
||||
}, nil
|
||||
default:
|
||||
return nil, nil, xerrors.Errorf("unexpected agent ID %q", id)
|
||||
}
|
||||
},
|
||||
func(_ context.Context, id uuid.UUID) (uuid.UUID, error) {
|
||||
if id != workspaceID {
|
||||
return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id)
|
||||
}
|
||||
validateCalls.Add(1)
|
||||
return currentAgentID, nil
|
||||
},
|
||||
0,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Same(t, currentConn, result.Conn)
|
||||
require.Equal(t, currentAgentID, result.AgentID)
|
||||
require.True(t, result.WasSwitched)
|
||||
require.Eventually(t, func() bool {
|
||||
return dialCalls.Load() == 2
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
require.EqualValues(t, 1, validateCalls.Load())
|
||||
require.Eventually(t, func() bool {
|
||||
return staleReleaseCalls.Load() == 1
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
|
||||
if result.Release != nil {
|
||||
result.Release()
|
||||
}
|
||||
require.EqualValues(t, 1, currentReleaseCalls.Load())
|
||||
})
|
||||
|
||||
t.Run("CanceledFailureDoesNotReleaseStaleConn", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
staleAgentID := uuid.New()
|
||||
currentAgentID := uuid.New()
|
||||
workspaceID := uuid.New()
|
||||
currentConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
|
||||
var dialCalls atomic.Int32
|
||||
var validateCalls atomic.Int32
|
||||
var staleReleaseCalls atomic.Int32
|
||||
var currentReleaseCalls atomic.Int32
|
||||
|
||||
result, err := dialWithLazyValidation(
|
||||
context.Background(),
|
||||
staleAgentID,
|
||||
workspaceID,
|
||||
func(ctx context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
||||
dialCalls.Add(1)
|
||||
switch id {
|
||||
case staleAgentID:
|
||||
<-ctx.Done()
|
||||
return nil, func() {
|
||||
staleReleaseCalls.Add(1)
|
||||
}, ctx.Err()
|
||||
case currentAgentID:
|
||||
return currentConn, func() {
|
||||
currentReleaseCalls.Add(1)
|
||||
}, nil
|
||||
default:
|
||||
return nil, nil, xerrors.Errorf("unexpected agent ID %q", id)
|
||||
}
|
||||
},
|
||||
func(_ context.Context, id uuid.UUID) (uuid.UUID, error) {
|
||||
if id != workspaceID {
|
||||
return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id)
|
||||
}
|
||||
validateCalls.Add(1)
|
||||
return currentAgentID, nil
|
||||
},
|
||||
0,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Same(t, currentConn, result.Conn)
|
||||
require.Equal(t, currentAgentID, result.AgentID)
|
||||
require.True(t, result.WasSwitched)
|
||||
require.Eventually(t, func() bool {
|
||||
return dialCalls.Load() == 2
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
require.EqualValues(t, 1, validateCalls.Load())
|
||||
require.EqualValues(t, 0, staleReleaseCalls.Load())
|
||||
|
||||
if result.Release != nil {
|
||||
result.Release()
|
||||
}
|
||||
require.EqualValues(t, 1, currentReleaseCalls.Load())
|
||||
})
|
||||
|
||||
t.Run("SwitchDoesNotBlock", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
staleAgentID := uuid.New()
|
||||
currentAgentID := uuid.New()
|
||||
workspaceID := uuid.New()
|
||||
staleConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
currentConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
staleDialStarted := make(chan struct{})
|
||||
allowStaleReturn := make(chan struct{})
|
||||
|
||||
var dialCalls atomic.Int32
|
||||
var validateCalls atomic.Int32
|
||||
var staleReleaseCalls atomic.Int32
|
||||
var currentReleaseCalls atomic.Int32
|
||||
var staleReturnReleased atomic.Bool
|
||||
releaseStaleReturn := func() {
|
||||
if staleReturnReleased.CompareAndSwap(false, true) {
|
||||
close(allowStaleReturn)
|
||||
}
|
||||
}
|
||||
defer releaseStaleReturn()
|
||||
|
||||
resultCh := make(chan DialResult, 1)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
result, err := dialWithLazyValidation(
|
||||
context.Background(),
|
||||
staleAgentID,
|
||||
workspaceID,
|
||||
func(_ context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
||||
dialCalls.Add(1)
|
||||
switch id {
|
||||
case staleAgentID:
|
||||
close(staleDialStarted)
|
||||
<-allowStaleReturn
|
||||
return staleConn, func() {
|
||||
staleReleaseCalls.Add(1)
|
||||
}, nil
|
||||
case currentAgentID:
|
||||
return currentConn, func() {
|
||||
currentReleaseCalls.Add(1)
|
||||
}, nil
|
||||
default:
|
||||
return nil, nil, xerrors.Errorf("unexpected agent ID %q", id)
|
||||
}
|
||||
},
|
||||
func(_ context.Context, id uuid.UUID) (uuid.UUID, error) {
|
||||
if id != workspaceID {
|
||||
return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id)
|
||||
}
|
||||
<-staleDialStarted
|
||||
validateCalls.Add(1)
|
||||
return currentAgentID, nil
|
||||
},
|
||||
0,
|
||||
)
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
resultCh <- result
|
||||
}()
|
||||
|
||||
var result DialResult
|
||||
select {
|
||||
case err := <-errCh:
|
||||
require.NoError(t, err)
|
||||
case result = <-resultCh:
|
||||
require.Same(t, currentConn, result.Conn)
|
||||
require.Equal(t, currentAgentID, result.AgentID)
|
||||
require.True(t, result.WasSwitched)
|
||||
releaseStaleReturn()
|
||||
case <-time.After(testutil.WaitShort):
|
||||
t.Fatal("dialWithLazyValidation blocked on stale dial cleanup")
|
||||
}
|
||||
|
||||
require.EqualValues(t, 2, dialCalls.Load())
|
||||
require.EqualValues(t, 1, validateCalls.Load())
|
||||
require.Eventually(t, func() bool {
|
||||
return staleReleaseCalls.Load() == 1
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
|
||||
if result.Release != nil {
|
||||
result.Release()
|
||||
}
|
||||
require.EqualValues(t, 1, currentReleaseCalls.Load())
|
||||
})
|
||||
}
|
||||
|
||||
func TestDialWithLazyValidation_FastFailure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
staleAgentID := uuid.New()
|
||||
currentAgentID := uuid.New()
|
||||
workspaceID := uuid.New()
|
||||
currentConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
|
||||
var dialCalls atomic.Int32
|
||||
var validateCalls atomic.Int32
|
||||
var currentReleaseCalls atomic.Int32
|
||||
|
||||
result, err := dialWithLazyValidation(
|
||||
context.Background(),
|
||||
staleAgentID,
|
||||
workspaceID,
|
||||
func(_ context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
||||
switch dialCalls.Add(1) {
|
||||
case 1:
|
||||
if id != staleAgentID {
|
||||
return nil, nil, xerrors.Errorf("unexpected agent ID %q", id)
|
||||
}
|
||||
return nil, nil, xerrors.New("dial failed")
|
||||
case 2:
|
||||
if id != currentAgentID {
|
||||
return nil, nil, xerrors.Errorf("unexpected agent ID %q", id)
|
||||
}
|
||||
return currentConn, func() {
|
||||
currentReleaseCalls.Add(1)
|
||||
}, nil
|
||||
default:
|
||||
return nil, nil, xerrors.New("unexpected dial call")
|
||||
}
|
||||
},
|
||||
func(_ context.Context, id uuid.UUID) (uuid.UUID, error) {
|
||||
if id != workspaceID {
|
||||
return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id)
|
||||
}
|
||||
validateCalls.Add(1)
|
||||
return currentAgentID, nil
|
||||
},
|
||||
time.Minute,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Same(t, currentConn, result.Conn)
|
||||
require.Equal(t, currentAgentID, result.AgentID)
|
||||
require.True(t, result.WasSwitched)
|
||||
require.EqualValues(t, 2, dialCalls.Load())
|
||||
require.EqualValues(t, 1, validateCalls.Load())
|
||||
|
||||
if result.Release != nil {
|
||||
result.Release()
|
||||
}
|
||||
require.EqualValues(t, 1, currentReleaseCalls.Load())
|
||||
}
|
||||
|
||||
func TestDialWithLazyValidation_FastFailureSameAgent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
agentID := uuid.New()
|
||||
workspaceID := uuid.New()
|
||||
conn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
|
||||
var dialCalls atomic.Int32
|
||||
var releaseCalls atomic.Int32
|
||||
var validateCalls atomic.Int32
|
||||
|
||||
result, err := dialWithLazyValidation(
|
||||
context.Background(),
|
||||
agentID,
|
||||
workspaceID,
|
||||
func(_ context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
||||
if id != agentID {
|
||||
return nil, nil, xerrors.Errorf("unexpected agent ID %q", id)
|
||||
}
|
||||
switch dialCalls.Add(1) {
|
||||
case 1:
|
||||
return nil, nil, xerrors.New("dial failed")
|
||||
case 2:
|
||||
return conn, func() {
|
||||
releaseCalls.Add(1)
|
||||
}, nil
|
||||
default:
|
||||
return nil, nil, xerrors.New("unexpected dial call")
|
||||
}
|
||||
},
|
||||
func(_ context.Context, id uuid.UUID) (uuid.UUID, error) {
|
||||
if id != workspaceID {
|
||||
return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id)
|
||||
}
|
||||
validateCalls.Add(1)
|
||||
return agentID, nil
|
||||
},
|
||||
time.Minute,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Same(t, conn, result.Conn)
|
||||
require.Equal(t, agentID, result.AgentID)
|
||||
require.False(t, result.WasSwitched)
|
||||
require.EqualValues(t, 2, dialCalls.Load())
|
||||
require.EqualValues(t, 1, validateCalls.Load())
|
||||
|
||||
if result.Release != nil {
|
||||
result.Release()
|
||||
}
|
||||
require.EqualValues(t, 1, releaseCalls.Load())
|
||||
}
|
||||
|
||||
func TestDialWithLazyValidation_FastFailureSameAgentRetryFails(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
agentID := uuid.New()
|
||||
workspaceID := uuid.New()
|
||||
|
||||
var dialCalls atomic.Int32
|
||||
var validateCalls atomic.Int32
|
||||
|
||||
_, err := dialWithLazyValidation(
|
||||
context.Background(),
|
||||
agentID,
|
||||
workspaceID,
|
||||
func(_ context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
||||
if id != agentID {
|
||||
return nil, nil, xerrors.Errorf("unexpected agent ID %q", id)
|
||||
}
|
||||
switch dialCalls.Add(1) {
|
||||
case 1:
|
||||
return nil, nil, xerrors.New("dial failed")
|
||||
case 2:
|
||||
return nil, nil, xerrors.New("retry failed")
|
||||
default:
|
||||
return nil, nil, xerrors.New("unexpected dial call")
|
||||
}
|
||||
},
|
||||
func(_ context.Context, id uuid.UUID) (uuid.UUID, error) {
|
||||
if id != workspaceID {
|
||||
return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id)
|
||||
}
|
||||
validateCalls.Add(1)
|
||||
return agentID, nil
|
||||
},
|
||||
time.Minute,
|
||||
)
|
||||
require.EqualError(t, err, "dial with lazy validation: retry failed")
|
||||
require.EqualValues(t, 2, dialCalls.Load())
|
||||
require.EqualValues(t, 1, validateCalls.Load())
|
||||
}
|
||||
|
||||
func TestDialWithLazyValidation_ValidationError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
agentID := uuid.New()
|
||||
workspaceID := uuid.New()
|
||||
conn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
unblockDial := make(chan struct{})
|
||||
|
||||
var releaseCalls atomic.Int32
|
||||
var validateCalls atomic.Int32
|
||||
|
||||
result, err := dialWithLazyValidation(
|
||||
context.Background(),
|
||||
agentID,
|
||||
workspaceID,
|
||||
func(ctx context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
||||
if id != agentID {
|
||||
return nil, nil, xerrors.Errorf("unexpected agent ID %q", id)
|
||||
}
|
||||
select {
|
||||
case <-unblockDial:
|
||||
return conn, func() {
|
||||
releaseCalls.Add(1)
|
||||
}, nil
|
||||
case <-ctx.Done():
|
||||
return nil, nil, ctx.Err()
|
||||
}
|
||||
},
|
||||
func(_ context.Context, id uuid.UUID) (uuid.UUID, error) {
|
||||
if id != workspaceID {
|
||||
return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id)
|
||||
}
|
||||
validateCalls.Add(1)
|
||||
// Validation fails — code should fall back to waiting
|
||||
// for the original dial.
|
||||
close(unblockDial)
|
||||
return uuid.Nil, xerrors.New("db connection reset")
|
||||
},
|
||||
0,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Same(t, conn, result.Conn)
|
||||
require.Equal(t, agentID, result.AgentID)
|
||||
require.False(t, result.WasSwitched)
|
||||
require.EqualValues(t, 1, validateCalls.Load())
|
||||
|
||||
if result.Release != nil {
|
||||
result.Release()
|
||||
}
|
||||
require.EqualValues(t, 1, releaseCalls.Load())
|
||||
}
|
||||
|
||||
func TestDialWithLazyValidation_ContextCanceled(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
agentID := uuid.New()
|
||||
workspaceID := uuid.New()
|
||||
|
||||
var validateCalls atomic.Int32
|
||||
|
||||
_, err := dialWithLazyValidation(
|
||||
ctx,
|
||||
agentID,
|
||||
workspaceID,
|
||||
func(ctx context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
||||
if id != agentID {
|
||||
return nil, nil, xerrors.Errorf("unexpected agent ID %q", id)
|
||||
}
|
||||
<-ctx.Done()
|
||||
return nil, nil, ctx.Err()
|
||||
},
|
||||
func(_ context.Context, id uuid.UUID) (uuid.UUID, error) {
|
||||
if id != workspaceID {
|
||||
return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id)
|
||||
}
|
||||
validateCalls.Add(1)
|
||||
cancel()
|
||||
return agentID, nil
|
||||
},
|
||||
0,
|
||||
)
|
||||
require.ErrorIs(t, err, context.Canceled)
|
||||
require.EqualValues(t, 1, validateCalls.Load())
|
||||
}
|
||||
@@ -313,6 +313,8 @@ func (p *Server) subagentTools(ctx context.Context, currentChat func() database.
|
||||
childChat, err := p.CreateChat(ctx, CreateOptions{
|
||||
OwnerID: parent.OwnerID,
|
||||
WorkspaceID: parent.WorkspaceID,
|
||||
BuildID: parent.BuildID,
|
||||
AgentID: parent.AgentID,
|
||||
ParentChatID: uuid.NullUUID{
|
||||
UUID: parent.ID,
|
||||
Valid: true,
|
||||
@@ -383,6 +385,8 @@ func (p *Server) createChildSubagentChat(
|
||||
child, err := p.CreateChat(ctx, CreateOptions{
|
||||
OwnerID: parent.OwnerID,
|
||||
WorkspaceID: parent.WorkspaceID,
|
||||
BuildID: parent.BuildID,
|
||||
AgentID: parent.AgentID,
|
||||
ParentChatID: uuid.NullUUID{
|
||||
UUID: parent.ID,
|
||||
Valid: true,
|
||||
|
||||
@@ -149,6 +149,45 @@ func seedInternalChatDeps(
|
||||
return user, model
|
||||
}
|
||||
|
||||
func seedWorkspaceBinding(
|
||||
t *testing.T,
|
||||
db database.Store,
|
||||
userID uuid.UUID,
|
||||
) (database.WorkspaceTable, database.WorkspaceBuild, database.WorkspaceAgent) {
|
||||
t.Helper()
|
||||
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
tv := dbgen.TemplateVersion(t, db, database.TemplateVersion{
|
||||
OrganizationID: org.ID,
|
||||
CreatedBy: userID,
|
||||
})
|
||||
tpl := dbgen.Template(t, db, database.Template{
|
||||
CreatedBy: userID,
|
||||
OrganizationID: org.ID,
|
||||
ActiveVersionID: tv.ID,
|
||||
})
|
||||
workspace := dbgen.Workspace(t, db, database.WorkspaceTable{
|
||||
TemplateID: tpl.ID,
|
||||
OwnerID: userID,
|
||||
OrganizationID: org.ID,
|
||||
})
|
||||
job := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
||||
InitiatorID: userID,
|
||||
OrganizationID: org.ID,
|
||||
})
|
||||
build := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
||||
TemplateVersionID: tv.ID,
|
||||
WorkspaceID: workspace.ID,
|
||||
JobID: job.ID,
|
||||
})
|
||||
resource := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
||||
Transition: database.WorkspaceTransitionStart,
|
||||
JobID: job.ID,
|
||||
})
|
||||
agent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{ResourceID: resource.ID})
|
||||
return workspace, build, agent
|
||||
}
|
||||
|
||||
// findToolByName returns the tool with the given name from the
|
||||
// slice, or nil if no match is found.
|
||||
func findToolByName(tools []fantasy.AgentTool, name string) fantasy.AgentTool {
|
||||
@@ -165,6 +204,49 @@ func chatdTestContext(t *testing.T) context.Context {
|
||||
return dbauthz.AsChatd(testutil.Context(t, testutil.WaitLong))
|
||||
}
|
||||
|
||||
func TestCreateChildSubagentChatInheritsWorkspaceBinding(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
|
||||
|
||||
ctx := chatdTestContext(t)
|
||||
user, model := seedInternalChatDeps(ctx, t, db)
|
||||
workspace, build, agent := seedWorkspaceBinding(t, db, user.ID)
|
||||
|
||||
parent, err := server.CreateChat(ctx, CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
WorkspaceID: uuid.NullUUID{
|
||||
UUID: workspace.ID,
|
||||
Valid: true,
|
||||
},
|
||||
BuildID: uuid.NullUUID{
|
||||
UUID: build.ID,
|
||||
Valid: true,
|
||||
},
|
||||
AgentID: uuid.NullUUID{
|
||||
UUID: agent.ID,
|
||||
Valid: true,
|
||||
},
|
||||
Title: "bound-parent",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
parentChat, err := db.GetChatByID(ctx, parent.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
child, err := server.createChildSubagentChat(ctx, parentChat, "inspect bindings", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
childChat, err := db.GetChatByID(ctx, child.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, parentChat.WorkspaceID, childChat.WorkspaceID)
|
||||
require.Equal(t, parentChat.BuildID, childChat.BuildID)
|
||||
require.Equal(t, parentChat.AgentID, childChat.AgentID)
|
||||
}
|
||||
|
||||
func TestSpawnComputerUseAgent_NoAnthropicProvider(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -292,13 +374,26 @@ func TestSpawnComputerUseAgent_UsesComputerUseModelNotParent(t *testing.T) {
|
||||
|
||||
ctx := chatdTestContext(t)
|
||||
user, model := seedInternalChatDeps(ctx, t, db)
|
||||
workspace, build, agent := seedWorkspaceBinding(t, db, user.ID)
|
||||
|
||||
// The parent uses an OpenAI model.
|
||||
require.Equal(t, "openai", model.Provider,
|
||||
"seed helper must create an OpenAI model")
|
||||
|
||||
parent, err := server.CreateChat(ctx, CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
OwnerID: user.ID,
|
||||
WorkspaceID: uuid.NullUUID{
|
||||
UUID: workspace.ID,
|
||||
Valid: true,
|
||||
},
|
||||
BuildID: uuid.NullUUID{
|
||||
UUID: build.ID,
|
||||
Valid: true,
|
||||
},
|
||||
AgentID: uuid.NullUUID{
|
||||
UUID: agent.ID,
|
||||
Valid: true,
|
||||
},
|
||||
Title: "parent-openai",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
||||
@@ -332,6 +427,10 @@ func TestSpawnComputerUseAgent_UsesComputerUseModelNotParent(t *testing.T) {
|
||||
childChat, err := db.GetChatByID(ctx, childID)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, parentChat.WorkspaceID, childChat.WorkspaceID)
|
||||
require.Equal(t, parentChat.BuildID, childChat.BuildID)
|
||||
require.Equal(t, parentChat.AgentID, childChat.AgentID)
|
||||
|
||||
// The child must have Mode=computer_use which causes
|
||||
// runChat to override the model to the predefined computer
|
||||
// use model instead of using the parent's model config.
|
||||
|
||||
@@ -49,6 +49,8 @@ type Chat struct {
|
||||
ID uuid.UUID `json:"id" format:"uuid"`
|
||||
OwnerID uuid.UUID `json:"owner_id" format:"uuid"`
|
||||
WorkspaceID *uuid.UUID `json:"workspace_id,omitempty" format:"uuid"`
|
||||
BuildID *uuid.UUID `json:"build_id,omitempty" format:"uuid"`
|
||||
AgentID *uuid.UUID `json:"agent_id,omitempty" format:"uuid"`
|
||||
ParentChatID *uuid.UUID `json:"parent_chat_id,omitempty" format:"uuid"`
|
||||
RootChatID *uuid.UUID `json:"root_chat_id,omitempty" format:"uuid"`
|
||||
LastModelConfigID uuid.UUID `json:"last_model_config_id" format:"uuid"`
|
||||
|
||||
Generated
+2
@@ -1089,6 +1089,8 @@ export interface Chat {
|
||||
readonly id: string;
|
||||
readonly owner_id: string;
|
||||
readonly workspace_id?: string;
|
||||
readonly build_id?: string;
|
||||
readonly agent_id?: string;
|
||||
readonly parent_chat_id?: string;
|
||||
readonly root_chat_id?: string;
|
||||
readonly last_model_config_id: string;
|
||||
|
||||
Reference in New Issue
Block a user