diff --git a/coderd/coderd.go b/coderd/coderd.go index 746959cb85..e1ae2a9502 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -1245,6 +1245,7 @@ func New(options *Options) *API { r.Get("/git", api.watchChatGit) }) r.Post("/interrupt", api.interruptChat) + r.Post("/tool-results", api.postChatToolResults) r.Post("/title/regenerate", api.regenerateChatTitle) r.Get("/diff", api.getChatDiffContents) r.Route("/queue/{queuedMessage}", func(r chi.Router) { diff --git a/coderd/database/db2sdk/db2sdk_test.go b/coderd/database/db2sdk/db2sdk_test.go index 2282f36fbd..38584a6017 100644 --- a/coderd/database/db2sdk/db2sdk_test.go +++ b/coderd/database/db2sdk/db2sdk_test.go @@ -552,6 +552,10 @@ func TestChat_AllFieldsPopulated(t *testing.T) { RawMessage: json.RawMessage(`[{"type":"context-file","context_file_path":"/AGENTS.md"}]`), Valid: true, }, + DynamicTools: pqtype.NullRawMessage{ + RawMessage: json.RawMessage(`[{"name":"tool1","description":"test tool","inputSchema":{"type":"object"}}]`), + Valid: true, + }, } // Only ChatID is needed here. This test checks that // Chat.DiffStatus is non-nil, not that every DiffStatus diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index b3d7d9081d..505c6451e5 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -293,7 +293,8 @@ CREATE TYPE chat_status AS ENUM ( 'running', 'paused', 'completed', - 'error' + 'error', + 'requires_action' ); CREATE TYPE connection_status AS ENUM ( @@ -1418,7 +1419,8 @@ CREATE TABLE chats ( agent_id uuid, pin_order integer DEFAULT 0 NOT NULL, last_read_message_id bigint, - last_injected_context jsonb + last_injected_context jsonb, + dynamic_tools jsonb ); CREATE TABLE connection_logs ( diff --git a/coderd/database/migrations/000463_chat_dynamic_tools.down.sql b/coderd/database/migrations/000463_chat_dynamic_tools.down.sql new file mode 100644 index 0000000000..9a8fedf2e7 --- /dev/null +++ b/coderd/database/migrations/000463_chat_dynamic_tools.down.sql @@ -0,0 +1,31 @@ +-- First update any rows using the value we're about to remove. +-- The column type is still the original chat_status at this point. +UPDATE chats SET status = 'error' WHERE status = 'requires_action'; + +-- Drop the column (this is independent of the enum). +ALTER TABLE chats DROP COLUMN IF EXISTS dynamic_tools; + +-- Drop the partial index that references the chat_status enum type. +-- It must be removed before the rename-create-cast-drop cycle +-- because the index's WHERE clause (status = 'pending'::chat_status) +-- would otherwise cause a cross-type comparison failure. +DROP INDEX IF EXISTS idx_chats_pending; + +-- Now recreate the enum without requires_action. +-- We must use the rename-create-cast-drop pattern. +ALTER TYPE chat_status RENAME TO chat_status_old; +CREATE TYPE chat_status AS ENUM ( + 'waiting', + 'pending', + 'running', + 'paused', + 'completed', + 'error' +); +ALTER TABLE chats ALTER COLUMN status DROP DEFAULT; +ALTER TABLE chats ALTER COLUMN status TYPE chat_status USING status::text::chat_status; +ALTER TABLE chats ALTER COLUMN status SET DEFAULT 'waiting'; +DROP TYPE chat_status_old; + +-- Recreate the partial index. +CREATE INDEX idx_chats_pending ON chats USING btree (status) WHERE (status = 'pending'::chat_status); diff --git a/coderd/database/migrations/000463_chat_dynamic_tools.up.sql b/coderd/database/migrations/000463_chat_dynamic_tools.up.sql new file mode 100644 index 0000000000..1601462f79 --- /dev/null +++ b/coderd/database/migrations/000463_chat_dynamic_tools.up.sql @@ -0,0 +1,3 @@ +ALTER TYPE chat_status ADD VALUE IF NOT EXISTS 'requires_action'; + +ALTER TABLE chats ADD COLUMN dynamic_tools JSONB DEFAULT NULL; diff --git a/coderd/database/modelqueries.go b/coderd/database/modelqueries.go index 169e46b3d4..63bb6afd1d 100644 --- a/coderd/database/modelqueries.go +++ b/coderd/database/modelqueries.go @@ -798,6 +798,7 @@ func (q *sqlQuerier) GetAuthorizedChats(ctx context.Context, arg GetChatsParams, &i.Chat.PinOrder, &i.Chat.LastReadMessageID, &i.Chat.LastInjectedContext, + &i.Chat.DynamicTools, &i.HasUnread); err != nil { return nil, err } diff --git a/coderd/database/models.go b/coderd/database/models.go index a9e70db01a..c33bcfd653 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -1290,12 +1290,13 @@ func AllChatModeValues() []ChatMode { type ChatStatus string const ( - ChatStatusWaiting ChatStatus = "waiting" - ChatStatusPending ChatStatus = "pending" - ChatStatusRunning ChatStatus = "running" - ChatStatusPaused ChatStatus = "paused" - ChatStatusCompleted ChatStatus = "completed" - ChatStatusError ChatStatus = "error" + ChatStatusWaiting ChatStatus = "waiting" + ChatStatusPending ChatStatus = "pending" + ChatStatusRunning ChatStatus = "running" + ChatStatusPaused ChatStatus = "paused" + ChatStatusCompleted ChatStatus = "completed" + ChatStatusError ChatStatus = "error" + ChatStatusRequiresAction ChatStatus = "requires_action" ) func (e *ChatStatus) Scan(src interface{}) error { @@ -1340,7 +1341,8 @@ func (e ChatStatus) Valid() bool { ChatStatusRunning, ChatStatusPaused, ChatStatusCompleted, - ChatStatusError: + ChatStatusError, + ChatStatusRequiresAction: return true } return false @@ -1354,6 +1356,7 @@ func AllChatStatusValues() []ChatStatus { ChatStatusPaused, ChatStatusCompleted, ChatStatusError, + ChatStatusRequiresAction, } } @@ -4180,6 +4183,7 @@ type Chat struct { PinOrder int32 `db:"pin_order" json:"pin_order"` LastReadMessageID sql.NullInt64 `db:"last_read_message_id" json:"last_read_message_id"` LastInjectedContext pqtype.NullRawMessage `db:"last_injected_context" json:"last_injected_context"` + DynamicTools pqtype.NullRawMessage `db:"dynamic_tools" json:"dynamic_tools"` } type ChatDiffStatus struct { diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 80dfc4070e..5dc83c2ccb 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -509,8 +509,10 @@ type sqlcQuerier interface { GetReplicasUpdatedAfter(ctx context.Context, updatedAt time.Time) ([]Replica, error) GetRunningPrebuiltWorkspaces(ctx context.Context) ([]GetRunningPrebuiltWorkspacesRow, error) GetRuntimeConfig(ctx context.Context, key string) (string, error) - // Find chats that appear stuck (running but heartbeat has expired). - // Used for recovery after coderd crashes or long hangs. + // Find chats that appear stuck and need recovery. This covers: + // 1. Running chats whose heartbeat has expired (worker crash). + // 2. Chats awaiting client action (requires_action) past the + // timeout threshold (client disappeared). GetStaleChats(ctx context.Context, staleThreshold time.Time) ([]Chat, error) GetTailnetPeers(ctx context.Context, id uuid.UUID) ([]TailnetPeer, error) GetTailnetTunnelPeerBindingsBatch(ctx context.Context, ids []uuid.UUID) ([]GetTailnetTunnelPeerBindingsBatchRow, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index c734ccdfc2..089c40cacf 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -4228,7 +4228,7 @@ WHERE $3::int ) RETURNING - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools ` type AcquireChatsParams struct { @@ -4272,6 +4272,7 @@ func (q *sqlQuerier) AcquireChats(ctx context.Context, arg AcquireChatsParams) ( &i.PinOrder, &i.LastReadMessageID, &i.LastInjectedContext, + &i.DynamicTools, ); err != nil { return nil, err } @@ -4410,9 +4411,9 @@ WITH chats AS ( UPDATE chats SET archived = true, pin_order = 0, updated_at = NOW() WHERE id = $1::uuid OR root_chat_id = $1::uuid - RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context + RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools ) -SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context +SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools FROM chats ORDER BY (id = $1::uuid) DESC, created_at ASC, id ASC ` @@ -4450,6 +4451,7 @@ func (q *sqlQuerier) ArchiveChatByID(ctx context.Context, id uuid.UUID) ([]Chat, &i.PinOrder, &i.LastReadMessageID, &i.LastInjectedContext, + &i.DynamicTools, ); err != nil { return nil, err } @@ -4587,7 +4589,7 @@ func (q *sqlQuerier) DeleteOldChats(ctx context.Context, arg DeleteOldChatsParam const getChatByID = `-- name: GetChatByID :one SELECT - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools FROM chats WHERE @@ -4621,12 +4623,13 @@ func (q *sqlQuerier) GetChatByID(ctx context.Context, id uuid.UUID) (Chat, error &i.PinOrder, &i.LastReadMessageID, &i.LastInjectedContext, + &i.DynamicTools, ) return i, err } const getChatByIDForUpdate = `-- name: GetChatByIDForUpdate :one -SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context FROM chats WHERE id = $1::uuid FOR UPDATE +SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools FROM chats WHERE id = $1::uuid FOR UPDATE ` func (q *sqlQuerier) GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (Chat, error) { @@ -4656,6 +4659,7 @@ func (q *sqlQuerier) GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (Ch &i.PinOrder, &i.LastReadMessageID, &i.LastInjectedContext, + &i.DynamicTools, ) return i, err } @@ -5710,7 +5714,7 @@ func (q *sqlQuerier) GetChatUsageLimitUserOverride(ctx context.Context, userID u const getChats = `-- name: GetChats :many SELECT - chats.id, chats.owner_id, chats.workspace_id, chats.title, chats.status, chats.worker_id, chats.started_at, chats.heartbeat_at, chats.created_at, chats.updated_at, chats.parent_chat_id, chats.root_chat_id, chats.last_model_config_id, chats.archived, chats.last_error, chats.mode, chats.mcp_server_ids, chats.labels, chats.build_id, chats.agent_id, chats.pin_order, chats.last_read_message_id, chats.last_injected_context, + chats.id, chats.owner_id, chats.workspace_id, chats.title, chats.status, chats.worker_id, chats.started_at, chats.heartbeat_at, chats.created_at, chats.updated_at, chats.parent_chat_id, chats.root_chat_id, chats.last_model_config_id, chats.archived, chats.last_error, chats.mode, chats.mcp_server_ids, chats.labels, chats.build_id, chats.agent_id, chats.pin_order, chats.last_read_message_id, chats.last_injected_context, chats.dynamic_tools, EXISTS ( SELECT 1 FROM chat_messages cm WHERE cm.chat_id = chats.id @@ -5818,6 +5822,7 @@ func (q *sqlQuerier) GetChats(ctx context.Context, arg GetChatsParams) ([]GetCha &i.Chat.PinOrder, &i.Chat.LastReadMessageID, &i.Chat.LastInjectedContext, + &i.Chat.DynamicTools, &i.HasUnread, ); err != nil { return nil, err @@ -5834,7 +5839,7 @@ func (q *sqlQuerier) GetChats(ctx context.Context, arg GetChatsParams) ([]GetCha } const getChatsByWorkspaceIDs = `-- name: GetChatsByWorkspaceIDs :many -SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context +SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools FROM chats WHERE archived = false AND workspace_id = ANY($1::uuid[]) @@ -5874,6 +5879,7 @@ func (q *sqlQuerier) GetChatsByWorkspaceIDs(ctx context.Context, ids []uuid.UUID &i.PinOrder, &i.LastReadMessageID, &i.LastInjectedContext, + &i.DynamicTools, ); err != nil { return nil, err } @@ -6001,16 +6007,20 @@ func (q *sqlQuerier) GetLastChatMessageByRole(ctx context.Context, arg GetLastCh const getStaleChats = `-- name: GetStaleChats :many SELECT - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools FROM chats WHERE - status = 'running'::chat_status - AND heartbeat_at < $1::timestamptz + (status = 'running'::chat_status + AND heartbeat_at < $1::timestamptz) + OR (status = 'requires_action'::chat_status + AND updated_at < $1::timestamptz) ` -// Find chats that appear stuck (running but heartbeat has expired). -// Used for recovery after coderd crashes or long hangs. +// Find chats that appear stuck and need recovery. This covers: +// 1. Running chats whose heartbeat has expired (worker crash). +// 2. Chats awaiting client action (requires_action) past the +// timeout threshold (client disappeared). func (q *sqlQuerier) GetStaleChats(ctx context.Context, staleThreshold time.Time) ([]Chat, error) { rows, err := q.db.QueryContext(ctx, getStaleChats, staleThreshold) if err != nil { @@ -6044,6 +6054,7 @@ func (q *sqlQuerier) GetStaleChats(ctx context.Context, staleThreshold time.Time &i.PinOrder, &i.LastReadMessageID, &i.LastInjectedContext, + &i.DynamicTools, ); err != nil { return nil, err } @@ -6111,7 +6122,8 @@ INSERT INTO chats ( mode, status, mcp_server_ids, - labels + labels, + dynamic_tools ) VALUES ( $1::uuid, $2::uuid, @@ -6124,10 +6136,11 @@ INSERT INTO chats ( $9::chat_mode, $10::chat_status, COALESCE($11::uuid[], '{}'::uuid[]), - COALESCE($12::jsonb, '{}'::jsonb) + COALESCE($12::jsonb, '{}'::jsonb), + $13::jsonb ) RETURNING - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools ` type InsertChatParams struct { @@ -6143,6 +6156,7 @@ type InsertChatParams struct { Status ChatStatus `db:"status" json:"status"` MCPServerIDs []uuid.UUID `db:"mcp_server_ids" json:"mcp_server_ids"` Labels pqtype.NullRawMessage `db:"labels" json:"labels"` + DynamicTools pqtype.NullRawMessage `db:"dynamic_tools" json:"dynamic_tools"` } func (q *sqlQuerier) InsertChat(ctx context.Context, arg InsertChatParams) (Chat, error) { @@ -6159,6 +6173,7 @@ func (q *sqlQuerier) InsertChat(ctx context.Context, arg InsertChatParams) (Chat arg.Status, pq.Array(arg.MCPServerIDs), arg.Labels, + arg.DynamicTools, ) var i Chat err := row.Scan( @@ -6185,6 +6200,7 @@ func (q *sqlQuerier) InsertChat(ctx context.Context, arg InsertChatParams) (Chat &i.PinOrder, &i.LastReadMessageID, &i.LastInjectedContext, + &i.DynamicTools, ) return i, err } @@ -6680,9 +6696,9 @@ WITH chats AS ( archived = false, updated_at = NOW() WHERE id = $1::uuid OR root_chat_id = $1::uuid - RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context + RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools ) -SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context +SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools FROM chats ORDER BY (id = $1::uuid) DESC, created_at ASC, id ASC ` @@ -6724,6 +6740,7 @@ func (q *sqlQuerier) UnarchiveChatByID(ctx context.Context, id uuid.UUID) ([]Cha &i.PinOrder, &i.LastReadMessageID, &i.LastInjectedContext, + &i.DynamicTools, ); err != nil { return nil, err } @@ -6804,7 +6821,7 @@ UPDATE chats SET updated_at = NOW() WHERE id = $3::uuid -RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context +RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools ` type UpdateChatBuildAgentBindingParams struct { @@ -6840,6 +6857,7 @@ func (q *sqlQuerier) UpdateChatBuildAgentBinding(ctx context.Context, arg Update &i.PinOrder, &i.LastReadMessageID, &i.LastInjectedContext, + &i.DynamicTools, ) return i, err } @@ -6853,7 +6871,7 @@ SET WHERE id = $2::uuid RETURNING - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools ` type UpdateChatByIDParams struct { @@ -6888,6 +6906,7 @@ func (q *sqlQuerier) UpdateChatByID(ctx context.Context, arg UpdateChatByIDParam &i.PinOrder, &i.LastReadMessageID, &i.LastInjectedContext, + &i.DynamicTools, ) return i, err } @@ -6946,7 +6965,7 @@ SET WHERE id = $2::uuid RETURNING - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools ` type UpdateChatLabelsByIDParams struct { @@ -6981,6 +7000,7 @@ func (q *sqlQuerier) UpdateChatLabelsByID(ctx context.Context, arg UpdateChatLab &i.PinOrder, &i.LastReadMessageID, &i.LastInjectedContext, + &i.DynamicTools, ) return i, err } @@ -6990,7 +7010,7 @@ UPDATE chats SET last_injected_context = $1::jsonb WHERE id = $2::uuid -RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context +RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools ` type UpdateChatLastInjectedContextParams struct { @@ -7029,6 +7049,7 @@ func (q *sqlQuerier) UpdateChatLastInjectedContext(ctx context.Context, arg Upda &i.PinOrder, &i.LastReadMessageID, &i.LastInjectedContext, + &i.DynamicTools, ) return i, err } @@ -7042,7 +7063,7 @@ SET WHERE id = $2::uuid RETURNING - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools ` type UpdateChatLastModelConfigByIDParams struct { @@ -7077,6 +7098,7 @@ func (q *sqlQuerier) UpdateChatLastModelConfigByID(ctx context.Context, arg Upda &i.PinOrder, &i.LastReadMessageID, &i.LastInjectedContext, + &i.DynamicTools, ) return i, err } @@ -7108,7 +7130,7 @@ SET WHERE id = $2::uuid RETURNING - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools ` type UpdateChatMCPServerIDsParams struct { @@ -7143,6 +7165,7 @@ func (q *sqlQuerier) UpdateChatMCPServerIDs(ctx context.Context, arg UpdateChatM &i.PinOrder, &i.LastReadMessageID, &i.LastInjectedContext, + &i.DynamicTools, ) return i, err } @@ -7278,7 +7301,7 @@ SET WHERE id = $6::uuid RETURNING - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools ` type UpdateChatStatusParams struct { @@ -7324,6 +7347,7 @@ func (q *sqlQuerier) UpdateChatStatus(ctx context.Context, arg UpdateChatStatusP &i.PinOrder, &i.LastReadMessageID, &i.LastInjectedContext, + &i.DynamicTools, ) return i, err } @@ -7341,7 +7365,7 @@ SET WHERE id = $7::uuid RETURNING - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools ` type UpdateChatStatusPreserveUpdatedAtParams struct { @@ -7389,6 +7413,7 @@ func (q *sqlQuerier) UpdateChatStatusPreserveUpdatedAt(ctx context.Context, arg &i.PinOrder, &i.LastReadMessageID, &i.LastInjectedContext, + &i.DynamicTools, ) return i, err } @@ -7400,7 +7425,7 @@ UPDATE chats SET agent_id = $3::uuid, updated_at = NOW() WHERE id = $4::uuid -RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context +RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools ` type UpdateChatWorkspaceBindingParams struct { @@ -7442,6 +7467,7 @@ func (q *sqlQuerier) UpdateChatWorkspaceBinding(ctx context.Context, arg UpdateC &i.PinOrder, &i.LastReadMessageID, &i.LastInjectedContext, + &i.DynamicTools, ) return i, err } diff --git a/coderd/database/queries/chats.sql b/coderd/database/queries/chats.sql index 49094e7841..587b4a4da6 100644 --- a/coderd/database/queries/chats.sql +++ b/coderd/database/queries/chats.sql @@ -399,7 +399,8 @@ INSERT INTO chats ( mode, status, mcp_server_ids, - labels + labels, + dynamic_tools ) VALUES ( @owner_id::uuid, sqlc.narg('workspace_id')::uuid, @@ -412,7 +413,8 @@ INSERT INTO chats ( sqlc.narg('mode')::chat_mode, @status::chat_status, COALESCE(@mcp_server_ids::uuid[], '{}'::uuid[]), - COALESCE(sqlc.narg('labels')::jsonb, '{}'::jsonb) + COALESCE(sqlc.narg('labels')::jsonb, '{}'::jsonb), + sqlc.narg('dynamic_tools')::jsonb ) RETURNING *; @@ -669,15 +671,19 @@ RETURNING *; -- name: GetStaleChats :many --- Find chats that appear stuck (running but heartbeat has expired). --- Used for recovery after coderd crashes or long hangs. +-- Find chats that appear stuck and need recovery. This covers: +-- 1. Running chats whose heartbeat has expired (worker crash). +-- 2. Chats awaiting client action (requires_action) past the +-- timeout threshold (client disappeared). SELECT * FROM chats WHERE - status = 'running'::chat_status - AND heartbeat_at < @stale_threshold::timestamptz; + (status = 'running'::chat_status + AND heartbeat_at < @stale_threshold::timestamptz) + OR (status = 'requires_action'::chat_status + AND updated_at < @stale_threshold::timestamptz); -- name: UpdateChatHeartbeats :many -- Bumps the heartbeat timestamp for the given set of chat IDs, diff --git a/coderd/exp_chats.go b/coderd/exp_chats.go index 3804d694e6..db8745652a 100644 --- a/coderd/exp_chats.go +++ b/coderd/exp_chats.go @@ -398,6 +398,10 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) { return } + // Cap the raw request body to prevent excessive memory use + // from large dynamic tool schemas. + r.Body = http.MaxBytesReader(rw, r.Body, int64(2*maxSystemPromptLenBytes)) + var req codersdk.CreateChatRequest if !httpapi.Read(ctx, rw, r, &req) { return @@ -488,6 +492,50 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) { return } + if len(req.UnsafeDynamicTools) > 250 { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Too many dynamic tools.", + Detail: "Maximum 250 dynamic tools per chat.", + }) + return + } + + // Validate that dynamic tool names are non-empty and unique + // within the list. Name collision with built-in tools is + // checked at chatloop time when the full tool set is known. + if len(req.UnsafeDynamicTools) > 0 { + seenNames := make(map[string]struct{}, len(req.UnsafeDynamicTools)) + for _, dt := range req.UnsafeDynamicTools { + if dt.Name == "" { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Dynamic tool name must not be empty.", + }) + return + } + if _, exists := seenNames[dt.Name]; exists { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Duplicate dynamic tool name.", + Detail: fmt.Sprintf("Tool %q appears more than once.", dt.Name), + }) + return + } + seenNames[dt.Name] = struct{}{} + } + } + + var dynamicToolsJSON json.RawMessage + if len(req.UnsafeDynamicTools) > 0 { + var err error + dynamicToolsJSON, err = json.Marshal(req.UnsafeDynamicTools) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to marshal dynamic tools.", + Detail: err.Error(), + }) + return + } + } + chat, err := api.chatDaemon.CreateChat(ctx, chatd.CreateOptions{ OwnerID: apiKey.UserID, WorkspaceID: workspaceSelection.WorkspaceID, @@ -497,6 +545,7 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) { InitialUserContent: contentBlocks, MCPServerIDs: mcpServerIDs, Labels: labels, + DynamicTools: dynamicToolsJSON, }) if err != nil { if maybeWriteLimitErr(ctx, rw, err) { @@ -5751,3 +5800,77 @@ func (api *API) prInsights(rw http.ResponseWriter, r *http.Request) { RecentPRs: prEntries, }) } + +// EXPERIMENTAL: this endpoint is experimental and is subject to change. +// +//nolint:revive // HTTP handler writes to ResponseWriter. +func (api *API) postChatToolResults(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + chat := httpmw.ChatParam(r) + apiKey := httpmw.APIKey(r) + + // Cap the raw request body to prevent excessive memory use. + r.Body = http.MaxBytesReader(rw, r.Body, int64(2*maxSystemPromptLenBytes)) + var req codersdk.SubmitToolResultsRequest + + if !httpapi.Read(ctx, rw, r, &req) { + return + } + + if len(req.Results) == 0 { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "At least one tool result is required.", + }) + return + } + + // Fast-path check outside the transaction. The authoritative + // check happens inside SubmitToolResults under a row lock. + if chat.Status != database.ChatStatusRequiresAction { + httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{ + Message: "Chat is not waiting for tool results.", + Detail: fmt.Sprintf("Chat status is %q, expected %q.", chat.Status, database.ChatStatusRequiresAction), + }) + return + } + + var dynamicTools json.RawMessage + if chat.DynamicTools.Valid { + dynamicTools = chat.DynamicTools.RawMessage + } + + err := api.chatDaemon.SubmitToolResults(ctx, chatd.SubmitToolResultsOptions{ + ChatID: chat.ID, + UserID: apiKey.UserID, + ModelConfigID: chat.LastModelConfigID, + Results: req.Results, + DynamicTools: dynamicTools, + }) + if err != nil { + var validationErr *chatd.ToolResultValidationError + var conflictErr *chatd.ToolResultStatusConflictError + switch { + case errors.As(err, &conflictErr): + httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{ + Message: "Chat is not waiting for tool results.", + Detail: err.Error(), + }) + case errors.As(err, &validationErr): + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: validationErr.Message, + Detail: validationErr.Detail, + }) + default: + api.Logger.Error(ctx, "tool results submission failed", + slog.F("chat_id", chat.ID), + slog.Error(err), + ) + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error submitting tool results.", + }) + } + return + } + + rw.WriteHeader(http.StatusNoContent) +} diff --git a/coderd/exp_chats_test.go b/coderd/exp_chats_test.go index cebe1fde9c..66208c183e 100644 --- a/coderd/exp_chats_test.go +++ b/coderd/exp_chats_test.go @@ -16,7 +16,9 @@ import ( "time" "github.com/google/uuid" + "github.com/mark3labs/mcp-go/mcp" "github.com/shopspring/decimal" + "github.com/sqlc-dev/pqtype" "github.com/stretchr/testify/require" "golang.org/x/xerrors" @@ -8209,6 +8211,375 @@ func TestGetChatsByWorkspace(t *testing.T) { }) } +func TestSubmitToolResults(t *testing.T) { + t.Parallel() + + // setupRequiresAction creates a chat via the DB with dynamic tools, + // inserts an assistant message containing tool-call parts for each + // given toolCallID, and sets the chat status to requires_action. + // It returns the chat row so callers can exercise the endpoint. + setupRequiresAction := func( + ctx context.Context, + t *testing.T, + db database.Store, + ownerID uuid.UUID, + modelConfigID uuid.UUID, + dynamicToolName string, + toolCallIDs []string, + ) database.Chat { + t.Helper() + + // Marshal dynamic tools into the chat row. + dynamicTools := []mcp.Tool{{ + Name: dynamicToolName, + Description: "a test dynamic tool", + InputSchema: mcp.ToolInputSchema{Type: "object"}, + }} + dtJSON, err := json.Marshal(dynamicTools) + require.NoError(t, err) + + chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{ + Status: database.ChatStatusWaiting, + OwnerID: ownerID, + LastModelConfigID: modelConfigID, + Title: "tool-results-test", + DynamicTools: pqtype.NullRawMessage{RawMessage: dtJSON, Valid: true}, + }) + require.NoError(t, err) + + // Build assistant message with tool-call parts. + parts := make([]codersdk.ChatMessagePart, 0, len(toolCallIDs)) + for _, id := range toolCallIDs { + parts = append(parts, codersdk.ChatMessagePart{ + Type: codersdk.ChatMessagePartTypeToolCall, + ToolCallID: id, + ToolName: dynamicToolName, + Args: json.RawMessage(`{"key":"value"}`), + }) + } + content, err := chatprompt.MarshalParts(parts) + require.NoError(t, err) + + _, err = db.InsertChatMessages(dbauthz.AsSystemRestricted(ctx), database.InsertChatMessagesParams{ + ChatID: chat.ID, + CreatedBy: []uuid.UUID{uuid.Nil}, + ModelConfigID: []uuid.UUID{modelConfigID}, + Role: []database.ChatMessageRole{database.ChatMessageRoleAssistant}, + ContentVersion: []int16{chatprompt.CurrentContentVersion}, + Content: []string{string(content.RawMessage)}, + Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth}, + InputTokens: []int64{0}, + OutputTokens: []int64{0}, + TotalTokens: []int64{0}, + ReasoningTokens: []int64{0}, + CacheCreationTokens: []int64{0}, + CacheReadTokens: []int64{0}, + ContextLimit: []int64{0}, + Compressed: []bool{false}, + TotalCostMicros: []int64{0}, + RuntimeMs: []int64{0}, + }) + require.NoError(t, err) + + // Transition to requires_action. + chat, err = db.UpdateChatStatus(dbauthz.AsSystemRestricted(ctx), database.UpdateChatStatusParams{ + ID: chat.ID, + Status: database.ChatStatusRequiresAction, + }) + require.NoError(t, err) + require.Equal(t, database.ChatStatusRequiresAction, chat.Status) + + return chat + } + + t.Run("Success", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client, db := newChatClientWithDatabase(t) + user := coderdtest.CreateFirstUser(t, client.Client) + modelConfig := createChatModelConfig(t, client) + + const toolName = "my_dynamic_tool" + toolCallIDs := []string{"call_abc", "call_def"} + + chat := setupRequiresAction(ctx, t, db, user.UserID, modelConfig.ID, toolName, toolCallIDs) + + err := client.SubmitToolResults(ctx, chat.ID, codersdk.SubmitToolResultsRequest{ + Results: []codersdk.ToolResult{ + {ToolCallID: "call_abc", Output: json.RawMessage(`"result_a"`)}, + {ToolCallID: "call_def", Output: json.RawMessage(`"result_b"`)}, + }, + }) + require.NoError(t, err) + + // Verify status is no longer requires_action. The chatd + // loop may have already picked the chat up and + // transitioned it further (pending → running → …), so we + // accept any non-requires_action status. + gotChat, err := client.GetChat(ctx, chat.ID) + require.NoError(t, err) + require.NotEqual(t, codersdk.ChatStatusRequiresAction, gotChat.Status, + "chat should no longer be in requires_action after submitting tool results") + + // Verify tool-result messages were persisted. + msgsResp, err := client.GetChatMessages(ctx, chat.ID, nil) + require.NoError(t, err) + + var toolResultCount int + for _, msg := range msgsResp.Messages { + if msg.Role == codersdk.ChatMessageRoleTool { + toolResultCount++ + } + } + require.Equal(t, len(toolCallIDs), toolResultCount, + "expected one tool-result message per submitted result") + }) + + t.Run("WrongStatus", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client, db := newChatClientWithDatabase(t) + user := coderdtest.CreateFirstUser(t, client.Client) + modelConfig := createChatModelConfig(t, client) + + // Create a chat that is NOT in requires_action status. + chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{ + Status: database.ChatStatusWaiting, + OwnerID: user.UserID, + LastModelConfigID: modelConfig.ID, + Title: "wrong-status-test", + }) + require.NoError(t, err) + + err = client.SubmitToolResults(ctx, chat.ID, codersdk.SubmitToolResultsRequest{ + Results: []codersdk.ToolResult{ + {ToolCallID: "call_xyz", Output: json.RawMessage(`"nope"`)}, + }, + }) + requireSDKError(t, err, http.StatusConflict) + }) + + t.Run("MissingResult", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client, db := newChatClientWithDatabase(t) + user := coderdtest.CreateFirstUser(t, client.Client) + modelConfig := createChatModelConfig(t, client) + + const toolName = "my_dynamic_tool" + toolCallIDs := []string{"call_one", "call_two"} + + chat := setupRequiresAction(ctx, t, db, user.UserID, modelConfig.ID, toolName, toolCallIDs) + + // Submit only one of the two required results. + err := client.SubmitToolResults(ctx, chat.ID, codersdk.SubmitToolResultsRequest{ + Results: []codersdk.ToolResult{ + {ToolCallID: "call_one", Output: json.RawMessage(`"partial"`)}, + }, + }) + requireSDKError(t, err, http.StatusBadRequest) + }) + + t.Run("UnexpectedResult", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client, db := newChatClientWithDatabase(t) + user := coderdtest.CreateFirstUser(t, client.Client) + modelConfig := createChatModelConfig(t, client) + + const toolName = "my_dynamic_tool" + toolCallIDs := []string{"call_real"} + + chat := setupRequiresAction(ctx, t, db, user.UserID, modelConfig.ID, toolName, toolCallIDs) + + // Submit a result with a wrong tool_call_id. + err := client.SubmitToolResults(ctx, chat.ID, codersdk.SubmitToolResultsRequest{ + Results: []codersdk.ToolResult{ + {ToolCallID: "call_bogus", Output: json.RawMessage(`"wrong"`)}, + }, + }) + requireSDKError(t, err, http.StatusBadRequest) + }) + + t.Run("InvalidJSONOutput", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client, db := newChatClientWithDatabase(t) + user := coderdtest.CreateFirstUser(t, client.Client) + modelConfig := createChatModelConfig(t, client) + + const toolName = "my_dynamic_tool" + toolCallIDs := []string{"call_json"} + + chat := setupRequiresAction(ctx, t, db, user.UserID, modelConfig.ID, toolName, toolCallIDs) + + // We must bypass the SDK client because json.RawMessage + // rejects invalid JSON during json.Marshal. A raw HTTP + // request lets the invalid payload reach the server so we + // can verify server-side validation. + rawBody := `{"results":[{"tool_call_id":"call_json","output":not-json,"is_error":false}]}` + url := client.URL.JoinPath(fmt.Sprintf("/api/experimental/chats/%s/tool-results", chat.ID)).String() + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBufferString(rawBody)) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + req.Header.Set(codersdk.SessionTokenHeader, client.SessionToken()) + + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + require.Equal(t, http.StatusBadRequest, resp.StatusCode) + }) + + t.Run("DuplicateToolCallID", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + client, db := newChatClientWithDatabase(t) + user := coderdtest.CreateFirstUser(t, client.Client) + modelConfig := createChatModelConfig(t, client) + + const toolName = "my_dynamic_tool" + toolCallIDs := []string{"call_dup1", "call_dup2"} + + chat := setupRequiresAction(ctx, t, db, user.UserID, modelConfig.ID, toolName, toolCallIDs) + + err := client.SubmitToolResults(ctx, chat.ID, codersdk.SubmitToolResultsRequest{ + Results: []codersdk.ToolResult{ + {ToolCallID: "call_dup1", Output: json.RawMessage(`"result_a"`)}, + {ToolCallID: "call_dup1", Output: json.RawMessage(`"result_b"`)}, + }, + }) + sdkErr := requireSDKError(t, err, http.StatusBadRequest) + require.Contains(t, sdkErr.Message, "Duplicate tool_call_id") + }) + + t.Run("EmptyResults", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client, db := newChatClientWithDatabase(t) + user := coderdtest.CreateFirstUser(t, client.Client) + modelConfig := createChatModelConfig(t, client) + + const toolName = "my_dynamic_tool" + toolCallIDs := []string{"call_empty"} + + chat := setupRequiresAction(ctx, t, db, user.UserID, modelConfig.ID, toolName, toolCallIDs) + + err := client.SubmitToolResults(ctx, chat.ID, codersdk.SubmitToolResultsRequest{ + Results: []codersdk.ToolResult{}, + }) + requireSDKError(t, err, http.StatusBadRequest) + }) + + t.Run("NotFoundForDifferentUser", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client, db := newChatClientWithDatabase(t) + user := coderdtest.CreateFirstUser(t, client.Client) + modelConfig := createChatModelConfig(t, client) + + const toolName = "my_dynamic_tool" + toolCallIDs := []string{"call_other"} + + chat := setupRequiresAction(ctx, t, db, user.UserID, modelConfig.ID, toolName, toolCallIDs) + + // Create a second user and try to submit tool results + // to user A's chat. + otherClientRaw, _ := coderdtest.CreateAnotherUser( + t, client.Client, user.OrganizationID, + rbac.RoleAgentsAccess(), + ) + otherClient := codersdk.NewExperimentalClient(otherClientRaw) + + err := otherClient.SubmitToolResults(ctx, chat.ID, codersdk.SubmitToolResultsRequest{ + Results: []codersdk.ToolResult{ + {ToolCallID: "call_other", Output: json.RawMessage(`"nope"`)}, + }, + }) + requireSDKError(t, err, http.StatusNotFound) + }) +} + +func TestPostChats_DynamicToolValidation(t *testing.T) { + t.Parallel() + + t.Run("TooManyTools", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + _ = coderdtest.CreateFirstUser(t, client.Client) + _ = createChatModelConfig(t, client) + + tools := make([]codersdk.DynamicTool, 251) + for i := range tools { + tools[i] = codersdk.DynamicTool{ + Name: fmt.Sprintf("tool-%d", i), + } + } + + _, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ + Content: []codersdk.ChatInputPart{{ + Type: codersdk.ChatInputPartTypeText, + Text: "hello", + }}, + UnsafeDynamicTools: tools, + }) + sdkErr := requireSDKError(t, err, http.StatusBadRequest) + require.Equal(t, "Too many dynamic tools.", sdkErr.Message) + }) + + t.Run("EmptyToolName", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + _ = coderdtest.CreateFirstUser(t, client.Client) + _ = createChatModelConfig(t, client) + + _, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ + Content: []codersdk.ChatInputPart{{ + Type: codersdk.ChatInputPartTypeText, + Text: "hello", + }}, + UnsafeDynamicTools: []codersdk.DynamicTool{ + {Name: ""}, + }, + }) + sdkErr := requireSDKError(t, err, http.StatusBadRequest) + require.Equal(t, "Dynamic tool name must not be empty.", sdkErr.Message) + }) + + t.Run("DuplicateToolName", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + _ = coderdtest.CreateFirstUser(t, client.Client) + _ = createChatModelConfig(t, client) + + _, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ + Content: []codersdk.ChatInputPart{{ + Type: codersdk.ChatInputPartTypeText, + Text: "hello", + }}, + UnsafeDynamicTools: []codersdk.DynamicTool{ + {Name: "dup-tool"}, + {Name: "dup-tool"}, + }, + }) + sdkErr := requireSDKError(t, err, http.StatusBadRequest) + require.Equal(t, "Duplicate dynamic tool name.", sdkErr.Message) + }) +} + func requireSDKError(t *testing.T, err error, expectedStatus int) *codersdk.Error { t.Helper() diff --git a/coderd/pubsub/chatevent.go b/coderd/pubsub/chatevent.go index bdadf01055..426d0e395a 100644 --- a/coderd/pubsub/chatevent.go +++ b/coderd/pubsub/chatevent.go @@ -32,8 +32,9 @@ func HandleChatEvent(cb func(ctx context.Context, payload ChatEvent, err error)) } type ChatEvent struct { - Kind ChatEventKind `json:"kind"` - Chat codersdk.Chat `json:"chat"` + Kind ChatEventKind `json:"kind"` + Chat codersdk.Chat `json:"chat"` + ToolCalls []codersdk.ChatStreamToolCall `json:"tool_calls,omitempty"` } type ChatEventKind string @@ -44,4 +45,5 @@ const ( ChatEventKindCreated ChatEventKind = "created" ChatEventKindDeleted ChatEventKind = "deleted" ChatEventKindDiffStatusChange ChatEventKind = "diff_status_change" + ChatEventKindActionRequired ChatEventKind = "action_required" ) diff --git a/coderd/x/chatd/chatd.go b/coderd/x/chatd/chatd.go index 45a0fdd46b..b02ae12ac8 100644 --- a/coderd/x/chatd/chatd.go +++ b/coderd/x/chatd/chatd.go @@ -788,6 +788,7 @@ type CreateOptions struct { InitialUserContent []codersdk.ChatMessagePart MCPServerIDs []uuid.UUID Labels database.StringMap + DynamicTools json.RawMessage } // SendMessageBusyBehavior controls what happens when a chat is already active. @@ -899,6 +900,10 @@ func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.C RawMessage: labelsJSON, Valid: true, }, + DynamicTools: pqtype.NullRawMessage{ + RawMessage: opts.DynamicTools, + Valid: len(opts.DynamicTools) > 0, + }, }) if err != nil { return xerrors.Errorf("insert chat: %w", err) @@ -1546,6 +1551,238 @@ func (p *Server) PromoteQueued( return result, nil } +// SubmitToolResultsOptions controls tool result submission. +type SubmitToolResultsOptions struct { + ChatID uuid.UUID + UserID uuid.UUID + ModelConfigID uuid.UUID + Results []codersdk.ToolResult + DynamicTools json.RawMessage +} + +// ToolResultValidationError indicates the submitted tool results +// failed validation (e.g. missing, duplicate, or unexpected IDs, +// or invalid JSON output). +type ToolResultValidationError struct { + Message string + Detail string +} + +func (e *ToolResultValidationError) Error() string { + if e.Detail != "" { + return e.Message + ": " + e.Detail + } + return e.Message +} + +// ToolResultStatusConflictError indicates the chat is not in the +// requires_action state expected for tool result submission. +type ToolResultStatusConflictError struct { + ActualStatus database.ChatStatus +} + +func (e *ToolResultStatusConflictError) Error() string { + return fmt.Sprintf( + "chat status is %q, expected %q", + e.ActualStatus, database.ChatStatusRequiresAction, + ) +} + +// SubmitToolResults validates and persists client-provided tool +// results, transitions the chat to pending, and wakes the run +// loop. The caller is responsible for the fast-path status check; +// this method performs an authoritative re-check under a row lock. +func (p *Server) SubmitToolResults( + ctx context.Context, + opts SubmitToolResultsOptions, +) error { + dynamicToolNames, err := parseDynamicToolNames(pqtype.NullRawMessage{ + RawMessage: opts.DynamicTools, + Valid: len(opts.DynamicTools) > 0, + }) + if err != nil { + return xerrors.Errorf("parse chat dynamic tools: %w", err) + } + + // The GetLastChatMessageByRole lookup and all subsequent + // validation and persistence run inside a single transaction + // so the assistant message cannot change between reads. + var statusConflict *ToolResultStatusConflictError + txErr := p.db.InTx(func(tx database.Store) error { + // Authoritative status check under row lock. + locked, lockErr := tx.GetChatByIDForUpdate(ctx, opts.ChatID) + if lockErr != nil { + return xerrors.Errorf("lock chat for update: %w", lockErr) + } + if locked.Status != database.ChatStatusRequiresAction { + statusConflict = &ToolResultStatusConflictError{ + ActualStatus: locked.Status, + } + return statusConflict + } + + // Get the last assistant message inside the transaction + // for consistency with the row lock above. + lastAssistant, err := tx.GetLastChatMessageByRole(ctx, database.GetLastChatMessageByRoleParams{ + ChatID: opts.ChatID, + Role: database.ChatMessageRoleAssistant, + }) + if err != nil { + return xerrors.Errorf("get last assistant message: %w", err) + } + + // Collect tool-call IDs that already have results. + // When a dynamic tool name collides with a built-in, + // the chatloop executes it as a built-in and persists + // the result. Those calls must not count as pending. + afterMsgs, afterErr := tx.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ + ChatID: opts.ChatID, + AfterID: lastAssistant.ID, + }) + if afterErr != nil { + return xerrors.Errorf("get messages after assistant: %w", afterErr) + } + handledCallIDs := make(map[string]bool) + for _, msg := range afterMsgs { + if msg.Role != database.ChatMessageRoleTool { + continue + } + msgParts, msgParseErr := chatprompt.ParseContent(msg) + if msgParseErr != nil { + continue + } + for _, mp := range msgParts { + if mp.Type == codersdk.ChatMessagePartTypeToolResult { + handledCallIDs[mp.ToolCallID] = true + } + } + } + + // Extract pending dynamic tool-call IDs, skipping any + // that were already handled by the chatloop. + pendingCallIDs := make(map[string]bool) + toolCallIDToName := make(map[string]string) + parts, parseErr := chatprompt.ParseContent(lastAssistant) + if parseErr != nil { + return xerrors.Errorf("parse assistant message: %w", parseErr) + } + for _, part := range parts { + if part.Type == codersdk.ChatMessagePartTypeToolCall && + dynamicToolNames[part.ToolName] && + !handledCallIDs[part.ToolCallID] { + pendingCallIDs[part.ToolCallID] = true + toolCallIDToName[part.ToolCallID] = part.ToolName + } + } + + // Validate submitted results match pending calls exactly. + submittedIDs := make(map[string]bool, len(opts.Results)) + for _, result := range opts.Results { + if submittedIDs[result.ToolCallID] { + return &ToolResultValidationError{ + Message: "Duplicate tool_call_id in results.", + Detail: fmt.Sprintf("Duplicate tool call ID %q.", result.ToolCallID), + } + } + submittedIDs[result.ToolCallID] = true + } + for id := range pendingCallIDs { + if !submittedIDs[id] { + return &ToolResultValidationError{ + Message: "Missing tool result.", + Detail: fmt.Sprintf("Missing result for tool call %q.", id), + } + } + } + for id := range submittedIDs { + if !pendingCallIDs[id] { + return &ToolResultValidationError{ + Message: "Unexpected tool result.", + Detail: fmt.Sprintf("No pending tool call with ID %q.", id), + } + } + } + + // Marshal each tool result into a separate message row. + resultContents := make([]pqtype.NullRawMessage, 0, len(opts.Results)) + for _, result := range opts.Results { + if !json.Valid(result.Output) { + return &ToolResultValidationError{ + Message: "Tool result output must be valid JSON.", + Detail: fmt.Sprintf("Output for tool call %q is not valid JSON.", result.ToolCallID), + } + } + part := codersdk.ChatMessagePart{ + Type: codersdk.ChatMessagePartTypeToolResult, + ToolCallID: result.ToolCallID, + ToolName: toolCallIDToName[result.ToolCallID], + Result: result.Output, + IsError: result.IsError, + } + marshaled, marshalErr := chatprompt.MarshalParts([]codersdk.ChatMessagePart{part}) + if marshalErr != nil { + return xerrors.Errorf("marshal tool result: %w", marshalErr) + } + resultContents = append(resultContents, marshaled) + } + + // Insert tool-result messages. + n := len(resultContents) + params := database.InsertChatMessagesParams{ + ChatID: opts.ChatID, + CreatedBy: make([]uuid.UUID, n), + ModelConfigID: make([]uuid.UUID, n), + Role: make([]database.ChatMessageRole, n), + Content: make([]string, n), + ContentVersion: make([]int16, n), + Visibility: make([]database.ChatMessageVisibility, n), + InputTokens: make([]int64, n), + OutputTokens: make([]int64, n), + TotalTokens: make([]int64, n), + ReasoningTokens: make([]int64, n), + CacheCreationTokens: make([]int64, n), + CacheReadTokens: make([]int64, n), + ContextLimit: make([]int64, n), + Compressed: make([]bool, n), + TotalCostMicros: make([]int64, n), + RuntimeMs: make([]int64, n), + ProviderResponseID: make([]string, n), + } + for i, rc := range resultContents { + params.CreatedBy[i] = opts.UserID + params.ModelConfigID[i] = opts.ModelConfigID + params.Role[i] = database.ChatMessageRoleTool + params.Content[i] = string(rc.RawMessage) + params.ContentVersion[i] = chatprompt.CurrentContentVersion + params.Visibility[i] = database.ChatMessageVisibilityBoth + } + if _, insertErr := tx.InsertChatMessages(ctx, params); insertErr != nil { + return xerrors.Errorf("insert tool results: %w", insertErr) + } + + // Transition chat to pending. + if _, updateErr := tx.UpdateChatStatus(ctx, database.UpdateChatStatusParams{ + ID: opts.ChatID, + Status: database.ChatStatusPending, + WorkerID: uuid.NullUUID{}, + StartedAt: sql.NullTime{}, + HeartbeatAt: sql.NullTime{}, + LastError: sql.NullString{}, + }); updateErr != nil { + return xerrors.Errorf("update chat status: %w", updateErr) + } + + return nil + }, nil) + if txErr != nil { + return txErr + } + + // Wake the chatd run loop so it processes the chat immediately. + p.signalWake() + return nil +} + // InterruptChat interrupts execution, sets waiting status, and broadcasts status updates. func (p *Server) InterruptChat( ctx context.Context, @@ -1555,6 +1792,32 @@ func (p *Server) InterruptChat( return chat } + // If the chat is in requires_action, insert synthetic error + // tool-result messages for each pending dynamic tool call + // before transitioning to waiting. Without this, the LLM + // would see unmatched tool-call parts on the next run. + if chat.Status == database.ChatStatusRequiresAction { + if txErr := p.db.InTx(func(tx database.Store) error { + locked, lockErr := tx.GetChatByIDForUpdate(ctx, chat.ID) + if lockErr != nil { + return xerrors.Errorf("lock chat for interrupt: %w", lockErr) + } + // Another request may have already transitioned + // the chat (e.g. SubmitToolResults committed + // between our snapshot and this lock). + if locked.Status != database.ChatStatusRequiresAction { + return nil + } + return insertSyntheticToolResultsTx(ctx, tx, locked, "Tool execution interrupted by user") + }, nil); txErr != nil { + p.logger.Error(ctx, "failed to insert synthetic tool results during interrupt", + slog.F("chat_id", chat.ID), + slog.Error(txErr), + ) + // Fall through — still try to set waiting status. + } + } + updatedChat, err := p.setChatWaiting(ctx, chat.ID) if err != nil { p.logger.Error(ctx, "failed to mark chat as waiting", @@ -2345,7 +2608,7 @@ func insertUserMessageAndSetPending( // queued while a chat is active. func shouldQueueUserMessage(status database.ChatStatus) bool { switch status { - case database.ChatStatusRunning, database.ChatStatusPending: + case database.ChatStatusRunning, database.ChatStatusPending, database.ChatStatusRequiresAction: return true default: return false @@ -3218,8 +3481,12 @@ func (p *Server) Subscribe( // Pubsub will deliver a duplicate status // later; the frontend deduplicates it // (setChatStatus is idempotent). + // action_required is also transient and + // only published on the local stream, so + // it must be forwarded here. if event.Type == codersdk.ChatStreamEventTypeMessagePart || - event.Type == codersdk.ChatStreamEventTypeStatus { + event.Type == codersdk.ChatStreamEventTypeStatus || + event.Type == codersdk.ChatStreamEventTypeActionRequired { select { case <-mergedCtx.Done(): return @@ -3345,6 +3612,51 @@ func (p *Server) publishChatPubsubEvent(chat database.Chat, kind coderdpubsub.Ch } } +// pendingToStreamToolCalls converts a slice of chatloop pending +// tool calls into the SDK streaming representation. +func pendingToStreamToolCalls(pending []chatloop.PendingToolCall) []codersdk.ChatStreamToolCall { + calls := make([]codersdk.ChatStreamToolCall, len(pending)) + for i, tc := range pending { + calls[i] = codersdk.ChatStreamToolCall{ + ToolCallID: tc.ToolCallID, + ToolName: tc.ToolName, + Args: tc.Args, + } + } + return calls +} + +// publishChatActionRequired broadcasts an action_required event via +// PostgreSQL pubsub so that global watchers can react to dynamic +// tool calls without streaming each chat individually. +func (p *Server) publishChatActionRequired(chat database.Chat, pending []chatloop.PendingToolCall) { + if p.pubsub == nil { + return + } + toolCalls := pendingToStreamToolCalls(pending) + sdkChat := db2sdk.Chat(chat, nil, nil) + + event := coderdpubsub.ChatEvent{ + Kind: coderdpubsub.ChatEventKindActionRequired, + Chat: sdkChat, + ToolCalls: toolCalls, + } + payload, err := json.Marshal(event) + if err != nil { + p.logger.Error(context.Background(), "failed to marshal chat action_required pubsub event", + slog.F("chat_id", chat.ID), + slog.Error(err), + ) + return + } + if err := p.pubsub.Publish(coderdpubsub.ChatEventChannel(chat.OwnerID), payload); err != nil { + p.logger.Error(context.Background(), "failed to publish chat action_required pubsub event", + slog.F("chat_id", chat.ID), + slog.Error(err), + ) + } +} + // PublishDiffStatusChange broadcasts a diff_status_change event for // the given chat so that watching clients know to re-fetch the diff // status. This is called from the HTTP layer after the diff status @@ -3849,6 +4161,21 @@ func (p *Server) processChat(ctx context.Context, chat database.Chat) { } p.publishChatPubsubEvent(updatedChat, coderdpubsub.ChatEventKindStatusChange, nil) + // When the chat is parked in requires_action, + // publish the stream event and global pubsub event + // after the DB status has committed. Publishing + // here (not in runChat) prevents a race where a + // fast client reacts before the status is visible. + if status == database.ChatStatusRequiresAction && len(runResult.PendingDynamicToolCalls) > 0 { + toolCalls := pendingToStreamToolCalls(runResult.PendingDynamicToolCalls) + p.publishEvent(chat.ID, codersdk.ChatStreamEvent{ + Type: codersdk.ChatStreamEventTypeActionRequired, + ActionRequired: &codersdk.ChatStreamActionRequired{ + ToolCalls: toolCalls, + }, + }) + p.publishChatActionRequired(updatedChat, runResult.PendingDynamicToolCalls) + } if !wasInterrupted { p.maybeSendPushNotification(cleanupCtx, updatedChat, status, lastError, runResult, logger) } @@ -3877,6 +4204,13 @@ func (p *Server) processChat(ctx context.Context, chat database.Chat) { return } + // The LLM invoked a dynamic tool — park the chat in + // requires_action so the client can supply tool results. + if len(runResult.PendingDynamicToolCalls) > 0 { + status = database.ChatStatusRequiresAction + return + } + // If runChat completed successfully but the server context was // canceled (e.g. during Close()), the chat should be returned // to pending so another replica can pick it up. There is a @@ -3943,9 +4277,10 @@ func (t *generatedChatTitle) Load() (string, bool) { } type runChatResult struct { - FinalAssistantText string - PushSummaryModel fantasy.LanguageModel - ProviderKeys chatprovider.ProviderAPIKeys + FinalAssistantText string + PushSummaryModel fantasy.LanguageModel + ProviderKeys chatprovider.ProviderAPIKeys + PendingDynamicToolCalls []chatloop.PendingToolCall } func (p *Server) runChat( @@ -4249,8 +4584,8 @@ func (p *Server) runChat( // server. toolNameToConfigID := make(map[string]uuid.UUID) for _, t := range mcpTools { - if mcp, ok := t.(mcpclient.MCPToolIdentifier); ok { - toolNameToConfigID[t.Info().Name] = mcp.MCPServerConfigID() + if mcpTool, ok := t.(mcpclient.MCPToolIdentifier); ok { + toolNameToConfigID[t.Info().Name] = mcpTool.MCPServerConfigID() } } @@ -4269,6 +4604,7 @@ func (p *Server) runChat( // (which is the common case). modelConfigContextLimit := modelConfig.ContextLimit var finalAssistantText string + var pendingDynamicCalls []chatloop.PendingToolCall persistStep := func(persistCtx context.Context, step chatloop.PersistedStep) error { // If the chat context has been canceled, bail out before @@ -4288,6 +4624,10 @@ func (p *Server) runChat( return persistCtx.Err() } + // Capture pending dynamic tool calls so the caller + // can surface them after chatloop.Run returns. + pendingDynamicCalls = step.PendingDynamicToolCalls + // Split the step content into assistant blocks and tool // result blocks so they can be stored as separate messages // with the appropriate roles. Provider-executed tool results @@ -4674,6 +5014,39 @@ func (p *Server) runChat( tools = append(tools, mcpTools...) tools = append(tools, workspaceMCPTools...) + // Append dynamic tools declared by the client at chat + // creation time. These appear in the LLM's tool list but + // are never executed by the chatloop — the client handles + // execution via POST /tool-results. + dynamicToolNames, err := parseDynamicToolNames(chat.DynamicTools) + if err != nil { + return result, xerrors.Errorf("parse dynamic tool names: %w", err) + } + // Unmarshal the full definitions separately so we can + // build the filtered list below. parseDynamicToolNames + // already validated the JSON, so this cannot fail. + var dynamicToolDefs []codersdk.DynamicTool + if chat.DynamicTools.Valid { + if err := json.Unmarshal(chat.DynamicTools.RawMessage, &dynamicToolDefs); err != nil { + return result, xerrors.Errorf("unmarshal dynamic tools: %w", err) + } + } + for _, t := range tools { + info := t.Info() + if dynamicToolNames[info.Name] { + logger.Warn(ctx, "dynamic tool name collides with built-in tool, built-in takes precedence", + slog.F("tool_name", info.Name)) + delete(dynamicToolNames, info.Name) + } + } + + var filteredDefs []codersdk.DynamicTool + for _, dt := range dynamicToolDefs { + if dynamicToolNames[dt.Name] { + filteredDefs = append(filteredDefs, dt) + } + } + tools = append(tools, dynamicToolsFromSDK(p.logger, filteredDefs)...) // Build provider-native tools (e.g., web search) based on // the model configuration. var providerTools []chatloop.ProviderTool @@ -4717,8 +5090,7 @@ func (p *Server) runChat( ) prompt = filterPromptForChainMode(prompt, chainInfo.trailingUserCount) } - - err := chatloop.Run(ctx, chatloop.RunOptions{ + err = chatloop.Run(ctx, chatloop.RunOptions{ Model: model, Messages: prompt, Tools: tools, MaxSteps: maxChatSteps, @@ -4726,6 +5098,9 @@ func (p *Server) runChat( ModelConfig: callConfig, ProviderOptions: providerOptions, ProviderTools: providerTools, + // dynamicToolNames now contains only names that don't + // collide with built-in/MCP tools. + DynamicToolNames: dynamicToolNames, ContextLimitFallback: modelConfigContextLimit, @@ -4803,6 +5178,15 @@ func (p *Server) runChat( p.logger.Warn(ctx, "failed to persist interrupted chat step", slog.Error(err)) }, }) + if errors.Is(err, chatloop.ErrDynamicToolCall) { + // The stream event is published in processChat's + // defer after the DB status transitions to + // requires_action, preventing a race where a fast + // client reacts before the status is committed. + result.FinalAssistantText = finalAssistantText + result.PendingDynamicToolCalls = pendingDynamicCalls + return result, nil + } if err != nil { classified := chaterror.Classify(err).WithProvider(model.Provider()) return result, chaterror.WithClassification(err, classified) @@ -5424,7 +5808,9 @@ func (p *Server) recoverStaleChats(ctx context.Context) { recovered := 0 for _, chat := range staleChats { - p.logger.Info(ctx, "recovering stale chat", slog.F("chat_id", chat.ID)) + p.logger.Info(ctx, "recovering stale chat", + slog.F("chat_id", chat.ID), + slog.F("status", chat.Status)) // Use a transaction with FOR UPDATE to avoid a TOCTOU race: // between GetStaleChats (a bare SELECT) and here, the chat's @@ -5436,34 +5822,73 @@ func (p *Server) recoverStaleChats(ctx context.Context) { return xerrors.Errorf("lock chat for recovery: %w", lockErr) } - // Only recover chats that are still running. - // Between GetStaleChats and this lock, the chat - // may have completed normally. - if locked.Status != database.ChatStatusRunning { + switch locked.Status { + case database.ChatStatusRunning: + // Re-check: only recover if the chat is still stale. + // A valid heartbeat at or after the threshold means + // the chat was refreshed after our snapshot. + if locked.HeartbeatAt.Valid && !locked.HeartbeatAt.Time.Before(staleAfter) { + p.logger.Debug(ctx, "chat heartbeat refreshed since snapshot, skipping recovery", + slog.F("chat_id", chat.ID)) + return nil + } + case database.ChatStatusRequiresAction: + // Re-check: the chat may have been updated after + // our snapshot, similar to the heartbeat check for + // running chats. + if !locked.UpdatedAt.Before(staleAfter) { + p.logger.Debug(ctx, "chat updated since snapshot, skipping recovery", + slog.F("chat_id", chat.ID)) + return nil + } + default: + // Status changed since our snapshot; skip. p.logger.Debug(ctx, "chat status changed since snapshot, skipping recovery", slog.F("chat_id", chat.ID), slog.F("status", locked.Status)) return nil } - // Re-check: only recover if the chat is still stale. - // A valid heartbeat that is at or after the stale - // threshold means the chat was refreshed after our - // initial snapshot — skip it. - if locked.HeartbeatAt.Valid && !locked.HeartbeatAt.Time.Before(staleAfter) { - p.logger.Debug(ctx, "chat heartbeat refreshed since snapshot, skipping recovery", - slog.F("chat_id", chat.ID)) - return nil + lastError := sql.NullString{} + if locked.Status == database.ChatStatusRequiresAction { + lastError = sql.NullString{ + String: "Dynamic tool execution timed out", + Valid: true, + } } - // Reset to pending so any replica can pick it up. + recoverStatus := database.ChatStatusPending + if locked.Status == database.ChatStatusRequiresAction { + // Timed-out requires_action chats have dangling + // tool calls with no matching results. Setting + // them back to pending would replay incomplete + // tool calls to the LLM, so mark them as errors. + recoverStatus = database.ChatStatusError + } + + // Insert synthetic error tool-result messages + // so the LLM history remains valid if the user + // retries the chat later. + if locked.Status == database.ChatStatusRequiresAction { + if synthErr := insertSyntheticToolResultsTx(ctx, tx, locked, "Dynamic tool execution timed out"); synthErr != nil { + p.logger.Warn(ctx, "failed to insert synthetic tool results during stale recovery", + slog.F("chat_id", chat.ID), + slog.Error(synthErr), + ) + // Continue with error status even if + // synthetic results fail to insert. + } + } + + // Reset so any replica can pick it up (pending) or + // the client sees the failure (error). _, updateErr := tx.UpdateChatStatus(ctx, database.UpdateChatStatusParams{ ID: chat.ID, - Status: database.ChatStatusPending, + Status: recoverStatus, WorkerID: uuid.NullUUID{}, StartedAt: sql.NullTime{}, HeartbeatAt: sql.NullTime{}, - LastError: sql.NullString{}, + LastError: lastError, }) if updateErr != nil { return updateErr @@ -5482,6 +5907,119 @@ func (p *Server) recoverStaleChats(ctx context.Context) { } } +// insertSyntheticToolResultsTx inserts error tool-result messages for +// every pending dynamic tool call in the last assistant message. This +// keeps the LLM message history valid (every tool-call has a matching +// tool-result) when a requires_action chat times out or is interrupted. +// It operates on the provided store, which may be a transaction handle. +func insertSyntheticToolResultsTx( + ctx context.Context, + store database.Store, + chat database.Chat, + reason string, +) error { + dynamicToolNames, err := parseDynamicToolNames(chat.DynamicTools) + if err != nil { + return xerrors.Errorf("parse dynamic tools: %w", err) + } + if len(dynamicToolNames) == 0 { + return nil + } + + // Get the last assistant message to find pending tool calls. + lastAssistant, err := store.GetLastChatMessageByRole(ctx, database.GetLastChatMessageByRoleParams{ + ChatID: chat.ID, + Role: database.ChatMessageRoleAssistant, + }) + if err != nil { + return xerrors.Errorf("get last assistant message: %w", err) + } + + parts, err := chatprompt.ParseContent(lastAssistant) + if err != nil { + return xerrors.Errorf("parse assistant message: %w", err) + } + + // Collect dynamic tool calls that need synthetic results. + var resultContents []pqtype.NullRawMessage + for _, part := range parts { + if part.Type != codersdk.ChatMessagePartTypeToolCall || !dynamicToolNames[part.ToolName] { + continue + } + resultPart := codersdk.ChatMessagePart{ + Type: codersdk.ChatMessagePartTypeToolResult, + ToolCallID: part.ToolCallID, + ToolName: part.ToolName, + Result: json.RawMessage(fmt.Sprintf("%q", reason)), + IsError: true, + } + marshaled, marshalErr := chatprompt.MarshalParts([]codersdk.ChatMessagePart{resultPart}) + if marshalErr != nil { + return xerrors.Errorf("marshal synthetic tool result: %w", marshalErr) + } + resultContents = append(resultContents, marshaled) + } + + if len(resultContents) == 0 { + return nil + } + + // Insert tool-result messages using the same pattern as + // SubmitToolResults. + n := len(resultContents) + params := database.InsertChatMessagesParams{ + ChatID: chat.ID, + CreatedBy: make([]uuid.UUID, n), + ModelConfigID: make([]uuid.UUID, n), + Role: make([]database.ChatMessageRole, n), + Content: make([]string, n), + ContentVersion: make([]int16, n), + Visibility: make([]database.ChatMessageVisibility, n), + InputTokens: make([]int64, n), + OutputTokens: make([]int64, n), + TotalTokens: make([]int64, n), + ReasoningTokens: make([]int64, n), + CacheCreationTokens: make([]int64, n), + CacheReadTokens: make([]int64, n), + ContextLimit: make([]int64, n), + Compressed: make([]bool, n), + TotalCostMicros: make([]int64, n), + RuntimeMs: make([]int64, n), + ProviderResponseID: make([]string, n), + } + for i, rc := range resultContents { + params.CreatedBy[i] = uuid.Nil + params.ModelConfigID[i] = chat.LastModelConfigID + params.Role[i] = database.ChatMessageRoleTool + params.Content[i] = string(rc.RawMessage) + params.ContentVersion[i] = chatprompt.CurrentContentVersion + params.Visibility[i] = database.ChatMessageVisibilityBoth + } + if _, err := store.InsertChatMessages(ctx, params); err != nil { + return xerrors.Errorf("insert synthetic tool results: %w", err) + } + + return nil +} + +// parseDynamicToolNames unmarshals the dynamic tools JSON column +// and returns a map of tool names. This centralizes the repeated +// pattern of deserializing DynamicTools into a name set. +func parseDynamicToolNames(raw pqtype.NullRawMessage) (map[string]bool, error) { + if !raw.Valid || len(raw.RawMessage) == 0 { + return make(map[string]bool), nil + } + var tools []codersdk.DynamicTool + if err := json.Unmarshal(raw.RawMessage, &tools); err != nil { + return nil, xerrors.Errorf("unmarshal dynamic tools: %w", err) + } + names := make(map[string]bool, len(tools)) + for _, t := range tools { + names[t.Name] = true + } + return names, nil +} + // maybeSendPushNotification sends a web push notification when an // agent chat reaches a terminal state. For errors it dispatches // synchronously; for successful completions it spawns a goroutine diff --git a/coderd/x/chatd/chatd_test.go b/coderd/x/chatd/chatd_test.go index c8388a7054..c290eb33f2 100644 --- a/coderd/x/chatd/chatd_test.go +++ b/coderd/x/chatd/chatd_test.go @@ -1531,6 +1531,70 @@ func TestRecoverStaleChatsPeriodically(t *testing.T) { }, testutil.WaitMedium, testutil.IntervalFast) } +func TestRecoverStaleRequiresActionChat(t *testing.T) { + t.Parallel() + + db, ps, rawDB := dbtestutil.NewDBWithSQLDB(t) + + ctx := testutil.Context(t, testutil.WaitLong) + user, model := seedChatDependencies(ctx, t, db) + + // Use a very short stale threshold so the periodic recovery + // kicks in quickly during the test. + staleAfter := 500 * time.Millisecond + + // Create a chat and set it to requires_action to simulate a + // client that disappeared while the chat was waiting for + // dynamic tool results. + chat, err := db.InsertChat(ctx, database.InsertChatParams{ + Status: database.ChatStatusWaiting, + OwnerID: user.ID, + Title: "stale-requires-action", + LastModelConfigID: model.ID, + }) + require.NoError(t, err) + + _, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{ + ID: chat.ID, + Status: database.ChatStatusRequiresAction, + }) + require.NoError(t, err) + + // Backdate updated_at so the chat appears stale to the + // recovery loop without needing time.Sleep. + _, err = rawDB.ExecContext(ctx, + "UPDATE chats SET updated_at = $1 WHERE id = $2", + time.Now().Add(-time.Hour), chat.ID) + require.NoError(t, err) + + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + server := chatd.New(chatd.Config{ + Logger: logger, + Database: db, + ReplicaID: uuid.New(), + Pubsub: ps, + PendingChatAcquireInterval: testutil.WaitLong, + InFlightChatStaleAfter: staleAfter, + }) + t.Cleanup(func() { + require.NoError(t, server.Close()) + }) + + // The stale recovery should transition the requires_action + // chat to error with the timeout message. + var chatResult database.Chat + require.Eventually(t, func() bool { + chatResult, err = db.GetChatByID(ctx, chat.ID) + if err != nil { + return false + } + return chatResult.Status == database.ChatStatusError + }, testutil.WaitMedium, testutil.IntervalFast) + + require.Contains(t, chatResult.LastError.String, "Dynamic tool execution timed out") + require.False(t, chatResult.WorkerID.Valid) +} + func TestNewReplicaRecoversStaleChatFromDeadReplica(t *testing.T) { t.Parallel() @@ -1882,6 +1946,518 @@ func TestPersistToolResultWithBinaryData(t *testing.T) { require.True(t, foundToolResultInSecondCall, "expected second streamed model call to include execute tool output") } +func TestDynamicToolCallPausesAndResumes(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + ctx := testutil.Context(t, testutil.WaitLong) + + // Track streaming calls to the mock LLM. + var streamedCallCount atomic.Int32 + var streamedCallsMu sync.Mutex + streamedCalls := make([]chattest.OpenAIRequest, 0, 2) + + openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { + // Non-streaming requests are title generation — return a + // simple title. + if !req.Stream { + return chattest.OpenAINonStreamingResponse("Dynamic tool test") + } + + // Capture the full request for later assertions. + streamedCallsMu.Lock() + streamedCalls = append(streamedCalls, chattest.OpenAIRequest{ + Messages: append([]chattest.OpenAIMessage(nil), req.Messages...), + Tools: append([]chattest.OpenAITool(nil), req.Tools...), + Stream: req.Stream, + }) + streamedCallsMu.Unlock() + + if streamedCallCount.Add(1) == 1 { + // First call: the LLM invokes our dynamic tool. + return chattest.OpenAIStreamingResponse( + chattest.OpenAIToolCallChunk( + "my_dynamic_tool", + `{"input":"hello world"}`, + ), + ) + } + // Second call: the LLM returns a normal text response. + return chattest.OpenAIStreamingResponse( + chattest.OpenAITextChunks("Dynamic tool result received.")..., + ) + }) + + user, model := seedChatDependenciesWithProvider(ctx, t, db, "openai-compat", openAIURL) + + // Dynamic tools do not need a workspace connection, but the + // chatd server always builds workspace tools. Use an active + // server without an agent connection — the built-in tools + // are never invoked because the only tool call targets our + // dynamic tool. + server := newActiveTestServer(t, db, ps) + + // Create a chat with a dynamic tool. + dynamicToolsJSON, err := json.Marshal([]mcpgo.Tool{{ + Name: "my_dynamic_tool", + Description: "A test dynamic tool.", + InputSchema: mcpgo.ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "input": map[string]any{"type": "string"}, + }, + Required: []string{"input"}, + }, + }}) + require.NoError(t, err) + + chat, err := server.CreateChat(ctx, chatd.CreateOptions{ + OwnerID: user.ID, + Title: "dynamic-tool-pause-resume", + ModelConfigID: model.ID, + InitialUserContent: []codersdk.ChatMessagePart{ + codersdk.ChatMessageText("Please call the dynamic tool."), + }, + DynamicTools: dynamicToolsJSON, + }) + require.NoError(t, err) + + // 1. Wait for the chat to reach requires_action status. + var chatResult database.Chat + require.Eventually(t, func() bool { + got, getErr := db.GetChatByID(ctx, chat.ID) + if getErr != nil { + return false + } + chatResult = got + return got.Status == database.ChatStatusRequiresAction || + got.Status == database.ChatStatusError + }, testutil.WaitLong, testutil.IntervalFast) + + require.Equal(t, database.ChatStatusRequiresAction, chatResult.Status, + "expected requires_action, got %s (last_error=%q)", + chatResult.Status, chatResult.LastError.String) + + // 2. Read the assistant message to find the tool-call ID. + var toolCallID string + var toolCallFound bool + testutil.Eventually(ctx, t, func(ctx context.Context) bool { + messages, dbErr := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ + ChatID: chat.ID, + AfterID: 0, + }) + if dbErr != nil { + return false + } + for _, msg := range messages { + if msg.Role != database.ChatMessageRoleAssistant { + continue + } + parts, parseErr := chatprompt.ParseContent(msg) + if parseErr != nil { + continue + } + for _, part := range parts { + if part.Type == codersdk.ChatMessagePartTypeToolCall && part.ToolName == "my_dynamic_tool" { + toolCallID = part.ToolCallID + toolCallFound = true + return true + } + } + } + return false + }, testutil.IntervalFast) + require.True(t, toolCallFound, "expected to find tool call for my_dynamic_tool") + require.NotEmpty(t, toolCallID) + + // 3. Submit tool results via SubmitToolResults. + toolResultOutput := json.RawMessage(`{"result":"dynamic tool output"}`) + err = server.SubmitToolResults(ctx, chatd.SubmitToolResultsOptions{ + ChatID: chat.ID, + UserID: user.ID, + ModelConfigID: chatResult.LastModelConfigID, + Results: []codersdk.ToolResult{{ + ToolCallID: toolCallID, + Output: toolResultOutput, + }}, + DynamicTools: dynamicToolsJSON, + }) + require.NoError(t, err) + + // 4. Wait for the chat to reach a terminal status. + require.Eventually(t, func() bool { + got, getErr := db.GetChatByID(ctx, chat.ID) + if getErr != nil { + return false + } + chatResult = got + return got.Status == database.ChatStatusWaiting || got.Status == database.ChatStatusError + }, testutil.WaitLong, testutil.IntervalFast) + + // 5. Verify the chat completed successfully. + if chatResult.Status == database.ChatStatusError { + require.FailNowf(t, "chat run failed", "last_error=%q", chatResult.LastError.String) + } + + // 6. Verify the mock received exactly 2 streaming calls. + require.Equal(t, int32(2), streamedCallCount.Load(), + "expected exactly 2 streaming calls to the LLM") + + streamedCallsMu.Lock() + recordedCalls := append([]chattest.OpenAIRequest(nil), streamedCalls...) + streamedCallsMu.Unlock() + require.Len(t, recordedCalls, 2) + + // 7. Verify the dynamic tool appeared in the first call's tool list. + var foundDynamicTool bool + for _, tool := range recordedCalls[0].Tools { + if tool.Function.Name == "my_dynamic_tool" { + foundDynamicTool = true + break + } + } + require.True(t, foundDynamicTool, + "expected 'my_dynamic_tool' in the first LLM call's tool list") + + // 8. Verify the second call's messages contain the tool result. + var foundToolResultInSecondCall bool + for _, message := range recordedCalls[1].Messages { + if message.Role != "tool" { + continue + } + if strings.Contains(message.Content, "dynamic tool output") { + foundToolResultInSecondCall = true + break + } + } + require.True(t, foundToolResultInSecondCall, + "expected second LLM call to include the submitted dynamic tool result") +} + +func TestDynamicToolCallMixedWithBuiltIn(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + ctx := testutil.Context(t, testutil.WaitLong) + + // Track streaming calls to the mock LLM. + var streamedCallCount atomic.Int32 + var streamedCallsMu sync.Mutex + streamedCalls := make([]chattest.OpenAIRequest, 0, 2) + + openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { + if !req.Stream { + return chattest.OpenAINonStreamingResponse("Mixed tool test") + } + + streamedCallsMu.Lock() + streamedCalls = append(streamedCalls, chattest.OpenAIRequest{ + Messages: append([]chattest.OpenAIMessage(nil), req.Messages...), + Tools: append([]chattest.OpenAITool(nil), req.Tools...), + Stream: req.Stream, + }) + streamedCallsMu.Unlock() + + if streamedCallCount.Add(1) == 1 { + // First call: return TWO tool calls in one + // response — a built-in tool (read_file) and a + // dynamic tool (my_dynamic_tool). + builtinChunk := chattest.OpenAIToolCallChunk( + "read_file", + `{"path":"/tmp/test.txt"}`, + ) + dynamicChunk := chattest.OpenAIToolCallChunk( + "my_dynamic_tool", + `{"input":"hello world"}`, + ) + // Merge both tool calls into one chunk with + // separate indices so the LLM appears to have + // requested both tools simultaneously. + mergedChunk := builtinChunk + dynCall := dynamicChunk.Choices[0].ToolCalls[0] + dynCall.Index = 1 + mergedChunk.Choices[0].ToolCalls = append( + mergedChunk.Choices[0].ToolCalls, + dynCall, + ) + return chattest.OpenAIStreamingResponse(mergedChunk) + } + // Second call (after tool results): normal text + // response. + return chattest.OpenAIStreamingResponse( + chattest.OpenAITextChunks("All done.")..., + ) + }) + + user, model := seedChatDependenciesWithProvider(ctx, t, db, "openai-compat", openAIURL) + server := newActiveTestServer(t, db, ps) + + // Create a chat with a dynamic tool. + dynamicToolsJSON, err := json.Marshal([]mcpgo.Tool{{ + Name: "my_dynamic_tool", + Description: "A test dynamic tool.", + InputSchema: mcpgo.ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "input": map[string]any{"type": "string"}, + }, + Required: []string{"input"}, + }, + }}) + require.NoError(t, err) + + chat, err := server.CreateChat(ctx, chatd.CreateOptions{ + OwnerID: user.ID, + Title: "mixed-builtin-dynamic", + ModelConfigID: model.ID, + InitialUserContent: []codersdk.ChatMessagePart{ + codersdk.ChatMessageText("Call both tools."), + }, + DynamicTools: dynamicToolsJSON, + }) + require.NoError(t, err) + + // 1. Wait for the chat to reach requires_action status. + var chatResult database.Chat + require.Eventually(t, func() bool { + got, getErr := db.GetChatByID(ctx, chat.ID) + if getErr != nil { + return false + } + chatResult = got + return got.Status == database.ChatStatusRequiresAction || + got.Status == database.ChatStatusError + }, testutil.WaitLong, testutil.IntervalFast) + + require.Equal(t, database.ChatStatusRequiresAction, chatResult.Status, + "expected requires_action, got %s (last_error=%q)", + chatResult.Status, chatResult.LastError.String) + + // 2. Verify the built-in tool (read_file) was already + // executed by checking that a tool result message + // exists for it in the database. + var builtinToolResultFound bool + var toolCallID string + testutil.Eventually(ctx, t, func(ctx context.Context) bool { + messages, dbErr := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ + ChatID: chat.ID, + AfterID: 0, + }) + if dbErr != nil { + return false + } + for _, msg := range messages { + parts, parseErr := chatprompt.ParseContent(msg) + if parseErr != nil { + continue + } + for _, part := range parts { + // Check for the built-in tool result. + if part.Type == codersdk.ChatMessagePartTypeToolResult && part.ToolName == "read_file" { + builtinToolResultFound = true + } + // Find the dynamic tool call ID. + if part.Type == codersdk.ChatMessagePartTypeToolCall && part.ToolName == "my_dynamic_tool" { + toolCallID = part.ToolCallID + } + } + } + return builtinToolResultFound && toolCallID != "" + }, testutil.IntervalFast) + + require.True(t, builtinToolResultFound, + "expected read_file tool result in the DB before dynamic tool resolution") + require.NotEmpty(t, toolCallID) + + // 3. Submit dynamic tool results. + err = server.SubmitToolResults(ctx, chatd.SubmitToolResultsOptions{ + ChatID: chat.ID, + UserID: user.ID, + ModelConfigID: chatResult.LastModelConfigID, + Results: []codersdk.ToolResult{{ + ToolCallID: toolCallID, + Output: json.RawMessage(`{"result":"dynamic output"}`), + }}, + DynamicTools: dynamicToolsJSON, + }) + require.NoError(t, err) + + // 4. Wait for the chat to complete. + require.Eventually(t, func() bool { + got, getErr := db.GetChatByID(ctx, chat.ID) + if getErr != nil { + return false + } + chatResult = got + return got.Status == database.ChatStatusWaiting || got.Status == database.ChatStatusError + }, testutil.WaitLong, testutil.IntervalFast) + + if chatResult.Status == database.ChatStatusError { + require.FailNowf(t, "chat run failed", "last_error=%q", chatResult.LastError.String) + } + + // 5. Verify the LLM received exactly 2 streaming calls. + require.Equal(t, int32(2), streamedCallCount.Load(), + "expected exactly 2 streaming calls to the LLM") +} + +func TestSubmitToolResultsConcurrency(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + ctx := testutil.Context(t, testutil.WaitLong) + + // The mock LLM returns a dynamic tool call on the first streaming + // request, then a plain text reply on the second. + var streamedCallCount atomic.Int32 + openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { + if !req.Stream { + return chattest.OpenAINonStreamingResponse("Concurrency test") + } + if streamedCallCount.Add(1) == 1 { + return chattest.OpenAIStreamingResponse( + chattest.OpenAIToolCallChunk( + "my_dynamic_tool", + `{"input":"hello"}`, + ), + ) + } + return chattest.OpenAIStreamingResponse( + chattest.OpenAITextChunks("Done.")..., + ) + }) + + user, model := seedChatDependenciesWithProvider(ctx, t, db, "openai-compat", openAIURL) + server := newActiveTestServer(t, db, ps) + + // Create a chat with a dynamic tool. + dynamicToolsJSON, err := json.Marshal([]mcpgo.Tool{{ + Name: "my_dynamic_tool", + Description: "A test dynamic tool.", + InputSchema: mcpgo.ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "input": map[string]any{"type": "string"}, + }, + Required: []string{"input"}, + }, + }}) + require.NoError(t, err) + + chat, err := server.CreateChat(ctx, chatd.CreateOptions{ + OwnerID: user.ID, + Title: "concurrency-tool-results", + ModelConfigID: model.ID, + InitialUserContent: []codersdk.ChatMessagePart{ + codersdk.ChatMessageText("Please call the dynamic tool."), + }, + DynamicTools: dynamicToolsJSON, + }) + require.NoError(t, err) + + // Wait for the chat to reach requires_action status. + var chatResult database.Chat + require.Eventually(t, func() bool { + got, getErr := db.GetChatByID(ctx, chat.ID) + if getErr != nil { + return false + } + chatResult = got + return got.Status == database.ChatStatusRequiresAction || + got.Status == database.ChatStatusError + }, testutil.WaitLong, testutil.IntervalFast) + require.Equal(t, database.ChatStatusRequiresAction, chatResult.Status, + "expected requires_action, got %s (last_error=%q)", + chatResult.Status, chatResult.LastError.String) + + // Find the tool call ID from the assistant message. + var toolCallID string + testutil.Eventually(ctx, t, func(ctx context.Context) bool { + messages, dbErr := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ + ChatID: chat.ID, + AfterID: 0, + }) + if dbErr != nil { + return false + } + for _, msg := range messages { + if msg.Role != database.ChatMessageRoleAssistant { + continue + } + parts, parseErr := chatprompt.ParseContent(msg) + if parseErr != nil { + continue + } + for _, part := range parts { + if part.Type == codersdk.ChatMessagePartTypeToolCall && part.ToolName == "my_dynamic_tool" { + toolCallID = part.ToolCallID + return true + } + } + } + return false + }, testutil.IntervalFast) + require.NotEmpty(t, toolCallID) + + // Spawn N goroutines that all try to submit tool results at the + // same time. Exactly one should succeed; the rest must get a + // ToolResultStatusConflictError. + const numGoroutines = 10 + var ( + wg sync.WaitGroup + ready = make(chan struct{}) + successes atomic.Int32 + conflicts atomic.Int32 + unexpectedErrors = make(chan error, numGoroutines) + ) + + for range numGoroutines { + wg.Go(func() { + // Wait for all goroutines to be ready. + <-ready + + submitErr := server.SubmitToolResults(ctx, chatd.SubmitToolResultsOptions{ + ChatID: chat.ID, + UserID: user.ID, + ModelConfigID: chatResult.LastModelConfigID, + Results: []codersdk.ToolResult{{ + ToolCallID: toolCallID, + Output: json.RawMessage(`{"result":"concurrent output"}`), + }}, + DynamicTools: dynamicToolsJSON, + }) + + if submitErr == nil { + successes.Add(1) + return + } + var conflict *chatd.ToolResultStatusConflictError + if errors.As(submitErr, &conflict) { + conflicts.Add(1) + return + } + // Collect unexpected errors for assertion + // outside the goroutine (require.NoError + // calls t.FailNow which is illegal here). + unexpectedErrors <- submitErr + }) + } + // Release all goroutines at once. + close(ready) + + wg.Wait() + close(unexpectedErrors) + + for ue := range unexpectedErrors { + require.NoError(t, ue, "unexpected error from SubmitToolResults") + } + + require.Equal(t, int32(1), successes.Load(), + "expected exactly 1 goroutine to succeed") + require.Equal(t, int32(numGoroutines-1), conflicts.Load(), + "expected %d conflict errors", numGoroutines-1) +} + func ptrRef[T any](v T) *T { return &v } diff --git a/coderd/x/chatd/chatloop/chatloop.go b/coderd/x/chatd/chatloop/chatloop.go index 6453c60f23..d380cd8d88 100644 --- a/coderd/x/chatd/chatloop/chatloop.go +++ b/coderd/x/chatd/chatloop/chatloop.go @@ -38,13 +38,23 @@ const ( ) var ( - ErrInterrupted = xerrors.New("chat interrupted") + ErrInterrupted = xerrors.New("chat interrupted") + ErrDynamicToolCall = xerrors.New("dynamic tool call") errStartupTimeout = xerrors.New( "chat response did not start before the startup timeout", ) ) +// PendingToolCall describes a tool call that targets a dynamic +// tool. These calls are not executed by the chatloop; instead +// they are persisted so the caller can fulfill them externally. +type PendingToolCall struct { + ToolCallID string + ToolName string + Args string +} + // PersistedStep contains the full content of a completed or // interrupted agent step. Content includes both assistant blocks // (text, reasoning, tool calls) and tool result blocks. The @@ -60,6 +70,11 @@ type PersistedStep struct { // Zero indicates the duration was not measured (e.g. // interrupted steps). Runtime time.Duration + // PendingDynamicToolCalls lists tool calls that target + // dynamic tools. When non-empty the chatloop exits with + // ErrDynamicToolCall so the caller can execute them + // externally and resume the loop. + PendingDynamicToolCalls []PendingToolCall } // RunOptions configures a single streaming chat loop run. @@ -77,6 +92,12 @@ type RunOptions struct { ActiveTools []string ContextLimitFallback int64 + // DynamicToolNames lists tool names that are handled + // externally. When the model invokes one of these tools + // the chatloop persists partial results and exits with + // ErrDynamicToolCall instead of executing the tool. + DynamicToolNames map[string]bool + // ModelConfig holds per-call LLM parameters (temperature, // max tokens, etc.) read from the chat model configuration. ModelConfig codersdk.ChatModelCallConfig @@ -385,7 +406,22 @@ func Run(ctx context.Context, opts RunOptions) error { return ctx.Err() } - toolResults = executeTools(ctx, opts.Tools, opts.ProviderTools, result.toolCalls, func(tr fantasy.ToolResultContent) { + // Partition tool calls into built-in and dynamic. + var builtinCalls, dynamicCalls []fantasy.ToolCallContent + if len(opts.DynamicToolNames) > 0 { + for _, tc := range result.toolCalls { + if opts.DynamicToolNames[tc.ToolName] { + dynamicCalls = append(dynamicCalls, tc) + } else { + builtinCalls = append(builtinCalls, tc) + } + } + } else { + builtinCalls = result.toolCalls + } + + // Execute only built-in tools. + toolResults = executeTools(ctx, opts.Tools, opts.ProviderTools, builtinCalls, func(tr fantasy.ToolResultContent) { publishMessagePart( codersdk.ChatMessageRoleTool, chatprompt.PartFromContent(tr), @@ -395,6 +431,47 @@ func Run(ctx context.Context, opts RunOptions) error { result.content = append(result.content, tr) } + // If dynamic tools were called, persist what we + // have (assistant + built-in results) and exit so + // the caller can execute them externally. + if len(dynamicCalls) > 0 { + pending := make([]PendingToolCall, 0, len(dynamicCalls)) + for _, dc := range dynamicCalls { + pending = append(pending, PendingToolCall{ + ToolCallID: dc.ToolCallID, + ToolName: dc.ToolName, + Args: dc.Input, + }) + } + + contextLimit := extractContextLimit(result.providerMetadata) + if !contextLimit.Valid && opts.ContextLimitFallback > 0 { + contextLimit = sql.NullInt64{ + Int64: opts.ContextLimitFallback, + Valid: true, + } + } + + if err := opts.PersistStep(ctx, PersistedStep{ + Content: result.content, + Usage: result.usage, + ContextLimit: contextLimit, + ProviderResponseID: extractOpenAIResponseIDIfStored(opts.ProviderOptions, result.providerMetadata), + Runtime: time.Since(stepStart), + PendingDynamicToolCalls: pending, + }); err != nil { + if errors.Is(err, ErrInterrupted) { + persistInterruptedStep(ctx, opts, &result) + return ErrInterrupted + } + return xerrors.Errorf("persist step: %w", err) + } + + tryCompactOnExit(ctx, opts, result.usage, result.providerMetadata) + + return ErrDynamicToolCall + } + // Check for interruption after tool execution. // Tools that were canceled mid-flight produce error // results via ctx cancellation. Persist the full @@ -1088,6 +1165,38 @@ func persistInterruptedStep( } } +// tryCompactOnExit runs compaction when the chatloop is about +// to exit early (e.g. via ErrDynamicToolCall). The normal +// inline and post-run compaction paths are unreachable in +// early-exit scenarios, so this ensures the context window +// doesn't grow unbounded. +func tryCompactOnExit( + ctx context.Context, + opts RunOptions, + usage fantasy.Usage, + metadata fantasy.ProviderMetadata, +) { + if opts.Compaction == nil || opts.ReloadMessages == nil { + return + } + reloaded, err := opts.ReloadMessages(ctx) + if err != nil { + return + } + _, compactErr := tryCompact( + ctx, + opts.Model, + opts.Compaction, + opts.ContextLimitFallback, + usage, + metadata, + reloaded, + ) + if compactErr != nil && opts.Compaction.OnError != nil { + opts.Compaction.OnError(compactErr) + } +} + // buildToolDefinitions converts AgentTool definitions into the // fantasy.Tool slice expected by fantasy.Call. When activeTools // is non-empty, only function tools whose name appears in the diff --git a/coderd/x/chatd/chatloop/compaction_test.go b/coderd/x/chatd/chatloop/compaction_test.go index 5c0f501126..5a443bcd62 100644 --- a/coderd/x/chatd/chatloop/compaction_test.go +++ b/coderd/x/chatd/chatloop/compaction_test.go @@ -713,4 +713,76 @@ func TestRun_Compaction(t *testing.T) { } require.True(t, hasUser, "re-entry prompt must contain a user message (the compaction summary)") }) + + t.Run("TriggersOnDynamicToolExit", func(t *testing.T) { + t.Parallel() + + var persistCompactionCalls int + const summaryText = "compaction summary for dynamic tool exit" + + // The LLM calls a dynamic tool. Usage is above the + // compaction threshold so compaction should fire even + // though the chatloop exits via ErrDynamicToolCall. + model := &loopTestModel{ + provider: "fake", + streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { + return streamFromParts([]fantasy.StreamPart{ + {Type: fantasy.StreamPartTypeToolInputStart, ID: "tc-1", ToolCallName: "my_dynamic_tool"}, + {Type: fantasy.StreamPartTypeToolInputDelta, ID: "tc-1", Delta: `{"query": "test"}`}, + {Type: fantasy.StreamPartTypeToolInputEnd, ID: "tc-1"}, + { + Type: fantasy.StreamPartTypeToolCall, + ID: "tc-1", + ToolCallName: "my_dynamic_tool", + ToolCallInput: `{"query": "test"}`, + }, + { + Type: fantasy.StreamPartTypeFinish, + FinishReason: fantasy.FinishReasonToolCalls, + Usage: fantasy.Usage{ + InputTokens: 80, + TotalTokens: 85, + }, + }, + }), nil + }, + generateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) { + return &fantasy.Response{ + Content: []fantasy.Content{ + fantasy.TextContent{Text: summaryText}, + }, + }, nil + }, + } + + err := Run(context.Background(), RunOptions{ + Model: model, + Messages: []fantasy.Message{ + textMessage(fantasy.MessageRoleUser, "hello"), + }, + MaxSteps: 5, + DynamicToolNames: map[string]bool{"my_dynamic_tool": true}, + PersistStep: func(_ context.Context, _ PersistedStep) error { + return nil + }, + ContextLimitFallback: 100, + Compaction: &CompactionOptions{ + ThresholdPercent: 70, + SummaryPrompt: "summarize now", + Persist: func(_ context.Context, result CompactionResult) error { + persistCompactionCalls++ + require.Contains(t, result.SystemSummary, summaryText) + return nil + }, + }, + ReloadMessages: func(_ context.Context) ([]fantasy.Message, error) { + return []fantasy.Message{ + textMessage(fantasy.MessageRoleUser, "hello"), + }, nil + }, + }) + require.ErrorIs(t, err, ErrDynamicToolCall) + require.Equal(t, 1, persistCompactionCalls, + "compaction must fire before dynamic tool exit") + }) } diff --git a/coderd/x/chatd/dynamictool.go b/coderd/x/chatd/dynamictool.go new file mode 100644 index 0000000000..98ad4b6ff7 --- /dev/null +++ b/coderd/x/chatd/dynamictool.go @@ -0,0 +1,91 @@ +package chatd + +import ( + "context" + "encoding/json" + + "charm.land/fantasy" + + "cdr.dev/slog/v3" + "github.com/coder/coder/v2/codersdk" +) + +// dynamicTool wraps a codersdk.DynamicTool as a fantasy.AgentTool. +// These tools are presented to the LLM but never executed by the +// chatloop — when the LLM calls one, the chatloop exits with +// requires_action status and the client handles execution. +// The Run method should never be called; it returns an error if +// it is, as a safety net. +type dynamicTool struct { + name string + description string + parameters map[string]any + required []string + opts fantasy.ProviderOptions +} + +// dynamicToolsFromSDK converts codersdk.DynamicTool definitions +// into fantasy.AgentTool implementations for inclusion in the LLM +// tool list. +func dynamicToolsFromSDK(logger slog.Logger, tools []codersdk.DynamicTool) []fantasy.AgentTool { + if len(tools) == 0 { + return nil + } + result := make([]fantasy.AgentTool, 0, len(tools)) + for _, t := range tools { + dt := &dynamicTool{ + name: t.Name, + description: t.Description, + } + // InputSchema is a full JSON Schema object stored as + // json.RawMessage. Extract the "properties" and + // "required" fields that fantasy.ToolInfo expects. + if len(t.InputSchema) > 0 { + var schema struct { + Properties map[string]any `json:"properties"` + Required []string `json:"required"` + } + if err := json.Unmarshal(t.InputSchema, &schema); err != nil { + // Defensive: present the tool with no parameter + // constraints rather than failing. The LLM may + // hallucinate argument shapes, but the tool will + // still appear in the tool list. + logger.Warn(context.Background(), "failed to parse dynamic tool input schema", + slog.F("tool_name", t.Name), + slog.Error(err)) + } else { + dt.parameters = schema.Properties + dt.required = schema.Required + } + } + result = append(result, dt) + } + return result +} + +func (t *dynamicTool) Info() fantasy.ToolInfo { + return fantasy.ToolInfo{ + Name: t.name, + Description: t.description, + Parameters: t.parameters, + Required: t.required, + } +} + +func (*dynamicTool) Run(_ context.Context, _ fantasy.ToolCall) (fantasy.ToolResponse, error) { + // Dynamic tools are never executed by the chatloop. If this + // method is called, it indicates a bug in the chatloop's + // dynamic tool detection logic. + return fantasy.NewTextErrorResponse( + "dynamic tool called in chatloop — this is a bug; " + + "dynamic tools should be handled by the client", + ), nil +} + +func (t *dynamicTool) ProviderOptions() fantasy.ProviderOptions { + return t.opts +} + +func (t *dynamicTool) SetProviderOptions(opts fantasy.ProviderOptions) { + t.opts = opts +} diff --git a/coderd/x/chatd/dynamictool_internal_test.go b/coderd/x/chatd/dynamictool_internal_test.go new file mode 100644 index 0000000000..a6474c7c67 --- /dev/null +++ b/coderd/x/chatd/dynamictool_internal_test.go @@ -0,0 +1,114 @@ +package chatd + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" + + "cdr.dev/slog/v3/sloggers/slogtest" + "github.com/coder/coder/v2/codersdk" +) + +func TestDynamicToolsFromSDK(t *testing.T) { + t.Parallel() + + t.Run("EmptySlice", func(t *testing.T) { + t.Parallel() + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + result := dynamicToolsFromSDK(logger, nil) + require.Nil(t, result) + }) + + t.Run("ValidToolWithSchema", func(t *testing.T) { + t.Parallel() + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + tools := []codersdk.DynamicTool{ + { + Name: "my_tool", + Description: "A useful tool", + InputSchema: json.RawMessage(`{"type":"object","properties":{"input":{"type":"string"}},"required":["input"]}`), + }, + } + result := dynamicToolsFromSDK(logger, tools) + require.Len(t, result, 1) + + info := result[0].Info() + require.Equal(t, "my_tool", info.Name) + require.Equal(t, "A useful tool", info.Description) + require.NotNil(t, info.Parameters) + require.Contains(t, info.Parameters, "input") + require.Equal(t, []string{"input"}, info.Required) + }) + + t.Run("ToolWithoutSchema", func(t *testing.T) { + t.Parallel() + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + tools := []codersdk.DynamicTool{ + { + Name: "no_schema", + Description: "Tool with no schema", + }, + } + result := dynamicToolsFromSDK(logger, tools) + require.Len(t, result, 1) + + info := result[0].Info() + require.Equal(t, "no_schema", info.Name) + require.Nil(t, info.Parameters) + require.Nil(t, info.Required) + }) + + t.Run("MalformedSchema", func(t *testing.T) { + t.Parallel() + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + tools := []codersdk.DynamicTool{ + { + Name: "bad_schema", + Description: "Tool with malformed schema", + InputSchema: json.RawMessage("not-json"), + }, + } + result := dynamicToolsFromSDK(logger, tools) + require.Len(t, result, 1) + + info := result[0].Info() + require.Equal(t, "bad_schema", info.Name) + require.Nil(t, info.Parameters) + require.Nil(t, info.Required) + }) + + t.Run("MultipleTools", func(t *testing.T) { + t.Parallel() + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + tools := []codersdk.DynamicTool{ + {Name: "first", Description: "First tool"}, + {Name: "second", Description: "Second tool"}, + {Name: "third", Description: "Third tool"}, + } + result := dynamicToolsFromSDK(logger, tools) + require.Len(t, result, 3) + require.Equal(t, "first", result[0].Info().Name) + require.Equal(t, "second", result[1].Info().Name) + require.Equal(t, "third", result[2].Info().Name) + }) + + t.Run("SchemaWithoutProperties", func(t *testing.T) { + t.Parallel() + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + tools := []codersdk.DynamicTool{ + { + Name: "bare_schema", + Description: "Schema with no properties", + InputSchema: json.RawMessage(`{"type":"object"}`), + }, + } + result := dynamicToolsFromSDK(logger, tools) + require.Len(t, result, 1) + + info := result[0].Info() + require.Equal(t, "bare_schema", info.Name) + require.Nil(t, info.Parameters) + require.Nil(t, info.Required) + }) +} diff --git a/codersdk/chats.go b/codersdk/chats.go index dc3dc53046..07f0cce878 100644 --- a/codersdk/chats.go +++ b/codersdk/chats.go @@ -15,6 +15,7 @@ import ( "time" "github.com/google/uuid" + "github.com/invopop/jsonschema" "github.com/shopspring/decimal" "golang.org/x/xerrors" @@ -42,12 +43,13 @@ func CompactionThresholdKey(modelConfigID uuid.UUID) string { type ChatStatus string const ( - ChatStatusWaiting ChatStatus = "waiting" - ChatStatusPending ChatStatus = "pending" - ChatStatusRunning ChatStatus = "running" - ChatStatusPaused ChatStatus = "paused" - ChatStatusCompleted ChatStatus = "completed" - ChatStatusError ChatStatus = "error" + ChatStatusWaiting ChatStatus = "waiting" + ChatStatusPending ChatStatus = "pending" + ChatStatusRunning ChatStatus = "running" + ChatStatusPaused ChatStatus = "paused" + ChatStatusCompleted ChatStatus = "completed" + ChatStatusError ChatStatus = "error" + ChatStatusRequiresAction ChatStatus = "requires_action" ) // Chat represents a chat session with an AI agent. @@ -361,6 +363,18 @@ type ChatInputPart struct { Content string `json:"content,omitempty"` } +// SubmitToolResultsRequest is the body for POST /chats/{id}/tool-results. +type SubmitToolResultsRequest struct { + Results []ToolResult `json:"results"` +} + +// ToolResult is the client's response to a dynamic tool call. +type ToolResult struct { + ToolCallID string `json:"tool_call_id"` + Output json.RawMessage `json:"output"` + IsError bool `json:"is_error"` +} + // CreateChatRequest is the request to create a new chat. type CreateChatRequest struct { Content []ChatInputPart `json:"content"` @@ -369,6 +383,10 @@ type CreateChatRequest struct { ModelConfigID *uuid.UUID `json:"model_config_id,omitempty" format:"uuid"` MCPServerIDs []uuid.UUID `json:"mcp_server_ids,omitempty" format:"uuid"` Labels map[string]string `json:"labels,omitempty"` + // UnsafeDynamicTools declares client-executed tools that the + // LLM can invoke. This API is highly experimental and highly + // subject to change. + UnsafeDynamicTools []DynamicTool `json:"unsafe_dynamic_tools,omitempty"` } // UpdateChatRequest is the request to update a chat. @@ -928,12 +946,13 @@ type ChatDiffContents struct { type ChatStreamEventType string const ( - ChatStreamEventTypeMessagePart ChatStreamEventType = "message_part" - ChatStreamEventTypeMessage ChatStreamEventType = "message" - ChatStreamEventTypeStatus ChatStreamEventType = "status" - ChatStreamEventTypeError ChatStreamEventType = "error" - ChatStreamEventTypeQueueUpdate ChatStreamEventType = "queue_update" - ChatStreamEventTypeRetry ChatStreamEventType = "retry" + ChatStreamEventTypeMessagePart ChatStreamEventType = "message_part" + ChatStreamEventTypeMessage ChatStreamEventType = "message" + ChatStreamEventTypeStatus ChatStreamEventType = "status" + ChatStreamEventTypeError ChatStreamEventType = "error" + ChatStreamEventTypeQueueUpdate ChatStreamEventType = "queue_update" + ChatStreamEventTypeRetry ChatStreamEventType = "retry" + ChatStreamEventTypeActionRequired ChatStreamEventType = "action_required" ) // ChatQueuedMessage represents a queued message waiting to be processed. @@ -988,16 +1007,123 @@ type ChatStreamRetry struct { RetryingAt time.Time `json:"retrying_at" format:"date-time"` } +// ChatStreamActionRequired is the payload of an action_required stream event. +type ChatStreamActionRequired struct { + ToolCalls []ChatStreamToolCall `json:"tool_calls"` +} + +// ChatStreamToolCall describes a pending dynamic tool call that the client +// must execute. +type ChatStreamToolCall struct { + ToolCallID string `json:"tool_call_id"` + ToolName string `json:"tool_name"` + Args string `json:"args"` +} + +// DynamicToolCall represents a pending tool invocation from the +// chat stream that the client must execute and submit back. +type DynamicToolCall struct { + ToolCallID string `json:"tool_call_id"` + ToolName string `json:"tool_name"` + Args string `json:"args"` +} + +// DynamicToolResponse holds the output of a dynamic tool +// execution. IsError indicates a tool-level error the LLM +// should see, as opposed to an infrastructure failure +// (returned as the error return value). +type DynamicToolResponse struct { + Content string `json:"content"` + IsError bool `json:"is_error"` +} + +// DynamicTool describes a client-declared tool definition. On the +// client side, the Handler callback executes the tool when the LLM +// invokes it. On the server side, only Name, Description, and +// InputSchema are used (Handler is not serialized). +type DynamicTool struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + // InputSchema's JSON key "input_schema" uses snake_case for + // SDK consistency, deviating from the camelCase "inputSchema" + // convention used by MCP. + InputSchema json.RawMessage `json:"input_schema"` + + // Handler executes the tool when the LLM invokes it. + // Not serialized — this only exists on the client side. + Handler func(ctx context.Context, call DynamicToolCall) (DynamicToolResponse, error) `json:"-"` +} + +// NewDynamicTool creates a DynamicTool with a typed handler. +// The JSON schema is derived from T using invopop/jsonschema. +// The handler receives deserialized args and the DynamicToolCall metadata. +func NewDynamicTool[T any]( + name, description string, + handler func(ctx context.Context, args T, call DynamicToolCall) (DynamicToolResponse, error), +) DynamicTool { + reflector := jsonschema.Reflector{ + DoNotReference: true, + Anonymous: true, + AllowAdditionalProperties: true, + } + schema := reflector.Reflect(new(T)) + schema.Version = "" + schemaJSON, err := json.Marshal(schema) + if err != nil { + panic(fmt.Sprintf("codersdk: failed to marshal schema for %q: %v", name, err)) + } + + return DynamicTool{ + Name: name, + Description: description, + InputSchema: schemaJSON, + Handler: func(ctx context.Context, call DynamicToolCall) (DynamicToolResponse, error) { + var parsed T + if err := json.Unmarshal([]byte(call.Args), &parsed); err != nil { + return DynamicToolResponse{ + Content: fmt.Sprintf("invalid parameters: %s", err), + IsError: true, + }, nil + } + return handler(ctx, parsed, call) + }, + } +} + +// ChatWatchEventKind represents the kind of event in the chat watch stream. +type ChatWatchEventKind string + +const ( + ChatWatchEventKindStatusChange ChatWatchEventKind = "status_change" + ChatWatchEventKindTitleChange ChatWatchEventKind = "title_change" + ChatWatchEventKindCreated ChatWatchEventKind = "created" + ChatWatchEventKindDeleted ChatWatchEventKind = "deleted" + ChatWatchEventKindDiffStatusChange ChatWatchEventKind = "diff_status_change" + ChatWatchEventKindActionRequired ChatWatchEventKind = "action_required" +) + +// ChatWatchEvent represents an event from the global chat watch stream. +// It delivers lifecycle events (created, status change, title change) +// for all of the authenticated user's chats. When Kind is +// ActionRequired, ToolCalls contains the pending dynamic tool +// invocations the client must execute and submit back. +type ChatWatchEvent struct { + Kind ChatWatchEventKind `json:"kind"` + Chat Chat `json:"chat"` + ToolCalls []ChatStreamToolCall `json:"tool_calls,omitempty"` +} + // ChatStreamEvent represents a real-time update for chat streaming. type ChatStreamEvent struct { - Type ChatStreamEventType `json:"type"` - ChatID uuid.UUID `json:"chat_id" format:"uuid"` - Message *ChatMessage `json:"message,omitempty"` - MessagePart *ChatStreamMessagePart `json:"message_part,omitempty"` - Status *ChatStreamStatus `json:"status,omitempty"` - Error *ChatStreamError `json:"error,omitempty"` - Retry *ChatStreamRetry `json:"retry,omitempty"` - QueuedMessages []ChatQueuedMessage `json:"queued_messages,omitempty"` + Type ChatStreamEventType `json:"type"` + ChatID uuid.UUID `json:"chat_id" format:"uuid"` + Message *ChatMessage `json:"message,omitempty"` + MessagePart *ChatStreamMessagePart `json:"message_part,omitempty"` + Status *ChatStreamStatus `json:"status,omitempty"` + Error *ChatStreamError `json:"error,omitempty"` + Retry *ChatStreamRetry `json:"retry,omitempty"` + QueuedMessages []ChatQueuedMessage `json:"queued_messages,omitempty"` + ActionRequired *ChatStreamActionRequired `json:"action_required,omitempty"` } type chatStreamEnvelope struct { @@ -1940,6 +2066,73 @@ func (c *ExperimentalClient) StreamChat(ctx context.Context, chatID uuid.UUID, o }), nil } +// WatchChats streams lifecycle events for all of the authenticated +// user's chats in real time. The returned channel emits +// ChatWatchEvent values for status changes, title changes, creation, +// deletion, diff-status changes, and action-required notifications. +// Callers must close the returned io.Closer to release the websocket +// connection when done. +func (c *ExperimentalClient) WatchChats(ctx context.Context) (<-chan ChatWatchEvent, io.Closer, error) { + conn, err := c.Dial( + ctx, + "/api/experimental/chats/watch", + &websocket.DialOptions{CompressionMode: websocket.CompressionDisabled}, + ) + if err != nil { + return nil, nil, err + } + conn.SetReadLimit(1 << 22) // 4MiB + + streamCtx, streamCancel := context.WithCancel(ctx) + events := make(chan ChatWatchEvent, 128) + + go func() { + defer close(events) + defer streamCancel() + defer func() { + _ = conn.Close(websocket.StatusNormalClosure, "") + }() + + for { + var envelope chatStreamEnvelope + if err := wsjson.Read(streamCtx, conn, &envelope); err != nil { + if streamCtx.Err() != nil { + return + } + switch websocket.CloseStatus(err) { + case websocket.StatusNormalClosure, websocket.StatusGoingAway: + return + } + return + } + + switch envelope.Type { + case ServerSentEventTypePing: + continue + case ServerSentEventTypeData: + var event ChatWatchEvent + if err := json.Unmarshal(envelope.Data, &event); err != nil { + return + } + select { + case <-streamCtx.Done(): + return + case events <- event: + } + case ServerSentEventTypeError: + return + default: + return + } + } + }() + + return events, closeFunc(func() error { + streamCancel() + return nil + }), nil +} + // GetChat returns a chat by ID. func (c *ExperimentalClient) GetChat(ctx context.Context, chatID uuid.UUID) (Chat, error) { res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/experimental/chats/%s", chatID), nil) @@ -2247,6 +2440,20 @@ func (c *ExperimentalClient) GetMyChatUsageLimitStatus(ctx context.Context) (Cha return resp, json.NewDecoder(res.Body).Decode(&resp) } +// SubmitToolResults submits the results of dynamic tool calls for a chat +// that is in requires_action status. +func (c *ExperimentalClient) SubmitToolResults(ctx context.Context, chatID uuid.UUID, req SubmitToolResultsRequest) error { + res, err := c.Request(ctx, http.MethodPost, fmt.Sprintf("/api/experimental/chats/%s/tool-results", chatID), req) + if err != nil { + return err + } + defer res.Body.Close() + if res.StatusCode != http.StatusNoContent { + return ReadBodyAsError(res) + } + return nil +} + // GetChatsByWorkspace returns a mapping of workspace ID to the latest // non-archived chat ID for each requested workspace. Workspaces with // no chats are omitted from the response. diff --git a/codersdk/chats_test.go b/codersdk/chats_test.go index 219f8696f5..e2afe06c86 100644 --- a/codersdk/chats_test.go +++ b/codersdk/chats_test.go @@ -469,6 +469,68 @@ func TestChat_JSONRoundTrip(t *testing.T) { require.Equal(t, original, decoded) } +func TestNewDynamicTool(t *testing.T) { + t.Parallel() + + type testArgs struct { + Query string `json:"query"` + } + + t.Run("CorrectSchema", func(t *testing.T) { + t.Parallel() + + tool := codersdk.NewDynamicTool( + "search", "search things", + func(_ context.Context, args testArgs, _ codersdk.DynamicToolCall) (codersdk.DynamicToolResponse, error) { + return codersdk.DynamicToolResponse{Content: args.Query}, nil + }, + ) + + require.Equal(t, "search", tool.Name) + require.Equal(t, "search things", tool.Description) + require.Contains(t, string(tool.InputSchema), `"query"`) + require.Contains(t, string(tool.InputSchema), `"string"`) + }) + + t.Run("HandlerReceivesArgs", func(t *testing.T) { + t.Parallel() + + var received testArgs + tool := codersdk.NewDynamicTool( + "search", "search things", + func(_ context.Context, args testArgs, _ codersdk.DynamicToolCall) (codersdk.DynamicToolResponse, error) { + received = args + return codersdk.DynamicToolResponse{Content: "ok"}, nil + }, + ) + + resp, err := tool.Handler(context.Background(), codersdk.DynamicToolCall{ + Args: `{"query":"hello"}`, + }) + require.NoError(t, err) + require.Equal(t, "ok", resp.Content) + require.Equal(t, "hello", received.Query) + }) + + t.Run("InvalidJSONArgs", func(t *testing.T) { + t.Parallel() + + tool := codersdk.NewDynamicTool( + "search", "search things", + func(_ context.Context, args testArgs, _ codersdk.DynamicToolCall) (codersdk.DynamicToolResponse, error) { + return codersdk.DynamicToolResponse{Content: "should not reach"}, nil + }, + ) + + resp, err := tool.Handler(context.Background(), codersdk.DynamicToolCall{ + Args: "not-json", + }) + require.NoError(t, err) + require.True(t, resp.IsError) + require.Contains(t, resp.Content, "invalid parameters") + }) +} + //nolint:tparallel,paralleltest func TestParseChatWorkspaceTTL(t *testing.T) { t.Parallel() diff --git a/go.mod b/go.mod index b762cddf15..c0736ef03d 100644 --- a/go.mod +++ b/go.mod @@ -492,6 +492,7 @@ require ( github.com/elazarl/goproxy v1.8.0 github.com/fsnotify/fsnotify v1.9.0 github.com/go-git/go-git/v5 v5.17.1 + github.com/invopop/jsonschema v0.13.0 github.com/mark3labs/mcp-go v0.38.0 github.com/shopspring/decimal v1.4.0 gonum.org/v1/gonum v0.17.0 @@ -566,7 +567,6 @@ require ( github.com/hashicorp/go-getter v1.8.4 // indirect github.com/hexops/gotextdiff v1.0.3 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect - github.com/invopop/jsonschema v0.13.0 // indirect github.com/jackmordaunt/icns/v3 v3.0.1 // indirect github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 // indirect github.com/kaptinlin/go-i18n v0.2.4 // indirect diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index 6fe7278af1..2809c2f19b 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -1940,6 +1940,7 @@ export type ChatStatus = | "error" | "paused" | "pending" + | "requires_action" | "running" | "waiting"; @@ -1948,10 +1949,19 @@ export const ChatStatuses: ChatStatus[] = [ "error", "paused", "pending", + "requires_action", "running", "waiting", ]; +// From codersdk/chats.go +/** + * ChatStreamActionRequired is the payload of an action_required stream event. + */ +export interface ChatStreamActionRequired { + readonly tool_calls: readonly ChatStreamToolCall[]; +} + // From codersdk/chats.go /** * ChatStreamError represents an error event in the stream. @@ -1992,10 +2002,12 @@ export interface ChatStreamEvent { readonly error?: ChatStreamError; readonly retry?: ChatStreamRetry; readonly queued_messages?: readonly ChatQueuedMessage[]; + readonly action_required?: ChatStreamActionRequired; } // From codersdk/chats.go export type ChatStreamEventType = + | "action_required" | "error" | "message" | "message_part" @@ -2004,6 +2016,7 @@ export type ChatStreamEventType = | "status"; export const ChatStreamEventTypes: ChatStreamEventType[] = [ + "action_required", "error", "message", "message_part", @@ -2065,6 +2078,17 @@ export interface ChatStreamStatus { readonly status: ChatStatus; } +// From codersdk/chats.go +/** + * ChatStreamToolCall describes a pending dynamic tool call that the client + * must execute. + */ +export interface ChatStreamToolCall { + readonly tool_call_id: string; + readonly tool_name: string; + readonly args: string; +} + // From codersdk/chats.go /** * ChatSystemPromptResponse is the response body for the chat system prompt @@ -2217,6 +2241,38 @@ export interface ChatUsageLimitStatus { readonly period_end?: string; } +// From codersdk/chats.go +/** + * ChatWatchEvent represents an event from the global chat watch stream. + * It delivers lifecycle events (created, status change, title change) + * for all of the authenticated user's chats. When Kind is + * ActionRequired, ToolCalls contains the pending dynamic tool + * invocations the client must execute and submit back. + */ +export interface ChatWatchEvent { + readonly kind: ChatWatchEventKind; + readonly chat: Chat; + readonly tool_calls?: readonly ChatStreamToolCall[]; +} + +// From codersdk/chats.go +export type ChatWatchEventKind = + | "action_required" + | "created" + | "deleted" + | "diff_status_change" + | "status_change" + | "title_change"; + +export const ChatWatchEventKinds: ChatWatchEventKind[] = [ + "action_required", + "created", + "deleted", + "diff_status_change", + "status_change", + "title_change", +]; + // From codersdk/chats.go /** * ChatWorkspaceTTLResponse is the response for getting the chat @@ -2424,6 +2480,12 @@ export interface CreateChatRequest { readonly model_config_id?: string; readonly mcp_server_ids?: readonly string[]; readonly labels?: Record; + /** + * UnsafeDynamicTools declares client-executed tools that the + * LLM can invoke. This API is highly experimental and highly + * subject to change. + */ + readonly unsafe_dynamic_tools?: readonly DynamicTool[]; } // From codersdk/users.go @@ -3224,6 +3286,47 @@ export interface DynamicParametersResponse { readonly parameters: readonly PreviewParameter[]; } +// From codersdk/chats.go +/** + * DynamicTool describes a client-declared tool definition. On the + * client side, the Handler callback executes the tool when the LLM + * invokes it. On the server side, only Name, Description, and + * InputSchema are used (Handler is not serialized). + */ +export interface DynamicTool { + readonly name: string; + readonly description?: string; + /** + * InputSchema's JSON key "input_schema" uses snake_case for + * SDK consistency, deviating from the camelCase "inputSchema" + * convention used by MCP. + */ + readonly input_schema: Record; +} + +// From codersdk/chats.go +/** + * DynamicToolCall represents a pending tool invocation from the + * chat stream that the client must execute and submit back. + */ +export interface DynamicToolCall { + readonly tool_call_id: string; + readonly tool_name: string; + readonly args: string; +} + +// From codersdk/chats.go +/** + * DynamicToolResponse holds the output of a dynamic tool + * execution. IsError indicates a tool-level error the LLM + * should see, as opposed to an infrastructure failure + * (returned as the error return value). + */ +export interface DynamicToolResponse { + readonly content: string; + readonly is_error: boolean; +} + // From codersdk/chats.go /** * EditChatMessageRequest is the request to edit a user message in a chat. @@ -6432,6 +6535,14 @@ export interface StreamChatOptions { export const SubdomainAppSessionTokenCookie = "coder_subdomain_app_session_token"; +// From codersdk/chats.go +/** + * SubmitToolResultsRequest is the body for POST /chats/{id}/tool-results. + */ +export interface SubmitToolResultsRequest { + readonly results: readonly ToolResult[]; +} + // From codersdk/deployment.go export interface SupportConfig { readonly links: SerpentStruct; @@ -7180,6 +7291,16 @@ export interface TokensFilter { readonly include_expired: boolean; } +// From codersdk/chats.go +/** + * ToolResult is the client's response to a dynamic tool call. + */ +export interface ToolResult { + readonly tool_call_id: string; + readonly output: Record; + readonly is_error: boolean; +} + // From codersdk/deployment.go export interface TraceConfig { readonly enable: boolean; diff --git a/site/src/pages/AgentsPage/components/Sidebar/AgentsSidebar.tsx b/site/src/pages/AgentsPage/components/Sidebar/AgentsSidebar.tsx index 3377f8e478..125da20408 100644 --- a/site/src/pages/AgentsPage/components/Sidebar/AgentsSidebar.tsx +++ b/site/src/pages/AgentsPage/components/Sidebar/AgentsSidebar.tsx @@ -150,6 +150,7 @@ const statusConfig = { pending: { icon: Loader2Icon, className: "text-content-link animate-spin" }, running: { icon: Loader2Icon, className: "text-content-link animate-spin" }, paused: { icon: PauseIcon, className: "text-content-warning" }, + requires_action: { icon: PauseIcon, className: "text-content-warning" }, error: { icon: AlertTriangleIcon, className: "text-content-destructive" }, completed: { icon: CheckIcon, className: "text-content-secondary" }, } as const;