perf(coderd/x/chatd): persist workspace agent binding across chat turns (#23274)

## Summary

This change removes the steady-state "resolve the latest workspace
agent" query from chat execution.

Instead of asking the database for the latest build's agent on every
turn, a chat now persists the workspace/build/agent binding it actually
uses and reuses that binding across subsequent turns. The common path
becomes "load the bound agent by ID and dial it", with fallback paths to
repair the binding when it is missing, stale, or intentionally changed.

## What changes

- add `workspace_id`, `build_id`, and `agent_id` binding fields to
`chats`
- expose those fields through the chat API / SDK so the execution
context is explicit
- load the persisted binding first in chatd, instead of always resolving
the latest build's agent
- persist a refreshed binding when chatd has to re-resolve the workspace
agent
- keep child / subagent chats on the same bound workspace context by
inheriting the parent binding
- leave `build_id` / `agent_id` unset for flows like `create_workspace`,
then bind them lazily on the next agent-backed turn

## Runtime behavior

The binding is treated as an optimistic cache of the agent a chat should
use:

- if the bound agent still exists and dials successfully, we use it
without a latest-build lookup
- if the bound agent is missing or no longer reachable, chatd
re-resolves against the latest build and persists the new binding
- if a workspace mutation changes the chat's target workspace, the
binding is updated as part of that mutation

To avoid reintroducing a hot-path query, dialing uses lazy validation:

- start dialing the cached agent immediately
- only validate against the latest build if the dial is still pending
after a short delay
- if validation finds a different agent, cancel the stale dial, switch
to the current agent, and persist the repaired binding

## Result

The hot path stops issuing
`GetWorkspaceAgentsInLatestBuildByWorkspaceID` for every user message,
which is the source of the DB pressure this PR is addressing. At the
same time, chats still converge to the correct workspace agent when the
binding becomes stale due to rebuilds or explicit workspace changes.
This commit is contained in:
Ethan
2026-03-26 17:22:38 +11:00
committed by GitHub
parent 17aea0b19c
commit 61e31ec5cc
24 changed files with 1655 additions and 179 deletions
+14 -10
View File
@@ -5619,6 +5619,18 @@ func (q *querier) UpdateAPIKeyByID(ctx context.Context, arg database.UpdateAPIKe
return update(q.log, q.auth, fetch, q.db.UpdateAPIKeyByID)(ctx, arg)
}
func (q *querier) UpdateChatBuildAgentBinding(ctx context.Context, arg database.UpdateChatBuildAgentBindingParams) (database.Chat, error) {
chat, err := q.db.GetChatByID(ctx, arg.ID)
if err != nil {
return database.Chat{}, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return database.Chat{}, err
}
return q.db.UpdateChatBuildAgentBinding(ctx, arg)
}
func (q *querier) UpdateChatByID(ctx context.Context, arg database.UpdateChatByIDParams) (database.Chat, error) {
chat, err := q.db.GetChatByID(ctx, arg.ID)
if err != nil {
@@ -5706,7 +5718,7 @@ func (q *querier) UpdateChatStatus(ctx context.Context, arg database.UpdateChatS
return q.db.UpdateChatStatus(ctx, arg)
}
func (q *querier) UpdateChatWorkspace(ctx context.Context, arg database.UpdateChatWorkspaceParams) (database.Chat, error) {
func (q *querier) UpdateChatWorkspaceBinding(ctx context.Context, arg database.UpdateChatWorkspaceBindingParams) (database.Chat, error) {
chat, err := q.db.GetChatByID(ctx, arg.ID)
if err != nil {
return database.Chat{}, err
@@ -5715,15 +5727,7 @@ func (q *querier) UpdateChatWorkspace(ctx context.Context, arg database.UpdateCh
return database.Chat{}, err
}
// UpdateChatWorkspace is manually implemented for chat tables and may not be
// present on every wrapped store interface yet.
chatWorkspaceUpdater, ok := q.db.(interface {
UpdateChatWorkspace(context.Context, database.UpdateChatWorkspaceParams) (database.Chat, error)
})
if !ok {
return database.Chat{}, xerrors.New("update chat workspace is not implemented by wrapped store")
}
return chatWorkspaceUpdater.UpdateChatWorkspace(ctx, arg)
return q.db.UpdateChatWorkspaceBinding(ctx, arg)
}
func (q *querier) UpdateCryptoKeyDeletesAt(ctx context.Context, arg database.UpdateCryptoKeyDeletesAtParams) (database.CryptoKey, error) {
+19 -5
View File
@@ -819,15 +819,29 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().UpdateChatStatus(gomock.Any(), arg).Return(chat, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(chat)
}))
s.Run("UpdateChatWorkspace", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
s.Run("UpdateChatBuildAgentBinding", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
arg := database.UpdateChatWorkspaceParams{
ID: chat.ID,
WorkspaceID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
arg := database.UpdateChatBuildAgentBindingParams{
ID: chat.ID,
BuildID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
AgentID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
}
updatedChat := testutil.Fake(s.T(), faker, database.Chat{ID: chat.ID})
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().UpdateChatWorkspace(gomock.Any(), arg).Return(updatedChat, nil).AnyTimes()
dbm.EXPECT().UpdateChatBuildAgentBinding(gomock.Any(), arg).Return(updatedChat, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(updatedChat)
}))
s.Run("UpdateChatWorkspaceBinding", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
arg := database.UpdateChatWorkspaceBindingParams{
ID: chat.ID,
WorkspaceID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
BuildID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
AgentID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
}
updatedChat := testutil.Fake(s.T(), faker, database.Chat{ID: chat.ID})
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().UpdateChatWorkspaceBinding(gomock.Any(), arg).Return(updatedChat, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(updatedChat)
}))
s.Run("UnsetDefaultChatModelConfigs", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
+12 -4
View File
@@ -4000,6 +4000,14 @@ func (m queryMetricsStore) UpdateAPIKeyByID(ctx context.Context, arg database.Up
return r0
}
func (m queryMetricsStore) UpdateChatBuildAgentBinding(ctx context.Context, arg database.UpdateChatBuildAgentBindingParams) (database.Chat, error) {
start := time.Now()
r0, r1 := m.s.UpdateChatBuildAgentBinding(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateChatBuildAgentBinding").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatBuildAgentBinding").Inc()
return r0, r1
}
func (m queryMetricsStore) UpdateChatByID(ctx context.Context, arg database.UpdateChatByIDParams) (database.Chat, error) {
start := time.Now()
r0, r1 := m.s.UpdateChatByID(ctx, arg)
@@ -4064,11 +4072,11 @@ func (m queryMetricsStore) UpdateChatStatus(ctx context.Context, arg database.Up
return r0, r1
}
func (m queryMetricsStore) UpdateChatWorkspace(ctx context.Context, arg database.UpdateChatWorkspaceParams) (database.Chat, error) {
func (m queryMetricsStore) UpdateChatWorkspaceBinding(ctx context.Context, arg database.UpdateChatWorkspaceBindingParams) (database.Chat, error) {
start := time.Now()
r0, r1 := m.s.UpdateChatWorkspace(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateChatWorkspace").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatWorkspace").Inc()
r0, r1 := m.s.UpdateChatWorkspaceBinding(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateChatWorkspaceBinding").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatWorkspaceBinding").Inc()
return r0, r1
}
+21 -6
View File
@@ -7552,6 +7552,21 @@ func (mr *MockStoreMockRecorder) UpdateAPIKeyByID(ctx, arg any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAPIKeyByID", reflect.TypeOf((*MockStore)(nil).UpdateAPIKeyByID), ctx, arg)
}
// UpdateChatBuildAgentBinding mocks base method.
func (m *MockStore) UpdateChatBuildAgentBinding(ctx context.Context, arg database.UpdateChatBuildAgentBindingParams) (database.Chat, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateChatBuildAgentBinding", ctx, arg)
ret0, _ := ret[0].(database.Chat)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateChatBuildAgentBinding indicates an expected call of UpdateChatBuildAgentBinding.
func (mr *MockStoreMockRecorder) UpdateChatBuildAgentBinding(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatBuildAgentBinding", reflect.TypeOf((*MockStore)(nil).UpdateChatBuildAgentBinding), ctx, arg)
}
// UpdateChatByID mocks base method.
func (m *MockStore) UpdateChatByID(ctx context.Context, arg database.UpdateChatByIDParams) (database.Chat, error) {
m.ctrl.T.Helper()
@@ -7672,19 +7687,19 @@ func (mr *MockStoreMockRecorder) UpdateChatStatus(ctx, arg any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatStatus", reflect.TypeOf((*MockStore)(nil).UpdateChatStatus), ctx, arg)
}
// UpdateChatWorkspace mocks base method.
func (m *MockStore) UpdateChatWorkspace(ctx context.Context, arg database.UpdateChatWorkspaceParams) (database.Chat, error) {
// UpdateChatWorkspaceBinding mocks base method.
func (m *MockStore) UpdateChatWorkspaceBinding(ctx context.Context, arg database.UpdateChatWorkspaceBindingParams) (database.Chat, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateChatWorkspace", ctx, arg)
ret := m.ctrl.Call(m, "UpdateChatWorkspaceBinding", ctx, arg)
ret0, _ := ret[0].(database.Chat)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateChatWorkspace indicates an expected call of UpdateChatWorkspace.
func (mr *MockStoreMockRecorder) UpdateChatWorkspace(ctx, arg any) *gomock.Call {
// UpdateChatWorkspaceBinding indicates an expected call of UpdateChatWorkspaceBinding.
func (mr *MockStoreMockRecorder) UpdateChatWorkspaceBinding(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatWorkspace", reflect.TypeOf((*MockStore)(nil).UpdateChatWorkspace), ctx, arg)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatWorkspaceBinding", reflect.TypeOf((*MockStore)(nil).UpdateChatWorkspaceBinding), ctx, arg)
}
// UpdateCryptoKeyDeletesAt mocks base method.
+9 -1
View File
@@ -1399,7 +1399,9 @@ CREATE TABLE chats (
last_error text,
mode chat_mode,
mcp_server_ids uuid[] DEFAULT '{}'::uuid[] NOT NULL,
labels jsonb DEFAULT '{}'::jsonb NOT NULL
labels jsonb DEFAULT '{}'::jsonb NOT NULL,
build_id uuid,
agent_id uuid
);
CREATE TABLE connection_logs (
@@ -4033,6 +4035,12 @@ ALTER TABLE ONLY chat_providers
ALTER TABLE ONLY chat_queued_messages
ADD CONSTRAINT chat_queued_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
ALTER TABLE ONLY chats
ADD CONSTRAINT chats_agent_id_fkey FOREIGN KEY (agent_id) REFERENCES workspace_agents(id) ON DELETE SET NULL;
ALTER TABLE ONLY chats
ADD CONSTRAINT chats_build_id_fkey FOREIGN KEY (build_id) REFERENCES workspace_builds(id) ON DELETE SET NULL;
ALTER TABLE ONLY chats
ADD CONSTRAINT chats_last_model_config_id_fkey FOREIGN KEY (last_model_config_id) REFERENCES chat_model_configs(id);
@@ -20,6 +20,8 @@ const (
ForeignKeyChatProvidersAPIKeyKeyID ForeignKeyConstraint = "chat_providers_api_key_key_id_fkey" // ALTER TABLE ONLY chat_providers ADD CONSTRAINT chat_providers_api_key_key_id_fkey FOREIGN KEY (api_key_key_id) REFERENCES dbcrypt_keys(active_key_digest);
ForeignKeyChatProvidersCreatedBy ForeignKeyConstraint = "chat_providers_created_by_fkey" // ALTER TABLE ONLY chat_providers ADD CONSTRAINT chat_providers_created_by_fkey FOREIGN KEY (created_by) REFERENCES users(id);
ForeignKeyChatQueuedMessagesChatID ForeignKeyConstraint = "chat_queued_messages_chat_id_fkey" // ALTER TABLE ONLY chat_queued_messages ADD CONSTRAINT chat_queued_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
ForeignKeyChatsAgentID ForeignKeyConstraint = "chats_agent_id_fkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_agent_id_fkey FOREIGN KEY (agent_id) REFERENCES workspace_agents(id) ON DELETE SET NULL;
ForeignKeyChatsBuildID ForeignKeyConstraint = "chats_build_id_fkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_build_id_fkey FOREIGN KEY (build_id) REFERENCES workspace_builds(id) ON DELETE SET NULL;
ForeignKeyChatsLastModelConfigID ForeignKeyConstraint = "chats_last_model_config_id_fkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_last_model_config_id_fkey FOREIGN KEY (last_model_config_id) REFERENCES chat_model_configs(id);
ForeignKeyChatsOwnerID ForeignKeyConstraint = "chats_owner_id_fkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_owner_id_fkey FOREIGN KEY (owner_id) REFERENCES users(id) ON DELETE CASCADE;
ForeignKeyChatsParentChatID ForeignKeyConstraint = "chats_parent_chat_id_fkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_parent_chat_id_fkey FOREIGN KEY (parent_chat_id) REFERENCES chats(id) ON DELETE SET NULL;
@@ -0,0 +1,3 @@
ALTER TABLE chats
DROP COLUMN IF EXISTS build_id,
DROP COLUMN IF EXISTS agent_id;
@@ -0,0 +1,3 @@
ALTER TABLE chats
ADD COLUMN build_id UUID REFERENCES workspace_builds(id) ON DELETE SET NULL,
ADD COLUMN agent_id UUID REFERENCES workspace_agents(id) ON DELETE SET NULL;
+2
View File
@@ -791,6 +791,8 @@ func (q *sqlQuerier) GetAuthorizedChats(ctx context.Context, arg GetChatsParams,
&i.Mode,
pq.Array(&i.MCPServerIDs),
&i.Labels,
&i.BuildID,
&i.AgentID,
); err != nil {
return nil, err
}
+2
View File
@@ -4171,6 +4171,8 @@ type Chat struct {
Mode NullChatMode `db:"mode" json:"mode"`
MCPServerIDs []uuid.UUID `db:"mcp_server_ids" json:"mcp_server_ids"`
Labels StringMap `db:"labels" json:"labels"`
BuildID uuid.NullUUID `db:"build_id" json:"build_id"`
AgentID uuid.NullUUID `db:"agent_id" json:"agent_id"`
}
type ChatDiffStatus struct {
+2 -1
View File
@@ -819,6 +819,7 @@ type sqlcQuerier interface {
UnsetDefaultChatModelConfigs(ctx context.Context) error
UpdateAIBridgeInterceptionEnded(ctx context.Context, arg UpdateAIBridgeInterceptionEndedParams) (AIBridgeInterception, error)
UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error
UpdateChatBuildAgentBinding(ctx context.Context, arg UpdateChatBuildAgentBindingParams) (Chat, error)
UpdateChatByID(ctx context.Context, arg UpdateChatByIDParams) (Chat, error)
// Bumps the heartbeat timestamp for a running chat so that other
// replicas know the worker is still alive.
@@ -829,7 +830,7 @@ type sqlcQuerier interface {
UpdateChatModelConfig(ctx context.Context, arg UpdateChatModelConfigParams) (ChatModelConfig, error)
UpdateChatProvider(ctx context.Context, arg UpdateChatProviderParams) (ChatProvider, error)
UpdateChatStatus(ctx context.Context, arg UpdateChatStatusParams) (Chat, error)
UpdateChatWorkspace(ctx context.Context, arg UpdateChatWorkspaceParams) (Chat, error)
UpdateChatWorkspaceBinding(ctx context.Context, arg UpdateChatWorkspaceBindingParams) (Chat, error)
UpdateCryptoKeyDeletesAt(ctx context.Context, arg UpdateCryptoKeyDeletesAtParams) (CryptoKey, error)
UpdateCustomRole(ctx context.Context, arg UpdateCustomRoleParams) (CustomRole, error)
UpdateExternalAuthLink(ctx context.Context, arg UpdateExternalAuthLinkParams) (ExternalAuthLink, error)
+104 -25
View File
@@ -3823,7 +3823,7 @@ WHERE
$3::int
)
RETURNING
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id
`
type AcquireChatsParams struct {
@@ -3862,6 +3862,8 @@ func (q *sqlQuerier) AcquireChats(ctx context.Context, arg AcquireChatsParams) (
&i.Mode,
pq.Array(&i.MCPServerIDs),
&i.Labels,
&i.BuildID,
&i.AgentID,
); err != nil {
return nil, err
}
@@ -4095,7 +4097,7 @@ func (q *sqlQuerier) DeleteChatUsageLimitUserOverride(ctx context.Context, userI
const getChatByID = `-- name: GetChatByID :one
SELECT
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id
FROM
chats
WHERE
@@ -4124,12 +4126,14 @@ func (q *sqlQuerier) GetChatByID(ctx context.Context, id uuid.UUID) (Chat, error
&i.Mode,
pq.Array(&i.MCPServerIDs),
&i.Labels,
&i.BuildID,
&i.AgentID,
)
return i, err
}
const getChatByIDForUpdate = `-- name: GetChatByIDForUpdate :one
SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels FROM chats WHERE id = $1::uuid FOR UPDATE
SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id FROM chats WHERE id = $1::uuid FOR UPDATE
`
func (q *sqlQuerier) GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (Chat, error) {
@@ -4154,6 +4158,8 @@ func (q *sqlQuerier) GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (Ch
&i.Mode,
pq.Array(&i.MCPServerIDs),
&i.Labels,
&i.BuildID,
&i.AgentID,
)
return i, err
}
@@ -4998,7 +5004,7 @@ func (q *sqlQuerier) GetChatUsageLimitUserOverride(ctx context.Context, userID u
const getChats = `-- name: GetChats :many
SELECT
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id
FROM
chats
WHERE
@@ -5089,6 +5095,8 @@ func (q *sqlQuerier) GetChats(ctx context.Context, arg GetChatsParams) ([]Chat,
&i.Mode,
pq.Array(&i.MCPServerIDs),
&i.Labels,
&i.BuildID,
&i.AgentID,
); err != nil {
return nil, err
}
@@ -5154,7 +5162,7 @@ func (q *sqlQuerier) GetLastChatMessageByRole(ctx context.Context, arg GetLastCh
const getStaleChats = `-- name: GetStaleChats :many
SELECT
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id
FROM
chats
WHERE
@@ -5192,6 +5200,8 @@ func (q *sqlQuerier) GetStaleChats(ctx context.Context, staleThreshold time.Time
&i.Mode,
pq.Array(&i.MCPServerIDs),
&i.Labels,
&i.BuildID,
&i.AgentID,
); err != nil {
return nil, err
}
@@ -5250,6 +5260,8 @@ const insertChat = `-- name: InsertChat :one
INSERT INTO chats (
owner_id,
workspace_id,
build_id,
agent_id,
parent_chat_id,
root_chat_id,
last_model_config_id,
@@ -5263,18 +5275,22 @@ INSERT INTO chats (
$3::uuid,
$4::uuid,
$5::uuid,
$6::text,
$7::chat_mode,
COALESCE($8::uuid[], '{}'::uuid[]),
COALESCE($9::jsonb, '{}'::jsonb)
$6::uuid,
$7::uuid,
$8::text,
$9::chat_mode,
COALESCE($10::uuid[], '{}'::uuid[]),
COALESCE($11::jsonb, '{}'::jsonb)
)
RETURNING
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id
`
type InsertChatParams struct {
OwnerID uuid.UUID `db:"owner_id" json:"owner_id"`
WorkspaceID uuid.NullUUID `db:"workspace_id" json:"workspace_id"`
BuildID uuid.NullUUID `db:"build_id" json:"build_id"`
AgentID uuid.NullUUID `db:"agent_id" json:"agent_id"`
ParentChatID uuid.NullUUID `db:"parent_chat_id" json:"parent_chat_id"`
RootChatID uuid.NullUUID `db:"root_chat_id" json:"root_chat_id"`
LastModelConfigID uuid.UUID `db:"last_model_config_id" json:"last_model_config_id"`
@@ -5288,6 +5304,8 @@ func (q *sqlQuerier) InsertChat(ctx context.Context, arg InsertChatParams) (Chat
row := q.db.QueryRowContext(ctx, insertChat,
arg.OwnerID,
arg.WorkspaceID,
arg.BuildID,
arg.AgentID,
arg.ParentChatID,
arg.RootChatID,
arg.LastModelConfigID,
@@ -5316,6 +5334,8 @@ func (q *sqlQuerier) InsertChat(ctx context.Context, arg InsertChatParams) (Chat
&i.Mode,
pq.Array(&i.MCPServerIDs),
&i.Labels,
&i.BuildID,
&i.AgentID,
)
return i, err
}
@@ -5702,6 +5722,50 @@ func (q *sqlQuerier) UnarchiveChatByID(ctx context.Context, id uuid.UUID) error
return err
}
const updateChatBuildAgentBinding = `-- name: UpdateChatBuildAgentBinding :one
UPDATE chats SET
build_id = $1::uuid,
agent_id = $2::uuid,
updated_at = NOW()
WHERE
id = $3::uuid
RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id
`
type UpdateChatBuildAgentBindingParams struct {
BuildID uuid.NullUUID `db:"build_id" json:"build_id"`
AgentID uuid.NullUUID `db:"agent_id" json:"agent_id"`
ID uuid.UUID `db:"id" json:"id"`
}
func (q *sqlQuerier) UpdateChatBuildAgentBinding(ctx context.Context, arg UpdateChatBuildAgentBindingParams) (Chat, error) {
row := q.db.QueryRowContext(ctx, updateChatBuildAgentBinding, arg.BuildID, arg.AgentID, arg.ID)
var i Chat
err := row.Scan(
&i.ID,
&i.OwnerID,
&i.WorkspaceID,
&i.Title,
&i.Status,
&i.WorkerID,
&i.StartedAt,
&i.HeartbeatAt,
&i.CreatedAt,
&i.UpdatedAt,
&i.ParentChatID,
&i.RootChatID,
&i.LastModelConfigID,
&i.Archived,
&i.LastError,
&i.Mode,
pq.Array(&i.MCPServerIDs),
&i.Labels,
&i.BuildID,
&i.AgentID,
)
return i, err
}
const updateChatByID = `-- name: UpdateChatByID :one
UPDATE
chats
@@ -5711,7 +5775,7 @@ SET
WHERE
id = $2::uuid
RETURNING
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id
`
type UpdateChatByIDParams struct {
@@ -5741,6 +5805,8 @@ func (q *sqlQuerier) UpdateChatByID(ctx context.Context, arg UpdateChatByIDParam
&i.Mode,
pq.Array(&i.MCPServerIDs),
&i.Labels,
&i.BuildID,
&i.AgentID,
)
return i, err
}
@@ -5780,7 +5846,7 @@ SET
WHERE
id = $2::uuid
RETURNING
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id
`
type UpdateChatLabelsByIDParams struct {
@@ -5810,6 +5876,8 @@ func (q *sqlQuerier) UpdateChatLabelsByID(ctx context.Context, arg UpdateChatLab
&i.Mode,
pq.Array(&i.MCPServerIDs),
&i.Labels,
&i.BuildID,
&i.AgentID,
)
return i, err
}
@@ -5823,7 +5891,7 @@ SET
WHERE
id = $2::uuid
RETURNING
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id
`
type UpdateChatMCPServerIDsParams struct {
@@ -5853,6 +5921,8 @@ func (q *sqlQuerier) UpdateChatMCPServerIDs(ctx context.Context, arg UpdateChatM
&i.Mode,
pq.Array(&i.MCPServerIDs),
&i.Labels,
&i.BuildID,
&i.AgentID,
)
return i, err
}
@@ -5917,7 +5987,7 @@ SET
WHERE
id = $6::uuid
RETURNING
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id
`
type UpdateChatStatusParams struct {
@@ -5958,29 +6028,36 @@ func (q *sqlQuerier) UpdateChatStatus(ctx context.Context, arg UpdateChatStatusP
&i.Mode,
pq.Array(&i.MCPServerIDs),
&i.Labels,
&i.BuildID,
&i.AgentID,
)
return i, err
}
const updateChatWorkspace = `-- name: UpdateChatWorkspace :one
UPDATE
chats
SET
const updateChatWorkspaceBinding = `-- name: UpdateChatWorkspaceBinding :one
UPDATE chats SET
workspace_id = $1::uuid,
build_id = $2::uuid,
agent_id = $3::uuid,
updated_at = NOW()
WHERE
id = $2::uuid
RETURNING
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels
WHERE id = $4::uuid
RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id
`
type UpdateChatWorkspaceParams struct {
type UpdateChatWorkspaceBindingParams struct {
WorkspaceID uuid.NullUUID `db:"workspace_id" json:"workspace_id"`
BuildID uuid.NullUUID `db:"build_id" json:"build_id"`
AgentID uuid.NullUUID `db:"agent_id" json:"agent_id"`
ID uuid.UUID `db:"id" json:"id"`
}
func (q *sqlQuerier) UpdateChatWorkspace(ctx context.Context, arg UpdateChatWorkspaceParams) (Chat, error) {
row := q.db.QueryRowContext(ctx, updateChatWorkspace, arg.WorkspaceID, arg.ID)
func (q *sqlQuerier) UpdateChatWorkspaceBinding(ctx context.Context, arg UpdateChatWorkspaceBindingParams) (Chat, error) {
row := q.db.QueryRowContext(ctx, updateChatWorkspaceBinding,
arg.WorkspaceID,
arg.BuildID,
arg.AgentID,
arg.ID,
)
var i Chat
err := row.Scan(
&i.ID,
@@ -6001,6 +6078,8 @@ func (q *sqlQuerier) UpdateChatWorkspace(ctx context.Context, arg UpdateChatWork
&i.Mode,
pq.Array(&i.MCPServerIDs),
&i.Labels,
&i.BuildID,
&i.AgentID,
)
return i, err
}
+17 -6
View File
@@ -180,6 +180,8 @@ LIMIT
INSERT INTO chats (
owner_id,
workspace_id,
build_id,
agent_id,
parent_chat_id,
root_chat_id,
last_model_config_id,
@@ -190,6 +192,8 @@ INSERT INTO chats (
) VALUES (
@owner_id::uuid,
sqlc.narg('workspace_id')::uuid,
sqlc.narg('build_id')::uuid,
sqlc.narg('agent_id')::uuid,
sqlc.narg('parent_chat_id')::uuid,
sqlc.narg('root_chat_id')::uuid,
@last_model_config_id::uuid,
@@ -305,16 +309,23 @@ WHERE
RETURNING
*;
-- name: UpdateChatWorkspace :one
UPDATE
chats
SET
-- name: UpdateChatWorkspaceBinding :one
UPDATE chats SET
workspace_id = sqlc.narg('workspace_id')::uuid,
build_id = sqlc.narg('build_id')::uuid,
agent_id = sqlc.narg('agent_id')::uuid,
updated_at = NOW()
WHERE id = @id::uuid
RETURNING *;
-- name: UpdateChatBuildAgentBinding :one
UPDATE chats SET
build_id = sqlc.narg('build_id')::uuid,
agent_id = sqlc.narg('agent_id')::uuid,
updated_at = NOW()
WHERE
id = @id::uuid
RETURNING
*;
RETURNING *;
-- name: UpdateChatMCPServerIDs :one
UPDATE
+6
View File
@@ -3600,6 +3600,12 @@ func convertChat(c database.Chat, diffStatus *database.ChatDiffStatus) codersdk.
if c.WorkspaceID.Valid {
chat.WorkspaceID = &c.WorkspaceID.UUID
}
if c.BuildID.Valid {
chat.BuildID = &c.BuildID.UUID
}
if c.AgentID.Valid {
chat.AgentID = &c.AgentID.UUID
}
if diffStatus != nil {
convertedDiffStatus := db2sdk.ChatDiffStatus(c.ID, diffStatus)
chat.DiffStatus = &convertedDiffStatus
+293 -99
View File
@@ -53,6 +53,8 @@ const (
DefaultInFlightChatStaleAfter = 5 * time.Minute
homeInstructionLookupTimeout = 5 * time.Second
instructionCacheTTL = 5 * time.Minute
workspaceDialValidationDelay = 5 * time.Second
// DefaultChatHeartbeatInterval is the default time between chat
// heartbeat updates while a chat is being processed.
DefaultChatHeartbeatInterval = 30 * time.Second
@@ -158,18 +160,26 @@ type turnWorkspaceContext struct {
currentChat *database.Chat
loadChatSnapshot func(context.Context, uuid.UUID) (database.Chat, error)
mu sync.Mutex
agent database.WorkspaceAgent
agentLoaded bool
conn workspacesdk.AgentConn
releaseConn func()
mu sync.Mutex
agent database.WorkspaceAgent
agentLoaded bool
conn workspacesdk.AgentConn
releaseConn func()
cachedWorkspaceID uuid.NullUUID
}
func (c *turnWorkspaceContext) close() {
c.clearCachedWorkspaceState()
}
func (c *turnWorkspaceContext) clearCachedWorkspaceState() {
c.mu.Lock()
releaseConn := c.releaseConn
c.agent = database.WorkspaceAgent{}
c.agentLoaded = false
c.conn = nil
c.releaseConn = nil
c.cachedWorkspaceID = uuid.NullUUID{}
c.mu.Unlock()
if releaseConn != nil {
@@ -177,6 +187,68 @@ func (c *turnWorkspaceContext) close() {
}
}
func (c *turnWorkspaceContext) setCurrentChat(chat database.Chat) {
c.chatStateMu.Lock()
*c.currentChat = chat
c.chatStateMu.Unlock()
}
func (c *turnWorkspaceContext) currentChatSnapshot() database.Chat {
c.chatStateMu.Lock()
chatSnapshot := *c.currentChat
c.chatStateMu.Unlock()
return chatSnapshot
}
func (c *turnWorkspaceContext) selectWorkspace(chat database.Chat) {
c.setCurrentChat(chat)
c.clearCachedWorkspaceState()
}
func (c *turnWorkspaceContext) currentWorkspaceMatches(expected uuid.NullUUID) (database.Chat, bool) {
chatSnapshot := c.currentChatSnapshot()
return chatSnapshot, nullUUIDEqual(chatSnapshot.WorkspaceID, expected)
}
func nullUUIDEqual(left, right uuid.NullUUID) bool {
if left.Valid != right.Valid {
return false
}
if !left.Valid {
return true
}
return left.UUID == right.UUID
}
func (c *turnWorkspaceContext) persistBuildAgentBinding(
ctx context.Context,
chatSnapshot database.Chat,
buildID uuid.UUID,
agentID uuid.UUID,
) (database.Chat, error) {
updatedChat, err := c.server.db.UpdateChatBuildAgentBinding(
ctx,
database.UpdateChatBuildAgentBindingParams{
ID: chatSnapshot.ID,
BuildID: uuid.NullUUID{
UUID: buildID,
Valid: true,
},
AgentID: uuid.NullUUID{
UUID: agentID,
Valid: true,
},
},
)
if err != nil {
return chatSnapshot, xerrors.Errorf(
"update chat build/agent binding: %w", err,
)
}
c.setCurrentChat(updatedChat)
return updatedChat, nil
}
func (c *turnWorkspaceContext) getWorkspaceAgent(ctx context.Context) (database.WorkspaceAgent, error) {
_, agent, err := c.ensureWorkspaceAgent(ctx)
return agent, err
@@ -189,128 +261,245 @@ func (c *turnWorkspaceContext) ensureWorkspaceAgent(
defer c.mu.Unlock()
if c.agentLoaded {
c.chatStateMu.Lock()
chatSnapshot := *c.currentChat
c.chatStateMu.Unlock()
return chatSnapshot, c.agent, nil
chatSnapshot := c.currentChatSnapshot()
if nullUUIDEqual(c.cachedWorkspaceID, chatSnapshot.WorkspaceID) {
return chatSnapshot, c.agent, nil
}
c.agent = database.WorkspaceAgent{}
c.agentLoaded = false
}
return c.loadWorkspaceAgentLocked(ctx)
}
func (c *turnWorkspaceContext) refreshWorkspaceAgent(
ctx context.Context,
) (database.Chat, database.WorkspaceAgent, error) {
c.mu.Lock()
defer c.mu.Unlock()
c.agent = database.WorkspaceAgent{}
c.agentLoaded = false
return c.loadWorkspaceAgentLocked(ctx)
}
func (c *turnWorkspaceContext) loadWorkspaceAgentLocked(
ctx context.Context,
) (database.Chat, database.WorkspaceAgent, error) {
c.chatStateMu.Lock()
chatSnapshot := *c.currentChat
c.chatStateMu.Unlock()
chatSnapshot := c.currentChatSnapshot()
if !chatSnapshot.WorkspaceID.Valid {
refreshedChat, refreshErr := refreshChatWorkspaceSnapshot(
for attempt := 0; attempt < 2; attempt++ {
if !chatSnapshot.WorkspaceID.Valid {
refreshedChat, refreshErr := refreshChatWorkspaceSnapshot(
ctx,
chatSnapshot,
c.loadChatSnapshot,
)
if refreshErr != nil {
return chatSnapshot, database.WorkspaceAgent{}, refreshErr
}
if refreshedChat.WorkspaceID.Valid {
c.setCurrentChat(refreshedChat)
chatSnapshot = refreshedChat
}
}
if !chatSnapshot.WorkspaceID.Valid {
return chatSnapshot, database.WorkspaceAgent{}, xerrors.New("chat has no workspace")
}
if chatSnapshot.AgentID.Valid {
agent, err := c.server.db.GetWorkspaceAgentByID(ctx, chatSnapshot.AgentID.UUID)
if err == nil {
latestChat, workspaceMatches := c.currentWorkspaceMatches(chatSnapshot.WorkspaceID)
if !workspaceMatches {
chatSnapshot = latestChat
continue
}
c.agent = agent
c.agentLoaded = true
c.cachedWorkspaceID = chatSnapshot.WorkspaceID
return chatSnapshot, c.agent, nil
}
if !xerrors.Is(err, sql.ErrNoRows) {
c.server.logger.Warn(ctx, "agent binding lookup failed, re-resolving",
slog.F("agent_id", chatSnapshot.AgentID.UUID),
slog.Error(err),
)
}
}
agents, err := c.server.db.GetWorkspaceAgentsInLatestBuildByWorkspaceID(
ctx,
chatSnapshot.WorkspaceID.UUID,
)
if err != nil || len(agents) == 0 {
return chatSnapshot, database.WorkspaceAgent{}, xerrors.New("chat has no workspace agent")
}
build, err := c.server.db.GetLatestWorkspaceBuildByWorkspaceID(ctx, chatSnapshot.WorkspaceID.UUID)
if err != nil {
return chatSnapshot, database.WorkspaceAgent{}, xerrors.Errorf("get latest workspace build: %w", err)
}
updatedChat, err := c.persistBuildAgentBinding(
ctx,
chatSnapshot,
c.loadChatSnapshot,
build.ID,
agents[0].ID,
)
if refreshErr != nil {
return chatSnapshot, database.WorkspaceAgent{}, refreshErr
if err != nil {
return chatSnapshot, database.WorkspaceAgent{}, err
}
if refreshedChat.WorkspaceID.Valid {
c.chatStateMu.Lock()
*c.currentChat = refreshedChat
c.chatStateMu.Unlock()
chatSnapshot = refreshedChat
chatSnapshot = updatedChat
latestChat, workspaceMatches := c.currentWorkspaceMatches(chatSnapshot.WorkspaceID)
if !workspaceMatches {
chatSnapshot = latestChat
continue
}
c.agent = agents[0]
c.agentLoaded = true
c.cachedWorkspaceID = chatSnapshot.WorkspaceID
return chatSnapshot, c.agent, nil
}
if !chatSnapshot.WorkspaceID.Valid {
return chatSnapshot, database.WorkspaceAgent{}, xerrors.New("chat has no workspace")
}
agents, err := c.server.db.GetWorkspaceAgentsInLatestBuildByWorkspaceID(
ctx,
chatSnapshot.WorkspaceID.UUID,
return chatSnapshot, database.WorkspaceAgent{}, xerrors.New(
"chat workspace changed while resolving agent",
)
if err != nil || len(agents) == 0 {
return chatSnapshot, database.WorkspaceAgent{}, xerrors.New("chat has no workspace agent")
}
// getWorkspaceConnLocked returns the cached connection when it still matches
// the current workspace. When the workspace changed, it clears the stale
// cached state and returns the release func for the caller to run after
// unlocking.
func (c *turnWorkspaceContext) getWorkspaceConnLocked() (workspacesdk.AgentConn, func()) {
if c.conn == nil {
return nil, nil
}
c.agent = agents[0]
c.agentLoaded = true
return chatSnapshot, c.agent, nil
chatSnapshot := c.currentChatSnapshot()
if nullUUIDEqual(c.cachedWorkspaceID, chatSnapshot.WorkspaceID) {
return c.conn, nil
}
agentRelease := c.releaseConn
c.agent = database.WorkspaceAgent{}
c.agentLoaded = false
c.conn = nil
c.releaseConn = nil
c.cachedWorkspaceID = uuid.NullUUID{}
return nil, agentRelease
}
func (c *turnWorkspaceContext) getWorkspaceConn(ctx context.Context) (workspacesdk.AgentConn, error) {
c.mu.Lock()
if c.conn != nil {
currentConn := c.conn
c.mu.Unlock()
return currentConn, nil
}
c.mu.Unlock()
if c.server.agentConnFn == nil {
return nil, xerrors.New("workspace agent connector is not configured")
}
chatSnapshot, agent, err := c.ensureWorkspaceAgent(ctx)
if err != nil {
return nil, err
}
agentConn, agentRelease, err := c.server.agentConnFn(ctx, agent.ID)
if err != nil {
refreshedChat, refreshedAgent, refreshErr := c.refreshWorkspaceAgent(ctx)
if refreshErr != nil {
return nil, xerrors.Errorf("connect to workspace agent: %w", err)
}
retryConn, retryRelease, retryErr := c.server.agentConnFn(ctx, refreshedAgent.ID)
if retryErr != nil {
return nil, xerrors.Errorf("connect to workspace agent after refresh: %w", retryErr)
}
chatSnapshot = refreshedChat
agentConn = retryConn
agentRelease = retryRelease
}
c.mu.Lock()
if c.conn == nil {
c.conn = agentConn
c.releaseConn = agentRelease
var ancestorIDs []string
if chatSnapshot.ParentChatID.Valid {
ancestorIDs = append(ancestorIDs, chatSnapshot.ParentChatID.UUID.String())
}
ancestorJSON, marshalErr := json.Marshal(ancestorIDs)
if marshalErr != nil {
ancestorJSON = []byte("[]")
}
agentConn.SetExtraHeaders(http.Header{
workspacesdk.CoderChatIDHeader: {chatSnapshot.ID.String()},
workspacesdk.CoderAncestorChatIDsHeader: {string(ancestorJSON)},
})
for attempt := 0; attempt < 2; attempt++ {
c.mu.Lock()
currentConn, staleRelease := c.getWorkspaceConnLocked()
c.mu.Unlock()
return agentConn, nil
}
currentConn := c.conn
c.mu.Unlock()
if currentConn != nil {
return currentConn, nil
}
if staleRelease != nil {
staleRelease()
}
agentRelease()
return currentConn, nil
chatSnapshot, agent, err := c.ensureWorkspaceAgent(ctx)
if err != nil {
return nil, err
}
dialResult, err := dialWithLazyValidation(
ctx,
agent.ID,
chatSnapshot.WorkspaceID.UUID,
DialFunc(c.server.agentConnFn),
func(ctx context.Context, workspaceID uuid.UUID) (uuid.UUID, error) {
agents, err := c.server.db.GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx, workspaceID)
if err != nil || len(agents) == 0 {
return uuid.Nil, xerrors.New("chat has no workspace agent")
}
return agents[0].ID, nil
},
workspaceDialValidationDelay,
)
if err != nil {
return nil, err
}
agentConn := dialResult.Conn
agentRelease := dialResult.Release
if dialResult.WasSwitched {
build, err := c.server.db.GetLatestWorkspaceBuildByWorkspaceID(ctx, chatSnapshot.WorkspaceID.UUID)
if err != nil {
if agentRelease != nil {
agentRelease()
}
return nil, xerrors.Errorf("get latest workspace build: %w", err)
}
switchedAgent, err := c.server.db.GetWorkspaceAgentByID(ctx, dialResult.AgentID)
if err != nil {
if agentRelease != nil {
agentRelease()
}
return nil, xerrors.Errorf("get workspace agent by id: %w", err)
}
updatedChat, err := c.persistBuildAgentBinding(
ctx,
chatSnapshot,
build.ID,
switchedAgent.ID,
)
if err != nil {
if agentRelease != nil {
agentRelease()
}
return nil, err
}
chatSnapshot = updatedChat
c.mu.Lock()
c.agent = switchedAgent
c.agentLoaded = true
c.cachedWorkspaceID = chatSnapshot.WorkspaceID
c.mu.Unlock()
}
if _, workspaceMatches := c.currentWorkspaceMatches(chatSnapshot.WorkspaceID); !workspaceMatches {
if agentRelease != nil {
agentRelease()
}
c.clearCachedWorkspaceState()
continue
}
c.mu.Lock()
if c.conn == nil {
c.conn = agentConn
c.releaseConn = agentRelease
c.cachedWorkspaceID = chatSnapshot.WorkspaceID
var ancestorIDs []string
if chatSnapshot.ParentChatID.Valid {
ancestorIDs = append(ancestorIDs, chatSnapshot.ParentChatID.UUID.String())
}
ancestorJSON, marshalErr := json.Marshal(ancestorIDs)
if marshalErr != nil {
ancestorJSON = []byte("[]")
}
agentConn.SetExtraHeaders(http.Header{
workspacesdk.CoderChatIDHeader: {chatSnapshot.ID.String()},
workspacesdk.CoderAncestorChatIDsHeader: {string(ancestorJSON)},
})
c.mu.Unlock()
return agentConn, nil
}
currentConn = c.conn
c.mu.Unlock()
if agentRelease != nil {
agentRelease()
}
return currentConn, nil
}
return nil, xerrors.New("chat workspace changed while connecting")
}
// AgentConnFunc provides access to workspace agent connections.
@@ -420,6 +609,8 @@ func (e *UsageLimitExceededError) Error() string {
type CreateOptions struct {
OwnerID uuid.UUID
WorkspaceID uuid.NullUUID
BuildID uuid.NullUUID
AgentID uuid.NullUUID
ParentChatID uuid.NullUUID
RootChatID uuid.NullUUID
Title string
@@ -525,6 +716,8 @@ func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.C
insertedChat, err := tx.InsertChat(ctx, database.InsertChatParams{
OwnerID: opts.OwnerID,
WorkspaceID: opts.WorkspaceID,
BuildID: opts.BuildID,
AgentID: opts.AgentID,
ParentChatID: opts.ParentChatID,
RootChatID: opts.RootChatID,
LastModelConfigID: opts.ModelConfigID,
@@ -3461,6 +3654,7 @@ func (p *Server) runChat(
AgentConnFn: chattool.AgentConnFunc(p.agentConnFn),
AgentInactiveDisconnectTimeout: p.agentInactiveDisconnectTimeout,
WorkspaceMu: &workspaceMu,
OnChatUpdated: workspaceCtx.selectWorkspace,
Logger: p.logger,
AllowedTemplateIDs: p.chatTemplateAllowlist,
}),
+291 -17
View File
@@ -112,24 +112,29 @@ func TestPersistInstructionFilesIncludesAgentMetadata(t *testing.T) {
db := dbmock.NewMockStore(ctrl)
workspaceID := uuid.New()
agentID := uuid.New()
chat := database.Chat{
ID: uuid.New(),
WorkspaceID: uuid.NullUUID{
UUID: workspaceID,
Valid: true,
},
AgentID: uuid.NullUUID{
UUID: agentID,
Valid: true,
},
}
workspaceAgent := database.WorkspaceAgent{
ID: uuid.New(),
ID: agentID,
OperatingSystem: "linux",
Directory: "/home/coder/project",
ExpandedDirectory: "/home/coder/project",
}
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(
db.EXPECT().GetWorkspaceAgentByID(
gomock.Any(),
workspaceID,
).Return([]database.WorkspaceAgent{workspaceAgent}, nil).Times(1)
agentID,
).Return(workspaceAgent, nil).Times(1)
db.EXPECT().InsertChatMessages(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
conn := agentconnmock.NewMockAgentConn(ctrl)
@@ -180,7 +185,7 @@ func TestPersistInstructionFilesIncludesAgentMetadata(t *testing.T) {
require.Contains(t, instruction, "Working Directory: /home/coder/project")
}
func TestTurnWorkspaceContextGetWorkspaceConnRefreshesWorkspaceAgent(t *testing.T) {
func TestTurnWorkspaceContext_BindingFirstPath(t *testing.T) {
t.Parallel()
ctx := context.Background()
@@ -188,6 +193,53 @@ func TestTurnWorkspaceContextGetWorkspaceConnRefreshesWorkspaceAgent(t *testing.
db := dbmock.NewMockStore(ctrl)
workspaceID := uuid.New()
agentID := uuid.New()
chat := database.Chat{
ID: uuid.New(),
WorkspaceID: uuid.NullUUID{
UUID: workspaceID,
Valid: true,
},
AgentID: uuid.NullUUID{
UUID: agentID,
Valid: true,
},
}
workspaceAgent := database.WorkspaceAgent{ID: agentID}
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).Return(workspaceAgent, nil).Times(1)
chatStateMu := &sync.Mutex{}
currentChat := chat
workspaceCtx := turnWorkspaceContext{
server: &Server{db: db},
chatStateMu: chatStateMu,
currentChat: &currentChat,
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
}
t.Cleanup(workspaceCtx.close)
chatSnapshot, agent, err := workspaceCtx.ensureWorkspaceAgent(ctx)
require.NoError(t, err)
require.Equal(t, chat, chatSnapshot)
require.Equal(t, workspaceAgent, agent)
gotAgent, err := workspaceCtx.getWorkspaceAgent(ctx)
require.NoError(t, err)
require.Equal(t, workspaceAgent, gotAgent)
require.Equal(t, chat, currentChat)
}
func TestTurnWorkspaceContext_NullBindingLazyBind(t *testing.T) {
t.Parallel()
ctx := context.Background()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
workspaceID := uuid.New()
buildID := uuid.New()
agentID := uuid.New()
chat := database.Chat{
ID: uuid.New(),
WorkspaceID: uuid.NullUUID{
@@ -195,18 +247,135 @@ func TestTurnWorkspaceContextGetWorkspaceConnRefreshesWorkspaceAgent(t *testing.
Valid: true,
},
}
initialAgent := database.WorkspaceAgent{ID: uuid.New()}
refreshedAgent := database.WorkspaceAgent{ID: uuid.New()}
workspaceAgent := database.WorkspaceAgent{ID: agentID}
updatedChat := chat
updatedChat.BuildID = uuid.NullUUID{UUID: buildID, Valid: true}
updatedChat.AgentID = uuid.NullUUID{UUID: agentID, Valid: true}
gomock.InOrder(
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(
gomock.Any(),
workspaceID,
).Return([]database.WorkspaceAgent{initialAgent}, nil),
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(
gomock.Any(),
workspaceID,
).Return([]database.WorkspaceAgent{refreshedAgent}, nil),
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).Return([]database.WorkspaceAgent{workspaceAgent}, nil),
db.EXPECT().GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), workspaceID).Return(database.WorkspaceBuild{ID: buildID}, nil),
db.EXPECT().UpdateChatBuildAgentBinding(gomock.Any(), database.UpdateChatBuildAgentBindingParams{
BuildID: uuid.NullUUID{UUID: buildID, Valid: true},
AgentID: uuid.NullUUID{UUID: agentID, Valid: true},
ID: chat.ID,
}).Return(updatedChat, nil),
)
chatStateMu := &sync.Mutex{}
currentChat := chat
workspaceCtx := turnWorkspaceContext{
server: &Server{db: db},
chatStateMu: chatStateMu,
currentChat: &currentChat,
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
}
t.Cleanup(workspaceCtx.close)
chatSnapshot, agent, err := workspaceCtx.ensureWorkspaceAgent(ctx)
require.NoError(t, err)
require.Equal(t, updatedChat, chatSnapshot)
require.Equal(t, workspaceAgent, agent)
require.Equal(t, updatedChat, currentChat)
gotAgent, err := workspaceCtx.getWorkspaceAgent(ctx)
require.NoError(t, err)
require.Equal(t, workspaceAgent, gotAgent)
}
func TestTurnWorkspaceContext_StaleBindingRepair(t *testing.T) {
t.Parallel()
ctx := context.Background()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
workspaceID := uuid.New()
staleAgentID := uuid.New()
buildID := uuid.New()
currentAgentID := uuid.New()
chat := database.Chat{
ID: uuid.New(),
WorkspaceID: uuid.NullUUID{
UUID: workspaceID,
Valid: true,
},
AgentID: uuid.NullUUID{
UUID: staleAgentID,
Valid: true,
},
}
currentAgent := database.WorkspaceAgent{ID: currentAgentID}
updatedChat := chat
updatedChat.BuildID = uuid.NullUUID{UUID: buildID, Valid: true}
updatedChat.AgentID = uuid.NullUUID{UUID: currentAgentID, Valid: true}
gomock.InOrder(
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), staleAgentID).Return(database.WorkspaceAgent{}, xerrors.New("missing agent")),
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).Return([]database.WorkspaceAgent{currentAgent}, nil),
db.EXPECT().GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), workspaceID).Return(database.WorkspaceBuild{ID: buildID}, nil),
db.EXPECT().UpdateChatBuildAgentBinding(gomock.Any(), database.UpdateChatBuildAgentBindingParams{
BuildID: uuid.NullUUID{UUID: buildID, Valid: true},
AgentID: uuid.NullUUID{UUID: currentAgentID, Valid: true},
ID: chat.ID,
}).Return(updatedChat, nil),
)
chatStateMu := &sync.Mutex{}
currentChat := chat
workspaceCtx := turnWorkspaceContext{
server: &Server{db: db},
chatStateMu: chatStateMu,
currentChat: &currentChat,
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
}
t.Cleanup(workspaceCtx.close)
chatSnapshot, agent, err := workspaceCtx.ensureWorkspaceAgent(ctx)
require.NoError(t, err)
require.Equal(t, updatedChat, chatSnapshot)
require.Equal(t, currentAgent, agent)
require.Equal(t, updatedChat, currentChat)
}
func TestTurnWorkspaceContextGetWorkspaceConnLazyValidationSwitchesWorkspaceAgent(t *testing.T) {
t.Parallel()
ctx := context.Background()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
workspaceID := uuid.New()
staleAgentID := uuid.New()
currentAgentID := uuid.New()
buildID := uuid.New()
chat := database.Chat{
ID: uuid.New(),
WorkspaceID: uuid.NullUUID{
UUID: workspaceID,
Valid: true,
},
AgentID: uuid.NullUUID{
UUID: staleAgentID,
Valid: true,
},
}
staleAgent := database.WorkspaceAgent{ID: staleAgentID}
currentAgent := database.WorkspaceAgent{ID: currentAgentID}
updatedChat := chat
updatedChat.BuildID = uuid.NullUUID{UUID: buildID, Valid: true}
updatedChat.AgentID = uuid.NullUUID{UUID: currentAgentID, Valid: true}
gomock.InOrder(
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), staleAgentID).Return(staleAgent, nil),
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).Return([]database.WorkspaceAgent{currentAgent}, nil),
db.EXPECT().GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), workspaceID).Return(database.WorkspaceBuild{ID: buildID}, nil),
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), currentAgentID).Return(currentAgent, nil),
db.EXPECT().UpdateChatBuildAgentBinding(gomock.Any(), database.UpdateChatBuildAgentBindingParams{
BuildID: uuid.NullUUID{UUID: buildID, Valid: true},
AgentID: uuid.NullUUID{UUID: currentAgentID, Valid: true},
ID: chat.ID,
}).Return(updatedChat, nil),
)
conn := agentconnmock.NewMockAgentConn(ctrl)
@@ -216,7 +385,7 @@ func TestTurnWorkspaceContextGetWorkspaceConnRefreshesWorkspaceAgent(t *testing.
server := &Server{db: db}
server.agentConnFn = func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) {
dialed = append(dialed, agentID)
if agentID == initialAgent.ID {
if agentID == staleAgentID {
return nil, nil, xerrors.New("dial failed")
}
return conn, func() {}, nil
@@ -235,7 +404,112 @@ func TestTurnWorkspaceContextGetWorkspaceConnRefreshesWorkspaceAgent(t *testing.
gotConn, err := workspaceCtx.getWorkspaceConn(ctx)
require.NoError(t, err)
require.Same(t, conn, gotConn)
require.Equal(t, []uuid.UUID{initialAgent.ID, refreshedAgent.ID}, dialed)
require.Equal(t, []uuid.UUID{staleAgentID, currentAgentID}, dialed)
require.Equal(t, updatedChat, currentChat)
gotAgent, err := workspaceCtx.getWorkspaceAgent(ctx)
require.NoError(t, err)
require.Equal(t, currentAgent, gotAgent)
}
func TestTurnWorkspaceContext_SelectWorkspaceClearsCachedState(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
currentChat := database.Chat{
ID: uuid.New(),
WorkspaceID: uuid.NullUUID{
UUID: uuid.New(),
Valid: true,
},
}
updatedChat := database.Chat{
ID: currentChat.ID,
WorkspaceID: uuid.NullUUID{
UUID: uuid.New(),
Valid: true,
},
}
cachedConn := agentconnmock.NewMockAgentConn(ctrl)
releaseCalls := 0
workspaceCtx := turnWorkspaceContext{
chatStateMu: &sync.Mutex{},
currentChat: &currentChat,
}
workspaceCtx.agent = database.WorkspaceAgent{ID: uuid.New()}
workspaceCtx.agentLoaded = true
workspaceCtx.conn = cachedConn
workspaceCtx.cachedWorkspaceID = currentChat.WorkspaceID
workspaceCtx.releaseConn = func() {
releaseCalls++
}
workspaceCtx.selectWorkspace(updatedChat)
require.Equal(t, updatedChat, currentChat)
require.Equal(t, 1, releaseCalls)
workspaceCtx.mu.Lock()
defer workspaceCtx.mu.Unlock()
require.Equal(t, database.WorkspaceAgent{}, workspaceCtx.agent)
require.False(t, workspaceCtx.agentLoaded)
require.Nil(t, workspaceCtx.conn)
require.Nil(t, workspaceCtx.releaseConn)
require.Equal(t, uuid.NullUUID{}, workspaceCtx.cachedWorkspaceID)
}
func TestTurnWorkspaceContext_EnsureWorkspaceAgentIgnoresCachedAgentForDifferentWorkspace(t *testing.T) {
t.Parallel()
ctx := context.Background()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
workspaceOneID := uuid.New()
workspaceTwoID := uuid.New()
buildID := uuid.New()
cachedAgent := database.WorkspaceAgent{ID: uuid.New()}
resolvedAgent := database.WorkspaceAgent{ID: uuid.New()}
chat := database.Chat{
ID: uuid.New(),
WorkspaceID: uuid.NullUUID{
UUID: workspaceTwoID,
Valid: true,
},
}
updatedChat := chat
updatedChat.BuildID = uuid.NullUUID{UUID: buildID, Valid: true}
updatedChat.AgentID = uuid.NullUUID{UUID: resolvedAgent.ID, Valid: true}
gomock.InOrder(
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceTwoID).Return([]database.WorkspaceAgent{resolvedAgent}, nil),
db.EXPECT().GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), workspaceTwoID).Return(database.WorkspaceBuild{ID: buildID}, nil),
db.EXPECT().UpdateChatBuildAgentBinding(gomock.Any(), database.UpdateChatBuildAgentBindingParams{
ID: chat.ID,
BuildID: uuid.NullUUID{UUID: buildID, Valid: true},
AgentID: uuid.NullUUID{UUID: resolvedAgent.ID, Valid: true},
}).Return(updatedChat, nil),
)
chatStateMu := &sync.Mutex{}
currentChat := chat
workspaceCtx := turnWorkspaceContext{
server: &Server{db: db},
chatStateMu: chatStateMu,
currentChat: &currentChat,
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
}
workspaceCtx.agent = cachedAgent
workspaceCtx.agentLoaded = true
workspaceCtx.cachedWorkspaceID = uuid.NullUUID{UUID: workspaceOneID, Valid: true}
defer workspaceCtx.close()
chatSnapshot, agent, err := workspaceCtx.ensureWorkspaceAgent(ctx)
require.NoError(t, err)
require.Equal(t, updatedChat, chatSnapshot)
require.Equal(t, resolvedAgent, agent)
require.Equal(t, updatedChat, currentChat)
}
func TestSubscribeSkipsDatabaseCatchupForLocallyDeliveredMessage(t *testing.T) {
+1 -1
View File
@@ -2280,7 +2280,7 @@ func TestHeartbeatBumpsWorkspaceUsage(t *testing.T) {
// Link the workspace to the chat in the DB, simulating what
// the create_workspace tool does mid-conversation.
_, err = db.UpdateChatWorkspace(ctx, database.UpdateChatWorkspaceParams{
_, err = db.UpdateChatWorkspaceBinding(ctx, database.UpdateChatWorkspaceBindingParams{
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
ID: chat.ID,
})
+13 -3
View File
@@ -66,6 +66,7 @@ type CreateWorkspaceOptions struct {
AgentConnFn AgentConnFunc
AgentInactiveDisconnectTimeout time.Duration
WorkspaceMu *sync.Mutex
OnChatUpdated func(database.Chat)
Logger slog.Logger
AllowedTemplateIDs func() map[uuid.UUID]bool
}
@@ -211,20 +212,29 @@ func CreateWorkspace(options CreateWorkspaceOptions) fantasy.AgentTool {
}
}
// Persist workspace + agent association on the chat.
// Persist the workspace binding on the chat.
if options.DB != nil && options.ChatID != uuid.Nil {
if _, err := options.DB.UpdateChatWorkspace(ctx, database.UpdateChatWorkspaceParams{
updatedChat, err := options.DB.UpdateChatWorkspaceBinding(ctx, database.UpdateChatWorkspaceBindingParams{
ID: options.ChatID,
WorkspaceID: uuid.NullUUID{
UUID: workspace.ID,
Valid: true,
},
}); err != nil {
// BuildID and AgentID are intentionally left null
// here. The chatd runtime (loadWorkspaceAgentLocked)
// will bind them on the next turn. Authoritative
// tool-path binding is deferred to a follow-up PR.
BuildID: uuid.NullUUID{},
AgentID: uuid.NullUUID{},
})
if err != nil {
options.Logger.Error(ctx, "failed to persist chat workspace association",
slog.F("chat_id", options.ChatID),
slog.F("workspace_id", workspace.ID),
slog.Error(err),
)
} else if options.OnChatUpdated != nil {
options.OnChatUpdated(updatedChat)
}
}
+170
View File
@@ -0,0 +1,170 @@
package chatd
import (
"context"
"time"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
// DialResult contains the outcome of dialWithLazyValidation.
type DialResult struct {
Conn workspacesdk.AgentConn
Release func()
AgentID uuid.UUID // The agent that was actually dialed.
WasSwitched bool // True if validation discovered a different agent.
}
// DialFunc dials an agent by ID and returns a connection.
type DialFunc func(ctx context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error)
// ValidateFunc returns the current agent ID for a workspace.
type ValidateFunc func(ctx context.Context, workspaceID uuid.UUID) (uuid.UUID, error)
type dialOut struct {
conn workspacesdk.AgentConn
release func()
err error
}
// dialWithLazyValidation dials an agent and only consults the database if the
// original dial is slow or fails quickly. This keeps the common path free of
// latest-build lookups while still repairing stale bindings.
//
// Outcomes:
// - The dial succeeds before delay, so validation is skipped.
// - The timer fires and validation confirms the same agent, so the original
// dial continues.
// - The timer fires and validation finds a different agent, so the stale
// dial is canceled and the new agent is dialed instead.
// - The dial fails before delay, so validation runs immediately and either
// switches to a different agent or retries the current one once.
func dialWithLazyValidation(
ctx context.Context,
agentID uuid.UUID,
workspaceID uuid.UUID,
dialFn DialFunc,
validateFn ValidateFunc,
delay time.Duration,
) (DialResult, error) {
wrapErr := func(err error) error {
return xerrors.Errorf("dial with lazy validation: %w", err)
}
dialCtx, dialCancel := context.WithCancel(ctx)
results := make(chan dialOut, 1)
go func() {
conn, release, err := dialFn(dialCtx, agentID)
results <- dialOut{conn: conn, release: release, err: err}
}()
drained := false
defer func() {
dialCancel()
if drained {
return
}
// Drain without blocking the caller. dialFn may take time to honor
// cancellation, but any late-arriving successful connection still needs to
// be released.
go func() {
result := <-results
if result.err == nil && result.release != nil {
result.release()
}
}()
}()
resultForAgent := func(dialedAgentID uuid.UUID, result dialOut, switched bool) DialResult {
return DialResult{
Conn: result.conn,
Release: result.release,
AgentID: dialedAgentID,
WasSwitched: switched,
}
}
dialAgent := func(targetAgentID uuid.UUID, switched bool) (DialResult, error) {
conn, release, err := dialFn(ctx, targetAgentID)
if err != nil {
return DialResult{}, wrapErr(err)
}
return resultForAgent(targetAgentID, dialOut{conn: conn, release: release}, switched), nil
}
preferReadyOriginalDial := func() (DialResult, bool) {
select {
case result := <-results:
drained = true
if result.err != nil {
return DialResult{}, false
}
return resultForAgent(agentID, result, false), true
default:
return DialResult{}, false
}
}
waitForOriginalDial := func(waitCtx context.Context) (DialResult, error) {
select {
case result := <-results:
drained = true
if result.err != nil {
return DialResult{}, wrapErr(result.err)
}
return resultForAgent(agentID, result, false), nil
case <-waitCtx.Done():
if ready, ok := preferReadyOriginalDial(); ok {
return ready, nil
}
return DialResult{}, waitCtx.Err()
}
}
validateBinding := func() (uuid.UUID, error) {
validatedAgentID, err := validateFn(ctx, workspaceID)
if err != nil {
return uuid.Nil, wrapErr(err)
}
return validatedAgentID, nil
}
resolveFastFailure := func() (DialResult, error) {
validatedAgentID, err := validateBinding()
if err != nil {
return DialResult{}, err
}
if validatedAgentID == agentID {
return dialAgent(agentID, false)
}
return dialAgent(validatedAgentID, true)
}
timer := time.NewTimer(delay)
defer timer.Stop()
select {
case result := <-results:
drained = true
if result.err == nil {
return resultForAgent(agentID, result, false), nil
}
return resolveFastFailure()
case <-timer.C:
validatedAgentID, validationErr := validateFn(ctx, workspaceID)
if validationErr != nil || validatedAgentID == agentID {
// Validation could not prove the binding was stale, so keep waiting on
// the original dial.
return waitForOriginalDial(ctx)
}
// The original dial is stale. Cancel it first, then let the deferred drain
// release any late result while we dial the validated agent immediately.
dialCancel()
return dialAgent(validatedAgentID, true)
case <-ctx.Done():
if ready, ok := preferReadyOriginalDial(); ok {
return ready, nil
}
return DialResult{}, ctx.Err()
}
}
+563
View File
@@ -0,0 +1,563 @@
package chatd //nolint:testpackage // Uses internal symbols.
import (
"context"
"sync/atomic"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock"
"github.com/coder/coder/v2/testutil"
)
func TestDialWithLazyValidation_FastDial(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
agentID := uuid.New()
workspaceID := uuid.New()
conn := agentconnmock.NewMockAgentConn(ctrl)
var releaseCalls atomic.Int32
var validateCalls atomic.Int32
result, err := dialWithLazyValidation(
context.Background(),
agentID,
workspaceID,
func(_ context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) {
if id != agentID {
return nil, nil, xerrors.Errorf("unexpected agent ID %q", id)
}
return conn, func() {
releaseCalls.Add(1)
}, nil
},
func(_ context.Context, id uuid.UUID) (uuid.UUID, error) {
validateCalls.Add(1)
return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id)
},
time.Minute,
)
require.NoError(t, err)
require.Same(t, conn, result.Conn)
require.Equal(t, agentID, result.AgentID)
require.False(t, result.WasSwitched)
require.EqualValues(t, 0, validateCalls.Load())
if result.Release != nil {
result.Release()
}
require.EqualValues(t, 1, releaseCalls.Load())
}
func TestDialWithLazyValidation_SlowDialSameAgent(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
agentID := uuid.New()
workspaceID := uuid.New()
conn := agentconnmock.NewMockAgentConn(ctrl)
unblockDial := make(chan struct{})
var releaseCalls atomic.Int32
var validateCalls atomic.Int32
result, err := dialWithLazyValidation(
context.Background(),
agentID,
workspaceID,
func(ctx context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) {
if id != agentID {
return nil, nil, xerrors.Errorf("unexpected agent ID %q", id)
}
select {
case <-unblockDial:
return conn, func() {
releaseCalls.Add(1)
}, nil
case <-ctx.Done():
return nil, nil, ctx.Err()
}
},
func(_ context.Context, id uuid.UUID) (uuid.UUID, error) {
if id != workspaceID {
return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id)
}
validateCalls.Add(1)
close(unblockDial)
return agentID, nil
},
0,
)
require.NoError(t, err)
require.Same(t, conn, result.Conn)
require.Equal(t, agentID, result.AgentID)
require.False(t, result.WasSwitched)
require.EqualValues(t, 1, validateCalls.Load())
if result.Release != nil {
result.Release()
}
require.EqualValues(t, 1, releaseCalls.Load())
}
func TestDialWithLazyValidation_SlowDialStaleAgent(t *testing.T) {
t.Parallel()
t.Run("LateSuccessReleasesStaleConn", func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
staleAgentID := uuid.New()
currentAgentID := uuid.New()
workspaceID := uuid.New()
staleConn := agentconnmock.NewMockAgentConn(ctrl)
currentConn := agentconnmock.NewMockAgentConn(ctrl)
var dialCalls atomic.Int32
var validateCalls atomic.Int32
var staleReleaseCalls atomic.Int32
var currentReleaseCalls atomic.Int32
result, err := dialWithLazyValidation(
context.Background(),
staleAgentID,
workspaceID,
func(ctx context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) {
dialCalls.Add(1)
switch id {
case staleAgentID:
<-ctx.Done()
return staleConn, func() {
staleReleaseCalls.Add(1)
}, nil
case currentAgentID:
return currentConn, func() {
currentReleaseCalls.Add(1)
}, nil
default:
return nil, nil, xerrors.Errorf("unexpected agent ID %q", id)
}
},
func(_ context.Context, id uuid.UUID) (uuid.UUID, error) {
if id != workspaceID {
return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id)
}
validateCalls.Add(1)
return currentAgentID, nil
},
0,
)
require.NoError(t, err)
require.Same(t, currentConn, result.Conn)
require.Equal(t, currentAgentID, result.AgentID)
require.True(t, result.WasSwitched)
require.Eventually(t, func() bool {
return dialCalls.Load() == 2
}, testutil.WaitShort, testutil.IntervalFast)
require.EqualValues(t, 1, validateCalls.Load())
require.Eventually(t, func() bool {
return staleReleaseCalls.Load() == 1
}, testutil.WaitShort, testutil.IntervalFast)
if result.Release != nil {
result.Release()
}
require.EqualValues(t, 1, currentReleaseCalls.Load())
})
t.Run("CanceledFailureDoesNotReleaseStaleConn", func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
staleAgentID := uuid.New()
currentAgentID := uuid.New()
workspaceID := uuid.New()
currentConn := agentconnmock.NewMockAgentConn(ctrl)
var dialCalls atomic.Int32
var validateCalls atomic.Int32
var staleReleaseCalls atomic.Int32
var currentReleaseCalls atomic.Int32
result, err := dialWithLazyValidation(
context.Background(),
staleAgentID,
workspaceID,
func(ctx context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) {
dialCalls.Add(1)
switch id {
case staleAgentID:
<-ctx.Done()
return nil, func() {
staleReleaseCalls.Add(1)
}, ctx.Err()
case currentAgentID:
return currentConn, func() {
currentReleaseCalls.Add(1)
}, nil
default:
return nil, nil, xerrors.Errorf("unexpected agent ID %q", id)
}
},
func(_ context.Context, id uuid.UUID) (uuid.UUID, error) {
if id != workspaceID {
return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id)
}
validateCalls.Add(1)
return currentAgentID, nil
},
0,
)
require.NoError(t, err)
require.Same(t, currentConn, result.Conn)
require.Equal(t, currentAgentID, result.AgentID)
require.True(t, result.WasSwitched)
require.Eventually(t, func() bool {
return dialCalls.Load() == 2
}, testutil.WaitShort, testutil.IntervalFast)
require.EqualValues(t, 1, validateCalls.Load())
require.EqualValues(t, 0, staleReleaseCalls.Load())
if result.Release != nil {
result.Release()
}
require.EqualValues(t, 1, currentReleaseCalls.Load())
})
t.Run("SwitchDoesNotBlock", func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
staleAgentID := uuid.New()
currentAgentID := uuid.New()
workspaceID := uuid.New()
staleConn := agentconnmock.NewMockAgentConn(ctrl)
currentConn := agentconnmock.NewMockAgentConn(ctrl)
staleDialStarted := make(chan struct{})
allowStaleReturn := make(chan struct{})
var dialCalls atomic.Int32
var validateCalls atomic.Int32
var staleReleaseCalls atomic.Int32
var currentReleaseCalls atomic.Int32
var staleReturnReleased atomic.Bool
releaseStaleReturn := func() {
if staleReturnReleased.CompareAndSwap(false, true) {
close(allowStaleReturn)
}
}
defer releaseStaleReturn()
resultCh := make(chan DialResult, 1)
errCh := make(chan error, 1)
go func() {
result, err := dialWithLazyValidation(
context.Background(),
staleAgentID,
workspaceID,
func(_ context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) {
dialCalls.Add(1)
switch id {
case staleAgentID:
close(staleDialStarted)
<-allowStaleReturn
return staleConn, func() {
staleReleaseCalls.Add(1)
}, nil
case currentAgentID:
return currentConn, func() {
currentReleaseCalls.Add(1)
}, nil
default:
return nil, nil, xerrors.Errorf("unexpected agent ID %q", id)
}
},
func(_ context.Context, id uuid.UUID) (uuid.UUID, error) {
if id != workspaceID {
return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id)
}
<-staleDialStarted
validateCalls.Add(1)
return currentAgentID, nil
},
0,
)
if err != nil {
errCh <- err
return
}
resultCh <- result
}()
var result DialResult
select {
case err := <-errCh:
require.NoError(t, err)
case result = <-resultCh:
require.Same(t, currentConn, result.Conn)
require.Equal(t, currentAgentID, result.AgentID)
require.True(t, result.WasSwitched)
releaseStaleReturn()
case <-time.After(testutil.WaitShort):
t.Fatal("dialWithLazyValidation blocked on stale dial cleanup")
}
require.EqualValues(t, 2, dialCalls.Load())
require.EqualValues(t, 1, validateCalls.Load())
require.Eventually(t, func() bool {
return staleReleaseCalls.Load() == 1
}, testutil.WaitShort, testutil.IntervalFast)
if result.Release != nil {
result.Release()
}
require.EqualValues(t, 1, currentReleaseCalls.Load())
})
}
func TestDialWithLazyValidation_FastFailure(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
staleAgentID := uuid.New()
currentAgentID := uuid.New()
workspaceID := uuid.New()
currentConn := agentconnmock.NewMockAgentConn(ctrl)
var dialCalls atomic.Int32
var validateCalls atomic.Int32
var currentReleaseCalls atomic.Int32
result, err := dialWithLazyValidation(
context.Background(),
staleAgentID,
workspaceID,
func(_ context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) {
switch dialCalls.Add(1) {
case 1:
if id != staleAgentID {
return nil, nil, xerrors.Errorf("unexpected agent ID %q", id)
}
return nil, nil, xerrors.New("dial failed")
case 2:
if id != currentAgentID {
return nil, nil, xerrors.Errorf("unexpected agent ID %q", id)
}
return currentConn, func() {
currentReleaseCalls.Add(1)
}, nil
default:
return nil, nil, xerrors.New("unexpected dial call")
}
},
func(_ context.Context, id uuid.UUID) (uuid.UUID, error) {
if id != workspaceID {
return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id)
}
validateCalls.Add(1)
return currentAgentID, nil
},
time.Minute,
)
require.NoError(t, err)
require.Same(t, currentConn, result.Conn)
require.Equal(t, currentAgentID, result.AgentID)
require.True(t, result.WasSwitched)
require.EqualValues(t, 2, dialCalls.Load())
require.EqualValues(t, 1, validateCalls.Load())
if result.Release != nil {
result.Release()
}
require.EqualValues(t, 1, currentReleaseCalls.Load())
}
func TestDialWithLazyValidation_FastFailureSameAgent(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
agentID := uuid.New()
workspaceID := uuid.New()
conn := agentconnmock.NewMockAgentConn(ctrl)
var dialCalls atomic.Int32
var releaseCalls atomic.Int32
var validateCalls atomic.Int32
result, err := dialWithLazyValidation(
context.Background(),
agentID,
workspaceID,
func(_ context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) {
if id != agentID {
return nil, nil, xerrors.Errorf("unexpected agent ID %q", id)
}
switch dialCalls.Add(1) {
case 1:
return nil, nil, xerrors.New("dial failed")
case 2:
return conn, func() {
releaseCalls.Add(1)
}, nil
default:
return nil, nil, xerrors.New("unexpected dial call")
}
},
func(_ context.Context, id uuid.UUID) (uuid.UUID, error) {
if id != workspaceID {
return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id)
}
validateCalls.Add(1)
return agentID, nil
},
time.Minute,
)
require.NoError(t, err)
require.Same(t, conn, result.Conn)
require.Equal(t, agentID, result.AgentID)
require.False(t, result.WasSwitched)
require.EqualValues(t, 2, dialCalls.Load())
require.EqualValues(t, 1, validateCalls.Load())
if result.Release != nil {
result.Release()
}
require.EqualValues(t, 1, releaseCalls.Load())
}
func TestDialWithLazyValidation_FastFailureSameAgentRetryFails(t *testing.T) {
t.Parallel()
agentID := uuid.New()
workspaceID := uuid.New()
var dialCalls atomic.Int32
var validateCalls atomic.Int32
_, err := dialWithLazyValidation(
context.Background(),
agentID,
workspaceID,
func(_ context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) {
if id != agentID {
return nil, nil, xerrors.Errorf("unexpected agent ID %q", id)
}
switch dialCalls.Add(1) {
case 1:
return nil, nil, xerrors.New("dial failed")
case 2:
return nil, nil, xerrors.New("retry failed")
default:
return nil, nil, xerrors.New("unexpected dial call")
}
},
func(_ context.Context, id uuid.UUID) (uuid.UUID, error) {
if id != workspaceID {
return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id)
}
validateCalls.Add(1)
return agentID, nil
},
time.Minute,
)
require.EqualError(t, err, "dial with lazy validation: retry failed")
require.EqualValues(t, 2, dialCalls.Load())
require.EqualValues(t, 1, validateCalls.Load())
}
func TestDialWithLazyValidation_ValidationError(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
agentID := uuid.New()
workspaceID := uuid.New()
conn := agentconnmock.NewMockAgentConn(ctrl)
unblockDial := make(chan struct{})
var releaseCalls atomic.Int32
var validateCalls atomic.Int32
result, err := dialWithLazyValidation(
context.Background(),
agentID,
workspaceID,
func(ctx context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) {
if id != agentID {
return nil, nil, xerrors.Errorf("unexpected agent ID %q", id)
}
select {
case <-unblockDial:
return conn, func() {
releaseCalls.Add(1)
}, nil
case <-ctx.Done():
return nil, nil, ctx.Err()
}
},
func(_ context.Context, id uuid.UUID) (uuid.UUID, error) {
if id != workspaceID {
return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id)
}
validateCalls.Add(1)
// Validation fails — code should fall back to waiting
// for the original dial.
close(unblockDial)
return uuid.Nil, xerrors.New("db connection reset")
},
0,
)
require.NoError(t, err)
require.Same(t, conn, result.Conn)
require.Equal(t, agentID, result.AgentID)
require.False(t, result.WasSwitched)
require.EqualValues(t, 1, validateCalls.Load())
if result.Release != nil {
result.Release()
}
require.EqualValues(t, 1, releaseCalls.Load())
}
func TestDialWithLazyValidation_ContextCanceled(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
agentID := uuid.New()
workspaceID := uuid.New()
var validateCalls atomic.Int32
_, err := dialWithLazyValidation(
ctx,
agentID,
workspaceID,
func(ctx context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) {
if id != agentID {
return nil, nil, xerrors.Errorf("unexpected agent ID %q", id)
}
<-ctx.Done()
return nil, nil, ctx.Err()
},
func(_ context.Context, id uuid.UUID) (uuid.UUID, error) {
if id != workspaceID {
return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id)
}
validateCalls.Add(1)
cancel()
return agentID, nil
},
0,
)
require.ErrorIs(t, err, context.Canceled)
require.EqualValues(t, 1, validateCalls.Load())
}
+4
View File
@@ -313,6 +313,8 @@ func (p *Server) subagentTools(ctx context.Context, currentChat func() database.
childChat, err := p.CreateChat(ctx, CreateOptions{
OwnerID: parent.OwnerID,
WorkspaceID: parent.WorkspaceID,
BuildID: parent.BuildID,
AgentID: parent.AgentID,
ParentChatID: uuid.NullUUID{
UUID: parent.ID,
Valid: true,
@@ -383,6 +385,8 @@ func (p *Server) createChildSubagentChat(
child, err := p.CreateChat(ctx, CreateOptions{
OwnerID: parent.OwnerID,
WorkspaceID: parent.WorkspaceID,
BuildID: parent.BuildID,
AgentID: parent.AgentID,
ParentChatID: uuid.NullUUID{
UUID: parent.ID,
Valid: true,
+100 -1
View File
@@ -149,6 +149,45 @@ func seedInternalChatDeps(
return user, model
}
func seedWorkspaceBinding(
t *testing.T,
db database.Store,
userID uuid.UUID,
) (database.WorkspaceTable, database.WorkspaceBuild, database.WorkspaceAgent) {
t.Helper()
org := dbgen.Organization(t, db, database.Organization{})
tv := dbgen.TemplateVersion(t, db, database.TemplateVersion{
OrganizationID: org.ID,
CreatedBy: userID,
})
tpl := dbgen.Template(t, db, database.Template{
CreatedBy: userID,
OrganizationID: org.ID,
ActiveVersionID: tv.ID,
})
workspace := dbgen.Workspace(t, db, database.WorkspaceTable{
TemplateID: tpl.ID,
OwnerID: userID,
OrganizationID: org.ID,
})
job := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
InitiatorID: userID,
OrganizationID: org.ID,
})
build := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
TemplateVersionID: tv.ID,
WorkspaceID: workspace.ID,
JobID: job.ID,
})
resource := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
Transition: database.WorkspaceTransitionStart,
JobID: job.ID,
})
agent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{ResourceID: resource.ID})
return workspace, build, agent
}
// findToolByName returns the tool with the given name from the
// slice, or nil if no match is found.
func findToolByName(tools []fantasy.AgentTool, name string) fantasy.AgentTool {
@@ -165,6 +204,49 @@ func chatdTestContext(t *testing.T) context.Context {
return dbauthz.AsChatd(testutil.Context(t, testutil.WaitLong))
}
func TestCreateChildSubagentChatInheritsWorkspaceBinding(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
ctx := chatdTestContext(t)
user, model := seedInternalChatDeps(ctx, t, db)
workspace, build, agent := seedWorkspaceBinding(t, db, user.ID)
parent, err := server.CreateChat(ctx, CreateOptions{
OwnerID: user.ID,
WorkspaceID: uuid.NullUUID{
UUID: workspace.ID,
Valid: true,
},
BuildID: uuid.NullUUID{
UUID: build.ID,
Valid: true,
},
AgentID: uuid.NullUUID{
UUID: agent.ID,
Valid: true,
},
Title: "bound-parent",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
})
require.NoError(t, err)
parentChat, err := db.GetChatByID(ctx, parent.ID)
require.NoError(t, err)
child, err := server.createChildSubagentChat(ctx, parentChat, "inspect bindings", "")
require.NoError(t, err)
childChat, err := db.GetChatByID(ctx, child.ID)
require.NoError(t, err)
require.Equal(t, parentChat.WorkspaceID, childChat.WorkspaceID)
require.Equal(t, parentChat.BuildID, childChat.BuildID)
require.Equal(t, parentChat.AgentID, childChat.AgentID)
}
func TestSpawnComputerUseAgent_NoAnthropicProvider(t *testing.T) {
t.Parallel()
@@ -292,13 +374,26 @@ func TestSpawnComputerUseAgent_UsesComputerUseModelNotParent(t *testing.T) {
ctx := chatdTestContext(t)
user, model := seedInternalChatDeps(ctx, t, db)
workspace, build, agent := seedWorkspaceBinding(t, db, user.ID)
// The parent uses an OpenAI model.
require.Equal(t, "openai", model.Provider,
"seed helper must create an OpenAI model")
parent, err := server.CreateChat(ctx, CreateOptions{
OwnerID: user.ID,
OwnerID: user.ID,
WorkspaceID: uuid.NullUUID{
UUID: workspace.ID,
Valid: true,
},
BuildID: uuid.NullUUID{
UUID: build.ID,
Valid: true,
},
AgentID: uuid.NullUUID{
UUID: agent.ID,
Valid: true,
},
Title: "parent-openai",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
@@ -332,6 +427,10 @@ func TestSpawnComputerUseAgent_UsesComputerUseModelNotParent(t *testing.T) {
childChat, err := db.GetChatByID(ctx, childID)
require.NoError(t, err)
require.Equal(t, parentChat.WorkspaceID, childChat.WorkspaceID)
require.Equal(t, parentChat.BuildID, childChat.BuildID)
require.Equal(t, parentChat.AgentID, childChat.AgentID)
// The child must have Mode=computer_use which causes
// runChat to override the model to the predefined computer
// use model instead of using the parent's model config.
+2
View File
@@ -49,6 +49,8 @@ type Chat struct {
ID uuid.UUID `json:"id" format:"uuid"`
OwnerID uuid.UUID `json:"owner_id" format:"uuid"`
WorkspaceID *uuid.UUID `json:"workspace_id,omitempty" format:"uuid"`
BuildID *uuid.UUID `json:"build_id,omitempty" format:"uuid"`
AgentID *uuid.UUID `json:"agent_id,omitempty" format:"uuid"`
ParentChatID *uuid.UUID `json:"parent_chat_id,omitempty" format:"uuid"`
RootChatID *uuid.UUID `json:"root_chat_id,omitempty" format:"uuid"`
LastModelConfigID uuid.UUID `json:"last_model_config_id" format:"uuid"`
+2
View File
@@ -1089,6 +1089,8 @@ export interface Chat {
readonly id: string;
readonly owner_id: string;
readonly workspace_id?: string;
readonly build_id?: string;
readonly agent_id?: string;
readonly parent_chat_id?: string;
readonly root_chat_id?: string;
readonly last_model_config_id: string;