mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
source files
This commit is contained in:
@@ -0,0 +1,106 @@
|
||||
-- Rollback for the chatd core state machine foundation migration.
|
||||
|
||||
-- 1. Recreate chats_expanded without the new chat fields. We must drop
|
||||
-- the view first because the subsequent column drops would fail with
|
||||
-- "view depends on column".
|
||||
DROP VIEW IF EXISTS chats_expanded;
|
||||
|
||||
-- 2. Drop the retry state trigger and function.
|
||||
DROP TRIGGER IF EXISTS trigger_sync_chat_retry_state ON chats;
|
||||
DROP FUNCTION IF EXISTS sync_chat_retry_state();
|
||||
|
||||
-- 3. Drop the queue version triggers and function.
|
||||
DROP TRIGGER IF EXISTS trigger_bump_chat_queue_version_on_queued_message_delete ON chat_queued_messages;
|
||||
DROP TRIGGER IF EXISTS trigger_bump_chat_queue_version_on_queued_message_update ON chat_queued_messages;
|
||||
DROP TRIGGER IF EXISTS trigger_bump_chat_queue_version_on_queued_message_insert ON chat_queued_messages;
|
||||
DROP FUNCTION IF EXISTS bump_chat_queue_version_on_queued_message_change();
|
||||
|
||||
-- 4. Drop the message revision triggers and functions.
|
||||
DROP TRIGGER IF EXISTS trigger_update_chat_history_after_message_update ON chat_messages;
|
||||
DROP TRIGGER IF EXISTS trigger_update_chat_history_after_message_insert ON chat_messages;
|
||||
DROP TRIGGER IF EXISTS trigger_set_chat_message_revision_on_update ON chat_messages;
|
||||
DROP TRIGGER IF EXISTS trigger_set_chat_message_revision_on_insert ON chat_messages;
|
||||
DROP FUNCTION IF EXISTS update_chat_history_after_message_update();
|
||||
DROP FUNCTION IF EXISTS update_chat_history_after_message_insert();
|
||||
-- The pre-split function name is kept here for backward compatibility
|
||||
-- with environments that may have applied an earlier draft of the up
|
||||
-- migration. DROP FUNCTION IF EXISTS is a no-op if the function is
|
||||
-- absent.
|
||||
DROP FUNCTION IF EXISTS update_chat_history_after_message_changes();
|
||||
DROP FUNCTION IF EXISTS set_chat_message_revision_before();
|
||||
DROP FUNCTION IF EXISTS set_chat_message_revision();
|
||||
|
||||
-- 5. Drop chat_heartbeats (and its index by association).
|
||||
DROP TABLE IF EXISTS chat_heartbeats;
|
||||
|
||||
-- 6. Drop the queue-order index.
|
||||
DROP INDEX IF EXISTS idx_chat_queued_messages_chat_position_id;
|
||||
|
||||
-- 7. Drop chat_queued_messages.position and its default sequence, plus
|
||||
-- created_by.
|
||||
ALTER TABLE chat_queued_messages
|
||||
ALTER COLUMN position DROP DEFAULT;
|
||||
ALTER TABLE chat_queued_messages
|
||||
DROP COLUMN IF EXISTS position,
|
||||
DROP COLUMN IF EXISTS created_by;
|
||||
DROP SEQUENCE IF EXISTS chat_queued_messages_position_seq;
|
||||
|
||||
-- 8. Drop chat_messages.revision.
|
||||
ALTER TABLE chat_messages
|
||||
DROP COLUMN IF EXISTS revision;
|
||||
|
||||
-- 9. Drop the new chats columns.
|
||||
ALTER TABLE chats
|
||||
DROP COLUMN IF EXISTS snapshot_version,
|
||||
DROP COLUMN IF EXISTS history_version,
|
||||
DROP COLUMN IF EXISTS queue_version,
|
||||
DROP COLUMN IF EXISTS generation_attempt,
|
||||
DROP COLUMN IF EXISTS retry_state,
|
||||
DROP COLUMN IF EXISTS retry_state_version,
|
||||
DROP COLUMN IF EXISTS runner_id,
|
||||
DROP COLUMN IF EXISTS requires_action_deadline_at;
|
||||
|
||||
-- 10. Recreate chats_expanded with the pre-migration field list.
|
||||
CREATE VIEW chats_expanded AS
|
||||
SELECT
|
||||
c.id,
|
||||
c.owner_id,
|
||||
c.workspace_id,
|
||||
c.title,
|
||||
c.status,
|
||||
c.worker_id,
|
||||
c.started_at,
|
||||
c.heartbeat_at,
|
||||
c.created_at,
|
||||
c.updated_at,
|
||||
c.parent_chat_id,
|
||||
c.root_chat_id,
|
||||
c.last_model_config_id,
|
||||
c.archived,
|
||||
c.last_error,
|
||||
c.mode,
|
||||
c.mcp_server_ids,
|
||||
c.labels,
|
||||
c.build_id,
|
||||
c.agent_id,
|
||||
c.pin_order,
|
||||
c.last_read_message_id,
|
||||
c.last_injected_context,
|
||||
c.dynamic_tools,
|
||||
c.organization_id,
|
||||
c.plan_mode,
|
||||
c.client_type,
|
||||
c.last_turn_summary,
|
||||
COALESCE(root.user_acl, c.user_acl) AS user_acl,
|
||||
COALESCE(root.group_acl, c.group_acl) AS group_acl,
|
||||
owner.username AS owner_username,
|
||||
owner.name AS owner_name
|
||||
FROM
|
||||
chats c
|
||||
LEFT JOIN chats root ON root.id = COALESCE(c.root_chat_id, c.parent_chat_id)
|
||||
JOIN visible_users owner ON owner.id = c.owner_id;
|
||||
|
||||
-- 11. The `interrupting` chat_status enum value is intentionally left
|
||||
-- in place. Postgres does not support dropping a single enum value
|
||||
-- without recreating the entire type, which would require rewriting
|
||||
-- every chat row and is unsafe inside a transactional rollback.
|
||||
@@ -0,0 +1,346 @@
|
||||
-- Foundation for the chatd core state machine refactor (PR 1).
|
||||
-- Adds new versioning fields to chats, a revision column to chat_messages,
|
||||
-- positional ordering and creator tracking to chat_queued_messages, an
|
||||
-- unlogged chat_heartbeats table for ownership leases, and Postgres
|
||||
-- triggers that keep history/queue versioning consistent.
|
||||
--
|
||||
-- This migration does not switch any runtime code over to the new
|
||||
-- state machine. Existing chatd code keeps running against the legacy
|
||||
-- "pending"/"paused"/"completed" enum values and the legacy ownership
|
||||
-- columns (chats.started_at, chats.heartbeat_at). Those columns and
|
||||
-- enum values stay in place for compatibility.
|
||||
|
||||
-- 1. Add `interrupting` to the chat_status enum.
|
||||
ALTER TYPE chat_status ADD VALUE IF NOT EXISTS 'interrupting';
|
||||
|
||||
-- 2. Add new versioning, ownership, retry, and pending-action fields to chats.
|
||||
ALTER TABLE chats
|
||||
ADD COLUMN snapshot_version bigint NOT NULL DEFAULT 1,
|
||||
ADD COLUMN history_version bigint NOT NULL DEFAULT 0,
|
||||
ADD COLUMN queue_version bigint NOT NULL DEFAULT 0,
|
||||
ADD COLUMN generation_attempt bigint NOT NULL DEFAULT 0,
|
||||
ADD COLUMN retry_state jsonb,
|
||||
ADD COLUMN retry_state_version bigint NOT NULL DEFAULT 0,
|
||||
ADD COLUMN runner_id uuid,
|
||||
ADD COLUMN requires_action_deadline_at timestamp with time zone;
|
||||
|
||||
-- 3. Add `revision` to chat_messages (nullable for now, backfilled below).
|
||||
ALTER TABLE chat_messages
|
||||
ADD COLUMN revision bigint;
|
||||
|
||||
-- 4. Backfill chat_messages.revision = 1 for existing rows.
|
||||
UPDATE chat_messages SET revision = 1 WHERE revision IS NULL;
|
||||
|
||||
-- 5. Backfill chats.history_version = 1 for chats that already have at
|
||||
-- least one message. We avoid recursive trigger fire by performing the
|
||||
-- backfill before the triggers are created.
|
||||
UPDATE chats
|
||||
SET history_version = 1
|
||||
WHERE EXISTS (
|
||||
SELECT 1 FROM chat_messages WHERE chat_messages.chat_id = chats.id
|
||||
);
|
||||
|
||||
-- 6. Enforce NOT NULL on chat_messages.revision.
|
||||
ALTER TABLE chat_messages
|
||||
ALTER COLUMN revision SET NOT NULL;
|
||||
|
||||
-- 7. Add `position` and `created_by` to chat_queued_messages.
|
||||
ALTER TABLE chat_queued_messages
|
||||
ADD COLUMN position bigint,
|
||||
ADD COLUMN created_by uuid;
|
||||
|
||||
-- 8. Backfill chat_queued_messages.position per chat using row_number(),
|
||||
-- ordering by created_at and breaking ties by id.
|
||||
WITH ordered AS (
|
||||
SELECT
|
||||
id,
|
||||
row_number() OVER (
|
||||
PARTITION BY chat_id
|
||||
ORDER BY created_at, id
|
||||
) AS rn
|
||||
FROM chat_queued_messages
|
||||
)
|
||||
UPDATE chat_queued_messages
|
||||
SET position = ordered.rn
|
||||
FROM ordered
|
||||
WHERE chat_queued_messages.id = ordered.id;
|
||||
|
||||
-- 9. Backfill chat_queued_messages.created_by from chats.owner_id.
|
||||
UPDATE chat_queued_messages
|
||||
SET created_by = chats.owner_id
|
||||
FROM chats
|
||||
WHERE chat_queued_messages.chat_id = chats.id
|
||||
AND chat_queued_messages.created_by IS NULL;
|
||||
|
||||
-- 10. Enforce NOT NULL on chat_queued_messages.position and
|
||||
-- created_by. Legacy queued-message inserts are updated to populate
|
||||
-- created_by from the chat owner when no explicit creator exists.
|
||||
ALTER TABLE chat_queued_messages
|
||||
ALTER COLUMN position SET NOT NULL,
|
||||
ALTER COLUMN created_by SET NOT NULL;
|
||||
|
||||
-- 11. Index for queue-order reads and head selection.
|
||||
CREATE INDEX IF NOT EXISTS idx_chat_queued_messages_chat_position_id
|
||||
ON chat_queued_messages(chat_id, position, id);
|
||||
|
||||
-- 12. Default sequence for new queued-message positions.
|
||||
-- A global sequence is acceptable because ordering only needs to be
|
||||
-- stable within a chat.
|
||||
CREATE SEQUENCE IF NOT EXISTS chat_queued_messages_position_seq AS bigint START WITH 1;
|
||||
SELECT setval(
|
||||
'chat_queued_messages_position_seq',
|
||||
GREATEST((SELECT COALESCE(MAX(position), 0) FROM chat_queued_messages), 1)
|
||||
);
|
||||
ALTER TABLE chat_queued_messages
|
||||
ALTER COLUMN position SET DEFAULT nextval('chat_queued_messages_position_seq');
|
||||
|
||||
-- 13. Backfill chats.queue_version = 1 for chats that already have queued
|
||||
-- messages. Same trigger-avoidance reasoning as for history_version.
|
||||
UPDATE chats
|
||||
SET queue_version = 1
|
||||
WHERE EXISTS (
|
||||
SELECT 1 FROM chat_queued_messages WHERE chat_queued_messages.chat_id = chats.id
|
||||
);
|
||||
|
||||
-- 14. chat_heartbeats: unlogged table for ownership leases. Keyed by
|
||||
-- (chat_id, runner_id) so a single chat can briefly have entries from
|
||||
-- multiple runners during failover.
|
||||
CREATE UNLOGGED TABLE IF NOT EXISTS chat_heartbeats (
|
||||
chat_id uuid NOT NULL REFERENCES chats(id) ON DELETE CASCADE,
|
||||
runner_id uuid NOT NULL,
|
||||
heartbeat_at timestamp with time zone NOT NULL,
|
||||
PRIMARY KEY (chat_id, runner_id)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS chat_heartbeats_heartbeat_at_idx
|
||||
ON chat_heartbeats (heartbeat_at);
|
||||
|
||||
-- 15. Message revision trigger.
|
||||
-- The BEFORE-trigger only assigns NEW.revision from chats.snapshot_version
|
||||
-- and validates immutability. The chats.history_version /
|
||||
-- generation_attempt update is performed by an AFTER STATEMENT trigger
|
||||
-- so it doesn't conflict with CTE updates on the chats row in the same
|
||||
-- command (the legacy InsertChatMessages query updates last_model_config_id
|
||||
-- in a CTE on chats and then inserts messages).
|
||||
CREATE FUNCTION set_chat_message_revision_before()
|
||||
RETURNS trigger AS $$
|
||||
DECLARE
|
||||
chat_snapshot_version bigint;
|
||||
BEGIN
|
||||
IF TG_OP = 'INSERT' AND NEW.revision IS NOT NULL THEN
|
||||
RAISE EXCEPTION 'chat_messages.revision must be assigned by trigger';
|
||||
END IF;
|
||||
|
||||
IF TG_OP = 'UPDATE' THEN
|
||||
IF OLD.chat_id IS DISTINCT FROM NEW.chat_id THEN
|
||||
RAISE EXCEPTION 'chat_messages.chat_id is immutable';
|
||||
END IF;
|
||||
|
||||
IF OLD.revision IS DISTINCT FROM NEW.revision THEN
|
||||
RAISE EXCEPTION 'chat_messages.revision must be assigned by trigger';
|
||||
END IF;
|
||||
|
||||
IF OLD IS NOT DISTINCT FROM NEW THEN
|
||||
RETURN NEW;
|
||||
END IF;
|
||||
END IF;
|
||||
|
||||
SELECT snapshot_version INTO chat_snapshot_version
|
||||
FROM chats WHERE id = NEW.chat_id;
|
||||
|
||||
IF chat_snapshot_version IS NULL THEN
|
||||
RAISE EXCEPTION 'chat % does not exist', NEW.chat_id;
|
||||
END IF;
|
||||
|
||||
NEW.revision = chat_snapshot_version;
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
|
||||
-- AFTER STATEMENT trigger functions. Use the transition tables to
|
||||
-- update chats.history_version / generation_attempt once per chat per
|
||||
-- command. Running AFTER row inserts/updates complete lets a CTE
|
||||
-- update on the same chats row in the same command finalize before
|
||||
-- this trigger needs to update it.
|
||||
--
|
||||
-- The INSERT and UPDATE variants are split so the UPDATE variant can
|
||||
-- reference both the OLD and NEW transition tables and skip rows that
|
||||
-- did not actually change. Without that filter, a no-op UPDATE on a
|
||||
-- chat_messages row (one whose OLD IS NOT DISTINCT FROM NEW) would
|
||||
-- still advance chats.history_version whenever the chat's snapshot
|
||||
-- had previously been bumped.
|
||||
CREATE FUNCTION update_chat_history_after_message_insert()
|
||||
RETURNS trigger AS $$
|
||||
BEGIN
|
||||
UPDATE chats c
|
||||
SET history_version = c.snapshot_version,
|
||||
generation_attempt = 0
|
||||
FROM (
|
||||
SELECT DISTINCT chat_id FROM chat_message_history_new_rows
|
||||
) AS affected
|
||||
WHERE c.id = affected.chat_id
|
||||
AND (
|
||||
c.history_version IS DISTINCT FROM c.snapshot_version
|
||||
OR c.generation_attempt <> 0
|
||||
);
|
||||
RETURN NULL;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
|
||||
CREATE FUNCTION update_chat_history_after_message_update()
|
||||
RETURNS trigger AS $$
|
||||
BEGIN
|
||||
UPDATE chats c
|
||||
SET history_version = c.snapshot_version,
|
||||
generation_attempt = 0
|
||||
FROM (
|
||||
SELECT DISTINCT n.chat_id
|
||||
FROM chat_message_history_new_rows n
|
||||
JOIN chat_message_history_old_rows o ON o.id = n.id
|
||||
WHERE o IS DISTINCT FROM n
|
||||
) AS affected
|
||||
WHERE c.id = affected.chat_id
|
||||
AND (
|
||||
c.history_version IS DISTINCT FROM c.snapshot_version
|
||||
OR c.generation_attempt <> 0
|
||||
);
|
||||
RETURN NULL;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
|
||||
CREATE TRIGGER trigger_set_chat_message_revision_on_insert
|
||||
BEFORE INSERT ON chat_messages
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION set_chat_message_revision_before();
|
||||
|
||||
CREATE TRIGGER trigger_set_chat_message_revision_on_update
|
||||
BEFORE UPDATE ON chat_messages
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION set_chat_message_revision_before();
|
||||
|
||||
CREATE TRIGGER trigger_update_chat_history_after_message_insert
|
||||
AFTER INSERT ON chat_messages
|
||||
REFERENCING NEW TABLE AS chat_message_history_new_rows
|
||||
FOR EACH STATEMENT
|
||||
EXECUTE FUNCTION update_chat_history_after_message_insert();
|
||||
|
||||
CREATE TRIGGER trigger_update_chat_history_after_message_update
|
||||
AFTER UPDATE ON chat_messages
|
||||
REFERENCING OLD TABLE AS chat_message_history_old_rows NEW TABLE AS chat_message_history_new_rows
|
||||
FOR EACH STATEMENT
|
||||
EXECUTE FUNCTION update_chat_history_after_message_update();
|
||||
|
||||
-- 16. Queue version trigger function.
|
||||
CREATE FUNCTION bump_chat_queue_version_on_queued_message_change()
|
||||
RETURNS trigger AS $$
|
||||
DECLARE
|
||||
changed_chat_id uuid;
|
||||
BEGIN
|
||||
IF TG_OP = 'DELETE' THEN
|
||||
changed_chat_id = OLD.chat_id;
|
||||
ELSE
|
||||
changed_chat_id = NEW.chat_id;
|
||||
END IF;
|
||||
|
||||
UPDATE chats
|
||||
SET queue_version = snapshot_version
|
||||
WHERE id = changed_chat_id;
|
||||
|
||||
IF TG_OP = 'DELETE' THEN
|
||||
RETURN OLD;
|
||||
END IF;
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
|
||||
CREATE TRIGGER trigger_bump_chat_queue_version_on_queued_message_insert
|
||||
AFTER INSERT ON chat_queued_messages
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION bump_chat_queue_version_on_queued_message_change();
|
||||
|
||||
CREATE TRIGGER trigger_bump_chat_queue_version_on_queued_message_update
|
||||
AFTER UPDATE OF content, model_config_id, position, created_by
|
||||
ON chat_queued_messages
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION bump_chat_queue_version_on_queued_message_change();
|
||||
|
||||
CREATE TRIGGER trigger_bump_chat_queue_version_on_queued_message_delete
|
||||
AFTER DELETE ON chat_queued_messages
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION bump_chat_queue_version_on_queued_message_change();
|
||||
|
||||
-- 17. Retry state trigger function.
|
||||
CREATE FUNCTION sync_chat_retry_state()
|
||||
RETURNS trigger AS $$
|
||||
BEGIN
|
||||
IF OLD.retry_state_version IS DISTINCT FROM NEW.retry_state_version THEN
|
||||
RAISE EXCEPTION 'chats.retry_state_version must be assigned by trigger';
|
||||
END IF;
|
||||
|
||||
IF NEW.generation_attempt IS DISTINCT FROM OLD.generation_attempt THEN
|
||||
NEW.retry_state = NULL;
|
||||
END IF;
|
||||
|
||||
IF NEW.retry_state IS DISTINCT FROM OLD.retry_state THEN
|
||||
NEW.retry_state_version = NEW.snapshot_version;
|
||||
END IF;
|
||||
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
|
||||
CREATE TRIGGER trigger_sync_chat_retry_state
|
||||
BEFORE UPDATE OF retry_state, retry_state_version, generation_attempt
|
||||
ON chats
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION sync_chat_retry_state();
|
||||
|
||||
-- 18. Refresh chats_expanded to include the new chat fields. Drop and
|
||||
-- recreate so column ordering is stable.
|
||||
DROP VIEW IF EXISTS chats_expanded;
|
||||
CREATE VIEW chats_expanded AS
|
||||
SELECT
|
||||
c.id,
|
||||
c.owner_id,
|
||||
c.workspace_id,
|
||||
c.title,
|
||||
c.status,
|
||||
c.worker_id,
|
||||
c.started_at,
|
||||
c.heartbeat_at,
|
||||
c.created_at,
|
||||
c.updated_at,
|
||||
c.parent_chat_id,
|
||||
c.root_chat_id,
|
||||
c.last_model_config_id,
|
||||
c.archived,
|
||||
c.last_error,
|
||||
c.mode,
|
||||
c.mcp_server_ids,
|
||||
c.labels,
|
||||
c.build_id,
|
||||
c.agent_id,
|
||||
c.pin_order,
|
||||
c.last_read_message_id,
|
||||
c.last_injected_context,
|
||||
c.dynamic_tools,
|
||||
c.organization_id,
|
||||
c.plan_mode,
|
||||
c.client_type,
|
||||
c.last_turn_summary,
|
||||
c.snapshot_version,
|
||||
c.history_version,
|
||||
c.queue_version,
|
||||
c.generation_attempt,
|
||||
c.retry_state,
|
||||
c.retry_state_version,
|
||||
c.runner_id,
|
||||
c.requires_action_deadline_at,
|
||||
COALESCE(root.user_acl, c.user_acl) AS user_acl,
|
||||
COALESCE(root.group_acl, c.group_acl) AS group_acl,
|
||||
owner.username AS owner_username,
|
||||
owner.name AS owner_name
|
||||
FROM
|
||||
chats c
|
||||
LEFT JOIN chats root ON root.id = COALESCE(c.root_chat_id, c.parent_chat_id)
|
||||
JOIN visible_users owner ON owner.id = c.owner_id;
|
||||
+17
@@ -0,0 +1,17 @@
|
||||
-- Fixture coverage for the chat_heartbeats table introduced in
|
||||
-- migration 000500. The earlier chat fixtures already insert at least
|
||||
-- one row into chats; we attach a heartbeat for the first such chat so
|
||||
-- migration tests see a non-empty chat_heartbeats table without
|
||||
-- hard-coding a specific chat ID.
|
||||
INSERT INTO chat_heartbeats (
|
||||
chat_id,
|
||||
runner_id,
|
||||
heartbeat_at
|
||||
)
|
||||
SELECT
|
||||
chats.id,
|
||||
'00000000-0000-0000-0000-0000000fea51'::uuid,
|
||||
'2024-01-01 00:00:00+00'
|
||||
FROM chats
|
||||
ORDER BY created_at, id
|
||||
LIMIT 1;
|
||||
@@ -820,6 +820,12 @@ func (q *sqlQuerier) GetAuthorizedChats(ctx context.Context, arg GetChatsParams,
|
||||
&i.Chat.PlanMode,
|
||||
&i.Chat.ClientType,
|
||||
&i.Chat.LastTurnSummary,
|
||||
&i.Chat.SnapshotVersion,
|
||||
&i.Chat.HistoryVersion,
|
||||
&i.Chat.QueueVersion,
|
||||
&i.Chat.GenerationAttempt,
|
||||
&i.Chat.RunnerID,
|
||||
&i.Chat.RequiresActionDeadlineAt,
|
||||
&i.Chat.UserACL,
|
||||
&i.Chat.GroupACL,
|
||||
&i.Chat.OwnerUsername,
|
||||
@@ -887,6 +893,12 @@ func (q *sqlQuerier) GetAuthorizedChatsByChatFileID(ctx context.Context, fileID
|
||||
&i.PlanMode,
|
||||
&i.ClientType,
|
||||
&i.LastTurnSummary,
|
||||
&i.SnapshotVersion,
|
||||
&i.HistoryVersion,
|
||||
&i.QueueVersion,
|
||||
&i.GenerationAttempt,
|
||||
&i.RunnerID,
|
||||
&i.RequiresActionDeadlineAt,
|
||||
&i.UserACL,
|
||||
&i.GroupACL,
|
||||
&i.OwnerUsername,
|
||||
|
||||
@@ -35,6 +35,14 @@ chats_expanded AS (
|
||||
updated_chats.plan_mode,
|
||||
updated_chats.client_type,
|
||||
updated_chats.last_turn_summary,
|
||||
updated_chats.snapshot_version,
|
||||
updated_chats.history_version,
|
||||
updated_chats.queue_version,
|
||||
updated_chats.generation_attempt,
|
||||
updated_chats.retry_state,
|
||||
updated_chats.retry_state_version,
|
||||
updated_chats.runner_id,
|
||||
updated_chats.requires_action_deadline_at,
|
||||
COALESCE(root.user_acl, updated_chats.user_acl) AS user_acl,
|
||||
COALESCE(root.group_acl, updated_chats.group_acl) AS group_acl,
|
||||
owner.username AS owner_username,
|
||||
@@ -90,6 +98,14 @@ chats_expanded AS (
|
||||
updated_chats.plan_mode,
|
||||
updated_chats.client_type,
|
||||
updated_chats.last_turn_summary,
|
||||
updated_chats.snapshot_version,
|
||||
updated_chats.history_version,
|
||||
updated_chats.queue_version,
|
||||
updated_chats.generation_attempt,
|
||||
updated_chats.retry_state,
|
||||
updated_chats.retry_state_version,
|
||||
updated_chats.runner_id,
|
||||
updated_chats.requires_action_deadline_at,
|
||||
COALESCE(root.user_acl, updated_chats.user_acl) AS user_acl,
|
||||
COALESCE(root.group_acl, updated_chats.group_acl) AS group_acl,
|
||||
owner.username AS owner_username,
|
||||
@@ -293,6 +309,15 @@ SELECT *
|
||||
FROM chats_expanded
|
||||
WHERE id = @id::uuid;
|
||||
|
||||
-- name: GetChatFamilyIDsByRootID :many
|
||||
-- Returns the chat IDs of every chat in a family (root + all children)
|
||||
-- in deterministic order. The id parameter must be the root id; the
|
||||
-- query does not walk up from a child.
|
||||
SELECT id
|
||||
FROM chats
|
||||
WHERE id = @id::uuid OR root_chat_id = @id::uuid
|
||||
ORDER BY (id = @id::uuid) DESC, created_at ASC, id ASC;
|
||||
|
||||
-- name: GetChatACLByID :one
|
||||
SELECT
|
||||
user_acl AS users,
|
||||
@@ -719,6 +744,14 @@ chats_expanded AS (
|
||||
inserted_chat.plan_mode,
|
||||
inserted_chat.client_type,
|
||||
inserted_chat.last_turn_summary,
|
||||
inserted_chat.snapshot_version,
|
||||
inserted_chat.history_version,
|
||||
inserted_chat.queue_version,
|
||||
inserted_chat.generation_attempt,
|
||||
inserted_chat.retry_state,
|
||||
inserted_chat.retry_state_version,
|
||||
inserted_chat.runner_id,
|
||||
inserted_chat.requires_action_deadline_at,
|
||||
COALESCE(root.user_acl, inserted_chat.user_acl) AS user_acl,
|
||||
COALESCE(root.group_acl, inserted_chat.group_acl) AS group_acl,
|
||||
owner.username AS owner_username,
|
||||
@@ -854,6 +887,14 @@ chats_expanded AS (
|
||||
updated_chat.plan_mode,
|
||||
updated_chat.client_type,
|
||||
updated_chat.last_turn_summary,
|
||||
updated_chat.snapshot_version,
|
||||
updated_chat.history_version,
|
||||
updated_chat.queue_version,
|
||||
updated_chat.generation_attempt,
|
||||
updated_chat.retry_state,
|
||||
updated_chat.retry_state_version,
|
||||
updated_chat.runner_id,
|
||||
updated_chat.requires_action_deadline_at,
|
||||
COALESCE(root.user_acl, updated_chat.user_acl) AS user_acl,
|
||||
COALESCE(root.group_acl, updated_chat.group_acl) AS group_acl,
|
||||
owner.username AS owner_username,
|
||||
@@ -909,6 +950,14 @@ chats_expanded AS (
|
||||
updated_chat.plan_mode,
|
||||
updated_chat.client_type,
|
||||
updated_chat.last_turn_summary,
|
||||
updated_chat.snapshot_version,
|
||||
updated_chat.history_version,
|
||||
updated_chat.queue_version,
|
||||
updated_chat.generation_attempt,
|
||||
updated_chat.retry_state,
|
||||
updated_chat.retry_state_version,
|
||||
updated_chat.runner_id,
|
||||
updated_chat.requires_action_deadline_at,
|
||||
COALESCE(root.user_acl, updated_chat.user_acl) AS user_acl,
|
||||
COALESCE(root.group_acl, updated_chat.group_acl) AS group_acl,
|
||||
owner.username AS owner_username,
|
||||
@@ -962,6 +1011,14 @@ chats_expanded AS (
|
||||
updated_chat.plan_mode,
|
||||
updated_chat.client_type,
|
||||
updated_chat.last_turn_summary,
|
||||
updated_chat.snapshot_version,
|
||||
updated_chat.history_version,
|
||||
updated_chat.queue_version,
|
||||
updated_chat.generation_attempt,
|
||||
updated_chat.retry_state,
|
||||
updated_chat.retry_state_version,
|
||||
updated_chat.runner_id,
|
||||
updated_chat.requires_action_deadline_at,
|
||||
COALESCE(root.user_acl, updated_chat.user_acl) AS user_acl,
|
||||
COALESCE(root.group_acl, updated_chat.group_acl) AS group_acl,
|
||||
owner.username AS owner_username,
|
||||
@@ -1015,6 +1072,14 @@ chats_expanded AS (
|
||||
updated_chat.plan_mode,
|
||||
updated_chat.client_type,
|
||||
updated_chat.last_turn_summary,
|
||||
updated_chat.snapshot_version,
|
||||
updated_chat.history_version,
|
||||
updated_chat.queue_version,
|
||||
updated_chat.generation_attempt,
|
||||
updated_chat.retry_state,
|
||||
updated_chat.retry_state_version,
|
||||
updated_chat.runner_id,
|
||||
updated_chat.requires_action_deadline_at,
|
||||
COALESCE(root.user_acl, updated_chat.user_acl) AS user_acl,
|
||||
COALESCE(root.group_acl, updated_chat.group_acl) AS group_acl,
|
||||
owner.username AS owner_username,
|
||||
@@ -1068,6 +1133,14 @@ chats_expanded AS (
|
||||
updated_chat.plan_mode,
|
||||
updated_chat.client_type,
|
||||
updated_chat.last_turn_summary,
|
||||
updated_chat.snapshot_version,
|
||||
updated_chat.history_version,
|
||||
updated_chat.queue_version,
|
||||
updated_chat.generation_attempt,
|
||||
updated_chat.retry_state,
|
||||
updated_chat.retry_state_version,
|
||||
updated_chat.runner_id,
|
||||
updated_chat.requires_action_deadline_at,
|
||||
COALESCE(root.user_acl, updated_chat.user_acl) AS user_acl,
|
||||
COALESCE(root.group_acl, updated_chat.group_acl) AS group_acl,
|
||||
owner.username AS owner_username,
|
||||
@@ -1120,6 +1193,14 @@ chats_expanded AS (
|
||||
updated_chat.plan_mode,
|
||||
updated_chat.client_type,
|
||||
updated_chat.last_turn_summary,
|
||||
updated_chat.snapshot_version,
|
||||
updated_chat.history_version,
|
||||
updated_chat.queue_version,
|
||||
updated_chat.generation_attempt,
|
||||
updated_chat.retry_state,
|
||||
updated_chat.retry_state_version,
|
||||
updated_chat.runner_id,
|
||||
updated_chat.requires_action_deadline_at,
|
||||
COALESCE(root.user_acl, updated_chat.user_acl) AS user_acl,
|
||||
COALESCE(root.group_acl, updated_chat.group_acl) AS group_acl,
|
||||
owner.username AS owner_username,
|
||||
@@ -1172,6 +1253,14 @@ chats_expanded AS (
|
||||
updated_chat.plan_mode,
|
||||
updated_chat.client_type,
|
||||
updated_chat.last_turn_summary,
|
||||
updated_chat.snapshot_version,
|
||||
updated_chat.history_version,
|
||||
updated_chat.queue_version,
|
||||
updated_chat.generation_attempt,
|
||||
updated_chat.retry_state,
|
||||
updated_chat.retry_state_version,
|
||||
updated_chat.runner_id,
|
||||
updated_chat.requires_action_deadline_at,
|
||||
COALESCE(root.user_acl, updated_chat.user_acl) AS user_acl,
|
||||
COALESCE(root.group_acl, updated_chat.group_acl) AS group_acl,
|
||||
owner.username AS owner_username,
|
||||
@@ -1226,6 +1315,14 @@ chats_expanded AS (
|
||||
updated_chat.plan_mode,
|
||||
updated_chat.client_type,
|
||||
updated_chat.last_turn_summary,
|
||||
updated_chat.snapshot_version,
|
||||
updated_chat.history_version,
|
||||
updated_chat.queue_version,
|
||||
updated_chat.generation_attempt,
|
||||
updated_chat.retry_state,
|
||||
updated_chat.retry_state_version,
|
||||
updated_chat.runner_id,
|
||||
updated_chat.requires_action_deadline_at,
|
||||
COALESCE(root.user_acl, updated_chat.user_acl) AS user_acl,
|
||||
COALESCE(root.group_acl, updated_chat.group_acl) AS group_acl,
|
||||
owner.username AS owner_username,
|
||||
@@ -1242,11 +1339,9 @@ FROM chats_expanded;
|
||||
-- Updates the cached last completed turn summary for sidebar display.
|
||||
-- Empty or whitespace-only summaries are stored as NULL here so direct
|
||||
-- query callers cannot accidentally persist blank sidebar text.
|
||||
-- This intentionally preserves updated_at. The staleness guard relies on
|
||||
-- every new-turn query, such as UpdateChatStatus and AcquireChats, bumping
|
||||
-- updated_at. Future chat-field updates that do not bump updated_at can let
|
||||
-- stale summaries persist. If this query ever bumps updated_at, later
|
||||
-- goroutine summary writes will be rejected as stale.
|
||||
-- This intentionally preserves updated_at. The staleness guard uses
|
||||
-- history_version so worker lifecycle transitions that do not change the
|
||||
-- active message history cannot reject final turn summary writes.
|
||||
-- Two summary workers using the same freshness marker are last-write-wins.
|
||||
UPDATE chats
|
||||
SET
|
||||
@@ -1255,7 +1350,7 @@ SET
|
||||
), '')
|
||||
WHERE
|
||||
id = @id::uuid
|
||||
AND updated_at = @expected_updated_at::timestamptz;
|
||||
AND history_version = @expected_history_version::bigint;
|
||||
|
||||
-- name: UpdateChatMCPServerIDs :one
|
||||
WITH updated_chat AS (
|
||||
@@ -1298,6 +1393,14 @@ chats_expanded AS (
|
||||
updated_chat.plan_mode,
|
||||
updated_chat.client_type,
|
||||
updated_chat.last_turn_summary,
|
||||
updated_chat.snapshot_version,
|
||||
updated_chat.history_version,
|
||||
updated_chat.queue_version,
|
||||
updated_chat.generation_attempt,
|
||||
updated_chat.retry_state,
|
||||
updated_chat.retry_state_version,
|
||||
updated_chat.runner_id,
|
||||
updated_chat.requires_action_deadline_at,
|
||||
COALESCE(root.user_acl, updated_chat.user_acl) AS user_acl,
|
||||
COALESCE(root.group_acl, updated_chat.group_acl) AS group_acl,
|
||||
owner.username AS owner_username,
|
||||
@@ -1407,6 +1510,14 @@ chats_expanded AS (
|
||||
acquired_chats.plan_mode,
|
||||
acquired_chats.client_type,
|
||||
acquired_chats.last_turn_summary,
|
||||
acquired_chats.snapshot_version,
|
||||
acquired_chats.history_version,
|
||||
acquired_chats.queue_version,
|
||||
acquired_chats.generation_attempt,
|
||||
acquired_chats.retry_state,
|
||||
acquired_chats.retry_state_version,
|
||||
acquired_chats.runner_id,
|
||||
acquired_chats.requires_action_deadline_at,
|
||||
COALESCE(root.user_acl, acquired_chats.user_acl) AS user_acl,
|
||||
COALESCE(root.group_acl, acquired_chats.group_acl) AS group_acl,
|
||||
owner.username AS owner_username,
|
||||
@@ -1464,6 +1575,14 @@ chats_expanded AS (
|
||||
updated_chat.plan_mode,
|
||||
updated_chat.client_type,
|
||||
updated_chat.last_turn_summary,
|
||||
updated_chat.snapshot_version,
|
||||
updated_chat.history_version,
|
||||
updated_chat.queue_version,
|
||||
updated_chat.generation_attempt,
|
||||
updated_chat.retry_state,
|
||||
updated_chat.retry_state_version,
|
||||
updated_chat.runner_id,
|
||||
updated_chat.requires_action_deadline_at,
|
||||
COALESCE(root.user_acl, updated_chat.user_acl) AS user_acl,
|
||||
COALESCE(root.group_acl, updated_chat.group_acl) AS group_acl,
|
||||
owner.username AS owner_username,
|
||||
@@ -1521,6 +1640,14 @@ chats_expanded AS (
|
||||
updated_chat.plan_mode,
|
||||
updated_chat.client_type,
|
||||
updated_chat.last_turn_summary,
|
||||
updated_chat.snapshot_version,
|
||||
updated_chat.history_version,
|
||||
updated_chat.queue_version,
|
||||
updated_chat.generation_attempt,
|
||||
updated_chat.retry_state,
|
||||
updated_chat.retry_state_version,
|
||||
updated_chat.runner_id,
|
||||
updated_chat.requires_action_deadline_at,
|
||||
COALESCE(root.user_acl, updated_chat.user_acl) AS user_acl,
|
||||
COALESCE(root.group_acl, updated_chat.group_acl) AS group_acl,
|
||||
owner.username AS owner_username,
|
||||
@@ -1688,12 +1815,17 @@ RETURNING
|
||||
*;
|
||||
|
||||
-- name: InsertChatQueuedMessage :one
|
||||
INSERT INTO chat_queued_messages (chat_id, content, model_config_id)
|
||||
VALUES (
|
||||
@chat_id,
|
||||
@content,
|
||||
sqlc.narg('model_config_id')::uuid
|
||||
)
|
||||
-- Legacy queue insertion path. When no caller-supplied creator exists,
|
||||
-- preserve the created_by invariant by attributing the queued row to the
|
||||
-- chat owner.
|
||||
INSERT INTO chat_queued_messages (chat_id, content, model_config_id, created_by)
|
||||
SELECT
|
||||
@chat_id::uuid,
|
||||
@content::jsonb,
|
||||
sqlc.narg('model_config_id')::uuid,
|
||||
chats.owner_id
|
||||
FROM chats
|
||||
WHERE chats.id = @chat_id::uuid
|
||||
RETURNING *;
|
||||
|
||||
-- name: GetChatQueuedMessages :many
|
||||
@@ -1779,6 +1911,14 @@ chats_expanded AS (
|
||||
locked_chat.plan_mode,
|
||||
locked_chat.client_type,
|
||||
locked_chat.last_turn_summary,
|
||||
locked_chat.snapshot_version,
|
||||
locked_chat.history_version,
|
||||
locked_chat.queue_version,
|
||||
locked_chat.generation_attempt,
|
||||
locked_chat.retry_state,
|
||||
locked_chat.retry_state_version,
|
||||
locked_chat.runner_id,
|
||||
locked_chat.requires_action_deadline_at,
|
||||
COALESCE(root.user_acl, locked_chat.user_acl) AS user_acl,
|
||||
COALESCE(root.group_acl, locked_chat.group_acl) AS group_acl,
|
||||
owner.username AS owner_username,
|
||||
@@ -1791,6 +1931,63 @@ chats_expanded AS (
|
||||
SELECT *
|
||||
FROM chats_expanded;
|
||||
|
||||
-- name: GetChatByIDForShare :one
|
||||
WITH shared_chat AS (
|
||||
SELECT *
|
||||
FROM chats
|
||||
WHERE id = @id::uuid
|
||||
FOR SHARE
|
||||
),
|
||||
chats_expanded AS (
|
||||
SELECT
|
||||
shared_chat.id,
|
||||
shared_chat.owner_id,
|
||||
shared_chat.workspace_id,
|
||||
shared_chat.title,
|
||||
shared_chat.status,
|
||||
shared_chat.worker_id,
|
||||
shared_chat.started_at,
|
||||
shared_chat.heartbeat_at,
|
||||
shared_chat.created_at,
|
||||
shared_chat.updated_at,
|
||||
shared_chat.parent_chat_id,
|
||||
shared_chat.root_chat_id,
|
||||
shared_chat.last_model_config_id,
|
||||
shared_chat.archived,
|
||||
shared_chat.last_error,
|
||||
shared_chat.mode,
|
||||
shared_chat.mcp_server_ids,
|
||||
shared_chat.labels,
|
||||
shared_chat.build_id,
|
||||
shared_chat.agent_id,
|
||||
shared_chat.pin_order,
|
||||
shared_chat.last_read_message_id,
|
||||
shared_chat.last_injected_context,
|
||||
shared_chat.dynamic_tools,
|
||||
shared_chat.organization_id,
|
||||
shared_chat.plan_mode,
|
||||
shared_chat.client_type,
|
||||
shared_chat.last_turn_summary,
|
||||
shared_chat.snapshot_version,
|
||||
shared_chat.history_version,
|
||||
shared_chat.queue_version,
|
||||
shared_chat.generation_attempt,
|
||||
shared_chat.retry_state,
|
||||
shared_chat.retry_state_version,
|
||||
shared_chat.runner_id,
|
||||
shared_chat.requires_action_deadline_at,
|
||||
COALESCE(root.user_acl, shared_chat.user_acl) AS user_acl,
|
||||
COALESCE(root.group_acl, shared_chat.group_acl) AS group_acl,
|
||||
owner.username AS owner_username,
|
||||
owner.name AS owner_name
|
||||
FROM
|
||||
shared_chat
|
||||
LEFT JOIN chats root ON root.id = COALESCE(shared_chat.root_chat_id, shared_chat.parent_chat_id)
|
||||
JOIN visible_users owner ON owner.id = shared_chat.owner_id
|
||||
)
|
||||
SELECT *
|
||||
FROM chats_expanded;
|
||||
|
||||
-- name: GetChatsByChatFileID :many
|
||||
SELECT
|
||||
*
|
||||
@@ -2310,59 +2507,404 @@ WHERE chat_id = @chat_id::uuid
|
||||
AND deleted = false
|
||||
AND content::jsonb @> '[{"type": "context-file"}]';
|
||||
|
||||
-- name: AutoArchiveInactiveChats :many
|
||||
-- Archives inactive root chats (pinned and already-archived chats skipped),
|
||||
-- cascading to children via root_chat_id. Limits apply to roots, not total
|
||||
-- rows. The Go caller passes @archive_cutoff as UTC midnight so that all
|
||||
-- chats sharing the same last-activity date are archived together.
|
||||
-- Used by dbpurge.
|
||||
WITH to_archive AS (
|
||||
SELECT
|
||||
c.id,
|
||||
-- Activity = MAX(cm.created_at) across the family, or c.created_at
|
||||
-- when the family has no non-deleted messages.
|
||||
COALESCE(activity.last_activity_at, c.created_at) AS last_activity_at
|
||||
FROM chats c
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT MAX(cm.created_at) AS last_activity_at
|
||||
FROM chat_messages cm
|
||||
JOIN chats fc ON fc.id = cm.chat_id
|
||||
WHERE (fc.id = c.id OR fc.root_chat_id = c.id)
|
||||
AND cm.deleted = false
|
||||
) activity ON TRUE
|
||||
WHERE c.archived = false
|
||||
AND c.pin_order = 0
|
||||
AND c.parent_chat_id IS NULL -- roots only
|
||||
-- Redundant filter helps the planner use the partial index on created_at.
|
||||
AND c.created_at < @archive_cutoff::timestamptz
|
||||
-- New active statuses must be added here to prevent archiving.
|
||||
AND c.status NOT IN ('running', 'pending', 'paused', 'requires_action')
|
||||
AND COALESCE(activity.last_activity_at, c.created_at) < @archive_cutoff::timestamptz
|
||||
-- Sorting by created_at lets Postgres drive the scan from the
|
||||
-- partial index instead of evaluating every LATERAL subquery
|
||||
-- before sorting. All candidates are past the cutoff, so the
|
||||
-- archive order is immaterial once the backlog drains.
|
||||
ORDER BY c.created_at ASC
|
||||
LIMIT @limit_count
|
||||
),
|
||||
archived AS (
|
||||
UPDATE chats c
|
||||
SET archived = true, pin_order = 0, updated_at = NOW()
|
||||
FROM to_archive t
|
||||
WHERE (c.id = t.id OR c.root_chat_id = t.id) -- cascade to children
|
||||
AND c.archived = false
|
||||
RETURNING c.*
|
||||
)
|
||||
-- name: GetChatWorkerAcquisitionCandidates :many
|
||||
-- Returns worker-runnable chats whose ownership is missing or whose
|
||||
-- current runner heartbeat is stale. The runner_id IS NULL predicate is
|
||||
-- a robustness extension for inconsistent rows where a worker_id exists
|
||||
-- without a runner_id; normal missing ownership is worker_id IS NULL or
|
||||
-- a missing or stale heartbeat row.
|
||||
SELECT
|
||||
a.*,
|
||||
-- Children inherit their root's activity so last_activity_at is never null.
|
||||
COALESCE(
|
||||
t.last_activity_at,
|
||||
(SELECT tr.last_activity_at FROM to_archive tr WHERE tr.id = a.root_chat_id),
|
||||
a.created_at
|
||||
)::timestamptz AS last_activity_at
|
||||
FROM archived a
|
||||
LEFT JOIN to_archive t ON t.id = a.id
|
||||
-- created_at ASC flows through to dbpurge's digest truncation; see
|
||||
-- buildDigestData in dbpurge.go for the tradeoff rationale.
|
||||
ORDER BY (a.root_chat_id IS NULL) DESC, a.owner_id ASC, a.created_at ASC, a.id ASC;
|
||||
chats_expanded.*,
|
||||
chat_heartbeats.heartbeat_at AS current_heartbeat_at,
|
||||
NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM chat_heartbeats current_lease
|
||||
WHERE current_lease.chat_id = chats_expanded.id
|
||||
AND current_lease.runner_id = chats_expanded.runner_id
|
||||
AND current_lease.heartbeat_at > NOW() - (INTERVAL '1 second' * @stale_seconds::int)
|
||||
) AS heartbeat_stale
|
||||
FROM chats_expanded
|
||||
LEFT JOIN chat_heartbeats
|
||||
ON chat_heartbeats.chat_id = chats_expanded.id
|
||||
AND chat_heartbeats.runner_id = chats_expanded.runner_id
|
||||
WHERE
|
||||
chats_expanded.status IN ('running'::chat_status, 'interrupting'::chat_status, 'requires_action'::chat_status)
|
||||
AND chats_expanded.archived = false
|
||||
AND (
|
||||
chats_expanded.worker_id IS NULL
|
||||
OR chats_expanded.runner_id IS NULL
|
||||
OR NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM chat_heartbeats current_lease
|
||||
WHERE current_lease.chat_id = chats_expanded.id
|
||||
AND current_lease.runner_id = chats_expanded.runner_id
|
||||
AND current_lease.heartbeat_at > NOW() - (INTERVAL '1 second' * @stale_seconds::int)
|
||||
)
|
||||
)
|
||||
ORDER BY chats_expanded.updated_at ASC, chats_expanded.id ASC
|
||||
LIMIT @limit_count::int;
|
||||
|
||||
-- name: GetChatsByIDsForRunnerSync :many
|
||||
SELECT *
|
||||
FROM chats_expanded
|
||||
WHERE id = ANY(@ids::uuid[])
|
||||
ORDER BY id ASC;
|
||||
|
||||
-- name: UpsertChatHeartbeats :exec
|
||||
INSERT INTO chat_heartbeats (chat_id, runner_id, heartbeat_at)
|
||||
SELECT chat_ids.chat_id, runner_ids.runner_id, NOW()
|
||||
FROM unnest(@chat_ids::uuid[]) WITH ORDINALITY AS chat_ids(chat_id, ord)
|
||||
JOIN unnest(@runner_ids::uuid[]) WITH ORDINALITY AS runner_ids(runner_id, ord) USING (ord)
|
||||
ON CONFLICT (chat_id, runner_id) DO UPDATE
|
||||
SET heartbeat_at = EXCLUDED.heartbeat_at;
|
||||
|
||||
-- name: DeleteStaleChatHeartbeats :execrows
|
||||
DELETE FROM chat_heartbeats
|
||||
WHERE heartbeat_at < NOW() - (INTERVAL '1 second' * @stale_seconds::int);
|
||||
|
||||
-- name: GetAutoArchiveInactiveChatCandidates :many
|
||||
-- Returns read-only root chat candidates for state-machine-backed
|
||||
-- auto-archive. Activity is computed across the root family. The query
|
||||
-- limits roots, not total family members.
|
||||
SELECT
|
||||
chats_expanded.*,
|
||||
COALESCE(activity.last_activity_at, chats_expanded.created_at)::timestamptz AS last_activity_at
|
||||
FROM chats_expanded
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT MAX(chat_messages.created_at) AS last_activity_at
|
||||
FROM chat_messages
|
||||
JOIN chats family_chat ON family_chat.id = chat_messages.chat_id
|
||||
WHERE (family_chat.id = chats_expanded.id OR family_chat.root_chat_id = chats_expanded.id)
|
||||
AND chat_messages.deleted = false
|
||||
) activity ON TRUE
|
||||
WHERE
|
||||
chats_expanded.archived = false
|
||||
AND chats_expanded.pin_order = 0
|
||||
AND chats_expanded.parent_chat_id IS NULL
|
||||
AND chats_expanded.created_at < @archive_cutoff::timestamptz
|
||||
AND chats_expanded.status NOT IN (
|
||||
'running'::chat_status,
|
||||
'interrupting'::chat_status,
|
||||
'pending'::chat_status,
|
||||
'paused'::chat_status,
|
||||
'requires_action'::chat_status
|
||||
)
|
||||
AND COALESCE(activity.last_activity_at, chats_expanded.created_at) < @archive_cutoff::timestamptz
|
||||
ORDER BY chats_expanded.created_at ASC
|
||||
LIMIT @limit_count::int;
|
||||
|
||||
|
||||
-- =====================================================================
|
||||
-- chatd core state machine queries.
|
||||
--
|
||||
-- These are consumed by the coderd/x/chatd/chatstate package. They
|
||||
-- are intentionally kept side-by-side with the legacy chatd queries
|
||||
-- above so the existing runtime keeps working while the state machine
|
||||
-- lands behind it.
|
||||
-- =====================================================================
|
||||
|
||||
-- name: LockChatAndBumpSnapshotVersion :one
|
||||
-- Locks the chat row with FOR UPDATE and atomically increments its
|
||||
-- snapshot_version, returning the post-bump chat. This is the single
|
||||
-- entry point ChatMachine.Update uses to acquire the row lock and
|
||||
-- allocate a new snapshot version in one round trip.
|
||||
WITH bumped_chat AS (
|
||||
UPDATE chats
|
||||
SET snapshot_version = snapshot_version + 1
|
||||
WHERE id = (
|
||||
SELECT id FROM chats
|
||||
WHERE id = @id::uuid
|
||||
FOR UPDATE
|
||||
)
|
||||
RETURNING *
|
||||
),
|
||||
chats_expanded AS (
|
||||
SELECT
|
||||
bumped_chat.id,
|
||||
bumped_chat.owner_id,
|
||||
bumped_chat.workspace_id,
|
||||
bumped_chat.title,
|
||||
bumped_chat.status,
|
||||
bumped_chat.worker_id,
|
||||
bumped_chat.started_at,
|
||||
bumped_chat.heartbeat_at,
|
||||
bumped_chat.created_at,
|
||||
bumped_chat.updated_at,
|
||||
bumped_chat.parent_chat_id,
|
||||
bumped_chat.root_chat_id,
|
||||
bumped_chat.last_model_config_id,
|
||||
bumped_chat.archived,
|
||||
bumped_chat.last_error,
|
||||
bumped_chat.mode,
|
||||
bumped_chat.mcp_server_ids,
|
||||
bumped_chat.labels,
|
||||
bumped_chat.build_id,
|
||||
bumped_chat.agent_id,
|
||||
bumped_chat.pin_order,
|
||||
bumped_chat.last_read_message_id,
|
||||
bumped_chat.last_injected_context,
|
||||
bumped_chat.dynamic_tools,
|
||||
bumped_chat.organization_id,
|
||||
bumped_chat.plan_mode,
|
||||
bumped_chat.client_type,
|
||||
bumped_chat.last_turn_summary,
|
||||
bumped_chat.snapshot_version,
|
||||
bumped_chat.history_version,
|
||||
bumped_chat.queue_version,
|
||||
bumped_chat.generation_attempt,
|
||||
bumped_chat.retry_state,
|
||||
bumped_chat.retry_state_version,
|
||||
bumped_chat.runner_id,
|
||||
bumped_chat.requires_action_deadline_at,
|
||||
COALESCE(root.user_acl, bumped_chat.user_acl) AS user_acl,
|
||||
COALESCE(root.group_acl, bumped_chat.group_acl) AS group_acl,
|
||||
owner.username AS owner_username,
|
||||
owner.name AS owner_name
|
||||
FROM bumped_chat
|
||||
LEFT JOIN chats root ON root.id = COALESCE(bumped_chat.root_chat_id, bumped_chat.parent_chat_id)
|
||||
JOIN visible_users owner ON owner.id = bumped_chat.owner_id
|
||||
)
|
||||
SELECT *
|
||||
FROM chats_expanded;
|
||||
|
||||
-- name: UpdateChatExecutionState :one
|
||||
-- Atomically updates the execution-state-managed fields on a chat:
|
||||
-- status, archived, last_error, ownership identifiers, and the
|
||||
-- requires-action deadline. Callers compose this with transition
|
||||
-- mutations inside a single ChatMachine.Update transaction.
|
||||
WITH updated_chat AS (
|
||||
UPDATE chats
|
||||
SET
|
||||
status = @status::chat_status,
|
||||
archived = @archived::boolean,
|
||||
worker_id = sqlc.narg('worker_id')::uuid,
|
||||
runner_id = sqlc.narg('runner_id')::uuid,
|
||||
last_error = sqlc.narg('last_error')::jsonb,
|
||||
requires_action_deadline_at = sqlc.narg('requires_action_deadline_at')::timestamptz,
|
||||
pin_order = CASE WHEN @archived::boolean THEN 0 ELSE pin_order END,
|
||||
updated_at = NOW()
|
||||
WHERE id = @id::uuid
|
||||
RETURNING *
|
||||
),
|
||||
chats_expanded AS (
|
||||
SELECT
|
||||
updated_chat.id,
|
||||
updated_chat.owner_id,
|
||||
updated_chat.workspace_id,
|
||||
updated_chat.title,
|
||||
updated_chat.status,
|
||||
updated_chat.worker_id,
|
||||
updated_chat.started_at,
|
||||
updated_chat.heartbeat_at,
|
||||
updated_chat.created_at,
|
||||
updated_chat.updated_at,
|
||||
updated_chat.parent_chat_id,
|
||||
updated_chat.root_chat_id,
|
||||
updated_chat.last_model_config_id,
|
||||
updated_chat.archived,
|
||||
updated_chat.last_error,
|
||||
updated_chat.mode,
|
||||
updated_chat.mcp_server_ids,
|
||||
updated_chat.labels,
|
||||
updated_chat.build_id,
|
||||
updated_chat.agent_id,
|
||||
updated_chat.pin_order,
|
||||
updated_chat.last_read_message_id,
|
||||
updated_chat.last_injected_context,
|
||||
updated_chat.dynamic_tools,
|
||||
updated_chat.organization_id,
|
||||
updated_chat.plan_mode,
|
||||
updated_chat.client_type,
|
||||
updated_chat.last_turn_summary,
|
||||
updated_chat.snapshot_version,
|
||||
updated_chat.history_version,
|
||||
updated_chat.queue_version,
|
||||
updated_chat.generation_attempt,
|
||||
updated_chat.retry_state,
|
||||
updated_chat.retry_state_version,
|
||||
updated_chat.runner_id,
|
||||
updated_chat.requires_action_deadline_at,
|
||||
COALESCE(root.user_acl, updated_chat.user_acl) AS user_acl,
|
||||
COALESCE(root.group_acl, updated_chat.group_acl) AS group_acl,
|
||||
owner.username AS owner_username,
|
||||
owner.name AS owner_name
|
||||
FROM updated_chat
|
||||
LEFT JOIN chats root ON root.id = COALESCE(updated_chat.root_chat_id, updated_chat.parent_chat_id)
|
||||
JOIN visible_users owner ON owner.id = updated_chat.owner_id
|
||||
)
|
||||
SELECT *
|
||||
FROM chats_expanded;
|
||||
|
||||
-- name: UpdateChatRetryState :one
|
||||
-- Stores the client-visible retry payload. retry_state_version is
|
||||
-- assigned by trigger from the current snapshot_version.
|
||||
WITH updated_chat AS (
|
||||
UPDATE chats
|
||||
SET
|
||||
retry_state = @retry_state::jsonb,
|
||||
updated_at = NOW()
|
||||
WHERE id = @id::uuid
|
||||
RETURNING *
|
||||
),
|
||||
chats_expanded AS (
|
||||
SELECT
|
||||
updated_chat.id,
|
||||
updated_chat.owner_id,
|
||||
updated_chat.workspace_id,
|
||||
updated_chat.title,
|
||||
updated_chat.status,
|
||||
updated_chat.worker_id,
|
||||
updated_chat.started_at,
|
||||
updated_chat.heartbeat_at,
|
||||
updated_chat.created_at,
|
||||
updated_chat.updated_at,
|
||||
updated_chat.parent_chat_id,
|
||||
updated_chat.root_chat_id,
|
||||
updated_chat.last_model_config_id,
|
||||
updated_chat.archived,
|
||||
updated_chat.last_error,
|
||||
updated_chat.mode,
|
||||
updated_chat.mcp_server_ids,
|
||||
updated_chat.labels,
|
||||
updated_chat.build_id,
|
||||
updated_chat.agent_id,
|
||||
updated_chat.pin_order,
|
||||
updated_chat.last_read_message_id,
|
||||
updated_chat.last_injected_context,
|
||||
updated_chat.dynamic_tools,
|
||||
updated_chat.organization_id,
|
||||
updated_chat.plan_mode,
|
||||
updated_chat.client_type,
|
||||
updated_chat.last_turn_summary,
|
||||
updated_chat.snapshot_version,
|
||||
updated_chat.history_version,
|
||||
updated_chat.queue_version,
|
||||
updated_chat.generation_attempt,
|
||||
updated_chat.retry_state,
|
||||
updated_chat.retry_state_version,
|
||||
updated_chat.runner_id,
|
||||
updated_chat.requires_action_deadline_at,
|
||||
COALESCE(root.user_acl, updated_chat.user_acl) AS user_acl,
|
||||
COALESCE(root.group_acl, updated_chat.group_acl) AS group_acl,
|
||||
owner.username AS owner_username,
|
||||
owner.name AS owner_name
|
||||
FROM updated_chat
|
||||
LEFT JOIN chats root ON root.id = COALESCE(updated_chat.root_chat_id, updated_chat.parent_chat_id)
|
||||
JOIN visible_users owner ON owner.id = updated_chat.owner_id
|
||||
)
|
||||
SELECT *
|
||||
FROM chats_expanded;
|
||||
|
||||
-- name: IncrementChatGenerationAttempt :one
|
||||
-- Increments generation_attempt and returns the resulting value.
|
||||
UPDATE chats
|
||||
SET generation_attempt = generation_attempt + 1, updated_at = NOW()
|
||||
WHERE id = @id::uuid
|
||||
RETURNING generation_attempt;
|
||||
|
||||
-- name: GetDatabaseNow :one
|
||||
-- Returns the current database timestamp. Used so transitions that
|
||||
-- record deadlines or heartbeats rely on a clock that is consistent
|
||||
-- with the database rather than the caller's local clock.
|
||||
SELECT NOW()::timestamptz AS now;
|
||||
|
||||
-- name: InsertChatQueuedMessageWithCreator :one
|
||||
-- Inserts a queued message that carries a position (from the default
|
||||
-- sequence) and an explicit created_by reference. Use this when the
|
||||
-- queued-message creator differs from the chat owner.
|
||||
INSERT INTO chat_queued_messages (chat_id, content, model_config_id, created_by)
|
||||
VALUES (
|
||||
@chat_id::uuid,
|
||||
@content::jsonb,
|
||||
sqlc.narg('model_config_id')::uuid,
|
||||
@created_by::uuid
|
||||
)
|
||||
RETURNING *;
|
||||
|
||||
-- name: GetChatQueuedMessagesByPosition :many
|
||||
-- Returns queued messages in state-machine order (position ASC, id ASC).
|
||||
SELECT * FROM chat_queued_messages
|
||||
WHERE chat_id = @chat_id::uuid
|
||||
ORDER BY position ASC, id ASC;
|
||||
|
||||
-- name: CountChatQueuedMessages :one
|
||||
-- Cheap queue-length check used by ChatMachine.Update when deciding
|
||||
-- whether the chat is in a "1" sub-state.
|
||||
SELECT COUNT(*)::bigint AS count
|
||||
FROM chat_queued_messages
|
||||
WHERE chat_id = @chat_id::uuid;
|
||||
|
||||
-- name: GetChatQueuedMessageHead :one
|
||||
-- Returns the queue head (lowest position, then lowest id).
|
||||
SELECT * FROM chat_queued_messages
|
||||
WHERE chat_id = @chat_id::uuid
|
||||
ORDER BY position ASC, id ASC
|
||||
LIMIT 1;
|
||||
|
||||
-- name: GetChatQueuedMessageByID :one
|
||||
SELECT * FROM chat_queued_messages
|
||||
WHERE id = @id::bigint AND chat_id = @chat_id::uuid;
|
||||
|
||||
-- name: DeleteChatQueuedMessageReturningCount :execrows
|
||||
-- Deletes a queued message, scoped to the parent chat. Returns the
|
||||
-- number of affected rows so callers can detect missing rows without
|
||||
-- a follow-up read.
|
||||
DELETE FROM chat_queued_messages
|
||||
WHERE id = @id::bigint AND chat_id = @chat_id::uuid;
|
||||
|
||||
-- name: DeleteAllChatQueuedMessagesReturningCount :execrows
|
||||
DELETE FROM chat_queued_messages
|
||||
WHERE chat_id = @chat_id::uuid;
|
||||
|
||||
-- name: ReorderChatQueuedMessageToHead :execrows
|
||||
-- Sets the target queued message's position to one less than the
|
||||
-- current minimum position for that chat, moving it to the head.
|
||||
UPDATE chat_queued_messages AS target
|
||||
SET position = COALESCE(
|
||||
(SELECT MIN(position) FROM chat_queued_messages WHERE chat_id = @chat_id::uuid),
|
||||
0
|
||||
) - 1
|
||||
WHERE target.id = @id::bigint
|
||||
AND target.chat_id = @chat_id::uuid
|
||||
AND target.position > COALESCE(
|
||||
(SELECT MIN(position) FROM chat_queued_messages WHERE chat_id = @chat_id::uuid),
|
||||
target.position
|
||||
);
|
||||
|
||||
-- name: UpsertChatHeartbeat :exec
|
||||
-- Upserts a heartbeat row for the (chat_id, runner_id) lease. Uses
|
||||
-- database time so callers do not depend on a local clock.
|
||||
INSERT INTO chat_heartbeats (chat_id, runner_id, heartbeat_at)
|
||||
VALUES (@chat_id::uuid, @runner_id::uuid, NOW())
|
||||
ON CONFLICT (chat_id, runner_id) DO UPDATE
|
||||
SET heartbeat_at = EXCLUDED.heartbeat_at;
|
||||
|
||||
-- name: GetChatHeartbeat :one
|
||||
SELECT * FROM chat_heartbeats
|
||||
WHERE chat_id = @chat_id::uuid AND runner_id = @runner_id::uuid;
|
||||
|
||||
-- name: IsChatHeartbeatStale :one
|
||||
-- Returns true when there is no heartbeat row for (chat_id, runner_id)
|
||||
-- or the existing row is older than @stale_seconds seconds by database
|
||||
-- time. chatstate calls this in a single query so the staleness check
|
||||
-- is atomic and does not depend on the caller's local clock.
|
||||
SELECT NOT EXISTS (
|
||||
SELECT 1 FROM chat_heartbeats
|
||||
WHERE chat_id = @chat_id::uuid
|
||||
AND runner_id = @runner_id::uuid
|
||||
AND heartbeat_at > NOW() - (INTERVAL '1 second' * @stale_seconds::int)
|
||||
) AS stale;
|
||||
|
||||
-- name: DeleteChatHeartbeats :execrows
|
||||
-- Deletes heartbeat rows for the supplied (chat_id, runner_id) pairs.
|
||||
DELETE FROM chat_heartbeats
|
||||
USING unnest(@chat_ids::uuid[]) WITH ORDINALITY AS chat_ids(chat_id, ord)
|
||||
JOIN unnest(@runner_ids::uuid[]) WITH ORDINALITY AS runner_ids(runner_id, ord) USING (ord)
|
||||
WHERE chat_heartbeats.chat_id = chat_ids.chat_id
|
||||
AND chat_heartbeats.runner_id = runner_ids.runner_id;
|
||||
|
||||
-- name: DeleteAllChatHeartbeats :exec
|
||||
-- Deletes all heartbeat rows for the chat. Used during ownership
|
||||
-- transitions that abandon a lease.
|
||||
DELETE FROM chat_heartbeats WHERE chat_id = @chat_id::uuid;
|
||||
|
||||
|
||||
@@ -0,0 +1,84 @@
|
||||
package pubsub
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
// ChatStateUpdateChannel returns the pubsub channel that receives one
|
||||
// `chat:update:{chat_id}` message every time the chatstate state
|
||||
// machine commits a transition for the chat.
|
||||
func ChatStateUpdateChannel(chatID uuid.UUID) string {
|
||||
return fmt.Sprintf("chat:update:%s", chatID)
|
||||
}
|
||||
|
||||
// ChatStateOwnershipChannel is the global pubsub channel that
|
||||
// receives ownership hints when a chat is runnable but currently has
|
||||
// missing or stale ownership. Workers listen on this channel to know
|
||||
// when to attempt acquisition.
|
||||
const ChatStateOwnershipChannel = "chat:ownership"
|
||||
|
||||
// ChatStateUpdateMessage is the JSON payload published on
|
||||
// [ChatStateUpdateChannel] after every successful CreateChat or
|
||||
// ChatMachine.Update commit. It carries the committed post-transition
|
||||
// versions and ownership identifiers so stream loops and workers can
|
||||
// decide whether to refetch state.
|
||||
type ChatStateUpdateMessage struct {
|
||||
SnapshotVersion int64 `json:"snapshot_version"`
|
||||
WorkerID *uuid.UUID `json:"worker_id,omitempty"`
|
||||
RunnerID *uuid.UUID `json:"runner_id,omitempty"`
|
||||
HistoryVersion int64 `json:"history_version"`
|
||||
QueueVersion int64 `json:"queue_version"`
|
||||
RetryStateVersion int64 `json:"retry_state_version"`
|
||||
GenerationAttempt int64 `json:"generation_attempt"`
|
||||
Status string `json:"status"`
|
||||
Archived bool `json:"archived"`
|
||||
}
|
||||
|
||||
// ChatStateOwnershipMessage is the JSON payload published on
|
||||
// [ChatStateOwnershipChannel] when ownership is missing or stale for
|
||||
// a runnable chat. Subscribers should reload the chat row to confirm
|
||||
// ownership before acting.
|
||||
type ChatStateOwnershipMessage struct {
|
||||
ChatID uuid.UUID `json:"chat_id"`
|
||||
SnapshotVersion int64 `json:"snapshot_version"`
|
||||
}
|
||||
|
||||
// HandleChatStateUpdate wraps a typed callback for
|
||||
// [ChatStateUpdateMessage] consumption, following the same pattern as
|
||||
// HandleChatWatchEvent.
|
||||
func HandleChatStateUpdate(cb func(ctx context.Context, payload ChatStateUpdateMessage, err error)) func(ctx context.Context, message []byte, err error) {
|
||||
return func(ctx context.Context, message []byte, err error) {
|
||||
if err != nil {
|
||||
cb(ctx, ChatStateUpdateMessage{}, xerrors.Errorf("chat state update pubsub: %w", err))
|
||||
return
|
||||
}
|
||||
var payload ChatStateUpdateMessage
|
||||
if uerr := json.Unmarshal(message, &payload); uerr != nil {
|
||||
cb(ctx, ChatStateUpdateMessage{}, xerrors.Errorf("unmarshal chat state update: %w", uerr))
|
||||
return
|
||||
}
|
||||
cb(ctx, payload, err)
|
||||
}
|
||||
}
|
||||
|
||||
// HandleChatStateOwnership wraps a typed callback for
|
||||
// [ChatStateOwnershipMessage] consumption.
|
||||
func HandleChatStateOwnership(cb func(ctx context.Context, payload ChatStateOwnershipMessage, err error)) func(ctx context.Context, message []byte, err error) {
|
||||
return func(ctx context.Context, message []byte, err error) {
|
||||
if err != nil {
|
||||
cb(ctx, ChatStateOwnershipMessage{}, xerrors.Errorf("chat state ownership pubsub: %w", err))
|
||||
return
|
||||
}
|
||||
var payload ChatStateOwnershipMessage
|
||||
if uerr := json.Unmarshal(message, &payload); uerr != nil {
|
||||
cb(ctx, ChatStateOwnershipMessage{}, xerrors.Errorf("unmarshal chat state ownership: %w", uerr))
|
||||
return
|
||||
}
|
||||
cb(ctx, payload, err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,169 @@
|
||||
package chatstate_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatstate"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
// waitForChan returns true if c receives a value before ctx is done.
|
||||
// Helper used in concurrency tests to avoid time.Sleep.
|
||||
func waitForChan(ctx context.Context, c <-chan struct{}) bool {
|
||||
select {
|
||||
case <-c:
|
||||
return true
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// stillBlocked returns true if c has NOT received a value yet. The
|
||||
// caller must already have established a happens-before ordering via
|
||||
// some other channel so this check is meaningful.
|
||||
func stillBlocked(c <-chan struct{}) bool {
|
||||
select {
|
||||
case <-c:
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// TestLockLocksChatRow verifies that ChatMachine.Lock holds the chat
|
||||
// row's FOR UPDATE lock until the callback returns, so a concurrent
|
||||
// ChatMachine.Update cannot enter its callback until the Lock
|
||||
// callback releases.
|
||||
func TestLockLocksChatRow(t *testing.T) {
|
||||
t.Parallel()
|
||||
f := newTestFixture(t)
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
created := createTestChat(t, f)
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
|
||||
|
||||
lockEntered := make(chan struct{})
|
||||
releaseLock := make(chan struct{})
|
||||
updateEntered := make(chan struct{})
|
||||
|
||||
// Goroutine A: hold a Lock and block.
|
||||
var lockErr error
|
||||
var lockWG sync.WaitGroup
|
||||
lockWG.Add(1)
|
||||
go func() {
|
||||
defer lockWG.Done()
|
||||
lockErr = m.Lock(ctx, func(_ database.Store) error {
|
||||
close(lockEntered)
|
||||
<-releaseLock
|
||||
return nil
|
||||
})
|
||||
}()
|
||||
|
||||
// Wait until A is inside its Lock callback (and therefore holds
|
||||
// the FOR UPDATE lock).
|
||||
require.True(t, waitForChan(ctx, lockEntered), "Lock callback never started")
|
||||
|
||||
// Goroutine B: try to Update the same chat. It must block on
|
||||
// LockChatAndBumpSnapshotVersion until A releases.
|
||||
var updateErr error
|
||||
var updateWG sync.WaitGroup
|
||||
updateWG.Add(1)
|
||||
go func() {
|
||||
defer updateWG.Done()
|
||||
updateErr = m.Update(ctx, func(_ *chatstate.Tx) error {
|
||||
close(updateEntered)
|
||||
return nil
|
||||
})
|
||||
}()
|
||||
|
||||
// Loop a few times re-checking that B is still blocked, with a
|
||||
// fresh round-trip through the database to give B's transaction
|
||||
// every chance to commit if the lock weren't held. The loop avoids
|
||||
// time.Sleep by using the database call itself as a "barrier".
|
||||
for range 5 {
|
||||
// Force a sync round-trip through the DB. This serves the
|
||||
// same role as a Sleep but is deterministic: by the time
|
||||
// this read completes, the scheduler has had a chance to
|
||||
// run goroutine B if it could make progress.
|
||||
_, err := f.DB.GetChatByID(ctx, created.Chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, stillBlocked(updateEntered),
|
||||
"Update entered while Lock was still held")
|
||||
}
|
||||
|
||||
// Release Lock and confirm Update completes successfully.
|
||||
close(releaseLock)
|
||||
require.True(t, waitForChan(ctx, updateEntered),
|
||||
"Update callback never started after Lock released")
|
||||
updateWG.Wait()
|
||||
lockWG.Wait()
|
||||
require.NoError(t, lockErr)
|
||||
require.NoError(t, updateErr)
|
||||
}
|
||||
|
||||
// TestLockRollsBackCallbackError verifies that a Lock callback
|
||||
// returning an error rolls back the surrounding transaction.
|
||||
func TestLockRollsBackCallbackError(t *testing.T) {
|
||||
t.Parallel()
|
||||
f := newTestFixture(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
created := createTestChat(t, f)
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
|
||||
|
||||
before := f.readChat(ctx, t, created.Chat.ID)
|
||||
publishedBefore := len(f.Pub.channels)
|
||||
|
||||
sentinel := xerrors.New("lock callback error")
|
||||
err := m.Lock(ctx, func(store database.Store) error {
|
||||
// Try a write that should be rolled back.
|
||||
_, werr := store.UpdateChatByID(ctx, database.UpdateChatByIDParams{
|
||||
ID: created.Chat.ID,
|
||||
Title: "rollback-me",
|
||||
})
|
||||
require.NoError(t, werr)
|
||||
return sentinel
|
||||
})
|
||||
require.ErrorIs(t, err, sentinel)
|
||||
|
||||
after := f.readChat(ctx, t, created.Chat.ID)
|
||||
require.Equal(t, before.Title, after.Title, "Lock callback error rolls back writes")
|
||||
require.Equal(t, publishedBefore, len(f.Pub.channels), "Lock publishes nothing on error")
|
||||
}
|
||||
|
||||
// TestConcurrentUpdatesSerializeOnChatRow verifies that two
|
||||
// goroutines racing to Update the same chat both succeed but their
|
||||
// effects serialize on the chat row lock: snapshot_version advances
|
||||
// by exactly N (one per Update) and each transition observes the
|
||||
// effects of the prior one.
|
||||
func TestConcurrentUpdatesSerializeOnChatRow(t *testing.T) {
|
||||
t.Parallel()
|
||||
f := newTestFixture(t)
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
created := createTestChat(t, f)
|
||||
before := f.readChat(ctx, t, created.Chat.ID)
|
||||
|
||||
const updates = 8
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(updates)
|
||||
errs := make([]error, updates)
|
||||
for i := 0; i < updates; i++ {
|
||||
i := i
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
|
||||
errs[i] = m.Update(ctx, func(_ *chatstate.Tx) error { return nil })
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
for i, err := range errs {
|
||||
require.NoError(t, err, "concurrent update %d failed", i)
|
||||
}
|
||||
after := f.readChat(ctx, t, created.Chat.ID)
|
||||
require.Equal(t, before.SnapshotVersion+int64(updates), after.SnapshotVersion,
|
||||
"snapshot_version advanced by exactly one per update")
|
||||
}
|
||||
@@ -0,0 +1,36 @@
|
||||
// Package chatstate owns the durable execution-state transitions for
|
||||
// the chatd subsystem. It implements the core state machine described
|
||||
// in the chatd RFC: a 13-state execution model plus a 2-state
|
||||
// ownership model on top of database rows in `chats`,
|
||||
// `chat_messages`, `chat_queued_messages`, and `chat_heartbeats`.
|
||||
//
|
||||
// The package exposes two top-level entry points:
|
||||
//
|
||||
// - [CreateChat] creates a brand new chat with its initial history
|
||||
// in a single transaction. It is standalone because no chat-scoped
|
||||
// state machine instance can exist before the chat row is written.
|
||||
// - [ChatMachine] wraps an existing chat. Callers use it to apply
|
||||
// one or more transitions atomically via [ChatMachine.Update], or
|
||||
// to read related rows while holding the chat row lock via
|
||||
// [ChatMachine.Lock].
|
||||
//
|
||||
// Every successful [ChatMachine.Update] call locks the chat row,
|
||||
// advances `snapshot_version` exactly once, applies transition methods
|
||||
// in order, and (on commit) publishes a single typed `chat:update`
|
||||
// pubsub message describing the post-transition snapshot. Optional
|
||||
// `chat:ownership` hints are published only when the post-transition
|
||||
// state is runnable and ownership is missing or stale. Stream side
|
||||
// effects are handled by `chat:update` consumers, and ownership hints
|
||||
// wake chat workers.
|
||||
//
|
||||
// Transition methods are explicit, typed wrappers around the SQL
|
||||
// mutations needed to move between states. Each transition reads the
|
||||
// current chat row and queue cardinality, classifies the resulting
|
||||
// execution state, and rejects with an [*TransitionError] wrapping
|
||||
// [ErrTransitionNotAllowed] when the transition is not legal from
|
||||
// that state. The transition matrix and
|
||||
// state classification helpers live in `state.go` and `transition.go`
|
||||
// alongside unit-testable classifiers; the SQL is in
|
||||
// `coderd/database/queries/chats.sql` (e.g. `LockChatAndBumpSnapshotVersion`,
|
||||
// `UpdateChatExecutionState`).
|
||||
package chatstate
|
||||
@@ -0,0 +1,152 @@
|
||||
package chatstate
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
// Sentinel errors returned by chatstate transitions and helpers.
|
||||
// Callers should use errors.Is to test for these.
|
||||
var (
|
||||
// ErrTransitionNotAllowed is returned when a transition is applied
|
||||
// to a chat whose current execution state does not permit it. The
|
||||
// concrete error returned by transition methods is a
|
||||
// *TransitionError that wraps this sentinel.
|
||||
ErrTransitionNotAllowed = xerrors.New("chat state transition not allowed")
|
||||
|
||||
// ErrInvalidState is returned when the chat row, queue, and
|
||||
// archive flag together produce a combination outside the 13
|
||||
// valid execution states described in the RFC.
|
||||
ErrInvalidState = xerrors.New("chat is in an invalid execution state")
|
||||
|
||||
// ErrQueuedMessageNotFound is returned by queue-targeting
|
||||
// transitions (delete, promote) when the supplied queued message
|
||||
// ID does not match a row on the chat.
|
||||
ErrQueuedMessageNotFound = xerrors.New("queued message not found")
|
||||
|
||||
// ErrMessageNotFound is returned by [Tx.EditMessage] when the
|
||||
// target chat_messages row is missing or belongs to another chat.
|
||||
ErrMessageNotFound = xerrors.New("chat message not found")
|
||||
|
||||
// ErrChatNotFound is returned when a non-create transition is
|
||||
// applied to a chat row that does not exist (or has been deleted
|
||||
// since the transition started).
|
||||
ErrChatNotFound = xerrors.New("chat not found")
|
||||
|
||||
// ErrChatNotRoot is returned by family-archive helpers when the
|
||||
// supplied chat is not a root chat (its parent_chat_id is set).
|
||||
ErrChatNotRoot = xerrors.New("chat is not a root chat")
|
||||
|
||||
// ErrEditedMessageNotUser is returned by [Tx.EditMessage] when the
|
||||
// targeted chat_messages row exists but its role is not user.
|
||||
ErrEditedMessageNotUser = xerrors.New("only user messages can be edited")
|
||||
|
||||
// ErrMessageQueueFull is returned by queue-appending transitions
|
||||
// when the per-chat queue cap has been reached. The concrete
|
||||
// error returned by transitions is a *MessageQueueFullError that
|
||||
// wraps this sentinel.
|
||||
ErrMessageQueueFull = xerrors.New("chat message queue is full")
|
||||
|
||||
// ErrToolResultDuplicate is returned by [Tx.CompleteRequiresAction]
|
||||
// when the same tool_call_id appears more than once in the
|
||||
// submitted results.
|
||||
ErrToolResultDuplicate = xerrors.New("duplicate tool result")
|
||||
|
||||
// ErrToolResultUnexpected is returned by
|
||||
// [Tx.CompleteRequiresAction] when a submitted tool_call_id does
|
||||
// not correspond to a pending dynamic tool call.
|
||||
ErrToolResultUnexpected = xerrors.New("unexpected tool result")
|
||||
|
||||
// ErrToolResultMissing is returned by [Tx.CompleteRequiresAction]
|
||||
// when a pending dynamic tool call has no submitted result.
|
||||
ErrToolResultMissing = xerrors.New("missing tool result")
|
||||
|
||||
// ErrToolResultInvalidJSON is returned by
|
||||
// [Tx.CompleteRequiresAction] when a submitted tool result output
|
||||
// is not valid JSON.
|
||||
ErrToolResultInvalidJSON = xerrors.New("tool result output is not valid JSON")
|
||||
)
|
||||
|
||||
// MessageQueueFullError carries the per-chat queue cap so HTTP
|
||||
// endpoints can include the cap in their response detail. It wraps
|
||||
// [ErrMessageQueueFull] so callers can match it with errors.Is.
|
||||
type MessageQueueFullError struct {
|
||||
Max int64
|
||||
}
|
||||
|
||||
// Error implements the error interface.
|
||||
func (e *MessageQueueFullError) Error() string {
|
||||
return fmt.Sprintf("chat message queue is full (max %d)", e.Max)
|
||||
}
|
||||
|
||||
// Unwrap returns [ErrMessageQueueFull] so callers can match the
|
||||
// generic sentinel.
|
||||
func (*MessageQueueFullError) Unwrap() error { return ErrMessageQueueFull }
|
||||
|
||||
// ToolResultValidationError carries a structured tool-result
|
||||
// validation failure. It always wraps a specific sentinel
|
||||
// (ErrToolResultDuplicate, ErrToolResultMissing,
|
||||
// ErrToolResultUnexpected, ErrToolResultInvalidJSON) so callers can
|
||||
// match either the generic sentinel or the specific cause.
|
||||
type ToolResultValidationError struct {
|
||||
Cause error
|
||||
ToolCallID string
|
||||
}
|
||||
|
||||
// Error implements the error interface.
|
||||
func (e *ToolResultValidationError) Error() string {
|
||||
if e.ToolCallID != "" {
|
||||
return fmt.Sprintf("%s: %s", e.Cause.Error(), e.ToolCallID)
|
||||
}
|
||||
return e.Cause.Error()
|
||||
}
|
||||
|
||||
// Unwrap returns the specific cause so callers can match it.
|
||||
func (e *ToolResultValidationError) Unwrap() error { return e.Cause }
|
||||
|
||||
// TransitionError carries the structured detail for a rejected
|
||||
// transition. It always wraps [ErrTransitionNotAllowed] so callers can
|
||||
// match with errors.Is without losing context. When a specific
|
||||
// chatstate sentinel is the proximate cause, Cause is set and
|
||||
// errors.Is will match that sentinel too.
|
||||
type TransitionError struct {
|
||||
Transition Transition
|
||||
From ExecutionState
|
||||
Reason string
|
||||
Cause error
|
||||
}
|
||||
|
||||
// Error implements the error interface.
|
||||
func (e *TransitionError) Error() string {
|
||||
if e.Reason == "" {
|
||||
return fmt.Sprintf(
|
||||
"chat state transition %s not allowed from state %s",
|
||||
e.Transition, e.From,
|
||||
)
|
||||
}
|
||||
return fmt.Sprintf(
|
||||
"chat state transition %s not allowed from state %s: %s",
|
||||
e.Transition, e.From, e.Reason,
|
||||
)
|
||||
}
|
||||
|
||||
// Unwrap returns the error chain attached to this error. The chain
|
||||
// always includes [ErrTransitionNotAllowed], and may include a more
|
||||
// specific cause through errors.Join, so callers can use errors.Is
|
||||
// without custom matching logic on TransitionError.
|
||||
func (e *TransitionError) Unwrap() error { return e.Cause }
|
||||
|
||||
// newTransitionError constructs a typed TransitionError. Returning the
|
||||
// pointer type lets callers inspect the structured fields when needed.
|
||||
func newTransitionError(t Transition, from ExecutionState, reason string) *TransitionError {
|
||||
return &TransitionError{Transition: t, From: from, Reason: reason, Cause: ErrTransitionNotAllowed}
|
||||
}
|
||||
|
||||
// newTransitionErrorWithCause constructs a TransitionError carrying
|
||||
// a specific underlying sentinel so callers can match the cause with
|
||||
// errors.Is.
|
||||
func newTransitionErrorWithCause(t Transition, from ExecutionState, cause error, reason string) *TransitionError {
|
||||
return &TransitionError{Transition: t, From: from, Reason: reason, Cause: errors.Join(ErrTransitionNotAllowed, cause)}
|
||||
}
|
||||
@@ -0,0 +1,130 @@
|
||||
package chatstate
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
)
|
||||
|
||||
// SetFamilyArchivedInput configures [SetFamilyArchived]. The struct
|
||||
// shape avoids a boolean flag parameter at the API surface; callers
|
||||
// build it explicitly with named fields for clarity.
|
||||
type SetFamilyArchivedInput struct {
|
||||
// RootID identifies the family root. SetFamilyArchived rejects
|
||||
// calls for child chats with [ErrChatNotRoot] and unknown chats
|
||||
// with [ErrChatNotFound].
|
||||
RootID uuid.UUID
|
||||
// Archived is the desired post-call archived value for every
|
||||
// family member.
|
||||
Archived bool
|
||||
}
|
||||
|
||||
// SetFamilyArchived runs Update for every chat in the root chat's
|
||||
// family inside one transaction, applying SetArchived when the chat's
|
||||
// archived flag differs from the requested value. It owns its
|
||||
// transaction lifecycle and its [PublishBuffer] lifecycle: pubsub
|
||||
// publications are buffered while the transaction is open and
|
||||
// flushed only after a successful commit; the deferred Discard
|
||||
// suppresses every buffered publication on failure.
|
||||
//
|
||||
// On success SetFamilyArchived returns one [database.Chat] per
|
||||
// family member in the order returned by GetChatFamilyIDsByRootID
|
||||
// (root first, then children).
|
||||
//
|
||||
// Family members that are already in the [StateInvalid] execution
|
||||
// state cause SetFamilyArchived to return [ErrInvalidState] and roll
|
||||
// back the cascade even when their archived flag already matches the
|
||||
// desired value; invalid-state detection is never bypassed.
|
||||
//
|
||||
// Family members that are valid and already match the desired
|
||||
// archived value still run through Update, which increments their
|
||||
// snapshot version and publishes a fresh snapshot without changing
|
||||
// the archived flag. Advancing the snapshot version without a field
|
||||
// change is safe, and it keeps publication behavior uniform while a
|
||||
// partially archived family converges to the desired state.
|
||||
func SetFamilyArchived(
|
||||
ctx context.Context,
|
||||
store database.Store,
|
||||
publisher Publisher,
|
||||
input SetFamilyArchivedInput,
|
||||
) ([]database.Chat, error) {
|
||||
if store == nil {
|
||||
return nil, xerrors.New("chatstate: SetFamilyArchived called with nil store")
|
||||
}
|
||||
if publisher == nil {
|
||||
return nil, xerrors.New("chatstate: SetFamilyArchived called with nil publisher")
|
||||
}
|
||||
|
||||
buffer := NewPublishBuffer(publisher)
|
||||
defer buffer.Discard()
|
||||
|
||||
var familyChats []database.Chat
|
||||
err := store.InTx(func(tx database.Store) error {
|
||||
// Lock the root chat first so concurrent archive races on the
|
||||
// same family serialize on a stable row.
|
||||
root, err := tx.GetChatByIDForUpdate(ctx, input.RootID)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return ErrChatNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return xerrors.Errorf("lock root chat for archive: %w", err)
|
||||
}
|
||||
if root.ParentChatID.Valid {
|
||||
return ErrChatNotRoot
|
||||
}
|
||||
ids, err := tx.GetChatFamilyIDsByRootID(ctx, input.RootID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get chat family: %w", err)
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
return ErrChatNotFound
|
||||
}
|
||||
familyChats = make([]database.Chat, 0, len(ids))
|
||||
for _, id := range ids {
|
||||
var chat database.Chat
|
||||
machine := NewChatMachine(tx, buffer, id, Options{})
|
||||
err := machine.Update(ctx, func(state *Tx) error {
|
||||
// Load the current chat and classify it so we can
|
||||
// reject invalid-state members with ErrInvalidState
|
||||
// even when their archived flag already matches.
|
||||
current, from, lerr := state.loadState()
|
||||
if lerr != nil {
|
||||
return lerr
|
||||
}
|
||||
if from == StateInvalid {
|
||||
return ErrInvalidState
|
||||
}
|
||||
if current.Archived == input.Archived {
|
||||
chat = current
|
||||
return nil
|
||||
}
|
||||
if _, err := state.SetArchived(SetArchivedInput{Archived: input.Archived}); err != nil {
|
||||
return err
|
||||
}
|
||||
var err error
|
||||
chat, err = state.Store().GetChatByID(state.Ctx(), state.ChatID())
|
||||
if err != nil {
|
||||
return xerrors.Errorf("reload archived chat: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
familyChats = append(familyChats, chat)
|
||||
}
|
||||
return nil
|
||||
}, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := buffer.Flush(); err != nil {
|
||||
return familyChats, err
|
||||
}
|
||||
return familyChats, nil
|
||||
}
|
||||
@@ -0,0 +1,218 @@
|
||||
package chatstate_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatstate"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
// TestSetFamilyArchivedRejectsChildChat asserts the chatstate helper
|
||||
// rejects calls that target a child chat. Family archive flows must
|
||||
// always start at the root.
|
||||
func TestSetFamilyArchivedRejectsChildChat(t *testing.T) {
|
||||
t.Parallel()
|
||||
f := newTestFixture(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
root := dbgen.Chat(t, f.DB, database.Chat{
|
||||
OrganizationID: f.Org.ID,
|
||||
OwnerID: f.User.ID,
|
||||
LastModelConfigID: f.Model.ID,
|
||||
Title: "root",
|
||||
})
|
||||
child := dbgen.Chat(t, f.DB, database.Chat{
|
||||
OrganizationID: f.Org.ID,
|
||||
OwnerID: f.User.ID,
|
||||
LastModelConfigID: f.Model.ID,
|
||||
Title: "child",
|
||||
ParentChatID: uuid.NullUUID{UUID: root.ID, Valid: true},
|
||||
RootChatID: uuid.NullUUID{UUID: root.ID, Valid: true},
|
||||
})
|
||||
|
||||
_, err := chatstate.SetFamilyArchived(ctx, f.DB, f.Pub, chatstate.SetFamilyArchivedInput{RootID: child.ID, Archived: true})
|
||||
require.ErrorIs(t, err, chatstate.ErrChatNotRoot)
|
||||
|
||||
require.False(t, f.readChat(ctx, t, root.ID).Archived,
|
||||
"failed family archive must not touch the root")
|
||||
require.False(t, f.readChat(ctx, t, child.ID).Archived,
|
||||
"failed family archive must not touch the child")
|
||||
}
|
||||
|
||||
// TestSetFamilyArchivedRollsBackWhenMemberCannotArchive verifies that
|
||||
// SetFamilyArchived is atomic: when one family member is in a state
|
||||
// that cannot satisfy the SetArchived transition, the whole cascade
|
||||
// rolls back and no publications reach the inner publisher.
|
||||
func TestSetFamilyArchivedRollsBackWhenMemberCannotArchive(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
user, org, model := seedFamilyDeps(t, db)
|
||||
|
||||
// Root chat: waiting is archive-eligible (state W).
|
||||
root := dbgen.Chat(t, db, database.Chat{
|
||||
OrganizationID: org.ID,
|
||||
OwnerID: user.ID,
|
||||
LastModelConfigID: model.ID,
|
||||
Title: "root",
|
||||
Status: database.ChatStatusWaiting,
|
||||
})
|
||||
// Child chat: running with no queue is R0 and NOT archive
|
||||
// eligible per the chatstate transition matrix.
|
||||
child := dbgen.Chat(t, db, database.Chat{
|
||||
OrganizationID: org.ID,
|
||||
OwnerID: user.ID,
|
||||
LastModelConfigID: model.ID,
|
||||
Title: "child",
|
||||
Status: database.ChatStatusRunning,
|
||||
ParentChatID: uuid.NullUUID{UUID: root.ID, Valid: true},
|
||||
RootChatID: uuid.NullUUID{UUID: root.ID, Valid: true},
|
||||
})
|
||||
|
||||
pub := newRecordingPubsub()
|
||||
_, err := chatstate.SetFamilyArchived(ctx, db, pub, chatstate.SetFamilyArchivedInput{RootID: root.ID, Archived: true})
|
||||
require.Error(t, err, "child in R0 must reject SetArchived")
|
||||
require.ErrorIs(t, err, chatstate.ErrTransitionNotAllowed)
|
||||
|
||||
rootAfter, err := db.GetChatByID(ctx, root.ID)
|
||||
require.NoError(t, err)
|
||||
require.False(t, rootAfter.Archived, "root archive must roll back when a child cannot archive")
|
||||
childAfter, err := db.GetChatByID(ctx, child.ID)
|
||||
require.NoError(t, err)
|
||||
require.False(t, childAfter.Archived, "child must not be archived in the rolled-back cascade")
|
||||
|
||||
require.Empty(t, pub.channels,
|
||||
"rolled-back family archive must publish nothing through the inner publisher")
|
||||
}
|
||||
|
||||
// TestSetFamilyArchivedRejectsInvalidStateEvenWhenAlreadyDesired
|
||||
// verifies that invalid-state detection is never bypassed: a family
|
||||
// member in StateInvalid causes the cascade to fail with
|
||||
// ErrInvalidState even when that member's archived flag already
|
||||
// matches the desired value.
|
||||
func TestSetFamilyArchivedRejectsInvalidStateEvenWhenAlreadyDesired(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
user, org, model := seedFamilyDeps(t, db)
|
||||
|
||||
root := dbgen.Chat(t, db, database.Chat{
|
||||
OrganizationID: org.ID,
|
||||
OwnerID: user.ID,
|
||||
LastModelConfigID: model.ID,
|
||||
Title: "root",
|
||||
Status: database.ChatStatusWaiting,
|
||||
})
|
||||
child := dbgen.Chat(t, db, database.Chat{
|
||||
OrganizationID: org.ID,
|
||||
OwnerID: user.ID,
|
||||
LastModelConfigID: model.ID,
|
||||
Title: "child",
|
||||
// status=waiting, archived=true; we will add a queued message
|
||||
// to produce the chatstate-invalid combination (archived chat
|
||||
// with a queued backlog is outside the 13 valid states).
|
||||
Status: database.ChatStatusWaiting,
|
||||
Archived: true,
|
||||
ParentChatID: uuid.NullUUID{UUID: root.ID, Valid: true},
|
||||
RootChatID: uuid.NullUUID{UUID: root.ID, Valid: true},
|
||||
})
|
||||
|
||||
// Seed a queued message under the child to push it into the
|
||||
// chatstate-invalid combination.
|
||||
rawContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageText("queued"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = db.InsertChatQueuedMessage(ctx, database.InsertChatQueuedMessageParams{
|
||||
ChatID: child.ID,
|
||||
Content: rawContent.RawMessage,
|
||||
ModelConfigID: uuid.NullUUID{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
pub := newRecordingPubsub()
|
||||
_, err = chatstate.SetFamilyArchived(ctx, db, pub, chatstate.SetFamilyArchivedInput{
|
||||
RootID: root.ID,
|
||||
Archived: true,
|
||||
})
|
||||
require.ErrorIs(t, err, chatstate.ErrInvalidState,
|
||||
"invalid-state child blocks the cascade even when archived flag already matches")
|
||||
|
||||
// Root must not be archived because the cascade rolled back.
|
||||
rootAfter, err := db.GetChatByID(ctx, root.ID)
|
||||
require.NoError(t, err)
|
||||
require.False(t, rootAfter.Archived, "root must roll back when a child is in StateInvalid")
|
||||
|
||||
require.Empty(t, pub.channels,
|
||||
"rolled-back cascade must not publish anything")
|
||||
}
|
||||
|
||||
// TestSetFamilyArchivedAcceptsAlreadyDesiredMembers verifies that an
|
||||
// individually archived child does not block a root archive cascade.
|
||||
// The cascade converges to the desired state even when some family
|
||||
// members already match it.
|
||||
func TestSetFamilyArchivedAcceptsAlreadyDesiredMembers(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
user, org, model := seedFamilyDeps(t, db)
|
||||
|
||||
root := dbgen.Chat(t, db, database.Chat{
|
||||
OrganizationID: org.ID,
|
||||
OwnerID: user.ID,
|
||||
LastModelConfigID: model.ID,
|
||||
Title: "root",
|
||||
Status: database.ChatStatusWaiting,
|
||||
})
|
||||
child := dbgen.Chat(t, db, database.Chat{
|
||||
OrganizationID: org.ID,
|
||||
OwnerID: user.ID,
|
||||
LastModelConfigID: model.ID,
|
||||
Title: "child",
|
||||
Status: database.ChatStatusWaiting,
|
||||
ParentChatID: uuid.NullUUID{UUID: root.ID, Valid: true},
|
||||
RootChatID: uuid.NullUUID{UUID: root.ID, Valid: true},
|
||||
Archived: true,
|
||||
})
|
||||
|
||||
pub := newRecordingPubsub()
|
||||
family, err := chatstate.SetFamilyArchived(ctx, db, pub, chatstate.SetFamilyArchivedInput{RootID: root.ID, Archived: true})
|
||||
require.NoError(t, err,
|
||||
"already archived members must not block the cascade")
|
||||
require.Len(t, family, 2)
|
||||
|
||||
rootAfter, err := db.GetChatByID(ctx, root.ID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, rootAfter.Archived)
|
||||
childAfter, err := db.GetChatByID(ctx, child.ID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, childAfter.Archived)
|
||||
}
|
||||
|
||||
func seedFamilyDeps(t *testing.T, db database.Store) (database.User, database.Organization, database.ChatModelConfig) {
|
||||
t.Helper()
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
||||
UserID: user.ID,
|
||||
OrganizationID: org.ID,
|
||||
})
|
||||
dbgen.ChatProvider(t, db, database.ChatProvider{
|
||||
Provider: "openai",
|
||||
DisplayName: "openai",
|
||||
BaseUrl: "http://example.invalid",
|
||||
})
|
||||
model := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{
|
||||
Provider: "openai",
|
||||
IsDefault: true,
|
||||
})
|
||||
return user, org, model
|
||||
}
|
||||
@@ -0,0 +1,92 @@
|
||||
package chatstate_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatstate"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
// ownershipPublishCount returns the number of `chat:ownership` messages
|
||||
// recorded so far on the test publisher. Tests use it to assert that
|
||||
// transitions do or do not publish an ownership hint.
|
||||
func (r *recordingPubsub) ownershipPublishCount() int {
|
||||
count := 0
|
||||
for _, c := range r.channels {
|
||||
if c == coderdpubsub.ChatStateOwnershipChannel {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// sendQueuedMessage seeds one queued user message via SendMessage with
|
||||
// BusyBehaviorQueue. The chat must already be in a state that allows
|
||||
// SendMessage (typically R0, R1, or I*).
|
||||
func sendQueuedMessage(t *testing.T, f *testFixture, m *chatstate.ChatMachine, body string) chatstate.SendMessageResult {
|
||||
t.Helper()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
var send chatstate.SendMessageResult
|
||||
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
var err error
|
||||
send, err = tx.SendMessage(chatstate.SendMessageInput{
|
||||
Message: userTextMessage(body, f.User.ID, f.Model.ID),
|
||||
BusyBehavior: chatstate.BusyBehaviorQueue,
|
||||
})
|
||||
return err
|
||||
}))
|
||||
return send
|
||||
}
|
||||
|
||||
// sendInterruptMessage seeds one queued user message via SendMessage
|
||||
// with BusyBehaviorInterrupt. From R0/R1 this transitions the chat to
|
||||
// `interrupting` and appends the new user message to the queue tail.
|
||||
func sendInterruptMessage(t *testing.T, f *testFixture, m *chatstate.ChatMachine, body string) chatstate.SendMessageResult {
|
||||
t.Helper()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
var send chatstate.SendMessageResult
|
||||
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
var err error
|
||||
send, err = tx.SendMessage(chatstate.SendMessageInput{
|
||||
Message: userTextMessage(body, f.User.ID, f.Model.ID),
|
||||
BusyBehavior: chatstate.BusyBehaviorInterrupt,
|
||||
})
|
||||
return err
|
||||
}))
|
||||
return send
|
||||
}
|
||||
|
||||
// queuedIDsByPosition returns the queued-message IDs for the chat in
|
||||
// queue order.
|
||||
func queuedIDsByPosition(ctx context.Context, t *testing.T, f *testFixture, chatID uuid.UUID) []int64 {
|
||||
t.Helper()
|
||||
rows, err := f.DB.GetChatQueuedMessagesByPosition(ctx, chatID)
|
||||
require.NoError(t, err)
|
||||
ids := make([]int64, len(rows))
|
||||
for i, r := range rows {
|
||||
ids[i] = r.ID
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// historyMessageIDs returns the chat history message IDs ordered by
|
||||
// row id. Used to assert that PromoteQueuedMessage from R1/I1 does NOT
|
||||
// insert any history rows.
|
||||
func historyMessageIDs(ctx context.Context, t *testing.T, f *testFixture, chatID uuid.UUID) []int64 {
|
||||
t.Helper()
|
||||
msgs, err := f.DB.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
out := make([]int64, len(msgs))
|
||||
for i, m := range msgs {
|
||||
out[i] = m.ID
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -0,0 +1,300 @@
|
||||
package chatstate
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
|
||||
)
|
||||
|
||||
// HeartbeatStaleSeconds is the threshold chatstate uses when deciding
|
||||
// whether to publish a `chat:ownership` hint for a runnable chat. A
|
||||
// heartbeat older than this many seconds (by database time) counts
|
||||
// as stale and triggers a hint so an idle worker can attempt a
|
||||
// takeover.
|
||||
const HeartbeatStaleSeconds = 30
|
||||
|
||||
// Options configures a [ChatMachine]. Reserved for future tunables;
|
||||
// currently empty.
|
||||
type Options struct{}
|
||||
|
||||
// ChatMachine is a chat-scoped handle for state-machine operations on
|
||||
// a single chat row. It captures the database store, the pubsub
|
||||
// publisher, and the chat ID at construction time so callers do not
|
||||
// have to thread them through Update, Lock, or any transition method.
|
||||
//
|
||||
// ChatMachine values are cheap. Create one per chat for the lifetime
|
||||
// of a request or worker turn; do not cache mutable chat state across
|
||||
// calls.
|
||||
type ChatMachine struct {
|
||||
store database.Store
|
||||
publisher Publisher
|
||||
chatID uuid.UUID
|
||||
opts Options
|
||||
}
|
||||
|
||||
// NewChatMachine constructs a chat-scoped state machine handle. The
|
||||
// store may be the root database handle or an existing transaction
|
||||
// handle; publisher is the pubsub used for `chat:update` and
|
||||
// `chat:ownership` emissions. Both are required and captured for the
|
||||
// lifetime of the returned machine.
|
||||
func NewChatMachine(
|
||||
store database.Store,
|
||||
publisher Publisher,
|
||||
chatID uuid.UUID,
|
||||
opts Options,
|
||||
) *ChatMachine {
|
||||
return &ChatMachine{
|
||||
store: store,
|
||||
publisher: publisher,
|
||||
chatID: chatID,
|
||||
opts: opts,
|
||||
}
|
||||
}
|
||||
|
||||
// ChatID returns the chat ID this machine is scoped to.
|
||||
func (m *ChatMachine) ChatID() uuid.UUID { return m.chatID }
|
||||
|
||||
// Tx is the per-transaction handle passed to [ChatMachine.Update]
|
||||
// callbacks. It carries the active context, the transactional store,
|
||||
// and the chat ID. Tx does not cache mutable chat state across calls:
|
||||
// every transition method reads the chat row and queue cardinality
|
||||
// from the database on entry, so a bundle of transitions inside one
|
||||
// Update callback always validates against the latest committed state.
|
||||
type Tx struct {
|
||||
ctx context.Context
|
||||
store database.Store
|
||||
chatID uuid.UUID
|
||||
}
|
||||
|
||||
// Ctx returns the context the surrounding [ChatMachine.Update] call
|
||||
// is using.
|
||||
func (tx *Tx) Ctx() context.Context { return tx.ctx }
|
||||
|
||||
// ChatID returns the chat ID this transaction is scoped to.
|
||||
func (tx *Tx) ChatID() uuid.UUID { return tx.chatID }
|
||||
|
||||
// Store exposes the active transaction store so callers can perform
|
||||
// validation reads (for example loading the messages affected by an
|
||||
// EditMessage transition) and metadata writes (for example updating
|
||||
// title or labels) that must be atomic with the transition.
|
||||
//
|
||||
// Callers MUST NOT use Store to mutate execution-state tables
|
||||
// (chats.status, chat_messages, chat_queued_messages, chat_heartbeats,
|
||||
// or the version fields on chats). Those mutations belong to the
|
||||
// transition methods and are validated against the state machine
|
||||
// matrix.
|
||||
func (tx *Tx) Store() database.Store { return tx.store }
|
||||
|
||||
// loadState reads the current chat row and queue cardinality from the
|
||||
// active transaction, classifies the execution state, and returns the
|
||||
// inputs every transition method needs. Returns ErrChatNotFound if
|
||||
// the chat row was deleted in this transaction (or never existed).
|
||||
func (tx *Tx) loadState() (database.Chat, ExecutionState, error) {
|
||||
chat, err := tx.store.GetChatByID(tx.ctx, tx.chatID)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return database.Chat{}, StateN, ErrChatNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return database.Chat{}, "", xerrors.Errorf("load chat: %w", err)
|
||||
}
|
||||
count, err := tx.store.CountChatQueuedMessages(tx.ctx, tx.chatID)
|
||||
if err != nil {
|
||||
return database.Chat{}, "", xerrors.Errorf("count queued messages: %w", err)
|
||||
}
|
||||
return chat, ClassifyExecutionState(chat, count > 0, true), nil
|
||||
}
|
||||
|
||||
// requireFromAllowed loads the current state and validates t against
|
||||
// the transition matrix. Returns the loaded chat and execution state
|
||||
// on success, [ErrInvalidState] when the chat is in an invalid state
|
||||
// and t is not [TransitionReconcileInvalidState], and a typed
|
||||
// *TransitionError otherwise.
|
||||
func (tx *Tx) requireFromAllowed(t Transition) (database.Chat, ExecutionState, error) {
|
||||
chat, from, err := tx.loadState()
|
||||
if err != nil {
|
||||
return chat, from, err
|
||||
}
|
||||
if from == StateInvalid && t != TransitionReconcileInvalidState {
|
||||
return chat, from, ErrInvalidState
|
||||
}
|
||||
if err := requireExecutionTransition(t, from); err != nil {
|
||||
return chat, from, err
|
||||
}
|
||||
return chat, from, nil
|
||||
}
|
||||
|
||||
// Update applies one or more transitions to the machine's chat.
|
||||
//
|
||||
// Update opens a transaction on the captured store, atomically locks
|
||||
// the chat row with FOR UPDATE and increments `snapshot_version`
|
||||
// exactly once, then runs fn against a fresh [*Tx]. It constructs a
|
||||
// [PublishBuffer], enqueues `chat:update` (and a `chat:ownership` hint
|
||||
// when the post-transition state is worker-runnable and ownership is
|
||||
// missing or stale) inside the transaction, and flushes the buffer only after
|
||||
// the transaction function succeeds. If the transaction rolls back,
|
||||
// the deferred Discard suppresses every buffered publication so
|
||||
// subscribers never see uncommitted state.
|
||||
//
|
||||
// If Update is called with a store that is already in a transaction,
|
||||
// [database.Store.InTx] reuses the active transaction. In that case,
|
||||
// callers that need outer-transaction publication semantics can pass a
|
||||
// [PublishBuffer] as the machine publisher. The inner buffer flushes
|
||||
// into the outer buffer, and the outer owner remains responsible for
|
||||
// publishing only after the outer transaction commits.
|
||||
//
|
||||
// Callers must not pass a store or publisher here; they belong on the
|
||||
// machine.
|
||||
//
|
||||
// If the chat row does not exist, Update returns [ErrChatNotFound]
|
||||
// without mutating anything.
|
||||
//
|
||||
// Callbacks that return an error roll back the transaction (rolling
|
||||
// back the automatic snapshot bump) and publish nothing.
|
||||
func (m *ChatMachine) Update(
|
||||
ctx context.Context,
|
||||
fn func(*Tx) error,
|
||||
) error {
|
||||
if m.store == nil {
|
||||
return xerrors.New("chatstate: ChatMachine has nil store")
|
||||
}
|
||||
if m.publisher == nil {
|
||||
return xerrors.New("chatstate: ChatMachine has nil publisher")
|
||||
}
|
||||
|
||||
buffer := NewPublishBuffer(m.publisher)
|
||||
defer buffer.Discard()
|
||||
|
||||
err := m.store.InTx(func(store database.Store) error {
|
||||
if _, err := store.LockChatAndBumpSnapshotVersion(ctx, m.chatID); err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return ErrChatNotFound
|
||||
}
|
||||
return xerrors.Errorf("lock chat and bump snapshot: %w", err)
|
||||
}
|
||||
tx := &Tx{
|
||||
ctx: ctx,
|
||||
store: store,
|
||||
chatID: m.chatID,
|
||||
}
|
||||
if err := fn(tx); err != nil {
|
||||
return err
|
||||
}
|
||||
chat, state, err := tx.loadState()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := buffer.Publish(
|
||||
coderdpubsub.ChatStateUpdateChannel(chat.ID),
|
||||
buildChatUpdateMessage(chat),
|
||||
); err != nil {
|
||||
return xerrors.Errorf("buffer chat update: %w", err)
|
||||
}
|
||||
if state.IsRunnable() {
|
||||
stale, herr := ownershipStaleOrMissing(ctx, store, chat, HeartbeatStaleSeconds)
|
||||
if herr != nil {
|
||||
return xerrors.Errorf("evaluate ownership: %w", herr)
|
||||
}
|
||||
if stale {
|
||||
if err := buffer.Publish(
|
||||
coderdpubsub.ChatStateOwnershipChannel,
|
||||
buildChatOwnershipMessage(chat),
|
||||
); err != nil {
|
||||
return xerrors.Errorf("buffer ownership hint: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return buffer.Flush()
|
||||
}
|
||||
|
||||
// Lock locks the chat row with FOR UPDATE and runs fn in a
|
||||
// transaction without advancing snapshot_version. It uses the store
|
||||
// captured by [NewChatMachine]. Use it when the caller needs a
|
||||
// consistent chat snapshot plus related rows such as messages or
|
||||
// queued messages but is NOT applying a transition.
|
||||
//
|
||||
// Callers must not pass a store here; it belongs on the machine.
|
||||
//
|
||||
// Lock publishes nothing. Callback errors roll back the transaction
|
||||
// and propagate to the caller.
|
||||
func (m *ChatMachine) Lock(
|
||||
ctx context.Context,
|
||||
fn func(database.Store) error,
|
||||
) error {
|
||||
if m.store == nil {
|
||||
return xerrors.New("chatstate: ChatMachine has nil store")
|
||||
}
|
||||
return m.store.InTx(func(store database.Store) error {
|
||||
// GetChatByIDForUpdate locks the row WITHOUT bumping snapshot.
|
||||
_, err := store.GetChatByIDForUpdate(ctx, m.chatID)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return ErrChatNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return xerrors.Errorf("lock chat: %w", err)
|
||||
}
|
||||
return fn(store)
|
||||
}, nil)
|
||||
}
|
||||
|
||||
// ReadLock takes a shared lock on the chat row with FOR SHARE and runs
|
||||
// fn in a transaction without advancing snapshot_version. It uses the
|
||||
// store captured by [NewChatMachine]. Use it when the caller needs a
|
||||
// consistent chat snapshot plus related rows such as messages or queued
|
||||
// messages but is NOT applying a transition and does NOT need to block
|
||||
// concurrent readers.
|
||||
//
|
||||
// Unlike [ChatMachine.Lock], the FOR SHARE lock permits other shared
|
||||
// lockers to proceed concurrently while still blocking writers that take
|
||||
// FOR UPDATE (such as [ChatMachine.Update] and [ChatMachine.Lock]) until
|
||||
// the transaction commits.
|
||||
//
|
||||
// Callers must not pass a store here; it belongs on the machine.
|
||||
//
|
||||
// ReadLock publishes nothing. Callback errors roll back the transaction
|
||||
// and propagate to the caller.
|
||||
func (m *ChatMachine) ReadLock(
|
||||
ctx context.Context,
|
||||
fn func(database.Store) error,
|
||||
) error {
|
||||
if m.store == nil {
|
||||
return xerrors.New("chatstate: ChatMachine has nil store")
|
||||
}
|
||||
return m.store.InTx(func(store database.Store) error {
|
||||
// GetChatByIDForShare takes a shared lock on the row WITHOUT
|
||||
// bumping snapshot.
|
||||
_, err := store.GetChatByIDForShare(ctx, m.chatID)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return ErrChatNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return xerrors.Errorf("read lock chat: %w", err)
|
||||
}
|
||||
return fn(store)
|
||||
}, nil)
|
||||
}
|
||||
|
||||
// ownershipStaleOrMissing reports whether the chat's current
|
||||
// (chat_id, runner_id) lease is missing or stale. The staleSeconds
|
||||
// threshold is forwarded to [database.IsChatHeartbeatStale] so the
|
||||
// comparison runs against database time inside a single SQL query.
|
||||
func ownershipStaleOrMissing(ctx context.Context, store database.Store, chat database.Chat, staleSeconds int32) (bool, error) {
|
||||
if !chat.WorkerID.Valid || !chat.RunnerID.Valid {
|
||||
return true, nil
|
||||
}
|
||||
return store.IsChatHeartbeatStale(ctx, database.IsChatHeartbeatStaleParams{
|
||||
ChatID: chat.ID,
|
||||
RunnerID: chat.RunnerID.UUID,
|
||||
StaleSeconds: staleSeconds,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,406 @@
|
||||
package chatstate_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatstate"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
// testFixture bundles the resources every integration test needs:
|
||||
// a database, a publisher recorder, a user/org/model triple, and
|
||||
// helper accessors. It is intentionally NOT a generic chatd test
|
||||
// fixture; tests outside this package should not depend on it.
|
||||
type testFixture struct {
|
||||
DB database.Store
|
||||
PubSub pubsub.Pubsub
|
||||
Pub *recordingPubsub
|
||||
User database.User
|
||||
Org database.Organization
|
||||
Model database.ChatModelConfig
|
||||
}
|
||||
|
||||
func newTestFixture(t *testing.T) *testFixture {
|
||||
t.Helper()
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
||||
UserID: user.ID,
|
||||
OrganizationID: org.ID,
|
||||
})
|
||||
dbgen.ChatProvider(t, db, database.ChatProvider{
|
||||
Provider: "openai",
|
||||
DisplayName: "openai",
|
||||
BaseUrl: "http://example.invalid",
|
||||
})
|
||||
model := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{
|
||||
Provider: "openai",
|
||||
IsDefault: true,
|
||||
})
|
||||
pub := newRecordingPubsub()
|
||||
return &testFixture{
|
||||
DB: db,
|
||||
PubSub: ps,
|
||||
Pub: pub,
|
||||
User: user,
|
||||
Org: org,
|
||||
Model: model,
|
||||
}
|
||||
}
|
||||
|
||||
// readChat re-reads the chat from the database. Tests use this to
|
||||
// verify post-transition state because transition results no longer
|
||||
// carry the chat snapshot.
|
||||
func (f *testFixture) readChat(ctx context.Context, t *testing.T, chatID uuid.UUID) database.Chat {
|
||||
t.Helper()
|
||||
chat, err := f.DB.GetChatByID(ctx, chatID)
|
||||
require.NoError(t, err)
|
||||
return chat
|
||||
}
|
||||
|
||||
// classify reads the chat plus queue cardinality and returns the RFC
|
||||
// execution state shorthand.
|
||||
func (f *testFixture) classify(ctx context.Context, t *testing.T, chatID uuid.UUID) chatstate.ExecutionState {
|
||||
t.Helper()
|
||||
chat := f.readChat(ctx, t, chatID)
|
||||
count, err := f.DB.CountChatQueuedMessages(ctx, chatID)
|
||||
require.NoError(t, err)
|
||||
return chatstate.ClassifyExecutionState(chat, count > 0, true)
|
||||
}
|
||||
|
||||
// recordingPubsub captures every Publish call so tests can assert on
|
||||
// the chatstate notifications without needing a live subscriber. The
|
||||
// mutex makes it safe to use from concurrent tests that race multiple
|
||||
// goroutines through the same publisher (see TestConcurrentUpdatesSerializeOnChatRow).
|
||||
type recordingPubsub struct {
|
||||
mu sync.Mutex
|
||||
channels []string
|
||||
payloads [][]byte
|
||||
}
|
||||
|
||||
func newRecordingPubsub() *recordingPubsub { return &recordingPubsub{} }
|
||||
|
||||
func (r *recordingPubsub) Publish(channel string, payload []byte) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.channels = append(r.channels, channel)
|
||||
cp := make([]byte, len(payload))
|
||||
copy(cp, payload)
|
||||
r.payloads = append(r.payloads, cp)
|
||||
return nil
|
||||
}
|
||||
|
||||
// expectChatUpdate finds the most recent chat:update message on the
|
||||
// per-chat channel and asserts that it has snapshot_version == want.
|
||||
func (r *recordingPubsub) expectChatUpdate(t *testing.T, chatID uuid.UUID, wantSnapshot int64) {
|
||||
t.Helper()
|
||||
channel := coderdpubsub.ChatStateUpdateChannel(chatID)
|
||||
for i := len(r.channels) - 1; i >= 0; i-- {
|
||||
if r.channels[i] != channel {
|
||||
continue
|
||||
}
|
||||
var msg coderdpubsub.ChatStateUpdateMessage
|
||||
require.NoError(t, json.Unmarshal(r.payloads[i], &msg))
|
||||
require.Equal(t, wantSnapshot, msg.SnapshotVersion)
|
||||
return
|
||||
}
|
||||
t.Fatalf("no chat:update on %s", channel)
|
||||
}
|
||||
|
||||
func (r *recordingPubsub) hasOwnership() bool {
|
||||
for _, c := range r.channels {
|
||||
if c == coderdpubsub.ChatStateOwnershipChannel {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func userTextMessage(text string, createdBy uuid.UUID, modelConfigID uuid.UUID) chatstate.Message {
|
||||
parts := []codersdk.ChatMessagePart{codersdk.ChatMessageText(text)}
|
||||
raw, err := chatprompt.MarshalParts(parts)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return chatstate.Message{
|
||||
Role: database.ChatMessageRoleUser,
|
||||
Content: raw,
|
||||
Visibility: database.ChatMessageVisibilityBoth,
|
||||
ContentVersion: chatprompt.CurrentContentVersion,
|
||||
CreatedBy: uuid.NullUUID{UUID: createdBy, Valid: true},
|
||||
ModelConfigID: uuid.NullUUID{UUID: modelConfigID, Valid: true},
|
||||
}
|
||||
}
|
||||
|
||||
// createTestChat is the standard "fresh R0 chat" helper used by other
|
||||
// tests. It exercises CreateChat itself.
|
||||
func createTestChat(t *testing.T, f *testFixture) chatstate.CreateChatResult {
|
||||
t.Helper()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
res, err := chatstate.CreateChat(ctx, f.DB, f.Pub, chatstate.CreateChatInput{
|
||||
OrganizationID: f.Org.ID,
|
||||
OwnerID: f.User.ID,
|
||||
LastModelConfigID: f.Model.ID,
|
||||
Title: "test",
|
||||
ClientType: database.ChatClientTypeApi,
|
||||
InitialMessages: []chatstate.Message{
|
||||
userTextMessage("hello", f.User.ID, f.Model.ID),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return res
|
||||
}
|
||||
|
||||
func TestChatMachine_Update_RejectsMissingChat(t *testing.T) {
|
||||
t.Parallel()
|
||||
f := newTestFixture(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, uuid.New(), chatstate.Options{})
|
||||
err := m.Update(ctx, func(tx *chatstate.Tx) error { return nil })
|
||||
require.ErrorIs(t, err, chatstate.ErrChatNotFound)
|
||||
require.Empty(t, f.Pub.channels)
|
||||
}
|
||||
|
||||
func TestChatMachine_Lock_DoesNotBumpSnapshot(t *testing.T) {
|
||||
t.Parallel()
|
||||
f := newTestFixture(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
created := createTestChat(t, f)
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
|
||||
|
||||
before := f.readChat(ctx, t, created.Chat.ID)
|
||||
publishedBefore := len(f.Pub.channels)
|
||||
|
||||
require.NoError(t, m.Lock(ctx, func(_ database.Store) error {
|
||||
return nil
|
||||
}))
|
||||
after := f.readChat(ctx, t, created.Chat.ID)
|
||||
require.Equal(t, before.SnapshotVersion, after.SnapshotVersion)
|
||||
require.Equal(t, publishedBefore, len(f.Pub.channels), "Lock must not publish")
|
||||
}
|
||||
|
||||
func TestChatMachine_ReadLock_DoesNotBumpSnapshot(t *testing.T) {
|
||||
t.Parallel()
|
||||
f := newTestFixture(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
created := createTestChat(t, f)
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
|
||||
|
||||
before := f.readChat(ctx, t, created.Chat.ID)
|
||||
publishedBefore := len(f.Pub.channels)
|
||||
|
||||
var called bool
|
||||
require.NoError(t, m.ReadLock(ctx, func(_ database.Store) error {
|
||||
called = true
|
||||
return nil
|
||||
}))
|
||||
require.True(t, called, "ReadLock must invoke the callback")
|
||||
after := f.readChat(ctx, t, created.Chat.ID)
|
||||
require.Equal(t, before.SnapshotVersion, after.SnapshotVersion)
|
||||
require.Equal(t, publishedBefore, len(f.Pub.channels), "ReadLock must not publish")
|
||||
}
|
||||
|
||||
func TestChatMachine_ReadLock_RejectsMissingChat(t *testing.T) {
|
||||
t.Parallel()
|
||||
f := newTestFixture(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, uuid.New(), chatstate.Options{})
|
||||
err := m.ReadLock(ctx, func(_ database.Store) error {
|
||||
t.Fatal("callback must not run when the chat is missing")
|
||||
return nil
|
||||
})
|
||||
require.ErrorIs(t, err, chatstate.ErrChatNotFound)
|
||||
require.Empty(t, f.Pub.channels)
|
||||
}
|
||||
|
||||
func TestChatMachine_UpdatePublishesAfterCommit(t *testing.T) {
|
||||
t.Parallel()
|
||||
f := newTestFixture(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
created := createTestChat(t, f)
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
|
||||
|
||||
publishedBefore := len(f.Pub.channels)
|
||||
// Run a no-op Update; snapshot bump still happens, one update message
|
||||
// should follow the commit.
|
||||
require.NoError(t, m.Update(ctx, func(_ *chatstate.Tx) error { return nil }))
|
||||
channel := coderdpubsub.ChatStateUpdateChannel(created.Chat.ID)
|
||||
var found bool
|
||||
for _, c := range f.Pub.channels[publishedBefore:] {
|
||||
if c == channel {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, found, "expected one chat:update message after commit")
|
||||
}
|
||||
|
||||
func TestChatMachine_FailedUpdate_PublishesNothing(t *testing.T) {
|
||||
t.Parallel()
|
||||
f := newTestFixture(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
created := createTestChat(t, f)
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
|
||||
|
||||
before := f.readChat(ctx, t, created.Chat.ID)
|
||||
channelsBefore := len(f.Pub.channels)
|
||||
expected := newSentinel()
|
||||
cbErr := m.Update(ctx, func(_ *chatstate.Tx) error { return expected })
|
||||
require.ErrorIs(t, cbErr, expected)
|
||||
require.Equal(t, channelsBefore, len(f.Pub.channels), "failed update should not publish")
|
||||
// snapshot_version should not have advanced.
|
||||
after := f.readChat(ctx, t, created.Chat.ID)
|
||||
require.Equal(t, before.SnapshotVersion, after.SnapshotVersion)
|
||||
}
|
||||
|
||||
func TestMessageRevisionTrigger_AssignsRevisionFromSnapshot(t *testing.T) {
|
||||
t.Parallel()
|
||||
f := newTestFixture(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
created := createTestChat(t, f) // snapshot 1, history_version 1 via trigger
|
||||
|
||||
// CommitStep an assistant message; it should land with revision = chat.snapshot_version after the bump.
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
|
||||
var step chatstate.CommitStepResult
|
||||
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
assistant := userTextMessage("assistant", f.User.ID, f.Model.ID)
|
||||
assistant.Role = database.ChatMessageRoleAssistant
|
||||
var err error
|
||||
step, err = tx.CommitStep(chatstate.CommitStepInput{
|
||||
Messages: []chatstate.Message{assistant},
|
||||
})
|
||||
return err
|
||||
}))
|
||||
require.Len(t, step.InsertedMessages, 1)
|
||||
after := f.readChat(ctx, t, created.Chat.ID)
|
||||
// The Update call bumps snapshot_version once before the trigger
|
||||
// runs, so the new revision should equal the bumped snapshot.
|
||||
require.Equal(t, after.SnapshotVersion, step.InsertedMessages[0].Revision)
|
||||
require.Equal(t, after.SnapshotVersion, after.HistoryVersion)
|
||||
require.Equal(t, int64(0), after.GenerationAttempt, "trigger resets generation_attempt to 0")
|
||||
}
|
||||
|
||||
func TestQueueVersionTrigger_AdvancesOnInsert(t *testing.T) {
|
||||
t.Parallel()
|
||||
f := newTestFixture(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
created := createTestChat(t, f) // queue_version starts at 0
|
||||
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
|
||||
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
_, err := tx.SendMessage(chatstate.SendMessageInput{
|
||||
Message: userTextMessage("queue", f.User.ID, f.Model.ID),
|
||||
BusyBehavior: chatstate.BusyBehaviorQueue,
|
||||
})
|
||||
return err
|
||||
}))
|
||||
after := f.readChat(ctx, t, created.Chat.ID)
|
||||
require.Equal(t, after.SnapshotVersion, after.QueueVersion)
|
||||
require.Greater(t, after.QueueVersion, int64(0))
|
||||
}
|
||||
|
||||
func TestQueueVersionTrigger_StableForNonQueueMutations(t *testing.T) {
|
||||
t.Parallel()
|
||||
f := newTestFixture(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
created := createTestChat(t, f)
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
|
||||
|
||||
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
assistant := userTextMessage("assistant", f.User.ID, f.Model.ID)
|
||||
assistant.Role = database.ChatMessageRoleAssistant
|
||||
_, err := tx.CommitStep(chatstate.CommitStepInput{
|
||||
Messages: []chatstate.Message{assistant},
|
||||
})
|
||||
return err
|
||||
}))
|
||||
// queue_version must remain unchanged from initial 0.
|
||||
require.Equal(t, int64(0), f.readChat(ctx, t, created.Chat.ID).QueueVersion)
|
||||
}
|
||||
|
||||
// TestUpdateFlushesBufferedPublicationsAfterCommit verifies that
|
||||
// ChatMachine.Update owns the PublishBuffer lifecycle: nothing
|
||||
// reaches the inner publisher until after the transaction commits,
|
||||
// and at commit the buffered chat:update is forwarded.
|
||||
func TestUpdateFlushesBufferedPublicationsAfterCommit(t *testing.T) {
|
||||
t.Parallel()
|
||||
f := newTestFixture(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
created := createTestChat(t, f)
|
||||
|
||||
channel := coderdpubsub.ChatStateUpdateChannel(created.Chat.ID)
|
||||
baseline := countChannel(f.Pub.channels, channel)
|
||||
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
|
||||
|
||||
// During the callback, no new chat:update for this chat may have
|
||||
// reached the inner publisher because the buffer holds it.
|
||||
require.NoError(t, m.Update(ctx, func(_ *chatstate.Tx) error {
|
||||
require.Equal(t, baseline, countChannel(f.Pub.channels, channel),
|
||||
"inner publisher saw chat:update before transaction committed")
|
||||
return nil
|
||||
}))
|
||||
|
||||
require.Equal(t, baseline+1, countChannel(f.Pub.channels, channel),
|
||||
"exactly one new chat:update reached the inner publisher after commit")
|
||||
}
|
||||
|
||||
// TestUpdateDiscardsBufferedPublicationsOnCallbackError verifies the
|
||||
// deferred Discard path: when the callback returns an error the
|
||||
// transaction rolls back and no buffered messages reach the inner
|
||||
// publisher.
|
||||
func TestUpdateDiscardsBufferedPublicationsOnCallbackError(t *testing.T) {
|
||||
t.Parallel()
|
||||
f := newTestFixture(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
created := createTestChat(t, f)
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
|
||||
|
||||
before := f.readChat(ctx, t, created.Chat.ID)
|
||||
channelsBefore := len(f.Pub.channels)
|
||||
|
||||
sentinel := xerrors.New("callback boom")
|
||||
err := m.Update(ctx, func(_ *chatstate.Tx) error { return sentinel })
|
||||
require.ErrorIs(t, err, sentinel)
|
||||
|
||||
require.Equal(t, channelsBefore, len(f.Pub.channels),
|
||||
"failed update must not flush any buffered publications")
|
||||
after := f.readChat(ctx, t, created.Chat.ID)
|
||||
require.Equal(t, before.SnapshotVersion, after.SnapshotVersion,
|
||||
"snapshot bump rolled back when callback returns error")
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// helpers
|
||||
// =============================================================================
|
||||
|
||||
type sentinelError struct{ msg string }
|
||||
|
||||
func (s *sentinelError) Error() string { return s.msg }
|
||||
|
||||
func newSentinel() error { return &sentinelError{msg: "sentinel"} }
|
||||
|
||||
func countChannel(channels []string, channel string) int {
|
||||
c := 0
|
||||
for _, ch := range channels {
|
||||
if ch == channel {
|
||||
c++
|
||||
}
|
||||
}
|
||||
return c
|
||||
}
|
||||
@@ -0,0 +1,112 @@
|
||||
package chatstate
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
)
|
||||
|
||||
// Message is the durable message input shape used by chatstate
|
||||
// transitions. It is intentionally lower level than the SDK message
|
||||
// request types: callers must produce a fully materialized message
|
||||
// (parsed parts, calculated cost, resolved model config) before
|
||||
// passing it in.
|
||||
//
|
||||
// The state machine never reshapes a Message except to attach the
|
||||
// runtime `chat_id`. The `revision` column is assigned by the
|
||||
// `set_chat_message_revision` trigger; runtime code must not populate
|
||||
// it.
|
||||
type Message struct {
|
||||
Role database.ChatMessageRole
|
||||
Content pqtype.NullRawMessage
|
||||
Visibility database.ChatMessageVisibility
|
||||
ModelConfigID uuid.NullUUID
|
||||
CreatedBy uuid.NullUUID
|
||||
ContentVersion int16
|
||||
Compressed bool
|
||||
InputTokens sql.NullInt64
|
||||
OutputTokens sql.NullInt64
|
||||
TotalTokens sql.NullInt64
|
||||
ReasoningTokens sql.NullInt64
|
||||
CacheCreationTokens sql.NullInt64
|
||||
CacheReadTokens sql.NullInt64
|
||||
ContextLimit sql.NullInt64
|
||||
TotalCostMicros sql.NullInt64
|
||||
RuntimeMs sql.NullInt64
|
||||
ProviderResponseID sql.NullString
|
||||
}
|
||||
|
||||
// toInsertParams converts a batch of Messages into the parallel-array
|
||||
// shape required by `InsertChatMessages`. The returned struct has all
|
||||
// arrays sized to len(messages).
|
||||
//
|
||||
// The chat ID is supplied by the caller because Message itself does
|
||||
// not carry one (the chat machine already knows the chat).
|
||||
func toInsertParams(chatID uuid.UUID, messages []Message) database.InsertChatMessagesParams {
|
||||
n := len(messages)
|
||||
params := database.InsertChatMessagesParams{
|
||||
ChatID: chatID,
|
||||
CreatedBy: make([]uuid.UUID, n),
|
||||
ModelConfigID: make([]uuid.UUID, n),
|
||||
Role: make([]database.ChatMessageRole, n),
|
||||
Content: make([]string, n),
|
||||
ContentVersion: make([]int16, n),
|
||||
Visibility: make([]database.ChatMessageVisibility, n),
|
||||
InputTokens: make([]int64, n),
|
||||
OutputTokens: make([]int64, n),
|
||||
TotalTokens: make([]int64, n),
|
||||
ReasoningTokens: make([]int64, n),
|
||||
CacheCreationTokens: make([]int64, n),
|
||||
CacheReadTokens: make([]int64, n),
|
||||
ContextLimit: make([]int64, n),
|
||||
Compressed: make([]bool, n),
|
||||
TotalCostMicros: make([]int64, n),
|
||||
RuntimeMs: make([]int64, n),
|
||||
ProviderResponseID: make([]string, n),
|
||||
}
|
||||
for i, m := range messages {
|
||||
params.CreatedBy[i] = nullUUIDOrNil(m.CreatedBy)
|
||||
params.ModelConfigID[i] = nullUUIDOrNil(m.ModelConfigID)
|
||||
params.Role[i] = m.Role
|
||||
if m.Content.Valid {
|
||||
params.Content[i] = string(m.Content.RawMessage)
|
||||
} else {
|
||||
// Use the JSON null literal; UNNEST + ::jsonb requires a
|
||||
// valid JSON value and the trigger leaves it untouched.
|
||||
params.Content[i] = "null"
|
||||
}
|
||||
params.ContentVersion[i] = m.ContentVersion
|
||||
params.Visibility[i] = m.Visibility
|
||||
params.InputTokens[i] = nullInt64Or(m.InputTokens, 0)
|
||||
params.OutputTokens[i] = nullInt64Or(m.OutputTokens, 0)
|
||||
params.TotalTokens[i] = nullInt64Or(m.TotalTokens, 0)
|
||||
params.ReasoningTokens[i] = nullInt64Or(m.ReasoningTokens, 0)
|
||||
params.CacheCreationTokens[i] = nullInt64Or(m.CacheCreationTokens, 0)
|
||||
params.CacheReadTokens[i] = nullInt64Or(m.CacheReadTokens, 0)
|
||||
params.ContextLimit[i] = nullInt64Or(m.ContextLimit, 0)
|
||||
params.Compressed[i] = m.Compressed
|
||||
params.TotalCostMicros[i] = nullInt64Or(m.TotalCostMicros, 0)
|
||||
params.RuntimeMs[i] = nullInt64Or(m.RuntimeMs, 0)
|
||||
if m.ProviderResponseID.Valid {
|
||||
params.ProviderResponseID[i] = m.ProviderResponseID.String
|
||||
}
|
||||
}
|
||||
return params
|
||||
}
|
||||
|
||||
func nullUUIDOrNil(u uuid.NullUUID) uuid.UUID {
|
||||
if u.Valid {
|
||||
return u.UUID
|
||||
}
|
||||
return uuid.Nil
|
||||
}
|
||||
|
||||
func nullInt64Or(v sql.NullInt64, fallback int64) int64 {
|
||||
if v.Valid {
|
||||
return v.Int64
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
@@ -0,0 +1,169 @@
|
||||
package chatstate
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
|
||||
)
|
||||
|
||||
// Publisher is the minimal interface chatstate needs to publish
|
||||
// pubsub messages. It is intentionally compatible with
|
||||
// database/pubsub.Pubsub: real callers pass the live pubsub directly
|
||||
// and tests pass a recording fake. Chatstate entry points
|
||||
// ([ChatMachine.Update], [CreateChat], [SetFamilyArchived]) own the
|
||||
// internal [PublishBuffer] lifecycle so callers never see it.
|
||||
type Publisher interface {
|
||||
Publish(event string, message []byte) error
|
||||
}
|
||||
|
||||
// PublishBuffer is a [Publisher] that records each Publish call in
|
||||
// order without forwarding it until [PublishBuffer.Flush] is called.
|
||||
// It is an internal primitive used by chatstate entry points to
|
||||
// hold pubsub messages until the surrounding transaction commits,
|
||||
// and by tests that need to observe buffered output. Normal callers
|
||||
// do not construct a PublishBuffer themselves and do not invoke
|
||||
// Flush or Discard; chatstate's entry points own that lifecycle.
|
||||
type PublishBuffer struct {
|
||||
inner Publisher
|
||||
|
||||
mu sync.Mutex
|
||||
pending []bufferedMessage
|
||||
flushed bool
|
||||
disabled bool
|
||||
}
|
||||
|
||||
type bufferedMessage struct {
|
||||
Channel string
|
||||
Payload []byte
|
||||
}
|
||||
|
||||
// NewPublishBuffer constructs a PublishBuffer that, when flushed, will
|
||||
// forward messages in order to inner.
|
||||
func NewPublishBuffer(inner Publisher) *PublishBuffer {
|
||||
return &PublishBuffer{inner: inner}
|
||||
}
|
||||
|
||||
// Publish records a message. It never forwards to the inner publisher
|
||||
// until [PublishBuffer.Flush] is called. Returns an error if Flush has
|
||||
// already happened to make accidental reuse obvious.
|
||||
func (b *PublishBuffer) Publish(channel string, payload []byte) error {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
if b.flushed {
|
||||
return xerrors.Errorf("publish buffer flushed; cannot accept message for %q", channel)
|
||||
}
|
||||
if b.disabled {
|
||||
return nil
|
||||
}
|
||||
cp := make([]byte, len(payload))
|
||||
copy(cp, payload)
|
||||
b.pending = append(b.pending, bufferedMessage{Channel: channel, Payload: cp})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Flush forwards every pending message to the inner publisher in the
|
||||
// order it was buffered, then marks the buffer flushed. The first
|
||||
// publish error is returned with the channel name annotated; later
|
||||
// messages still attempt to publish so the inner publisher sees a
|
||||
// consistent state.
|
||||
func (b *PublishBuffer) Flush() error {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
if b.flushed {
|
||||
return nil
|
||||
}
|
||||
b.flushed = true
|
||||
var firstErr error
|
||||
for _, msg := range b.pending {
|
||||
if err := b.inner.Publish(msg.Channel, msg.Payload); err != nil && firstErr == nil {
|
||||
firstErr = xerrors.Errorf("publish %s: %w", msg.Channel, err)
|
||||
}
|
||||
}
|
||||
return firstErr
|
||||
}
|
||||
|
||||
// Discard clears the buffered messages without forwarding them. It
|
||||
// is safe to call multiple times and is harmless after [PublishBuffer.Flush]:
|
||||
// once Flush has marked the buffer flushed and forwarded its
|
||||
// pending messages, a subsequent Discard simply clears the (now
|
||||
// empty) pending slice and sets the buffer to drop any future
|
||||
// Publish calls. This makes `defer buf.Discard()` a safe pattern
|
||||
// after a successful flush, including the one chatstate entry
|
||||
// points use to own the buffer lifecycle.
|
||||
func (b *PublishBuffer) Discard() {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
b.pending = nil
|
||||
b.disabled = true
|
||||
}
|
||||
|
||||
// pending returns a snapshot of the buffered messages, primarily for
|
||||
// tests via [PublishBuffer.BufferedChannels]. The returned slice is a
|
||||
// copy and safe to inspect without holding the buffer lock.
|
||||
func (b *PublishBuffer) snapshotPending() []bufferedMessage {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
out := make([]bufferedMessage, len(b.pending))
|
||||
copy(out, b.pending)
|
||||
return out
|
||||
}
|
||||
|
||||
// BufferedChannels returns just the channels of the pending messages
|
||||
// in order. Primarily useful for assertions in tests.
|
||||
func (b *PublishBuffer) BufferedChannels() []string {
|
||||
pending := b.snapshotPending()
|
||||
out := make([]string, len(pending))
|
||||
for i, m := range pending {
|
||||
out[i] = m.Channel
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// buildChatUpdateMessage produces the JSON payload for a
|
||||
// `chat:update:{chat_id}` message describing the post-transition
|
||||
// snapshot of chat.
|
||||
func buildChatUpdateMessage(chat database.Chat) []byte {
|
||||
msg := coderdpubsub.ChatStateUpdateMessage{
|
||||
SnapshotVersion: chat.SnapshotVersion,
|
||||
HistoryVersion: chat.HistoryVersion,
|
||||
QueueVersion: chat.QueueVersion,
|
||||
RetryStateVersion: chat.RetryStateVersion,
|
||||
GenerationAttempt: chat.GenerationAttempt,
|
||||
Status: string(chat.Status),
|
||||
Archived: chat.Archived,
|
||||
}
|
||||
if chat.WorkerID.Valid {
|
||||
id := chat.WorkerID.UUID
|
||||
msg.WorkerID = &id
|
||||
}
|
||||
if chat.RunnerID.Valid {
|
||||
id := chat.RunnerID.UUID
|
||||
msg.RunnerID = &id
|
||||
}
|
||||
payload, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
// json.Marshal on this struct is total; panic is acceptable
|
||||
// because the only failure mode would be a bug in this
|
||||
// package, not user input.
|
||||
panic(fmt.Sprintf("marshal chat state update: %v", err))
|
||||
}
|
||||
return payload
|
||||
}
|
||||
|
||||
// buildChatOwnershipMessage produces the JSON payload for the global
|
||||
// `chat:ownership` ownership hint for chat.
|
||||
func buildChatOwnershipMessage(chat database.Chat) []byte {
|
||||
payload, err := json.Marshal(coderdpubsub.ChatStateOwnershipMessage{
|
||||
ChatID: chat.ID,
|
||||
SnapshotVersion: chat.SnapshotVersion,
|
||||
})
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("marshal chat state ownership: %v", err))
|
||||
}
|
||||
return payload
|
||||
}
|
||||
@@ -0,0 +1,373 @@
|
||||
package chatstate_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatstate"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
// =============================================================================
|
||||
// Publication helpers
|
||||
// =============================================================================
|
||||
|
||||
// publishedOn returns the indices into f.Pub.channels (and f.Pub.payloads)
|
||||
// that match the given channel name, in order.
|
||||
func publishedOn(f *testFixture, channel string) []int {
|
||||
var idx []int
|
||||
for i, c := range f.Pub.channels {
|
||||
if c == channel {
|
||||
idx = append(idx, i)
|
||||
}
|
||||
}
|
||||
return idx
|
||||
}
|
||||
|
||||
// TestCreateChatPublishesAfterCommit asserts that a successful
|
||||
// CreateChat call publishes exactly one chat:update message on the
|
||||
// per-chat channel after the inner transaction commits.
|
||||
func TestCreateChatPublishesAfterCommit(t *testing.T) {
|
||||
t.Parallel()
|
||||
f := newTestFixture(t)
|
||||
res := createTestChat(t, f)
|
||||
|
||||
channel := coderdpubsub.ChatStateUpdateChannel(res.Chat.ID)
|
||||
idx := publishedOn(f, channel)
|
||||
require.Len(t, idx, 1, "exactly one chat:update for the new chat")
|
||||
|
||||
var msg coderdpubsub.ChatStateUpdateMessage
|
||||
require.NoError(t, json.Unmarshal(f.Pub.payloads[idx[0]], &msg))
|
||||
require.Equal(t, res.Chat.SnapshotVersion, msg.SnapshotVersion)
|
||||
require.Equal(t, string(database.ChatStatusRunning), msg.Status)
|
||||
}
|
||||
|
||||
// TestUpdatePublishesAfterCommit asserts that ChatMachine.Update
|
||||
// publishes one chat:update on the per-chat channel after the inner
|
||||
// transaction commits, even when the callback performs no transition.
|
||||
func TestUpdatePublishesAfterCommit(t *testing.T) {
|
||||
t.Parallel()
|
||||
f := newTestFixture(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
created := createTestChat(t, f)
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
|
||||
|
||||
createIdx := publishedOn(f, coderdpubsub.ChatStateUpdateChannel(created.Chat.ID))
|
||||
require.Len(t, createIdx, 1, "create published one chat:update")
|
||||
|
||||
require.NoError(t, m.Update(ctx, func(_ *chatstate.Tx) error { return nil }))
|
||||
|
||||
updIdx := publishedOn(f, coderdpubsub.ChatStateUpdateChannel(created.Chat.ID))
|
||||
require.Len(t, updIdx, 2, "no-op Update still publishes a chat:update")
|
||||
|
||||
after, err := f.DB.GetChatByID(ctx, created.Chat.ID)
|
||||
require.NoError(t, err)
|
||||
var msg coderdpubsub.ChatStateUpdateMessage
|
||||
require.NoError(t, json.Unmarshal(f.Pub.payloads[updIdx[1]], &msg))
|
||||
require.Equal(t, after.SnapshotVersion, msg.SnapshotVersion)
|
||||
}
|
||||
|
||||
// TestUpdatePublishesOneFinalChatUpdateForTransitionBundle bundles
|
||||
// several transitions inside one Update callback and verifies the
|
||||
// commit publishes exactly one chat:update on the per-chat channel
|
||||
// (not one per transition).
|
||||
func TestUpdatePublishesOneFinalChatUpdateForTransitionBundle(t *testing.T) {
|
||||
t.Parallel()
|
||||
f := newTestFixture(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
created := createTestChat(t, f)
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
|
||||
|
||||
baseUpdates := len(publishedOn(f, coderdpubsub.ChatStateUpdateChannel(created.Chat.ID)))
|
||||
|
||||
// R0 -> W (FinishTurn) -> XW (SetArchived true) -> W (SetArchived false).
|
||||
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
if _, err := tx.FinishTurn(chatstate.FinishTurnInput{}); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.SetArchived(chatstate.SetArchivedInput{Archived: true}); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.SetArchived(chatstate.SetArchivedInput{Archived: false}); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}))
|
||||
|
||||
updIdx := publishedOn(f, coderdpubsub.ChatStateUpdateChannel(created.Chat.ID))
|
||||
require.Equal(t, baseUpdates+1, len(updIdx),
|
||||
"three-transition bundle publishes exactly one final chat:update")
|
||||
|
||||
after, err := f.DB.GetChatByID(ctx, created.Chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, database.ChatStatusWaiting, after.Status, "ends in W")
|
||||
}
|
||||
|
||||
// TestUpdateAppliesTransitionBundleSequentially verifies that
|
||||
// transitions chained inside a single Update callback see each
|
||||
// other's effects: later transitions validate against the state
|
||||
// produced by earlier ones (R0 -> W is rejected when called twice
|
||||
// because the second call sees state W and FinishTurn is no longer
|
||||
// allowed).
|
||||
func TestUpdateAppliesTransitionBundleSequentially(t *testing.T) {
|
||||
t.Parallel()
|
||||
f := newTestFixture(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
created := createTestChat(t, f)
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
|
||||
|
||||
err := m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
if _, err := tx.FinishTurn(chatstate.FinishTurnInput{}); err != nil {
|
||||
return err
|
||||
}
|
||||
// Second FinishTurn should fail because state is now W.
|
||||
_, err := tx.FinishTurn(chatstate.FinishTurnInput{})
|
||||
return err
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, chatstate.ErrTransitionNotAllowed)
|
||||
|
||||
// Failed bundle rolls back: state must not have advanced past R0.
|
||||
require.Equal(t, chatstate.StateR0, f.classify(ctx, t, created.Chat.ID),
|
||||
"failed bundle rolls back the whole transaction")
|
||||
}
|
||||
|
||||
// TestFailedUpdatePublishesNothing verifies that a callback error
|
||||
// rolls back the snapshot bump and publishes nothing.
|
||||
func TestFailedUpdatePublishesNothing(t *testing.T) {
|
||||
t.Parallel()
|
||||
f := newTestFixture(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
created := createTestChat(t, f)
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
|
||||
|
||||
publishedBefore := len(f.Pub.channels)
|
||||
beforeChat := f.readChat(ctx, t, created.Chat.ID)
|
||||
|
||||
sentinel := xerrors.New("forced failure")
|
||||
err := m.Update(ctx, func(_ *chatstate.Tx) error { return sentinel })
|
||||
require.ErrorIs(t, err, sentinel)
|
||||
require.Equal(t, publishedBefore, len(f.Pub.channels), "failed update publishes nothing")
|
||||
|
||||
after := f.readChat(ctx, t, created.Chat.ID)
|
||||
require.Equal(t, beforeChat.SnapshotVersion, after.SnapshotVersion,
|
||||
"failed update rolls back snapshot bump")
|
||||
}
|
||||
|
||||
// TestLockPublishesNothing verifies that Lock does not publish even
|
||||
// though it locks the chat row.
|
||||
func TestLockPublishesNothing(t *testing.T) {
|
||||
t.Parallel()
|
||||
f := newTestFixture(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
created := createTestChat(t, f)
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
|
||||
|
||||
publishedBefore := len(f.Pub.channels)
|
||||
require.NoError(t, m.Lock(ctx, func(_ database.Store) error { return nil }))
|
||||
require.Equal(t, publishedBefore, len(f.Pub.channels), "Lock publishes nothing")
|
||||
}
|
||||
|
||||
// TestPublishBufferWithRolledBackOuterTransactionPublishesNothing
|
||||
// wires a chatstate machine through a PublishBuffer and exercises
|
||||
// the buffer primitive directly: when the caller discards before
|
||||
// flushing, the inner publisher receives nothing. ChatMachine.Update
|
||||
// uses the same primitive internally with a deferred Discard;
|
||||
// callers no longer drive Flush or Discard themselves.
|
||||
func TestPublishBufferWithRolledBackOuterTransactionPublishesNothing(t *testing.T) {
|
||||
t.Parallel()
|
||||
f := newTestFixture(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
created := createTestChat(t, f)
|
||||
|
||||
// Run one normal Update to establish a stable baseline channel
|
||||
// count. CreateChat plus this Update may publish chat:update
|
||||
// and chat:ownership messages depending on ownership, so we
|
||||
// take the snapshot after that activity settles.
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
|
||||
require.NoError(t, m.Update(ctx, func(_ *chatstate.Tx) error { return nil }))
|
||||
baseline := len(f.Pub.channels)
|
||||
|
||||
// Now exercise the PublishBuffer rollback path explicitly. The
|
||||
// outer transaction "rolls back": the caller buffers messages,
|
||||
// discards them, then flushes. The inner publisher must see
|
||||
// none of the buffered messages.
|
||||
buf := chatstate.NewPublishBuffer(f.Pub)
|
||||
require.NoError(t, buf.Publish("chat:update:bogus", []byte("payload")))
|
||||
require.NoError(t, buf.Publish("chat:ownership", []byte("payload")))
|
||||
buf.Discard()
|
||||
require.NoError(t, buf.Flush())
|
||||
|
||||
require.Equal(t, baseline, len(f.Pub.channels),
|
||||
"discarded buffer publishes nothing through the inner publisher")
|
||||
}
|
||||
|
||||
// TestChatUpdateMessagePayloadShape verifies the JSON shape of the
|
||||
// chat:update payload contains every field consumers depend on:
|
||||
// snapshot_version, history_version, queue_version,
|
||||
// retry_state_version, generation_attempt, status, archived, and
|
||||
// optional worker_id / runner_id.
|
||||
func TestChatUpdateMessagePayloadShape(t *testing.T) {
|
||||
t.Parallel()
|
||||
f := newTestFixture(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
created := createTestChat(t, f)
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
|
||||
|
||||
// Acquire ownership so worker_id and runner_id are present.
|
||||
worker := uuid.New()
|
||||
runner := uuid.New()
|
||||
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
_, err := tx.Acquire(chatstate.AcquireInput{WorkerID: worker, RunnerID: runner})
|
||||
return err
|
||||
}))
|
||||
|
||||
// Find the last chat:update message.
|
||||
channel := coderdpubsub.ChatStateUpdateChannel(created.Chat.ID)
|
||||
idx := publishedOn(f, channel)
|
||||
require.NotEmpty(t, idx)
|
||||
last := f.Pub.payloads[idx[len(idx)-1]]
|
||||
|
||||
// Strict-decode against the typed struct.
|
||||
var typed coderdpubsub.ChatStateUpdateMessage
|
||||
require.NoError(t, json.Unmarshal(last, &typed))
|
||||
require.Greater(t, typed.SnapshotVersion, int64(0))
|
||||
require.NotNil(t, typed.WorkerID)
|
||||
require.Equal(t, worker, *typed.WorkerID)
|
||||
require.NotNil(t, typed.RunnerID)
|
||||
require.Equal(t, runner, *typed.RunnerID)
|
||||
require.Equal(t, string(database.ChatStatusRunning), typed.Status)
|
||||
require.False(t, typed.Archived)
|
||||
|
||||
// Permissive decode to assert exact JSON keys.
|
||||
var raw map[string]json.RawMessage
|
||||
require.NoError(t, json.Unmarshal(last, &raw))
|
||||
for _, key := range []string{
|
||||
"snapshot_version",
|
||||
"history_version",
|
||||
"queue_version",
|
||||
"retry_state_version",
|
||||
"generation_attempt",
|
||||
"status",
|
||||
"archived",
|
||||
"worker_id",
|
||||
"runner_id",
|
||||
} {
|
||||
_, ok := raw[key]
|
||||
require.True(t, ok, "payload missing key %q", key)
|
||||
}
|
||||
}
|
||||
|
||||
// TestChatOwnershipMessagePayloadShape verifies the JSON shape of
|
||||
// chat:ownership: chat_id and snapshot_version.
|
||||
func TestChatOwnershipMessagePayloadShape(t *testing.T) {
|
||||
t.Parallel()
|
||||
f := newTestFixture(t)
|
||||
// CreateChat publishes one ownership hint because the new chat is
|
||||
// unowned and runnable.
|
||||
created := createTestChat(t, f)
|
||||
|
||||
idx := publishedOn(f, coderdpubsub.ChatStateOwnershipChannel)
|
||||
require.NotEmpty(t, idx, "CreateChat publishes at least one chat:ownership hint")
|
||||
|
||||
payload := f.Pub.payloads[idx[len(idx)-1]]
|
||||
var typed coderdpubsub.ChatStateOwnershipMessage
|
||||
require.NoError(t, json.Unmarshal(payload, &typed))
|
||||
require.Equal(t, created.Chat.ID, typed.ChatID)
|
||||
require.Greater(t, typed.SnapshotVersion, int64(0))
|
||||
|
||||
var raw map[string]json.RawMessage
|
||||
require.NoError(t, json.Unmarshal(payload, &raw))
|
||||
for _, key := range []string{"chat_id", "snapshot_version"} {
|
||||
_, ok := raw[key]
|
||||
require.True(t, ok, "ownership payload missing key %q", key)
|
||||
}
|
||||
}
|
||||
|
||||
// TestOwnershipNotificationUsesDatabaseHeartbeatStaleness verifies
|
||||
// that an ownership hint fires when the heartbeat is stale by the
|
||||
// database's clock, regardless of what the local Go clock says. We
|
||||
// rewrite the heartbeat row to a deterministically old timestamp via
|
||||
// raw SQL and confirm the post-commit hint is sent on a subsequent
|
||||
// runnable Update.
|
||||
func TestOwnershipNotificationUsesDatabaseHeartbeatStaleness(t *testing.T) {
|
||||
t.Parallel()
|
||||
tf := newTriggerFixture(t)
|
||||
f := tf.f
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
created := createTestChat(t, f)
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
|
||||
|
||||
// Acquire ownership; this writes a fresh heartbeat.
|
||||
worker := uuid.New()
|
||||
runner := uuid.New()
|
||||
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
_, err := tx.Acquire(chatstate.AcquireInput{WorkerID: worker, RunnerID: runner})
|
||||
return err
|
||||
}))
|
||||
hb, err := f.DB.GetChatHeartbeat(ctx, database.GetChatHeartbeatParams{
|
||||
ChatID: created.Chat.ID,
|
||||
RunnerID: runner,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.WithinDuration(t, time.Now(), hb.HeartbeatAt, time.Minute,
|
||||
"Acquire wrote a fresh heartbeat")
|
||||
|
||||
// Snapshot ownership-hint count before the test trigger.
|
||||
ownershipBefore := f.Pub.ownershipPublishCount()
|
||||
|
||||
// Force the heartbeat to a deterministically old time.
|
||||
_, err = tf.sqlDB.ExecContext(ctx, `
|
||||
UPDATE chat_heartbeats
|
||||
SET heartbeat_at = NOW() - INTERVAL '1 hour'
|
||||
WHERE chat_id = $1 AND runner_id = $2
|
||||
`, created.Chat.ID, runner)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Confirm database-side staleness check agrees.
|
||||
stale, err := f.DB.IsChatHeartbeatStale(ctx, database.IsChatHeartbeatStaleParams{
|
||||
ChatID: created.Chat.ID,
|
||||
RunnerID: runner,
|
||||
StaleSeconds: chatstate.HeartbeatStaleSeconds,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.True(t, stale, "heartbeat is stale per database time")
|
||||
|
||||
// Run a no-op Update. The chat is runnable (R0) and the
|
||||
// heartbeat is stale, so post-commit logic must publish exactly
|
||||
// one chat:ownership hint.
|
||||
require.NoError(t, m.Update(ctx, func(_ *chatstate.Tx) error { return nil }))
|
||||
|
||||
ownershipAfter := f.Pub.ownershipPublishCount()
|
||||
require.Equal(t, ownershipBefore+1, ownershipAfter,
|
||||
"stale heartbeat triggers a fresh ownership hint")
|
||||
}
|
||||
|
||||
// TestUpdateContextCancellationPublishesNothing verifies that
|
||||
// canceling the caller's context (between the inner commit and the
|
||||
// publish loop's first call) does not corrupt state. We exercise the
|
||||
// simpler observable contract: when the user cancels before Update
|
||||
// gets to do anything, nothing is published. The strict before-publish
|
||||
// race is exercised in concurrency tests with channel sync.
|
||||
func TestUpdateContextCancellationPublishesNothing(t *testing.T) {
|
||||
t.Parallel()
|
||||
f := newTestFixture(t)
|
||||
created := createTestChat(t, f)
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
|
||||
|
||||
publishedBefore := len(f.Pub.channels)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
err := m.Update(ctx, func(_ *chatstate.Tx) error { return nil })
|
||||
require.Error(t, err)
|
||||
require.Equal(t, publishedBefore, len(f.Pub.channels),
|
||||
"caller-aborted update publishes nothing")
|
||||
}
|
||||
@@ -0,0 +1,113 @@
|
||||
package chatstate
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
type recordingPublisher struct {
|
||||
calls []recordedCall
|
||||
errOn map[string]error
|
||||
failed map[string]int
|
||||
}
|
||||
|
||||
type recordedCall struct {
|
||||
Channel string
|
||||
Payload []byte
|
||||
}
|
||||
|
||||
func newRecordingPublisher() *recordingPublisher {
|
||||
return &recordingPublisher{
|
||||
errOn: map[string]error{},
|
||||
failed: map[string]int{},
|
||||
}
|
||||
}
|
||||
|
||||
func (r *recordingPublisher) Publish(channel string, payload []byte) error {
|
||||
r.calls = append(r.calls, recordedCall{Channel: channel, Payload: append([]byte(nil), payload...)})
|
||||
if err, ok := r.errOn[channel]; ok {
|
||||
r.failed[channel]++
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestPublishBuffer_DefersPublishUntilFlush(t *testing.T) {
|
||||
t.Parallel()
|
||||
inner := newRecordingPublisher()
|
||||
buf := NewPublishBuffer(inner)
|
||||
|
||||
require.NoError(t, buf.Publish("a", []byte("1")))
|
||||
require.NoError(t, buf.Publish("b", []byte("2")))
|
||||
|
||||
require.Empty(t, inner.calls, "inner publisher should not be called before flush")
|
||||
require.Equal(t, []string{"a", "b"}, buf.BufferedChannels())
|
||||
}
|
||||
|
||||
func TestPublishBuffer_FlushPublishesInOrder(t *testing.T) {
|
||||
t.Parallel()
|
||||
inner := newRecordingPublisher()
|
||||
buf := NewPublishBuffer(inner)
|
||||
|
||||
require.NoError(t, buf.Publish("a", []byte("1")))
|
||||
require.NoError(t, buf.Publish("b", []byte("2")))
|
||||
require.NoError(t, buf.Publish("c", []byte("3")))
|
||||
|
||||
require.NoError(t, buf.Flush())
|
||||
require.Len(t, inner.calls, 3)
|
||||
require.Equal(t, "a", inner.calls[0].Channel)
|
||||
require.Equal(t, "b", inner.calls[1].Channel)
|
||||
require.Equal(t, "c", inner.calls[2].Channel)
|
||||
require.Equal(t, []byte("1"), inner.calls[0].Payload)
|
||||
}
|
||||
|
||||
func TestPublishBuffer_FlushReturnsFirstError(t *testing.T) {
|
||||
t.Parallel()
|
||||
inner := newRecordingPublisher()
|
||||
inner.errOn["b"] = xerrors.New("broken")
|
||||
buf := NewPublishBuffer(inner)
|
||||
|
||||
require.NoError(t, buf.Publish("a", []byte("1")))
|
||||
require.NoError(t, buf.Publish("b", []byte("2")))
|
||||
require.NoError(t, buf.Publish("c", []byte("3")))
|
||||
|
||||
err := buf.Flush()
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "publish b:")
|
||||
// Even after the broken channel, later messages should still be
|
||||
// attempted so the inner publisher sees them.
|
||||
require.Len(t, inner.calls, 3)
|
||||
}
|
||||
|
||||
func TestPublishBuffer_PublishAfterFlushFails(t *testing.T) {
|
||||
t.Parallel()
|
||||
inner := newRecordingPublisher()
|
||||
buf := NewPublishBuffer(inner)
|
||||
require.NoError(t, buf.Flush())
|
||||
require.Error(t, buf.Publish("x", []byte("y")))
|
||||
}
|
||||
|
||||
func TestPublishBuffer_DiscardSuppressesPending(t *testing.T) {
|
||||
t.Parallel()
|
||||
inner := newRecordingPublisher()
|
||||
buf := NewPublishBuffer(inner)
|
||||
require.NoError(t, buf.Publish("a", []byte("1")))
|
||||
buf.Discard()
|
||||
require.NoError(t, buf.Flush())
|
||||
require.Empty(t, inner.calls)
|
||||
}
|
||||
|
||||
func TestPublishBuffer_DiscardBlocksLaterPublishes(t *testing.T) {
|
||||
t.Parallel()
|
||||
inner := newRecordingPublisher()
|
||||
buf := NewPublishBuffer(inner)
|
||||
buf.Discard()
|
||||
// Discard sets disabled; subsequent Publish is a no-op (not an
|
||||
// error) so callers using Discard before/around rollback paths
|
||||
// do not have to special-case unwind.
|
||||
require.NoError(t, buf.Publish("a", []byte("1")))
|
||||
require.NoError(t, buf.Flush())
|
||||
require.Empty(t, inner.calls)
|
||||
}
|
||||
@@ -0,0 +1,182 @@
|
||||
package chatstate
|
||||
|
||||
import (
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
)
|
||||
|
||||
// ExecutionState is the RFC shorthand for a chat's current execution
|
||||
// state. The valid set contains 13 states; everything outside it is
|
||||
// represented by [StateInvalid].
|
||||
type ExecutionState string
|
||||
|
||||
const (
|
||||
// StateN: chat does not exist.
|
||||
StateN ExecutionState = "N"
|
||||
// StateW: waiting, empty queue, not archived.
|
||||
StateW ExecutionState = "W"
|
||||
// StateE0: error, empty queue, not archived.
|
||||
StateE0 ExecutionState = "E0"
|
||||
// StateE1: error, non-empty queue, not archived.
|
||||
StateE1 ExecutionState = "E1"
|
||||
// StateR0: running, empty queue, not archived.
|
||||
StateR0 ExecutionState = "R0"
|
||||
// StateR1: running, non-empty queue, not archived.
|
||||
StateR1 ExecutionState = "R1"
|
||||
// StateI0: interrupting, empty queue, not archived.
|
||||
StateI0 ExecutionState = "I0"
|
||||
// StateI1: interrupting, non-empty queue, not archived.
|
||||
StateI1 ExecutionState = "I1"
|
||||
// StateA0: requires_action, empty queue, not archived.
|
||||
StateA0 ExecutionState = "A0"
|
||||
// StateA1: requires_action, non-empty queue, not archived.
|
||||
StateA1 ExecutionState = "A1"
|
||||
// StateXW: archived waiting, empty queue.
|
||||
StateXW ExecutionState = "XW"
|
||||
// StateXE0: archived error, empty queue.
|
||||
StateXE0 ExecutionState = "XE0"
|
||||
// StateXE1: archived error, non-empty queue.
|
||||
StateXE1 ExecutionState = "XE1"
|
||||
|
||||
// StateInvalid groups every status/archive/queue combination that
|
||||
// is not one of the 13 valid states above. The state machine
|
||||
// refuses non-reconciliation transitions on invalid states and
|
||||
// exposes the [Tx.ReconcileInvalidState] transition to recover.
|
||||
StateInvalid ExecutionState = "Invalid"
|
||||
)
|
||||
|
||||
// String implements fmt.Stringer.
|
||||
func (s ExecutionState) String() string { return string(s) }
|
||||
|
||||
// AllExecutionStates is the canonical enumeration of every value the
|
||||
// classifier can return. Tests rely on this list to iterate over every
|
||||
// state when verifying transition coverage.
|
||||
var AllExecutionStates = []ExecutionState{
|
||||
StateN,
|
||||
StateW,
|
||||
StateE0,
|
||||
StateE1,
|
||||
StateR0,
|
||||
StateR1,
|
||||
StateI0,
|
||||
StateI1,
|
||||
StateA0,
|
||||
StateA1,
|
||||
StateXW,
|
||||
StateXE0,
|
||||
StateXE1,
|
||||
StateInvalid,
|
||||
}
|
||||
|
||||
// IsRunnable returns true for the execution states that the chat
|
||||
// worker is allowed to acquire and drive forward: R0, R1, I0, I1,
|
||||
// A0, and A1. Requires-action states need worker ownership for
|
||||
// timeout processing. Other states are idle (W, E*, XW, XE*), absent
|
||||
// (N), or invalid.
|
||||
func (s ExecutionState) IsRunnable() bool {
|
||||
switch s {
|
||||
case StateR0, StateR1, StateI0, StateI1, StateA0, StateA1:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// IsArchived returns true for the three archived execution states.
|
||||
func (s ExecutionState) IsArchived() bool {
|
||||
switch s {
|
||||
case StateXW, StateXE0, StateXE1:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// QueueNonEmpty returns true for execution states that require a
|
||||
// non-empty queue. Useful when seeding test fixtures.
|
||||
func (s ExecutionState) QueueNonEmpty() bool {
|
||||
switch s {
|
||||
case StateE1, StateR1, StateI1, StateA1, StateXE1:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// ClassifyExecutionState turns the chat row, queue cardinality, and
|
||||
// whether the chat row exists into one of the 14 [ExecutionState]
|
||||
// values. The caller is responsible for loading the chat under the
|
||||
// row lock and reading the queue count in the same transaction.
|
||||
//
|
||||
// Callers that have no chat row (lookup returned sql.ErrNoRows)
|
||||
// should pass exists=false; the chat, status, and archive arguments
|
||||
// are then ignored.
|
||||
//
|
||||
// The classifier is a single flat switch over the 13 valid
|
||||
// (status, archived, queue) tuples from the RFC. Anything outside
|
||||
// that set (legacy pending/paused/completed statuses, archived busy
|
||||
// states, waiting with a non-empty queue, future enum values) falls
|
||||
// through to [StateInvalid].
|
||||
//
|
||||
//nolint:revive // queueNonEmpty/exists are simple flags that mirror the RFC inputs.
|
||||
func ClassifyExecutionState(chat database.Chat, queueNonEmpty, exists bool) ExecutionState {
|
||||
if !exists {
|
||||
return StateN
|
||||
}
|
||||
switch {
|
||||
case chat.Status == database.ChatStatusWaiting && !chat.Archived && !queueNonEmpty:
|
||||
return StateW
|
||||
case chat.Status == database.ChatStatusWaiting && chat.Archived && !queueNonEmpty:
|
||||
return StateXW
|
||||
case chat.Status == database.ChatStatusError && !chat.Archived && !queueNonEmpty:
|
||||
return StateE0
|
||||
case chat.Status == database.ChatStatusError && !chat.Archived && queueNonEmpty:
|
||||
return StateE1
|
||||
case chat.Status == database.ChatStatusError && chat.Archived && !queueNonEmpty:
|
||||
return StateXE0
|
||||
case chat.Status == database.ChatStatusError && chat.Archived && queueNonEmpty:
|
||||
return StateXE1
|
||||
case chat.Status == database.ChatStatusRunning && !chat.Archived && !queueNonEmpty:
|
||||
return StateR0
|
||||
case chat.Status == database.ChatStatusRunning && !chat.Archived && queueNonEmpty:
|
||||
return StateR1
|
||||
case chat.Status == database.ChatStatusInterrupting && !chat.Archived && !queueNonEmpty:
|
||||
return StateI0
|
||||
case chat.Status == database.ChatStatusInterrupting && !chat.Archived && queueNonEmpty:
|
||||
return StateI1
|
||||
case chat.Status == database.ChatStatusRequiresAction && !chat.Archived && !queueNonEmpty:
|
||||
return StateA0
|
||||
case chat.Status == database.ChatStatusRequiresAction && !chat.Archived && queueNonEmpty:
|
||||
return StateA1
|
||||
}
|
||||
return StateInvalid
|
||||
}
|
||||
|
||||
// OwnershipState is the RFC shorthand for whether a chat row is
|
||||
// currently owned by a worker. The state machine treats execution and
|
||||
// ownership as orthogonal.
|
||||
type OwnershipState string
|
||||
|
||||
const (
|
||||
// StateU: chat has no owner (worker_id IS NULL).
|
||||
StateU OwnershipState = "U"
|
||||
// StateO: chat has an owner (worker_id IS NOT NULL).
|
||||
StateO OwnershipState = "O"
|
||||
)
|
||||
|
||||
// String implements fmt.Stringer.
|
||||
func (s OwnershipState) String() string { return string(s) }
|
||||
|
||||
// AllOwnershipStates is the canonical enumeration of ownership states.
|
||||
var AllOwnershipStates = []OwnershipState{StateU, StateO}
|
||||
|
||||
// ClassifyOwnershipState returns [StateU] when worker_id is NULL and
|
||||
// [StateO] otherwise. The runner_id field is intentionally not part of
|
||||
// the ownership classification. Acquire writes both worker_id and
|
||||
// runner_id; Abandon clears the current row without using runner_id as
|
||||
// a transition input.
|
||||
func ClassifyOwnershipState(chat database.Chat) OwnershipState {
|
||||
if !chat.WorkerID.Valid {
|
||||
return StateU
|
||||
}
|
||||
return StateO
|
||||
}
|
||||
@@ -0,0 +1,195 @@
|
||||
package chatstate
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
)
|
||||
|
||||
func chatWithStatus(status database.ChatStatus, archived bool) database.Chat {
|
||||
return database.Chat{
|
||||
ID: uuid.New(),
|
||||
Status: status,
|
||||
Archived: archived,
|
||||
OwnerID: uuid.New(),
|
||||
}
|
||||
}
|
||||
|
||||
// TestClassifyExecutionState_Valid covers every valid classification:
|
||||
// N (missing chat) plus the twelve valid existing-chat states.
|
||||
func TestClassifyExecutionState_Valid(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
status database.ChatStatus
|
||||
archived bool
|
||||
queueNonEmpty bool
|
||||
exists bool
|
||||
want ExecutionState
|
||||
}{
|
||||
{name: "N", exists: false, want: StateN},
|
||||
{name: "W", status: database.ChatStatusWaiting, exists: true, want: StateW},
|
||||
{name: "E0", status: database.ChatStatusError, exists: true, want: StateE0},
|
||||
{name: "E1", status: database.ChatStatusError, queueNonEmpty: true, exists: true, want: StateE1},
|
||||
{name: "R0", status: database.ChatStatusRunning, exists: true, want: StateR0},
|
||||
{name: "R1", status: database.ChatStatusRunning, queueNonEmpty: true, exists: true, want: StateR1},
|
||||
{name: "I0", status: database.ChatStatusInterrupting, exists: true, want: StateI0},
|
||||
{name: "I1", status: database.ChatStatusInterrupting, queueNonEmpty: true, exists: true, want: StateI1},
|
||||
{name: "A0", status: database.ChatStatusRequiresAction, exists: true, want: StateA0},
|
||||
{name: "A1", status: database.ChatStatusRequiresAction, queueNonEmpty: true, exists: true, want: StateA1},
|
||||
{name: "XW", status: database.ChatStatusWaiting, archived: true, exists: true, want: StateXW},
|
||||
{name: "XE0", status: database.ChatStatusError, archived: true, exists: true, want: StateXE0},
|
||||
{name: "XE1", status: database.ChatStatusError, archived: true, queueNonEmpty: true, exists: true, want: StateXE1},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
chat := database.Chat{}
|
||||
if tc.exists {
|
||||
chat = chatWithStatus(tc.status, tc.archived)
|
||||
}
|
||||
require.Equal(t, tc.want, ClassifyExecutionState(chat, tc.queueNonEmpty, tc.exists))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestClassifyExecutionState_Invalid covers every documented invalid
|
||||
// combination: legacy statuses, waiting-with-queue, and archived busy
|
||||
// statuses.
|
||||
func TestClassifyExecutionState_Invalid(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
status database.ChatStatus
|
||||
archived bool
|
||||
queueNonEmpty bool
|
||||
}{
|
||||
// Legacy statuses (pending/paused/completed) are invalid for
|
||||
// the new state machine.
|
||||
{name: "LegacyPending", status: "pending"},
|
||||
{name: "LegacyPaused", status: "paused"},
|
||||
{name: "LegacyCompleted", status: "completed"},
|
||||
|
||||
// Waiting must always have an empty queue.
|
||||
{name: "WaitingWithQueue", status: database.ChatStatusWaiting, queueNonEmpty: true},
|
||||
{name: "WaitingArchivedWithQueue", status: database.ChatStatusWaiting, archived: true, queueNonEmpty: true},
|
||||
|
||||
// Archived busy statuses are invalid.
|
||||
{name: "ArchivedRunning", status: database.ChatStatusRunning, archived: true},
|
||||
{name: "ArchivedInterrupting", status: database.ChatStatusInterrupting, archived: true},
|
||||
{name: "ArchivedRequiresAction", status: database.ChatStatusRequiresAction, archived: true},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := ClassifyExecutionState(chatWithStatus(tc.status, tc.archived), tc.queueNonEmpty, true)
|
||||
require.Equal(t, StateInvalid, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestClassifyExecutionState_RejectsAllUnlistedCombinations enumerates
|
||||
// every (status, archived, queueNonEmpty) tuple for an existing chat
|
||||
// and asserts exactly the twelve RFC tuples classify out of
|
||||
// [StateInvalid]. Missing chats are handled separately via the N case
|
||||
// in [TestClassifyExecutionState_Valid].
|
||||
func TestClassifyExecutionState_RejectsAllUnlistedCombinations(t *testing.T) {
|
||||
t.Parallel()
|
||||
allStatuses := []database.ChatStatus{
|
||||
database.ChatStatusWaiting,
|
||||
database.ChatStatusError,
|
||||
database.ChatStatusRunning,
|
||||
database.ChatStatusInterrupting,
|
||||
database.ChatStatusRequiresAction,
|
||||
"pending", "paused", "completed",
|
||||
}
|
||||
validCount := 0
|
||||
for _, status := range allStatuses {
|
||||
for _, archived := range []bool{false, true} {
|
||||
for _, queueNonEmpty := range []bool{false, true} {
|
||||
got := ClassifyExecutionState(chatWithStatus(status, archived), queueNonEmpty, true)
|
||||
if got != StateInvalid {
|
||||
validCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
require.Equal(t, 12, validCount, "exactly 12 valid existing-chat (status, archived, queue) tuples")
|
||||
}
|
||||
|
||||
// TestClassifyOwnershipState covers both ownership classifications.
|
||||
func TestClassifyOwnershipState(t *testing.T) {
|
||||
t.Parallel()
|
||||
cases := []struct {
|
||||
name string
|
||||
chat func() database.Chat
|
||||
want OwnershipState
|
||||
}{
|
||||
{
|
||||
name: "Unowned",
|
||||
chat: func() database.Chat { return database.Chat{} },
|
||||
want: StateU,
|
||||
},
|
||||
{
|
||||
name: "Owned",
|
||||
chat: func() database.Chat {
|
||||
c := database.Chat{}
|
||||
c.WorkerID.UUID = uuid.New()
|
||||
c.WorkerID.Valid = true
|
||||
return c
|
||||
},
|
||||
want: StateO,
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Equal(t, tc.want, ClassifyOwnershipState(tc.chat()))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAllExecutionStates_Enumeration verifies AllExecutionStates
|
||||
// contains every declared execution state exactly once.
|
||||
func TestAllExecutionStates_Enumeration(t *testing.T) {
|
||||
t.Parallel()
|
||||
want := map[ExecutionState]bool{
|
||||
StateN: true, StateW: true, StateE0: true, StateE1: true,
|
||||
StateR0: true, StateR1: true, StateI0: true, StateI1: true,
|
||||
StateA0: true, StateA1: true, StateXW: true, StateXE0: true,
|
||||
StateXE1: true, StateInvalid: true,
|
||||
}
|
||||
require.Len(t, AllExecutionStates, len(want))
|
||||
seen := make(map[ExecutionState]bool, len(want))
|
||||
for _, s := range AllExecutionStates {
|
||||
require.True(t, want[s], "unexpected state %s", s)
|
||||
require.False(t, seen[s], "duplicate state %s", s)
|
||||
seen[s] = true
|
||||
}
|
||||
}
|
||||
|
||||
// TestExecutionState_Predicates covers IsRunnable and QueueNonEmpty
|
||||
// for every declared execution state.
|
||||
func TestExecutionState_Predicates(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
runnable := map[ExecutionState]bool{
|
||||
StateR0: true, StateR1: true, StateI0: true, StateI1: true,
|
||||
StateA0: true, StateA1: true,
|
||||
}
|
||||
nonEmpty := map[ExecutionState]bool{
|
||||
StateE1: true, StateR1: true, StateI1: true, StateA1: true, StateXE1: true,
|
||||
}
|
||||
for _, s := range AllExecutionStates {
|
||||
require.Equal(t, runnable[s], s.IsRunnable(), "IsRunnable(%s)", s)
|
||||
require.Equal(t, nonEmpty[s], s.QueueNonEmpty(), "QueueNonEmpty(%s)", s)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,485 @@
|
||||
package chatstate_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatstate"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
// =============================================================================
|
||||
// Helpers for synthetic cancellation tests
|
||||
// =============================================================================
|
||||
|
||||
// nonDynamicAssistantToolCallMessage builds an assistant message that
|
||||
// issues a single tool call against a tool that is NOT in the chat's
|
||||
// dynamic_tools set. The send-message and edit-message paths use the
|
||||
// "cancel every outstanding tool call regardless of source" variant
|
||||
// (dynamicOnly=false), so the cancellation must still fire even for
|
||||
// non-dynamic tools.
|
||||
func nonDynamicAssistantToolCallMessage(t *testing.T, modelID uuid.UUID, callID string) chatstate.Message {
|
||||
t.Helper()
|
||||
raw, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeToolCall,
|
||||
ToolCallID: callID,
|
||||
ToolName: "non_dynamic_tool",
|
||||
Args: json.RawMessage(`{}`),
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
return chatstate.Message{
|
||||
Role: database.ChatMessageRoleAssistant,
|
||||
Content: raw,
|
||||
Visibility: database.ChatMessageVisibilityBoth,
|
||||
ContentVersion: chatprompt.CurrentContentVersion,
|
||||
ModelConfigID: uuid.NullUUID{UUID: modelID, Valid: true},
|
||||
}
|
||||
}
|
||||
|
||||
// assertToolResultForCall asserts that msg is a tool-result message
|
||||
// that resolves a tool call with id wantCallID and is_error=true.
|
||||
func assertToolResultForCall(t *testing.T, msg database.ChatMessage, wantCallID string) {
|
||||
t.Helper()
|
||||
require.Equal(t, database.ChatMessageRoleTool, msg.Role)
|
||||
parts, err := chatprompt.ParseContent(msg)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, parts)
|
||||
var found bool
|
||||
for _, p := range parts {
|
||||
if p.Type != codersdk.ChatMessagePartTypeToolResult {
|
||||
continue
|
||||
}
|
||||
require.Equal(t, wantCallID, p.ToolCallID, "tool-call id matches")
|
||||
require.True(t, p.IsError, "synthetic cancellation must be marked is_error=true")
|
||||
found = true
|
||||
}
|
||||
require.True(t, found, "expected at least one tool-result part")
|
||||
}
|
||||
|
||||
// commitAssistantToolCall pushes an assistant message that calls
|
||||
// `tool_name` with `callID` into history via CommitStep. Returns the
|
||||
// inserted assistant ChatMessage. Use the dynamic-tools chat fixture
|
||||
// (createTestChatWithDynamicTools) when dynamicOnly cancellation
|
||||
// paths are exercised.
|
||||
func commitAssistantToolCall(
|
||||
t *testing.T,
|
||||
f *testFixture,
|
||||
m *chatstate.ChatMachine,
|
||||
msg chatstate.Message,
|
||||
) database.ChatMessage {
|
||||
t.Helper()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
var step chatstate.CommitStepResult
|
||||
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
var err error
|
||||
step, err = tx.CommitStep(chatstate.CommitStepInput{Messages: []chatstate.Message{msg}})
|
||||
return err
|
||||
}))
|
||||
require.Len(t, step.InsertedMessages, 1)
|
||||
return step.InsertedMessages[0]
|
||||
}
|
||||
|
||||
// landInW puts a fresh R0 chat into state W (waiting) via FinishTurn.
|
||||
func landInW(t *testing.T, f *testFixture, m *chatstate.ChatMachine) {
|
||||
t.Helper()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
_, err := tx.FinishTurn(chatstate.FinishTurnInput{})
|
||||
return err
|
||||
}))
|
||||
require.Equal(t, chatstate.StateW, f.classify(ctx, t, m.ChatID()))
|
||||
}
|
||||
|
||||
// landInE0 puts a fresh R0 chat into state E0 (error, empty queue).
|
||||
func landInE0(t *testing.T, f *testFixture, m *chatstate.ChatMachine) {
|
||||
t.Helper()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
_, err := tx.FinishError(chatstate.FinishErrorInput{
|
||||
LastError: pqtype.NullRawMessage{
|
||||
RawMessage: json.RawMessage(`{"message":"boom"}`),
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
return err
|
||||
}))
|
||||
require.Equal(t, chatstate.StateE0, f.classify(ctx, t, m.ChatID()))
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// SendMessage direct-history paths (W, E0)
|
||||
// =============================================================================
|
||||
|
||||
// TestSendMessageDirect_W_SynthesizesToolCancellations verifies that
|
||||
// from W, SendMessage inserts synthetic tool-result rows for every
|
||||
// outstanding tool call on the last assistant message BEFORE the new
|
||||
// user message, regardless of whether the tools are dynamic.
|
||||
func TestSendMessageDirect_W_SynthesizesToolCancellations(t *testing.T) {
|
||||
t.Parallel()
|
||||
f := newTestFixture(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
created := createTestChat(t, f)
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
|
||||
|
||||
callID := "call_" + uuid.NewString()
|
||||
assistant := commitAssistantToolCall(t, f, m,
|
||||
nonDynamicAssistantToolCallMessage(t, f.Model.ID, callID))
|
||||
require.Equal(t, database.ChatMessageRoleAssistant, assistant.Role)
|
||||
|
||||
// R0 -> W.
|
||||
landInW(t, f, m)
|
||||
|
||||
// SendMessage with a fresh user message. The direct-history path
|
||||
// must insert a synthetic tool-result (for callID) followed by
|
||||
// the new user message.
|
||||
var send chatstate.SendMessageResult
|
||||
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
var err error
|
||||
send, err = tx.SendMessage(chatstate.SendMessageInput{
|
||||
Message: userTextMessage("after-cancel", f.User.ID, f.Model.ID),
|
||||
BusyBehavior: chatstate.BusyBehaviorQueue,
|
||||
})
|
||||
return err
|
||||
}))
|
||||
|
||||
require.Len(t, send.InsertedMessages, 2, "synthetic cancel + new user")
|
||||
assertToolResultForCall(t, send.InsertedMessages[0], callID)
|
||||
require.Equal(t, database.ChatMessageRoleUser, send.InsertedMessages[1].Role)
|
||||
require.Less(t, send.InsertedMessages[0].ID, send.InsertedMessages[1].ID,
|
||||
"synthetic cancel is inserted before the user message")
|
||||
}
|
||||
|
||||
// TestSendMessageDirect_E0_SynthesizesToolCancellations exercises
|
||||
// the same path from E0.
|
||||
func TestSendMessageDirect_E0_SynthesizesToolCancellations(t *testing.T) {
|
||||
t.Parallel()
|
||||
f := newTestFixture(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
created := createTestChat(t, f)
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
|
||||
|
||||
callID := "call_" + uuid.NewString()
|
||||
commitAssistantToolCall(t, f, m,
|
||||
nonDynamicAssistantToolCallMessage(t, f.Model.ID, callID))
|
||||
|
||||
// R0 -> E0.
|
||||
landInE0(t, f, m)
|
||||
|
||||
var send chatstate.SendMessageResult
|
||||
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
var err error
|
||||
send, err = tx.SendMessage(chatstate.SendMessageInput{
|
||||
Message: userTextMessage("after-error", f.User.ID, f.Model.ID),
|
||||
BusyBehavior: chatstate.BusyBehaviorQueue,
|
||||
})
|
||||
return err
|
||||
}))
|
||||
|
||||
require.Len(t, send.InsertedMessages, 2)
|
||||
assertToolResultForCall(t, send.InsertedMessages[0], callID)
|
||||
require.Equal(t, database.ChatMessageRoleUser, send.InsertedMessages[1].Role)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// EditMessage replacement insertion
|
||||
// =============================================================================
|
||||
|
||||
// TestEditMessage_SynthesizesToolCancellationsBeforeReplacement
|
||||
// verifies that EditMessage from a state with an outstanding tool
|
||||
// call before the edited user message inserts a synthetic
|
||||
// tool-result before the replacement user message in history.
|
||||
//
|
||||
// The scenario is:
|
||||
// - user message 1 (initial)
|
||||
// - assistant tool-call (outstanding)
|
||||
// - user message 2 (the one we will edit)
|
||||
//
|
||||
// EditMessage soft-deletes user message 2 and everything after it,
|
||||
// then synthesizes cancellations for tool calls on the last
|
||||
// surviving assistant message that have no matching tool-result.
|
||||
func TestEditMessage_SynthesizesToolCancellationsBeforeReplacement(t *testing.T) {
|
||||
t.Parallel()
|
||||
f := newTestFixture(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
created := createTestChat(t, f)
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
|
||||
|
||||
// Build the history described above. CommitStep is happy to insert
|
||||
// a mixed batch as long as it stays inside R0.
|
||||
callID := "call_" + uuid.NewString()
|
||||
assistantTC := nonDynamicAssistantToolCallMessage(t, f.Model.ID, callID)
|
||||
secondUser := userTextMessage("second user", f.User.ID, f.Model.ID)
|
||||
var step chatstate.CommitStepResult
|
||||
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
var err error
|
||||
step, err = tx.CommitStep(chatstate.CommitStepInput{
|
||||
Messages: []chatstate.Message{assistantTC, secondUser},
|
||||
})
|
||||
return err
|
||||
}))
|
||||
require.Len(t, step.InsertedMessages, 2)
|
||||
secondUserID := step.InsertedMessages[1].ID
|
||||
require.Equal(t, database.ChatMessageRoleUser, step.InsertedMessages[1].Role)
|
||||
|
||||
var edit chatstate.EditMessageResult
|
||||
editedContent := mustMarshalParts(t, []codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageText("edited"),
|
||||
})
|
||||
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
var err error
|
||||
edit, err = tx.EditMessage(chatstate.EditMessageInput{
|
||||
MessageID: secondUserID,
|
||||
CreatedBy: f.User.ID,
|
||||
Content: editedContent,
|
||||
})
|
||||
return err
|
||||
}))
|
||||
|
||||
require.Len(t, edit.CancellationMessages, 1, "synthetic cancel inserted")
|
||||
assertToolResultForCall(t, edit.CancellationMessages[0], callID)
|
||||
require.Equal(t, database.ChatMessageRoleUser, edit.ReplacementMessage.Role)
|
||||
require.Less(t, edit.CancellationMessages[0].ID, edit.ReplacementMessage.ID,
|
||||
"cancellations are inserted before the replacement user message")
|
||||
}
|
||||
|
||||
// PromoteQueuedMessage synchronous-history paths (E1, A1)
|
||||
// =============================================================================
|
||||
|
||||
// TestPromoteQueuedMessage_E1_SynthesizesToolCancellations verifies
|
||||
// that promoting a queued message from E1 inserts synthetic
|
||||
// tool-result rows for outstanding tool calls before the promoted
|
||||
// user message in history.
|
||||
func TestPromoteQueuedMessage_E1_SynthesizesToolCancellations(t *testing.T) {
|
||||
t.Parallel()
|
||||
f := newTestFixture(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
created := createTestChat(t, f)
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
|
||||
|
||||
callID := "call_" + uuid.NewString()
|
||||
commitAssistantToolCall(t, f, m,
|
||||
nonDynamicAssistantToolCallMessage(t, f.Model.ID, callID))
|
||||
|
||||
// Land in R1 with one queued message.
|
||||
queued := sendQueuedMessage(t, f, m, "queued-for-promote")
|
||||
require.NotNil(t, queued.QueuedMessage)
|
||||
require.Equal(t, chatstate.StateR1, f.classify(ctx, t, created.Chat.ID))
|
||||
|
||||
// R1 -> E1 via FinishError.
|
||||
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
_, err := tx.FinishError(chatstate.FinishErrorInput{
|
||||
LastError: pqtype.NullRawMessage{
|
||||
RawMessage: json.RawMessage(`{"message":"boom"}`),
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
return err
|
||||
}))
|
||||
require.Equal(t, chatstate.StateE1, f.classify(ctx, t, created.Chat.ID))
|
||||
|
||||
var promote chatstate.PromoteQueuedMessageResult
|
||||
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
var err error
|
||||
promote, err = tx.PromoteQueuedMessage(chatstate.PromoteQueuedMessageInput{
|
||||
QueuedMessageID: queued.QueuedMessage.ID,
|
||||
})
|
||||
return err
|
||||
}))
|
||||
|
||||
require.Len(t, promote.CancellationMessages, 1)
|
||||
assertToolResultForCall(t, promote.CancellationMessages[0], callID)
|
||||
require.NotNil(t, promote.InsertedMessage)
|
||||
require.Equal(t, database.ChatMessageRoleUser, promote.InsertedMessage.Role)
|
||||
require.Less(t, promote.CancellationMessages[0].ID, promote.InsertedMessage.ID,
|
||||
"cancel is inserted before the promoted user message")
|
||||
}
|
||||
|
||||
// TestPromoteQueuedMessage_A1_SynthesizesDynamicToolCancellations
|
||||
// verifies that the dynamic outstanding tool call is canceled when
|
||||
// promoting from A1. The non-dynamic cancellation counterpart lives
|
||||
// in TestPromoteQueuedMessage_A1_CancelsAllOutstandingToolCalls.
|
||||
func TestPromoteQueuedMessage_A1_SynthesizesDynamicToolCancellations(t *testing.T) {
|
||||
t.Parallel()
|
||||
f := newTestFixture(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
toolName := "dyn_promote_a1"
|
||||
created := createTestChatWithDynamicTools(t, f, toolName)
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
|
||||
|
||||
dynCallID := "call_" + uuid.NewString()
|
||||
commitAssistantToolCall(t, f, m,
|
||||
assistantToolCallMessage(t, f.Model.ID, toolName, dynCallID))
|
||||
|
||||
// Land in A0.
|
||||
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
_, err := tx.EnterRequiresAction(chatstate.EnterRequiresActionInput{})
|
||||
return err
|
||||
}))
|
||||
require.Equal(t, chatstate.StateA0, f.classify(ctx, t, created.Chat.ID))
|
||||
|
||||
// A0 -> A1 with one queued user message.
|
||||
queued := sendQueuedMessage(t, f, m, "queued-for-a1-promote")
|
||||
require.NotNil(t, queued.QueuedMessage)
|
||||
require.Equal(t, chatstate.StateA1, f.classify(ctx, t, created.Chat.ID))
|
||||
|
||||
var promote chatstate.PromoteQueuedMessageResult
|
||||
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
var err error
|
||||
promote, err = tx.PromoteQueuedMessage(chatstate.PromoteQueuedMessageInput{
|
||||
QueuedMessageID: queued.QueuedMessage.ID,
|
||||
})
|
||||
return err
|
||||
}))
|
||||
|
||||
require.Len(t, promote.CancellationMessages, 1, "dynamic tool call canceled")
|
||||
assertToolResultForCall(t, promote.CancellationMessages[0], dynCallID)
|
||||
require.NotNil(t, promote.InsertedMessage)
|
||||
require.Equal(t, database.ChatMessageRoleUser, promote.InsertedMessage.Role)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// FinishTurn queue-promotion paths
|
||||
// =============================================================================
|
||||
|
||||
// TestFinishTurn_R1_SynthesizesToolCancellationsBeforePromotion
|
||||
// verifies that finishing a turn while a queued message exists
|
||||
// synthesizes outstanding tool cancellations before promoting the
|
||||
// queue head into history.
|
||||
func TestFinishTurn_R1_SynthesizesToolCancellationsBeforePromotion(t *testing.T) {
|
||||
t.Parallel()
|
||||
f := newTestFixture(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
created := createTestChat(t, f)
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
|
||||
|
||||
callID := "call_" + uuid.NewString()
|
||||
commitAssistantToolCall(t, f, m,
|
||||
nonDynamicAssistantToolCallMessage(t, f.Model.ID, callID))
|
||||
|
||||
queued := sendQueuedMessage(t, f, m, "queued-for-finish")
|
||||
require.NotNil(t, queued.QueuedMessage)
|
||||
require.Equal(t, chatstate.StateR1, f.classify(ctx, t, created.Chat.ID))
|
||||
|
||||
beforeIDs := historyMessageIDs(ctx, t, f, created.Chat.ID)
|
||||
|
||||
var finish chatstate.FinishTurnResult
|
||||
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
var err error
|
||||
finish, err = tx.FinishTurn(chatstate.FinishTurnInput{})
|
||||
return err
|
||||
}))
|
||||
require.NotNil(t, finish.PromotedMessage)
|
||||
require.Equal(t, database.ChatMessageRoleUser, finish.PromotedMessage.Role)
|
||||
|
||||
afterIDs := historyMessageIDs(ctx, t, f, created.Chat.ID)
|
||||
require.Equal(t, len(beforeIDs)+2, len(afterIDs),
|
||||
"finish inserts both a tool cancel and the promoted user")
|
||||
|
||||
// The two newly inserted messages are tool-result then user.
|
||||
newIDs := afterIDs[len(beforeIDs):]
|
||||
cancel, err := f.DB.GetChatMessageByID(ctx, newIDs[0])
|
||||
require.NoError(t, err)
|
||||
assertToolResultForCall(t, cancel, callID)
|
||||
require.Equal(t, finish.PromotedMessage.ID, newIDs[1])
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// FinishInterruption queue-promotion paths and outstanding tool calls
|
||||
// =============================================================================
|
||||
|
||||
// TestFinishInterruption_I1_PromotesQueueHead verifies that
|
||||
// FinishInterruption from I1 with no outstanding tool calls
|
||||
// promotes the queue head into history.
|
||||
func TestFinishInterruption_I1_PromotesQueueHead(t *testing.T) {
|
||||
t.Parallel()
|
||||
f := newTestFixture(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
created := createTestChat(t, f)
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
|
||||
|
||||
// Reach R1 with one queued message.
|
||||
queued := sendQueuedMessage(t, f, m, "queued-for-interruption")
|
||||
require.NotNil(t, queued.QueuedMessage)
|
||||
// R1 -> I1 via Interrupt.
|
||||
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
_, err := tx.Interrupt(chatstate.InterruptInput{Reason: "test"})
|
||||
return err
|
||||
}))
|
||||
require.Equal(t, chatstate.StateI1, f.classify(ctx, t, created.Chat.ID))
|
||||
|
||||
beforeIDs := historyMessageIDs(ctx, t, f, created.Chat.ID)
|
||||
|
||||
var finish chatstate.FinishInterruptionResult
|
||||
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
var err error
|
||||
finish, err = tx.FinishInterruption(chatstate.FinishInterruptionInput{})
|
||||
return err
|
||||
}))
|
||||
require.NotNil(t, finish.PromotedMessage)
|
||||
require.Equal(t, database.ChatMessageRoleUser, finish.PromotedMessage.Role)
|
||||
|
||||
afterIDs := historyMessageIDs(ctx, t, f, created.Chat.ID)
|
||||
require.Equal(t, len(beforeIDs)+1, len(afterIDs))
|
||||
require.Equal(t, chatstate.StateR0, f.classify(ctx, t, created.Chat.ID))
|
||||
}
|
||||
|
||||
// TestFinishInterruption_RejectsOutstandingToolCalls verifies that
|
||||
// FinishInterruption fails (TransitionNotAllowed-shaped) when the
|
||||
// chat still has an outstanding dynamic tool call after the partial
|
||||
// commit. The non-dynamic counterpart lives in
|
||||
// TestFinishInterruption_RejectsNonDynamicOutstandingToolCall. The
|
||||
// chat must remain in its prior state.
|
||||
func TestFinishInterruption_RejectsOutstandingToolCalls(t *testing.T) {
|
||||
t.Parallel()
|
||||
f := newTestFixture(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
toolName := "dyn_finish_reject"
|
||||
created := createTestChatWithDynamicTools(t, f, toolName)
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
|
||||
|
||||
dynCallID := "call_" + uuid.NewString()
|
||||
commitAssistantToolCall(t, f, m,
|
||||
assistantToolCallMessage(t, f.Model.ID, toolName, dynCallID))
|
||||
|
||||
// R0 -> I0 via Interrupt. Interrupt closes pending dynamic calls
|
||||
// when transitioning from A0/A1, but from R0 it does NOT, so the
|
||||
// chat keeps its outstanding dynamic call.
|
||||
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
_, err := tx.Interrupt(chatstate.InterruptInput{Reason: "test"})
|
||||
return err
|
||||
}))
|
||||
require.Equal(t, chatstate.StateI0, f.classify(ctx, t, created.Chat.ID))
|
||||
|
||||
stateBefore := f.classify(ctx, t, created.Chat.ID)
|
||||
historyBefore := historyMessageIDs(ctx, t, f, created.Chat.ID)
|
||||
publishedBefore := len(f.Pub.channels)
|
||||
|
||||
// FinishInterruption with no partial commits should reject
|
||||
// because the dynamic call is still outstanding.
|
||||
err := m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
_, err := tx.FinishInterruption(chatstate.FinishInterruptionInput{})
|
||||
return err
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, chatstate.ErrTransitionNotAllowed)
|
||||
|
||||
require.Equal(t, stateBefore, f.classify(ctx, t, created.Chat.ID), "state unchanged")
|
||||
require.Equal(t, historyBefore, historyMessageIDs(ctx, t, f, created.Chat.ID),
|
||||
"history unchanged on rejected finish")
|
||||
require.Equal(t, publishedBefore, len(f.Pub.channels),
|
||||
"failed FinishInterruption publishes nothing")
|
||||
}
|
||||
|
||||
// ensure unused imports don't break the build if any helper is
|
||||
// removed later.
|
||||
var _ = context.Background
|
||||
@@ -0,0 +1,234 @@
|
||||
package chatstate
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
// synthesizePendingToolCancellations builds [Message] inserts that
|
||||
// satisfy every outstanding tool call on the chat's last assistant
|
||||
// message with a synthetic cancellation tool-result message.
|
||||
//
|
||||
// "Outstanding" means a tool call present on the last assistant
|
||||
// message that does not yet have a matching tool-result message in
|
||||
// the active history after it. The caller controls whether to limit
|
||||
// to dynamic-tool calls (true) or close every outstanding tool call
|
||||
// regardless of source (false). The dynamic-only variant is used by
|
||||
// requires-action interrupts; the all-tool variant is used by any
|
||||
// transition that needs to insert a new user message into history.
|
||||
//
|
||||
// The synthetic results use the supplied chat's last_model_config_id.
|
||||
// Returns (nil, nil) when there is nothing to synthesize.
|
||||
//
|
||||
//nolint:revive // dynamicOnly is a domain flag, not a control flag.
|
||||
func synthesizePendingToolCancellations(
|
||||
ctx context.Context,
|
||||
store database.Store,
|
||||
chat database.Chat,
|
||||
reason string,
|
||||
dynamicOnly bool,
|
||||
) ([]Message, error) {
|
||||
var dynamicToolNames map[string]bool
|
||||
if dynamicOnly {
|
||||
var err error
|
||||
dynamicToolNames, err = parseDynamicToolNamesFromRaw(chat.DynamicTools)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse dynamic tool names: %w", err)
|
||||
}
|
||||
if len(dynamicToolNames) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
lastAssistant, err := store.GetLastChatMessageByRole(ctx, database.GetLastChatMessageByRoleParams{
|
||||
ChatID: chat.ID,
|
||||
Role: database.ChatMessageRoleAssistant,
|
||||
})
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get last assistant message: %w", err)
|
||||
}
|
||||
assistantParts, err := chatprompt.ParseContent(lastAssistant)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse assistant message: %w", err)
|
||||
}
|
||||
afterMsgs, err := store.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chat.ID,
|
||||
AfterID: lastAssistant.ID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get messages after assistant: %w", err)
|
||||
}
|
||||
handled := make(map[string]bool)
|
||||
for _, msg := range afterMsgs {
|
||||
if msg.Role != database.ChatMessageRoleTool {
|
||||
continue
|
||||
}
|
||||
parts, perr := chatprompt.ParseContent(msg)
|
||||
if perr != nil {
|
||||
// Don't fail the whole cancellation just because one
|
||||
// historical message is unparseable; treat its tool
|
||||
// results as unknown.
|
||||
continue
|
||||
}
|
||||
for _, p := range parts {
|
||||
if p.Type == codersdk.ChatMessagePartTypeToolResult {
|
||||
handled[p.ToolCallID] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
out := make([]Message, 0)
|
||||
for _, part := range assistantParts {
|
||||
if part.Type != codersdk.ChatMessagePartTypeToolCall {
|
||||
continue
|
||||
}
|
||||
if dynamicOnly && !dynamicToolNames[part.ToolName] {
|
||||
continue
|
||||
}
|
||||
if handled[part.ToolCallID] {
|
||||
continue
|
||||
}
|
||||
resultPart := codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeToolResult,
|
||||
ToolCallID: part.ToolCallID,
|
||||
ToolName: part.ToolName,
|
||||
Result: json.RawMessage(fmt.Sprintf("%q", reason)),
|
||||
IsError: true,
|
||||
}
|
||||
raw, merr := chatprompt.MarshalParts([]codersdk.ChatMessagePart{resultPart})
|
||||
if merr != nil {
|
||||
return nil, xerrors.Errorf("marshal synthetic tool result: %w", merr)
|
||||
}
|
||||
out = append(out, Message{
|
||||
Role: database.ChatMessageRoleTool,
|
||||
Content: raw,
|
||||
Visibility: database.ChatMessageVisibilityBoth,
|
||||
ContentVersion: chatprompt.CurrentContentVersion,
|
||||
ModelConfigID: uuid.NullUUID{UUID: chat.LastModelConfigID, Valid: true},
|
||||
})
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// pendingDynamicToolCallIDs returns the dynamic tool-call IDs on the
|
||||
// chat's last assistant message that do not yet have a matching
|
||||
// tool-result message in active history. The returned map is keyed by
|
||||
// tool-call ID and valued by tool name so callers can build matching
|
||||
// result messages without re-parsing the assistant content.
|
||||
func pendingDynamicToolCallIDs(ctx context.Context, store database.Store, chat database.Chat) (map[string]string, error) {
|
||||
dynamic, err := parseDynamicToolNamesFromRaw(chat.DynamicTools)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(dynamic) == 0 {
|
||||
return map[string]string{}, nil
|
||||
}
|
||||
return outstandingToolCallIDs(ctx, store, chat, func(toolName string) bool {
|
||||
return dynamic[toolName]
|
||||
})
|
||||
}
|
||||
|
||||
// pendingAllToolCallIDs returns the tool-call IDs of every outstanding
|
||||
// tool call on the chat's last assistant message, regardless of
|
||||
// whether the tool is dynamic. The returned map is keyed by tool-call
|
||||
// ID and valued by tool name. Callers that must guarantee a valid
|
||||
// LLM message history (e.g. before promoting a user message into
|
||||
// active history, or after committing an interruption's partial
|
||||
// messages) should use this variant so non-dynamic tool calls do not
|
||||
// silently bypass the check.
|
||||
func pendingAllToolCallIDs(ctx context.Context, store database.Store, chat database.Chat) (map[string]string, error) {
|
||||
return outstandingToolCallIDs(ctx, store, chat, func(string) bool { return true })
|
||||
}
|
||||
|
||||
// outstandingToolCallIDs walks the chat's last assistant message and
|
||||
// returns the subset of its tool calls that have no matching
|
||||
// tool-result message in the active history after it. The accept
|
||||
// callback can be used to restrict the walk to a subset of tools
|
||||
// (e.g. dynamic-only).
|
||||
func outstandingToolCallIDs(ctx context.Context, store database.Store, chat database.Chat, accept func(toolName string) bool) (map[string]string, error) {
|
||||
lastAssistant, err := store.GetLastChatMessageByRole(ctx, database.GetLastChatMessageByRoleParams{
|
||||
ChatID: chat.ID,
|
||||
Role: database.ChatMessageRoleAssistant,
|
||||
})
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return map[string]string{}, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get last assistant: %w", err)
|
||||
}
|
||||
parts, err := chatprompt.ParseContent(lastAssistant)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse assistant: %w", err)
|
||||
}
|
||||
afterMsgs, err := store.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chat.ID,
|
||||
AfterID: lastAssistant.ID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get messages after assistant: %w", err)
|
||||
}
|
||||
handled := make(map[string]bool)
|
||||
for _, msg := range afterMsgs {
|
||||
if msg.Role != database.ChatMessageRoleTool {
|
||||
continue
|
||||
}
|
||||
ps, perr := chatprompt.ParseContent(msg)
|
||||
if perr != nil {
|
||||
continue
|
||||
}
|
||||
for _, p := range ps {
|
||||
if p.Type == codersdk.ChatMessagePartTypeToolResult {
|
||||
handled[p.ToolCallID] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
out := make(map[string]string)
|
||||
for _, p := range parts {
|
||||
if p.Type != codersdk.ChatMessagePartTypeToolCall {
|
||||
continue
|
||||
}
|
||||
if !accept(p.ToolName) {
|
||||
continue
|
||||
}
|
||||
if handled[p.ToolCallID] {
|
||||
continue
|
||||
}
|
||||
out[p.ToolCallID] = p.ToolName
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// parseDynamicToolNamesFromRaw is a private mirror of
|
||||
// chatd.parseDynamicToolNames so chatstate does not pull a runtime
|
||||
// dependency on the chatd package. It accepts a nullable raw JSON
|
||||
// blob and returns a name set.
|
||||
func parseDynamicToolNamesFromRaw(raw pqtype.NullRawMessage) (map[string]bool, error) {
|
||||
if !raw.Valid || len(raw.RawMessage) == 0 {
|
||||
return map[string]bool{}, nil
|
||||
}
|
||||
var tools []codersdk.DynamicTool
|
||||
if err := json.Unmarshal(raw.RawMessage, &tools); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make(map[string]bool, len(tools))
|
||||
for _, t := range tools {
|
||||
out[t.Name] = true
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
@@ -0,0 +1,226 @@
|
||||
package chatstate
|
||||
|
||||
// Transition is the enumeration of transitions implemented by the
|
||||
// state machine. Values intentionally match the names of the public
|
||||
// methods on [Tx] (and [CreateChat]). The transition matrix below
|
||||
// declares the legal (from -> to) execution-state mappings used by
|
||||
// each transition method for validation.
|
||||
type Transition string
|
||||
|
||||
const (
|
||||
TransitionCreateChat Transition = "CreateChat"
|
||||
TransitionSetArchived Transition = "SetArchived"
|
||||
TransitionSendMessage Transition = "SendMessage"
|
||||
TransitionEditMessage Transition = "EditMessage"
|
||||
TransitionDeleteQueuedMessage Transition = "DeleteQueuedMessage"
|
||||
TransitionPromoteQueuedMessage Transition = "PromoteQueuedMessage"
|
||||
TransitionInterrupt Transition = "Interrupt"
|
||||
TransitionCompleteRequiresAction Transition = "CompleteRequiresAction"
|
||||
TransitionAcquire Transition = "Acquire"
|
||||
TransitionAbandon Transition = "Abandon"
|
||||
TransitionRecordGenerationAttempt Transition = "RecordGenerationAttempt"
|
||||
TransitionRecordRetryState Transition = "RecordRetryState"
|
||||
TransitionCommitStep Transition = "CommitStep"
|
||||
TransitionEnterRequiresAction Transition = "EnterRequiresAction"
|
||||
TransitionFinishInterruption Transition = "FinishInterruption"
|
||||
TransitionFinishTurn Transition = "FinishTurn"
|
||||
TransitionFinishError Transition = "FinishError"
|
||||
TransitionCancelRequiresAction Transition = "CancelRequiresAction"
|
||||
TransitionReconcileInvalidState Transition = "ReconcileInvalidState"
|
||||
)
|
||||
|
||||
// String implements fmt.Stringer.
|
||||
func (t Transition) String() string { return string(t) }
|
||||
|
||||
// AllExecutionTransitions is the canonical enumeration of every
|
||||
// execution-state transition that has an entry in the matrix below.
|
||||
// Ownership transitions (Acquire, Abandon) are intentionally not part
|
||||
// of this slice because they are validated independently and do not
|
||||
// have a (from->to) execution mapping.
|
||||
var AllExecutionTransitions = []Transition{
|
||||
TransitionCreateChat,
|
||||
TransitionSetArchived,
|
||||
TransitionSendMessage,
|
||||
TransitionEditMessage,
|
||||
TransitionDeleteQueuedMessage,
|
||||
TransitionPromoteQueuedMessage,
|
||||
TransitionInterrupt,
|
||||
TransitionCompleteRequiresAction,
|
||||
TransitionRecordGenerationAttempt,
|
||||
TransitionRecordRetryState,
|
||||
TransitionCommitStep,
|
||||
TransitionEnterRequiresAction,
|
||||
TransitionFinishInterruption,
|
||||
TransitionFinishTurn,
|
||||
TransitionFinishError,
|
||||
TransitionCancelRequiresAction,
|
||||
TransitionReconcileInvalidState,
|
||||
}
|
||||
|
||||
// transitionMatrix is the in-code mirror of the RFC's execution-state
|
||||
// transition table. Each entry maps an input state to the set of
|
||||
// allowed transitions together with the possible classified output
|
||||
// states that the transition implementation may land in. Outputs may
|
||||
// depend on the post-mutation queue cardinality (for example
|
||||
// DeleteQueuedMessage from E1 lands in E0 when the deleted row was the
|
||||
// last queued message, or stays in E1 otherwise), which is why several
|
||||
// entries list more than one output.
|
||||
//
|
||||
// Ownership transitions (Acquire, Abandon) are intentionally not
|
||||
// included; they are orthogonal to execution state.
|
||||
var transitionMatrix = map[ExecutionState]map[Transition][]ExecutionState{
|
||||
StateN: {
|
||||
TransitionCreateChat: {StateR0},
|
||||
},
|
||||
StateW: {
|
||||
TransitionSetArchived: {StateXW},
|
||||
TransitionSendMessage: {StateR0},
|
||||
TransitionEditMessage: {StateR0},
|
||||
},
|
||||
StateE0: {
|
||||
TransitionSetArchived: {StateXE0},
|
||||
TransitionSendMessage: {StateR0},
|
||||
TransitionEditMessage: {StateR0},
|
||||
},
|
||||
StateE1: {
|
||||
TransitionSetArchived: {StateXE1},
|
||||
TransitionSendMessage: {StateR1},
|
||||
TransitionEditMessage: {StateR0},
|
||||
TransitionDeleteQueuedMessage: {StateE0, StateE1},
|
||||
TransitionPromoteQueuedMessage: {StateR0, StateR1},
|
||||
},
|
||||
StateR0: {
|
||||
TransitionSendMessage: {StateR1, StateI1},
|
||||
TransitionEditMessage: {StateR0},
|
||||
TransitionInterrupt: {StateI0},
|
||||
TransitionRecordGenerationAttempt: {StateR0},
|
||||
TransitionRecordRetryState: {StateR0},
|
||||
TransitionCommitStep: {StateR0},
|
||||
TransitionEnterRequiresAction: {StateA0},
|
||||
TransitionFinishTurn: {StateW},
|
||||
TransitionFinishError: {StateE0},
|
||||
},
|
||||
StateR1: {
|
||||
TransitionSendMessage: {StateR1, StateI1},
|
||||
TransitionEditMessage: {StateR0},
|
||||
TransitionDeleteQueuedMessage: {StateR0, StateR1},
|
||||
TransitionPromoteQueuedMessage: {StateI1},
|
||||
TransitionInterrupt: {StateI1},
|
||||
TransitionRecordGenerationAttempt: {StateR1},
|
||||
TransitionRecordRetryState: {StateR1},
|
||||
TransitionCommitStep: {StateR1},
|
||||
TransitionEnterRequiresAction: {StateA1},
|
||||
TransitionFinishTurn: {StateR0, StateR1},
|
||||
TransitionFinishError: {StateE1},
|
||||
},
|
||||
StateI0: {
|
||||
TransitionSendMessage: {StateI1},
|
||||
TransitionEditMessage: {StateR0},
|
||||
TransitionFinishInterruption: {StateW},
|
||||
},
|
||||
StateI1: {
|
||||
TransitionSendMessage: {StateI1},
|
||||
TransitionEditMessage: {StateR0},
|
||||
TransitionDeleteQueuedMessage: {StateI0, StateI1},
|
||||
TransitionPromoteQueuedMessage: {StateI1},
|
||||
TransitionFinishInterruption: {StateR0, StateR1},
|
||||
},
|
||||
StateA0: {
|
||||
TransitionSendMessage: {StateA1, StateR1},
|
||||
TransitionEditMessage: {StateR0},
|
||||
TransitionInterrupt: {StateR0},
|
||||
TransitionCompleteRequiresAction: {StateR0},
|
||||
TransitionCancelRequiresAction: {StateR0},
|
||||
},
|
||||
StateA1: {
|
||||
TransitionSendMessage: {StateA1, StateR1},
|
||||
TransitionEditMessage: {StateR0},
|
||||
TransitionDeleteQueuedMessage: {StateA0, StateA1},
|
||||
TransitionPromoteQueuedMessage: {StateR0, StateR1},
|
||||
TransitionInterrupt: {StateR1},
|
||||
TransitionCompleteRequiresAction: {StateR1},
|
||||
TransitionCancelRequiresAction: {StateR1},
|
||||
},
|
||||
StateXW: {
|
||||
TransitionSetArchived: {StateW},
|
||||
},
|
||||
StateXE0: {
|
||||
TransitionSetArchived: {StateE0},
|
||||
},
|
||||
StateXE1: {
|
||||
TransitionSetArchived: {StateE1},
|
||||
},
|
||||
StateInvalid: {
|
||||
TransitionReconcileInvalidState: {StateE0, StateE1},
|
||||
},
|
||||
}
|
||||
|
||||
// isExecutionTransitionAllowed reports whether a transition is legal
|
||||
// from the supplied input state per the matrix above. Ownership
|
||||
// transitions are not stored in the matrix and always return false.
|
||||
func isExecutionTransitionAllowed(t Transition, from ExecutionState) bool {
|
||||
allowed, ok := transitionMatrix[from]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
_, ok = allowed[t]
|
||||
return ok
|
||||
}
|
||||
|
||||
// requireExecutionTransition validates that t is legal from `from`
|
||||
// and returns a typed *TransitionError otherwise.
|
||||
func requireExecutionTransition(t Transition, from ExecutionState) error {
|
||||
if isExecutionTransitionAllowed(t, from) {
|
||||
return nil
|
||||
}
|
||||
return newTransitionError(t, from, "")
|
||||
}
|
||||
|
||||
// AllowedExecutionTransitionsFrom returns a deterministic slice of
|
||||
// transitions legal from `from`. Mostly used by tests to enumerate the
|
||||
// matrix without leaking the internal map.
|
||||
func AllowedExecutionTransitionsFrom(from ExecutionState) []Transition {
|
||||
allowed := transitionMatrix[from]
|
||||
out := make([]Transition, 0, len(allowed))
|
||||
for _, t := range AllExecutionTransitions {
|
||||
if _, ok := allowed[t]; ok {
|
||||
out = append(out, t)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// AllowedInputStates returns a deterministic slice of execution states
|
||||
// from which `tr` is legal per the matrix above. Mostly used by tests
|
||||
// to enumerate the matrix without leaking the internal map.
|
||||
func AllowedInputStates(tr Transition) []ExecutionState {
|
||||
var out []ExecutionState
|
||||
for _, from := range AllExecutionStates {
|
||||
if isExecutionTransitionAllowed(tr, from) {
|
||||
out = append(out, from)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// AllowedExecutionTransitionOutputs returns the set of classified
|
||||
// post-states that the transition `tr` may produce from `from` per
|
||||
// the matrix above. The returned slice is a copy so callers may mutate
|
||||
// it without affecting the underlying matrix.
|
||||
//
|
||||
// When `tr` is not allowed from `from`, an empty (nil) slice is
|
||||
// returned. Tests use this helper to enumerate the (transition, from,
|
||||
// want) triples that must be exercised by the row-level matrix tests.
|
||||
func AllowedExecutionTransitionOutputs(from ExecutionState, tr Transition) []ExecutionState {
|
||||
allowed, ok := transitionMatrix[from]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
outputs, ok := allowed[tr]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
cp := make([]ExecutionState, len(outputs))
|
||||
copy(cp, outputs)
|
||||
return cp
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,907 @@
|
||||
package chatstate_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatstate"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
// =============================================================================
|
||||
// State seeding helpers
|
||||
//
|
||||
// seededChat is the shared output of seedState. Some transition tests
|
||||
// need extra context beyond the chat ID (for example, the queued
|
||||
// message ID to delete, or the message ID to edit), so this struct
|
||||
// surfaces what each state was seeded with.
|
||||
// =============================================================================
|
||||
|
||||
type seededChat struct {
|
||||
chatID uuid.UUID
|
||||
exists bool
|
||||
initialUserMessageID int64
|
||||
assistantToolCallMsgID int64
|
||||
queuedMessageIDs []int64
|
||||
// queuedMessageBodies is parallel to queuedMessageIDs and records
|
||||
// the text body each queued message was seeded with. Cases that
|
||||
// promote queued messages into history use this to assert the
|
||||
// promoted message content matches what was originally queued.
|
||||
queuedMessageBodies []string
|
||||
queuedMessageCreatedBy []uuid.UUID
|
||||
dynamicToolName string
|
||||
pendingToolCallID string
|
||||
pendingToolCallIDs []string
|
||||
}
|
||||
|
||||
// dynamicToolJSON returns the canonical [{name,description,input_schema}]
|
||||
// payload used to seed dynamic_tools on a chat. Tests that need
|
||||
// pending dynamic tool calls (A0, A1) reuse this and reference the
|
||||
// returned tool name in their assistant tool-call message.
|
||||
func dynamicToolJSON(name string) []byte {
|
||||
tools := []codersdk.DynamicTool{{
|
||||
Name: name,
|
||||
Description: "test tool",
|
||||
InputSchema: json.RawMessage(`{"type":"object","properties":{}}`),
|
||||
}}
|
||||
raw, err := json.Marshal(tools)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return raw
|
||||
}
|
||||
|
||||
// assistantToolCallMessage builds a chatstate.Message for an
|
||||
// assistant message that issues one tool call against the supplied
|
||||
// dynamic tool name. The tool-call ID is unique per call so multiple
|
||||
// messages do not collide.
|
||||
func assistantToolCallMessage(t *testing.T, modelID uuid.UUID, toolName, callID string) chatstate.Message {
|
||||
t.Helper()
|
||||
raw, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeToolCall,
|
||||
ToolCallID: callID,
|
||||
ToolName: toolName,
|
||||
Args: json.RawMessage(`{}`),
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
return chatstate.Message{
|
||||
Role: database.ChatMessageRoleAssistant,
|
||||
Content: raw,
|
||||
Visibility: database.ChatMessageVisibilityBoth,
|
||||
ContentVersion: chatprompt.CurrentContentVersion,
|
||||
ModelConfigID: uuid.NullUUID{UUID: modelID, Valid: true},
|
||||
}
|
||||
}
|
||||
|
||||
func mixedAssistantToolCallMessage(t *testing.T, modelID uuid.UUID, dynamicTool, dynCallID, nonDynCallID string) chatstate.Message {
|
||||
t.Helper()
|
||||
parts := []codersdk.ChatMessagePart{
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeToolCall,
|
||||
ToolCallID: dynCallID,
|
||||
ToolName: dynamicTool,
|
||||
Args: json.RawMessage(`{}`),
|
||||
},
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeToolCall,
|
||||
ToolCallID: nonDynCallID,
|
||||
ToolName: "non_dynamic_tool",
|
||||
Args: json.RawMessage(`{}`),
|
||||
},
|
||||
}
|
||||
raw, err := chatprompt.MarshalParts(parts)
|
||||
require.NoError(t, err)
|
||||
return chatstate.Message{
|
||||
Role: database.ChatMessageRoleAssistant,
|
||||
Content: raw,
|
||||
Visibility: database.ChatMessageVisibilityBoth,
|
||||
ContentVersion: chatprompt.CurrentContentVersion,
|
||||
ModelConfigID: uuid.NullUUID{UUID: modelID, Valid: true},
|
||||
}
|
||||
}
|
||||
|
||||
// createTestChatWithDynamicTools mirrors createTestChat but seeds the
|
||||
// chat with a non-empty dynamic_tools blob so EnterRequiresAction,
|
||||
// CompleteRequiresAction, and CancelRequiresAction can find pending
|
||||
// dynamic tool calls.
|
||||
func createTestChatWithDynamicTools(t *testing.T, f *testFixture, toolName string) chatstate.CreateChatResult {
|
||||
t.Helper()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
res, err := chatstate.CreateChat(ctx, f.DB, f.Pub, chatstate.CreateChatInput{
|
||||
OrganizationID: f.Org.ID,
|
||||
OwnerID: f.User.ID,
|
||||
LastModelConfigID: f.Model.ID,
|
||||
Title: "test",
|
||||
ClientType: database.ChatClientTypeApi,
|
||||
DynamicTools: pqtype.NullRawMessage{
|
||||
RawMessage: dynamicToolJSON(toolName),
|
||||
Valid: true,
|
||||
},
|
||||
InitialMessages: []chatstate.Message{
|
||||
userTextMessage("hello", f.User.ID, f.Model.ID),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return res
|
||||
}
|
||||
|
||||
// seedAOrA1 seeds a chat into A0 (queuedExtras=0) or A1
|
||||
// (queuedExtras>=1) with a real pending dynamic tool call. Used by
|
||||
// cases that need A0 or A1 with a configurable queue cardinality.
|
||||
func seedAOrA1(t *testing.T, f *testFixture, queuedExtras int, namePrefix string) seededChat {
|
||||
t.Helper()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
toolName := namePrefix
|
||||
callID := "call_" + uuid.NewString()
|
||||
created := createTestChatWithDynamicTools(t, f, toolName)
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
|
||||
var step chatstate.CommitStepResult
|
||||
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
var err error
|
||||
step, err = tx.CommitStep(chatstate.CommitStepInput{
|
||||
Messages: []chatstate.Message{
|
||||
assistantToolCallMessage(t, f.Model.ID, toolName, callID),
|
||||
},
|
||||
})
|
||||
return err
|
||||
}))
|
||||
require.Len(t, step.InsertedMessages, 1)
|
||||
// R0 -> A0.
|
||||
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
_, err := tx.EnterRequiresAction(chatstate.EnterRequiresActionInput{})
|
||||
return err
|
||||
}))
|
||||
var (
|
||||
queuedIDs []int64
|
||||
queuedBodies []string
|
||||
)
|
||||
for i := 0; i < queuedExtras; i++ {
|
||||
body := fmt.Sprintf("queued-%s-%d", namePrefix, i)
|
||||
sm := sendQueuedMessage(t, f, m, body)
|
||||
require.NotNil(t, sm.QueuedMessage)
|
||||
queuedIDs = append(queuedIDs, sm.QueuedMessage.ID)
|
||||
queuedBodies = append(queuedBodies, body)
|
||||
}
|
||||
return seededChat{
|
||||
chatID: created.Chat.ID,
|
||||
exists: true,
|
||||
initialUserMessageID: firstUserMessageID(ctx, t, f, created.Chat.ID),
|
||||
assistantToolCallMsgID: step.InsertedMessages[0].ID,
|
||||
queuedMessageIDs: queuedIDs,
|
||||
queuedMessageBodies: queuedBodies,
|
||||
dynamicToolName: toolName,
|
||||
pendingToolCallID: callID,
|
||||
}
|
||||
}
|
||||
|
||||
// seedState seeds a chat into the supplied execution state and
|
||||
// returns identifying handles useful for downstream assertions. For
|
||||
// [chatstate.StateN] the returned chatID is a fresh UUID that does
|
||||
// not exist in the database. Multi-queued seeds (for E1, R1, I1,
|
||||
// A1 with 2 queued messages, and Invalid with a non-empty queue) live in
|
||||
// seedStateMultiQueued.
|
||||
func seedState(t *testing.T, f *testFixture, state chatstate.ExecutionState) seededChat {
|
||||
t.Helper()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
switch state {
|
||||
case chatstate.StateN:
|
||||
return seededChat{chatID: uuid.New(), exists: false}
|
||||
|
||||
case chatstate.StateR0:
|
||||
created := createTestChat(t, f)
|
||||
initial := firstUserMessageID(ctx, t, f, created.Chat.ID)
|
||||
return seededChat{
|
||||
chatID: created.Chat.ID,
|
||||
exists: true,
|
||||
initialUserMessageID: initial,
|
||||
}
|
||||
|
||||
case chatstate.StateW:
|
||||
created := createTestChat(t, f)
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
|
||||
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
_, err := tx.FinishTurn(chatstate.FinishTurnInput{})
|
||||
return err
|
||||
}))
|
||||
return seededChat{
|
||||
chatID: created.Chat.ID,
|
||||
exists: true,
|
||||
initialUserMessageID: firstUserMessageID(ctx, t, f, created.Chat.ID),
|
||||
}
|
||||
|
||||
case chatstate.StateE0:
|
||||
created := createTestChat(t, f)
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
|
||||
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
_, err := tx.FinishError(chatstate.FinishErrorInput{
|
||||
LastError: pqtype.NullRawMessage{
|
||||
RawMessage: json.RawMessage(`{"message":"boom"}`),
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
return err
|
||||
}))
|
||||
return seededChat{
|
||||
chatID: created.Chat.ID,
|
||||
exists: true,
|
||||
initialUserMessageID: firstUserMessageID(ctx, t, f, created.Chat.ID),
|
||||
}
|
||||
|
||||
case chatstate.StateE1:
|
||||
created := createTestChat(t, f)
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
|
||||
// R0 -> R1
|
||||
queuedBody := "queued-for-E1"
|
||||
queued := sendQueuedMessage(t, f, m, queuedBody)
|
||||
require.NotNil(t, queued.QueuedMessage)
|
||||
// R1 -> E1
|
||||
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
_, err := tx.FinishError(chatstate.FinishErrorInput{
|
||||
LastError: pqtype.NullRawMessage{
|
||||
RawMessage: json.RawMessage(`{"message":"boom"}`),
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
return err
|
||||
}))
|
||||
return seededChat{
|
||||
chatID: created.Chat.ID,
|
||||
exists: true,
|
||||
initialUserMessageID: firstUserMessageID(ctx, t, f, created.Chat.ID),
|
||||
queuedMessageIDs: []int64{queued.QueuedMessage.ID},
|
||||
queuedMessageBodies: []string{queuedBody},
|
||||
}
|
||||
|
||||
case chatstate.StateR1:
|
||||
created := createTestChat(t, f)
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
|
||||
queuedBody := "queued-for-R1"
|
||||
queued := sendQueuedMessage(t, f, m, queuedBody)
|
||||
require.NotNil(t, queued.QueuedMessage)
|
||||
return seededChat{
|
||||
chatID: created.Chat.ID,
|
||||
exists: true,
|
||||
initialUserMessageID: firstUserMessageID(ctx, t, f, created.Chat.ID),
|
||||
queuedMessageIDs: []int64{queued.QueuedMessage.ID},
|
||||
queuedMessageBodies: []string{queuedBody},
|
||||
}
|
||||
|
||||
case chatstate.StateI0:
|
||||
created := createTestChat(t, f)
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
|
||||
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
_, err := tx.Interrupt(chatstate.InterruptInput{Reason: "seed"})
|
||||
return err
|
||||
}))
|
||||
return seededChat{
|
||||
chatID: created.Chat.ID,
|
||||
exists: true,
|
||||
initialUserMessageID: firstUserMessageID(ctx, t, f, created.Chat.ID),
|
||||
}
|
||||
|
||||
case chatstate.StateI1:
|
||||
created := createTestChat(t, f)
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
|
||||
// R0 -> I1: SendMessage with interrupt behavior queues the
|
||||
// message and sets status to interrupting.
|
||||
queuedBody := "queued-for-I1"
|
||||
sm := sendInterruptMessage(t, f, m, queuedBody)
|
||||
require.NotNil(t, sm.QueuedMessage)
|
||||
return seededChat{
|
||||
chatID: created.Chat.ID,
|
||||
exists: true,
|
||||
initialUserMessageID: firstUserMessageID(ctx, t, f, created.Chat.ID),
|
||||
queuedMessageIDs: []int64{sm.QueuedMessage.ID},
|
||||
queuedMessageBodies: []string{queuedBody},
|
||||
}
|
||||
|
||||
case chatstate.StateA0:
|
||||
return seedAOrA1(t, f, 0, "seed_tool_a0")
|
||||
|
||||
case chatstate.StateA1:
|
||||
return seedAOrA1(t, f, 1, "seed_tool_a1")
|
||||
|
||||
case chatstate.StateXW:
|
||||
created := createTestChat(t, f)
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
|
||||
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
_, err := tx.FinishTurn(chatstate.FinishTurnInput{})
|
||||
return err
|
||||
}))
|
||||
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
_, err := tx.SetArchived(chatstate.SetArchivedInput{Archived: true})
|
||||
return err
|
||||
}))
|
||||
return seededChat{
|
||||
chatID: created.Chat.ID,
|
||||
exists: true,
|
||||
initialUserMessageID: firstUserMessageID(ctx, t, f, created.Chat.ID),
|
||||
}
|
||||
|
||||
case chatstate.StateXE0:
|
||||
created := createTestChat(t, f)
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
|
||||
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
_, err := tx.FinishError(chatstate.FinishErrorInput{
|
||||
LastError: pqtype.NullRawMessage{
|
||||
RawMessage: json.RawMessage(`{"message":"boom"}`),
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
return err
|
||||
}))
|
||||
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
_, err := tx.SetArchived(chatstate.SetArchivedInput{Archived: true})
|
||||
return err
|
||||
}))
|
||||
return seededChat{
|
||||
chatID: created.Chat.ID,
|
||||
exists: true,
|
||||
initialUserMessageID: firstUserMessageID(ctx, t, f, created.Chat.ID),
|
||||
}
|
||||
|
||||
case chatstate.StateXE1:
|
||||
created := createTestChat(t, f)
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
|
||||
queuedBody := "queued-for-XE1"
|
||||
queued := sendQueuedMessage(t, f, m, queuedBody)
|
||||
require.NotNil(t, queued.QueuedMessage)
|
||||
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
_, err := tx.FinishError(chatstate.FinishErrorInput{
|
||||
LastError: pqtype.NullRawMessage{
|
||||
RawMessage: json.RawMessage(`{"message":"boom"}`),
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
return err
|
||||
}))
|
||||
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
_, err := tx.SetArchived(chatstate.SetArchivedInput{Archived: true})
|
||||
return err
|
||||
}))
|
||||
return seededChat{
|
||||
chatID: created.Chat.ID,
|
||||
exists: true,
|
||||
initialUserMessageID: firstUserMessageID(ctx, t, f, created.Chat.ID),
|
||||
queuedMessageIDs: []int64{queued.QueuedMessage.ID},
|
||||
queuedMessageBodies: []string{queuedBody},
|
||||
}
|
||||
|
||||
case chatstate.StateInvalid:
|
||||
created := createTestChat(t, f)
|
||||
// Force running + archived, a deliberately invalid
|
||||
// combination per the classifier.
|
||||
_, err := f.DB.UpdateChatExecutionState(ctx, database.UpdateChatExecutionStateParams{
|
||||
ID: created.Chat.ID,
|
||||
Status: database.ChatStatusRunning,
|
||||
Archived: true,
|
||||
WorkerID: created.Chat.WorkerID,
|
||||
RunnerID: created.Chat.RunnerID,
|
||||
LastError: created.Chat.LastError,
|
||||
RequiresActionDeadlineAt: created.Chat.RequiresActionDeadlineAt,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return seededChat{
|
||||
chatID: created.Chat.ID,
|
||||
exists: true,
|
||||
initialUserMessageID: firstUserMessageID(ctx, t, f, created.Chat.ID),
|
||||
}
|
||||
}
|
||||
t.Fatalf("seedState: unsupported execution state %s", state)
|
||||
return seededChat{}
|
||||
}
|
||||
|
||||
// seedStateMultiQueued seeds a state with two queued messages. Used
|
||||
// by cases that need the post-mutation queue to remain non-empty.
|
||||
// Supported states: E1, R1, I1, A1.
|
||||
func seedStateMultiQueued(t *testing.T, f *testFixture, state chatstate.ExecutionState) seededChat {
|
||||
t.Helper()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
switch state {
|
||||
case chatstate.StateE1:
|
||||
created := createTestChat(t, f)
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
|
||||
firstBody := "queued-e1-a"
|
||||
first := sendQueuedMessage(t, f, m, firstBody)
|
||||
require.NotNil(t, first.QueuedMessage)
|
||||
secondBody := "queued-e1-b"
|
||||
second := sendQueuedMessage(t, f, m, secondBody)
|
||||
require.NotNil(t, second.QueuedMessage)
|
||||
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
_, err := tx.FinishError(chatstate.FinishErrorInput{
|
||||
LastError: pqtype.NullRawMessage{
|
||||
RawMessage: json.RawMessage(`{"message":"boom"}`),
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
return err
|
||||
}))
|
||||
return seededChat{
|
||||
chatID: created.Chat.ID,
|
||||
exists: true,
|
||||
initialUserMessageID: firstUserMessageID(ctx, t, f, created.Chat.ID),
|
||||
queuedMessageIDs: []int64{first.QueuedMessage.ID, second.QueuedMessage.ID},
|
||||
queuedMessageBodies: []string{firstBody, secondBody},
|
||||
}
|
||||
|
||||
case chatstate.StateR1:
|
||||
created := createTestChat(t, f)
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
|
||||
firstBody := "queued-r1-a"
|
||||
first := sendQueuedMessage(t, f, m, firstBody)
|
||||
require.NotNil(t, first.QueuedMessage)
|
||||
secondBody := "queued-r1-b"
|
||||
second := sendQueuedMessage(t, f, m, secondBody)
|
||||
require.NotNil(t, second.QueuedMessage)
|
||||
return seededChat{
|
||||
chatID: created.Chat.ID,
|
||||
exists: true,
|
||||
initialUserMessageID: firstUserMessageID(ctx, t, f, created.Chat.ID),
|
||||
queuedMessageIDs: []int64{first.QueuedMessage.ID, second.QueuedMessage.ID},
|
||||
queuedMessageBodies: []string{firstBody, secondBody},
|
||||
}
|
||||
|
||||
case chatstate.StateI1:
|
||||
created := createTestChat(t, f)
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
|
||||
firstBody := "queued-i1-a"
|
||||
first := sendQueuedMessage(t, f, m, firstBody)
|
||||
require.NotNil(t, first.QueuedMessage)
|
||||
// R1 -> I1 via interrupt-mode SendMessage queues a second
|
||||
// message and flips status to interrupting.
|
||||
secondBody := "queued-i1-b"
|
||||
second := sendInterruptMessage(t, f, m, secondBody)
|
||||
require.NotNil(t, second.QueuedMessage)
|
||||
return seededChat{
|
||||
chatID: created.Chat.ID,
|
||||
exists: true,
|
||||
initialUserMessageID: firstUserMessageID(ctx, t, f, created.Chat.ID),
|
||||
queuedMessageIDs: []int64{first.QueuedMessage.ID, second.QueuedMessage.ID},
|
||||
queuedMessageBodies: []string{firstBody, secondBody},
|
||||
}
|
||||
|
||||
case chatstate.StateA1:
|
||||
return seedAOrA1(t, f, 2, "seed_tool_a1_multi")
|
||||
}
|
||||
t.Fatalf("seedStateMultiQueued: unsupported execution state %s", state)
|
||||
return seededChat{}
|
||||
}
|
||||
|
||||
// seedA1WithMixedOutstandingToolCalls seeds A1 with one queued message
|
||||
// and one assistant message carrying both a dynamic and non-dynamic
|
||||
// outstanding tool call. It is used by PromoteQueuedMessage(A1) to
|
||||
// prove all tool calls are closed before inserting the promoted user.
|
||||
func seedA1WithMixedOutstandingToolCalls(t *testing.T, f *testFixture, queuedExtras int, namePrefix string) seededChat {
|
||||
t.Helper()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
toolName := namePrefix
|
||||
dynCallID := "call_" + uuid.NewString()
|
||||
nonDynCallID := "call_" + uuid.NewString()
|
||||
created := createTestChatWithDynamicTools(t, f, toolName)
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
|
||||
var step chatstate.CommitStepResult
|
||||
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
var err error
|
||||
step, err = tx.CommitStep(chatstate.CommitStepInput{
|
||||
Messages: []chatstate.Message{
|
||||
mixedAssistantToolCallMessage(t, f.Model.ID, toolName, dynCallID, nonDynCallID),
|
||||
},
|
||||
})
|
||||
return err
|
||||
}))
|
||||
require.Len(t, step.InsertedMessages, 1)
|
||||
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
_, err := tx.EnterRequiresAction(chatstate.EnterRequiresActionInput{})
|
||||
return err
|
||||
}))
|
||||
var (
|
||||
queuedIDs []int64
|
||||
queuedBodies []string
|
||||
queuedCreatedBy []uuid.UUID
|
||||
)
|
||||
for i := range queuedExtras {
|
||||
body := fmt.Sprintf("queued-%s-%d", namePrefix, i)
|
||||
createdBy := uuid.New()
|
||||
queued, err := f.DB.InsertChatQueuedMessageWithCreator(ctx, database.InsertChatQueuedMessageWithCreatorParams{
|
||||
ChatID: created.Chat.ID,
|
||||
Content: userMessageContent(t, body),
|
||||
ModelConfigID: uuid.NullUUID{UUID: f.Model.ID, Valid: true},
|
||||
CreatedBy: createdBy,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
queuedIDs = append(queuedIDs, queued.ID)
|
||||
queuedBodies = append(queuedBodies, body)
|
||||
queuedCreatedBy = append(queuedCreatedBy, createdBy)
|
||||
}
|
||||
return seededChat{
|
||||
chatID: created.Chat.ID,
|
||||
exists: true,
|
||||
initialUserMessageID: firstUserMessageID(ctx, t, f, created.Chat.ID),
|
||||
assistantToolCallMsgID: step.InsertedMessages[0].ID,
|
||||
queuedMessageIDs: queuedIDs,
|
||||
queuedMessageBodies: queuedBodies,
|
||||
queuedMessageCreatedBy: queuedCreatedBy,
|
||||
dynamicToolName: toolName,
|
||||
pendingToolCallID: dynCallID,
|
||||
pendingToolCallIDs: []string{dynCallID, nonDynCallID},
|
||||
}
|
||||
}
|
||||
|
||||
// seedInvalidWithQueue seeds Invalid with a single queued message so
|
||||
// ReconcileInvalidState lands in E1 (non-empty queue) instead of E0.
|
||||
func seedInvalidWithQueue(t *testing.T, f *testFixture) seededChat {
|
||||
t.Helper()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
created := createTestChat(t, f)
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
|
||||
queuedBody := "queued-invalid"
|
||||
queued := sendQueuedMessage(t, f, m, queuedBody)
|
||||
require.NotNil(t, queued.QueuedMessage)
|
||||
// Force the deliberately invalid running + archived combo on
|
||||
// top of the queue.
|
||||
chat, err := f.DB.GetChatByID(ctx, created.Chat.ID)
|
||||
require.NoError(t, err)
|
||||
_, err = f.DB.UpdateChatExecutionState(ctx, database.UpdateChatExecutionStateParams{
|
||||
ID: chat.ID,
|
||||
Status: database.ChatStatusRunning,
|
||||
Archived: true,
|
||||
WorkerID: chat.WorkerID,
|
||||
RunnerID: chat.RunnerID,
|
||||
LastError: chat.LastError,
|
||||
RequiresActionDeadlineAt: chat.RequiresActionDeadlineAt,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return seededChat{
|
||||
chatID: chat.ID,
|
||||
exists: true,
|
||||
initialUserMessageID: firstUserMessageID(ctx, t, f, chat.ID),
|
||||
queuedMessageIDs: []int64{queued.QueuedMessage.ID},
|
||||
queuedMessageBodies: []string{queuedBody},
|
||||
}
|
||||
}
|
||||
|
||||
// firstUserMessageID returns the lowest-id non-deleted user message
|
||||
// on the chat. Most transition tests reuse this when they need a
|
||||
// user message to edit.
|
||||
func firstUserMessageID(ctx context.Context, t *testing.T, f *testFixture, chatID uuid.UUID) int64 {
|
||||
t.Helper()
|
||||
msgs, err := f.DB.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
for _, m := range msgs {
|
||||
if m.Role == database.ChatMessageRoleUser && !m.Deleted {
|
||||
return m.ID
|
||||
}
|
||||
}
|
||||
t.Fatalf("firstUserMessageID: chat %s has no user messages", chatID)
|
||||
return 0
|
||||
}
|
||||
|
||||
func firstAssistantMessageID(ctx context.Context, t *testing.T, f *testFixture, chatID uuid.UUID) int64 {
|
||||
t.Helper()
|
||||
msgs, err := f.DB.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
for _, m := range msgs {
|
||||
if m.Role == database.ChatMessageRoleAssistant && !m.Deleted {
|
||||
return m.ID
|
||||
}
|
||||
}
|
||||
t.Fatalf("firstAssistantMessageID: chat %s has no assistant messages", chatID)
|
||||
return 0
|
||||
}
|
||||
|
||||
// seedForEnterRequiresAction extends seedState for R0 and R1 with a
|
||||
// chat that has dynamic_tools plus an assistant tool-call message in
|
||||
// history. EnterRequiresAction's precondition rejects R0/R1 without
|
||||
// pending dynamic tool calls, so the generic seedState path will not
|
||||
// do. Other states fall through to the default seedState.
|
||||
func seedForEnterRequiresAction(t *testing.T, f *testFixture, state chatstate.ExecutionState) seededChat {
|
||||
t.Helper()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
switch state {
|
||||
case chatstate.StateR0:
|
||||
toolName := "ra_tool_r0"
|
||||
callID := "call_" + uuid.NewString()
|
||||
created := createTestChatWithDynamicTools(t, f, toolName)
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
|
||||
var step chatstate.CommitStepResult
|
||||
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
var err error
|
||||
step, err = tx.CommitStep(chatstate.CommitStepInput{
|
||||
Messages: []chatstate.Message{
|
||||
assistantToolCallMessage(t, f.Model.ID, toolName, callID),
|
||||
},
|
||||
})
|
||||
return err
|
||||
}))
|
||||
require.Len(t, step.InsertedMessages, 1)
|
||||
return seededChat{
|
||||
chatID: created.Chat.ID,
|
||||
exists: true,
|
||||
initialUserMessageID: firstUserMessageID(ctx, t, f, created.Chat.ID),
|
||||
assistantToolCallMsgID: step.InsertedMessages[0].ID,
|
||||
dynamicToolName: toolName,
|
||||
pendingToolCallID: callID,
|
||||
pendingToolCallIDs: []string{callID},
|
||||
}
|
||||
case chatstate.StateR1:
|
||||
toolName := "ra_tool_r1"
|
||||
callID := "call_" + uuid.NewString()
|
||||
created := createTestChatWithDynamicTools(t, f, toolName)
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
|
||||
var step chatstate.CommitStepResult
|
||||
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
var err error
|
||||
step, err = tx.CommitStep(chatstate.CommitStepInput{
|
||||
Messages: []chatstate.Message{
|
||||
assistantToolCallMessage(t, f.Model.ID, toolName, callID),
|
||||
},
|
||||
})
|
||||
return err
|
||||
}))
|
||||
// R0 -> R1 with a queued message.
|
||||
queuedBody := "queued-for-RA-r1"
|
||||
sm := sendQueuedMessage(t, f, m, queuedBody)
|
||||
require.NotNil(t, sm.QueuedMessage)
|
||||
return seededChat{
|
||||
chatID: created.Chat.ID,
|
||||
exists: true,
|
||||
initialUserMessageID: firstUserMessageID(ctx, t, f, created.Chat.ID),
|
||||
assistantToolCallMsgID: step.InsertedMessages[0].ID,
|
||||
queuedMessageIDs: []int64{sm.QueuedMessage.ID},
|
||||
queuedMessageBodies: []string{queuedBody},
|
||||
dynamicToolName: toolName,
|
||||
pendingToolCallID: callID,
|
||||
pendingToolCallIDs: []string{callID},
|
||||
}
|
||||
}
|
||||
return seedState(t, f, state)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Snapshot baselines and shared assertion helpers used by every matrix
|
||||
// case. captureBaseline records the chat's snapshot_version and the
|
||||
// publisher's recorded channel count immediately before a transition
|
||||
// runs; assertSnapshotBumpedOnce and assertNoMutationOrPublish use the
|
||||
// baseline to verify either a single snapshot bump and one chat:update
|
||||
// on success, or zero mutation and zero publishes on failure.
|
||||
// =============================================================================
|
||||
|
||||
// activeHistoryIDs returns the ids of non-deleted history messages
|
||||
// for the chat in row-id order. Useful for verifying CommitStep,
|
||||
// EditMessage replacement, and PromoteQueuedMessage head insertion.
|
||||
func activeHistoryIDs(ctx context.Context, t *testing.T, f *testFixture, chatID uuid.UUID) []int64 {
|
||||
t.Helper()
|
||||
msgs, err := f.DB.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
out := make([]int64, 0, len(msgs))
|
||||
for _, m := range msgs {
|
||||
if !m.Deleted {
|
||||
out = append(out, m.ID)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func requireChatMessageByID(ctx context.Context, t *testing.T, f *testFixture, id int64) database.ChatMessage {
|
||||
t.Helper()
|
||||
msg, err := f.DB.GetChatMessageByID(ctx, id)
|
||||
require.NoError(t, err)
|
||||
return msg
|
||||
}
|
||||
|
||||
func requireQueuedMessageByID(ctx context.Context, t *testing.T, f *testFixture, chatID uuid.UUID, id int64) database.ChatQueuedMessage {
|
||||
t.Helper()
|
||||
msg, err := f.DB.GetChatQueuedMessageByID(ctx, database.GetChatQueuedMessageByIDParams{
|
||||
ID: id,
|
||||
ChatID: chatID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return msg
|
||||
}
|
||||
|
||||
func requireQueuedMessageDeleted(ctx context.Context, t *testing.T, f *testFixture, chatID uuid.UUID, id int64) {
|
||||
t.Helper()
|
||||
_, err := f.DB.GetChatQueuedMessageByID(ctx, database.GetChatQueuedMessageByIDParams{
|
||||
ID: id,
|
||||
ChatID: chatID,
|
||||
})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func assertFetchedUserMessage(ctx context.Context, t *testing.T, f *testFixture, msg database.ChatMessage) database.ChatMessage {
|
||||
t.Helper()
|
||||
fetched := requireChatMessageByID(ctx, t, f, msg.ID)
|
||||
require.Equal(t, msg.ChatID, fetched.ChatID)
|
||||
require.Equal(t, database.ChatMessageRoleUser, fetched.Role)
|
||||
require.True(t, fetched.CreatedBy.Valid)
|
||||
require.Equal(t, f.User.ID, fetched.CreatedBy.UUID)
|
||||
require.True(t, fetched.ModelConfigID.Valid)
|
||||
require.Equal(t, f.Model.ID, fetched.ModelConfigID.UUID)
|
||||
require.Equal(t, chatprompt.CurrentContentVersion, fetched.ContentVersion)
|
||||
return fetched
|
||||
}
|
||||
|
||||
func assertFetchedQueuedMessage(ctx context.Context, t *testing.T, f *testFixture, chatID uuid.UUID, queued database.ChatQueuedMessage) database.ChatQueuedMessage {
|
||||
t.Helper()
|
||||
fetched := requireQueuedMessageByID(ctx, t, f, chatID, queued.ID)
|
||||
require.Equal(t, chatID, fetched.ChatID)
|
||||
require.Equal(t, f.User.ID, fetched.CreatedBy)
|
||||
require.True(t, fetched.ModelConfigID.Valid)
|
||||
require.Equal(t, f.Model.ID, fetched.ModelConfigID.UUID)
|
||||
require.NotEmpty(t, fetched.Content)
|
||||
return fetched
|
||||
}
|
||||
|
||||
func newActiveMessageIDs(base snapshotBaseline, after []int64) []int64 {
|
||||
seen := make(map[int64]struct{}, len(base.historyIDs))
|
||||
for _, id := range base.historyIDs {
|
||||
seen[id] = struct{}{}
|
||||
}
|
||||
out := make([]int64, 0, len(after))
|
||||
for _, id := range after {
|
||||
if _, ok := seen[id]; !ok {
|
||||
out = append(out, id)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// assertToolResultForCallNoError asserts that msg is a tool-result
|
||||
// message that resolves a tool call with id wantCallID, is_error=false,
|
||||
// and that the result JSON matches wantResultJSON. Complements
|
||||
// assertToolResultForCall in synthetic_cancellation_test.go which
|
||||
// asserts is_error=true.
|
||||
func assertToolResultForCallNoError(t *testing.T, msg database.ChatMessage, wantCallID, wantResultJSON string) {
|
||||
t.Helper()
|
||||
require.Equal(t, database.ChatMessageRoleTool, msg.Role)
|
||||
parts, err := chatprompt.ParseContent(msg)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, parts)
|
||||
var found bool
|
||||
for _, p := range parts {
|
||||
if p.Type != codersdk.ChatMessagePartTypeToolResult {
|
||||
continue
|
||||
}
|
||||
require.Equal(t, wantCallID, p.ToolCallID, "tool-call id matches")
|
||||
require.False(t, p.IsError, "CompleteRequiresAction tool result must not be is_error")
|
||||
require.JSONEq(t, wantResultJSON, string(p.Result), "CompleteRequiresAction tool result JSON matches submitted output")
|
||||
found = true
|
||||
}
|
||||
require.True(t, found, "expected at least one tool-result part")
|
||||
}
|
||||
|
||||
// assertChatMessageText asserts that the persisted content of msg
|
||||
// decodes to a single text part with the supplied body. Used by
|
||||
// matrix cases that need to verify the actual text submitted via
|
||||
// SendMessage / EditMessage / CommitStep, or the text that was
|
||||
// promoted out of the queue into history.
|
||||
func assertChatMessageText(t *testing.T, msg database.ChatMessage, want string) {
|
||||
t.Helper()
|
||||
parts, err := chatprompt.ParseContent(msg)
|
||||
require.NoError(t, err, "parse chat message content")
|
||||
require.Len(t, parts, 1, "expected exactly one content part")
|
||||
require.Equal(t, codersdk.ChatMessagePartTypeText, parts[0].Type,
|
||||
"expected a text content part")
|
||||
require.Equal(t, want, parts[0].Text, "unexpected chat message text")
|
||||
}
|
||||
|
||||
// assertQueuedMessageText asserts that the JSON content of queued
|
||||
// decodes to a single text part with the supplied body. Used by
|
||||
// matrix cases that need to verify the body inserted into
|
||||
// chat_queued_messages via SendMessage.
|
||||
func assertQueuedMessageText(t *testing.T, queued database.ChatQueuedMessage, want string) {
|
||||
t.Helper()
|
||||
var parts []codersdk.ChatMessagePart
|
||||
require.NoError(t, json.Unmarshal(queued.Content, &parts), "unmarshal queued content")
|
||||
require.Len(t, parts, 1, "expected exactly one queued content part")
|
||||
require.Equal(t, codersdk.ChatMessagePartTypeText, parts[0].Type,
|
||||
"expected a text content part")
|
||||
require.Equal(t, want, parts[0].Text, "unexpected queued message text")
|
||||
}
|
||||
|
||||
// assertQueueBodiesInOrder fetches the queued messages for the chat
|
||||
// in queue order and asserts each row's text body matches the
|
||||
// supplied bodies. Used by matrix cases that need to verify the
|
||||
// remaining queue content after a promote / finish-turn /
|
||||
// finish-interruption.
|
||||
func assertQueueBodiesInOrder(ctx context.Context, t *testing.T, f *testFixture, chatID uuid.UUID, want []string) {
|
||||
t.Helper()
|
||||
rows, err := f.DB.GetChatQueuedMessagesByPosition(ctx, chatID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, rows, len(want), "queue length must match expected bodies")
|
||||
for i, r := range rows {
|
||||
assertQueuedMessageText(t, r, want[i])
|
||||
}
|
||||
}
|
||||
|
||||
// snapshotBaseline records the chat's snapshot_version and the
|
||||
// publisher's recorded channel count immediately before a transition
|
||||
// runs. Tests use it to verify either a single snapshot bump and one
|
||||
// chat:update on success, or zero mutation and zero publishes on
|
||||
// failure.
|
||||
type snapshotBaseline struct {
|
||||
exists bool
|
||||
chat database.Chat
|
||||
snapshot int64
|
||||
historyVersion int64
|
||||
queueVersion int64
|
||||
retryStateVersion int64
|
||||
generationAttempt int64
|
||||
queueCount int64
|
||||
queueIDs []int64
|
||||
historyIDs []int64
|
||||
channels int
|
||||
}
|
||||
|
||||
func captureBaseline(ctx context.Context, t *testing.T, f *testFixture, seeded seededChat) snapshotBaseline {
|
||||
t.Helper()
|
||||
base := snapshotBaseline{
|
||||
exists: seeded.exists,
|
||||
channels: len(f.Pub.channels),
|
||||
}
|
||||
if !seeded.exists {
|
||||
return base
|
||||
}
|
||||
chat, err := f.DB.GetChatByID(ctx, seeded.chatID)
|
||||
require.NoError(t, err)
|
||||
base.chat = chat
|
||||
base.snapshot = chat.SnapshotVersion
|
||||
base.historyVersion = chat.HistoryVersion
|
||||
base.queueVersion = chat.QueueVersion
|
||||
base.retryStateVersion = chat.RetryStateVersion
|
||||
base.generationAttempt = chat.GenerationAttempt
|
||||
base.queueIDs = queuedIDsByPosition(ctx, t, f, seeded.chatID)
|
||||
count, err := f.DB.CountChatQueuedMessages(ctx, seeded.chatID)
|
||||
require.NoError(t, err)
|
||||
base.queueCount = count
|
||||
base.historyIDs = activeHistoryIDs(ctx, t, f, seeded.chatID)
|
||||
return base
|
||||
}
|
||||
|
||||
// assertSnapshotBumpedOnce asserts that one Update committed; that is,
|
||||
// snapshot_version advanced by exactly one and the publisher saw at
|
||||
// least one chat:update on the per-chat channel after the baseline.
|
||||
func assertSnapshotBumpedOnce(ctx context.Context, t *testing.T, f *testFixture, chatID uuid.UUID, base snapshotBaseline) {
|
||||
t.Helper()
|
||||
after, err := f.DB.GetChatByID(ctx, chatID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, base.snapshot+1, after.SnapshotVersion, "snapshot_version must bump exactly once")
|
||||
channel := coderdpubsub.ChatStateUpdateChannel(chatID)
|
||||
found := false
|
||||
for _, c := range f.Pub.channels[base.channels:] {
|
||||
if c == channel {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, found, "expected one chat:update on %s after commit", channel)
|
||||
}
|
||||
|
||||
// assertNoMutationOrPublish asserts a failed transition rolled back
|
||||
// the automatic snapshot bump and published nothing.
|
||||
func assertNoMutationOrPublish(ctx context.Context, t *testing.T, f *testFixture, chatID uuid.UUID, base snapshotBaseline) {
|
||||
t.Helper()
|
||||
require.Equal(t, base.channels, len(f.Pub.channels), "failed transition must not publish")
|
||||
if base.exists {
|
||||
after, err := f.DB.GetChatByID(ctx, chatID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, base.snapshot, after.SnapshotVersion, "failed transition must not advance snapshot_version")
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,747 @@
|
||||
package chatstate_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatstate"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
// =============================================================================
|
||||
// CreateChat tests.
|
||||
//
|
||||
// CreateChat is the only transition that originates from StateN and it
|
||||
// is not exercised through ChatMachine.Update, so it lives outside
|
||||
// TestTransitionMatrix_AllCombinations.
|
||||
// =============================================================================
|
||||
|
||||
// TestTransitionCreate_NToR0 verifies that CreateChat lands a fresh
|
||||
// chat in R0 with snapshot_version 1, the initial user message
|
||||
// recorded at revision 1, queue_version still 0, and the post-commit
|
||||
// publish requesting an ownership hint plus a chat:update.
|
||||
func TestTransitionCreate_NToR0(t *testing.T) {
|
||||
t.Parallel()
|
||||
f := newTestFixture(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
res := createTestChat(t, f)
|
||||
|
||||
require.Equal(t, database.ChatStatusRunning, res.Chat.Status)
|
||||
require.False(t, res.Chat.Archived)
|
||||
require.Equal(t, int64(1), res.Chat.SnapshotVersion, "snapshot_version starts at 1")
|
||||
require.Equal(t, int64(1), res.Chat.HistoryVersion, "history_version set by trigger after initial insert")
|
||||
require.Equal(t, int64(0), res.Chat.QueueVersion, "queue_version stays 0 when no queue rows")
|
||||
require.Equal(t, int64(0), res.Chat.GenerationAttempt)
|
||||
require.NotEmpty(t, res.InitialMessages)
|
||||
require.Equal(t, int64(1), res.InitialMessages[0].Revision)
|
||||
require.Equal(t, chatstate.StateR0, f.classify(ctx, t, res.Chat.ID))
|
||||
require.True(t, f.Pub.hasOwnership(), "newly created chat is runnable and unowned")
|
||||
f.Pub.expectChatUpdate(t, res.Chat.ID, 1)
|
||||
}
|
||||
|
||||
// TestCreateChat_RejectsEmptyInitialMessages verifies that CreateChat
|
||||
// rejects an empty InitialMessages slice with ErrTransitionNotAllowed
|
||||
// and does not publish anything.
|
||||
func TestCreateChat_RejectsEmptyInitialMessages(t *testing.T) {
|
||||
t.Parallel()
|
||||
f := newTestFixture(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
_, err := chatstate.CreateChat(ctx, f.DB, f.Pub, chatstate.CreateChatInput{
|
||||
OrganizationID: f.Org.ID,
|
||||
OwnerID: f.User.ID,
|
||||
LastModelConfigID: f.Model.ID,
|
||||
ClientType: database.ChatClientTypeApi,
|
||||
Title: "t",
|
||||
InitialMessages: nil,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, chatstate.ErrTransitionNotAllowed)
|
||||
require.Empty(t, f.Pub.channels, "rejected create must not publish")
|
||||
}
|
||||
|
||||
func TestCreateChat_AllowsNoUserMessages(t *testing.T) {
|
||||
t.Parallel()
|
||||
f := newTestFixture(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
assistant := userTextMessage("oops", f.User.ID, f.Model.ID)
|
||||
assistant.Role = database.ChatMessageRoleAssistant
|
||||
res, err := chatstate.CreateChat(ctx, f.DB, f.Pub, chatstate.CreateChatInput{
|
||||
OrganizationID: f.Org.ID,
|
||||
OwnerID: f.User.ID,
|
||||
LastModelConfigID: f.Model.ID,
|
||||
Title: "t",
|
||||
ClientType: database.ChatClientTypeApi,
|
||||
InitialMessages: []chatstate.Message{assistant},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, res.InitialMessages, 1)
|
||||
}
|
||||
|
||||
func TestCreateChat_AllowsNonFinalUserMessage(t *testing.T) {
|
||||
t.Parallel()
|
||||
f := newTestFixture(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
res, err := chatstate.CreateChat(ctx, f.DB, f.Pub, chatstate.CreateChatInput{
|
||||
OrganizationID: f.Org.ID,
|
||||
OwnerID: f.User.ID,
|
||||
LastModelConfigID: f.Model.ID,
|
||||
Title: "t",
|
||||
ClientType: database.ChatClientTypeApi,
|
||||
InitialMessages: []chatstate.Message{
|
||||
userTextMessage("context user", f.User.ID, f.Model.ID),
|
||||
userTextMessage("final user", f.User.ID, f.Model.ID),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, res.InitialMessages, 2)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Input-specific rejection cases.
|
||||
//
|
||||
// These tests cover the same matrix rows as TestTransitionMatrix_AllCombinations
|
||||
// but exercise legal source states with invalid transition inputs. They are
|
||||
// intentionally outside the matrix entry point so the matrix focus stays on
|
||||
// positive cases and generated disallowed cases.
|
||||
// =============================================================================
|
||||
|
||||
type setArchivedWrongDirectionCase struct {
|
||||
from chatstate.ExecutionState
|
||||
wantArchive bool
|
||||
label string
|
||||
}
|
||||
|
||||
func setArchivedWrongDirectionCases() []setArchivedWrongDirectionCase {
|
||||
return []setArchivedWrongDirectionCase{
|
||||
// Non-archived states with archived=false: no-op.
|
||||
{from: chatstate.StateW, wantArchive: false, label: "W_to_W"},
|
||||
{from: chatstate.StateE0, wantArchive: false, label: "E0_to_E0"},
|
||||
{from: chatstate.StateE1, wantArchive: false, label: "E1_to_E1"},
|
||||
// Archived states with archived=true: no-op.
|
||||
{from: chatstate.StateXW, wantArchive: true, label: "XW_to_XW"},
|
||||
{from: chatstate.StateXE0, wantArchive: true, label: "XE0_to_XE0"},
|
||||
{from: chatstate.StateXE1, wantArchive: true, label: "XE1_to_XE1"},
|
||||
}
|
||||
}
|
||||
|
||||
var invalidBusyBehaviors = []chatstate.BusyBehavior{
|
||||
chatstate.BusyBehavior(""),
|
||||
chatstate.BusyBehavior("not-a-real-mode"),
|
||||
}
|
||||
|
||||
func runSetArchivedWrongDirectionCase(t *testing.T, tc setArchivedWrongDirectionCase) {
|
||||
t.Helper()
|
||||
f := newTestFixture(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
seeded := seedState(t, f, tc.from)
|
||||
require.Equal(t, tc.from, f.classify(ctx, t, seeded.chatID))
|
||||
|
||||
base := captureBaseline(ctx, t, f, seeded)
|
||||
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, seeded.chatID, chatstate.Options{})
|
||||
err := m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
_, serr := tx.SetArchived(chatstate.SetArchivedInput{Archived: tc.wantArchive})
|
||||
return serr
|
||||
})
|
||||
require.Error(t, err, "SetArchived must reject when Archived matches the current value")
|
||||
require.ErrorIs(t, err, chatstate.ErrTransitionNotAllowed,
|
||||
"SetArchived must wrap ErrTransitionNotAllowed")
|
||||
var te *chatstate.TransitionError
|
||||
require.ErrorAs(t, err, &te,
|
||||
"SetArchived must return a typed TransitionError")
|
||||
require.Equal(t, chatstate.TransitionSetArchived, te.Transition)
|
||||
require.Equal(t, tc.from, te.From, "TransitionError records the loaded from-state")
|
||||
|
||||
require.Equal(t, tc.from, f.classify(ctx, t, seeded.chatID),
|
||||
"rejected SetArchived must leave the chat in the same state")
|
||||
assertNoMutationOrPublish(ctx, t, f, seeded.chatID, base)
|
||||
}
|
||||
|
||||
func runInvalidBusyBehaviorCase(t *testing.T, from chatstate.ExecutionState, bb chatstate.BusyBehavior) {
|
||||
t.Helper()
|
||||
f := newTestFixture(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
seeded := seedState(t, f, from)
|
||||
require.Equal(t, from, f.classify(ctx, t, seeded.chatID),
|
||||
"seed must land in %s", from)
|
||||
base := captureBaseline(ctx, t, f, seeded)
|
||||
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, seeded.chatID, chatstate.Options{})
|
||||
err := m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
_, serr := tx.SendMessage(chatstate.SendMessageInput{
|
||||
Message: userTextMessage("invalid-bb", f.User.ID, f.Model.ID),
|
||||
BusyBehavior: bb,
|
||||
})
|
||||
return serr
|
||||
})
|
||||
require.Error(t, err, "SendMessage must reject invalid BusyBehavior")
|
||||
require.ErrorIs(t, err, chatstate.ErrTransitionNotAllowed,
|
||||
"SendMessage rejection must wrap ErrTransitionNotAllowed")
|
||||
var te *chatstate.TransitionError
|
||||
require.ErrorAs(t, err, &te,
|
||||
"SendMessage must return a typed TransitionError")
|
||||
require.Equal(t, chatstate.TransitionSendMessage, te.Transition)
|
||||
require.Equal(t, from, te.From,
|
||||
"TransitionError records the source state")
|
||||
|
||||
require.Equal(t, from, f.classify(ctx, t, seeded.chatID),
|
||||
"rejected SendMessage must leave the chat in the same state")
|
||||
assertNoMutationOrPublish(ctx, t, f, seeded.chatID, base)
|
||||
}
|
||||
|
||||
type completeRequiresActionRejectCase struct {
|
||||
name string
|
||||
results func(seeded seededChat) []chatstate.ToolResultInput
|
||||
}
|
||||
|
||||
type recordRetryStateRejectCase struct {
|
||||
name string
|
||||
retryState pqtype.NullRawMessage
|
||||
}
|
||||
|
||||
func completeRequiresActionRejectCases() []completeRequiresActionRejectCase {
|
||||
valid := func(id string) chatstate.ToolResultInput {
|
||||
return chatstate.ToolResultInput{
|
||||
ToolCallID: id,
|
||||
Output: json.RawMessage(`{"ok":true}`),
|
||||
}
|
||||
}
|
||||
return []completeRequiresActionRejectCase{
|
||||
{
|
||||
name: "missing_required_tool_result",
|
||||
results: func(seeded seededChat) []chatstate.ToolResultInput { return nil },
|
||||
},
|
||||
{
|
||||
name: "extra_tool_result",
|
||||
results: func(seeded seededChat) []chatstate.ToolResultInput {
|
||||
return []chatstate.ToolResultInput{valid(seeded.pendingToolCallID), valid("call_extra")}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "duplicate_tool_call_id",
|
||||
results: func(seeded seededChat) []chatstate.ToolResultInput {
|
||||
return []chatstate.ToolResultInput{valid(seeded.pendingToolCallID), valid(seeded.pendingToolCallID)}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "mismatched_tool_call_id",
|
||||
results: func(seeded seededChat) []chatstate.ToolResultInput {
|
||||
return []chatstate.ToolResultInput{valid("call_mismatch")}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid_json_output",
|
||||
results: func(seeded seededChat) []chatstate.ToolResultInput {
|
||||
return []chatstate.ToolResultInput{{ToolCallID: seeded.pendingToolCallID, Output: json.RawMessage(`{`)}}
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func recordRetryStateRejectCases() []recordRetryStateRejectCase {
|
||||
return []recordRetryStateRejectCase{
|
||||
{
|
||||
name: "sql_null_payload",
|
||||
},
|
||||
{
|
||||
name: "empty_payload",
|
||||
retryState: pqtype.NullRawMessage{RawMessage: json.RawMessage(``), Valid: true},
|
||||
},
|
||||
{
|
||||
name: "invalid_json_payload",
|
||||
retryState: pqtype.NullRawMessage{RawMessage: json.RawMessage(`{`), Valid: true},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func runCompleteRequiresActionRejectCase(t *testing.T, tc completeRequiresActionRejectCase) {
|
||||
t.Helper()
|
||||
f := newTestFixture(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
seeded := seedAOrA1(t, f, 0, "reject_complete_requires_action")
|
||||
require.Equal(t, chatstate.StateA0, f.classify(ctx, t, seeded.chatID))
|
||||
base := captureBaseline(ctx, t, f, seeded)
|
||||
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, seeded.chatID, chatstate.Options{})
|
||||
err := m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
_, cerr := tx.CompleteRequiresAction(chatstate.CompleteRequiresActionInput{
|
||||
CreatedBy: f.User.ID,
|
||||
ModelConfigID: f.Model.ID,
|
||||
Results: tc.results(seeded),
|
||||
})
|
||||
return cerr
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, chatstate.ErrTransitionNotAllowed)
|
||||
var te *chatstate.TransitionError
|
||||
require.ErrorAs(t, err, &te)
|
||||
require.Equal(t, chatstate.TransitionCompleteRequiresAction, te.Transition)
|
||||
require.Equal(t, chatstate.StateA0, te.From)
|
||||
assertNoMutationOrPublish(ctx, t, f, seeded.chatID, base)
|
||||
}
|
||||
|
||||
func runRecordRetryStateRejectCase(t *testing.T, tc recordRetryStateRejectCase) {
|
||||
t.Helper()
|
||||
f := newTestFixture(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
seeded := seedState(t, f, chatstate.StateR0)
|
||||
require.Equal(t, chatstate.StateR0, f.classify(ctx, t, seeded.chatID))
|
||||
base := captureBaseline(ctx, t, f, seeded)
|
||||
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, seeded.chatID, chatstate.Options{})
|
||||
err := m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
_, rerr := tx.RecordRetryState(chatstate.RecordRetryStateInput{
|
||||
RetryState: tc.retryState,
|
||||
})
|
||||
return rerr
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, chatstate.ErrTransitionNotAllowed)
|
||||
var te *chatstate.TransitionError
|
||||
require.ErrorAs(t, err, &te)
|
||||
require.Equal(t, chatstate.TransitionRecordRetryState, te.Transition)
|
||||
require.Equal(t, chatstate.StateR0, te.From)
|
||||
assertNoMutationOrPublish(ctx, t, f, seeded.chatID, base)
|
||||
}
|
||||
|
||||
// TestTransitionInputValidation groups every input-specific rejection
|
||||
// test. The matrix coverage entry point in
|
||||
// TestTransitionMatrix_AllCombinations intentionally focuses on
|
||||
// positive cases and generated disallowed cases; rejection cases that
|
||||
// exercise legal matrix rows with invalid inputs live here so the
|
||||
// matrix entry point stays focused.
|
||||
func TestTransitionInputValidation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("SetArchived_wrong_direction", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
for _, tc := range setArchivedWrongDirectionCases() {
|
||||
t.Run(tc.label, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
runSetArchivedWrongDirectionCase(t, tc)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SendMessage_invalid_busy_behavior", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
for _, from := range chatstate.AllowedInputStates(chatstate.TransitionSendMessage) {
|
||||
for _, bb := range invalidBusyBehaviors {
|
||||
label := from.String() + "/" + string(bb)
|
||||
if bb == "" {
|
||||
label = from.String() + "/empty"
|
||||
}
|
||||
t.Run(label, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
runInvalidBusyBehaviorCase(t, from, bb)
|
||||
})
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CompleteRequiresAction_invalid_results", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
for _, tc := range completeRequiresActionRejectCases() {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
runCompleteRequiresActionRejectCase(t, tc)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("RecordRetryState_invalid_payload", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
for _, tc := range recordRetryStateRejectCases() {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
runRecordRetryStateRejectCase(t, tc)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestSendMessageQueueCapRejectsQueueAppend seeds a chat with the
|
||||
// maximum queued messages and asserts that the next SendMessage in
|
||||
// a queue-appending state returns chatstate.ErrMessageQueueFull and
|
||||
// rolls back without persisting another queued row.
|
||||
func TestSendMessageQueueCapRejectsQueueAppend(t *testing.T) {
|
||||
t.Parallel()
|
||||
f := newTestFixture(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
created := createTestChat(t, f)
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
|
||||
|
||||
// createTestChat lands the chat in R0; SendMessage in R0 with
|
||||
// BusyBehaviorQueue queues. Fill the queue to MaxQueueSize.
|
||||
for i := 0; i < chatstate.MaxQueueSize; i++ {
|
||||
sendQueuedMessage(t, f, m, "filler")
|
||||
}
|
||||
count, err := f.DB.CountChatQueuedMessages(ctx, created.Chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, chatstate.MaxQueueSize, count)
|
||||
chatBefore := f.readChat(ctx, t, created.Chat.ID)
|
||||
|
||||
// The next queue append must fail with ErrMessageQueueFull and a
|
||||
// typed wrapper that exposes the cap.
|
||||
err = m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
_, serr := tx.SendMessage(chatstate.SendMessageInput{
|
||||
Message: userTextMessage("overflow", f.User.ID, f.Model.ID),
|
||||
BusyBehavior: chatstate.BusyBehaviorQueue,
|
||||
})
|
||||
return serr
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, chatstate.ErrMessageQueueFull,
|
||||
"queue-append over the cap returns ErrMessageQueueFull")
|
||||
var typed *chatstate.MessageQueueFullError
|
||||
require.ErrorAs(t, err, &typed, "ErrMessageQueueFull is carried as a typed error")
|
||||
require.EqualValues(t, chatstate.MaxQueueSize, typed.Max)
|
||||
|
||||
// The transaction rolled back: queue size, snapshot version,
|
||||
// and queue version are unchanged.
|
||||
countAfter, err := f.DB.CountChatQueuedMessages(ctx, created.Chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, chatstate.MaxQueueSize, countAfter,
|
||||
"queue size must not change when the cap rejects the append")
|
||||
chatAfter := f.readChat(ctx, t, created.Chat.ID)
|
||||
require.Equal(t, chatBefore.SnapshotVersion, chatAfter.SnapshotVersion,
|
||||
"failed queue append must not bump snapshot_version")
|
||||
require.Equal(t, chatBefore.QueueVersion, chatAfter.QueueVersion,
|
||||
"failed queue append must not bump queue_version")
|
||||
}
|
||||
|
||||
// TestEditMessageNonUserReturnsSentinel asserts that editing a
|
||||
// non-user message returns chatstate.ErrEditedMessageNotUser via
|
||||
// the TransitionError cause chain, and still matches the generic
|
||||
// transition sentinel.
|
||||
func TestEditMessageNonUserReturnsSentinel(t *testing.T) {
|
||||
t.Parallel()
|
||||
f := newTestFixture(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
created := createTestChat(t, f)
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
|
||||
|
||||
// Insert an assistant message via CommitStep so we have a
|
||||
// non-user message to target.
|
||||
var assistantID int64
|
||||
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
assistant := userTextMessage("assistant", f.User.ID, f.Model.ID)
|
||||
assistant.Role = database.ChatMessageRoleAssistant
|
||||
step, err := tx.CommitStep(chatstate.CommitStepInput{
|
||||
Messages: []chatstate.Message{assistant},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
require.Len(t, step.InsertedMessages, 1)
|
||||
assistantID = step.InsertedMessages[0].ID
|
||||
return nil
|
||||
}))
|
||||
|
||||
rawContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageText("new content"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
editErr := m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
_, eerr := tx.EditMessage(chatstate.EditMessageInput{
|
||||
MessageID: assistantID,
|
||||
CreatedBy: f.User.ID,
|
||||
Content: rawContent,
|
||||
})
|
||||
return eerr
|
||||
})
|
||||
require.Error(t, editErr)
|
||||
require.ErrorIs(t, editErr, chatstate.ErrEditedMessageNotUser,
|
||||
"non-user edit returns ErrEditedMessageNotUser via TransitionError cause")
|
||||
require.ErrorIs(t, editErr, chatstate.ErrTransitionNotAllowed,
|
||||
"ErrEditedMessageNotUser still matches the generic transition sentinel")
|
||||
}
|
||||
|
||||
// TestTransitionAbandon_RejectsUnowned verifies that calling Abandon
|
||||
// on a chat the runner does not own returns ErrTransitionNotAllowed
|
||||
// wrapped in a TransitionError that records the loaded from-state,
|
||||
// without mutating chat state or publishing anything.
|
||||
func TestTransitionAbandon_RejectsUnowned(t *testing.T) {
|
||||
t.Parallel()
|
||||
f := newTestFixture(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
created := createTestChat(t, f)
|
||||
seeded := seededChat{chatID: created.Chat.ID, exists: true}
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
|
||||
base := captureBaseline(ctx, t, f, seeded)
|
||||
|
||||
err := m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
_, aerr := tx.Abandon(chatstate.AbandonInput{})
|
||||
return aerr
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, chatstate.ErrTransitionNotAllowed)
|
||||
var te *chatstate.TransitionError
|
||||
require.ErrorAs(t, err, &te)
|
||||
require.Equal(t, chatstate.TransitionAbandon, te.Transition)
|
||||
// createTestChat lands the chat in R0; Abandon's precondition
|
||||
// rejects an unowned chat there.
|
||||
require.Equal(t, chatstate.StateR0, te.From)
|
||||
assertNoMutationOrPublish(ctx, t, f, seeded.chatID, base)
|
||||
}
|
||||
|
||||
// TestTransitionAbandon_ClearsOwnership verifies the Acquire/Abandon
|
||||
// round-trip: after Acquire the chat carries a worker+runner and a
|
||||
// fresh heartbeat row exists, and after Abandon both ownership fields
|
||||
// are cleared. The heartbeat row is not deleted by Abandon; heartbeat
|
||||
// cleanup is a separate concern.
|
||||
func TestTransitionAbandon_ClearsOwnership(t *testing.T) {
|
||||
t.Parallel()
|
||||
f := newTestFixture(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
created := createTestChat(t, f)
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
|
||||
|
||||
worker := uuid.New()
|
||||
runner := uuid.New()
|
||||
|
||||
// Acquire writes ownership and a fresh heartbeat row.
|
||||
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
_, err := tx.Acquire(chatstate.AcquireInput{WorkerID: worker, RunnerID: runner})
|
||||
return err
|
||||
}))
|
||||
owned := f.readChat(ctx, t, created.Chat.ID)
|
||||
require.Equal(t, worker, owned.WorkerID.UUID)
|
||||
require.Equal(t, runner, owned.RunnerID.UUID)
|
||||
hb, err := f.DB.GetChatHeartbeat(ctx, database.GetChatHeartbeatParams{
|
||||
ChatID: created.Chat.ID,
|
||||
RunnerID: runner,
|
||||
})
|
||||
require.NoError(t, err, "Acquire writes a fresh heartbeat row")
|
||||
require.Equal(t, runner, hb.RunnerID)
|
||||
|
||||
// Abandon clears ownership but leaves the heartbeat row intact.
|
||||
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
_, err := tx.Abandon(chatstate.AbandonInput{})
|
||||
return err
|
||||
}))
|
||||
hb, err = f.DB.GetChatHeartbeat(ctx, database.GetChatHeartbeatParams{
|
||||
ChatID: created.Chat.ID,
|
||||
RunnerID: runner,
|
||||
})
|
||||
require.NoError(t, err, "Abandon does not delete the heartbeat row")
|
||||
abandoned := f.readChat(ctx, t, created.Chat.ID)
|
||||
require.False(t, abandoned.WorkerID.Valid, "Abandon clears worker_id")
|
||||
require.False(t, abandoned.RunnerID.Valid, "Abandon clears runner_id")
|
||||
}
|
||||
|
||||
// TestTransitionAcquire_OverwritesFreshOwnership verifies that Acquire
|
||||
// is an unconditional ownership handoff: a second worker calling
|
||||
// Acquire on a chat that was *just* acquired by another worker
|
||||
// successfully replaces ownership without inspecting heartbeat
|
||||
// freshness. It also asserts that Acquire itself does not request an
|
||||
// ownership hint, so the post-commit publish stays quiet on
|
||||
// `chat:ownership` when the resulting heartbeat is fresh.
|
||||
func TestTransitionAcquire_OverwritesFreshOwnership(t *testing.T) {
|
||||
t.Parallel()
|
||||
f := newTestFixture(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
created := createTestChat(t, f)
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
|
||||
|
||||
firstWorker := uuid.New()
|
||||
firstRunner := uuid.New()
|
||||
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
_, err := tx.Acquire(chatstate.AcquireInput{WorkerID: firstWorker, RunnerID: firstRunner})
|
||||
return err
|
||||
}))
|
||||
|
||||
// The chat is now owned with a fresh (chat_id, firstRunner)
|
||||
// heartbeat written by the first Acquire.
|
||||
firstChat := f.readChat(ctx, t, created.Chat.ID)
|
||||
require.Equal(t, firstWorker, firstChat.WorkerID.UUID)
|
||||
require.Equal(t, firstRunner, firstChat.RunnerID.UUID)
|
||||
_, err := f.DB.GetChatHeartbeat(ctx, database.GetChatHeartbeatParams{
|
||||
ChatID: created.Chat.ID,
|
||||
RunnerID: firstRunner,
|
||||
})
|
||||
require.NoError(t, err, "first Acquire wrote a fresh heartbeat")
|
||||
// Sanity check: heartbeat is not stale by the same threshold the
|
||||
// machine uses for ownership-hint decisions.
|
||||
stale, err := f.DB.IsChatHeartbeatStale(ctx, database.IsChatHeartbeatStaleParams{
|
||||
ChatID: created.Chat.ID,
|
||||
RunnerID: firstRunner,
|
||||
StaleSeconds: chatstate.HeartbeatStaleSeconds,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.False(t, stale, "first runner's heartbeat is fresh before the second Acquire")
|
||||
|
||||
// Snapshot publish counts before the takeover so we can assert
|
||||
// Acquire does not publish an ownership hint itself.
|
||||
ownershipBefore := f.Pub.ownershipPublishCount()
|
||||
beforeChat := f.readChat(ctx, t, created.Chat.ID)
|
||||
|
||||
secondWorker := uuid.New()
|
||||
secondRunner := uuid.New()
|
||||
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
_, err := tx.Acquire(chatstate.AcquireInput{WorkerID: secondWorker, RunnerID: secondRunner})
|
||||
return err
|
||||
}))
|
||||
|
||||
after := f.readChat(ctx, t, created.Chat.ID)
|
||||
require.Equal(t, secondWorker, after.WorkerID.UUID, "ownership replaced")
|
||||
require.Equal(t, secondRunner, after.RunnerID.UUID, "runner replaced")
|
||||
require.Equal(t, beforeChat.SnapshotVersion+1, after.SnapshotVersion, "snapshot bumps exactly once")
|
||||
f.Pub.expectChatUpdate(t, created.Chat.ID, after.SnapshotVersion)
|
||||
|
||||
// The new (chat_id, secondRunner) heartbeat exists. The old
|
||||
// (chat_id, firstRunner) row may or may not exist; Acquire is not
|
||||
// responsible for cleaning it up.
|
||||
_, err = f.DB.GetChatHeartbeat(ctx, database.GetChatHeartbeatParams{
|
||||
ChatID: created.Chat.ID,
|
||||
RunnerID: secondRunner,
|
||||
})
|
||||
require.NoError(t, err, "second Acquire wrote a heartbeat for the new runner")
|
||||
|
||||
// Acquire does not publish an ownership hint when it writes a fresh
|
||||
// heartbeat. The post-commit ownership-hint logic in Update stays
|
||||
// quiet because the new heartbeat is fresh, so no `chat:ownership`
|
||||
// notification fires.
|
||||
require.Equal(t, ownershipBefore, f.Pub.ownershipPublishCount(),
|
||||
"Acquire must not publish an ownership hint when the resulting heartbeat is fresh")
|
||||
}
|
||||
|
||||
// TestTransitionAcquire_ExecutionStateOrthogonal verifies that Acquire
|
||||
// preserves every execution-state field on the chat across
|
||||
// representative valid execution states, including idle, runnable, and
|
||||
// archived states. The transition only mutates ownership.
|
||||
func TestTransitionAcquire_ExecutionStateOrthogonal(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Each setup leaves the chat in the named state and returns the
|
||||
// chat ID for downstream assertions.
|
||||
cases := []struct {
|
||||
name string
|
||||
state chatstate.ExecutionState
|
||||
setup func(t *testing.T, f *testFixture) uuid.UUID
|
||||
}{
|
||||
{
|
||||
name: "R0",
|
||||
state: chatstate.StateR0,
|
||||
setup: func(t *testing.T, f *testFixture) uuid.UUID {
|
||||
return createTestChat(t, f).Chat.ID
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "W",
|
||||
state: chatstate.StateW,
|
||||
setup: func(t *testing.T, f *testFixture) uuid.UUID {
|
||||
created := createTestChat(t, f)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
|
||||
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
_, err := tx.FinishTurn(chatstate.FinishTurnInput{})
|
||||
return err
|
||||
}))
|
||||
return created.Chat.ID
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "E0",
|
||||
state: chatstate.StateE0,
|
||||
setup: func(t *testing.T, f *testFixture) uuid.UUID {
|
||||
created := createTestChat(t, f)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
|
||||
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
_, err := tx.FinishError(chatstate.FinishErrorInput{
|
||||
LastError: pqtype.NullRawMessage{
|
||||
RawMessage: json.RawMessage(`{"message":"boom"}`),
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
return err
|
||||
}))
|
||||
return created.Chat.ID
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "I0",
|
||||
state: chatstate.StateI0,
|
||||
setup: func(t *testing.T, f *testFixture) uuid.UUID {
|
||||
created := createTestChat(t, f)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
|
||||
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
_, err := tx.Interrupt(chatstate.InterruptInput{Reason: "test"})
|
||||
return err
|
||||
}))
|
||||
return created.Chat.ID
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "XW",
|
||||
state: chatstate.StateXW,
|
||||
setup: func(t *testing.T, f *testFixture) uuid.UUID {
|
||||
created := createTestChat(t, f)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, created.Chat.ID, chatstate.Options{})
|
||||
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
_, err := tx.FinishTurn(chatstate.FinishTurnInput{})
|
||||
return err
|
||||
}))
|
||||
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
_, err := tx.SetArchived(chatstate.SetArchivedInput{Archived: true})
|
||||
return err
|
||||
}))
|
||||
return created.Chat.ID
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
f := newTestFixture(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
chatID := tc.setup(t, f)
|
||||
require.Equal(t, tc.state, f.classify(ctx, t, chatID), "test setup must leave chat in %s", tc.state)
|
||||
|
||||
before := f.readChat(ctx, t, chatID)
|
||||
queueBefore, err := f.DB.CountChatQueuedMessages(ctx, chatID)
|
||||
require.NoError(t, err)
|
||||
historyBefore := historyMessageIDs(ctx, t, f, chatID)
|
||||
|
||||
worker := uuid.New()
|
||||
runner := uuid.New()
|
||||
m := chatstate.NewChatMachine(f.DB, f.Pub, chatID, chatstate.Options{})
|
||||
require.NoError(t, m.Update(ctx, func(tx *chatstate.Tx) error {
|
||||
_, err := tx.Acquire(chatstate.AcquireInput{WorkerID: worker, RunnerID: runner})
|
||||
return err
|
||||
}))
|
||||
|
||||
after := f.readChat(ctx, t, chatID)
|
||||
// Ownership updated.
|
||||
require.Equal(t, worker, after.WorkerID.UUID)
|
||||
require.Equal(t, runner, after.RunnerID.UUID)
|
||||
// Execution state preserved.
|
||||
require.Equal(t, before.Status, after.Status, "status preserved")
|
||||
require.Equal(t, before.Archived, after.Archived, "archived flag preserved")
|
||||
require.Equal(t, before.RequiresActionDeadlineAt, after.RequiresActionDeadlineAt, "requires-action deadline preserved")
|
||||
require.Equal(t, before.LastError, after.LastError, "last_error preserved")
|
||||
require.Equal(t, before.HistoryVersion, after.HistoryVersion, "history_version preserved")
|
||||
require.Equal(t, before.QueueVersion, after.QueueVersion, "queue_version preserved")
|
||||
require.Equal(t, before.GenerationAttempt, after.GenerationAttempt, "generation_attempt preserved")
|
||||
// Classified state unchanged.
|
||||
require.Equal(t, tc.state, f.classify(ctx, t, chatID), "execution state preserved by Acquire")
|
||||
// Queue and history rows untouched.
|
||||
queueAfter, err := f.DB.CountChatQueuedMessages(ctx, chatID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, queueBefore, queueAfter, "queue cardinality preserved")
|
||||
require.Equal(t, historyBefore, historyMessageIDs(ctx, t, f, chatID), "history preserved")
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,631 @@
|
||||
package chatstate_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
// triggerFixture is a slim variant of testFixture that also exposes a
|
||||
// raw *sql.DB so the trigger tests can run UPDATE/INSERT statements
|
||||
// that bypass the typed sqlc layer. Tests that only need the typed
|
||||
// store should keep using newTestFixture.
|
||||
type triggerFixture struct {
|
||||
f *testFixture
|
||||
sqlDB *sql.DB
|
||||
}
|
||||
|
||||
func newTriggerFixture(t *testing.T) *triggerFixture {
|
||||
t.Helper()
|
||||
db, ps, sqlDB := dbtestutil.NewDBWithSQLDB(t)
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
||||
UserID: user.ID,
|
||||
OrganizationID: org.ID,
|
||||
})
|
||||
dbgen.ChatProvider(t, db, database.ChatProvider{
|
||||
Provider: "openai",
|
||||
DisplayName: "openai",
|
||||
BaseUrl: "http://example.invalid",
|
||||
})
|
||||
model := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{
|
||||
Provider: "openai",
|
||||
IsDefault: true,
|
||||
})
|
||||
f := &testFixture{
|
||||
DB: db,
|
||||
PubSub: ps,
|
||||
Pub: newRecordingPubsub(),
|
||||
User: user,
|
||||
Org: org,
|
||||
Model: model,
|
||||
}
|
||||
return &triggerFixture{f: f, sqlDB: sqlDB}
|
||||
}
|
||||
|
||||
// userMessageContent returns a marshaled user message body suitable
|
||||
// for raw INSERT into chat_messages.
|
||||
func userMessageContent(t *testing.T, text string) []byte {
|
||||
t.Helper()
|
||||
raw, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{codersdk.ChatMessageText(text)})
|
||||
require.NoError(t, err)
|
||||
return raw.RawMessage
|
||||
}
|
||||
|
||||
// TestMessageInsertAssignsRevisionAndHistoryVersion verifies that
|
||||
// inserting a chat message via the legacy InsertChatMessages query
|
||||
// assigns NEW.revision from chats.snapshot_version (BEFORE trigger)
|
||||
// and bumps chats.history_version + resets generation_attempt (AFTER
|
||||
// STATEMENT trigger).
|
||||
func TestMessageInsertAssignsRevisionAndHistoryVersion(t *testing.T) {
|
||||
t.Parallel()
|
||||
tf := newTriggerFixture(t)
|
||||
f := tf.f
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
created := createTestChat(t, f)
|
||||
require.Equal(t, int64(1), created.Chat.SnapshotVersion)
|
||||
require.Equal(t, int64(1), created.Chat.HistoryVersion)
|
||||
|
||||
// Force generation_attempt > 0 so we can prove the trigger
|
||||
// resets it on a new history change.
|
||||
_, err := f.DB.IncrementChatGenerationAttempt(ctx, created.Chat.ID)
|
||||
require.NoError(t, err)
|
||||
before, err := f.DB.GetChatByID(ctx, created.Chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), before.GenerationAttempt)
|
||||
|
||||
// Bump snapshot_version directly to simulate a transition having
|
||||
// taken the row lock.
|
||||
bumped, err := f.DB.LockChatAndBumpSnapshotVersion(ctx, created.Chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, before.SnapshotVersion+1, bumped.SnapshotVersion)
|
||||
|
||||
// Insert a new assistant message via raw SQL so we know the
|
||||
// BEFORE+AFTER triggers (and only those) decide revision and
|
||||
// history_version.
|
||||
content := userMessageContent(t, "hello-after-bump")
|
||||
_, err = tf.sqlDB.ExecContext(ctx, `
|
||||
INSERT INTO chat_messages (chat_id, role, content, content_version, visibility)
|
||||
VALUES ($1, 'assistant', $2::jsonb, $3, 'both')
|
||||
`, created.Chat.ID, string(content), int(chatprompt.CurrentContentVersion))
|
||||
require.NoError(t, err)
|
||||
|
||||
// History version equals snapshot_version, generation_attempt resets.
|
||||
after, err := f.DB.GetChatByID(ctx, created.Chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, bumped.SnapshotVersion, after.HistoryVersion)
|
||||
require.Equal(t, int64(0), after.GenerationAttempt)
|
||||
|
||||
// The inserted message picked up revision = bumped snapshot.
|
||||
msgs, err := f.DB.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: created.Chat.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, msgs)
|
||||
last := msgs[len(msgs)-1]
|
||||
require.Equal(t, database.ChatMessageRoleAssistant, last.Role)
|
||||
require.Equal(t, bumped.SnapshotVersion, last.Revision)
|
||||
}
|
||||
|
||||
// TestMessageUpdateAssignsNewRevisionAndHistoryVersion verifies that
|
||||
// updating a chat message's content advances NEW.revision to the
|
||||
// current chats.snapshot_version and that chats.history_version
|
||||
// bumps to match.
|
||||
func TestMessageUpdateAssignsNewRevisionAndHistoryVersion(t *testing.T) {
|
||||
t.Parallel()
|
||||
tf := newTriggerFixture(t)
|
||||
f := tf.f
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
created := createTestChat(t, f)
|
||||
|
||||
msgs, err := f.DB.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: created.Chat.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, msgs)
|
||||
target := msgs[0]
|
||||
originalRevision := target.Revision
|
||||
|
||||
// Bump the snapshot so the trigger sees a new revision target.
|
||||
bumped, err := f.DB.LockChatAndBumpSnapshotVersion(ctx, created.Chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Greater(t, bumped.SnapshotVersion, originalRevision)
|
||||
|
||||
newContent := userMessageContent(t, "edited content")
|
||||
_, err = tf.sqlDB.ExecContext(ctx, `
|
||||
UPDATE chat_messages SET content = $1::jsonb WHERE id = $2
|
||||
`, string(newContent), target.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
reloaded, err := f.DB.GetChatMessageByID(ctx, target.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, bumped.SnapshotVersion, reloaded.Revision,
|
||||
"updated message picks up the current snapshot version")
|
||||
|
||||
chatAfter, err := f.DB.GetChatByID(ctx, created.Chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, bumped.SnapshotVersion, chatAfter.HistoryVersion)
|
||||
require.Equal(t, int64(0), chatAfter.GenerationAttempt,
|
||||
"history change resets generation_attempt")
|
||||
}
|
||||
|
||||
// TestMessageRevisionCannotBeSetByRuntimeCode verifies the BEFORE
|
||||
// trigger rejects explicit revision values on INSERT and rejects
|
||||
// revision changes on UPDATE.
|
||||
func TestMessageRevisionCannotBeSetByRuntimeCode(t *testing.T) {
|
||||
t.Parallel()
|
||||
tf := newTriggerFixture(t)
|
||||
f := tf.f
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
created := createTestChat(t, f)
|
||||
|
||||
content := userMessageContent(t, "explicit revision")
|
||||
_, err := tf.sqlDB.ExecContext(ctx, `
|
||||
INSERT INTO chat_messages (chat_id, role, content, content_version, visibility, revision)
|
||||
VALUES ($1, 'user', $2::jsonb, $3, 'both', 999)
|
||||
`, created.Chat.ID, string(content), int(chatprompt.CurrentContentVersion))
|
||||
require.Error(t, err, "INSERT with explicit revision must be rejected")
|
||||
require.Contains(t, err.Error(), "revision must be assigned by trigger")
|
||||
|
||||
msgs, err := f.DB.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: created.Chat.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, msgs)
|
||||
target := msgs[0]
|
||||
|
||||
_, err = tf.sqlDB.ExecContext(ctx, `
|
||||
UPDATE chat_messages SET revision = revision + 100 WHERE id = $1
|
||||
`, target.ID)
|
||||
require.Error(t, err, "UPDATE that changes revision must be rejected")
|
||||
require.Contains(t, err.Error(), "revision must be assigned by trigger")
|
||||
}
|
||||
|
||||
// TestMessageChatIDCannotChange verifies the BEFORE trigger rejects
|
||||
// updates that change chat_messages.chat_id.
|
||||
func TestMessageChatIDCannotChange(t *testing.T) {
|
||||
t.Parallel()
|
||||
tf := newTriggerFixture(t)
|
||||
f := tf.f
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
first := createTestChat(t, f)
|
||||
second := createTestChat(t, f)
|
||||
|
||||
firstMsgs, err := f.DB.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: first.Chat.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, firstMsgs)
|
||||
target := firstMsgs[0]
|
||||
|
||||
_, err = tf.sqlDB.ExecContext(ctx, `
|
||||
UPDATE chat_messages SET chat_id = $1 WHERE id = $2
|
||||
`, second.Chat.ID, target.ID)
|
||||
require.Error(t, err, "UPDATE that changes chat_id must be rejected")
|
||||
require.Contains(t, err.Error(), "chat_id is immutable")
|
||||
}
|
||||
|
||||
// TestNoopMessageUpdateDoesNotAdvanceHistoryVersion verifies that a
|
||||
// no-op UPDATE on a chat_messages row (one whose OLD and NEW are
|
||||
// indistinguishable) does NOT advance chats.history_version even
|
||||
// when the snapshot was previously bumped. This guards against the
|
||||
// AFTER UPDATE STATEMENT trigger naively reacting to every touched
|
||||
// row id regardless of whether the row actually changed.
|
||||
func TestNoopMessageUpdateDoesNotAdvanceHistoryVersion(t *testing.T) {
|
||||
t.Parallel()
|
||||
tf := newTriggerFixture(t)
|
||||
f := tf.f
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
created := createTestChat(t, f)
|
||||
|
||||
msgs, err := f.DB.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: created.Chat.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, msgs)
|
||||
target := msgs[0]
|
||||
originalRevision := target.Revision
|
||||
|
||||
// Bump snapshot so the AFTER STATEMENT guard
|
||||
// (history_version != snapshot_version) is now true.
|
||||
bumped, err := f.DB.LockChatAndBumpSnapshotVersion(ctx, created.Chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, bumped.SnapshotVersion, bumped.HistoryVersion,
|
||||
"snapshot bump leaves history_version trailing")
|
||||
|
||||
// No-op UPDATE: SET content = content. OLD IS NOT DISTINCT FROM NEW.
|
||||
_, err = tf.sqlDB.ExecContext(ctx, `
|
||||
UPDATE chat_messages SET content = content WHERE id = $1
|
||||
`, target.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
after, err := f.DB.GetChatByID(ctx, created.Chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, bumped.HistoryVersion, after.HistoryVersion,
|
||||
"no-op update must NOT advance history_version")
|
||||
|
||||
// And the row's revision is untouched.
|
||||
reloaded, err := f.DB.GetChatMessageByID(ctx, target.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, originalRevision, reloaded.Revision,
|
||||
"no-op update must NOT advance message revision")
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Queue version triggers
|
||||
// =============================================================================
|
||||
|
||||
// TestQueueInsertUpdatesQueueVersion verifies that an INSERT into
|
||||
// chat_queued_messages bumps chats.queue_version to the current
|
||||
// snapshot_version.
|
||||
func TestQueueInsertUpdatesQueueVersion(t *testing.T) {
|
||||
t.Parallel()
|
||||
tf := newTriggerFixture(t)
|
||||
f := tf.f
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
created := createTestChat(t, f)
|
||||
before, err := f.DB.GetChatByID(ctx, created.Chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(0), before.QueueVersion)
|
||||
|
||||
bumped, err := f.DB.LockChatAndBumpSnapshotVersion(ctx, created.Chat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
content := userMessageContent(t, "queued")
|
||||
_, err = f.DB.InsertChatQueuedMessageWithCreator(ctx, database.InsertChatQueuedMessageWithCreatorParams{
|
||||
ChatID: created.Chat.ID,
|
||||
Content: content,
|
||||
CreatedBy: f.User.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
after, err := f.DB.GetChatByID(ctx, created.Chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, bumped.SnapshotVersion, after.QueueVersion,
|
||||
"INSERT into chat_queued_messages bumps queue_version")
|
||||
}
|
||||
|
||||
// TestQueuedMessageCreatedByIsRequired verifies the database enforces
|
||||
// creator metadata for every queued message row.
|
||||
func TestQueuedMessageCreatedByIsRequired(t *testing.T) {
|
||||
t.Parallel()
|
||||
tf := newTriggerFixture(t)
|
||||
f := tf.f
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
created := createTestChat(t, f)
|
||||
|
||||
content := userMessageContent(t, "queued-without-creator")
|
||||
_, err := tf.sqlDB.ExecContext(ctx, `
|
||||
INSERT INTO chat_queued_messages (chat_id, content, model_config_id, created_by)
|
||||
VALUES ($1, $2::jsonb, NULL, NULL)
|
||||
`, created.Chat.ID, string(content))
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "created_by")
|
||||
}
|
||||
|
||||
func TestLegacyQueuedMessageInsertUsesChatOwnerAsCreator(t *testing.T) {
|
||||
t.Parallel()
|
||||
tf := newTriggerFixture(t)
|
||||
f := tf.f
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
created := createTestChat(t, f)
|
||||
|
||||
queued, err := f.DB.InsertChatQueuedMessage(ctx, database.InsertChatQueuedMessageParams{
|
||||
ChatID: created.Chat.ID,
|
||||
Content: userMessageContent(t, "legacy-queued"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, created.Chat.OwnerID, queued.CreatedBy)
|
||||
}
|
||||
|
||||
// TestQueueUpdateContentUpdatesQueueVersion verifies that an UPDATE
|
||||
// of chat_queued_messages.content bumps queue_version. The
|
||||
// AFTER UPDATE trigger explicitly listens for content changes.
|
||||
func TestQueueUpdateContentUpdatesQueueVersion(t *testing.T) {
|
||||
t.Parallel()
|
||||
tf := newTriggerFixture(t)
|
||||
f := tf.f
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
created := createTestChat(t, f)
|
||||
|
||||
queued, err := f.DB.InsertChatQueuedMessageWithCreator(ctx, database.InsertChatQueuedMessageWithCreatorParams{
|
||||
ChatID: created.Chat.ID,
|
||||
Content: userMessageContent(t, "initial"),
|
||||
CreatedBy: f.User.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
before, err := f.DB.GetChatByID(ctx, created.Chat.ID)
|
||||
require.NoError(t, err)
|
||||
bumped, err := f.DB.LockChatAndBumpSnapshotVersion(ctx, created.Chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Greater(t, bumped.SnapshotVersion, before.QueueVersion)
|
||||
|
||||
updated := userMessageContent(t, "updated")
|
||||
_, err = tf.sqlDB.ExecContext(ctx, `
|
||||
UPDATE chat_queued_messages SET content = $1::jsonb WHERE id = $2
|
||||
`, string(updated), queued.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
after, err := f.DB.GetChatByID(ctx, created.Chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, bumped.SnapshotVersion, after.QueueVersion,
|
||||
"UPDATE of queued content bumps queue_version")
|
||||
}
|
||||
|
||||
// TestQueueUpdatePositionUpdatesQueueVersion verifies that an UPDATE
|
||||
// of chat_queued_messages.position (such as the reorder-to-head
|
||||
// path) bumps queue_version.
|
||||
func TestQueueUpdatePositionUpdatesQueueVersion(t *testing.T) {
|
||||
t.Parallel()
|
||||
tf := newTriggerFixture(t)
|
||||
f := tf.f
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
created := createTestChat(t, f)
|
||||
|
||||
q1, err := f.DB.InsertChatQueuedMessageWithCreator(ctx, database.InsertChatQueuedMessageWithCreatorParams{
|
||||
ChatID: created.Chat.ID,
|
||||
Content: userMessageContent(t, "first"),
|
||||
CreatedBy: f.User.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
q2, err := f.DB.InsertChatQueuedMessageWithCreator(ctx, database.InsertChatQueuedMessageWithCreatorParams{
|
||||
ChatID: created.Chat.ID,
|
||||
Content: userMessageContent(t, "second"),
|
||||
CreatedBy: f.User.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, q1.ID, q2.ID)
|
||||
|
||||
bumped, err := f.DB.LockChatAndBumpSnapshotVersion(ctx, created.Chat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Move q2 to head by setting its position to q1.position - 1.
|
||||
_, err = tf.sqlDB.ExecContext(ctx, `
|
||||
UPDATE chat_queued_messages SET position = $1 WHERE id = $2
|
||||
`, q1.Position-1, q2.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
after, err := f.DB.GetChatByID(ctx, created.Chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, bumped.SnapshotVersion, after.QueueVersion,
|
||||
"UPDATE of queued position bumps queue_version")
|
||||
}
|
||||
|
||||
// TestQueueDeleteUpdatesQueueVersion verifies that DELETE from
|
||||
// chat_queued_messages bumps queue_version.
|
||||
func TestQueueDeleteUpdatesQueueVersion(t *testing.T) {
|
||||
t.Parallel()
|
||||
tf := newTriggerFixture(t)
|
||||
f := tf.f
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
created := createTestChat(t, f)
|
||||
|
||||
queued, err := f.DB.InsertChatQueuedMessageWithCreator(ctx, database.InsertChatQueuedMessageWithCreatorParams{
|
||||
ChatID: created.Chat.ID,
|
||||
Content: userMessageContent(t, "to delete"),
|
||||
CreatedBy: f.User.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
bumped, err := f.DB.LockChatAndBumpSnapshotVersion(ctx, created.Chat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
rows, err := f.DB.DeleteChatQueuedMessageReturningCount(ctx, database.DeleteChatQueuedMessageReturningCountParams{
|
||||
ID: queued.ID,
|
||||
ChatID: created.Chat.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), rows)
|
||||
|
||||
after, err := f.DB.GetChatByID(ctx, created.Chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, bumped.SnapshotVersion, after.QueueVersion,
|
||||
"DELETE from queue bumps queue_version")
|
||||
}
|
||||
|
||||
// TestNonQueueUpdateDoesNotUpdateQueueVersion verifies that mutations
|
||||
// on other chat-related tables do NOT bump queue_version. The
|
||||
// canonical case is inserting a chat message: it must update
|
||||
// history_version but leave queue_version untouched.
|
||||
func TestNonQueueUpdateDoesNotUpdateQueueVersion(t *testing.T) {
|
||||
t.Parallel()
|
||||
tf := newTriggerFixture(t)
|
||||
f := tf.f
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
created := createTestChat(t, f)
|
||||
before, err := f.DB.GetChatByID(ctx, created.Chat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
bumped, err := f.DB.LockChatAndBumpSnapshotVersion(ctx, created.Chat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
content := userMessageContent(t, "non-queue mutation")
|
||||
_, err = tf.sqlDB.ExecContext(ctx, `
|
||||
INSERT INTO chat_messages (chat_id, role, content, content_version, visibility)
|
||||
VALUES ($1, 'assistant', $2::jsonb, $3, 'both')
|
||||
`, created.Chat.ID, string(content), int(chatprompt.CurrentContentVersion))
|
||||
require.NoError(t, err)
|
||||
|
||||
after, err := f.DB.GetChatByID(ctx, created.Chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, before.QueueVersion, after.QueueVersion,
|
||||
"chat_messages INSERT must not bump queue_version")
|
||||
// Sanity: history_version DID move.
|
||||
require.Equal(t, bumped.SnapshotVersion, after.HistoryVersion)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Retry state triggers
|
||||
// =============================================================================
|
||||
|
||||
func TestRetryStateDefaults(t *testing.T) {
|
||||
t.Parallel()
|
||||
tf := newTriggerFixture(t)
|
||||
f := tf.f
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
created := createTestChat(t, f)
|
||||
|
||||
chat, err := f.DB.GetChatByID(ctx, created.Chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.False(t, chat.RetryState.Valid)
|
||||
require.Equal(t, int64(0), chat.RetryStateVersion)
|
||||
}
|
||||
|
||||
func TestRetryStateUpdateSetsRetryStateVersion(t *testing.T) {
|
||||
t.Parallel()
|
||||
tf := newTriggerFixture(t)
|
||||
f := tf.f
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
created := createTestChat(t, f)
|
||||
|
||||
bumped, err := f.DB.LockChatAndBumpSnapshotVersion(ctx, created.Chat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
after, err := f.DB.UpdateChatRetryState(ctx, database.UpdateChatRetryStateParams{
|
||||
ID: created.Chat.ID,
|
||||
RetryState: []byte(`{"attempt":1,"delay_ms":250,"error":"retry","retrying_at":"2026-05-29T00:00:00Z"}`),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.True(t, after.RetryState.Valid)
|
||||
require.JSONEq(t,
|
||||
`{"attempt":1,"delay_ms":250,"error":"retry","retrying_at":"2026-05-29T00:00:00Z"}`,
|
||||
string(after.RetryState.RawMessage))
|
||||
require.Equal(t, bumped.SnapshotVersion, after.RetryStateVersion)
|
||||
}
|
||||
|
||||
func TestRetryStateSameValueDoesNotUpdateRetryStateVersion(t *testing.T) {
|
||||
t.Parallel()
|
||||
tf := newTriggerFixture(t)
|
||||
f := tf.f
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
created := createTestChat(t, f)
|
||||
|
||||
payload := []byte(`{"attempt":1,"delay_ms":250,"error":"retry","retrying_at":"2026-05-29T00:00:00Z"}`)
|
||||
_, err := f.DB.LockChatAndBumpSnapshotVersion(ctx, created.Chat.ID)
|
||||
require.NoError(t, err)
|
||||
first, err := f.DB.UpdateChatRetryState(ctx, database.UpdateChatRetryStateParams{
|
||||
ID: created.Chat.ID,
|
||||
RetryState: payload,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = f.DB.LockChatAndBumpSnapshotVersion(ctx, created.Chat.ID)
|
||||
require.NoError(t, err)
|
||||
second, err := f.DB.UpdateChatRetryState(ctx, database.UpdateChatRetryStateParams{
|
||||
ID: created.Chat.ID,
|
||||
RetryState: payload,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, first.RetryStateVersion, second.RetryStateVersion,
|
||||
"same retry_state payload must not update retry_state_version")
|
||||
}
|
||||
|
||||
func TestGenerationAttemptClearsRetryState(t *testing.T) {
|
||||
t.Parallel()
|
||||
tf := newTriggerFixture(t)
|
||||
f := tf.f
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
created := createTestChat(t, f)
|
||||
|
||||
_, err := f.DB.LockChatAndBumpSnapshotVersion(ctx, created.Chat.ID)
|
||||
require.NoError(t, err)
|
||||
withRetry, err := f.DB.UpdateChatRetryState(ctx, database.UpdateChatRetryStateParams{
|
||||
ID: created.Chat.ID,
|
||||
RetryState: []byte(`{"attempt":1,"delay_ms":250,"error":"retry","retrying_at":"2026-05-29T00:00:00Z"}`),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.True(t, withRetry.RetryState.Valid)
|
||||
|
||||
bumped, err := f.DB.LockChatAndBumpSnapshotVersion(ctx, created.Chat.ID)
|
||||
require.NoError(t, err)
|
||||
attempt, err := f.DB.IncrementChatGenerationAttempt(ctx, created.Chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), attempt)
|
||||
|
||||
after, err := f.DB.GetChatByID(ctx, created.Chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.False(t, after.RetryState.Valid)
|
||||
require.Equal(t, bumped.SnapshotVersion, after.RetryStateVersion,
|
||||
"clearing retry_state on generation attempt bumps retry_state_version")
|
||||
}
|
||||
|
||||
func TestGenerationAttemptWithNullRetryStateDoesNotUpdateRetryStateVersion(t *testing.T) {
|
||||
t.Parallel()
|
||||
tf := newTriggerFixture(t)
|
||||
f := tf.f
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
created := createTestChat(t, f)
|
||||
before, err := f.DB.GetChatByID(ctx, created.Chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.False(t, before.RetryState.Valid)
|
||||
|
||||
_, err = f.DB.LockChatAndBumpSnapshotVersion(ctx, created.Chat.ID)
|
||||
require.NoError(t, err)
|
||||
_, err = f.DB.IncrementChatGenerationAttempt(ctx, created.Chat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
after, err := f.DB.GetChatByID(ctx, created.Chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.False(t, after.RetryState.Valid)
|
||||
require.Equal(t, before.RetryStateVersion, after.RetryStateVersion,
|
||||
"generation attempt with null retry_state leaves retry_state_version unchanged")
|
||||
}
|
||||
|
||||
func TestRetryStateVersionCannotBeSetByRuntimeCode(t *testing.T) {
|
||||
t.Parallel()
|
||||
tf := newTriggerFixture(t)
|
||||
f := tf.f
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
created := createTestChat(t, f)
|
||||
|
||||
_, err := tf.sqlDB.ExecContext(ctx, `
|
||||
UPDATE chats SET retry_state_version = retry_state_version + 1 WHERE id = $1
|
||||
`, created.Chat.ID)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "retry_state_version must be assigned by trigger")
|
||||
}
|
||||
|
||||
func TestHistoryChangeClearsRetryState(t *testing.T) {
|
||||
t.Parallel()
|
||||
tf := newTriggerFixture(t)
|
||||
f := tf.f
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
created := createTestChat(t, f)
|
||||
|
||||
_, err := f.DB.LockChatAndBumpSnapshotVersion(ctx, created.Chat.ID)
|
||||
require.NoError(t, err)
|
||||
_, err = f.DB.IncrementChatGenerationAttempt(ctx, created.Chat.ID)
|
||||
require.NoError(t, err)
|
||||
_, err = f.DB.UpdateChatRetryState(ctx, database.UpdateChatRetryStateParams{
|
||||
ID: created.Chat.ID,
|
||||
RetryState: []byte(`{"attempt":1,"delay_ms":250,"error":"retry","retrying_at":"2026-05-29T00:00:00Z"}`),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
bumped, err := f.DB.LockChatAndBumpSnapshotVersion(ctx, created.Chat.ID)
|
||||
require.NoError(t, err)
|
||||
content := userMessageContent(t, "history clears retry state")
|
||||
_, err = tf.sqlDB.ExecContext(ctx, `
|
||||
INSERT INTO chat_messages (chat_id, role, content, content_version, visibility)
|
||||
VALUES ($1, 'assistant', $2::jsonb, $3, 'both')
|
||||
`, created.Chat.ID, string(content), int(chatprompt.CurrentContentVersion))
|
||||
require.NoError(t, err)
|
||||
|
||||
after, err := f.DB.GetChatByID(ctx, created.Chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(0), after.GenerationAttempt)
|
||||
require.False(t, after.RetryState.Valid)
|
||||
require.Equal(t, bumped.SnapshotVersion, after.RetryStateVersion,
|
||||
"history reset of generation_attempt clears retry_state")
|
||||
}
|
||||
+40
-32
@@ -414,38 +414,46 @@ var auditableResourcesTypes = map[any]map[string]Action{
|
||||
"deleted_at": ActionIgnore, // Changes, but is implicit when a delete event is fired.
|
||||
},
|
||||
&database.Chat{}: {
|
||||
"id": ActionTrack,
|
||||
"owner_id": ActionTrack,
|
||||
"owner_username": ActionIgnore,
|
||||
"owner_name": ActionIgnore,
|
||||
"organization_id": ActionIgnore, // Never changes after creation.
|
||||
"workspace_id": ActionTrack,
|
||||
"build_id": ActionIgnore, // Internal lifecycle.
|
||||
"agent_id": ActionIgnore, // Internal lifecycle.
|
||||
"title": ActionSecret, // May contain sensitive content.
|
||||
"status": ActionIgnore, // Churns every message.
|
||||
"worker_id": ActionIgnore, // Internal.
|
||||
"started_at": ActionIgnore,
|
||||
"heartbeat_at": ActionIgnore, // Internal.
|
||||
"created_at": ActionIgnore, // Never changes.
|
||||
"updated_at": ActionIgnore, // Bumped on every mutation.
|
||||
"parent_chat_id": ActionIgnore, // Immutable after creation.
|
||||
"root_chat_id": ActionIgnore, // Immutable after creation.
|
||||
"last_model_config_id": ActionIgnore, // Churns every message.
|
||||
"archived": ActionTrack,
|
||||
"last_error": ActionIgnore, // Internal.
|
||||
"last_turn_summary": ActionIgnore, // Internal cached display text.
|
||||
"mode": ActionTrack,
|
||||
"mcp_server_ids": ActionTrack,
|
||||
"labels": ActionTrack,
|
||||
"user_acl": ActionTrack,
|
||||
"group_acl": ActionTrack,
|
||||
"pin_order": ActionTrack,
|
||||
"last_read_message_id": ActionIgnore, // User-scoped read cursor.
|
||||
"last_injected_context": ActionIgnore, // Internal lifecycle.
|
||||
"dynamic_tools": ActionIgnore, // Internal lifecycle.
|
||||
"plan_mode": ActionIgnore, // Can flip back and forth during a session.
|
||||
"client_type": ActionIgnore, // Set at creation.
|
||||
"id": ActionTrack,
|
||||
"owner_id": ActionTrack,
|
||||
"owner_username": ActionIgnore,
|
||||
"owner_name": ActionIgnore,
|
||||
"organization_id": ActionIgnore, // Never changes after creation.
|
||||
"workspace_id": ActionTrack,
|
||||
"build_id": ActionIgnore, // Internal lifecycle.
|
||||
"agent_id": ActionIgnore, // Internal lifecycle.
|
||||
"title": ActionSecret, // May contain sensitive content.
|
||||
"status": ActionIgnore, // Churns every message.
|
||||
"worker_id": ActionIgnore, // Internal.
|
||||
"started_at": ActionIgnore,
|
||||
"heartbeat_at": ActionIgnore, // Internal.
|
||||
"created_at": ActionIgnore, // Never changes.
|
||||
"updated_at": ActionIgnore, // Bumped on every mutation.
|
||||
"parent_chat_id": ActionIgnore, // Immutable after creation.
|
||||
"root_chat_id": ActionIgnore, // Immutable after creation.
|
||||
"last_model_config_id": ActionIgnore, // Churns every message.
|
||||
"archived": ActionTrack,
|
||||
"last_error": ActionIgnore, // Internal.
|
||||
"last_turn_summary": ActionIgnore, // Internal cached display text.
|
||||
"mode": ActionTrack,
|
||||
"mcp_server_ids": ActionTrack,
|
||||
"labels": ActionTrack,
|
||||
"user_acl": ActionTrack,
|
||||
"group_acl": ActionTrack,
|
||||
"pin_order": ActionTrack,
|
||||
"last_read_message_id": ActionIgnore, // User-scoped read cursor.
|
||||
"last_injected_context": ActionIgnore, // Internal lifecycle.
|
||||
"dynamic_tools": ActionIgnore, // Internal lifecycle.
|
||||
"plan_mode": ActionIgnore, // Can flip back and forth during a session.
|
||||
"client_type": ActionIgnore, // Set at creation.
|
||||
"snapshot_version": ActionIgnore, // Internal state machine version.
|
||||
"history_version": ActionIgnore, // Internal state machine version.
|
||||
"queue_version": ActionIgnore, // Internal state machine version.
|
||||
"retry_state": ActionIgnore, // Internal transient retry UI state.
|
||||
"retry_state_version": ActionIgnore, // Internal state machine version.
|
||||
"generation_attempt": ActionIgnore, // Internal retry counter.
|
||||
"runner_id": ActionIgnore, // Internal ownership identifier.
|
||||
"requires_action_deadline_at": ActionIgnore, // Internal pending-action deadline.
|
||||
},
|
||||
&database.UserSkill{}: {
|
||||
"id": ActionTrack,
|
||||
|
||||
Reference in New Issue
Block a user