diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 2d704add79..c2b29ce64e 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -5619,6 +5619,18 @@ func (q *querier) UpdateAPIKeyByID(ctx context.Context, arg database.UpdateAPIKe return update(q.log, q.auth, fetch, q.db.UpdateAPIKeyByID)(ctx, arg) } +func (q *querier) UpdateChatBuildAgentBinding(ctx context.Context, arg database.UpdateChatBuildAgentBindingParams) (database.Chat, error) { + chat, err := q.db.GetChatByID(ctx, arg.ID) + if err != nil { + return database.Chat{}, err + } + if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil { + return database.Chat{}, err + } + + return q.db.UpdateChatBuildAgentBinding(ctx, arg) +} + func (q *querier) UpdateChatByID(ctx context.Context, arg database.UpdateChatByIDParams) (database.Chat, error) { chat, err := q.db.GetChatByID(ctx, arg.ID) if err != nil { @@ -5706,7 +5718,7 @@ func (q *querier) UpdateChatStatus(ctx context.Context, arg database.UpdateChatS return q.db.UpdateChatStatus(ctx, arg) } -func (q *querier) UpdateChatWorkspace(ctx context.Context, arg database.UpdateChatWorkspaceParams) (database.Chat, error) { +func (q *querier) UpdateChatWorkspaceBinding(ctx context.Context, arg database.UpdateChatWorkspaceBindingParams) (database.Chat, error) { chat, err := q.db.GetChatByID(ctx, arg.ID) if err != nil { return database.Chat{}, err @@ -5715,15 +5727,7 @@ func (q *querier) UpdateChatWorkspace(ctx context.Context, arg database.UpdateCh return database.Chat{}, err } - // UpdateChatWorkspace is manually implemented for chat tables and may not be - // present on every wrapped store interface yet. - chatWorkspaceUpdater, ok := q.db.(interface { - UpdateChatWorkspace(context.Context, database.UpdateChatWorkspaceParams) (database.Chat, error) - }) - if !ok { - return database.Chat{}, xerrors.New("update chat workspace is not implemented by wrapped store") - } - return chatWorkspaceUpdater.UpdateChatWorkspace(ctx, arg) + return q.db.UpdateChatWorkspaceBinding(ctx, arg) } func (q *querier) UpdateCryptoKeyDeletesAt(ctx context.Context, arg database.UpdateCryptoKeyDeletesAtParams) (database.CryptoKey, error) { diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 31765996dc..4d657e2f12 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -819,15 +819,29 @@ func (s *MethodTestSuite) TestChats() { dbm.EXPECT().UpdateChatStatus(gomock.Any(), arg).Return(chat, nil).AnyTimes() check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(chat) })) - s.Run("UpdateChatWorkspace", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + s.Run("UpdateChatBuildAgentBinding", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { chat := testutil.Fake(s.T(), faker, database.Chat{}) - arg := database.UpdateChatWorkspaceParams{ - ID: chat.ID, - WorkspaceID: uuid.NullUUID{UUID: uuid.New(), Valid: true}, + arg := database.UpdateChatBuildAgentBindingParams{ + ID: chat.ID, + BuildID: uuid.NullUUID{UUID: uuid.New(), Valid: true}, + AgentID: uuid.NullUUID{UUID: uuid.New(), Valid: true}, } updatedChat := testutil.Fake(s.T(), faker, database.Chat{ID: chat.ID}) dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes() - dbm.EXPECT().UpdateChatWorkspace(gomock.Any(), arg).Return(updatedChat, nil).AnyTimes() + dbm.EXPECT().UpdateChatBuildAgentBinding(gomock.Any(), arg).Return(updatedChat, nil).AnyTimes() + check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(updatedChat) + })) + s.Run("UpdateChatWorkspaceBinding", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + chat := testutil.Fake(s.T(), faker, database.Chat{}) + arg := database.UpdateChatWorkspaceBindingParams{ + ID: chat.ID, + WorkspaceID: uuid.NullUUID{UUID: uuid.New(), Valid: true}, + BuildID: uuid.NullUUID{UUID: uuid.New(), Valid: true}, + AgentID: uuid.NullUUID{UUID: uuid.New(), Valid: true}, + } + updatedChat := testutil.Fake(s.T(), faker, database.Chat{ID: chat.ID}) + dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes() + dbm.EXPECT().UpdateChatWorkspaceBinding(gomock.Any(), arg).Return(updatedChat, nil).AnyTimes() check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(updatedChat) })) s.Run("UnsetDefaultChatModelConfigs", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index e1f6c1e73b..1353d11d1e 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -4000,6 +4000,14 @@ func (m queryMetricsStore) UpdateAPIKeyByID(ctx context.Context, arg database.Up return r0 } +func (m queryMetricsStore) UpdateChatBuildAgentBinding(ctx context.Context, arg database.UpdateChatBuildAgentBindingParams) (database.Chat, error) { + start := time.Now() + r0, r1 := m.s.UpdateChatBuildAgentBinding(ctx, arg) + m.queryLatencies.WithLabelValues("UpdateChatBuildAgentBinding").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatBuildAgentBinding").Inc() + return r0, r1 +} + func (m queryMetricsStore) UpdateChatByID(ctx context.Context, arg database.UpdateChatByIDParams) (database.Chat, error) { start := time.Now() r0, r1 := m.s.UpdateChatByID(ctx, arg) @@ -4064,11 +4072,11 @@ func (m queryMetricsStore) UpdateChatStatus(ctx context.Context, arg database.Up return r0, r1 } -func (m queryMetricsStore) UpdateChatWorkspace(ctx context.Context, arg database.UpdateChatWorkspaceParams) (database.Chat, error) { +func (m queryMetricsStore) UpdateChatWorkspaceBinding(ctx context.Context, arg database.UpdateChatWorkspaceBindingParams) (database.Chat, error) { start := time.Now() - r0, r1 := m.s.UpdateChatWorkspace(ctx, arg) - m.queryLatencies.WithLabelValues("UpdateChatWorkspace").Observe(time.Since(start).Seconds()) - m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatWorkspace").Inc() + r0, r1 := m.s.UpdateChatWorkspaceBinding(ctx, arg) + m.queryLatencies.WithLabelValues("UpdateChatWorkspaceBinding").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatWorkspaceBinding").Inc() return r0, r1 } diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index 50d7bd9e09..13622b6a1a 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -7552,6 +7552,21 @@ func (mr *MockStoreMockRecorder) UpdateAPIKeyByID(ctx, arg any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAPIKeyByID", reflect.TypeOf((*MockStore)(nil).UpdateAPIKeyByID), ctx, arg) } +// UpdateChatBuildAgentBinding mocks base method. +func (m *MockStore) UpdateChatBuildAgentBinding(ctx context.Context, arg database.UpdateChatBuildAgentBindingParams) (database.Chat, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateChatBuildAgentBinding", ctx, arg) + ret0, _ := ret[0].(database.Chat) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpdateChatBuildAgentBinding indicates an expected call of UpdateChatBuildAgentBinding. +func (mr *MockStoreMockRecorder) UpdateChatBuildAgentBinding(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatBuildAgentBinding", reflect.TypeOf((*MockStore)(nil).UpdateChatBuildAgentBinding), ctx, arg) +} + // UpdateChatByID mocks base method. func (m *MockStore) UpdateChatByID(ctx context.Context, arg database.UpdateChatByIDParams) (database.Chat, error) { m.ctrl.T.Helper() @@ -7672,19 +7687,19 @@ func (mr *MockStoreMockRecorder) UpdateChatStatus(ctx, arg any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatStatus", reflect.TypeOf((*MockStore)(nil).UpdateChatStatus), ctx, arg) } -// UpdateChatWorkspace mocks base method. -func (m *MockStore) UpdateChatWorkspace(ctx context.Context, arg database.UpdateChatWorkspaceParams) (database.Chat, error) { +// UpdateChatWorkspaceBinding mocks base method. +func (m *MockStore) UpdateChatWorkspaceBinding(ctx context.Context, arg database.UpdateChatWorkspaceBindingParams) (database.Chat, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateChatWorkspace", ctx, arg) + ret := m.ctrl.Call(m, "UpdateChatWorkspaceBinding", ctx, arg) ret0, _ := ret[0].(database.Chat) ret1, _ := ret[1].(error) return ret0, ret1 } -// UpdateChatWorkspace indicates an expected call of UpdateChatWorkspace. -func (mr *MockStoreMockRecorder) UpdateChatWorkspace(ctx, arg any) *gomock.Call { +// UpdateChatWorkspaceBinding indicates an expected call of UpdateChatWorkspaceBinding. +func (mr *MockStoreMockRecorder) UpdateChatWorkspaceBinding(ctx, arg any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatWorkspace", reflect.TypeOf((*MockStore)(nil).UpdateChatWorkspace), ctx, arg) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatWorkspaceBinding", reflect.TypeOf((*MockStore)(nil).UpdateChatWorkspaceBinding), ctx, arg) } // UpdateCryptoKeyDeletesAt mocks base method. diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index 1c3d158b59..154d22cd62 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -1399,7 +1399,9 @@ CREATE TABLE chats ( last_error text, mode chat_mode, mcp_server_ids uuid[] DEFAULT '{}'::uuid[] NOT NULL, - labels jsonb DEFAULT '{}'::jsonb NOT NULL + labels jsonb DEFAULT '{}'::jsonb NOT NULL, + build_id uuid, + agent_id uuid ); CREATE TABLE connection_logs ( @@ -4033,6 +4035,12 @@ ALTER TABLE ONLY chat_providers ALTER TABLE ONLY chat_queued_messages ADD CONSTRAINT chat_queued_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE; +ALTER TABLE ONLY chats + ADD CONSTRAINT chats_agent_id_fkey FOREIGN KEY (agent_id) REFERENCES workspace_agents(id) ON DELETE SET NULL; + +ALTER TABLE ONLY chats + ADD CONSTRAINT chats_build_id_fkey FOREIGN KEY (build_id) REFERENCES workspace_builds(id) ON DELETE SET NULL; + ALTER TABLE ONLY chats ADD CONSTRAINT chats_last_model_config_id_fkey FOREIGN KEY (last_model_config_id) REFERENCES chat_model_configs(id); diff --git a/coderd/database/foreign_key_constraint.go b/coderd/database/foreign_key_constraint.go index b6095e0547..4f7ec37f0b 100644 --- a/coderd/database/foreign_key_constraint.go +++ b/coderd/database/foreign_key_constraint.go @@ -20,6 +20,8 @@ const ( ForeignKeyChatProvidersAPIKeyKeyID ForeignKeyConstraint = "chat_providers_api_key_key_id_fkey" // ALTER TABLE ONLY chat_providers ADD CONSTRAINT chat_providers_api_key_key_id_fkey FOREIGN KEY (api_key_key_id) REFERENCES dbcrypt_keys(active_key_digest); ForeignKeyChatProvidersCreatedBy ForeignKeyConstraint = "chat_providers_created_by_fkey" // ALTER TABLE ONLY chat_providers ADD CONSTRAINT chat_providers_created_by_fkey FOREIGN KEY (created_by) REFERENCES users(id); ForeignKeyChatQueuedMessagesChatID ForeignKeyConstraint = "chat_queued_messages_chat_id_fkey" // ALTER TABLE ONLY chat_queued_messages ADD CONSTRAINT chat_queued_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE; + ForeignKeyChatsAgentID ForeignKeyConstraint = "chats_agent_id_fkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_agent_id_fkey FOREIGN KEY (agent_id) REFERENCES workspace_agents(id) ON DELETE SET NULL; + ForeignKeyChatsBuildID ForeignKeyConstraint = "chats_build_id_fkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_build_id_fkey FOREIGN KEY (build_id) REFERENCES workspace_builds(id) ON DELETE SET NULL; ForeignKeyChatsLastModelConfigID ForeignKeyConstraint = "chats_last_model_config_id_fkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_last_model_config_id_fkey FOREIGN KEY (last_model_config_id) REFERENCES chat_model_configs(id); ForeignKeyChatsOwnerID ForeignKeyConstraint = "chats_owner_id_fkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_owner_id_fkey FOREIGN KEY (owner_id) REFERENCES users(id) ON DELETE CASCADE; ForeignKeyChatsParentChatID ForeignKeyConstraint = "chats_parent_chat_id_fkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_parent_chat_id_fkey FOREIGN KEY (parent_chat_id) REFERENCES chats(id) ON DELETE SET NULL; diff --git a/coderd/database/migrations/000452_chat_workspace_binding.down.sql b/coderd/database/migrations/000452_chat_workspace_binding.down.sql new file mode 100644 index 0000000000..c192261389 --- /dev/null +++ b/coderd/database/migrations/000452_chat_workspace_binding.down.sql @@ -0,0 +1,3 @@ +ALTER TABLE chats + DROP COLUMN IF EXISTS build_id, + DROP COLUMN IF EXISTS agent_id; diff --git a/coderd/database/migrations/000452_chat_workspace_binding.up.sql b/coderd/database/migrations/000452_chat_workspace_binding.up.sql new file mode 100644 index 0000000000..8788ac93f0 --- /dev/null +++ b/coderd/database/migrations/000452_chat_workspace_binding.up.sql @@ -0,0 +1,3 @@ +ALTER TABLE chats + ADD COLUMN build_id UUID REFERENCES workspace_builds(id) ON DELETE SET NULL, + ADD COLUMN agent_id UUID REFERENCES workspace_agents(id) ON DELETE SET NULL; diff --git a/coderd/database/modelqueries.go b/coderd/database/modelqueries.go index e62cb25922..33d181053d 100644 --- a/coderd/database/modelqueries.go +++ b/coderd/database/modelqueries.go @@ -791,6 +791,8 @@ func (q *sqlQuerier) GetAuthorizedChats(ctx context.Context, arg GetChatsParams, &i.Mode, pq.Array(&i.MCPServerIDs), &i.Labels, + &i.BuildID, + &i.AgentID, ); err != nil { return nil, err } diff --git a/coderd/database/models.go b/coderd/database/models.go index 08b06cd811..d1c614923b 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -4171,6 +4171,8 @@ type Chat struct { Mode NullChatMode `db:"mode" json:"mode"` MCPServerIDs []uuid.UUID `db:"mcp_server_ids" json:"mcp_server_ids"` Labels StringMap `db:"labels" json:"labels"` + BuildID uuid.NullUUID `db:"build_id" json:"build_id"` + AgentID uuid.NullUUID `db:"agent_id" json:"agent_id"` } type ChatDiffStatus struct { diff --git a/coderd/database/querier.go b/coderd/database/querier.go index cf6cc78ce2..a1fc9cc05e 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -819,6 +819,7 @@ type sqlcQuerier interface { UnsetDefaultChatModelConfigs(ctx context.Context) error UpdateAIBridgeInterceptionEnded(ctx context.Context, arg UpdateAIBridgeInterceptionEndedParams) (AIBridgeInterception, error) UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error + UpdateChatBuildAgentBinding(ctx context.Context, arg UpdateChatBuildAgentBindingParams) (Chat, error) UpdateChatByID(ctx context.Context, arg UpdateChatByIDParams) (Chat, error) // Bumps the heartbeat timestamp for a running chat so that other // replicas know the worker is still alive. @@ -829,7 +830,7 @@ type sqlcQuerier interface { UpdateChatModelConfig(ctx context.Context, arg UpdateChatModelConfigParams) (ChatModelConfig, error) UpdateChatProvider(ctx context.Context, arg UpdateChatProviderParams) (ChatProvider, error) UpdateChatStatus(ctx context.Context, arg UpdateChatStatusParams) (Chat, error) - UpdateChatWorkspace(ctx context.Context, arg UpdateChatWorkspaceParams) (Chat, error) + UpdateChatWorkspaceBinding(ctx context.Context, arg UpdateChatWorkspaceBindingParams) (Chat, error) UpdateCryptoKeyDeletesAt(ctx context.Context, arg UpdateCryptoKeyDeletesAtParams) (CryptoKey, error) UpdateCustomRole(ctx context.Context, arg UpdateCustomRoleParams) (CustomRole, error) UpdateExternalAuthLink(ctx context.Context, arg UpdateExternalAuthLinkParams) (ExternalAuthLink, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index d276c524fe..43f85ce750 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -3823,7 +3823,7 @@ WHERE $3::int ) RETURNING - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id ` type AcquireChatsParams struct { @@ -3862,6 +3862,8 @@ func (q *sqlQuerier) AcquireChats(ctx context.Context, arg AcquireChatsParams) ( &i.Mode, pq.Array(&i.MCPServerIDs), &i.Labels, + &i.BuildID, + &i.AgentID, ); err != nil { return nil, err } @@ -4095,7 +4097,7 @@ func (q *sqlQuerier) DeleteChatUsageLimitUserOverride(ctx context.Context, userI const getChatByID = `-- name: GetChatByID :one SELECT - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id FROM chats WHERE @@ -4124,12 +4126,14 @@ func (q *sqlQuerier) GetChatByID(ctx context.Context, id uuid.UUID) (Chat, error &i.Mode, pq.Array(&i.MCPServerIDs), &i.Labels, + &i.BuildID, + &i.AgentID, ) return i, err } const getChatByIDForUpdate = `-- name: GetChatByIDForUpdate :one -SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels FROM chats WHERE id = $1::uuid FOR UPDATE +SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id FROM chats WHERE id = $1::uuid FOR UPDATE ` func (q *sqlQuerier) GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (Chat, error) { @@ -4154,6 +4158,8 @@ func (q *sqlQuerier) GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (Ch &i.Mode, pq.Array(&i.MCPServerIDs), &i.Labels, + &i.BuildID, + &i.AgentID, ) return i, err } @@ -4998,7 +5004,7 @@ func (q *sqlQuerier) GetChatUsageLimitUserOverride(ctx context.Context, userID u const getChats = `-- name: GetChats :many SELECT - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id FROM chats WHERE @@ -5089,6 +5095,8 @@ func (q *sqlQuerier) GetChats(ctx context.Context, arg GetChatsParams) ([]Chat, &i.Mode, pq.Array(&i.MCPServerIDs), &i.Labels, + &i.BuildID, + &i.AgentID, ); err != nil { return nil, err } @@ -5154,7 +5162,7 @@ func (q *sqlQuerier) GetLastChatMessageByRole(ctx context.Context, arg GetLastCh const getStaleChats = `-- name: GetStaleChats :many SELECT - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id FROM chats WHERE @@ -5192,6 +5200,8 @@ func (q *sqlQuerier) GetStaleChats(ctx context.Context, staleThreshold time.Time &i.Mode, pq.Array(&i.MCPServerIDs), &i.Labels, + &i.BuildID, + &i.AgentID, ); err != nil { return nil, err } @@ -5250,6 +5260,8 @@ const insertChat = `-- name: InsertChat :one INSERT INTO chats ( owner_id, workspace_id, + build_id, + agent_id, parent_chat_id, root_chat_id, last_model_config_id, @@ -5263,18 +5275,22 @@ INSERT INTO chats ( $3::uuid, $4::uuid, $5::uuid, - $6::text, - $7::chat_mode, - COALESCE($8::uuid[], '{}'::uuid[]), - COALESCE($9::jsonb, '{}'::jsonb) + $6::uuid, + $7::uuid, + $8::text, + $9::chat_mode, + COALESCE($10::uuid[], '{}'::uuid[]), + COALESCE($11::jsonb, '{}'::jsonb) ) RETURNING - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id ` type InsertChatParams struct { OwnerID uuid.UUID `db:"owner_id" json:"owner_id"` WorkspaceID uuid.NullUUID `db:"workspace_id" json:"workspace_id"` + BuildID uuid.NullUUID `db:"build_id" json:"build_id"` + AgentID uuid.NullUUID `db:"agent_id" json:"agent_id"` ParentChatID uuid.NullUUID `db:"parent_chat_id" json:"parent_chat_id"` RootChatID uuid.NullUUID `db:"root_chat_id" json:"root_chat_id"` LastModelConfigID uuid.UUID `db:"last_model_config_id" json:"last_model_config_id"` @@ -5288,6 +5304,8 @@ func (q *sqlQuerier) InsertChat(ctx context.Context, arg InsertChatParams) (Chat row := q.db.QueryRowContext(ctx, insertChat, arg.OwnerID, arg.WorkspaceID, + arg.BuildID, + arg.AgentID, arg.ParentChatID, arg.RootChatID, arg.LastModelConfigID, @@ -5316,6 +5334,8 @@ func (q *sqlQuerier) InsertChat(ctx context.Context, arg InsertChatParams) (Chat &i.Mode, pq.Array(&i.MCPServerIDs), &i.Labels, + &i.BuildID, + &i.AgentID, ) return i, err } @@ -5702,6 +5722,50 @@ func (q *sqlQuerier) UnarchiveChatByID(ctx context.Context, id uuid.UUID) error return err } +const updateChatBuildAgentBinding = `-- name: UpdateChatBuildAgentBinding :one +UPDATE chats SET + build_id = $1::uuid, + agent_id = $2::uuid, + updated_at = NOW() +WHERE + id = $3::uuid +RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id +` + +type UpdateChatBuildAgentBindingParams struct { + BuildID uuid.NullUUID `db:"build_id" json:"build_id"` + AgentID uuid.NullUUID `db:"agent_id" json:"agent_id"` + ID uuid.UUID `db:"id" json:"id"` +} + +func (q *sqlQuerier) UpdateChatBuildAgentBinding(ctx context.Context, arg UpdateChatBuildAgentBindingParams) (Chat, error) { + row := q.db.QueryRowContext(ctx, updateChatBuildAgentBinding, arg.BuildID, arg.AgentID, arg.ID) + var i Chat + err := row.Scan( + &i.ID, + &i.OwnerID, + &i.WorkspaceID, + &i.Title, + &i.Status, + &i.WorkerID, + &i.StartedAt, + &i.HeartbeatAt, + &i.CreatedAt, + &i.UpdatedAt, + &i.ParentChatID, + &i.RootChatID, + &i.LastModelConfigID, + &i.Archived, + &i.LastError, + &i.Mode, + pq.Array(&i.MCPServerIDs), + &i.Labels, + &i.BuildID, + &i.AgentID, + ) + return i, err +} + const updateChatByID = `-- name: UpdateChatByID :one UPDATE chats @@ -5711,7 +5775,7 @@ SET WHERE id = $2::uuid RETURNING - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id ` type UpdateChatByIDParams struct { @@ -5741,6 +5805,8 @@ func (q *sqlQuerier) UpdateChatByID(ctx context.Context, arg UpdateChatByIDParam &i.Mode, pq.Array(&i.MCPServerIDs), &i.Labels, + &i.BuildID, + &i.AgentID, ) return i, err } @@ -5780,7 +5846,7 @@ SET WHERE id = $2::uuid RETURNING - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id ` type UpdateChatLabelsByIDParams struct { @@ -5810,6 +5876,8 @@ func (q *sqlQuerier) UpdateChatLabelsByID(ctx context.Context, arg UpdateChatLab &i.Mode, pq.Array(&i.MCPServerIDs), &i.Labels, + &i.BuildID, + &i.AgentID, ) return i, err } @@ -5823,7 +5891,7 @@ SET WHERE id = $2::uuid RETURNING - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id ` type UpdateChatMCPServerIDsParams struct { @@ -5853,6 +5921,8 @@ func (q *sqlQuerier) UpdateChatMCPServerIDs(ctx context.Context, arg UpdateChatM &i.Mode, pq.Array(&i.MCPServerIDs), &i.Labels, + &i.BuildID, + &i.AgentID, ) return i, err } @@ -5917,7 +5987,7 @@ SET WHERE id = $6::uuid RETURNING - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id ` type UpdateChatStatusParams struct { @@ -5958,29 +6028,36 @@ func (q *sqlQuerier) UpdateChatStatus(ctx context.Context, arg UpdateChatStatusP &i.Mode, pq.Array(&i.MCPServerIDs), &i.Labels, + &i.BuildID, + &i.AgentID, ) return i, err } -const updateChatWorkspace = `-- name: UpdateChatWorkspace :one -UPDATE - chats -SET +const updateChatWorkspaceBinding = `-- name: UpdateChatWorkspaceBinding :one +UPDATE chats SET workspace_id = $1::uuid, + build_id = $2::uuid, + agent_id = $3::uuid, updated_at = NOW() -WHERE - id = $2::uuid -RETURNING - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels +WHERE id = $4::uuid +RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id ` -type UpdateChatWorkspaceParams struct { +type UpdateChatWorkspaceBindingParams struct { WorkspaceID uuid.NullUUID `db:"workspace_id" json:"workspace_id"` + BuildID uuid.NullUUID `db:"build_id" json:"build_id"` + AgentID uuid.NullUUID `db:"agent_id" json:"agent_id"` ID uuid.UUID `db:"id" json:"id"` } -func (q *sqlQuerier) UpdateChatWorkspace(ctx context.Context, arg UpdateChatWorkspaceParams) (Chat, error) { - row := q.db.QueryRowContext(ctx, updateChatWorkspace, arg.WorkspaceID, arg.ID) +func (q *sqlQuerier) UpdateChatWorkspaceBinding(ctx context.Context, arg UpdateChatWorkspaceBindingParams) (Chat, error) { + row := q.db.QueryRowContext(ctx, updateChatWorkspaceBinding, + arg.WorkspaceID, + arg.BuildID, + arg.AgentID, + arg.ID, + ) var i Chat err := row.Scan( &i.ID, @@ -6001,6 +6078,8 @@ func (q *sqlQuerier) UpdateChatWorkspace(ctx context.Context, arg UpdateChatWork &i.Mode, pq.Array(&i.MCPServerIDs), &i.Labels, + &i.BuildID, + &i.AgentID, ) return i, err } diff --git a/coderd/database/queries/chats.sql b/coderd/database/queries/chats.sql index 0130d71774..0b9d3b7b24 100644 --- a/coderd/database/queries/chats.sql +++ b/coderd/database/queries/chats.sql @@ -180,6 +180,8 @@ LIMIT INSERT INTO chats ( owner_id, workspace_id, + build_id, + agent_id, parent_chat_id, root_chat_id, last_model_config_id, @@ -190,6 +192,8 @@ INSERT INTO chats ( ) VALUES ( @owner_id::uuid, sqlc.narg('workspace_id')::uuid, + sqlc.narg('build_id')::uuid, + sqlc.narg('agent_id')::uuid, sqlc.narg('parent_chat_id')::uuid, sqlc.narg('root_chat_id')::uuid, @last_model_config_id::uuid, @@ -305,16 +309,23 @@ WHERE RETURNING *; --- name: UpdateChatWorkspace :one -UPDATE - chats -SET +-- name: UpdateChatWorkspaceBinding :one +UPDATE chats SET workspace_id = sqlc.narg('workspace_id')::uuid, + build_id = sqlc.narg('build_id')::uuid, + agent_id = sqlc.narg('agent_id')::uuid, + updated_at = NOW() +WHERE id = @id::uuid +RETURNING *; + +-- name: UpdateChatBuildAgentBinding :one +UPDATE chats SET + build_id = sqlc.narg('build_id')::uuid, + agent_id = sqlc.narg('agent_id')::uuid, updated_at = NOW() WHERE id = @id::uuid -RETURNING - *; +RETURNING *; -- name: UpdateChatMCPServerIDs :one UPDATE diff --git a/coderd/exp_chats.go b/coderd/exp_chats.go index 79f97f3ffe..37b1927764 100644 --- a/coderd/exp_chats.go +++ b/coderd/exp_chats.go @@ -3600,6 +3600,12 @@ func convertChat(c database.Chat, diffStatus *database.ChatDiffStatus) codersdk. if c.WorkspaceID.Valid { chat.WorkspaceID = &c.WorkspaceID.UUID } + if c.BuildID.Valid { + chat.BuildID = &c.BuildID.UUID + } + if c.AgentID.Valid { + chat.AgentID = &c.AgentID.UUID + } if diffStatus != nil { convertedDiffStatus := db2sdk.ChatDiffStatus(c.ID, diffStatus) chat.DiffStatus = &convertedDiffStatus diff --git a/coderd/x/chatd/chatd.go b/coderd/x/chatd/chatd.go index 710d90db2e..4f88792106 100644 --- a/coderd/x/chatd/chatd.go +++ b/coderd/x/chatd/chatd.go @@ -53,6 +53,8 @@ const ( DefaultInFlightChatStaleAfter = 5 * time.Minute homeInstructionLookupTimeout = 5 * time.Second + instructionCacheTTL = 5 * time.Minute + workspaceDialValidationDelay = 5 * time.Second // DefaultChatHeartbeatInterval is the default time between chat // heartbeat updates while a chat is being processed. DefaultChatHeartbeatInterval = 30 * time.Second @@ -158,18 +160,26 @@ type turnWorkspaceContext struct { currentChat *database.Chat loadChatSnapshot func(context.Context, uuid.UUID) (database.Chat, error) - mu sync.Mutex - agent database.WorkspaceAgent - agentLoaded bool - conn workspacesdk.AgentConn - releaseConn func() + mu sync.Mutex + agent database.WorkspaceAgent + agentLoaded bool + conn workspacesdk.AgentConn + releaseConn func() + cachedWorkspaceID uuid.NullUUID } func (c *turnWorkspaceContext) close() { + c.clearCachedWorkspaceState() +} + +func (c *turnWorkspaceContext) clearCachedWorkspaceState() { c.mu.Lock() releaseConn := c.releaseConn + c.agent = database.WorkspaceAgent{} + c.agentLoaded = false c.conn = nil c.releaseConn = nil + c.cachedWorkspaceID = uuid.NullUUID{} c.mu.Unlock() if releaseConn != nil { @@ -177,6 +187,68 @@ func (c *turnWorkspaceContext) close() { } } +func (c *turnWorkspaceContext) setCurrentChat(chat database.Chat) { + c.chatStateMu.Lock() + *c.currentChat = chat + c.chatStateMu.Unlock() +} + +func (c *turnWorkspaceContext) currentChatSnapshot() database.Chat { + c.chatStateMu.Lock() + chatSnapshot := *c.currentChat + c.chatStateMu.Unlock() + return chatSnapshot +} + +func (c *turnWorkspaceContext) selectWorkspace(chat database.Chat) { + c.setCurrentChat(chat) + c.clearCachedWorkspaceState() +} + +func (c *turnWorkspaceContext) currentWorkspaceMatches(expected uuid.NullUUID) (database.Chat, bool) { + chatSnapshot := c.currentChatSnapshot() + return chatSnapshot, nullUUIDEqual(chatSnapshot.WorkspaceID, expected) +} + +func nullUUIDEqual(left, right uuid.NullUUID) bool { + if left.Valid != right.Valid { + return false + } + if !left.Valid { + return true + } + return left.UUID == right.UUID +} + +func (c *turnWorkspaceContext) persistBuildAgentBinding( + ctx context.Context, + chatSnapshot database.Chat, + buildID uuid.UUID, + agentID uuid.UUID, +) (database.Chat, error) { + updatedChat, err := c.server.db.UpdateChatBuildAgentBinding( + ctx, + database.UpdateChatBuildAgentBindingParams{ + ID: chatSnapshot.ID, + BuildID: uuid.NullUUID{ + UUID: buildID, + Valid: true, + }, + AgentID: uuid.NullUUID{ + UUID: agentID, + Valid: true, + }, + }, + ) + if err != nil { + return chatSnapshot, xerrors.Errorf( + "update chat build/agent binding: %w", err, + ) + } + c.setCurrentChat(updatedChat) + return updatedChat, nil +} + func (c *turnWorkspaceContext) getWorkspaceAgent(ctx context.Context) (database.WorkspaceAgent, error) { _, agent, err := c.ensureWorkspaceAgent(ctx) return agent, err @@ -189,128 +261,245 @@ func (c *turnWorkspaceContext) ensureWorkspaceAgent( defer c.mu.Unlock() if c.agentLoaded { - c.chatStateMu.Lock() - chatSnapshot := *c.currentChat - c.chatStateMu.Unlock() - return chatSnapshot, c.agent, nil + chatSnapshot := c.currentChatSnapshot() + if nullUUIDEqual(c.cachedWorkspaceID, chatSnapshot.WorkspaceID) { + return chatSnapshot, c.agent, nil + } + c.agent = database.WorkspaceAgent{} + c.agentLoaded = false } return c.loadWorkspaceAgentLocked(ctx) } -func (c *turnWorkspaceContext) refreshWorkspaceAgent( - ctx context.Context, -) (database.Chat, database.WorkspaceAgent, error) { - c.mu.Lock() - defer c.mu.Unlock() - - c.agent = database.WorkspaceAgent{} - c.agentLoaded = false - return c.loadWorkspaceAgentLocked(ctx) -} - func (c *turnWorkspaceContext) loadWorkspaceAgentLocked( ctx context.Context, ) (database.Chat, database.WorkspaceAgent, error) { - c.chatStateMu.Lock() - chatSnapshot := *c.currentChat - c.chatStateMu.Unlock() + chatSnapshot := c.currentChatSnapshot() - if !chatSnapshot.WorkspaceID.Valid { - refreshedChat, refreshErr := refreshChatWorkspaceSnapshot( + for attempt := 0; attempt < 2; attempt++ { + if !chatSnapshot.WorkspaceID.Valid { + refreshedChat, refreshErr := refreshChatWorkspaceSnapshot( + ctx, + chatSnapshot, + c.loadChatSnapshot, + ) + if refreshErr != nil { + return chatSnapshot, database.WorkspaceAgent{}, refreshErr + } + if refreshedChat.WorkspaceID.Valid { + c.setCurrentChat(refreshedChat) + chatSnapshot = refreshedChat + } + } + + if !chatSnapshot.WorkspaceID.Valid { + return chatSnapshot, database.WorkspaceAgent{}, xerrors.New("chat has no workspace") + } + + if chatSnapshot.AgentID.Valid { + agent, err := c.server.db.GetWorkspaceAgentByID(ctx, chatSnapshot.AgentID.UUID) + if err == nil { + latestChat, workspaceMatches := c.currentWorkspaceMatches(chatSnapshot.WorkspaceID) + if !workspaceMatches { + chatSnapshot = latestChat + continue + } + c.agent = agent + c.agentLoaded = true + c.cachedWorkspaceID = chatSnapshot.WorkspaceID + return chatSnapshot, c.agent, nil + } + if !xerrors.Is(err, sql.ErrNoRows) { + c.server.logger.Warn(ctx, "agent binding lookup failed, re-resolving", + slog.F("agent_id", chatSnapshot.AgentID.UUID), + slog.Error(err), + ) + } + } + + agents, err := c.server.db.GetWorkspaceAgentsInLatestBuildByWorkspaceID( + ctx, + chatSnapshot.WorkspaceID.UUID, + ) + if err != nil || len(agents) == 0 { + return chatSnapshot, database.WorkspaceAgent{}, xerrors.New("chat has no workspace agent") + } + + build, err := c.server.db.GetLatestWorkspaceBuildByWorkspaceID(ctx, chatSnapshot.WorkspaceID.UUID) + if err != nil { + return chatSnapshot, database.WorkspaceAgent{}, xerrors.Errorf("get latest workspace build: %w", err) + } + + updatedChat, err := c.persistBuildAgentBinding( ctx, chatSnapshot, - c.loadChatSnapshot, + build.ID, + agents[0].ID, ) - if refreshErr != nil { - return chatSnapshot, database.WorkspaceAgent{}, refreshErr + if err != nil { + return chatSnapshot, database.WorkspaceAgent{}, err } - if refreshedChat.WorkspaceID.Valid { - c.chatStateMu.Lock() - *c.currentChat = refreshedChat - c.chatStateMu.Unlock() - chatSnapshot = refreshedChat + + chatSnapshot = updatedChat + latestChat, workspaceMatches := c.currentWorkspaceMatches(chatSnapshot.WorkspaceID) + if !workspaceMatches { + chatSnapshot = latestChat + continue } + c.agent = agents[0] + c.agentLoaded = true + c.cachedWorkspaceID = chatSnapshot.WorkspaceID + return chatSnapshot, c.agent, nil } - if !chatSnapshot.WorkspaceID.Valid { - return chatSnapshot, database.WorkspaceAgent{}, xerrors.New("chat has no workspace") - } - - agents, err := c.server.db.GetWorkspaceAgentsInLatestBuildByWorkspaceID( - ctx, - chatSnapshot.WorkspaceID.UUID, + return chatSnapshot, database.WorkspaceAgent{}, xerrors.New( + "chat workspace changed while resolving agent", ) - if err != nil || len(agents) == 0 { - return chatSnapshot, database.WorkspaceAgent{}, xerrors.New("chat has no workspace agent") +} + +// getWorkspaceConnLocked returns the cached connection when it still matches +// the current workspace. When the workspace changed, it clears the stale +// cached state and returns the release func for the caller to run after +// unlocking. +func (c *turnWorkspaceContext) getWorkspaceConnLocked() (workspacesdk.AgentConn, func()) { + if c.conn == nil { + return nil, nil } - c.agent = agents[0] - c.agentLoaded = true - return chatSnapshot, c.agent, nil + chatSnapshot := c.currentChatSnapshot() + if nullUUIDEqual(c.cachedWorkspaceID, chatSnapshot.WorkspaceID) { + return c.conn, nil + } + + agentRelease := c.releaseConn + c.agent = database.WorkspaceAgent{} + c.agentLoaded = false + c.conn = nil + c.releaseConn = nil + c.cachedWorkspaceID = uuid.NullUUID{} + return nil, agentRelease } func (c *turnWorkspaceContext) getWorkspaceConn(ctx context.Context) (workspacesdk.AgentConn, error) { - c.mu.Lock() - if c.conn != nil { - currentConn := c.conn - c.mu.Unlock() - return currentConn, nil - } - c.mu.Unlock() - if c.server.agentConnFn == nil { return nil, xerrors.New("workspace agent connector is not configured") } - chatSnapshot, agent, err := c.ensureWorkspaceAgent(ctx) - if err != nil { - return nil, err - } - - agentConn, agentRelease, err := c.server.agentConnFn(ctx, agent.ID) - if err != nil { - refreshedChat, refreshedAgent, refreshErr := c.refreshWorkspaceAgent(ctx) - if refreshErr != nil { - return nil, xerrors.Errorf("connect to workspace agent: %w", err) - } - - retryConn, retryRelease, retryErr := c.server.agentConnFn(ctx, refreshedAgent.ID) - if retryErr != nil { - return nil, xerrors.Errorf("connect to workspace agent after refresh: %w", retryErr) - } - - chatSnapshot = refreshedChat - agentConn = retryConn - agentRelease = retryRelease - } - - c.mu.Lock() - if c.conn == nil { - c.conn = agentConn - c.releaseConn = agentRelease - - var ancestorIDs []string - if chatSnapshot.ParentChatID.Valid { - ancestorIDs = append(ancestorIDs, chatSnapshot.ParentChatID.UUID.String()) - } - ancestorJSON, marshalErr := json.Marshal(ancestorIDs) - if marshalErr != nil { - ancestorJSON = []byte("[]") - } - agentConn.SetExtraHeaders(http.Header{ - workspacesdk.CoderChatIDHeader: {chatSnapshot.ID.String()}, - workspacesdk.CoderAncestorChatIDsHeader: {string(ancestorJSON)}, - }) - + for attempt := 0; attempt < 2; attempt++ { + c.mu.Lock() + currentConn, staleRelease := c.getWorkspaceConnLocked() c.mu.Unlock() - return agentConn, nil - } - currentConn := c.conn - c.mu.Unlock() + if currentConn != nil { + return currentConn, nil + } + if staleRelease != nil { + staleRelease() + } - agentRelease() - return currentConn, nil + chatSnapshot, agent, err := c.ensureWorkspaceAgent(ctx) + if err != nil { + return nil, err + } + + dialResult, err := dialWithLazyValidation( + ctx, + agent.ID, + chatSnapshot.WorkspaceID.UUID, + DialFunc(c.server.agentConnFn), + func(ctx context.Context, workspaceID uuid.UUID) (uuid.UUID, error) { + agents, err := c.server.db.GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx, workspaceID) + if err != nil || len(agents) == 0 { + return uuid.Nil, xerrors.New("chat has no workspace agent") + } + return agents[0].ID, nil + }, + workspaceDialValidationDelay, + ) + if err != nil { + return nil, err + } + + agentConn := dialResult.Conn + agentRelease := dialResult.Release + if dialResult.WasSwitched { + build, err := c.server.db.GetLatestWorkspaceBuildByWorkspaceID(ctx, chatSnapshot.WorkspaceID.UUID) + if err != nil { + if agentRelease != nil { + agentRelease() + } + return nil, xerrors.Errorf("get latest workspace build: %w", err) + } + + switchedAgent, err := c.server.db.GetWorkspaceAgentByID(ctx, dialResult.AgentID) + if err != nil { + if agentRelease != nil { + agentRelease() + } + return nil, xerrors.Errorf("get workspace agent by id: %w", err) + } + + updatedChat, err := c.persistBuildAgentBinding( + ctx, + chatSnapshot, + build.ID, + switchedAgent.ID, + ) + if err != nil { + if agentRelease != nil { + agentRelease() + } + return nil, err + } + chatSnapshot = updatedChat + + c.mu.Lock() + c.agent = switchedAgent + c.agentLoaded = true + c.cachedWorkspaceID = chatSnapshot.WorkspaceID + c.mu.Unlock() + } + + if _, workspaceMatches := c.currentWorkspaceMatches(chatSnapshot.WorkspaceID); !workspaceMatches { + if agentRelease != nil { + agentRelease() + } + c.clearCachedWorkspaceState() + continue + } + + c.mu.Lock() + if c.conn == nil { + c.conn = agentConn + c.releaseConn = agentRelease + c.cachedWorkspaceID = chatSnapshot.WorkspaceID + + var ancestorIDs []string + if chatSnapshot.ParentChatID.Valid { + ancestorIDs = append(ancestorIDs, chatSnapshot.ParentChatID.UUID.String()) + } + ancestorJSON, marshalErr := json.Marshal(ancestorIDs) + if marshalErr != nil { + ancestorJSON = []byte("[]") + } + agentConn.SetExtraHeaders(http.Header{ + workspacesdk.CoderChatIDHeader: {chatSnapshot.ID.String()}, + workspacesdk.CoderAncestorChatIDsHeader: {string(ancestorJSON)}, + }) + + c.mu.Unlock() + return agentConn, nil + } + currentConn = c.conn + c.mu.Unlock() + + if agentRelease != nil { + agentRelease() + } + return currentConn, nil + } + + return nil, xerrors.New("chat workspace changed while connecting") } // AgentConnFunc provides access to workspace agent connections. @@ -420,6 +609,8 @@ func (e *UsageLimitExceededError) Error() string { type CreateOptions struct { OwnerID uuid.UUID WorkspaceID uuid.NullUUID + BuildID uuid.NullUUID + AgentID uuid.NullUUID ParentChatID uuid.NullUUID RootChatID uuid.NullUUID Title string @@ -525,6 +716,8 @@ func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.C insertedChat, err := tx.InsertChat(ctx, database.InsertChatParams{ OwnerID: opts.OwnerID, WorkspaceID: opts.WorkspaceID, + BuildID: opts.BuildID, + AgentID: opts.AgentID, ParentChatID: opts.ParentChatID, RootChatID: opts.RootChatID, LastModelConfigID: opts.ModelConfigID, @@ -3461,6 +3654,7 @@ func (p *Server) runChat( AgentConnFn: chattool.AgentConnFunc(p.agentConnFn), AgentInactiveDisconnectTimeout: p.agentInactiveDisconnectTimeout, WorkspaceMu: &workspaceMu, + OnChatUpdated: workspaceCtx.selectWorkspace, Logger: p.logger, AllowedTemplateIDs: p.chatTemplateAllowlist, }), diff --git a/coderd/x/chatd/chatd_internal_test.go b/coderd/x/chatd/chatd_internal_test.go index dce239cc1f..958f3ba19e 100644 --- a/coderd/x/chatd/chatd_internal_test.go +++ b/coderd/x/chatd/chatd_internal_test.go @@ -112,24 +112,29 @@ func TestPersistInstructionFilesIncludesAgentMetadata(t *testing.T) { db := dbmock.NewMockStore(ctrl) workspaceID := uuid.New() + agentID := uuid.New() chat := database.Chat{ ID: uuid.New(), WorkspaceID: uuid.NullUUID{ UUID: workspaceID, Valid: true, }, + AgentID: uuid.NullUUID{ + UUID: agentID, + Valid: true, + }, } workspaceAgent := database.WorkspaceAgent{ - ID: uuid.New(), + ID: agentID, OperatingSystem: "linux", Directory: "/home/coder/project", ExpandedDirectory: "/home/coder/project", } - db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID( + db.EXPECT().GetWorkspaceAgentByID( gomock.Any(), - workspaceID, - ).Return([]database.WorkspaceAgent{workspaceAgent}, nil).Times(1) + agentID, + ).Return(workspaceAgent, nil).Times(1) db.EXPECT().InsertChatMessages(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() conn := agentconnmock.NewMockAgentConn(ctrl) @@ -180,7 +185,7 @@ func TestPersistInstructionFilesIncludesAgentMetadata(t *testing.T) { require.Contains(t, instruction, "Working Directory: /home/coder/project") } -func TestTurnWorkspaceContextGetWorkspaceConnRefreshesWorkspaceAgent(t *testing.T) { +func TestTurnWorkspaceContext_BindingFirstPath(t *testing.T) { t.Parallel() ctx := context.Background() @@ -188,6 +193,53 @@ func TestTurnWorkspaceContextGetWorkspaceConnRefreshesWorkspaceAgent(t *testing. db := dbmock.NewMockStore(ctrl) workspaceID := uuid.New() + agentID := uuid.New() + chat := database.Chat{ + ID: uuid.New(), + WorkspaceID: uuid.NullUUID{ + UUID: workspaceID, + Valid: true, + }, + AgentID: uuid.NullUUID{ + UUID: agentID, + Valid: true, + }, + } + workspaceAgent := database.WorkspaceAgent{ID: agentID} + + db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).Return(workspaceAgent, nil).Times(1) + + chatStateMu := &sync.Mutex{} + currentChat := chat + workspaceCtx := turnWorkspaceContext{ + server: &Server{db: db}, + chatStateMu: chatStateMu, + currentChat: ¤tChat, + loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil }, + } + t.Cleanup(workspaceCtx.close) + + chatSnapshot, agent, err := workspaceCtx.ensureWorkspaceAgent(ctx) + require.NoError(t, err) + require.Equal(t, chat, chatSnapshot) + require.Equal(t, workspaceAgent, agent) + + gotAgent, err := workspaceCtx.getWorkspaceAgent(ctx) + require.NoError(t, err) + require.Equal(t, workspaceAgent, gotAgent) + require.Equal(t, chat, currentChat) +} + +func TestTurnWorkspaceContext_NullBindingLazyBind(t *testing.T) { + t.Parallel() + + ctx := context.Background() + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + + workspaceID := uuid.New() + buildID := uuid.New() + agentID := uuid.New() chat := database.Chat{ ID: uuid.New(), WorkspaceID: uuid.NullUUID{ @@ -195,18 +247,135 @@ func TestTurnWorkspaceContextGetWorkspaceConnRefreshesWorkspaceAgent(t *testing. Valid: true, }, } - initialAgent := database.WorkspaceAgent{ID: uuid.New()} - refreshedAgent := database.WorkspaceAgent{ID: uuid.New()} + workspaceAgent := database.WorkspaceAgent{ID: agentID} + updatedChat := chat + updatedChat.BuildID = uuid.NullUUID{UUID: buildID, Valid: true} + updatedChat.AgentID = uuid.NullUUID{UUID: agentID, Valid: true} gomock.InOrder( - db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID( - gomock.Any(), - workspaceID, - ).Return([]database.WorkspaceAgent{initialAgent}, nil), - db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID( - gomock.Any(), - workspaceID, - ).Return([]database.WorkspaceAgent{refreshedAgent}, nil), + db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).Return([]database.WorkspaceAgent{workspaceAgent}, nil), + db.EXPECT().GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), workspaceID).Return(database.WorkspaceBuild{ID: buildID}, nil), + db.EXPECT().UpdateChatBuildAgentBinding(gomock.Any(), database.UpdateChatBuildAgentBindingParams{ + BuildID: uuid.NullUUID{UUID: buildID, Valid: true}, + AgentID: uuid.NullUUID{UUID: agentID, Valid: true}, + ID: chat.ID, + }).Return(updatedChat, nil), + ) + + chatStateMu := &sync.Mutex{} + currentChat := chat + workspaceCtx := turnWorkspaceContext{ + server: &Server{db: db}, + chatStateMu: chatStateMu, + currentChat: ¤tChat, + loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil }, + } + t.Cleanup(workspaceCtx.close) + + chatSnapshot, agent, err := workspaceCtx.ensureWorkspaceAgent(ctx) + require.NoError(t, err) + require.Equal(t, updatedChat, chatSnapshot) + require.Equal(t, workspaceAgent, agent) + require.Equal(t, updatedChat, currentChat) + + gotAgent, err := workspaceCtx.getWorkspaceAgent(ctx) + require.NoError(t, err) + require.Equal(t, workspaceAgent, gotAgent) +} + +func TestTurnWorkspaceContext_StaleBindingRepair(t *testing.T) { + t.Parallel() + + ctx := context.Background() + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + + workspaceID := uuid.New() + staleAgentID := uuid.New() + buildID := uuid.New() + currentAgentID := uuid.New() + chat := database.Chat{ + ID: uuid.New(), + WorkspaceID: uuid.NullUUID{ + UUID: workspaceID, + Valid: true, + }, + AgentID: uuid.NullUUID{ + UUID: staleAgentID, + Valid: true, + }, + } + currentAgent := database.WorkspaceAgent{ID: currentAgentID} + updatedChat := chat + updatedChat.BuildID = uuid.NullUUID{UUID: buildID, Valid: true} + updatedChat.AgentID = uuid.NullUUID{UUID: currentAgentID, Valid: true} + + gomock.InOrder( + db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), staleAgentID).Return(database.WorkspaceAgent{}, xerrors.New("missing agent")), + db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).Return([]database.WorkspaceAgent{currentAgent}, nil), + db.EXPECT().GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), workspaceID).Return(database.WorkspaceBuild{ID: buildID}, nil), + db.EXPECT().UpdateChatBuildAgentBinding(gomock.Any(), database.UpdateChatBuildAgentBindingParams{ + BuildID: uuid.NullUUID{UUID: buildID, Valid: true}, + AgentID: uuid.NullUUID{UUID: currentAgentID, Valid: true}, + ID: chat.ID, + }).Return(updatedChat, nil), + ) + + chatStateMu := &sync.Mutex{} + currentChat := chat + workspaceCtx := turnWorkspaceContext{ + server: &Server{db: db}, + chatStateMu: chatStateMu, + currentChat: ¤tChat, + loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil }, + } + t.Cleanup(workspaceCtx.close) + + chatSnapshot, agent, err := workspaceCtx.ensureWorkspaceAgent(ctx) + require.NoError(t, err) + require.Equal(t, updatedChat, chatSnapshot) + require.Equal(t, currentAgent, agent) + require.Equal(t, updatedChat, currentChat) +} + +func TestTurnWorkspaceContextGetWorkspaceConnLazyValidationSwitchesWorkspaceAgent(t *testing.T) { + t.Parallel() + + ctx := context.Background() + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + + workspaceID := uuid.New() + staleAgentID := uuid.New() + currentAgentID := uuid.New() + buildID := uuid.New() + chat := database.Chat{ + ID: uuid.New(), + WorkspaceID: uuid.NullUUID{ + UUID: workspaceID, + Valid: true, + }, + AgentID: uuid.NullUUID{ + UUID: staleAgentID, + Valid: true, + }, + } + staleAgent := database.WorkspaceAgent{ID: staleAgentID} + currentAgent := database.WorkspaceAgent{ID: currentAgentID} + updatedChat := chat + updatedChat.BuildID = uuid.NullUUID{UUID: buildID, Valid: true} + updatedChat.AgentID = uuid.NullUUID{UUID: currentAgentID, Valid: true} + + gomock.InOrder( + db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), staleAgentID).Return(staleAgent, nil), + db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).Return([]database.WorkspaceAgent{currentAgent}, nil), + db.EXPECT().GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), workspaceID).Return(database.WorkspaceBuild{ID: buildID}, nil), + db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), currentAgentID).Return(currentAgent, nil), + db.EXPECT().UpdateChatBuildAgentBinding(gomock.Any(), database.UpdateChatBuildAgentBindingParams{ + BuildID: uuid.NullUUID{UUID: buildID, Valid: true}, + AgentID: uuid.NullUUID{UUID: currentAgentID, Valid: true}, + ID: chat.ID, + }).Return(updatedChat, nil), ) conn := agentconnmock.NewMockAgentConn(ctrl) @@ -216,7 +385,7 @@ func TestTurnWorkspaceContextGetWorkspaceConnRefreshesWorkspaceAgent(t *testing. server := &Server{db: db} server.agentConnFn = func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) { dialed = append(dialed, agentID) - if agentID == initialAgent.ID { + if agentID == staleAgentID { return nil, nil, xerrors.New("dial failed") } return conn, func() {}, nil @@ -235,7 +404,112 @@ func TestTurnWorkspaceContextGetWorkspaceConnRefreshesWorkspaceAgent(t *testing. gotConn, err := workspaceCtx.getWorkspaceConn(ctx) require.NoError(t, err) require.Same(t, conn, gotConn) - require.Equal(t, []uuid.UUID{initialAgent.ID, refreshedAgent.ID}, dialed) + require.Equal(t, []uuid.UUID{staleAgentID, currentAgentID}, dialed) + require.Equal(t, updatedChat, currentChat) + + gotAgent, err := workspaceCtx.getWorkspaceAgent(ctx) + require.NoError(t, err) + require.Equal(t, currentAgent, gotAgent) +} + +func TestTurnWorkspaceContext_SelectWorkspaceClearsCachedState(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + currentChat := database.Chat{ + ID: uuid.New(), + WorkspaceID: uuid.NullUUID{ + UUID: uuid.New(), + Valid: true, + }, + } + updatedChat := database.Chat{ + ID: currentChat.ID, + WorkspaceID: uuid.NullUUID{ + UUID: uuid.New(), + Valid: true, + }, + } + cachedConn := agentconnmock.NewMockAgentConn(ctrl) + releaseCalls := 0 + + workspaceCtx := turnWorkspaceContext{ + chatStateMu: &sync.Mutex{}, + currentChat: ¤tChat, + } + workspaceCtx.agent = database.WorkspaceAgent{ID: uuid.New()} + workspaceCtx.agentLoaded = true + workspaceCtx.conn = cachedConn + workspaceCtx.cachedWorkspaceID = currentChat.WorkspaceID + workspaceCtx.releaseConn = func() { + releaseCalls++ + } + + workspaceCtx.selectWorkspace(updatedChat) + + require.Equal(t, updatedChat, currentChat) + require.Equal(t, 1, releaseCalls) + + workspaceCtx.mu.Lock() + defer workspaceCtx.mu.Unlock() + require.Equal(t, database.WorkspaceAgent{}, workspaceCtx.agent) + require.False(t, workspaceCtx.agentLoaded) + require.Nil(t, workspaceCtx.conn) + require.Nil(t, workspaceCtx.releaseConn) + require.Equal(t, uuid.NullUUID{}, workspaceCtx.cachedWorkspaceID) +} + +func TestTurnWorkspaceContext_EnsureWorkspaceAgentIgnoresCachedAgentForDifferentWorkspace(t *testing.T) { + t.Parallel() + + ctx := context.Background() + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + + workspaceOneID := uuid.New() + workspaceTwoID := uuid.New() + buildID := uuid.New() + cachedAgent := database.WorkspaceAgent{ID: uuid.New()} + resolvedAgent := database.WorkspaceAgent{ID: uuid.New()} + chat := database.Chat{ + ID: uuid.New(), + WorkspaceID: uuid.NullUUID{ + UUID: workspaceTwoID, + Valid: true, + }, + } + updatedChat := chat + updatedChat.BuildID = uuid.NullUUID{UUID: buildID, Valid: true} + updatedChat.AgentID = uuid.NullUUID{UUID: resolvedAgent.ID, Valid: true} + + gomock.InOrder( + db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceTwoID).Return([]database.WorkspaceAgent{resolvedAgent}, nil), + db.EXPECT().GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), workspaceTwoID).Return(database.WorkspaceBuild{ID: buildID}, nil), + db.EXPECT().UpdateChatBuildAgentBinding(gomock.Any(), database.UpdateChatBuildAgentBindingParams{ + ID: chat.ID, + BuildID: uuid.NullUUID{UUID: buildID, Valid: true}, + AgentID: uuid.NullUUID{UUID: resolvedAgent.ID, Valid: true}, + }).Return(updatedChat, nil), + ) + + chatStateMu := &sync.Mutex{} + currentChat := chat + workspaceCtx := turnWorkspaceContext{ + server: &Server{db: db}, + chatStateMu: chatStateMu, + currentChat: ¤tChat, + loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil }, + } + workspaceCtx.agent = cachedAgent + workspaceCtx.agentLoaded = true + workspaceCtx.cachedWorkspaceID = uuid.NullUUID{UUID: workspaceOneID, Valid: true} + defer workspaceCtx.close() + + chatSnapshot, agent, err := workspaceCtx.ensureWorkspaceAgent(ctx) + require.NoError(t, err) + require.Equal(t, updatedChat, chatSnapshot) + require.Equal(t, resolvedAgent, agent) + require.Equal(t, updatedChat, currentChat) } func TestSubscribeSkipsDatabaseCatchupForLocallyDeliveredMessage(t *testing.T) { diff --git a/coderd/x/chatd/chatd_test.go b/coderd/x/chatd/chatd_test.go index 66172177e9..b3f846117b 100644 --- a/coderd/x/chatd/chatd_test.go +++ b/coderd/x/chatd/chatd_test.go @@ -2280,7 +2280,7 @@ func TestHeartbeatBumpsWorkspaceUsage(t *testing.T) { // Link the workspace to the chat in the DB, simulating what // the create_workspace tool does mid-conversation. - _, err = db.UpdateChatWorkspace(ctx, database.UpdateChatWorkspaceParams{ + _, err = db.UpdateChatWorkspaceBinding(ctx, database.UpdateChatWorkspaceBindingParams{ WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true}, ID: chat.ID, }) diff --git a/coderd/x/chatd/chattool/createworkspace.go b/coderd/x/chatd/chattool/createworkspace.go index 33f27285f9..9f00c2108f 100644 --- a/coderd/x/chatd/chattool/createworkspace.go +++ b/coderd/x/chatd/chattool/createworkspace.go @@ -66,6 +66,7 @@ type CreateWorkspaceOptions struct { AgentConnFn AgentConnFunc AgentInactiveDisconnectTimeout time.Duration WorkspaceMu *sync.Mutex + OnChatUpdated func(database.Chat) Logger slog.Logger AllowedTemplateIDs func() map[uuid.UUID]bool } @@ -211,20 +212,29 @@ func CreateWorkspace(options CreateWorkspaceOptions) fantasy.AgentTool { } } - // Persist workspace + agent association on the chat. + // Persist the workspace binding on the chat. if options.DB != nil && options.ChatID != uuid.Nil { - if _, err := options.DB.UpdateChatWorkspace(ctx, database.UpdateChatWorkspaceParams{ + updatedChat, err := options.DB.UpdateChatWorkspaceBinding(ctx, database.UpdateChatWorkspaceBindingParams{ ID: options.ChatID, WorkspaceID: uuid.NullUUID{ UUID: workspace.ID, Valid: true, }, - }); err != nil { + // BuildID and AgentID are intentionally left null + // here. The chatd runtime (loadWorkspaceAgentLocked) + // will bind them on the next turn. Authoritative + // tool-path binding is deferred to a follow-up PR. + BuildID: uuid.NullUUID{}, + AgentID: uuid.NullUUID{}, + }) + if err != nil { options.Logger.Error(ctx, "failed to persist chat workspace association", slog.F("chat_id", options.ChatID), slog.F("workspace_id", workspace.ID), slog.Error(err), ) + } else if options.OnChatUpdated != nil { + options.OnChatUpdated(updatedChat) } } diff --git a/coderd/x/chatd/dialvalidation.go b/coderd/x/chatd/dialvalidation.go new file mode 100644 index 0000000000..06c1c536af --- /dev/null +++ b/coderd/x/chatd/dialvalidation.go @@ -0,0 +1,170 @@ +package chatd + +import ( + "context" + "time" + + "github.com/google/uuid" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/codersdk/workspacesdk" +) + +// DialResult contains the outcome of dialWithLazyValidation. +type DialResult struct { + Conn workspacesdk.AgentConn + Release func() + AgentID uuid.UUID // The agent that was actually dialed. + WasSwitched bool // True if validation discovered a different agent. +} + +// DialFunc dials an agent by ID and returns a connection. +type DialFunc func(ctx context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) + +// ValidateFunc returns the current agent ID for a workspace. +type ValidateFunc func(ctx context.Context, workspaceID uuid.UUID) (uuid.UUID, error) + +type dialOut struct { + conn workspacesdk.AgentConn + release func() + err error +} + +// dialWithLazyValidation dials an agent and only consults the database if the +// original dial is slow or fails quickly. This keeps the common path free of +// latest-build lookups while still repairing stale bindings. +// +// Outcomes: +// - The dial succeeds before delay, so validation is skipped. +// - The timer fires and validation confirms the same agent, so the original +// dial continues. +// - The timer fires and validation finds a different agent, so the stale +// dial is canceled and the new agent is dialed instead. +// - The dial fails before delay, so validation runs immediately and either +// switches to a different agent or retries the current one once. +func dialWithLazyValidation( + ctx context.Context, + agentID uuid.UUID, + workspaceID uuid.UUID, + dialFn DialFunc, + validateFn ValidateFunc, + delay time.Duration, +) (DialResult, error) { + wrapErr := func(err error) error { + return xerrors.Errorf("dial with lazy validation: %w", err) + } + + dialCtx, dialCancel := context.WithCancel(ctx) + results := make(chan dialOut, 1) + go func() { + conn, release, err := dialFn(dialCtx, agentID) + results <- dialOut{conn: conn, release: release, err: err} + }() + + drained := false + defer func() { + dialCancel() + if drained { + return + } + // Drain without blocking the caller. dialFn may take time to honor + // cancellation, but any late-arriving successful connection still needs to + // be released. + go func() { + result := <-results + if result.err == nil && result.release != nil { + result.release() + } + }() + }() + + resultForAgent := func(dialedAgentID uuid.UUID, result dialOut, switched bool) DialResult { + return DialResult{ + Conn: result.conn, + Release: result.release, + AgentID: dialedAgentID, + WasSwitched: switched, + } + } + dialAgent := func(targetAgentID uuid.UUID, switched bool) (DialResult, error) { + conn, release, err := dialFn(ctx, targetAgentID) + if err != nil { + return DialResult{}, wrapErr(err) + } + return resultForAgent(targetAgentID, dialOut{conn: conn, release: release}, switched), nil + } + preferReadyOriginalDial := func() (DialResult, bool) { + select { + case result := <-results: + drained = true + if result.err != nil { + return DialResult{}, false + } + return resultForAgent(agentID, result, false), true + default: + return DialResult{}, false + } + } + waitForOriginalDial := func(waitCtx context.Context) (DialResult, error) { + select { + case result := <-results: + drained = true + if result.err != nil { + return DialResult{}, wrapErr(result.err) + } + return resultForAgent(agentID, result, false), nil + case <-waitCtx.Done(): + if ready, ok := preferReadyOriginalDial(); ok { + return ready, nil + } + return DialResult{}, waitCtx.Err() + } + } + validateBinding := func() (uuid.UUID, error) { + validatedAgentID, err := validateFn(ctx, workspaceID) + if err != nil { + return uuid.Nil, wrapErr(err) + } + return validatedAgentID, nil + } + resolveFastFailure := func() (DialResult, error) { + validatedAgentID, err := validateBinding() + if err != nil { + return DialResult{}, err + } + if validatedAgentID == agentID { + return dialAgent(agentID, false) + } + return dialAgent(validatedAgentID, true) + } + + timer := time.NewTimer(delay) + defer timer.Stop() + + select { + case result := <-results: + drained = true + if result.err == nil { + return resultForAgent(agentID, result, false), nil + } + return resolveFastFailure() + + case <-timer.C: + validatedAgentID, validationErr := validateFn(ctx, workspaceID) + if validationErr != nil || validatedAgentID == agentID { + // Validation could not prove the binding was stale, so keep waiting on + // the original dial. + return waitForOriginalDial(ctx) + } + // The original dial is stale. Cancel it first, then let the deferred drain + // release any late result while we dial the validated agent immediately. + dialCancel() + return dialAgent(validatedAgentID, true) + + case <-ctx.Done(): + if ready, ok := preferReadyOriginalDial(); ok { + return ready, nil + } + return DialResult{}, ctx.Err() + } +} diff --git a/coderd/x/chatd/dialvalidation_test.go b/coderd/x/chatd/dialvalidation_test.go new file mode 100644 index 0000000000..c2fea03acb --- /dev/null +++ b/coderd/x/chatd/dialvalidation_test.go @@ -0,0 +1,563 @@ +package chatd //nolint:testpackage // Uses internal symbols. + +import ( + "context" + "sync/atomic" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/codersdk/workspacesdk" + "github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock" + "github.com/coder/coder/v2/testutil" +) + +func TestDialWithLazyValidation_FastDial(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + agentID := uuid.New() + workspaceID := uuid.New() + conn := agentconnmock.NewMockAgentConn(ctrl) + + var releaseCalls atomic.Int32 + var validateCalls atomic.Int32 + + result, err := dialWithLazyValidation( + context.Background(), + agentID, + workspaceID, + func(_ context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) { + if id != agentID { + return nil, nil, xerrors.Errorf("unexpected agent ID %q", id) + } + return conn, func() { + releaseCalls.Add(1) + }, nil + }, + func(_ context.Context, id uuid.UUID) (uuid.UUID, error) { + validateCalls.Add(1) + return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id) + }, + time.Minute, + ) + require.NoError(t, err) + require.Same(t, conn, result.Conn) + require.Equal(t, agentID, result.AgentID) + require.False(t, result.WasSwitched) + require.EqualValues(t, 0, validateCalls.Load()) + + if result.Release != nil { + result.Release() + } + require.EqualValues(t, 1, releaseCalls.Load()) +} + +func TestDialWithLazyValidation_SlowDialSameAgent(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + agentID := uuid.New() + workspaceID := uuid.New() + conn := agentconnmock.NewMockAgentConn(ctrl) + unblockDial := make(chan struct{}) + + var releaseCalls atomic.Int32 + var validateCalls atomic.Int32 + + result, err := dialWithLazyValidation( + context.Background(), + agentID, + workspaceID, + func(ctx context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) { + if id != agentID { + return nil, nil, xerrors.Errorf("unexpected agent ID %q", id) + } + select { + case <-unblockDial: + return conn, func() { + releaseCalls.Add(1) + }, nil + case <-ctx.Done(): + return nil, nil, ctx.Err() + } + }, + func(_ context.Context, id uuid.UUID) (uuid.UUID, error) { + if id != workspaceID { + return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id) + } + validateCalls.Add(1) + close(unblockDial) + return agentID, nil + }, + 0, + ) + require.NoError(t, err) + require.Same(t, conn, result.Conn) + require.Equal(t, agentID, result.AgentID) + require.False(t, result.WasSwitched) + require.EqualValues(t, 1, validateCalls.Load()) + + if result.Release != nil { + result.Release() + } + require.EqualValues(t, 1, releaseCalls.Load()) +} + +func TestDialWithLazyValidation_SlowDialStaleAgent(t *testing.T) { + t.Parallel() + + t.Run("LateSuccessReleasesStaleConn", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + staleAgentID := uuid.New() + currentAgentID := uuid.New() + workspaceID := uuid.New() + staleConn := agentconnmock.NewMockAgentConn(ctrl) + currentConn := agentconnmock.NewMockAgentConn(ctrl) + + var dialCalls atomic.Int32 + var validateCalls atomic.Int32 + var staleReleaseCalls atomic.Int32 + var currentReleaseCalls atomic.Int32 + + result, err := dialWithLazyValidation( + context.Background(), + staleAgentID, + workspaceID, + func(ctx context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) { + dialCalls.Add(1) + switch id { + case staleAgentID: + <-ctx.Done() + return staleConn, func() { + staleReleaseCalls.Add(1) + }, nil + case currentAgentID: + return currentConn, func() { + currentReleaseCalls.Add(1) + }, nil + default: + return nil, nil, xerrors.Errorf("unexpected agent ID %q", id) + } + }, + func(_ context.Context, id uuid.UUID) (uuid.UUID, error) { + if id != workspaceID { + return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id) + } + validateCalls.Add(1) + return currentAgentID, nil + }, + 0, + ) + require.NoError(t, err) + require.Same(t, currentConn, result.Conn) + require.Equal(t, currentAgentID, result.AgentID) + require.True(t, result.WasSwitched) + require.Eventually(t, func() bool { + return dialCalls.Load() == 2 + }, testutil.WaitShort, testutil.IntervalFast) + require.EqualValues(t, 1, validateCalls.Load()) + require.Eventually(t, func() bool { + return staleReleaseCalls.Load() == 1 + }, testutil.WaitShort, testutil.IntervalFast) + + if result.Release != nil { + result.Release() + } + require.EqualValues(t, 1, currentReleaseCalls.Load()) + }) + + t.Run("CanceledFailureDoesNotReleaseStaleConn", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + staleAgentID := uuid.New() + currentAgentID := uuid.New() + workspaceID := uuid.New() + currentConn := agentconnmock.NewMockAgentConn(ctrl) + + var dialCalls atomic.Int32 + var validateCalls atomic.Int32 + var staleReleaseCalls atomic.Int32 + var currentReleaseCalls atomic.Int32 + + result, err := dialWithLazyValidation( + context.Background(), + staleAgentID, + workspaceID, + func(ctx context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) { + dialCalls.Add(1) + switch id { + case staleAgentID: + <-ctx.Done() + return nil, func() { + staleReleaseCalls.Add(1) + }, ctx.Err() + case currentAgentID: + return currentConn, func() { + currentReleaseCalls.Add(1) + }, nil + default: + return nil, nil, xerrors.Errorf("unexpected agent ID %q", id) + } + }, + func(_ context.Context, id uuid.UUID) (uuid.UUID, error) { + if id != workspaceID { + return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id) + } + validateCalls.Add(1) + return currentAgentID, nil + }, + 0, + ) + require.NoError(t, err) + require.Same(t, currentConn, result.Conn) + require.Equal(t, currentAgentID, result.AgentID) + require.True(t, result.WasSwitched) + require.Eventually(t, func() bool { + return dialCalls.Load() == 2 + }, testutil.WaitShort, testutil.IntervalFast) + require.EqualValues(t, 1, validateCalls.Load()) + require.EqualValues(t, 0, staleReleaseCalls.Load()) + + if result.Release != nil { + result.Release() + } + require.EqualValues(t, 1, currentReleaseCalls.Load()) + }) + + t.Run("SwitchDoesNotBlock", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + staleAgentID := uuid.New() + currentAgentID := uuid.New() + workspaceID := uuid.New() + staleConn := agentconnmock.NewMockAgentConn(ctrl) + currentConn := agentconnmock.NewMockAgentConn(ctrl) + staleDialStarted := make(chan struct{}) + allowStaleReturn := make(chan struct{}) + + var dialCalls atomic.Int32 + var validateCalls atomic.Int32 + var staleReleaseCalls atomic.Int32 + var currentReleaseCalls atomic.Int32 + var staleReturnReleased atomic.Bool + releaseStaleReturn := func() { + if staleReturnReleased.CompareAndSwap(false, true) { + close(allowStaleReturn) + } + } + defer releaseStaleReturn() + + resultCh := make(chan DialResult, 1) + errCh := make(chan error, 1) + go func() { + result, err := dialWithLazyValidation( + context.Background(), + staleAgentID, + workspaceID, + func(_ context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) { + dialCalls.Add(1) + switch id { + case staleAgentID: + close(staleDialStarted) + <-allowStaleReturn + return staleConn, func() { + staleReleaseCalls.Add(1) + }, nil + case currentAgentID: + return currentConn, func() { + currentReleaseCalls.Add(1) + }, nil + default: + return nil, nil, xerrors.Errorf("unexpected agent ID %q", id) + } + }, + func(_ context.Context, id uuid.UUID) (uuid.UUID, error) { + if id != workspaceID { + return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id) + } + <-staleDialStarted + validateCalls.Add(1) + return currentAgentID, nil + }, + 0, + ) + if err != nil { + errCh <- err + return + } + resultCh <- result + }() + + var result DialResult + select { + case err := <-errCh: + require.NoError(t, err) + case result = <-resultCh: + require.Same(t, currentConn, result.Conn) + require.Equal(t, currentAgentID, result.AgentID) + require.True(t, result.WasSwitched) + releaseStaleReturn() + case <-time.After(testutil.WaitShort): + t.Fatal("dialWithLazyValidation blocked on stale dial cleanup") + } + + require.EqualValues(t, 2, dialCalls.Load()) + require.EqualValues(t, 1, validateCalls.Load()) + require.Eventually(t, func() bool { + return staleReleaseCalls.Load() == 1 + }, testutil.WaitShort, testutil.IntervalFast) + + if result.Release != nil { + result.Release() + } + require.EqualValues(t, 1, currentReleaseCalls.Load()) + }) +} + +func TestDialWithLazyValidation_FastFailure(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + staleAgentID := uuid.New() + currentAgentID := uuid.New() + workspaceID := uuid.New() + currentConn := agentconnmock.NewMockAgentConn(ctrl) + + var dialCalls atomic.Int32 + var validateCalls atomic.Int32 + var currentReleaseCalls atomic.Int32 + + result, err := dialWithLazyValidation( + context.Background(), + staleAgentID, + workspaceID, + func(_ context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) { + switch dialCalls.Add(1) { + case 1: + if id != staleAgentID { + return nil, nil, xerrors.Errorf("unexpected agent ID %q", id) + } + return nil, nil, xerrors.New("dial failed") + case 2: + if id != currentAgentID { + return nil, nil, xerrors.Errorf("unexpected agent ID %q", id) + } + return currentConn, func() { + currentReleaseCalls.Add(1) + }, nil + default: + return nil, nil, xerrors.New("unexpected dial call") + } + }, + func(_ context.Context, id uuid.UUID) (uuid.UUID, error) { + if id != workspaceID { + return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id) + } + validateCalls.Add(1) + return currentAgentID, nil + }, + time.Minute, + ) + require.NoError(t, err) + require.Same(t, currentConn, result.Conn) + require.Equal(t, currentAgentID, result.AgentID) + require.True(t, result.WasSwitched) + require.EqualValues(t, 2, dialCalls.Load()) + require.EqualValues(t, 1, validateCalls.Load()) + + if result.Release != nil { + result.Release() + } + require.EqualValues(t, 1, currentReleaseCalls.Load()) +} + +func TestDialWithLazyValidation_FastFailureSameAgent(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + agentID := uuid.New() + workspaceID := uuid.New() + conn := agentconnmock.NewMockAgentConn(ctrl) + + var dialCalls atomic.Int32 + var releaseCalls atomic.Int32 + var validateCalls atomic.Int32 + + result, err := dialWithLazyValidation( + context.Background(), + agentID, + workspaceID, + func(_ context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) { + if id != agentID { + return nil, nil, xerrors.Errorf("unexpected agent ID %q", id) + } + switch dialCalls.Add(1) { + case 1: + return nil, nil, xerrors.New("dial failed") + case 2: + return conn, func() { + releaseCalls.Add(1) + }, nil + default: + return nil, nil, xerrors.New("unexpected dial call") + } + }, + func(_ context.Context, id uuid.UUID) (uuid.UUID, error) { + if id != workspaceID { + return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id) + } + validateCalls.Add(1) + return agentID, nil + }, + time.Minute, + ) + require.NoError(t, err) + require.Same(t, conn, result.Conn) + require.Equal(t, agentID, result.AgentID) + require.False(t, result.WasSwitched) + require.EqualValues(t, 2, dialCalls.Load()) + require.EqualValues(t, 1, validateCalls.Load()) + + if result.Release != nil { + result.Release() + } + require.EqualValues(t, 1, releaseCalls.Load()) +} + +func TestDialWithLazyValidation_FastFailureSameAgentRetryFails(t *testing.T) { + t.Parallel() + + agentID := uuid.New() + workspaceID := uuid.New() + + var dialCalls atomic.Int32 + var validateCalls atomic.Int32 + + _, err := dialWithLazyValidation( + context.Background(), + agentID, + workspaceID, + func(_ context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) { + if id != agentID { + return nil, nil, xerrors.Errorf("unexpected agent ID %q", id) + } + switch dialCalls.Add(1) { + case 1: + return nil, nil, xerrors.New("dial failed") + case 2: + return nil, nil, xerrors.New("retry failed") + default: + return nil, nil, xerrors.New("unexpected dial call") + } + }, + func(_ context.Context, id uuid.UUID) (uuid.UUID, error) { + if id != workspaceID { + return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id) + } + validateCalls.Add(1) + return agentID, nil + }, + time.Minute, + ) + require.EqualError(t, err, "dial with lazy validation: retry failed") + require.EqualValues(t, 2, dialCalls.Load()) + require.EqualValues(t, 1, validateCalls.Load()) +} + +func TestDialWithLazyValidation_ValidationError(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + agentID := uuid.New() + workspaceID := uuid.New() + conn := agentconnmock.NewMockAgentConn(ctrl) + unblockDial := make(chan struct{}) + + var releaseCalls atomic.Int32 + var validateCalls atomic.Int32 + + result, err := dialWithLazyValidation( + context.Background(), + agentID, + workspaceID, + func(ctx context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) { + if id != agentID { + return nil, nil, xerrors.Errorf("unexpected agent ID %q", id) + } + select { + case <-unblockDial: + return conn, func() { + releaseCalls.Add(1) + }, nil + case <-ctx.Done(): + return nil, nil, ctx.Err() + } + }, + func(_ context.Context, id uuid.UUID) (uuid.UUID, error) { + if id != workspaceID { + return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id) + } + validateCalls.Add(1) + // Validation fails — code should fall back to waiting + // for the original dial. + close(unblockDial) + return uuid.Nil, xerrors.New("db connection reset") + }, + 0, + ) + require.NoError(t, err) + require.Same(t, conn, result.Conn) + require.Equal(t, agentID, result.AgentID) + require.False(t, result.WasSwitched) + require.EqualValues(t, 1, validateCalls.Load()) + + if result.Release != nil { + result.Release() + } + require.EqualValues(t, 1, releaseCalls.Load()) +} + +func TestDialWithLazyValidation_ContextCanceled(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + agentID := uuid.New() + workspaceID := uuid.New() + + var validateCalls atomic.Int32 + + _, err := dialWithLazyValidation( + ctx, + agentID, + workspaceID, + func(ctx context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) { + if id != agentID { + return nil, nil, xerrors.Errorf("unexpected agent ID %q", id) + } + <-ctx.Done() + return nil, nil, ctx.Err() + }, + func(_ context.Context, id uuid.UUID) (uuid.UUID, error) { + if id != workspaceID { + return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id) + } + validateCalls.Add(1) + cancel() + return agentID, nil + }, + 0, + ) + require.ErrorIs(t, err, context.Canceled) + require.EqualValues(t, 1, validateCalls.Load()) +} diff --git a/coderd/x/chatd/subagent.go b/coderd/x/chatd/subagent.go index 0475d9e8a2..cd1d71a0e1 100644 --- a/coderd/x/chatd/subagent.go +++ b/coderd/x/chatd/subagent.go @@ -313,6 +313,8 @@ func (p *Server) subagentTools(ctx context.Context, currentChat func() database. childChat, err := p.CreateChat(ctx, CreateOptions{ OwnerID: parent.OwnerID, WorkspaceID: parent.WorkspaceID, + BuildID: parent.BuildID, + AgentID: parent.AgentID, ParentChatID: uuid.NullUUID{ UUID: parent.ID, Valid: true, @@ -383,6 +385,8 @@ func (p *Server) createChildSubagentChat( child, err := p.CreateChat(ctx, CreateOptions{ OwnerID: parent.OwnerID, WorkspaceID: parent.WorkspaceID, + BuildID: parent.BuildID, + AgentID: parent.AgentID, ParentChatID: uuid.NullUUID{ UUID: parent.ID, Valid: true, diff --git a/coderd/x/chatd/subagent_internal_test.go b/coderd/x/chatd/subagent_internal_test.go index 55ed19cf94..fb31ba176c 100644 --- a/coderd/x/chatd/subagent_internal_test.go +++ b/coderd/x/chatd/subagent_internal_test.go @@ -149,6 +149,45 @@ func seedInternalChatDeps( return user, model } +func seedWorkspaceBinding( + t *testing.T, + db database.Store, + userID uuid.UUID, +) (database.WorkspaceTable, database.WorkspaceBuild, database.WorkspaceAgent) { + t.Helper() + + org := dbgen.Organization(t, db, database.Organization{}) + tv := dbgen.TemplateVersion(t, db, database.TemplateVersion{ + OrganizationID: org.ID, + CreatedBy: userID, + }) + tpl := dbgen.Template(t, db, database.Template{ + CreatedBy: userID, + OrganizationID: org.ID, + ActiveVersionID: tv.ID, + }) + workspace := dbgen.Workspace(t, db, database.WorkspaceTable{ + TemplateID: tpl.ID, + OwnerID: userID, + OrganizationID: org.ID, + }) + job := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{ + InitiatorID: userID, + OrganizationID: org.ID, + }) + build := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{ + TemplateVersionID: tv.ID, + WorkspaceID: workspace.ID, + JobID: job.ID, + }) + resource := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{ + Transition: database.WorkspaceTransitionStart, + JobID: job.ID, + }) + agent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{ResourceID: resource.ID}) + return workspace, build, agent +} + // findToolByName returns the tool with the given name from the // slice, or nil if no match is found. func findToolByName(tools []fantasy.AgentTool, name string) fantasy.AgentTool { @@ -165,6 +204,49 @@ func chatdTestContext(t *testing.T) context.Context { return dbauthz.AsChatd(testutil.Context(t, testutil.WaitLong)) } +func TestCreateChildSubagentChatInheritsWorkspaceBinding(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{}) + + ctx := chatdTestContext(t) + user, model := seedInternalChatDeps(ctx, t, db) + workspace, build, agent := seedWorkspaceBinding(t, db, user.ID) + + parent, err := server.CreateChat(ctx, CreateOptions{ + OwnerID: user.ID, + WorkspaceID: uuid.NullUUID{ + UUID: workspace.ID, + Valid: true, + }, + BuildID: uuid.NullUUID{ + UUID: build.ID, + Valid: true, + }, + AgentID: uuid.NullUUID{ + UUID: agent.ID, + Valid: true, + }, + Title: "bound-parent", + ModelConfigID: model.ID, + InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")}, + }) + require.NoError(t, err) + + parentChat, err := db.GetChatByID(ctx, parent.ID) + require.NoError(t, err) + + child, err := server.createChildSubagentChat(ctx, parentChat, "inspect bindings", "") + require.NoError(t, err) + + childChat, err := db.GetChatByID(ctx, child.ID) + require.NoError(t, err) + require.Equal(t, parentChat.WorkspaceID, childChat.WorkspaceID) + require.Equal(t, parentChat.BuildID, childChat.BuildID) + require.Equal(t, parentChat.AgentID, childChat.AgentID) +} + func TestSpawnComputerUseAgent_NoAnthropicProvider(t *testing.T) { t.Parallel() @@ -292,13 +374,26 @@ func TestSpawnComputerUseAgent_UsesComputerUseModelNotParent(t *testing.T) { ctx := chatdTestContext(t) user, model := seedInternalChatDeps(ctx, t, db) + workspace, build, agent := seedWorkspaceBinding(t, db, user.ID) // The parent uses an OpenAI model. require.Equal(t, "openai", model.Provider, "seed helper must create an OpenAI model") parent, err := server.CreateChat(ctx, CreateOptions{ - OwnerID: user.ID, + OwnerID: user.ID, + WorkspaceID: uuid.NullUUID{ + UUID: workspace.ID, + Valid: true, + }, + BuildID: uuid.NullUUID{ + UUID: build.ID, + Valid: true, + }, + AgentID: uuid.NullUUID{ + UUID: agent.ID, + Valid: true, + }, Title: "parent-openai", ModelConfigID: model.ID, InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")}, @@ -332,6 +427,10 @@ func TestSpawnComputerUseAgent_UsesComputerUseModelNotParent(t *testing.T) { childChat, err := db.GetChatByID(ctx, childID) require.NoError(t, err) + require.Equal(t, parentChat.WorkspaceID, childChat.WorkspaceID) + require.Equal(t, parentChat.BuildID, childChat.BuildID) + require.Equal(t, parentChat.AgentID, childChat.AgentID) + // The child must have Mode=computer_use which causes // runChat to override the model to the predefined computer // use model instead of using the parent's model config. diff --git a/codersdk/chats.go b/codersdk/chats.go index 355c10ed11..b4eb35d482 100644 --- a/codersdk/chats.go +++ b/codersdk/chats.go @@ -49,6 +49,8 @@ type Chat struct { ID uuid.UUID `json:"id" format:"uuid"` OwnerID uuid.UUID `json:"owner_id" format:"uuid"` WorkspaceID *uuid.UUID `json:"workspace_id,omitempty" format:"uuid"` + BuildID *uuid.UUID `json:"build_id,omitempty" format:"uuid"` + AgentID *uuid.UUID `json:"agent_id,omitempty" format:"uuid"` ParentChatID *uuid.UUID `json:"parent_chat_id,omitempty" format:"uuid"` RootChatID *uuid.UUID `json:"root_chat_id,omitempty" format:"uuid"` LastModelConfigID uuid.UUID `json:"last_model_config_id" format:"uuid"` diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index c1fd4cc6c0..6b4a5513dc 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -1089,6 +1089,8 @@ export interface Chat { readonly id: string; readonly owner_id: string; readonly workspace_id?: string; + readonly build_id?: string; + readonly agent_id?: string; readonly parent_chat_id?: string; readonly root_chat_id?: string; readonly last_model_config_id: string;