source files

This commit is contained in:
Hugo Dutka
2026-05-22 07:18:40 +00:00
parent a4afb9dfc6
commit 61a21e097b
29 changed files with 10442 additions and 99 deletions
@@ -0,0 +1,106 @@
-- Rollback for the chatd core state machine foundation migration.
-- 1. Recreate chats_expanded without the new chat fields. We must drop
-- the view first because the subsequent column drops would fail with
-- "view depends on column".
DROP VIEW IF EXISTS chats_expanded;
-- 2. Drop the retry state trigger and function.
DROP TRIGGER IF EXISTS trigger_sync_chat_retry_state ON chats;
DROP FUNCTION IF EXISTS sync_chat_retry_state();
-- 3. Drop the queue version triggers and function.
DROP TRIGGER IF EXISTS trigger_bump_chat_queue_version_on_queued_message_delete ON chat_queued_messages;
DROP TRIGGER IF EXISTS trigger_bump_chat_queue_version_on_queued_message_update ON chat_queued_messages;
DROP TRIGGER IF EXISTS trigger_bump_chat_queue_version_on_queued_message_insert ON chat_queued_messages;
DROP FUNCTION IF EXISTS bump_chat_queue_version_on_queued_message_change();
-- 4. Drop the message revision triggers and functions.
DROP TRIGGER IF EXISTS trigger_update_chat_history_after_message_update ON chat_messages;
DROP TRIGGER IF EXISTS trigger_update_chat_history_after_message_insert ON chat_messages;
DROP TRIGGER IF EXISTS trigger_set_chat_message_revision_on_update ON chat_messages;
DROP TRIGGER IF EXISTS trigger_set_chat_message_revision_on_insert ON chat_messages;
DROP FUNCTION IF EXISTS update_chat_history_after_message_update();
DROP FUNCTION IF EXISTS update_chat_history_after_message_insert();
-- The pre-split function name is kept here for backward compatibility
-- with environments that may have applied an earlier draft of the up
-- migration. DROP FUNCTION IF EXISTS is a no-op if the function is
-- absent.
DROP FUNCTION IF EXISTS update_chat_history_after_message_changes();
DROP FUNCTION IF EXISTS set_chat_message_revision_before();
DROP FUNCTION IF EXISTS set_chat_message_revision();
-- 5. Drop chat_heartbeats (and its index by association).
DROP TABLE IF EXISTS chat_heartbeats;
-- 6. Drop the queue-order index.
DROP INDEX IF EXISTS idx_chat_queued_messages_chat_position_id;
-- 7. Drop chat_queued_messages.position and its default sequence, plus
-- created_by.
ALTER TABLE chat_queued_messages
ALTER COLUMN position DROP DEFAULT;
ALTER TABLE chat_queued_messages
DROP COLUMN IF EXISTS position,
DROP COLUMN IF EXISTS created_by;
DROP SEQUENCE IF EXISTS chat_queued_messages_position_seq;
-- 8. Drop chat_messages.revision.
ALTER TABLE chat_messages
DROP COLUMN IF EXISTS revision;
-- 9. Drop the new chats columns.
ALTER TABLE chats
DROP COLUMN IF EXISTS snapshot_version,
DROP COLUMN IF EXISTS history_version,
DROP COLUMN IF EXISTS queue_version,
DROP COLUMN IF EXISTS generation_attempt,
DROP COLUMN IF EXISTS retry_state,
DROP COLUMN IF EXISTS retry_state_version,
DROP COLUMN IF EXISTS runner_id,
DROP COLUMN IF EXISTS requires_action_deadline_at;
-- 10. Recreate chats_expanded with the pre-migration field list.
CREATE VIEW chats_expanded AS
SELECT
c.id,
c.owner_id,
c.workspace_id,
c.title,
c.status,
c.worker_id,
c.started_at,
c.heartbeat_at,
c.created_at,
c.updated_at,
c.parent_chat_id,
c.root_chat_id,
c.last_model_config_id,
c.archived,
c.last_error,
c.mode,
c.mcp_server_ids,
c.labels,
c.build_id,
c.agent_id,
c.pin_order,
c.last_read_message_id,
c.last_injected_context,
c.dynamic_tools,
c.organization_id,
c.plan_mode,
c.client_type,
c.last_turn_summary,
COALESCE(root.user_acl, c.user_acl) AS user_acl,
COALESCE(root.group_acl, c.group_acl) AS group_acl,
owner.username AS owner_username,
owner.name AS owner_name
FROM
chats c
LEFT JOIN chats root ON root.id = COALESCE(c.root_chat_id, c.parent_chat_id)
JOIN visible_users owner ON owner.id = c.owner_id;
-- 11. The `interrupting` chat_status enum value is intentionally left
-- in place. Postgres does not support dropping a single enum value
-- without recreating the entire type, which would require rewriting
-- every chat row and is unsafe inside a transactional rollback.
@@ -0,0 +1,346 @@
-- Foundation for the chatd core state machine refactor (PR 1).
-- Adds new versioning fields to chats, a revision column to chat_messages,
-- positional ordering and creator tracking to chat_queued_messages, an
-- unlogged chat_heartbeats table for ownership leases, and Postgres
-- triggers that keep history/queue versioning consistent.
--
-- This migration does not switch any runtime code over to the new
-- state machine. Existing chatd code keeps running against the legacy
-- "pending"/"paused"/"completed" enum values and the legacy ownership
-- columns (chats.started_at, chats.heartbeat_at). Those columns and
-- enum values stay in place for compatibility.
-- 1. Add `interrupting` to the chat_status enum.
ALTER TYPE chat_status ADD VALUE IF NOT EXISTS 'interrupting';
-- 2. Add new versioning, ownership, retry, and pending-action fields to chats.
ALTER TABLE chats
ADD COLUMN snapshot_version bigint NOT NULL DEFAULT 1,
ADD COLUMN history_version bigint NOT NULL DEFAULT 0,
ADD COLUMN queue_version bigint NOT NULL DEFAULT 0,
ADD COLUMN generation_attempt bigint NOT NULL DEFAULT 0,
ADD COLUMN retry_state jsonb,
ADD COLUMN retry_state_version bigint NOT NULL DEFAULT 0,
ADD COLUMN runner_id uuid,
ADD COLUMN requires_action_deadline_at timestamp with time zone;
-- 3. Add `revision` to chat_messages (nullable for now, backfilled below).
ALTER TABLE chat_messages
ADD COLUMN revision bigint;
-- 4. Backfill chat_messages.revision = 1 for existing rows.
UPDATE chat_messages SET revision = 1 WHERE revision IS NULL;
-- 5. Backfill chats.history_version = 1 for chats that already have at
-- least one message. We avoid recursive trigger fire by performing the
-- backfill before the triggers are created.
UPDATE chats
SET history_version = 1
WHERE EXISTS (
SELECT 1 FROM chat_messages WHERE chat_messages.chat_id = chats.id
);
-- 6. Enforce NOT NULL on chat_messages.revision.
ALTER TABLE chat_messages
ALTER COLUMN revision SET NOT NULL;
-- 7. Add `position` and `created_by` to chat_queued_messages.
ALTER TABLE chat_queued_messages
ADD COLUMN position bigint,
ADD COLUMN created_by uuid;
-- 8. Backfill chat_queued_messages.position per chat using row_number(),
-- ordering by created_at and breaking ties by id.
WITH ordered AS (
SELECT
id,
row_number() OVER (
PARTITION BY chat_id
ORDER BY created_at, id
) AS rn
FROM chat_queued_messages
)
UPDATE chat_queued_messages
SET position = ordered.rn
FROM ordered
WHERE chat_queued_messages.id = ordered.id;
-- 9. Backfill chat_queued_messages.created_by from chats.owner_id.
UPDATE chat_queued_messages
SET created_by = chats.owner_id
FROM chats
WHERE chat_queued_messages.chat_id = chats.id
AND chat_queued_messages.created_by IS NULL;
-- 10. Enforce NOT NULL on chat_queued_messages.position and
-- created_by. Legacy queued-message inserts are updated to populate
-- created_by from the chat owner when no explicit creator exists.
ALTER TABLE chat_queued_messages
ALTER COLUMN position SET NOT NULL,
ALTER COLUMN created_by SET NOT NULL;
-- 11. Index for queue-order reads and head selection.
CREATE INDEX IF NOT EXISTS idx_chat_queued_messages_chat_position_id
ON chat_queued_messages(chat_id, position, id);
-- 12. Default sequence for new queued-message positions.
-- A global sequence is acceptable because ordering only needs to be
-- stable within a chat.
CREATE SEQUENCE IF NOT EXISTS chat_queued_messages_position_seq AS bigint START WITH 1;
SELECT setval(
'chat_queued_messages_position_seq',
GREATEST((SELECT COALESCE(MAX(position), 0) FROM chat_queued_messages), 1)
);
ALTER TABLE chat_queued_messages
ALTER COLUMN position SET DEFAULT nextval('chat_queued_messages_position_seq');
-- 13. Backfill chats.queue_version = 1 for chats that already have queued
-- messages. Same trigger-avoidance reasoning as for history_version.
UPDATE chats
SET queue_version = 1
WHERE EXISTS (
SELECT 1 FROM chat_queued_messages WHERE chat_queued_messages.chat_id = chats.id
);
-- 14. chat_heartbeats: unlogged table for ownership leases. Keyed by
-- (chat_id, runner_id) so a single chat can briefly have entries from
-- multiple runners during failover.
CREATE UNLOGGED TABLE IF NOT EXISTS chat_heartbeats (
chat_id uuid NOT NULL REFERENCES chats(id) ON DELETE CASCADE,
runner_id uuid NOT NULL,
heartbeat_at timestamp with time zone NOT NULL,
PRIMARY KEY (chat_id, runner_id)
);
CREATE INDEX IF NOT EXISTS chat_heartbeats_heartbeat_at_idx
ON chat_heartbeats (heartbeat_at);
-- 15. Message revision trigger.
-- The BEFORE-trigger only assigns NEW.revision from chats.snapshot_version
-- and validates immutability. The chats.history_version /
-- generation_attempt update is performed by an AFTER STATEMENT trigger
-- so it doesn't conflict with CTE updates on the chats row in the same
-- command (the legacy InsertChatMessages query updates last_model_config_id
-- in a CTE on chats and then inserts messages).
CREATE FUNCTION set_chat_message_revision_before()
RETURNS trigger AS $$
DECLARE
chat_snapshot_version bigint;
BEGIN
IF TG_OP = 'INSERT' AND NEW.revision IS NOT NULL THEN
RAISE EXCEPTION 'chat_messages.revision must be assigned by trigger';
END IF;
IF TG_OP = 'UPDATE' THEN
IF OLD.chat_id IS DISTINCT FROM NEW.chat_id THEN
RAISE EXCEPTION 'chat_messages.chat_id is immutable';
END IF;
IF OLD.revision IS DISTINCT FROM NEW.revision THEN
RAISE EXCEPTION 'chat_messages.revision must be assigned by trigger';
END IF;
IF OLD IS NOT DISTINCT FROM NEW THEN
RETURN NEW;
END IF;
END IF;
SELECT snapshot_version INTO chat_snapshot_version
FROM chats WHERE id = NEW.chat_id;
IF chat_snapshot_version IS NULL THEN
RAISE EXCEPTION 'chat % does not exist', NEW.chat_id;
END IF;
NEW.revision = chat_snapshot_version;
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
-- AFTER STATEMENT trigger functions. Use the transition tables to
-- update chats.history_version / generation_attempt once per chat per
-- command. Running AFTER row inserts/updates complete lets a CTE
-- update on the same chats row in the same command finalize before
-- this trigger needs to update it.
--
-- The INSERT and UPDATE variants are split so the UPDATE variant can
-- reference both the OLD and NEW transition tables and skip rows that
-- did not actually change. Without that filter, a no-op UPDATE on a
-- chat_messages row (one whose OLD IS NOT DISTINCT FROM NEW) would
-- still advance chats.history_version whenever the chat's snapshot
-- had previously been bumped.
CREATE FUNCTION update_chat_history_after_message_insert()
RETURNS trigger AS $$
BEGIN
UPDATE chats c
SET history_version = c.snapshot_version,
generation_attempt = 0
FROM (
SELECT DISTINCT chat_id FROM chat_message_history_new_rows
) AS affected
WHERE c.id = affected.chat_id
AND (
c.history_version IS DISTINCT FROM c.snapshot_version
OR c.generation_attempt <> 0
);
RETURN NULL;
END;
$$ LANGUAGE plpgsql;
CREATE FUNCTION update_chat_history_after_message_update()
RETURNS trigger AS $$
BEGIN
UPDATE chats c
SET history_version = c.snapshot_version,
generation_attempt = 0
FROM (
SELECT DISTINCT n.chat_id
FROM chat_message_history_new_rows n
JOIN chat_message_history_old_rows o ON o.id = n.id
WHERE o IS DISTINCT FROM n
) AS affected
WHERE c.id = affected.chat_id
AND (
c.history_version IS DISTINCT FROM c.snapshot_version
OR c.generation_attempt <> 0
);
RETURN NULL;
END;
$$ LANGUAGE plpgsql;
CREATE TRIGGER trigger_set_chat_message_revision_on_insert
BEFORE INSERT ON chat_messages
FOR EACH ROW
EXECUTE FUNCTION set_chat_message_revision_before();
CREATE TRIGGER trigger_set_chat_message_revision_on_update
BEFORE UPDATE ON chat_messages
FOR EACH ROW
EXECUTE FUNCTION set_chat_message_revision_before();
CREATE TRIGGER trigger_update_chat_history_after_message_insert
AFTER INSERT ON chat_messages
REFERENCING NEW TABLE AS chat_message_history_new_rows
FOR EACH STATEMENT
EXECUTE FUNCTION update_chat_history_after_message_insert();
CREATE TRIGGER trigger_update_chat_history_after_message_update
AFTER UPDATE ON chat_messages
REFERENCING OLD TABLE AS chat_message_history_old_rows NEW TABLE AS chat_message_history_new_rows
FOR EACH STATEMENT
EXECUTE FUNCTION update_chat_history_after_message_update();
-- 16. Queue version trigger function.
CREATE FUNCTION bump_chat_queue_version_on_queued_message_change()
RETURNS trigger AS $$
DECLARE
changed_chat_id uuid;
BEGIN
IF TG_OP = 'DELETE' THEN
changed_chat_id = OLD.chat_id;
ELSE
changed_chat_id = NEW.chat_id;
END IF;
UPDATE chats
SET queue_version = snapshot_version
WHERE id = changed_chat_id;
IF TG_OP = 'DELETE' THEN
RETURN OLD;
END IF;
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
CREATE TRIGGER trigger_bump_chat_queue_version_on_queued_message_insert
AFTER INSERT ON chat_queued_messages
FOR EACH ROW
EXECUTE FUNCTION bump_chat_queue_version_on_queued_message_change();
CREATE TRIGGER trigger_bump_chat_queue_version_on_queued_message_update
AFTER UPDATE OF content, model_config_id, position, created_by
ON chat_queued_messages
FOR EACH ROW
EXECUTE FUNCTION bump_chat_queue_version_on_queued_message_change();
CREATE TRIGGER trigger_bump_chat_queue_version_on_queued_message_delete
AFTER DELETE ON chat_queued_messages
FOR EACH ROW
EXECUTE FUNCTION bump_chat_queue_version_on_queued_message_change();
-- 17. Retry state trigger function.
CREATE FUNCTION sync_chat_retry_state()
RETURNS trigger AS $$
BEGIN
IF OLD.retry_state_version IS DISTINCT FROM NEW.retry_state_version THEN
RAISE EXCEPTION 'chats.retry_state_version must be assigned by trigger';
END IF;
IF NEW.generation_attempt IS DISTINCT FROM OLD.generation_attempt THEN
NEW.retry_state = NULL;
END IF;
IF NEW.retry_state IS DISTINCT FROM OLD.retry_state THEN
NEW.retry_state_version = NEW.snapshot_version;
END IF;
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
CREATE TRIGGER trigger_sync_chat_retry_state
BEFORE UPDATE OF retry_state, retry_state_version, generation_attempt
ON chats
FOR EACH ROW
EXECUTE FUNCTION sync_chat_retry_state();
-- 18. Refresh chats_expanded to include the new chat fields. Drop and
-- recreate so column ordering is stable.
DROP VIEW IF EXISTS chats_expanded;
CREATE VIEW chats_expanded AS
SELECT
c.id,
c.owner_id,
c.workspace_id,
c.title,
c.status,
c.worker_id,
c.started_at,
c.heartbeat_at,
c.created_at,
c.updated_at,
c.parent_chat_id,
c.root_chat_id,
c.last_model_config_id,
c.archived,
c.last_error,
c.mode,
c.mcp_server_ids,
c.labels,
c.build_id,
c.agent_id,
c.pin_order,
c.last_read_message_id,
c.last_injected_context,
c.dynamic_tools,
c.organization_id,
c.plan_mode,
c.client_type,
c.last_turn_summary,
c.snapshot_version,
c.history_version,
c.queue_version,
c.generation_attempt,
c.retry_state,
c.retry_state_version,
c.runner_id,
c.requires_action_deadline_at,
COALESCE(root.user_acl, c.user_acl) AS user_acl,
COALESCE(root.group_acl, c.group_acl) AS group_acl,
owner.username AS owner_username,
owner.name AS owner_name
FROM
chats c
LEFT JOIN chats root ON root.id = COALESCE(c.root_chat_id, c.parent_chat_id)
JOIN visible_users owner ON owner.id = c.owner_id;
@@ -0,0 +1,17 @@
-- Fixture coverage for the chat_heartbeats table introduced in
-- migration 000500. The earlier chat fixtures already insert at least
-- one row into chats; we attach a heartbeat for the first such chat so
-- migration tests see a non-empty chat_heartbeats table without
-- hard-coding a specific chat ID.
INSERT INTO chat_heartbeats (
chat_id,
runner_id,
heartbeat_at
)
SELECT
chats.id,
'00000000-0000-0000-0000-0000000fea51'::uuid,
'2024-01-01 00:00:00+00'
FROM chats
ORDER BY created_at, id
LIMIT 1;
+12
View File
@@ -820,6 +820,12 @@ func (q *sqlQuerier) GetAuthorizedChats(ctx context.Context, arg GetChatsParams,
&i.Chat.PlanMode,
&i.Chat.ClientType,
&i.Chat.LastTurnSummary,
&i.Chat.SnapshotVersion,
&i.Chat.HistoryVersion,
&i.Chat.QueueVersion,
&i.Chat.GenerationAttempt,
&i.Chat.RunnerID,
&i.Chat.RequiresActionDeadlineAt,
&i.Chat.UserACL,
&i.Chat.GroupACL,
&i.Chat.OwnerUsername,
@@ -887,6 +893,12 @@ func (q *sqlQuerier) GetAuthorizedChatsByChatFileID(ctx context.Context, fileID
&i.PlanMode,
&i.ClientType,
&i.LastTurnSummary,
&i.SnapshotVersion,
&i.HistoryVersion,
&i.QueueVersion,
&i.GenerationAttempt,
&i.RunnerID,
&i.RequiresActionDeadlineAt,
&i.UserACL,
&i.GroupACL,
&i.OwnerUsername,
+609 -67
View File
@@ -35,6 +35,14 @@ chats_expanded AS (
updated_chats.plan_mode,
updated_chats.client_type,
updated_chats.last_turn_summary,
updated_chats.snapshot_version,
updated_chats.history_version,
updated_chats.queue_version,
updated_chats.generation_attempt,
updated_chats.retry_state,
updated_chats.retry_state_version,
updated_chats.runner_id,
updated_chats.requires_action_deadline_at,
COALESCE(root.user_acl, updated_chats.user_acl) AS user_acl,
COALESCE(root.group_acl, updated_chats.group_acl) AS group_acl,
owner.username AS owner_username,
@@ -90,6 +98,14 @@ chats_expanded AS (
updated_chats.plan_mode,
updated_chats.client_type,
updated_chats.last_turn_summary,
updated_chats.snapshot_version,
updated_chats.history_version,
updated_chats.queue_version,
updated_chats.generation_attempt,
updated_chats.retry_state,
updated_chats.retry_state_version,
updated_chats.runner_id,
updated_chats.requires_action_deadline_at,
COALESCE(root.user_acl, updated_chats.user_acl) AS user_acl,
COALESCE(root.group_acl, updated_chats.group_acl) AS group_acl,
owner.username AS owner_username,
@@ -293,6 +309,15 @@ SELECT *
FROM chats_expanded
WHERE id = @id::uuid;
-- name: GetChatFamilyIDsByRootID :many
-- Returns the chat IDs of every chat in a family (root + all children)
-- in deterministic order. The id parameter must be the root id; the
-- query does not walk up from a child.
SELECT id
FROM chats
WHERE id = @id::uuid OR root_chat_id = @id::uuid
ORDER BY (id = @id::uuid) DESC, created_at ASC, id ASC;
-- name: GetChatACLByID :one
SELECT
user_acl AS users,
@@ -719,6 +744,14 @@ chats_expanded AS (
inserted_chat.plan_mode,
inserted_chat.client_type,
inserted_chat.last_turn_summary,
inserted_chat.snapshot_version,
inserted_chat.history_version,
inserted_chat.queue_version,
inserted_chat.generation_attempt,
inserted_chat.retry_state,
inserted_chat.retry_state_version,
inserted_chat.runner_id,
inserted_chat.requires_action_deadline_at,
COALESCE(root.user_acl, inserted_chat.user_acl) AS user_acl,
COALESCE(root.group_acl, inserted_chat.group_acl) AS group_acl,
owner.username AS owner_username,
@@ -854,6 +887,14 @@ chats_expanded AS (
updated_chat.plan_mode,
updated_chat.client_type,
updated_chat.last_turn_summary,
updated_chat.snapshot_version,
updated_chat.history_version,
updated_chat.queue_version,
updated_chat.generation_attempt,
updated_chat.retry_state,
updated_chat.retry_state_version,
updated_chat.runner_id,
updated_chat.requires_action_deadline_at,
COALESCE(root.user_acl, updated_chat.user_acl) AS user_acl,
COALESCE(root.group_acl, updated_chat.group_acl) AS group_acl,
owner.username AS owner_username,
@@ -909,6 +950,14 @@ chats_expanded AS (
updated_chat.plan_mode,
updated_chat.client_type,
updated_chat.last_turn_summary,
updated_chat.snapshot_version,
updated_chat.history_version,
updated_chat.queue_version,
updated_chat.generation_attempt,
updated_chat.retry_state,
updated_chat.retry_state_version,
updated_chat.runner_id,
updated_chat.requires_action_deadline_at,
COALESCE(root.user_acl, updated_chat.user_acl) AS user_acl,
COALESCE(root.group_acl, updated_chat.group_acl) AS group_acl,
owner.username AS owner_username,
@@ -962,6 +1011,14 @@ chats_expanded AS (
updated_chat.plan_mode,
updated_chat.client_type,
updated_chat.last_turn_summary,
updated_chat.snapshot_version,
updated_chat.history_version,
updated_chat.queue_version,
updated_chat.generation_attempt,
updated_chat.retry_state,
updated_chat.retry_state_version,
updated_chat.runner_id,
updated_chat.requires_action_deadline_at,
COALESCE(root.user_acl, updated_chat.user_acl) AS user_acl,
COALESCE(root.group_acl, updated_chat.group_acl) AS group_acl,
owner.username AS owner_username,
@@ -1015,6 +1072,14 @@ chats_expanded AS (
updated_chat.plan_mode,
updated_chat.client_type,
updated_chat.last_turn_summary,
updated_chat.snapshot_version,
updated_chat.history_version,
updated_chat.queue_version,
updated_chat.generation_attempt,
updated_chat.retry_state,
updated_chat.retry_state_version,
updated_chat.runner_id,
updated_chat.requires_action_deadline_at,
COALESCE(root.user_acl, updated_chat.user_acl) AS user_acl,
COALESCE(root.group_acl, updated_chat.group_acl) AS group_acl,
owner.username AS owner_username,
@@ -1068,6 +1133,14 @@ chats_expanded AS (
updated_chat.plan_mode,
updated_chat.client_type,
updated_chat.last_turn_summary,
updated_chat.snapshot_version,
updated_chat.history_version,
updated_chat.queue_version,
updated_chat.generation_attempt,
updated_chat.retry_state,
updated_chat.retry_state_version,
updated_chat.runner_id,
updated_chat.requires_action_deadline_at,
COALESCE(root.user_acl, updated_chat.user_acl) AS user_acl,
COALESCE(root.group_acl, updated_chat.group_acl) AS group_acl,
owner.username AS owner_username,
@@ -1120,6 +1193,14 @@ chats_expanded AS (
updated_chat.plan_mode,
updated_chat.client_type,
updated_chat.last_turn_summary,
updated_chat.snapshot_version,
updated_chat.history_version,
updated_chat.queue_version,
updated_chat.generation_attempt,
updated_chat.retry_state,
updated_chat.retry_state_version,
updated_chat.runner_id,
updated_chat.requires_action_deadline_at,
COALESCE(root.user_acl, updated_chat.user_acl) AS user_acl,
COALESCE(root.group_acl, updated_chat.group_acl) AS group_acl,
owner.username AS owner_username,
@@ -1172,6 +1253,14 @@ chats_expanded AS (
updated_chat.plan_mode,
updated_chat.client_type,
updated_chat.last_turn_summary,
updated_chat.snapshot_version,
updated_chat.history_version,
updated_chat.queue_version,
updated_chat.generation_attempt,
updated_chat.retry_state,
updated_chat.retry_state_version,
updated_chat.runner_id,
updated_chat.requires_action_deadline_at,
COALESCE(root.user_acl, updated_chat.user_acl) AS user_acl,
COALESCE(root.group_acl, updated_chat.group_acl) AS group_acl,
owner.username AS owner_username,
@@ -1226,6 +1315,14 @@ chats_expanded AS (
updated_chat.plan_mode,
updated_chat.client_type,
updated_chat.last_turn_summary,
updated_chat.snapshot_version,
updated_chat.history_version,
updated_chat.queue_version,
updated_chat.generation_attempt,
updated_chat.retry_state,
updated_chat.retry_state_version,
updated_chat.runner_id,
updated_chat.requires_action_deadline_at,
COALESCE(root.user_acl, updated_chat.user_acl) AS user_acl,
COALESCE(root.group_acl, updated_chat.group_acl) AS group_acl,
owner.username AS owner_username,
@@ -1242,11 +1339,9 @@ FROM chats_expanded;
-- Updates the cached last completed turn summary for sidebar display.
-- Empty or whitespace-only summaries are stored as NULL here so direct
-- query callers cannot accidentally persist blank sidebar text.
-- This intentionally preserves updated_at. The staleness guard relies on
-- every new-turn query, such as UpdateChatStatus and AcquireChats, bumping
-- updated_at. Future chat-field updates that do not bump updated_at can let
-- stale summaries persist. If this query ever bumps updated_at, later
-- goroutine summary writes will be rejected as stale.
-- This intentionally preserves updated_at. The staleness guard uses
-- history_version so worker lifecycle transitions that do not change the
-- active message history cannot reject final turn summary writes.
-- Two summary workers using the same freshness marker are last-write-wins.
UPDATE chats
SET
@@ -1255,7 +1350,7 @@ SET
), '')
WHERE
id = @id::uuid
AND updated_at = @expected_updated_at::timestamptz;
AND history_version = @expected_history_version::bigint;
-- name: UpdateChatMCPServerIDs :one
WITH updated_chat AS (
@@ -1298,6 +1393,14 @@ chats_expanded AS (
updated_chat.plan_mode,
updated_chat.client_type,
updated_chat.last_turn_summary,
updated_chat.snapshot_version,
updated_chat.history_version,
updated_chat.queue_version,
updated_chat.generation_attempt,
updated_chat.retry_state,
updated_chat.retry_state_version,
updated_chat.runner_id,
updated_chat.requires_action_deadline_at,
COALESCE(root.user_acl, updated_chat.user_acl) AS user_acl,
COALESCE(root.group_acl, updated_chat.group_acl) AS group_acl,
owner.username AS owner_username,
@@ -1407,6 +1510,14 @@ chats_expanded AS (
acquired_chats.plan_mode,
acquired_chats.client_type,
acquired_chats.last_turn_summary,
acquired_chats.snapshot_version,
acquired_chats.history_version,
acquired_chats.queue_version,
acquired_chats.generation_attempt,
acquired_chats.retry_state,
acquired_chats.retry_state_version,
acquired_chats.runner_id,
acquired_chats.requires_action_deadline_at,
COALESCE(root.user_acl, acquired_chats.user_acl) AS user_acl,
COALESCE(root.group_acl, acquired_chats.group_acl) AS group_acl,
owner.username AS owner_username,
@@ -1464,6 +1575,14 @@ chats_expanded AS (
updated_chat.plan_mode,
updated_chat.client_type,
updated_chat.last_turn_summary,
updated_chat.snapshot_version,
updated_chat.history_version,
updated_chat.queue_version,
updated_chat.generation_attempt,
updated_chat.retry_state,
updated_chat.retry_state_version,
updated_chat.runner_id,
updated_chat.requires_action_deadline_at,
COALESCE(root.user_acl, updated_chat.user_acl) AS user_acl,
COALESCE(root.group_acl, updated_chat.group_acl) AS group_acl,
owner.username AS owner_username,
@@ -1521,6 +1640,14 @@ chats_expanded AS (
updated_chat.plan_mode,
updated_chat.client_type,
updated_chat.last_turn_summary,
updated_chat.snapshot_version,
updated_chat.history_version,
updated_chat.queue_version,
updated_chat.generation_attempt,
updated_chat.retry_state,
updated_chat.retry_state_version,
updated_chat.runner_id,
updated_chat.requires_action_deadline_at,
COALESCE(root.user_acl, updated_chat.user_acl) AS user_acl,
COALESCE(root.group_acl, updated_chat.group_acl) AS group_acl,
owner.username AS owner_username,
@@ -1688,12 +1815,17 @@ RETURNING
*;
-- name: InsertChatQueuedMessage :one
INSERT INTO chat_queued_messages (chat_id, content, model_config_id)
VALUES (
@chat_id,
@content,
sqlc.narg('model_config_id')::uuid
)
-- Legacy queue insertion path. When no caller-supplied creator exists,
-- preserve the created_by invariant by attributing the queued row to the
-- chat owner.
INSERT INTO chat_queued_messages (chat_id, content, model_config_id, created_by)
SELECT
@chat_id::uuid,
@content::jsonb,
sqlc.narg('model_config_id')::uuid,
chats.owner_id
FROM chats
WHERE chats.id = @chat_id::uuid
RETURNING *;
-- name: GetChatQueuedMessages :many
@@ -1779,6 +1911,14 @@ chats_expanded AS (
locked_chat.plan_mode,
locked_chat.client_type,
locked_chat.last_turn_summary,
locked_chat.snapshot_version,
locked_chat.history_version,
locked_chat.queue_version,
locked_chat.generation_attempt,
locked_chat.retry_state,
locked_chat.retry_state_version,
locked_chat.runner_id,
locked_chat.requires_action_deadline_at,
COALESCE(root.user_acl, locked_chat.user_acl) AS user_acl,
COALESCE(root.group_acl, locked_chat.group_acl) AS group_acl,
owner.username AS owner_username,
@@ -1791,6 +1931,63 @@ chats_expanded AS (
SELECT *
FROM chats_expanded;
-- name: GetChatByIDForShare :one
WITH shared_chat AS (
SELECT *
FROM chats
WHERE id = @id::uuid
FOR SHARE
),
chats_expanded AS (
SELECT
shared_chat.id,
shared_chat.owner_id,
shared_chat.workspace_id,
shared_chat.title,
shared_chat.status,
shared_chat.worker_id,
shared_chat.started_at,
shared_chat.heartbeat_at,
shared_chat.created_at,
shared_chat.updated_at,
shared_chat.parent_chat_id,
shared_chat.root_chat_id,
shared_chat.last_model_config_id,
shared_chat.archived,
shared_chat.last_error,
shared_chat.mode,
shared_chat.mcp_server_ids,
shared_chat.labels,
shared_chat.build_id,
shared_chat.agent_id,
shared_chat.pin_order,
shared_chat.last_read_message_id,
shared_chat.last_injected_context,
shared_chat.dynamic_tools,
shared_chat.organization_id,
shared_chat.plan_mode,
shared_chat.client_type,
shared_chat.last_turn_summary,
shared_chat.snapshot_version,
shared_chat.history_version,
shared_chat.queue_version,
shared_chat.generation_attempt,
shared_chat.retry_state,
shared_chat.retry_state_version,
shared_chat.runner_id,
shared_chat.requires_action_deadline_at,
COALESCE(root.user_acl, shared_chat.user_acl) AS user_acl,
COALESCE(root.group_acl, shared_chat.group_acl) AS group_acl,
owner.username AS owner_username,
owner.name AS owner_name
FROM
shared_chat
LEFT JOIN chats root ON root.id = COALESCE(shared_chat.root_chat_id, shared_chat.parent_chat_id)
JOIN visible_users owner ON owner.id = shared_chat.owner_id
)
SELECT *
FROM chats_expanded;
-- name: GetChatsByChatFileID :many
SELECT
*
@@ -2310,59 +2507,404 @@ WHERE chat_id = @chat_id::uuid
AND deleted = false
AND content::jsonb @> '[{"type": "context-file"}]';
-- name: AutoArchiveInactiveChats :many
-- Archives inactive root chats (pinned and already-archived chats skipped),
-- cascading to children via root_chat_id. Limits apply to roots, not total
-- rows. The Go caller passes @archive_cutoff as UTC midnight so that all
-- chats sharing the same last-activity date are archived together.
-- Used by dbpurge.
WITH to_archive AS (
SELECT
c.id,
-- Activity = MAX(cm.created_at) across the family, or c.created_at
-- when the family has no non-deleted messages.
COALESCE(activity.last_activity_at, c.created_at) AS last_activity_at
FROM chats c
LEFT JOIN LATERAL (
SELECT MAX(cm.created_at) AS last_activity_at
FROM chat_messages cm
JOIN chats fc ON fc.id = cm.chat_id
WHERE (fc.id = c.id OR fc.root_chat_id = c.id)
AND cm.deleted = false
) activity ON TRUE
WHERE c.archived = false
AND c.pin_order = 0
AND c.parent_chat_id IS NULL -- roots only
-- Redundant filter helps the planner use the partial index on created_at.
AND c.created_at < @archive_cutoff::timestamptz
-- New active statuses must be added here to prevent archiving.
AND c.status NOT IN ('running', 'pending', 'paused', 'requires_action')
AND COALESCE(activity.last_activity_at, c.created_at) < @archive_cutoff::timestamptz
-- Sorting by created_at lets Postgres drive the scan from the
-- partial index instead of evaluating every LATERAL subquery
-- before sorting. All candidates are past the cutoff, so the
-- archive order is immaterial once the backlog drains.
ORDER BY c.created_at ASC
LIMIT @limit_count
),
archived AS (
UPDATE chats c
SET archived = true, pin_order = 0, updated_at = NOW()
FROM to_archive t
WHERE (c.id = t.id OR c.root_chat_id = t.id) -- cascade to children
AND c.archived = false
RETURNING c.*
)
-- name: GetChatWorkerAcquisitionCandidates :many
-- Returns worker-runnable chats whose ownership is missing or whose
-- current runner heartbeat is stale. The runner_id IS NULL predicate is
-- a robustness extension for inconsistent rows where a worker_id exists
-- without a runner_id; normal missing ownership is worker_id IS NULL or
-- a missing or stale heartbeat row.
SELECT
a.*,
-- Children inherit their root's activity so last_activity_at is never null.
COALESCE(
t.last_activity_at,
(SELECT tr.last_activity_at FROM to_archive tr WHERE tr.id = a.root_chat_id),
a.created_at
)::timestamptz AS last_activity_at
FROM archived a
LEFT JOIN to_archive t ON t.id = a.id
-- created_at ASC flows through to dbpurge's digest truncation; see
-- buildDigestData in dbpurge.go for the tradeoff rationale.
ORDER BY (a.root_chat_id IS NULL) DESC, a.owner_id ASC, a.created_at ASC, a.id ASC;
chats_expanded.*,
chat_heartbeats.heartbeat_at AS current_heartbeat_at,
NOT EXISTS (
SELECT 1
FROM chat_heartbeats current_lease
WHERE current_lease.chat_id = chats_expanded.id
AND current_lease.runner_id = chats_expanded.runner_id
AND current_lease.heartbeat_at > NOW() - (INTERVAL '1 second' * @stale_seconds::int)
) AS heartbeat_stale
FROM chats_expanded
LEFT JOIN chat_heartbeats
ON chat_heartbeats.chat_id = chats_expanded.id
AND chat_heartbeats.runner_id = chats_expanded.runner_id
WHERE
chats_expanded.status IN ('running'::chat_status, 'interrupting'::chat_status, 'requires_action'::chat_status)
AND chats_expanded.archived = false
AND (
chats_expanded.worker_id IS NULL
OR chats_expanded.runner_id IS NULL
OR NOT EXISTS (
SELECT 1
FROM chat_heartbeats current_lease
WHERE current_lease.chat_id = chats_expanded.id
AND current_lease.runner_id = chats_expanded.runner_id
AND current_lease.heartbeat_at > NOW() - (INTERVAL '1 second' * @stale_seconds::int)
)
)
ORDER BY chats_expanded.updated_at ASC, chats_expanded.id ASC
LIMIT @limit_count::int;
-- name: GetChatsByIDsForRunnerSync :many
SELECT *
FROM chats_expanded
WHERE id = ANY(@ids::uuid[])
ORDER BY id ASC;
-- name: UpsertChatHeartbeats :exec
INSERT INTO chat_heartbeats (chat_id, runner_id, heartbeat_at)
SELECT chat_ids.chat_id, runner_ids.runner_id, NOW()
FROM unnest(@chat_ids::uuid[]) WITH ORDINALITY AS chat_ids(chat_id, ord)
JOIN unnest(@runner_ids::uuid[]) WITH ORDINALITY AS runner_ids(runner_id, ord) USING (ord)
ON CONFLICT (chat_id, runner_id) DO UPDATE
SET heartbeat_at = EXCLUDED.heartbeat_at;
-- name: DeleteStaleChatHeartbeats :execrows
DELETE FROM chat_heartbeats
WHERE heartbeat_at < NOW() - (INTERVAL '1 second' * @stale_seconds::int);
-- name: GetAutoArchiveInactiveChatCandidates :many
-- Returns read-only root chat candidates for state-machine-backed
-- auto-archive. Activity is computed across the root family. The query
-- limits roots, not total family members.
SELECT
chats_expanded.*,
COALESCE(activity.last_activity_at, chats_expanded.created_at)::timestamptz AS last_activity_at
FROM chats_expanded
LEFT JOIN LATERAL (
SELECT MAX(chat_messages.created_at) AS last_activity_at
FROM chat_messages
JOIN chats family_chat ON family_chat.id = chat_messages.chat_id
WHERE (family_chat.id = chats_expanded.id OR family_chat.root_chat_id = chats_expanded.id)
AND chat_messages.deleted = false
) activity ON TRUE
WHERE
chats_expanded.archived = false
AND chats_expanded.pin_order = 0
AND chats_expanded.parent_chat_id IS NULL
AND chats_expanded.created_at < @archive_cutoff::timestamptz
AND chats_expanded.status NOT IN (
'running'::chat_status,
'interrupting'::chat_status,
'pending'::chat_status,
'paused'::chat_status,
'requires_action'::chat_status
)
AND COALESCE(activity.last_activity_at, chats_expanded.created_at) < @archive_cutoff::timestamptz
ORDER BY chats_expanded.created_at ASC
LIMIT @limit_count::int;
-- =====================================================================
-- chatd core state machine queries.
--
-- These are consumed by the coderd/x/chatd/chatstate package. They
-- are intentionally kept side-by-side with the legacy chatd queries
-- above so the existing runtime keeps working while the state machine
-- lands behind it.
-- =====================================================================
-- name: LockChatAndBumpSnapshotVersion :one
-- Locks the chat row with FOR UPDATE and atomically increments its
-- snapshot_version, returning the post-bump chat. This is the single
-- entry point ChatMachine.Update uses to acquire the row lock and
-- allocate a new snapshot version in one round trip.
WITH bumped_chat AS (
UPDATE chats
SET snapshot_version = snapshot_version + 1
WHERE id = (
SELECT id FROM chats
WHERE id = @id::uuid
FOR UPDATE
)
RETURNING *
),
chats_expanded AS (
SELECT
bumped_chat.id,
bumped_chat.owner_id,
bumped_chat.workspace_id,
bumped_chat.title,
bumped_chat.status,
bumped_chat.worker_id,
bumped_chat.started_at,
bumped_chat.heartbeat_at,
bumped_chat.created_at,
bumped_chat.updated_at,
bumped_chat.parent_chat_id,
bumped_chat.root_chat_id,
bumped_chat.last_model_config_id,
bumped_chat.archived,
bumped_chat.last_error,
bumped_chat.mode,
bumped_chat.mcp_server_ids,
bumped_chat.labels,
bumped_chat.build_id,
bumped_chat.agent_id,
bumped_chat.pin_order,
bumped_chat.last_read_message_id,
bumped_chat.last_injected_context,
bumped_chat.dynamic_tools,
bumped_chat.organization_id,
bumped_chat.plan_mode,
bumped_chat.client_type,
bumped_chat.last_turn_summary,
bumped_chat.snapshot_version,
bumped_chat.history_version,
bumped_chat.queue_version,
bumped_chat.generation_attempt,
bumped_chat.retry_state,
bumped_chat.retry_state_version,
bumped_chat.runner_id,
bumped_chat.requires_action_deadline_at,
COALESCE(root.user_acl, bumped_chat.user_acl) AS user_acl,
COALESCE(root.group_acl, bumped_chat.group_acl) AS group_acl,
owner.username AS owner_username,
owner.name AS owner_name
FROM bumped_chat
LEFT JOIN chats root ON root.id = COALESCE(bumped_chat.root_chat_id, bumped_chat.parent_chat_id)
JOIN visible_users owner ON owner.id = bumped_chat.owner_id
)
SELECT *
FROM chats_expanded;
-- name: UpdateChatExecutionState :one
-- Atomically updates the execution-state-managed fields on a chat:
-- status, archived, last_error, ownership identifiers, and the
-- requires-action deadline. Callers compose this with transition
-- mutations inside a single ChatMachine.Update transaction.
WITH updated_chat AS (
UPDATE chats
SET
status = @status::chat_status,
archived = @archived::boolean,
worker_id = sqlc.narg('worker_id')::uuid,
runner_id = sqlc.narg('runner_id')::uuid,
last_error = sqlc.narg('last_error')::jsonb,
requires_action_deadline_at = sqlc.narg('requires_action_deadline_at')::timestamptz,
pin_order = CASE WHEN @archived::boolean THEN 0 ELSE pin_order END,
updated_at = NOW()
WHERE id = @id::uuid
RETURNING *
),
chats_expanded AS (
SELECT
updated_chat.id,
updated_chat.owner_id,
updated_chat.workspace_id,
updated_chat.title,
updated_chat.status,
updated_chat.worker_id,
updated_chat.started_at,
updated_chat.heartbeat_at,
updated_chat.created_at,
updated_chat.updated_at,
updated_chat.parent_chat_id,
updated_chat.root_chat_id,
updated_chat.last_model_config_id,
updated_chat.archived,
updated_chat.last_error,
updated_chat.mode,
updated_chat.mcp_server_ids,
updated_chat.labels,
updated_chat.build_id,
updated_chat.agent_id,
updated_chat.pin_order,
updated_chat.last_read_message_id,
updated_chat.last_injected_context,
updated_chat.dynamic_tools,
updated_chat.organization_id,
updated_chat.plan_mode,
updated_chat.client_type,
updated_chat.last_turn_summary,
updated_chat.snapshot_version,
updated_chat.history_version,
updated_chat.queue_version,
updated_chat.generation_attempt,
updated_chat.retry_state,
updated_chat.retry_state_version,
updated_chat.runner_id,
updated_chat.requires_action_deadline_at,
COALESCE(root.user_acl, updated_chat.user_acl) AS user_acl,
COALESCE(root.group_acl, updated_chat.group_acl) AS group_acl,
owner.username AS owner_username,
owner.name AS owner_name
FROM updated_chat
LEFT JOIN chats root ON root.id = COALESCE(updated_chat.root_chat_id, updated_chat.parent_chat_id)
JOIN visible_users owner ON owner.id = updated_chat.owner_id
)
SELECT *
FROM chats_expanded;
-- name: UpdateChatRetryState :one
-- Stores the client-visible retry payload. retry_state_version is
-- assigned by trigger from the current snapshot_version.
WITH updated_chat AS (
UPDATE chats
SET
retry_state = @retry_state::jsonb,
updated_at = NOW()
WHERE id = @id::uuid
RETURNING *
),
chats_expanded AS (
SELECT
updated_chat.id,
updated_chat.owner_id,
updated_chat.workspace_id,
updated_chat.title,
updated_chat.status,
updated_chat.worker_id,
updated_chat.started_at,
updated_chat.heartbeat_at,
updated_chat.created_at,
updated_chat.updated_at,
updated_chat.parent_chat_id,
updated_chat.root_chat_id,
updated_chat.last_model_config_id,
updated_chat.archived,
updated_chat.last_error,
updated_chat.mode,
updated_chat.mcp_server_ids,
updated_chat.labels,
updated_chat.build_id,
updated_chat.agent_id,
updated_chat.pin_order,
updated_chat.last_read_message_id,
updated_chat.last_injected_context,
updated_chat.dynamic_tools,
updated_chat.organization_id,
updated_chat.plan_mode,
updated_chat.client_type,
updated_chat.last_turn_summary,
updated_chat.snapshot_version,
updated_chat.history_version,
updated_chat.queue_version,
updated_chat.generation_attempt,
updated_chat.retry_state,
updated_chat.retry_state_version,
updated_chat.runner_id,
updated_chat.requires_action_deadline_at,
COALESCE(root.user_acl, updated_chat.user_acl) AS user_acl,
COALESCE(root.group_acl, updated_chat.group_acl) AS group_acl,
owner.username AS owner_username,
owner.name AS owner_name
FROM updated_chat
LEFT JOIN chats root ON root.id = COALESCE(updated_chat.root_chat_id, updated_chat.parent_chat_id)
JOIN visible_users owner ON owner.id = updated_chat.owner_id
)
SELECT *
FROM chats_expanded;
-- name: IncrementChatGenerationAttempt :one
-- Increments generation_attempt and returns the resulting value.
UPDATE chats
SET generation_attempt = generation_attempt + 1, updated_at = NOW()
WHERE id = @id::uuid
RETURNING generation_attempt;
-- name: GetDatabaseNow :one
-- Returns the current database timestamp. Used so transitions that
-- record deadlines or heartbeats rely on a clock that is consistent
-- with the database rather than the caller's local clock.
SELECT NOW()::timestamptz AS now;
-- name: InsertChatQueuedMessageWithCreator :one
-- Inserts a queued message that carries a position (from the default
-- sequence) and an explicit created_by reference. Use this when the
-- queued-message creator differs from the chat owner.
INSERT INTO chat_queued_messages (chat_id, content, model_config_id, created_by)
VALUES (
@chat_id::uuid,
@content::jsonb,
sqlc.narg('model_config_id')::uuid,
@created_by::uuid
)
RETURNING *;
-- name: GetChatQueuedMessagesByPosition :many
-- Returns queued messages in state-machine order (position ASC, id ASC).
SELECT * FROM chat_queued_messages
WHERE chat_id = @chat_id::uuid
ORDER BY position ASC, id ASC;
-- name: CountChatQueuedMessages :one
-- Cheap queue-length check used by ChatMachine.Update when deciding
-- whether the chat is in a "1" sub-state.
SELECT COUNT(*)::bigint AS count
FROM chat_queued_messages
WHERE chat_id = @chat_id::uuid;
-- name: GetChatQueuedMessageHead :one
-- Returns the queue head (lowest position, then lowest id).
SELECT * FROM chat_queued_messages
WHERE chat_id = @chat_id::uuid
ORDER BY position ASC, id ASC
LIMIT 1;
-- name: GetChatQueuedMessageByID :one
SELECT * FROM chat_queued_messages
WHERE id = @id::bigint AND chat_id = @chat_id::uuid;
-- name: DeleteChatQueuedMessageReturningCount :execrows
-- Deletes a queued message, scoped to the parent chat. Returns the
-- number of affected rows so callers can detect missing rows without
-- a follow-up read.
DELETE FROM chat_queued_messages
WHERE id = @id::bigint AND chat_id = @chat_id::uuid;
-- name: DeleteAllChatQueuedMessagesReturningCount :execrows
DELETE FROM chat_queued_messages
WHERE chat_id = @chat_id::uuid;
-- name: ReorderChatQueuedMessageToHead :execrows
-- Sets the target queued message's position to one less than the
-- current minimum position for that chat, moving it to the head.
UPDATE chat_queued_messages AS target
SET position = COALESCE(
(SELECT MIN(position) FROM chat_queued_messages WHERE chat_id = @chat_id::uuid),
0
) - 1
WHERE target.id = @id::bigint
AND target.chat_id = @chat_id::uuid
AND target.position > COALESCE(
(SELECT MIN(position) FROM chat_queued_messages WHERE chat_id = @chat_id::uuid),
target.position
);
-- name: UpsertChatHeartbeat :exec
-- Upserts a heartbeat row for the (chat_id, runner_id) lease. Uses
-- database time so callers do not depend on a local clock.
INSERT INTO chat_heartbeats (chat_id, runner_id, heartbeat_at)
VALUES (@chat_id::uuid, @runner_id::uuid, NOW())
ON CONFLICT (chat_id, runner_id) DO UPDATE
SET heartbeat_at = EXCLUDED.heartbeat_at;
-- name: GetChatHeartbeat :one
SELECT * FROM chat_heartbeats
WHERE chat_id = @chat_id::uuid AND runner_id = @runner_id::uuid;
-- name: IsChatHeartbeatStale :one
-- Returns true when there is no heartbeat row for (chat_id, runner_id)
-- or the existing row is older than @stale_seconds seconds by database
-- time. chatstate calls this in a single query so the staleness check
-- is atomic and does not depend on the caller's local clock.
SELECT NOT EXISTS (
SELECT 1 FROM chat_heartbeats
WHERE chat_id = @chat_id::uuid
AND runner_id = @runner_id::uuid
AND heartbeat_at > NOW() - (INTERVAL '1 second' * @stale_seconds::int)
) AS stale;
-- name: DeleteChatHeartbeats :execrows
-- Deletes heartbeat rows for the supplied (chat_id, runner_id) pairs.
DELETE FROM chat_heartbeats
USING unnest(@chat_ids::uuid[]) WITH ORDINALITY AS chat_ids(chat_id, ord)
JOIN unnest(@runner_ids::uuid[]) WITH ORDINALITY AS runner_ids(runner_id, ord) USING (ord)
WHERE chat_heartbeats.chat_id = chat_ids.chat_id
AND chat_heartbeats.runner_id = runner_ids.runner_id;
-- name: DeleteAllChatHeartbeats :exec
-- Deletes all heartbeat rows for the chat. Used during ownership
-- transitions that abandon a lease.
DELETE FROM chat_heartbeats WHERE chat_id = @chat_id::uuid;
+84
View File
@@ -0,0 +1,84 @@
package pubsub
import (
"context"
"encoding/json"
"fmt"
"github.com/google/uuid"
"golang.org/x/xerrors"
)
// ChatStateUpdateChannel returns the pubsub channel that receives one
// `chat:update:{chat_id}` message every time the chatstate state
// machine commits a transition for the chat.
func ChatStateUpdateChannel(chatID uuid.UUID) string {
return fmt.Sprintf("chat:update:%s", chatID)
}
// ChatStateOwnershipChannel is the global pubsub channel that
// receives ownership hints when a chat is runnable but currently has
// missing or stale ownership. Workers listen on this channel to know
// when to attempt acquisition.
const ChatStateOwnershipChannel = "chat:ownership"
// ChatStateUpdateMessage is the JSON payload published on
// [ChatStateUpdateChannel] after every successful CreateChat or
// ChatMachine.Update commit. It carries the committed post-transition
// versions and ownership identifiers so stream loops and workers can
// decide whether to refetch state.
type ChatStateUpdateMessage struct {
SnapshotVersion int64 `json:"snapshot_version"`
WorkerID *uuid.UUID `json:"worker_id,omitempty"`
RunnerID *uuid.UUID `json:"runner_id,omitempty"`
HistoryVersion int64 `json:"history_version"`
QueueVersion int64 `json:"queue_version"`
RetryStateVersion int64 `json:"retry_state_version"`
GenerationAttempt int64 `json:"generation_attempt"`
Status string `json:"status"`
Archived bool `json:"archived"`
}
// ChatStateOwnershipMessage is the JSON payload published on
// [ChatStateOwnershipChannel] when ownership is missing or stale for
// a runnable chat. Subscribers should reload the chat row to confirm
// ownership before acting.
type ChatStateOwnershipMessage struct {
ChatID uuid.UUID `json:"chat_id"`
SnapshotVersion int64 `json:"snapshot_version"`
}
// HandleChatStateUpdate wraps a typed callback for
// [ChatStateUpdateMessage] consumption, following the same pattern as
// HandleChatWatchEvent.
func HandleChatStateUpdate(cb func(ctx context.Context, payload ChatStateUpdateMessage, err error)) func(ctx context.Context, message []byte, err error) {
return func(ctx context.Context, message []byte, err error) {
if err != nil {
cb(ctx, ChatStateUpdateMessage{}, xerrors.Errorf("chat state update pubsub: %w", err))
return
}
var payload ChatStateUpdateMessage
if uerr := json.Unmarshal(message, &payload); uerr != nil {
cb(ctx, ChatStateUpdateMessage{}, xerrors.Errorf("unmarshal chat state update: %w", uerr))
return
}
cb(ctx, payload, err)
}
}
// HandleChatStateOwnership wraps a typed callback for
// [ChatStateOwnershipMessage] consumption.
func HandleChatStateOwnership(cb func(ctx context.Context, payload ChatStateOwnershipMessage, err error)) func(ctx context.Context, message []byte, err error) {
return func(ctx context.Context, message []byte, err error) {
if err != nil {
cb(ctx, ChatStateOwnershipMessage{}, xerrors.Errorf("chat state ownership pubsub: %w", err))
return
}
var payload ChatStateOwnershipMessage
if uerr := json.Unmarshal(message, &payload); uerr != nil {
cb(ctx, ChatStateOwnershipMessage{}, xerrors.Errorf("unmarshal chat state ownership: %w", uerr))
return
}
cb(ctx, payload, err)
}
}
@@ -0,0 +1,169 @@
package chatstate_test
import (
"context"
"sync"
"testing"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/x/chatd/chatstate"
"github.com/coder/coder/v2/testutil"
)
// waitForChan returns true if c receives a value before ctx is done.
// Helper used in concurrency tests to avoid time.Sleep.
func waitForChan(ctx context.Context, c <-chan struct{}) bool {
select {
case <-c:
return true
case <-ctx.Done():
return false
}
}
// stillBlocked returns true if c has NOT received a value yet. The
// caller must already have established a happens-before ordering via
// some other channel so this check is meaningful.
func stillBlocked(c <-chan struct{}) bool {
select {
case <-c:
return false
default:
return true
}
}
// TestLockLocksChatRow verifies that ChatMachine.Lock holds the chat
// row's FOR UPDATE lock until the callback returns, so a concurrent
// ChatMachine.Update cannot enter its callback until the Lock
// callback releases.
func TestLockLocksChatRow(t *testing.T) {
t.Parallel()
f := newTestFixture(t)
ctx := testutil.Context(t, testutil.WaitMedium)
created := createTestChat(t, f)
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
lockEntered := make(chan struct{})
releaseLock := make(chan struct{})
updateEntered := make(chan struct{})
// Goroutine A: hold a Lock and block.
var lockErr error
var lockWG sync.WaitGroup
lockWG.Add(1)
go func() {
defer lockWG.Done()
lockErr = m.Lock(ctx, func(_ database.Store) error {
close(lockEntered)
<-releaseLock
return nil
})
}()
// Wait until A is inside its Lock callback (and therefore holds
// the FOR UPDATE lock).
require.True(t, waitForChan(ctx, lockEntered), "Lock callback never started")
// Goroutine B: try to Update the same chat. It must block on
// LockChatAndBumpSnapshotVersion until A releases.
var updateErr error
var updateWG sync.WaitGroup
updateWG.Add(1)
go func() {
defer updateWG.Done()
updateErr = m.Update(ctx, func(_ *chatstate.Tx) error {
close(updateEntered)
return nil
})
}()
// Loop a few times re-checking that B is still blocked, with a
// fresh round-trip through the database to give B's transaction
// every chance to commit if the lock weren't held. The loop avoids
// time.Sleep by using the database call itself as a "barrier".
for range 5 {
// Force a sync round-trip through the DB. This serves the
// same role as a Sleep but is deterministic: by the time
// this read completes, the scheduler has had a chance to
// run goroutine B if it could make progress.
_, err := f.DB.GetChatByID(ctx, created.Chat.ID)
require.NoError(t, err)
require.True(t, stillBlocked(updateEntered),
"Update entered while Lock was still held")
}
// Release Lock and confirm Update completes successfully.
close(releaseLock)
require.True(t, waitForChan(ctx, updateEntered),
"Update callback never started after Lock released")
updateWG.Wait()
lockWG.Wait()
require.NoError(t, lockErr)
require.NoError(t, updateErr)
}
// TestLockRollsBackCallbackError verifies that a Lock callback
// returning an error rolls back the surrounding transaction.
func TestLockRollsBackCallbackError(t *testing.T) {
t.Parallel()
f := newTestFixture(t)
ctx := testutil.Context(t, testutil.WaitShort)
created := createTestChat(t, f)
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
before := f.readChat(ctx, t, created.Chat.ID)
publishedBefore := len(f.Pub.channels)
sentinel := xerrors.New("lock callback error")
err := m.Lock(ctx, func(store database.Store) error {
// Try a write that should be rolled back.
_, werr := store.UpdateChatByID(ctx, database.UpdateChatByIDParams{
ID: created.Chat.ID,
Title: "rollback-me",
})
require.NoError(t, werr)
return sentinel
})
require.ErrorIs(t, err, sentinel)
after := f.readChat(ctx, t, created.Chat.ID)
require.Equal(t, before.Title, after.Title, "Lock callback error rolls back writes")
require.Equal(t, publishedBefore, len(f.Pub.channels), "Lock publishes nothing on error")
}
// TestConcurrentUpdatesSerializeOnChatRow verifies that two
// goroutines racing to Update the same chat both succeed but their
// effects serialize on the chat row lock: snapshot_version advances
// by exactly N (one per Update) and each transition observes the
// effects of the prior one.
func TestConcurrentUpdatesSerializeOnChatRow(t *testing.T) {
t.Parallel()
f := newTestFixture(t)
ctx := testutil.Context(t, testutil.WaitMedium)
created := createTestChat(t, f)
before := f.readChat(ctx, t, created.Chat.ID)
const updates = 8
var wg sync.WaitGroup
wg.Add(updates)
errs := make([]error, updates)
for i := 0; i < updates; i++ {
i := i
go func() {
defer wg.Done()
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
errs[i] = m.Update(ctx, func(_ *chatstate.Tx) error { return nil })
}()
}
wg.Wait()
for i, err := range errs {
require.NoError(t, err, "concurrent update %d failed", i)
}
after := f.readChat(ctx, t, created.Chat.ID)
require.Equal(t, before.SnapshotVersion+int64(updates), after.SnapshotVersion,
"snapshot_version advanced by exactly one per update")
}
+36
View File
@@ -0,0 +1,36 @@
// Package chatstate owns the durable execution-state transitions for
// the chatd subsystem. It implements the core state machine described
// in the chatd RFC: a 13-state execution model plus a 2-state
// ownership model on top of database rows in `chats`,
// `chat_messages`, `chat_queued_messages`, and `chat_heartbeats`.
//
// The package exposes two top-level entry points:
//
// - [CreateChat] creates a brand new chat with its initial history
// in a single transaction. It is standalone because no chat-scoped
// state machine instance can exist before the chat row is written.
// - [ChatMachine] wraps an existing chat. Callers use it to apply
// one or more transitions atomically via [ChatMachine.Update], or
// to read related rows while holding the chat row lock via
// [ChatMachine.Lock].
//
// Every successful [ChatMachine.Update] call locks the chat row,
// advances `snapshot_version` exactly once, applies transition methods
// in order, and (on commit) publishes a single typed `chat:update`
// pubsub message describing the post-transition snapshot. Optional
// `chat:ownership` hints are published only when the post-transition
// state is runnable and ownership is missing or stale. Stream side
// effects are handled by `chat:update` consumers, and ownership hints
// wake chat workers.
//
// Transition methods are explicit, typed wrappers around the SQL
// mutations needed to move between states. Each transition reads the
// current chat row and queue cardinality, classifies the resulting
// execution state, and rejects with an [*TransitionError] wrapping
// [ErrTransitionNotAllowed] when the transition is not legal from
// that state. The transition matrix and
// state classification helpers live in `state.go` and `transition.go`
// alongside unit-testable classifiers; the SQL is in
// `coderd/database/queries/chats.sql` (e.g. `LockChatAndBumpSnapshotVersion`,
// `UpdateChatExecutionState`).
package chatstate
+152
View File
@@ -0,0 +1,152 @@
package chatstate
import (
"errors"
"fmt"
"golang.org/x/xerrors"
)
// Sentinel errors returned by chatstate transitions and helpers.
// Callers should use errors.Is to test for these.
var (
// ErrTransitionNotAllowed is returned when a transition is applied
// to a chat whose current execution state does not permit it. The
// concrete error returned by transition methods is a
// *TransitionError that wraps this sentinel.
ErrTransitionNotAllowed = xerrors.New("chat state transition not allowed")
// ErrInvalidState is returned when the chat row, queue, and
// archive flag together produce a combination outside the 13
// valid execution states described in the RFC.
ErrInvalidState = xerrors.New("chat is in an invalid execution state")
// ErrQueuedMessageNotFound is returned by queue-targeting
// transitions (delete, promote) when the supplied queued message
// ID does not match a row on the chat.
ErrQueuedMessageNotFound = xerrors.New("queued message not found")
// ErrMessageNotFound is returned by [Tx.EditMessage] when the
// target chat_messages row is missing or belongs to another chat.
ErrMessageNotFound = xerrors.New("chat message not found")
// ErrChatNotFound is returned when a non-create transition is
// applied to a chat row that does not exist (or has been deleted
// since the transition started).
ErrChatNotFound = xerrors.New("chat not found")
// ErrChatNotRoot is returned by family-archive helpers when the
// supplied chat is not a root chat (its parent_chat_id is set).
ErrChatNotRoot = xerrors.New("chat is not a root chat")
// ErrEditedMessageNotUser is returned by [Tx.EditMessage] when the
// targeted chat_messages row exists but its role is not user.
ErrEditedMessageNotUser = xerrors.New("only user messages can be edited")
// ErrMessageQueueFull is returned by queue-appending transitions
// when the per-chat queue cap has been reached. The concrete
// error returned by transitions is a *MessageQueueFullError that
// wraps this sentinel.
ErrMessageQueueFull = xerrors.New("chat message queue is full")
// ErrToolResultDuplicate is returned by [Tx.CompleteRequiresAction]
// when the same tool_call_id appears more than once in the
// submitted results.
ErrToolResultDuplicate = xerrors.New("duplicate tool result")
// ErrToolResultUnexpected is returned by
// [Tx.CompleteRequiresAction] when a submitted tool_call_id does
// not correspond to a pending dynamic tool call.
ErrToolResultUnexpected = xerrors.New("unexpected tool result")
// ErrToolResultMissing is returned by [Tx.CompleteRequiresAction]
// when a pending dynamic tool call has no submitted result.
ErrToolResultMissing = xerrors.New("missing tool result")
// ErrToolResultInvalidJSON is returned by
// [Tx.CompleteRequiresAction] when a submitted tool result output
// is not valid JSON.
ErrToolResultInvalidJSON = xerrors.New("tool result output is not valid JSON")
)
// MessageQueueFullError carries the per-chat queue cap so HTTP
// endpoints can include the cap in their response detail. It wraps
// [ErrMessageQueueFull] so callers can match it with errors.Is.
type MessageQueueFullError struct {
Max int64
}
// Error implements the error interface.
func (e *MessageQueueFullError) Error() string {
return fmt.Sprintf("chat message queue is full (max %d)", e.Max)
}
// Unwrap returns [ErrMessageQueueFull] so callers can match the
// generic sentinel.
func (*MessageQueueFullError) Unwrap() error { return ErrMessageQueueFull }
// ToolResultValidationError carries a structured tool-result
// validation failure. It always wraps a specific sentinel
// (ErrToolResultDuplicate, ErrToolResultMissing,
// ErrToolResultUnexpected, ErrToolResultInvalidJSON) so callers can
// match either the generic sentinel or the specific cause.
type ToolResultValidationError struct {
Cause error
ToolCallID string
}
// Error implements the error interface.
func (e *ToolResultValidationError) Error() string {
if e.ToolCallID != "" {
return fmt.Sprintf("%s: %s", e.Cause.Error(), e.ToolCallID)
}
return e.Cause.Error()
}
// Unwrap returns the specific cause so callers can match it.
func (e *ToolResultValidationError) Unwrap() error { return e.Cause }
// TransitionError carries the structured detail for a rejected
// transition. It always wraps [ErrTransitionNotAllowed] so callers can
// match with errors.Is without losing context. When a specific
// chatstate sentinel is the proximate cause, Cause is set and
// errors.Is will match that sentinel too.
type TransitionError struct {
Transition Transition
From ExecutionState
Reason string
Cause error
}
// Error implements the error interface.
func (e *TransitionError) Error() string {
if e.Reason == "" {
return fmt.Sprintf(
"chat state transition %s not allowed from state %s",
e.Transition, e.From,
)
}
return fmt.Sprintf(
"chat state transition %s not allowed from state %s: %s",
e.Transition, e.From, e.Reason,
)
}
// Unwrap returns the error chain attached to this error. The chain
// always includes [ErrTransitionNotAllowed], and may include a more
// specific cause through errors.Join, so callers can use errors.Is
// without custom matching logic on TransitionError.
func (e *TransitionError) Unwrap() error { return e.Cause }
// newTransitionError constructs a typed TransitionError. Returning the
// pointer type lets callers inspect the structured fields when needed.
func newTransitionError(t Transition, from ExecutionState, reason string) *TransitionError {
return &TransitionError{Transition: t, From: from, Reason: reason, Cause: ErrTransitionNotAllowed}
}
// newTransitionErrorWithCause constructs a TransitionError carrying
// a specific underlying sentinel so callers can match the cause with
// errors.Is.
func newTransitionErrorWithCause(t Transition, from ExecutionState, cause error, reason string) *TransitionError {
return &TransitionError{Transition: t, From: from, Reason: reason, Cause: errors.Join(ErrTransitionNotAllowed, cause)}
}
+130
View File
@@ -0,0 +1,130 @@
package chatstate
import (
"context"
"database/sql"
"errors"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database"
)
// SetFamilyArchivedInput configures [SetFamilyArchived]. The struct
// shape avoids a boolean flag parameter at the API surface; callers
// build it explicitly with named fields for clarity.
type SetFamilyArchivedInput struct {
// RootID identifies the family root. SetFamilyArchived rejects
// calls for child chats with [ErrChatNotRoot] and unknown chats
// with [ErrChatNotFound].
RootID uuid.UUID
// Archived is the desired post-call archived value for every
// family member.
Archived bool
}
// SetFamilyArchived runs Update for every chat in the root chat's
// family inside one transaction, applying SetArchived when the chat's
// archived flag differs from the requested value. It owns its
// transaction lifecycle and its [PublishBuffer] lifecycle: pubsub
// publications are buffered while the transaction is open and
// flushed only after a successful commit; the deferred Discard
// suppresses every buffered publication on failure.
//
// On success SetFamilyArchived returns one [database.Chat] per
// family member in the order returned by GetChatFamilyIDsByRootID
// (root first, then children).
//
// Family members that are already in the [StateInvalid] execution
// state cause SetFamilyArchived to return [ErrInvalidState] and roll
// back the cascade even when their archived flag already matches the
// desired value; invalid-state detection is never bypassed.
//
// Family members that are valid and already match the desired
// archived value still run through Update, which increments their
// snapshot version and publishes a fresh snapshot without changing
// the archived flag. Advancing the snapshot version without a field
// change is safe, and it keeps publication behavior uniform while a
// partially archived family converges to the desired state.
func SetFamilyArchived(
ctx context.Context,
store database.Store,
publisher Publisher,
input SetFamilyArchivedInput,
) ([]database.Chat, error) {
if store == nil {
return nil, xerrors.New("chatstate: SetFamilyArchived called with nil store")
}
if publisher == nil {
return nil, xerrors.New("chatstate: SetFamilyArchived called with nil publisher")
}
buffer := NewPublishBuffer(publisher)
defer buffer.Discard()
var familyChats []database.Chat
err := store.InTx(func(tx database.Store) error {
// Lock the root chat first so concurrent archive races on the
// same family serialize on a stable row.
root, err := tx.GetChatByIDForUpdate(ctx, input.RootID)
if errors.Is(err, sql.ErrNoRows) {
return ErrChatNotFound
}
if err != nil {
return xerrors.Errorf("lock root chat for archive: %w", err)
}
if root.ParentChatID.Valid {
return ErrChatNotRoot
}
ids, err := tx.GetChatFamilyIDsByRootID(ctx, input.RootID)
if err != nil {
return xerrors.Errorf("get chat family: %w", err)
}
if len(ids) == 0 {
return ErrChatNotFound
}
familyChats = make([]database.Chat, 0, len(ids))
for _, id := range ids {
var chat database.Chat
machine := NewChatMachine(tx, buffer, id, Options{})
err := machine.Update(ctx, func(state *Tx) error {
// Load the current chat and classify it so we can
// reject invalid-state members with ErrInvalidState
// even when their archived flag already matches.
current, from, lerr := state.loadState()
if lerr != nil {
return lerr
}
if from == StateInvalid {
return ErrInvalidState
}
if current.Archived == input.Archived {
chat = current
return nil
}
if _, err := state.SetArchived(SetArchivedInput{Archived: input.Archived}); err != nil {
return err
}
var err error
chat, err = state.Store().GetChatByID(state.Ctx(), state.ChatID())
if err != nil {
return xerrors.Errorf("reload archived chat: %w", err)
}
return nil
})
if err != nil {
return err
}
familyChats = append(familyChats, chat)
}
return nil
}, nil)
if err != nil {
return nil, err
}
if err := buffer.Flush(); err != nil {
return familyChats, err
}
return familyChats, nil
}
+218
View File
@@ -0,0 +1,218 @@
package chatstate_test
import (
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
"github.com/coder/coder/v2/coderd/x/chatd/chatstate"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
)
// TestSetFamilyArchivedRejectsChildChat asserts the chatstate helper
// rejects calls that target a child chat. Family archive flows must
// always start at the root.
func TestSetFamilyArchivedRejectsChildChat(t *testing.T) {
t.Parallel()
f := newTestFixture(t)
ctx := testutil.Context(t, testutil.WaitShort)
root := dbgen.Chat(t, f.DB, database.Chat{
OrganizationID: f.Org.ID,
OwnerID: f.User.ID,
LastModelConfigID: f.Model.ID,
Title: "root",
})
child := dbgen.Chat(t, f.DB, database.Chat{
OrganizationID: f.Org.ID,
OwnerID: f.User.ID,
LastModelConfigID: f.Model.ID,
Title: "child",
ParentChatID: uuid.NullUUID{UUID: root.ID, Valid: true},
RootChatID: uuid.NullUUID{UUID: root.ID, Valid: true},
})
_, err := chatstate.SetFamilyArchived(ctx, f.DB, f.Pub, chatstate.SetFamilyArchivedInput{RootID: child.ID, Archived: true})
require.ErrorIs(t, err, chatstate.ErrChatNotRoot)
require.False(t, f.readChat(ctx, t, root.ID).Archived,
"failed family archive must not touch the root")
require.False(t, f.readChat(ctx, t, child.ID).Archived,
"failed family archive must not touch the child")
}
// TestSetFamilyArchivedRollsBackWhenMemberCannotArchive verifies that
// SetFamilyArchived is atomic: when one family member is in a state
// that cannot satisfy the SetArchived transition, the whole cascade
// rolls back and no publications reach the inner publisher.
func TestSetFamilyArchivedRollsBackWhenMemberCannotArchive(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitShort)
user, org, model := seedFamilyDeps(t, db)
// Root chat: waiting is archive-eligible (state W).
root := dbgen.Chat(t, db, database.Chat{
OrganizationID: org.ID,
OwnerID: user.ID,
LastModelConfigID: model.ID,
Title: "root",
Status: database.ChatStatusWaiting,
})
// Child chat: running with no queue is R0 and NOT archive
// eligible per the chatstate transition matrix.
child := dbgen.Chat(t, db, database.Chat{
OrganizationID: org.ID,
OwnerID: user.ID,
LastModelConfigID: model.ID,
Title: "child",
Status: database.ChatStatusRunning,
ParentChatID: uuid.NullUUID{UUID: root.ID, Valid: true},
RootChatID: uuid.NullUUID{UUID: root.ID, Valid: true},
})
pub := newRecordingPubsub()
_, err := chatstate.SetFamilyArchived(ctx, db, pub, chatstate.SetFamilyArchivedInput{RootID: root.ID, Archived: true})
require.Error(t, err, "child in R0 must reject SetArchived")
require.ErrorIs(t, err, chatstate.ErrTransitionNotAllowed)
rootAfter, err := db.GetChatByID(ctx, root.ID)
require.NoError(t, err)
require.False(t, rootAfter.Archived, "root archive must roll back when a child cannot archive")
childAfter, err := db.GetChatByID(ctx, child.ID)
require.NoError(t, err)
require.False(t, childAfter.Archived, "child must not be archived in the rolled-back cascade")
require.Empty(t, pub.channels,
"rolled-back family archive must publish nothing through the inner publisher")
}
// TestSetFamilyArchivedRejectsInvalidStateEvenWhenAlreadyDesired
// verifies that invalid-state detection is never bypassed: a family
// member in StateInvalid causes the cascade to fail with
// ErrInvalidState even when that member's archived flag already
// matches the desired value.
func TestSetFamilyArchivedRejectsInvalidStateEvenWhenAlreadyDesired(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitShort)
user, org, model := seedFamilyDeps(t, db)
root := dbgen.Chat(t, db, database.Chat{
OrganizationID: org.ID,
OwnerID: user.ID,
LastModelConfigID: model.ID,
Title: "root",
Status: database.ChatStatusWaiting,
})
child := dbgen.Chat(t, db, database.Chat{
OrganizationID: org.ID,
OwnerID: user.ID,
LastModelConfigID: model.ID,
Title: "child",
// status=waiting, archived=true; we will add a queued message
// to produce the chatstate-invalid combination (archived chat
// with a queued backlog is outside the 13 valid states).
Status: database.ChatStatusWaiting,
Archived: true,
ParentChatID: uuid.NullUUID{UUID: root.ID, Valid: true},
RootChatID: uuid.NullUUID{UUID: root.ID, Valid: true},
})
// Seed a queued message under the child to push it into the
// chatstate-invalid combination.
rawContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
codersdk.ChatMessageText("queued"),
})
require.NoError(t, err)
_, err = db.InsertChatQueuedMessage(ctx, database.InsertChatQueuedMessageParams{
ChatID: child.ID,
Content: rawContent.RawMessage,
ModelConfigID: uuid.NullUUID{},
})
require.NoError(t, err)
pub := newRecordingPubsub()
_, err = chatstate.SetFamilyArchived(ctx, db, pub, chatstate.SetFamilyArchivedInput{
RootID: root.ID,
Archived: true,
})
require.ErrorIs(t, err, chatstate.ErrInvalidState,
"invalid-state child blocks the cascade even when archived flag already matches")
// Root must not be archived because the cascade rolled back.
rootAfter, err := db.GetChatByID(ctx, root.ID)
require.NoError(t, err)
require.False(t, rootAfter.Archived, "root must roll back when a child is in StateInvalid")
require.Empty(t, pub.channels,
"rolled-back cascade must not publish anything")
}
// TestSetFamilyArchivedAcceptsAlreadyDesiredMembers verifies that an
// individually archived child does not block a root archive cascade.
// The cascade converges to the desired state even when some family
// members already match it.
func TestSetFamilyArchivedAcceptsAlreadyDesiredMembers(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitShort)
user, org, model := seedFamilyDeps(t, db)
root := dbgen.Chat(t, db, database.Chat{
OrganizationID: org.ID,
OwnerID: user.ID,
LastModelConfigID: model.ID,
Title: "root",
Status: database.ChatStatusWaiting,
})
child := dbgen.Chat(t, db, database.Chat{
OrganizationID: org.ID,
OwnerID: user.ID,
LastModelConfigID: model.ID,
Title: "child",
Status: database.ChatStatusWaiting,
ParentChatID: uuid.NullUUID{UUID: root.ID, Valid: true},
RootChatID: uuid.NullUUID{UUID: root.ID, Valid: true},
Archived: true,
})
pub := newRecordingPubsub()
family, err := chatstate.SetFamilyArchived(ctx, db, pub, chatstate.SetFamilyArchivedInput{RootID: root.ID, Archived: true})
require.NoError(t, err,
"already archived members must not block the cascade")
require.Len(t, family, 2)
rootAfter, err := db.GetChatByID(ctx, root.ID)
require.NoError(t, err)
require.True(t, rootAfter.Archived)
childAfter, err := db.GetChatByID(ctx, child.ID)
require.NoError(t, err)
require.True(t, childAfter.Archived)
}
func seedFamilyDeps(t *testing.T, db database.Store) (database.User, database.Organization, database.ChatModelConfig) {
t.Helper()
user := dbgen.User(t, db, database.User{})
org := dbgen.Organization(t, db, database.Organization{})
dbgen.OrganizationMember(t, db, database.OrganizationMember{
UserID: user.ID,
OrganizationID: org.ID,
})
dbgen.ChatProvider(t, db, database.ChatProvider{
Provider: "openai",
DisplayName: "openai",
BaseUrl: "http://example.invalid",
})
model := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{
Provider: "openai",
IsDefault: true,
})
return user, org, model
}
+92
View File
@@ -0,0 +1,92 @@
package chatstate_test
import (
"context"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/database"
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
"github.com/coder/coder/v2/coderd/x/chatd/chatstate"
"github.com/coder/coder/v2/testutil"
)
// ownershipPublishCount returns the number of `chat:ownership` messages
// recorded so far on the test publisher. Tests use it to assert that
// transitions do or do not publish an ownership hint.
func (r *recordingPubsub) ownershipPublishCount() int {
count := 0
for _, c := range r.channels {
if c == coderdpubsub.ChatStateOwnershipChannel {
count++
}
}
return count
}
// sendQueuedMessage seeds one queued user message via SendMessage with
// BusyBehaviorQueue. The chat must already be in a state that allows
// SendMessage (typically R0, R1, or I*).
func sendQueuedMessage(t *testing.T, f *testFixture, m *chatstate.ChatMachine, body string) chatstate.SendMessageResult {
t.Helper()
ctx := testutil.Context(t, testutil.WaitShort)
var send chatstate.SendMessageResult
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
var err error
send, err = tx.SendMessage(chatstate.SendMessageInput{
Message: userTextMessage(body, f.User.ID, f.Model.ID),
BusyBehavior: chatstate.BusyBehaviorQueue,
})
return err
}))
return send
}
// sendInterruptMessage seeds one queued user message via SendMessage
// with BusyBehaviorInterrupt. From R0/R1 this transitions the chat to
// `interrupting` and appends the new user message to the queue tail.
func sendInterruptMessage(t *testing.T, f *testFixture, m *chatstate.ChatMachine, body string) chatstate.SendMessageResult {
t.Helper()
ctx := testutil.Context(t, testutil.WaitShort)
var send chatstate.SendMessageResult
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
var err error
send, err = tx.SendMessage(chatstate.SendMessageInput{
Message: userTextMessage(body, f.User.ID, f.Model.ID),
BusyBehavior: chatstate.BusyBehaviorInterrupt,
})
return err
}))
return send
}
// queuedIDsByPosition returns the queued-message IDs for the chat in
// queue order.
func queuedIDsByPosition(ctx context.Context, t *testing.T, f *testFixture, chatID uuid.UUID) []int64 {
t.Helper()
rows, err := f.DB.GetChatQueuedMessagesByPosition(ctx, chatID)
require.NoError(t, err)
ids := make([]int64, len(rows))
for i, r := range rows {
ids[i] = r.ID
}
return ids
}
// historyMessageIDs returns the chat history message IDs ordered by
// row id. Used to assert that PromoteQueuedMessage from R1/I1 does NOT
// insert any history rows.
func historyMessageIDs(ctx context.Context, t *testing.T, f *testFixture, chatID uuid.UUID) []int64 {
t.Helper()
msgs, err := f.DB.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: chatID,
})
require.NoError(t, err)
out := make([]int64, len(msgs))
for i, m := range msgs {
out[i] = m.ID
}
return out
}
+300
View File
@@ -0,0 +1,300 @@
package chatstate
import (
"context"
"database/sql"
"errors"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database"
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
)
// HeartbeatStaleSeconds is the threshold chatstate uses when deciding
// whether to publish a `chat:ownership` hint for a runnable chat. A
// heartbeat older than this many seconds (by database time) counts
// as stale and triggers a hint so an idle worker can attempt a
// takeover.
const HeartbeatStaleSeconds = 30
// Options configures a [ChatMachine]. Reserved for future tunables;
// currently empty.
type Options struct{}
// ChatMachine is a chat-scoped handle for state-machine operations on
// a single chat row. It captures the database store, the pubsub
// publisher, and the chat ID at construction time so callers do not
// have to thread them through Update, Lock, or any transition method.
//
// ChatMachine values are cheap. Create one per chat for the lifetime
// of a request or worker turn; do not cache mutable chat state across
// calls.
type ChatMachine struct {
store database.Store
publisher Publisher
chatID uuid.UUID
opts Options
}
// NewChatMachine constructs a chat-scoped state machine handle. The
// store may be the root database handle or an existing transaction
// handle; publisher is the pubsub used for `chat:update` and
// `chat:ownership` emissions. Both are required and captured for the
// lifetime of the returned machine.
func NewChatMachine(
store database.Store,
publisher Publisher,
chatID uuid.UUID,
opts Options,
) *ChatMachine {
return &ChatMachine{
store: store,
publisher: publisher,
chatID: chatID,
opts: opts,
}
}
// ChatID returns the chat ID this machine is scoped to.
func (m *ChatMachine) ChatID() uuid.UUID { return m.chatID }
// Tx is the per-transaction handle passed to [ChatMachine.Update]
// callbacks. It carries the active context, the transactional store,
// and the chat ID. Tx does not cache mutable chat state across calls:
// every transition method reads the chat row and queue cardinality
// from the database on entry, so a bundle of transitions inside one
// Update callback always validates against the latest committed state.
type Tx struct {
ctx context.Context
store database.Store
chatID uuid.UUID
}
// Ctx returns the context the surrounding [ChatMachine.Update] call
// is using.
func (tx *Tx) Ctx() context.Context { return tx.ctx }
// ChatID returns the chat ID this transaction is scoped to.
func (tx *Tx) ChatID() uuid.UUID { return tx.chatID }
// Store exposes the active transaction store so callers can perform
// validation reads (for example loading the messages affected by an
// EditMessage transition) and metadata writes (for example updating
// title or labels) that must be atomic with the transition.
//
// Callers MUST NOT use Store to mutate execution-state tables
// (chats.status, chat_messages, chat_queued_messages, chat_heartbeats,
// or the version fields on chats). Those mutations belong to the
// transition methods and are validated against the state machine
// matrix.
func (tx *Tx) Store() database.Store { return tx.store }
// loadState reads the current chat row and queue cardinality from the
// active transaction, classifies the execution state, and returns the
// inputs every transition method needs. Returns ErrChatNotFound if
// the chat row was deleted in this transaction (or never existed).
func (tx *Tx) loadState() (database.Chat, ExecutionState, error) {
chat, err := tx.store.GetChatByID(tx.ctx, tx.chatID)
if errors.Is(err, sql.ErrNoRows) {
return database.Chat{}, StateN, ErrChatNotFound
}
if err != nil {
return database.Chat{}, "", xerrors.Errorf("load chat: %w", err)
}
count, err := tx.store.CountChatQueuedMessages(tx.ctx, tx.chatID)
if err != nil {
return database.Chat{}, "", xerrors.Errorf("count queued messages: %w", err)
}
return chat, ClassifyExecutionState(chat, count > 0, true), nil
}
// requireFromAllowed loads the current state and validates t against
// the transition matrix. Returns the loaded chat and execution state
// on success, [ErrInvalidState] when the chat is in an invalid state
// and t is not [TransitionReconcileInvalidState], and a typed
// *TransitionError otherwise.
func (tx *Tx) requireFromAllowed(t Transition) (database.Chat, ExecutionState, error) {
chat, from, err := tx.loadState()
if err != nil {
return chat, from, err
}
if from == StateInvalid && t != TransitionReconcileInvalidState {
return chat, from, ErrInvalidState
}
if err := requireExecutionTransition(t, from); err != nil {
return chat, from, err
}
return chat, from, nil
}
// Update applies one or more transitions to the machine's chat.
//
// Update opens a transaction on the captured store, atomically locks
// the chat row with FOR UPDATE and increments `snapshot_version`
// exactly once, then runs fn against a fresh [*Tx]. It constructs a
// [PublishBuffer], enqueues `chat:update` (and a `chat:ownership` hint
// when the post-transition state is worker-runnable and ownership is
// missing or stale) inside the transaction, and flushes the buffer only after
// the transaction function succeeds. If the transaction rolls back,
// the deferred Discard suppresses every buffered publication so
// subscribers never see uncommitted state.
//
// If Update is called with a store that is already in a transaction,
// [database.Store.InTx] reuses the active transaction. In that case,
// callers that need outer-transaction publication semantics can pass a
// [PublishBuffer] as the machine publisher. The inner buffer flushes
// into the outer buffer, and the outer owner remains responsible for
// publishing only after the outer transaction commits.
//
// Callers must not pass a store or publisher here; they belong on the
// machine.
//
// If the chat row does not exist, Update returns [ErrChatNotFound]
// without mutating anything.
//
// Callbacks that return an error roll back the transaction (rolling
// back the automatic snapshot bump) and publish nothing.
func (m *ChatMachine) Update(
ctx context.Context,
fn func(*Tx) error,
) error {
if m.store == nil {
return xerrors.New("chatstate: ChatMachine has nil store")
}
if m.publisher == nil {
return xerrors.New("chatstate: ChatMachine has nil publisher")
}
buffer := NewPublishBuffer(m.publisher)
defer buffer.Discard()
err := m.store.InTx(func(store database.Store) error {
if _, err := store.LockChatAndBumpSnapshotVersion(ctx, m.chatID); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return ErrChatNotFound
}
return xerrors.Errorf("lock chat and bump snapshot: %w", err)
}
tx := &Tx{
ctx: ctx,
store: store,
chatID: m.chatID,
}
if err := fn(tx); err != nil {
return err
}
chat, state, err := tx.loadState()
if err != nil {
return err
}
if err := buffer.Publish(
coderdpubsub.ChatStateUpdateChannel(chat.ID),
buildChatUpdateMessage(chat),
); err != nil {
return xerrors.Errorf("buffer chat update: %w", err)
}
if state.IsRunnable() {
stale, herr := ownershipStaleOrMissing(ctx, store, chat, HeartbeatStaleSeconds)
if herr != nil {
return xerrors.Errorf("evaluate ownership: %w", herr)
}
if stale {
if err := buffer.Publish(
coderdpubsub.ChatStateOwnershipChannel,
buildChatOwnershipMessage(chat),
); err != nil {
return xerrors.Errorf("buffer ownership hint: %w", err)
}
}
}
return nil
}, nil)
if err != nil {
return err
}
return buffer.Flush()
}
// Lock locks the chat row with FOR UPDATE and runs fn in a
// transaction without advancing snapshot_version. It uses the store
// captured by [NewChatMachine]. Use it when the caller needs a
// consistent chat snapshot plus related rows such as messages or
// queued messages but is NOT applying a transition.
//
// Callers must not pass a store here; it belongs on the machine.
//
// Lock publishes nothing. Callback errors roll back the transaction
// and propagate to the caller.
func (m *ChatMachine) Lock(
ctx context.Context,
fn func(database.Store) error,
) error {
if m.store == nil {
return xerrors.New("chatstate: ChatMachine has nil store")
}
return m.store.InTx(func(store database.Store) error {
// GetChatByIDForUpdate locks the row WITHOUT bumping snapshot.
_, err := store.GetChatByIDForUpdate(ctx, m.chatID)
if errors.Is(err, sql.ErrNoRows) {
return ErrChatNotFound
}
if err != nil {
return xerrors.Errorf("lock chat: %w", err)
}
return fn(store)
}, nil)
}
// ReadLock takes a shared lock on the chat row with FOR SHARE and runs
// fn in a transaction without advancing snapshot_version. It uses the
// store captured by [NewChatMachine]. Use it when the caller needs a
// consistent chat snapshot plus related rows such as messages or queued
// messages but is NOT applying a transition and does NOT need to block
// concurrent readers.
//
// Unlike [ChatMachine.Lock], the FOR SHARE lock permits other shared
// lockers to proceed concurrently while still blocking writers that take
// FOR UPDATE (such as [ChatMachine.Update] and [ChatMachine.Lock]) until
// the transaction commits.
//
// Callers must not pass a store here; it belongs on the machine.
//
// ReadLock publishes nothing. Callback errors roll back the transaction
// and propagate to the caller.
func (m *ChatMachine) ReadLock(
ctx context.Context,
fn func(database.Store) error,
) error {
if m.store == nil {
return xerrors.New("chatstate: ChatMachine has nil store")
}
return m.store.InTx(func(store database.Store) error {
// GetChatByIDForShare takes a shared lock on the row WITHOUT
// bumping snapshot.
_, err := store.GetChatByIDForShare(ctx, m.chatID)
if errors.Is(err, sql.ErrNoRows) {
return ErrChatNotFound
}
if err != nil {
return xerrors.Errorf("read lock chat: %w", err)
}
return fn(store)
}, nil)
}
// ownershipStaleOrMissing reports whether the chat's current
// (chat_id, runner_id) lease is missing or stale. The staleSeconds
// threshold is forwarded to [database.IsChatHeartbeatStale] so the
// comparison runs against database time inside a single SQL query.
func ownershipStaleOrMissing(ctx context.Context, store database.Store, chat database.Chat, staleSeconds int32) (bool, error) {
if !chat.WorkerID.Valid || !chat.RunnerID.Valid {
return true, nil
}
return store.IsChatHeartbeatStale(ctx, database.IsChatHeartbeatStaleParams{
ChatID: chat.ID,
RunnerID: chat.RunnerID.UUID,
StaleSeconds: staleSeconds,
})
}
+406
View File
@@ -0,0 +1,406 @@
package chatstate_test
import (
"context"
"encoding/json"
"sync"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/database/pubsub"
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
"github.com/coder/coder/v2/coderd/x/chatd/chatstate"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
)
// testFixture bundles the resources every integration test needs:
// a database, a publisher recorder, a user/org/model triple, and
// helper accessors. It is intentionally NOT a generic chatd test
// fixture; tests outside this package should not depend on it.
type testFixture struct {
DB database.Store
PubSub pubsub.Pubsub
Pub *recordingPubsub
User database.User
Org database.Organization
Model database.ChatModelConfig
}
func newTestFixture(t *testing.T) *testFixture {
t.Helper()
db, ps := dbtestutil.NewDB(t)
user := dbgen.User(t, db, database.User{})
org := dbgen.Organization(t, db, database.Organization{})
dbgen.OrganizationMember(t, db, database.OrganizationMember{
UserID: user.ID,
OrganizationID: org.ID,
})
dbgen.ChatProvider(t, db, database.ChatProvider{
Provider: "openai",
DisplayName: "openai",
BaseUrl: "http://example.invalid",
})
model := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{
Provider: "openai",
IsDefault: true,
})
pub := newRecordingPubsub()
return &testFixture{
DB: db,
PubSub: ps,
Pub: pub,
User: user,
Org: org,
Model: model,
}
}
// readChat re-reads the chat from the database. Tests use this to
// verify post-transition state because transition results no longer
// carry the chat snapshot.
func (f *testFixture) readChat(ctx context.Context, t *testing.T, chatID uuid.UUID) database.Chat {
t.Helper()
chat, err := f.DB.GetChatByID(ctx, chatID)
require.NoError(t, err)
return chat
}
// classify reads the chat plus queue cardinality and returns the RFC
// execution state shorthand.
func (f *testFixture) classify(ctx context.Context, t *testing.T, chatID uuid.UUID) chatstate.ExecutionState {
t.Helper()
chat := f.readChat(ctx, t, chatID)
count, err := f.DB.CountChatQueuedMessages(ctx, chatID)
require.NoError(t, err)
return chatstate.ClassifyExecutionState(chat, count > 0, true)
}
// recordingPubsub captures every Publish call so tests can assert on
// the chatstate notifications without needing a live subscriber. The
// mutex makes it safe to use from concurrent tests that race multiple
// goroutines through the same publisher (see TestConcurrentUpdatesSerializeOnChatRow).
type recordingPubsub struct {
mu sync.Mutex
channels []string
payloads [][]byte
}
func newRecordingPubsub() *recordingPubsub { return &recordingPubsub{} }
func (r *recordingPubsub) Publish(channel string, payload []byte) error {
r.mu.Lock()
defer r.mu.Unlock()
r.channels = append(r.channels, channel)
cp := make([]byte, len(payload))
copy(cp, payload)
r.payloads = append(r.payloads, cp)
return nil
}
// expectChatUpdate finds the most recent chat:update message on the
// per-chat channel and asserts that it has snapshot_version == want.
func (r *recordingPubsub) expectChatUpdate(t *testing.T, chatID uuid.UUID, wantSnapshot int64) {
t.Helper()
channel := coderdpubsub.ChatStateUpdateChannel(chatID)
for i := len(r.channels) - 1; i >= 0; i-- {
if r.channels[i] != channel {
continue
}
var msg coderdpubsub.ChatStateUpdateMessage
require.NoError(t, json.Unmarshal(r.payloads[i], &msg))
require.Equal(t, wantSnapshot, msg.SnapshotVersion)
return
}
t.Fatalf("no chat:update on %s", channel)
}
func (r *recordingPubsub) hasOwnership() bool {
for _, c := range r.channels {
if c == coderdpubsub.ChatStateOwnershipChannel {
return true
}
}
return false
}
func userTextMessage(text string, createdBy uuid.UUID, modelConfigID uuid.UUID) chatstate.Message {
parts := []codersdk.ChatMessagePart{codersdk.ChatMessageText(text)}
raw, err := chatprompt.MarshalParts(parts)
if err != nil {
panic(err)
}
return chatstate.Message{
Role: database.ChatMessageRoleUser,
Content: raw,
Visibility: database.ChatMessageVisibilityBoth,
ContentVersion: chatprompt.CurrentContentVersion,
CreatedBy: uuid.NullUUID{UUID: createdBy, Valid: true},
ModelConfigID: uuid.NullUUID{UUID: modelConfigID, Valid: true},
}
}
// createTestChat is the standard "fresh R0 chat" helper used by other
// tests. It exercises CreateChat itself.
func createTestChat(t *testing.T, f *testFixture) chatstate.CreateChatResult {
t.Helper()
ctx := testutil.Context(t, testutil.WaitShort)
res, err := chatstate.CreateChat(ctx, f.DB, f.Pub, chatstate.CreateChatInput{
OrganizationID: f.Org.ID,
OwnerID: f.User.ID,
LastModelConfigID: f.Model.ID,
Title: "test",
ClientType: database.ChatClientTypeApi,
InitialMessages: []chatstate.Message{
userTextMessage("hello", f.User.ID, f.Model.ID),
},
})
require.NoError(t, err)
return res
}
func TestChatMachine_Update_RejectsMissingChat(t *testing.T) {
t.Parallel()
f := newTestFixture(t)
ctx := testutil.Context(t, testutil.WaitShort)
m := chatstate.NewChatMachine(f.DB, f.Pub, uuid.New(), chatstate.Options{})
err := m.Update(ctx, func(tx *chatstate.Tx) error { return nil })
require.ErrorIs(t, err, chatstate.ErrChatNotFound)
require.Empty(t, f.Pub.channels)
}
func TestChatMachine_Lock_DoesNotBumpSnapshot(t *testing.T) {
t.Parallel()
f := newTestFixture(t)
ctx := testutil.Context(t, testutil.WaitShort)
created := createTestChat(t, f)
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
before := f.readChat(ctx, t, created.Chat.ID)
publishedBefore := len(f.Pub.channels)
require.NoError(t, m.Lock(ctx, func(_ database.Store) error {
return nil
}))
after := f.readChat(ctx, t, created.Chat.ID)
require.Equal(t, before.SnapshotVersion, after.SnapshotVersion)
require.Equal(t, publishedBefore, len(f.Pub.channels), "Lock must not publish")
}
func TestChatMachine_ReadLock_DoesNotBumpSnapshot(t *testing.T) {
t.Parallel()
f := newTestFixture(t)
ctx := testutil.Context(t, testutil.WaitShort)
created := createTestChat(t, f)
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
before := f.readChat(ctx, t, created.Chat.ID)
publishedBefore := len(f.Pub.channels)
var called bool
require.NoError(t, m.ReadLock(ctx, func(_ database.Store) error {
called = true
return nil
}))
require.True(t, called, "ReadLock must invoke the callback")
after := f.readChat(ctx, t, created.Chat.ID)
require.Equal(t, before.SnapshotVersion, after.SnapshotVersion)
require.Equal(t, publishedBefore, len(f.Pub.channels), "ReadLock must not publish")
}
func TestChatMachine_ReadLock_RejectsMissingChat(t *testing.T) {
t.Parallel()
f := newTestFixture(t)
ctx := testutil.Context(t, testutil.WaitShort)
m := chatstate.NewChatMachine(f.DB, f.Pub, uuid.New(), chatstate.Options{})
err := m.ReadLock(ctx, func(_ database.Store) error {
t.Fatal("callback must not run when the chat is missing")
return nil
})
require.ErrorIs(t, err, chatstate.ErrChatNotFound)
require.Empty(t, f.Pub.channels)
}
func TestChatMachine_UpdatePublishesAfterCommit(t *testing.T) {
t.Parallel()
f := newTestFixture(t)
ctx := testutil.Context(t, testutil.WaitShort)
created := createTestChat(t, f)
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
publishedBefore := len(f.Pub.channels)
// Run a no-op Update; snapshot bump still happens, one update message
// should follow the commit.
require.NoError(t, m.Update(ctx, func(_ *chatstate.Tx) error { return nil }))
channel := coderdpubsub.ChatStateUpdateChannel(created.Chat.ID)
var found bool
for _, c := range f.Pub.channels[publishedBefore:] {
if c == channel {
found = true
break
}
}
require.True(t, found, "expected one chat:update message after commit")
}
func TestChatMachine_FailedUpdate_PublishesNothing(t *testing.T) {
t.Parallel()
f := newTestFixture(t)
ctx := testutil.Context(t, testutil.WaitShort)
created := createTestChat(t, f)
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
before := f.readChat(ctx, t, created.Chat.ID)
channelsBefore := len(f.Pub.channels)
expected := newSentinel()
cbErr := m.Update(ctx, func(_ *chatstate.Tx) error { return expected })
require.ErrorIs(t, cbErr, expected)
require.Equal(t, channelsBefore, len(f.Pub.channels), "failed update should not publish")
// snapshot_version should not have advanced.
after := f.readChat(ctx, t, created.Chat.ID)
require.Equal(t, before.SnapshotVersion, after.SnapshotVersion)
}
func TestMessageRevisionTrigger_AssignsRevisionFromSnapshot(t *testing.T) {
t.Parallel()
f := newTestFixture(t)
ctx := testutil.Context(t, testutil.WaitShort)
created := createTestChat(t, f) // snapshot 1, history_version 1 via trigger
// CommitStep an assistant message; it should land with revision = chat.snapshot_version after the bump.
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
var step chatstate.CommitStepResult
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
assistant := userTextMessage("assistant", f.User.ID, f.Model.ID)
assistant.Role = database.ChatMessageRoleAssistant
var err error
step, err = tx.CommitStep(chatstate.CommitStepInput{
Messages: []chatstate.Message{assistant},
})
return err
}))
require.Len(t, step.InsertedMessages, 1)
after := f.readChat(ctx, t, created.Chat.ID)
// The Update call bumps snapshot_version once before the trigger
// runs, so the new revision should equal the bumped snapshot.
require.Equal(t, after.SnapshotVersion, step.InsertedMessages[0].Revision)
require.Equal(t, after.SnapshotVersion, after.HistoryVersion)
require.Equal(t, int64(0), after.GenerationAttempt, "trigger resets generation_attempt to 0")
}
func TestQueueVersionTrigger_AdvancesOnInsert(t *testing.T) {
t.Parallel()
f := newTestFixture(t)
ctx := testutil.Context(t, testutil.WaitShort)
created := createTestChat(t, f) // queue_version starts at 0
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
_, err := tx.SendMessage(chatstate.SendMessageInput{
Message: userTextMessage("queue", f.User.ID, f.Model.ID),
BusyBehavior: chatstate.BusyBehaviorQueue,
})
return err
}))
after := f.readChat(ctx, t, created.Chat.ID)
require.Equal(t, after.SnapshotVersion, after.QueueVersion)
require.Greater(t, after.QueueVersion, int64(0))
}
func TestQueueVersionTrigger_StableForNonQueueMutations(t *testing.T) {
t.Parallel()
f := newTestFixture(t)
ctx := testutil.Context(t, testutil.WaitShort)
created := createTestChat(t, f)
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
assistant := userTextMessage("assistant", f.User.ID, f.Model.ID)
assistant.Role = database.ChatMessageRoleAssistant
_, err := tx.CommitStep(chatstate.CommitStepInput{
Messages: []chatstate.Message{assistant},
})
return err
}))
// queue_version must remain unchanged from initial 0.
require.Equal(t, int64(0), f.readChat(ctx, t, created.Chat.ID).QueueVersion)
}
// TestUpdateFlushesBufferedPublicationsAfterCommit verifies that
// ChatMachine.Update owns the PublishBuffer lifecycle: nothing
// reaches the inner publisher until after the transaction commits,
// and at commit the buffered chat:update is forwarded.
func TestUpdateFlushesBufferedPublicationsAfterCommit(t *testing.T) {
t.Parallel()
f := newTestFixture(t)
ctx := testutil.Context(t, testutil.WaitShort)
created := createTestChat(t, f)
channel := coderdpubsub.ChatStateUpdateChannel(created.Chat.ID)
baseline := countChannel(f.Pub.channels, channel)
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
// During the callback, no new chat:update for this chat may have
// reached the inner publisher because the buffer holds it.
require.NoError(t, m.Update(ctx, func(_ *chatstate.Tx) error {
require.Equal(t, baseline, countChannel(f.Pub.channels, channel),
"inner publisher saw chat:update before transaction committed")
return nil
}))
require.Equal(t, baseline+1, countChannel(f.Pub.channels, channel),
"exactly one new chat:update reached the inner publisher after commit")
}
// TestUpdateDiscardsBufferedPublicationsOnCallbackError verifies the
// deferred Discard path: when the callback returns an error the
// transaction rolls back and no buffered messages reach the inner
// publisher.
func TestUpdateDiscardsBufferedPublicationsOnCallbackError(t *testing.T) {
t.Parallel()
f := newTestFixture(t)
ctx := testutil.Context(t, testutil.WaitShort)
created := createTestChat(t, f)
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
before := f.readChat(ctx, t, created.Chat.ID)
channelsBefore := len(f.Pub.channels)
sentinel := xerrors.New("callback boom")
err := m.Update(ctx, func(_ *chatstate.Tx) error { return sentinel })
require.ErrorIs(t, err, sentinel)
require.Equal(t, channelsBefore, len(f.Pub.channels),
"failed update must not flush any buffered publications")
after := f.readChat(ctx, t, created.Chat.ID)
require.Equal(t, before.SnapshotVersion, after.SnapshotVersion,
"snapshot bump rolled back when callback returns error")
}
// =============================================================================
// helpers
// =============================================================================
type sentinelError struct{ msg string }
func (s *sentinelError) Error() string { return s.msg }
func newSentinel() error { return &sentinelError{msg: "sentinel"} }
func countChannel(channels []string, channel string) int {
c := 0
for _, ch := range channels {
if ch == channel {
c++
}
}
return c
}
+112
View File
@@ -0,0 +1,112 @@
package chatstate
import (
"database/sql"
"github.com/google/uuid"
"github.com/sqlc-dev/pqtype"
"github.com/coder/coder/v2/coderd/database"
)
// Message is the durable message input shape used by chatstate
// transitions. It is intentionally lower level than the SDK message
// request types: callers must produce a fully materialized message
// (parsed parts, calculated cost, resolved model config) before
// passing it in.
//
// The state machine never reshapes a Message except to attach the
// runtime `chat_id`. The `revision` column is assigned by the
// `set_chat_message_revision` trigger; runtime code must not populate
// it.
type Message struct {
Role database.ChatMessageRole
Content pqtype.NullRawMessage
Visibility database.ChatMessageVisibility
ModelConfigID uuid.NullUUID
CreatedBy uuid.NullUUID
ContentVersion int16
Compressed bool
InputTokens sql.NullInt64
OutputTokens sql.NullInt64
TotalTokens sql.NullInt64
ReasoningTokens sql.NullInt64
CacheCreationTokens sql.NullInt64
CacheReadTokens sql.NullInt64
ContextLimit sql.NullInt64
TotalCostMicros sql.NullInt64
RuntimeMs sql.NullInt64
ProviderResponseID sql.NullString
}
// toInsertParams converts a batch of Messages into the parallel-array
// shape required by `InsertChatMessages`. The returned struct has all
// arrays sized to len(messages).
//
// The chat ID is supplied by the caller because Message itself does
// not carry one (the chat machine already knows the chat).
func toInsertParams(chatID uuid.UUID, messages []Message) database.InsertChatMessagesParams {
n := len(messages)
params := database.InsertChatMessagesParams{
ChatID: chatID,
CreatedBy: make([]uuid.UUID, n),
ModelConfigID: make([]uuid.UUID, n),
Role: make([]database.ChatMessageRole, n),
Content: make([]string, n),
ContentVersion: make([]int16, n),
Visibility: make([]database.ChatMessageVisibility, n),
InputTokens: make([]int64, n),
OutputTokens: make([]int64, n),
TotalTokens: make([]int64, n),
ReasoningTokens: make([]int64, n),
CacheCreationTokens: make([]int64, n),
CacheReadTokens: make([]int64, n),
ContextLimit: make([]int64, n),
Compressed: make([]bool, n),
TotalCostMicros: make([]int64, n),
RuntimeMs: make([]int64, n),
ProviderResponseID: make([]string, n),
}
for i, m := range messages {
params.CreatedBy[i] = nullUUIDOrNil(m.CreatedBy)
params.ModelConfigID[i] = nullUUIDOrNil(m.ModelConfigID)
params.Role[i] = m.Role
if m.Content.Valid {
params.Content[i] = string(m.Content.RawMessage)
} else {
// Use the JSON null literal; UNNEST + ::jsonb requires a
// valid JSON value and the trigger leaves it untouched.
params.Content[i] = "null"
}
params.ContentVersion[i] = m.ContentVersion
params.Visibility[i] = m.Visibility
params.InputTokens[i] = nullInt64Or(m.InputTokens, 0)
params.OutputTokens[i] = nullInt64Or(m.OutputTokens, 0)
params.TotalTokens[i] = nullInt64Or(m.TotalTokens, 0)
params.ReasoningTokens[i] = nullInt64Or(m.ReasoningTokens, 0)
params.CacheCreationTokens[i] = nullInt64Or(m.CacheCreationTokens, 0)
params.CacheReadTokens[i] = nullInt64Or(m.CacheReadTokens, 0)
params.ContextLimit[i] = nullInt64Or(m.ContextLimit, 0)
params.Compressed[i] = m.Compressed
params.TotalCostMicros[i] = nullInt64Or(m.TotalCostMicros, 0)
params.RuntimeMs[i] = nullInt64Or(m.RuntimeMs, 0)
if m.ProviderResponseID.Valid {
params.ProviderResponseID[i] = m.ProviderResponseID.String
}
}
return params
}
func nullUUIDOrNil(u uuid.NullUUID) uuid.UUID {
if u.Valid {
return u.UUID
}
return uuid.Nil
}
func nullInt64Or(v sql.NullInt64, fallback int64) int64 {
if v.Valid {
return v.Int64
}
return fallback
}
+169
View File
@@ -0,0 +1,169 @@
package chatstate
import (
"encoding/json"
"fmt"
"sync"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database"
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
)
// Publisher is the minimal interface chatstate needs to publish
// pubsub messages. It is intentionally compatible with
// database/pubsub.Pubsub: real callers pass the live pubsub directly
// and tests pass a recording fake. Chatstate entry points
// ([ChatMachine.Update], [CreateChat], [SetFamilyArchived]) own the
// internal [PublishBuffer] lifecycle so callers never see it.
type Publisher interface {
Publish(event string, message []byte) error
}
// PublishBuffer is a [Publisher] that records each Publish call in
// order without forwarding it until [PublishBuffer.Flush] is called.
// It is an internal primitive used by chatstate entry points to
// hold pubsub messages until the surrounding transaction commits,
// and by tests that need to observe buffered output. Normal callers
// do not construct a PublishBuffer themselves and do not invoke
// Flush or Discard; chatstate's entry points own that lifecycle.
type PublishBuffer struct {
inner Publisher
mu sync.Mutex
pending []bufferedMessage
flushed bool
disabled bool
}
type bufferedMessage struct {
Channel string
Payload []byte
}
// NewPublishBuffer constructs a PublishBuffer that, when flushed, will
// forward messages in order to inner.
func NewPublishBuffer(inner Publisher) *PublishBuffer {
return &PublishBuffer{inner: inner}
}
// Publish records a message. It never forwards to the inner publisher
// until [PublishBuffer.Flush] is called. Returns an error if Flush has
// already happened to make accidental reuse obvious.
func (b *PublishBuffer) Publish(channel string, payload []byte) error {
b.mu.Lock()
defer b.mu.Unlock()
if b.flushed {
return xerrors.Errorf("publish buffer flushed; cannot accept message for %q", channel)
}
if b.disabled {
return nil
}
cp := make([]byte, len(payload))
copy(cp, payload)
b.pending = append(b.pending, bufferedMessage{Channel: channel, Payload: cp})
return nil
}
// Flush forwards every pending message to the inner publisher in the
// order it was buffered, then marks the buffer flushed. The first
// publish error is returned with the channel name annotated; later
// messages still attempt to publish so the inner publisher sees a
// consistent state.
func (b *PublishBuffer) Flush() error {
b.mu.Lock()
defer b.mu.Unlock()
if b.flushed {
return nil
}
b.flushed = true
var firstErr error
for _, msg := range b.pending {
if err := b.inner.Publish(msg.Channel, msg.Payload); err != nil && firstErr == nil {
firstErr = xerrors.Errorf("publish %s: %w", msg.Channel, err)
}
}
return firstErr
}
// Discard clears the buffered messages without forwarding them. It
// is safe to call multiple times and is harmless after [PublishBuffer.Flush]:
// once Flush has marked the buffer flushed and forwarded its
// pending messages, a subsequent Discard simply clears the (now
// empty) pending slice and sets the buffer to drop any future
// Publish calls. This makes `defer buf.Discard()` a safe pattern
// after a successful flush, including the one chatstate entry
// points use to own the buffer lifecycle.
func (b *PublishBuffer) Discard() {
b.mu.Lock()
defer b.mu.Unlock()
b.pending = nil
b.disabled = true
}
// pending returns a snapshot of the buffered messages, primarily for
// tests via [PublishBuffer.BufferedChannels]. The returned slice is a
// copy and safe to inspect without holding the buffer lock.
func (b *PublishBuffer) snapshotPending() []bufferedMessage {
b.mu.Lock()
defer b.mu.Unlock()
out := make([]bufferedMessage, len(b.pending))
copy(out, b.pending)
return out
}
// BufferedChannels returns just the channels of the pending messages
// in order. Primarily useful for assertions in tests.
func (b *PublishBuffer) BufferedChannels() []string {
pending := b.snapshotPending()
out := make([]string, len(pending))
for i, m := range pending {
out[i] = m.Channel
}
return out
}
// buildChatUpdateMessage produces the JSON payload for a
// `chat:update:{chat_id}` message describing the post-transition
// snapshot of chat.
func buildChatUpdateMessage(chat database.Chat) []byte {
msg := coderdpubsub.ChatStateUpdateMessage{
SnapshotVersion: chat.SnapshotVersion,
HistoryVersion: chat.HistoryVersion,
QueueVersion: chat.QueueVersion,
RetryStateVersion: chat.RetryStateVersion,
GenerationAttempt: chat.GenerationAttempt,
Status: string(chat.Status),
Archived: chat.Archived,
}
if chat.WorkerID.Valid {
id := chat.WorkerID.UUID
msg.WorkerID = &id
}
if chat.RunnerID.Valid {
id := chat.RunnerID.UUID
msg.RunnerID = &id
}
payload, err := json.Marshal(msg)
if err != nil {
// json.Marshal on this struct is total; panic is acceptable
// because the only failure mode would be a bug in this
// package, not user input.
panic(fmt.Sprintf("marshal chat state update: %v", err))
}
return payload
}
// buildChatOwnershipMessage produces the JSON payload for the global
// `chat:ownership` ownership hint for chat.
func buildChatOwnershipMessage(chat database.Chat) []byte {
payload, err := json.Marshal(coderdpubsub.ChatStateOwnershipMessage{
ChatID: chat.ID,
SnapshotVersion: chat.SnapshotVersion,
})
if err != nil {
panic(fmt.Sprintf("marshal chat state ownership: %v", err))
}
return payload
}
@@ -0,0 +1,373 @@
package chatstate_test
import (
"context"
"encoding/json"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database"
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
"github.com/coder/coder/v2/coderd/x/chatd/chatstate"
"github.com/coder/coder/v2/testutil"
)
// =============================================================================
// Publication helpers
// =============================================================================
// publishedOn returns the indices into f.Pub.channels (and f.Pub.payloads)
// that match the given channel name, in order.
func publishedOn(f *testFixture, channel string) []int {
var idx []int
for i, c := range f.Pub.channels {
if c == channel {
idx = append(idx, i)
}
}
return idx
}
// TestCreateChatPublishesAfterCommit asserts that a successful
// CreateChat call publishes exactly one chat:update message on the
// per-chat channel after the inner transaction commits.
func TestCreateChatPublishesAfterCommit(t *testing.T) {
t.Parallel()
f := newTestFixture(t)
res := createTestChat(t, f)
channel := coderdpubsub.ChatStateUpdateChannel(res.Chat.ID)
idx := publishedOn(f, channel)
require.Len(t, idx, 1, "exactly one chat:update for the new chat")
var msg coderdpubsub.ChatStateUpdateMessage
require.NoError(t, json.Unmarshal(f.Pub.payloads[idx[0]], &msg))
require.Equal(t, res.Chat.SnapshotVersion, msg.SnapshotVersion)
require.Equal(t, string(database.ChatStatusRunning), msg.Status)
}
// TestUpdatePublishesAfterCommit asserts that ChatMachine.Update
// publishes one chat:update on the per-chat channel after the inner
// transaction commits, even when the callback performs no transition.
func TestUpdatePublishesAfterCommit(t *testing.T) {
t.Parallel()
f := newTestFixture(t)
ctx := testutil.Context(t, testutil.WaitShort)
created := createTestChat(t, f)
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
createIdx := publishedOn(f, coderdpubsub.ChatStateUpdateChannel(created.Chat.ID))
require.Len(t, createIdx, 1, "create published one chat:update")
require.NoError(t, m.Update(ctx, func(_ *chatstate.Tx) error { return nil }))
updIdx := publishedOn(f, coderdpubsub.ChatStateUpdateChannel(created.Chat.ID))
require.Len(t, updIdx, 2, "no-op Update still publishes a chat:update")
after, err := f.DB.GetChatByID(ctx, created.Chat.ID)
require.NoError(t, err)
var msg coderdpubsub.ChatStateUpdateMessage
require.NoError(t, json.Unmarshal(f.Pub.payloads[updIdx[1]], &msg))
require.Equal(t, after.SnapshotVersion, msg.SnapshotVersion)
}
// TestUpdatePublishesOneFinalChatUpdateForTransitionBundle bundles
// several transitions inside one Update callback and verifies the
// commit publishes exactly one chat:update on the per-chat channel
// (not one per transition).
func TestUpdatePublishesOneFinalChatUpdateForTransitionBundle(t *testing.T) {
t.Parallel()
f := newTestFixture(t)
ctx := testutil.Context(t, testutil.WaitShort)
created := createTestChat(t, f)
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
baseUpdates := len(publishedOn(f, coderdpubsub.ChatStateUpdateChannel(created.Chat.ID)))
// R0 -> W (FinishTurn) -> XW (SetArchived true) -> W (SetArchived false).
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
if _, err := tx.FinishTurn(chatstate.FinishTurnInput{}); err != nil {
return err
}
if _, err := tx.SetArchived(chatstate.SetArchivedInput{Archived: true}); err != nil {
return err
}
if _, err := tx.SetArchived(chatstate.SetArchivedInput{Archived: false}); err != nil {
return err
}
return nil
}))
updIdx := publishedOn(f, coderdpubsub.ChatStateUpdateChannel(created.Chat.ID))
require.Equal(t, baseUpdates+1, len(updIdx),
"three-transition bundle publishes exactly one final chat:update")
after, err := f.DB.GetChatByID(ctx, created.Chat.ID)
require.NoError(t, err)
require.Equal(t, database.ChatStatusWaiting, after.Status, "ends in W")
}
// TestUpdateAppliesTransitionBundleSequentially verifies that
// transitions chained inside a single Update callback see each
// other's effects: later transitions validate against the state
// produced by earlier ones (R0 -> W is rejected when called twice
// because the second call sees state W and FinishTurn is no longer
// allowed).
func TestUpdateAppliesTransitionBundleSequentially(t *testing.T) {
t.Parallel()
f := newTestFixture(t)
ctx := testutil.Context(t, testutil.WaitShort)
created := createTestChat(t, f)
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
err := m.Update(ctx, func(tx *chatstate.Tx) error {
if _, err := tx.FinishTurn(chatstate.FinishTurnInput{}); err != nil {
return err
}
// Second FinishTurn should fail because state is now W.
_, err := tx.FinishTurn(chatstate.FinishTurnInput{})
return err
})
require.Error(t, err)
require.ErrorIs(t, err, chatstate.ErrTransitionNotAllowed)
// Failed bundle rolls back: state must not have advanced past R0.
require.Equal(t, chatstate.StateR0, f.classify(ctx, t, created.Chat.ID),
"failed bundle rolls back the whole transaction")
}
// TestFailedUpdatePublishesNothing verifies that a callback error
// rolls back the snapshot bump and publishes nothing.
func TestFailedUpdatePublishesNothing(t *testing.T) {
t.Parallel()
f := newTestFixture(t)
ctx := testutil.Context(t, testutil.WaitShort)
created := createTestChat(t, f)
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
publishedBefore := len(f.Pub.channels)
beforeChat := f.readChat(ctx, t, created.Chat.ID)
sentinel := xerrors.New("forced failure")
err := m.Update(ctx, func(_ *chatstate.Tx) error { return sentinel })
require.ErrorIs(t, err, sentinel)
require.Equal(t, publishedBefore, len(f.Pub.channels), "failed update publishes nothing")
after := f.readChat(ctx, t, created.Chat.ID)
require.Equal(t, beforeChat.SnapshotVersion, after.SnapshotVersion,
"failed update rolls back snapshot bump")
}
// TestLockPublishesNothing verifies that Lock does not publish even
// though it locks the chat row.
func TestLockPublishesNothing(t *testing.T) {
t.Parallel()
f := newTestFixture(t)
ctx := testutil.Context(t, testutil.WaitShort)
created := createTestChat(t, f)
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
publishedBefore := len(f.Pub.channels)
require.NoError(t, m.Lock(ctx, func(_ database.Store) error { return nil }))
require.Equal(t, publishedBefore, len(f.Pub.channels), "Lock publishes nothing")
}
// TestPublishBufferWithRolledBackOuterTransactionPublishesNothing
// wires a chatstate machine through a PublishBuffer and exercises
// the buffer primitive directly: when the caller discards before
// flushing, the inner publisher receives nothing. ChatMachine.Update
// uses the same primitive internally with a deferred Discard;
// callers no longer drive Flush or Discard themselves.
func TestPublishBufferWithRolledBackOuterTransactionPublishesNothing(t *testing.T) {
t.Parallel()
f := newTestFixture(t)
ctx := testutil.Context(t, testutil.WaitShort)
created := createTestChat(t, f)
// Run one normal Update to establish a stable baseline channel
// count. CreateChat plus this Update may publish chat:update
// and chat:ownership messages depending on ownership, so we
// take the snapshot after that activity settles.
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
require.NoError(t, m.Update(ctx, func(_ *chatstate.Tx) error { return nil }))
baseline := len(f.Pub.channels)
// Now exercise the PublishBuffer rollback path explicitly. The
// outer transaction "rolls back": the caller buffers messages,
// discards them, then flushes. The inner publisher must see
// none of the buffered messages.
buf := chatstate.NewPublishBuffer(f.Pub)
require.NoError(t, buf.Publish("chat:update:bogus", []byte("payload")))
require.NoError(t, buf.Publish("chat:ownership", []byte("payload")))
buf.Discard()
require.NoError(t, buf.Flush())
require.Equal(t, baseline, len(f.Pub.channels),
"discarded buffer publishes nothing through the inner publisher")
}
// TestChatUpdateMessagePayloadShape verifies the JSON shape of the
// chat:update payload contains every field consumers depend on:
// snapshot_version, history_version, queue_version,
// retry_state_version, generation_attempt, status, archived, and
// optional worker_id / runner_id.
func TestChatUpdateMessagePayloadShape(t *testing.T) {
t.Parallel()
f := newTestFixture(t)
ctx := testutil.Context(t, testutil.WaitShort)
created := createTestChat(t, f)
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
// Acquire ownership so worker_id and runner_id are present.
worker := uuid.New()
runner := uuid.New()
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
_, err := tx.Acquire(chatstate.AcquireInput{WorkerID: worker, RunnerID: runner})
return err
}))
// Find the last chat:update message.
channel := coderdpubsub.ChatStateUpdateChannel(created.Chat.ID)
idx := publishedOn(f, channel)
require.NotEmpty(t, idx)
last := f.Pub.payloads[idx[len(idx)-1]]
// Strict-decode against the typed struct.
var typed coderdpubsub.ChatStateUpdateMessage
require.NoError(t, json.Unmarshal(last, &typed))
require.Greater(t, typed.SnapshotVersion, int64(0))
require.NotNil(t, typed.WorkerID)
require.Equal(t, worker, *typed.WorkerID)
require.NotNil(t, typed.RunnerID)
require.Equal(t, runner, *typed.RunnerID)
require.Equal(t, string(database.ChatStatusRunning), typed.Status)
require.False(t, typed.Archived)
// Permissive decode to assert exact JSON keys.
var raw map[string]json.RawMessage
require.NoError(t, json.Unmarshal(last, &raw))
for _, key := range []string{
"snapshot_version",
"history_version",
"queue_version",
"retry_state_version",
"generation_attempt",
"status",
"archived",
"worker_id",
"runner_id",
} {
_, ok := raw[key]
require.True(t, ok, "payload missing key %q", key)
}
}
// TestChatOwnershipMessagePayloadShape verifies the JSON shape of
// chat:ownership: chat_id and snapshot_version.
func TestChatOwnershipMessagePayloadShape(t *testing.T) {
t.Parallel()
f := newTestFixture(t)
// CreateChat publishes one ownership hint because the new chat is
// unowned and runnable.
created := createTestChat(t, f)
idx := publishedOn(f, coderdpubsub.ChatStateOwnershipChannel)
require.NotEmpty(t, idx, "CreateChat publishes at least one chat:ownership hint")
payload := f.Pub.payloads[idx[len(idx)-1]]
var typed coderdpubsub.ChatStateOwnershipMessage
require.NoError(t, json.Unmarshal(payload, &typed))
require.Equal(t, created.Chat.ID, typed.ChatID)
require.Greater(t, typed.SnapshotVersion, int64(0))
var raw map[string]json.RawMessage
require.NoError(t, json.Unmarshal(payload, &raw))
for _, key := range []string{"chat_id", "snapshot_version"} {
_, ok := raw[key]
require.True(t, ok, "ownership payload missing key %q", key)
}
}
// TestOwnershipNotificationUsesDatabaseHeartbeatStaleness verifies
// that an ownership hint fires when the heartbeat is stale by the
// database's clock, regardless of what the local Go clock says. We
// rewrite the heartbeat row to a deterministically old timestamp via
// raw SQL and confirm the post-commit hint is sent on a subsequent
// runnable Update.
func TestOwnershipNotificationUsesDatabaseHeartbeatStaleness(t *testing.T) {
t.Parallel()
tf := newTriggerFixture(t)
f := tf.f
ctx := testutil.Context(t, testutil.WaitShort)
created := createTestChat(t, f)
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
// Acquire ownership; this writes a fresh heartbeat.
worker := uuid.New()
runner := uuid.New()
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
_, err := tx.Acquire(chatstate.AcquireInput{WorkerID: worker, RunnerID: runner})
return err
}))
hb, err := f.DB.GetChatHeartbeat(ctx, database.GetChatHeartbeatParams{
ChatID: created.Chat.ID,
RunnerID: runner,
})
require.NoError(t, err)
require.WithinDuration(t, time.Now(), hb.HeartbeatAt, time.Minute,
"Acquire wrote a fresh heartbeat")
// Snapshot ownership-hint count before the test trigger.
ownershipBefore := f.Pub.ownershipPublishCount()
// Force the heartbeat to a deterministically old time.
_, err = tf.sqlDB.ExecContext(ctx, `
UPDATE chat_heartbeats
SET heartbeat_at = NOW() - INTERVAL '1 hour'
WHERE chat_id = $1 AND runner_id = $2
`, created.Chat.ID, runner)
require.NoError(t, err)
// Confirm database-side staleness check agrees.
stale, err := f.DB.IsChatHeartbeatStale(ctx, database.IsChatHeartbeatStaleParams{
ChatID: created.Chat.ID,
RunnerID: runner,
StaleSeconds: chatstate.HeartbeatStaleSeconds,
})
require.NoError(t, err)
require.True(t, stale, "heartbeat is stale per database time")
// Run a no-op Update. The chat is runnable (R0) and the
// heartbeat is stale, so post-commit logic must publish exactly
// one chat:ownership hint.
require.NoError(t, m.Update(ctx, func(_ *chatstate.Tx) error { return nil }))
ownershipAfter := f.Pub.ownershipPublishCount()
require.Equal(t, ownershipBefore+1, ownershipAfter,
"stale heartbeat triggers a fresh ownership hint")
}
// TestUpdateContextCancellationPublishesNothing verifies that
// canceling the caller's context (between the inner commit and the
// publish loop's first call) does not corrupt state. We exercise the
// simpler observable contract: when the user cancels before Update
// gets to do anything, nothing is published. The strict before-publish
// race is exercised in concurrency tests with channel sync.
func TestUpdateContextCancellationPublishesNothing(t *testing.T) {
t.Parallel()
f := newTestFixture(t)
created := createTestChat(t, f)
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
publishedBefore := len(f.Pub.channels)
ctx, cancel := context.WithCancel(context.Background())
cancel()
err := m.Update(ctx, func(_ *chatstate.Tx) error { return nil })
require.Error(t, err)
require.Equal(t, publishedBefore, len(f.Pub.channels),
"caller-aborted update publishes nothing")
}
@@ -0,0 +1,113 @@
package chatstate
import (
"testing"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
)
type recordingPublisher struct {
calls []recordedCall
errOn map[string]error
failed map[string]int
}
type recordedCall struct {
Channel string
Payload []byte
}
func newRecordingPublisher() *recordingPublisher {
return &recordingPublisher{
errOn: map[string]error{},
failed: map[string]int{},
}
}
func (r *recordingPublisher) Publish(channel string, payload []byte) error {
r.calls = append(r.calls, recordedCall{Channel: channel, Payload: append([]byte(nil), payload...)})
if err, ok := r.errOn[channel]; ok {
r.failed[channel]++
return err
}
return nil
}
func TestPublishBuffer_DefersPublishUntilFlush(t *testing.T) {
t.Parallel()
inner := newRecordingPublisher()
buf := NewPublishBuffer(inner)
require.NoError(t, buf.Publish("a", []byte("1")))
require.NoError(t, buf.Publish("b", []byte("2")))
require.Empty(t, inner.calls, "inner publisher should not be called before flush")
require.Equal(t, []string{"a", "b"}, buf.BufferedChannels())
}
func TestPublishBuffer_FlushPublishesInOrder(t *testing.T) {
t.Parallel()
inner := newRecordingPublisher()
buf := NewPublishBuffer(inner)
require.NoError(t, buf.Publish("a", []byte("1")))
require.NoError(t, buf.Publish("b", []byte("2")))
require.NoError(t, buf.Publish("c", []byte("3")))
require.NoError(t, buf.Flush())
require.Len(t, inner.calls, 3)
require.Equal(t, "a", inner.calls[0].Channel)
require.Equal(t, "b", inner.calls[1].Channel)
require.Equal(t, "c", inner.calls[2].Channel)
require.Equal(t, []byte("1"), inner.calls[0].Payload)
}
func TestPublishBuffer_FlushReturnsFirstError(t *testing.T) {
t.Parallel()
inner := newRecordingPublisher()
inner.errOn["b"] = xerrors.New("broken")
buf := NewPublishBuffer(inner)
require.NoError(t, buf.Publish("a", []byte("1")))
require.NoError(t, buf.Publish("b", []byte("2")))
require.NoError(t, buf.Publish("c", []byte("3")))
err := buf.Flush()
require.Error(t, err)
require.Contains(t, err.Error(), "publish b:")
// Even after the broken channel, later messages should still be
// attempted so the inner publisher sees them.
require.Len(t, inner.calls, 3)
}
func TestPublishBuffer_PublishAfterFlushFails(t *testing.T) {
t.Parallel()
inner := newRecordingPublisher()
buf := NewPublishBuffer(inner)
require.NoError(t, buf.Flush())
require.Error(t, buf.Publish("x", []byte("y")))
}
func TestPublishBuffer_DiscardSuppressesPending(t *testing.T) {
t.Parallel()
inner := newRecordingPublisher()
buf := NewPublishBuffer(inner)
require.NoError(t, buf.Publish("a", []byte("1")))
buf.Discard()
require.NoError(t, buf.Flush())
require.Empty(t, inner.calls)
}
func TestPublishBuffer_DiscardBlocksLaterPublishes(t *testing.T) {
t.Parallel()
inner := newRecordingPublisher()
buf := NewPublishBuffer(inner)
buf.Discard()
// Discard sets disabled; subsequent Publish is a no-op (not an
// error) so callers using Discard before/around rollback paths
// do not have to special-case unwind.
require.NoError(t, buf.Publish("a", []byte("1")))
require.NoError(t, buf.Flush())
require.Empty(t, inner.calls)
}
+182
View File
@@ -0,0 +1,182 @@
package chatstate
import (
"github.com/coder/coder/v2/coderd/database"
)
// ExecutionState is the RFC shorthand for a chat's current execution
// state. The valid set contains 13 states; everything outside it is
// represented by [StateInvalid].
type ExecutionState string
const (
// StateN: chat does not exist.
StateN ExecutionState = "N"
// StateW: waiting, empty queue, not archived.
StateW ExecutionState = "W"
// StateE0: error, empty queue, not archived.
StateE0 ExecutionState = "E0"
// StateE1: error, non-empty queue, not archived.
StateE1 ExecutionState = "E1"
// StateR0: running, empty queue, not archived.
StateR0 ExecutionState = "R0"
// StateR1: running, non-empty queue, not archived.
StateR1 ExecutionState = "R1"
// StateI0: interrupting, empty queue, not archived.
StateI0 ExecutionState = "I0"
// StateI1: interrupting, non-empty queue, not archived.
StateI1 ExecutionState = "I1"
// StateA0: requires_action, empty queue, not archived.
StateA0 ExecutionState = "A0"
// StateA1: requires_action, non-empty queue, not archived.
StateA1 ExecutionState = "A1"
// StateXW: archived waiting, empty queue.
StateXW ExecutionState = "XW"
// StateXE0: archived error, empty queue.
StateXE0 ExecutionState = "XE0"
// StateXE1: archived error, non-empty queue.
StateXE1 ExecutionState = "XE1"
// StateInvalid groups every status/archive/queue combination that
// is not one of the 13 valid states above. The state machine
// refuses non-reconciliation transitions on invalid states and
// exposes the [Tx.ReconcileInvalidState] transition to recover.
StateInvalid ExecutionState = "Invalid"
)
// String implements fmt.Stringer.
func (s ExecutionState) String() string { return string(s) }
// AllExecutionStates is the canonical enumeration of every value the
// classifier can return. Tests rely on this list to iterate over every
// state when verifying transition coverage.
var AllExecutionStates = []ExecutionState{
StateN,
StateW,
StateE0,
StateE1,
StateR0,
StateR1,
StateI0,
StateI1,
StateA0,
StateA1,
StateXW,
StateXE0,
StateXE1,
StateInvalid,
}
// IsRunnable returns true for the execution states that the chat
// worker is allowed to acquire and drive forward: R0, R1, I0, I1,
// A0, and A1. Requires-action states need worker ownership for
// timeout processing. Other states are idle (W, E*, XW, XE*), absent
// (N), or invalid.
func (s ExecutionState) IsRunnable() bool {
switch s {
case StateR0, StateR1, StateI0, StateI1, StateA0, StateA1:
return true
default:
return false
}
}
// IsArchived returns true for the three archived execution states.
func (s ExecutionState) IsArchived() bool {
switch s {
case StateXW, StateXE0, StateXE1:
return true
default:
return false
}
}
// QueueNonEmpty returns true for execution states that require a
// non-empty queue. Useful when seeding test fixtures.
func (s ExecutionState) QueueNonEmpty() bool {
switch s {
case StateE1, StateR1, StateI1, StateA1, StateXE1:
return true
default:
return false
}
}
// ClassifyExecutionState turns the chat row, queue cardinality, and
// whether the chat row exists into one of the 14 [ExecutionState]
// values. The caller is responsible for loading the chat under the
// row lock and reading the queue count in the same transaction.
//
// Callers that have no chat row (lookup returned sql.ErrNoRows)
// should pass exists=false; the chat, status, and archive arguments
// are then ignored.
//
// The classifier is a single flat switch over the 13 valid
// (status, archived, queue) tuples from the RFC. Anything outside
// that set (legacy pending/paused/completed statuses, archived busy
// states, waiting with a non-empty queue, future enum values) falls
// through to [StateInvalid].
//
//nolint:revive // queueNonEmpty/exists are simple flags that mirror the RFC inputs.
func ClassifyExecutionState(chat database.Chat, queueNonEmpty, exists bool) ExecutionState {
if !exists {
return StateN
}
switch {
case chat.Status == database.ChatStatusWaiting && !chat.Archived && !queueNonEmpty:
return StateW
case chat.Status == database.ChatStatusWaiting && chat.Archived && !queueNonEmpty:
return StateXW
case chat.Status == database.ChatStatusError && !chat.Archived && !queueNonEmpty:
return StateE0
case chat.Status == database.ChatStatusError && !chat.Archived && queueNonEmpty:
return StateE1
case chat.Status == database.ChatStatusError && chat.Archived && !queueNonEmpty:
return StateXE0
case chat.Status == database.ChatStatusError && chat.Archived && queueNonEmpty:
return StateXE1
case chat.Status == database.ChatStatusRunning && !chat.Archived && !queueNonEmpty:
return StateR0
case chat.Status == database.ChatStatusRunning && !chat.Archived && queueNonEmpty:
return StateR1
case chat.Status == database.ChatStatusInterrupting && !chat.Archived && !queueNonEmpty:
return StateI0
case chat.Status == database.ChatStatusInterrupting && !chat.Archived && queueNonEmpty:
return StateI1
case chat.Status == database.ChatStatusRequiresAction && !chat.Archived && !queueNonEmpty:
return StateA0
case chat.Status == database.ChatStatusRequiresAction && !chat.Archived && queueNonEmpty:
return StateA1
}
return StateInvalid
}
// OwnershipState is the RFC shorthand for whether a chat row is
// currently owned by a worker. The state machine treats execution and
// ownership as orthogonal.
type OwnershipState string
const (
// StateU: chat has no owner (worker_id IS NULL).
StateU OwnershipState = "U"
// StateO: chat has an owner (worker_id IS NOT NULL).
StateO OwnershipState = "O"
)
// String implements fmt.Stringer.
func (s OwnershipState) String() string { return string(s) }
// AllOwnershipStates is the canonical enumeration of ownership states.
var AllOwnershipStates = []OwnershipState{StateU, StateO}
// ClassifyOwnershipState returns [StateU] when worker_id is NULL and
// [StateO] otherwise. The runner_id field is intentionally not part of
// the ownership classification. Acquire writes both worker_id and
// runner_id; Abandon clears the current row without using runner_id as
// a transition input.
func ClassifyOwnershipState(chat database.Chat) OwnershipState {
if !chat.WorkerID.Valid {
return StateU
}
return StateO
}
@@ -0,0 +1,195 @@
package chatstate
import (
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/database"
)
func chatWithStatus(status database.ChatStatus, archived bool) database.Chat {
return database.Chat{
ID: uuid.New(),
Status: status,
Archived: archived,
OwnerID: uuid.New(),
}
}
// TestClassifyExecutionState_Valid covers every valid classification:
// N (missing chat) plus the twelve valid existing-chat states.
func TestClassifyExecutionState_Valid(t *testing.T) {
t.Parallel()
cases := []struct {
name string
status database.ChatStatus
archived bool
queueNonEmpty bool
exists bool
want ExecutionState
}{
{name: "N", exists: false, want: StateN},
{name: "W", status: database.ChatStatusWaiting, exists: true, want: StateW},
{name: "E0", status: database.ChatStatusError, exists: true, want: StateE0},
{name: "E1", status: database.ChatStatusError, queueNonEmpty: true, exists: true, want: StateE1},
{name: "R0", status: database.ChatStatusRunning, exists: true, want: StateR0},
{name: "R1", status: database.ChatStatusRunning, queueNonEmpty: true, exists: true, want: StateR1},
{name: "I0", status: database.ChatStatusInterrupting, exists: true, want: StateI0},
{name: "I1", status: database.ChatStatusInterrupting, queueNonEmpty: true, exists: true, want: StateI1},
{name: "A0", status: database.ChatStatusRequiresAction, exists: true, want: StateA0},
{name: "A1", status: database.ChatStatusRequiresAction, queueNonEmpty: true, exists: true, want: StateA1},
{name: "XW", status: database.ChatStatusWaiting, archived: true, exists: true, want: StateXW},
{name: "XE0", status: database.ChatStatusError, archived: true, exists: true, want: StateXE0},
{name: "XE1", status: database.ChatStatusError, archived: true, queueNonEmpty: true, exists: true, want: StateXE1},
}
for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
chat := database.Chat{}
if tc.exists {
chat = chatWithStatus(tc.status, tc.archived)
}
require.Equal(t, tc.want, ClassifyExecutionState(chat, tc.queueNonEmpty, tc.exists))
})
}
}
// TestClassifyExecutionState_Invalid covers every documented invalid
// combination: legacy statuses, waiting-with-queue, and archived busy
// statuses.
func TestClassifyExecutionState_Invalid(t *testing.T) {
t.Parallel()
cases := []struct {
name string
status database.ChatStatus
archived bool
queueNonEmpty bool
}{
// Legacy statuses (pending/paused/completed) are invalid for
// the new state machine.
{name: "LegacyPending", status: "pending"},
{name: "LegacyPaused", status: "paused"},
{name: "LegacyCompleted", status: "completed"},
// Waiting must always have an empty queue.
{name: "WaitingWithQueue", status: database.ChatStatusWaiting, queueNonEmpty: true},
{name: "WaitingArchivedWithQueue", status: database.ChatStatusWaiting, archived: true, queueNonEmpty: true},
// Archived busy statuses are invalid.
{name: "ArchivedRunning", status: database.ChatStatusRunning, archived: true},
{name: "ArchivedInterrupting", status: database.ChatStatusInterrupting, archived: true},
{name: "ArchivedRequiresAction", status: database.ChatStatusRequiresAction, archived: true},
}
for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got := ClassifyExecutionState(chatWithStatus(tc.status, tc.archived), tc.queueNonEmpty, true)
require.Equal(t, StateInvalid, got)
})
}
}
// TestClassifyExecutionState_RejectsAllUnlistedCombinations enumerates
// every (status, archived, queueNonEmpty) tuple for an existing chat
// and asserts exactly the twelve RFC tuples classify out of
// [StateInvalid]. Missing chats are handled separately via the N case
// in [TestClassifyExecutionState_Valid].
func TestClassifyExecutionState_RejectsAllUnlistedCombinations(t *testing.T) {
t.Parallel()
allStatuses := []database.ChatStatus{
database.ChatStatusWaiting,
database.ChatStatusError,
database.ChatStatusRunning,
database.ChatStatusInterrupting,
database.ChatStatusRequiresAction,
"pending", "paused", "completed",
}
validCount := 0
for _, status := range allStatuses {
for _, archived := range []bool{false, true} {
for _, queueNonEmpty := range []bool{false, true} {
got := ClassifyExecutionState(chatWithStatus(status, archived), queueNonEmpty, true)
if got != StateInvalid {
validCount++
}
}
}
}
require.Equal(t, 12, validCount, "exactly 12 valid existing-chat (status, archived, queue) tuples")
}
// TestClassifyOwnershipState covers both ownership classifications.
func TestClassifyOwnershipState(t *testing.T) {
t.Parallel()
cases := []struct {
name string
chat func() database.Chat
want OwnershipState
}{
{
name: "Unowned",
chat: func() database.Chat { return database.Chat{} },
want: StateU,
},
{
name: "Owned",
chat: func() database.Chat {
c := database.Chat{}
c.WorkerID.UUID = uuid.New()
c.WorkerID.Valid = true
return c
},
want: StateO,
},
}
for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
require.Equal(t, tc.want, ClassifyOwnershipState(tc.chat()))
})
}
}
// TestAllExecutionStates_Enumeration verifies AllExecutionStates
// contains every declared execution state exactly once.
func TestAllExecutionStates_Enumeration(t *testing.T) {
t.Parallel()
want := map[ExecutionState]bool{
StateN: true, StateW: true, StateE0: true, StateE1: true,
StateR0: true, StateR1: true, StateI0: true, StateI1: true,
StateA0: true, StateA1: true, StateXW: true, StateXE0: true,
StateXE1: true, StateInvalid: true,
}
require.Len(t, AllExecutionStates, len(want))
seen := make(map[ExecutionState]bool, len(want))
for _, s := range AllExecutionStates {
require.True(t, want[s], "unexpected state %s", s)
require.False(t, seen[s], "duplicate state %s", s)
seen[s] = true
}
}
// TestExecutionState_Predicates covers IsRunnable and QueueNonEmpty
// for every declared execution state.
func TestExecutionState_Predicates(t *testing.T) {
t.Parallel()
runnable := map[ExecutionState]bool{
StateR0: true, StateR1: true, StateI0: true, StateI1: true,
StateA0: true, StateA1: true,
}
nonEmpty := map[ExecutionState]bool{
StateE1: true, StateR1: true, StateI1: true, StateA1: true, StateXE1: true,
}
for _, s := range AllExecutionStates {
require.Equal(t, runnable[s], s.IsRunnable(), "IsRunnable(%s)", s)
require.Equal(t, nonEmpty[s], s.QueueNonEmpty(), "QueueNonEmpty(%s)", s)
}
}
@@ -0,0 +1,485 @@
package chatstate_test
import (
"context"
"encoding/json"
"testing"
"github.com/google/uuid"
"github.com/sqlc-dev/pqtype"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
"github.com/coder/coder/v2/coderd/x/chatd/chatstate"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
)
// =============================================================================
// Helpers for synthetic cancellation tests
// =============================================================================
// nonDynamicAssistantToolCallMessage builds an assistant message that
// issues a single tool call against a tool that is NOT in the chat's
// dynamic_tools set. The send-message and edit-message paths use the
// "cancel every outstanding tool call regardless of source" variant
// (dynamicOnly=false), so the cancellation must still fire even for
// non-dynamic tools.
func nonDynamicAssistantToolCallMessage(t *testing.T, modelID uuid.UUID, callID string) chatstate.Message {
t.Helper()
raw, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeToolCall,
ToolCallID: callID,
ToolName: "non_dynamic_tool",
Args: json.RawMessage(`{}`),
}})
require.NoError(t, err)
return chatstate.Message{
Role: database.ChatMessageRoleAssistant,
Content: raw,
Visibility: database.ChatMessageVisibilityBoth,
ContentVersion: chatprompt.CurrentContentVersion,
ModelConfigID: uuid.NullUUID{UUID: modelID, Valid: true},
}
}
// assertToolResultForCall asserts that msg is a tool-result message
// that resolves a tool call with id wantCallID and is_error=true.
func assertToolResultForCall(t *testing.T, msg database.ChatMessage, wantCallID string) {
t.Helper()
require.Equal(t, database.ChatMessageRoleTool, msg.Role)
parts, err := chatprompt.ParseContent(msg)
require.NoError(t, err)
require.NotEmpty(t, parts)
var found bool
for _, p := range parts {
if p.Type != codersdk.ChatMessagePartTypeToolResult {
continue
}
require.Equal(t, wantCallID, p.ToolCallID, "tool-call id matches")
require.True(t, p.IsError, "synthetic cancellation must be marked is_error=true")
found = true
}
require.True(t, found, "expected at least one tool-result part")
}
// commitAssistantToolCall pushes an assistant message that calls
// `tool_name` with `callID` into history via CommitStep. Returns the
// inserted assistant ChatMessage. Use the dynamic-tools chat fixture
// (createTestChatWithDynamicTools) when dynamicOnly cancellation
// paths are exercised.
func commitAssistantToolCall(
t *testing.T,
f *testFixture,
m *chatstate.ChatMachine,
msg chatstate.Message,
) database.ChatMessage {
t.Helper()
ctx := testutil.Context(t, testutil.WaitShort)
var step chatstate.CommitStepResult
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
var err error
step, err = tx.CommitStep(chatstate.CommitStepInput{Messages: []chatstate.Message{msg}})
return err
}))
require.Len(t, step.InsertedMessages, 1)
return step.InsertedMessages[0]
}
// landInW puts a fresh R0 chat into state W (waiting) via FinishTurn.
func landInW(t *testing.T, f *testFixture, m *chatstate.ChatMachine) {
t.Helper()
ctx := testutil.Context(t, testutil.WaitShort)
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
_, err := tx.FinishTurn(chatstate.FinishTurnInput{})
return err
}))
require.Equal(t, chatstate.StateW, f.classify(ctx, t, m.ChatID()))
}
// landInE0 puts a fresh R0 chat into state E0 (error, empty queue).
func landInE0(t *testing.T, f *testFixture, m *chatstate.ChatMachine) {
t.Helper()
ctx := testutil.Context(t, testutil.WaitShort)
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
_, err := tx.FinishError(chatstate.FinishErrorInput{
LastError: pqtype.NullRawMessage{
RawMessage: json.RawMessage(`{"message":"boom"}`),
Valid: true,
},
})
return err
}))
require.Equal(t, chatstate.StateE0, f.classify(ctx, t, m.ChatID()))
}
// =============================================================================
// SendMessage direct-history paths (W, E0)
// =============================================================================
// TestSendMessageDirect_W_SynthesizesToolCancellations verifies that
// from W, SendMessage inserts synthetic tool-result rows for every
// outstanding tool call on the last assistant message BEFORE the new
// user message, regardless of whether the tools are dynamic.
func TestSendMessageDirect_W_SynthesizesToolCancellations(t *testing.T) {
t.Parallel()
f := newTestFixture(t)
ctx := testutil.Context(t, testutil.WaitShort)
created := createTestChat(t, f)
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
callID := "call_" + uuid.NewString()
assistant := commitAssistantToolCall(t, f, m,
nonDynamicAssistantToolCallMessage(t, f.Model.ID, callID))
require.Equal(t, database.ChatMessageRoleAssistant, assistant.Role)
// R0 -> W.
landInW(t, f, m)
// SendMessage with a fresh user message. The direct-history path
// must insert a synthetic tool-result (for callID) followed by
// the new user message.
var send chatstate.SendMessageResult
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
var err error
send, err = tx.SendMessage(chatstate.SendMessageInput{
Message: userTextMessage("after-cancel", f.User.ID, f.Model.ID),
BusyBehavior: chatstate.BusyBehaviorQueue,
})
return err
}))
require.Len(t, send.InsertedMessages, 2, "synthetic cancel + new user")
assertToolResultForCall(t, send.InsertedMessages[0], callID)
require.Equal(t, database.ChatMessageRoleUser, send.InsertedMessages[1].Role)
require.Less(t, send.InsertedMessages[0].ID, send.InsertedMessages[1].ID,
"synthetic cancel is inserted before the user message")
}
// TestSendMessageDirect_E0_SynthesizesToolCancellations exercises
// the same path from E0.
func TestSendMessageDirect_E0_SynthesizesToolCancellations(t *testing.T) {
t.Parallel()
f := newTestFixture(t)
ctx := testutil.Context(t, testutil.WaitShort)
created := createTestChat(t, f)
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
callID := "call_" + uuid.NewString()
commitAssistantToolCall(t, f, m,
nonDynamicAssistantToolCallMessage(t, f.Model.ID, callID))
// R0 -> E0.
landInE0(t, f, m)
var send chatstate.SendMessageResult
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
var err error
send, err = tx.SendMessage(chatstate.SendMessageInput{
Message: userTextMessage("after-error", f.User.ID, f.Model.ID),
BusyBehavior: chatstate.BusyBehaviorQueue,
})
return err
}))
require.Len(t, send.InsertedMessages, 2)
assertToolResultForCall(t, send.InsertedMessages[0], callID)
require.Equal(t, database.ChatMessageRoleUser, send.InsertedMessages[1].Role)
}
// =============================================================================
// EditMessage replacement insertion
// =============================================================================
// TestEditMessage_SynthesizesToolCancellationsBeforeReplacement
// verifies that EditMessage from a state with an outstanding tool
// call before the edited user message inserts a synthetic
// tool-result before the replacement user message in history.
//
// The scenario is:
// - user message 1 (initial)
// - assistant tool-call (outstanding)
// - user message 2 (the one we will edit)
//
// EditMessage soft-deletes user message 2 and everything after it,
// then synthesizes cancellations for tool calls on the last
// surviving assistant message that have no matching tool-result.
func TestEditMessage_SynthesizesToolCancellationsBeforeReplacement(t *testing.T) {
t.Parallel()
f := newTestFixture(t)
ctx := testutil.Context(t, testutil.WaitShort)
created := createTestChat(t, f)
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
// Build the history described above. CommitStep is happy to insert
// a mixed batch as long as it stays inside R0.
callID := "call_" + uuid.NewString()
assistantTC := nonDynamicAssistantToolCallMessage(t, f.Model.ID, callID)
secondUser := userTextMessage("second user", f.User.ID, f.Model.ID)
var step chatstate.CommitStepResult
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
var err error
step, err = tx.CommitStep(chatstate.CommitStepInput{
Messages: []chatstate.Message{assistantTC, secondUser},
})
return err
}))
require.Len(t, step.InsertedMessages, 2)
secondUserID := step.InsertedMessages[1].ID
require.Equal(t, database.ChatMessageRoleUser, step.InsertedMessages[1].Role)
var edit chatstate.EditMessageResult
editedContent := mustMarshalParts(t, []codersdk.ChatMessagePart{
codersdk.ChatMessageText("edited"),
})
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
var err error
edit, err = tx.EditMessage(chatstate.EditMessageInput{
MessageID: secondUserID,
CreatedBy: f.User.ID,
Content: editedContent,
})
return err
}))
require.Len(t, edit.CancellationMessages, 1, "synthetic cancel inserted")
assertToolResultForCall(t, edit.CancellationMessages[0], callID)
require.Equal(t, database.ChatMessageRoleUser, edit.ReplacementMessage.Role)
require.Less(t, edit.CancellationMessages[0].ID, edit.ReplacementMessage.ID,
"cancellations are inserted before the replacement user message")
}
// PromoteQueuedMessage synchronous-history paths (E1, A1)
// =============================================================================
// TestPromoteQueuedMessage_E1_SynthesizesToolCancellations verifies
// that promoting a queued message from E1 inserts synthetic
// tool-result rows for outstanding tool calls before the promoted
// user message in history.
func TestPromoteQueuedMessage_E1_SynthesizesToolCancellations(t *testing.T) {
t.Parallel()
f := newTestFixture(t)
ctx := testutil.Context(t, testutil.WaitShort)
created := createTestChat(t, f)
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
callID := "call_" + uuid.NewString()
commitAssistantToolCall(t, f, m,
nonDynamicAssistantToolCallMessage(t, f.Model.ID, callID))
// Land in R1 with one queued message.
queued := sendQueuedMessage(t, f, m, "queued-for-promote")
require.NotNil(t, queued.QueuedMessage)
require.Equal(t, chatstate.StateR1, f.classify(ctx, t, created.Chat.ID))
// R1 -> E1 via FinishError.
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
_, err := tx.FinishError(chatstate.FinishErrorInput{
LastError: pqtype.NullRawMessage{
RawMessage: json.RawMessage(`{"message":"boom"}`),
Valid: true,
},
})
return err
}))
require.Equal(t, chatstate.StateE1, f.classify(ctx, t, created.Chat.ID))
var promote chatstate.PromoteQueuedMessageResult
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
var err error
promote, err = tx.PromoteQueuedMessage(chatstate.PromoteQueuedMessageInput{
QueuedMessageID: queued.QueuedMessage.ID,
})
return err
}))
require.Len(t, promote.CancellationMessages, 1)
assertToolResultForCall(t, promote.CancellationMessages[0], callID)
require.NotNil(t, promote.InsertedMessage)
require.Equal(t, database.ChatMessageRoleUser, promote.InsertedMessage.Role)
require.Less(t, promote.CancellationMessages[0].ID, promote.InsertedMessage.ID,
"cancel is inserted before the promoted user message")
}
// TestPromoteQueuedMessage_A1_SynthesizesDynamicToolCancellations
// verifies that the dynamic outstanding tool call is canceled when
// promoting from A1. The non-dynamic cancellation counterpart lives
// in TestPromoteQueuedMessage_A1_CancelsAllOutstandingToolCalls.
func TestPromoteQueuedMessage_A1_SynthesizesDynamicToolCancellations(t *testing.T) {
t.Parallel()
f := newTestFixture(t)
ctx := testutil.Context(t, testutil.WaitShort)
toolName := "dyn_promote_a1"
created := createTestChatWithDynamicTools(t, f, toolName)
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
dynCallID := "call_" + uuid.NewString()
commitAssistantToolCall(t, f, m,
assistantToolCallMessage(t, f.Model.ID, toolName, dynCallID))
// Land in A0.
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
_, err := tx.EnterRequiresAction(chatstate.EnterRequiresActionInput{})
return err
}))
require.Equal(t, chatstate.StateA0, f.classify(ctx, t, created.Chat.ID))
// A0 -> A1 with one queued user message.
queued := sendQueuedMessage(t, f, m, "queued-for-a1-promote")
require.NotNil(t, queued.QueuedMessage)
require.Equal(t, chatstate.StateA1, f.classify(ctx, t, created.Chat.ID))
var promote chatstate.PromoteQueuedMessageResult
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
var err error
promote, err = tx.PromoteQueuedMessage(chatstate.PromoteQueuedMessageInput{
QueuedMessageID: queued.QueuedMessage.ID,
})
return err
}))
require.Len(t, promote.CancellationMessages, 1, "dynamic tool call canceled")
assertToolResultForCall(t, promote.CancellationMessages[0], dynCallID)
require.NotNil(t, promote.InsertedMessage)
require.Equal(t, database.ChatMessageRoleUser, promote.InsertedMessage.Role)
}
// =============================================================================
// FinishTurn queue-promotion paths
// =============================================================================
// TestFinishTurn_R1_SynthesizesToolCancellationsBeforePromotion
// verifies that finishing a turn while a queued message exists
// synthesizes outstanding tool cancellations before promoting the
// queue head into history.
func TestFinishTurn_R1_SynthesizesToolCancellationsBeforePromotion(t *testing.T) {
t.Parallel()
f := newTestFixture(t)
ctx := testutil.Context(t, testutil.WaitShort)
created := createTestChat(t, f)
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
callID := "call_" + uuid.NewString()
commitAssistantToolCall(t, f, m,
nonDynamicAssistantToolCallMessage(t, f.Model.ID, callID))
queued := sendQueuedMessage(t, f, m, "queued-for-finish")
require.NotNil(t, queued.QueuedMessage)
require.Equal(t, chatstate.StateR1, f.classify(ctx, t, created.Chat.ID))
beforeIDs := historyMessageIDs(ctx, t, f, created.Chat.ID)
var finish chatstate.FinishTurnResult
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
var err error
finish, err = tx.FinishTurn(chatstate.FinishTurnInput{})
return err
}))
require.NotNil(t, finish.PromotedMessage)
require.Equal(t, database.ChatMessageRoleUser, finish.PromotedMessage.Role)
afterIDs := historyMessageIDs(ctx, t, f, created.Chat.ID)
require.Equal(t, len(beforeIDs)+2, len(afterIDs),
"finish inserts both a tool cancel and the promoted user")
// The two newly inserted messages are tool-result then user.
newIDs := afterIDs[len(beforeIDs):]
cancel, err := f.DB.GetChatMessageByID(ctx, newIDs[0])
require.NoError(t, err)
assertToolResultForCall(t, cancel, callID)
require.Equal(t, finish.PromotedMessage.ID, newIDs[1])
}
// =============================================================================
// FinishInterruption queue-promotion paths and outstanding tool calls
// =============================================================================
// TestFinishInterruption_I1_PromotesQueueHead verifies that
// FinishInterruption from I1 with no outstanding tool calls
// promotes the queue head into history.
func TestFinishInterruption_I1_PromotesQueueHead(t *testing.T) {
t.Parallel()
f := newTestFixture(t)
ctx := testutil.Context(t, testutil.WaitShort)
created := createTestChat(t, f)
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
// Reach R1 with one queued message.
queued := sendQueuedMessage(t, f, m, "queued-for-interruption")
require.NotNil(t, queued.QueuedMessage)
// R1 -> I1 via Interrupt.
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
_, err := tx.Interrupt(chatstate.InterruptInput{Reason: "test"})
return err
}))
require.Equal(t, chatstate.StateI1, f.classify(ctx, t, created.Chat.ID))
beforeIDs := historyMessageIDs(ctx, t, f, created.Chat.ID)
var finish chatstate.FinishInterruptionResult
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
var err error
finish, err = tx.FinishInterruption(chatstate.FinishInterruptionInput{})
return err
}))
require.NotNil(t, finish.PromotedMessage)
require.Equal(t, database.ChatMessageRoleUser, finish.PromotedMessage.Role)
afterIDs := historyMessageIDs(ctx, t, f, created.Chat.ID)
require.Equal(t, len(beforeIDs)+1, len(afterIDs))
require.Equal(t, chatstate.StateR0, f.classify(ctx, t, created.Chat.ID))
}
// TestFinishInterruption_RejectsOutstandingToolCalls verifies that
// FinishInterruption fails (TransitionNotAllowed-shaped) when the
// chat still has an outstanding dynamic tool call after the partial
// commit. The non-dynamic counterpart lives in
// TestFinishInterruption_RejectsNonDynamicOutstandingToolCall. The
// chat must remain in its prior state.
func TestFinishInterruption_RejectsOutstandingToolCalls(t *testing.T) {
t.Parallel()
f := newTestFixture(t)
ctx := testutil.Context(t, testutil.WaitShort)
toolName := "dyn_finish_reject"
created := createTestChatWithDynamicTools(t, f, toolName)
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
dynCallID := "call_" + uuid.NewString()
commitAssistantToolCall(t, f, m,
assistantToolCallMessage(t, f.Model.ID, toolName, dynCallID))
// R0 -> I0 via Interrupt. Interrupt closes pending dynamic calls
// when transitioning from A0/A1, but from R0 it does NOT, so the
// chat keeps its outstanding dynamic call.
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
_, err := tx.Interrupt(chatstate.InterruptInput{Reason: "test"})
return err
}))
require.Equal(t, chatstate.StateI0, f.classify(ctx, t, created.Chat.ID))
stateBefore := f.classify(ctx, t, created.Chat.ID)
historyBefore := historyMessageIDs(ctx, t, f, created.Chat.ID)
publishedBefore := len(f.Pub.channels)
// FinishInterruption with no partial commits should reject
// because the dynamic call is still outstanding.
err := m.Update(ctx, func(tx *chatstate.Tx) error {
_, err := tx.FinishInterruption(chatstate.FinishInterruptionInput{})
return err
})
require.Error(t, err)
require.ErrorIs(t, err, chatstate.ErrTransitionNotAllowed)
require.Equal(t, stateBefore, f.classify(ctx, t, created.Chat.ID), "state unchanged")
require.Equal(t, historyBefore, historyMessageIDs(ctx, t, f, created.Chat.ID),
"history unchanged on rejected finish")
require.Equal(t, publishedBefore, len(f.Pub.channels),
"failed FinishInterruption publishes nothing")
}
// ensure unused imports don't break the build if any helper is
// removed later.
var _ = context.Background
+234
View File
@@ -0,0 +1,234 @@
package chatstate
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"github.com/google/uuid"
"github.com/sqlc-dev/pqtype"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
"github.com/coder/coder/v2/codersdk"
)
// synthesizePendingToolCancellations builds [Message] inserts that
// satisfy every outstanding tool call on the chat's last assistant
// message with a synthetic cancellation tool-result message.
//
// "Outstanding" means a tool call present on the last assistant
// message that does not yet have a matching tool-result message in
// the active history after it. The caller controls whether to limit
// to dynamic-tool calls (true) or close every outstanding tool call
// regardless of source (false). The dynamic-only variant is used by
// requires-action interrupts; the all-tool variant is used by any
// transition that needs to insert a new user message into history.
//
// The synthetic results use the supplied chat's last_model_config_id.
// Returns (nil, nil) when there is nothing to synthesize.
//
//nolint:revive // dynamicOnly is a domain flag, not a control flag.
func synthesizePendingToolCancellations(
ctx context.Context,
store database.Store,
chat database.Chat,
reason string,
dynamicOnly bool,
) ([]Message, error) {
var dynamicToolNames map[string]bool
if dynamicOnly {
var err error
dynamicToolNames, err = parseDynamicToolNamesFromRaw(chat.DynamicTools)
if err != nil {
return nil, xerrors.Errorf("parse dynamic tool names: %w", err)
}
if len(dynamicToolNames) == 0 {
return nil, nil
}
}
lastAssistant, err := store.GetLastChatMessageByRole(ctx, database.GetLastChatMessageByRoleParams{
ChatID: chat.ID,
Role: database.ChatMessageRoleAssistant,
})
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
if err != nil {
return nil, xerrors.Errorf("get last assistant message: %w", err)
}
assistantParts, err := chatprompt.ParseContent(lastAssistant)
if err != nil {
return nil, xerrors.Errorf("parse assistant message: %w", err)
}
afterMsgs, err := store.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: chat.ID,
AfterID: lastAssistant.ID,
})
if err != nil {
return nil, xerrors.Errorf("get messages after assistant: %w", err)
}
handled := make(map[string]bool)
for _, msg := range afterMsgs {
if msg.Role != database.ChatMessageRoleTool {
continue
}
parts, perr := chatprompt.ParseContent(msg)
if perr != nil {
// Don't fail the whole cancellation just because one
// historical message is unparseable; treat its tool
// results as unknown.
continue
}
for _, p := range parts {
if p.Type == codersdk.ChatMessagePartTypeToolResult {
handled[p.ToolCallID] = true
}
}
}
out := make([]Message, 0)
for _, part := range assistantParts {
if part.Type != codersdk.ChatMessagePartTypeToolCall {
continue
}
if dynamicOnly && !dynamicToolNames[part.ToolName] {
continue
}
if handled[part.ToolCallID] {
continue
}
resultPart := codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeToolResult,
ToolCallID: part.ToolCallID,
ToolName: part.ToolName,
Result: json.RawMessage(fmt.Sprintf("%q", reason)),
IsError: true,
}
raw, merr := chatprompt.MarshalParts([]codersdk.ChatMessagePart{resultPart})
if merr != nil {
return nil, xerrors.Errorf("marshal synthetic tool result: %w", merr)
}
out = append(out, Message{
Role: database.ChatMessageRoleTool,
Content: raw,
Visibility: database.ChatMessageVisibilityBoth,
ContentVersion: chatprompt.CurrentContentVersion,
ModelConfigID: uuid.NullUUID{UUID: chat.LastModelConfigID, Valid: true},
})
}
if len(out) == 0 {
return nil, nil
}
return out, nil
}
// pendingDynamicToolCallIDs returns the dynamic tool-call IDs on the
// chat's last assistant message that do not yet have a matching
// tool-result message in active history. The returned map is keyed by
// tool-call ID and valued by tool name so callers can build matching
// result messages without re-parsing the assistant content.
func pendingDynamicToolCallIDs(ctx context.Context, store database.Store, chat database.Chat) (map[string]string, error) {
dynamic, err := parseDynamicToolNamesFromRaw(chat.DynamicTools)
if err != nil {
return nil, err
}
if len(dynamic) == 0 {
return map[string]string{}, nil
}
return outstandingToolCallIDs(ctx, store, chat, func(toolName string) bool {
return dynamic[toolName]
})
}
// pendingAllToolCallIDs returns the tool-call IDs of every outstanding
// tool call on the chat's last assistant message, regardless of
// whether the tool is dynamic. The returned map is keyed by tool-call
// ID and valued by tool name. Callers that must guarantee a valid
// LLM message history (e.g. before promoting a user message into
// active history, or after committing an interruption's partial
// messages) should use this variant so non-dynamic tool calls do not
// silently bypass the check.
func pendingAllToolCallIDs(ctx context.Context, store database.Store, chat database.Chat) (map[string]string, error) {
return outstandingToolCallIDs(ctx, store, chat, func(string) bool { return true })
}
// outstandingToolCallIDs walks the chat's last assistant message and
// returns the subset of its tool calls that have no matching
// tool-result message in the active history after it. The accept
// callback can be used to restrict the walk to a subset of tools
// (e.g. dynamic-only).
func outstandingToolCallIDs(ctx context.Context, store database.Store, chat database.Chat, accept func(toolName string) bool) (map[string]string, error) {
lastAssistant, err := store.GetLastChatMessageByRole(ctx, database.GetLastChatMessageByRoleParams{
ChatID: chat.ID,
Role: database.ChatMessageRoleAssistant,
})
if errors.Is(err, sql.ErrNoRows) {
return map[string]string{}, nil
}
if err != nil {
return nil, xerrors.Errorf("get last assistant: %w", err)
}
parts, err := chatprompt.ParseContent(lastAssistant)
if err != nil {
return nil, xerrors.Errorf("parse assistant: %w", err)
}
afterMsgs, err := store.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: chat.ID,
AfterID: lastAssistant.ID,
})
if err != nil {
return nil, xerrors.Errorf("get messages after assistant: %w", err)
}
handled := make(map[string]bool)
for _, msg := range afterMsgs {
if msg.Role != database.ChatMessageRoleTool {
continue
}
ps, perr := chatprompt.ParseContent(msg)
if perr != nil {
continue
}
for _, p := range ps {
if p.Type == codersdk.ChatMessagePartTypeToolResult {
handled[p.ToolCallID] = true
}
}
}
out := make(map[string]string)
for _, p := range parts {
if p.Type != codersdk.ChatMessagePartTypeToolCall {
continue
}
if !accept(p.ToolName) {
continue
}
if handled[p.ToolCallID] {
continue
}
out[p.ToolCallID] = p.ToolName
}
return out, nil
}
// parseDynamicToolNamesFromRaw is a private mirror of
// chatd.parseDynamicToolNames so chatstate does not pull a runtime
// dependency on the chatd package. It accepts a nullable raw JSON
// blob and returns a name set.
func parseDynamicToolNamesFromRaw(raw pqtype.NullRawMessage) (map[string]bool, error) {
if !raw.Valid || len(raw.RawMessage) == 0 {
return map[string]bool{}, nil
}
var tools []codersdk.DynamicTool
if err := json.Unmarshal(raw.RawMessage, &tools); err != nil {
return nil, err
}
out := make(map[string]bool, len(tools))
for _, t := range tools {
out[t.Name] = true
}
return out, nil
}
+226
View File
@@ -0,0 +1,226 @@
package chatstate
// Transition is the enumeration of transitions implemented by the
// state machine. Values intentionally match the names of the public
// methods on [Tx] (and [CreateChat]). The transition matrix below
// declares the legal (from -> to) execution-state mappings used by
// each transition method for validation.
type Transition string
const (
TransitionCreateChat Transition = "CreateChat"
TransitionSetArchived Transition = "SetArchived"
TransitionSendMessage Transition = "SendMessage"
TransitionEditMessage Transition = "EditMessage"
TransitionDeleteQueuedMessage Transition = "DeleteQueuedMessage"
TransitionPromoteQueuedMessage Transition = "PromoteQueuedMessage"
TransitionInterrupt Transition = "Interrupt"
TransitionCompleteRequiresAction Transition = "CompleteRequiresAction"
TransitionAcquire Transition = "Acquire"
TransitionAbandon Transition = "Abandon"
TransitionRecordGenerationAttempt Transition = "RecordGenerationAttempt"
TransitionRecordRetryState Transition = "RecordRetryState"
TransitionCommitStep Transition = "CommitStep"
TransitionEnterRequiresAction Transition = "EnterRequiresAction"
TransitionFinishInterruption Transition = "FinishInterruption"
TransitionFinishTurn Transition = "FinishTurn"
TransitionFinishError Transition = "FinishError"
TransitionCancelRequiresAction Transition = "CancelRequiresAction"
TransitionReconcileInvalidState Transition = "ReconcileInvalidState"
)
// String implements fmt.Stringer.
func (t Transition) String() string { return string(t) }
// AllExecutionTransitions is the canonical enumeration of every
// execution-state transition that has an entry in the matrix below.
// Ownership transitions (Acquire, Abandon) are intentionally not part
// of this slice because they are validated independently and do not
// have a (from->to) execution mapping.
var AllExecutionTransitions = []Transition{
TransitionCreateChat,
TransitionSetArchived,
TransitionSendMessage,
TransitionEditMessage,
TransitionDeleteQueuedMessage,
TransitionPromoteQueuedMessage,
TransitionInterrupt,
TransitionCompleteRequiresAction,
TransitionRecordGenerationAttempt,
TransitionRecordRetryState,
TransitionCommitStep,
TransitionEnterRequiresAction,
TransitionFinishInterruption,
TransitionFinishTurn,
TransitionFinishError,
TransitionCancelRequiresAction,
TransitionReconcileInvalidState,
}
// transitionMatrix is the in-code mirror of the RFC's execution-state
// transition table. Each entry maps an input state to the set of
// allowed transitions together with the possible classified output
// states that the transition implementation may land in. Outputs may
// depend on the post-mutation queue cardinality (for example
// DeleteQueuedMessage from E1 lands in E0 when the deleted row was the
// last queued message, or stays in E1 otherwise), which is why several
// entries list more than one output.
//
// Ownership transitions (Acquire, Abandon) are intentionally not
// included; they are orthogonal to execution state.
var transitionMatrix = map[ExecutionState]map[Transition][]ExecutionState{
StateN: {
TransitionCreateChat: {StateR0},
},
StateW: {
TransitionSetArchived: {StateXW},
TransitionSendMessage: {StateR0},
TransitionEditMessage: {StateR0},
},
StateE0: {
TransitionSetArchived: {StateXE0},
TransitionSendMessage: {StateR0},
TransitionEditMessage: {StateR0},
},
StateE1: {
TransitionSetArchived: {StateXE1},
TransitionSendMessage: {StateR1},
TransitionEditMessage: {StateR0},
TransitionDeleteQueuedMessage: {StateE0, StateE1},
TransitionPromoteQueuedMessage: {StateR0, StateR1},
},
StateR0: {
TransitionSendMessage: {StateR1, StateI1},
TransitionEditMessage: {StateR0},
TransitionInterrupt: {StateI0},
TransitionRecordGenerationAttempt: {StateR0},
TransitionRecordRetryState: {StateR0},
TransitionCommitStep: {StateR0},
TransitionEnterRequiresAction: {StateA0},
TransitionFinishTurn: {StateW},
TransitionFinishError: {StateE0},
},
StateR1: {
TransitionSendMessage: {StateR1, StateI1},
TransitionEditMessage: {StateR0},
TransitionDeleteQueuedMessage: {StateR0, StateR1},
TransitionPromoteQueuedMessage: {StateI1},
TransitionInterrupt: {StateI1},
TransitionRecordGenerationAttempt: {StateR1},
TransitionRecordRetryState: {StateR1},
TransitionCommitStep: {StateR1},
TransitionEnterRequiresAction: {StateA1},
TransitionFinishTurn: {StateR0, StateR1},
TransitionFinishError: {StateE1},
},
StateI0: {
TransitionSendMessage: {StateI1},
TransitionEditMessage: {StateR0},
TransitionFinishInterruption: {StateW},
},
StateI1: {
TransitionSendMessage: {StateI1},
TransitionEditMessage: {StateR0},
TransitionDeleteQueuedMessage: {StateI0, StateI1},
TransitionPromoteQueuedMessage: {StateI1},
TransitionFinishInterruption: {StateR0, StateR1},
},
StateA0: {
TransitionSendMessage: {StateA1, StateR1},
TransitionEditMessage: {StateR0},
TransitionInterrupt: {StateR0},
TransitionCompleteRequiresAction: {StateR0},
TransitionCancelRequiresAction: {StateR0},
},
StateA1: {
TransitionSendMessage: {StateA1, StateR1},
TransitionEditMessage: {StateR0},
TransitionDeleteQueuedMessage: {StateA0, StateA1},
TransitionPromoteQueuedMessage: {StateR0, StateR1},
TransitionInterrupt: {StateR1},
TransitionCompleteRequiresAction: {StateR1},
TransitionCancelRequiresAction: {StateR1},
},
StateXW: {
TransitionSetArchived: {StateW},
},
StateXE0: {
TransitionSetArchived: {StateE0},
},
StateXE1: {
TransitionSetArchived: {StateE1},
},
StateInvalid: {
TransitionReconcileInvalidState: {StateE0, StateE1},
},
}
// isExecutionTransitionAllowed reports whether a transition is legal
// from the supplied input state per the matrix above. Ownership
// transitions are not stored in the matrix and always return false.
func isExecutionTransitionAllowed(t Transition, from ExecutionState) bool {
allowed, ok := transitionMatrix[from]
if !ok {
return false
}
_, ok = allowed[t]
return ok
}
// requireExecutionTransition validates that t is legal from `from`
// and returns a typed *TransitionError otherwise.
func requireExecutionTransition(t Transition, from ExecutionState) error {
if isExecutionTransitionAllowed(t, from) {
return nil
}
return newTransitionError(t, from, "")
}
// AllowedExecutionTransitionsFrom returns a deterministic slice of
// transitions legal from `from`. Mostly used by tests to enumerate the
// matrix without leaking the internal map.
func AllowedExecutionTransitionsFrom(from ExecutionState) []Transition {
allowed := transitionMatrix[from]
out := make([]Transition, 0, len(allowed))
for _, t := range AllExecutionTransitions {
if _, ok := allowed[t]; ok {
out = append(out, t)
}
}
return out
}
// AllowedInputStates returns a deterministic slice of execution states
// from which `tr` is legal per the matrix above. Mostly used by tests
// to enumerate the matrix without leaking the internal map.
func AllowedInputStates(tr Transition) []ExecutionState {
var out []ExecutionState
for _, from := range AllExecutionStates {
if isExecutionTransitionAllowed(tr, from) {
out = append(out, from)
}
}
return out
}
// AllowedExecutionTransitionOutputs returns the set of classified
// post-states that the transition `tr` may produce from `from` per
// the matrix above. The returned slice is a copy so callers may mutate
// it without affecting the underlying matrix.
//
// When `tr` is not allowed from `from`, an empty (nil) slice is
// returned. Tests use this helper to enumerate the (transition, from,
// want) triples that must be exercised by the row-level matrix tests.
func AllowedExecutionTransitionOutputs(from ExecutionState, tr Transition) []ExecutionState {
allowed, ok := transitionMatrix[from]
if !ok {
return nil
}
outputs, ok := allowed[tr]
if !ok {
return nil
}
cp := make([]ExecutionState, len(outputs))
copy(cp, outputs)
return cp
}
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,907 @@
package chatstate_test
import (
"context"
"encoding/json"
"fmt"
"testing"
"github.com/google/uuid"
"github.com/sqlc-dev/pqtype"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/database"
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
"github.com/coder/coder/v2/coderd/x/chatd/chatstate"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
)
// =============================================================================
// State seeding helpers
//
// seededChat is the shared output of seedState. Some transition tests
// need extra context beyond the chat ID (for example, the queued
// message ID to delete, or the message ID to edit), so this struct
// surfaces what each state was seeded with.
// =============================================================================
type seededChat struct {
chatID uuid.UUID
exists bool
initialUserMessageID int64
assistantToolCallMsgID int64
queuedMessageIDs []int64
// queuedMessageBodies is parallel to queuedMessageIDs and records
// the text body each queued message was seeded with. Cases that
// promote queued messages into history use this to assert the
// promoted message content matches what was originally queued.
queuedMessageBodies []string
queuedMessageCreatedBy []uuid.UUID
dynamicToolName string
pendingToolCallID string
pendingToolCallIDs []string
}
// dynamicToolJSON returns the canonical [{name,description,input_schema}]
// payload used to seed dynamic_tools on a chat. Tests that need
// pending dynamic tool calls (A0, A1) reuse this and reference the
// returned tool name in their assistant tool-call message.
func dynamicToolJSON(name string) []byte {
tools := []codersdk.DynamicTool{{
Name: name,
Description: "test tool",
InputSchema: json.RawMessage(`{"type":"object","properties":{}}`),
}}
raw, err := json.Marshal(tools)
if err != nil {
panic(err)
}
return raw
}
// assistantToolCallMessage builds a chatstate.Message for an
// assistant message that issues one tool call against the supplied
// dynamic tool name. The tool-call ID is unique per call so multiple
// messages do not collide.
func assistantToolCallMessage(t *testing.T, modelID uuid.UUID, toolName, callID string) chatstate.Message {
t.Helper()
raw, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeToolCall,
ToolCallID: callID,
ToolName: toolName,
Args: json.RawMessage(`{}`),
}})
require.NoError(t, err)
return chatstate.Message{
Role: database.ChatMessageRoleAssistant,
Content: raw,
Visibility: database.ChatMessageVisibilityBoth,
ContentVersion: chatprompt.CurrentContentVersion,
ModelConfigID: uuid.NullUUID{UUID: modelID, Valid: true},
}
}
func mixedAssistantToolCallMessage(t *testing.T, modelID uuid.UUID, dynamicTool, dynCallID, nonDynCallID string) chatstate.Message {
t.Helper()
parts := []codersdk.ChatMessagePart{
{
Type: codersdk.ChatMessagePartTypeToolCall,
ToolCallID: dynCallID,
ToolName: dynamicTool,
Args: json.RawMessage(`{}`),
},
{
Type: codersdk.ChatMessagePartTypeToolCall,
ToolCallID: nonDynCallID,
ToolName: "non_dynamic_tool",
Args: json.RawMessage(`{}`),
},
}
raw, err := chatprompt.MarshalParts(parts)
require.NoError(t, err)
return chatstate.Message{
Role: database.ChatMessageRoleAssistant,
Content: raw,
Visibility: database.ChatMessageVisibilityBoth,
ContentVersion: chatprompt.CurrentContentVersion,
ModelConfigID: uuid.NullUUID{UUID: modelID, Valid: true},
}
}
// createTestChatWithDynamicTools mirrors createTestChat but seeds the
// chat with a non-empty dynamic_tools blob so EnterRequiresAction,
// CompleteRequiresAction, and CancelRequiresAction can find pending
// dynamic tool calls.
func createTestChatWithDynamicTools(t *testing.T, f *testFixture, toolName string) chatstate.CreateChatResult {
t.Helper()
ctx := testutil.Context(t, testutil.WaitShort)
res, err := chatstate.CreateChat(ctx, f.DB, f.Pub, chatstate.CreateChatInput{
OrganizationID: f.Org.ID,
OwnerID: f.User.ID,
LastModelConfigID: f.Model.ID,
Title: "test",
ClientType: database.ChatClientTypeApi,
DynamicTools: pqtype.NullRawMessage{
RawMessage: dynamicToolJSON(toolName),
Valid: true,
},
InitialMessages: []chatstate.Message{
userTextMessage("hello", f.User.ID, f.Model.ID),
},
})
require.NoError(t, err)
return res
}
// seedAOrA1 seeds a chat into A0 (queuedExtras=0) or A1
// (queuedExtras>=1) with a real pending dynamic tool call. Used by
// cases that need A0 or A1 with a configurable queue cardinality.
func seedAOrA1(t *testing.T, f *testFixture, queuedExtras int, namePrefix string) seededChat {
t.Helper()
ctx := testutil.Context(t, testutil.WaitShort)
toolName := namePrefix
callID := "call_" + uuid.NewString()
created := createTestChatWithDynamicTools(t, f, toolName)
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
var step chatstate.CommitStepResult
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
var err error
step, err = tx.CommitStep(chatstate.CommitStepInput{
Messages: []chatstate.Message{
assistantToolCallMessage(t, f.Model.ID, toolName, callID),
},
})
return err
}))
require.Len(t, step.InsertedMessages, 1)
// R0 -> A0.
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
_, err := tx.EnterRequiresAction(chatstate.EnterRequiresActionInput{})
return err
}))
var (
queuedIDs []int64
queuedBodies []string
)
for i := 0; i < queuedExtras; i++ {
body := fmt.Sprintf("queued-%s-%d", namePrefix, i)
sm := sendQueuedMessage(t, f, m, body)
require.NotNil(t, sm.QueuedMessage)
queuedIDs = append(queuedIDs, sm.QueuedMessage.ID)
queuedBodies = append(queuedBodies, body)
}
return seededChat{
chatID: created.Chat.ID,
exists: true,
initialUserMessageID: firstUserMessageID(ctx, t, f, created.Chat.ID),
assistantToolCallMsgID: step.InsertedMessages[0].ID,
queuedMessageIDs: queuedIDs,
queuedMessageBodies: queuedBodies,
dynamicToolName: toolName,
pendingToolCallID: callID,
}
}
// seedState seeds a chat into the supplied execution state and
// returns identifying handles useful for downstream assertions. For
// [chatstate.StateN] the returned chatID is a fresh UUID that does
// not exist in the database. Multi-queued seeds (for E1, R1, I1,
// A1 with 2 queued messages, and Invalid with a non-empty queue) live in
// seedStateMultiQueued.
func seedState(t *testing.T, f *testFixture, state chatstate.ExecutionState) seededChat {
t.Helper()
ctx := testutil.Context(t, testutil.WaitShort)
switch state {
case chatstate.StateN:
return seededChat{chatID: uuid.New(), exists: false}
case chatstate.StateR0:
created := createTestChat(t, f)
initial := firstUserMessageID(ctx, t, f, created.Chat.ID)
return seededChat{
chatID: created.Chat.ID,
exists: true,
initialUserMessageID: initial,
}
case chatstate.StateW:
created := createTestChat(t, f)
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
_, err := tx.FinishTurn(chatstate.FinishTurnInput{})
return err
}))
return seededChat{
chatID: created.Chat.ID,
exists: true,
initialUserMessageID: firstUserMessageID(ctx, t, f, created.Chat.ID),
}
case chatstate.StateE0:
created := createTestChat(t, f)
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
_, err := tx.FinishError(chatstate.FinishErrorInput{
LastError: pqtype.NullRawMessage{
RawMessage: json.RawMessage(`{"message":"boom"}`),
Valid: true,
},
})
return err
}))
return seededChat{
chatID: created.Chat.ID,
exists: true,
initialUserMessageID: firstUserMessageID(ctx, t, f, created.Chat.ID),
}
case chatstate.StateE1:
created := createTestChat(t, f)
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
// R0 -> R1
queuedBody := "queued-for-E1"
queued := sendQueuedMessage(t, f, m, queuedBody)
require.NotNil(t, queued.QueuedMessage)
// R1 -> E1
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
_, err := tx.FinishError(chatstate.FinishErrorInput{
LastError: pqtype.NullRawMessage{
RawMessage: json.RawMessage(`{"message":"boom"}`),
Valid: true,
},
})
return err
}))
return seededChat{
chatID: created.Chat.ID,
exists: true,
initialUserMessageID: firstUserMessageID(ctx, t, f, created.Chat.ID),
queuedMessageIDs: []int64{queued.QueuedMessage.ID},
queuedMessageBodies: []string{queuedBody},
}
case chatstate.StateR1:
created := createTestChat(t, f)
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
queuedBody := "queued-for-R1"
queued := sendQueuedMessage(t, f, m, queuedBody)
require.NotNil(t, queued.QueuedMessage)
return seededChat{
chatID: created.Chat.ID,
exists: true,
initialUserMessageID: firstUserMessageID(ctx, t, f, created.Chat.ID),
queuedMessageIDs: []int64{queued.QueuedMessage.ID},
queuedMessageBodies: []string{queuedBody},
}
case chatstate.StateI0:
created := createTestChat(t, f)
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
_, err := tx.Interrupt(chatstate.InterruptInput{Reason: "seed"})
return err
}))
return seededChat{
chatID: created.Chat.ID,
exists: true,
initialUserMessageID: firstUserMessageID(ctx, t, f, created.Chat.ID),
}
case chatstate.StateI1:
created := createTestChat(t, f)
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
// R0 -> I1: SendMessage with interrupt behavior queues the
// message and sets status to interrupting.
queuedBody := "queued-for-I1"
sm := sendInterruptMessage(t, f, m, queuedBody)
require.NotNil(t, sm.QueuedMessage)
return seededChat{
chatID: created.Chat.ID,
exists: true,
initialUserMessageID: firstUserMessageID(ctx, t, f, created.Chat.ID),
queuedMessageIDs: []int64{sm.QueuedMessage.ID},
queuedMessageBodies: []string{queuedBody},
}
case chatstate.StateA0:
return seedAOrA1(t, f, 0, "seed_tool_a0")
case chatstate.StateA1:
return seedAOrA1(t, f, 1, "seed_tool_a1")
case chatstate.StateXW:
created := createTestChat(t, f)
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
_, err := tx.FinishTurn(chatstate.FinishTurnInput{})
return err
}))
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
_, err := tx.SetArchived(chatstate.SetArchivedInput{Archived: true})
return err
}))
return seededChat{
chatID: created.Chat.ID,
exists: true,
initialUserMessageID: firstUserMessageID(ctx, t, f, created.Chat.ID),
}
case chatstate.StateXE0:
created := createTestChat(t, f)
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
_, err := tx.FinishError(chatstate.FinishErrorInput{
LastError: pqtype.NullRawMessage{
RawMessage: json.RawMessage(`{"message":"boom"}`),
Valid: true,
},
})
return err
}))
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
_, err := tx.SetArchived(chatstate.SetArchivedInput{Archived: true})
return err
}))
return seededChat{
chatID: created.Chat.ID,
exists: true,
initialUserMessageID: firstUserMessageID(ctx, t, f, created.Chat.ID),
}
case chatstate.StateXE1:
created := createTestChat(t, f)
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
queuedBody := "queued-for-XE1"
queued := sendQueuedMessage(t, f, m, queuedBody)
require.NotNil(t, queued.QueuedMessage)
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
_, err := tx.FinishError(chatstate.FinishErrorInput{
LastError: pqtype.NullRawMessage{
RawMessage: json.RawMessage(`{"message":"boom"}`),
Valid: true,
},
})
return err
}))
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
_, err := tx.SetArchived(chatstate.SetArchivedInput{Archived: true})
return err
}))
return seededChat{
chatID: created.Chat.ID,
exists: true,
initialUserMessageID: firstUserMessageID(ctx, t, f, created.Chat.ID),
queuedMessageIDs: []int64{queued.QueuedMessage.ID},
queuedMessageBodies: []string{queuedBody},
}
case chatstate.StateInvalid:
created := createTestChat(t, f)
// Force running + archived, a deliberately invalid
// combination per the classifier.
_, err := f.DB.UpdateChatExecutionState(ctx, database.UpdateChatExecutionStateParams{
ID: created.Chat.ID,
Status: database.ChatStatusRunning,
Archived: true,
WorkerID: created.Chat.WorkerID,
RunnerID: created.Chat.RunnerID,
LastError: created.Chat.LastError,
RequiresActionDeadlineAt: created.Chat.RequiresActionDeadlineAt,
})
require.NoError(t, err)
return seededChat{
chatID: created.Chat.ID,
exists: true,
initialUserMessageID: firstUserMessageID(ctx, t, f, created.Chat.ID),
}
}
t.Fatalf("seedState: unsupported execution state %s", state)
return seededChat{}
}
// seedStateMultiQueued seeds a state with two queued messages. Used
// by cases that need the post-mutation queue to remain non-empty.
// Supported states: E1, R1, I1, A1.
func seedStateMultiQueued(t *testing.T, f *testFixture, state chatstate.ExecutionState) seededChat {
t.Helper()
ctx := testutil.Context(t, testutil.WaitShort)
switch state {
case chatstate.StateE1:
created := createTestChat(t, f)
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
firstBody := "queued-e1-a"
first := sendQueuedMessage(t, f, m, firstBody)
require.NotNil(t, first.QueuedMessage)
secondBody := "queued-e1-b"
second := sendQueuedMessage(t, f, m, secondBody)
require.NotNil(t, second.QueuedMessage)
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
_, err := tx.FinishError(chatstate.FinishErrorInput{
LastError: pqtype.NullRawMessage{
RawMessage: json.RawMessage(`{"message":"boom"}`),
Valid: true,
},
})
return err
}))
return seededChat{
chatID: created.Chat.ID,
exists: true,
initialUserMessageID: firstUserMessageID(ctx, t, f, created.Chat.ID),
queuedMessageIDs: []int64{first.QueuedMessage.ID, second.QueuedMessage.ID},
queuedMessageBodies: []string{firstBody, secondBody},
}
case chatstate.StateR1:
created := createTestChat(t, f)
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
firstBody := "queued-r1-a"
first := sendQueuedMessage(t, f, m, firstBody)
require.NotNil(t, first.QueuedMessage)
secondBody := "queued-r1-b"
second := sendQueuedMessage(t, f, m, secondBody)
require.NotNil(t, second.QueuedMessage)
return seededChat{
chatID: created.Chat.ID,
exists: true,
initialUserMessageID: firstUserMessageID(ctx, t, f, created.Chat.ID),
queuedMessageIDs: []int64{first.QueuedMessage.ID, second.QueuedMessage.ID},
queuedMessageBodies: []string{firstBody, secondBody},
}
case chatstate.StateI1:
created := createTestChat(t, f)
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
firstBody := "queued-i1-a"
first := sendQueuedMessage(t, f, m, firstBody)
require.NotNil(t, first.QueuedMessage)
// R1 -> I1 via interrupt-mode SendMessage queues a second
// message and flips status to interrupting.
secondBody := "queued-i1-b"
second := sendInterruptMessage(t, f, m, secondBody)
require.NotNil(t, second.QueuedMessage)
return seededChat{
chatID: created.Chat.ID,
exists: true,
initialUserMessageID: firstUserMessageID(ctx, t, f, created.Chat.ID),
queuedMessageIDs: []int64{first.QueuedMessage.ID, second.QueuedMessage.ID},
queuedMessageBodies: []string{firstBody, secondBody},
}
case chatstate.StateA1:
return seedAOrA1(t, f, 2, "seed_tool_a1_multi")
}
t.Fatalf("seedStateMultiQueued: unsupported execution state %s", state)
return seededChat{}
}
// seedA1WithMixedOutstandingToolCalls seeds A1 with one queued message
// and one assistant message carrying both a dynamic and non-dynamic
// outstanding tool call. It is used by PromoteQueuedMessage(A1) to
// prove all tool calls are closed before inserting the promoted user.
func seedA1WithMixedOutstandingToolCalls(t *testing.T, f *testFixture, queuedExtras int, namePrefix string) seededChat {
t.Helper()
ctx := testutil.Context(t, testutil.WaitShort)
toolName := namePrefix
dynCallID := "call_" + uuid.NewString()
nonDynCallID := "call_" + uuid.NewString()
created := createTestChatWithDynamicTools(t, f, toolName)
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
var step chatstate.CommitStepResult
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
var err error
step, err = tx.CommitStep(chatstate.CommitStepInput{
Messages: []chatstate.Message{
mixedAssistantToolCallMessage(t, f.Model.ID, toolName, dynCallID, nonDynCallID),
},
})
return err
}))
require.Len(t, step.InsertedMessages, 1)
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
_, err := tx.EnterRequiresAction(chatstate.EnterRequiresActionInput{})
return err
}))
var (
queuedIDs []int64
queuedBodies []string
queuedCreatedBy []uuid.UUID
)
for i := range queuedExtras {
body := fmt.Sprintf("queued-%s-%d", namePrefix, i)
createdBy := uuid.New()
queued, err := f.DB.InsertChatQueuedMessageWithCreator(ctx, database.InsertChatQueuedMessageWithCreatorParams{
ChatID: created.Chat.ID,
Content: userMessageContent(t, body),
ModelConfigID: uuid.NullUUID{UUID: f.Model.ID, Valid: true},
CreatedBy: createdBy,
})
require.NoError(t, err)
queuedIDs = append(queuedIDs, queued.ID)
queuedBodies = append(queuedBodies, body)
queuedCreatedBy = append(queuedCreatedBy, createdBy)
}
return seededChat{
chatID: created.Chat.ID,
exists: true,
initialUserMessageID: firstUserMessageID(ctx, t, f, created.Chat.ID),
assistantToolCallMsgID: step.InsertedMessages[0].ID,
queuedMessageIDs: queuedIDs,
queuedMessageBodies: queuedBodies,
queuedMessageCreatedBy: queuedCreatedBy,
dynamicToolName: toolName,
pendingToolCallID: dynCallID,
pendingToolCallIDs: []string{dynCallID, nonDynCallID},
}
}
// seedInvalidWithQueue seeds Invalid with a single queued message so
// ReconcileInvalidState lands in E1 (non-empty queue) instead of E0.
func seedInvalidWithQueue(t *testing.T, f *testFixture) seededChat {
t.Helper()
ctx := testutil.Context(t, testutil.WaitShort)
created := createTestChat(t, f)
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
queuedBody := "queued-invalid"
queued := sendQueuedMessage(t, f, m, queuedBody)
require.NotNil(t, queued.QueuedMessage)
// Force the deliberately invalid running + archived combo on
// top of the queue.
chat, err := f.DB.GetChatByID(ctx, created.Chat.ID)
require.NoError(t, err)
_, err = f.DB.UpdateChatExecutionState(ctx, database.UpdateChatExecutionStateParams{
ID: chat.ID,
Status: database.ChatStatusRunning,
Archived: true,
WorkerID: chat.WorkerID,
RunnerID: chat.RunnerID,
LastError: chat.LastError,
RequiresActionDeadlineAt: chat.RequiresActionDeadlineAt,
})
require.NoError(t, err)
return seededChat{
chatID: chat.ID,
exists: true,
initialUserMessageID: firstUserMessageID(ctx, t, f, chat.ID),
queuedMessageIDs: []int64{queued.QueuedMessage.ID},
queuedMessageBodies: []string{queuedBody},
}
}
// firstUserMessageID returns the lowest-id non-deleted user message
// on the chat. Most transition tests reuse this when they need a
// user message to edit.
func firstUserMessageID(ctx context.Context, t *testing.T, f *testFixture, chatID uuid.UUID) int64 {
t.Helper()
msgs, err := f.DB.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: chatID,
})
require.NoError(t, err)
for _, m := range msgs {
if m.Role == database.ChatMessageRoleUser && !m.Deleted {
return m.ID
}
}
t.Fatalf("firstUserMessageID: chat %s has no user messages", chatID)
return 0
}
func firstAssistantMessageID(ctx context.Context, t *testing.T, f *testFixture, chatID uuid.UUID) int64 {
t.Helper()
msgs, err := f.DB.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: chatID,
})
require.NoError(t, err)
for _, m := range msgs {
if m.Role == database.ChatMessageRoleAssistant && !m.Deleted {
return m.ID
}
}
t.Fatalf("firstAssistantMessageID: chat %s has no assistant messages", chatID)
return 0
}
// seedForEnterRequiresAction extends seedState for R0 and R1 with a
// chat that has dynamic_tools plus an assistant tool-call message in
// history. EnterRequiresAction's precondition rejects R0/R1 without
// pending dynamic tool calls, so the generic seedState path will not
// do. Other states fall through to the default seedState.
func seedForEnterRequiresAction(t *testing.T, f *testFixture, state chatstate.ExecutionState) seededChat {
t.Helper()
ctx := testutil.Context(t, testutil.WaitShort)
switch state {
case chatstate.StateR0:
toolName := "ra_tool_r0"
callID := "call_" + uuid.NewString()
created := createTestChatWithDynamicTools(t, f, toolName)
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
var step chatstate.CommitStepResult
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
var err error
step, err = tx.CommitStep(chatstate.CommitStepInput{
Messages: []chatstate.Message{
assistantToolCallMessage(t, f.Model.ID, toolName, callID),
},
})
return err
}))
require.Len(t, step.InsertedMessages, 1)
return seededChat{
chatID: created.Chat.ID,
exists: true,
initialUserMessageID: firstUserMessageID(ctx, t, f, created.Chat.ID),
assistantToolCallMsgID: step.InsertedMessages[0].ID,
dynamicToolName: toolName,
pendingToolCallID: callID,
pendingToolCallIDs: []string{callID},
}
case chatstate.StateR1:
toolName := "ra_tool_r1"
callID := "call_" + uuid.NewString()
created := createTestChatWithDynamicTools(t, f, toolName)
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
var step chatstate.CommitStepResult
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
var err error
step, err = tx.CommitStep(chatstate.CommitStepInput{
Messages: []chatstate.Message{
assistantToolCallMessage(t, f.Model.ID, toolName, callID),
},
})
return err
}))
// R0 -> R1 with a queued message.
queuedBody := "queued-for-RA-r1"
sm := sendQueuedMessage(t, f, m, queuedBody)
require.NotNil(t, sm.QueuedMessage)
return seededChat{
chatID: created.Chat.ID,
exists: true,
initialUserMessageID: firstUserMessageID(ctx, t, f, created.Chat.ID),
assistantToolCallMsgID: step.InsertedMessages[0].ID,
queuedMessageIDs: []int64{sm.QueuedMessage.ID},
queuedMessageBodies: []string{queuedBody},
dynamicToolName: toolName,
pendingToolCallID: callID,
pendingToolCallIDs: []string{callID},
}
}
return seedState(t, f, state)
}
// =============================================================================
// Snapshot baselines and shared assertion helpers used by every matrix
// case. captureBaseline records the chat's snapshot_version and the
// publisher's recorded channel count immediately before a transition
// runs; assertSnapshotBumpedOnce and assertNoMutationOrPublish use the
// baseline to verify either a single snapshot bump and one chat:update
// on success, or zero mutation and zero publishes on failure.
// =============================================================================
// activeHistoryIDs returns the ids of non-deleted history messages
// for the chat in row-id order. Useful for verifying CommitStep,
// EditMessage replacement, and PromoteQueuedMessage head insertion.
func activeHistoryIDs(ctx context.Context, t *testing.T, f *testFixture, chatID uuid.UUID) []int64 {
t.Helper()
msgs, err := f.DB.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: chatID,
})
require.NoError(t, err)
out := make([]int64, 0, len(msgs))
for _, m := range msgs {
if !m.Deleted {
out = append(out, m.ID)
}
}
return out
}
func requireChatMessageByID(ctx context.Context, t *testing.T, f *testFixture, id int64) database.ChatMessage {
t.Helper()
msg, err := f.DB.GetChatMessageByID(ctx, id)
require.NoError(t, err)
return msg
}
func requireQueuedMessageByID(ctx context.Context, t *testing.T, f *testFixture, chatID uuid.UUID, id int64) database.ChatQueuedMessage {
t.Helper()
msg, err := f.DB.GetChatQueuedMessageByID(ctx, database.GetChatQueuedMessageByIDParams{
ID: id,
ChatID: chatID,
})
require.NoError(t, err)
return msg
}
func requireQueuedMessageDeleted(ctx context.Context, t *testing.T, f *testFixture, chatID uuid.UUID, id int64) {
t.Helper()
_, err := f.DB.GetChatQueuedMessageByID(ctx, database.GetChatQueuedMessageByIDParams{
ID: id,
ChatID: chatID,
})
require.Error(t, err)
}
func assertFetchedUserMessage(ctx context.Context, t *testing.T, f *testFixture, msg database.ChatMessage) database.ChatMessage {
t.Helper()
fetched := requireChatMessageByID(ctx, t, f, msg.ID)
require.Equal(t, msg.ChatID, fetched.ChatID)
require.Equal(t, database.ChatMessageRoleUser, fetched.Role)
require.True(t, fetched.CreatedBy.Valid)
require.Equal(t, f.User.ID, fetched.CreatedBy.UUID)
require.True(t, fetched.ModelConfigID.Valid)
require.Equal(t, f.Model.ID, fetched.ModelConfigID.UUID)
require.Equal(t, chatprompt.CurrentContentVersion, fetched.ContentVersion)
return fetched
}
func assertFetchedQueuedMessage(ctx context.Context, t *testing.T, f *testFixture, chatID uuid.UUID, queued database.ChatQueuedMessage) database.ChatQueuedMessage {
t.Helper()
fetched := requireQueuedMessageByID(ctx, t, f, chatID, queued.ID)
require.Equal(t, chatID, fetched.ChatID)
require.Equal(t, f.User.ID, fetched.CreatedBy)
require.True(t, fetched.ModelConfigID.Valid)
require.Equal(t, f.Model.ID, fetched.ModelConfigID.UUID)
require.NotEmpty(t, fetched.Content)
return fetched
}
func newActiveMessageIDs(base snapshotBaseline, after []int64) []int64 {
seen := make(map[int64]struct{}, len(base.historyIDs))
for _, id := range base.historyIDs {
seen[id] = struct{}{}
}
out := make([]int64, 0, len(after))
for _, id := range after {
if _, ok := seen[id]; !ok {
out = append(out, id)
}
}
return out
}
// assertToolResultForCallNoError asserts that msg is a tool-result
// message that resolves a tool call with id wantCallID, is_error=false,
// and that the result JSON matches wantResultJSON. Complements
// assertToolResultForCall in synthetic_cancellation_test.go which
// asserts is_error=true.
func assertToolResultForCallNoError(t *testing.T, msg database.ChatMessage, wantCallID, wantResultJSON string) {
t.Helper()
require.Equal(t, database.ChatMessageRoleTool, msg.Role)
parts, err := chatprompt.ParseContent(msg)
require.NoError(t, err)
require.NotEmpty(t, parts)
var found bool
for _, p := range parts {
if p.Type != codersdk.ChatMessagePartTypeToolResult {
continue
}
require.Equal(t, wantCallID, p.ToolCallID, "tool-call id matches")
require.False(t, p.IsError, "CompleteRequiresAction tool result must not be is_error")
require.JSONEq(t, wantResultJSON, string(p.Result), "CompleteRequiresAction tool result JSON matches submitted output")
found = true
}
require.True(t, found, "expected at least one tool-result part")
}
// assertChatMessageText asserts that the persisted content of msg
// decodes to a single text part with the supplied body. Used by
// matrix cases that need to verify the actual text submitted via
// SendMessage / EditMessage / CommitStep, or the text that was
// promoted out of the queue into history.
func assertChatMessageText(t *testing.T, msg database.ChatMessage, want string) {
t.Helper()
parts, err := chatprompt.ParseContent(msg)
require.NoError(t, err, "parse chat message content")
require.Len(t, parts, 1, "expected exactly one content part")
require.Equal(t, codersdk.ChatMessagePartTypeText, parts[0].Type,
"expected a text content part")
require.Equal(t, want, parts[0].Text, "unexpected chat message text")
}
// assertQueuedMessageText asserts that the JSON content of queued
// decodes to a single text part with the supplied body. Used by
// matrix cases that need to verify the body inserted into
// chat_queued_messages via SendMessage.
func assertQueuedMessageText(t *testing.T, queued database.ChatQueuedMessage, want string) {
t.Helper()
var parts []codersdk.ChatMessagePart
require.NoError(t, json.Unmarshal(queued.Content, &parts), "unmarshal queued content")
require.Len(t, parts, 1, "expected exactly one queued content part")
require.Equal(t, codersdk.ChatMessagePartTypeText, parts[0].Type,
"expected a text content part")
require.Equal(t, want, parts[0].Text, "unexpected queued message text")
}
// assertQueueBodiesInOrder fetches the queued messages for the chat
// in queue order and asserts each row's text body matches the
// supplied bodies. Used by matrix cases that need to verify the
// remaining queue content after a promote / finish-turn /
// finish-interruption.
func assertQueueBodiesInOrder(ctx context.Context, t *testing.T, f *testFixture, chatID uuid.UUID, want []string) {
t.Helper()
rows, err := f.DB.GetChatQueuedMessagesByPosition(ctx, chatID)
require.NoError(t, err)
require.Len(t, rows, len(want), "queue length must match expected bodies")
for i, r := range rows {
assertQueuedMessageText(t, r, want[i])
}
}
// snapshotBaseline records the chat's snapshot_version and the
// publisher's recorded channel count immediately before a transition
// runs. Tests use it to verify either a single snapshot bump and one
// chat:update on success, or zero mutation and zero publishes on
// failure.
type snapshotBaseline struct {
exists bool
chat database.Chat
snapshot int64
historyVersion int64
queueVersion int64
retryStateVersion int64
generationAttempt int64
queueCount int64
queueIDs []int64
historyIDs []int64
channels int
}
func captureBaseline(ctx context.Context, t *testing.T, f *testFixture, seeded seededChat) snapshotBaseline {
t.Helper()
base := snapshotBaseline{
exists: seeded.exists,
channels: len(f.Pub.channels),
}
if !seeded.exists {
return base
}
chat, err := f.DB.GetChatByID(ctx, seeded.chatID)
require.NoError(t, err)
base.chat = chat
base.snapshot = chat.SnapshotVersion
base.historyVersion = chat.HistoryVersion
base.queueVersion = chat.QueueVersion
base.retryStateVersion = chat.RetryStateVersion
base.generationAttempt = chat.GenerationAttempt
base.queueIDs = queuedIDsByPosition(ctx, t, f, seeded.chatID)
count, err := f.DB.CountChatQueuedMessages(ctx, seeded.chatID)
require.NoError(t, err)
base.queueCount = count
base.historyIDs = activeHistoryIDs(ctx, t, f, seeded.chatID)
return base
}
// assertSnapshotBumpedOnce asserts that one Update committed; that is,
// snapshot_version advanced by exactly one and the publisher saw at
// least one chat:update on the per-chat channel after the baseline.
func assertSnapshotBumpedOnce(ctx context.Context, t *testing.T, f *testFixture, chatID uuid.UUID, base snapshotBaseline) {
t.Helper()
after, err := f.DB.GetChatByID(ctx, chatID)
require.NoError(t, err)
require.Equal(t, base.snapshot+1, after.SnapshotVersion, "snapshot_version must bump exactly once")
channel := coderdpubsub.ChatStateUpdateChannel(chatID)
found := false
for _, c := range f.Pub.channels[base.channels:] {
if c == channel {
found = true
break
}
}
require.True(t, found, "expected one chat:update on %s after commit", channel)
}
// assertNoMutationOrPublish asserts a failed transition rolled back
// the automatic snapshot bump and published nothing.
func assertNoMutationOrPublish(ctx context.Context, t *testing.T, f *testFixture, chatID uuid.UUID, base snapshotBaseline) {
t.Helper()
require.Equal(t, base.channels, len(f.Pub.channels), "failed transition must not publish")
if base.exists {
after, err := f.DB.GetChatByID(ctx, chatID)
require.NoError(t, err)
require.Equal(t, base.snapshot, after.SnapshotVersion, "failed transition must not advance snapshot_version")
}
}
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,747 @@
package chatstate_test
import (
"encoding/json"
"testing"
"github.com/google/uuid"
"github.com/sqlc-dev/pqtype"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
"github.com/coder/coder/v2/coderd/x/chatd/chatstate"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
)
// =============================================================================
// CreateChat tests.
//
// CreateChat is the only transition that originates from StateN and it
// is not exercised through ChatMachine.Update, so it lives outside
// TestTransitionMatrix_AllCombinations.
// =============================================================================
// TestTransitionCreate_NToR0 verifies that CreateChat lands a fresh
// chat in R0 with snapshot_version 1, the initial user message
// recorded at revision 1, queue_version still 0, and the post-commit
// publish requesting an ownership hint plus a chat:update.
func TestTransitionCreate_NToR0(t *testing.T) {
t.Parallel()
f := newTestFixture(t)
ctx := testutil.Context(t, testutil.WaitShort)
res := createTestChat(t, f)
require.Equal(t, database.ChatStatusRunning, res.Chat.Status)
require.False(t, res.Chat.Archived)
require.Equal(t, int64(1), res.Chat.SnapshotVersion, "snapshot_version starts at 1")
require.Equal(t, int64(1), res.Chat.HistoryVersion, "history_version set by trigger after initial insert")
require.Equal(t, int64(0), res.Chat.QueueVersion, "queue_version stays 0 when no queue rows")
require.Equal(t, int64(0), res.Chat.GenerationAttempt)
require.NotEmpty(t, res.InitialMessages)
require.Equal(t, int64(1), res.InitialMessages[0].Revision)
require.Equal(t, chatstate.StateR0, f.classify(ctx, t, res.Chat.ID))
require.True(t, f.Pub.hasOwnership(), "newly created chat is runnable and unowned")
f.Pub.expectChatUpdate(t, res.Chat.ID, 1)
}
// TestCreateChat_RejectsEmptyInitialMessages verifies that CreateChat
// rejects an empty InitialMessages slice with ErrTransitionNotAllowed
// and does not publish anything.
func TestCreateChat_RejectsEmptyInitialMessages(t *testing.T) {
t.Parallel()
f := newTestFixture(t)
ctx := testutil.Context(t, testutil.WaitShort)
_, err := chatstate.CreateChat(ctx, f.DB, f.Pub, chatstate.CreateChatInput{
OrganizationID: f.Org.ID,
OwnerID: f.User.ID,
LastModelConfigID: f.Model.ID,
ClientType: database.ChatClientTypeApi,
Title: "t",
InitialMessages: nil,
})
require.Error(t, err)
require.ErrorIs(t, err, chatstate.ErrTransitionNotAllowed)
require.Empty(t, f.Pub.channels, "rejected create must not publish")
}
func TestCreateChat_AllowsNoUserMessages(t *testing.T) {
t.Parallel()
f := newTestFixture(t)
ctx := testutil.Context(t, testutil.WaitShort)
assistant := userTextMessage("oops", f.User.ID, f.Model.ID)
assistant.Role = database.ChatMessageRoleAssistant
res, err := chatstate.CreateChat(ctx, f.DB, f.Pub, chatstate.CreateChatInput{
OrganizationID: f.Org.ID,
OwnerID: f.User.ID,
LastModelConfigID: f.Model.ID,
Title: "t",
ClientType: database.ChatClientTypeApi,
InitialMessages: []chatstate.Message{assistant},
})
require.NoError(t, err)
require.Len(t, res.InitialMessages, 1)
}
func TestCreateChat_AllowsNonFinalUserMessage(t *testing.T) {
t.Parallel()
f := newTestFixture(t)
ctx := testutil.Context(t, testutil.WaitShort)
res, err := chatstate.CreateChat(ctx, f.DB, f.Pub, chatstate.CreateChatInput{
OrganizationID: f.Org.ID,
OwnerID: f.User.ID,
LastModelConfigID: f.Model.ID,
Title: "t",
ClientType: database.ChatClientTypeApi,
InitialMessages: []chatstate.Message{
userTextMessage("context user", f.User.ID, f.Model.ID),
userTextMessage("final user", f.User.ID, f.Model.ID),
},
})
require.NoError(t, err)
require.Len(t, res.InitialMessages, 2)
}
// =============================================================================
// Input-specific rejection cases.
//
// These tests cover the same matrix rows as TestTransitionMatrix_AllCombinations
// but exercise legal source states with invalid transition inputs. They are
// intentionally outside the matrix entry point so the matrix focus stays on
// positive cases and generated disallowed cases.
// =============================================================================
type setArchivedWrongDirectionCase struct {
from chatstate.ExecutionState
wantArchive bool
label string
}
func setArchivedWrongDirectionCases() []setArchivedWrongDirectionCase {
return []setArchivedWrongDirectionCase{
// Non-archived states with archived=false: no-op.
{from: chatstate.StateW, wantArchive: false, label: "W_to_W"},
{from: chatstate.StateE0, wantArchive: false, label: "E0_to_E0"},
{from: chatstate.StateE1, wantArchive: false, label: "E1_to_E1"},
// Archived states with archived=true: no-op.
{from: chatstate.StateXW, wantArchive: true, label: "XW_to_XW"},
{from: chatstate.StateXE0, wantArchive: true, label: "XE0_to_XE0"},
{from: chatstate.StateXE1, wantArchive: true, label: "XE1_to_XE1"},
}
}
var invalidBusyBehaviors = []chatstate.BusyBehavior{
chatstate.BusyBehavior(""),
chatstate.BusyBehavior("not-a-real-mode"),
}
func runSetArchivedWrongDirectionCase(t *testing.T, tc setArchivedWrongDirectionCase) {
t.Helper()
f := newTestFixture(t)
ctx := testutil.Context(t, testutil.WaitShort)
seeded := seedState(t, f, tc.from)
require.Equal(t, tc.from, f.classify(ctx, t, seeded.chatID))
base := captureBaseline(ctx, t, f, seeded)
m := chatstate.NewChatMachine(f.DB, f.Pub, seeded.chatID, chatstate.Options{})
err := m.Update(ctx, func(tx *chatstate.Tx) error {
_, serr := tx.SetArchived(chatstate.SetArchivedInput{Archived: tc.wantArchive})
return serr
})
require.Error(t, err, "SetArchived must reject when Archived matches the current value")
require.ErrorIs(t, err, chatstate.ErrTransitionNotAllowed,
"SetArchived must wrap ErrTransitionNotAllowed")
var te *chatstate.TransitionError
require.ErrorAs(t, err, &te,
"SetArchived must return a typed TransitionError")
require.Equal(t, chatstate.TransitionSetArchived, te.Transition)
require.Equal(t, tc.from, te.From, "TransitionError records the loaded from-state")
require.Equal(t, tc.from, f.classify(ctx, t, seeded.chatID),
"rejected SetArchived must leave the chat in the same state")
assertNoMutationOrPublish(ctx, t, f, seeded.chatID, base)
}
func runInvalidBusyBehaviorCase(t *testing.T, from chatstate.ExecutionState, bb chatstate.BusyBehavior) {
t.Helper()
f := newTestFixture(t)
ctx := testutil.Context(t, testutil.WaitShort)
seeded := seedState(t, f, from)
require.Equal(t, from, f.classify(ctx, t, seeded.chatID),
"seed must land in %s", from)
base := captureBaseline(ctx, t, f, seeded)
m := chatstate.NewChatMachine(f.DB, f.Pub, seeded.chatID, chatstate.Options{})
err := m.Update(ctx, func(tx *chatstate.Tx) error {
_, serr := tx.SendMessage(chatstate.SendMessageInput{
Message: userTextMessage("invalid-bb", f.User.ID, f.Model.ID),
BusyBehavior: bb,
})
return serr
})
require.Error(t, err, "SendMessage must reject invalid BusyBehavior")
require.ErrorIs(t, err, chatstate.ErrTransitionNotAllowed,
"SendMessage rejection must wrap ErrTransitionNotAllowed")
var te *chatstate.TransitionError
require.ErrorAs(t, err, &te,
"SendMessage must return a typed TransitionError")
require.Equal(t, chatstate.TransitionSendMessage, te.Transition)
require.Equal(t, from, te.From,
"TransitionError records the source state")
require.Equal(t, from, f.classify(ctx, t, seeded.chatID),
"rejected SendMessage must leave the chat in the same state")
assertNoMutationOrPublish(ctx, t, f, seeded.chatID, base)
}
type completeRequiresActionRejectCase struct {
name string
results func(seeded seededChat) []chatstate.ToolResultInput
}
type recordRetryStateRejectCase struct {
name string
retryState pqtype.NullRawMessage
}
func completeRequiresActionRejectCases() []completeRequiresActionRejectCase {
valid := func(id string) chatstate.ToolResultInput {
return chatstate.ToolResultInput{
ToolCallID: id,
Output: json.RawMessage(`{"ok":true}`),
}
}
return []completeRequiresActionRejectCase{
{
name: "missing_required_tool_result",
results: func(seeded seededChat) []chatstate.ToolResultInput { return nil },
},
{
name: "extra_tool_result",
results: func(seeded seededChat) []chatstate.ToolResultInput {
return []chatstate.ToolResultInput{valid(seeded.pendingToolCallID), valid("call_extra")}
},
},
{
name: "duplicate_tool_call_id",
results: func(seeded seededChat) []chatstate.ToolResultInput {
return []chatstate.ToolResultInput{valid(seeded.pendingToolCallID), valid(seeded.pendingToolCallID)}
},
},
{
name: "mismatched_tool_call_id",
results: func(seeded seededChat) []chatstate.ToolResultInput {
return []chatstate.ToolResultInput{valid("call_mismatch")}
},
},
{
name: "invalid_json_output",
results: func(seeded seededChat) []chatstate.ToolResultInput {
return []chatstate.ToolResultInput{{ToolCallID: seeded.pendingToolCallID, Output: json.RawMessage(`{`)}}
},
},
}
}
func recordRetryStateRejectCases() []recordRetryStateRejectCase {
return []recordRetryStateRejectCase{
{
name: "sql_null_payload",
},
{
name: "empty_payload",
retryState: pqtype.NullRawMessage{RawMessage: json.RawMessage(``), Valid: true},
},
{
name: "invalid_json_payload",
retryState: pqtype.NullRawMessage{RawMessage: json.RawMessage(`{`), Valid: true},
},
}
}
func runCompleteRequiresActionRejectCase(t *testing.T, tc completeRequiresActionRejectCase) {
t.Helper()
f := newTestFixture(t)
ctx := testutil.Context(t, testutil.WaitShort)
seeded := seedAOrA1(t, f, 0, "reject_complete_requires_action")
require.Equal(t, chatstate.StateA0, f.classify(ctx, t, seeded.chatID))
base := captureBaseline(ctx, t, f, seeded)
m := chatstate.NewChatMachine(f.DB, f.Pub, seeded.chatID, chatstate.Options{})
err := m.Update(ctx, func(tx *chatstate.Tx) error {
_, cerr := tx.CompleteRequiresAction(chatstate.CompleteRequiresActionInput{
CreatedBy: f.User.ID,
ModelConfigID: f.Model.ID,
Results: tc.results(seeded),
})
return cerr
})
require.Error(t, err)
require.ErrorIs(t, err, chatstate.ErrTransitionNotAllowed)
var te *chatstate.TransitionError
require.ErrorAs(t, err, &te)
require.Equal(t, chatstate.TransitionCompleteRequiresAction, te.Transition)
require.Equal(t, chatstate.StateA0, te.From)
assertNoMutationOrPublish(ctx, t, f, seeded.chatID, base)
}
func runRecordRetryStateRejectCase(t *testing.T, tc recordRetryStateRejectCase) {
t.Helper()
f := newTestFixture(t)
ctx := testutil.Context(t, testutil.WaitShort)
seeded := seedState(t, f, chatstate.StateR0)
require.Equal(t, chatstate.StateR0, f.classify(ctx, t, seeded.chatID))
base := captureBaseline(ctx, t, f, seeded)
m := chatstate.NewChatMachine(f.DB, f.Pub, seeded.chatID, chatstate.Options{})
err := m.Update(ctx, func(tx *chatstate.Tx) error {
_, rerr := tx.RecordRetryState(chatstate.RecordRetryStateInput{
RetryState: tc.retryState,
})
return rerr
})
require.Error(t, err)
require.ErrorIs(t, err, chatstate.ErrTransitionNotAllowed)
var te *chatstate.TransitionError
require.ErrorAs(t, err, &te)
require.Equal(t, chatstate.TransitionRecordRetryState, te.Transition)
require.Equal(t, chatstate.StateR0, te.From)
assertNoMutationOrPublish(ctx, t, f, seeded.chatID, base)
}
// TestTransitionInputValidation groups every input-specific rejection
// test. The matrix coverage entry point in
// TestTransitionMatrix_AllCombinations intentionally focuses on
// positive cases and generated disallowed cases; rejection cases that
// exercise legal matrix rows with invalid inputs live here so the
// matrix entry point stays focused.
func TestTransitionInputValidation(t *testing.T) {
t.Parallel()
t.Run("SetArchived_wrong_direction", func(t *testing.T) {
t.Parallel()
for _, tc := range setArchivedWrongDirectionCases() {
t.Run(tc.label, func(t *testing.T) {
t.Parallel()
runSetArchivedWrongDirectionCase(t, tc)
})
}
})
t.Run("SendMessage_invalid_busy_behavior", func(t *testing.T) {
t.Parallel()
for _, from := range chatstate.AllowedInputStates(chatstate.TransitionSendMessage) {
for _, bb := range invalidBusyBehaviors {
label := from.String() + "/" + string(bb)
if bb == "" {
label = from.String() + "/empty"
}
t.Run(label, func(t *testing.T) {
t.Parallel()
runInvalidBusyBehaviorCase(t, from, bb)
})
}
}
})
t.Run("CompleteRequiresAction_invalid_results", func(t *testing.T) {
t.Parallel()
for _, tc := range completeRequiresActionRejectCases() {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
runCompleteRequiresActionRejectCase(t, tc)
})
}
})
t.Run("RecordRetryState_invalid_payload", func(t *testing.T) {
t.Parallel()
for _, tc := range recordRetryStateRejectCases() {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
runRecordRetryStateRejectCase(t, tc)
})
}
})
}
// TestSendMessageQueueCapRejectsQueueAppend seeds a chat with the
// maximum queued messages and asserts that the next SendMessage in
// a queue-appending state returns chatstate.ErrMessageQueueFull and
// rolls back without persisting another queued row.
func TestSendMessageQueueCapRejectsQueueAppend(t *testing.T) {
t.Parallel()
f := newTestFixture(t)
ctx := testutil.Context(t, testutil.WaitShort)
created := createTestChat(t, f)
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
// createTestChat lands the chat in R0; SendMessage in R0 with
// BusyBehaviorQueue queues. Fill the queue to MaxQueueSize.
for i := 0; i < chatstate.MaxQueueSize; i++ {
sendQueuedMessage(t, f, m, "filler")
}
count, err := f.DB.CountChatQueuedMessages(ctx, created.Chat.ID)
require.NoError(t, err)
require.EqualValues(t, chatstate.MaxQueueSize, count)
chatBefore := f.readChat(ctx, t, created.Chat.ID)
// The next queue append must fail with ErrMessageQueueFull and a
// typed wrapper that exposes the cap.
err = m.Update(ctx, func(tx *chatstate.Tx) error {
_, serr := tx.SendMessage(chatstate.SendMessageInput{
Message: userTextMessage("overflow", f.User.ID, f.Model.ID),
BusyBehavior: chatstate.BusyBehaviorQueue,
})
return serr
})
require.Error(t, err)
require.ErrorIs(t, err, chatstate.ErrMessageQueueFull,
"queue-append over the cap returns ErrMessageQueueFull")
var typed *chatstate.MessageQueueFullError
require.ErrorAs(t, err, &typed, "ErrMessageQueueFull is carried as a typed error")
require.EqualValues(t, chatstate.MaxQueueSize, typed.Max)
// The transaction rolled back: queue size, snapshot version,
// and queue version are unchanged.
countAfter, err := f.DB.CountChatQueuedMessages(ctx, created.Chat.ID)
require.NoError(t, err)
require.EqualValues(t, chatstate.MaxQueueSize, countAfter,
"queue size must not change when the cap rejects the append")
chatAfter := f.readChat(ctx, t, created.Chat.ID)
require.Equal(t, chatBefore.SnapshotVersion, chatAfter.SnapshotVersion,
"failed queue append must not bump snapshot_version")
require.Equal(t, chatBefore.QueueVersion, chatAfter.QueueVersion,
"failed queue append must not bump queue_version")
}
// TestEditMessageNonUserReturnsSentinel asserts that editing a
// non-user message returns chatstate.ErrEditedMessageNotUser via
// the TransitionError cause chain, and still matches the generic
// transition sentinel.
func TestEditMessageNonUserReturnsSentinel(t *testing.T) {
t.Parallel()
f := newTestFixture(t)
ctx := testutil.Context(t, testutil.WaitShort)
created := createTestChat(t, f)
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
// Insert an assistant message via CommitStep so we have a
// non-user message to target.
var assistantID int64
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
assistant := userTextMessage("assistant", f.User.ID, f.Model.ID)
assistant.Role = database.ChatMessageRoleAssistant
step, err := tx.CommitStep(chatstate.CommitStepInput{
Messages: []chatstate.Message{assistant},
})
if err != nil {
return err
}
require.Len(t, step.InsertedMessages, 1)
assistantID = step.InsertedMessages[0].ID
return nil
}))
rawContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
codersdk.ChatMessageText("new content"),
})
require.NoError(t, err)
editErr := m.Update(ctx, func(tx *chatstate.Tx) error {
_, eerr := tx.EditMessage(chatstate.EditMessageInput{
MessageID: assistantID,
CreatedBy: f.User.ID,
Content: rawContent,
})
return eerr
})
require.Error(t, editErr)
require.ErrorIs(t, editErr, chatstate.ErrEditedMessageNotUser,
"non-user edit returns ErrEditedMessageNotUser via TransitionError cause")
require.ErrorIs(t, editErr, chatstate.ErrTransitionNotAllowed,
"ErrEditedMessageNotUser still matches the generic transition sentinel")
}
// TestTransitionAbandon_RejectsUnowned verifies that calling Abandon
// on a chat the runner does not own returns ErrTransitionNotAllowed
// wrapped in a TransitionError that records the loaded from-state,
// without mutating chat state or publishing anything.
func TestTransitionAbandon_RejectsUnowned(t *testing.T) {
t.Parallel()
f := newTestFixture(t)
ctx := testutil.Context(t, testutil.WaitShort)
created := createTestChat(t, f)
seeded := seededChat{chatID: created.Chat.ID, exists: true}
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
base := captureBaseline(ctx, t, f, seeded)
err := m.Update(ctx, func(tx *chatstate.Tx) error {
_, aerr := tx.Abandon(chatstate.AbandonInput{})
return aerr
})
require.Error(t, err)
require.ErrorIs(t, err, chatstate.ErrTransitionNotAllowed)
var te *chatstate.TransitionError
require.ErrorAs(t, err, &te)
require.Equal(t, chatstate.TransitionAbandon, te.Transition)
// createTestChat lands the chat in R0; Abandon's precondition
// rejects an unowned chat there.
require.Equal(t, chatstate.StateR0, te.From)
assertNoMutationOrPublish(ctx, t, f, seeded.chatID, base)
}
// TestTransitionAbandon_ClearsOwnership verifies the Acquire/Abandon
// round-trip: after Acquire the chat carries a worker+runner and a
// fresh heartbeat row exists, and after Abandon both ownership fields
// are cleared. The heartbeat row is not deleted by Abandon; heartbeat
// cleanup is a separate concern.
func TestTransitionAbandon_ClearsOwnership(t *testing.T) {
t.Parallel()
f := newTestFixture(t)
ctx := testutil.Context(t, testutil.WaitShort)
created := createTestChat(t, f)
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
worker := uuid.New()
runner := uuid.New()
// Acquire writes ownership and a fresh heartbeat row.
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
_, err := tx.Acquire(chatstate.AcquireInput{WorkerID: worker, RunnerID: runner})
return err
}))
owned := f.readChat(ctx, t, created.Chat.ID)
require.Equal(t, worker, owned.WorkerID.UUID)
require.Equal(t, runner, owned.RunnerID.UUID)
hb, err := f.DB.GetChatHeartbeat(ctx, database.GetChatHeartbeatParams{
ChatID: created.Chat.ID,
RunnerID: runner,
})
require.NoError(t, err, "Acquire writes a fresh heartbeat row")
require.Equal(t, runner, hb.RunnerID)
// Abandon clears ownership but leaves the heartbeat row intact.
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
_, err := tx.Abandon(chatstate.AbandonInput{})
return err
}))
hb, err = f.DB.GetChatHeartbeat(ctx, database.GetChatHeartbeatParams{
ChatID: created.Chat.ID,
RunnerID: runner,
})
require.NoError(t, err, "Abandon does not delete the heartbeat row")
abandoned := f.readChat(ctx, t, created.Chat.ID)
require.False(t, abandoned.WorkerID.Valid, "Abandon clears worker_id")
require.False(t, abandoned.RunnerID.Valid, "Abandon clears runner_id")
}
// TestTransitionAcquire_OverwritesFreshOwnership verifies that Acquire
// is an unconditional ownership handoff: a second worker calling
// Acquire on a chat that was *just* acquired by another worker
// successfully replaces ownership without inspecting heartbeat
// freshness. It also asserts that Acquire itself does not request an
// ownership hint, so the post-commit publish stays quiet on
// `chat:ownership` when the resulting heartbeat is fresh.
func TestTransitionAcquire_OverwritesFreshOwnership(t *testing.T) {
t.Parallel()
f := newTestFixture(t)
ctx := testutil.Context(t, testutil.WaitShort)
created := createTestChat(t, f)
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
firstWorker := uuid.New()
firstRunner := uuid.New()
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
_, err := tx.Acquire(chatstate.AcquireInput{WorkerID: firstWorker, RunnerID: firstRunner})
return err
}))
// The chat is now owned with a fresh (chat_id, firstRunner)
// heartbeat written by the first Acquire.
firstChat := f.readChat(ctx, t, created.Chat.ID)
require.Equal(t, firstWorker, firstChat.WorkerID.UUID)
require.Equal(t, firstRunner, firstChat.RunnerID.UUID)
_, err := f.DB.GetChatHeartbeat(ctx, database.GetChatHeartbeatParams{
ChatID: created.Chat.ID,
RunnerID: firstRunner,
})
require.NoError(t, err, "first Acquire wrote a fresh heartbeat")
// Sanity check: heartbeat is not stale by the same threshold the
// machine uses for ownership-hint decisions.
stale, err := f.DB.IsChatHeartbeatStale(ctx, database.IsChatHeartbeatStaleParams{
ChatID: created.Chat.ID,
RunnerID: firstRunner,
StaleSeconds: chatstate.HeartbeatStaleSeconds,
})
require.NoError(t, err)
require.False(t, stale, "first runner's heartbeat is fresh before the second Acquire")
// Snapshot publish counts before the takeover so we can assert
// Acquire does not publish an ownership hint itself.
ownershipBefore := f.Pub.ownershipPublishCount()
beforeChat := f.readChat(ctx, t, created.Chat.ID)
secondWorker := uuid.New()
secondRunner := uuid.New()
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
_, err := tx.Acquire(chatstate.AcquireInput{WorkerID: secondWorker, RunnerID: secondRunner})
return err
}))
after := f.readChat(ctx, t, created.Chat.ID)
require.Equal(t, secondWorker, after.WorkerID.UUID, "ownership replaced")
require.Equal(t, secondRunner, after.RunnerID.UUID, "runner replaced")
require.Equal(t, beforeChat.SnapshotVersion+1, after.SnapshotVersion, "snapshot bumps exactly once")
f.Pub.expectChatUpdate(t, created.Chat.ID, after.SnapshotVersion)
// The new (chat_id, secondRunner) heartbeat exists. The old
// (chat_id, firstRunner) row may or may not exist; Acquire is not
// responsible for cleaning it up.
_, err = f.DB.GetChatHeartbeat(ctx, database.GetChatHeartbeatParams{
ChatID: created.Chat.ID,
RunnerID: secondRunner,
})
require.NoError(t, err, "second Acquire wrote a heartbeat for the new runner")
// Acquire does not publish an ownership hint when it writes a fresh
// heartbeat. The post-commit ownership-hint logic in Update stays
// quiet because the new heartbeat is fresh, so no `chat:ownership`
// notification fires.
require.Equal(t, ownershipBefore, f.Pub.ownershipPublishCount(),
"Acquire must not publish an ownership hint when the resulting heartbeat is fresh")
}
// TestTransitionAcquire_ExecutionStateOrthogonal verifies that Acquire
// preserves every execution-state field on the chat across
// representative valid execution states, including idle, runnable, and
// archived states. The transition only mutates ownership.
func TestTransitionAcquire_ExecutionStateOrthogonal(t *testing.T) {
t.Parallel()
// Each setup leaves the chat in the named state and returns the
// chat ID for downstream assertions.
cases := []struct {
name string
state chatstate.ExecutionState
setup func(t *testing.T, f *testFixture) uuid.UUID
}{
{
name: "R0",
state: chatstate.StateR0,
setup: func(t *testing.T, f *testFixture) uuid.UUID {
return createTestChat(t, f).Chat.ID
},
},
{
name: "W",
state: chatstate.StateW,
setup: func(t *testing.T, f *testFixture) uuid.UUID {
created := createTestChat(t, f)
ctx := testutil.Context(t, testutil.WaitShort)
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
_, err := tx.FinishTurn(chatstate.FinishTurnInput{})
return err
}))
return created.Chat.ID
},
},
{
name: "E0",
state: chatstate.StateE0,
setup: func(t *testing.T, f *testFixture) uuid.UUID {
created := createTestChat(t, f)
ctx := testutil.Context(t, testutil.WaitShort)
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
_, err := tx.FinishError(chatstate.FinishErrorInput{
LastError: pqtype.NullRawMessage{
RawMessage: json.RawMessage(`{"message":"boom"}`),
Valid: true,
},
})
return err
}))
return created.Chat.ID
},
},
{
name: "I0",
state: chatstate.StateI0,
setup: func(t *testing.T, f *testFixture) uuid.UUID {
created := createTestChat(t, f)
ctx := testutil.Context(t, testutil.WaitShort)
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
_, err := tx.Interrupt(chatstate.InterruptInput{Reason: "test"})
return err
}))
return created.Chat.ID
},
},
{
name: "XW",
state: chatstate.StateXW,
setup: func(t *testing.T, f *testFixture) uuid.UUID {
created := createTestChat(t, f)
ctx := testutil.Context(t, testutil.WaitShort)
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
_, err := tx.FinishTurn(chatstate.FinishTurnInput{})
return err
}))
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
_, err := tx.SetArchived(chatstate.SetArchivedInput{Archived: true})
return err
}))
return created.Chat.ID
},
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
f := newTestFixture(t)
ctx := testutil.Context(t, testutil.WaitShort)
chatID := tc.setup(t, f)
require.Equal(t, tc.state, f.classify(ctx, t, chatID), "test setup must leave chat in %s", tc.state)
before := f.readChat(ctx, t, chatID)
queueBefore, err := f.DB.CountChatQueuedMessages(ctx, chatID)
require.NoError(t, err)
historyBefore := historyMessageIDs(ctx, t, f, chatID)
worker := uuid.New()
runner := uuid.New()
m := chatstate.NewChatMachine(f.DB, f.Pub, chatID, chatstate.Options{})
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
_, err := tx.Acquire(chatstate.AcquireInput{WorkerID: worker, RunnerID: runner})
return err
}))
after := f.readChat(ctx, t, chatID)
// Ownership updated.
require.Equal(t, worker, after.WorkerID.UUID)
require.Equal(t, runner, after.RunnerID.UUID)
// Execution state preserved.
require.Equal(t, before.Status, after.Status, "status preserved")
require.Equal(t, before.Archived, after.Archived, "archived flag preserved")
require.Equal(t, before.RequiresActionDeadlineAt, after.RequiresActionDeadlineAt, "requires-action deadline preserved")
require.Equal(t, before.LastError, after.LastError, "last_error preserved")
require.Equal(t, before.HistoryVersion, after.HistoryVersion, "history_version preserved")
require.Equal(t, before.QueueVersion, after.QueueVersion, "queue_version preserved")
require.Equal(t, before.GenerationAttempt, after.GenerationAttempt, "generation_attempt preserved")
// Classified state unchanged.
require.Equal(t, tc.state, f.classify(ctx, t, chatID), "execution state preserved by Acquire")
// Queue and history rows untouched.
queueAfter, err := f.DB.CountChatQueuedMessages(ctx, chatID)
require.NoError(t, err)
require.Equal(t, queueBefore, queueAfter, "queue cardinality preserved")
require.Equal(t, historyBefore, historyMessageIDs(ctx, t, f, chatID), "history preserved")
})
}
}
+631
View File
@@ -0,0 +1,631 @@
package chatstate_test
import (
"database/sql"
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
)
// triggerFixture is a slim variant of testFixture that also exposes a
// raw *sql.DB so the trigger tests can run UPDATE/INSERT statements
// that bypass the typed sqlc layer. Tests that only need the typed
// store should keep using newTestFixture.
type triggerFixture struct {
f *testFixture
sqlDB *sql.DB
}
func newTriggerFixture(t *testing.T) *triggerFixture {
t.Helper()
db, ps, sqlDB := dbtestutil.NewDBWithSQLDB(t)
user := dbgen.User(t, db, database.User{})
org := dbgen.Organization(t, db, database.Organization{})
dbgen.OrganizationMember(t, db, database.OrganizationMember{
UserID: user.ID,
OrganizationID: org.ID,
})
dbgen.ChatProvider(t, db, database.ChatProvider{
Provider: "openai",
DisplayName: "openai",
BaseUrl: "http://example.invalid",
})
model := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{
Provider: "openai",
IsDefault: true,
})
f := &testFixture{
DB: db,
PubSub: ps,
Pub: newRecordingPubsub(),
User: user,
Org: org,
Model: model,
}
return &triggerFixture{f: f, sqlDB: sqlDB}
}
// userMessageContent returns a marshaled user message body suitable
// for raw INSERT into chat_messages.
func userMessageContent(t *testing.T, text string) []byte {
t.Helper()
raw, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{codersdk.ChatMessageText(text)})
require.NoError(t, err)
return raw.RawMessage
}
// TestMessageInsertAssignsRevisionAndHistoryVersion verifies that
// inserting a chat message via the legacy InsertChatMessages query
// assigns NEW.revision from chats.snapshot_version (BEFORE trigger)
// and bumps chats.history_version + resets generation_attempt (AFTER
// STATEMENT trigger).
func TestMessageInsertAssignsRevisionAndHistoryVersion(t *testing.T) {
t.Parallel()
tf := newTriggerFixture(t)
f := tf.f
ctx := testutil.Context(t, testutil.WaitShort)
created := createTestChat(t, f)
require.Equal(t, int64(1), created.Chat.SnapshotVersion)
require.Equal(t, int64(1), created.Chat.HistoryVersion)
// Force generation_attempt > 0 so we can prove the trigger
// resets it on a new history change.
_, err := f.DB.IncrementChatGenerationAttempt(ctx, created.Chat.ID)
require.NoError(t, err)
before, err := f.DB.GetChatByID(ctx, created.Chat.ID)
require.NoError(t, err)
require.Equal(t, int64(1), before.GenerationAttempt)
// Bump snapshot_version directly to simulate a transition having
// taken the row lock.
bumped, err := f.DB.LockChatAndBumpSnapshotVersion(ctx, created.Chat.ID)
require.NoError(t, err)
require.Equal(t, before.SnapshotVersion+1, bumped.SnapshotVersion)
// Insert a new assistant message via raw SQL so we know the
// BEFORE+AFTER triggers (and only those) decide revision and
// history_version.
content := userMessageContent(t, "hello-after-bump")
_, err = tf.sqlDB.ExecContext(ctx, `
INSERT INTO chat_messages (chat_id, role, content, content_version, visibility)
VALUES ($1, 'assistant', $2::jsonb, $3, 'both')
`, created.Chat.ID, string(content), int(chatprompt.CurrentContentVersion))
require.NoError(t, err)
// History version equals snapshot_version, generation_attempt resets.
after, err := f.DB.GetChatByID(ctx, created.Chat.ID)
require.NoError(t, err)
require.Equal(t, bumped.SnapshotVersion, after.HistoryVersion)
require.Equal(t, int64(0), after.GenerationAttempt)
// The inserted message picked up revision = bumped snapshot.
msgs, err := f.DB.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: created.Chat.ID,
})
require.NoError(t, err)
require.NotEmpty(t, msgs)
last := msgs[len(msgs)-1]
require.Equal(t, database.ChatMessageRoleAssistant, last.Role)
require.Equal(t, bumped.SnapshotVersion, last.Revision)
}
// TestMessageUpdateAssignsNewRevisionAndHistoryVersion verifies that
// updating a chat message's content advances NEW.revision to the
// current chats.snapshot_version and that chats.history_version
// bumps to match.
func TestMessageUpdateAssignsNewRevisionAndHistoryVersion(t *testing.T) {
t.Parallel()
tf := newTriggerFixture(t)
f := tf.f
ctx := testutil.Context(t, testutil.WaitShort)
created := createTestChat(t, f)
msgs, err := f.DB.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: created.Chat.ID,
})
require.NoError(t, err)
require.NotEmpty(t, msgs)
target := msgs[0]
originalRevision := target.Revision
// Bump the snapshot so the trigger sees a new revision target.
bumped, err := f.DB.LockChatAndBumpSnapshotVersion(ctx, created.Chat.ID)
require.NoError(t, err)
require.Greater(t, bumped.SnapshotVersion, originalRevision)
newContent := userMessageContent(t, "edited content")
_, err = tf.sqlDB.ExecContext(ctx, `
UPDATE chat_messages SET content = $1::jsonb WHERE id = $2
`, string(newContent), target.ID)
require.NoError(t, err)
reloaded, err := f.DB.GetChatMessageByID(ctx, target.ID)
require.NoError(t, err)
require.Equal(t, bumped.SnapshotVersion, reloaded.Revision,
"updated message picks up the current snapshot version")
chatAfter, err := f.DB.GetChatByID(ctx, created.Chat.ID)
require.NoError(t, err)
require.Equal(t, bumped.SnapshotVersion, chatAfter.HistoryVersion)
require.Equal(t, int64(0), chatAfter.GenerationAttempt,
"history change resets generation_attempt")
}
// TestMessageRevisionCannotBeSetByRuntimeCode verifies the BEFORE
// trigger rejects explicit revision values on INSERT and rejects
// revision changes on UPDATE.
func TestMessageRevisionCannotBeSetByRuntimeCode(t *testing.T) {
t.Parallel()
tf := newTriggerFixture(t)
f := tf.f
ctx := testutil.Context(t, testutil.WaitShort)
created := createTestChat(t, f)
content := userMessageContent(t, "explicit revision")
_, err := tf.sqlDB.ExecContext(ctx, `
INSERT INTO chat_messages (chat_id, role, content, content_version, visibility, revision)
VALUES ($1, 'user', $2::jsonb, $3, 'both', 999)
`, created.Chat.ID, string(content), int(chatprompt.CurrentContentVersion))
require.Error(t, err, "INSERT with explicit revision must be rejected")
require.Contains(t, err.Error(), "revision must be assigned by trigger")
msgs, err := f.DB.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: created.Chat.ID,
})
require.NoError(t, err)
require.NotEmpty(t, msgs)
target := msgs[0]
_, err = tf.sqlDB.ExecContext(ctx, `
UPDATE chat_messages SET revision = revision + 100 WHERE id = $1
`, target.ID)
require.Error(t, err, "UPDATE that changes revision must be rejected")
require.Contains(t, err.Error(), "revision must be assigned by trigger")
}
// TestMessageChatIDCannotChange verifies the BEFORE trigger rejects
// updates that change chat_messages.chat_id.
func TestMessageChatIDCannotChange(t *testing.T) {
t.Parallel()
tf := newTriggerFixture(t)
f := tf.f
ctx := testutil.Context(t, testutil.WaitShort)
first := createTestChat(t, f)
second := createTestChat(t, f)
firstMsgs, err := f.DB.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: first.Chat.ID,
})
require.NoError(t, err)
require.NotEmpty(t, firstMsgs)
target := firstMsgs[0]
_, err = tf.sqlDB.ExecContext(ctx, `
UPDATE chat_messages SET chat_id = $1 WHERE id = $2
`, second.Chat.ID, target.ID)
require.Error(t, err, "UPDATE that changes chat_id must be rejected")
require.Contains(t, err.Error(), "chat_id is immutable")
}
// TestNoopMessageUpdateDoesNotAdvanceHistoryVersion verifies that a
// no-op UPDATE on a chat_messages row (one whose OLD and NEW are
// indistinguishable) does NOT advance chats.history_version even
// when the snapshot was previously bumped. This guards against the
// AFTER UPDATE STATEMENT trigger naively reacting to every touched
// row id regardless of whether the row actually changed.
func TestNoopMessageUpdateDoesNotAdvanceHistoryVersion(t *testing.T) {
t.Parallel()
tf := newTriggerFixture(t)
f := tf.f
ctx := testutil.Context(t, testutil.WaitShort)
created := createTestChat(t, f)
msgs, err := f.DB.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: created.Chat.ID,
})
require.NoError(t, err)
require.NotEmpty(t, msgs)
target := msgs[0]
originalRevision := target.Revision
// Bump snapshot so the AFTER STATEMENT guard
// (history_version != snapshot_version) is now true.
bumped, err := f.DB.LockChatAndBumpSnapshotVersion(ctx, created.Chat.ID)
require.NoError(t, err)
require.NotEqual(t, bumped.SnapshotVersion, bumped.HistoryVersion,
"snapshot bump leaves history_version trailing")
// No-op UPDATE: SET content = content. OLD IS NOT DISTINCT FROM NEW.
_, err = tf.sqlDB.ExecContext(ctx, `
UPDATE chat_messages SET content = content WHERE id = $1
`, target.ID)
require.NoError(t, err)
after, err := f.DB.GetChatByID(ctx, created.Chat.ID)
require.NoError(t, err)
require.Equal(t, bumped.HistoryVersion, after.HistoryVersion,
"no-op update must NOT advance history_version")
// And the row's revision is untouched.
reloaded, err := f.DB.GetChatMessageByID(ctx, target.ID)
require.NoError(t, err)
require.Equal(t, originalRevision, reloaded.Revision,
"no-op update must NOT advance message revision")
}
// =============================================================================
// Queue version triggers
// =============================================================================
// TestQueueInsertUpdatesQueueVersion verifies that an INSERT into
// chat_queued_messages bumps chats.queue_version to the current
// snapshot_version.
func TestQueueInsertUpdatesQueueVersion(t *testing.T) {
t.Parallel()
tf := newTriggerFixture(t)
f := tf.f
ctx := testutil.Context(t, testutil.WaitShort)
created := createTestChat(t, f)
before, err := f.DB.GetChatByID(ctx, created.Chat.ID)
require.NoError(t, err)
require.Equal(t, int64(0), before.QueueVersion)
bumped, err := f.DB.LockChatAndBumpSnapshotVersion(ctx, created.Chat.ID)
require.NoError(t, err)
content := userMessageContent(t, "queued")
_, err = f.DB.InsertChatQueuedMessageWithCreator(ctx, database.InsertChatQueuedMessageWithCreatorParams{
ChatID: created.Chat.ID,
Content: content,
CreatedBy: f.User.ID,
})
require.NoError(t, err)
after, err := f.DB.GetChatByID(ctx, created.Chat.ID)
require.NoError(t, err)
require.Equal(t, bumped.SnapshotVersion, after.QueueVersion,
"INSERT into chat_queued_messages bumps queue_version")
}
// TestQueuedMessageCreatedByIsRequired verifies the database enforces
// creator metadata for every queued message row.
func TestQueuedMessageCreatedByIsRequired(t *testing.T) {
t.Parallel()
tf := newTriggerFixture(t)
f := tf.f
ctx := testutil.Context(t, testutil.WaitShort)
created := createTestChat(t, f)
content := userMessageContent(t, "queued-without-creator")
_, err := tf.sqlDB.ExecContext(ctx, `
INSERT INTO chat_queued_messages (chat_id, content, model_config_id, created_by)
VALUES ($1, $2::jsonb, NULL, NULL)
`, created.Chat.ID, string(content))
require.Error(t, err)
require.Contains(t, err.Error(), "created_by")
}
func TestLegacyQueuedMessageInsertUsesChatOwnerAsCreator(t *testing.T) {
t.Parallel()
tf := newTriggerFixture(t)
f := tf.f
ctx := testutil.Context(t, testutil.WaitShort)
created := createTestChat(t, f)
queued, err := f.DB.InsertChatQueuedMessage(ctx, database.InsertChatQueuedMessageParams{
ChatID: created.Chat.ID,
Content: userMessageContent(t, "legacy-queued"),
})
require.NoError(t, err)
require.Equal(t, created.Chat.OwnerID, queued.CreatedBy)
}
// TestQueueUpdateContentUpdatesQueueVersion verifies that an UPDATE
// of chat_queued_messages.content bumps queue_version. The
// AFTER UPDATE trigger explicitly listens for content changes.
func TestQueueUpdateContentUpdatesQueueVersion(t *testing.T) {
t.Parallel()
tf := newTriggerFixture(t)
f := tf.f
ctx := testutil.Context(t, testutil.WaitShort)
created := createTestChat(t, f)
queued, err := f.DB.InsertChatQueuedMessageWithCreator(ctx, database.InsertChatQueuedMessageWithCreatorParams{
ChatID: created.Chat.ID,
Content: userMessageContent(t, "initial"),
CreatedBy: f.User.ID,
})
require.NoError(t, err)
before, err := f.DB.GetChatByID(ctx, created.Chat.ID)
require.NoError(t, err)
bumped, err := f.DB.LockChatAndBumpSnapshotVersion(ctx, created.Chat.ID)
require.NoError(t, err)
require.Greater(t, bumped.SnapshotVersion, before.QueueVersion)
updated := userMessageContent(t, "updated")
_, err = tf.sqlDB.ExecContext(ctx, `
UPDATE chat_queued_messages SET content = $1::jsonb WHERE id = $2
`, string(updated), queued.ID)
require.NoError(t, err)
after, err := f.DB.GetChatByID(ctx, created.Chat.ID)
require.NoError(t, err)
require.Equal(t, bumped.SnapshotVersion, after.QueueVersion,
"UPDATE of queued content bumps queue_version")
}
// TestQueueUpdatePositionUpdatesQueueVersion verifies that an UPDATE
// of chat_queued_messages.position (such as the reorder-to-head
// path) bumps queue_version.
func TestQueueUpdatePositionUpdatesQueueVersion(t *testing.T) {
t.Parallel()
tf := newTriggerFixture(t)
f := tf.f
ctx := testutil.Context(t, testutil.WaitShort)
created := createTestChat(t, f)
q1, err := f.DB.InsertChatQueuedMessageWithCreator(ctx, database.InsertChatQueuedMessageWithCreatorParams{
ChatID: created.Chat.ID,
Content: userMessageContent(t, "first"),
CreatedBy: f.User.ID,
})
require.NoError(t, err)
q2, err := f.DB.InsertChatQueuedMessageWithCreator(ctx, database.InsertChatQueuedMessageWithCreatorParams{
ChatID: created.Chat.ID,
Content: userMessageContent(t, "second"),
CreatedBy: f.User.ID,
})
require.NoError(t, err)
require.NotEqual(t, q1.ID, q2.ID)
bumped, err := f.DB.LockChatAndBumpSnapshotVersion(ctx, created.Chat.ID)
require.NoError(t, err)
// Move q2 to head by setting its position to q1.position - 1.
_, err = tf.sqlDB.ExecContext(ctx, `
UPDATE chat_queued_messages SET position = $1 WHERE id = $2
`, q1.Position-1, q2.ID)
require.NoError(t, err)
after, err := f.DB.GetChatByID(ctx, created.Chat.ID)
require.NoError(t, err)
require.Equal(t, bumped.SnapshotVersion, after.QueueVersion,
"UPDATE of queued position bumps queue_version")
}
// TestQueueDeleteUpdatesQueueVersion verifies that DELETE from
// chat_queued_messages bumps queue_version.
func TestQueueDeleteUpdatesQueueVersion(t *testing.T) {
t.Parallel()
tf := newTriggerFixture(t)
f := tf.f
ctx := testutil.Context(t, testutil.WaitShort)
created := createTestChat(t, f)
queued, err := f.DB.InsertChatQueuedMessageWithCreator(ctx, database.InsertChatQueuedMessageWithCreatorParams{
ChatID: created.Chat.ID,
Content: userMessageContent(t, "to delete"),
CreatedBy: f.User.ID,
})
require.NoError(t, err)
bumped, err := f.DB.LockChatAndBumpSnapshotVersion(ctx, created.Chat.ID)
require.NoError(t, err)
rows, err := f.DB.DeleteChatQueuedMessageReturningCount(ctx, database.DeleteChatQueuedMessageReturningCountParams{
ID: queued.ID,
ChatID: created.Chat.ID,
})
require.NoError(t, err)
require.Equal(t, int64(1), rows)
after, err := f.DB.GetChatByID(ctx, created.Chat.ID)
require.NoError(t, err)
require.Equal(t, bumped.SnapshotVersion, after.QueueVersion,
"DELETE from queue bumps queue_version")
}
// TestNonQueueUpdateDoesNotUpdateQueueVersion verifies that mutations
// on other chat-related tables do NOT bump queue_version. The
// canonical case is inserting a chat message: it must update
// history_version but leave queue_version untouched.
func TestNonQueueUpdateDoesNotUpdateQueueVersion(t *testing.T) {
t.Parallel()
tf := newTriggerFixture(t)
f := tf.f
ctx := testutil.Context(t, testutil.WaitShort)
created := createTestChat(t, f)
before, err := f.DB.GetChatByID(ctx, created.Chat.ID)
require.NoError(t, err)
bumped, err := f.DB.LockChatAndBumpSnapshotVersion(ctx, created.Chat.ID)
require.NoError(t, err)
content := userMessageContent(t, "non-queue mutation")
_, err = tf.sqlDB.ExecContext(ctx, `
INSERT INTO chat_messages (chat_id, role, content, content_version, visibility)
VALUES ($1, 'assistant', $2::jsonb, $3, 'both')
`, created.Chat.ID, string(content), int(chatprompt.CurrentContentVersion))
require.NoError(t, err)
after, err := f.DB.GetChatByID(ctx, created.Chat.ID)
require.NoError(t, err)
require.Equal(t, before.QueueVersion, after.QueueVersion,
"chat_messages INSERT must not bump queue_version")
// Sanity: history_version DID move.
require.Equal(t, bumped.SnapshotVersion, after.HistoryVersion)
}
// =============================================================================
// Retry state triggers
// =============================================================================
func TestRetryStateDefaults(t *testing.T) {
t.Parallel()
tf := newTriggerFixture(t)
f := tf.f
ctx := testutil.Context(t, testutil.WaitShort)
created := createTestChat(t, f)
chat, err := f.DB.GetChatByID(ctx, created.Chat.ID)
require.NoError(t, err)
require.False(t, chat.RetryState.Valid)
require.Equal(t, int64(0), chat.RetryStateVersion)
}
func TestRetryStateUpdateSetsRetryStateVersion(t *testing.T) {
t.Parallel()
tf := newTriggerFixture(t)
f := tf.f
ctx := testutil.Context(t, testutil.WaitShort)
created := createTestChat(t, f)
bumped, err := f.DB.LockChatAndBumpSnapshotVersion(ctx, created.Chat.ID)
require.NoError(t, err)
after, err := f.DB.UpdateChatRetryState(ctx, database.UpdateChatRetryStateParams{
ID: created.Chat.ID,
RetryState: []byte(`{"attempt":1,"delay_ms":250,"error":"retry","retrying_at":"2026-05-29T00:00:00Z"}`),
})
require.NoError(t, err)
require.True(t, after.RetryState.Valid)
require.JSONEq(t,
`{"attempt":1,"delay_ms":250,"error":"retry","retrying_at":"2026-05-29T00:00:00Z"}`,
string(after.RetryState.RawMessage))
require.Equal(t, bumped.SnapshotVersion, after.RetryStateVersion)
}
func TestRetryStateSameValueDoesNotUpdateRetryStateVersion(t *testing.T) {
t.Parallel()
tf := newTriggerFixture(t)
f := tf.f
ctx := testutil.Context(t, testutil.WaitShort)
created := createTestChat(t, f)
payload := []byte(`{"attempt":1,"delay_ms":250,"error":"retry","retrying_at":"2026-05-29T00:00:00Z"}`)
_, err := f.DB.LockChatAndBumpSnapshotVersion(ctx, created.Chat.ID)
require.NoError(t, err)
first, err := f.DB.UpdateChatRetryState(ctx, database.UpdateChatRetryStateParams{
ID: created.Chat.ID,
RetryState: payload,
})
require.NoError(t, err)
_, err = f.DB.LockChatAndBumpSnapshotVersion(ctx, created.Chat.ID)
require.NoError(t, err)
second, err := f.DB.UpdateChatRetryState(ctx, database.UpdateChatRetryStateParams{
ID: created.Chat.ID,
RetryState: payload,
})
require.NoError(t, err)
require.Equal(t, first.RetryStateVersion, second.RetryStateVersion,
"same retry_state payload must not update retry_state_version")
}
func TestGenerationAttemptClearsRetryState(t *testing.T) {
t.Parallel()
tf := newTriggerFixture(t)
f := tf.f
ctx := testutil.Context(t, testutil.WaitShort)
created := createTestChat(t, f)
_, err := f.DB.LockChatAndBumpSnapshotVersion(ctx, created.Chat.ID)
require.NoError(t, err)
withRetry, err := f.DB.UpdateChatRetryState(ctx, database.UpdateChatRetryStateParams{
ID: created.Chat.ID,
RetryState: []byte(`{"attempt":1,"delay_ms":250,"error":"retry","retrying_at":"2026-05-29T00:00:00Z"}`),
})
require.NoError(t, err)
require.True(t, withRetry.RetryState.Valid)
bumped, err := f.DB.LockChatAndBumpSnapshotVersion(ctx, created.Chat.ID)
require.NoError(t, err)
attempt, err := f.DB.IncrementChatGenerationAttempt(ctx, created.Chat.ID)
require.NoError(t, err)
require.Equal(t, int64(1), attempt)
after, err := f.DB.GetChatByID(ctx, created.Chat.ID)
require.NoError(t, err)
require.False(t, after.RetryState.Valid)
require.Equal(t, bumped.SnapshotVersion, after.RetryStateVersion,
"clearing retry_state on generation attempt bumps retry_state_version")
}
func TestGenerationAttemptWithNullRetryStateDoesNotUpdateRetryStateVersion(t *testing.T) {
t.Parallel()
tf := newTriggerFixture(t)
f := tf.f
ctx := testutil.Context(t, testutil.WaitShort)
created := createTestChat(t, f)
before, err := f.DB.GetChatByID(ctx, created.Chat.ID)
require.NoError(t, err)
require.False(t, before.RetryState.Valid)
_, err = f.DB.LockChatAndBumpSnapshotVersion(ctx, created.Chat.ID)
require.NoError(t, err)
_, err = f.DB.IncrementChatGenerationAttempt(ctx, created.Chat.ID)
require.NoError(t, err)
after, err := f.DB.GetChatByID(ctx, created.Chat.ID)
require.NoError(t, err)
require.False(t, after.RetryState.Valid)
require.Equal(t, before.RetryStateVersion, after.RetryStateVersion,
"generation attempt with null retry_state leaves retry_state_version unchanged")
}
func TestRetryStateVersionCannotBeSetByRuntimeCode(t *testing.T) {
t.Parallel()
tf := newTriggerFixture(t)
f := tf.f
ctx := testutil.Context(t, testutil.WaitShort)
created := createTestChat(t, f)
_, err := tf.sqlDB.ExecContext(ctx, `
UPDATE chats SET retry_state_version = retry_state_version + 1 WHERE id = $1
`, created.Chat.ID)
require.Error(t, err)
require.Contains(t, err.Error(), "retry_state_version must be assigned by trigger")
}
func TestHistoryChangeClearsRetryState(t *testing.T) {
t.Parallel()
tf := newTriggerFixture(t)
f := tf.f
ctx := testutil.Context(t, testutil.WaitShort)
created := createTestChat(t, f)
_, err := f.DB.LockChatAndBumpSnapshotVersion(ctx, created.Chat.ID)
require.NoError(t, err)
_, err = f.DB.IncrementChatGenerationAttempt(ctx, created.Chat.ID)
require.NoError(t, err)
_, err = f.DB.UpdateChatRetryState(ctx, database.UpdateChatRetryStateParams{
ID: created.Chat.ID,
RetryState: []byte(`{"attempt":1,"delay_ms":250,"error":"retry","retrying_at":"2026-05-29T00:00:00Z"}`),
})
require.NoError(t, err)
bumped, err := f.DB.LockChatAndBumpSnapshotVersion(ctx, created.Chat.ID)
require.NoError(t, err)
content := userMessageContent(t, "history clears retry state")
_, err = tf.sqlDB.ExecContext(ctx, `
INSERT INTO chat_messages (chat_id, role, content, content_version, visibility)
VALUES ($1, 'assistant', $2::jsonb, $3, 'both')
`, created.Chat.ID, string(content), int(chatprompt.CurrentContentVersion))
require.NoError(t, err)
after, err := f.DB.GetChatByID(ctx, created.Chat.ID)
require.NoError(t, err)
require.Equal(t, int64(0), after.GenerationAttempt)
require.False(t, after.RetryState.Valid)
require.Equal(t, bumped.SnapshotVersion, after.RetryStateVersion,
"history reset of generation_attempt clears retry_state")
}
+40 -32
View File
@@ -414,38 +414,46 @@ var auditableResourcesTypes = map[any]map[string]Action{
"deleted_at": ActionIgnore, // Changes, but is implicit when a delete event is fired.
},
&database.Chat{}: {
"id": ActionTrack,
"owner_id": ActionTrack,
"owner_username": ActionIgnore,
"owner_name": ActionIgnore,
"organization_id": ActionIgnore, // Never changes after creation.
"workspace_id": ActionTrack,
"build_id": ActionIgnore, // Internal lifecycle.
"agent_id": ActionIgnore, // Internal lifecycle.
"title": ActionSecret, // May contain sensitive content.
"status": ActionIgnore, // Churns every message.
"worker_id": ActionIgnore, // Internal.
"started_at": ActionIgnore,
"heartbeat_at": ActionIgnore, // Internal.
"created_at": ActionIgnore, // Never changes.
"updated_at": ActionIgnore, // Bumped on every mutation.
"parent_chat_id": ActionIgnore, // Immutable after creation.
"root_chat_id": ActionIgnore, // Immutable after creation.
"last_model_config_id": ActionIgnore, // Churns every message.
"archived": ActionTrack,
"last_error": ActionIgnore, // Internal.
"last_turn_summary": ActionIgnore, // Internal cached display text.
"mode": ActionTrack,
"mcp_server_ids": ActionTrack,
"labels": ActionTrack,
"user_acl": ActionTrack,
"group_acl": ActionTrack,
"pin_order": ActionTrack,
"last_read_message_id": ActionIgnore, // User-scoped read cursor.
"last_injected_context": ActionIgnore, // Internal lifecycle.
"dynamic_tools": ActionIgnore, // Internal lifecycle.
"plan_mode": ActionIgnore, // Can flip back and forth during a session.
"client_type": ActionIgnore, // Set at creation.
"id": ActionTrack,
"owner_id": ActionTrack,
"owner_username": ActionIgnore,
"owner_name": ActionIgnore,
"organization_id": ActionIgnore, // Never changes after creation.
"workspace_id": ActionTrack,
"build_id": ActionIgnore, // Internal lifecycle.
"agent_id": ActionIgnore, // Internal lifecycle.
"title": ActionSecret, // May contain sensitive content.
"status": ActionIgnore, // Churns every message.
"worker_id": ActionIgnore, // Internal.
"started_at": ActionIgnore,
"heartbeat_at": ActionIgnore, // Internal.
"created_at": ActionIgnore, // Never changes.
"updated_at": ActionIgnore, // Bumped on every mutation.
"parent_chat_id": ActionIgnore, // Immutable after creation.
"root_chat_id": ActionIgnore, // Immutable after creation.
"last_model_config_id": ActionIgnore, // Churns every message.
"archived": ActionTrack,
"last_error": ActionIgnore, // Internal.
"last_turn_summary": ActionIgnore, // Internal cached display text.
"mode": ActionTrack,
"mcp_server_ids": ActionTrack,
"labels": ActionTrack,
"user_acl": ActionTrack,
"group_acl": ActionTrack,
"pin_order": ActionTrack,
"last_read_message_id": ActionIgnore, // User-scoped read cursor.
"last_injected_context": ActionIgnore, // Internal lifecycle.
"dynamic_tools": ActionIgnore, // Internal lifecycle.
"plan_mode": ActionIgnore, // Can flip back and forth during a session.
"client_type": ActionIgnore, // Set at creation.
"snapshot_version": ActionIgnore, // Internal state machine version.
"history_version": ActionIgnore, // Internal state machine version.
"queue_version": ActionIgnore, // Internal state machine version.
"retry_state": ActionIgnore, // Internal transient retry UI state.
"retry_state_version": ActionIgnore, // Internal state machine version.
"generation_attempt": ActionIgnore, // Internal retry counter.
"runner_id": ActionIgnore, // Internal ownership identifier.
"requires_action_deadline_at": ActionIgnore, // Internal pending-action deadline.
},
&database.UserSkill{}: {
"id": ActionTrack,