diff --git a/coderd/database/db2sdk/db2sdk.go b/coderd/database/db2sdk/db2sdk.go index b1af71aabd..da7b7a991f 100644 --- a/coderd/database/db2sdk/db2sdk.go +++ b/coderd/database/db2sdk/db2sdk.go @@ -1572,6 +1572,17 @@ func Chat(c database.Chat, diffStatus *database.ChatDiffStatus) codersdk.Chat { convertedDiffStatus := ChatDiffStatus(c.ID, diffStatus) chat.DiffStatus = &convertedDiffStatus } + if c.LastInjectedContext.Valid { + var parts []codersdk.ChatMessagePart + // Internal fields are stripped at write time in + // chatd.updateLastInjectedContext, so no + // StripInternal call is needed here. Unmarshal + // errors are suppressed — the column is written by + // us with a known schema. + if err := json.Unmarshal(c.LastInjectedContext.RawMessage, &parts); err == nil { + chat.LastInjectedContext = parts + } + } return chat } diff --git a/coderd/database/db2sdk/db2sdk_test.go b/coderd/database/db2sdk/db2sdk_test.go index 1a3d3b3c9d..4043580f90 100644 --- a/coderd/database/db2sdk/db2sdk_test.go +++ b/coderd/database/db2sdk/db2sdk_test.go @@ -541,6 +541,13 @@ func TestChat_AllFieldsPopulated(t *testing.T) { PinOrder: 1, MCPServerIDs: []uuid.UUID{uuid.New()}, Labels: database.StringMap{"env": "prod"}, + LastInjectedContext: pqtype.NullRawMessage{ + // Use a context-file part to verify internal + // fields are not present (they are stripped at + // write time by chatd, not at read time). + RawMessage: json.RawMessage(`[{"type":"context-file","context_file_path":"/AGENTS.md"}]`), + Valid: true, + }, } // Only ChatID is needed here. This test checks that // Chat.DiffStatus is non-nil, not that every DiffStatus diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index bf75a2595a..c3a64a346c 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -5752,6 +5752,17 @@ func (q *querier) UpdateChatLabelsByID(ctx context.Context, arg database.UpdateC return q.db.UpdateChatLabelsByID(ctx, arg) } +func (q *querier) UpdateChatLastInjectedContext(ctx context.Context, arg database.UpdateChatLastInjectedContextParams) (database.Chat, error) { + chat, err := q.db.GetChatByID(ctx, arg.ID) + if err != nil { + return database.Chat{}, err + } + if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil { + return database.Chat{}, err + } + return q.db.UpdateChatLastInjectedContext(ctx, arg) +} + func (q *querier) UpdateChatLastModelConfigByID(ctx context.Context, arg database.UpdateChatLastModelConfigByIDParams) (database.Chat, error) { chat, err := q.db.GetChatByID(ctx, arg.ID) if err != nil { diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 4e67955f3d..d7c01e5089 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -1204,6 +1204,19 @@ func (s *MethodTestSuite) TestChats() { dbm.EXPECT().UpdateChatMCPServerIDs(gomock.Any(), arg).Return(chat, nil).AnyTimes() check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(chat) })) + s.Run("UpdateChatLastInjectedContext", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + chat := testutil.Fake(s.T(), faker, database.Chat{}) + arg := database.UpdateChatLastInjectedContextParams{ + ID: chat.ID, + LastInjectedContext: pqtype.NullRawMessage{ + RawMessage: json.RawMessage(`[{"type":"text","text":"test"}]`), + Valid: true, + }, + } + dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes() + dbm.EXPECT().UpdateChatLastInjectedContext(gomock.Any(), arg).Return(chat, nil).AnyTimes() + check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(chat) + })) s.Run("UpdateChatLastReadMessageID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { chat := testutil.Fake(s.T(), faker, database.Chat{}) arg := database.UpdateChatLastReadMessageIDParams{ diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index 1ed0237ef5..51d030b86d 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -4112,6 +4112,14 @@ func (m queryMetricsStore) UpdateChatLabelsByID(ctx context.Context, arg databas return r0, r1 } +func (m queryMetricsStore) UpdateChatLastInjectedContext(ctx context.Context, arg database.UpdateChatLastInjectedContextParams) (database.Chat, error) { + start := time.Now() + r0, r1 := m.s.UpdateChatLastInjectedContext(ctx, arg) + m.queryLatencies.WithLabelValues("UpdateChatLastInjectedContext").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatLastInjectedContext").Inc() + return r0, r1 +} + func (m queryMetricsStore) UpdateChatLastModelConfigByID(ctx context.Context, arg database.UpdateChatLastModelConfigByIDParams) (database.Chat, error) { start := time.Now() r0, r1 := m.s.UpdateChatLastModelConfigByID(ctx, arg) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index d789e2b1f5..f427bcb030 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -7790,6 +7790,21 @@ func (mr *MockStoreMockRecorder) UpdateChatLabelsByID(ctx, arg any) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatLabelsByID", reflect.TypeOf((*MockStore)(nil).UpdateChatLabelsByID), ctx, arg) } +// UpdateChatLastInjectedContext mocks base method. +func (m *MockStore) UpdateChatLastInjectedContext(ctx context.Context, arg database.UpdateChatLastInjectedContextParams) (database.Chat, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateChatLastInjectedContext", ctx, arg) + ret0, _ := ret[0].(database.Chat) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpdateChatLastInjectedContext indicates an expected call of UpdateChatLastInjectedContext. +func (mr *MockStoreMockRecorder) UpdateChatLastInjectedContext(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatLastInjectedContext", reflect.TypeOf((*MockStore)(nil).UpdateChatLastInjectedContext), ctx, arg) +} + // UpdateChatLastModelConfigByID mocks base method. func (m *MockStore) UpdateChatLastModelConfigByID(ctx context.Context, arg database.UpdateChatLastModelConfigByIDParams) (database.Chat, error) { m.ctrl.T.Helper() diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index 4205a31566..6abc4024d5 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -1403,7 +1403,8 @@ CREATE TABLE chats ( build_id uuid, agent_id uuid, pin_order integer DEFAULT 0 NOT NULL, - last_read_message_id bigint + last_read_message_id bigint, + last_injected_context jsonb ); CREATE TABLE connection_logs ( diff --git a/coderd/database/migrations/000456_chat_last_injected_context.down.sql b/coderd/database/migrations/000456_chat_last_injected_context.down.sql new file mode 100644 index 0000000000..a91c2fa33a --- /dev/null +++ b/coderd/database/migrations/000456_chat_last_injected_context.down.sql @@ -0,0 +1 @@ +ALTER TABLE chats DROP COLUMN last_injected_context; diff --git a/coderd/database/migrations/000456_chat_last_injected_context.up.sql b/coderd/database/migrations/000456_chat_last_injected_context.up.sql new file mode 100644 index 0000000000..ef507553b5 --- /dev/null +++ b/coderd/database/migrations/000456_chat_last_injected_context.up.sql @@ -0,0 +1 @@ +ALTER TABLE chats ADD COLUMN last_injected_context JSONB; diff --git a/coderd/database/modelqueries.go b/coderd/database/modelqueries.go index 4409ddc9d0..2b92947a14 100644 --- a/coderd/database/modelqueries.go +++ b/coderd/database/modelqueries.go @@ -795,6 +795,7 @@ func (q *sqlQuerier) GetAuthorizedChats(ctx context.Context, arg GetChatsParams, &i.Chat.AgentID, &i.Chat.PinOrder, &i.Chat.LastReadMessageID, + &i.Chat.LastInjectedContext, &i.HasUnread); err != nil { return nil, err } diff --git a/coderd/database/models.go b/coderd/database/models.go index 7bd3c811e5..41e70403c0 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -4153,28 +4153,29 @@ type BoundaryUsageStat struct { } type Chat struct { - ID uuid.UUID `db:"id" json:"id"` - OwnerID uuid.UUID `db:"owner_id" json:"owner_id"` - WorkspaceID uuid.NullUUID `db:"workspace_id" json:"workspace_id"` - Title string `db:"title" json:"title"` - Status ChatStatus `db:"status" json:"status"` - WorkerID uuid.NullUUID `db:"worker_id" json:"worker_id"` - StartedAt sql.NullTime `db:"started_at" json:"started_at"` - HeartbeatAt sql.NullTime `db:"heartbeat_at" json:"heartbeat_at"` - CreatedAt time.Time `db:"created_at" json:"created_at"` - UpdatedAt time.Time `db:"updated_at" json:"updated_at"` - ParentChatID uuid.NullUUID `db:"parent_chat_id" json:"parent_chat_id"` - RootChatID uuid.NullUUID `db:"root_chat_id" json:"root_chat_id"` - LastModelConfigID uuid.UUID `db:"last_model_config_id" json:"last_model_config_id"` - Archived bool `db:"archived" json:"archived"` - LastError sql.NullString `db:"last_error" json:"last_error"` - Mode NullChatMode `db:"mode" json:"mode"` - MCPServerIDs []uuid.UUID `db:"mcp_server_ids" json:"mcp_server_ids"` - Labels StringMap `db:"labels" json:"labels"` - BuildID uuid.NullUUID `db:"build_id" json:"build_id"` - AgentID uuid.NullUUID `db:"agent_id" json:"agent_id"` - PinOrder int32 `db:"pin_order" json:"pin_order"` - LastReadMessageID sql.NullInt64 `db:"last_read_message_id" json:"last_read_message_id"` + ID uuid.UUID `db:"id" json:"id"` + OwnerID uuid.UUID `db:"owner_id" json:"owner_id"` + WorkspaceID uuid.NullUUID `db:"workspace_id" json:"workspace_id"` + Title string `db:"title" json:"title"` + Status ChatStatus `db:"status" json:"status"` + WorkerID uuid.NullUUID `db:"worker_id" json:"worker_id"` + StartedAt sql.NullTime `db:"started_at" json:"started_at"` + HeartbeatAt sql.NullTime `db:"heartbeat_at" json:"heartbeat_at"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` + ParentChatID uuid.NullUUID `db:"parent_chat_id" json:"parent_chat_id"` + RootChatID uuid.NullUUID `db:"root_chat_id" json:"root_chat_id"` + LastModelConfigID uuid.UUID `db:"last_model_config_id" json:"last_model_config_id"` + Archived bool `db:"archived" json:"archived"` + LastError sql.NullString `db:"last_error" json:"last_error"` + Mode NullChatMode `db:"mode" json:"mode"` + MCPServerIDs []uuid.UUID `db:"mcp_server_ids" json:"mcp_server_ids"` + Labels StringMap `db:"labels" json:"labels"` + BuildID uuid.NullUUID `db:"build_id" json:"build_id"` + AgentID uuid.NullUUID `db:"agent_id" json:"agent_id"` + PinOrder int32 `db:"pin_order" json:"pin_order"` + LastReadMessageID sql.NullInt64 `db:"last_read_message_id" json:"last_read_message_id"` + LastInjectedContext pqtype.NullRawMessage `db:"last_injected_context" json:"last_injected_context"` } type ChatDiffStatus struct { diff --git a/coderd/database/querier.go b/coderd/database/querier.go index b6a7fc07ce..186603a66c 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -854,6 +854,11 @@ type sqlcQuerier interface { // replicas know the worker is still alive. UpdateChatHeartbeat(ctx context.Context, arg UpdateChatHeartbeatParams) (int64, error) UpdateChatLabelsByID(ctx context.Context, arg UpdateChatLabelsByIDParams) (Chat, error) + // Updates the cached injected context parts (AGENTS.md + + // skills) on the chat row. Called only when context changes + // (first workspace attach or agent change). updated_at is + // intentionally not touched to avoid reordering the chat list. + UpdateChatLastInjectedContext(ctx context.Context, arg UpdateChatLastInjectedContextParams) (Chat, error) UpdateChatLastModelConfigByID(ctx context.Context, arg UpdateChatLastModelConfigByIDParams) (Chat, error) // Updates the last read message ID for a chat. This is used to track // which messages the owner has seen, enabling unread indicators. diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 1f78c9e86b..5287e13814 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -4065,7 +4065,7 @@ WHERE $3::int ) RETURNING - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context ` type AcquireChatsParams struct { @@ -4108,6 +4108,7 @@ func (q *sqlQuerier) AcquireChats(ctx context.Context, arg AcquireChatsParams) ( &i.AgentID, &i.PinOrder, &i.LastReadMessageID, + &i.LastInjectedContext, ); err != nil { return nil, err } @@ -4341,7 +4342,7 @@ func (q *sqlQuerier) DeleteChatUsageLimitUserOverride(ctx context.Context, userI const getChatByID = `-- name: GetChatByID :one SELECT - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context FROM chats WHERE @@ -4374,12 +4375,13 @@ func (q *sqlQuerier) GetChatByID(ctx context.Context, id uuid.UUID) (Chat, error &i.AgentID, &i.PinOrder, &i.LastReadMessageID, + &i.LastInjectedContext, ) return i, err } const getChatByIDForUpdate = `-- name: GetChatByIDForUpdate :one -SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id FROM chats WHERE id = $1::uuid FOR UPDATE +SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context FROM chats WHERE id = $1::uuid FOR UPDATE ` func (q *sqlQuerier) GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (Chat, error) { @@ -4408,6 +4410,7 @@ func (q *sqlQuerier) GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (Ch &i.AgentID, &i.PinOrder, &i.LastReadMessageID, + &i.LastInjectedContext, ) return i, err } @@ -5319,7 +5322,7 @@ func (q *sqlQuerier) GetChatUsageLimitUserOverride(ctx context.Context, userID u const getChats = `-- name: GetChats :many SELECT - chats.id, chats.owner_id, chats.workspace_id, chats.title, chats.status, chats.worker_id, chats.started_at, chats.heartbeat_at, chats.created_at, chats.updated_at, chats.parent_chat_id, chats.root_chat_id, chats.last_model_config_id, chats.archived, chats.last_error, chats.mode, chats.mcp_server_ids, chats.labels, chats.build_id, chats.agent_id, chats.pin_order, chats.last_read_message_id, + chats.id, chats.owner_id, chats.workspace_id, chats.title, chats.status, chats.worker_id, chats.started_at, chats.heartbeat_at, chats.created_at, chats.updated_at, chats.parent_chat_id, chats.root_chat_id, chats.last_model_config_id, chats.archived, chats.last_error, chats.mode, chats.mcp_server_ids, chats.labels, chats.build_id, chats.agent_id, chats.pin_order, chats.last_read_message_id, chats.last_injected_context, EXISTS ( SELECT 1 FROM chat_messages cm WHERE cm.chat_id = chats.id @@ -5426,6 +5429,7 @@ func (q *sqlQuerier) GetChats(ctx context.Context, arg GetChatsParams) ([]GetCha &i.Chat.AgentID, &i.Chat.PinOrder, &i.Chat.LastReadMessageID, + &i.Chat.LastInjectedContext, &i.HasUnread, ); err != nil { return nil, err @@ -5442,7 +5446,7 @@ func (q *sqlQuerier) GetChats(ctx context.Context, arg GetChatsParams) ([]GetCha } const getChatsByWorkspaceIDs = `-- name: GetChatsByWorkspaceIDs :many -SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id +SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context FROM chats WHERE archived = false AND workspace_id = ANY($1::uuid[]) @@ -5481,6 +5485,7 @@ func (q *sqlQuerier) GetChatsByWorkspaceIDs(ctx context.Context, ids []uuid.UUID &i.AgentID, &i.PinOrder, &i.LastReadMessageID, + &i.LastInjectedContext, ); err != nil { return nil, err } @@ -5546,7 +5551,7 @@ func (q *sqlQuerier) GetLastChatMessageByRole(ctx context.Context, arg GetLastCh const getStaleChats = `-- name: GetStaleChats :many SELECT - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context FROM chats WHERE @@ -5588,6 +5593,7 @@ func (q *sqlQuerier) GetStaleChats(ctx context.Context, staleThreshold time.Time &i.AgentID, &i.PinOrder, &i.LastReadMessageID, + &i.LastInjectedContext, ); err != nil { return nil, err } @@ -5669,7 +5675,7 @@ INSERT INTO chats ( COALESCE($11::jsonb, '{}'::jsonb) ) RETURNING - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context ` type InsertChatParams struct { @@ -5724,6 +5730,7 @@ func (q *sqlQuerier) InsertChat(ctx context.Context, arg InsertChatParams) (Chat &i.AgentID, &i.PinOrder, &i.LastReadMessageID, + &i.LastInjectedContext, ) return i, err } @@ -6237,7 +6244,7 @@ UPDATE chats SET updated_at = NOW() WHERE id = $3::uuid -RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id +RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context ` type UpdateChatBuildAgentBindingParams struct { @@ -6272,6 +6279,7 @@ func (q *sqlQuerier) UpdateChatBuildAgentBinding(ctx context.Context, arg Update &i.AgentID, &i.PinOrder, &i.LastReadMessageID, + &i.LastInjectedContext, ) return i, err } @@ -6285,7 +6293,7 @@ SET WHERE id = $2::uuid RETURNING - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context ` type UpdateChatByIDParams struct { @@ -6319,6 +6327,7 @@ func (q *sqlQuerier) UpdateChatByID(ctx context.Context, arg UpdateChatByIDParam &i.AgentID, &i.PinOrder, &i.LastReadMessageID, + &i.LastInjectedContext, ) return i, err } @@ -6358,7 +6367,7 @@ SET WHERE id = $2::uuid RETURNING - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context ` type UpdateChatLabelsByIDParams struct { @@ -6392,6 +6401,55 @@ func (q *sqlQuerier) UpdateChatLabelsByID(ctx context.Context, arg UpdateChatLab &i.AgentID, &i.PinOrder, &i.LastReadMessageID, + &i.LastInjectedContext, + ) + return i, err +} + +const updateChatLastInjectedContext = `-- name: UpdateChatLastInjectedContext :one +UPDATE chats SET + last_injected_context = $1::jsonb +WHERE + id = $2::uuid +RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context +` + +type UpdateChatLastInjectedContextParams struct { + LastInjectedContext pqtype.NullRawMessage `db:"last_injected_context" json:"last_injected_context"` + ID uuid.UUID `db:"id" json:"id"` +} + +// Updates the cached injected context parts (AGENTS.md + +// skills) on the chat row. Called only when context changes +// (first workspace attach or agent change). updated_at is +// intentionally not touched to avoid reordering the chat list. +func (q *sqlQuerier) UpdateChatLastInjectedContext(ctx context.Context, arg UpdateChatLastInjectedContextParams) (Chat, error) { + row := q.db.QueryRowContext(ctx, updateChatLastInjectedContext, arg.LastInjectedContext, arg.ID) + var i Chat + err := row.Scan( + &i.ID, + &i.OwnerID, + &i.WorkspaceID, + &i.Title, + &i.Status, + &i.WorkerID, + &i.StartedAt, + &i.HeartbeatAt, + &i.CreatedAt, + &i.UpdatedAt, + &i.ParentChatID, + &i.RootChatID, + &i.LastModelConfigID, + &i.Archived, + &i.LastError, + &i.Mode, + pq.Array(&i.MCPServerIDs), + &i.Labels, + &i.BuildID, + &i.AgentID, + &i.PinOrder, + &i.LastReadMessageID, + &i.LastInjectedContext, ) return i, err } @@ -6405,7 +6463,7 @@ SET WHERE id = $2::uuid RETURNING - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context ` type UpdateChatLastModelConfigByIDParams struct { @@ -6439,6 +6497,7 @@ func (q *sqlQuerier) UpdateChatLastModelConfigByID(ctx context.Context, arg Upda &i.AgentID, &i.PinOrder, &i.LastReadMessageID, + &i.LastInjectedContext, ) return i, err } @@ -6470,7 +6529,7 @@ SET WHERE id = $2::uuid RETURNING - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context ` type UpdateChatMCPServerIDsParams struct { @@ -6504,6 +6563,7 @@ func (q *sqlQuerier) UpdateChatMCPServerIDs(ctx context.Context, arg UpdateChatM &i.AgentID, &i.PinOrder, &i.LastReadMessageID, + &i.LastInjectedContext, ) return i, err } @@ -6639,7 +6699,7 @@ SET WHERE id = $6::uuid RETURNING - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context ` type UpdateChatStatusParams struct { @@ -6684,6 +6744,7 @@ func (q *sqlQuerier) UpdateChatStatus(ctx context.Context, arg UpdateChatStatusP &i.AgentID, &i.PinOrder, &i.LastReadMessageID, + &i.LastInjectedContext, ) return i, err } @@ -6701,7 +6762,7 @@ SET WHERE id = $7::uuid RETURNING - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context ` type UpdateChatStatusPreserveUpdatedAtParams struct { @@ -6748,6 +6809,7 @@ func (q *sqlQuerier) UpdateChatStatusPreserveUpdatedAt(ctx context.Context, arg &i.AgentID, &i.PinOrder, &i.LastReadMessageID, + &i.LastInjectedContext, ) return i, err } @@ -6759,7 +6821,7 @@ UPDATE chats SET agent_id = $3::uuid, updated_at = NOW() WHERE id = $4::uuid -RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id +RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context ` type UpdateChatWorkspaceBindingParams struct { @@ -6800,6 +6862,7 @@ func (q *sqlQuerier) UpdateChatWorkspaceBinding(ctx context.Context, arg UpdateC &i.AgentID, &i.PinOrder, &i.LastReadMessageID, + &i.LastInjectedContext, ) return i, err } diff --git a/coderd/database/queries/chats.sql b/coderd/database/queries/chats.sql index 85983ee22d..772967da76 100644 --- a/coderd/database/queries/chats.sql +++ b/coderd/database/queries/chats.sql @@ -528,6 +528,17 @@ WHERE id = @id::uuid RETURNING *; +-- name: UpdateChatLastInjectedContext :one +-- Updates the cached injected context parts (AGENTS.md + +-- skills) on the chat row. Called only when context changes +-- (first workspace attach or agent change). updated_at is +-- intentionally not touched to avoid reordering the chat list. +UPDATE chats SET + last_injected_context = sqlc.narg('last_injected_context')::jsonb +WHERE + id = @id::uuid +RETURNING *; + -- name: UpdateChatMCPServerIDs :one UPDATE chats diff --git a/coderd/x/chatd/chatd.go b/coderd/x/chatd/chatd.go index b038e818af..3e17ed14e2 100644 --- a/coderd/x/chatd/chatd.go +++ b/coderd/x/chatd/chatd.go @@ -4970,9 +4970,25 @@ func (p *Server) persistInstructionFiles( chatprompt.CurrentContentVersion, )) _, _ = p.db.InsertChatMessages(ctx, msgParams) + // Update the cache column: persist skills if any + // exist, or clear to NULL so stale data from a + // previous agent doesn't linger. + if len(discoveredSkills) > 0 { + skillParts := make([]codersdk.ChatMessagePart, 0, len(discoveredSkills)) + for _, s := range discoveredSkills { + skillParts = append(skillParts, codersdk.ChatMessagePart{ + Type: codersdk.ChatMessagePartTypeSkill, + SkillName: s.Name, + SkillDescription: s.Description, + ContextFileAgentID: uuid.NullUUID{UUID: agent.ID, Valid: true}, + }) + } + p.updateLastInjectedContext(ctx, chat.ID, skillParts) + } else { + p.updateLastInjectedContext(ctx, chat.ID, nil) + } return "", discoveredSkills, nil } - // Build context-file parts (one per instruction file) and // skill parts (one per discovered skill). parts := make([]codersdk.ChatMessagePart, 0, len(sections)+len(discoveredSkills)) @@ -5015,6 +5031,15 @@ func (p *Server) persistInstructionFiles( if _, err := p.db.InsertChatMessages(ctx, msgParams); err != nil { return "", nil, xerrors.Errorf("persist instruction files: %w", err) } + // Build stripped copies for the cache column so internal + // fields (full file content, OS, directory, skill paths) + // are never persisted or returned to API clients. + stripped := make([]codersdk.ChatMessagePart, len(parts)) + copy(stripped, parts) + for i := range stripped { + stripped[i].StripInternal() + } + p.updateLastInjectedContext(ctx, chat.ID, stripped) // Return the formatted instruction text and discovered skills // so the caller can inject them into this turn's prompt (since @@ -5022,6 +5047,35 @@ func (p *Server) persistInstructionFiles( return formatSystemInstructions(agent.OperatingSystem, directory, sections), discoveredSkills, nil } +// updateLastInjectedContext persists the injected context +// parts (AGENTS.md files and skills) on the chat row so they +// are directly queryable without scanning messages. This is +// best-effort — a failure here is logged but does not block +// the turn. +func (p *Server) updateLastInjectedContext(ctx context.Context, chatID uuid.UUID, parts []codersdk.ChatMessagePart) { + param := pqtype.NullRawMessage{Valid: false} + if parts != nil { + raw, err := json.Marshal(parts) + if err != nil { + p.logger.Warn(ctx, "failed to marshal injected context", + slog.F("chat_id", chatID), + slog.Error(err), + ) + return + } + param = pqtype.NullRawMessage{RawMessage: raw, Valid: true} + } + if _, err := p.db.UpdateChatLastInjectedContext(ctx, database.UpdateChatLastInjectedContextParams{ + ID: chatID, + LastInjectedContext: param, + }); err != nil { + p.logger.Warn(ctx, "failed to update injected context", + slog.F("chat_id", chatID), + slog.Error(err), + ) + } +} + // resolveUserCompactionThreshold looks up the user's per-model // compaction threshold override. Returns the override value and // true if one exists and is valid, or 0 and false otherwise. diff --git a/coderd/x/chatd/chatd_internal_test.go b/coderd/x/chatd/chatd_internal_test.go index 7632aa86ad..9d03e50409 100644 --- a/coderd/x/chatd/chatd_internal_test.go +++ b/coderd/x/chatd/chatd_internal_test.go @@ -484,6 +484,32 @@ func TestPersistInstructionFilesIncludesAgentMetadata(t *testing.T) { agentID, ).Return(workspaceAgent, nil).Times(1) db.EXPECT().InsertChatMessages(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() + db.EXPECT().UpdateChatLastInjectedContext(gomock.Any(), + gomock.Cond(func(x any) bool { + arg, ok := x.(database.UpdateChatLastInjectedContextParams) + if !ok || arg.ID != chat.ID { + return false + } + if !arg.LastInjectedContext.Valid { + return false + } + var parts []codersdk.ChatMessagePart + if err := json.Unmarshal(arg.LastInjectedContext.RawMessage, &parts); err != nil { + return false + } + // Expect at least one context-file part for the + // working-directory AGENTS.md, with internal fields + // stripped (no content, OS, or directory). + for _, p := range parts { + if p.Type == codersdk.ChatMessagePartTypeContextFile && p.ContextFilePath != "" { + return p.ContextFileContent == "" && + p.ContextFileOS == "" && + p.ContextFileDirectory == "" + } + } + return false + }), + ).Return(database.Chat{}, nil).Times(1) conn := agentconnmock.NewMockAgentConn(ctrl) conn.EXPECT().SetExtraHeaders(gomock.Any()).Times(1) @@ -569,6 +595,247 @@ func TestPersistInstructionFilesSkipsSentinelWhenWorkspaceUnavailable(t *testing require.Empty(t, instruction) } +func TestPersistInstructionFilesSentinelWithSkills(t *testing.T) { + t.Parallel() + + ctx := context.Background() + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + + workspaceID := uuid.New() + agentID := uuid.New() + chat := database.Chat{ + ID: uuid.New(), + WorkspaceID: uuid.NullUUID{ + UUID: workspaceID, + Valid: true, + }, + AgentID: uuid.NullUUID{ + UUID: agentID, + Valid: true, + }, + } + workspaceAgent := database.WorkspaceAgent{ + ID: agentID, + OperatingSystem: "linux", + Directory: "/home/coder/project", + ExpandedDirectory: "/home/coder/project", + } + + db.EXPECT().GetWorkspaceAgentByID( + gomock.Any(), + agentID, + ).Return(workspaceAgent, nil).Times(1) + db.EXPECT().InsertChatMessages(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() + db.EXPECT().UpdateChatLastInjectedContext(gomock.Any(), + gomock.Cond(func(x any) bool { + arg, ok := x.(database.UpdateChatLastInjectedContextParams) + if !ok || arg.ID != chat.ID { + return false + } + if !arg.LastInjectedContext.Valid { + return false + } + var parts []codersdk.ChatMessagePart + if err := json.Unmarshal(arg.LastInjectedContext.RawMessage, &parts); err != nil { + return false + } + // The sentinel path should persist only skill parts + // with ContextFileAgentID set. + for _, p := range parts { + if p.Type == codersdk.ChatMessagePartTypeSkill && + p.SkillName == "my-skill" && + p.ContextFileAgentID == (uuid.NullUUID{UUID: agentID, Valid: true}) { + return true + } + } + return false + }), + ).Return(database.Chat{}, nil).Times(1) + + conn := agentconnmock.NewMockAgentConn(ctrl) + conn.EXPECT().SetExtraHeaders(gomock.Any()).Times(1) + + // Home LS (.coder directory): return 404 so no home + // instruction file is found. + conn.EXPECT().LS(gomock.Any(), "", + gomock.Cond(func(x any) bool { + req, ok := x.(workspacesdk.LSRequest) + return ok && req.Relativity == workspacesdk.LSRelativityHome + }), + ).Return( + workspacesdk.LSResponse{}, + codersdk.NewTestError(404, "POST", "/api/v0/list-directory"), + ).Times(1) + + // Pwd AGENTS.md: return 404 so no working-directory + // instruction file is found either. + conn.EXPECT().ReadFile(gomock.Any(), + "/home/coder/project/AGENTS.md", + int64(0), + int64(maxInstructionFileBytes+1), + ).Return( + nil, "", + codersdk.NewTestError(404, "GET", "/api/v0/read-file"), + ).Times(1) + + // Skills LS (.agents/skills directory): return one skill + // directory so DiscoverSkills finds it. + conn.EXPECT().LS(gomock.Any(), "", + gomock.Cond(func(x any) bool { + req, ok := x.(workspacesdk.LSRequest) + return ok && req.Relativity == workspacesdk.LSRelativityRoot + }), + ).Return(workspacesdk.LSResponse{ + Contents: []workspacesdk.LSFile{{ + Name: "my-skill", + AbsolutePathString: "/home/coder/project/.agents/skills/my-skill", + IsDir: true, + }}, + }, nil).Times(1) + + // Skills SKILL.md ReadFile: return valid frontmatter. + skillContent := "---\nname: my-skill\ndescription: A test skill\n---\nSkill body" + conn.EXPECT().ReadFile(gomock.Any(), + "/home/coder/project/.agents/skills/my-skill/SKILL.md", + int64(0), + int64(64*1024+1), + ).Return( + io.NopCloser(strings.NewReader(skillContent)), + "", + nil, + ).Times(1) + + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + server := &Server{ + db: db, + logger: logger, + agentConnFn: func(context.Context, uuid.UUID) (workspacesdk.AgentConn, func(), error) { + return conn, func() {}, nil + }, + } + + chatStateMu := &sync.Mutex{} + currentChat := chat + workspaceCtx := turnWorkspaceContext{ + server: server, + chatStateMu: chatStateMu, + currentChat: ¤tChat, + loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil }, + } + t.Cleanup(workspaceCtx.close) + + instruction, skills, err := server.persistInstructionFiles( + ctx, + chat, + uuid.New(), + workspaceCtx.getWorkspaceAgent, + workspaceCtx.getWorkspaceConn, + ) + require.NoError(t, err) + // Sentinel path returns empty instruction string. + require.Empty(t, instruction) + // Skills are still discovered and returned. + require.Len(t, skills, 1) + require.Equal(t, "my-skill", skills[0].Name) +} + +func TestPersistInstructionFilesSentinelNoSkillsClearsColumn(t *testing.T) { + t.Parallel() + + ctx := context.Background() + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + + workspaceID := uuid.New() + agentID := uuid.New() + chat := database.Chat{ + ID: uuid.New(), + WorkspaceID: uuid.NullUUID{ + UUID: workspaceID, + Valid: true, + }, + AgentID: uuid.NullUUID{ + UUID: agentID, + Valid: true, + }, + } + workspaceAgent := database.WorkspaceAgent{ + ID: agentID, + OperatingSystem: "linux", + Directory: "/home/coder/project", + ExpandedDirectory: "/home/coder/project", + } + + db.EXPECT().GetWorkspaceAgentByID( + gomock.Any(), + agentID, + ).Return(workspaceAgent, nil).Times(1) + db.EXPECT().InsertChatMessages(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() + db.EXPECT().UpdateChatLastInjectedContext(gomock.Any(), + gomock.Cond(func(x any) bool { + arg, ok := x.(database.UpdateChatLastInjectedContextParams) + if !ok || arg.ID != chat.ID { + return false + } + // No skills discovered, so the column should be + // cleared to NULL. + return !arg.LastInjectedContext.Valid + }), + ).Return(database.Chat{}, nil).Times(1) + + conn := agentconnmock.NewMockAgentConn(ctrl) + conn.EXPECT().SetExtraHeaders(gomock.Any()).Times(1) + + // All LS calls return 404: no home .coder directory and no + // .agents/skills directory. + conn.EXPECT().LS(gomock.Any(), "", gomock.Any()).Return( + workspacesdk.LSResponse{}, + codersdk.NewTestError(404, "POST", "/api/v0/list-directory"), + ).AnyTimes() + + // Pwd AGENTS.md: return 404. + conn.EXPECT().ReadFile(gomock.Any(), + "/home/coder/project/AGENTS.md", + int64(0), + int64(maxInstructionFileBytes+1), + ).Return( + nil, "", + codersdk.NewTestError(404, "GET", "/api/v0/read-file"), + ).Times(1) + + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + server := &Server{ + db: db, + logger: logger, + agentConnFn: func(context.Context, uuid.UUID) (workspacesdk.AgentConn, func(), error) { + return conn, func() {}, nil + }, + } + + chatStateMu := &sync.Mutex{} + currentChat := chat + workspaceCtx := turnWorkspaceContext{ + server: server, + chatStateMu: chatStateMu, + currentChat: ¤tChat, + loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil }, + } + t.Cleanup(workspaceCtx.close) + + instruction, skills, err := server.persistInstructionFiles( + ctx, + chat, + uuid.New(), + workspaceCtx.getWorkspaceAgent, + workspaceCtx.getWorkspaceConn, + ) + require.NoError(t, err) + // Sentinel path: empty instruction, no skills. + require.Empty(t, instruction) + require.Empty(t, skills) +} + func TestTurnWorkspaceContext_BindingFirstPath(t *testing.T) { t.Parallel() diff --git a/codersdk/chats.go b/codersdk/chats.go index 1b19d7a82b..aaacb1a034 100644 --- a/codersdk/chats.go +++ b/codersdk/chats.go @@ -68,6 +68,11 @@ type Chat struct { // the owner's read cursor, which updates on stream // connect and disconnect. HasUnread bool `json:"has_unread"` + // LastInjectedContext holds the most recently persisted + // injected context parts (AGENTS.md files and skills). It + // is updated only when context changes — first workspace + // attach or agent change. + LastInjectedContext []ChatMessagePart `json:"last_injected_context,omitempty"` } // ChatMessage represents a single message in a chat. diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index 38e1c9687c..61941fb84b 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -1200,6 +1200,13 @@ export interface Chat { * connect and disconnect. */ readonly has_unread: boolean; + /** + * LastInjectedContext holds the most recently persisted + * injected context parts (AGENTS.md files and skills). It + * is updated only when context changes — first workspace + * attach or agent change. + */ + readonly last_injected_context?: readonly ChatMessagePart[]; } // From codersdk/chats.go