From 61e31ec5cc664b82477b2e534c25ceebaf5909f2 Mon Sep 17 00:00:00 2001 From: Ethan <39577870+ethanndickson@users.noreply.github.com> Date: Thu, 26 Mar 2026 17:22:38 +1100 Subject: [PATCH] perf(coderd/x/chatd): persist workspace agent binding across chat turns (#23274) ## Summary This change removes the steady-state "resolve the latest workspace agent" query from chat execution. Instead of asking the database for the latest build's agent on every turn, a chat now persists the workspace/build/agent binding it actually uses and reuses that binding across subsequent turns. The common path becomes "load the bound agent by ID and dial it", with fallback paths to repair the binding when it is missing, stale, or intentionally changed. ## What changes - add `workspace_id`, `build_id`, and `agent_id` binding fields to `chats` - expose those fields through the chat API / SDK so the execution context is explicit - load the persisted binding first in chatd, instead of always resolving the latest build's agent - persist a refreshed binding when chatd has to re-resolve the workspace agent - keep child / subagent chats on the same bound workspace context by inheriting the parent binding - leave `build_id` / `agent_id` unset for flows like `create_workspace`, then bind them lazily on the next agent-backed turn ## Runtime behavior The binding is treated as an optimistic cache of the agent a chat should use: - if the bound agent still exists and dials successfully, we use it without a latest-build lookup - if the bound agent is missing or no longer reachable, chatd re-resolves against the latest build and persists the new binding - if a workspace mutation changes the chat's target workspace, the binding is updated as part of that mutation To avoid reintroducing a hot-path query, dialing uses lazy validation: - start dialing the cached agent immediately - only validate against the latest build if the dial is still pending after a short delay - if validation finds a different agent, cancel the stale dial, switch to the current agent, and persist the repaired binding ## Result The hot path stops issuing `GetWorkspaceAgentsInLatestBuildByWorkspaceID` for every user message, which is the source of the DB pressure this PR is addressing. At the same time, chats still converge to the correct workspace agent when the binding becomes stale due to rebuilds or explicit workspace changes. --- coderd/database/dbauthz/dbauthz.go | 24 +- coderd/database/dbauthz/dbauthz_test.go | 24 +- coderd/database/dbmetrics/querymetrics.go | 16 +- coderd/database/dbmock/dbmock.go | 27 +- coderd/database/dump.sql | 10 +- coderd/database/foreign_key_constraint.go | 2 + .../000452_chat_workspace_binding.down.sql | 3 + .../000452_chat_workspace_binding.up.sql | 3 + coderd/database/modelqueries.go | 2 + coderd/database/models.go | 2 + coderd/database/querier.go | 3 +- coderd/database/queries.sql.go | 129 +++- coderd/database/queries/chats.sql | 23 +- coderd/exp_chats.go | 6 + coderd/x/chatd/chatd.go | 392 +++++++++--- coderd/x/chatd/chatd_internal_test.go | 308 +++++++++- coderd/x/chatd/chatd_test.go | 2 +- coderd/x/chatd/chattool/createworkspace.go | 16 +- coderd/x/chatd/dialvalidation.go | 170 ++++++ coderd/x/chatd/dialvalidation_test.go | 563 ++++++++++++++++++ coderd/x/chatd/subagent.go | 4 + coderd/x/chatd/subagent_internal_test.go | 101 +++- codersdk/chats.go | 2 + site/src/api/typesGenerated.ts | 2 + 24 files changed, 1655 insertions(+), 179 deletions(-) create mode 100644 coderd/database/migrations/000452_chat_workspace_binding.down.sql create mode 100644 coderd/database/migrations/000452_chat_workspace_binding.up.sql create mode 100644 coderd/x/chatd/dialvalidation.go create mode 100644 coderd/x/chatd/dialvalidation_test.go diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 2d704add79..c2b29ce64e 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -5619,6 +5619,18 @@ func (q *querier) UpdateAPIKeyByID(ctx context.Context, arg database.UpdateAPIKe return update(q.log, q.auth, fetch, q.db.UpdateAPIKeyByID)(ctx, arg) } +func (q *querier) UpdateChatBuildAgentBinding(ctx context.Context, arg database.UpdateChatBuildAgentBindingParams) (database.Chat, error) { + chat, err := q.db.GetChatByID(ctx, arg.ID) + if err != nil { + return database.Chat{}, err + } + if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil { + return database.Chat{}, err + } + + return q.db.UpdateChatBuildAgentBinding(ctx, arg) +} + func (q *querier) UpdateChatByID(ctx context.Context, arg database.UpdateChatByIDParams) (database.Chat, error) { chat, err := q.db.GetChatByID(ctx, arg.ID) if err != nil { @@ -5706,7 +5718,7 @@ func (q *querier) UpdateChatStatus(ctx context.Context, arg database.UpdateChatS return q.db.UpdateChatStatus(ctx, arg) } -func (q *querier) UpdateChatWorkspace(ctx context.Context, arg database.UpdateChatWorkspaceParams) (database.Chat, error) { +func (q *querier) UpdateChatWorkspaceBinding(ctx context.Context, arg database.UpdateChatWorkspaceBindingParams) (database.Chat, error) { chat, err := q.db.GetChatByID(ctx, arg.ID) if err != nil { return database.Chat{}, err @@ -5715,15 +5727,7 @@ func (q *querier) UpdateChatWorkspace(ctx context.Context, arg database.UpdateCh return database.Chat{}, err } - // UpdateChatWorkspace is manually implemented for chat tables and may not be - // present on every wrapped store interface yet. - chatWorkspaceUpdater, ok := q.db.(interface { - UpdateChatWorkspace(context.Context, database.UpdateChatWorkspaceParams) (database.Chat, error) - }) - if !ok { - return database.Chat{}, xerrors.New("update chat workspace is not implemented by wrapped store") - } - return chatWorkspaceUpdater.UpdateChatWorkspace(ctx, arg) + return q.db.UpdateChatWorkspaceBinding(ctx, arg) } func (q *querier) UpdateCryptoKeyDeletesAt(ctx context.Context, arg database.UpdateCryptoKeyDeletesAtParams) (database.CryptoKey, error) { diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 31765996dc..4d657e2f12 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -819,15 +819,29 @@ func (s *MethodTestSuite) TestChats() { dbm.EXPECT().UpdateChatStatus(gomock.Any(), arg).Return(chat, nil).AnyTimes() check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(chat) })) - s.Run("UpdateChatWorkspace", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + s.Run("UpdateChatBuildAgentBinding", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { chat := testutil.Fake(s.T(), faker, database.Chat{}) - arg := database.UpdateChatWorkspaceParams{ - ID: chat.ID, - WorkspaceID: uuid.NullUUID{UUID: uuid.New(), Valid: true}, + arg := database.UpdateChatBuildAgentBindingParams{ + ID: chat.ID, + BuildID: uuid.NullUUID{UUID: uuid.New(), Valid: true}, + AgentID: uuid.NullUUID{UUID: uuid.New(), Valid: true}, } updatedChat := testutil.Fake(s.T(), faker, database.Chat{ID: chat.ID}) dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes() - dbm.EXPECT().UpdateChatWorkspace(gomock.Any(), arg).Return(updatedChat, nil).AnyTimes() + dbm.EXPECT().UpdateChatBuildAgentBinding(gomock.Any(), arg).Return(updatedChat, nil).AnyTimes() + check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(updatedChat) + })) + s.Run("UpdateChatWorkspaceBinding", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + chat := testutil.Fake(s.T(), faker, database.Chat{}) + arg := database.UpdateChatWorkspaceBindingParams{ + ID: chat.ID, + WorkspaceID: uuid.NullUUID{UUID: uuid.New(), Valid: true}, + BuildID: uuid.NullUUID{UUID: uuid.New(), Valid: true}, + AgentID: uuid.NullUUID{UUID: uuid.New(), Valid: true}, + } + updatedChat := testutil.Fake(s.T(), faker, database.Chat{ID: chat.ID}) + dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes() + dbm.EXPECT().UpdateChatWorkspaceBinding(gomock.Any(), arg).Return(updatedChat, nil).AnyTimes() check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(updatedChat) })) s.Run("UnsetDefaultChatModelConfigs", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index e1f6c1e73b..1353d11d1e 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -4000,6 +4000,14 @@ func (m queryMetricsStore) UpdateAPIKeyByID(ctx context.Context, arg database.Up return r0 } +func (m queryMetricsStore) UpdateChatBuildAgentBinding(ctx context.Context, arg database.UpdateChatBuildAgentBindingParams) (database.Chat, error) { + start := time.Now() + r0, r1 := m.s.UpdateChatBuildAgentBinding(ctx, arg) + m.queryLatencies.WithLabelValues("UpdateChatBuildAgentBinding").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatBuildAgentBinding").Inc() + return r0, r1 +} + func (m queryMetricsStore) UpdateChatByID(ctx context.Context, arg database.UpdateChatByIDParams) (database.Chat, error) { start := time.Now() r0, r1 := m.s.UpdateChatByID(ctx, arg) @@ -4064,11 +4072,11 @@ func (m queryMetricsStore) UpdateChatStatus(ctx context.Context, arg database.Up return r0, r1 } -func (m queryMetricsStore) UpdateChatWorkspace(ctx context.Context, arg database.UpdateChatWorkspaceParams) (database.Chat, error) { +func (m queryMetricsStore) UpdateChatWorkspaceBinding(ctx context.Context, arg database.UpdateChatWorkspaceBindingParams) (database.Chat, error) { start := time.Now() - r0, r1 := m.s.UpdateChatWorkspace(ctx, arg) - m.queryLatencies.WithLabelValues("UpdateChatWorkspace").Observe(time.Since(start).Seconds()) - m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatWorkspace").Inc() + r0, r1 := m.s.UpdateChatWorkspaceBinding(ctx, arg) + m.queryLatencies.WithLabelValues("UpdateChatWorkspaceBinding").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatWorkspaceBinding").Inc() return r0, r1 } diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index 50d7bd9e09..13622b6a1a 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -7552,6 +7552,21 @@ func (mr *MockStoreMockRecorder) UpdateAPIKeyByID(ctx, arg any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAPIKeyByID", reflect.TypeOf((*MockStore)(nil).UpdateAPIKeyByID), ctx, arg) } +// UpdateChatBuildAgentBinding mocks base method. +func (m *MockStore) UpdateChatBuildAgentBinding(ctx context.Context, arg database.UpdateChatBuildAgentBindingParams) (database.Chat, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateChatBuildAgentBinding", ctx, arg) + ret0, _ := ret[0].(database.Chat) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpdateChatBuildAgentBinding indicates an expected call of UpdateChatBuildAgentBinding. +func (mr *MockStoreMockRecorder) UpdateChatBuildAgentBinding(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatBuildAgentBinding", reflect.TypeOf((*MockStore)(nil).UpdateChatBuildAgentBinding), ctx, arg) +} + // UpdateChatByID mocks base method. func (m *MockStore) UpdateChatByID(ctx context.Context, arg database.UpdateChatByIDParams) (database.Chat, error) { m.ctrl.T.Helper() @@ -7672,19 +7687,19 @@ func (mr *MockStoreMockRecorder) UpdateChatStatus(ctx, arg any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatStatus", reflect.TypeOf((*MockStore)(nil).UpdateChatStatus), ctx, arg) } -// UpdateChatWorkspace mocks base method. -func (m *MockStore) UpdateChatWorkspace(ctx context.Context, arg database.UpdateChatWorkspaceParams) (database.Chat, error) { +// UpdateChatWorkspaceBinding mocks base method. +func (m *MockStore) UpdateChatWorkspaceBinding(ctx context.Context, arg database.UpdateChatWorkspaceBindingParams) (database.Chat, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateChatWorkspace", ctx, arg) + ret := m.ctrl.Call(m, "UpdateChatWorkspaceBinding", ctx, arg) ret0, _ := ret[0].(database.Chat) ret1, _ := ret[1].(error) return ret0, ret1 } -// UpdateChatWorkspace indicates an expected call of UpdateChatWorkspace. -func (mr *MockStoreMockRecorder) UpdateChatWorkspace(ctx, arg any) *gomock.Call { +// UpdateChatWorkspaceBinding indicates an expected call of UpdateChatWorkspaceBinding. +func (mr *MockStoreMockRecorder) UpdateChatWorkspaceBinding(ctx, arg any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatWorkspace", reflect.TypeOf((*MockStore)(nil).UpdateChatWorkspace), ctx, arg) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatWorkspaceBinding", reflect.TypeOf((*MockStore)(nil).UpdateChatWorkspaceBinding), ctx, arg) } // UpdateCryptoKeyDeletesAt mocks base method. diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index 1c3d158b59..154d22cd62 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -1399,7 +1399,9 @@ CREATE TABLE chats ( last_error text, mode chat_mode, mcp_server_ids uuid[] DEFAULT '{}'::uuid[] NOT NULL, - labels jsonb DEFAULT '{}'::jsonb NOT NULL + labels jsonb DEFAULT '{}'::jsonb NOT NULL, + build_id uuid, + agent_id uuid ); CREATE TABLE connection_logs ( @@ -4033,6 +4035,12 @@ ALTER TABLE ONLY chat_providers ALTER TABLE ONLY chat_queued_messages ADD CONSTRAINT chat_queued_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE; +ALTER TABLE ONLY chats + ADD CONSTRAINT chats_agent_id_fkey FOREIGN KEY (agent_id) REFERENCES workspace_agents(id) ON DELETE SET NULL; + +ALTER TABLE ONLY chats + ADD CONSTRAINT chats_build_id_fkey FOREIGN KEY (build_id) REFERENCES workspace_builds(id) ON DELETE SET NULL; + ALTER TABLE ONLY chats ADD CONSTRAINT chats_last_model_config_id_fkey FOREIGN KEY (last_model_config_id) REFERENCES chat_model_configs(id); diff --git a/coderd/database/foreign_key_constraint.go b/coderd/database/foreign_key_constraint.go index b6095e0547..4f7ec37f0b 100644 --- a/coderd/database/foreign_key_constraint.go +++ b/coderd/database/foreign_key_constraint.go @@ -20,6 +20,8 @@ const ( ForeignKeyChatProvidersAPIKeyKeyID ForeignKeyConstraint = "chat_providers_api_key_key_id_fkey" // ALTER TABLE ONLY chat_providers ADD CONSTRAINT chat_providers_api_key_key_id_fkey FOREIGN KEY (api_key_key_id) REFERENCES dbcrypt_keys(active_key_digest); ForeignKeyChatProvidersCreatedBy ForeignKeyConstraint = "chat_providers_created_by_fkey" // ALTER TABLE ONLY chat_providers ADD CONSTRAINT chat_providers_created_by_fkey FOREIGN KEY (created_by) REFERENCES users(id); ForeignKeyChatQueuedMessagesChatID ForeignKeyConstraint = "chat_queued_messages_chat_id_fkey" // ALTER TABLE ONLY chat_queued_messages ADD CONSTRAINT chat_queued_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE; + ForeignKeyChatsAgentID ForeignKeyConstraint = "chats_agent_id_fkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_agent_id_fkey FOREIGN KEY (agent_id) REFERENCES workspace_agents(id) ON DELETE SET NULL; + ForeignKeyChatsBuildID ForeignKeyConstraint = "chats_build_id_fkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_build_id_fkey FOREIGN KEY (build_id) REFERENCES workspace_builds(id) ON DELETE SET NULL; ForeignKeyChatsLastModelConfigID ForeignKeyConstraint = "chats_last_model_config_id_fkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_last_model_config_id_fkey FOREIGN KEY (last_model_config_id) REFERENCES chat_model_configs(id); ForeignKeyChatsOwnerID ForeignKeyConstraint = "chats_owner_id_fkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_owner_id_fkey FOREIGN KEY (owner_id) REFERENCES users(id) ON DELETE CASCADE; ForeignKeyChatsParentChatID ForeignKeyConstraint = "chats_parent_chat_id_fkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_parent_chat_id_fkey FOREIGN KEY (parent_chat_id) REFERENCES chats(id) ON DELETE SET NULL; diff --git a/coderd/database/migrations/000452_chat_workspace_binding.down.sql b/coderd/database/migrations/000452_chat_workspace_binding.down.sql new file mode 100644 index 0000000000..c192261389 --- /dev/null +++ b/coderd/database/migrations/000452_chat_workspace_binding.down.sql @@ -0,0 +1,3 @@ +ALTER TABLE chats + DROP COLUMN IF EXISTS build_id, + DROP COLUMN IF EXISTS agent_id; diff --git a/coderd/database/migrations/000452_chat_workspace_binding.up.sql b/coderd/database/migrations/000452_chat_workspace_binding.up.sql new file mode 100644 index 0000000000..8788ac93f0 --- /dev/null +++ b/coderd/database/migrations/000452_chat_workspace_binding.up.sql @@ -0,0 +1,3 @@ +ALTER TABLE chats + ADD COLUMN build_id UUID REFERENCES workspace_builds(id) ON DELETE SET NULL, + ADD COLUMN agent_id UUID REFERENCES workspace_agents(id) ON DELETE SET NULL; diff --git a/coderd/database/modelqueries.go b/coderd/database/modelqueries.go index e62cb25922..33d181053d 100644 --- a/coderd/database/modelqueries.go +++ b/coderd/database/modelqueries.go @@ -791,6 +791,8 @@ func (q *sqlQuerier) GetAuthorizedChats(ctx context.Context, arg GetChatsParams, &i.Mode, pq.Array(&i.MCPServerIDs), &i.Labels, + &i.BuildID, + &i.AgentID, ); err != nil { return nil, err } diff --git a/coderd/database/models.go b/coderd/database/models.go index 08b06cd811..d1c614923b 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -4171,6 +4171,8 @@ type Chat struct { Mode NullChatMode `db:"mode" json:"mode"` MCPServerIDs []uuid.UUID `db:"mcp_server_ids" json:"mcp_server_ids"` Labels StringMap `db:"labels" json:"labels"` + BuildID uuid.NullUUID `db:"build_id" json:"build_id"` + AgentID uuid.NullUUID `db:"agent_id" json:"agent_id"` } type ChatDiffStatus struct { diff --git a/coderd/database/querier.go b/coderd/database/querier.go index cf6cc78ce2..a1fc9cc05e 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -819,6 +819,7 @@ type sqlcQuerier interface { UnsetDefaultChatModelConfigs(ctx context.Context) error UpdateAIBridgeInterceptionEnded(ctx context.Context, arg UpdateAIBridgeInterceptionEndedParams) (AIBridgeInterception, error) UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error + UpdateChatBuildAgentBinding(ctx context.Context, arg UpdateChatBuildAgentBindingParams) (Chat, error) UpdateChatByID(ctx context.Context, arg UpdateChatByIDParams) (Chat, error) // Bumps the heartbeat timestamp for a running chat so that other // replicas know the worker is still alive. @@ -829,7 +830,7 @@ type sqlcQuerier interface { UpdateChatModelConfig(ctx context.Context, arg UpdateChatModelConfigParams) (ChatModelConfig, error) UpdateChatProvider(ctx context.Context, arg UpdateChatProviderParams) (ChatProvider, error) UpdateChatStatus(ctx context.Context, arg UpdateChatStatusParams) (Chat, error) - UpdateChatWorkspace(ctx context.Context, arg UpdateChatWorkspaceParams) (Chat, error) + UpdateChatWorkspaceBinding(ctx context.Context, arg UpdateChatWorkspaceBindingParams) (Chat, error) UpdateCryptoKeyDeletesAt(ctx context.Context, arg UpdateCryptoKeyDeletesAtParams) (CryptoKey, error) UpdateCustomRole(ctx context.Context, arg UpdateCustomRoleParams) (CustomRole, error) UpdateExternalAuthLink(ctx context.Context, arg UpdateExternalAuthLinkParams) (ExternalAuthLink, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index d276c524fe..43f85ce750 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -3823,7 +3823,7 @@ WHERE $3::int ) RETURNING - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id ` type AcquireChatsParams struct { @@ -3862,6 +3862,8 @@ func (q *sqlQuerier) AcquireChats(ctx context.Context, arg AcquireChatsParams) ( &i.Mode, pq.Array(&i.MCPServerIDs), &i.Labels, + &i.BuildID, + &i.AgentID, ); err != nil { return nil, err } @@ -4095,7 +4097,7 @@ func (q *sqlQuerier) DeleteChatUsageLimitUserOverride(ctx context.Context, userI const getChatByID = `-- name: GetChatByID :one SELECT - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id FROM chats WHERE @@ -4124,12 +4126,14 @@ func (q *sqlQuerier) GetChatByID(ctx context.Context, id uuid.UUID) (Chat, error &i.Mode, pq.Array(&i.MCPServerIDs), &i.Labels, + &i.BuildID, + &i.AgentID, ) return i, err } const getChatByIDForUpdate = `-- name: GetChatByIDForUpdate :one -SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels FROM chats WHERE id = $1::uuid FOR UPDATE +SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id FROM chats WHERE id = $1::uuid FOR UPDATE ` func (q *sqlQuerier) GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (Chat, error) { @@ -4154,6 +4158,8 @@ func (q *sqlQuerier) GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (Ch &i.Mode, pq.Array(&i.MCPServerIDs), &i.Labels, + &i.BuildID, + &i.AgentID, ) return i, err } @@ -4998,7 +5004,7 @@ func (q *sqlQuerier) GetChatUsageLimitUserOverride(ctx context.Context, userID u const getChats = `-- name: GetChats :many SELECT - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id FROM chats WHERE @@ -5089,6 +5095,8 @@ func (q *sqlQuerier) GetChats(ctx context.Context, arg GetChatsParams) ([]Chat, &i.Mode, pq.Array(&i.MCPServerIDs), &i.Labels, + &i.BuildID, + &i.AgentID, ); err != nil { return nil, err } @@ -5154,7 +5162,7 @@ func (q *sqlQuerier) GetLastChatMessageByRole(ctx context.Context, arg GetLastCh const getStaleChats = `-- name: GetStaleChats :many SELECT - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id FROM chats WHERE @@ -5192,6 +5200,8 @@ func (q *sqlQuerier) GetStaleChats(ctx context.Context, staleThreshold time.Time &i.Mode, pq.Array(&i.MCPServerIDs), &i.Labels, + &i.BuildID, + &i.AgentID, ); err != nil { return nil, err } @@ -5250,6 +5260,8 @@ const insertChat = `-- name: InsertChat :one INSERT INTO chats ( owner_id, workspace_id, + build_id, + agent_id, parent_chat_id, root_chat_id, last_model_config_id, @@ -5263,18 +5275,22 @@ INSERT INTO chats ( $3::uuid, $4::uuid, $5::uuid, - $6::text, - $7::chat_mode, - COALESCE($8::uuid[], '{}'::uuid[]), - COALESCE($9::jsonb, '{}'::jsonb) + $6::uuid, + $7::uuid, + $8::text, + $9::chat_mode, + COALESCE($10::uuid[], '{}'::uuid[]), + COALESCE($11::jsonb, '{}'::jsonb) ) RETURNING - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id ` type InsertChatParams struct { OwnerID uuid.UUID `db:"owner_id" json:"owner_id"` WorkspaceID uuid.NullUUID `db:"workspace_id" json:"workspace_id"` + BuildID uuid.NullUUID `db:"build_id" json:"build_id"` + AgentID uuid.NullUUID `db:"agent_id" json:"agent_id"` ParentChatID uuid.NullUUID `db:"parent_chat_id" json:"parent_chat_id"` RootChatID uuid.NullUUID `db:"root_chat_id" json:"root_chat_id"` LastModelConfigID uuid.UUID `db:"last_model_config_id" json:"last_model_config_id"` @@ -5288,6 +5304,8 @@ func (q *sqlQuerier) InsertChat(ctx context.Context, arg InsertChatParams) (Chat row := q.db.QueryRowContext(ctx, insertChat, arg.OwnerID, arg.WorkspaceID, + arg.BuildID, + arg.AgentID, arg.ParentChatID, arg.RootChatID, arg.LastModelConfigID, @@ -5316,6 +5334,8 @@ func (q *sqlQuerier) InsertChat(ctx context.Context, arg InsertChatParams) (Chat &i.Mode, pq.Array(&i.MCPServerIDs), &i.Labels, + &i.BuildID, + &i.AgentID, ) return i, err } @@ -5702,6 +5722,50 @@ func (q *sqlQuerier) UnarchiveChatByID(ctx context.Context, id uuid.UUID) error return err } +const updateChatBuildAgentBinding = `-- name: UpdateChatBuildAgentBinding :one +UPDATE chats SET + build_id = $1::uuid, + agent_id = $2::uuid, + updated_at = NOW() +WHERE + id = $3::uuid +RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id +` + +type UpdateChatBuildAgentBindingParams struct { + BuildID uuid.NullUUID `db:"build_id" json:"build_id"` + AgentID uuid.NullUUID `db:"agent_id" json:"agent_id"` + ID uuid.UUID `db:"id" json:"id"` +} + +func (q *sqlQuerier) UpdateChatBuildAgentBinding(ctx context.Context, arg UpdateChatBuildAgentBindingParams) (Chat, error) { + row := q.db.QueryRowContext(ctx, updateChatBuildAgentBinding, arg.BuildID, arg.AgentID, arg.ID) + var i Chat + err := row.Scan( + &i.ID, + &i.OwnerID, + &i.WorkspaceID, + &i.Title, + &i.Status, + &i.WorkerID, + &i.StartedAt, + &i.HeartbeatAt, + &i.CreatedAt, + &i.UpdatedAt, + &i.ParentChatID, + &i.RootChatID, + &i.LastModelConfigID, + &i.Archived, + &i.LastError, + &i.Mode, + pq.Array(&i.MCPServerIDs), + &i.Labels, + &i.BuildID, + &i.AgentID, + ) + return i, err +} + const updateChatByID = `-- name: UpdateChatByID :one UPDATE chats @@ -5711,7 +5775,7 @@ SET WHERE id = $2::uuid RETURNING - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id ` type UpdateChatByIDParams struct { @@ -5741,6 +5805,8 @@ func (q *sqlQuerier) UpdateChatByID(ctx context.Context, arg UpdateChatByIDParam &i.Mode, pq.Array(&i.MCPServerIDs), &i.Labels, + &i.BuildID, + &i.AgentID, ) return i, err } @@ -5780,7 +5846,7 @@ SET WHERE id = $2::uuid RETURNING - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id ` type UpdateChatLabelsByIDParams struct { @@ -5810,6 +5876,8 @@ func (q *sqlQuerier) UpdateChatLabelsByID(ctx context.Context, arg UpdateChatLab &i.Mode, pq.Array(&i.MCPServerIDs), &i.Labels, + &i.BuildID, + &i.AgentID, ) return i, err } @@ -5823,7 +5891,7 @@ SET WHERE id = $2::uuid RETURNING - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id ` type UpdateChatMCPServerIDsParams struct { @@ -5853,6 +5921,8 @@ func (q *sqlQuerier) UpdateChatMCPServerIDs(ctx context.Context, arg UpdateChatM &i.Mode, pq.Array(&i.MCPServerIDs), &i.Labels, + &i.BuildID, + &i.AgentID, ) return i, err } @@ -5917,7 +5987,7 @@ SET WHERE id = $6::uuid RETURNING - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id ` type UpdateChatStatusParams struct { @@ -5958,29 +6028,36 @@ func (q *sqlQuerier) UpdateChatStatus(ctx context.Context, arg UpdateChatStatusP &i.Mode, pq.Array(&i.MCPServerIDs), &i.Labels, + &i.BuildID, + &i.AgentID, ) return i, err } -const updateChatWorkspace = `-- name: UpdateChatWorkspace :one -UPDATE - chats -SET +const updateChatWorkspaceBinding = `-- name: UpdateChatWorkspaceBinding :one +UPDATE chats SET workspace_id = $1::uuid, + build_id = $2::uuid, + agent_id = $3::uuid, updated_at = NOW() -WHERE - id = $2::uuid -RETURNING - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels +WHERE id = $4::uuid +RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id ` -type UpdateChatWorkspaceParams struct { +type UpdateChatWorkspaceBindingParams struct { WorkspaceID uuid.NullUUID `db:"workspace_id" json:"workspace_id"` + BuildID uuid.NullUUID `db:"build_id" json:"build_id"` + AgentID uuid.NullUUID `db:"agent_id" json:"agent_id"` ID uuid.UUID `db:"id" json:"id"` } -func (q *sqlQuerier) UpdateChatWorkspace(ctx context.Context, arg UpdateChatWorkspaceParams) (Chat, error) { - row := q.db.QueryRowContext(ctx, updateChatWorkspace, arg.WorkspaceID, arg.ID) +func (q *sqlQuerier) UpdateChatWorkspaceBinding(ctx context.Context, arg UpdateChatWorkspaceBindingParams) (Chat, error) { + row := q.db.QueryRowContext(ctx, updateChatWorkspaceBinding, + arg.WorkspaceID, + arg.BuildID, + arg.AgentID, + arg.ID, + ) var i Chat err := row.Scan( &i.ID, @@ -6001,6 +6078,8 @@ func (q *sqlQuerier) UpdateChatWorkspace(ctx context.Context, arg UpdateChatWork &i.Mode, pq.Array(&i.MCPServerIDs), &i.Labels, + &i.BuildID, + &i.AgentID, ) return i, err } diff --git a/coderd/database/queries/chats.sql b/coderd/database/queries/chats.sql index 0130d71774..0b9d3b7b24 100644 --- a/coderd/database/queries/chats.sql +++ b/coderd/database/queries/chats.sql @@ -180,6 +180,8 @@ LIMIT INSERT INTO chats ( owner_id, workspace_id, + build_id, + agent_id, parent_chat_id, root_chat_id, last_model_config_id, @@ -190,6 +192,8 @@ INSERT INTO chats ( ) VALUES ( @owner_id::uuid, sqlc.narg('workspace_id')::uuid, + sqlc.narg('build_id')::uuid, + sqlc.narg('agent_id')::uuid, sqlc.narg('parent_chat_id')::uuid, sqlc.narg('root_chat_id')::uuid, @last_model_config_id::uuid, @@ -305,16 +309,23 @@ WHERE RETURNING *; --- name: UpdateChatWorkspace :one -UPDATE - chats -SET +-- name: UpdateChatWorkspaceBinding :one +UPDATE chats SET workspace_id = sqlc.narg('workspace_id')::uuid, + build_id = sqlc.narg('build_id')::uuid, + agent_id = sqlc.narg('agent_id')::uuid, + updated_at = NOW() +WHERE id = @id::uuid +RETURNING *; + +-- name: UpdateChatBuildAgentBinding :one +UPDATE chats SET + build_id = sqlc.narg('build_id')::uuid, + agent_id = sqlc.narg('agent_id')::uuid, updated_at = NOW() WHERE id = @id::uuid -RETURNING - *; +RETURNING *; -- name: UpdateChatMCPServerIDs :one UPDATE diff --git a/coderd/exp_chats.go b/coderd/exp_chats.go index 79f97f3ffe..37b1927764 100644 --- a/coderd/exp_chats.go +++ b/coderd/exp_chats.go @@ -3600,6 +3600,12 @@ func convertChat(c database.Chat, diffStatus *database.ChatDiffStatus) codersdk. if c.WorkspaceID.Valid { chat.WorkspaceID = &c.WorkspaceID.UUID } + if c.BuildID.Valid { + chat.BuildID = &c.BuildID.UUID + } + if c.AgentID.Valid { + chat.AgentID = &c.AgentID.UUID + } if diffStatus != nil { convertedDiffStatus := db2sdk.ChatDiffStatus(c.ID, diffStatus) chat.DiffStatus = &convertedDiffStatus diff --git a/coderd/x/chatd/chatd.go b/coderd/x/chatd/chatd.go index 710d90db2e..4f88792106 100644 --- a/coderd/x/chatd/chatd.go +++ b/coderd/x/chatd/chatd.go @@ -53,6 +53,8 @@ const ( DefaultInFlightChatStaleAfter = 5 * time.Minute homeInstructionLookupTimeout = 5 * time.Second + instructionCacheTTL = 5 * time.Minute + workspaceDialValidationDelay = 5 * time.Second // DefaultChatHeartbeatInterval is the default time between chat // heartbeat updates while a chat is being processed. DefaultChatHeartbeatInterval = 30 * time.Second @@ -158,18 +160,26 @@ type turnWorkspaceContext struct { currentChat *database.Chat loadChatSnapshot func(context.Context, uuid.UUID) (database.Chat, error) - mu sync.Mutex - agent database.WorkspaceAgent - agentLoaded bool - conn workspacesdk.AgentConn - releaseConn func() + mu sync.Mutex + agent database.WorkspaceAgent + agentLoaded bool + conn workspacesdk.AgentConn + releaseConn func() + cachedWorkspaceID uuid.NullUUID } func (c *turnWorkspaceContext) close() { + c.clearCachedWorkspaceState() +} + +func (c *turnWorkspaceContext) clearCachedWorkspaceState() { c.mu.Lock() releaseConn := c.releaseConn + c.agent = database.WorkspaceAgent{} + c.agentLoaded = false c.conn = nil c.releaseConn = nil + c.cachedWorkspaceID = uuid.NullUUID{} c.mu.Unlock() if releaseConn != nil { @@ -177,6 +187,68 @@ func (c *turnWorkspaceContext) close() { } } +func (c *turnWorkspaceContext) setCurrentChat(chat database.Chat) { + c.chatStateMu.Lock() + *c.currentChat = chat + c.chatStateMu.Unlock() +} + +func (c *turnWorkspaceContext) currentChatSnapshot() database.Chat { + c.chatStateMu.Lock() + chatSnapshot := *c.currentChat + c.chatStateMu.Unlock() + return chatSnapshot +} + +func (c *turnWorkspaceContext) selectWorkspace(chat database.Chat) { + c.setCurrentChat(chat) + c.clearCachedWorkspaceState() +} + +func (c *turnWorkspaceContext) currentWorkspaceMatches(expected uuid.NullUUID) (database.Chat, bool) { + chatSnapshot := c.currentChatSnapshot() + return chatSnapshot, nullUUIDEqual(chatSnapshot.WorkspaceID, expected) +} + +func nullUUIDEqual(left, right uuid.NullUUID) bool { + if left.Valid != right.Valid { + return false + } + if !left.Valid { + return true + } + return left.UUID == right.UUID +} + +func (c *turnWorkspaceContext) persistBuildAgentBinding( + ctx context.Context, + chatSnapshot database.Chat, + buildID uuid.UUID, + agentID uuid.UUID, +) (database.Chat, error) { + updatedChat, err := c.server.db.UpdateChatBuildAgentBinding( + ctx, + database.UpdateChatBuildAgentBindingParams{ + ID: chatSnapshot.ID, + BuildID: uuid.NullUUID{ + UUID: buildID, + Valid: true, + }, + AgentID: uuid.NullUUID{ + UUID: agentID, + Valid: true, + }, + }, + ) + if err != nil { + return chatSnapshot, xerrors.Errorf( + "update chat build/agent binding: %w", err, + ) + } + c.setCurrentChat(updatedChat) + return updatedChat, nil +} + func (c *turnWorkspaceContext) getWorkspaceAgent(ctx context.Context) (database.WorkspaceAgent, error) { _, agent, err := c.ensureWorkspaceAgent(ctx) return agent, err @@ -189,128 +261,245 @@ func (c *turnWorkspaceContext) ensureWorkspaceAgent( defer c.mu.Unlock() if c.agentLoaded { - c.chatStateMu.Lock() - chatSnapshot := *c.currentChat - c.chatStateMu.Unlock() - return chatSnapshot, c.agent, nil + chatSnapshot := c.currentChatSnapshot() + if nullUUIDEqual(c.cachedWorkspaceID, chatSnapshot.WorkspaceID) { + return chatSnapshot, c.agent, nil + } + c.agent = database.WorkspaceAgent{} + c.agentLoaded = false } return c.loadWorkspaceAgentLocked(ctx) } -func (c *turnWorkspaceContext) refreshWorkspaceAgent( - ctx context.Context, -) (database.Chat, database.WorkspaceAgent, error) { - c.mu.Lock() - defer c.mu.Unlock() - - c.agent = database.WorkspaceAgent{} - c.agentLoaded = false - return c.loadWorkspaceAgentLocked(ctx) -} - func (c *turnWorkspaceContext) loadWorkspaceAgentLocked( ctx context.Context, ) (database.Chat, database.WorkspaceAgent, error) { - c.chatStateMu.Lock() - chatSnapshot := *c.currentChat - c.chatStateMu.Unlock() + chatSnapshot := c.currentChatSnapshot() - if !chatSnapshot.WorkspaceID.Valid { - refreshedChat, refreshErr := refreshChatWorkspaceSnapshot( + for attempt := 0; attempt < 2; attempt++ { + if !chatSnapshot.WorkspaceID.Valid { + refreshedChat, refreshErr := refreshChatWorkspaceSnapshot( + ctx, + chatSnapshot, + c.loadChatSnapshot, + ) + if refreshErr != nil { + return chatSnapshot, database.WorkspaceAgent{}, refreshErr + } + if refreshedChat.WorkspaceID.Valid { + c.setCurrentChat(refreshedChat) + chatSnapshot = refreshedChat + } + } + + if !chatSnapshot.WorkspaceID.Valid { + return chatSnapshot, database.WorkspaceAgent{}, xerrors.New("chat has no workspace") + } + + if chatSnapshot.AgentID.Valid { + agent, err := c.server.db.GetWorkspaceAgentByID(ctx, chatSnapshot.AgentID.UUID) + if err == nil { + latestChat, workspaceMatches := c.currentWorkspaceMatches(chatSnapshot.WorkspaceID) + if !workspaceMatches { + chatSnapshot = latestChat + continue + } + c.agent = agent + c.agentLoaded = true + c.cachedWorkspaceID = chatSnapshot.WorkspaceID + return chatSnapshot, c.agent, nil + } + if !xerrors.Is(err, sql.ErrNoRows) { + c.server.logger.Warn(ctx, "agent binding lookup failed, re-resolving", + slog.F("agent_id", chatSnapshot.AgentID.UUID), + slog.Error(err), + ) + } + } + + agents, err := c.server.db.GetWorkspaceAgentsInLatestBuildByWorkspaceID( + ctx, + chatSnapshot.WorkspaceID.UUID, + ) + if err != nil || len(agents) == 0 { + return chatSnapshot, database.WorkspaceAgent{}, xerrors.New("chat has no workspace agent") + } + + build, err := c.server.db.GetLatestWorkspaceBuildByWorkspaceID(ctx, chatSnapshot.WorkspaceID.UUID) + if err != nil { + return chatSnapshot, database.WorkspaceAgent{}, xerrors.Errorf("get latest workspace build: %w", err) + } + + updatedChat, err := c.persistBuildAgentBinding( ctx, chatSnapshot, - c.loadChatSnapshot, + build.ID, + agents[0].ID, ) - if refreshErr != nil { - return chatSnapshot, database.WorkspaceAgent{}, refreshErr + if err != nil { + return chatSnapshot, database.WorkspaceAgent{}, err } - if refreshedChat.WorkspaceID.Valid { - c.chatStateMu.Lock() - *c.currentChat = refreshedChat - c.chatStateMu.Unlock() - chatSnapshot = refreshedChat + + chatSnapshot = updatedChat + latestChat, workspaceMatches := c.currentWorkspaceMatches(chatSnapshot.WorkspaceID) + if !workspaceMatches { + chatSnapshot = latestChat + continue } + c.agent = agents[0] + c.agentLoaded = true + c.cachedWorkspaceID = chatSnapshot.WorkspaceID + return chatSnapshot, c.agent, nil } - if !chatSnapshot.WorkspaceID.Valid { - return chatSnapshot, database.WorkspaceAgent{}, xerrors.New("chat has no workspace") - } - - agents, err := c.server.db.GetWorkspaceAgentsInLatestBuildByWorkspaceID( - ctx, - chatSnapshot.WorkspaceID.UUID, + return chatSnapshot, database.WorkspaceAgent{}, xerrors.New( + "chat workspace changed while resolving agent", ) - if err != nil || len(agents) == 0 { - return chatSnapshot, database.WorkspaceAgent{}, xerrors.New("chat has no workspace agent") +} + +// getWorkspaceConnLocked returns the cached connection when it still matches +// the current workspace. When the workspace changed, it clears the stale +// cached state and returns the release func for the caller to run after +// unlocking. +func (c *turnWorkspaceContext) getWorkspaceConnLocked() (workspacesdk.AgentConn, func()) { + if c.conn == nil { + return nil, nil } - c.agent = agents[0] - c.agentLoaded = true - return chatSnapshot, c.agent, nil + chatSnapshot := c.currentChatSnapshot() + if nullUUIDEqual(c.cachedWorkspaceID, chatSnapshot.WorkspaceID) { + return c.conn, nil + } + + agentRelease := c.releaseConn + c.agent = database.WorkspaceAgent{} + c.agentLoaded = false + c.conn = nil + c.releaseConn = nil + c.cachedWorkspaceID = uuid.NullUUID{} + return nil, agentRelease } func (c *turnWorkspaceContext) getWorkspaceConn(ctx context.Context) (workspacesdk.AgentConn, error) { - c.mu.Lock() - if c.conn != nil { - currentConn := c.conn - c.mu.Unlock() - return currentConn, nil - } - c.mu.Unlock() - if c.server.agentConnFn == nil { return nil, xerrors.New("workspace agent connector is not configured") } - chatSnapshot, agent, err := c.ensureWorkspaceAgent(ctx) - if err != nil { - return nil, err - } - - agentConn, agentRelease, err := c.server.agentConnFn(ctx, agent.ID) - if err != nil { - refreshedChat, refreshedAgent, refreshErr := c.refreshWorkspaceAgent(ctx) - if refreshErr != nil { - return nil, xerrors.Errorf("connect to workspace agent: %w", err) - } - - retryConn, retryRelease, retryErr := c.server.agentConnFn(ctx, refreshedAgent.ID) - if retryErr != nil { - return nil, xerrors.Errorf("connect to workspace agent after refresh: %w", retryErr) - } - - chatSnapshot = refreshedChat - agentConn = retryConn - agentRelease = retryRelease - } - - c.mu.Lock() - if c.conn == nil { - c.conn = agentConn - c.releaseConn = agentRelease - - var ancestorIDs []string - if chatSnapshot.ParentChatID.Valid { - ancestorIDs = append(ancestorIDs, chatSnapshot.ParentChatID.UUID.String()) - } - ancestorJSON, marshalErr := json.Marshal(ancestorIDs) - if marshalErr != nil { - ancestorJSON = []byte("[]") - } - agentConn.SetExtraHeaders(http.Header{ - workspacesdk.CoderChatIDHeader: {chatSnapshot.ID.String()}, - workspacesdk.CoderAncestorChatIDsHeader: {string(ancestorJSON)}, - }) - + for attempt := 0; attempt < 2; attempt++ { + c.mu.Lock() + currentConn, staleRelease := c.getWorkspaceConnLocked() c.mu.Unlock() - return agentConn, nil - } - currentConn := c.conn - c.mu.Unlock() + if currentConn != nil { + return currentConn, nil + } + if staleRelease != nil { + staleRelease() + } - agentRelease() - return currentConn, nil + chatSnapshot, agent, err := c.ensureWorkspaceAgent(ctx) + if err != nil { + return nil, err + } + + dialResult, err := dialWithLazyValidation( + ctx, + agent.ID, + chatSnapshot.WorkspaceID.UUID, + DialFunc(c.server.agentConnFn), + func(ctx context.Context, workspaceID uuid.UUID) (uuid.UUID, error) { + agents, err := c.server.db.GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx, workspaceID) + if err != nil || len(agents) == 0 { + return uuid.Nil, xerrors.New("chat has no workspace agent") + } + return agents[0].ID, nil + }, + workspaceDialValidationDelay, + ) + if err != nil { + return nil, err + } + + agentConn := dialResult.Conn + agentRelease := dialResult.Release + if dialResult.WasSwitched { + build, err := c.server.db.GetLatestWorkspaceBuildByWorkspaceID(ctx, chatSnapshot.WorkspaceID.UUID) + if err != nil { + if agentRelease != nil { + agentRelease() + } + return nil, xerrors.Errorf("get latest workspace build: %w", err) + } + + switchedAgent, err := c.server.db.GetWorkspaceAgentByID(ctx, dialResult.AgentID) + if err != nil { + if agentRelease != nil { + agentRelease() + } + return nil, xerrors.Errorf("get workspace agent by id: %w", err) + } + + updatedChat, err := c.persistBuildAgentBinding( + ctx, + chatSnapshot, + build.ID, + switchedAgent.ID, + ) + if err != nil { + if agentRelease != nil { + agentRelease() + } + return nil, err + } + chatSnapshot = updatedChat + + c.mu.Lock() + c.agent = switchedAgent + c.agentLoaded = true + c.cachedWorkspaceID = chatSnapshot.WorkspaceID + c.mu.Unlock() + } + + if _, workspaceMatches := c.currentWorkspaceMatches(chatSnapshot.WorkspaceID); !workspaceMatches { + if agentRelease != nil { + agentRelease() + } + c.clearCachedWorkspaceState() + continue + } + + c.mu.Lock() + if c.conn == nil { + c.conn = agentConn + c.releaseConn = agentRelease + c.cachedWorkspaceID = chatSnapshot.WorkspaceID + + var ancestorIDs []string + if chatSnapshot.ParentChatID.Valid { + ancestorIDs = append(ancestorIDs, chatSnapshot.ParentChatID.UUID.String()) + } + ancestorJSON, marshalErr := json.Marshal(ancestorIDs) + if marshalErr != nil { + ancestorJSON = []byte("[]") + } + agentConn.SetExtraHeaders(http.Header{ + workspacesdk.CoderChatIDHeader: {chatSnapshot.ID.String()}, + workspacesdk.CoderAncestorChatIDsHeader: {string(ancestorJSON)}, + }) + + c.mu.Unlock() + return agentConn, nil + } + currentConn = c.conn + c.mu.Unlock() + + if agentRelease != nil { + agentRelease() + } + return currentConn, nil + } + + return nil, xerrors.New("chat workspace changed while connecting") } // AgentConnFunc provides access to workspace agent connections. @@ -420,6 +609,8 @@ func (e *UsageLimitExceededError) Error() string { type CreateOptions struct { OwnerID uuid.UUID WorkspaceID uuid.NullUUID + BuildID uuid.NullUUID + AgentID uuid.NullUUID ParentChatID uuid.NullUUID RootChatID uuid.NullUUID Title string @@ -525,6 +716,8 @@ func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.C insertedChat, err := tx.InsertChat(ctx, database.InsertChatParams{ OwnerID: opts.OwnerID, WorkspaceID: opts.WorkspaceID, + BuildID: opts.BuildID, + AgentID: opts.AgentID, ParentChatID: opts.ParentChatID, RootChatID: opts.RootChatID, LastModelConfigID: opts.ModelConfigID, @@ -3461,6 +3654,7 @@ func (p *Server) runChat( AgentConnFn: chattool.AgentConnFunc(p.agentConnFn), AgentInactiveDisconnectTimeout: p.agentInactiveDisconnectTimeout, WorkspaceMu: &workspaceMu, + OnChatUpdated: workspaceCtx.selectWorkspace, Logger: p.logger, AllowedTemplateIDs: p.chatTemplateAllowlist, }), diff --git a/coderd/x/chatd/chatd_internal_test.go b/coderd/x/chatd/chatd_internal_test.go index dce239cc1f..958f3ba19e 100644 --- a/coderd/x/chatd/chatd_internal_test.go +++ b/coderd/x/chatd/chatd_internal_test.go @@ -112,24 +112,29 @@ func TestPersistInstructionFilesIncludesAgentMetadata(t *testing.T) { db := dbmock.NewMockStore(ctrl) workspaceID := uuid.New() + agentID := uuid.New() chat := database.Chat{ ID: uuid.New(), WorkspaceID: uuid.NullUUID{ UUID: workspaceID, Valid: true, }, + AgentID: uuid.NullUUID{ + UUID: agentID, + Valid: true, + }, } workspaceAgent := database.WorkspaceAgent{ - ID: uuid.New(), + ID: agentID, OperatingSystem: "linux", Directory: "/home/coder/project", ExpandedDirectory: "/home/coder/project", } - db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID( + db.EXPECT().GetWorkspaceAgentByID( gomock.Any(), - workspaceID, - ).Return([]database.WorkspaceAgent{workspaceAgent}, nil).Times(1) + agentID, + ).Return(workspaceAgent, nil).Times(1) db.EXPECT().InsertChatMessages(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() conn := agentconnmock.NewMockAgentConn(ctrl) @@ -180,7 +185,7 @@ func TestPersistInstructionFilesIncludesAgentMetadata(t *testing.T) { require.Contains(t, instruction, "Working Directory: /home/coder/project") } -func TestTurnWorkspaceContextGetWorkspaceConnRefreshesWorkspaceAgent(t *testing.T) { +func TestTurnWorkspaceContext_BindingFirstPath(t *testing.T) { t.Parallel() ctx := context.Background() @@ -188,6 +193,53 @@ func TestTurnWorkspaceContextGetWorkspaceConnRefreshesWorkspaceAgent(t *testing. db := dbmock.NewMockStore(ctrl) workspaceID := uuid.New() + agentID := uuid.New() + chat := database.Chat{ + ID: uuid.New(), + WorkspaceID: uuid.NullUUID{ + UUID: workspaceID, + Valid: true, + }, + AgentID: uuid.NullUUID{ + UUID: agentID, + Valid: true, + }, + } + workspaceAgent := database.WorkspaceAgent{ID: agentID} + + db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).Return(workspaceAgent, nil).Times(1) + + chatStateMu := &sync.Mutex{} + currentChat := chat + workspaceCtx := turnWorkspaceContext{ + server: &Server{db: db}, + chatStateMu: chatStateMu, + currentChat: ¤tChat, + loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil }, + } + t.Cleanup(workspaceCtx.close) + + chatSnapshot, agent, err := workspaceCtx.ensureWorkspaceAgent(ctx) + require.NoError(t, err) + require.Equal(t, chat, chatSnapshot) + require.Equal(t, workspaceAgent, agent) + + gotAgent, err := workspaceCtx.getWorkspaceAgent(ctx) + require.NoError(t, err) + require.Equal(t, workspaceAgent, gotAgent) + require.Equal(t, chat, currentChat) +} + +func TestTurnWorkspaceContext_NullBindingLazyBind(t *testing.T) { + t.Parallel() + + ctx := context.Background() + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + + workspaceID := uuid.New() + buildID := uuid.New() + agentID := uuid.New() chat := database.Chat{ ID: uuid.New(), WorkspaceID: uuid.NullUUID{ @@ -195,18 +247,135 @@ func TestTurnWorkspaceContextGetWorkspaceConnRefreshesWorkspaceAgent(t *testing. Valid: true, }, } - initialAgent := database.WorkspaceAgent{ID: uuid.New()} - refreshedAgent := database.WorkspaceAgent{ID: uuid.New()} + workspaceAgent := database.WorkspaceAgent{ID: agentID} + updatedChat := chat + updatedChat.BuildID = uuid.NullUUID{UUID: buildID, Valid: true} + updatedChat.AgentID = uuid.NullUUID{UUID: agentID, Valid: true} gomock.InOrder( - db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID( - gomock.Any(), - workspaceID, - ).Return([]database.WorkspaceAgent{initialAgent}, nil), - db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID( - gomock.Any(), - workspaceID, - ).Return([]database.WorkspaceAgent{refreshedAgent}, nil), + db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).Return([]database.WorkspaceAgent{workspaceAgent}, nil), + db.EXPECT().GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), workspaceID).Return(database.WorkspaceBuild{ID: buildID}, nil), + db.EXPECT().UpdateChatBuildAgentBinding(gomock.Any(), database.UpdateChatBuildAgentBindingParams{ + BuildID: uuid.NullUUID{UUID: buildID, Valid: true}, + AgentID: uuid.NullUUID{UUID: agentID, Valid: true}, + ID: chat.ID, + }).Return(updatedChat, nil), + ) + + chatStateMu := &sync.Mutex{} + currentChat := chat + workspaceCtx := turnWorkspaceContext{ + server: &Server{db: db}, + chatStateMu: chatStateMu, + currentChat: ¤tChat, + loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil }, + } + t.Cleanup(workspaceCtx.close) + + chatSnapshot, agent, err := workspaceCtx.ensureWorkspaceAgent(ctx) + require.NoError(t, err) + require.Equal(t, updatedChat, chatSnapshot) + require.Equal(t, workspaceAgent, agent) + require.Equal(t, updatedChat, currentChat) + + gotAgent, err := workspaceCtx.getWorkspaceAgent(ctx) + require.NoError(t, err) + require.Equal(t, workspaceAgent, gotAgent) +} + +func TestTurnWorkspaceContext_StaleBindingRepair(t *testing.T) { + t.Parallel() + + ctx := context.Background() + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + + workspaceID := uuid.New() + staleAgentID := uuid.New() + buildID := uuid.New() + currentAgentID := uuid.New() + chat := database.Chat{ + ID: uuid.New(), + WorkspaceID: uuid.NullUUID{ + UUID: workspaceID, + Valid: true, + }, + AgentID: uuid.NullUUID{ + UUID: staleAgentID, + Valid: true, + }, + } + currentAgent := database.WorkspaceAgent{ID: currentAgentID} + updatedChat := chat + updatedChat.BuildID = uuid.NullUUID{UUID: buildID, Valid: true} + updatedChat.AgentID = uuid.NullUUID{UUID: currentAgentID, Valid: true} + + gomock.InOrder( + db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), staleAgentID).Return(database.WorkspaceAgent{}, xerrors.New("missing agent")), + db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).Return([]database.WorkspaceAgent{currentAgent}, nil), + db.EXPECT().GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), workspaceID).Return(database.WorkspaceBuild{ID: buildID}, nil), + db.EXPECT().UpdateChatBuildAgentBinding(gomock.Any(), database.UpdateChatBuildAgentBindingParams{ + BuildID: uuid.NullUUID{UUID: buildID, Valid: true}, + AgentID: uuid.NullUUID{UUID: currentAgentID, Valid: true}, + ID: chat.ID, + }).Return(updatedChat, nil), + ) + + chatStateMu := &sync.Mutex{} + currentChat := chat + workspaceCtx := turnWorkspaceContext{ + server: &Server{db: db}, + chatStateMu: chatStateMu, + currentChat: ¤tChat, + loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil }, + } + t.Cleanup(workspaceCtx.close) + + chatSnapshot, agent, err := workspaceCtx.ensureWorkspaceAgent(ctx) + require.NoError(t, err) + require.Equal(t, updatedChat, chatSnapshot) + require.Equal(t, currentAgent, agent) + require.Equal(t, updatedChat, currentChat) +} + +func TestTurnWorkspaceContextGetWorkspaceConnLazyValidationSwitchesWorkspaceAgent(t *testing.T) { + t.Parallel() + + ctx := context.Background() + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + + workspaceID := uuid.New() + staleAgentID := uuid.New() + currentAgentID := uuid.New() + buildID := uuid.New() + chat := database.Chat{ + ID: uuid.New(), + WorkspaceID: uuid.NullUUID{ + UUID: workspaceID, + Valid: true, + }, + AgentID: uuid.NullUUID{ + UUID: staleAgentID, + Valid: true, + }, + } + staleAgent := database.WorkspaceAgent{ID: staleAgentID} + currentAgent := database.WorkspaceAgent{ID: currentAgentID} + updatedChat := chat + updatedChat.BuildID = uuid.NullUUID{UUID: buildID, Valid: true} + updatedChat.AgentID = uuid.NullUUID{UUID: currentAgentID, Valid: true} + + gomock.InOrder( + db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), staleAgentID).Return(staleAgent, nil), + db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).Return([]database.WorkspaceAgent{currentAgent}, nil), + db.EXPECT().GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), workspaceID).Return(database.WorkspaceBuild{ID: buildID}, nil), + db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), currentAgentID).Return(currentAgent, nil), + db.EXPECT().UpdateChatBuildAgentBinding(gomock.Any(), database.UpdateChatBuildAgentBindingParams{ + BuildID: uuid.NullUUID{UUID: buildID, Valid: true}, + AgentID: uuid.NullUUID{UUID: currentAgentID, Valid: true}, + ID: chat.ID, + }).Return(updatedChat, nil), ) conn := agentconnmock.NewMockAgentConn(ctrl) @@ -216,7 +385,7 @@ func TestTurnWorkspaceContextGetWorkspaceConnRefreshesWorkspaceAgent(t *testing. server := &Server{db: db} server.agentConnFn = func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) { dialed = append(dialed, agentID) - if agentID == initialAgent.ID { + if agentID == staleAgentID { return nil, nil, xerrors.New("dial failed") } return conn, func() {}, nil @@ -235,7 +404,112 @@ func TestTurnWorkspaceContextGetWorkspaceConnRefreshesWorkspaceAgent(t *testing. gotConn, err := workspaceCtx.getWorkspaceConn(ctx) require.NoError(t, err) require.Same(t, conn, gotConn) - require.Equal(t, []uuid.UUID{initialAgent.ID, refreshedAgent.ID}, dialed) + require.Equal(t, []uuid.UUID{staleAgentID, currentAgentID}, dialed) + require.Equal(t, updatedChat, currentChat) + + gotAgent, err := workspaceCtx.getWorkspaceAgent(ctx) + require.NoError(t, err) + require.Equal(t, currentAgent, gotAgent) +} + +func TestTurnWorkspaceContext_SelectWorkspaceClearsCachedState(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + currentChat := database.Chat{ + ID: uuid.New(), + WorkspaceID: uuid.NullUUID{ + UUID: uuid.New(), + Valid: true, + }, + } + updatedChat := database.Chat{ + ID: currentChat.ID, + WorkspaceID: uuid.NullUUID{ + UUID: uuid.New(), + Valid: true, + }, + } + cachedConn := agentconnmock.NewMockAgentConn(ctrl) + releaseCalls := 0 + + workspaceCtx := turnWorkspaceContext{ + chatStateMu: &sync.Mutex{}, + currentChat: ¤tChat, + } + workspaceCtx.agent = database.WorkspaceAgent{ID: uuid.New()} + workspaceCtx.agentLoaded = true + workspaceCtx.conn = cachedConn + workspaceCtx.cachedWorkspaceID = currentChat.WorkspaceID + workspaceCtx.releaseConn = func() { + releaseCalls++ + } + + workspaceCtx.selectWorkspace(updatedChat) + + require.Equal(t, updatedChat, currentChat) + require.Equal(t, 1, releaseCalls) + + workspaceCtx.mu.Lock() + defer workspaceCtx.mu.Unlock() + require.Equal(t, database.WorkspaceAgent{}, workspaceCtx.agent) + require.False(t, workspaceCtx.agentLoaded) + require.Nil(t, workspaceCtx.conn) + require.Nil(t, workspaceCtx.releaseConn) + require.Equal(t, uuid.NullUUID{}, workspaceCtx.cachedWorkspaceID) +} + +func TestTurnWorkspaceContext_EnsureWorkspaceAgentIgnoresCachedAgentForDifferentWorkspace(t *testing.T) { + t.Parallel() + + ctx := context.Background() + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + + workspaceOneID := uuid.New() + workspaceTwoID := uuid.New() + buildID := uuid.New() + cachedAgent := database.WorkspaceAgent{ID: uuid.New()} + resolvedAgent := database.WorkspaceAgent{ID: uuid.New()} + chat := database.Chat{ + ID: uuid.New(), + WorkspaceID: uuid.NullUUID{ + UUID: workspaceTwoID, + Valid: true, + }, + } + updatedChat := chat + updatedChat.BuildID = uuid.NullUUID{UUID: buildID, Valid: true} + updatedChat.AgentID = uuid.NullUUID{UUID: resolvedAgent.ID, Valid: true} + + gomock.InOrder( + db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceTwoID).Return([]database.WorkspaceAgent{resolvedAgent}, nil), + db.EXPECT().GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), workspaceTwoID).Return(database.WorkspaceBuild{ID: buildID}, nil), + db.EXPECT().UpdateChatBuildAgentBinding(gomock.Any(), database.UpdateChatBuildAgentBindingParams{ + ID: chat.ID, + BuildID: uuid.NullUUID{UUID: buildID, Valid: true}, + AgentID: uuid.NullUUID{UUID: resolvedAgent.ID, Valid: true}, + }).Return(updatedChat, nil), + ) + + chatStateMu := &sync.Mutex{} + currentChat := chat + workspaceCtx := turnWorkspaceContext{ + server: &Server{db: db}, + chatStateMu: chatStateMu, + currentChat: ¤tChat, + loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil }, + } + workspaceCtx.agent = cachedAgent + workspaceCtx.agentLoaded = true + workspaceCtx.cachedWorkspaceID = uuid.NullUUID{UUID: workspaceOneID, Valid: true} + defer workspaceCtx.close() + + chatSnapshot, agent, err := workspaceCtx.ensureWorkspaceAgent(ctx) + require.NoError(t, err) + require.Equal(t, updatedChat, chatSnapshot) + require.Equal(t, resolvedAgent, agent) + require.Equal(t, updatedChat, currentChat) } func TestSubscribeSkipsDatabaseCatchupForLocallyDeliveredMessage(t *testing.T) { diff --git a/coderd/x/chatd/chatd_test.go b/coderd/x/chatd/chatd_test.go index 66172177e9..b3f846117b 100644 --- a/coderd/x/chatd/chatd_test.go +++ b/coderd/x/chatd/chatd_test.go @@ -2280,7 +2280,7 @@ func TestHeartbeatBumpsWorkspaceUsage(t *testing.T) { // Link the workspace to the chat in the DB, simulating what // the create_workspace tool does mid-conversation. - _, err = db.UpdateChatWorkspace(ctx, database.UpdateChatWorkspaceParams{ + _, err = db.UpdateChatWorkspaceBinding(ctx, database.UpdateChatWorkspaceBindingParams{ WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true}, ID: chat.ID, }) diff --git a/coderd/x/chatd/chattool/createworkspace.go b/coderd/x/chatd/chattool/createworkspace.go index 33f27285f9..9f00c2108f 100644 --- a/coderd/x/chatd/chattool/createworkspace.go +++ b/coderd/x/chatd/chattool/createworkspace.go @@ -66,6 +66,7 @@ type CreateWorkspaceOptions struct { AgentConnFn AgentConnFunc AgentInactiveDisconnectTimeout time.Duration WorkspaceMu *sync.Mutex + OnChatUpdated func(database.Chat) Logger slog.Logger AllowedTemplateIDs func() map[uuid.UUID]bool } @@ -211,20 +212,29 @@ func CreateWorkspace(options CreateWorkspaceOptions) fantasy.AgentTool { } } - // Persist workspace + agent association on the chat. + // Persist the workspace binding on the chat. if options.DB != nil && options.ChatID != uuid.Nil { - if _, err := options.DB.UpdateChatWorkspace(ctx, database.UpdateChatWorkspaceParams{ + updatedChat, err := options.DB.UpdateChatWorkspaceBinding(ctx, database.UpdateChatWorkspaceBindingParams{ ID: options.ChatID, WorkspaceID: uuid.NullUUID{ UUID: workspace.ID, Valid: true, }, - }); err != nil { + // BuildID and AgentID are intentionally left null + // here. The chatd runtime (loadWorkspaceAgentLocked) + // will bind them on the next turn. Authoritative + // tool-path binding is deferred to a follow-up PR. + BuildID: uuid.NullUUID{}, + AgentID: uuid.NullUUID{}, + }) + if err != nil { options.Logger.Error(ctx, "failed to persist chat workspace association", slog.F("chat_id", options.ChatID), slog.F("workspace_id", workspace.ID), slog.Error(err), ) + } else if options.OnChatUpdated != nil { + options.OnChatUpdated(updatedChat) } } diff --git a/coderd/x/chatd/dialvalidation.go b/coderd/x/chatd/dialvalidation.go new file mode 100644 index 0000000000..06c1c536af --- /dev/null +++ b/coderd/x/chatd/dialvalidation.go @@ -0,0 +1,170 @@ +package chatd + +import ( + "context" + "time" + + "github.com/google/uuid" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/codersdk/workspacesdk" +) + +// DialResult contains the outcome of dialWithLazyValidation. +type DialResult struct { + Conn workspacesdk.AgentConn + Release func() + AgentID uuid.UUID // The agent that was actually dialed. + WasSwitched bool // True if validation discovered a different agent. +} + +// DialFunc dials an agent by ID and returns a connection. +type DialFunc func(ctx context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) + +// ValidateFunc returns the current agent ID for a workspace. +type ValidateFunc func(ctx context.Context, workspaceID uuid.UUID) (uuid.UUID, error) + +type dialOut struct { + conn workspacesdk.AgentConn + release func() + err error +} + +// dialWithLazyValidation dials an agent and only consults the database if the +// original dial is slow or fails quickly. This keeps the common path free of +// latest-build lookups while still repairing stale bindings. +// +// Outcomes: +// - The dial succeeds before delay, so validation is skipped. +// - The timer fires and validation confirms the same agent, so the original +// dial continues. +// - The timer fires and validation finds a different agent, so the stale +// dial is canceled and the new agent is dialed instead. +// - The dial fails before delay, so validation runs immediately and either +// switches to a different agent or retries the current one once. +func dialWithLazyValidation( + ctx context.Context, + agentID uuid.UUID, + workspaceID uuid.UUID, + dialFn DialFunc, + validateFn ValidateFunc, + delay time.Duration, +) (DialResult, error) { + wrapErr := func(err error) error { + return xerrors.Errorf("dial with lazy validation: %w", err) + } + + dialCtx, dialCancel := context.WithCancel(ctx) + results := make(chan dialOut, 1) + go func() { + conn, release, err := dialFn(dialCtx, agentID) + results <- dialOut{conn: conn, release: release, err: err} + }() + + drained := false + defer func() { + dialCancel() + if drained { + return + } + // Drain without blocking the caller. dialFn may take time to honor + // cancellation, but any late-arriving successful connection still needs to + // be released. + go func() { + result := <-results + if result.err == nil && result.release != nil { + result.release() + } + }() + }() + + resultForAgent := func(dialedAgentID uuid.UUID, result dialOut, switched bool) DialResult { + return DialResult{ + Conn: result.conn, + Release: result.release, + AgentID: dialedAgentID, + WasSwitched: switched, + } + } + dialAgent := func(targetAgentID uuid.UUID, switched bool) (DialResult, error) { + conn, release, err := dialFn(ctx, targetAgentID) + if err != nil { + return DialResult{}, wrapErr(err) + } + return resultForAgent(targetAgentID, dialOut{conn: conn, release: release}, switched), nil + } + preferReadyOriginalDial := func() (DialResult, bool) { + select { + case result := <-results: + drained = true + if result.err != nil { + return DialResult{}, false + } + return resultForAgent(agentID, result, false), true + default: + return DialResult{}, false + } + } + waitForOriginalDial := func(waitCtx context.Context) (DialResult, error) { + select { + case result := <-results: + drained = true + if result.err != nil { + return DialResult{}, wrapErr(result.err) + } + return resultForAgent(agentID, result, false), nil + case <-waitCtx.Done(): + if ready, ok := preferReadyOriginalDial(); ok { + return ready, nil + } + return DialResult{}, waitCtx.Err() + } + } + validateBinding := func() (uuid.UUID, error) { + validatedAgentID, err := validateFn(ctx, workspaceID) + if err != nil { + return uuid.Nil, wrapErr(err) + } + return validatedAgentID, nil + } + resolveFastFailure := func() (DialResult, error) { + validatedAgentID, err := validateBinding() + if err != nil { + return DialResult{}, err + } + if validatedAgentID == agentID { + return dialAgent(agentID, false) + } + return dialAgent(validatedAgentID, true) + } + + timer := time.NewTimer(delay) + defer timer.Stop() + + select { + case result := <-results: + drained = true + if result.err == nil { + return resultForAgent(agentID, result, false), nil + } + return resolveFastFailure() + + case <-timer.C: + validatedAgentID, validationErr := validateFn(ctx, workspaceID) + if validationErr != nil || validatedAgentID == agentID { + // Validation could not prove the binding was stale, so keep waiting on + // the original dial. + return waitForOriginalDial(ctx) + } + // The original dial is stale. Cancel it first, then let the deferred drain + // release any late result while we dial the validated agent immediately. + dialCancel() + return dialAgent(validatedAgentID, true) + + case <-ctx.Done(): + if ready, ok := preferReadyOriginalDial(); ok { + return ready, nil + } + return DialResult{}, ctx.Err() + } +} diff --git a/coderd/x/chatd/dialvalidation_test.go b/coderd/x/chatd/dialvalidation_test.go new file mode 100644 index 0000000000..c2fea03acb --- /dev/null +++ b/coderd/x/chatd/dialvalidation_test.go @@ -0,0 +1,563 @@ +package chatd //nolint:testpackage // Uses internal symbols. + +import ( + "context" + "sync/atomic" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/codersdk/workspacesdk" + "github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock" + "github.com/coder/coder/v2/testutil" +) + +func TestDialWithLazyValidation_FastDial(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + agentID := uuid.New() + workspaceID := uuid.New() + conn := agentconnmock.NewMockAgentConn(ctrl) + + var releaseCalls atomic.Int32 + var validateCalls atomic.Int32 + + result, err := dialWithLazyValidation( + context.Background(), + agentID, + workspaceID, + func(_ context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) { + if id != agentID { + return nil, nil, xerrors.Errorf("unexpected agent ID %q", id) + } + return conn, func() { + releaseCalls.Add(1) + }, nil + }, + func(_ context.Context, id uuid.UUID) (uuid.UUID, error) { + validateCalls.Add(1) + return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id) + }, + time.Minute, + ) + require.NoError(t, err) + require.Same(t, conn, result.Conn) + require.Equal(t, agentID, result.AgentID) + require.False(t, result.WasSwitched) + require.EqualValues(t, 0, validateCalls.Load()) + + if result.Release != nil { + result.Release() + } + require.EqualValues(t, 1, releaseCalls.Load()) +} + +func TestDialWithLazyValidation_SlowDialSameAgent(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + agentID := uuid.New() + workspaceID := uuid.New() + conn := agentconnmock.NewMockAgentConn(ctrl) + unblockDial := make(chan struct{}) + + var releaseCalls atomic.Int32 + var validateCalls atomic.Int32 + + result, err := dialWithLazyValidation( + context.Background(), + agentID, + workspaceID, + func(ctx context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) { + if id != agentID { + return nil, nil, xerrors.Errorf("unexpected agent ID %q", id) + } + select { + case <-unblockDial: + return conn, func() { + releaseCalls.Add(1) + }, nil + case <-ctx.Done(): + return nil, nil, ctx.Err() + } + }, + func(_ context.Context, id uuid.UUID) (uuid.UUID, error) { + if id != workspaceID { + return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id) + } + validateCalls.Add(1) + close(unblockDial) + return agentID, nil + }, + 0, + ) + require.NoError(t, err) + require.Same(t, conn, result.Conn) + require.Equal(t, agentID, result.AgentID) + require.False(t, result.WasSwitched) + require.EqualValues(t, 1, validateCalls.Load()) + + if result.Release != nil { + result.Release() + } + require.EqualValues(t, 1, releaseCalls.Load()) +} + +func TestDialWithLazyValidation_SlowDialStaleAgent(t *testing.T) { + t.Parallel() + + t.Run("LateSuccessReleasesStaleConn", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + staleAgentID := uuid.New() + currentAgentID := uuid.New() + workspaceID := uuid.New() + staleConn := agentconnmock.NewMockAgentConn(ctrl) + currentConn := agentconnmock.NewMockAgentConn(ctrl) + + var dialCalls atomic.Int32 + var validateCalls atomic.Int32 + var staleReleaseCalls atomic.Int32 + var currentReleaseCalls atomic.Int32 + + result, err := dialWithLazyValidation( + context.Background(), + staleAgentID, + workspaceID, + func(ctx context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) { + dialCalls.Add(1) + switch id { + case staleAgentID: + <-ctx.Done() + return staleConn, func() { + staleReleaseCalls.Add(1) + }, nil + case currentAgentID: + return currentConn, func() { + currentReleaseCalls.Add(1) + }, nil + default: + return nil, nil, xerrors.Errorf("unexpected agent ID %q", id) + } + }, + func(_ context.Context, id uuid.UUID) (uuid.UUID, error) { + if id != workspaceID { + return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id) + } + validateCalls.Add(1) + return currentAgentID, nil + }, + 0, + ) + require.NoError(t, err) + require.Same(t, currentConn, result.Conn) + require.Equal(t, currentAgentID, result.AgentID) + require.True(t, result.WasSwitched) + require.Eventually(t, func() bool { + return dialCalls.Load() == 2 + }, testutil.WaitShort, testutil.IntervalFast) + require.EqualValues(t, 1, validateCalls.Load()) + require.Eventually(t, func() bool { + return staleReleaseCalls.Load() == 1 + }, testutil.WaitShort, testutil.IntervalFast) + + if result.Release != nil { + result.Release() + } + require.EqualValues(t, 1, currentReleaseCalls.Load()) + }) + + t.Run("CanceledFailureDoesNotReleaseStaleConn", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + staleAgentID := uuid.New() + currentAgentID := uuid.New() + workspaceID := uuid.New() + currentConn := agentconnmock.NewMockAgentConn(ctrl) + + var dialCalls atomic.Int32 + var validateCalls atomic.Int32 + var staleReleaseCalls atomic.Int32 + var currentReleaseCalls atomic.Int32 + + result, err := dialWithLazyValidation( + context.Background(), + staleAgentID, + workspaceID, + func(ctx context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) { + dialCalls.Add(1) + switch id { + case staleAgentID: + <-ctx.Done() + return nil, func() { + staleReleaseCalls.Add(1) + }, ctx.Err() + case currentAgentID: + return currentConn, func() { + currentReleaseCalls.Add(1) + }, nil + default: + return nil, nil, xerrors.Errorf("unexpected agent ID %q", id) + } + }, + func(_ context.Context, id uuid.UUID) (uuid.UUID, error) { + if id != workspaceID { + return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id) + } + validateCalls.Add(1) + return currentAgentID, nil + }, + 0, + ) + require.NoError(t, err) + require.Same(t, currentConn, result.Conn) + require.Equal(t, currentAgentID, result.AgentID) + require.True(t, result.WasSwitched) + require.Eventually(t, func() bool { + return dialCalls.Load() == 2 + }, testutil.WaitShort, testutil.IntervalFast) + require.EqualValues(t, 1, validateCalls.Load()) + require.EqualValues(t, 0, staleReleaseCalls.Load()) + + if result.Release != nil { + result.Release() + } + require.EqualValues(t, 1, currentReleaseCalls.Load()) + }) + + t.Run("SwitchDoesNotBlock", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + staleAgentID := uuid.New() + currentAgentID := uuid.New() + workspaceID := uuid.New() + staleConn := agentconnmock.NewMockAgentConn(ctrl) + currentConn := agentconnmock.NewMockAgentConn(ctrl) + staleDialStarted := make(chan struct{}) + allowStaleReturn := make(chan struct{}) + + var dialCalls atomic.Int32 + var validateCalls atomic.Int32 + var staleReleaseCalls atomic.Int32 + var currentReleaseCalls atomic.Int32 + var staleReturnReleased atomic.Bool + releaseStaleReturn := func() { + if staleReturnReleased.CompareAndSwap(false, true) { + close(allowStaleReturn) + } + } + defer releaseStaleReturn() + + resultCh := make(chan DialResult, 1) + errCh := make(chan error, 1) + go func() { + result, err := dialWithLazyValidation( + context.Background(), + staleAgentID, + workspaceID, + func(_ context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) { + dialCalls.Add(1) + switch id { + case staleAgentID: + close(staleDialStarted) + <-allowStaleReturn + return staleConn, func() { + staleReleaseCalls.Add(1) + }, nil + case currentAgentID: + return currentConn, func() { + currentReleaseCalls.Add(1) + }, nil + default: + return nil, nil, xerrors.Errorf("unexpected agent ID %q", id) + } + }, + func(_ context.Context, id uuid.UUID) (uuid.UUID, error) { + if id != workspaceID { + return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id) + } + <-staleDialStarted + validateCalls.Add(1) + return currentAgentID, nil + }, + 0, + ) + if err != nil { + errCh <- err + return + } + resultCh <- result + }() + + var result DialResult + select { + case err := <-errCh: + require.NoError(t, err) + case result = <-resultCh: + require.Same(t, currentConn, result.Conn) + require.Equal(t, currentAgentID, result.AgentID) + require.True(t, result.WasSwitched) + releaseStaleReturn() + case <-time.After(testutil.WaitShort): + t.Fatal("dialWithLazyValidation blocked on stale dial cleanup") + } + + require.EqualValues(t, 2, dialCalls.Load()) + require.EqualValues(t, 1, validateCalls.Load()) + require.Eventually(t, func() bool { + return staleReleaseCalls.Load() == 1 + }, testutil.WaitShort, testutil.IntervalFast) + + if result.Release != nil { + result.Release() + } + require.EqualValues(t, 1, currentReleaseCalls.Load()) + }) +} + +func TestDialWithLazyValidation_FastFailure(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + staleAgentID := uuid.New() + currentAgentID := uuid.New() + workspaceID := uuid.New() + currentConn := agentconnmock.NewMockAgentConn(ctrl) + + var dialCalls atomic.Int32 + var validateCalls atomic.Int32 + var currentReleaseCalls atomic.Int32 + + result, err := dialWithLazyValidation( + context.Background(), + staleAgentID, + workspaceID, + func(_ context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) { + switch dialCalls.Add(1) { + case 1: + if id != staleAgentID { + return nil, nil, xerrors.Errorf("unexpected agent ID %q", id) + } + return nil, nil, xerrors.New("dial failed") + case 2: + if id != currentAgentID { + return nil, nil, xerrors.Errorf("unexpected agent ID %q", id) + } + return currentConn, func() { + currentReleaseCalls.Add(1) + }, nil + default: + return nil, nil, xerrors.New("unexpected dial call") + } + }, + func(_ context.Context, id uuid.UUID) (uuid.UUID, error) { + if id != workspaceID { + return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id) + } + validateCalls.Add(1) + return currentAgentID, nil + }, + time.Minute, + ) + require.NoError(t, err) + require.Same(t, currentConn, result.Conn) + require.Equal(t, currentAgentID, result.AgentID) + require.True(t, result.WasSwitched) + require.EqualValues(t, 2, dialCalls.Load()) + require.EqualValues(t, 1, validateCalls.Load()) + + if result.Release != nil { + result.Release() + } + require.EqualValues(t, 1, currentReleaseCalls.Load()) +} + +func TestDialWithLazyValidation_FastFailureSameAgent(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + agentID := uuid.New() + workspaceID := uuid.New() + conn := agentconnmock.NewMockAgentConn(ctrl) + + var dialCalls atomic.Int32 + var releaseCalls atomic.Int32 + var validateCalls atomic.Int32 + + result, err := dialWithLazyValidation( + context.Background(), + agentID, + workspaceID, + func(_ context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) { + if id != agentID { + return nil, nil, xerrors.Errorf("unexpected agent ID %q", id) + } + switch dialCalls.Add(1) { + case 1: + return nil, nil, xerrors.New("dial failed") + case 2: + return conn, func() { + releaseCalls.Add(1) + }, nil + default: + return nil, nil, xerrors.New("unexpected dial call") + } + }, + func(_ context.Context, id uuid.UUID) (uuid.UUID, error) { + if id != workspaceID { + return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id) + } + validateCalls.Add(1) + return agentID, nil + }, + time.Minute, + ) + require.NoError(t, err) + require.Same(t, conn, result.Conn) + require.Equal(t, agentID, result.AgentID) + require.False(t, result.WasSwitched) + require.EqualValues(t, 2, dialCalls.Load()) + require.EqualValues(t, 1, validateCalls.Load()) + + if result.Release != nil { + result.Release() + } + require.EqualValues(t, 1, releaseCalls.Load()) +} + +func TestDialWithLazyValidation_FastFailureSameAgentRetryFails(t *testing.T) { + t.Parallel() + + agentID := uuid.New() + workspaceID := uuid.New() + + var dialCalls atomic.Int32 + var validateCalls atomic.Int32 + + _, err := dialWithLazyValidation( + context.Background(), + agentID, + workspaceID, + func(_ context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) { + if id != agentID { + return nil, nil, xerrors.Errorf("unexpected agent ID %q", id) + } + switch dialCalls.Add(1) { + case 1: + return nil, nil, xerrors.New("dial failed") + case 2: + return nil, nil, xerrors.New("retry failed") + default: + return nil, nil, xerrors.New("unexpected dial call") + } + }, + func(_ context.Context, id uuid.UUID) (uuid.UUID, error) { + if id != workspaceID { + return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id) + } + validateCalls.Add(1) + return agentID, nil + }, + time.Minute, + ) + require.EqualError(t, err, "dial with lazy validation: retry failed") + require.EqualValues(t, 2, dialCalls.Load()) + require.EqualValues(t, 1, validateCalls.Load()) +} + +func TestDialWithLazyValidation_ValidationError(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + agentID := uuid.New() + workspaceID := uuid.New() + conn := agentconnmock.NewMockAgentConn(ctrl) + unblockDial := make(chan struct{}) + + var releaseCalls atomic.Int32 + var validateCalls atomic.Int32 + + result, err := dialWithLazyValidation( + context.Background(), + agentID, + workspaceID, + func(ctx context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) { + if id != agentID { + return nil, nil, xerrors.Errorf("unexpected agent ID %q", id) + } + select { + case <-unblockDial: + return conn, func() { + releaseCalls.Add(1) + }, nil + case <-ctx.Done(): + return nil, nil, ctx.Err() + } + }, + func(_ context.Context, id uuid.UUID) (uuid.UUID, error) { + if id != workspaceID { + return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id) + } + validateCalls.Add(1) + // Validation fails — code should fall back to waiting + // for the original dial. + close(unblockDial) + return uuid.Nil, xerrors.New("db connection reset") + }, + 0, + ) + require.NoError(t, err) + require.Same(t, conn, result.Conn) + require.Equal(t, agentID, result.AgentID) + require.False(t, result.WasSwitched) + require.EqualValues(t, 1, validateCalls.Load()) + + if result.Release != nil { + result.Release() + } + require.EqualValues(t, 1, releaseCalls.Load()) +} + +func TestDialWithLazyValidation_ContextCanceled(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + agentID := uuid.New() + workspaceID := uuid.New() + + var validateCalls atomic.Int32 + + _, err := dialWithLazyValidation( + ctx, + agentID, + workspaceID, + func(ctx context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) { + if id != agentID { + return nil, nil, xerrors.Errorf("unexpected agent ID %q", id) + } + <-ctx.Done() + return nil, nil, ctx.Err() + }, + func(_ context.Context, id uuid.UUID) (uuid.UUID, error) { + if id != workspaceID { + return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id) + } + validateCalls.Add(1) + cancel() + return agentID, nil + }, + 0, + ) + require.ErrorIs(t, err, context.Canceled) + require.EqualValues(t, 1, validateCalls.Load()) +} diff --git a/coderd/x/chatd/subagent.go b/coderd/x/chatd/subagent.go index 0475d9e8a2..cd1d71a0e1 100644 --- a/coderd/x/chatd/subagent.go +++ b/coderd/x/chatd/subagent.go @@ -313,6 +313,8 @@ func (p *Server) subagentTools(ctx context.Context, currentChat func() database. childChat, err := p.CreateChat(ctx, CreateOptions{ OwnerID: parent.OwnerID, WorkspaceID: parent.WorkspaceID, + BuildID: parent.BuildID, + AgentID: parent.AgentID, ParentChatID: uuid.NullUUID{ UUID: parent.ID, Valid: true, @@ -383,6 +385,8 @@ func (p *Server) createChildSubagentChat( child, err := p.CreateChat(ctx, CreateOptions{ OwnerID: parent.OwnerID, WorkspaceID: parent.WorkspaceID, + BuildID: parent.BuildID, + AgentID: parent.AgentID, ParentChatID: uuid.NullUUID{ UUID: parent.ID, Valid: true, diff --git a/coderd/x/chatd/subagent_internal_test.go b/coderd/x/chatd/subagent_internal_test.go index 55ed19cf94..fb31ba176c 100644 --- a/coderd/x/chatd/subagent_internal_test.go +++ b/coderd/x/chatd/subagent_internal_test.go @@ -149,6 +149,45 @@ func seedInternalChatDeps( return user, model } +func seedWorkspaceBinding( + t *testing.T, + db database.Store, + userID uuid.UUID, +) (database.WorkspaceTable, database.WorkspaceBuild, database.WorkspaceAgent) { + t.Helper() + + org := dbgen.Organization(t, db, database.Organization{}) + tv := dbgen.TemplateVersion(t, db, database.TemplateVersion{ + OrganizationID: org.ID, + CreatedBy: userID, + }) + tpl := dbgen.Template(t, db, database.Template{ + CreatedBy: userID, + OrganizationID: org.ID, + ActiveVersionID: tv.ID, + }) + workspace := dbgen.Workspace(t, db, database.WorkspaceTable{ + TemplateID: tpl.ID, + OwnerID: userID, + OrganizationID: org.ID, + }) + job := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{ + InitiatorID: userID, + OrganizationID: org.ID, + }) + build := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{ + TemplateVersionID: tv.ID, + WorkspaceID: workspace.ID, + JobID: job.ID, + }) + resource := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{ + Transition: database.WorkspaceTransitionStart, + JobID: job.ID, + }) + agent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{ResourceID: resource.ID}) + return workspace, build, agent +} + // findToolByName returns the tool with the given name from the // slice, or nil if no match is found. func findToolByName(tools []fantasy.AgentTool, name string) fantasy.AgentTool { @@ -165,6 +204,49 @@ func chatdTestContext(t *testing.T) context.Context { return dbauthz.AsChatd(testutil.Context(t, testutil.WaitLong)) } +func TestCreateChildSubagentChatInheritsWorkspaceBinding(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{}) + + ctx := chatdTestContext(t) + user, model := seedInternalChatDeps(ctx, t, db) + workspace, build, agent := seedWorkspaceBinding(t, db, user.ID) + + parent, err := server.CreateChat(ctx, CreateOptions{ + OwnerID: user.ID, + WorkspaceID: uuid.NullUUID{ + UUID: workspace.ID, + Valid: true, + }, + BuildID: uuid.NullUUID{ + UUID: build.ID, + Valid: true, + }, + AgentID: uuid.NullUUID{ + UUID: agent.ID, + Valid: true, + }, + Title: "bound-parent", + ModelConfigID: model.ID, + InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")}, + }) + require.NoError(t, err) + + parentChat, err := db.GetChatByID(ctx, parent.ID) + require.NoError(t, err) + + child, err := server.createChildSubagentChat(ctx, parentChat, "inspect bindings", "") + require.NoError(t, err) + + childChat, err := db.GetChatByID(ctx, child.ID) + require.NoError(t, err) + require.Equal(t, parentChat.WorkspaceID, childChat.WorkspaceID) + require.Equal(t, parentChat.BuildID, childChat.BuildID) + require.Equal(t, parentChat.AgentID, childChat.AgentID) +} + func TestSpawnComputerUseAgent_NoAnthropicProvider(t *testing.T) { t.Parallel() @@ -292,13 +374,26 @@ func TestSpawnComputerUseAgent_UsesComputerUseModelNotParent(t *testing.T) { ctx := chatdTestContext(t) user, model := seedInternalChatDeps(ctx, t, db) + workspace, build, agent := seedWorkspaceBinding(t, db, user.ID) // The parent uses an OpenAI model. require.Equal(t, "openai", model.Provider, "seed helper must create an OpenAI model") parent, err := server.CreateChat(ctx, CreateOptions{ - OwnerID: user.ID, + OwnerID: user.ID, + WorkspaceID: uuid.NullUUID{ + UUID: workspace.ID, + Valid: true, + }, + BuildID: uuid.NullUUID{ + UUID: build.ID, + Valid: true, + }, + AgentID: uuid.NullUUID{ + UUID: agent.ID, + Valid: true, + }, Title: "parent-openai", ModelConfigID: model.ID, InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")}, @@ -332,6 +427,10 @@ func TestSpawnComputerUseAgent_UsesComputerUseModelNotParent(t *testing.T) { childChat, err := db.GetChatByID(ctx, childID) require.NoError(t, err) + require.Equal(t, parentChat.WorkspaceID, childChat.WorkspaceID) + require.Equal(t, parentChat.BuildID, childChat.BuildID) + require.Equal(t, parentChat.AgentID, childChat.AgentID) + // The child must have Mode=computer_use which causes // runChat to override the model to the predefined computer // use model instead of using the parent's model config. diff --git a/codersdk/chats.go b/codersdk/chats.go index 355c10ed11..b4eb35d482 100644 --- a/codersdk/chats.go +++ b/codersdk/chats.go @@ -49,6 +49,8 @@ type Chat struct { ID uuid.UUID `json:"id" format:"uuid"` OwnerID uuid.UUID `json:"owner_id" format:"uuid"` WorkspaceID *uuid.UUID `json:"workspace_id,omitempty" format:"uuid"` + BuildID *uuid.UUID `json:"build_id,omitempty" format:"uuid"` + AgentID *uuid.UUID `json:"agent_id,omitempty" format:"uuid"` ParentChatID *uuid.UUID `json:"parent_chat_id,omitempty" format:"uuid"` RootChatID *uuid.UUID `json:"root_chat_id,omitempty" format:"uuid"` LastModelConfigID uuid.UUID `json:"last_model_config_id" format:"uuid"` diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index c1fd4cc6c0..6b4a5513dc 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -1089,6 +1089,8 @@ export interface Chat { readonly id: string; readonly owner_id: string; readonly workspace_id?: string; + readonly build_id?: string; + readonly agent_id?: string; readonly parent_chat_id?: string; readonly root_chat_id?: string; readonly last_model_config_id: string;