diff --git a/coderd/database/migrations/000508_chatd_core_state_machine.down.sql b/coderd/database/migrations/000508_chatd_core_state_machine.down.sql new file mode 100644 index 0000000000..75a532d1d1 --- /dev/null +++ b/coderd/database/migrations/000508_chatd_core_state_machine.down.sql @@ -0,0 +1,106 @@ +-- Rollback for the chatd core state machine foundation migration. + +-- 1. Recreate chats_expanded without the new chat fields. We must drop +-- the view first because the subsequent column drops would fail with +-- "view depends on column". +DROP VIEW IF EXISTS chats_expanded; + +-- 2. Drop the retry state trigger and function. +DROP TRIGGER IF EXISTS trigger_sync_chat_retry_state ON chats; +DROP FUNCTION IF EXISTS sync_chat_retry_state(); + +-- 3. Drop the queue version triggers and function. +DROP TRIGGER IF EXISTS trigger_bump_chat_queue_version_on_queued_message_delete ON chat_queued_messages; +DROP TRIGGER IF EXISTS trigger_bump_chat_queue_version_on_queued_message_update ON chat_queued_messages; +DROP TRIGGER IF EXISTS trigger_bump_chat_queue_version_on_queued_message_insert ON chat_queued_messages; +DROP FUNCTION IF EXISTS bump_chat_queue_version_on_queued_message_change(); + +-- 4. Drop the message revision triggers and functions. +DROP TRIGGER IF EXISTS trigger_update_chat_history_after_message_update ON chat_messages; +DROP TRIGGER IF EXISTS trigger_update_chat_history_after_message_insert ON chat_messages; +DROP TRIGGER IF EXISTS trigger_set_chat_message_revision_on_update ON chat_messages; +DROP TRIGGER IF EXISTS trigger_set_chat_message_revision_on_insert ON chat_messages; +DROP FUNCTION IF EXISTS update_chat_history_after_message_update(); +DROP FUNCTION IF EXISTS update_chat_history_after_message_insert(); +-- The pre-split function name is kept here for backward compatibility +-- with environments that may have applied an earlier draft of the up +-- migration. DROP FUNCTION IF EXISTS is a no-op if the function is +-- absent. +DROP FUNCTION IF EXISTS update_chat_history_after_message_changes(); +DROP FUNCTION IF EXISTS set_chat_message_revision_before(); +DROP FUNCTION IF EXISTS set_chat_message_revision(); + +-- 5. Drop chat_heartbeats (and its index by association). +DROP TABLE IF EXISTS chat_heartbeats; + +-- 6. Drop the queue-order index. +DROP INDEX IF EXISTS idx_chat_queued_messages_chat_position_id; + +-- 7. Drop chat_queued_messages.position and its default sequence, plus +-- created_by. +ALTER TABLE chat_queued_messages + ALTER COLUMN position DROP DEFAULT; +ALTER TABLE chat_queued_messages + DROP COLUMN IF EXISTS position, + DROP COLUMN IF EXISTS created_by; +DROP SEQUENCE IF EXISTS chat_queued_messages_position_seq; + +-- 8. Drop chat_messages.revision. +ALTER TABLE chat_messages + DROP COLUMN IF EXISTS revision; + +-- 9. Drop the new chats columns. +ALTER TABLE chats + DROP COLUMN IF EXISTS snapshot_version, + DROP COLUMN IF EXISTS history_version, + DROP COLUMN IF EXISTS queue_version, + DROP COLUMN IF EXISTS generation_attempt, + DROP COLUMN IF EXISTS retry_state, + DROP COLUMN IF EXISTS retry_state_version, + DROP COLUMN IF EXISTS runner_id, + DROP COLUMN IF EXISTS requires_action_deadline_at; + +-- 10. Recreate chats_expanded with the pre-migration field list. +CREATE VIEW chats_expanded AS +SELECT + c.id, + c.owner_id, + c.workspace_id, + c.title, + c.status, + c.worker_id, + c.started_at, + c.heartbeat_at, + c.created_at, + c.updated_at, + c.parent_chat_id, + c.root_chat_id, + c.last_model_config_id, + c.archived, + c.last_error, + c.mode, + c.mcp_server_ids, + c.labels, + c.build_id, + c.agent_id, + c.pin_order, + c.last_read_message_id, + c.last_injected_context, + c.dynamic_tools, + c.organization_id, + c.plan_mode, + c.client_type, + c.last_turn_summary, + COALESCE(root.user_acl, c.user_acl) AS user_acl, + COALESCE(root.group_acl, c.group_acl) AS group_acl, + owner.username AS owner_username, + owner.name AS owner_name +FROM + chats c + LEFT JOIN chats root ON root.id = COALESCE(c.root_chat_id, c.parent_chat_id) + JOIN visible_users owner ON owner.id = c.owner_id; + +-- 11. The `interrupting` chat_status enum value is intentionally left +-- in place. Postgres does not support dropping a single enum value +-- without recreating the entire type, which would require rewriting +-- every chat row and is unsafe inside a transactional rollback. diff --git a/coderd/database/migrations/000508_chatd_core_state_machine.up.sql b/coderd/database/migrations/000508_chatd_core_state_machine.up.sql new file mode 100644 index 0000000000..5361cd6e16 --- /dev/null +++ b/coderd/database/migrations/000508_chatd_core_state_machine.up.sql @@ -0,0 +1,346 @@ +-- Foundation for the chatd core state machine refactor (PR 1). +-- Adds new versioning fields to chats, a revision column to chat_messages, +-- positional ordering and creator tracking to chat_queued_messages, an +-- unlogged chat_heartbeats table for ownership leases, and Postgres +-- triggers that keep history/queue versioning consistent. +-- +-- This migration does not switch any runtime code over to the new +-- state machine. Existing chatd code keeps running against the legacy +-- "pending"/"paused"/"completed" enum values and the legacy ownership +-- columns (chats.started_at, chats.heartbeat_at). Those columns and +-- enum values stay in place for compatibility. + +-- 1. Add `interrupting` to the chat_status enum. +ALTER TYPE chat_status ADD VALUE IF NOT EXISTS 'interrupting'; + +-- 2. Add new versioning, ownership, retry, and pending-action fields to chats. +ALTER TABLE chats + ADD COLUMN snapshot_version bigint NOT NULL DEFAULT 1, + ADD COLUMN history_version bigint NOT NULL DEFAULT 0, + ADD COLUMN queue_version bigint NOT NULL DEFAULT 0, + ADD COLUMN generation_attempt bigint NOT NULL DEFAULT 0, + ADD COLUMN retry_state jsonb, + ADD COLUMN retry_state_version bigint NOT NULL DEFAULT 0, + ADD COLUMN runner_id uuid, + ADD COLUMN requires_action_deadline_at timestamp with time zone; + +-- 3. Add `revision` to chat_messages (nullable for now, backfilled below). +ALTER TABLE chat_messages + ADD COLUMN revision bigint; + +-- 4. Backfill chat_messages.revision = 1 for existing rows. +UPDATE chat_messages SET revision = 1 WHERE revision IS NULL; + +-- 5. Backfill chats.history_version = 1 for chats that already have at +-- least one message. We avoid recursive trigger fire by performing the +-- backfill before the triggers are created. +UPDATE chats +SET history_version = 1 +WHERE EXISTS ( + SELECT 1 FROM chat_messages WHERE chat_messages.chat_id = chats.id +); + +-- 6. Enforce NOT NULL on chat_messages.revision. +ALTER TABLE chat_messages + ALTER COLUMN revision SET NOT NULL; + +-- 7. Add `position` and `created_by` to chat_queued_messages. +ALTER TABLE chat_queued_messages + ADD COLUMN position bigint, + ADD COLUMN created_by uuid; + +-- 8. Backfill chat_queued_messages.position per chat using row_number(), +-- ordering by created_at and breaking ties by id. +WITH ordered AS ( + SELECT + id, + row_number() OVER ( + PARTITION BY chat_id + ORDER BY created_at, id + ) AS rn + FROM chat_queued_messages +) +UPDATE chat_queued_messages +SET position = ordered.rn +FROM ordered +WHERE chat_queued_messages.id = ordered.id; + +-- 9. Backfill chat_queued_messages.created_by from chats.owner_id. +UPDATE chat_queued_messages +SET created_by = chats.owner_id +FROM chats +WHERE chat_queued_messages.chat_id = chats.id + AND chat_queued_messages.created_by IS NULL; + +-- 10. Enforce NOT NULL on chat_queued_messages.position and +-- created_by. Legacy queued-message inserts are updated to populate +-- created_by from the chat owner when no explicit creator exists. +ALTER TABLE chat_queued_messages + ALTER COLUMN position SET NOT NULL, + ALTER COLUMN created_by SET NOT NULL; + +-- 11. Index for queue-order reads and head selection. +CREATE INDEX IF NOT EXISTS idx_chat_queued_messages_chat_position_id + ON chat_queued_messages(chat_id, position, id); + +-- 12. Default sequence for new queued-message positions. +-- A global sequence is acceptable because ordering only needs to be +-- stable within a chat. +CREATE SEQUENCE IF NOT EXISTS chat_queued_messages_position_seq AS bigint START WITH 1; +SELECT setval( + 'chat_queued_messages_position_seq', + GREATEST((SELECT COALESCE(MAX(position), 0) FROM chat_queued_messages), 1) +); +ALTER TABLE chat_queued_messages + ALTER COLUMN position SET DEFAULT nextval('chat_queued_messages_position_seq'); + +-- 13. Backfill chats.queue_version = 1 for chats that already have queued +-- messages. Same trigger-avoidance reasoning as for history_version. +UPDATE chats +SET queue_version = 1 +WHERE EXISTS ( + SELECT 1 FROM chat_queued_messages WHERE chat_queued_messages.chat_id = chats.id +); + +-- 14. chat_heartbeats: unlogged table for ownership leases. Keyed by +-- (chat_id, runner_id) so a single chat can briefly have entries from +-- multiple runners during failover. +CREATE UNLOGGED TABLE IF NOT EXISTS chat_heartbeats ( + chat_id uuid NOT NULL REFERENCES chats(id) ON DELETE CASCADE, + runner_id uuid NOT NULL, + heartbeat_at timestamp with time zone NOT NULL, + PRIMARY KEY (chat_id, runner_id) +); + +CREATE INDEX IF NOT EXISTS chat_heartbeats_heartbeat_at_idx + ON chat_heartbeats (heartbeat_at); + +-- 15. Message revision trigger. +-- The BEFORE-trigger only assigns NEW.revision from chats.snapshot_version +-- and validates immutability. The chats.history_version / +-- generation_attempt update is performed by an AFTER STATEMENT trigger +-- so it doesn't conflict with CTE updates on the chats row in the same +-- command (the legacy InsertChatMessages query updates last_model_config_id +-- in a CTE on chats and then inserts messages). +CREATE FUNCTION set_chat_message_revision_before() +RETURNS trigger AS $$ +DECLARE + chat_snapshot_version bigint; +BEGIN + IF TG_OP = 'INSERT' AND NEW.revision IS NOT NULL THEN + RAISE EXCEPTION 'chat_messages.revision must be assigned by trigger'; + END IF; + + IF TG_OP = 'UPDATE' THEN + IF OLD.chat_id IS DISTINCT FROM NEW.chat_id THEN + RAISE EXCEPTION 'chat_messages.chat_id is immutable'; + END IF; + + IF OLD.revision IS DISTINCT FROM NEW.revision THEN + RAISE EXCEPTION 'chat_messages.revision must be assigned by trigger'; + END IF; + + IF OLD IS NOT DISTINCT FROM NEW THEN + RETURN NEW; + END IF; + END IF; + + SELECT snapshot_version INTO chat_snapshot_version + FROM chats WHERE id = NEW.chat_id; + + IF chat_snapshot_version IS NULL THEN + RAISE EXCEPTION 'chat % does not exist', NEW.chat_id; + END IF; + + NEW.revision = chat_snapshot_version; + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + +-- AFTER STATEMENT trigger functions. Use the transition tables to +-- update chats.history_version / generation_attempt once per chat per +-- command. Running AFTER row inserts/updates complete lets a CTE +-- update on the same chats row in the same command finalize before +-- this trigger needs to update it. +-- +-- The INSERT and UPDATE variants are split so the UPDATE variant can +-- reference both the OLD and NEW transition tables and skip rows that +-- did not actually change. Without that filter, a no-op UPDATE on a +-- chat_messages row (one whose OLD IS NOT DISTINCT FROM NEW) would +-- still advance chats.history_version whenever the chat's snapshot +-- had previously been bumped. +CREATE FUNCTION update_chat_history_after_message_insert() +RETURNS trigger AS $$ +BEGIN + UPDATE chats c + SET history_version = c.snapshot_version, + generation_attempt = 0 + FROM ( + SELECT DISTINCT chat_id FROM chat_message_history_new_rows + ) AS affected + WHERE c.id = affected.chat_id + AND ( + c.history_version IS DISTINCT FROM c.snapshot_version + OR c.generation_attempt <> 0 + ); + RETURN NULL; +END; +$$ LANGUAGE plpgsql; + +CREATE FUNCTION update_chat_history_after_message_update() +RETURNS trigger AS $$ +BEGIN + UPDATE chats c + SET history_version = c.snapshot_version, + generation_attempt = 0 + FROM ( + SELECT DISTINCT n.chat_id + FROM chat_message_history_new_rows n + JOIN chat_message_history_old_rows o ON o.id = n.id + WHERE o IS DISTINCT FROM n + ) AS affected + WHERE c.id = affected.chat_id + AND ( + c.history_version IS DISTINCT FROM c.snapshot_version + OR c.generation_attempt <> 0 + ); + RETURN NULL; +END; +$$ LANGUAGE plpgsql; + +CREATE TRIGGER trigger_set_chat_message_revision_on_insert +BEFORE INSERT ON chat_messages +FOR EACH ROW +EXECUTE FUNCTION set_chat_message_revision_before(); + +CREATE TRIGGER trigger_set_chat_message_revision_on_update +BEFORE UPDATE ON chat_messages +FOR EACH ROW +EXECUTE FUNCTION set_chat_message_revision_before(); + +CREATE TRIGGER trigger_update_chat_history_after_message_insert +AFTER INSERT ON chat_messages +REFERENCING NEW TABLE AS chat_message_history_new_rows +FOR EACH STATEMENT +EXECUTE FUNCTION update_chat_history_after_message_insert(); + +CREATE TRIGGER trigger_update_chat_history_after_message_update +AFTER UPDATE ON chat_messages +REFERENCING OLD TABLE AS chat_message_history_old_rows NEW TABLE AS chat_message_history_new_rows +FOR EACH STATEMENT +EXECUTE FUNCTION update_chat_history_after_message_update(); + +-- 16. Queue version trigger function. +CREATE FUNCTION bump_chat_queue_version_on_queued_message_change() +RETURNS trigger AS $$ +DECLARE + changed_chat_id uuid; +BEGIN + IF TG_OP = 'DELETE' THEN + changed_chat_id = OLD.chat_id; + ELSE + changed_chat_id = NEW.chat_id; + END IF; + + UPDATE chats + SET queue_version = snapshot_version + WHERE id = changed_chat_id; + + IF TG_OP = 'DELETE' THEN + RETURN OLD; + END IF; + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + +CREATE TRIGGER trigger_bump_chat_queue_version_on_queued_message_insert +AFTER INSERT ON chat_queued_messages +FOR EACH ROW +EXECUTE FUNCTION bump_chat_queue_version_on_queued_message_change(); + +CREATE TRIGGER trigger_bump_chat_queue_version_on_queued_message_update +AFTER UPDATE OF content, model_config_id, position, created_by +ON chat_queued_messages +FOR EACH ROW +EXECUTE FUNCTION bump_chat_queue_version_on_queued_message_change(); + +CREATE TRIGGER trigger_bump_chat_queue_version_on_queued_message_delete +AFTER DELETE ON chat_queued_messages +FOR EACH ROW +EXECUTE FUNCTION bump_chat_queue_version_on_queued_message_change(); + +-- 17. Retry state trigger function. +CREATE FUNCTION sync_chat_retry_state() +RETURNS trigger AS $$ +BEGIN + IF OLD.retry_state_version IS DISTINCT FROM NEW.retry_state_version THEN + RAISE EXCEPTION 'chats.retry_state_version must be assigned by trigger'; + END IF; + + IF NEW.generation_attempt IS DISTINCT FROM OLD.generation_attempt THEN + NEW.retry_state = NULL; + END IF; + + IF NEW.retry_state IS DISTINCT FROM OLD.retry_state THEN + NEW.retry_state_version = NEW.snapshot_version; + END IF; + + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + +CREATE TRIGGER trigger_sync_chat_retry_state +BEFORE UPDATE OF retry_state, retry_state_version, generation_attempt +ON chats +FOR EACH ROW +EXECUTE FUNCTION sync_chat_retry_state(); + +-- 18. Refresh chats_expanded to include the new chat fields. Drop and +-- recreate so column ordering is stable. +DROP VIEW IF EXISTS chats_expanded; +CREATE VIEW chats_expanded AS +SELECT + c.id, + c.owner_id, + c.workspace_id, + c.title, + c.status, + c.worker_id, + c.started_at, + c.heartbeat_at, + c.created_at, + c.updated_at, + c.parent_chat_id, + c.root_chat_id, + c.last_model_config_id, + c.archived, + c.last_error, + c.mode, + c.mcp_server_ids, + c.labels, + c.build_id, + c.agent_id, + c.pin_order, + c.last_read_message_id, + c.last_injected_context, + c.dynamic_tools, + c.organization_id, + c.plan_mode, + c.client_type, + c.last_turn_summary, + c.snapshot_version, + c.history_version, + c.queue_version, + c.generation_attempt, + c.retry_state, + c.retry_state_version, + c.runner_id, + c.requires_action_deadline_at, + COALESCE(root.user_acl, c.user_acl) AS user_acl, + COALESCE(root.group_acl, c.group_acl) AS group_acl, + owner.username AS owner_username, + owner.name AS owner_name +FROM + chats c + LEFT JOIN chats root ON root.id = COALESCE(c.root_chat_id, c.parent_chat_id) + JOIN visible_users owner ON owner.id = c.owner_id; diff --git a/coderd/database/migrations/testdata/fixtures/000508_chatd_core_state_machine.up.sql b/coderd/database/migrations/testdata/fixtures/000508_chatd_core_state_machine.up.sql new file mode 100644 index 0000000000..31ce67fd1b --- /dev/null +++ b/coderd/database/migrations/testdata/fixtures/000508_chatd_core_state_machine.up.sql @@ -0,0 +1,17 @@ +-- Fixture coverage for the chat_heartbeats table introduced in +-- migration 000500. The earlier chat fixtures already insert at least +-- one row into chats; we attach a heartbeat for the first such chat so +-- migration tests see a non-empty chat_heartbeats table without +-- hard-coding a specific chat ID. +INSERT INTO chat_heartbeats ( + chat_id, + runner_id, + heartbeat_at +) +SELECT + chats.id, + '00000000-0000-0000-0000-0000000fea51'::uuid, + '2024-01-01 00:00:00+00' +FROM chats +ORDER BY created_at, id +LIMIT 1; diff --git a/coderd/database/modelqueries.go b/coderd/database/modelqueries.go index f33c047c94..74aa53859c 100644 --- a/coderd/database/modelqueries.go +++ b/coderd/database/modelqueries.go @@ -820,6 +820,12 @@ func (q *sqlQuerier) GetAuthorizedChats(ctx context.Context, arg GetChatsParams, &i.Chat.PlanMode, &i.Chat.ClientType, &i.Chat.LastTurnSummary, + &i.Chat.SnapshotVersion, + &i.Chat.HistoryVersion, + &i.Chat.QueueVersion, + &i.Chat.GenerationAttempt, + &i.Chat.RunnerID, + &i.Chat.RequiresActionDeadlineAt, &i.Chat.UserACL, &i.Chat.GroupACL, &i.Chat.OwnerUsername, @@ -887,6 +893,12 @@ func (q *sqlQuerier) GetAuthorizedChatsByChatFileID(ctx context.Context, fileID &i.PlanMode, &i.ClientType, &i.LastTurnSummary, + &i.SnapshotVersion, + &i.HistoryVersion, + &i.QueueVersion, + &i.GenerationAttempt, + &i.RunnerID, + &i.RequiresActionDeadlineAt, &i.UserACL, &i.GroupACL, &i.OwnerUsername, diff --git a/coderd/database/queries/chats.sql b/coderd/database/queries/chats.sql index 963eec817b..54bd1e1a8a 100644 --- a/coderd/database/queries/chats.sql +++ b/coderd/database/queries/chats.sql @@ -35,6 +35,14 @@ chats_expanded AS ( updated_chats.plan_mode, updated_chats.client_type, updated_chats.last_turn_summary, + updated_chats.snapshot_version, + updated_chats.history_version, + updated_chats.queue_version, + updated_chats.generation_attempt, + updated_chats.retry_state, + updated_chats.retry_state_version, + updated_chats.runner_id, + updated_chats.requires_action_deadline_at, COALESCE(root.user_acl, updated_chats.user_acl) AS user_acl, COALESCE(root.group_acl, updated_chats.group_acl) AS group_acl, owner.username AS owner_username, @@ -90,6 +98,14 @@ chats_expanded AS ( updated_chats.plan_mode, updated_chats.client_type, updated_chats.last_turn_summary, + updated_chats.snapshot_version, + updated_chats.history_version, + updated_chats.queue_version, + updated_chats.generation_attempt, + updated_chats.retry_state, + updated_chats.retry_state_version, + updated_chats.runner_id, + updated_chats.requires_action_deadline_at, COALESCE(root.user_acl, updated_chats.user_acl) AS user_acl, COALESCE(root.group_acl, updated_chats.group_acl) AS group_acl, owner.username AS owner_username, @@ -293,6 +309,15 @@ SELECT * FROM chats_expanded WHERE id = @id::uuid; +-- name: GetChatFamilyIDsByRootID :many +-- Returns the chat IDs of every chat in a family (root + all children) +-- in deterministic order. The id parameter must be the root id; the +-- query does not walk up from a child. +SELECT id +FROM chats +WHERE id = @id::uuid OR root_chat_id = @id::uuid +ORDER BY (id = @id::uuid) DESC, created_at ASC, id ASC; + -- name: GetChatACLByID :one SELECT user_acl AS users, @@ -719,6 +744,14 @@ chats_expanded AS ( inserted_chat.plan_mode, inserted_chat.client_type, inserted_chat.last_turn_summary, + inserted_chat.snapshot_version, + inserted_chat.history_version, + inserted_chat.queue_version, + inserted_chat.generation_attempt, + inserted_chat.retry_state, + inserted_chat.retry_state_version, + inserted_chat.runner_id, + inserted_chat.requires_action_deadline_at, COALESCE(root.user_acl, inserted_chat.user_acl) AS user_acl, COALESCE(root.group_acl, inserted_chat.group_acl) AS group_acl, owner.username AS owner_username, @@ -854,6 +887,14 @@ chats_expanded AS ( updated_chat.plan_mode, updated_chat.client_type, updated_chat.last_turn_summary, + updated_chat.snapshot_version, + updated_chat.history_version, + updated_chat.queue_version, + updated_chat.generation_attempt, + updated_chat.retry_state, + updated_chat.retry_state_version, + updated_chat.runner_id, + updated_chat.requires_action_deadline_at, COALESCE(root.user_acl, updated_chat.user_acl) AS user_acl, COALESCE(root.group_acl, updated_chat.group_acl) AS group_acl, owner.username AS owner_username, @@ -909,6 +950,14 @@ chats_expanded AS ( updated_chat.plan_mode, updated_chat.client_type, updated_chat.last_turn_summary, + updated_chat.snapshot_version, + updated_chat.history_version, + updated_chat.queue_version, + updated_chat.generation_attempt, + updated_chat.retry_state, + updated_chat.retry_state_version, + updated_chat.runner_id, + updated_chat.requires_action_deadline_at, COALESCE(root.user_acl, updated_chat.user_acl) AS user_acl, COALESCE(root.group_acl, updated_chat.group_acl) AS group_acl, owner.username AS owner_username, @@ -962,6 +1011,14 @@ chats_expanded AS ( updated_chat.plan_mode, updated_chat.client_type, updated_chat.last_turn_summary, + updated_chat.snapshot_version, + updated_chat.history_version, + updated_chat.queue_version, + updated_chat.generation_attempt, + updated_chat.retry_state, + updated_chat.retry_state_version, + updated_chat.runner_id, + updated_chat.requires_action_deadline_at, COALESCE(root.user_acl, updated_chat.user_acl) AS user_acl, COALESCE(root.group_acl, updated_chat.group_acl) AS group_acl, owner.username AS owner_username, @@ -1015,6 +1072,14 @@ chats_expanded AS ( updated_chat.plan_mode, updated_chat.client_type, updated_chat.last_turn_summary, + updated_chat.snapshot_version, + updated_chat.history_version, + updated_chat.queue_version, + updated_chat.generation_attempt, + updated_chat.retry_state, + updated_chat.retry_state_version, + updated_chat.runner_id, + updated_chat.requires_action_deadline_at, COALESCE(root.user_acl, updated_chat.user_acl) AS user_acl, COALESCE(root.group_acl, updated_chat.group_acl) AS group_acl, owner.username AS owner_username, @@ -1068,6 +1133,14 @@ chats_expanded AS ( updated_chat.plan_mode, updated_chat.client_type, updated_chat.last_turn_summary, + updated_chat.snapshot_version, + updated_chat.history_version, + updated_chat.queue_version, + updated_chat.generation_attempt, + updated_chat.retry_state, + updated_chat.retry_state_version, + updated_chat.runner_id, + updated_chat.requires_action_deadline_at, COALESCE(root.user_acl, updated_chat.user_acl) AS user_acl, COALESCE(root.group_acl, updated_chat.group_acl) AS group_acl, owner.username AS owner_username, @@ -1120,6 +1193,14 @@ chats_expanded AS ( updated_chat.plan_mode, updated_chat.client_type, updated_chat.last_turn_summary, + updated_chat.snapshot_version, + updated_chat.history_version, + updated_chat.queue_version, + updated_chat.generation_attempt, + updated_chat.retry_state, + updated_chat.retry_state_version, + updated_chat.runner_id, + updated_chat.requires_action_deadline_at, COALESCE(root.user_acl, updated_chat.user_acl) AS user_acl, COALESCE(root.group_acl, updated_chat.group_acl) AS group_acl, owner.username AS owner_username, @@ -1172,6 +1253,14 @@ chats_expanded AS ( updated_chat.plan_mode, updated_chat.client_type, updated_chat.last_turn_summary, + updated_chat.snapshot_version, + updated_chat.history_version, + updated_chat.queue_version, + updated_chat.generation_attempt, + updated_chat.retry_state, + updated_chat.retry_state_version, + updated_chat.runner_id, + updated_chat.requires_action_deadline_at, COALESCE(root.user_acl, updated_chat.user_acl) AS user_acl, COALESCE(root.group_acl, updated_chat.group_acl) AS group_acl, owner.username AS owner_username, @@ -1226,6 +1315,14 @@ chats_expanded AS ( updated_chat.plan_mode, updated_chat.client_type, updated_chat.last_turn_summary, + updated_chat.snapshot_version, + updated_chat.history_version, + updated_chat.queue_version, + updated_chat.generation_attempt, + updated_chat.retry_state, + updated_chat.retry_state_version, + updated_chat.runner_id, + updated_chat.requires_action_deadline_at, COALESCE(root.user_acl, updated_chat.user_acl) AS user_acl, COALESCE(root.group_acl, updated_chat.group_acl) AS group_acl, owner.username AS owner_username, @@ -1242,11 +1339,9 @@ FROM chats_expanded; -- Updates the cached last completed turn summary for sidebar display. -- Empty or whitespace-only summaries are stored as NULL here so direct -- query callers cannot accidentally persist blank sidebar text. --- This intentionally preserves updated_at. The staleness guard relies on --- every new-turn query, such as UpdateChatStatus and AcquireChats, bumping --- updated_at. Future chat-field updates that do not bump updated_at can let --- stale summaries persist. If this query ever bumps updated_at, later --- goroutine summary writes will be rejected as stale. +-- This intentionally preserves updated_at. The staleness guard uses +-- history_version so worker lifecycle transitions that do not change the +-- active message history cannot reject final turn summary writes. -- Two summary workers using the same freshness marker are last-write-wins. UPDATE chats SET @@ -1255,7 +1350,7 @@ SET ), '') WHERE id = @id::uuid - AND updated_at = @expected_updated_at::timestamptz; + AND history_version = @expected_history_version::bigint; -- name: UpdateChatMCPServerIDs :one WITH updated_chat AS ( @@ -1298,6 +1393,14 @@ chats_expanded AS ( updated_chat.plan_mode, updated_chat.client_type, updated_chat.last_turn_summary, + updated_chat.snapshot_version, + updated_chat.history_version, + updated_chat.queue_version, + updated_chat.generation_attempt, + updated_chat.retry_state, + updated_chat.retry_state_version, + updated_chat.runner_id, + updated_chat.requires_action_deadline_at, COALESCE(root.user_acl, updated_chat.user_acl) AS user_acl, COALESCE(root.group_acl, updated_chat.group_acl) AS group_acl, owner.username AS owner_username, @@ -1407,6 +1510,14 @@ chats_expanded AS ( acquired_chats.plan_mode, acquired_chats.client_type, acquired_chats.last_turn_summary, + acquired_chats.snapshot_version, + acquired_chats.history_version, + acquired_chats.queue_version, + acquired_chats.generation_attempt, + acquired_chats.retry_state, + acquired_chats.retry_state_version, + acquired_chats.runner_id, + acquired_chats.requires_action_deadline_at, COALESCE(root.user_acl, acquired_chats.user_acl) AS user_acl, COALESCE(root.group_acl, acquired_chats.group_acl) AS group_acl, owner.username AS owner_username, @@ -1464,6 +1575,14 @@ chats_expanded AS ( updated_chat.plan_mode, updated_chat.client_type, updated_chat.last_turn_summary, + updated_chat.snapshot_version, + updated_chat.history_version, + updated_chat.queue_version, + updated_chat.generation_attempt, + updated_chat.retry_state, + updated_chat.retry_state_version, + updated_chat.runner_id, + updated_chat.requires_action_deadline_at, COALESCE(root.user_acl, updated_chat.user_acl) AS user_acl, COALESCE(root.group_acl, updated_chat.group_acl) AS group_acl, owner.username AS owner_username, @@ -1521,6 +1640,14 @@ chats_expanded AS ( updated_chat.plan_mode, updated_chat.client_type, updated_chat.last_turn_summary, + updated_chat.snapshot_version, + updated_chat.history_version, + updated_chat.queue_version, + updated_chat.generation_attempt, + updated_chat.retry_state, + updated_chat.retry_state_version, + updated_chat.runner_id, + updated_chat.requires_action_deadline_at, COALESCE(root.user_acl, updated_chat.user_acl) AS user_acl, COALESCE(root.group_acl, updated_chat.group_acl) AS group_acl, owner.username AS owner_username, @@ -1688,12 +1815,17 @@ RETURNING *; -- name: InsertChatQueuedMessage :one -INSERT INTO chat_queued_messages (chat_id, content, model_config_id) -VALUES ( - @chat_id, - @content, - sqlc.narg('model_config_id')::uuid -) +-- Legacy queue insertion path. When no caller-supplied creator exists, +-- preserve the created_by invariant by attributing the queued row to the +-- chat owner. +INSERT INTO chat_queued_messages (chat_id, content, model_config_id, created_by) +SELECT + @chat_id::uuid, + @content::jsonb, + sqlc.narg('model_config_id')::uuid, + chats.owner_id +FROM chats +WHERE chats.id = @chat_id::uuid RETURNING *; -- name: GetChatQueuedMessages :many @@ -1779,6 +1911,14 @@ chats_expanded AS ( locked_chat.plan_mode, locked_chat.client_type, locked_chat.last_turn_summary, + locked_chat.snapshot_version, + locked_chat.history_version, + locked_chat.queue_version, + locked_chat.generation_attempt, + locked_chat.retry_state, + locked_chat.retry_state_version, + locked_chat.runner_id, + locked_chat.requires_action_deadline_at, COALESCE(root.user_acl, locked_chat.user_acl) AS user_acl, COALESCE(root.group_acl, locked_chat.group_acl) AS group_acl, owner.username AS owner_username, @@ -1791,6 +1931,63 @@ chats_expanded AS ( SELECT * FROM chats_expanded; +-- name: GetChatByIDForShare :one +WITH shared_chat AS ( + SELECT * + FROM chats + WHERE id = @id::uuid + FOR SHARE +), +chats_expanded AS ( + SELECT + shared_chat.id, + shared_chat.owner_id, + shared_chat.workspace_id, + shared_chat.title, + shared_chat.status, + shared_chat.worker_id, + shared_chat.started_at, + shared_chat.heartbeat_at, + shared_chat.created_at, + shared_chat.updated_at, + shared_chat.parent_chat_id, + shared_chat.root_chat_id, + shared_chat.last_model_config_id, + shared_chat.archived, + shared_chat.last_error, + shared_chat.mode, + shared_chat.mcp_server_ids, + shared_chat.labels, + shared_chat.build_id, + shared_chat.agent_id, + shared_chat.pin_order, + shared_chat.last_read_message_id, + shared_chat.last_injected_context, + shared_chat.dynamic_tools, + shared_chat.organization_id, + shared_chat.plan_mode, + shared_chat.client_type, + shared_chat.last_turn_summary, + shared_chat.snapshot_version, + shared_chat.history_version, + shared_chat.queue_version, + shared_chat.generation_attempt, + shared_chat.retry_state, + shared_chat.retry_state_version, + shared_chat.runner_id, + shared_chat.requires_action_deadline_at, + COALESCE(root.user_acl, shared_chat.user_acl) AS user_acl, + COALESCE(root.group_acl, shared_chat.group_acl) AS group_acl, + owner.username AS owner_username, + owner.name AS owner_name + FROM + shared_chat + LEFT JOIN chats root ON root.id = COALESCE(shared_chat.root_chat_id, shared_chat.parent_chat_id) + JOIN visible_users owner ON owner.id = shared_chat.owner_id +) +SELECT * +FROM chats_expanded; + -- name: GetChatsByChatFileID :many SELECT * @@ -2310,59 +2507,404 @@ WHERE chat_id = @chat_id::uuid AND deleted = false AND content::jsonb @> '[{"type": "context-file"}]'; --- name: AutoArchiveInactiveChats :many --- Archives inactive root chats (pinned and already-archived chats skipped), --- cascading to children via root_chat_id. Limits apply to roots, not total --- rows. The Go caller passes @archive_cutoff as UTC midnight so that all --- chats sharing the same last-activity date are archived together. --- Used by dbpurge. -WITH to_archive AS ( - SELECT - c.id, - -- Activity = MAX(cm.created_at) across the family, or c.created_at - -- when the family has no non-deleted messages. - COALESCE(activity.last_activity_at, c.created_at) AS last_activity_at - FROM chats c - LEFT JOIN LATERAL ( - SELECT MAX(cm.created_at) AS last_activity_at - FROM chat_messages cm - JOIN chats fc ON fc.id = cm.chat_id - WHERE (fc.id = c.id OR fc.root_chat_id = c.id) - AND cm.deleted = false - ) activity ON TRUE - WHERE c.archived = false - AND c.pin_order = 0 - AND c.parent_chat_id IS NULL -- roots only - -- Redundant filter helps the planner use the partial index on created_at. - AND c.created_at < @archive_cutoff::timestamptz - -- New active statuses must be added here to prevent archiving. - AND c.status NOT IN ('running', 'pending', 'paused', 'requires_action') - AND COALESCE(activity.last_activity_at, c.created_at) < @archive_cutoff::timestamptz - -- Sorting by created_at lets Postgres drive the scan from the - -- partial index instead of evaluating every LATERAL subquery - -- before sorting. All candidates are past the cutoff, so the - -- archive order is immaterial once the backlog drains. - ORDER BY c.created_at ASC - LIMIT @limit_count -), -archived AS ( - UPDATE chats c - SET archived = true, pin_order = 0, updated_at = NOW() - FROM to_archive t - WHERE (c.id = t.id OR c.root_chat_id = t.id) -- cascade to children - AND c.archived = false - RETURNING c.* -) +-- name: GetChatWorkerAcquisitionCandidates :many +-- Returns worker-runnable chats whose ownership is missing or whose +-- current runner heartbeat is stale. The runner_id IS NULL predicate is +-- a robustness extension for inconsistent rows where a worker_id exists +-- without a runner_id; normal missing ownership is worker_id IS NULL or +-- a missing or stale heartbeat row. SELECT - a.*, - -- Children inherit their root's activity so last_activity_at is never null. - COALESCE( - t.last_activity_at, - (SELECT tr.last_activity_at FROM to_archive tr WHERE tr.id = a.root_chat_id), - a.created_at - )::timestamptz AS last_activity_at -FROM archived a -LEFT JOIN to_archive t ON t.id = a.id --- created_at ASC flows through to dbpurge's digest truncation; see --- buildDigestData in dbpurge.go for the tradeoff rationale. -ORDER BY (a.root_chat_id IS NULL) DESC, a.owner_id ASC, a.created_at ASC, a.id ASC; + chats_expanded.*, + chat_heartbeats.heartbeat_at AS current_heartbeat_at, + NOT EXISTS ( + SELECT 1 + FROM chat_heartbeats current_lease + WHERE current_lease.chat_id = chats_expanded.id + AND current_lease.runner_id = chats_expanded.runner_id + AND current_lease.heartbeat_at > NOW() - (INTERVAL '1 second' * @stale_seconds::int) + ) AS heartbeat_stale +FROM chats_expanded +LEFT JOIN chat_heartbeats + ON chat_heartbeats.chat_id = chats_expanded.id + AND chat_heartbeats.runner_id = chats_expanded.runner_id +WHERE + chats_expanded.status IN ('running'::chat_status, 'interrupting'::chat_status, 'requires_action'::chat_status) + AND chats_expanded.archived = false + AND ( + chats_expanded.worker_id IS NULL + OR chats_expanded.runner_id IS NULL + OR NOT EXISTS ( + SELECT 1 + FROM chat_heartbeats current_lease + WHERE current_lease.chat_id = chats_expanded.id + AND current_lease.runner_id = chats_expanded.runner_id + AND current_lease.heartbeat_at > NOW() - (INTERVAL '1 second' * @stale_seconds::int) + ) + ) +ORDER BY chats_expanded.updated_at ASC, chats_expanded.id ASC +LIMIT @limit_count::int; + +-- name: GetChatsByIDsForRunnerSync :many +SELECT * +FROM chats_expanded +WHERE id = ANY(@ids::uuid[]) +ORDER BY id ASC; + +-- name: UpsertChatHeartbeats :exec +INSERT INTO chat_heartbeats (chat_id, runner_id, heartbeat_at) +SELECT chat_ids.chat_id, runner_ids.runner_id, NOW() +FROM unnest(@chat_ids::uuid[]) WITH ORDINALITY AS chat_ids(chat_id, ord) +JOIN unnest(@runner_ids::uuid[]) WITH ORDINALITY AS runner_ids(runner_id, ord) USING (ord) +ON CONFLICT (chat_id, runner_id) DO UPDATE +SET heartbeat_at = EXCLUDED.heartbeat_at; + +-- name: DeleteStaleChatHeartbeats :execrows +DELETE FROM chat_heartbeats +WHERE heartbeat_at < NOW() - (INTERVAL '1 second' * @stale_seconds::int); + +-- name: GetAutoArchiveInactiveChatCandidates :many +-- Returns read-only root chat candidates for state-machine-backed +-- auto-archive. Activity is computed across the root family. The query +-- limits roots, not total family members. +SELECT + chats_expanded.*, + COALESCE(activity.last_activity_at, chats_expanded.created_at)::timestamptz AS last_activity_at +FROM chats_expanded +LEFT JOIN LATERAL ( + SELECT MAX(chat_messages.created_at) AS last_activity_at + FROM chat_messages + JOIN chats family_chat ON family_chat.id = chat_messages.chat_id + WHERE (family_chat.id = chats_expanded.id OR family_chat.root_chat_id = chats_expanded.id) + AND chat_messages.deleted = false +) activity ON TRUE +WHERE + chats_expanded.archived = false + AND chats_expanded.pin_order = 0 + AND chats_expanded.parent_chat_id IS NULL + AND chats_expanded.created_at < @archive_cutoff::timestamptz + AND chats_expanded.status NOT IN ( + 'running'::chat_status, + 'interrupting'::chat_status, + 'pending'::chat_status, + 'paused'::chat_status, + 'requires_action'::chat_status + ) + AND COALESCE(activity.last_activity_at, chats_expanded.created_at) < @archive_cutoff::timestamptz +ORDER BY chats_expanded.created_at ASC +LIMIT @limit_count::int; + + +-- ===================================================================== +-- chatd core state machine queries. +-- +-- These are consumed by the coderd/x/chatd/chatstate package. They +-- are intentionally kept side-by-side with the legacy chatd queries +-- above so the existing runtime keeps working while the state machine +-- lands behind it. +-- ===================================================================== + +-- name: LockChatAndBumpSnapshotVersion :one +-- Locks the chat row with FOR UPDATE and atomically increments its +-- snapshot_version, returning the post-bump chat. This is the single +-- entry point ChatMachine.Update uses to acquire the row lock and +-- allocate a new snapshot version in one round trip. +WITH bumped_chat AS ( + UPDATE chats + SET snapshot_version = snapshot_version + 1 + WHERE id = ( + SELECT id FROM chats + WHERE id = @id::uuid + FOR UPDATE + ) + RETURNING * +), +chats_expanded AS ( + SELECT + bumped_chat.id, + bumped_chat.owner_id, + bumped_chat.workspace_id, + bumped_chat.title, + bumped_chat.status, + bumped_chat.worker_id, + bumped_chat.started_at, + bumped_chat.heartbeat_at, + bumped_chat.created_at, + bumped_chat.updated_at, + bumped_chat.parent_chat_id, + bumped_chat.root_chat_id, + bumped_chat.last_model_config_id, + bumped_chat.archived, + bumped_chat.last_error, + bumped_chat.mode, + bumped_chat.mcp_server_ids, + bumped_chat.labels, + bumped_chat.build_id, + bumped_chat.agent_id, + bumped_chat.pin_order, + bumped_chat.last_read_message_id, + bumped_chat.last_injected_context, + bumped_chat.dynamic_tools, + bumped_chat.organization_id, + bumped_chat.plan_mode, + bumped_chat.client_type, + bumped_chat.last_turn_summary, + bumped_chat.snapshot_version, + bumped_chat.history_version, + bumped_chat.queue_version, + bumped_chat.generation_attempt, + bumped_chat.retry_state, + bumped_chat.retry_state_version, + bumped_chat.runner_id, + bumped_chat.requires_action_deadline_at, + COALESCE(root.user_acl, bumped_chat.user_acl) AS user_acl, + COALESCE(root.group_acl, bumped_chat.group_acl) AS group_acl, + owner.username AS owner_username, + owner.name AS owner_name + FROM bumped_chat + LEFT JOIN chats root ON root.id = COALESCE(bumped_chat.root_chat_id, bumped_chat.parent_chat_id) + JOIN visible_users owner ON owner.id = bumped_chat.owner_id +) +SELECT * +FROM chats_expanded; + +-- name: UpdateChatExecutionState :one +-- Atomically updates the execution-state-managed fields on a chat: +-- status, archived, last_error, ownership identifiers, and the +-- requires-action deadline. Callers compose this with transition +-- mutations inside a single ChatMachine.Update transaction. +WITH updated_chat AS ( + UPDATE chats + SET + status = @status::chat_status, + archived = @archived::boolean, + worker_id = sqlc.narg('worker_id')::uuid, + runner_id = sqlc.narg('runner_id')::uuid, + last_error = sqlc.narg('last_error')::jsonb, + requires_action_deadline_at = sqlc.narg('requires_action_deadline_at')::timestamptz, + pin_order = CASE WHEN @archived::boolean THEN 0 ELSE pin_order END, + updated_at = NOW() + WHERE id = @id::uuid + RETURNING * +), +chats_expanded AS ( + SELECT + updated_chat.id, + updated_chat.owner_id, + updated_chat.workspace_id, + updated_chat.title, + updated_chat.status, + updated_chat.worker_id, + updated_chat.started_at, + updated_chat.heartbeat_at, + updated_chat.created_at, + updated_chat.updated_at, + updated_chat.parent_chat_id, + updated_chat.root_chat_id, + updated_chat.last_model_config_id, + updated_chat.archived, + updated_chat.last_error, + updated_chat.mode, + updated_chat.mcp_server_ids, + updated_chat.labels, + updated_chat.build_id, + updated_chat.agent_id, + updated_chat.pin_order, + updated_chat.last_read_message_id, + updated_chat.last_injected_context, + updated_chat.dynamic_tools, + updated_chat.organization_id, + updated_chat.plan_mode, + updated_chat.client_type, + updated_chat.last_turn_summary, + updated_chat.snapshot_version, + updated_chat.history_version, + updated_chat.queue_version, + updated_chat.generation_attempt, + updated_chat.retry_state, + updated_chat.retry_state_version, + updated_chat.runner_id, + updated_chat.requires_action_deadline_at, + COALESCE(root.user_acl, updated_chat.user_acl) AS user_acl, + COALESCE(root.group_acl, updated_chat.group_acl) AS group_acl, + owner.username AS owner_username, + owner.name AS owner_name + FROM updated_chat + LEFT JOIN chats root ON root.id = COALESCE(updated_chat.root_chat_id, updated_chat.parent_chat_id) + JOIN visible_users owner ON owner.id = updated_chat.owner_id +) +SELECT * +FROM chats_expanded; + +-- name: UpdateChatRetryState :one +-- Stores the client-visible retry payload. retry_state_version is +-- assigned by trigger from the current snapshot_version. +WITH updated_chat AS ( + UPDATE chats + SET + retry_state = @retry_state::jsonb, + updated_at = NOW() + WHERE id = @id::uuid + RETURNING * +), +chats_expanded AS ( + SELECT + updated_chat.id, + updated_chat.owner_id, + updated_chat.workspace_id, + updated_chat.title, + updated_chat.status, + updated_chat.worker_id, + updated_chat.started_at, + updated_chat.heartbeat_at, + updated_chat.created_at, + updated_chat.updated_at, + updated_chat.parent_chat_id, + updated_chat.root_chat_id, + updated_chat.last_model_config_id, + updated_chat.archived, + updated_chat.last_error, + updated_chat.mode, + updated_chat.mcp_server_ids, + updated_chat.labels, + updated_chat.build_id, + updated_chat.agent_id, + updated_chat.pin_order, + updated_chat.last_read_message_id, + updated_chat.last_injected_context, + updated_chat.dynamic_tools, + updated_chat.organization_id, + updated_chat.plan_mode, + updated_chat.client_type, + updated_chat.last_turn_summary, + updated_chat.snapshot_version, + updated_chat.history_version, + updated_chat.queue_version, + updated_chat.generation_attempt, + updated_chat.retry_state, + updated_chat.retry_state_version, + updated_chat.runner_id, + updated_chat.requires_action_deadline_at, + COALESCE(root.user_acl, updated_chat.user_acl) AS user_acl, + COALESCE(root.group_acl, updated_chat.group_acl) AS group_acl, + owner.username AS owner_username, + owner.name AS owner_name + FROM updated_chat + LEFT JOIN chats root ON root.id = COALESCE(updated_chat.root_chat_id, updated_chat.parent_chat_id) + JOIN visible_users owner ON owner.id = updated_chat.owner_id +) +SELECT * +FROM chats_expanded; + +-- name: IncrementChatGenerationAttempt :one +-- Increments generation_attempt and returns the resulting value. +UPDATE chats +SET generation_attempt = generation_attempt + 1, updated_at = NOW() +WHERE id = @id::uuid +RETURNING generation_attempt; + +-- name: GetDatabaseNow :one +-- Returns the current database timestamp. Used so transitions that +-- record deadlines or heartbeats rely on a clock that is consistent +-- with the database rather than the caller's local clock. +SELECT NOW()::timestamptz AS now; + +-- name: InsertChatQueuedMessageWithCreator :one +-- Inserts a queued message that carries a position (from the default +-- sequence) and an explicit created_by reference. Use this when the +-- queued-message creator differs from the chat owner. +INSERT INTO chat_queued_messages (chat_id, content, model_config_id, created_by) +VALUES ( + @chat_id::uuid, + @content::jsonb, + sqlc.narg('model_config_id')::uuid, + @created_by::uuid +) +RETURNING *; + +-- name: GetChatQueuedMessagesByPosition :many +-- Returns queued messages in state-machine order (position ASC, id ASC). +SELECT * FROM chat_queued_messages +WHERE chat_id = @chat_id::uuid +ORDER BY position ASC, id ASC; + +-- name: CountChatQueuedMessages :one +-- Cheap queue-length check used by ChatMachine.Update when deciding +-- whether the chat is in a "1" sub-state. +SELECT COUNT(*)::bigint AS count +FROM chat_queued_messages +WHERE chat_id = @chat_id::uuid; + +-- name: GetChatQueuedMessageHead :one +-- Returns the queue head (lowest position, then lowest id). +SELECT * FROM chat_queued_messages +WHERE chat_id = @chat_id::uuid +ORDER BY position ASC, id ASC +LIMIT 1; + +-- name: GetChatQueuedMessageByID :one +SELECT * FROM chat_queued_messages +WHERE id = @id::bigint AND chat_id = @chat_id::uuid; + +-- name: DeleteChatQueuedMessageReturningCount :execrows +-- Deletes a queued message, scoped to the parent chat. Returns the +-- number of affected rows so callers can detect missing rows without +-- a follow-up read. +DELETE FROM chat_queued_messages +WHERE id = @id::bigint AND chat_id = @chat_id::uuid; + +-- name: DeleteAllChatQueuedMessagesReturningCount :execrows +DELETE FROM chat_queued_messages +WHERE chat_id = @chat_id::uuid; + +-- name: ReorderChatQueuedMessageToHead :execrows +-- Sets the target queued message's position to one less than the +-- current minimum position for that chat, moving it to the head. +UPDATE chat_queued_messages AS target +SET position = COALESCE( + (SELECT MIN(position) FROM chat_queued_messages WHERE chat_id = @chat_id::uuid), + 0 +) - 1 +WHERE target.id = @id::bigint + AND target.chat_id = @chat_id::uuid + AND target.position > COALESCE( + (SELECT MIN(position) FROM chat_queued_messages WHERE chat_id = @chat_id::uuid), + target.position + ); + +-- name: UpsertChatHeartbeat :exec +-- Upserts a heartbeat row for the (chat_id, runner_id) lease. Uses +-- database time so callers do not depend on a local clock. +INSERT INTO chat_heartbeats (chat_id, runner_id, heartbeat_at) +VALUES (@chat_id::uuid, @runner_id::uuid, NOW()) +ON CONFLICT (chat_id, runner_id) DO UPDATE +SET heartbeat_at = EXCLUDED.heartbeat_at; + +-- name: GetChatHeartbeat :one +SELECT * FROM chat_heartbeats +WHERE chat_id = @chat_id::uuid AND runner_id = @runner_id::uuid; + +-- name: IsChatHeartbeatStale :one +-- Returns true when there is no heartbeat row for (chat_id, runner_id) +-- or the existing row is older than @stale_seconds seconds by database +-- time. chatstate calls this in a single query so the staleness check +-- is atomic and does not depend on the caller's local clock. +SELECT NOT EXISTS ( + SELECT 1 FROM chat_heartbeats + WHERE chat_id = @chat_id::uuid + AND runner_id = @runner_id::uuid + AND heartbeat_at > NOW() - (INTERVAL '1 second' * @stale_seconds::int) +) AS stale; + +-- name: DeleteChatHeartbeats :execrows +-- Deletes heartbeat rows for the supplied (chat_id, runner_id) pairs. +DELETE FROM chat_heartbeats +USING unnest(@chat_ids::uuid[]) WITH ORDINALITY AS chat_ids(chat_id, ord) +JOIN unnest(@runner_ids::uuid[]) WITH ORDINALITY AS runner_ids(runner_id, ord) USING (ord) +WHERE chat_heartbeats.chat_id = chat_ids.chat_id + AND chat_heartbeats.runner_id = runner_ids.runner_id; + +-- name: DeleteAllChatHeartbeats :exec +-- Deletes all heartbeat rows for the chat. Used during ownership +-- transitions that abandon a lease. +DELETE FROM chat_heartbeats WHERE chat_id = @chat_id::uuid; + diff --git a/coderd/pubsub/chatstateupdate.go b/coderd/pubsub/chatstateupdate.go new file mode 100644 index 0000000000..2983b7f8c0 --- /dev/null +++ b/coderd/pubsub/chatstateupdate.go @@ -0,0 +1,84 @@ +package pubsub + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/google/uuid" + "golang.org/x/xerrors" +) + +// ChatStateUpdateChannel returns the pubsub channel that receives one +// `chat:update:{chat_id}` message every time the chatstate state +// machine commits a transition for the chat. +func ChatStateUpdateChannel(chatID uuid.UUID) string { + return fmt.Sprintf("chat:update:%s", chatID) +} + +// ChatStateOwnershipChannel is the global pubsub channel that +// receives ownership hints when a chat is runnable but currently has +// missing or stale ownership. Workers listen on this channel to know +// when to attempt acquisition. +const ChatStateOwnershipChannel = "chat:ownership" + +// ChatStateUpdateMessage is the JSON payload published on +// [ChatStateUpdateChannel] after every successful CreateChat or +// ChatMachine.Update commit. It carries the committed post-transition +// versions and ownership identifiers so stream loops and workers can +// decide whether to refetch state. +type ChatStateUpdateMessage struct { + SnapshotVersion int64 `json:"snapshot_version"` + WorkerID *uuid.UUID `json:"worker_id,omitempty"` + RunnerID *uuid.UUID `json:"runner_id,omitempty"` + HistoryVersion int64 `json:"history_version"` + QueueVersion int64 `json:"queue_version"` + RetryStateVersion int64 `json:"retry_state_version"` + GenerationAttempt int64 `json:"generation_attempt"` + Status string `json:"status"` + Archived bool `json:"archived"` +} + +// ChatStateOwnershipMessage is the JSON payload published on +// [ChatStateOwnershipChannel] when ownership is missing or stale for +// a runnable chat. Subscribers should reload the chat row to confirm +// ownership before acting. +type ChatStateOwnershipMessage struct { + ChatID uuid.UUID `json:"chat_id"` + SnapshotVersion int64 `json:"snapshot_version"` +} + +// HandleChatStateUpdate wraps a typed callback for +// [ChatStateUpdateMessage] consumption, following the same pattern as +// HandleChatWatchEvent. +func HandleChatStateUpdate(cb func(ctx context.Context, payload ChatStateUpdateMessage, err error)) func(ctx context.Context, message []byte, err error) { + return func(ctx context.Context, message []byte, err error) { + if err != nil { + cb(ctx, ChatStateUpdateMessage{}, xerrors.Errorf("chat state update pubsub: %w", err)) + return + } + var payload ChatStateUpdateMessage + if uerr := json.Unmarshal(message, &payload); uerr != nil { + cb(ctx, ChatStateUpdateMessage{}, xerrors.Errorf("unmarshal chat state update: %w", uerr)) + return + } + cb(ctx, payload, err) + } +} + +// HandleChatStateOwnership wraps a typed callback for +// [ChatStateOwnershipMessage] consumption. +func HandleChatStateOwnership(cb func(ctx context.Context, payload ChatStateOwnershipMessage, err error)) func(ctx context.Context, message []byte, err error) { + return func(ctx context.Context, message []byte, err error) { + if err != nil { + cb(ctx, ChatStateOwnershipMessage{}, xerrors.Errorf("chat state ownership pubsub: %w", err)) + return + } + var payload ChatStateOwnershipMessage + if uerr := json.Unmarshal(message, &payload); uerr != nil { + cb(ctx, ChatStateOwnershipMessage{}, xerrors.Errorf("unmarshal chat state ownership: %w", uerr)) + return + } + cb(ctx, payload, err) + } +} diff --git a/coderd/x/chatd/chatstate/concurrency_test.go b/coderd/x/chatd/chatstate/concurrency_test.go new file mode 100644 index 0000000000..b53b1ca43e --- /dev/null +++ b/coderd/x/chatd/chatstate/concurrency_test.go @@ -0,0 +1,169 @@ +package chatstate_test + +import ( + "context" + "sync" + "testing" + + "github.com/stretchr/testify/require" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/x/chatd/chatstate" + "github.com/coder/coder/v2/testutil" +) + +// waitForChan returns true if c receives a value before ctx is done. +// Helper used in concurrency tests to avoid time.Sleep. +func waitForChan(ctx context.Context, c <-chan struct{}) bool { + select { + case <-c: + return true + case <-ctx.Done(): + return false + } +} + +// stillBlocked returns true if c has NOT received a value yet. The +// caller must already have established a happens-before ordering via +// some other channel so this check is meaningful. +func stillBlocked(c <-chan struct{}) bool { + select { + case <-c: + return false + default: + return true + } +} + +// TestLockLocksChatRow verifies that ChatMachine.Lock holds the chat +// row's FOR UPDATE lock until the callback returns, so a concurrent +// ChatMachine.Update cannot enter its callback until the Lock +// callback releases. +func TestLockLocksChatRow(t *testing.T) { + t.Parallel() + f := newTestFixture(t) + ctx := testutil.Context(t, testutil.WaitMedium) + created := createTestChat(t, f) + m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{}) + + lockEntered := make(chan struct{}) + releaseLock := make(chan struct{}) + updateEntered := make(chan struct{}) + + // Goroutine A: hold a Lock and block. + var lockErr error + var lockWG sync.WaitGroup + lockWG.Add(1) + go func() { + defer lockWG.Done() + lockErr = m.Lock(ctx, func(_ database.Store) error { + close(lockEntered) + <-releaseLock + return nil + }) + }() + + // Wait until A is inside its Lock callback (and therefore holds + // the FOR UPDATE lock). + require.True(t, waitForChan(ctx, lockEntered), "Lock callback never started") + + // Goroutine B: try to Update the same chat. It must block on + // LockChatAndBumpSnapshotVersion until A releases. + var updateErr error + var updateWG sync.WaitGroup + updateWG.Add(1) + go func() { + defer updateWG.Done() + updateErr = m.Update(ctx, func(_ *chatstate.Tx) error { + close(updateEntered) + return nil + }) + }() + + // Loop a few times re-checking that B is still blocked, with a + // fresh round-trip through the database to give B's transaction + // every chance to commit if the lock weren't held. The loop avoids + // time.Sleep by using the database call itself as a "barrier". + for range 5 { + // Force a sync round-trip through the DB. This serves the + // same role as a Sleep but is deterministic: by the time + // this read completes, the scheduler has had a chance to + // run goroutine B if it could make progress. + _, err := f.DB.GetChatByID(ctx, created.Chat.ID) + require.NoError(t, err) + require.True(t, stillBlocked(updateEntered), + "Update entered while Lock was still held") + } + + // Release Lock and confirm Update completes successfully. + close(releaseLock) + require.True(t, waitForChan(ctx, updateEntered), + "Update callback never started after Lock released") + updateWG.Wait() + lockWG.Wait() + require.NoError(t, lockErr) + require.NoError(t, updateErr) +} + +// TestLockRollsBackCallbackError verifies that a Lock callback +// returning an error rolls back the surrounding transaction. +func TestLockRollsBackCallbackError(t *testing.T) { + t.Parallel() + f := newTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + created := createTestChat(t, f) + m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{}) + + before := f.readChat(ctx, t, created.Chat.ID) + publishedBefore := len(f.Pub.channels) + + sentinel := xerrors.New("lock callback error") + err := m.Lock(ctx, func(store database.Store) error { + // Try a write that should be rolled back. + _, werr := store.UpdateChatByID(ctx, database.UpdateChatByIDParams{ + ID: created.Chat.ID, + Title: "rollback-me", + }) + require.NoError(t, werr) + return sentinel + }) + require.ErrorIs(t, err, sentinel) + + after := f.readChat(ctx, t, created.Chat.ID) + require.Equal(t, before.Title, after.Title, "Lock callback error rolls back writes") + require.Equal(t, publishedBefore, len(f.Pub.channels), "Lock publishes nothing on error") +} + +// TestConcurrentUpdatesSerializeOnChatRow verifies that two +// goroutines racing to Update the same chat both succeed but their +// effects serialize on the chat row lock: snapshot_version advances +// by exactly N (one per Update) and each transition observes the +// effects of the prior one. +func TestConcurrentUpdatesSerializeOnChatRow(t *testing.T) { + t.Parallel() + f := newTestFixture(t) + ctx := testutil.Context(t, testutil.WaitMedium) + created := createTestChat(t, f) + before := f.readChat(ctx, t, created.Chat.ID) + + const updates = 8 + var wg sync.WaitGroup + wg.Add(updates) + errs := make([]error, updates) + for i := 0; i < updates; i++ { + i := i + go func() { + defer wg.Done() + m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{}) + errs[i] = m.Update(ctx, func(_ *chatstate.Tx) error { return nil }) + }() + } + wg.Wait() + for i, err := range errs { + require.NoError(t, err, "concurrent update %d failed", i) + } + after := f.readChat(ctx, t, created.Chat.ID) + require.Equal(t, before.SnapshotVersion+int64(updates), after.SnapshotVersion, + "snapshot_version advanced by exactly one per update") +} diff --git a/coderd/x/chatd/chatstate/doc.go b/coderd/x/chatd/chatstate/doc.go new file mode 100644 index 0000000000..363a297029 --- /dev/null +++ b/coderd/x/chatd/chatstate/doc.go @@ -0,0 +1,36 @@ +// Package chatstate owns the durable execution-state transitions for +// the chatd subsystem. It implements the core state machine described +// in the chatd RFC: a 13-state execution model plus a 2-state +// ownership model on top of database rows in `chats`, +// `chat_messages`, `chat_queued_messages`, and `chat_heartbeats`. +// +// The package exposes two top-level entry points: +// +// - [CreateChat] creates a brand new chat with its initial history +// in a single transaction. It is standalone because no chat-scoped +// state machine instance can exist before the chat row is written. +// - [ChatMachine] wraps an existing chat. Callers use it to apply +// one or more transitions atomically via [ChatMachine.Update], or +// to read related rows while holding the chat row lock via +// [ChatMachine.Lock]. +// +// Every successful [ChatMachine.Update] call locks the chat row, +// advances `snapshot_version` exactly once, applies transition methods +// in order, and (on commit) publishes a single typed `chat:update` +// pubsub message describing the post-transition snapshot. Optional +// `chat:ownership` hints are published only when the post-transition +// state is runnable and ownership is missing or stale. Stream side +// effects are handled by `chat:update` consumers, and ownership hints +// wake chat workers. +// +// Transition methods are explicit, typed wrappers around the SQL +// mutations needed to move between states. Each transition reads the +// current chat row and queue cardinality, classifies the resulting +// execution state, and rejects with an [*TransitionError] wrapping +// [ErrTransitionNotAllowed] when the transition is not legal from +// that state. The transition matrix and +// state classification helpers live in `state.go` and `transition.go` +// alongside unit-testable classifiers; the SQL is in +// `coderd/database/queries/chats.sql` (e.g. `LockChatAndBumpSnapshotVersion`, +// `UpdateChatExecutionState`). +package chatstate diff --git a/coderd/x/chatd/chatstate/errors.go b/coderd/x/chatd/chatstate/errors.go new file mode 100644 index 0000000000..2dd3874b6d --- /dev/null +++ b/coderd/x/chatd/chatstate/errors.go @@ -0,0 +1,152 @@ +package chatstate + +import ( + "errors" + "fmt" + + "golang.org/x/xerrors" +) + +// Sentinel errors returned by chatstate transitions and helpers. +// Callers should use errors.Is to test for these. +var ( + // ErrTransitionNotAllowed is returned when a transition is applied + // to a chat whose current execution state does not permit it. The + // concrete error returned by transition methods is a + // *TransitionError that wraps this sentinel. + ErrTransitionNotAllowed = xerrors.New("chat state transition not allowed") + + // ErrInvalidState is returned when the chat row, queue, and + // archive flag together produce a combination outside the 13 + // valid execution states described in the RFC. + ErrInvalidState = xerrors.New("chat is in an invalid execution state") + + // ErrQueuedMessageNotFound is returned by queue-targeting + // transitions (delete, promote) when the supplied queued message + // ID does not match a row on the chat. + ErrQueuedMessageNotFound = xerrors.New("queued message not found") + + // ErrMessageNotFound is returned by [Tx.EditMessage] when the + // target chat_messages row is missing or belongs to another chat. + ErrMessageNotFound = xerrors.New("chat message not found") + + // ErrChatNotFound is returned when a non-create transition is + // applied to a chat row that does not exist (or has been deleted + // since the transition started). + ErrChatNotFound = xerrors.New("chat not found") + + // ErrChatNotRoot is returned by family-archive helpers when the + // supplied chat is not a root chat (its parent_chat_id is set). + ErrChatNotRoot = xerrors.New("chat is not a root chat") + + // ErrEditedMessageNotUser is returned by [Tx.EditMessage] when the + // targeted chat_messages row exists but its role is not user. + ErrEditedMessageNotUser = xerrors.New("only user messages can be edited") + + // ErrMessageQueueFull is returned by queue-appending transitions + // when the per-chat queue cap has been reached. The concrete + // error returned by transitions is a *MessageQueueFullError that + // wraps this sentinel. + ErrMessageQueueFull = xerrors.New("chat message queue is full") + + // ErrToolResultDuplicate is returned by [Tx.CompleteRequiresAction] + // when the same tool_call_id appears more than once in the + // submitted results. + ErrToolResultDuplicate = xerrors.New("duplicate tool result") + + // ErrToolResultUnexpected is returned by + // [Tx.CompleteRequiresAction] when a submitted tool_call_id does + // not correspond to a pending dynamic tool call. + ErrToolResultUnexpected = xerrors.New("unexpected tool result") + + // ErrToolResultMissing is returned by [Tx.CompleteRequiresAction] + // when a pending dynamic tool call has no submitted result. + ErrToolResultMissing = xerrors.New("missing tool result") + + // ErrToolResultInvalidJSON is returned by + // [Tx.CompleteRequiresAction] when a submitted tool result output + // is not valid JSON. + ErrToolResultInvalidJSON = xerrors.New("tool result output is not valid JSON") +) + +// MessageQueueFullError carries the per-chat queue cap so HTTP +// endpoints can include the cap in their response detail. It wraps +// [ErrMessageQueueFull] so callers can match it with errors.Is. +type MessageQueueFullError struct { + Max int64 +} + +// Error implements the error interface. +func (e *MessageQueueFullError) Error() string { + return fmt.Sprintf("chat message queue is full (max %d)", e.Max) +} + +// Unwrap returns [ErrMessageQueueFull] so callers can match the +// generic sentinel. +func (*MessageQueueFullError) Unwrap() error { return ErrMessageQueueFull } + +// ToolResultValidationError carries a structured tool-result +// validation failure. It always wraps a specific sentinel +// (ErrToolResultDuplicate, ErrToolResultMissing, +// ErrToolResultUnexpected, ErrToolResultInvalidJSON) so callers can +// match either the generic sentinel or the specific cause. +type ToolResultValidationError struct { + Cause error + ToolCallID string +} + +// Error implements the error interface. +func (e *ToolResultValidationError) Error() string { + if e.ToolCallID != "" { + return fmt.Sprintf("%s: %s", e.Cause.Error(), e.ToolCallID) + } + return e.Cause.Error() +} + +// Unwrap returns the specific cause so callers can match it. +func (e *ToolResultValidationError) Unwrap() error { return e.Cause } + +// TransitionError carries the structured detail for a rejected +// transition. It always wraps [ErrTransitionNotAllowed] so callers can +// match with errors.Is without losing context. When a specific +// chatstate sentinel is the proximate cause, Cause is set and +// errors.Is will match that sentinel too. +type TransitionError struct { + Transition Transition + From ExecutionState + Reason string + Cause error +} + +// Error implements the error interface. +func (e *TransitionError) Error() string { + if e.Reason == "" { + return fmt.Sprintf( + "chat state transition %s not allowed from state %s", + e.Transition, e.From, + ) + } + return fmt.Sprintf( + "chat state transition %s not allowed from state %s: %s", + e.Transition, e.From, e.Reason, + ) +} + +// Unwrap returns the error chain attached to this error. The chain +// always includes [ErrTransitionNotAllowed], and may include a more +// specific cause through errors.Join, so callers can use errors.Is +// without custom matching logic on TransitionError. +func (e *TransitionError) Unwrap() error { return e.Cause } + +// newTransitionError constructs a typed TransitionError. Returning the +// pointer type lets callers inspect the structured fields when needed. +func newTransitionError(t Transition, from ExecutionState, reason string) *TransitionError { + return &TransitionError{Transition: t, From: from, Reason: reason, Cause: ErrTransitionNotAllowed} +} + +// newTransitionErrorWithCause constructs a TransitionError carrying +// a specific underlying sentinel so callers can match the cause with +// errors.Is. +func newTransitionErrorWithCause(t Transition, from ExecutionState, cause error, reason string) *TransitionError { + return &TransitionError{Transition: t, From: from, Reason: reason, Cause: errors.Join(ErrTransitionNotAllowed, cause)} +} diff --git a/coderd/x/chatd/chatstate/family.go b/coderd/x/chatd/chatstate/family.go new file mode 100644 index 0000000000..c35b2ef418 --- /dev/null +++ b/coderd/x/chatd/chatstate/family.go @@ -0,0 +1,130 @@ +package chatstate + +import ( + "context" + "database/sql" + "errors" + + "github.com/google/uuid" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/database" +) + +// SetFamilyArchivedInput configures [SetFamilyArchived]. The struct +// shape avoids a boolean flag parameter at the API surface; callers +// build it explicitly with named fields for clarity. +type SetFamilyArchivedInput struct { + // RootID identifies the family root. SetFamilyArchived rejects + // calls for child chats with [ErrChatNotRoot] and unknown chats + // with [ErrChatNotFound]. + RootID uuid.UUID + // Archived is the desired post-call archived value for every + // family member. + Archived bool +} + +// SetFamilyArchived runs Update for every chat in the root chat's +// family inside one transaction, applying SetArchived when the chat's +// archived flag differs from the requested value. It owns its +// transaction lifecycle and its [PublishBuffer] lifecycle: pubsub +// publications are buffered while the transaction is open and +// flushed only after a successful commit; the deferred Discard +// suppresses every buffered publication on failure. +// +// On success SetFamilyArchived returns one [database.Chat] per +// family member in the order returned by GetChatFamilyIDsByRootID +// (root first, then children). +// +// Family members that are already in the [StateInvalid] execution +// state cause SetFamilyArchived to return [ErrInvalidState] and roll +// back the cascade even when their archived flag already matches the +// desired value; invalid-state detection is never bypassed. +// +// Family members that are valid and already match the desired +// archived value still run through Update, which increments their +// snapshot version and publishes a fresh snapshot without changing +// the archived flag. Advancing the snapshot version without a field +// change is safe, and it keeps publication behavior uniform while a +// partially archived family converges to the desired state. +func SetFamilyArchived( + ctx context.Context, + store database.Store, + publisher Publisher, + input SetFamilyArchivedInput, +) ([]database.Chat, error) { + if store == nil { + return nil, xerrors.New("chatstate: SetFamilyArchived called with nil store") + } + if publisher == nil { + return nil, xerrors.New("chatstate: SetFamilyArchived called with nil publisher") + } + + buffer := NewPublishBuffer(publisher) + defer buffer.Discard() + + var familyChats []database.Chat + err := store.InTx(func(tx database.Store) error { + // Lock the root chat first so concurrent archive races on the + // same family serialize on a stable row. + root, err := tx.GetChatByIDForUpdate(ctx, input.RootID) + if errors.Is(err, sql.ErrNoRows) { + return ErrChatNotFound + } + if err != nil { + return xerrors.Errorf("lock root chat for archive: %w", err) + } + if root.ParentChatID.Valid { + return ErrChatNotRoot + } + ids, err := tx.GetChatFamilyIDsByRootID(ctx, input.RootID) + if err != nil { + return xerrors.Errorf("get chat family: %w", err) + } + if len(ids) == 0 { + return ErrChatNotFound + } + familyChats = make([]database.Chat, 0, len(ids)) + for _, id := range ids { + var chat database.Chat + machine := NewChatMachine(tx, buffer, id, Options{}) + err := machine.Update(ctx, func(state *Tx) error { + // Load the current chat and classify it so we can + // reject invalid-state members with ErrInvalidState + // even when their archived flag already matches. + current, from, lerr := state.loadState() + if lerr != nil { + return lerr + } + if from == StateInvalid { + return ErrInvalidState + } + if current.Archived == input.Archived { + chat = current + return nil + } + if _, err := state.SetArchived(SetArchivedInput{Archived: input.Archived}); err != nil { + return err + } + var err error + chat, err = state.Store().GetChatByID(state.Ctx(), state.ChatID()) + if err != nil { + return xerrors.Errorf("reload archived chat: %w", err) + } + return nil + }) + if err != nil { + return err + } + familyChats = append(familyChats, chat) + } + return nil + }, nil) + if err != nil { + return nil, err + } + if err := buffer.Flush(); err != nil { + return familyChats, err + } + return familyChats, nil +} diff --git a/coderd/x/chatd/chatstate/family_test.go b/coderd/x/chatd/chatstate/family_test.go new file mode 100644 index 0000000000..8fe78bc699 --- /dev/null +++ b/coderd/x/chatd/chatstate/family_test.go @@ -0,0 +1,218 @@ +package chatstate_test + +import ( + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbgen" + "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/coderd/x/chatd/chatprompt" + "github.com/coder/coder/v2/coderd/x/chatd/chatstate" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/testutil" +) + +// TestSetFamilyArchivedRejectsChildChat asserts the chatstate helper +// rejects calls that target a child chat. Family archive flows must +// always start at the root. +func TestSetFamilyArchivedRejectsChildChat(t *testing.T) { + t.Parallel() + f := newTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + + root := dbgen.Chat(t, f.DB, database.Chat{ + OrganizationID: f.Org.ID, + OwnerID: f.User.ID, + LastModelConfigID: f.Model.ID, + Title: "root", + }) + child := dbgen.Chat(t, f.DB, database.Chat{ + OrganizationID: f.Org.ID, + OwnerID: f.User.ID, + LastModelConfigID: f.Model.ID, + Title: "child", + ParentChatID: uuid.NullUUID{UUID: root.ID, Valid: true}, + RootChatID: uuid.NullUUID{UUID: root.ID, Valid: true}, + }) + + _, err := chatstate.SetFamilyArchived(ctx, f.DB, f.Pub, chatstate.SetFamilyArchivedInput{RootID: child.ID, Archived: true}) + require.ErrorIs(t, err, chatstate.ErrChatNotRoot) + + require.False(t, f.readChat(ctx, t, root.ID).Archived, + "failed family archive must not touch the root") + require.False(t, f.readChat(ctx, t, child.ID).Archived, + "failed family archive must not touch the child") +} + +// TestSetFamilyArchivedRollsBackWhenMemberCannotArchive verifies that +// SetFamilyArchived is atomic: when one family member is in a state +// that cannot satisfy the SetArchived transition, the whole cascade +// rolls back and no publications reach the inner publisher. +func TestSetFamilyArchivedRollsBackWhenMemberCannotArchive(t *testing.T) { + t.Parallel() + db, _ := dbtestutil.NewDB(t) + ctx := testutil.Context(t, testutil.WaitShort) + user, org, model := seedFamilyDeps(t, db) + + // Root chat: waiting is archive-eligible (state W). + root := dbgen.Chat(t, db, database.Chat{ + OrganizationID: org.ID, + OwnerID: user.ID, + LastModelConfigID: model.ID, + Title: "root", + Status: database.ChatStatusWaiting, + }) + // Child chat: running with no queue is R0 and NOT archive + // eligible per the chatstate transition matrix. + child := dbgen.Chat(t, db, database.Chat{ + OrganizationID: org.ID, + OwnerID: user.ID, + LastModelConfigID: model.ID, + Title: "child", + Status: database.ChatStatusRunning, + ParentChatID: uuid.NullUUID{UUID: root.ID, Valid: true}, + RootChatID: uuid.NullUUID{UUID: root.ID, Valid: true}, + }) + + pub := newRecordingPubsub() + _, err := chatstate.SetFamilyArchived(ctx, db, pub, chatstate.SetFamilyArchivedInput{RootID: root.ID, Archived: true}) + require.Error(t, err, "child in R0 must reject SetArchived") + require.ErrorIs(t, err, chatstate.ErrTransitionNotAllowed) + + rootAfter, err := db.GetChatByID(ctx, root.ID) + require.NoError(t, err) + require.False(t, rootAfter.Archived, "root archive must roll back when a child cannot archive") + childAfter, err := db.GetChatByID(ctx, child.ID) + require.NoError(t, err) + require.False(t, childAfter.Archived, "child must not be archived in the rolled-back cascade") + + require.Empty(t, pub.channels, + "rolled-back family archive must publish nothing through the inner publisher") +} + +// TestSetFamilyArchivedRejectsInvalidStateEvenWhenAlreadyDesired +// verifies that invalid-state detection is never bypassed: a family +// member in StateInvalid causes the cascade to fail with +// ErrInvalidState even when that member's archived flag already +// matches the desired value. +func TestSetFamilyArchivedRejectsInvalidStateEvenWhenAlreadyDesired(t *testing.T) { + t.Parallel() + db, _ := dbtestutil.NewDB(t) + ctx := testutil.Context(t, testutil.WaitShort) + user, org, model := seedFamilyDeps(t, db) + + root := dbgen.Chat(t, db, database.Chat{ + OrganizationID: org.ID, + OwnerID: user.ID, + LastModelConfigID: model.ID, + Title: "root", + Status: database.ChatStatusWaiting, + }) + child := dbgen.Chat(t, db, database.Chat{ + OrganizationID: org.ID, + OwnerID: user.ID, + LastModelConfigID: model.ID, + Title: "child", + // status=waiting, archived=true; we will add a queued message + // to produce the chatstate-invalid combination (archived chat + // with a queued backlog is outside the 13 valid states). + Status: database.ChatStatusWaiting, + Archived: true, + ParentChatID: uuid.NullUUID{UUID: root.ID, Valid: true}, + RootChatID: uuid.NullUUID{UUID: root.ID, Valid: true}, + }) + + // Seed a queued message under the child to push it into the + // chatstate-invalid combination. + rawContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{ + codersdk.ChatMessageText("queued"), + }) + require.NoError(t, err) + _, err = db.InsertChatQueuedMessage(ctx, database.InsertChatQueuedMessageParams{ + ChatID: child.ID, + Content: rawContent.RawMessage, + ModelConfigID: uuid.NullUUID{}, + }) + require.NoError(t, err) + + pub := newRecordingPubsub() + _, err = chatstate.SetFamilyArchived(ctx, db, pub, chatstate.SetFamilyArchivedInput{ + RootID: root.ID, + Archived: true, + }) + require.ErrorIs(t, err, chatstate.ErrInvalidState, + "invalid-state child blocks the cascade even when archived flag already matches") + + // Root must not be archived because the cascade rolled back. + rootAfter, err := db.GetChatByID(ctx, root.ID) + require.NoError(t, err) + require.False(t, rootAfter.Archived, "root must roll back when a child is in StateInvalid") + + require.Empty(t, pub.channels, + "rolled-back cascade must not publish anything") +} + +// TestSetFamilyArchivedAcceptsAlreadyDesiredMembers verifies that an +// individually archived child does not block a root archive cascade. +// The cascade converges to the desired state even when some family +// members already match it. +func TestSetFamilyArchivedAcceptsAlreadyDesiredMembers(t *testing.T) { + t.Parallel() + db, _ := dbtestutil.NewDB(t) + ctx := testutil.Context(t, testutil.WaitShort) + user, org, model := seedFamilyDeps(t, db) + + root := dbgen.Chat(t, db, database.Chat{ + OrganizationID: org.ID, + OwnerID: user.ID, + LastModelConfigID: model.ID, + Title: "root", + Status: database.ChatStatusWaiting, + }) + child := dbgen.Chat(t, db, database.Chat{ + OrganizationID: org.ID, + OwnerID: user.ID, + LastModelConfigID: model.ID, + Title: "child", + Status: database.ChatStatusWaiting, + ParentChatID: uuid.NullUUID{UUID: root.ID, Valid: true}, + RootChatID: uuid.NullUUID{UUID: root.ID, Valid: true}, + Archived: true, + }) + + pub := newRecordingPubsub() + family, err := chatstate.SetFamilyArchived(ctx, db, pub, chatstate.SetFamilyArchivedInput{RootID: root.ID, Archived: true}) + require.NoError(t, err, + "already archived members must not block the cascade") + require.Len(t, family, 2) + + rootAfter, err := db.GetChatByID(ctx, root.ID) + require.NoError(t, err) + require.True(t, rootAfter.Archived) + childAfter, err := db.GetChatByID(ctx, child.ID) + require.NoError(t, err) + require.True(t, childAfter.Archived) +} + +func seedFamilyDeps(t *testing.T, db database.Store) (database.User, database.Organization, database.ChatModelConfig) { + t.Helper() + user := dbgen.User(t, db, database.User{}) + org := dbgen.Organization(t, db, database.Organization{}) + dbgen.OrganizationMember(t, db, database.OrganizationMember{ + UserID: user.ID, + OrganizationID: org.ID, + }) + dbgen.ChatProvider(t, db, database.ChatProvider{ + Provider: "openai", + DisplayName: "openai", + BaseUrl: "http://example.invalid", + }) + model := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{ + Provider: "openai", + IsDefault: true, + }) + return user, org, model +} diff --git a/coderd/x/chatd/chatstate/helpers_test.go b/coderd/x/chatd/chatstate/helpers_test.go new file mode 100644 index 0000000000..d8e92ed274 --- /dev/null +++ b/coderd/x/chatd/chatstate/helpers_test.go @@ -0,0 +1,92 @@ +package chatstate_test + +import ( + "context" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/coderd/database" + coderdpubsub "github.com/coder/coder/v2/coderd/pubsub" + "github.com/coder/coder/v2/coderd/x/chatd/chatstate" + "github.com/coder/coder/v2/testutil" +) + +// ownershipPublishCount returns the number of `chat:ownership` messages +// recorded so far on the test publisher. Tests use it to assert that +// transitions do or do not publish an ownership hint. +func (r *recordingPubsub) ownershipPublishCount() int { + count := 0 + for _, c := range r.channels { + if c == coderdpubsub.ChatStateOwnershipChannel { + count++ + } + } + return count +} + +// sendQueuedMessage seeds one queued user message via SendMessage with +// BusyBehaviorQueue. The chat must already be in a state that allows +// SendMessage (typically R0, R1, or I*). +func sendQueuedMessage(t *testing.T, f *testFixture, m *chatstate.ChatMachine, body string) chatstate.SendMessageResult { + t.Helper() + ctx := testutil.Context(t, testutil.WaitShort) + var send chatstate.SendMessageResult + require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error { + var err error + send, err = tx.SendMessage(chatstate.SendMessageInput{ + Message: userTextMessage(body, f.User.ID, f.Model.ID), + BusyBehavior: chatstate.BusyBehaviorQueue, + }) + return err + })) + return send +} + +// sendInterruptMessage seeds one queued user message via SendMessage +// with BusyBehaviorInterrupt. From R0/R1 this transitions the chat to +// `interrupting` and appends the new user message to the queue tail. +func sendInterruptMessage(t *testing.T, f *testFixture, m *chatstate.ChatMachine, body string) chatstate.SendMessageResult { + t.Helper() + ctx := testutil.Context(t, testutil.WaitShort) + var send chatstate.SendMessageResult + require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error { + var err error + send, err = tx.SendMessage(chatstate.SendMessageInput{ + Message: userTextMessage(body, f.User.ID, f.Model.ID), + BusyBehavior: chatstate.BusyBehaviorInterrupt, + }) + return err + })) + return send +} + +// queuedIDsByPosition returns the queued-message IDs for the chat in +// queue order. +func queuedIDsByPosition(ctx context.Context, t *testing.T, f *testFixture, chatID uuid.UUID) []int64 { + t.Helper() + rows, err := f.DB.GetChatQueuedMessagesByPosition(ctx, chatID) + require.NoError(t, err) + ids := make([]int64, len(rows)) + for i, r := range rows { + ids[i] = r.ID + } + return ids +} + +// historyMessageIDs returns the chat history message IDs ordered by +// row id. Used to assert that PromoteQueuedMessage from R1/I1 does NOT +// insert any history rows. +func historyMessageIDs(ctx context.Context, t *testing.T, f *testFixture, chatID uuid.UUID) []int64 { + t.Helper() + msgs, err := f.DB.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ + ChatID: chatID, + }) + require.NoError(t, err) + out := make([]int64, len(msgs)) + for i, m := range msgs { + out[i] = m.ID + } + return out +} diff --git a/coderd/x/chatd/chatstate/machine.go b/coderd/x/chatd/chatstate/machine.go new file mode 100644 index 0000000000..11a903b870 --- /dev/null +++ b/coderd/x/chatd/chatstate/machine.go @@ -0,0 +1,300 @@ +package chatstate + +import ( + "context" + "database/sql" + "errors" + + "github.com/google/uuid" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/database" + coderdpubsub "github.com/coder/coder/v2/coderd/pubsub" +) + +// HeartbeatStaleSeconds is the threshold chatstate uses when deciding +// whether to publish a `chat:ownership` hint for a runnable chat. A +// heartbeat older than this many seconds (by database time) counts +// as stale and triggers a hint so an idle worker can attempt a +// takeover. +const HeartbeatStaleSeconds = 30 + +// Options configures a [ChatMachine]. Reserved for future tunables; +// currently empty. +type Options struct{} + +// ChatMachine is a chat-scoped handle for state-machine operations on +// a single chat row. It captures the database store, the pubsub +// publisher, and the chat ID at construction time so callers do not +// have to thread them through Update, Lock, or any transition method. +// +// ChatMachine values are cheap. Create one per chat for the lifetime +// of a request or worker turn; do not cache mutable chat state across +// calls. +type ChatMachine struct { + store database.Store + publisher Publisher + chatID uuid.UUID + opts Options +} + +// NewChatMachine constructs a chat-scoped state machine handle. The +// store may be the root database handle or an existing transaction +// handle; publisher is the pubsub used for `chat:update` and +// `chat:ownership` emissions. Both are required and captured for the +// lifetime of the returned machine. +func NewChatMachine( + store database.Store, + publisher Publisher, + chatID uuid.UUID, + opts Options, +) *ChatMachine { + return &ChatMachine{ + store: store, + publisher: publisher, + chatID: chatID, + opts: opts, + } +} + +// ChatID returns the chat ID this machine is scoped to. +func (m *ChatMachine) ChatID() uuid.UUID { return m.chatID } + +// Tx is the per-transaction handle passed to [ChatMachine.Update] +// callbacks. It carries the active context, the transactional store, +// and the chat ID. Tx does not cache mutable chat state across calls: +// every transition method reads the chat row and queue cardinality +// from the database on entry, so a bundle of transitions inside one +// Update callback always validates against the latest committed state. +type Tx struct { + ctx context.Context + store database.Store + chatID uuid.UUID +} + +// Ctx returns the context the surrounding [ChatMachine.Update] call +// is using. +func (tx *Tx) Ctx() context.Context { return tx.ctx } + +// ChatID returns the chat ID this transaction is scoped to. +func (tx *Tx) ChatID() uuid.UUID { return tx.chatID } + +// Store exposes the active transaction store so callers can perform +// validation reads (for example loading the messages affected by an +// EditMessage transition) and metadata writes (for example updating +// title or labels) that must be atomic with the transition. +// +// Callers MUST NOT use Store to mutate execution-state tables +// (chats.status, chat_messages, chat_queued_messages, chat_heartbeats, +// or the version fields on chats). Those mutations belong to the +// transition methods and are validated against the state machine +// matrix. +func (tx *Tx) Store() database.Store { return tx.store } + +// loadState reads the current chat row and queue cardinality from the +// active transaction, classifies the execution state, and returns the +// inputs every transition method needs. Returns ErrChatNotFound if +// the chat row was deleted in this transaction (or never existed). +func (tx *Tx) loadState() (database.Chat, ExecutionState, error) { + chat, err := tx.store.GetChatByID(tx.ctx, tx.chatID) + if errors.Is(err, sql.ErrNoRows) { + return database.Chat{}, StateN, ErrChatNotFound + } + if err != nil { + return database.Chat{}, "", xerrors.Errorf("load chat: %w", err) + } + count, err := tx.store.CountChatQueuedMessages(tx.ctx, tx.chatID) + if err != nil { + return database.Chat{}, "", xerrors.Errorf("count queued messages: %w", err) + } + return chat, ClassifyExecutionState(chat, count > 0, true), nil +} + +// requireFromAllowed loads the current state and validates t against +// the transition matrix. Returns the loaded chat and execution state +// on success, [ErrInvalidState] when the chat is in an invalid state +// and t is not [TransitionReconcileInvalidState], and a typed +// *TransitionError otherwise. +func (tx *Tx) requireFromAllowed(t Transition) (database.Chat, ExecutionState, error) { + chat, from, err := tx.loadState() + if err != nil { + return chat, from, err + } + if from == StateInvalid && t != TransitionReconcileInvalidState { + return chat, from, ErrInvalidState + } + if err := requireExecutionTransition(t, from); err != nil { + return chat, from, err + } + return chat, from, nil +} + +// Update applies one or more transitions to the machine's chat. +// +// Update opens a transaction on the captured store, atomically locks +// the chat row with FOR UPDATE and increments `snapshot_version` +// exactly once, then runs fn against a fresh [*Tx]. It constructs a +// [PublishBuffer], enqueues `chat:update` (and a `chat:ownership` hint +// when the post-transition state is worker-runnable and ownership is +// missing or stale) inside the transaction, and flushes the buffer only after +// the transaction function succeeds. If the transaction rolls back, +// the deferred Discard suppresses every buffered publication so +// subscribers never see uncommitted state. +// +// If Update is called with a store that is already in a transaction, +// [database.Store.InTx] reuses the active transaction. In that case, +// callers that need outer-transaction publication semantics can pass a +// [PublishBuffer] as the machine publisher. The inner buffer flushes +// into the outer buffer, and the outer owner remains responsible for +// publishing only after the outer transaction commits. +// +// Callers must not pass a store or publisher here; they belong on the +// machine. +// +// If the chat row does not exist, Update returns [ErrChatNotFound] +// without mutating anything. +// +// Callbacks that return an error roll back the transaction (rolling +// back the automatic snapshot bump) and publish nothing. +func (m *ChatMachine) Update( + ctx context.Context, + fn func(*Tx) error, +) error { + if m.store == nil { + return xerrors.New("chatstate: ChatMachine has nil store") + } + if m.publisher == nil { + return xerrors.New("chatstate: ChatMachine has nil publisher") + } + + buffer := NewPublishBuffer(m.publisher) + defer buffer.Discard() + + err := m.store.InTx(func(store database.Store) error { + if _, err := store.LockChatAndBumpSnapshotVersion(ctx, m.chatID); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return ErrChatNotFound + } + return xerrors.Errorf("lock chat and bump snapshot: %w", err) + } + tx := &Tx{ + ctx: ctx, + store: store, + chatID: m.chatID, + } + if err := fn(tx); err != nil { + return err + } + chat, state, err := tx.loadState() + if err != nil { + return err + } + if err := buffer.Publish( + coderdpubsub.ChatStateUpdateChannel(chat.ID), + buildChatUpdateMessage(chat), + ); err != nil { + return xerrors.Errorf("buffer chat update: %w", err) + } + if state.IsRunnable() { + stale, herr := ownershipStaleOrMissing(ctx, store, chat, HeartbeatStaleSeconds) + if herr != nil { + return xerrors.Errorf("evaluate ownership: %w", herr) + } + if stale { + if err := buffer.Publish( + coderdpubsub.ChatStateOwnershipChannel, + buildChatOwnershipMessage(chat), + ); err != nil { + return xerrors.Errorf("buffer ownership hint: %w", err) + } + } + } + return nil + }, nil) + if err != nil { + return err + } + return buffer.Flush() +} + +// Lock locks the chat row with FOR UPDATE and runs fn in a +// transaction without advancing snapshot_version. It uses the store +// captured by [NewChatMachine]. Use it when the caller needs a +// consistent chat snapshot plus related rows such as messages or +// queued messages but is NOT applying a transition. +// +// Callers must not pass a store here; it belongs on the machine. +// +// Lock publishes nothing. Callback errors roll back the transaction +// and propagate to the caller. +func (m *ChatMachine) Lock( + ctx context.Context, + fn func(database.Store) error, +) error { + if m.store == nil { + return xerrors.New("chatstate: ChatMachine has nil store") + } + return m.store.InTx(func(store database.Store) error { + // GetChatByIDForUpdate locks the row WITHOUT bumping snapshot. + _, err := store.GetChatByIDForUpdate(ctx, m.chatID) + if errors.Is(err, sql.ErrNoRows) { + return ErrChatNotFound + } + if err != nil { + return xerrors.Errorf("lock chat: %w", err) + } + return fn(store) + }, nil) +} + +// ReadLock takes a shared lock on the chat row with FOR SHARE and runs +// fn in a transaction without advancing snapshot_version. It uses the +// store captured by [NewChatMachine]. Use it when the caller needs a +// consistent chat snapshot plus related rows such as messages or queued +// messages but is NOT applying a transition and does NOT need to block +// concurrent readers. +// +// Unlike [ChatMachine.Lock], the FOR SHARE lock permits other shared +// lockers to proceed concurrently while still blocking writers that take +// FOR UPDATE (such as [ChatMachine.Update] and [ChatMachine.Lock]) until +// the transaction commits. +// +// Callers must not pass a store here; it belongs on the machine. +// +// ReadLock publishes nothing. Callback errors roll back the transaction +// and propagate to the caller. +func (m *ChatMachine) ReadLock( + ctx context.Context, + fn func(database.Store) error, +) error { + if m.store == nil { + return xerrors.New("chatstate: ChatMachine has nil store") + } + return m.store.InTx(func(store database.Store) error { + // GetChatByIDForShare takes a shared lock on the row WITHOUT + // bumping snapshot. + _, err := store.GetChatByIDForShare(ctx, m.chatID) + if errors.Is(err, sql.ErrNoRows) { + return ErrChatNotFound + } + if err != nil { + return xerrors.Errorf("read lock chat: %w", err) + } + return fn(store) + }, nil) +} + +// ownershipStaleOrMissing reports whether the chat's current +// (chat_id, runner_id) lease is missing or stale. The staleSeconds +// threshold is forwarded to [database.IsChatHeartbeatStale] so the +// comparison runs against database time inside a single SQL query. +func ownershipStaleOrMissing(ctx context.Context, store database.Store, chat database.Chat, staleSeconds int32) (bool, error) { + if !chat.WorkerID.Valid || !chat.RunnerID.Valid { + return true, nil + } + return store.IsChatHeartbeatStale(ctx, database.IsChatHeartbeatStaleParams{ + ChatID: chat.ID, + RunnerID: chat.RunnerID.UUID, + StaleSeconds: staleSeconds, + }) +} diff --git a/coderd/x/chatd/chatstate/machine_test.go b/coderd/x/chatd/chatstate/machine_test.go new file mode 100644 index 0000000000..52b533e707 --- /dev/null +++ b/coderd/x/chatd/chatstate/machine_test.go @@ -0,0 +1,406 @@ +package chatstate_test + +import ( + "context" + "encoding/json" + "sync" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbgen" + "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/coderd/database/pubsub" + coderdpubsub "github.com/coder/coder/v2/coderd/pubsub" + "github.com/coder/coder/v2/coderd/x/chatd/chatprompt" + "github.com/coder/coder/v2/coderd/x/chatd/chatstate" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/testutil" +) + +// testFixture bundles the resources every integration test needs: +// a database, a publisher recorder, a user/org/model triple, and +// helper accessors. It is intentionally NOT a generic chatd test +// fixture; tests outside this package should not depend on it. +type testFixture struct { + DB database.Store + PubSub pubsub.Pubsub + Pub *recordingPubsub + User database.User + Org database.Organization + Model database.ChatModelConfig +} + +func newTestFixture(t *testing.T) *testFixture { + t.Helper() + db, ps := dbtestutil.NewDB(t) + user := dbgen.User(t, db, database.User{}) + org := dbgen.Organization(t, db, database.Organization{}) + dbgen.OrganizationMember(t, db, database.OrganizationMember{ + UserID: user.ID, + OrganizationID: org.ID, + }) + dbgen.ChatProvider(t, db, database.ChatProvider{ + Provider: "openai", + DisplayName: "openai", + BaseUrl: "http://example.invalid", + }) + model := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{ + Provider: "openai", + IsDefault: true, + }) + pub := newRecordingPubsub() + return &testFixture{ + DB: db, + PubSub: ps, + Pub: pub, + User: user, + Org: org, + Model: model, + } +} + +// readChat re-reads the chat from the database. Tests use this to +// verify post-transition state because transition results no longer +// carry the chat snapshot. +func (f *testFixture) readChat(ctx context.Context, t *testing.T, chatID uuid.UUID) database.Chat { + t.Helper() + chat, err := f.DB.GetChatByID(ctx, chatID) + require.NoError(t, err) + return chat +} + +// classify reads the chat plus queue cardinality and returns the RFC +// execution state shorthand. +func (f *testFixture) classify(ctx context.Context, t *testing.T, chatID uuid.UUID) chatstate.ExecutionState { + t.Helper() + chat := f.readChat(ctx, t, chatID) + count, err := f.DB.CountChatQueuedMessages(ctx, chatID) + require.NoError(t, err) + return chatstate.ClassifyExecutionState(chat, count > 0, true) +} + +// recordingPubsub captures every Publish call so tests can assert on +// the chatstate notifications without needing a live subscriber. The +// mutex makes it safe to use from concurrent tests that race multiple +// goroutines through the same publisher (see TestConcurrentUpdatesSerializeOnChatRow). +type recordingPubsub struct { + mu sync.Mutex + channels []string + payloads [][]byte +} + +func newRecordingPubsub() *recordingPubsub { return &recordingPubsub{} } + +func (r *recordingPubsub) Publish(channel string, payload []byte) error { + r.mu.Lock() + defer r.mu.Unlock() + r.channels = append(r.channels, channel) + cp := make([]byte, len(payload)) + copy(cp, payload) + r.payloads = append(r.payloads, cp) + return nil +} + +// expectChatUpdate finds the most recent chat:update message on the +// per-chat channel and asserts that it has snapshot_version == want. +func (r *recordingPubsub) expectChatUpdate(t *testing.T, chatID uuid.UUID, wantSnapshot int64) { + t.Helper() + channel := coderdpubsub.ChatStateUpdateChannel(chatID) + for i := len(r.channels) - 1; i >= 0; i-- { + if r.channels[i] != channel { + continue + } + var msg coderdpubsub.ChatStateUpdateMessage + require.NoError(t, json.Unmarshal(r.payloads[i], &msg)) + require.Equal(t, wantSnapshot, msg.SnapshotVersion) + return + } + t.Fatalf("no chat:update on %s", channel) +} + +func (r *recordingPubsub) hasOwnership() bool { + for _, c := range r.channels { + if c == coderdpubsub.ChatStateOwnershipChannel { + return true + } + } + return false +} + +func userTextMessage(text string, createdBy uuid.UUID, modelConfigID uuid.UUID) chatstate.Message { + parts := []codersdk.ChatMessagePart{codersdk.ChatMessageText(text)} + raw, err := chatprompt.MarshalParts(parts) + if err != nil { + panic(err) + } + return chatstate.Message{ + Role: database.ChatMessageRoleUser, + Content: raw, + Visibility: database.ChatMessageVisibilityBoth, + ContentVersion: chatprompt.CurrentContentVersion, + CreatedBy: uuid.NullUUID{UUID: createdBy, Valid: true}, + ModelConfigID: uuid.NullUUID{UUID: modelConfigID, Valid: true}, + } +} + +// createTestChat is the standard "fresh R0 chat" helper used by other +// tests. It exercises CreateChat itself. +func createTestChat(t *testing.T, f *testFixture) chatstate.CreateChatResult { + t.Helper() + ctx := testutil.Context(t, testutil.WaitShort) + res, err := chatstate.CreateChat(ctx, f.DB, f.Pub, chatstate.CreateChatInput{ + OrganizationID: f.Org.ID, + OwnerID: f.User.ID, + LastModelConfigID: f.Model.ID, + Title: "test", + ClientType: database.ChatClientTypeApi, + InitialMessages: []chatstate.Message{ + userTextMessage("hello", f.User.ID, f.Model.ID), + }, + }) + require.NoError(t, err) + return res +} + +func TestChatMachine_Update_RejectsMissingChat(t *testing.T) { + t.Parallel() + f := newTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + m := chatstate.NewChatMachine(f.DB, f.Pub, uuid.New(), chatstate.Options{}) + err := m.Update(ctx, func(tx *chatstate.Tx) error { return nil }) + require.ErrorIs(t, err, chatstate.ErrChatNotFound) + require.Empty(t, f.Pub.channels) +} + +func TestChatMachine_Lock_DoesNotBumpSnapshot(t *testing.T) { + t.Parallel() + f := newTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + created := createTestChat(t, f) + m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{}) + + before := f.readChat(ctx, t, created.Chat.ID) + publishedBefore := len(f.Pub.channels) + + require.NoError(t, m.Lock(ctx, func(_ database.Store) error { + return nil + })) + after := f.readChat(ctx, t, created.Chat.ID) + require.Equal(t, before.SnapshotVersion, after.SnapshotVersion) + require.Equal(t, publishedBefore, len(f.Pub.channels), "Lock must not publish") +} + +func TestChatMachine_ReadLock_DoesNotBumpSnapshot(t *testing.T) { + t.Parallel() + f := newTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + created := createTestChat(t, f) + m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{}) + + before := f.readChat(ctx, t, created.Chat.ID) + publishedBefore := len(f.Pub.channels) + + var called bool + require.NoError(t, m.ReadLock(ctx, func(_ database.Store) error { + called = true + return nil + })) + require.True(t, called, "ReadLock must invoke the callback") + after := f.readChat(ctx, t, created.Chat.ID) + require.Equal(t, before.SnapshotVersion, after.SnapshotVersion) + require.Equal(t, publishedBefore, len(f.Pub.channels), "ReadLock must not publish") +} + +func TestChatMachine_ReadLock_RejectsMissingChat(t *testing.T) { + t.Parallel() + f := newTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + m := chatstate.NewChatMachine(f.DB, f.Pub, uuid.New(), chatstate.Options{}) + err := m.ReadLock(ctx, func(_ database.Store) error { + t.Fatal("callback must not run when the chat is missing") + return nil + }) + require.ErrorIs(t, err, chatstate.ErrChatNotFound) + require.Empty(t, f.Pub.channels) +} + +func TestChatMachine_UpdatePublishesAfterCommit(t *testing.T) { + t.Parallel() + f := newTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + created := createTestChat(t, f) + m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{}) + + publishedBefore := len(f.Pub.channels) + // Run a no-op Update; snapshot bump still happens, one update message + // should follow the commit. + require.NoError(t, m.Update(ctx, func(_ *chatstate.Tx) error { return nil })) + channel := coderdpubsub.ChatStateUpdateChannel(created.Chat.ID) + var found bool + for _, c := range f.Pub.channels[publishedBefore:] { + if c == channel { + found = true + break + } + } + require.True(t, found, "expected one chat:update message after commit") +} + +func TestChatMachine_FailedUpdate_PublishesNothing(t *testing.T) { + t.Parallel() + f := newTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + created := createTestChat(t, f) + m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{}) + + before := f.readChat(ctx, t, created.Chat.ID) + channelsBefore := len(f.Pub.channels) + expected := newSentinel() + cbErr := m.Update(ctx, func(_ *chatstate.Tx) error { return expected }) + require.ErrorIs(t, cbErr, expected) + require.Equal(t, channelsBefore, len(f.Pub.channels), "failed update should not publish") + // snapshot_version should not have advanced. + after := f.readChat(ctx, t, created.Chat.ID) + require.Equal(t, before.SnapshotVersion, after.SnapshotVersion) +} + +func TestMessageRevisionTrigger_AssignsRevisionFromSnapshot(t *testing.T) { + t.Parallel() + f := newTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + created := createTestChat(t, f) // snapshot 1, history_version 1 via trigger + + // CommitStep an assistant message; it should land with revision = chat.snapshot_version after the bump. + m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{}) + var step chatstate.CommitStepResult + require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error { + assistant := userTextMessage("assistant", f.User.ID, f.Model.ID) + assistant.Role = database.ChatMessageRoleAssistant + var err error + step, err = tx.CommitStep(chatstate.CommitStepInput{ + Messages: []chatstate.Message{assistant}, + }) + return err + })) + require.Len(t, step.InsertedMessages, 1) + after := f.readChat(ctx, t, created.Chat.ID) + // The Update call bumps snapshot_version once before the trigger + // runs, so the new revision should equal the bumped snapshot. + require.Equal(t, after.SnapshotVersion, step.InsertedMessages[0].Revision) + require.Equal(t, after.SnapshotVersion, after.HistoryVersion) + require.Equal(t, int64(0), after.GenerationAttempt, "trigger resets generation_attempt to 0") +} + +func TestQueueVersionTrigger_AdvancesOnInsert(t *testing.T) { + t.Parallel() + f := newTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + created := createTestChat(t, f) // queue_version starts at 0 + + m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{}) + require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error { + _, err := tx.SendMessage(chatstate.SendMessageInput{ + Message: userTextMessage("queue", f.User.ID, f.Model.ID), + BusyBehavior: chatstate.BusyBehaviorQueue, + }) + return err + })) + after := f.readChat(ctx, t, created.Chat.ID) + require.Equal(t, after.SnapshotVersion, after.QueueVersion) + require.Greater(t, after.QueueVersion, int64(0)) +} + +func TestQueueVersionTrigger_StableForNonQueueMutations(t *testing.T) { + t.Parallel() + f := newTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + created := createTestChat(t, f) + m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{}) + + require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error { + assistant := userTextMessage("assistant", f.User.ID, f.Model.ID) + assistant.Role = database.ChatMessageRoleAssistant + _, err := tx.CommitStep(chatstate.CommitStepInput{ + Messages: []chatstate.Message{assistant}, + }) + return err + })) + // queue_version must remain unchanged from initial 0. + require.Equal(t, int64(0), f.readChat(ctx, t, created.Chat.ID).QueueVersion) +} + +// TestUpdateFlushesBufferedPublicationsAfterCommit verifies that +// ChatMachine.Update owns the PublishBuffer lifecycle: nothing +// reaches the inner publisher until after the transaction commits, +// and at commit the buffered chat:update is forwarded. +func TestUpdateFlushesBufferedPublicationsAfterCommit(t *testing.T) { + t.Parallel() + f := newTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + created := createTestChat(t, f) + + channel := coderdpubsub.ChatStateUpdateChannel(created.Chat.ID) + baseline := countChannel(f.Pub.channels, channel) + + m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{}) + + // During the callback, no new chat:update for this chat may have + // reached the inner publisher because the buffer holds it. + require.NoError(t, m.Update(ctx, func(_ *chatstate.Tx) error { + require.Equal(t, baseline, countChannel(f.Pub.channels, channel), + "inner publisher saw chat:update before transaction committed") + return nil + })) + + require.Equal(t, baseline+1, countChannel(f.Pub.channels, channel), + "exactly one new chat:update reached the inner publisher after commit") +} + +// TestUpdateDiscardsBufferedPublicationsOnCallbackError verifies the +// deferred Discard path: when the callback returns an error the +// transaction rolls back and no buffered messages reach the inner +// publisher. +func TestUpdateDiscardsBufferedPublicationsOnCallbackError(t *testing.T) { + t.Parallel() + f := newTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + created := createTestChat(t, f) + m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{}) + + before := f.readChat(ctx, t, created.Chat.ID) + channelsBefore := len(f.Pub.channels) + + sentinel := xerrors.New("callback boom") + err := m.Update(ctx, func(_ *chatstate.Tx) error { return sentinel }) + require.ErrorIs(t, err, sentinel) + + require.Equal(t, channelsBefore, len(f.Pub.channels), + "failed update must not flush any buffered publications") + after := f.readChat(ctx, t, created.Chat.ID) + require.Equal(t, before.SnapshotVersion, after.SnapshotVersion, + "snapshot bump rolled back when callback returns error") +} + +// ============================================================================= +// helpers +// ============================================================================= + +type sentinelError struct{ msg string } + +func (s *sentinelError) Error() string { return s.msg } + +func newSentinel() error { return &sentinelError{msg: "sentinel"} } + +func countChannel(channels []string, channel string) int { + c := 0 + for _, ch := range channels { + if ch == channel { + c++ + } + } + return c +} diff --git a/coderd/x/chatd/chatstate/messages.go b/coderd/x/chatd/chatstate/messages.go new file mode 100644 index 0000000000..7248509970 --- /dev/null +++ b/coderd/x/chatd/chatstate/messages.go @@ -0,0 +1,112 @@ +package chatstate + +import ( + "database/sql" + + "github.com/google/uuid" + "github.com/sqlc-dev/pqtype" + + "github.com/coder/coder/v2/coderd/database" +) + +// Message is the durable message input shape used by chatstate +// transitions. It is intentionally lower level than the SDK message +// request types: callers must produce a fully materialized message +// (parsed parts, calculated cost, resolved model config) before +// passing it in. +// +// The state machine never reshapes a Message except to attach the +// runtime `chat_id`. The `revision` column is assigned by the +// `set_chat_message_revision` trigger; runtime code must not populate +// it. +type Message struct { + Role database.ChatMessageRole + Content pqtype.NullRawMessage + Visibility database.ChatMessageVisibility + ModelConfigID uuid.NullUUID + CreatedBy uuid.NullUUID + ContentVersion int16 + Compressed bool + InputTokens sql.NullInt64 + OutputTokens sql.NullInt64 + TotalTokens sql.NullInt64 + ReasoningTokens sql.NullInt64 + CacheCreationTokens sql.NullInt64 + CacheReadTokens sql.NullInt64 + ContextLimit sql.NullInt64 + TotalCostMicros sql.NullInt64 + RuntimeMs sql.NullInt64 + ProviderResponseID sql.NullString +} + +// toInsertParams converts a batch of Messages into the parallel-array +// shape required by `InsertChatMessages`. The returned struct has all +// arrays sized to len(messages). +// +// The chat ID is supplied by the caller because Message itself does +// not carry one (the chat machine already knows the chat). +func toInsertParams(chatID uuid.UUID, messages []Message) database.InsertChatMessagesParams { + n := len(messages) + params := database.InsertChatMessagesParams{ + ChatID: chatID, + CreatedBy: make([]uuid.UUID, n), + ModelConfigID: make([]uuid.UUID, n), + Role: make([]database.ChatMessageRole, n), + Content: make([]string, n), + ContentVersion: make([]int16, n), + Visibility: make([]database.ChatMessageVisibility, n), + InputTokens: make([]int64, n), + OutputTokens: make([]int64, n), + TotalTokens: make([]int64, n), + ReasoningTokens: make([]int64, n), + CacheCreationTokens: make([]int64, n), + CacheReadTokens: make([]int64, n), + ContextLimit: make([]int64, n), + Compressed: make([]bool, n), + TotalCostMicros: make([]int64, n), + RuntimeMs: make([]int64, n), + ProviderResponseID: make([]string, n), + } + for i, m := range messages { + params.CreatedBy[i] = nullUUIDOrNil(m.CreatedBy) + params.ModelConfigID[i] = nullUUIDOrNil(m.ModelConfigID) + params.Role[i] = m.Role + if m.Content.Valid { + params.Content[i] = string(m.Content.RawMessage) + } else { + // Use the JSON null literal; UNNEST + ::jsonb requires a + // valid JSON value and the trigger leaves it untouched. + params.Content[i] = "null" + } + params.ContentVersion[i] = m.ContentVersion + params.Visibility[i] = m.Visibility + params.InputTokens[i] = nullInt64Or(m.InputTokens, 0) + params.OutputTokens[i] = nullInt64Or(m.OutputTokens, 0) + params.TotalTokens[i] = nullInt64Or(m.TotalTokens, 0) + params.ReasoningTokens[i] = nullInt64Or(m.ReasoningTokens, 0) + params.CacheCreationTokens[i] = nullInt64Or(m.CacheCreationTokens, 0) + params.CacheReadTokens[i] = nullInt64Or(m.CacheReadTokens, 0) + params.ContextLimit[i] = nullInt64Or(m.ContextLimit, 0) + params.Compressed[i] = m.Compressed + params.TotalCostMicros[i] = nullInt64Or(m.TotalCostMicros, 0) + params.RuntimeMs[i] = nullInt64Or(m.RuntimeMs, 0) + if m.ProviderResponseID.Valid { + params.ProviderResponseID[i] = m.ProviderResponseID.String + } + } + return params +} + +func nullUUIDOrNil(u uuid.NullUUID) uuid.UUID { + if u.Valid { + return u.UUID + } + return uuid.Nil +} + +func nullInt64Or(v sql.NullInt64, fallback int64) int64 { + if v.Valid { + return v.Int64 + } + return fallback +} diff --git a/coderd/x/chatd/chatstate/notify.go b/coderd/x/chatd/chatstate/notify.go new file mode 100644 index 0000000000..1392ab88b7 --- /dev/null +++ b/coderd/x/chatd/chatstate/notify.go @@ -0,0 +1,169 @@ +package chatstate + +import ( + "encoding/json" + "fmt" + "sync" + + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/database" + coderdpubsub "github.com/coder/coder/v2/coderd/pubsub" +) + +// Publisher is the minimal interface chatstate needs to publish +// pubsub messages. It is intentionally compatible with +// database/pubsub.Pubsub: real callers pass the live pubsub directly +// and tests pass a recording fake. Chatstate entry points +// ([ChatMachine.Update], [CreateChat], [SetFamilyArchived]) own the +// internal [PublishBuffer] lifecycle so callers never see it. +type Publisher interface { + Publish(event string, message []byte) error +} + +// PublishBuffer is a [Publisher] that records each Publish call in +// order without forwarding it until [PublishBuffer.Flush] is called. +// It is an internal primitive used by chatstate entry points to +// hold pubsub messages until the surrounding transaction commits, +// and by tests that need to observe buffered output. Normal callers +// do not construct a PublishBuffer themselves and do not invoke +// Flush or Discard; chatstate's entry points own that lifecycle. +type PublishBuffer struct { + inner Publisher + + mu sync.Mutex + pending []bufferedMessage + flushed bool + disabled bool +} + +type bufferedMessage struct { + Channel string + Payload []byte +} + +// NewPublishBuffer constructs a PublishBuffer that, when flushed, will +// forward messages in order to inner. +func NewPublishBuffer(inner Publisher) *PublishBuffer { + return &PublishBuffer{inner: inner} +} + +// Publish records a message. It never forwards to the inner publisher +// until [PublishBuffer.Flush] is called. Returns an error if Flush has +// already happened to make accidental reuse obvious. +func (b *PublishBuffer) Publish(channel string, payload []byte) error { + b.mu.Lock() + defer b.mu.Unlock() + if b.flushed { + return xerrors.Errorf("publish buffer flushed; cannot accept message for %q", channel) + } + if b.disabled { + return nil + } + cp := make([]byte, len(payload)) + copy(cp, payload) + b.pending = append(b.pending, bufferedMessage{Channel: channel, Payload: cp}) + return nil +} + +// Flush forwards every pending message to the inner publisher in the +// order it was buffered, then marks the buffer flushed. The first +// publish error is returned with the channel name annotated; later +// messages still attempt to publish so the inner publisher sees a +// consistent state. +func (b *PublishBuffer) Flush() error { + b.mu.Lock() + defer b.mu.Unlock() + if b.flushed { + return nil + } + b.flushed = true + var firstErr error + for _, msg := range b.pending { + if err := b.inner.Publish(msg.Channel, msg.Payload); err != nil && firstErr == nil { + firstErr = xerrors.Errorf("publish %s: %w", msg.Channel, err) + } + } + return firstErr +} + +// Discard clears the buffered messages without forwarding them. It +// is safe to call multiple times and is harmless after [PublishBuffer.Flush]: +// once Flush has marked the buffer flushed and forwarded its +// pending messages, a subsequent Discard simply clears the (now +// empty) pending slice and sets the buffer to drop any future +// Publish calls. This makes `defer buf.Discard()` a safe pattern +// after a successful flush, including the one chatstate entry +// points use to own the buffer lifecycle. +func (b *PublishBuffer) Discard() { + b.mu.Lock() + defer b.mu.Unlock() + b.pending = nil + b.disabled = true +} + +// pending returns a snapshot of the buffered messages, primarily for +// tests via [PublishBuffer.BufferedChannels]. The returned slice is a +// copy and safe to inspect without holding the buffer lock. +func (b *PublishBuffer) snapshotPending() []bufferedMessage { + b.mu.Lock() + defer b.mu.Unlock() + out := make([]bufferedMessage, len(b.pending)) + copy(out, b.pending) + return out +} + +// BufferedChannels returns just the channels of the pending messages +// in order. Primarily useful for assertions in tests. +func (b *PublishBuffer) BufferedChannels() []string { + pending := b.snapshotPending() + out := make([]string, len(pending)) + for i, m := range pending { + out[i] = m.Channel + } + return out +} + +// buildChatUpdateMessage produces the JSON payload for a +// `chat:update:{chat_id}` message describing the post-transition +// snapshot of chat. +func buildChatUpdateMessage(chat database.Chat) []byte { + msg := coderdpubsub.ChatStateUpdateMessage{ + SnapshotVersion: chat.SnapshotVersion, + HistoryVersion: chat.HistoryVersion, + QueueVersion: chat.QueueVersion, + RetryStateVersion: chat.RetryStateVersion, + GenerationAttempt: chat.GenerationAttempt, + Status: string(chat.Status), + Archived: chat.Archived, + } + if chat.WorkerID.Valid { + id := chat.WorkerID.UUID + msg.WorkerID = &id + } + if chat.RunnerID.Valid { + id := chat.RunnerID.UUID + msg.RunnerID = &id + } + payload, err := json.Marshal(msg) + if err != nil { + // json.Marshal on this struct is total; panic is acceptable + // because the only failure mode would be a bug in this + // package, not user input. + panic(fmt.Sprintf("marshal chat state update: %v", err)) + } + return payload +} + +// buildChatOwnershipMessage produces the JSON payload for the global +// `chat:ownership` ownership hint for chat. +func buildChatOwnershipMessage(chat database.Chat) []byte { + payload, err := json.Marshal(coderdpubsub.ChatStateOwnershipMessage{ + ChatID: chat.ID, + SnapshotVersion: chat.SnapshotVersion, + }) + if err != nil { + panic(fmt.Sprintf("marshal chat state ownership: %v", err)) + } + return payload +} diff --git a/coderd/x/chatd/chatstate/notify_integration_test.go b/coderd/x/chatd/chatstate/notify_integration_test.go new file mode 100644 index 0000000000..2c3a704137 --- /dev/null +++ b/coderd/x/chatd/chatstate/notify_integration_test.go @@ -0,0 +1,373 @@ +package chatstate_test + +import ( + "context" + "encoding/json" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/database" + coderdpubsub "github.com/coder/coder/v2/coderd/pubsub" + "github.com/coder/coder/v2/coderd/x/chatd/chatstate" + "github.com/coder/coder/v2/testutil" +) + +// ============================================================================= +// Publication helpers +// ============================================================================= + +// publishedOn returns the indices into f.Pub.channels (and f.Pub.payloads) +// that match the given channel name, in order. +func publishedOn(f *testFixture, channel string) []int { + var idx []int + for i, c := range f.Pub.channels { + if c == channel { + idx = append(idx, i) + } + } + return idx +} + +// TestCreateChatPublishesAfterCommit asserts that a successful +// CreateChat call publishes exactly one chat:update message on the +// per-chat channel after the inner transaction commits. +func TestCreateChatPublishesAfterCommit(t *testing.T) { + t.Parallel() + f := newTestFixture(t) + res := createTestChat(t, f) + + channel := coderdpubsub.ChatStateUpdateChannel(res.Chat.ID) + idx := publishedOn(f, channel) + require.Len(t, idx, 1, "exactly one chat:update for the new chat") + + var msg coderdpubsub.ChatStateUpdateMessage + require.NoError(t, json.Unmarshal(f.Pub.payloads[idx[0]], &msg)) + require.Equal(t, res.Chat.SnapshotVersion, msg.SnapshotVersion) + require.Equal(t, string(database.ChatStatusRunning), msg.Status) +} + +// TestUpdatePublishesAfterCommit asserts that ChatMachine.Update +// publishes one chat:update on the per-chat channel after the inner +// transaction commits, even when the callback performs no transition. +func TestUpdatePublishesAfterCommit(t *testing.T) { + t.Parallel() + f := newTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + created := createTestChat(t, f) + m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{}) + + createIdx := publishedOn(f, coderdpubsub.ChatStateUpdateChannel(created.Chat.ID)) + require.Len(t, createIdx, 1, "create published one chat:update") + + require.NoError(t, m.Update(ctx, func(_ *chatstate.Tx) error { return nil })) + + updIdx := publishedOn(f, coderdpubsub.ChatStateUpdateChannel(created.Chat.ID)) + require.Len(t, updIdx, 2, "no-op Update still publishes a chat:update") + + after, err := f.DB.GetChatByID(ctx, created.Chat.ID) + require.NoError(t, err) + var msg coderdpubsub.ChatStateUpdateMessage + require.NoError(t, json.Unmarshal(f.Pub.payloads[updIdx[1]], &msg)) + require.Equal(t, after.SnapshotVersion, msg.SnapshotVersion) +} + +// TestUpdatePublishesOneFinalChatUpdateForTransitionBundle bundles +// several transitions inside one Update callback and verifies the +// commit publishes exactly one chat:update on the per-chat channel +// (not one per transition). +func TestUpdatePublishesOneFinalChatUpdateForTransitionBundle(t *testing.T) { + t.Parallel() + f := newTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + created := createTestChat(t, f) + m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{}) + + baseUpdates := len(publishedOn(f, coderdpubsub.ChatStateUpdateChannel(created.Chat.ID))) + + // R0 -> W (FinishTurn) -> XW (SetArchived true) -> W (SetArchived false). + require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error { + if _, err := tx.FinishTurn(chatstate.FinishTurnInput{}); err != nil { + return err + } + if _, err := tx.SetArchived(chatstate.SetArchivedInput{Archived: true}); err != nil { + return err + } + if _, err := tx.SetArchived(chatstate.SetArchivedInput{Archived: false}); err != nil { + return err + } + return nil + })) + + updIdx := publishedOn(f, coderdpubsub.ChatStateUpdateChannel(created.Chat.ID)) + require.Equal(t, baseUpdates+1, len(updIdx), + "three-transition bundle publishes exactly one final chat:update") + + after, err := f.DB.GetChatByID(ctx, created.Chat.ID) + require.NoError(t, err) + require.Equal(t, database.ChatStatusWaiting, after.Status, "ends in W") +} + +// TestUpdateAppliesTransitionBundleSequentially verifies that +// transitions chained inside a single Update callback see each +// other's effects: later transitions validate against the state +// produced by earlier ones (R0 -> W is rejected when called twice +// because the second call sees state W and FinishTurn is no longer +// allowed). +func TestUpdateAppliesTransitionBundleSequentially(t *testing.T) { + t.Parallel() + f := newTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + created := createTestChat(t, f) + m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{}) + + err := m.Update(ctx, func(tx *chatstate.Tx) error { + if _, err := tx.FinishTurn(chatstate.FinishTurnInput{}); err != nil { + return err + } + // Second FinishTurn should fail because state is now W. + _, err := tx.FinishTurn(chatstate.FinishTurnInput{}) + return err + }) + require.Error(t, err) + require.ErrorIs(t, err, chatstate.ErrTransitionNotAllowed) + + // Failed bundle rolls back: state must not have advanced past R0. + require.Equal(t, chatstate.StateR0, f.classify(ctx, t, created.Chat.ID), + "failed bundle rolls back the whole transaction") +} + +// TestFailedUpdatePublishesNothing verifies that a callback error +// rolls back the snapshot bump and publishes nothing. +func TestFailedUpdatePublishesNothing(t *testing.T) { + t.Parallel() + f := newTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + created := createTestChat(t, f) + m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{}) + + publishedBefore := len(f.Pub.channels) + beforeChat := f.readChat(ctx, t, created.Chat.ID) + + sentinel := xerrors.New("forced failure") + err := m.Update(ctx, func(_ *chatstate.Tx) error { return sentinel }) + require.ErrorIs(t, err, sentinel) + require.Equal(t, publishedBefore, len(f.Pub.channels), "failed update publishes nothing") + + after := f.readChat(ctx, t, created.Chat.ID) + require.Equal(t, beforeChat.SnapshotVersion, after.SnapshotVersion, + "failed update rolls back snapshot bump") +} + +// TestLockPublishesNothing verifies that Lock does not publish even +// though it locks the chat row. +func TestLockPublishesNothing(t *testing.T) { + t.Parallel() + f := newTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + created := createTestChat(t, f) + m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{}) + + publishedBefore := len(f.Pub.channels) + require.NoError(t, m.Lock(ctx, func(_ database.Store) error { return nil })) + require.Equal(t, publishedBefore, len(f.Pub.channels), "Lock publishes nothing") +} + +// TestPublishBufferWithRolledBackOuterTransactionPublishesNothing +// wires a chatstate machine through a PublishBuffer and exercises +// the buffer primitive directly: when the caller discards before +// flushing, the inner publisher receives nothing. ChatMachine.Update +// uses the same primitive internally with a deferred Discard; +// callers no longer drive Flush or Discard themselves. +func TestPublishBufferWithRolledBackOuterTransactionPublishesNothing(t *testing.T) { + t.Parallel() + f := newTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + created := createTestChat(t, f) + + // Run one normal Update to establish a stable baseline channel + // count. CreateChat plus this Update may publish chat:update + // and chat:ownership messages depending on ownership, so we + // take the snapshot after that activity settles. + m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{}) + require.NoError(t, m.Update(ctx, func(_ *chatstate.Tx) error { return nil })) + baseline := len(f.Pub.channels) + + // Now exercise the PublishBuffer rollback path explicitly. The + // outer transaction "rolls back": the caller buffers messages, + // discards them, then flushes. The inner publisher must see + // none of the buffered messages. + buf := chatstate.NewPublishBuffer(f.Pub) + require.NoError(t, buf.Publish("chat:update:bogus", []byte("payload"))) + require.NoError(t, buf.Publish("chat:ownership", []byte("payload"))) + buf.Discard() + require.NoError(t, buf.Flush()) + + require.Equal(t, baseline, len(f.Pub.channels), + "discarded buffer publishes nothing through the inner publisher") +} + +// TestChatUpdateMessagePayloadShape verifies the JSON shape of the +// chat:update payload contains every field consumers depend on: +// snapshot_version, history_version, queue_version, +// retry_state_version, generation_attempt, status, archived, and +// optional worker_id / runner_id. +func TestChatUpdateMessagePayloadShape(t *testing.T) { + t.Parallel() + f := newTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + created := createTestChat(t, f) + m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{}) + + // Acquire ownership so worker_id and runner_id are present. + worker := uuid.New() + runner := uuid.New() + require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error { + _, err := tx.Acquire(chatstate.AcquireInput{WorkerID: worker, RunnerID: runner}) + return err + })) + + // Find the last chat:update message. + channel := coderdpubsub.ChatStateUpdateChannel(created.Chat.ID) + idx := publishedOn(f, channel) + require.NotEmpty(t, idx) + last := f.Pub.payloads[idx[len(idx)-1]] + + // Strict-decode against the typed struct. + var typed coderdpubsub.ChatStateUpdateMessage + require.NoError(t, json.Unmarshal(last, &typed)) + require.Greater(t, typed.SnapshotVersion, int64(0)) + require.NotNil(t, typed.WorkerID) + require.Equal(t, worker, *typed.WorkerID) + require.NotNil(t, typed.RunnerID) + require.Equal(t, runner, *typed.RunnerID) + require.Equal(t, string(database.ChatStatusRunning), typed.Status) + require.False(t, typed.Archived) + + // Permissive decode to assert exact JSON keys. + var raw map[string]json.RawMessage + require.NoError(t, json.Unmarshal(last, &raw)) + for _, key := range []string{ + "snapshot_version", + "history_version", + "queue_version", + "retry_state_version", + "generation_attempt", + "status", + "archived", + "worker_id", + "runner_id", + } { + _, ok := raw[key] + require.True(t, ok, "payload missing key %q", key) + } +} + +// TestChatOwnershipMessagePayloadShape verifies the JSON shape of +// chat:ownership: chat_id and snapshot_version. +func TestChatOwnershipMessagePayloadShape(t *testing.T) { + t.Parallel() + f := newTestFixture(t) + // CreateChat publishes one ownership hint because the new chat is + // unowned and runnable. + created := createTestChat(t, f) + + idx := publishedOn(f, coderdpubsub.ChatStateOwnershipChannel) + require.NotEmpty(t, idx, "CreateChat publishes at least one chat:ownership hint") + + payload := f.Pub.payloads[idx[len(idx)-1]] + var typed coderdpubsub.ChatStateOwnershipMessage + require.NoError(t, json.Unmarshal(payload, &typed)) + require.Equal(t, created.Chat.ID, typed.ChatID) + require.Greater(t, typed.SnapshotVersion, int64(0)) + + var raw map[string]json.RawMessage + require.NoError(t, json.Unmarshal(payload, &raw)) + for _, key := range []string{"chat_id", "snapshot_version"} { + _, ok := raw[key] + require.True(t, ok, "ownership payload missing key %q", key) + } +} + +// TestOwnershipNotificationUsesDatabaseHeartbeatStaleness verifies +// that an ownership hint fires when the heartbeat is stale by the +// database's clock, regardless of what the local Go clock says. We +// rewrite the heartbeat row to a deterministically old timestamp via +// raw SQL and confirm the post-commit hint is sent on a subsequent +// runnable Update. +func TestOwnershipNotificationUsesDatabaseHeartbeatStaleness(t *testing.T) { + t.Parallel() + tf := newTriggerFixture(t) + f := tf.f + ctx := testutil.Context(t, testutil.WaitShort) + created := createTestChat(t, f) + m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{}) + + // Acquire ownership; this writes a fresh heartbeat. + worker := uuid.New() + runner := uuid.New() + require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error { + _, err := tx.Acquire(chatstate.AcquireInput{WorkerID: worker, RunnerID: runner}) + return err + })) + hb, err := f.DB.GetChatHeartbeat(ctx, database.GetChatHeartbeatParams{ + ChatID: created.Chat.ID, + RunnerID: runner, + }) + require.NoError(t, err) + require.WithinDuration(t, time.Now(), hb.HeartbeatAt, time.Minute, + "Acquire wrote a fresh heartbeat") + + // Snapshot ownership-hint count before the test trigger. + ownershipBefore := f.Pub.ownershipPublishCount() + + // Force the heartbeat to a deterministically old time. + _, err = tf.sqlDB.ExecContext(ctx, ` + UPDATE chat_heartbeats + SET heartbeat_at = NOW() - INTERVAL '1 hour' + WHERE chat_id = $1 AND runner_id = $2 + `, created.Chat.ID, runner) + require.NoError(t, err) + + // Confirm database-side staleness check agrees. + stale, err := f.DB.IsChatHeartbeatStale(ctx, database.IsChatHeartbeatStaleParams{ + ChatID: created.Chat.ID, + RunnerID: runner, + StaleSeconds: chatstate.HeartbeatStaleSeconds, + }) + require.NoError(t, err) + require.True(t, stale, "heartbeat is stale per database time") + + // Run a no-op Update. The chat is runnable (R0) and the + // heartbeat is stale, so post-commit logic must publish exactly + // one chat:ownership hint. + require.NoError(t, m.Update(ctx, func(_ *chatstate.Tx) error { return nil })) + + ownershipAfter := f.Pub.ownershipPublishCount() + require.Equal(t, ownershipBefore+1, ownershipAfter, + "stale heartbeat triggers a fresh ownership hint") +} + +// TestUpdateContextCancellationPublishesNothing verifies that +// canceling the caller's context (between the inner commit and the +// publish loop's first call) does not corrupt state. We exercise the +// simpler observable contract: when the user cancels before Update +// gets to do anything, nothing is published. The strict before-publish +// race is exercised in concurrency tests with channel sync. +func TestUpdateContextCancellationPublishesNothing(t *testing.T) { + t.Parallel() + f := newTestFixture(t) + created := createTestChat(t, f) + m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{}) + + publishedBefore := len(f.Pub.channels) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + err := m.Update(ctx, func(_ *chatstate.Tx) error { return nil }) + require.Error(t, err) + require.Equal(t, publishedBefore, len(f.Pub.channels), + "caller-aborted update publishes nothing") +} diff --git a/coderd/x/chatd/chatstate/notify_internal_test.go b/coderd/x/chatd/chatstate/notify_internal_test.go new file mode 100644 index 0000000000..e04419cb92 --- /dev/null +++ b/coderd/x/chatd/chatstate/notify_internal_test.go @@ -0,0 +1,113 @@ +package chatstate + +import ( + "testing" + + "github.com/stretchr/testify/require" + "golang.org/x/xerrors" +) + +type recordingPublisher struct { + calls []recordedCall + errOn map[string]error + failed map[string]int +} + +type recordedCall struct { + Channel string + Payload []byte +} + +func newRecordingPublisher() *recordingPublisher { + return &recordingPublisher{ + errOn: map[string]error{}, + failed: map[string]int{}, + } +} + +func (r *recordingPublisher) Publish(channel string, payload []byte) error { + r.calls = append(r.calls, recordedCall{Channel: channel, Payload: append([]byte(nil), payload...)}) + if err, ok := r.errOn[channel]; ok { + r.failed[channel]++ + return err + } + return nil +} + +func TestPublishBuffer_DefersPublishUntilFlush(t *testing.T) { + t.Parallel() + inner := newRecordingPublisher() + buf := NewPublishBuffer(inner) + + require.NoError(t, buf.Publish("a", []byte("1"))) + require.NoError(t, buf.Publish("b", []byte("2"))) + + require.Empty(t, inner.calls, "inner publisher should not be called before flush") + require.Equal(t, []string{"a", "b"}, buf.BufferedChannels()) +} + +func TestPublishBuffer_FlushPublishesInOrder(t *testing.T) { + t.Parallel() + inner := newRecordingPublisher() + buf := NewPublishBuffer(inner) + + require.NoError(t, buf.Publish("a", []byte("1"))) + require.NoError(t, buf.Publish("b", []byte("2"))) + require.NoError(t, buf.Publish("c", []byte("3"))) + + require.NoError(t, buf.Flush()) + require.Len(t, inner.calls, 3) + require.Equal(t, "a", inner.calls[0].Channel) + require.Equal(t, "b", inner.calls[1].Channel) + require.Equal(t, "c", inner.calls[2].Channel) + require.Equal(t, []byte("1"), inner.calls[0].Payload) +} + +func TestPublishBuffer_FlushReturnsFirstError(t *testing.T) { + t.Parallel() + inner := newRecordingPublisher() + inner.errOn["b"] = xerrors.New("broken") + buf := NewPublishBuffer(inner) + + require.NoError(t, buf.Publish("a", []byte("1"))) + require.NoError(t, buf.Publish("b", []byte("2"))) + require.NoError(t, buf.Publish("c", []byte("3"))) + + err := buf.Flush() + require.Error(t, err) + require.Contains(t, err.Error(), "publish b:") + // Even after the broken channel, later messages should still be + // attempted so the inner publisher sees them. + require.Len(t, inner.calls, 3) +} + +func TestPublishBuffer_PublishAfterFlushFails(t *testing.T) { + t.Parallel() + inner := newRecordingPublisher() + buf := NewPublishBuffer(inner) + require.NoError(t, buf.Flush()) + require.Error(t, buf.Publish("x", []byte("y"))) +} + +func TestPublishBuffer_DiscardSuppressesPending(t *testing.T) { + t.Parallel() + inner := newRecordingPublisher() + buf := NewPublishBuffer(inner) + require.NoError(t, buf.Publish("a", []byte("1"))) + buf.Discard() + require.NoError(t, buf.Flush()) + require.Empty(t, inner.calls) +} + +func TestPublishBuffer_DiscardBlocksLaterPublishes(t *testing.T) { + t.Parallel() + inner := newRecordingPublisher() + buf := NewPublishBuffer(inner) + buf.Discard() + // Discard sets disabled; subsequent Publish is a no-op (not an + // error) so callers using Discard before/around rollback paths + // do not have to special-case unwind. + require.NoError(t, buf.Publish("a", []byte("1"))) + require.NoError(t, buf.Flush()) + require.Empty(t, inner.calls) +} diff --git a/coderd/x/chatd/chatstate/state.go b/coderd/x/chatd/chatstate/state.go new file mode 100644 index 0000000000..018148d521 --- /dev/null +++ b/coderd/x/chatd/chatstate/state.go @@ -0,0 +1,182 @@ +package chatstate + +import ( + "github.com/coder/coder/v2/coderd/database" +) + +// ExecutionState is the RFC shorthand for a chat's current execution +// state. The valid set contains 13 states; everything outside it is +// represented by [StateInvalid]. +type ExecutionState string + +const ( + // StateN: chat does not exist. + StateN ExecutionState = "N" + // StateW: waiting, empty queue, not archived. + StateW ExecutionState = "W" + // StateE0: error, empty queue, not archived. + StateE0 ExecutionState = "E0" + // StateE1: error, non-empty queue, not archived. + StateE1 ExecutionState = "E1" + // StateR0: running, empty queue, not archived. + StateR0 ExecutionState = "R0" + // StateR1: running, non-empty queue, not archived. + StateR1 ExecutionState = "R1" + // StateI0: interrupting, empty queue, not archived. + StateI0 ExecutionState = "I0" + // StateI1: interrupting, non-empty queue, not archived. + StateI1 ExecutionState = "I1" + // StateA0: requires_action, empty queue, not archived. + StateA0 ExecutionState = "A0" + // StateA1: requires_action, non-empty queue, not archived. + StateA1 ExecutionState = "A1" + // StateXW: archived waiting, empty queue. + StateXW ExecutionState = "XW" + // StateXE0: archived error, empty queue. + StateXE0 ExecutionState = "XE0" + // StateXE1: archived error, non-empty queue. + StateXE1 ExecutionState = "XE1" + + // StateInvalid groups every status/archive/queue combination that + // is not one of the 13 valid states above. The state machine + // refuses non-reconciliation transitions on invalid states and + // exposes the [Tx.ReconcileInvalidState] transition to recover. + StateInvalid ExecutionState = "Invalid" +) + +// String implements fmt.Stringer. +func (s ExecutionState) String() string { return string(s) } + +// AllExecutionStates is the canonical enumeration of every value the +// classifier can return. Tests rely on this list to iterate over every +// state when verifying transition coverage. +var AllExecutionStates = []ExecutionState{ + StateN, + StateW, + StateE0, + StateE1, + StateR0, + StateR1, + StateI0, + StateI1, + StateA0, + StateA1, + StateXW, + StateXE0, + StateXE1, + StateInvalid, +} + +// IsRunnable returns true for the execution states that the chat +// worker is allowed to acquire and drive forward: R0, R1, I0, I1, +// A0, and A1. Requires-action states need worker ownership for +// timeout processing. Other states are idle (W, E*, XW, XE*), absent +// (N), or invalid. +func (s ExecutionState) IsRunnable() bool { + switch s { + case StateR0, StateR1, StateI0, StateI1, StateA0, StateA1: + return true + default: + return false + } +} + +// IsArchived returns true for the three archived execution states. +func (s ExecutionState) IsArchived() bool { + switch s { + case StateXW, StateXE0, StateXE1: + return true + default: + return false + } +} + +// QueueNonEmpty returns true for execution states that require a +// non-empty queue. Useful when seeding test fixtures. +func (s ExecutionState) QueueNonEmpty() bool { + switch s { + case StateE1, StateR1, StateI1, StateA1, StateXE1: + return true + default: + return false + } +} + +// ClassifyExecutionState turns the chat row, queue cardinality, and +// whether the chat row exists into one of the 14 [ExecutionState] +// values. The caller is responsible for loading the chat under the +// row lock and reading the queue count in the same transaction. +// +// Callers that have no chat row (lookup returned sql.ErrNoRows) +// should pass exists=false; the chat, status, and archive arguments +// are then ignored. +// +// The classifier is a single flat switch over the 13 valid +// (status, archived, queue) tuples from the RFC. Anything outside +// that set (legacy pending/paused/completed statuses, archived busy +// states, waiting with a non-empty queue, future enum values) falls +// through to [StateInvalid]. +// +//nolint:revive // queueNonEmpty/exists are simple flags that mirror the RFC inputs. +func ClassifyExecutionState(chat database.Chat, queueNonEmpty, exists bool) ExecutionState { + if !exists { + return StateN + } + switch { + case chat.Status == database.ChatStatusWaiting && !chat.Archived && !queueNonEmpty: + return StateW + case chat.Status == database.ChatStatusWaiting && chat.Archived && !queueNonEmpty: + return StateXW + case chat.Status == database.ChatStatusError && !chat.Archived && !queueNonEmpty: + return StateE0 + case chat.Status == database.ChatStatusError && !chat.Archived && queueNonEmpty: + return StateE1 + case chat.Status == database.ChatStatusError && chat.Archived && !queueNonEmpty: + return StateXE0 + case chat.Status == database.ChatStatusError && chat.Archived && queueNonEmpty: + return StateXE1 + case chat.Status == database.ChatStatusRunning && !chat.Archived && !queueNonEmpty: + return StateR0 + case chat.Status == database.ChatStatusRunning && !chat.Archived && queueNonEmpty: + return StateR1 + case chat.Status == database.ChatStatusInterrupting && !chat.Archived && !queueNonEmpty: + return StateI0 + case chat.Status == database.ChatStatusInterrupting && !chat.Archived && queueNonEmpty: + return StateI1 + case chat.Status == database.ChatStatusRequiresAction && !chat.Archived && !queueNonEmpty: + return StateA0 + case chat.Status == database.ChatStatusRequiresAction && !chat.Archived && queueNonEmpty: + return StateA1 + } + return StateInvalid +} + +// OwnershipState is the RFC shorthand for whether a chat row is +// currently owned by a worker. The state machine treats execution and +// ownership as orthogonal. +type OwnershipState string + +const ( + // StateU: chat has no owner (worker_id IS NULL). + StateU OwnershipState = "U" + // StateO: chat has an owner (worker_id IS NOT NULL). + StateO OwnershipState = "O" +) + +// String implements fmt.Stringer. +func (s OwnershipState) String() string { return string(s) } + +// AllOwnershipStates is the canonical enumeration of ownership states. +var AllOwnershipStates = []OwnershipState{StateU, StateO} + +// ClassifyOwnershipState returns [StateU] when worker_id is NULL and +// [StateO] otherwise. The runner_id field is intentionally not part of +// the ownership classification. Acquire writes both worker_id and +// runner_id; Abandon clears the current row without using runner_id as +// a transition input. +func ClassifyOwnershipState(chat database.Chat) OwnershipState { + if !chat.WorkerID.Valid { + return StateU + } + return StateO +} diff --git a/coderd/x/chatd/chatstate/state_internal_test.go b/coderd/x/chatd/chatstate/state_internal_test.go new file mode 100644 index 0000000000..ade1951dc9 --- /dev/null +++ b/coderd/x/chatd/chatstate/state_internal_test.go @@ -0,0 +1,195 @@ +package chatstate + +import ( + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/coderd/database" +) + +func chatWithStatus(status database.ChatStatus, archived bool) database.Chat { + return database.Chat{ + ID: uuid.New(), + Status: status, + Archived: archived, + OwnerID: uuid.New(), + } +} + +// TestClassifyExecutionState_Valid covers every valid classification: +// N (missing chat) plus the twelve valid existing-chat states. +func TestClassifyExecutionState_Valid(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + status database.ChatStatus + archived bool + queueNonEmpty bool + exists bool + want ExecutionState + }{ + {name: "N", exists: false, want: StateN}, + {name: "W", status: database.ChatStatusWaiting, exists: true, want: StateW}, + {name: "E0", status: database.ChatStatusError, exists: true, want: StateE0}, + {name: "E1", status: database.ChatStatusError, queueNonEmpty: true, exists: true, want: StateE1}, + {name: "R0", status: database.ChatStatusRunning, exists: true, want: StateR0}, + {name: "R1", status: database.ChatStatusRunning, queueNonEmpty: true, exists: true, want: StateR1}, + {name: "I0", status: database.ChatStatusInterrupting, exists: true, want: StateI0}, + {name: "I1", status: database.ChatStatusInterrupting, queueNonEmpty: true, exists: true, want: StateI1}, + {name: "A0", status: database.ChatStatusRequiresAction, exists: true, want: StateA0}, + {name: "A1", status: database.ChatStatusRequiresAction, queueNonEmpty: true, exists: true, want: StateA1}, + {name: "XW", status: database.ChatStatusWaiting, archived: true, exists: true, want: StateXW}, + {name: "XE0", status: database.ChatStatusError, archived: true, exists: true, want: StateXE0}, + {name: "XE1", status: database.ChatStatusError, archived: true, queueNonEmpty: true, exists: true, want: StateXE1}, + } + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + chat := database.Chat{} + if tc.exists { + chat = chatWithStatus(tc.status, tc.archived) + } + require.Equal(t, tc.want, ClassifyExecutionState(chat, tc.queueNonEmpty, tc.exists)) + }) + } +} + +// TestClassifyExecutionState_Invalid covers every documented invalid +// combination: legacy statuses, waiting-with-queue, and archived busy +// statuses. +func TestClassifyExecutionState_Invalid(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + status database.ChatStatus + archived bool + queueNonEmpty bool + }{ + // Legacy statuses (pending/paused/completed) are invalid for + // the new state machine. + {name: "LegacyPending", status: "pending"}, + {name: "LegacyPaused", status: "paused"}, + {name: "LegacyCompleted", status: "completed"}, + + // Waiting must always have an empty queue. + {name: "WaitingWithQueue", status: database.ChatStatusWaiting, queueNonEmpty: true}, + {name: "WaitingArchivedWithQueue", status: database.ChatStatusWaiting, archived: true, queueNonEmpty: true}, + + // Archived busy statuses are invalid. + {name: "ArchivedRunning", status: database.ChatStatusRunning, archived: true}, + {name: "ArchivedInterrupting", status: database.ChatStatusInterrupting, archived: true}, + {name: "ArchivedRequiresAction", status: database.ChatStatusRequiresAction, archived: true}, + } + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + got := ClassifyExecutionState(chatWithStatus(tc.status, tc.archived), tc.queueNonEmpty, true) + require.Equal(t, StateInvalid, got) + }) + } +} + +// TestClassifyExecutionState_RejectsAllUnlistedCombinations enumerates +// every (status, archived, queueNonEmpty) tuple for an existing chat +// and asserts exactly the twelve RFC tuples classify out of +// [StateInvalid]. Missing chats are handled separately via the N case +// in [TestClassifyExecutionState_Valid]. +func TestClassifyExecutionState_RejectsAllUnlistedCombinations(t *testing.T) { + t.Parallel() + allStatuses := []database.ChatStatus{ + database.ChatStatusWaiting, + database.ChatStatusError, + database.ChatStatusRunning, + database.ChatStatusInterrupting, + database.ChatStatusRequiresAction, + "pending", "paused", "completed", + } + validCount := 0 + for _, status := range allStatuses { + for _, archived := range []bool{false, true} { + for _, queueNonEmpty := range []bool{false, true} { + got := ClassifyExecutionState(chatWithStatus(status, archived), queueNonEmpty, true) + if got != StateInvalid { + validCount++ + } + } + } + } + require.Equal(t, 12, validCount, "exactly 12 valid existing-chat (status, archived, queue) tuples") +} + +// TestClassifyOwnershipState covers both ownership classifications. +func TestClassifyOwnershipState(t *testing.T) { + t.Parallel() + cases := []struct { + name string + chat func() database.Chat + want OwnershipState + }{ + { + name: "Unowned", + chat: func() database.Chat { return database.Chat{} }, + want: StateU, + }, + { + name: "Owned", + chat: func() database.Chat { + c := database.Chat{} + c.WorkerID.UUID = uuid.New() + c.WorkerID.Valid = true + return c + }, + want: StateO, + }, + } + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + require.Equal(t, tc.want, ClassifyOwnershipState(tc.chat())) + }) + } +} + +// TestAllExecutionStates_Enumeration verifies AllExecutionStates +// contains every declared execution state exactly once. +func TestAllExecutionStates_Enumeration(t *testing.T) { + t.Parallel() + want := map[ExecutionState]bool{ + StateN: true, StateW: true, StateE0: true, StateE1: true, + StateR0: true, StateR1: true, StateI0: true, StateI1: true, + StateA0: true, StateA1: true, StateXW: true, StateXE0: true, + StateXE1: true, StateInvalid: true, + } + require.Len(t, AllExecutionStates, len(want)) + seen := make(map[ExecutionState]bool, len(want)) + for _, s := range AllExecutionStates { + require.True(t, want[s], "unexpected state %s", s) + require.False(t, seen[s], "duplicate state %s", s) + seen[s] = true + } +} + +// TestExecutionState_Predicates covers IsRunnable and QueueNonEmpty +// for every declared execution state. +func TestExecutionState_Predicates(t *testing.T) { + t.Parallel() + + runnable := map[ExecutionState]bool{ + StateR0: true, StateR1: true, StateI0: true, StateI1: true, + StateA0: true, StateA1: true, + } + nonEmpty := map[ExecutionState]bool{ + StateE1: true, StateR1: true, StateI1: true, StateA1: true, StateXE1: true, + } + for _, s := range AllExecutionStates { + require.Equal(t, runnable[s], s.IsRunnable(), "IsRunnable(%s)", s) + require.Equal(t, nonEmpty[s], s.QueueNonEmpty(), "QueueNonEmpty(%s)", s) + } +} diff --git a/coderd/x/chatd/chatstate/synthetic_cancellation_test.go b/coderd/x/chatd/chatstate/synthetic_cancellation_test.go new file mode 100644 index 0000000000..33542bdcf7 --- /dev/null +++ b/coderd/x/chatd/chatstate/synthetic_cancellation_test.go @@ -0,0 +1,485 @@ +package chatstate_test + +import ( + "context" + "encoding/json" + "testing" + + "github.com/google/uuid" + "github.com/sqlc-dev/pqtype" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/x/chatd/chatprompt" + "github.com/coder/coder/v2/coderd/x/chatd/chatstate" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/testutil" +) + +// ============================================================================= +// Helpers for synthetic cancellation tests +// ============================================================================= + +// nonDynamicAssistantToolCallMessage builds an assistant message that +// issues a single tool call against a tool that is NOT in the chat's +// dynamic_tools set. The send-message and edit-message paths use the +// "cancel every outstanding tool call regardless of source" variant +// (dynamicOnly=false), so the cancellation must still fire even for +// non-dynamic tools. +func nonDynamicAssistantToolCallMessage(t *testing.T, modelID uuid.UUID, callID string) chatstate.Message { + t.Helper() + raw, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{{ + Type: codersdk.ChatMessagePartTypeToolCall, + ToolCallID: callID, + ToolName: "non_dynamic_tool", + Args: json.RawMessage(`{}`), + }}) + require.NoError(t, err) + return chatstate.Message{ + Role: database.ChatMessageRoleAssistant, + Content: raw, + Visibility: database.ChatMessageVisibilityBoth, + ContentVersion: chatprompt.CurrentContentVersion, + ModelConfigID: uuid.NullUUID{UUID: modelID, Valid: true}, + } +} + +// assertToolResultForCall asserts that msg is a tool-result message +// that resolves a tool call with id wantCallID and is_error=true. +func assertToolResultForCall(t *testing.T, msg database.ChatMessage, wantCallID string) { + t.Helper() + require.Equal(t, database.ChatMessageRoleTool, msg.Role) + parts, err := chatprompt.ParseContent(msg) + require.NoError(t, err) + require.NotEmpty(t, parts) + var found bool + for _, p := range parts { + if p.Type != codersdk.ChatMessagePartTypeToolResult { + continue + } + require.Equal(t, wantCallID, p.ToolCallID, "tool-call id matches") + require.True(t, p.IsError, "synthetic cancellation must be marked is_error=true") + found = true + } + require.True(t, found, "expected at least one tool-result part") +} + +// commitAssistantToolCall pushes an assistant message that calls +// `tool_name` with `callID` into history via CommitStep. Returns the +// inserted assistant ChatMessage. Use the dynamic-tools chat fixture +// (createTestChatWithDynamicTools) when dynamicOnly cancellation +// paths are exercised. +func commitAssistantToolCall( + t *testing.T, + f *testFixture, + m *chatstate.ChatMachine, + msg chatstate.Message, +) database.ChatMessage { + t.Helper() + ctx := testutil.Context(t, testutil.WaitShort) + var step chatstate.CommitStepResult + require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error { + var err error + step, err = tx.CommitStep(chatstate.CommitStepInput{Messages: []chatstate.Message{msg}}) + return err + })) + require.Len(t, step.InsertedMessages, 1) + return step.InsertedMessages[0] +} + +// landInW puts a fresh R0 chat into state W (waiting) via FinishTurn. +func landInW(t *testing.T, f *testFixture, m *chatstate.ChatMachine) { + t.Helper() + ctx := testutil.Context(t, testutil.WaitShort) + require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error { + _, err := tx.FinishTurn(chatstate.FinishTurnInput{}) + return err + })) + require.Equal(t, chatstate.StateW, f.classify(ctx, t, m.ChatID())) +} + +// landInE0 puts a fresh R0 chat into state E0 (error, empty queue). +func landInE0(t *testing.T, f *testFixture, m *chatstate.ChatMachine) { + t.Helper() + ctx := testutil.Context(t, testutil.WaitShort) + require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error { + _, err := tx.FinishError(chatstate.FinishErrorInput{ + LastError: pqtype.NullRawMessage{ + RawMessage: json.RawMessage(`{"message":"boom"}`), + Valid: true, + }, + }) + return err + })) + require.Equal(t, chatstate.StateE0, f.classify(ctx, t, m.ChatID())) +} + +// ============================================================================= +// SendMessage direct-history paths (W, E0) +// ============================================================================= + +// TestSendMessageDirect_W_SynthesizesToolCancellations verifies that +// from W, SendMessage inserts synthetic tool-result rows for every +// outstanding tool call on the last assistant message BEFORE the new +// user message, regardless of whether the tools are dynamic. +func TestSendMessageDirect_W_SynthesizesToolCancellations(t *testing.T) { + t.Parallel() + f := newTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + created := createTestChat(t, f) + m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{}) + + callID := "call_" + uuid.NewString() + assistant := commitAssistantToolCall(t, f, m, + nonDynamicAssistantToolCallMessage(t, f.Model.ID, callID)) + require.Equal(t, database.ChatMessageRoleAssistant, assistant.Role) + + // R0 -> W. + landInW(t, f, m) + + // SendMessage with a fresh user message. The direct-history path + // must insert a synthetic tool-result (for callID) followed by + // the new user message. + var send chatstate.SendMessageResult + require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error { + var err error + send, err = tx.SendMessage(chatstate.SendMessageInput{ + Message: userTextMessage("after-cancel", f.User.ID, f.Model.ID), + BusyBehavior: chatstate.BusyBehaviorQueue, + }) + return err + })) + + require.Len(t, send.InsertedMessages, 2, "synthetic cancel + new user") + assertToolResultForCall(t, send.InsertedMessages[0], callID) + require.Equal(t, database.ChatMessageRoleUser, send.InsertedMessages[1].Role) + require.Less(t, send.InsertedMessages[0].ID, send.InsertedMessages[1].ID, + "synthetic cancel is inserted before the user message") +} + +// TestSendMessageDirect_E0_SynthesizesToolCancellations exercises +// the same path from E0. +func TestSendMessageDirect_E0_SynthesizesToolCancellations(t *testing.T) { + t.Parallel() + f := newTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + created := createTestChat(t, f) + m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{}) + + callID := "call_" + uuid.NewString() + commitAssistantToolCall(t, f, m, + nonDynamicAssistantToolCallMessage(t, f.Model.ID, callID)) + + // R0 -> E0. + landInE0(t, f, m) + + var send chatstate.SendMessageResult + require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error { + var err error + send, err = tx.SendMessage(chatstate.SendMessageInput{ + Message: userTextMessage("after-error", f.User.ID, f.Model.ID), + BusyBehavior: chatstate.BusyBehaviorQueue, + }) + return err + })) + + require.Len(t, send.InsertedMessages, 2) + assertToolResultForCall(t, send.InsertedMessages[0], callID) + require.Equal(t, database.ChatMessageRoleUser, send.InsertedMessages[1].Role) +} + +// ============================================================================= +// EditMessage replacement insertion +// ============================================================================= + +// TestEditMessage_SynthesizesToolCancellationsBeforeReplacement +// verifies that EditMessage from a state with an outstanding tool +// call before the edited user message inserts a synthetic +// tool-result before the replacement user message in history. +// +// The scenario is: +// - user message 1 (initial) +// - assistant tool-call (outstanding) +// - user message 2 (the one we will edit) +// +// EditMessage soft-deletes user message 2 and everything after it, +// then synthesizes cancellations for tool calls on the last +// surviving assistant message that have no matching tool-result. +func TestEditMessage_SynthesizesToolCancellationsBeforeReplacement(t *testing.T) { + t.Parallel() + f := newTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + created := createTestChat(t, f) + m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{}) + + // Build the history described above. CommitStep is happy to insert + // a mixed batch as long as it stays inside R0. + callID := "call_" + uuid.NewString() + assistantTC := nonDynamicAssistantToolCallMessage(t, f.Model.ID, callID) + secondUser := userTextMessage("second user", f.User.ID, f.Model.ID) + var step chatstate.CommitStepResult + require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error { + var err error + step, err = tx.CommitStep(chatstate.CommitStepInput{ + Messages: []chatstate.Message{assistantTC, secondUser}, + }) + return err + })) + require.Len(t, step.InsertedMessages, 2) + secondUserID := step.InsertedMessages[1].ID + require.Equal(t, database.ChatMessageRoleUser, step.InsertedMessages[1].Role) + + var edit chatstate.EditMessageResult + editedContent := mustMarshalParts(t, []codersdk.ChatMessagePart{ + codersdk.ChatMessageText("edited"), + }) + require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error { + var err error + edit, err = tx.EditMessage(chatstate.EditMessageInput{ + MessageID: secondUserID, + CreatedBy: f.User.ID, + Content: editedContent, + }) + return err + })) + + require.Len(t, edit.CancellationMessages, 1, "synthetic cancel inserted") + assertToolResultForCall(t, edit.CancellationMessages[0], callID) + require.Equal(t, database.ChatMessageRoleUser, edit.ReplacementMessage.Role) + require.Less(t, edit.CancellationMessages[0].ID, edit.ReplacementMessage.ID, + "cancellations are inserted before the replacement user message") +} + +// PromoteQueuedMessage synchronous-history paths (E1, A1) +// ============================================================================= + +// TestPromoteQueuedMessage_E1_SynthesizesToolCancellations verifies +// that promoting a queued message from E1 inserts synthetic +// tool-result rows for outstanding tool calls before the promoted +// user message in history. +func TestPromoteQueuedMessage_E1_SynthesizesToolCancellations(t *testing.T) { + t.Parallel() + f := newTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + created := createTestChat(t, f) + m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{}) + + callID := "call_" + uuid.NewString() + commitAssistantToolCall(t, f, m, + nonDynamicAssistantToolCallMessage(t, f.Model.ID, callID)) + + // Land in R1 with one queued message. + queued := sendQueuedMessage(t, f, m, "queued-for-promote") + require.NotNil(t, queued.QueuedMessage) + require.Equal(t, chatstate.StateR1, f.classify(ctx, t, created.Chat.ID)) + + // R1 -> E1 via FinishError. + require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error { + _, err := tx.FinishError(chatstate.FinishErrorInput{ + LastError: pqtype.NullRawMessage{ + RawMessage: json.RawMessage(`{"message":"boom"}`), + Valid: true, + }, + }) + return err + })) + require.Equal(t, chatstate.StateE1, f.classify(ctx, t, created.Chat.ID)) + + var promote chatstate.PromoteQueuedMessageResult + require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error { + var err error + promote, err = tx.PromoteQueuedMessage(chatstate.PromoteQueuedMessageInput{ + QueuedMessageID: queued.QueuedMessage.ID, + }) + return err + })) + + require.Len(t, promote.CancellationMessages, 1) + assertToolResultForCall(t, promote.CancellationMessages[0], callID) + require.NotNil(t, promote.InsertedMessage) + require.Equal(t, database.ChatMessageRoleUser, promote.InsertedMessage.Role) + require.Less(t, promote.CancellationMessages[0].ID, promote.InsertedMessage.ID, + "cancel is inserted before the promoted user message") +} + +// TestPromoteQueuedMessage_A1_SynthesizesDynamicToolCancellations +// verifies that the dynamic outstanding tool call is canceled when +// promoting from A1. The non-dynamic cancellation counterpart lives +// in TestPromoteQueuedMessage_A1_CancelsAllOutstandingToolCalls. +func TestPromoteQueuedMessage_A1_SynthesizesDynamicToolCancellations(t *testing.T) { + t.Parallel() + f := newTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + + toolName := "dyn_promote_a1" + created := createTestChatWithDynamicTools(t, f, toolName) + m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{}) + + dynCallID := "call_" + uuid.NewString() + commitAssistantToolCall(t, f, m, + assistantToolCallMessage(t, f.Model.ID, toolName, dynCallID)) + + // Land in A0. + require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error { + _, err := tx.EnterRequiresAction(chatstate.EnterRequiresActionInput{}) + return err + })) + require.Equal(t, chatstate.StateA0, f.classify(ctx, t, created.Chat.ID)) + + // A0 -> A1 with one queued user message. + queued := sendQueuedMessage(t, f, m, "queued-for-a1-promote") + require.NotNil(t, queued.QueuedMessage) + require.Equal(t, chatstate.StateA1, f.classify(ctx, t, created.Chat.ID)) + + var promote chatstate.PromoteQueuedMessageResult + require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error { + var err error + promote, err = tx.PromoteQueuedMessage(chatstate.PromoteQueuedMessageInput{ + QueuedMessageID: queued.QueuedMessage.ID, + }) + return err + })) + + require.Len(t, promote.CancellationMessages, 1, "dynamic tool call canceled") + assertToolResultForCall(t, promote.CancellationMessages[0], dynCallID) + require.NotNil(t, promote.InsertedMessage) + require.Equal(t, database.ChatMessageRoleUser, promote.InsertedMessage.Role) +} + +// ============================================================================= +// FinishTurn queue-promotion paths +// ============================================================================= + +// TestFinishTurn_R1_SynthesizesToolCancellationsBeforePromotion +// verifies that finishing a turn while a queued message exists +// synthesizes outstanding tool cancellations before promoting the +// queue head into history. +func TestFinishTurn_R1_SynthesizesToolCancellationsBeforePromotion(t *testing.T) { + t.Parallel() + f := newTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + created := createTestChat(t, f) + m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{}) + + callID := "call_" + uuid.NewString() + commitAssistantToolCall(t, f, m, + nonDynamicAssistantToolCallMessage(t, f.Model.ID, callID)) + + queued := sendQueuedMessage(t, f, m, "queued-for-finish") + require.NotNil(t, queued.QueuedMessage) + require.Equal(t, chatstate.StateR1, f.classify(ctx, t, created.Chat.ID)) + + beforeIDs := historyMessageIDs(ctx, t, f, created.Chat.ID) + + var finish chatstate.FinishTurnResult + require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error { + var err error + finish, err = tx.FinishTurn(chatstate.FinishTurnInput{}) + return err + })) + require.NotNil(t, finish.PromotedMessage) + require.Equal(t, database.ChatMessageRoleUser, finish.PromotedMessage.Role) + + afterIDs := historyMessageIDs(ctx, t, f, created.Chat.ID) + require.Equal(t, len(beforeIDs)+2, len(afterIDs), + "finish inserts both a tool cancel and the promoted user") + + // The two newly inserted messages are tool-result then user. + newIDs := afterIDs[len(beforeIDs):] + cancel, err := f.DB.GetChatMessageByID(ctx, newIDs[0]) + require.NoError(t, err) + assertToolResultForCall(t, cancel, callID) + require.Equal(t, finish.PromotedMessage.ID, newIDs[1]) +} + +// ============================================================================= +// FinishInterruption queue-promotion paths and outstanding tool calls +// ============================================================================= + +// TestFinishInterruption_I1_PromotesQueueHead verifies that +// FinishInterruption from I1 with no outstanding tool calls +// promotes the queue head into history. +func TestFinishInterruption_I1_PromotesQueueHead(t *testing.T) { + t.Parallel() + f := newTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + created := createTestChat(t, f) + m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{}) + + // Reach R1 with one queued message. + queued := sendQueuedMessage(t, f, m, "queued-for-interruption") + require.NotNil(t, queued.QueuedMessage) + // R1 -> I1 via Interrupt. + require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error { + _, err := tx.Interrupt(chatstate.InterruptInput{Reason: "test"}) + return err + })) + require.Equal(t, chatstate.StateI1, f.classify(ctx, t, created.Chat.ID)) + + beforeIDs := historyMessageIDs(ctx, t, f, created.Chat.ID) + + var finish chatstate.FinishInterruptionResult + require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error { + var err error + finish, err = tx.FinishInterruption(chatstate.FinishInterruptionInput{}) + return err + })) + require.NotNil(t, finish.PromotedMessage) + require.Equal(t, database.ChatMessageRoleUser, finish.PromotedMessage.Role) + + afterIDs := historyMessageIDs(ctx, t, f, created.Chat.ID) + require.Equal(t, len(beforeIDs)+1, len(afterIDs)) + require.Equal(t, chatstate.StateR0, f.classify(ctx, t, created.Chat.ID)) +} + +// TestFinishInterruption_RejectsOutstandingToolCalls verifies that +// FinishInterruption fails (TransitionNotAllowed-shaped) when the +// chat still has an outstanding dynamic tool call after the partial +// commit. The non-dynamic counterpart lives in +// TestFinishInterruption_RejectsNonDynamicOutstandingToolCall. The +// chat must remain in its prior state. +func TestFinishInterruption_RejectsOutstandingToolCalls(t *testing.T) { + t.Parallel() + f := newTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + + toolName := "dyn_finish_reject" + created := createTestChatWithDynamicTools(t, f, toolName) + m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{}) + + dynCallID := "call_" + uuid.NewString() + commitAssistantToolCall(t, f, m, + assistantToolCallMessage(t, f.Model.ID, toolName, dynCallID)) + + // R0 -> I0 via Interrupt. Interrupt closes pending dynamic calls + // when transitioning from A0/A1, but from R0 it does NOT, so the + // chat keeps its outstanding dynamic call. + require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error { + _, err := tx.Interrupt(chatstate.InterruptInput{Reason: "test"}) + return err + })) + require.Equal(t, chatstate.StateI0, f.classify(ctx, t, created.Chat.ID)) + + stateBefore := f.classify(ctx, t, created.Chat.ID) + historyBefore := historyMessageIDs(ctx, t, f, created.Chat.ID) + publishedBefore := len(f.Pub.channels) + + // FinishInterruption with no partial commits should reject + // because the dynamic call is still outstanding. + err := m.Update(ctx, func(tx *chatstate.Tx) error { + _, err := tx.FinishInterruption(chatstate.FinishInterruptionInput{}) + return err + }) + require.Error(t, err) + require.ErrorIs(t, err, chatstate.ErrTransitionNotAllowed) + + require.Equal(t, stateBefore, f.classify(ctx, t, created.Chat.ID), "state unchanged") + require.Equal(t, historyBefore, historyMessageIDs(ctx, t, f, created.Chat.ID), + "history unchanged on rejected finish") + require.Equal(t, publishedBefore, len(f.Pub.channels), + "failed FinishInterruption publishes nothing") +} + +// ensure unused imports don't break the build if any helper is +// removed later. +var _ = context.Background diff --git a/coderd/x/chatd/chatstate/synthetics.go b/coderd/x/chatd/chatstate/synthetics.go new file mode 100644 index 0000000000..28f5ac2dfd --- /dev/null +++ b/coderd/x/chatd/chatstate/synthetics.go @@ -0,0 +1,234 @@ +package chatstate + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + + "github.com/google/uuid" + "github.com/sqlc-dev/pqtype" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/x/chatd/chatprompt" + "github.com/coder/coder/v2/codersdk" +) + +// synthesizePendingToolCancellations builds [Message] inserts that +// satisfy every outstanding tool call on the chat's last assistant +// message with a synthetic cancellation tool-result message. +// +// "Outstanding" means a tool call present on the last assistant +// message that does not yet have a matching tool-result message in +// the active history after it. The caller controls whether to limit +// to dynamic-tool calls (true) or close every outstanding tool call +// regardless of source (false). The dynamic-only variant is used by +// requires-action interrupts; the all-tool variant is used by any +// transition that needs to insert a new user message into history. +// +// The synthetic results use the supplied chat's last_model_config_id. +// Returns (nil, nil) when there is nothing to synthesize. +// +//nolint:revive // dynamicOnly is a domain flag, not a control flag. +func synthesizePendingToolCancellations( + ctx context.Context, + store database.Store, + chat database.Chat, + reason string, + dynamicOnly bool, +) ([]Message, error) { + var dynamicToolNames map[string]bool + if dynamicOnly { + var err error + dynamicToolNames, err = parseDynamicToolNamesFromRaw(chat.DynamicTools) + if err != nil { + return nil, xerrors.Errorf("parse dynamic tool names: %w", err) + } + if len(dynamicToolNames) == 0 { + return nil, nil + } + } + + lastAssistant, err := store.GetLastChatMessageByRole(ctx, database.GetLastChatMessageByRoleParams{ + ChatID: chat.ID, + Role: database.ChatMessageRoleAssistant, + }) + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + if err != nil { + return nil, xerrors.Errorf("get last assistant message: %w", err) + } + assistantParts, err := chatprompt.ParseContent(lastAssistant) + if err != nil { + return nil, xerrors.Errorf("parse assistant message: %w", err) + } + afterMsgs, err := store.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ + ChatID: chat.ID, + AfterID: lastAssistant.ID, + }) + if err != nil { + return nil, xerrors.Errorf("get messages after assistant: %w", err) + } + handled := make(map[string]bool) + for _, msg := range afterMsgs { + if msg.Role != database.ChatMessageRoleTool { + continue + } + parts, perr := chatprompt.ParseContent(msg) + if perr != nil { + // Don't fail the whole cancellation just because one + // historical message is unparseable; treat its tool + // results as unknown. + continue + } + for _, p := range parts { + if p.Type == codersdk.ChatMessagePartTypeToolResult { + handled[p.ToolCallID] = true + } + } + } + out := make([]Message, 0) + for _, part := range assistantParts { + if part.Type != codersdk.ChatMessagePartTypeToolCall { + continue + } + if dynamicOnly && !dynamicToolNames[part.ToolName] { + continue + } + if handled[part.ToolCallID] { + continue + } + resultPart := codersdk.ChatMessagePart{ + Type: codersdk.ChatMessagePartTypeToolResult, + ToolCallID: part.ToolCallID, + ToolName: part.ToolName, + Result: json.RawMessage(fmt.Sprintf("%q", reason)), + IsError: true, + } + raw, merr := chatprompt.MarshalParts([]codersdk.ChatMessagePart{resultPart}) + if merr != nil { + return nil, xerrors.Errorf("marshal synthetic tool result: %w", merr) + } + out = append(out, Message{ + Role: database.ChatMessageRoleTool, + Content: raw, + Visibility: database.ChatMessageVisibilityBoth, + ContentVersion: chatprompt.CurrentContentVersion, + ModelConfigID: uuid.NullUUID{UUID: chat.LastModelConfigID, Valid: true}, + }) + } + if len(out) == 0 { + return nil, nil + } + return out, nil +} + +// pendingDynamicToolCallIDs returns the dynamic tool-call IDs on the +// chat's last assistant message that do not yet have a matching +// tool-result message in active history. The returned map is keyed by +// tool-call ID and valued by tool name so callers can build matching +// result messages without re-parsing the assistant content. +func pendingDynamicToolCallIDs(ctx context.Context, store database.Store, chat database.Chat) (map[string]string, error) { + dynamic, err := parseDynamicToolNamesFromRaw(chat.DynamicTools) + if err != nil { + return nil, err + } + if len(dynamic) == 0 { + return map[string]string{}, nil + } + return outstandingToolCallIDs(ctx, store, chat, func(toolName string) bool { + return dynamic[toolName] + }) +} + +// pendingAllToolCallIDs returns the tool-call IDs of every outstanding +// tool call on the chat's last assistant message, regardless of +// whether the tool is dynamic. The returned map is keyed by tool-call +// ID and valued by tool name. Callers that must guarantee a valid +// LLM message history (e.g. before promoting a user message into +// active history, or after committing an interruption's partial +// messages) should use this variant so non-dynamic tool calls do not +// silently bypass the check. +func pendingAllToolCallIDs(ctx context.Context, store database.Store, chat database.Chat) (map[string]string, error) { + return outstandingToolCallIDs(ctx, store, chat, func(string) bool { return true }) +} + +// outstandingToolCallIDs walks the chat's last assistant message and +// returns the subset of its tool calls that have no matching +// tool-result message in the active history after it. The accept +// callback can be used to restrict the walk to a subset of tools +// (e.g. dynamic-only). +func outstandingToolCallIDs(ctx context.Context, store database.Store, chat database.Chat, accept func(toolName string) bool) (map[string]string, error) { + lastAssistant, err := store.GetLastChatMessageByRole(ctx, database.GetLastChatMessageByRoleParams{ + ChatID: chat.ID, + Role: database.ChatMessageRoleAssistant, + }) + if errors.Is(err, sql.ErrNoRows) { + return map[string]string{}, nil + } + if err != nil { + return nil, xerrors.Errorf("get last assistant: %w", err) + } + parts, err := chatprompt.ParseContent(lastAssistant) + if err != nil { + return nil, xerrors.Errorf("parse assistant: %w", err) + } + afterMsgs, err := store.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ + ChatID: chat.ID, + AfterID: lastAssistant.ID, + }) + if err != nil { + return nil, xerrors.Errorf("get messages after assistant: %w", err) + } + handled := make(map[string]bool) + for _, msg := range afterMsgs { + if msg.Role != database.ChatMessageRoleTool { + continue + } + ps, perr := chatprompt.ParseContent(msg) + if perr != nil { + continue + } + for _, p := range ps { + if p.Type == codersdk.ChatMessagePartTypeToolResult { + handled[p.ToolCallID] = true + } + } + } + out := make(map[string]string) + for _, p := range parts { + if p.Type != codersdk.ChatMessagePartTypeToolCall { + continue + } + if !accept(p.ToolName) { + continue + } + if handled[p.ToolCallID] { + continue + } + out[p.ToolCallID] = p.ToolName + } + return out, nil +} + +// parseDynamicToolNamesFromRaw is a private mirror of +// chatd.parseDynamicToolNames so chatstate does not pull a runtime +// dependency on the chatd package. It accepts a nullable raw JSON +// blob and returns a name set. +func parseDynamicToolNamesFromRaw(raw pqtype.NullRawMessage) (map[string]bool, error) { + if !raw.Valid || len(raw.RawMessage) == 0 { + return map[string]bool{}, nil + } + var tools []codersdk.DynamicTool + if err := json.Unmarshal(raw.RawMessage, &tools); err != nil { + return nil, err + } + out := make(map[string]bool, len(tools)) + for _, t := range tools { + out[t.Name] = true + } + return out, nil +} diff --git a/coderd/x/chatd/chatstate/transition.go b/coderd/x/chatd/chatstate/transition.go new file mode 100644 index 0000000000..e1fdbe9f3d --- /dev/null +++ b/coderd/x/chatd/chatstate/transition.go @@ -0,0 +1,226 @@ +package chatstate + +// Transition is the enumeration of transitions implemented by the +// state machine. Values intentionally match the names of the public +// methods on [Tx] (and [CreateChat]). The transition matrix below +// declares the legal (from -> to) execution-state mappings used by +// each transition method for validation. +type Transition string + +const ( + TransitionCreateChat Transition = "CreateChat" + TransitionSetArchived Transition = "SetArchived" + TransitionSendMessage Transition = "SendMessage" + TransitionEditMessage Transition = "EditMessage" + TransitionDeleteQueuedMessage Transition = "DeleteQueuedMessage" + TransitionPromoteQueuedMessage Transition = "PromoteQueuedMessage" + TransitionInterrupt Transition = "Interrupt" + TransitionCompleteRequiresAction Transition = "CompleteRequiresAction" + TransitionAcquire Transition = "Acquire" + TransitionAbandon Transition = "Abandon" + TransitionRecordGenerationAttempt Transition = "RecordGenerationAttempt" + TransitionRecordRetryState Transition = "RecordRetryState" + TransitionCommitStep Transition = "CommitStep" + TransitionEnterRequiresAction Transition = "EnterRequiresAction" + TransitionFinishInterruption Transition = "FinishInterruption" + TransitionFinishTurn Transition = "FinishTurn" + TransitionFinishError Transition = "FinishError" + TransitionCancelRequiresAction Transition = "CancelRequiresAction" + TransitionReconcileInvalidState Transition = "ReconcileInvalidState" +) + +// String implements fmt.Stringer. +func (t Transition) String() string { return string(t) } + +// AllExecutionTransitions is the canonical enumeration of every +// execution-state transition that has an entry in the matrix below. +// Ownership transitions (Acquire, Abandon) are intentionally not part +// of this slice because they are validated independently and do not +// have a (from->to) execution mapping. +var AllExecutionTransitions = []Transition{ + TransitionCreateChat, + TransitionSetArchived, + TransitionSendMessage, + TransitionEditMessage, + TransitionDeleteQueuedMessage, + TransitionPromoteQueuedMessage, + TransitionInterrupt, + TransitionCompleteRequiresAction, + TransitionRecordGenerationAttempt, + TransitionRecordRetryState, + TransitionCommitStep, + TransitionEnterRequiresAction, + TransitionFinishInterruption, + TransitionFinishTurn, + TransitionFinishError, + TransitionCancelRequiresAction, + TransitionReconcileInvalidState, +} + +// transitionMatrix is the in-code mirror of the RFC's execution-state +// transition table. Each entry maps an input state to the set of +// allowed transitions together with the possible classified output +// states that the transition implementation may land in. Outputs may +// depend on the post-mutation queue cardinality (for example +// DeleteQueuedMessage from E1 lands in E0 when the deleted row was the +// last queued message, or stays in E1 otherwise), which is why several +// entries list more than one output. +// +// Ownership transitions (Acquire, Abandon) are intentionally not +// included; they are orthogonal to execution state. +var transitionMatrix = map[ExecutionState]map[Transition][]ExecutionState{ + StateN: { + TransitionCreateChat: {StateR0}, + }, + StateW: { + TransitionSetArchived: {StateXW}, + TransitionSendMessage: {StateR0}, + TransitionEditMessage: {StateR0}, + }, + StateE0: { + TransitionSetArchived: {StateXE0}, + TransitionSendMessage: {StateR0}, + TransitionEditMessage: {StateR0}, + }, + StateE1: { + TransitionSetArchived: {StateXE1}, + TransitionSendMessage: {StateR1}, + TransitionEditMessage: {StateR0}, + TransitionDeleteQueuedMessage: {StateE0, StateE1}, + TransitionPromoteQueuedMessage: {StateR0, StateR1}, + }, + StateR0: { + TransitionSendMessage: {StateR1, StateI1}, + TransitionEditMessage: {StateR0}, + TransitionInterrupt: {StateI0}, + TransitionRecordGenerationAttempt: {StateR0}, + TransitionRecordRetryState: {StateR0}, + TransitionCommitStep: {StateR0}, + TransitionEnterRequiresAction: {StateA0}, + TransitionFinishTurn: {StateW}, + TransitionFinishError: {StateE0}, + }, + StateR1: { + TransitionSendMessage: {StateR1, StateI1}, + TransitionEditMessage: {StateR0}, + TransitionDeleteQueuedMessage: {StateR0, StateR1}, + TransitionPromoteQueuedMessage: {StateI1}, + TransitionInterrupt: {StateI1}, + TransitionRecordGenerationAttempt: {StateR1}, + TransitionRecordRetryState: {StateR1}, + TransitionCommitStep: {StateR1}, + TransitionEnterRequiresAction: {StateA1}, + TransitionFinishTurn: {StateR0, StateR1}, + TransitionFinishError: {StateE1}, + }, + StateI0: { + TransitionSendMessage: {StateI1}, + TransitionEditMessage: {StateR0}, + TransitionFinishInterruption: {StateW}, + }, + StateI1: { + TransitionSendMessage: {StateI1}, + TransitionEditMessage: {StateR0}, + TransitionDeleteQueuedMessage: {StateI0, StateI1}, + TransitionPromoteQueuedMessage: {StateI1}, + TransitionFinishInterruption: {StateR0, StateR1}, + }, + StateA0: { + TransitionSendMessage: {StateA1, StateR1}, + TransitionEditMessage: {StateR0}, + TransitionInterrupt: {StateR0}, + TransitionCompleteRequiresAction: {StateR0}, + TransitionCancelRequiresAction: {StateR0}, + }, + StateA1: { + TransitionSendMessage: {StateA1, StateR1}, + TransitionEditMessage: {StateR0}, + TransitionDeleteQueuedMessage: {StateA0, StateA1}, + TransitionPromoteQueuedMessage: {StateR0, StateR1}, + TransitionInterrupt: {StateR1}, + TransitionCompleteRequiresAction: {StateR1}, + TransitionCancelRequiresAction: {StateR1}, + }, + StateXW: { + TransitionSetArchived: {StateW}, + }, + StateXE0: { + TransitionSetArchived: {StateE0}, + }, + StateXE1: { + TransitionSetArchived: {StateE1}, + }, + StateInvalid: { + TransitionReconcileInvalidState: {StateE0, StateE1}, + }, +} + +// isExecutionTransitionAllowed reports whether a transition is legal +// from the supplied input state per the matrix above. Ownership +// transitions are not stored in the matrix and always return false. +func isExecutionTransitionAllowed(t Transition, from ExecutionState) bool { + allowed, ok := transitionMatrix[from] + if !ok { + return false + } + _, ok = allowed[t] + return ok +} + +// requireExecutionTransition validates that t is legal from `from` +// and returns a typed *TransitionError otherwise. +func requireExecutionTransition(t Transition, from ExecutionState) error { + if isExecutionTransitionAllowed(t, from) { + return nil + } + return newTransitionError(t, from, "") +} + +// AllowedExecutionTransitionsFrom returns a deterministic slice of +// transitions legal from `from`. Mostly used by tests to enumerate the +// matrix without leaking the internal map. +func AllowedExecutionTransitionsFrom(from ExecutionState) []Transition { + allowed := transitionMatrix[from] + out := make([]Transition, 0, len(allowed)) + for _, t := range AllExecutionTransitions { + if _, ok := allowed[t]; ok { + out = append(out, t) + } + } + return out +} + +// AllowedInputStates returns a deterministic slice of execution states +// from which `tr` is legal per the matrix above. Mostly used by tests +// to enumerate the matrix without leaking the internal map. +func AllowedInputStates(tr Transition) []ExecutionState { + var out []ExecutionState + for _, from := range AllExecutionStates { + if isExecutionTransitionAllowed(tr, from) { + out = append(out, from) + } + } + return out +} + +// AllowedExecutionTransitionOutputs returns the set of classified +// post-states that the transition `tr` may produce from `from` per +// the matrix above. The returned slice is a copy so callers may mutate +// it without affecting the underlying matrix. +// +// When `tr` is not allowed from `from`, an empty (nil) slice is +// returned. Tests use this helper to enumerate the (transition, from, +// want) triples that must be exercised by the row-level matrix tests. +func AllowedExecutionTransitionOutputs(from ExecutionState, tr Transition) []ExecutionState { + allowed, ok := transitionMatrix[from] + if !ok { + return nil + } + outputs, ok := allowed[tr] + if !ok { + return nil + } + cp := make([]ExecutionState, len(outputs)) + copy(cp, outputs) + return cp +} diff --git a/coderd/x/chatd/chatstate/transitions.go b/coderd/x/chatd/chatstate/transitions.go new file mode 100644 index 0000000000..9c142f8605 --- /dev/null +++ b/coderd/x/chatd/chatstate/transitions.go @@ -0,0 +1,1495 @@ +package chatstate + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "time" + + "github.com/google/uuid" + "github.com/sqlc-dev/pqtype" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/database" + coderdpubsub "github.com/coder/coder/v2/coderd/pubsub" + "github.com/coder/coder/v2/coderd/x/chatd/chatprompt" + "github.com/coder/coder/v2/codersdk" +) + +// ============================================================================= +// CreateChat +// +// CreateChat is the standalone N -> R0 entry point. It is a +// package-level function rather than a method on [ChatMachine] +// because no chat-scoped machine can exist before the chat row is +// written. +// ============================================================================= + +// CreateChatInput configures [CreateChat]. +type CreateChatInput struct { + OrganizationID uuid.UUID + OwnerID uuid.UUID + WorkspaceID uuid.NullUUID + BuildID uuid.NullUUID + AgentID uuid.NullUUID + ParentChatID uuid.NullUUID + RootChatID uuid.NullUUID + LastModelConfigID uuid.UUID + Title string + Mode database.NullChatMode + PlanMode database.NullChatPlanMode + MCPServerIDs []uuid.UUID + Labels pqtype.NullRawMessage + DynamicTools pqtype.NullRawMessage + ClientType database.ChatClientType + InitialMessages []Message +} + +// CreateChatResult is the value returned by [CreateChat]. It carries +// the new chat row and the inserted initial history. +type CreateChatResult struct { + Chat database.Chat + InitialMessages []database.ChatMessage +} + +// CreateChat creates a brand new chat with initial history in a single +// transaction. +// +// Validation: +// - InitialMessages must be non-empty. +// +// After commit CreateChat publishes a `chat:update` message describing +// the new chat snapshot. Because the new chat has no worker assigned, +// CreateChat also publishes an ownership hint so workers can race to +// acquire the runnable chat. +func CreateChat( + ctx context.Context, + store database.Store, + publisher Publisher, + input CreateChatInput, +) (CreateChatResult, error) { + if store == nil { + return CreateChatResult{}, xerrors.New("chatstate: CreateChat called with nil store") + } + if publisher == nil { + return CreateChatResult{}, xerrors.New("chatstate: CreateChat called with nil publisher") + } + if len(input.InitialMessages) == 0 { + return CreateChatResult{}, newTransitionError( + TransitionCreateChat, StateN, + "initial messages must include at least one message", + ) + } + var result CreateChatResult + buffer := NewPublishBuffer(publisher) + defer buffer.Discard() + err := store.InTx(func(store database.Store) error { + chat, ierr := store.InsertChat(ctx, database.InsertChatParams{ + OrganizationID: input.OrganizationID, + OwnerID: input.OwnerID, + WorkspaceID: input.WorkspaceID, + BuildID: input.BuildID, + AgentID: input.AgentID, + ParentChatID: input.ParentChatID, + RootChatID: input.RootChatID, + LastModelConfigID: input.LastModelConfigID, + Title: input.Title, + Mode: input.Mode, + PlanMode: input.PlanMode, + Status: database.ChatStatusRunning, + MCPServerIDs: input.MCPServerIDs, + Labels: input.Labels, + DynamicTools: input.DynamicTools, + ClientType: input.ClientType, + }) + if ierr != nil { + return xerrors.Errorf("insert chat: %w", ierr) + } + // Insert the initial history under the new chat row. The + // message revision trigger advances `history_version` to the + // current `snapshot_version` (which is 1 for a brand new chat). + inserted, ierr := store.InsertChatMessages(ctx, toInsertParams(chat.ID, input.InitialMessages)) + if ierr != nil { + return xerrors.Errorf("insert initial messages: %w", ierr) + } + refreshed, gerr := store.GetChatByID(ctx, chat.ID) + if gerr != nil { + return xerrors.Errorf("reload chat after initial messages: %w", gerr) + } + result = CreateChatResult{ + Chat: refreshed, + InitialMessages: inserted, + } + if perr := buffer.Publish( + coderdpubsub.ChatStateUpdateChannel(refreshed.ID), + buildChatUpdateMessage(refreshed), + ); perr != nil { + return xerrors.Errorf("buffer chat update: %w", perr) + } + if ClassifyExecutionState(refreshed, false, true).IsRunnable() { + if perr := buffer.Publish( + coderdpubsub.ChatStateOwnershipChannel, + buildChatOwnershipMessage(refreshed), + ); perr != nil { + return xerrors.Errorf("buffer ownership hint: %w", perr) + } + } + return nil + }, nil) + if err != nil { + return CreateChatResult{}, err + } + if perr := buffer.Flush(); perr != nil { + return result, perr + } + return result, nil +} + +// ============================================================================= +// Shared transition helpers +// ============================================================================= + +// applyExecutionStateUpdate is a small adapter so transition methods +// do not have to repeat the UpdateChatExecutionState boilerplate. +// The state machine writes status, archived, last_error, ownership +// identifiers, and the requires-action deadline as one atomic update. +type executionStateUpdate struct { + Status database.ChatStatus + Archived bool + WorkerID uuid.NullUUID + RunnerID uuid.NullUUID + LastError pqtype.NullRawMessage + RequiresActionDeadlineAt sql.NullTime +} + +func (tx *Tx) applyExecutionState(u executionStateUpdate) (database.Chat, error) { + return tx.store.UpdateChatExecutionState(tx.ctx, database.UpdateChatExecutionStateParams{ + ID: tx.chatID, + Status: u.Status, + Archived: u.Archived, + WorkerID: u.WorkerID, + RunnerID: u.RunnerID, + LastError: u.LastError, + RequiresActionDeadlineAt: u.RequiresActionDeadlineAt, + }) +} + +// insertMessages inserts the given Message batch under the current +// chat. The message-revision trigger handles history_version and +// generation_attempt bookkeeping automatically. +func (tx *Tx) insertMessages(messages []Message) ([]database.ChatMessage, error) { + if len(messages) == 0 { + return nil, nil + } + inserted, err := tx.store.InsertChatMessages(tx.ctx, toInsertParams(tx.chatID, messages)) + if err != nil { + return nil, xerrors.Errorf("insert messages: %w", err) + } + return inserted, nil +} + +// clearQueue deletes all queued messages on the chat and returns the +// IDs that were deleted in queue order. +func (tx *Tx) clearQueue() ([]int64, error) { + queued, err := tx.store.GetChatQueuedMessagesByPosition(tx.ctx, tx.chatID) + if err != nil { + return nil, xerrors.Errorf("get queued for clear: %w", err) + } + if len(queued) == 0 { + return nil, nil + } + if _, err := tx.store.DeleteAllChatQueuedMessagesReturningCount(tx.ctx, tx.chatID); err != nil { + return nil, xerrors.Errorf("delete queued: %w", err) + } + ids := make([]int64, len(queued)) + for i, q := range queued { + ids[i] = q.ID + } + return ids, nil +} + +// MaxQueueSize is the maximum number of queued user messages per chat. +// Queue-appending transitions reject inserts that would exceed this +// cap with a *MessageQueueFullError that wraps [ErrMessageQueueFull]. +const MaxQueueSize = 20 + +// requireQueueCapacity rejects the call when the chat already has +// MaxQueueSize queued messages. Queue-appending transitions invoke +// this helper inside the transaction immediately before inserting a +// new queued message so the check is atomic with the insert. +func (tx *Tx) requireQueueCapacity() error { + count, err := tx.store.CountChatQueuedMessages(tx.ctx, tx.chatID) + if err != nil { + return xerrors.Errorf("count queued messages: %w", err) + } + if count >= MaxQueueSize { + return &MessageQueueFullError{Max: MaxQueueSize} + } + return nil +} + +// insertQueuedMessage inserts a queued user message. created_by falls +// back to chats.owner_id only when the message does not supply one. +func (tx *Tx) insertQueuedMessage(ownerFallback uuid.UUID, m Message) (database.ChatQueuedMessage, error) { + createdBy := ownerFallback + if m.CreatedBy.Valid { + createdBy = m.CreatedBy.UUID + } + rawContent := m.Content.RawMessage + if !m.Content.Valid || len(rawContent) == 0 { + rawContent = json.RawMessage("null") + } + if err := tx.requireQueueCapacity(); err != nil { + return database.ChatQueuedMessage{}, err + } + return tx.store.InsertChatQueuedMessageWithCreator(tx.ctx, database.InsertChatQueuedMessageWithCreatorParams{ + ChatID: tx.chatID, + Content: rawContent, + ModelConfigID: m.ModelConfigID, + CreatedBy: createdBy, + }) +} + +// messageFromQueuedRow synthesizes a Message from a stored queued row, +// suitable for promoting into active history. +func messageFromQueuedRow(q database.ChatQueuedMessage) Message { + return Message{ + Role: database.ChatMessageRoleUser, + Content: pqtype.NullRawMessage{RawMessage: q.Content, Valid: q.Content != nil}, + Visibility: database.ChatMessageVisibilityBoth, + ModelConfigID: q.ModelConfigID, + CreatedBy: uuid.NullUUID{UUID: q.CreatedBy, Valid: true}, + ContentVersion: chatprompt.CurrentContentVersion, + } +} + +// ============================================================================= +// SetArchived +// ============================================================================= + +// SetArchivedInput configures [Tx.SetArchived]. +type SetArchivedInput struct { + Archived bool +} + +// SetArchivedResult is returned by [Tx.SetArchived]. +type SetArchivedResult struct{} + +// SetArchived sets or clears the chat's archived marker. +func (tx *Tx) SetArchived(input SetArchivedInput) (SetArchivedResult, error) { + chat, from, err := tx.requireFromAllowed(TransitionSetArchived) + if err != nil { + return SetArchivedResult{}, err + } + if input.Archived == chat.Archived { + // The matrix only allows SetArchived(true) from W/E0/E1 and + // SetArchived(false) from XW/XE0/XE1. A request whose Archived + // field already matches the chat's current archived flag is + // the wrong direction (or a no-op) and must be rejected so we + // do not silently roll the snapshot or publish a chat:update. + return SetArchivedResult{}, newTransitionError( + TransitionSetArchived, from, + "SetArchived input matches the current archived flag", + ) + } + if _, err := tx.applyExecutionState(executionStateUpdate{ + Status: chat.Status, + Archived: input.Archived, + WorkerID: chat.WorkerID, + RunnerID: chat.RunnerID, + LastError: chat.LastError, + RequiresActionDeadlineAt: chat.RequiresActionDeadlineAt, + }); err != nil { + return SetArchivedResult{}, xerrors.Errorf("update archive: %w", err) + } + return SetArchivedResult{}, nil +} + +// ============================================================================= +// SendMessage +// ============================================================================= + +// BusyBehavior controls how SendMessage behaves when the chat is +// currently busy (R*/I*/A*). From idle/error states the two behaviors +// are equivalent. +type BusyBehavior string + +const ( + BusyBehaviorQueue BusyBehavior = "queue" + BusyBehaviorInterrupt BusyBehavior = "interrupt" +) + +// SendMessageInput configures [Tx.SendMessage]. +type SendMessageInput struct { + Message Message + BusyBehavior BusyBehavior +} + +// SendMessageResult is returned by [Tx.SendMessage]. +type SendMessageResult struct { + InsertedMessages []database.ChatMessage + QueuedMessage *database.ChatQueuedMessage +} + +// SendMessage admits a new user message. Depending on input state and +// BusyBehavior, the message lands directly in history, in the queue, +// or replaces the queue head as part of a running-state promotion. +func (tx *Tx) SendMessage(input SendMessageInput) (SendMessageResult, error) { + chat, from, err := tx.requireFromAllowed(TransitionSendMessage) + if err != nil { + return SendMessageResult{}, err + } + if input.Message.Role != database.ChatMessageRoleUser { + return SendMessageResult{}, newTransitionError( + TransitionSendMessage, from, + "SendMessage requires a user message", + ) + } + switch input.BusyBehavior { + case BusyBehaviorQueue, BusyBehaviorInterrupt: + // ok + default: + // Reject unknown / empty BusyBehavior up front so an invalid + // value cannot fall through to the queue path on busy states + // or be silently ignored on idle states. The callers in chatd + // default empty to queue; chatstate is the lower-level API + // and refuses to guess. + return SendMessageResult{}, newTransitionError( + TransitionSendMessage, from, + "invalid BusyBehavior", + ) + } + switch from { + // Idle / empty-queue error: insert directly into history, clear + // last_error, leave queue alone. + case StateW, StateE0: + return tx.sendMessageDirect(chat, input.Message) + + // Error-with-queue: append to tail, promote previous head into + // history, clear last_error. + case StateE1: + return tx.sendMessageE1(chat, input.Message) + + // Running with no queue. + case StateR0: + if input.BusyBehavior == BusyBehaviorInterrupt { + return tx.sendMessageQueueAndSetStatus(chat, input.Message, database.ChatStatusInterrupting, chat.LastError, chat.RequiresActionDeadlineAt) + } + return tx.sendMessageQueueAndSetStatus(chat, input.Message, chat.Status, chat.LastError, chat.RequiresActionDeadlineAt) + + // Running with queue. + case StateR1: + if input.BusyBehavior == BusyBehaviorInterrupt { + return tx.sendMessageQueueAndSetStatus(chat, input.Message, database.ChatStatusInterrupting, chat.LastError, chat.RequiresActionDeadlineAt) + } + return tx.sendMessageQueueAndSetStatus(chat, input.Message, chat.Status, chat.LastError, chat.RequiresActionDeadlineAt) + + // Interrupting: queue regardless of busy behavior. + case StateI0, StateI1: + return tx.sendMessageQueueAndSetStatus(chat, input.Message, chat.Status, chat.LastError, chat.RequiresActionDeadlineAt) + + // Requires-action: queue keeps A*; interrupt cancels pending + // dynamic calls and resumes in running. + case StateA0, StateA1: + if input.BusyBehavior == BusyBehaviorInterrupt { + return tx.sendMessageInterruptRequiresAction(chat, input.Message) + } + return tx.sendMessageQueueAndSetStatus(chat, input.Message, chat.Status, chat.LastError, chat.RequiresActionDeadlineAt) + } + return SendMessageResult{}, newTransitionError(TransitionSendMessage, from, "unhandled state in SendMessage") +} + +func (tx *Tx) sendMessageDirect(chat database.Chat, m Message) (SendMessageResult, error) { + cancels, err := synthesizePendingToolCancellations(tx.ctx, tx.store, chat, "Tool execution interrupted by new user message", false) + if err != nil { + return SendMessageResult{}, err + } + inserted, err := tx.insertMessages(append(cancels, m)) + if err != nil { + return SendMessageResult{}, err + } + if _, err := tx.applyExecutionState(executionStateUpdate{ + Status: database.ChatStatusRunning, + Archived: false, + WorkerID: chat.WorkerID, + RunnerID: chat.RunnerID, + LastError: pqtype.NullRawMessage{}, + RequiresActionDeadlineAt: sql.NullTime{}, + }); err != nil { + return SendMessageResult{}, xerrors.Errorf("set running: %w", err) + } + return SendMessageResult{ + InsertedMessages: inserted, + }, nil +} + +func (tx *Tx) sendMessageE1(chat database.Chat, m Message) (SendMessageResult, error) { + queued, err := tx.insertQueuedMessage(chat.OwnerID, m) + if err != nil { + return SendMessageResult{}, xerrors.Errorf("insert queued: %w", err) + } + head, err := tx.store.GetChatQueuedMessageHead(tx.ctx, tx.chatID) + if err != nil { + return SendMessageResult{}, xerrors.Errorf("get queue head: %w", err) + } + cancels, err := synthesizePendingToolCancellations(tx.ctx, tx.store, chat, "Tool execution interrupted by queued message promotion", false) + if err != nil { + return SendMessageResult{}, err + } + promoted := messageFromQueuedRow(head) + inserted, err := tx.insertMessages(append(cancels, promoted)) + if err != nil { + return SendMessageResult{}, err + } + if _, err := tx.store.DeleteChatQueuedMessageReturningCount(tx.ctx, database.DeleteChatQueuedMessageReturningCountParams{ + ID: head.ID, + ChatID: tx.chatID, + }); err != nil { + return SendMessageResult{}, xerrors.Errorf("delete promoted queued head: %w", err) + } + if _, err := tx.applyExecutionState(executionStateUpdate{ + Status: database.ChatStatusRunning, + Archived: false, + WorkerID: chat.WorkerID, + RunnerID: chat.RunnerID, + LastError: pqtype.NullRawMessage{}, + RequiresActionDeadlineAt: sql.NullTime{}, + }); err != nil { + return SendMessageResult{}, xerrors.Errorf("set running: %w", err) + } + return SendMessageResult{ + InsertedMessages: inserted, + QueuedMessage: &queued, + }, nil +} + +func (tx *Tx) sendMessageQueueAndSetStatus( + chat database.Chat, + m Message, + status database.ChatStatus, + lastError pqtype.NullRawMessage, + deadline sql.NullTime, +) (SendMessageResult, error) { + queued, err := tx.insertQueuedMessage(chat.OwnerID, m) + if err != nil { + return SendMessageResult{}, xerrors.Errorf("insert queued: %w", err) + } + if _, err := tx.applyExecutionState(executionStateUpdate{ + Status: status, + Archived: false, + WorkerID: chat.WorkerID, + RunnerID: chat.RunnerID, + LastError: lastError, + RequiresActionDeadlineAt: deadline, + }); err != nil { + return SendMessageResult{}, xerrors.Errorf("update status: %w", err) + } + return SendMessageResult{ + QueuedMessage: &queued, + }, nil +} + +func (tx *Tx) sendMessageInterruptRequiresAction(chat database.Chat, m Message) (SendMessageResult, error) { + cancels, err := synthesizePendingToolCancellations(tx.ctx, tx.store, chat, "Tool execution interrupted by user message", true) + if err != nil { + return SendMessageResult{}, err + } + if _, err := tx.insertMessages(cancels); err != nil { + return SendMessageResult{}, err + } + return tx.sendMessageQueueAndSetStatus(chat, m, database.ChatStatusRunning, chat.LastError, sql.NullTime{}) +} + +// ============================================================================= +// EditMessage +// ============================================================================= + +// EditMessageInput configures [Tx.EditMessage]. +type EditMessageInput struct { + MessageID int64 + CreatedBy uuid.UUID + Content pqtype.NullRawMessage + ModelConfigIDOverride uuid.NullUUID +} + +// EditMessageResult is returned by [Tx.EditMessage]. +type EditMessageResult struct { + ReplacementMessage database.ChatMessage + DeletedMessageIDs []int64 + DeletedQueuedMessageIDs []int64 + CancellationMessages []database.ChatMessage +} + +// EditMessage replaces an earlier user message and discards the +// active-history suffix that followed it. +func (tx *Tx) EditMessage(input EditMessageInput) (EditMessageResult, error) { + chat, from, err := tx.requireFromAllowed(TransitionEditMessage) + if err != nil { + return EditMessageResult{}, err + } + target, err := tx.store.GetChatMessageByID(tx.ctx, input.MessageID) + if errors.Is(err, sql.ErrNoRows) { + return EditMessageResult{}, ErrMessageNotFound + } + if err != nil { + return EditMessageResult{}, xerrors.Errorf("get target message: %w", err) + } + if target.ChatID != tx.chatID { + return EditMessageResult{}, ErrMessageNotFound + } + if target.Deleted { + return EditMessageResult{}, ErrMessageNotFound + } + if target.Role != database.ChatMessageRoleUser { + return EditMessageResult{}, newTransitionErrorWithCause( + TransitionEditMessage, from, + ErrEditedMessageNotUser, + "only user messages can be edited", + ) + } + + suffix, err := tx.store.GetChatMessagesByChatID(tx.ctx, database.GetChatMessagesByChatIDParams{ + ChatID: tx.chatID, + AfterID: target.ID - 1, // include target and everything after + }) + if err != nil { + return EditMessageResult{}, xerrors.Errorf("get suffix messages: %w", err) + } + deletedIDs := make([]int64, 0, len(suffix)) + for _, m := range suffix { + if !m.Deleted { + deletedIDs = append(deletedIDs, m.ID) + } + } + + if err := tx.store.SoftDeleteChatMessageByID(tx.ctx, target.ID); err != nil { + return EditMessageResult{}, xerrors.Errorf("soft-delete target: %w", err) + } + if err := tx.store.SoftDeleteChatMessagesAfterID(tx.ctx, database.SoftDeleteChatMessagesAfterIDParams{ + ChatID: tx.chatID, + AfterID: target.ID, + }); err != nil { + return EditMessageResult{}, xerrors.Errorf("soft-delete suffix: %w", err) + } + + cancels, err := synthesizePendingToolCancellations(tx.ctx, tx.store, chat, "Tool execution interrupted by message edit", false) + if err != nil { + return EditMessageResult{}, err + } + cancellationMessages, err := tx.insertMessages(cancels) + if err != nil { + return EditMessageResult{}, err + } + + modelConfig := target.ModelConfigID + if input.ModelConfigIDOverride.Valid { + modelConfig = input.ModelConfigIDOverride + } + replacement := Message{ + Role: database.ChatMessageRoleUser, + Content: input.Content, + Visibility: target.Visibility, + ModelConfigID: modelConfig, + CreatedBy: uuid.NullUUID{UUID: input.CreatedBy, Valid: true}, + ContentVersion: chatprompt.CurrentContentVersion, + } + insertedReplacement, err := tx.insertMessages([]Message{replacement}) + if err != nil { + return EditMessageResult{}, err + } + var replacementRow database.ChatMessage + if len(insertedReplacement) == 1 { + replacementRow = insertedReplacement[0] + } + + deletedQueuedIDs, err := tx.clearQueue() + if err != nil { + return EditMessageResult{}, err + } + + if _, err := tx.applyExecutionState(executionStateUpdate{ + Status: database.ChatStatusRunning, + Archived: false, + WorkerID: chat.WorkerID, + RunnerID: chat.RunnerID, + LastError: pqtype.NullRawMessage{}, + RequiresActionDeadlineAt: sql.NullTime{}, + }); err != nil { + return EditMessageResult{}, xerrors.Errorf("set running: %w", err) + } + return EditMessageResult{ + ReplacementMessage: replacementRow, + DeletedMessageIDs: deletedIDs, + DeletedQueuedMessageIDs: deletedQueuedIDs, + CancellationMessages: cancellationMessages, + }, nil +} + +// ============================================================================= +// DeleteQueuedMessage +// ============================================================================= + +// DeleteQueuedMessageInput configures [Tx.DeleteQueuedMessage]. +type DeleteQueuedMessageInput struct { + QueuedMessageID int64 +} + +// DeleteQueuedMessageResult is returned by [Tx.DeleteQueuedMessage]. +type DeleteQueuedMessageResult struct { + DeletedQueuedMessage database.ChatQueuedMessage +} + +// DeleteQueuedMessage removes a single queued user message. +func (tx *Tx) DeleteQueuedMessage(input DeleteQueuedMessageInput) (DeleteQueuedMessageResult, error) { + _, _, err := tx.requireFromAllowed(TransitionDeleteQueuedMessage) + if err != nil { + return DeleteQueuedMessageResult{}, err + } + target, err := tx.store.GetChatQueuedMessageByID(tx.ctx, database.GetChatQueuedMessageByIDParams{ + ID: input.QueuedMessageID, + ChatID: tx.chatID, + }) + if errors.Is(err, sql.ErrNoRows) { + return DeleteQueuedMessageResult{}, ErrQueuedMessageNotFound + } + if err != nil { + return DeleteQueuedMessageResult{}, xerrors.Errorf("get queued: %w", err) + } + rows, err := tx.store.DeleteChatQueuedMessageReturningCount(tx.ctx, database.DeleteChatQueuedMessageReturningCountParams{ + ID: input.QueuedMessageID, + ChatID: tx.chatID, + }) + if err != nil { + return DeleteQueuedMessageResult{}, xerrors.Errorf("delete queued: %w", err) + } + if rows == 0 { + return DeleteQueuedMessageResult{}, ErrQueuedMessageNotFound + } + return DeleteQueuedMessageResult{ + DeletedQueuedMessage: target, + }, nil +} + +// ============================================================================= +// PromoteQueuedMessage +// ============================================================================= + +// PromoteQueuedMessageInput configures [Tx.PromoteQueuedMessage]. +type PromoteQueuedMessageInput struct { + QueuedMessageID int64 +} + +// PromoteQueuedMessageResult is returned by [Tx.PromoteQueuedMessage]. +type PromoteQueuedMessageResult struct { + QueuedMessage database.ChatQueuedMessage + InsertedMessage *database.ChatMessage + ReorderedQueueOnly bool + CancellationMessages []database.ChatMessage +} + +// PromoteQueuedMessage promotes the target queued message to the +// queue head; from E1/A1 it also pops it into active history. +func (tx *Tx) PromoteQueuedMessage(input PromoteQueuedMessageInput) (PromoteQueuedMessageResult, error) { + chat, from, err := tx.requireFromAllowed(TransitionPromoteQueuedMessage) + if err != nil { + return PromoteQueuedMessageResult{}, err + } + target, err := tx.store.GetChatQueuedMessageByID(tx.ctx, database.GetChatQueuedMessageByIDParams{ + ID: input.QueuedMessageID, + ChatID: tx.chatID, + }) + if errors.Is(err, sql.ErrNoRows) { + return PromoteQueuedMessageResult{}, ErrQueuedMessageNotFound + } + if err != nil { + return PromoteQueuedMessageResult{}, xerrors.Errorf("get queued: %w", err) + } + rows, err := tx.store.ReorderChatQueuedMessageToHead(tx.ctx, database.ReorderChatQueuedMessageToHeadParams{ + ID: input.QueuedMessageID, + ChatID: tx.chatID, + }) + if err != nil { + return PromoteQueuedMessageResult{}, xerrors.Errorf("reorder queue: %w", err) + } + reorderOnly := rows > 0 + + // R1/I1: leave the target at the queue head and transition to + // status `interrupting` so the worker can drain the in-flight + // generation before promoting the queue head into active history. + // No history row is inserted here and no queue rows are deleted. + if from == StateR1 || from == StateI1 { + if _, err := tx.applyExecutionState(executionStateUpdate{ + Status: database.ChatStatusInterrupting, + Archived: false, + WorkerID: chat.WorkerID, + RunnerID: chat.RunnerID, + LastError: chat.LastError, + RequiresActionDeadlineAt: chat.RequiresActionDeadlineAt, + }); err != nil { + return PromoteQueuedMessageResult{}, xerrors.Errorf("set interrupting: %w", err) + } + return PromoteQueuedMessageResult{ + QueuedMessage: target, + ReorderedQueueOnly: reorderOnly, + }, nil + } + + // E1/A1: synthesize cancellations, pop the head, insert into + // history, set running. Both paths insert a queued user message + // into active history, so every outstanding tool call must be + // closed (not just dynamic ones) to keep the LLM history valid. + cancels, err := synthesizePendingToolCancellations(tx.ctx, tx.store, chat, "Tool execution interrupted by queued message promotion", false) + if err != nil { + return PromoteQueuedMessageResult{}, err + } + promotedMsg := messageFromQueuedRow(target) + inserted, err := tx.insertMessages(append(cancels, promotedMsg)) + if err != nil { + return PromoteQueuedMessageResult{}, err + } + if _, err := tx.store.DeleteChatQueuedMessageReturningCount(tx.ctx, database.DeleteChatQueuedMessageReturningCountParams{ + ID: target.ID, + ChatID: tx.chatID, + }); err != nil { + return PromoteQueuedMessageResult{}, xerrors.Errorf("delete promoted queued: %w", err) + } + if _, err := tx.applyExecutionState(executionStateUpdate{ + Status: database.ChatStatusRunning, + Archived: false, + WorkerID: chat.WorkerID, + RunnerID: chat.RunnerID, + LastError: pqtype.NullRawMessage{}, + RequiresActionDeadlineAt: sql.NullTime{}, + }); err != nil { + return PromoteQueuedMessageResult{}, xerrors.Errorf("set running: %w", err) + } + cancellations := inserted[:len(inserted)-1] + insertedUserMsg := inserted[len(inserted)-1] + return PromoteQueuedMessageResult{ + QueuedMessage: target, + InsertedMessage: &insertedUserMsg, + CancellationMessages: cancellations, + ReorderedQueueOnly: reorderOnly, + }, nil +} + +// ============================================================================= +// Interrupt +// ============================================================================= + +// InterruptInput configures [Tx.Interrupt]. +type InterruptInput struct { + Reason string +} + +// InterruptResult is returned by [Tx.Interrupt]. +type InterruptResult struct { + CancellationMessages []database.ChatMessage +} + +// Interrupt requests interruption of an active or requires-action +// chat. +func (tx *Tx) Interrupt(input InterruptInput) (InterruptResult, error) { + chat, from, err := tx.requireFromAllowed(TransitionInterrupt) + if err != nil { + return InterruptResult{}, err + } + switch from { + case StateR0, StateR1: + if _, err := tx.applyExecutionState(executionStateUpdate{ + Status: database.ChatStatusInterrupting, + Archived: false, + WorkerID: chat.WorkerID, + RunnerID: chat.RunnerID, + LastError: chat.LastError, + RequiresActionDeadlineAt: chat.RequiresActionDeadlineAt, + }); err != nil { + return InterruptResult{}, xerrors.Errorf("set interrupting: %w", err) + } + return InterruptResult{}, nil + case StateA0, StateA1: + reason := input.Reason + if reason == "" { + reason = "Tool execution interrupted by user" + } + cancels, err := synthesizePendingToolCancellations(tx.ctx, tx.store, chat, reason, true) + if err != nil { + return InterruptResult{}, err + } + inserted, err := tx.insertMessages(cancels) + if err != nil { + return InterruptResult{}, err + } + if _, err := tx.applyExecutionState(executionStateUpdate{ + Status: database.ChatStatusRunning, + Archived: false, + WorkerID: chat.WorkerID, + RunnerID: chat.RunnerID, + LastError: chat.LastError, + RequiresActionDeadlineAt: sql.NullTime{}, + }); err != nil { + return InterruptResult{}, xerrors.Errorf("set running: %w", err) + } + return InterruptResult{ + CancellationMessages: inserted, + }, nil + } + return InterruptResult{}, newTransitionError(TransitionInterrupt, from, "unhandled state in Interrupt") +} + +// ============================================================================= +// CompleteRequiresAction +// ============================================================================= + +// ToolResultInput is one submitted dynamic-tool result. +type ToolResultInput struct { + ToolCallID string + Output json.RawMessage + IsError bool +} + +// CompleteRequiresActionInput configures [Tx.CompleteRequiresAction]. +type CompleteRequiresActionInput struct { + CreatedBy uuid.UUID + ModelConfigID uuid.UUID + Results []ToolResultInput +} + +// CompleteRequiresActionResult is returned by [Tx.CompleteRequiresAction]. +type CompleteRequiresActionResult struct { + InsertedMessages []database.ChatMessage +} + +// CompleteRequiresAction validates and stores user-submitted tool +// results that satisfy the chat's pending dynamic tool calls, then +// returns the chat to running. +func (tx *Tx) CompleteRequiresAction(input CompleteRequiresActionInput) (CompleteRequiresActionResult, error) { + chat, from, err := tx.requireFromAllowed(TransitionCompleteRequiresAction) + if err != nil { + return CompleteRequiresActionResult{}, err + } + pending, err := pendingDynamicToolCallIDs(tx.ctx, tx.store, chat) + if err != nil { + return CompleteRequiresActionResult{}, err + } + submitted := make(map[string]ToolResultInput, len(input.Results)) + for _, r := range input.Results { + if _, dup := submitted[r.ToolCallID]; dup { + return CompleteRequiresActionResult{}, newTransitionErrorWithCause( + TransitionCompleteRequiresAction, from, + &ToolResultValidationError{Cause: ErrToolResultDuplicate, ToolCallID: r.ToolCallID}, + "duplicate tool_call_id submitted", + ) + } + if !json.Valid(r.Output) { + return CompleteRequiresActionResult{}, newTransitionErrorWithCause( + TransitionCompleteRequiresAction, from, + &ToolResultValidationError{Cause: ErrToolResultInvalidJSON, ToolCallID: r.ToolCallID}, + "tool result output is not valid JSON", + ) + } + submitted[r.ToolCallID] = r + } + for id := range pending { + if _, ok := submitted[id]; !ok { + return CompleteRequiresActionResult{}, newTransitionErrorWithCause( + TransitionCompleteRequiresAction, from, + &ToolResultValidationError{Cause: ErrToolResultMissing, ToolCallID: id}, + "submitted tool results do not match pending tool calls", + ) + } + } + for id := range submitted { + if _, ok := pending[id]; !ok { + return CompleteRequiresActionResult{}, newTransitionErrorWithCause( + TransitionCompleteRequiresAction, from, + &ToolResultValidationError{Cause: ErrToolResultUnexpected, ToolCallID: id}, + "submitted tool_call_id does not match a pending dynamic tool call", + ) + } + } + messages := make([]Message, 0, len(input.Results)) + for _, r := range input.Results { + part := codersdk.ChatMessagePart{ + Type: codersdk.ChatMessagePartTypeToolResult, + ToolCallID: r.ToolCallID, + ToolName: pending[r.ToolCallID], + Result: r.Output, + IsError: r.IsError, + } + raw, merr := chatprompt.MarshalParts([]codersdk.ChatMessagePart{part}) + if merr != nil { + return CompleteRequiresActionResult{}, xerrors.Errorf("marshal tool result: %w", merr) + } + messages = append(messages, Message{ + Role: database.ChatMessageRoleTool, + Content: raw, + Visibility: database.ChatMessageVisibilityBoth, + CreatedBy: uuid.NullUUID{UUID: input.CreatedBy, Valid: true}, + ModelConfigID: uuid.NullUUID{UUID: input.ModelConfigID, Valid: true}, + ContentVersion: chatprompt.CurrentContentVersion, + }) + } + inserted, err := tx.insertMessages(messages) + if err != nil { + return CompleteRequiresActionResult{}, err + } + if _, err := tx.applyExecutionState(executionStateUpdate{ + Status: database.ChatStatusRunning, + Archived: false, + WorkerID: chat.WorkerID, + RunnerID: chat.RunnerID, + LastError: chat.LastError, + RequiresActionDeadlineAt: sql.NullTime{}, + }); err != nil { + return CompleteRequiresActionResult{}, xerrors.Errorf("set running: %w", err) + } + return CompleteRequiresActionResult{ + InsertedMessages: inserted, + }, nil +} + +// ============================================================================= +// Ownership transitions: Acquire, Abandon +// ============================================================================= + +// AcquireInput configures [Tx.Acquire]. +type AcquireInput struct { + WorkerID uuid.UUID + RunnerID uuid.UUID +} + +// AcquireResult is returned by [Tx.Acquire]. +type AcquireResult struct{} + +// Acquire claims the chat for a worker/runner pair. Execution state +// is preserved. +// +// Acquire never inspects the chat's current ownership: it simply +// overwrites worker_id/runner_id with the supplied identifiers and +// upserts a fresh heartbeat. Detecting and recovering from stale +// leases is a worker-side fence concern outside the state machine. +// Callers that need to coordinate takeovers with the previous owner +// must arrange that out-of-band before calling Acquire. +func (tx *Tx) Acquire(input AcquireInput) (AcquireResult, error) { + chat, _, err := tx.loadState() + if err != nil { + return AcquireResult{}, err + } + if _, err := tx.applyExecutionState(executionStateUpdate{ + Status: chat.Status, + Archived: chat.Archived, + WorkerID: uuid.NullUUID{UUID: input.WorkerID, Valid: true}, + RunnerID: uuid.NullUUID{UUID: input.RunnerID, Valid: true}, + LastError: chat.LastError, + RequiresActionDeadlineAt: chat.RequiresActionDeadlineAt, + }); err != nil { + return AcquireResult{}, xerrors.Errorf("set ownership: %w", err) + } + if err := tx.store.UpsertChatHeartbeat(tx.ctx, database.UpsertChatHeartbeatParams{ + ChatID: tx.chatID, + RunnerID: input.RunnerID, + }); err != nil { + return AcquireResult{}, xerrors.Errorf("upsert heartbeat: %w", err) + } + // Acquire writes a fresh heartbeat itself, so the post-commit + // ownership-hint logic in Update will evaluate the heartbeat as + // fresh and skip publishing a `chat:ownership` hint. + return AcquireResult{}, nil +} + +// AbandonInput is intentionally empty. Ownership-fence checks belong +// outside the transition in caller code that reads the locked row before +// invoking Abandon. +type AbandonInput struct{} + +// AbandonResult is returned by [Tx.Abandon]. +type AbandonResult struct{} + +// Abandon clears worker_id and runner_id from the locked chat row. It +// rejects calls when the chat is not currently owned (worker_id IS NULL). +// Callers that need to verify their own identity before abandoning should +// read the locked row through tx.Store() and compare values before +// invoking Abandon. +func (tx *Tx) Abandon(_ AbandonInput) (AbandonResult, error) { + chat, from, err := tx.loadState() + if err != nil { + return AbandonResult{}, err + } + if !chat.WorkerID.Valid { + return AbandonResult{}, newTransitionError(TransitionAbandon, from, "chat is not owned") + } + if _, err := tx.applyExecutionState(executionStateUpdate{ + Status: chat.Status, + Archived: chat.Archived, + WorkerID: uuid.NullUUID{}, + RunnerID: uuid.NullUUID{}, + LastError: chat.LastError, + RequiresActionDeadlineAt: chat.RequiresActionDeadlineAt, + }); err != nil { + return AbandonResult{}, xerrors.Errorf("clear ownership: %w", err) + } + return AbandonResult{}, nil +} + +// ============================================================================= +// Worker-only transitions on running chats +// ============================================================================= + +// RecordGenerationAttemptInput is intentionally empty. +type RecordGenerationAttemptInput struct{} + +// RecordGenerationAttemptResult is returned by [Tx.RecordGenerationAttempt]. +type RecordGenerationAttemptResult struct { + GenerationAttempt int64 +} + +// RecordGenerationAttempt durably records that the worker is +// attempting another generation under the current history version. +func (tx *Tx) RecordGenerationAttempt(_ RecordGenerationAttemptInput) (RecordGenerationAttemptResult, error) { + _, _, err := tx.requireFromAllowed(TransitionRecordGenerationAttempt) + if err != nil { + return RecordGenerationAttemptResult{}, err + } + value, err := tx.store.IncrementChatGenerationAttempt(tx.ctx, tx.chatID) + if err != nil { + return RecordGenerationAttemptResult{}, xerrors.Errorf("increment generation attempt: %w", err) + } + return RecordGenerationAttemptResult{ + GenerationAttempt: value, + }, nil +} + +// RecordRetryStateInput configures [Tx.RecordRetryState]. +type RecordRetryStateInput struct { + RetryState pqtype.NullRawMessage +} + +// RecordRetryStateResult is returned by [Tx.RecordRetryState]. +type RecordRetryStateResult struct { + Chat database.Chat +} + +// RecordRetryState stores the client-visible retry payload for the +// current generation attempt. +func (tx *Tx) RecordRetryState(input RecordRetryStateInput) (RecordRetryStateResult, error) { + _, from, err := tx.requireFromAllowed(TransitionRecordRetryState) + if err != nil { + return RecordRetryStateResult{}, err + } + if !input.RetryState.Valid || len(input.RetryState.RawMessage) == 0 { + return RecordRetryStateResult{}, newTransitionError( + TransitionRecordRetryState, from, + "RecordRetryState requires a retry payload", + ) + } + if !json.Valid(input.RetryState.RawMessage) { + return RecordRetryStateResult{}, newTransitionError( + TransitionRecordRetryState, from, + "retry payload is not valid JSON", + ) + } + chat, err := tx.store.UpdateChatRetryState(tx.ctx, database.UpdateChatRetryStateParams{ + ID: tx.chatID, + RetryState: input.RetryState.RawMessage, + }) + if err != nil { + return RecordRetryStateResult{}, xerrors.Errorf("update retry state: %w", err) + } + return RecordRetryStateResult{Chat: chat}, nil +} + +// CommitStepInput configures [Tx.CommitStep]. +type CommitStepInput struct { + Messages []Message +} + +// CommitStepResult is returned by [Tx.CommitStep]. +type CommitStepResult struct { + InsertedMessages []database.ChatMessage +} + +// CommitStep stores one durable message suffix while remaining +// running. +func (tx *Tx) CommitStep(input CommitStepInput) (CommitStepResult, error) { + _, from, err := tx.requireFromAllowed(TransitionCommitStep) + if err != nil { + return CommitStepResult{}, err + } + if len(input.Messages) == 0 { + return CommitStepResult{}, newTransitionError( + TransitionCommitStep, from, + "CommitStep requires at least one message", + ) + } + inserted, err := tx.insertMessages(input.Messages) + if err != nil { + return CommitStepResult{}, err + } + return CommitStepResult{ + InsertedMessages: inserted, + }, nil +} + +// ============================================================================= +// EnterRequiresAction +// ============================================================================= + +// EnterRequiresActionInput is intentionally empty. +type EnterRequiresActionInput struct{} + +// EnterRequiresActionResult is returned by [Tx.EnterRequiresAction]. +type EnterRequiresActionResult struct { + RequiresActionDeadlineAt sql.NullTime +} + +// EnterRequiresAction parks the chat in requires_action with a +// database-time deadline of now() + 5 minutes. +func (tx *Tx) EnterRequiresAction(_ EnterRequiresActionInput) (EnterRequiresActionResult, error) { + chat, from, err := tx.requireFromAllowed(TransitionEnterRequiresAction) + if err != nil { + return EnterRequiresActionResult{}, err + } + pending, err := pendingDynamicToolCallIDs(tx.ctx, tx.store, chat) + if err != nil { + return EnterRequiresActionResult{}, err + } + if len(pending) == 0 { + return EnterRequiresActionResult{}, newTransitionError( + TransitionEnterRequiresAction, from, + "no pending dynamic tool calls", + ) + } + now, err := tx.store.GetDatabaseNow(tx.ctx) + if err != nil { + return EnterRequiresActionResult{}, xerrors.Errorf("get db now: %w", err) + } + deadline := sql.NullTime{Time: now.Add(5 * time.Minute), Valid: true} + if _, err := tx.applyExecutionState(executionStateUpdate{ + Status: database.ChatStatusRequiresAction, + Archived: false, + WorkerID: chat.WorkerID, + RunnerID: chat.RunnerID, + LastError: chat.LastError, + RequiresActionDeadlineAt: deadline, + }); err != nil { + return EnterRequiresActionResult{}, xerrors.Errorf("set requires_action: %w", err) + } + return EnterRequiresActionResult{ + RequiresActionDeadlineAt: deadline, + }, nil +} + +// ============================================================================= +// FinishInterruption +// ============================================================================= + +// FinishInterruptionInput configures [Tx.FinishInterruption]. +type FinishInterruptionInput struct { + PartialMessages []Message +} + +// FinishInterruptionResult is returned by [Tx.FinishInterruption]. +type FinishInterruptionResult struct { + InsertedMessages []database.ChatMessage + PromotedMessage *database.ChatMessage +} + +// FinishInterruption commits an optional partial assistant/tool suffix +// and lands the chat in waiting (I0) or running with the next queued +// message promoted (I1). +func (tx *Tx) FinishInterruption(input FinishInterruptionInput) (FinishInterruptionResult, error) { + chat, from, err := tx.requireFromAllowed(TransitionFinishInterruption) + if err != nil { + return FinishInterruptionResult{}, err + } + insertedPartial, err := tx.insertMessages(input.PartialMessages) + if err != nil { + return FinishInterruptionResult{}, err + } + pendingAll, err := pendingAllToolCallIDs(tx.ctx, tx.store, chat) + if err != nil { + return FinishInterruptionResult{}, err + } + if len(pendingAll) > 0 { + return FinishInterruptionResult{}, newTransitionError( + TransitionFinishInterruption, from, + "outstanding tool calls remain after partial commit", + ) + } + + if from == StateI0 { + if _, err := tx.applyExecutionState(executionStateUpdate{ + Status: database.ChatStatusWaiting, + Archived: false, + WorkerID: chat.WorkerID, + RunnerID: chat.RunnerID, + LastError: chat.LastError, + RequiresActionDeadlineAt: sql.NullTime{}, + }); err != nil { + return FinishInterruptionResult{}, xerrors.Errorf("set waiting: %w", err) + } + return FinishInterruptionResult{ + InsertedMessages: insertedPartial, + }, nil + } + + // I1: promote queue head into history. + head, err := tx.store.GetChatQueuedMessageHead(tx.ctx, tx.chatID) + if err != nil { + return FinishInterruptionResult{}, xerrors.Errorf("get queue head: %w", err) + } + promotedMsg := messageFromQueuedRow(head) + insertedHead, err := tx.insertMessages([]Message{promotedMsg}) + if err != nil { + return FinishInterruptionResult{}, err + } + if _, err := tx.store.DeleteChatQueuedMessageReturningCount(tx.ctx, database.DeleteChatQueuedMessageReturningCountParams{ + ID: head.ID, + ChatID: tx.chatID, + }); err != nil { + return FinishInterruptionResult{}, xerrors.Errorf("delete promoted head: %w", err) + } + if _, err := tx.applyExecutionState(executionStateUpdate{ + Status: database.ChatStatusRunning, + Archived: false, + WorkerID: chat.WorkerID, + RunnerID: chat.RunnerID, + LastError: chat.LastError, + RequiresActionDeadlineAt: sql.NullTime{}, + }); err != nil { + return FinishInterruptionResult{}, xerrors.Errorf("set running: %w", err) + } + all := append([]database.ChatMessage{}, insertedPartial...) + all = append(all, insertedHead...) + var promoted *database.ChatMessage + if len(insertedHead) == 1 { + promoted = &insertedHead[0] + } + return FinishInterruptionResult{ + InsertedMessages: all, + PromotedMessage: promoted, + }, nil +} + +// ============================================================================= +// FinishTurn +// ============================================================================= + +// FinishTurnInput is intentionally empty. +type FinishTurnInput struct{} + +// FinishTurnResult is returned by [Tx.FinishTurn]. +type FinishTurnResult struct { + PromotedMessage *database.ChatMessage +} + +// FinishTurn completes a running turn. +func (tx *Tx) FinishTurn(_ FinishTurnInput) (FinishTurnResult, error) { + chat, from, err := tx.requireFromAllowed(TransitionFinishTurn) + if err != nil { + return FinishTurnResult{}, err + } + if from == StateR0 { + if _, err := tx.applyExecutionState(executionStateUpdate{ + Status: database.ChatStatusWaiting, + Archived: false, + WorkerID: chat.WorkerID, + RunnerID: chat.RunnerID, + LastError: chat.LastError, + RequiresActionDeadlineAt: sql.NullTime{}, + }); err != nil { + return FinishTurnResult{}, xerrors.Errorf("set waiting: %w", err) + } + return FinishTurnResult{}, nil + } + // R1. + head, err := tx.store.GetChatQueuedMessageHead(tx.ctx, tx.chatID) + if err != nil { + return FinishTurnResult{}, xerrors.Errorf("get queue head: %w", err) + } + cancels, err := synthesizePendingToolCancellations(tx.ctx, tx.store, chat, "Tool execution interrupted by queued message promotion", false) + if err != nil { + return FinishTurnResult{}, err + } + promotedMsg := messageFromQueuedRow(head) + inserted, err := tx.insertMessages(append(cancels, promotedMsg)) + if err != nil { + return FinishTurnResult{}, err + } + if _, err := tx.store.DeleteChatQueuedMessageReturningCount(tx.ctx, database.DeleteChatQueuedMessageReturningCountParams{ + ID: head.ID, + ChatID: tx.chatID, + }); err != nil { + return FinishTurnResult{}, xerrors.Errorf("delete promoted head: %w", err) + } + if _, err := tx.applyExecutionState(executionStateUpdate{ + Status: database.ChatStatusRunning, + Archived: false, + WorkerID: chat.WorkerID, + RunnerID: chat.RunnerID, + LastError: chat.LastError, + RequiresActionDeadlineAt: sql.NullTime{}, + }); err != nil { + return FinishTurnResult{}, xerrors.Errorf("set running: %w", err) + } + var promoted *database.ChatMessage + if len(inserted) > 0 { + promoted = &inserted[len(inserted)-1] + } + return FinishTurnResult{ + PromotedMessage: promoted, + }, nil +} + +// ============================================================================= +// FinishError +// ============================================================================= + +// FinishErrorInput configures [Tx.FinishError]. +type FinishErrorInput struct { + LastError pqtype.NullRawMessage +} + +// FinishErrorResult is returned by [Tx.FinishError]. +type FinishErrorResult struct{} + +// FinishError parks the chat in error with the supplied last_error. +func (tx *Tx) FinishError(input FinishErrorInput) (FinishErrorResult, error) { + chat, _, err := tx.requireFromAllowed(TransitionFinishError) + if err != nil { + return FinishErrorResult{}, err + } + if _, err := tx.applyExecutionState(executionStateUpdate{ + Status: database.ChatStatusError, + Archived: false, + WorkerID: chat.WorkerID, + RunnerID: chat.RunnerID, + LastError: input.LastError, + RequiresActionDeadlineAt: sql.NullTime{}, + }); err != nil { + return FinishErrorResult{}, xerrors.Errorf("set error: %w", err) + } + return FinishErrorResult{}, nil +} + +// ============================================================================= +// CancelRequiresAction +// ============================================================================= + +// CancelRequiresActionInput configures [Tx.CancelRequiresAction]. +type CancelRequiresActionInput struct { + Reason string +} + +// CancelRequiresActionResult is returned by [Tx.CancelRequiresAction]. +type CancelRequiresActionResult struct { + CancellationMessages []database.ChatMessage +} + +// CancelRequiresAction synthesizes cancellation results for every +// pending dynamic tool call and returns the chat to running. +func (tx *Tx) CancelRequiresAction(input CancelRequiresActionInput) (CancelRequiresActionResult, error) { + chat, from, err := tx.requireFromAllowed(TransitionCancelRequiresAction) + if err != nil { + return CancelRequiresActionResult{}, err + } + reason := input.Reason + if reason == "" { + reason = "Tool execution timed out" + } + cancels, err := synthesizePendingToolCancellations(tx.ctx, tx.store, chat, reason, true) + if err != nil { + return CancelRequiresActionResult{}, err + } + if len(cancels) == 0 { + return CancelRequiresActionResult{}, newTransitionError( + TransitionCancelRequiresAction, from, + "no pending dynamic tool calls to cancel", + ) + } + inserted, err := tx.insertMessages(cancels) + if err != nil { + return CancelRequiresActionResult{}, err + } + if _, err := tx.applyExecutionState(executionStateUpdate{ + Status: database.ChatStatusRunning, + Archived: false, + WorkerID: chat.WorkerID, + RunnerID: chat.RunnerID, + LastError: chat.LastError, + RequiresActionDeadlineAt: sql.NullTime{}, + }); err != nil { + return CancelRequiresActionResult{}, xerrors.Errorf("set running: %w", err) + } + return CancelRequiresActionResult{ + CancellationMessages: inserted, + }, nil +} + +// ============================================================================= +// ReconcileInvalidState +// ============================================================================= + +// ReconcileInvalidStateInput configures [Tx.ReconcileInvalidState]. +type ReconcileInvalidStateInput struct { + LastError pqtype.NullRawMessage + CancellationReason string +} + +// ReconcileInvalidStateResult is returned by [Tx.ReconcileInvalidState]. +type ReconcileInvalidStateResult struct { + CancellationMessages []database.ChatMessage +} + +// ReconcileInvalidState moves an invalid execution-state combination +// into a valid error state. Queued messages are preserved; pending +// dynamic-tool calls are closed with synthetic cancellation results. +func (tx *Tx) ReconcileInvalidState(input ReconcileInvalidStateInput) (ReconcileInvalidStateResult, error) { + chat, from, err := tx.loadState() + if err != nil { + return ReconcileInvalidStateResult{}, err + } + if from != StateInvalid { + return ReconcileInvalidStateResult{}, newTransitionError( + TransitionReconcileInvalidState, from, + "reconcile is only valid for invalid states", + ) + } + reason := input.CancellationReason + if reason == "" { + reason = "Tool execution canceled due to invalid chat state" + } + cancels, err := synthesizePendingToolCancellations(tx.ctx, tx.store, chat, reason, true) + if err != nil { + return ReconcileInvalidStateResult{}, err + } + var inserted []database.ChatMessage + if len(cancels) > 0 { + inserted, err = tx.insertMessages(cancels) + if err != nil { + return ReconcileInvalidStateResult{}, err + } + } + lastErr := input.LastError + if !lastErr.Valid { + lastErr = pqtype.NullRawMessage{ + RawMessage: json.RawMessage(`{"message":"chat was in an invalid state; send a new message or edit history to continue"}`), + Valid: true, + } + } + if _, err := tx.applyExecutionState(executionStateUpdate{ + Status: database.ChatStatusError, + Archived: false, + WorkerID: chat.WorkerID, + RunnerID: chat.RunnerID, + LastError: lastErr, + RequiresActionDeadlineAt: sql.NullTime{}, + }); err != nil { + return ReconcileInvalidStateResult{}, xerrors.Errorf("set error: %w", err) + } + return ReconcileInvalidStateResult{ + CancellationMessages: inserted, + }, nil +} diff --git a/coderd/x/chatd/chatstate/transitions_helpers_test.go b/coderd/x/chatd/chatstate/transitions_helpers_test.go new file mode 100644 index 0000000000..5eb818bfcc --- /dev/null +++ b/coderd/x/chatd/chatstate/transitions_helpers_test.go @@ -0,0 +1,907 @@ +package chatstate_test + +import ( + "context" + "encoding/json" + "fmt" + "testing" + + "github.com/google/uuid" + "github.com/sqlc-dev/pqtype" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/coderd/database" + coderdpubsub "github.com/coder/coder/v2/coderd/pubsub" + "github.com/coder/coder/v2/coderd/x/chatd/chatprompt" + "github.com/coder/coder/v2/coderd/x/chatd/chatstate" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/testutil" +) + +// ============================================================================= +// State seeding helpers +// +// seededChat is the shared output of seedState. Some transition tests +// need extra context beyond the chat ID (for example, the queued +// message ID to delete, or the message ID to edit), so this struct +// surfaces what each state was seeded with. +// ============================================================================= + +type seededChat struct { + chatID uuid.UUID + exists bool + initialUserMessageID int64 + assistantToolCallMsgID int64 + queuedMessageIDs []int64 + // queuedMessageBodies is parallel to queuedMessageIDs and records + // the text body each queued message was seeded with. Cases that + // promote queued messages into history use this to assert the + // promoted message content matches what was originally queued. + queuedMessageBodies []string + queuedMessageCreatedBy []uuid.UUID + dynamicToolName string + pendingToolCallID string + pendingToolCallIDs []string +} + +// dynamicToolJSON returns the canonical [{name,description,input_schema}] +// payload used to seed dynamic_tools on a chat. Tests that need +// pending dynamic tool calls (A0, A1) reuse this and reference the +// returned tool name in their assistant tool-call message. +func dynamicToolJSON(name string) []byte { + tools := []codersdk.DynamicTool{{ + Name: name, + Description: "test tool", + InputSchema: json.RawMessage(`{"type":"object","properties":{}}`), + }} + raw, err := json.Marshal(tools) + if err != nil { + panic(err) + } + return raw +} + +// assistantToolCallMessage builds a chatstate.Message for an +// assistant message that issues one tool call against the supplied +// dynamic tool name. The tool-call ID is unique per call so multiple +// messages do not collide. +func assistantToolCallMessage(t *testing.T, modelID uuid.UUID, toolName, callID string) chatstate.Message { + t.Helper() + raw, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{{ + Type: codersdk.ChatMessagePartTypeToolCall, + ToolCallID: callID, + ToolName: toolName, + Args: json.RawMessage(`{}`), + }}) + require.NoError(t, err) + return chatstate.Message{ + Role: database.ChatMessageRoleAssistant, + Content: raw, + Visibility: database.ChatMessageVisibilityBoth, + ContentVersion: chatprompt.CurrentContentVersion, + ModelConfigID: uuid.NullUUID{UUID: modelID, Valid: true}, + } +} + +func mixedAssistantToolCallMessage(t *testing.T, modelID uuid.UUID, dynamicTool, dynCallID, nonDynCallID string) chatstate.Message { + t.Helper() + parts := []codersdk.ChatMessagePart{ + { + Type: codersdk.ChatMessagePartTypeToolCall, + ToolCallID: dynCallID, + ToolName: dynamicTool, + Args: json.RawMessage(`{}`), + }, + { + Type: codersdk.ChatMessagePartTypeToolCall, + ToolCallID: nonDynCallID, + ToolName: "non_dynamic_tool", + Args: json.RawMessage(`{}`), + }, + } + raw, err := chatprompt.MarshalParts(parts) + require.NoError(t, err) + return chatstate.Message{ + Role: database.ChatMessageRoleAssistant, + Content: raw, + Visibility: database.ChatMessageVisibilityBoth, + ContentVersion: chatprompt.CurrentContentVersion, + ModelConfigID: uuid.NullUUID{UUID: modelID, Valid: true}, + } +} + +// createTestChatWithDynamicTools mirrors createTestChat but seeds the +// chat with a non-empty dynamic_tools blob so EnterRequiresAction, +// CompleteRequiresAction, and CancelRequiresAction can find pending +// dynamic tool calls. +func createTestChatWithDynamicTools(t *testing.T, f *testFixture, toolName string) chatstate.CreateChatResult { + t.Helper() + ctx := testutil.Context(t, testutil.WaitShort) + res, err := chatstate.CreateChat(ctx, f.DB, f.Pub, chatstate.CreateChatInput{ + OrganizationID: f.Org.ID, + OwnerID: f.User.ID, + LastModelConfigID: f.Model.ID, + Title: "test", + ClientType: database.ChatClientTypeApi, + DynamicTools: pqtype.NullRawMessage{ + RawMessage: dynamicToolJSON(toolName), + Valid: true, + }, + InitialMessages: []chatstate.Message{ + userTextMessage("hello", f.User.ID, f.Model.ID), + }, + }) + require.NoError(t, err) + return res +} + +// seedAOrA1 seeds a chat into A0 (queuedExtras=0) or A1 +// (queuedExtras>=1) with a real pending dynamic tool call. Used by +// cases that need A0 or A1 with a configurable queue cardinality. +func seedAOrA1(t *testing.T, f *testFixture, queuedExtras int, namePrefix string) seededChat { + t.Helper() + ctx := testutil.Context(t, testutil.WaitShort) + toolName := namePrefix + callID := "call_" + uuid.NewString() + created := createTestChatWithDynamicTools(t, f, toolName) + m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{}) + var step chatstate.CommitStepResult + require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error { + var err error + step, err = tx.CommitStep(chatstate.CommitStepInput{ + Messages: []chatstate.Message{ + assistantToolCallMessage(t, f.Model.ID, toolName, callID), + }, + }) + return err + })) + require.Len(t, step.InsertedMessages, 1) + // R0 -> A0. + require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error { + _, err := tx.EnterRequiresAction(chatstate.EnterRequiresActionInput{}) + return err + })) + var ( + queuedIDs []int64 + queuedBodies []string + ) + for i := 0; i < queuedExtras; i++ { + body := fmt.Sprintf("queued-%s-%d", namePrefix, i) + sm := sendQueuedMessage(t, f, m, body) + require.NotNil(t, sm.QueuedMessage) + queuedIDs = append(queuedIDs, sm.QueuedMessage.ID) + queuedBodies = append(queuedBodies, body) + } + return seededChat{ + chatID: created.Chat.ID, + exists: true, + initialUserMessageID: firstUserMessageID(ctx, t, f, created.Chat.ID), + assistantToolCallMsgID: step.InsertedMessages[0].ID, + queuedMessageIDs: queuedIDs, + queuedMessageBodies: queuedBodies, + dynamicToolName: toolName, + pendingToolCallID: callID, + } +} + +// seedState seeds a chat into the supplied execution state and +// returns identifying handles useful for downstream assertions. For +// [chatstate.StateN] the returned chatID is a fresh UUID that does +// not exist in the database. Multi-queued seeds (for E1, R1, I1, +// A1 with 2 queued messages, and Invalid with a non-empty queue) live in +// seedStateMultiQueued. +func seedState(t *testing.T, f *testFixture, state chatstate.ExecutionState) seededChat { + t.Helper() + ctx := testutil.Context(t, testutil.WaitShort) + + switch state { + case chatstate.StateN: + return seededChat{chatID: uuid.New(), exists: false} + + case chatstate.StateR0: + created := createTestChat(t, f) + initial := firstUserMessageID(ctx, t, f, created.Chat.ID) + return seededChat{ + chatID: created.Chat.ID, + exists: true, + initialUserMessageID: initial, + } + + case chatstate.StateW: + created := createTestChat(t, f) + m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{}) + require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error { + _, err := tx.FinishTurn(chatstate.FinishTurnInput{}) + return err + })) + return seededChat{ + chatID: created.Chat.ID, + exists: true, + initialUserMessageID: firstUserMessageID(ctx, t, f, created.Chat.ID), + } + + case chatstate.StateE0: + created := createTestChat(t, f) + m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{}) + require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error { + _, err := tx.FinishError(chatstate.FinishErrorInput{ + LastError: pqtype.NullRawMessage{ + RawMessage: json.RawMessage(`{"message":"boom"}`), + Valid: true, + }, + }) + return err + })) + return seededChat{ + chatID: created.Chat.ID, + exists: true, + initialUserMessageID: firstUserMessageID(ctx, t, f, created.Chat.ID), + } + + case chatstate.StateE1: + created := createTestChat(t, f) + m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{}) + // R0 -> R1 + queuedBody := "queued-for-E1" + queued := sendQueuedMessage(t, f, m, queuedBody) + require.NotNil(t, queued.QueuedMessage) + // R1 -> E1 + require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error { + _, err := tx.FinishError(chatstate.FinishErrorInput{ + LastError: pqtype.NullRawMessage{ + RawMessage: json.RawMessage(`{"message":"boom"}`), + Valid: true, + }, + }) + return err + })) + return seededChat{ + chatID: created.Chat.ID, + exists: true, + initialUserMessageID: firstUserMessageID(ctx, t, f, created.Chat.ID), + queuedMessageIDs: []int64{queued.QueuedMessage.ID}, + queuedMessageBodies: []string{queuedBody}, + } + + case chatstate.StateR1: + created := createTestChat(t, f) + m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{}) + queuedBody := "queued-for-R1" + queued := sendQueuedMessage(t, f, m, queuedBody) + require.NotNil(t, queued.QueuedMessage) + return seededChat{ + chatID: created.Chat.ID, + exists: true, + initialUserMessageID: firstUserMessageID(ctx, t, f, created.Chat.ID), + queuedMessageIDs: []int64{queued.QueuedMessage.ID}, + queuedMessageBodies: []string{queuedBody}, + } + + case chatstate.StateI0: + created := createTestChat(t, f) + m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{}) + require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error { + _, err := tx.Interrupt(chatstate.InterruptInput{Reason: "seed"}) + return err + })) + return seededChat{ + chatID: created.Chat.ID, + exists: true, + initialUserMessageID: firstUserMessageID(ctx, t, f, created.Chat.ID), + } + + case chatstate.StateI1: + created := createTestChat(t, f) + m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{}) + // R0 -> I1: SendMessage with interrupt behavior queues the + // message and sets status to interrupting. + queuedBody := "queued-for-I1" + sm := sendInterruptMessage(t, f, m, queuedBody) + require.NotNil(t, sm.QueuedMessage) + return seededChat{ + chatID: created.Chat.ID, + exists: true, + initialUserMessageID: firstUserMessageID(ctx, t, f, created.Chat.ID), + queuedMessageIDs: []int64{sm.QueuedMessage.ID}, + queuedMessageBodies: []string{queuedBody}, + } + + case chatstate.StateA0: + return seedAOrA1(t, f, 0, "seed_tool_a0") + + case chatstate.StateA1: + return seedAOrA1(t, f, 1, "seed_tool_a1") + + case chatstate.StateXW: + created := createTestChat(t, f) + m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{}) + require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error { + _, err := tx.FinishTurn(chatstate.FinishTurnInput{}) + return err + })) + require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error { + _, err := tx.SetArchived(chatstate.SetArchivedInput{Archived: true}) + return err + })) + return seededChat{ + chatID: created.Chat.ID, + exists: true, + initialUserMessageID: firstUserMessageID(ctx, t, f, created.Chat.ID), + } + + case chatstate.StateXE0: + created := createTestChat(t, f) + m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{}) + require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error { + _, err := tx.FinishError(chatstate.FinishErrorInput{ + LastError: pqtype.NullRawMessage{ + RawMessage: json.RawMessage(`{"message":"boom"}`), + Valid: true, + }, + }) + return err + })) + require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error { + _, err := tx.SetArchived(chatstate.SetArchivedInput{Archived: true}) + return err + })) + return seededChat{ + chatID: created.Chat.ID, + exists: true, + initialUserMessageID: firstUserMessageID(ctx, t, f, created.Chat.ID), + } + + case chatstate.StateXE1: + created := createTestChat(t, f) + m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{}) + queuedBody := "queued-for-XE1" + queued := sendQueuedMessage(t, f, m, queuedBody) + require.NotNil(t, queued.QueuedMessage) + require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error { + _, err := tx.FinishError(chatstate.FinishErrorInput{ + LastError: pqtype.NullRawMessage{ + RawMessage: json.RawMessage(`{"message":"boom"}`), + Valid: true, + }, + }) + return err + })) + require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error { + _, err := tx.SetArchived(chatstate.SetArchivedInput{Archived: true}) + return err + })) + return seededChat{ + chatID: created.Chat.ID, + exists: true, + initialUserMessageID: firstUserMessageID(ctx, t, f, created.Chat.ID), + queuedMessageIDs: []int64{queued.QueuedMessage.ID}, + queuedMessageBodies: []string{queuedBody}, + } + + case chatstate.StateInvalid: + created := createTestChat(t, f) + // Force running + archived, a deliberately invalid + // combination per the classifier. + _, err := f.DB.UpdateChatExecutionState(ctx, database.UpdateChatExecutionStateParams{ + ID: created.Chat.ID, + Status: database.ChatStatusRunning, + Archived: true, + WorkerID: created.Chat.WorkerID, + RunnerID: created.Chat.RunnerID, + LastError: created.Chat.LastError, + RequiresActionDeadlineAt: created.Chat.RequiresActionDeadlineAt, + }) + require.NoError(t, err) + return seededChat{ + chatID: created.Chat.ID, + exists: true, + initialUserMessageID: firstUserMessageID(ctx, t, f, created.Chat.ID), + } + } + t.Fatalf("seedState: unsupported execution state %s", state) + return seededChat{} +} + +// seedStateMultiQueued seeds a state with two queued messages. Used +// by cases that need the post-mutation queue to remain non-empty. +// Supported states: E1, R1, I1, A1. +func seedStateMultiQueued(t *testing.T, f *testFixture, state chatstate.ExecutionState) seededChat { + t.Helper() + ctx := testutil.Context(t, testutil.WaitShort) + switch state { + case chatstate.StateE1: + created := createTestChat(t, f) + m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{}) + firstBody := "queued-e1-a" + first := sendQueuedMessage(t, f, m, firstBody) + require.NotNil(t, first.QueuedMessage) + secondBody := "queued-e1-b" + second := sendQueuedMessage(t, f, m, secondBody) + require.NotNil(t, second.QueuedMessage) + require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error { + _, err := tx.FinishError(chatstate.FinishErrorInput{ + LastError: pqtype.NullRawMessage{ + RawMessage: json.RawMessage(`{"message":"boom"}`), + Valid: true, + }, + }) + return err + })) + return seededChat{ + chatID: created.Chat.ID, + exists: true, + initialUserMessageID: firstUserMessageID(ctx, t, f, created.Chat.ID), + queuedMessageIDs: []int64{first.QueuedMessage.ID, second.QueuedMessage.ID}, + queuedMessageBodies: []string{firstBody, secondBody}, + } + + case chatstate.StateR1: + created := createTestChat(t, f) + m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{}) + firstBody := "queued-r1-a" + first := sendQueuedMessage(t, f, m, firstBody) + require.NotNil(t, first.QueuedMessage) + secondBody := "queued-r1-b" + second := sendQueuedMessage(t, f, m, secondBody) + require.NotNil(t, second.QueuedMessage) + return seededChat{ + chatID: created.Chat.ID, + exists: true, + initialUserMessageID: firstUserMessageID(ctx, t, f, created.Chat.ID), + queuedMessageIDs: []int64{first.QueuedMessage.ID, second.QueuedMessage.ID}, + queuedMessageBodies: []string{firstBody, secondBody}, + } + + case chatstate.StateI1: + created := createTestChat(t, f) + m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{}) + firstBody := "queued-i1-a" + first := sendQueuedMessage(t, f, m, firstBody) + require.NotNil(t, first.QueuedMessage) + // R1 -> I1 via interrupt-mode SendMessage queues a second + // message and flips status to interrupting. + secondBody := "queued-i1-b" + second := sendInterruptMessage(t, f, m, secondBody) + require.NotNil(t, second.QueuedMessage) + return seededChat{ + chatID: created.Chat.ID, + exists: true, + initialUserMessageID: firstUserMessageID(ctx, t, f, created.Chat.ID), + queuedMessageIDs: []int64{first.QueuedMessage.ID, second.QueuedMessage.ID}, + queuedMessageBodies: []string{firstBody, secondBody}, + } + + case chatstate.StateA1: + return seedAOrA1(t, f, 2, "seed_tool_a1_multi") + } + t.Fatalf("seedStateMultiQueued: unsupported execution state %s", state) + return seededChat{} +} + +// seedA1WithMixedOutstandingToolCalls seeds A1 with one queued message +// and one assistant message carrying both a dynamic and non-dynamic +// outstanding tool call. It is used by PromoteQueuedMessage(A1) to +// prove all tool calls are closed before inserting the promoted user. +func seedA1WithMixedOutstandingToolCalls(t *testing.T, f *testFixture, queuedExtras int, namePrefix string) seededChat { + t.Helper() + ctx := testutil.Context(t, testutil.WaitShort) + toolName := namePrefix + dynCallID := "call_" + uuid.NewString() + nonDynCallID := "call_" + uuid.NewString() + created := createTestChatWithDynamicTools(t, f, toolName) + m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{}) + var step chatstate.CommitStepResult + require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error { + var err error + step, err = tx.CommitStep(chatstate.CommitStepInput{ + Messages: []chatstate.Message{ + mixedAssistantToolCallMessage(t, f.Model.ID, toolName, dynCallID, nonDynCallID), + }, + }) + return err + })) + require.Len(t, step.InsertedMessages, 1) + require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error { + _, err := tx.EnterRequiresAction(chatstate.EnterRequiresActionInput{}) + return err + })) + var ( + queuedIDs []int64 + queuedBodies []string + queuedCreatedBy []uuid.UUID + ) + for i := range queuedExtras { + body := fmt.Sprintf("queued-%s-%d", namePrefix, i) + createdBy := uuid.New() + queued, err := f.DB.InsertChatQueuedMessageWithCreator(ctx, database.InsertChatQueuedMessageWithCreatorParams{ + ChatID: created.Chat.ID, + Content: userMessageContent(t, body), + ModelConfigID: uuid.NullUUID{UUID: f.Model.ID, Valid: true}, + CreatedBy: createdBy, + }) + require.NoError(t, err) + queuedIDs = append(queuedIDs, queued.ID) + queuedBodies = append(queuedBodies, body) + queuedCreatedBy = append(queuedCreatedBy, createdBy) + } + return seededChat{ + chatID: created.Chat.ID, + exists: true, + initialUserMessageID: firstUserMessageID(ctx, t, f, created.Chat.ID), + assistantToolCallMsgID: step.InsertedMessages[0].ID, + queuedMessageIDs: queuedIDs, + queuedMessageBodies: queuedBodies, + queuedMessageCreatedBy: queuedCreatedBy, + dynamicToolName: toolName, + pendingToolCallID: dynCallID, + pendingToolCallIDs: []string{dynCallID, nonDynCallID}, + } +} + +// seedInvalidWithQueue seeds Invalid with a single queued message so +// ReconcileInvalidState lands in E1 (non-empty queue) instead of E0. +func seedInvalidWithQueue(t *testing.T, f *testFixture) seededChat { + t.Helper() + ctx := testutil.Context(t, testutil.WaitShort) + created := createTestChat(t, f) + m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{}) + queuedBody := "queued-invalid" + queued := sendQueuedMessage(t, f, m, queuedBody) + require.NotNil(t, queued.QueuedMessage) + // Force the deliberately invalid running + archived combo on + // top of the queue. + chat, err := f.DB.GetChatByID(ctx, created.Chat.ID) + require.NoError(t, err) + _, err = f.DB.UpdateChatExecutionState(ctx, database.UpdateChatExecutionStateParams{ + ID: chat.ID, + Status: database.ChatStatusRunning, + Archived: true, + WorkerID: chat.WorkerID, + RunnerID: chat.RunnerID, + LastError: chat.LastError, + RequiresActionDeadlineAt: chat.RequiresActionDeadlineAt, + }) + require.NoError(t, err) + return seededChat{ + chatID: chat.ID, + exists: true, + initialUserMessageID: firstUserMessageID(ctx, t, f, chat.ID), + queuedMessageIDs: []int64{queued.QueuedMessage.ID}, + queuedMessageBodies: []string{queuedBody}, + } +} + +// firstUserMessageID returns the lowest-id non-deleted user message +// on the chat. Most transition tests reuse this when they need a +// user message to edit. +func firstUserMessageID(ctx context.Context, t *testing.T, f *testFixture, chatID uuid.UUID) int64 { + t.Helper() + msgs, err := f.DB.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ + ChatID: chatID, + }) + require.NoError(t, err) + for _, m := range msgs { + if m.Role == database.ChatMessageRoleUser && !m.Deleted { + return m.ID + } + } + t.Fatalf("firstUserMessageID: chat %s has no user messages", chatID) + return 0 +} + +func firstAssistantMessageID(ctx context.Context, t *testing.T, f *testFixture, chatID uuid.UUID) int64 { + t.Helper() + msgs, err := f.DB.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ + ChatID: chatID, + }) + require.NoError(t, err) + for _, m := range msgs { + if m.Role == database.ChatMessageRoleAssistant && !m.Deleted { + return m.ID + } + } + t.Fatalf("firstAssistantMessageID: chat %s has no assistant messages", chatID) + return 0 +} + +// seedForEnterRequiresAction extends seedState for R0 and R1 with a +// chat that has dynamic_tools plus an assistant tool-call message in +// history. EnterRequiresAction's precondition rejects R0/R1 without +// pending dynamic tool calls, so the generic seedState path will not +// do. Other states fall through to the default seedState. +func seedForEnterRequiresAction(t *testing.T, f *testFixture, state chatstate.ExecutionState) seededChat { + t.Helper() + ctx := testutil.Context(t, testutil.WaitShort) + switch state { + case chatstate.StateR0: + toolName := "ra_tool_r0" + callID := "call_" + uuid.NewString() + created := createTestChatWithDynamicTools(t, f, toolName) + m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{}) + var step chatstate.CommitStepResult + require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error { + var err error + step, err = tx.CommitStep(chatstate.CommitStepInput{ + Messages: []chatstate.Message{ + assistantToolCallMessage(t, f.Model.ID, toolName, callID), + }, + }) + return err + })) + require.Len(t, step.InsertedMessages, 1) + return seededChat{ + chatID: created.Chat.ID, + exists: true, + initialUserMessageID: firstUserMessageID(ctx, t, f, created.Chat.ID), + assistantToolCallMsgID: step.InsertedMessages[0].ID, + dynamicToolName: toolName, + pendingToolCallID: callID, + pendingToolCallIDs: []string{callID}, + } + case chatstate.StateR1: + toolName := "ra_tool_r1" + callID := "call_" + uuid.NewString() + created := createTestChatWithDynamicTools(t, f, toolName) + m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{}) + var step chatstate.CommitStepResult + require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error { + var err error + step, err = tx.CommitStep(chatstate.CommitStepInput{ + Messages: []chatstate.Message{ + assistantToolCallMessage(t, f.Model.ID, toolName, callID), + }, + }) + return err + })) + // R0 -> R1 with a queued message. + queuedBody := "queued-for-RA-r1" + sm := sendQueuedMessage(t, f, m, queuedBody) + require.NotNil(t, sm.QueuedMessage) + return seededChat{ + chatID: created.Chat.ID, + exists: true, + initialUserMessageID: firstUserMessageID(ctx, t, f, created.Chat.ID), + assistantToolCallMsgID: step.InsertedMessages[0].ID, + queuedMessageIDs: []int64{sm.QueuedMessage.ID}, + queuedMessageBodies: []string{queuedBody}, + dynamicToolName: toolName, + pendingToolCallID: callID, + pendingToolCallIDs: []string{callID}, + } + } + return seedState(t, f, state) +} + +// ============================================================================= +// Snapshot baselines and shared assertion helpers used by every matrix +// case. captureBaseline records the chat's snapshot_version and the +// publisher's recorded channel count immediately before a transition +// runs; assertSnapshotBumpedOnce and assertNoMutationOrPublish use the +// baseline to verify either a single snapshot bump and one chat:update +// on success, or zero mutation and zero publishes on failure. +// ============================================================================= + +// activeHistoryIDs returns the ids of non-deleted history messages +// for the chat in row-id order. Useful for verifying CommitStep, +// EditMessage replacement, and PromoteQueuedMessage head insertion. +func activeHistoryIDs(ctx context.Context, t *testing.T, f *testFixture, chatID uuid.UUID) []int64 { + t.Helper() + msgs, err := f.DB.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ + ChatID: chatID, + }) + require.NoError(t, err) + out := make([]int64, 0, len(msgs)) + for _, m := range msgs { + if !m.Deleted { + out = append(out, m.ID) + } + } + return out +} + +func requireChatMessageByID(ctx context.Context, t *testing.T, f *testFixture, id int64) database.ChatMessage { + t.Helper() + msg, err := f.DB.GetChatMessageByID(ctx, id) + require.NoError(t, err) + return msg +} + +func requireQueuedMessageByID(ctx context.Context, t *testing.T, f *testFixture, chatID uuid.UUID, id int64) database.ChatQueuedMessage { + t.Helper() + msg, err := f.DB.GetChatQueuedMessageByID(ctx, database.GetChatQueuedMessageByIDParams{ + ID: id, + ChatID: chatID, + }) + require.NoError(t, err) + return msg +} + +func requireQueuedMessageDeleted(ctx context.Context, t *testing.T, f *testFixture, chatID uuid.UUID, id int64) { + t.Helper() + _, err := f.DB.GetChatQueuedMessageByID(ctx, database.GetChatQueuedMessageByIDParams{ + ID: id, + ChatID: chatID, + }) + require.Error(t, err) +} + +func assertFetchedUserMessage(ctx context.Context, t *testing.T, f *testFixture, msg database.ChatMessage) database.ChatMessage { + t.Helper() + fetched := requireChatMessageByID(ctx, t, f, msg.ID) + require.Equal(t, msg.ChatID, fetched.ChatID) + require.Equal(t, database.ChatMessageRoleUser, fetched.Role) + require.True(t, fetched.CreatedBy.Valid) + require.Equal(t, f.User.ID, fetched.CreatedBy.UUID) + require.True(t, fetched.ModelConfigID.Valid) + require.Equal(t, f.Model.ID, fetched.ModelConfigID.UUID) + require.Equal(t, chatprompt.CurrentContentVersion, fetched.ContentVersion) + return fetched +} + +func assertFetchedQueuedMessage(ctx context.Context, t *testing.T, f *testFixture, chatID uuid.UUID, queued database.ChatQueuedMessage) database.ChatQueuedMessage { + t.Helper() + fetched := requireQueuedMessageByID(ctx, t, f, chatID, queued.ID) + require.Equal(t, chatID, fetched.ChatID) + require.Equal(t, f.User.ID, fetched.CreatedBy) + require.True(t, fetched.ModelConfigID.Valid) + require.Equal(t, f.Model.ID, fetched.ModelConfigID.UUID) + require.NotEmpty(t, fetched.Content) + return fetched +} + +func newActiveMessageIDs(base snapshotBaseline, after []int64) []int64 { + seen := make(map[int64]struct{}, len(base.historyIDs)) + for _, id := range base.historyIDs { + seen[id] = struct{}{} + } + out := make([]int64, 0, len(after)) + for _, id := range after { + if _, ok := seen[id]; !ok { + out = append(out, id) + } + } + return out +} + +// assertToolResultForCallNoError asserts that msg is a tool-result +// message that resolves a tool call with id wantCallID, is_error=false, +// and that the result JSON matches wantResultJSON. Complements +// assertToolResultForCall in synthetic_cancellation_test.go which +// asserts is_error=true. +func assertToolResultForCallNoError(t *testing.T, msg database.ChatMessage, wantCallID, wantResultJSON string) { + t.Helper() + require.Equal(t, database.ChatMessageRoleTool, msg.Role) + parts, err := chatprompt.ParseContent(msg) + require.NoError(t, err) + require.NotEmpty(t, parts) + var found bool + for _, p := range parts { + if p.Type != codersdk.ChatMessagePartTypeToolResult { + continue + } + require.Equal(t, wantCallID, p.ToolCallID, "tool-call id matches") + require.False(t, p.IsError, "CompleteRequiresAction tool result must not be is_error") + require.JSONEq(t, wantResultJSON, string(p.Result), "CompleteRequiresAction tool result JSON matches submitted output") + found = true + } + require.True(t, found, "expected at least one tool-result part") +} + +// assertChatMessageText asserts that the persisted content of msg +// decodes to a single text part with the supplied body. Used by +// matrix cases that need to verify the actual text submitted via +// SendMessage / EditMessage / CommitStep, or the text that was +// promoted out of the queue into history. +func assertChatMessageText(t *testing.T, msg database.ChatMessage, want string) { + t.Helper() + parts, err := chatprompt.ParseContent(msg) + require.NoError(t, err, "parse chat message content") + require.Len(t, parts, 1, "expected exactly one content part") + require.Equal(t, codersdk.ChatMessagePartTypeText, parts[0].Type, + "expected a text content part") + require.Equal(t, want, parts[0].Text, "unexpected chat message text") +} + +// assertQueuedMessageText asserts that the JSON content of queued +// decodes to a single text part with the supplied body. Used by +// matrix cases that need to verify the body inserted into +// chat_queued_messages via SendMessage. +func assertQueuedMessageText(t *testing.T, queued database.ChatQueuedMessage, want string) { + t.Helper() + var parts []codersdk.ChatMessagePart + require.NoError(t, json.Unmarshal(queued.Content, &parts), "unmarshal queued content") + require.Len(t, parts, 1, "expected exactly one queued content part") + require.Equal(t, codersdk.ChatMessagePartTypeText, parts[0].Type, + "expected a text content part") + require.Equal(t, want, parts[0].Text, "unexpected queued message text") +} + +// assertQueueBodiesInOrder fetches the queued messages for the chat +// in queue order and asserts each row's text body matches the +// supplied bodies. Used by matrix cases that need to verify the +// remaining queue content after a promote / finish-turn / +// finish-interruption. +func assertQueueBodiesInOrder(ctx context.Context, t *testing.T, f *testFixture, chatID uuid.UUID, want []string) { + t.Helper() + rows, err := f.DB.GetChatQueuedMessagesByPosition(ctx, chatID) + require.NoError(t, err) + require.Len(t, rows, len(want), "queue length must match expected bodies") + for i, r := range rows { + assertQueuedMessageText(t, r, want[i]) + } +} + +// snapshotBaseline records the chat's snapshot_version and the +// publisher's recorded channel count immediately before a transition +// runs. Tests use it to verify either a single snapshot bump and one +// chat:update on success, or zero mutation and zero publishes on +// failure. +type snapshotBaseline struct { + exists bool + chat database.Chat + snapshot int64 + historyVersion int64 + queueVersion int64 + retryStateVersion int64 + generationAttempt int64 + queueCount int64 + queueIDs []int64 + historyIDs []int64 + channels int +} + +func captureBaseline(ctx context.Context, t *testing.T, f *testFixture, seeded seededChat) snapshotBaseline { + t.Helper() + base := snapshotBaseline{ + exists: seeded.exists, + channels: len(f.Pub.channels), + } + if !seeded.exists { + return base + } + chat, err := f.DB.GetChatByID(ctx, seeded.chatID) + require.NoError(t, err) + base.chat = chat + base.snapshot = chat.SnapshotVersion + base.historyVersion = chat.HistoryVersion + base.queueVersion = chat.QueueVersion + base.retryStateVersion = chat.RetryStateVersion + base.generationAttempt = chat.GenerationAttempt + base.queueIDs = queuedIDsByPosition(ctx, t, f, seeded.chatID) + count, err := f.DB.CountChatQueuedMessages(ctx, seeded.chatID) + require.NoError(t, err) + base.queueCount = count + base.historyIDs = activeHistoryIDs(ctx, t, f, seeded.chatID) + return base +} + +// assertSnapshotBumpedOnce asserts that one Update committed; that is, +// snapshot_version advanced by exactly one and the publisher saw at +// least one chat:update on the per-chat channel after the baseline. +func assertSnapshotBumpedOnce(ctx context.Context, t *testing.T, f *testFixture, chatID uuid.UUID, base snapshotBaseline) { + t.Helper() + after, err := f.DB.GetChatByID(ctx, chatID) + require.NoError(t, err) + require.Equal(t, base.snapshot+1, after.SnapshotVersion, "snapshot_version must bump exactly once") + channel := coderdpubsub.ChatStateUpdateChannel(chatID) + found := false + for _, c := range f.Pub.channels[base.channels:] { + if c == channel { + found = true + break + } + } + require.True(t, found, "expected one chat:update on %s after commit", channel) +} + +// assertNoMutationOrPublish asserts a failed transition rolled back +// the automatic snapshot bump and published nothing. +func assertNoMutationOrPublish(ctx context.Context, t *testing.T, f *testFixture, chatID uuid.UUID, base snapshotBaseline) { + t.Helper() + require.Equal(t, base.channels, len(f.Pub.channels), "failed transition must not publish") + if base.exists { + after, err := f.DB.GetChatByID(ctx, chatID) + require.NoError(t, err) + require.Equal(t, base.snapshot, after.SnapshotVersion, "failed transition must not advance snapshot_version") + } +} diff --git a/coderd/x/chatd/chatstate/transitions_matrix_test.go b/coderd/x/chatd/chatstate/transitions_matrix_test.go new file mode 100644 index 0000000000..221c7d3049 --- /dev/null +++ b/coderd/x/chatd/chatstate/transitions_matrix_test.go @@ -0,0 +1,1856 @@ +package chatstate_test + +import ( + "context" + "encoding/json" + "fmt" + "slices" + "sync" + "testing" + + "github.com/google/uuid" + "github.com/sqlc-dev/pqtype" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/x/chatd/chatprompt" + "github.com/coder/coder/v2/coderd/x/chatd/chatstate" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/testutil" +) + +// ============================================================================= +// Matrix harness: spec types, scenario labels, appliers, case runners, +// and the single entry point that walks the production transition +// matrix to confirm every allowed combination has positive coverage +// and every disallowed combination surfaces the right sentinel error. +// ============================================================================= + +// scenario is a typed, semantic label that distinguishes positive +// matrix cases that share the same (transition, from, want) key. +// Empty scenario is fine when no label is needed. The constants +// below enumerate every label used by matrixCases(). +type scenario string + +const ( + // scenarioQueue marks SendMessage cases driven by + // BusyBehaviorQueue. + scenarioQueue scenario = "queue" + // scenarioInterrupt marks SendMessage cases driven by + // BusyBehaviorInterrupt. + scenarioInterrupt scenario = "interrupt" + // scenarioMulti marks cases seeded with multiple queued + // messages so the post-mutation queue stays non-empty. + scenarioMulti scenario = "multi" + // scenarioHeadTarget marks multi-queued PromoteQueuedMessage + // cases that target the queue head. For R1/I1 head-target is + // reorder-only: no rows are updated, so queue order and + // queue_version are unchanged. For E1/A1 head-target still + // pops the head into history. + scenarioHeadTarget scenario = "head_target" + // scenarioNonHead marks multi-queued PromoteQueuedMessage cases + // that target a non-head queued message so the target moves to + // the head and queue_version advances. + scenarioNonHead scenario = "non_head" + // scenarioWithQueue marks ReconcileInvalidState cases seeded + // with a non-empty queue. + scenarioWithQueue scenario = "with_queue" + // scenarioRejectNonDynamicOutstandingToolCall marks the + // FinishInterruption case that exercises the precondition + // rejecting outstanding non-dynamic tool calls. + scenarioRejectNonDynamicOutstandingToolCall scenario = "reject_non_dynamic_outstanding_tool_call" +) + +// ============================================================================= +// Matrix lookup helpers +// ============================================================================= + +func transitionAllowed(tr chatstate.Transition, from chatstate.ExecutionState) bool { + return slices.Contains(chatstate.AllowedExecutionTransitionsFrom(from), tr) +} + +// expectedErrorForDisallowed returns the sentinel chatstate package +// returns when a transition is attempted from a state where the +// matrix forbids it. N (missing chat) becomes ErrChatNotFound; +// Invalid becomes ErrInvalidState (except for ReconcileInvalidState +// which is allowed); everything else becomes ErrTransitionNotAllowed. +func expectedErrorForDisallowed(tr chatstate.Transition, from chatstate.ExecutionState) error { + switch from { + case chatstate.StateN: + if tr == chatstate.TransitionCreateChat { + // CreateChat is not exercised through ChatMachine.Update, + // so this branch is unused in practice. Returning the + // not-allowed sentinel keeps the helper total. + return chatstate.ErrTransitionNotAllowed + } + return chatstate.ErrChatNotFound + case chatstate.StateInvalid: + if tr == chatstate.TransitionReconcileInvalidState { + return nil + } + return chatstate.ErrInvalidState + } + return chatstate.ErrTransitionNotAllowed +} + +// ============================================================================= +// Transition appliers +// +// Each transition has one default applier that exercises it with +// inputs derived from the seeded chat. Positive case specs reuse these +// appliers unless a case needs a different input shape (for example, +// SendMessage queue versus interrupt from the same source state). +// The disallowed coverage path also uses these defaults. +// ============================================================================= + +func applySetArchived(t *testing.T, f *testFixture, tx *chatstate.Tx, seeded seededChat, from chatstate.ExecutionState, result *transitionCaseResult) error { + t.Helper() + _ = f + _ = seeded + _ = result + // Archived states unarchive, others archive. For disallowed + // states the value does not matter; the transition fails first. + archived := true + switch from { + case chatstate.StateXW, chatstate.StateXE0, chatstate.StateXE1: + archived = false + } + _, err := tx.SetArchived(chatstate.SetArchivedInput{Archived: archived}) + return err +} + +func applySendMessageQueue(t *testing.T, f *testFixture, tx *chatstate.Tx, _ seededChat, _ chatstate.ExecutionState, result *transitionCaseResult) error { + t.Helper() + var err error + result.sendMessage, err = tx.SendMessage(chatstate.SendMessageInput{ + Message: userTextMessage("sm-queue", f.User.ID, f.Model.ID), + BusyBehavior: chatstate.BusyBehaviorQueue, + }) + return err +} + +func applySendMessageInterrupt(t *testing.T, f *testFixture, tx *chatstate.Tx, _ seededChat, _ chatstate.ExecutionState, result *transitionCaseResult) error { + t.Helper() + var err error + result.sendMessage, err = tx.SendMessage(chatstate.SendMessageInput{ + Message: userTextMessage("sm-interrupt", f.User.ID, f.Model.ID), + BusyBehavior: chatstate.BusyBehaviorInterrupt, + }) + return err +} + +func applyEditMessage(t *testing.T, f *testFixture, tx *chatstate.Tx, seeded seededChat, _ chatstate.ExecutionState, result *transitionCaseResult) error { + t.Helper() + content := mustMarshalParts(t, []codersdk.ChatMessagePart{codersdk.ChatMessageText("edited")}) + var err error + result.editMessage, err = tx.EditMessage(chatstate.EditMessageInput{ + MessageID: seeded.initialUserMessageID, + CreatedBy: f.User.ID, + Content: content, + }) + return err +} + +func applyDeleteQueuedMessage(t *testing.T, _ *testFixture, tx *chatstate.Tx, seeded seededChat, _ chatstate.ExecutionState, result *transitionCaseResult) error { + t.Helper() + var targetQueueID int64 + if len(seeded.queuedMessageIDs) > 0 { + targetQueueID = seeded.queuedMessageIDs[0] + } + var err error + result.deleteQueuedMessage, err = tx.DeleteQueuedMessage(chatstate.DeleteQueuedMessageInput{ + QueuedMessageID: targetQueueID, + }) + return err +} + +func applyPromoteQueuedMessage(t *testing.T, _ *testFixture, tx *chatstate.Tx, seeded seededChat, _ chatstate.ExecutionState, result *transitionCaseResult) error { + t.Helper() + var targetQueueID int64 + if len(seeded.queuedMessageIDs) > 0 { + targetQueueID = seeded.queuedMessageIDs[0] + } + var err error + result.promoteQueuedMessage, err = tx.PromoteQueuedMessage(chatstate.PromoteQueuedMessageInput{ + QueuedMessageID: targetQueueID, + }) + return err +} + +func applyInterrupt(t *testing.T, _ *testFixture, tx *chatstate.Tx, _ seededChat, _ chatstate.ExecutionState, result *transitionCaseResult) error { + t.Helper() + var err error + result.interrupt, err = tx.Interrupt(chatstate.InterruptInput{Reason: "test"}) + return err +} + +func applyCompleteRequiresAction(t *testing.T, f *testFixture, tx *chatstate.Tx, seeded seededChat, _ chatstate.ExecutionState, result *transitionCaseResult) error { + t.Helper() + var results []chatstate.ToolResultInput + if seeded.pendingToolCallID != "" { + results = []chatstate.ToolResultInput{{ + ToolCallID: seeded.pendingToolCallID, + Output: json.RawMessage(`{"ok":true}`), + IsError: false, + }} + } + var err error + result.completeRequiresAction, err = tx.CompleteRequiresAction(chatstate.CompleteRequiresActionInput{ + CreatedBy: f.User.ID, + ModelConfigID: f.Model.ID, + Results: results, + }) + return err +} + +func applyRecordGenerationAttempt(t *testing.T, _ *testFixture, tx *chatstate.Tx, _ seededChat, _ chatstate.ExecutionState, result *transitionCaseResult) error { + t.Helper() + var err error + result.recordGenerationAttempt, err = tx.RecordGenerationAttempt(chatstate.RecordGenerationAttemptInput{}) + return err +} + +func applyRecordRetryState(t *testing.T, _ *testFixture, tx *chatstate.Tx, _ seededChat, _ chatstate.ExecutionState, result *transitionCaseResult) error { + t.Helper() + var err error + result.recordRetryState, err = tx.RecordRetryState(chatstate.RecordRetryStateInput{ + RetryState: pqtype.NullRawMessage{ + RawMessage: json.RawMessage(`{"attempt":1,"delay_ms":250,"error":"retry","retrying_at":"2026-05-29T00:00:00Z"}`), + Valid: true, + }, + }) + return err +} + +func applyCommitStep(t *testing.T, f *testFixture, tx *chatstate.Tx, _ seededChat, _ chatstate.ExecutionState, result *transitionCaseResult) error { + t.Helper() + assistant := userTextMessage("assistant", f.User.ID, f.Model.ID) + assistant.Role = database.ChatMessageRoleAssistant + var err error + result.commitStep, err = tx.CommitStep(chatstate.CommitStepInput{ + Messages: []chatstate.Message{assistant}, + }) + return err +} + +func applyEnterRequiresAction(t *testing.T, _ *testFixture, tx *chatstate.Tx, _ seededChat, _ chatstate.ExecutionState, result *transitionCaseResult) error { + t.Helper() + var err error + result.enterRequiresAction, err = tx.EnterRequiresAction(chatstate.EnterRequiresActionInput{}) + return err +} + +func applyFinishInterruption(t *testing.T, _ *testFixture, tx *chatstate.Tx, _ seededChat, _ chatstate.ExecutionState, result *transitionCaseResult) error { + t.Helper() + var err error + result.finishInterruption, err = tx.FinishInterruption(chatstate.FinishInterruptionInput{}) + return err +} + +func applyFinishTurn(t *testing.T, _ *testFixture, tx *chatstate.Tx, _ seededChat, _ chatstate.ExecutionState, result *transitionCaseResult) error { + t.Helper() + var err error + result.finishTurn, err = tx.FinishTurn(chatstate.FinishTurnInput{}) + return err +} + +func applyFinishError(t *testing.T, _ *testFixture, tx *chatstate.Tx, _ seededChat, _ chatstate.ExecutionState, result *transitionCaseResult) error { + t.Helper() + var err error + result.finishError, err = tx.FinishError(chatstate.FinishErrorInput{ + LastError: pqtype.NullRawMessage{ + RawMessage: json.RawMessage(`{"message":"finish-error"}`), + Valid: true, + }, + }) + return err +} + +func applyCancelRequiresAction(t *testing.T, _ *testFixture, tx *chatstate.Tx, _ seededChat, _ chatstate.ExecutionState, result *transitionCaseResult) error { + t.Helper() + var err error + result.cancelRequiresAction, err = tx.CancelRequiresAction(chatstate.CancelRequiresActionInput{ + Reason: "cancel from test", + }) + return err +} + +func applyReconcileInvalidState(t *testing.T, _ *testFixture, tx *chatstate.Tx, _ seededChat, _ chatstate.ExecutionState, result *transitionCaseResult) error { + t.Helper() + var err error + result.reconcileInvalidState, err = tx.ReconcileInvalidState(chatstate.ReconcileInvalidStateInput{}) + return err +} + +// defaultApplier returns the canonical applier for tr. Used by the +// disallowed coverage path where the input shape does not matter +// because the transition fails before the inputs are consumed. +func defaultApplier(tr chatstate.Transition) applierFn { + switch tr { + case chatstate.TransitionSetArchived: + return applySetArchived + case chatstate.TransitionSendMessage: + return applySendMessageQueue + case chatstate.TransitionEditMessage: + return applyEditMessage + case chatstate.TransitionDeleteQueuedMessage: + return applyDeleteQueuedMessage + case chatstate.TransitionPromoteQueuedMessage: + return applyPromoteQueuedMessage + case chatstate.TransitionInterrupt: + return applyInterrupt + case chatstate.TransitionCompleteRequiresAction: + return applyCompleteRequiresAction + case chatstate.TransitionRecordGenerationAttempt: + return applyRecordGenerationAttempt + case chatstate.TransitionRecordRetryState: + return applyRecordRetryState + case chatstate.TransitionCommitStep: + return applyCommitStep + case chatstate.TransitionEnterRequiresAction: + return applyEnterRequiresAction + case chatstate.TransitionFinishInterruption: + return applyFinishInterruption + case chatstate.TransitionFinishTurn: + return applyFinishTurn + case chatstate.TransitionFinishError: + return applyFinishError + case chatstate.TransitionCancelRequiresAction: + return applyCancelRequiresAction + case chatstate.TransitionReconcileInvalidState: + return applyReconcileInvalidState + } + return nil +} + +// mustMarshalParts is a tiny test helper that fails the test on +// marshal error rather than forcing every call site to handle it. +func mustMarshalParts(t *testing.T, parts []codersdk.ChatMessagePart) pqtype.NullRawMessage { + t.Helper() + raw, err := chatprompt.MarshalParts(parts) + require.NoError(t, err) + return raw +} + +// ============================================================================= +// Case-level transition matrix spec. +// +// Each entry in matrixCases is one positive (transition, from, want) +// triple. The coverage key is (transition, from, want); scenario is a +// readability-and-semantic suffix for the subtest name that +// distinguishes multiple cases sharing the same coverage key. +// Disallowed combinations are enumerated separately from +// AllowedExecutionTransitionsFrom and AllowedExecutionTransitionOutputs. +// ============================================================================= + +type transitionCaseResult struct { + sendMessage chatstate.SendMessageResult + editMessage chatstate.EditMessageResult + deleteQueuedMessage chatstate.DeleteQueuedMessageResult + promoteQueuedMessage chatstate.PromoteQueuedMessageResult + interrupt chatstate.InterruptResult + completeRequiresAction chatstate.CompleteRequiresActionResult + recordGenerationAttempt chatstate.RecordGenerationAttemptResult + recordRetryState chatstate.RecordRetryStateResult + commitStep chatstate.CommitStepResult + enterRequiresAction chatstate.EnterRequiresActionResult + finishInterruption chatstate.FinishInterruptionResult + finishTurn chatstate.FinishTurnResult + finishError chatstate.FinishErrorResult + cancelRequiresAction chatstate.CancelRequiresActionResult + reconcileInvalidState chatstate.ReconcileInvalidStateResult +} + +type applierFn func(t *testing.T, f *testFixture, tx *chatstate.Tx, seeded seededChat, from chatstate.ExecutionState, result *transitionCaseResult) error + +type assertFn func(ctx context.Context, t *testing.T, f *testFixture, seeded seededChat, base snapshotBaseline, result transitionCaseResult) + +// seederFn produces a seededChat for a case. Cases that omit a custom +// seeder use seedState by default. Custom seeders are required when +// the case needs more than one queued message, an Invalid chat with a +// non-empty queue, or a transition that needs a fresh A0/A1 seed. +type seederFn func(t *testing.T, f *testFixture, from chatstate.ExecutionState) seededChat + +type transitionCaseSpec struct { + transition chatstate.Transition + from chatstate.ExecutionState + want chatstate.ExecutionState + // scenario is a semantic label appended to the subtest name + // when the same (transition, from, want) key needs to run more + // than once. It is not part of the coverage key but is part of + // the duplicate-detection key. + scenario scenario + + seed seederFn + apply applierFn + assert assertFn + assertFailure func(ctx context.Context, t *testing.T, f *testFixture, seeded seededChat, base snapshotBaseline, err error) +} + +// caseKey is the unit of coverage for positive cases. scenario is +// intentionally not part of the key so cases with different scenarios +// can still satisfy the same coverage cell. +type caseKey struct { + transition chatstate.Transition + from chatstate.ExecutionState + want chatstate.ExecutionState +} + +// fullCaseKey extends caseKey with scenario. Used for duplicate +// detection: two cases must not share the same full key. +type fullCaseKey struct { + transition chatstate.Transition + from chatstate.ExecutionState + want chatstate.ExecutionState + scenario scenario +} + +// queueShape selects the seed variant for transition case builders. +// A typed enum is used instead of a bool to avoid the revive +// flag-parameter rule and to make call sites self-documenting. +type queueShape int + +const ( + // queueShapeDefault routes through seedState, which produces the + // canonical single-queued seed for queue-bearing states and the + // empty queue for non-queue states. + queueShapeDefault queueShape = iota + // queueShapeMulti routes through seedStateMultiQueued (or + // seedInvalidWithQueue for ReconcileInvalidState) so the + // post-mutation queue can remain non-empty. + queueShapeMulti +) + +func (s queueShape) isMulti() bool { return s == queueShapeMulti } + +func (s transitionCaseSpec) key() caseKey { + return caseKey{transition: s.transition, from: s.from, want: s.want} +} + +func (s transitionCaseSpec) fullKey() fullCaseKey { + return fullCaseKey{ + transition: s.transition, + from: s.from, + want: s.want, + scenario: s.scenario, + } +} + +func (s transitionCaseSpec) subtestName() string { + name := fmt.Sprintf("%s/%s_to_%s", s.transition, s.from, s.want) + if s.scenario != "" { + name += "/" + string(s.scenario) + } + return name +} + +// disallowedCaseKey is the unit of coverage for negative cases. +type disallowedCaseKey struct { + transition chatstate.Transition + from chatstate.ExecutionState +} + +// remainingExcluding returns ids with the entry at exclude removed. +// The order of the surviving entries is preserved. +func remainingExcluding(ids []int64, exclude int) []int64 { + out := make([]int64, 0, len(ids)) + for i, id := range ids { + if i == exclude { + continue + } + out = append(out, id) + } + return out +} + +// remainingBodiesExcluding returns bodies with the entry at exclude +// removed. The order of the surviving entries is preserved. +func remainingBodiesExcluding(bodies []string, exclude int) []string { + out := make([]string, 0, len(bodies)) + for i, b := range bodies { + if i == exclude { + continue + } + out = append(out, b) + } + return out +} + +// ============================================================================= +// Test runner +// ============================================================================= + +// runPositiveCase seeds the chat, runs the transition, and asserts the +// post-state plus case-specific effects. +func runPositiveCase(t *testing.T, spec transitionCaseSpec) { + t.Helper() + require.NotNil(t, spec.apply, "case %s missing apply", spec.subtestName()) + f := newTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + + seeder := spec.seed + if seeder == nil { + seeder = seedState + } + seeded := seeder(t, f, spec.from) + if seeded.exists { + require.Equal(t, spec.from, f.classify(ctx, t, seeded.chatID), + "seed must land in %s", spec.from) + } + base := captureBaseline(ctx, t, f, seeded) + + m := chatstate.NewChatMachine(f.DB, f.Pub, seeded.chatID, chatstate.Options{}) + var result transitionCaseResult + err := m.Update(ctx, func(tx *chatstate.Tx) error { + return spec.apply(t, f, tx, seeded, spec.from, &result) + }) + if spec.assertFailure != nil { + spec.assertFailure(ctx, t, f, seeded, base, err) + return + } + require.NoError(t, err, "%s from %s must succeed", spec.transition, spec.from) + assertSnapshotBumpedOnce(ctx, t, f, seeded.chatID, base) + require.Equal(t, spec.want, f.classify(ctx, t, seeded.chatID), + "%s: %s -> %s", spec.transition, spec.from, spec.want) + if spec.assert != nil { + spec.assert(ctx, t, f, seeded, base, result) + } +} + +// runDisallowedCase seeds the chat, runs the transition with default +// inputs, and asserts that the chatstate package surfaces the right +// sentinel error and rolled the snapshot bump back. +func runDisallowedCase(t *testing.T, tr chatstate.Transition, from chatstate.ExecutionState) { + t.Helper() + f := newTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + seeded := seedState(t, f, from) + if seeded.exists { + require.Equal(t, from, f.classify(ctx, t, seeded.chatID), + "disallowed seed must land in %s", from) + } + base := captureBaseline(ctx, t, f, seeded) + + applier := defaultApplier(tr) + require.NotNil(t, applier, "no default applier for transition %s", tr) + m := chatstate.NewChatMachine(f.DB, f.Pub, seeded.chatID, chatstate.Options{}) + var result transitionCaseResult + err := m.Update(ctx, func(tx *chatstate.Tx) error { + return applier(t, f, tx, seeded, from, &result) + }) + + if tr == chatstate.TransitionReconcileInvalidState && from != chatstate.StateN { + // ReconcileInvalidState does not use requireFromAllowed. + // It hits loadState successfully, sees the state is not + // Invalid, and returns a TransitionError directly. + require.Error(t, err) + var te *chatstate.TransitionError + require.ErrorAs(t, err, &te, + "reconcile from non-invalid state must return TransitionError") + require.Equal(t, chatstate.TransitionReconcileInvalidState, te.Transition) + require.ErrorIs(t, err, chatstate.ErrTransitionNotAllowed) + assertNoMutationOrPublish(ctx, t, f, seeded.chatID, base) + return + } + + expectErr := expectedErrorForDisallowed(tr, from) + require.Error(t, err) + require.ErrorIs(t, err, expectErr) + assertNoMutationOrPublish(ctx, t, f, seeded.chatID, base) +} + +// TestTransitionMatrix_AllCombinations is the single entry point for +// the case-level transition matrix coverage. Each positive case in +// matrixCases() is one (transition, from, want) triple with a focused +// effect assertion. Disallowed combinations are enumerated from +// transition.go to confirm every non-CreateChat (transition, from) +// pair outside the allowed set surfaces the right sentinel error. +// +// After all parallel subtests complete the test verifies that the +// positive coverage matches AllowedExecutionTransitionOutputs (no +// missing key, no unexpected key) and that every disallowed +// (transition, from) pair was exercised exactly once. +// +// Input-specific rejection tests live in TestTransitionInputValidation +// and are intentionally not part of this matrix entry point so the +// matrix focus stays on positive cases and generated disallowed cases. +func TestTransitionMatrix_AllCombinations(t *testing.T) { + t.Parallel() + + cases := matrixCases() + + // Detect duplicate full keys and duplicate subtest names. The + // coverage key intentionally ignores scenario, so two cases may + // share the same (transition, from, want) only when their + // scenarios differ. + seenFullKeys := make(map[fullCaseKey]string, len(cases)) + seenNames := make(map[string]struct{}, len(cases)) + for _, tc := range cases { + full := tc.fullKey() + name := tc.subtestName() + if prev, ok := seenFullKeys[full]; ok { + t.Fatalf("duplicate matrix case %+v: previous %s, new %s", full, prev, name) + } + seenFullKeys[full] = name + if _, ok := seenNames[name]; ok { + t.Fatalf("duplicate matrix subtest name %s", name) + } + seenNames[name] = struct{}{} + } + + // Build the expected positive set from the matrix in + // transition.go. CreateChat is intentionally excluded because + // it is not exercised via ChatMachine.Update. + expectedPositive := make(map[caseKey]struct{}) + for _, from := range chatstate.AllExecutionStates { + for _, tr := range chatstate.AllowedExecutionTransitionsFrom(from) { + if tr == chatstate.TransitionCreateChat { + continue + } + for _, to := range chatstate.AllowedExecutionTransitionOutputs(from, tr) { + expectedPositive[caseKey{transition: tr, from: from, want: to}] = struct{}{} + } + } + } + + // Build the expected disallowed set: for each non-CreateChat + // transition, every state where the transition is not allowed. + expectedDisallowed := make(map[disallowedCaseKey]struct{}) + for _, tr := range chatstate.AllExecutionTransitions { + if tr == chatstate.TransitionCreateChat { + continue + } + for _, from := range chatstate.AllExecutionStates { + if transitionAllowed(tr, from) { + continue + } + expectedDisallowed[disallowedCaseKey{transition: tr, from: from}] = struct{}{} + } + } + + // Validate that every case in matrixCases describes a + // (transition, from, want) combination that the matrix actually + // admits. This guards against typos in matrixCases wiring up + // nonsense cases that happen to compile. + for _, tc := range cases { + if tc.assertFailure != nil { + continue + } + key := tc.key() + _, ok := expectedPositive[key] + require.True(t, ok, + "case %s is not in the allowed (transition, from, want) set", tc.subtestName()) + } + + // actualPositive and actualDisallowed are mutated under mu from + // parallel subtests. The final comparison runs in t.Cleanup, + // which fires only after every parallel child finishes. + var mu sync.Mutex + actualPositive := make(map[caseKey]struct{}, len(expectedPositive)) + actualDisallowed := make(map[disallowedCaseKey]struct{}, len(expectedDisallowed)) + + t.Cleanup(func() { + mu.Lock() + defer mu.Unlock() + for k := range expectedPositive { + if _, ok := actualPositive[k]; !ok { + t.Errorf("matrix coverage: missing positive case %+v", k) + } + } + for k := range actualPositive { + if _, ok := expectedPositive[k]; !ok { + t.Errorf("matrix coverage: unexpected positive case %+v", k) + } + } + for k := range expectedDisallowed { + if _, ok := actualDisallowed[k]; !ok { + t.Errorf("matrix coverage: missing disallowed case %+v", k) + } + } + for k := range actualDisallowed { + if _, ok := expectedDisallowed[k]; !ok { + t.Errorf("matrix coverage: unexpected disallowed case %+v", k) + } + } + }) + + // Positive cases: one parallel subtest per case. + t.Run("positive", func(t *testing.T) { + t.Parallel() + for _, tc := range cases { + tc := tc + t.Run(tc.subtestName(), func(t *testing.T) { + t.Parallel() + if tc.assertFailure == nil { + mu.Lock() + actualPositive[tc.key()] = struct{}{} + mu.Unlock() + } + runPositiveCase(t, tc) + }) + } + }) + + // Negative cases: one parallel subtest per (transition, from) + // pair where the transition is not allowed. Iterate over + // transitions in canonical order, and within each transition + // iterate states in canonical AllExecutionStates order, so + // subtest names are stable. + t.Run("disallowed", func(t *testing.T) { + t.Parallel() + // Sort disallowed keys for deterministic subtest names. + // AllExecutionTransitions and AllExecutionStates are + // already canonical, so iterate in their order. + for _, tr := range chatstate.AllExecutionTransitions { + tr := tr + if tr == chatstate.TransitionCreateChat { + continue + } + t.Run(string(tr), func(t *testing.T) { + t.Parallel() + for _, from := range chatstate.AllExecutionStates { + from := from + if transitionAllowed(tr, from) { + continue + } + t.Run(string(from), func(t *testing.T) { + t.Parallel() + mu.Lock() + actualDisallowed[disallowedCaseKey{transition: tr, from: from}] = struct{}{} + mu.Unlock() + runDisallowedCase(t, tr, from) + }) + } + }) + } + }) +} + +// ============================================================================= +// Positive case specs. +// +// Each case asserts (at minimum) the resulting classified post-state +// matches want, plus one transition-specific effect. Helpers reused +// from other tests handle the snapshot bump and the chat:update +// publish; per-case assertions focus on what the transition meant to +// change. +// ============================================================================= + +func matrixCases() []transitionCaseSpec { + return []transitionCaseSpec{ + // SetArchived cases: each archived/unarchived pair flips the + // archived flag, preserves status, history and last_error, + // and does not insert anything new. + setArchivedCase(chatstate.StateW, chatstate.StateXW, database.ChatStatusWaiting), + setArchivedCase(chatstate.StateE0, chatstate.StateXE0, database.ChatStatusError), + setArchivedCase(chatstate.StateE1, chatstate.StateXE1, database.ChatStatusError), + setArchivedCase(chatstate.StateXW, chatstate.StateW, database.ChatStatusWaiting), + setArchivedCase(chatstate.StateXE0, chatstate.StateE0, database.ChatStatusError), + setArchivedCase(chatstate.StateXE1, chatstate.StateE1, database.ChatStatusError), + + // SendMessage(queue) cases: idle states insert directly, + // busy states append to the queue tail. + sendMessageQueueCase(chatstate.StateW, chatstate.StateR0, true, 0), + sendMessageQueueCase(chatstate.StateE0, chatstate.StateR0, true, 0), + // E1 promotes the queue head and queues the new tail, so + // the net queue delta is zero. + sendMessageQueueCase(chatstate.StateE1, chatstate.StateR1, false, 0), + sendMessageQueueCase(chatstate.StateR0, chatstate.StateR1, false, +1), + sendMessageQueueCase(chatstate.StateR1, chatstate.StateR1, false, +1), + sendMessageQueueCase(chatstate.StateI0, chatstate.StateI1, false, +1), + sendMessageQueueCase(chatstate.StateI1, chatstate.StateI1, false, +1), + sendMessageQueueCase(chatstate.StateA0, chatstate.StateA1, false, +1), + sendMessageQueueCase(chatstate.StateA1, chatstate.StateA1, false, +1), + + // SendMessage(interrupt) cases. The interrupt applier runs + // with body "sm-interrupt" so the assertion can prove the + // interrupt input path was taken. From W/E0/E1/I0/I1 the + // resulting (transition, from, want) coverage key is + // identical to the queue case, but we still exercise the + // interrupt entry point to guard against a future bug where + // it stops routing through the correct direct-insert / + // queue-tail / promotion paths. From the busy R0/R1/A0/A1 + // states the interrupt destination differs from the queue + // destination so the scenario label is the only case for that key. + sendMessageInterruptCase(chatstate.StateW, chatstate.StateR0), + sendMessageInterruptCase(chatstate.StateE0, chatstate.StateR0), + sendMessageInterruptCase(chatstate.StateE1, chatstate.StateR1), + sendMessageInterruptCase(chatstate.StateR0, chatstate.StateI1), + sendMessageInterruptCase(chatstate.StateR1, chatstate.StateI1), + sendMessageInterruptCase(chatstate.StateI0, chatstate.StateI1), + sendMessageInterruptCase(chatstate.StateI1, chatstate.StateI1), + sendMessageInterruptCase(chatstate.StateA0, chatstate.StateR1), + sendMessageInterruptCase(chatstate.StateA1, chatstate.StateR1), + + // EditMessage cases: every allowed source state lands in R0 + // with the queue cleared, last_error reset, and a + // replacement user message in active history. + editMessageCase(chatstate.StateW), + editMessageCase(chatstate.StateE0), + editMessageCase(chatstate.StateE1), + editMessageCase(chatstate.StateR0), + editMessageCase(chatstate.StateR1), + editMessageCase(chatstate.StateI0), + editMessageCase(chatstate.StateI1), + editMessageCase(chatstate.StateA0), + editMessageCase(chatstate.StateA1), + + // DeleteQueuedMessage cases. Empty-tail want collapses the + // classified state (E1->E0, R1->R0, I1->I0, A1->A0). The + // non-empty-tail cases need a multi-queued seed. + deleteQueuedCase(chatstate.StateE1, chatstate.StateE0, queueShapeDefault), + deleteQueuedCase(chatstate.StateE1, chatstate.StateE1, queueShapeMulti), + deleteQueuedCase(chatstate.StateR1, chatstate.StateR0, queueShapeDefault), + deleteQueuedCase(chatstate.StateR1, chatstate.StateR1, queueShapeMulti), + deleteQueuedCase(chatstate.StateI1, chatstate.StateI0, queueShapeDefault), + deleteQueuedCase(chatstate.StateI1, chatstate.StateI1, queueShapeMulti), + deleteQueuedCase(chatstate.StateA1, chatstate.StateA0, queueShapeDefault), + deleteQueuedCase(chatstate.StateA1, chatstate.StateA1, queueShapeMulti), + + // PromoteQueuedMessage cases. E1/A1 pop the head into + // history; R1/I1 only reorder the queue without + // inserting history. R1/I1 has both a head-target + // scenario (zero rows updated, queue_version unchanged) + // and a non-head scenario (target moves to head, + // queue_version advances). + promoteQueuedCase(chatstate.StateE1, chatstate.StateR0, queueShapeDefault, 0), + promoteQueuedCase(chatstate.StateE1, chatstate.StateR1, queueShapeMulti, 0), + promoteQueuedCase(chatstate.StateR1, chatstate.StateI1, queueShapeMulti, 0), + promoteQueuedCase(chatstate.StateR1, chatstate.StateI1, queueShapeMulti, 1), + promoteQueuedCase(chatstate.StateI1, chatstate.StateI1, queueShapeMulti, 1), + promoteQueuedCase(chatstate.StateA1, chatstate.StateR0, queueShapeDefault, 0), + promoteQueuedCase(chatstate.StateA1, chatstate.StateR1, queueShapeMulti, 0), + + // Interrupt cases. + interruptCase(chatstate.StateR0, chatstate.StateI0), + interruptCase(chatstate.StateR1, chatstate.StateI1), + interruptCase(chatstate.StateA0, chatstate.StateR0), + interruptCase(chatstate.StateA1, chatstate.StateR1), + + // CompleteRequiresAction cases: A0->R0, A1->R1. + completeRequiresActionCase(chatstate.StateA0, chatstate.StateR0), + completeRequiresActionCase(chatstate.StateA1, chatstate.StateR1), + + // CancelRequiresAction cases: A0->R0, A1->R1. + cancelRequiresActionCase(chatstate.StateA0, chatstate.StateR0), + cancelRequiresActionCase(chatstate.StateA1, chatstate.StateR1), + + // RecordGenerationAttempt cases: from-state preserved. + recordGenerationAttemptCase(chatstate.StateR0), + recordGenerationAttemptCase(chatstate.StateR1), + + // RecordRetryState cases: from-state preserved. + recordRetryStateCase(chatstate.StateR0), + recordRetryStateCase(chatstate.StateR1), + + // CommitStep cases: from-state preserved, history grows by + // one message. + commitStepCase(chatstate.StateR0), + commitStepCase(chatstate.StateR1), + + // EnterRequiresAction cases. R0/R1 need a pending tool call + // seeded; use seedForEnterRequiresAction so the precondition + // is met. + enterRequiresActionCase(chatstate.StateR0, chatstate.StateA0), + enterRequiresActionCase(chatstate.StateR1, chatstate.StateA1), + + // FinishInterruption cases: I0->W, I1->R0 (head promoted into + // history when only one queued), I1->R1 (with more than one + // queued, the head is promoted but the queue stays + // non-empty). + finishInterruptionCase(chatstate.StateI0, chatstate.StateW, queueShapeDefault), + finishInterruptionRejectsOutstandingToolCallCase(), + finishInterruptionCase(chatstate.StateI1, chatstate.StateR0, queueShapeDefault), + finishInterruptionCase(chatstate.StateI1, chatstate.StateR1, queueShapeMulti), + + // FinishTurn cases. + finishTurnCase(chatstate.StateR0, chatstate.StateW, queueShapeDefault), + finishTurnCase(chatstate.StateR1, chatstate.StateR0, queueShapeDefault), + finishTurnCase(chatstate.StateR1, chatstate.StateR1, queueShapeMulti), + + // FinishError cases. + finishErrorCase(chatstate.StateR0, chatstate.StateE0), + finishErrorCase(chatstate.StateR1, chatstate.StateE1), + + // ReconcileInvalidState cases: Invalid with empty queue + // lands in E0; Invalid with non-empty queue lands in E1. + reconcileInvalidStateCase(chatstate.StateE0, queueShapeDefault), + reconcileInvalidStateCase(chatstate.StateE1, queueShapeMulti), + } +} + +// ----------------------------------------------------------------------------- +// Case builders. Each helper returns a transitionCaseSpec wired with +// the per-state-specific applier and assertion. Keeping the helpers +// small and focused keeps matrixCases() readable. +// ----------------------------------------------------------------------------- + +func setArchivedCase(from, want chatstate.ExecutionState, wantStatus database.ChatStatus) transitionCaseSpec { + wantArchived := false + switch want { + case chatstate.StateXW, chatstate.StateXE0, chatstate.StateXE1: + wantArchived = true + } + return transitionCaseSpec{ + transition: chatstate.TransitionSetArchived, + from: from, + want: want, + apply: applySetArchived, + assert: func(ctx context.Context, t *testing.T, f *testFixture, seeded seededChat, base snapshotBaseline, result transitionCaseResult) { + _ = result + after, err := f.DB.GetChatByID(ctx, seeded.chatID) + require.NoError(t, err) + require.Equal(t, wantArchived, after.Archived, + "SetArchived must set archived=%v", wantArchived) + require.Equal(t, wantStatus, after.Status, + "SetArchived preserves chat status") + require.Equal(t, base.chat.LastError, after.LastError, + "SetArchived preserves last_error") + require.Equal(t, base.historyVersion, after.HistoryVersion, + "SetArchived does not insert history") + require.Equal(t, base.historyIDs, activeHistoryIDs(ctx, t, f, seeded.chatID), + "SetArchived leaves history messages unchanged") + require.Equal(t, base.queueVersion, after.QueueVersion, + "SetArchived does not mutate queued messages") + require.Equal(t, base.queueIDs, queuedIDsByPosition(ctx, t, f, seeded.chatID), + "SetArchived leaves queued messages unchanged") + }, + } +} + +func sendMessageQueueCase(from, want chatstate.ExecutionState, directInsert bool, queueDelta int64) transitionCaseSpec { + return transitionCaseSpec{ + transition: chatstate.TransitionSendMessage, + from: from, + want: want, + scenario: scenarioQueue, + apply: applySendMessageQueue, + assert: func(ctx context.Context, t *testing.T, f *testFixture, seeded seededChat, base snapshotBaseline, result transitionCaseResult) { + after, err := f.DB.GetChatByID(ctx, seeded.chatID) + require.NoError(t, err) + afterQueue, err := f.DB.CountChatQueuedMessages(ctx, seeded.chatID) + require.NoError(t, err) + afterHistory := activeHistoryIDs(ctx, t, f, seeded.chatID) + afterQueueIDs := queuedIDsByPosition(ctx, t, f, seeded.chatID) + + require.Equal(t, base.queueCount+queueDelta, afterQueue, + "SendMessage(queue): unexpected queue count delta") + + switch { + case directInsert: + // W/E0: insert directly into history, no queue + // mutation. result.InsertedMessages contains exactly + // the new user message. + require.Len(t, result.sendMessage.InsertedMessages, 1, + "SendMessage(queue) into W/E0 inserts exactly one history message") + require.Nil(t, result.sendMessage.QueuedMessage, + "SendMessage(queue) into W/E0 does not queue") + inserted := assertFetchedUserMessage(ctx, t, f, result.sendMessage.InsertedMessages[0]) + require.Equal(t, seeded.chatID, inserted.ChatID) + assertChatMessageText(t, inserted, "sm-queue") + require.False(t, after.LastError.Valid, + "SendMessage(queue) clears last_error when transitioning out of an error state") + require.Equal(t, database.ChatStatusRunning, after.Status, + "SendMessage(queue) into W/E0 lands in running") + require.Equal(t, base.queueIDs, afterQueueIDs, + "SendMessage(queue) into W/E0 must not touch queued messages") + require.Equal(t, base.queueVersion, after.QueueVersion, + "SendMessage(queue) into W/E0 must not bump queue_version") + require.Equal(t, append([]int64{}, base.historyIDs...), afterHistory[:len(base.historyIDs)], + "SendMessage(queue) into W/E0 leaves the existing history prefix intact") + require.Equal(t, []int64{inserted.ID}, newActiveMessageIDs(base, afterHistory), + "SendMessage(queue) into W/E0 appends exactly the new user message") + + case from == chatstate.StateE1: + // E1: the previous head is promoted into history + // and replaced by the new tail. Net queue size + // unchanged. + require.NotNil(t, result.sendMessage.QueuedMessage, + "SendMessage(queue) from E1 returns the new queued tail") + require.Len(t, result.sendMessage.InsertedMessages, 1, + "SendMessage(queue) from E1 promotes the previous head into history") + promoted := assertFetchedUserMessage(ctx, t, f, result.sendMessage.InsertedMessages[0]) + require.Equal(t, seeded.chatID, promoted.ChatID) + require.NotEmpty(t, seeded.queuedMessageBodies, + "E1 seed must record the queue head body") + assertChatMessageText(t, promoted, seeded.queuedMessageBodies[0]) + newQueued := assertFetchedQueuedMessage(ctx, t, f, seeded.chatID, *result.sendMessage.QueuedMessage) + assertQueuedMessageText(t, newQueued, "sm-queue") + // Previous head queued message is gone from the + // queue and now lives in history. + require.NotEmpty(t, base.queueIDs, "E1 seed must have a queue head") + requireQueuedMessageDeleted(ctx, t, f, seeded.chatID, base.queueIDs[0]) + require.Equal(t, []int64{newQueued.ID}, afterQueueIDs, + "E1 -> R1: queue must end with only the new tail") + require.False(t, after.LastError.Valid, "E1 -> R1 clears last_error") + require.Equal(t, database.ChatStatusRunning, after.Status) + require.Equal(t, []int64{promoted.ID}, newActiveMessageIDs(base, afterHistory), + "E1 -> R1 inserts only the promoted user message") + require.Greater(t, after.QueueVersion, base.queueVersion, + "E1 -> R1 advances queue_version") + + default: + // Busy states: the new user message is appended at + // the queue tail; history is untouched. + require.NotNil(t, result.sendMessage.QueuedMessage, + "SendMessage(queue) from busy states returns the queued message") + require.Empty(t, result.sendMessage.InsertedMessages, + "SendMessage(queue) from busy states does not insert history") + newQueued := assertFetchedQueuedMessage(ctx, t, f, seeded.chatID, *result.sendMessage.QueuedMessage) + assertQueuedMessageText(t, newQueued, "sm-queue") + wantQueue := append(append([]int64{}, base.queueIDs...), newQueued.ID) + require.Equal(t, wantQueue, afterQueueIDs, + "SendMessage(queue) from busy states appends to the queue tail") + require.Equal(t, base.historyIDs, afterHistory, + "SendMessage(queue) from busy states does not change history") + require.Greater(t, after.QueueVersion, base.queueVersion, + "SendMessage(queue) from busy states advances queue_version") + switch from { + case chatstate.StateA0, chatstate.StateA1: + require.True(t, after.RequiresActionDeadlineAt.Valid, + "SendMessage(queue) from A* preserves requires_action_deadline_at") + require.Equal(t, base.chat.RequiresActionDeadlineAt, after.RequiresActionDeadlineAt, + "SendMessage(queue) from A* preserves the deadline value") + require.Equal(t, database.ChatStatusRequiresAction, after.Status) + case chatstate.StateI0, chatstate.StateI1: + require.Equal(t, database.ChatStatusInterrupting, after.Status) + case chatstate.StateR0, chatstate.StateR1: + require.Equal(t, database.ChatStatusRunning, after.Status) + } + } + }, + } +} + +func sendMessageInterruptCase(from, want chatstate.ExecutionState) transitionCaseSpec { + return transitionCaseSpec{ + transition: chatstate.TransitionSendMessage, + from: from, + want: want, + scenario: scenarioInterrupt, + apply: applySendMessageInterrupt, + assert: func(ctx context.Context, t *testing.T, f *testFixture, seeded seededChat, base snapshotBaseline, result transitionCaseResult) { + after, err := f.DB.GetChatByID(ctx, seeded.chatID) + require.NoError(t, err) + afterQueue, err := f.DB.CountChatQueuedMessages(ctx, seeded.chatID) + require.NoError(t, err) + afterHistory := activeHistoryIDs(ctx, t, f, seeded.chatID) + afterQueueIDs := queuedIDsByPosition(ctx, t, f, seeded.chatID) + + switch from { + case chatstate.StateW, chatstate.StateE0: + // W/E0 with interrupt-mode behaves like the direct + // insert from queue-mode: the new user message lands + // directly in history, the queue is left untouched, + // last_error is cleared, and the chat lands in R0. + require.Equal(t, base.queueCount, afterQueue, + "SendMessage(interrupt) into W/E0 must not queue") + require.Nil(t, result.sendMessage.QueuedMessage, + "SendMessage(interrupt) into W/E0 does not return a queued message") + require.Len(t, result.sendMessage.InsertedMessages, 1, + "SendMessage(interrupt) into W/E0 inserts exactly one history message") + inserted := assertFetchedUserMessage(ctx, t, f, result.sendMessage.InsertedMessages[0]) + require.Equal(t, seeded.chatID, inserted.ChatID) + assertChatMessageText(t, inserted, "sm-interrupt") + require.False(t, after.LastError.Valid, + "SendMessage(interrupt) into W/E0 clears last_error") + require.Equal(t, database.ChatStatusRunning, after.Status, + "SendMessage(interrupt) into W/E0 lands in running") + require.Equal(t, base.queueIDs, afterQueueIDs, + "SendMessage(interrupt) into W/E0 must not touch queued messages") + require.Equal(t, []int64{inserted.ID}, newActiveMessageIDs(base, afterHistory), + "SendMessage(interrupt) into W/E0 appends exactly the new user message") + + case chatstate.StateE1: + // E1 with interrupt-mode mirrors queue-mode: the + // previous head is promoted into history and the new + // tail replaces it in the queue. Net queue size + // unchanged, last_error cleared. + require.Equal(t, base.queueCount, afterQueue, + "SendMessage(interrupt) from E1 leaves queue size unchanged") + require.NotNil(t, result.sendMessage.QueuedMessage, + "SendMessage(interrupt) from E1 returns the new queued tail") + require.Len(t, result.sendMessage.InsertedMessages, 1, + "SendMessage(interrupt) from E1 promotes the previous head into history") + promoted := assertFetchedUserMessage(ctx, t, f, result.sendMessage.InsertedMessages[0]) + require.Equal(t, seeded.chatID, promoted.ChatID) + require.NotEmpty(t, seeded.queuedMessageBodies, "E1 seed must record queue head body") + assertChatMessageText(t, promoted, seeded.queuedMessageBodies[0]) + newQueued := assertFetchedQueuedMessage(ctx, t, f, seeded.chatID, *result.sendMessage.QueuedMessage) + assertQueuedMessageText(t, newQueued, "sm-interrupt") + require.NotEmpty(t, base.queueIDs, "E1 seed must have a queue head") + requireQueuedMessageDeleted(ctx, t, f, seeded.chatID, base.queueIDs[0]) + require.Equal(t, []int64{newQueued.ID}, afterQueueIDs, + "E1 -> R1 interrupt: queue must end with only the new tail") + require.False(t, after.LastError.Valid) + require.Equal(t, database.ChatStatusRunning, after.Status) + require.Equal(t, []int64{promoted.ID}, newActiveMessageIDs(base, afterHistory), + "E1 -> R1 interrupt inserts only the promoted user message") + require.Greater(t, after.QueueVersion, base.queueVersion, + "E1 -> R1 interrupt advances queue_version") + + case chatstate.StateI0, chatstate.StateI1: + // I*: append to queue tail, history untouched, status + // stays interrupting. + require.Equal(t, base.queueCount+1, afterQueue, + "SendMessage(interrupt) from I* appends one queued message") + require.NotNil(t, result.sendMessage.QueuedMessage, + "SendMessage(interrupt) from I* returns the queued tail") + require.Empty(t, result.sendMessage.InsertedMessages, + "SendMessage(interrupt) from I* does not insert history") + newQueued := assertFetchedQueuedMessage(ctx, t, f, seeded.chatID, *result.sendMessage.QueuedMessage) + assertQueuedMessageText(t, newQueued, "sm-interrupt") + wantQueue := append(append([]int64{}, base.queueIDs...), newQueued.ID) + require.Equal(t, wantQueue, afterQueueIDs, + "SendMessage(interrupt) from I* appends to the queue tail") + require.Equal(t, base.historyIDs, afterHistory, + "SendMessage(interrupt) from I* must not touch history") + require.Equal(t, database.ChatStatusInterrupting, after.Status, + "SendMessage(interrupt) from I* keeps status interrupting") + require.Greater(t, after.QueueVersion, base.queueVersion, + "SendMessage(interrupt) from I* advances queue_version") + + case chatstate.StateR0, chatstate.StateR1: + require.Equal(t, base.queueCount+1, afterQueue, + "SendMessage(interrupt) from R* appends one queued message") + require.NotNil(t, result.sendMessage.QueuedMessage, + "SendMessage(interrupt) from R* returns the queued tail") + newQueued := assertFetchedQueuedMessage(ctx, t, f, seeded.chatID, *result.sendMessage.QueuedMessage) + assertQueuedMessageText(t, newQueued, "sm-interrupt") + wantQueue := append(append([]int64{}, base.queueIDs...), newQueued.ID) + require.Equal(t, wantQueue, afterQueueIDs, + "SendMessage(interrupt) from R* appends to the queue tail") + require.Greater(t, after.QueueVersion, base.queueVersion, + "SendMessage(interrupt) from R* advances queue_version") + require.Equal(t, database.ChatStatusInterrupting, after.Status, + "R* -> I1 sets status interrupting") + require.Equal(t, base.historyIDs, afterHistory, + "SendMessage(interrupt) from R* must not touch history") + + case chatstate.StateA0, chatstate.StateA1: + require.Equal(t, base.queueCount+1, afterQueue, + "SendMessage(interrupt) from A* appends one queued message") + require.NotNil(t, result.sendMessage.QueuedMessage, + "SendMessage(interrupt) from A* returns the queued tail") + newQueued := assertFetchedQueuedMessage(ctx, t, f, seeded.chatID, *result.sendMessage.QueuedMessage) + assertQueuedMessageText(t, newQueued, "sm-interrupt") + wantQueue := append(append([]int64{}, base.queueIDs...), newQueued.ID) + require.Equal(t, wantQueue, afterQueueIDs, + "SendMessage(interrupt) from A* appends to the queue tail") + require.Greater(t, after.QueueVersion, base.queueVersion, + "SendMessage(interrupt) from A* advances queue_version") + require.Equal(t, database.ChatStatusRunning, after.Status, + "A* -> R1 cancels pending dynamic calls and resumes running") + require.False(t, after.RequiresActionDeadlineAt.Valid, + "A* -> R1 clears requires_action_deadline_at") + // Cancellation messages for the pending dynamic + // tool call should land in active history. They are + // not returned via SendMessageResult, so we fetch + // them by diffing the active history set. + newIDs := newActiveMessageIDs(base, afterHistory) + require.Len(t, newIDs, 1, + "SendMessage(interrupt) from A* synthesizes exactly one tool-result cancellation") + cancel := requireChatMessageByID(ctx, t, f, newIDs[0]) + assertToolResultForCall(t, cancel, seeded.pendingToolCallID) + } + }, + } +} + +func editMessageCase(from chatstate.ExecutionState) transitionCaseSpec { + return transitionCaseSpec{ + transition: chatstate.TransitionEditMessage, + from: from, + want: chatstate.StateR0, + apply: applyEditMessage, + assert: func(ctx context.Context, t *testing.T, f *testFixture, seeded seededChat, base snapshotBaseline, result transitionCaseResult) { + after, err := f.DB.GetChatByID(ctx, seeded.chatID) + require.NoError(t, err) + require.Equal(t, database.ChatStatusRunning, after.Status, + "EditMessage always lands in running") + require.False(t, after.Archived, "EditMessage clears archived") + require.False(t, after.LastError.Valid, + "EditMessage clears last_error") + count, err := f.DB.CountChatQueuedMessages(ctx, seeded.chatID) + require.NoError(t, err) + require.Zero(t, count, "EditMessage clears the queue") + require.Empty(t, queuedIDsByPosition(ctx, t, f, seeded.chatID), + "EditMessage leaves no queued messages") + + // Replacement message must be a fresh user message that + // replaces the original target and lives in active history. + require.NotZero(t, result.editMessage.ReplacementMessage.ID, + "EditMessage returns the replacement message") + replacement := assertFetchedUserMessage(ctx, t, f, result.editMessage.ReplacementMessage) + require.Equal(t, seeded.chatID, replacement.ChatID) + require.NotEqual(t, seeded.initialUserMessageID, replacement.ID, + "EditMessage inserts a new replacement message") + assertChatMessageText(t, replacement, "edited") + + // Every history message from the edited message onward, + // inclusive, must be soft-deleted. base.historyIDs is the + // active history in order before the transition, so the + // expected deleted suffix is everything from the target's + // position to the end of that slice. GetChatMessageByID + // filters deleted=false, so it must return an error for + // each deleted ID. + require.NotEmpty(t, result.editMessage.DeletedMessageIDs, + "EditMessage deletes at least the target user message") + targetIdx := slices.Index(base.historyIDs, seeded.initialUserMessageID) + require.GreaterOrEqual(t, targetIdx, 0, + "baseline active history must contain the edited message") + wantDeleted := append([]int64{}, base.historyIDs[targetIdx:]...) + require.Equal(t, wantDeleted, result.editMessage.DeletedMessageIDs, + "EditMessage soft-deletes the edited message and every later active history message in order") + for _, id := range result.editMessage.DeletedMessageIDs { + _, err := f.DB.GetChatMessageByID(ctx, id) + require.Error(t, err, + "EditMessage: deleted message %d must not be active", id) + } + // Every deleted queued message must be gone from the queue. + for _, id := range result.editMessage.DeletedQueuedMessageIDs { + requireQueuedMessageDeleted(ctx, t, f, seeded.chatID, id) + } + for _, id := range base.queueIDs { + requireQueuedMessageDeleted(ctx, t, f, seeded.chatID, id) + } + }, + } +} + +func deleteQueuedCase(from, want chatstate.ExecutionState, shape queueShape) transitionCaseSpec { + spec := transitionCaseSpec{ + transition: chatstate.TransitionDeleteQueuedMessage, + from: from, + want: want, + apply: applyDeleteQueuedMessage, + assert: func(ctx context.Context, t *testing.T, f *testFixture, seeded seededChat, base snapshotBaseline, result transitionCaseResult) { + afterQueue, err := f.DB.CountChatQueuedMessages(ctx, seeded.chatID) + require.NoError(t, err) + require.Equal(t, base.queueCount-1, afterQueue, + "DeleteQueuedMessage removes exactly one queued message") + after, err := f.DB.GetChatByID(ctx, seeded.chatID) + require.NoError(t, err) + require.Greater(t, after.QueueVersion, base.queueVersion, + "DeleteQueuedMessage advances queue_version") + + // The target queued message is the seeded head. It must + // be returned in DeletedQueuedMessage, and it must no + // longer be fetchable. + require.NotEmpty(t, seeded.queuedMessageIDs) + targetID := seeded.queuedMessageIDs[0] + require.Equal(t, targetID, result.deleteQueuedMessage.DeletedQueuedMessage.ID, + "DeletedQueuedMessage returns the targeted queued message") + require.Equal(t, seeded.chatID, result.deleteQueuedMessage.DeletedQueuedMessage.ChatID) + requireQueuedMessageDeleted(ctx, t, f, seeded.chatID, targetID) + + // Remaining queue IDs are the baseline tail. + wantRemaining := append([]int64{}, base.queueIDs[1:]...) + require.Equal(t, wantRemaining, queuedIDsByPosition(ctx, t, f, seeded.chatID), + "DeleteQueuedMessage preserves remaining queue order") + require.Equal(t, base.historyIDs, activeHistoryIDs(ctx, t, f, seeded.chatID), + "DeleteQueuedMessage does not touch history") + }, + } + if shape.isMulti() { + spec.scenario = scenarioMulti + spec.seed = func(t *testing.T, f *testFixture, _ chatstate.ExecutionState) seededChat { + return seedStateMultiQueued(t, f, from) + } + } + return spec +} + +func promoteQueuedCase(from, want chatstate.ExecutionState, shape queueShape, targetIdx int) transitionCaseSpec { + var sc scenario + if shape.isMulti() { + switch targetIdx { + case 0: + sc = scenarioHeadTarget + default: + sc = scenarioNonHead + } + } + apply := func(t *testing.T, _ *testFixture, tx *chatstate.Tx, seeded seededChat, _ chatstate.ExecutionState, result *transitionCaseResult) error { + t.Helper() + require.Less(t, targetIdx, len(seeded.queuedMessageIDs), "promote target index out of range") + var err error + result.promoteQueuedMessage, err = tx.PromoteQueuedMessage(chatstate.PromoteQueuedMessageInput{ + QueuedMessageID: seeded.queuedMessageIDs[targetIdx], + }) + return err + } + spec := transitionCaseSpec{ + transition: chatstate.TransitionPromoteQueuedMessage, + from: from, + want: want, + scenario: sc, + apply: apply, + assert: func(ctx context.Context, t *testing.T, f *testFixture, seeded seededChat, base snapshotBaseline, result transitionCaseResult) { + after, err := f.DB.GetChatByID(ctx, seeded.chatID) + require.NoError(t, err) + afterHistory := activeHistoryIDs(ctx, t, f, seeded.chatID) + afterQueueIDs := queuedIDsByPosition(ctx, t, f, seeded.chatID) + afterQueue := int64(len(afterQueueIDs)) + + require.NotEmpty(t, seeded.queuedMessageIDs) + require.Less(t, targetIdx, len(seeded.queuedMessageIDs)) + targetID := seeded.queuedMessageIDs[targetIdx] + require.Equal(t, targetID, result.promoteQueuedMessage.QueuedMessage.ID, + "PromoteQueuedMessage returns the targeted queued message") + + switch from { + case chatstate.StateE1, chatstate.StateA1: + // Head is popped into history. + require.Equal(t, base.queueCount-1, afterQueue, + "E1/A1 promote pops the head into history") + require.Equal(t, database.ChatStatusRunning, after.Status, + "E1/A1 promote lands in running") + require.False(t, after.LastError.Valid, + "E1/A1 promote clears last_error") + require.False(t, after.RequiresActionDeadlineAt.Valid, + "E1/A1 promote clears requires_action_deadline_at") + require.NotNil(t, result.promoteQueuedMessage.InsertedMessage, + "E1/A1 promote inserts a user history message") + inserted := requireChatMessageByID(ctx, t, f, result.promoteQueuedMessage.InsertedMessage.ID) + require.Equal(t, seeded.chatID, inserted.ChatID) + require.Equal(t, database.ChatMessageRoleUser, inserted.Role) + require.True(t, inserted.ModelConfigID.Valid) + require.Equal(t, f.Model.ID, inserted.ModelConfigID.UUID) + require.Equal(t, chatprompt.CurrentContentVersion, inserted.ContentVersion) + require.True(t, inserted.CreatedBy.Valid) + require.Equal(t, result.promoteQueuedMessage.QueuedMessage.CreatedBy, inserted.CreatedBy.UUID, + "promoted history message preserves queued created_by") + if len(seeded.queuedMessageCreatedBy) > targetIdx { + require.Equal(t, seeded.queuedMessageCreatedBy[targetIdx], inserted.CreatedBy.UUID, + "promoted history message preserves non-owner queued creator") + } + require.NotEmpty(t, seeded.queuedMessageBodies, + "E1/A1 seed must record queued message bodies") + assertChatMessageText(t, inserted, seeded.queuedMessageBodies[targetIdx]) + requireQueuedMessageDeleted(ctx, t, f, seeded.chatID, targetID) + wantRemaining := remainingExcluding(base.queueIDs, targetIdx) + require.Equal(t, wantRemaining, afterQueueIDs, + "E1/A1 promote leaves the remaining queue order intact") + assertQueueBodiesInOrder(ctx, t, f, seeded.chatID, + remainingBodiesExcluding(seeded.queuedMessageBodies, targetIdx)) + // New active history adds exactly the inserted + // user message plus any synthetic cancellations. + newIDs := newActiveMessageIDs(base, afterHistory) + require.Contains(t, newIDs, inserted.ID, + "newly-active history contains the promoted user message") + if from == chatstate.StateA1 { + // A1: every outstanding tool call must be + // canceled before the promoted user message. + require.Len(t, result.promoteQueuedMessage.CancellationMessages, len(seeded.pendingToolCallIDs), + "A1 promote synthesizes one tool-result cancellation per outstanding call") + gotIDs := make(map[string]bool) + for _, cancelMsg := range result.promoteQueuedMessage.CancellationMessages { + cancel := requireChatMessageByID(ctx, t, f, cancelMsg.ID) + require.Less(t, cancel.ID, inserted.ID, + "A1 promote inserts cancellations before the promoted user message") + parts, err := chatprompt.ParseContent(cancel) + require.NoError(t, err) + for _, part := range parts { + if part.Type != codersdk.ChatMessagePartTypeToolResult { + continue + } + require.True(t, part.IsError, + "A1 promote synthetic cancellation is marked as an error") + gotIDs[part.ToolCallID] = true + } + } + for _, callID := range seeded.pendingToolCallIDs { + require.True(t, gotIDs[callID], + "A1 promote cancels outstanding tool call %s", callID) + } + } else { + require.Empty(t, result.promoteQueuedMessage.CancellationMessages, + "E1 promote has no synthetic cancellations") + } + case chatstate.StateR1, chatstate.StateI1: + // Reorder-only: status flips to interrupting, no + // history insert, queue cardinality unchanged. + require.Equal(t, base.queueCount, afterQueue, + "R1/I1 promote leaves queue cardinality unchanged") + require.Equal(t, database.ChatStatusInterrupting, after.Status, + "R1/I1 promote lands in interrupting") + require.Nil(t, result.promoteQueuedMessage.InsertedMessage, + "R1/I1 promote must not insert a history message") + require.Empty(t, result.promoteQueuedMessage.CancellationMessages, + "R1/I1 promote has no synthetic cancellations") + require.Equal(t, base.historyIDs, afterHistory, + "R1/I1 promote leaves history unchanged") + // Target must still be present and now at the head. + queued := requireQueuedMessageByID(ctx, t, f, seeded.chatID, targetID) + require.Equal(t, targetID, queued.ID) + require.NotEmpty(t, afterQueueIDs) + require.Equal(t, targetID, afterQueueIDs[0], + "R1/I1 promote brings the target to the queue head") + require.NotEmpty(t, seeded.queuedMessageBodies, + "R1/I1 seed must record queued message bodies") + if targetIdx == 0 { + // Head-target: zero rows updated, so the + // queue order is unchanged and queue_version + // stays put. + require.Equal(t, base.queueIDs, afterQueueIDs, + "head-target promote preserves queue order") + require.Equal(t, base.queueVersion, after.QueueVersion, + "head-target promote leaves queue_version unchanged") + assertQueueBodiesInOrder(ctx, t, f, seeded.chatID, seeded.queuedMessageBodies) + } else { + // Non-head: target moves to the head, the rest + // of the original order is preserved. + wantQueue := append([]int64{targetID}, remainingExcluding(base.queueIDs, targetIdx)...) + require.Equal(t, wantQueue, afterQueueIDs, + "non-head promote moves the target to the head and preserves the rest") + require.Greater(t, after.QueueVersion, base.queueVersion, + "non-head promote advances queue_version") + wantBodies := append([]string{seeded.queuedMessageBodies[targetIdx]}, + remainingBodiesExcluding(seeded.queuedMessageBodies, targetIdx)...) + assertQueueBodiesInOrder(ctx, t, f, seeded.chatID, wantBodies) + } + } + }, + } + if from == chatstate.StateA1 { + spec.seed = func(t *testing.T, f *testFixture, _ chatstate.ExecutionState) seededChat { + queuedExtras := 1 + if shape.isMulti() { + queuedExtras = 2 + } + return seedA1WithMixedOutstandingToolCalls(t, f, queuedExtras, "seed_tool_a1_promote") + } + } else if shape.isMulti() { + spec.seed = func(t *testing.T, f *testFixture, _ chatstate.ExecutionState) seededChat { + return seedStateMultiQueued(t, f, from) + } + } + return spec +} + +func interruptCase(from, want chatstate.ExecutionState) transitionCaseSpec { + return transitionCaseSpec{ + transition: chatstate.TransitionInterrupt, + from: from, + want: want, + apply: applyInterrupt, + assert: func(ctx context.Context, t *testing.T, f *testFixture, seeded seededChat, base snapshotBaseline, result transitionCaseResult) { + after, err := f.DB.GetChatByID(ctx, seeded.chatID) + require.NoError(t, err) + afterHistory := activeHistoryIDs(ctx, t, f, seeded.chatID) + afterQueueIDs := queuedIDsByPosition(ctx, t, f, seeded.chatID) + require.Equal(t, base.queueIDs, afterQueueIDs, + "Interrupt does not touch queued messages") + + switch from { + case chatstate.StateR0, chatstate.StateR1: + require.Equal(t, database.ChatStatusInterrupting, after.Status, + "Interrupt from R* sets status interrupting") + require.Equal(t, base.historyIDs, afterHistory, + "Interrupt from R* leaves history unchanged") + require.Empty(t, result.interrupt.CancellationMessages, + "Interrupt from R* does not synthesize tool cancellations") + case chatstate.StateA0, chatstate.StateA1: + require.Equal(t, database.ChatStatusRunning, after.Status, + "Interrupt from A* cancels pending dynamic calls and resumes running") + require.False(t, after.RequiresActionDeadlineAt.Valid, + "Interrupt from A* clears requires_action_deadline_at") + require.Len(t, result.interrupt.CancellationMessages, 1, + "Interrupt from A* synthesizes one tool-result cancellation") + cancel := requireChatMessageByID(ctx, t, f, + result.interrupt.CancellationMessages[0].ID) + assertToolResultForCall(t, cancel, seeded.pendingToolCallID) + } + }, + } +} + +func completeRequiresActionCase(from, want chatstate.ExecutionState) transitionCaseSpec { + // Re-seed A0/A1 fresh per case so the pending tool call ID is + // available on the seeded chat. + return transitionCaseSpec{ + transition: chatstate.TransitionCompleteRequiresAction, + from: from, + want: want, + apply: applyCompleteRequiresAction, + assert: func(ctx context.Context, t *testing.T, f *testFixture, seeded seededChat, base snapshotBaseline, result transitionCaseResult) { + after, err := f.DB.GetChatByID(ctx, seeded.chatID) + require.NoError(t, err) + require.Equal(t, database.ChatStatusRunning, after.Status, + "CompleteRequiresAction sets status running") + require.False(t, after.RequiresActionDeadlineAt.Valid, + "CompleteRequiresAction clears requires_action_deadline_at") + require.Equal(t, base.queueIDs, queuedIDsByPosition(ctx, t, f, seeded.chatID), + "CompleteRequiresAction preserves queued messages") + + // The user-submitted tool result must be inserted as a + // tool-role message that references the seeded + // pendingToolCallID with is_error=false. + require.Len(t, result.completeRequiresAction.InsertedMessages, 1, + "CompleteRequiresAction inserts one tool-result message per pending call") + inserted := requireChatMessageByID(ctx, t, f, + result.completeRequiresAction.InsertedMessages[0].ID) + assertToolResultForCallNoError(t, inserted, seeded.pendingToolCallID, `{"ok":true}`) + }, + } +} + +func cancelRequiresActionCase(from, want chatstate.ExecutionState) transitionCaseSpec { + return transitionCaseSpec{ + transition: chatstate.TransitionCancelRequiresAction, + from: from, + want: want, + apply: applyCancelRequiresAction, + assert: func(ctx context.Context, t *testing.T, f *testFixture, seeded seededChat, base snapshotBaseline, result transitionCaseResult) { + after, err := f.DB.GetChatByID(ctx, seeded.chatID) + require.NoError(t, err) + require.Equal(t, database.ChatStatusRunning, after.Status, + "CancelRequiresAction sets status running") + require.False(t, after.RequiresActionDeadlineAt.Valid, + "CancelRequiresAction clears requires_action_deadline_at") + require.Equal(t, base.queueIDs, queuedIDsByPosition(ctx, t, f, seeded.chatID), + "CancelRequiresAction preserves queued messages") + + // One synthetic tool-result cancellation per pending call. + require.Len(t, result.cancelRequiresAction.CancellationMessages, 1, + "CancelRequiresAction synthesizes one tool-result per pending call") + cancel := requireChatMessageByID(ctx, t, f, + result.cancelRequiresAction.CancellationMessages[0].ID) + assertToolResultForCall(t, cancel, seeded.pendingToolCallID) + }, + } +} + +func recordGenerationAttemptCase(from chatstate.ExecutionState) transitionCaseSpec { + return transitionCaseSpec{ + transition: chatstate.TransitionRecordGenerationAttempt, + from: from, + want: from, // state preserved + apply: applyRecordGenerationAttempt, + assert: func(ctx context.Context, t *testing.T, f *testFixture, seeded seededChat, base snapshotBaseline, result transitionCaseResult) { + after, err := f.DB.GetChatByID(ctx, seeded.chatID) + require.NoError(t, err) + require.Equal(t, int64(1), after.GenerationAttempt, + "RecordGenerationAttempt increments generation_attempt by one") + require.Equal(t, result.recordGenerationAttempt.GenerationAttempt, after.GenerationAttempt, + "RecordGenerationAttempt result mirrors the persisted value") + require.Equal(t, base.historyVersion, after.HistoryVersion, + "RecordGenerationAttempt does not change history_version") + require.Equal(t, base.queueVersion, after.QueueVersion, + "RecordGenerationAttempt does not change queue_version") + require.Equal(t, base.queueIDs, queuedIDsByPosition(ctx, t, f, seeded.chatID), + "RecordGenerationAttempt does not change queue order") + require.Equal(t, base.historyIDs, activeHistoryIDs(ctx, t, f, seeded.chatID), + "RecordGenerationAttempt does not change history messages") + }, + } +} + +func recordRetryStateCase(from chatstate.ExecutionState) transitionCaseSpec { + return transitionCaseSpec{ + transition: chatstate.TransitionRecordRetryState, + from: from, + want: from, // state preserved + apply: applyRecordRetryState, + assert: func(ctx context.Context, t *testing.T, f *testFixture, seeded seededChat, base snapshotBaseline, result transitionCaseResult) { + after, err := f.DB.GetChatByID(ctx, seeded.chatID) + require.NoError(t, err) + require.True(t, after.RetryState.Valid, + "RecordRetryState stores retry_state") + require.JSONEq(t, + string(result.recordRetryState.Chat.RetryState.RawMessage), + string(after.RetryState.RawMessage), + "RecordRetryState result mirrors persisted retry_state") + require.JSONEq(t, + `{"attempt":1,"delay_ms":250,"error":"retry","retrying_at":"2026-05-29T00:00:00Z"}`, + string(after.RetryState.RawMessage), + "RecordRetryState stores the expected payload") + require.Equal(t, after.SnapshotVersion, after.RetryStateVersion, + "RecordRetryState sets retry_state_version to snapshot_version") + require.Greater(t, after.RetryStateVersion, base.retryStateVersion, + "RecordRetryState advances retry_state_version") + require.Equal(t, base.historyVersion, after.HistoryVersion, + "RecordRetryState does not change history_version") + require.Equal(t, base.queueVersion, after.QueueVersion, + "RecordRetryState does not change queue_version") + require.Equal(t, base.generationAttempt, after.GenerationAttempt, + "RecordRetryState does not change generation_attempt") + require.Equal(t, base.queueIDs, queuedIDsByPosition(ctx, t, f, seeded.chatID), + "RecordRetryState does not change queue order") + require.Equal(t, base.historyIDs, activeHistoryIDs(ctx, t, f, seeded.chatID), + "RecordRetryState does not change history messages") + }, + } +} + +func commitStepCase(from chatstate.ExecutionState) transitionCaseSpec { + return transitionCaseSpec{ + transition: chatstate.TransitionCommitStep, + from: from, + want: from, // state preserved + apply: applyCommitStep, + assert: func(ctx context.Context, t *testing.T, f *testFixture, seeded seededChat, base snapshotBaseline, result transitionCaseResult) { + afterHistory := activeHistoryIDs(ctx, t, f, seeded.chatID) + require.Equal(t, len(base.historyIDs)+1, len(afterHistory), + "CommitStep appends exactly one history message") + after, err := f.DB.GetChatByID(ctx, seeded.chatID) + require.NoError(t, err) + require.Greater(t, after.HistoryVersion, base.historyVersion, + "CommitStep advances history_version") + + require.Len(t, result.commitStep.InsertedMessages, 1, + "CommitStep returns the inserted assistant message") + inserted := requireChatMessageByID(ctx, t, f, + result.commitStep.InsertedMessages[0].ID) + require.Equal(t, seeded.chatID, inserted.ChatID) + require.Equal(t, database.ChatMessageRoleAssistant, inserted.Role, + "CommitStep inserts an assistant-role message") + assertChatMessageText(t, inserted, "assistant") + require.Equal(t, base.queueIDs, queuedIDsByPosition(ctx, t, f, seeded.chatID), + "CommitStep does not change queue order") + require.Equal(t, base.queueVersion, after.QueueVersion, + "CommitStep does not change queue_version") + }, + } +} + +func enterRequiresActionCase(from, want chatstate.ExecutionState) transitionCaseSpec { + return transitionCaseSpec{ + transition: chatstate.TransitionEnterRequiresAction, + from: from, + want: want, + seed: seedForEnterRequiresAction, + apply: applyEnterRequiresAction, + assert: func(ctx context.Context, t *testing.T, f *testFixture, seeded seededChat, base snapshotBaseline, result transitionCaseResult) { + after, err := f.DB.GetChatByID(ctx, seeded.chatID) + require.NoError(t, err) + require.Equal(t, database.ChatStatusRequiresAction, after.Status, + "EnterRequiresAction sets status requires_action") + require.True(t, after.RequiresActionDeadlineAt.Valid, + "EnterRequiresAction populates requires_action_deadline_at") + require.True(t, result.enterRequiresAction.RequiresActionDeadlineAt.Valid, + "EnterRequiresAction returns the deadline") + require.Equal(t, result.enterRequiresAction.RequiresActionDeadlineAt, after.RequiresActionDeadlineAt, + "EnterRequiresAction returned deadline matches the persisted value") + require.Equal(t, base.queueIDs, queuedIDsByPosition(ctx, t, f, seeded.chatID), + "EnterRequiresAction preserves queued messages") + require.Equal(t, base.queueVersion, after.QueueVersion, + "EnterRequiresAction does not bump queue_version") + require.Equal(t, base.historyIDs, activeHistoryIDs(ctx, t, f, seeded.chatID), + "EnterRequiresAction does not insert history") + }, + } +} + +func finishInterruptionRejectsOutstandingToolCallCase() transitionCaseSpec { + return transitionCaseSpec{ + transition: chatstate.TransitionFinishInterruption, + from: chatstate.StateI0, + want: chatstate.StateI0, + scenario: scenarioRejectNonDynamicOutstandingToolCall, + seed: func(t *testing.T, f *testFixture, _ chatstate.ExecutionState) seededChat { + ctx := testutil.Context(t, testutil.WaitShort) + created := createTestChat(t, f) + m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{}) + + nonDynCallID := "call_" + uuid.NewString() + commitAssistantToolCall(t, f, m, + nonDynamicAssistantToolCallMessage(t, f.Model.ID, nonDynCallID)) + + require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error { + _, err := tx.Interrupt(chatstate.InterruptInput{Reason: "test"}) + return err + })) + return seededChat{ + chatID: created.Chat.ID, + exists: true, + initialUserMessageID: firstUserMessageID(ctx, t, f, created.Chat.ID), + assistantToolCallMsgID: firstAssistantMessageID(ctx, t, f, created.Chat.ID), + pendingToolCallIDs: []string{nonDynCallID}, + } + }, + apply: applyFinishInterruption, + assertFailure: func(ctx context.Context, t *testing.T, f *testFixture, seeded seededChat, base snapshotBaseline, err error) { + require.Error(t, err, "FinishInterruption must reject an outstanding tool call") + require.ErrorIs(t, err, chatstate.ErrTransitionNotAllowed, + "rejection must wrap ErrTransitionNotAllowed") + var te *chatstate.TransitionError + require.ErrorAs(t, err, &te, + "FinishInterruption must return a typed TransitionError") + require.Equal(t, chatstate.TransitionFinishInterruption, te.Transition) + require.Equal(t, chatstate.StateI0, te.From) + assertNoMutationOrPublish(ctx, t, f, seeded.chatID, base) + }, + } +} + +func finishInterruptionCase(from, want chatstate.ExecutionState, shape queueShape) transitionCaseSpec { + spec := transitionCaseSpec{ + transition: chatstate.TransitionFinishInterruption, + from: from, + want: want, + apply: applyFinishInterruption, + assert: func(ctx context.Context, t *testing.T, f *testFixture, seeded seededChat, base snapshotBaseline, result transitionCaseResult) { + after, err := f.DB.GetChatByID(ctx, seeded.chatID) + require.NoError(t, err) + afterHistory := activeHistoryIDs(ctx, t, f, seeded.chatID) + afterQueueIDs := queuedIDsByPosition(ctx, t, f, seeded.chatID) + switch from { + case chatstate.StateI0: + require.Equal(t, database.ChatStatusWaiting, after.Status, + "FinishInterruption from I0 lands in waiting") + require.Nil(t, result.finishInterruption.PromotedMessage, + "FinishInterruption from I0 promotes nothing") + require.Equal(t, base.queueIDs, afterQueueIDs, + "FinishInterruption from I0 leaves queued messages unchanged") + require.Equal(t, base.historyIDs, afterHistory, + "FinishInterruption from I0 with no partial messages leaves history unchanged") + case chatstate.StateI1: + require.Equal(t, database.ChatStatusRunning, after.Status, + "FinishInterruption from I1 lands in running") + require.NotNil(t, result.finishInterruption.PromotedMessage, + "FinishInterruption from I1 promotes the head into history") + promoted := assertFetchedUserMessage(ctx, t, f, + *result.finishInterruption.PromotedMessage) + require.Equal(t, seeded.chatID, promoted.ChatID) + require.Contains(t, newActiveMessageIDs(base, afterHistory), promoted.ID, + "FinishInterruption from I1 inserts the promoted user message") + require.NotEmpty(t, seeded.queuedMessageBodies, + "I1 seed must record queued message bodies") + assertChatMessageText(t, promoted, seeded.queuedMessageBodies[0]) + require.NotEmpty(t, base.queueIDs) + requireQueuedMessageDeleted(ctx, t, f, seeded.chatID, base.queueIDs[0]) + wantRemaining := append([]int64{}, base.queueIDs[1:]...) + require.Equal(t, wantRemaining, afterQueueIDs, + "FinishInterruption from I1 preserves the queue tail order") + assertQueueBodiesInOrder(ctx, t, f, seeded.chatID, + seeded.queuedMessageBodies[1:]) + } + }, + } + if shape.isMulti() { + spec.scenario = scenarioMulti + spec.seed = func(t *testing.T, f *testFixture, _ chatstate.ExecutionState) seededChat { + return seedStateMultiQueued(t, f, from) + } + } + return spec +} + +func finishTurnCase(from, want chatstate.ExecutionState, shape queueShape) transitionCaseSpec { + spec := transitionCaseSpec{ + transition: chatstate.TransitionFinishTurn, + from: from, + want: want, + apply: applyFinishTurn, + assert: func(ctx context.Context, t *testing.T, f *testFixture, seeded seededChat, base snapshotBaseline, result transitionCaseResult) { + after, err := f.DB.GetChatByID(ctx, seeded.chatID) + require.NoError(t, err) + afterHistory := activeHistoryIDs(ctx, t, f, seeded.chatID) + afterQueueIDs := queuedIDsByPosition(ctx, t, f, seeded.chatID) + switch from { + case chatstate.StateR0: + require.Equal(t, database.ChatStatusWaiting, after.Status, + "FinishTurn from R0 lands in waiting") + require.Nil(t, result.finishTurn.PromotedMessage, + "FinishTurn from R0 promotes nothing") + require.Equal(t, base.queueIDs, afterQueueIDs, + "FinishTurn from R0 leaves queued messages unchanged") + require.Equal(t, base.historyIDs, afterHistory, + "FinishTurn from R0 leaves history unchanged") + case chatstate.StateR1: + require.Equal(t, database.ChatStatusRunning, after.Status, + "FinishTurn from R1 lands in running") + require.NotNil(t, result.finishTurn.PromotedMessage, + "FinishTurn from R1 promotes the head into history") + promoted := assertFetchedUserMessage(ctx, t, f, + *result.finishTurn.PromotedMessage) + require.Equal(t, seeded.chatID, promoted.ChatID) + require.Contains(t, newActiveMessageIDs(base, afterHistory), promoted.ID, + "FinishTurn from R1 inserts the promoted user message") + require.NotEmpty(t, seeded.queuedMessageBodies, + "R1 seed must record queued message bodies") + assertChatMessageText(t, promoted, seeded.queuedMessageBodies[0]) + require.NotEmpty(t, base.queueIDs) + requireQueuedMessageDeleted(ctx, t, f, seeded.chatID, base.queueIDs[0]) + wantRemaining := append([]int64{}, base.queueIDs[1:]...) + require.Equal(t, wantRemaining, afterQueueIDs, + "FinishTurn from R1 preserves the queue tail order") + assertQueueBodiesInOrder(ctx, t, f, seeded.chatID, + seeded.queuedMessageBodies[1:]) + } + }, + } + if shape.isMulti() { + spec.scenario = scenarioMulti + spec.seed = func(t *testing.T, f *testFixture, _ chatstate.ExecutionState) seededChat { + return seedStateMultiQueued(t, f, from) + } + } + return spec +} + +func finishErrorCase(from, want chatstate.ExecutionState) transitionCaseSpec { + return transitionCaseSpec{ + transition: chatstate.TransitionFinishError, + from: from, + want: want, + apply: applyFinishError, + assert: func(ctx context.Context, t *testing.T, f *testFixture, seeded seededChat, base snapshotBaseline, result transitionCaseResult) { + _ = result + after, err := f.DB.GetChatByID(ctx, seeded.chatID) + require.NoError(t, err) + require.Equal(t, database.ChatStatusError, after.Status, + "FinishError sets status error") + require.True(t, after.LastError.Valid, + "FinishError stores last_error") + require.JSONEq(t, `{"message":"finish-error"}`, string(after.LastError.RawMessage), + "FinishError persists the input last_error JSON") + require.Equal(t, base.historyVersion, after.HistoryVersion, + "FinishError does not change history_version") + require.Equal(t, base.queueVersion, after.QueueVersion, + "FinishError does not change queue_version") + require.Equal(t, base.queueIDs, queuedIDsByPosition(ctx, t, f, seeded.chatID), + "FinishError preserves queued messages") + require.Equal(t, base.historyIDs, activeHistoryIDs(ctx, t, f, seeded.chatID), + "FinishError preserves history messages") + }, + } +} + +func reconcileInvalidStateCase(want chatstate.ExecutionState, shape queueShape) transitionCaseSpec { + spec := transitionCaseSpec{ + transition: chatstate.TransitionReconcileInvalidState, + from: chatstate.StateInvalid, + want: want, + apply: applyReconcileInvalidState, + assert: func(ctx context.Context, t *testing.T, f *testFixture, seeded seededChat, base snapshotBaseline, result transitionCaseResult) { + after, err := f.DB.GetChatByID(ctx, seeded.chatID) + require.NoError(t, err) + require.Equal(t, database.ChatStatusError, after.Status, + "ReconcileInvalidState lands in error") + require.False(t, after.Archived, + "ReconcileInvalidState clears archived") + require.True(t, after.LastError.Valid, + "ReconcileInvalidState sets a default last_error") + require.Equal(t, base.queueIDs, queuedIDsByPosition(ctx, t, f, seeded.chatID), + "ReconcileInvalidState preserves queued messages") + // For the current invalid seeds there are no pending + // dynamic tool calls, so no cancellation messages are + // expected. Still, if any are returned we fetch them + // to verify they were persisted as tool-role messages. + for _, c := range result.reconcileInvalidState.CancellationMessages { + msg := requireChatMessageByID(ctx, t, f, c.ID) + require.Equal(t, database.ChatMessageRoleTool, msg.Role) + } + }, + } + if shape.isMulti() { + spec.scenario = scenarioWithQueue + spec.seed = func(t *testing.T, f *testFixture, _ chatstate.ExecutionState) seededChat { + return seedInvalidWithQueue(t, f) + } + } + return spec +} diff --git a/coderd/x/chatd/chatstate/transitions_test.go b/coderd/x/chatd/chatstate/transitions_test.go new file mode 100644 index 0000000000..82fcb4924c --- /dev/null +++ b/coderd/x/chatd/chatstate/transitions_test.go @@ -0,0 +1,747 @@ +package chatstate_test + +import ( + "encoding/json" + "testing" + + "github.com/google/uuid" + "github.com/sqlc-dev/pqtype" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/x/chatd/chatprompt" + "github.com/coder/coder/v2/coderd/x/chatd/chatstate" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/testutil" +) + +// ============================================================================= +// CreateChat tests. +// +// CreateChat is the only transition that originates from StateN and it +// is not exercised through ChatMachine.Update, so it lives outside +// TestTransitionMatrix_AllCombinations. +// ============================================================================= + +// TestTransitionCreate_NToR0 verifies that CreateChat lands a fresh +// chat in R0 with snapshot_version 1, the initial user message +// recorded at revision 1, queue_version still 0, and the post-commit +// publish requesting an ownership hint plus a chat:update. +func TestTransitionCreate_NToR0(t *testing.T) { + t.Parallel() + f := newTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + res := createTestChat(t, f) + + require.Equal(t, database.ChatStatusRunning, res.Chat.Status) + require.False(t, res.Chat.Archived) + require.Equal(t, int64(1), res.Chat.SnapshotVersion, "snapshot_version starts at 1") + require.Equal(t, int64(1), res.Chat.HistoryVersion, "history_version set by trigger after initial insert") + require.Equal(t, int64(0), res.Chat.QueueVersion, "queue_version stays 0 when no queue rows") + require.Equal(t, int64(0), res.Chat.GenerationAttempt) + require.NotEmpty(t, res.InitialMessages) + require.Equal(t, int64(1), res.InitialMessages[0].Revision) + require.Equal(t, chatstate.StateR0, f.classify(ctx, t, res.Chat.ID)) + require.True(t, f.Pub.hasOwnership(), "newly created chat is runnable and unowned") + f.Pub.expectChatUpdate(t, res.Chat.ID, 1) +} + +// TestCreateChat_RejectsEmptyInitialMessages verifies that CreateChat +// rejects an empty InitialMessages slice with ErrTransitionNotAllowed +// and does not publish anything. +func TestCreateChat_RejectsEmptyInitialMessages(t *testing.T) { + t.Parallel() + f := newTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + _, err := chatstate.CreateChat(ctx, f.DB, f.Pub, chatstate.CreateChatInput{ + OrganizationID: f.Org.ID, + OwnerID: f.User.ID, + LastModelConfigID: f.Model.ID, + ClientType: database.ChatClientTypeApi, + Title: "t", + InitialMessages: nil, + }) + require.Error(t, err) + require.ErrorIs(t, err, chatstate.ErrTransitionNotAllowed) + require.Empty(t, f.Pub.channels, "rejected create must not publish") +} + +func TestCreateChat_AllowsNoUserMessages(t *testing.T) { + t.Parallel() + f := newTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + assistant := userTextMessage("oops", f.User.ID, f.Model.ID) + assistant.Role = database.ChatMessageRoleAssistant + res, err := chatstate.CreateChat(ctx, f.DB, f.Pub, chatstate.CreateChatInput{ + OrganizationID: f.Org.ID, + OwnerID: f.User.ID, + LastModelConfigID: f.Model.ID, + Title: "t", + ClientType: database.ChatClientTypeApi, + InitialMessages: []chatstate.Message{assistant}, + }) + require.NoError(t, err) + require.Len(t, res.InitialMessages, 1) +} + +func TestCreateChat_AllowsNonFinalUserMessage(t *testing.T) { + t.Parallel() + f := newTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + res, err := chatstate.CreateChat(ctx, f.DB, f.Pub, chatstate.CreateChatInput{ + OrganizationID: f.Org.ID, + OwnerID: f.User.ID, + LastModelConfigID: f.Model.ID, + Title: "t", + ClientType: database.ChatClientTypeApi, + InitialMessages: []chatstate.Message{ + userTextMessage("context user", f.User.ID, f.Model.ID), + userTextMessage("final user", f.User.ID, f.Model.ID), + }, + }) + require.NoError(t, err) + require.Len(t, res.InitialMessages, 2) +} + +// ============================================================================= +// Input-specific rejection cases. +// +// These tests cover the same matrix rows as TestTransitionMatrix_AllCombinations +// but exercise legal source states with invalid transition inputs. They are +// intentionally outside the matrix entry point so the matrix focus stays on +// positive cases and generated disallowed cases. +// ============================================================================= + +type setArchivedWrongDirectionCase struct { + from chatstate.ExecutionState + wantArchive bool + label string +} + +func setArchivedWrongDirectionCases() []setArchivedWrongDirectionCase { + return []setArchivedWrongDirectionCase{ + // Non-archived states with archived=false: no-op. + {from: chatstate.StateW, wantArchive: false, label: "W_to_W"}, + {from: chatstate.StateE0, wantArchive: false, label: "E0_to_E0"}, + {from: chatstate.StateE1, wantArchive: false, label: "E1_to_E1"}, + // Archived states with archived=true: no-op. + {from: chatstate.StateXW, wantArchive: true, label: "XW_to_XW"}, + {from: chatstate.StateXE0, wantArchive: true, label: "XE0_to_XE0"}, + {from: chatstate.StateXE1, wantArchive: true, label: "XE1_to_XE1"}, + } +} + +var invalidBusyBehaviors = []chatstate.BusyBehavior{ + chatstate.BusyBehavior(""), + chatstate.BusyBehavior("not-a-real-mode"), +} + +func runSetArchivedWrongDirectionCase(t *testing.T, tc setArchivedWrongDirectionCase) { + t.Helper() + f := newTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + seeded := seedState(t, f, tc.from) + require.Equal(t, tc.from, f.classify(ctx, t, seeded.chatID)) + + base := captureBaseline(ctx, t, f, seeded) + + m := chatstate.NewChatMachine(f.DB, f.Pub, seeded.chatID, chatstate.Options{}) + err := m.Update(ctx, func(tx *chatstate.Tx) error { + _, serr := tx.SetArchived(chatstate.SetArchivedInput{Archived: tc.wantArchive}) + return serr + }) + require.Error(t, err, "SetArchived must reject when Archived matches the current value") + require.ErrorIs(t, err, chatstate.ErrTransitionNotAllowed, + "SetArchived must wrap ErrTransitionNotAllowed") + var te *chatstate.TransitionError + require.ErrorAs(t, err, &te, + "SetArchived must return a typed TransitionError") + require.Equal(t, chatstate.TransitionSetArchived, te.Transition) + require.Equal(t, tc.from, te.From, "TransitionError records the loaded from-state") + + require.Equal(t, tc.from, f.classify(ctx, t, seeded.chatID), + "rejected SetArchived must leave the chat in the same state") + assertNoMutationOrPublish(ctx, t, f, seeded.chatID, base) +} + +func runInvalidBusyBehaviorCase(t *testing.T, from chatstate.ExecutionState, bb chatstate.BusyBehavior) { + t.Helper() + f := newTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + seeded := seedState(t, f, from) + require.Equal(t, from, f.classify(ctx, t, seeded.chatID), + "seed must land in %s", from) + base := captureBaseline(ctx, t, f, seeded) + + m := chatstate.NewChatMachine(f.DB, f.Pub, seeded.chatID, chatstate.Options{}) + err := m.Update(ctx, func(tx *chatstate.Tx) error { + _, serr := tx.SendMessage(chatstate.SendMessageInput{ + Message: userTextMessage("invalid-bb", f.User.ID, f.Model.ID), + BusyBehavior: bb, + }) + return serr + }) + require.Error(t, err, "SendMessage must reject invalid BusyBehavior") + require.ErrorIs(t, err, chatstate.ErrTransitionNotAllowed, + "SendMessage rejection must wrap ErrTransitionNotAllowed") + var te *chatstate.TransitionError + require.ErrorAs(t, err, &te, + "SendMessage must return a typed TransitionError") + require.Equal(t, chatstate.TransitionSendMessage, te.Transition) + require.Equal(t, from, te.From, + "TransitionError records the source state") + + require.Equal(t, from, f.classify(ctx, t, seeded.chatID), + "rejected SendMessage must leave the chat in the same state") + assertNoMutationOrPublish(ctx, t, f, seeded.chatID, base) +} + +type completeRequiresActionRejectCase struct { + name string + results func(seeded seededChat) []chatstate.ToolResultInput +} + +type recordRetryStateRejectCase struct { + name string + retryState pqtype.NullRawMessage +} + +func completeRequiresActionRejectCases() []completeRequiresActionRejectCase { + valid := func(id string) chatstate.ToolResultInput { + return chatstate.ToolResultInput{ + ToolCallID: id, + Output: json.RawMessage(`{"ok":true}`), + } + } + return []completeRequiresActionRejectCase{ + { + name: "missing_required_tool_result", + results: func(seeded seededChat) []chatstate.ToolResultInput { return nil }, + }, + { + name: "extra_tool_result", + results: func(seeded seededChat) []chatstate.ToolResultInput { + return []chatstate.ToolResultInput{valid(seeded.pendingToolCallID), valid("call_extra")} + }, + }, + { + name: "duplicate_tool_call_id", + results: func(seeded seededChat) []chatstate.ToolResultInput { + return []chatstate.ToolResultInput{valid(seeded.pendingToolCallID), valid(seeded.pendingToolCallID)} + }, + }, + { + name: "mismatched_tool_call_id", + results: func(seeded seededChat) []chatstate.ToolResultInput { + return []chatstate.ToolResultInput{valid("call_mismatch")} + }, + }, + { + name: "invalid_json_output", + results: func(seeded seededChat) []chatstate.ToolResultInput { + return []chatstate.ToolResultInput{{ToolCallID: seeded.pendingToolCallID, Output: json.RawMessage(`{`)}} + }, + }, + } +} + +func recordRetryStateRejectCases() []recordRetryStateRejectCase { + return []recordRetryStateRejectCase{ + { + name: "sql_null_payload", + }, + { + name: "empty_payload", + retryState: pqtype.NullRawMessage{RawMessage: json.RawMessage(``), Valid: true}, + }, + { + name: "invalid_json_payload", + retryState: pqtype.NullRawMessage{RawMessage: json.RawMessage(`{`), Valid: true}, + }, + } +} + +func runCompleteRequiresActionRejectCase(t *testing.T, tc completeRequiresActionRejectCase) { + t.Helper() + f := newTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + seeded := seedAOrA1(t, f, 0, "reject_complete_requires_action") + require.Equal(t, chatstate.StateA0, f.classify(ctx, t, seeded.chatID)) + base := captureBaseline(ctx, t, f, seeded) + + m := chatstate.NewChatMachine(f.DB, f.Pub, seeded.chatID, chatstate.Options{}) + err := m.Update(ctx, func(tx *chatstate.Tx) error { + _, cerr := tx.CompleteRequiresAction(chatstate.CompleteRequiresActionInput{ + CreatedBy: f.User.ID, + ModelConfigID: f.Model.ID, + Results: tc.results(seeded), + }) + return cerr + }) + require.Error(t, err) + require.ErrorIs(t, err, chatstate.ErrTransitionNotAllowed) + var te *chatstate.TransitionError + require.ErrorAs(t, err, &te) + require.Equal(t, chatstate.TransitionCompleteRequiresAction, te.Transition) + require.Equal(t, chatstate.StateA0, te.From) + assertNoMutationOrPublish(ctx, t, f, seeded.chatID, base) +} + +func runRecordRetryStateRejectCase(t *testing.T, tc recordRetryStateRejectCase) { + t.Helper() + f := newTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + seeded := seedState(t, f, chatstate.StateR0) + require.Equal(t, chatstate.StateR0, f.classify(ctx, t, seeded.chatID)) + base := captureBaseline(ctx, t, f, seeded) + + m := chatstate.NewChatMachine(f.DB, f.Pub, seeded.chatID, chatstate.Options{}) + err := m.Update(ctx, func(tx *chatstate.Tx) error { + _, rerr := tx.RecordRetryState(chatstate.RecordRetryStateInput{ + RetryState: tc.retryState, + }) + return rerr + }) + require.Error(t, err) + require.ErrorIs(t, err, chatstate.ErrTransitionNotAllowed) + var te *chatstate.TransitionError + require.ErrorAs(t, err, &te) + require.Equal(t, chatstate.TransitionRecordRetryState, te.Transition) + require.Equal(t, chatstate.StateR0, te.From) + assertNoMutationOrPublish(ctx, t, f, seeded.chatID, base) +} + +// TestTransitionInputValidation groups every input-specific rejection +// test. The matrix coverage entry point in +// TestTransitionMatrix_AllCombinations intentionally focuses on +// positive cases and generated disallowed cases; rejection cases that +// exercise legal matrix rows with invalid inputs live here so the +// matrix entry point stays focused. +func TestTransitionInputValidation(t *testing.T) { + t.Parallel() + + t.Run("SetArchived_wrong_direction", func(t *testing.T) { + t.Parallel() + for _, tc := range setArchivedWrongDirectionCases() { + t.Run(tc.label, func(t *testing.T) { + t.Parallel() + runSetArchivedWrongDirectionCase(t, tc) + }) + } + }) + + t.Run("SendMessage_invalid_busy_behavior", func(t *testing.T) { + t.Parallel() + for _, from := range chatstate.AllowedInputStates(chatstate.TransitionSendMessage) { + for _, bb := range invalidBusyBehaviors { + label := from.String() + "/" + string(bb) + if bb == "" { + label = from.String() + "/empty" + } + t.Run(label, func(t *testing.T) { + t.Parallel() + runInvalidBusyBehaviorCase(t, from, bb) + }) + } + } + }) + + t.Run("CompleteRequiresAction_invalid_results", func(t *testing.T) { + t.Parallel() + for _, tc := range completeRequiresActionRejectCases() { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + runCompleteRequiresActionRejectCase(t, tc) + }) + } + }) + + t.Run("RecordRetryState_invalid_payload", func(t *testing.T) { + t.Parallel() + for _, tc := range recordRetryStateRejectCases() { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + runRecordRetryStateRejectCase(t, tc) + }) + } + }) +} + +// TestSendMessageQueueCapRejectsQueueAppend seeds a chat with the +// maximum queued messages and asserts that the next SendMessage in +// a queue-appending state returns chatstate.ErrMessageQueueFull and +// rolls back without persisting another queued row. +func TestSendMessageQueueCapRejectsQueueAppend(t *testing.T) { + t.Parallel() + f := newTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + created := createTestChat(t, f) + m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{}) + + // createTestChat lands the chat in R0; SendMessage in R0 with + // BusyBehaviorQueue queues. Fill the queue to MaxQueueSize. + for i := 0; i < chatstate.MaxQueueSize; i++ { + sendQueuedMessage(t, f, m, "filler") + } + count, err := f.DB.CountChatQueuedMessages(ctx, created.Chat.ID) + require.NoError(t, err) + require.EqualValues(t, chatstate.MaxQueueSize, count) + chatBefore := f.readChat(ctx, t, created.Chat.ID) + + // The next queue append must fail with ErrMessageQueueFull and a + // typed wrapper that exposes the cap. + err = m.Update(ctx, func(tx *chatstate.Tx) error { + _, serr := tx.SendMessage(chatstate.SendMessageInput{ + Message: userTextMessage("overflow", f.User.ID, f.Model.ID), + BusyBehavior: chatstate.BusyBehaviorQueue, + }) + return serr + }) + require.Error(t, err) + require.ErrorIs(t, err, chatstate.ErrMessageQueueFull, + "queue-append over the cap returns ErrMessageQueueFull") + var typed *chatstate.MessageQueueFullError + require.ErrorAs(t, err, &typed, "ErrMessageQueueFull is carried as a typed error") + require.EqualValues(t, chatstate.MaxQueueSize, typed.Max) + + // The transaction rolled back: queue size, snapshot version, + // and queue version are unchanged. + countAfter, err := f.DB.CountChatQueuedMessages(ctx, created.Chat.ID) + require.NoError(t, err) + require.EqualValues(t, chatstate.MaxQueueSize, countAfter, + "queue size must not change when the cap rejects the append") + chatAfter := f.readChat(ctx, t, created.Chat.ID) + require.Equal(t, chatBefore.SnapshotVersion, chatAfter.SnapshotVersion, + "failed queue append must not bump snapshot_version") + require.Equal(t, chatBefore.QueueVersion, chatAfter.QueueVersion, + "failed queue append must not bump queue_version") +} + +// TestEditMessageNonUserReturnsSentinel asserts that editing a +// non-user message returns chatstate.ErrEditedMessageNotUser via +// the TransitionError cause chain, and still matches the generic +// transition sentinel. +func TestEditMessageNonUserReturnsSentinel(t *testing.T) { + t.Parallel() + f := newTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + created := createTestChat(t, f) + m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{}) + + // Insert an assistant message via CommitStep so we have a + // non-user message to target. + var assistantID int64 + require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error { + assistant := userTextMessage("assistant", f.User.ID, f.Model.ID) + assistant.Role = database.ChatMessageRoleAssistant + step, err := tx.CommitStep(chatstate.CommitStepInput{ + Messages: []chatstate.Message{assistant}, + }) + if err != nil { + return err + } + require.Len(t, step.InsertedMessages, 1) + assistantID = step.InsertedMessages[0].ID + return nil + })) + + rawContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{ + codersdk.ChatMessageText("new content"), + }) + require.NoError(t, err) + + editErr := m.Update(ctx, func(tx *chatstate.Tx) error { + _, eerr := tx.EditMessage(chatstate.EditMessageInput{ + MessageID: assistantID, + CreatedBy: f.User.ID, + Content: rawContent, + }) + return eerr + }) + require.Error(t, editErr) + require.ErrorIs(t, editErr, chatstate.ErrEditedMessageNotUser, + "non-user edit returns ErrEditedMessageNotUser via TransitionError cause") + require.ErrorIs(t, editErr, chatstate.ErrTransitionNotAllowed, + "ErrEditedMessageNotUser still matches the generic transition sentinel") +} + +// TestTransitionAbandon_RejectsUnowned verifies that calling Abandon +// on a chat the runner does not own returns ErrTransitionNotAllowed +// wrapped in a TransitionError that records the loaded from-state, +// without mutating chat state or publishing anything. +func TestTransitionAbandon_RejectsUnowned(t *testing.T) { + t.Parallel() + f := newTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + created := createTestChat(t, f) + seeded := seededChat{chatID: created.Chat.ID, exists: true} + m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{}) + base := captureBaseline(ctx, t, f, seeded) + + err := m.Update(ctx, func(tx *chatstate.Tx) error { + _, aerr := tx.Abandon(chatstate.AbandonInput{}) + return aerr + }) + require.Error(t, err) + require.ErrorIs(t, err, chatstate.ErrTransitionNotAllowed) + var te *chatstate.TransitionError + require.ErrorAs(t, err, &te) + require.Equal(t, chatstate.TransitionAbandon, te.Transition) + // createTestChat lands the chat in R0; Abandon's precondition + // rejects an unowned chat there. + require.Equal(t, chatstate.StateR0, te.From) + assertNoMutationOrPublish(ctx, t, f, seeded.chatID, base) +} + +// TestTransitionAbandon_ClearsOwnership verifies the Acquire/Abandon +// round-trip: after Acquire the chat carries a worker+runner and a +// fresh heartbeat row exists, and after Abandon both ownership fields +// are cleared. The heartbeat row is not deleted by Abandon; heartbeat +// cleanup is a separate concern. +func TestTransitionAbandon_ClearsOwnership(t *testing.T) { + t.Parallel() + f := newTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + created := createTestChat(t, f) + m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{}) + + worker := uuid.New() + runner := uuid.New() + + // Acquire writes ownership and a fresh heartbeat row. + require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error { + _, err := tx.Acquire(chatstate.AcquireInput{WorkerID: worker, RunnerID: runner}) + return err + })) + owned := f.readChat(ctx, t, created.Chat.ID) + require.Equal(t, worker, owned.WorkerID.UUID) + require.Equal(t, runner, owned.RunnerID.UUID) + hb, err := f.DB.GetChatHeartbeat(ctx, database.GetChatHeartbeatParams{ + ChatID: created.Chat.ID, + RunnerID: runner, + }) + require.NoError(t, err, "Acquire writes a fresh heartbeat row") + require.Equal(t, runner, hb.RunnerID) + + // Abandon clears ownership but leaves the heartbeat row intact. + require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error { + _, err := tx.Abandon(chatstate.AbandonInput{}) + return err + })) + hb, err = f.DB.GetChatHeartbeat(ctx, database.GetChatHeartbeatParams{ + ChatID: created.Chat.ID, + RunnerID: runner, + }) + require.NoError(t, err, "Abandon does not delete the heartbeat row") + abandoned := f.readChat(ctx, t, created.Chat.ID) + require.False(t, abandoned.WorkerID.Valid, "Abandon clears worker_id") + require.False(t, abandoned.RunnerID.Valid, "Abandon clears runner_id") +} + +// TestTransitionAcquire_OverwritesFreshOwnership verifies that Acquire +// is an unconditional ownership handoff: a second worker calling +// Acquire on a chat that was *just* acquired by another worker +// successfully replaces ownership without inspecting heartbeat +// freshness. It also asserts that Acquire itself does not request an +// ownership hint, so the post-commit publish stays quiet on +// `chat:ownership` when the resulting heartbeat is fresh. +func TestTransitionAcquire_OverwritesFreshOwnership(t *testing.T) { + t.Parallel() + f := newTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + created := createTestChat(t, f) + m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{}) + + firstWorker := uuid.New() + firstRunner := uuid.New() + require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error { + _, err := tx.Acquire(chatstate.AcquireInput{WorkerID: firstWorker, RunnerID: firstRunner}) + return err + })) + + // The chat is now owned with a fresh (chat_id, firstRunner) + // heartbeat written by the first Acquire. + firstChat := f.readChat(ctx, t, created.Chat.ID) + require.Equal(t, firstWorker, firstChat.WorkerID.UUID) + require.Equal(t, firstRunner, firstChat.RunnerID.UUID) + _, err := f.DB.GetChatHeartbeat(ctx, database.GetChatHeartbeatParams{ + ChatID: created.Chat.ID, + RunnerID: firstRunner, + }) + require.NoError(t, err, "first Acquire wrote a fresh heartbeat") + // Sanity check: heartbeat is not stale by the same threshold the + // machine uses for ownership-hint decisions. + stale, err := f.DB.IsChatHeartbeatStale(ctx, database.IsChatHeartbeatStaleParams{ + ChatID: created.Chat.ID, + RunnerID: firstRunner, + StaleSeconds: chatstate.HeartbeatStaleSeconds, + }) + require.NoError(t, err) + require.False(t, stale, "first runner's heartbeat is fresh before the second Acquire") + + // Snapshot publish counts before the takeover so we can assert + // Acquire does not publish an ownership hint itself. + ownershipBefore := f.Pub.ownershipPublishCount() + beforeChat := f.readChat(ctx, t, created.Chat.ID) + + secondWorker := uuid.New() + secondRunner := uuid.New() + require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error { + _, err := tx.Acquire(chatstate.AcquireInput{WorkerID: secondWorker, RunnerID: secondRunner}) + return err + })) + + after := f.readChat(ctx, t, created.Chat.ID) + require.Equal(t, secondWorker, after.WorkerID.UUID, "ownership replaced") + require.Equal(t, secondRunner, after.RunnerID.UUID, "runner replaced") + require.Equal(t, beforeChat.SnapshotVersion+1, after.SnapshotVersion, "snapshot bumps exactly once") + f.Pub.expectChatUpdate(t, created.Chat.ID, after.SnapshotVersion) + + // The new (chat_id, secondRunner) heartbeat exists. The old + // (chat_id, firstRunner) row may or may not exist; Acquire is not + // responsible for cleaning it up. + _, err = f.DB.GetChatHeartbeat(ctx, database.GetChatHeartbeatParams{ + ChatID: created.Chat.ID, + RunnerID: secondRunner, + }) + require.NoError(t, err, "second Acquire wrote a heartbeat for the new runner") + + // Acquire does not publish an ownership hint when it writes a fresh + // heartbeat. The post-commit ownership-hint logic in Update stays + // quiet because the new heartbeat is fresh, so no `chat:ownership` + // notification fires. + require.Equal(t, ownershipBefore, f.Pub.ownershipPublishCount(), + "Acquire must not publish an ownership hint when the resulting heartbeat is fresh") +} + +// TestTransitionAcquire_ExecutionStateOrthogonal verifies that Acquire +// preserves every execution-state field on the chat across +// representative valid execution states, including idle, runnable, and +// archived states. The transition only mutates ownership. +func TestTransitionAcquire_ExecutionStateOrthogonal(t *testing.T) { + t.Parallel() + + // Each setup leaves the chat in the named state and returns the + // chat ID for downstream assertions. + cases := []struct { + name string + state chatstate.ExecutionState + setup func(t *testing.T, f *testFixture) uuid.UUID + }{ + { + name: "R0", + state: chatstate.StateR0, + setup: func(t *testing.T, f *testFixture) uuid.UUID { + return createTestChat(t, f).Chat.ID + }, + }, + { + name: "W", + state: chatstate.StateW, + setup: func(t *testing.T, f *testFixture) uuid.UUID { + created := createTestChat(t, f) + ctx := testutil.Context(t, testutil.WaitShort) + m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{}) + require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error { + _, err := tx.FinishTurn(chatstate.FinishTurnInput{}) + return err + })) + return created.Chat.ID + }, + }, + { + name: "E0", + state: chatstate.StateE0, + setup: func(t *testing.T, f *testFixture) uuid.UUID { + created := createTestChat(t, f) + ctx := testutil.Context(t, testutil.WaitShort) + m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{}) + require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error { + _, err := tx.FinishError(chatstate.FinishErrorInput{ + LastError: pqtype.NullRawMessage{ + RawMessage: json.RawMessage(`{"message":"boom"}`), + Valid: true, + }, + }) + return err + })) + return created.Chat.ID + }, + }, + { + name: "I0", + state: chatstate.StateI0, + setup: func(t *testing.T, f *testFixture) uuid.UUID { + created := createTestChat(t, f) + ctx := testutil.Context(t, testutil.WaitShort) + m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{}) + require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error { + _, err := tx.Interrupt(chatstate.InterruptInput{Reason: "test"}) + return err + })) + return created.Chat.ID + }, + }, + { + name: "XW", + state: chatstate.StateXW, + setup: func(t *testing.T, f *testFixture) uuid.UUID { + created := createTestChat(t, f) + ctx := testutil.Context(t, testutil.WaitShort) + m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{}) + require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error { + _, err := tx.FinishTurn(chatstate.FinishTurnInput{}) + return err + })) + require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error { + _, err := tx.SetArchived(chatstate.SetArchivedInput{Archived: true}) + return err + })) + return created.Chat.ID + }, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + f := newTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + chatID := tc.setup(t, f) + require.Equal(t, tc.state, f.classify(ctx, t, chatID), "test setup must leave chat in %s", tc.state) + + before := f.readChat(ctx, t, chatID) + queueBefore, err := f.DB.CountChatQueuedMessages(ctx, chatID) + require.NoError(t, err) + historyBefore := historyMessageIDs(ctx, t, f, chatID) + + worker := uuid.New() + runner := uuid.New() + m := chatstate.NewChatMachine(f.DB, f.Pub, chatID, chatstate.Options{}) + require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error { + _, err := tx.Acquire(chatstate.AcquireInput{WorkerID: worker, RunnerID: runner}) + return err + })) + + after := f.readChat(ctx, t, chatID) + // Ownership updated. + require.Equal(t, worker, after.WorkerID.UUID) + require.Equal(t, runner, after.RunnerID.UUID) + // Execution state preserved. + require.Equal(t, before.Status, after.Status, "status preserved") + require.Equal(t, before.Archived, after.Archived, "archived flag preserved") + require.Equal(t, before.RequiresActionDeadlineAt, after.RequiresActionDeadlineAt, "requires-action deadline preserved") + require.Equal(t, before.LastError, after.LastError, "last_error preserved") + require.Equal(t, before.HistoryVersion, after.HistoryVersion, "history_version preserved") + require.Equal(t, before.QueueVersion, after.QueueVersion, "queue_version preserved") + require.Equal(t, before.GenerationAttempt, after.GenerationAttempt, "generation_attempt preserved") + // Classified state unchanged. + require.Equal(t, tc.state, f.classify(ctx, t, chatID), "execution state preserved by Acquire") + // Queue and history rows untouched. + queueAfter, err := f.DB.CountChatQueuedMessages(ctx, chatID) + require.NoError(t, err) + require.Equal(t, queueBefore, queueAfter, "queue cardinality preserved") + require.Equal(t, historyBefore, historyMessageIDs(ctx, t, f, chatID), "history preserved") + }) + } +} diff --git a/coderd/x/chatd/chatstate/trigger_test.go b/coderd/x/chatd/chatstate/trigger_test.go new file mode 100644 index 0000000000..3bc0435802 --- /dev/null +++ b/coderd/x/chatd/chatstate/trigger_test.go @@ -0,0 +1,631 @@ +package chatstate_test + +import ( + "database/sql" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbgen" + "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/coderd/x/chatd/chatprompt" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/testutil" +) + +// triggerFixture is a slim variant of testFixture that also exposes a +// raw *sql.DB so the trigger tests can run UPDATE/INSERT statements +// that bypass the typed sqlc layer. Tests that only need the typed +// store should keep using newTestFixture. +type triggerFixture struct { + f *testFixture + sqlDB *sql.DB +} + +func newTriggerFixture(t *testing.T) *triggerFixture { + t.Helper() + db, ps, sqlDB := dbtestutil.NewDBWithSQLDB(t) + user := dbgen.User(t, db, database.User{}) + org := dbgen.Organization(t, db, database.Organization{}) + dbgen.OrganizationMember(t, db, database.OrganizationMember{ + UserID: user.ID, + OrganizationID: org.ID, + }) + dbgen.ChatProvider(t, db, database.ChatProvider{ + Provider: "openai", + DisplayName: "openai", + BaseUrl: "http://example.invalid", + }) + model := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{ + Provider: "openai", + IsDefault: true, + }) + f := &testFixture{ + DB: db, + PubSub: ps, + Pub: newRecordingPubsub(), + User: user, + Org: org, + Model: model, + } + return &triggerFixture{f: f, sqlDB: sqlDB} +} + +// userMessageContent returns a marshaled user message body suitable +// for raw INSERT into chat_messages. +func userMessageContent(t *testing.T, text string) []byte { + t.Helper() + raw, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{codersdk.ChatMessageText(text)}) + require.NoError(t, err) + return raw.RawMessage +} + +// TestMessageInsertAssignsRevisionAndHistoryVersion verifies that +// inserting a chat message via the legacy InsertChatMessages query +// assigns NEW.revision from chats.snapshot_version (BEFORE trigger) +// and bumps chats.history_version + resets generation_attempt (AFTER +// STATEMENT trigger). +func TestMessageInsertAssignsRevisionAndHistoryVersion(t *testing.T) { + t.Parallel() + tf := newTriggerFixture(t) + f := tf.f + ctx := testutil.Context(t, testutil.WaitShort) + + created := createTestChat(t, f) + require.Equal(t, int64(1), created.Chat.SnapshotVersion) + require.Equal(t, int64(1), created.Chat.HistoryVersion) + + // Force generation_attempt > 0 so we can prove the trigger + // resets it on a new history change. + _, err := f.DB.IncrementChatGenerationAttempt(ctx, created.Chat.ID) + require.NoError(t, err) + before, err := f.DB.GetChatByID(ctx, created.Chat.ID) + require.NoError(t, err) + require.Equal(t, int64(1), before.GenerationAttempt) + + // Bump snapshot_version directly to simulate a transition having + // taken the row lock. + bumped, err := f.DB.LockChatAndBumpSnapshotVersion(ctx, created.Chat.ID) + require.NoError(t, err) + require.Equal(t, before.SnapshotVersion+1, bumped.SnapshotVersion) + + // Insert a new assistant message via raw SQL so we know the + // BEFORE+AFTER triggers (and only those) decide revision and + // history_version. + content := userMessageContent(t, "hello-after-bump") + _, err = tf.sqlDB.ExecContext(ctx, ` + INSERT INTO chat_messages (chat_id, role, content, content_version, visibility) + VALUES ($1, 'assistant', $2::jsonb, $3, 'both') + `, created.Chat.ID, string(content), int(chatprompt.CurrentContentVersion)) + require.NoError(t, err) + + // History version equals snapshot_version, generation_attempt resets. + after, err := f.DB.GetChatByID(ctx, created.Chat.ID) + require.NoError(t, err) + require.Equal(t, bumped.SnapshotVersion, after.HistoryVersion) + require.Equal(t, int64(0), after.GenerationAttempt) + + // The inserted message picked up revision = bumped snapshot. + msgs, err := f.DB.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ + ChatID: created.Chat.ID, + }) + require.NoError(t, err) + require.NotEmpty(t, msgs) + last := msgs[len(msgs)-1] + require.Equal(t, database.ChatMessageRoleAssistant, last.Role) + require.Equal(t, bumped.SnapshotVersion, last.Revision) +} + +// TestMessageUpdateAssignsNewRevisionAndHistoryVersion verifies that +// updating a chat message's content advances NEW.revision to the +// current chats.snapshot_version and that chats.history_version +// bumps to match. +func TestMessageUpdateAssignsNewRevisionAndHistoryVersion(t *testing.T) { + t.Parallel() + tf := newTriggerFixture(t) + f := tf.f + ctx := testutil.Context(t, testutil.WaitShort) + created := createTestChat(t, f) + + msgs, err := f.DB.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ + ChatID: created.Chat.ID, + }) + require.NoError(t, err) + require.NotEmpty(t, msgs) + target := msgs[0] + originalRevision := target.Revision + + // Bump the snapshot so the trigger sees a new revision target. + bumped, err := f.DB.LockChatAndBumpSnapshotVersion(ctx, created.Chat.ID) + require.NoError(t, err) + require.Greater(t, bumped.SnapshotVersion, originalRevision) + + newContent := userMessageContent(t, "edited content") + _, err = tf.sqlDB.ExecContext(ctx, ` + UPDATE chat_messages SET content = $1::jsonb WHERE id = $2 + `, string(newContent), target.ID) + require.NoError(t, err) + + reloaded, err := f.DB.GetChatMessageByID(ctx, target.ID) + require.NoError(t, err) + require.Equal(t, bumped.SnapshotVersion, reloaded.Revision, + "updated message picks up the current snapshot version") + + chatAfter, err := f.DB.GetChatByID(ctx, created.Chat.ID) + require.NoError(t, err) + require.Equal(t, bumped.SnapshotVersion, chatAfter.HistoryVersion) + require.Equal(t, int64(0), chatAfter.GenerationAttempt, + "history change resets generation_attempt") +} + +// TestMessageRevisionCannotBeSetByRuntimeCode verifies the BEFORE +// trigger rejects explicit revision values on INSERT and rejects +// revision changes on UPDATE. +func TestMessageRevisionCannotBeSetByRuntimeCode(t *testing.T) { + t.Parallel() + tf := newTriggerFixture(t) + f := tf.f + ctx := testutil.Context(t, testutil.WaitShort) + created := createTestChat(t, f) + + content := userMessageContent(t, "explicit revision") + _, err := tf.sqlDB.ExecContext(ctx, ` + INSERT INTO chat_messages (chat_id, role, content, content_version, visibility, revision) + VALUES ($1, 'user', $2::jsonb, $3, 'both', 999) + `, created.Chat.ID, string(content), int(chatprompt.CurrentContentVersion)) + require.Error(t, err, "INSERT with explicit revision must be rejected") + require.Contains(t, err.Error(), "revision must be assigned by trigger") + + msgs, err := f.DB.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ + ChatID: created.Chat.ID, + }) + require.NoError(t, err) + require.NotEmpty(t, msgs) + target := msgs[0] + + _, err = tf.sqlDB.ExecContext(ctx, ` + UPDATE chat_messages SET revision = revision + 100 WHERE id = $1 + `, target.ID) + require.Error(t, err, "UPDATE that changes revision must be rejected") + require.Contains(t, err.Error(), "revision must be assigned by trigger") +} + +// TestMessageChatIDCannotChange verifies the BEFORE trigger rejects +// updates that change chat_messages.chat_id. +func TestMessageChatIDCannotChange(t *testing.T) { + t.Parallel() + tf := newTriggerFixture(t) + f := tf.f + ctx := testutil.Context(t, testutil.WaitShort) + first := createTestChat(t, f) + second := createTestChat(t, f) + + firstMsgs, err := f.DB.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ + ChatID: first.Chat.ID, + }) + require.NoError(t, err) + require.NotEmpty(t, firstMsgs) + target := firstMsgs[0] + + _, err = tf.sqlDB.ExecContext(ctx, ` + UPDATE chat_messages SET chat_id = $1 WHERE id = $2 + `, second.Chat.ID, target.ID) + require.Error(t, err, "UPDATE that changes chat_id must be rejected") + require.Contains(t, err.Error(), "chat_id is immutable") +} + +// TestNoopMessageUpdateDoesNotAdvanceHistoryVersion verifies that a +// no-op UPDATE on a chat_messages row (one whose OLD and NEW are +// indistinguishable) does NOT advance chats.history_version even +// when the snapshot was previously bumped. This guards against the +// AFTER UPDATE STATEMENT trigger naively reacting to every touched +// row id regardless of whether the row actually changed. +func TestNoopMessageUpdateDoesNotAdvanceHistoryVersion(t *testing.T) { + t.Parallel() + tf := newTriggerFixture(t) + f := tf.f + ctx := testutil.Context(t, testutil.WaitShort) + created := createTestChat(t, f) + + msgs, err := f.DB.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ + ChatID: created.Chat.ID, + }) + require.NoError(t, err) + require.NotEmpty(t, msgs) + target := msgs[0] + originalRevision := target.Revision + + // Bump snapshot so the AFTER STATEMENT guard + // (history_version != snapshot_version) is now true. + bumped, err := f.DB.LockChatAndBumpSnapshotVersion(ctx, created.Chat.ID) + require.NoError(t, err) + require.NotEqual(t, bumped.SnapshotVersion, bumped.HistoryVersion, + "snapshot bump leaves history_version trailing") + + // No-op UPDATE: SET content = content. OLD IS NOT DISTINCT FROM NEW. + _, err = tf.sqlDB.ExecContext(ctx, ` + UPDATE chat_messages SET content = content WHERE id = $1 + `, target.ID) + require.NoError(t, err) + + after, err := f.DB.GetChatByID(ctx, created.Chat.ID) + require.NoError(t, err) + require.Equal(t, bumped.HistoryVersion, after.HistoryVersion, + "no-op update must NOT advance history_version") + + // And the row's revision is untouched. + reloaded, err := f.DB.GetChatMessageByID(ctx, target.ID) + require.NoError(t, err) + require.Equal(t, originalRevision, reloaded.Revision, + "no-op update must NOT advance message revision") +} + +// ============================================================================= +// Queue version triggers +// ============================================================================= + +// TestQueueInsertUpdatesQueueVersion verifies that an INSERT into +// chat_queued_messages bumps chats.queue_version to the current +// snapshot_version. +func TestQueueInsertUpdatesQueueVersion(t *testing.T) { + t.Parallel() + tf := newTriggerFixture(t) + f := tf.f + ctx := testutil.Context(t, testutil.WaitShort) + created := createTestChat(t, f) + before, err := f.DB.GetChatByID(ctx, created.Chat.ID) + require.NoError(t, err) + require.Equal(t, int64(0), before.QueueVersion) + + bumped, err := f.DB.LockChatAndBumpSnapshotVersion(ctx, created.Chat.ID) + require.NoError(t, err) + + content := userMessageContent(t, "queued") + _, err = f.DB.InsertChatQueuedMessageWithCreator(ctx, database.InsertChatQueuedMessageWithCreatorParams{ + ChatID: created.Chat.ID, + Content: content, + CreatedBy: f.User.ID, + }) + require.NoError(t, err) + + after, err := f.DB.GetChatByID(ctx, created.Chat.ID) + require.NoError(t, err) + require.Equal(t, bumped.SnapshotVersion, after.QueueVersion, + "INSERT into chat_queued_messages bumps queue_version") +} + +// TestQueuedMessageCreatedByIsRequired verifies the database enforces +// creator metadata for every queued message row. +func TestQueuedMessageCreatedByIsRequired(t *testing.T) { + t.Parallel() + tf := newTriggerFixture(t) + f := tf.f + ctx := testutil.Context(t, testutil.WaitShort) + created := createTestChat(t, f) + + content := userMessageContent(t, "queued-without-creator") + _, err := tf.sqlDB.ExecContext(ctx, ` + INSERT INTO chat_queued_messages (chat_id, content, model_config_id, created_by) + VALUES ($1, $2::jsonb, NULL, NULL) + `, created.Chat.ID, string(content)) + require.Error(t, err) + require.Contains(t, err.Error(), "created_by") +} + +func TestLegacyQueuedMessageInsertUsesChatOwnerAsCreator(t *testing.T) { + t.Parallel() + tf := newTriggerFixture(t) + f := tf.f + ctx := testutil.Context(t, testutil.WaitShort) + created := createTestChat(t, f) + + queued, err := f.DB.InsertChatQueuedMessage(ctx, database.InsertChatQueuedMessageParams{ + ChatID: created.Chat.ID, + Content: userMessageContent(t, "legacy-queued"), + }) + require.NoError(t, err) + require.Equal(t, created.Chat.OwnerID, queued.CreatedBy) +} + +// TestQueueUpdateContentUpdatesQueueVersion verifies that an UPDATE +// of chat_queued_messages.content bumps queue_version. The +// AFTER UPDATE trigger explicitly listens for content changes. +func TestQueueUpdateContentUpdatesQueueVersion(t *testing.T) { + t.Parallel() + tf := newTriggerFixture(t) + f := tf.f + ctx := testutil.Context(t, testutil.WaitShort) + created := createTestChat(t, f) + + queued, err := f.DB.InsertChatQueuedMessageWithCreator(ctx, database.InsertChatQueuedMessageWithCreatorParams{ + ChatID: created.Chat.ID, + Content: userMessageContent(t, "initial"), + CreatedBy: f.User.ID, + }) + require.NoError(t, err) + + before, err := f.DB.GetChatByID(ctx, created.Chat.ID) + require.NoError(t, err) + bumped, err := f.DB.LockChatAndBumpSnapshotVersion(ctx, created.Chat.ID) + require.NoError(t, err) + require.Greater(t, bumped.SnapshotVersion, before.QueueVersion) + + updated := userMessageContent(t, "updated") + _, err = tf.sqlDB.ExecContext(ctx, ` + UPDATE chat_queued_messages SET content = $1::jsonb WHERE id = $2 + `, string(updated), queued.ID) + require.NoError(t, err) + + after, err := f.DB.GetChatByID(ctx, created.Chat.ID) + require.NoError(t, err) + require.Equal(t, bumped.SnapshotVersion, after.QueueVersion, + "UPDATE of queued content bumps queue_version") +} + +// TestQueueUpdatePositionUpdatesQueueVersion verifies that an UPDATE +// of chat_queued_messages.position (such as the reorder-to-head +// path) bumps queue_version. +func TestQueueUpdatePositionUpdatesQueueVersion(t *testing.T) { + t.Parallel() + tf := newTriggerFixture(t) + f := tf.f + ctx := testutil.Context(t, testutil.WaitShort) + created := createTestChat(t, f) + + q1, err := f.DB.InsertChatQueuedMessageWithCreator(ctx, database.InsertChatQueuedMessageWithCreatorParams{ + ChatID: created.Chat.ID, + Content: userMessageContent(t, "first"), + CreatedBy: f.User.ID, + }) + require.NoError(t, err) + q2, err := f.DB.InsertChatQueuedMessageWithCreator(ctx, database.InsertChatQueuedMessageWithCreatorParams{ + ChatID: created.Chat.ID, + Content: userMessageContent(t, "second"), + CreatedBy: f.User.ID, + }) + require.NoError(t, err) + require.NotEqual(t, q1.ID, q2.ID) + + bumped, err := f.DB.LockChatAndBumpSnapshotVersion(ctx, created.Chat.ID) + require.NoError(t, err) + + // Move q2 to head by setting its position to q1.position - 1. + _, err = tf.sqlDB.ExecContext(ctx, ` + UPDATE chat_queued_messages SET position = $1 WHERE id = $2 + `, q1.Position-1, q2.ID) + require.NoError(t, err) + + after, err := f.DB.GetChatByID(ctx, created.Chat.ID) + require.NoError(t, err) + require.Equal(t, bumped.SnapshotVersion, after.QueueVersion, + "UPDATE of queued position bumps queue_version") +} + +// TestQueueDeleteUpdatesQueueVersion verifies that DELETE from +// chat_queued_messages bumps queue_version. +func TestQueueDeleteUpdatesQueueVersion(t *testing.T) { + t.Parallel() + tf := newTriggerFixture(t) + f := tf.f + ctx := testutil.Context(t, testutil.WaitShort) + created := createTestChat(t, f) + + queued, err := f.DB.InsertChatQueuedMessageWithCreator(ctx, database.InsertChatQueuedMessageWithCreatorParams{ + ChatID: created.Chat.ID, + Content: userMessageContent(t, "to delete"), + CreatedBy: f.User.ID, + }) + require.NoError(t, err) + + bumped, err := f.DB.LockChatAndBumpSnapshotVersion(ctx, created.Chat.ID) + require.NoError(t, err) + + rows, err := f.DB.DeleteChatQueuedMessageReturningCount(ctx, database.DeleteChatQueuedMessageReturningCountParams{ + ID: queued.ID, + ChatID: created.Chat.ID, + }) + require.NoError(t, err) + require.Equal(t, int64(1), rows) + + after, err := f.DB.GetChatByID(ctx, created.Chat.ID) + require.NoError(t, err) + require.Equal(t, bumped.SnapshotVersion, after.QueueVersion, + "DELETE from queue bumps queue_version") +} + +// TestNonQueueUpdateDoesNotUpdateQueueVersion verifies that mutations +// on other chat-related tables do NOT bump queue_version. The +// canonical case is inserting a chat message: it must update +// history_version but leave queue_version untouched. +func TestNonQueueUpdateDoesNotUpdateQueueVersion(t *testing.T) { + t.Parallel() + tf := newTriggerFixture(t) + f := tf.f + ctx := testutil.Context(t, testutil.WaitShort) + created := createTestChat(t, f) + before, err := f.DB.GetChatByID(ctx, created.Chat.ID) + require.NoError(t, err) + + bumped, err := f.DB.LockChatAndBumpSnapshotVersion(ctx, created.Chat.ID) + require.NoError(t, err) + + content := userMessageContent(t, "non-queue mutation") + _, err = tf.sqlDB.ExecContext(ctx, ` + INSERT INTO chat_messages (chat_id, role, content, content_version, visibility) + VALUES ($1, 'assistant', $2::jsonb, $3, 'both') + `, created.Chat.ID, string(content), int(chatprompt.CurrentContentVersion)) + require.NoError(t, err) + + after, err := f.DB.GetChatByID(ctx, created.Chat.ID) + require.NoError(t, err) + require.Equal(t, before.QueueVersion, after.QueueVersion, + "chat_messages INSERT must not bump queue_version") + // Sanity: history_version DID move. + require.Equal(t, bumped.SnapshotVersion, after.HistoryVersion) +} + +// ============================================================================= +// Retry state triggers +// ============================================================================= + +func TestRetryStateDefaults(t *testing.T) { + t.Parallel() + tf := newTriggerFixture(t) + f := tf.f + ctx := testutil.Context(t, testutil.WaitShort) + created := createTestChat(t, f) + + chat, err := f.DB.GetChatByID(ctx, created.Chat.ID) + require.NoError(t, err) + require.False(t, chat.RetryState.Valid) + require.Equal(t, int64(0), chat.RetryStateVersion) +} + +func TestRetryStateUpdateSetsRetryStateVersion(t *testing.T) { + t.Parallel() + tf := newTriggerFixture(t) + f := tf.f + ctx := testutil.Context(t, testutil.WaitShort) + created := createTestChat(t, f) + + bumped, err := f.DB.LockChatAndBumpSnapshotVersion(ctx, created.Chat.ID) + require.NoError(t, err) + + after, err := f.DB.UpdateChatRetryState(ctx, database.UpdateChatRetryStateParams{ + ID: created.Chat.ID, + RetryState: []byte(`{"attempt":1,"delay_ms":250,"error":"retry","retrying_at":"2026-05-29T00:00:00Z"}`), + }) + require.NoError(t, err) + require.True(t, after.RetryState.Valid) + require.JSONEq(t, + `{"attempt":1,"delay_ms":250,"error":"retry","retrying_at":"2026-05-29T00:00:00Z"}`, + string(after.RetryState.RawMessage)) + require.Equal(t, bumped.SnapshotVersion, after.RetryStateVersion) +} + +func TestRetryStateSameValueDoesNotUpdateRetryStateVersion(t *testing.T) { + t.Parallel() + tf := newTriggerFixture(t) + f := tf.f + ctx := testutil.Context(t, testutil.WaitShort) + created := createTestChat(t, f) + + payload := []byte(`{"attempt":1,"delay_ms":250,"error":"retry","retrying_at":"2026-05-29T00:00:00Z"}`) + _, err := f.DB.LockChatAndBumpSnapshotVersion(ctx, created.Chat.ID) + require.NoError(t, err) + first, err := f.DB.UpdateChatRetryState(ctx, database.UpdateChatRetryStateParams{ + ID: created.Chat.ID, + RetryState: payload, + }) + require.NoError(t, err) + + _, err = f.DB.LockChatAndBumpSnapshotVersion(ctx, created.Chat.ID) + require.NoError(t, err) + second, err := f.DB.UpdateChatRetryState(ctx, database.UpdateChatRetryStateParams{ + ID: created.Chat.ID, + RetryState: payload, + }) + require.NoError(t, err) + require.Equal(t, first.RetryStateVersion, second.RetryStateVersion, + "same retry_state payload must not update retry_state_version") +} + +func TestGenerationAttemptClearsRetryState(t *testing.T) { + t.Parallel() + tf := newTriggerFixture(t) + f := tf.f + ctx := testutil.Context(t, testutil.WaitShort) + created := createTestChat(t, f) + + _, err := f.DB.LockChatAndBumpSnapshotVersion(ctx, created.Chat.ID) + require.NoError(t, err) + withRetry, err := f.DB.UpdateChatRetryState(ctx, database.UpdateChatRetryStateParams{ + ID: created.Chat.ID, + RetryState: []byte(`{"attempt":1,"delay_ms":250,"error":"retry","retrying_at":"2026-05-29T00:00:00Z"}`), + }) + require.NoError(t, err) + require.True(t, withRetry.RetryState.Valid) + + bumped, err := f.DB.LockChatAndBumpSnapshotVersion(ctx, created.Chat.ID) + require.NoError(t, err) + attempt, err := f.DB.IncrementChatGenerationAttempt(ctx, created.Chat.ID) + require.NoError(t, err) + require.Equal(t, int64(1), attempt) + + after, err := f.DB.GetChatByID(ctx, created.Chat.ID) + require.NoError(t, err) + require.False(t, after.RetryState.Valid) + require.Equal(t, bumped.SnapshotVersion, after.RetryStateVersion, + "clearing retry_state on generation attempt bumps retry_state_version") +} + +func TestGenerationAttemptWithNullRetryStateDoesNotUpdateRetryStateVersion(t *testing.T) { + t.Parallel() + tf := newTriggerFixture(t) + f := tf.f + ctx := testutil.Context(t, testutil.WaitShort) + created := createTestChat(t, f) + before, err := f.DB.GetChatByID(ctx, created.Chat.ID) + require.NoError(t, err) + require.False(t, before.RetryState.Valid) + + _, err = f.DB.LockChatAndBumpSnapshotVersion(ctx, created.Chat.ID) + require.NoError(t, err) + _, err = f.DB.IncrementChatGenerationAttempt(ctx, created.Chat.ID) + require.NoError(t, err) + + after, err := f.DB.GetChatByID(ctx, created.Chat.ID) + require.NoError(t, err) + require.False(t, after.RetryState.Valid) + require.Equal(t, before.RetryStateVersion, after.RetryStateVersion, + "generation attempt with null retry_state leaves retry_state_version unchanged") +} + +func TestRetryStateVersionCannotBeSetByRuntimeCode(t *testing.T) { + t.Parallel() + tf := newTriggerFixture(t) + f := tf.f + ctx := testutil.Context(t, testutil.WaitShort) + created := createTestChat(t, f) + + _, err := tf.sqlDB.ExecContext(ctx, ` + UPDATE chats SET retry_state_version = retry_state_version + 1 WHERE id = $1 + `, created.Chat.ID) + require.Error(t, err) + require.Contains(t, err.Error(), "retry_state_version must be assigned by trigger") +} + +func TestHistoryChangeClearsRetryState(t *testing.T) { + t.Parallel() + tf := newTriggerFixture(t) + f := tf.f + ctx := testutil.Context(t, testutil.WaitShort) + created := createTestChat(t, f) + + _, err := f.DB.LockChatAndBumpSnapshotVersion(ctx, created.Chat.ID) + require.NoError(t, err) + _, err = f.DB.IncrementChatGenerationAttempt(ctx, created.Chat.ID) + require.NoError(t, err) + _, err = f.DB.UpdateChatRetryState(ctx, database.UpdateChatRetryStateParams{ + ID: created.Chat.ID, + RetryState: []byte(`{"attempt":1,"delay_ms":250,"error":"retry","retrying_at":"2026-05-29T00:00:00Z"}`), + }) + require.NoError(t, err) + + bumped, err := f.DB.LockChatAndBumpSnapshotVersion(ctx, created.Chat.ID) + require.NoError(t, err) + content := userMessageContent(t, "history clears retry state") + _, err = tf.sqlDB.ExecContext(ctx, ` + INSERT INTO chat_messages (chat_id, role, content, content_version, visibility) + VALUES ($1, 'assistant', $2::jsonb, $3, 'both') + `, created.Chat.ID, string(content), int(chatprompt.CurrentContentVersion)) + require.NoError(t, err) + + after, err := f.DB.GetChatByID(ctx, created.Chat.ID) + require.NoError(t, err) + require.Equal(t, int64(0), after.GenerationAttempt) + require.False(t, after.RetryState.Valid) + require.Equal(t, bumped.SnapshotVersion, after.RetryStateVersion, + "history reset of generation_attempt clears retry_state") +} diff --git a/enterprise/audit/table.go b/enterprise/audit/table.go index e97a76daed..9f31417795 100644 --- a/enterprise/audit/table.go +++ b/enterprise/audit/table.go @@ -414,38 +414,46 @@ var auditableResourcesTypes = map[any]map[string]Action{ "deleted_at": ActionIgnore, // Changes, but is implicit when a delete event is fired. }, &database.Chat{}: { - "id": ActionTrack, - "owner_id": ActionTrack, - "owner_username": ActionIgnore, - "owner_name": ActionIgnore, - "organization_id": ActionIgnore, // Never changes after creation. - "workspace_id": ActionTrack, - "build_id": ActionIgnore, // Internal lifecycle. - "agent_id": ActionIgnore, // Internal lifecycle. - "title": ActionSecret, // May contain sensitive content. - "status": ActionIgnore, // Churns every message. - "worker_id": ActionIgnore, // Internal. - "started_at": ActionIgnore, - "heartbeat_at": ActionIgnore, // Internal. - "created_at": ActionIgnore, // Never changes. - "updated_at": ActionIgnore, // Bumped on every mutation. - "parent_chat_id": ActionIgnore, // Immutable after creation. - "root_chat_id": ActionIgnore, // Immutable after creation. - "last_model_config_id": ActionIgnore, // Churns every message. - "archived": ActionTrack, - "last_error": ActionIgnore, // Internal. - "last_turn_summary": ActionIgnore, // Internal cached display text. - "mode": ActionTrack, - "mcp_server_ids": ActionTrack, - "labels": ActionTrack, - "user_acl": ActionTrack, - "group_acl": ActionTrack, - "pin_order": ActionTrack, - "last_read_message_id": ActionIgnore, // User-scoped read cursor. - "last_injected_context": ActionIgnore, // Internal lifecycle. - "dynamic_tools": ActionIgnore, // Internal lifecycle. - "plan_mode": ActionIgnore, // Can flip back and forth during a session. - "client_type": ActionIgnore, // Set at creation. + "id": ActionTrack, + "owner_id": ActionTrack, + "owner_username": ActionIgnore, + "owner_name": ActionIgnore, + "organization_id": ActionIgnore, // Never changes after creation. + "workspace_id": ActionTrack, + "build_id": ActionIgnore, // Internal lifecycle. + "agent_id": ActionIgnore, // Internal lifecycle. + "title": ActionSecret, // May contain sensitive content. + "status": ActionIgnore, // Churns every message. + "worker_id": ActionIgnore, // Internal. + "started_at": ActionIgnore, + "heartbeat_at": ActionIgnore, // Internal. + "created_at": ActionIgnore, // Never changes. + "updated_at": ActionIgnore, // Bumped on every mutation. + "parent_chat_id": ActionIgnore, // Immutable after creation. + "root_chat_id": ActionIgnore, // Immutable after creation. + "last_model_config_id": ActionIgnore, // Churns every message. + "archived": ActionTrack, + "last_error": ActionIgnore, // Internal. + "last_turn_summary": ActionIgnore, // Internal cached display text. + "mode": ActionTrack, + "mcp_server_ids": ActionTrack, + "labels": ActionTrack, + "user_acl": ActionTrack, + "group_acl": ActionTrack, + "pin_order": ActionTrack, + "last_read_message_id": ActionIgnore, // User-scoped read cursor. + "last_injected_context": ActionIgnore, // Internal lifecycle. + "dynamic_tools": ActionIgnore, // Internal lifecycle. + "plan_mode": ActionIgnore, // Can flip back and forth during a session. + "client_type": ActionIgnore, // Set at creation. + "snapshot_version": ActionIgnore, // Internal state machine version. + "history_version": ActionIgnore, // Internal state machine version. + "queue_version": ActionIgnore, // Internal state machine version. + "retry_state": ActionIgnore, // Internal transient retry UI state. + "retry_state_version": ActionIgnore, // Internal state machine version. + "generation_attempt": ActionIgnore, // Internal retry counter. + "runner_id": ActionIgnore, // Internal ownership identifier. + "requires_action_deadline_at": ActionIgnore, // Internal pending-action deadline. }, &database.UserSkill{}: { "id": ActionTrack,