From d4660d8a690dcbf499080d73f328398b50166193 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Wed, 25 Mar 2026 13:26:26 -0400 Subject: [PATCH] feat: add labels to chats (#23594) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Adds a general-purpose `map[string]string` label system to chats, stored as jsonb with a GIN index for efficient containment queries. This is a standalone foundational feature that will be used by the upcoming Automations feature for session identity (matching webhook events to existing chats), replacing the need for bespoke session-key tables. ## Changes ### Database - **Migration 000451**: Adds `labels jsonb NOT NULL DEFAULT '{}'` column to `chats` table with a GIN index (`idx_chats_labels`) - **`InsertChat`**: Accepts labels on creation via `COALESCE(@labels, '{}')` - **`UpdateChatByID`**: Supports partial update — `COALESCE(sqlc.narg('labels'), labels)` preserves existing labels when NULL is passed - **`GetChats`**: New `has_labels` filter using PostgreSQL `@>` containment operator - **`GetAuthorizedChats`**: Synced with generated `GetChats` (new column scan + query param) ### API - **Create chat** (`POST /chats`): Accepts optional `labels` field, validated before creation - **Update chat** (`PATCH /chats/{chat}`): Supports `labels` field for atomic label replacement - **List chats** (`GET /chats`): Supports `?label=key:value` query parameters (multiple are AND-ed) ### SDK - `Chat`, `CreateChatRequest`, `UpdateChatRequest`, `ListChatsOptions` all gain `Labels` fields - `UpdateChatRequest.Labels` is a pointer (`*map[string]string`) so `nil` means "don't change" vs empty map means "clear all" ### Validation (`coderd/httpapi/labels.go`) - Max 50 labels per chat - Key: 1–64 chars, must match `[a-zA-Z0-9][a-zA-Z0-9._/-]*` (supports namespaced keys like `github.repo`, `automation/pr-number`) - Value: 1–256 chars - 13 test cases covering all edge cases ### Chat runtime - `chatd.CreateOptions` gains `Labels` field, threaded through to `InsertChat` - Existing `UpdateChatByID` callers (e.g., quickgen title updates) are unaffected — NULL labels preserve existing values via COALESCE --- coderd/database/dbauthz/dbauthz.go | 11 + coderd/database/dbauthz/dbauthz_test.go | 10 + coderd/database/dbmetrics/querymetrics.go | 8 + coderd/database/dbmock/dbmock.go | 15 ++ coderd/database/dump.sql | 5 +- .../migrations/000451_chat_labels.down.sql | 3 + .../migrations/000451_chat_labels.up.sql | 3 + coderd/database/modelqueries.go | 2 + coderd/database/models.go | 1 + coderd/database/querier.go | 1 + coderd/database/querier_test.go | 212 ++++++++++++++++++ coderd/database/queries.sql.go | 117 +++++++--- coderd/database/queries/chats.sql | 21 +- coderd/database/sqlc.yaml | 3 + coderd/exp_chats.go | 91 +++++++- coderd/httpapi/chatlabels.go | 78 +++++++ coderd/httpapi/chatlabels_test.go | 191 ++++++++++++++++ coderd/x/chatd/chatd.go | 13 ++ codersdk/chats.go | 55 +++-- site/src/api/queries/chats.test.ts | 1 + site/src/api/typesGenerated.ts | 4 + .../pages/AgentsPage/AgentDetail.stories.tsx | 1 + .../AgentsPage/AgentsPageView.stories.tsx | 1 + .../AgentDetail/ChatContext.test.tsx | 1 + .../components/AgentDetail/TopBar.stories.tsx | 1 + .../components/AgentDetailView.stories.tsx | 1 + .../Sidebar/AgentsSidebar.stories.tsx | 1 + .../components/Sidebar/AgentsSidebar.test.tsx | 1 + 28 files changed, 796 insertions(+), 56 deletions(-) create mode 100644 coderd/database/migrations/000451_chat_labels.down.sql create mode 100644 coderd/database/migrations/000451_chat_labels.up.sql create mode 100644 coderd/httpapi/chatlabels.go create mode 100644 coderd/httpapi/chatlabels_test.go diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 347b5a1a1f..2d704add79 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -5641,6 +5641,17 @@ func (q *querier) UpdateChatHeartbeat(ctx context.Context, arg database.UpdateCh return q.db.UpdateChatHeartbeat(ctx, arg) } +func (q *querier) UpdateChatLabelsByID(ctx context.Context, arg database.UpdateChatLabelsByIDParams) (database.Chat, error) { + chat, err := q.db.GetChatByID(ctx, arg.ID) + if err != nil { + return database.Chat{}, err + } + if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil { + return database.Chat{}, err + } + return q.db.UpdateChatLabelsByID(ctx, arg) +} + func (q *querier) UpdateChatMCPServerIDs(ctx context.Context, arg database.UpdateChatMCPServerIDsParams) (database.Chat, error) { chat, err := q.db.GetChatByID(ctx, arg.ID) if err != nil { diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 20640ed74d..31765996dc 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -749,6 +749,16 @@ func (s *MethodTestSuite) TestChats() { dbm.EXPECT().UpdateChatByID(gomock.Any(), arg).Return(chat, nil).AnyTimes() check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(chat) })) + s.Run("UpdateChatLabelsByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + chat := testutil.Fake(s.T(), faker, database.Chat{}) + arg := database.UpdateChatLabelsByIDParams{ + ID: chat.ID, + Labels: []byte(`{"env":"prod"}`), + } + dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes() + dbm.EXPECT().UpdateChatLabelsByID(gomock.Any(), arg).Return(chat, nil).AnyTimes() + check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(chat) + })) s.Run("UpdateChatHeartbeat", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { chat := testutil.Fake(s.T(), faker, database.Chat{}) arg := database.UpdateChatHeartbeatParams{ diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index 2890e31535..e1f6c1e73b 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -4016,6 +4016,14 @@ func (m queryMetricsStore) UpdateChatHeartbeat(ctx context.Context, arg database return r0, r1 } +func (m queryMetricsStore) UpdateChatLabelsByID(ctx context.Context, arg database.UpdateChatLabelsByIDParams) (database.Chat, error) { + start := time.Now() + r0, r1 := m.s.UpdateChatLabelsByID(ctx, arg) + m.queryLatencies.WithLabelValues("UpdateChatLabelsByID").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatLabelsByID").Inc() + return r0, r1 +} + func (m queryMetricsStore) UpdateChatMCPServerIDs(ctx context.Context, arg database.UpdateChatMCPServerIDsParams) (database.Chat, error) { start := time.Now() r0, r1 := m.s.UpdateChatMCPServerIDs(ctx, arg) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index 7dff5e62bb..50d7bd9e09 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -7582,6 +7582,21 @@ func (mr *MockStoreMockRecorder) UpdateChatHeartbeat(ctx, arg any) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatHeartbeat", reflect.TypeOf((*MockStore)(nil).UpdateChatHeartbeat), ctx, arg) } +// UpdateChatLabelsByID mocks base method. +func (m *MockStore) UpdateChatLabelsByID(ctx context.Context, arg database.UpdateChatLabelsByIDParams) (database.Chat, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateChatLabelsByID", ctx, arg) + ret0, _ := ret[0].(database.Chat) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpdateChatLabelsByID indicates an expected call of UpdateChatLabelsByID. +func (mr *MockStoreMockRecorder) UpdateChatLabelsByID(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatLabelsByID", reflect.TypeOf((*MockStore)(nil).UpdateChatLabelsByID), ctx, arg) +} + // UpdateChatMCPServerIDs mocks base method. func (m *MockStore) UpdateChatMCPServerIDs(ctx context.Context, arg database.UpdateChatMCPServerIDsParams) (database.Chat, error) { m.ctrl.T.Helper() diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index f34566454a..1c3d158b59 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -1398,7 +1398,8 @@ CREATE TABLE chats ( archived boolean DEFAULT false NOT NULL, last_error text, mode chat_mode, - mcp_server_ids uuid[] DEFAULT '{}'::uuid[] NOT NULL + mcp_server_ids uuid[] DEFAULT '{}'::uuid[] NOT NULL, + labels jsonb DEFAULT '{}'::jsonb NOT NULL ); CREATE TABLE connection_logs ( @@ -3726,6 +3727,8 @@ CREATE INDEX idx_chat_providers_enabled ON chat_providers USING btree (enabled); CREATE INDEX idx_chat_queued_messages_chat_id ON chat_queued_messages USING btree (chat_id); +CREATE INDEX idx_chats_labels ON chats USING gin (labels); + CREATE INDEX idx_chats_last_model_config_id ON chats USING btree (last_model_config_id); CREATE INDEX idx_chats_owner ON chats USING btree (owner_id); diff --git a/coderd/database/migrations/000451_chat_labels.down.sql b/coderd/database/migrations/000451_chat_labels.down.sql new file mode 100644 index 0000000000..baa6213bb5 --- /dev/null +++ b/coderd/database/migrations/000451_chat_labels.down.sql @@ -0,0 +1,3 @@ +DROP INDEX IF EXISTS idx_chats_labels; + +ALTER TABLE chats DROP COLUMN labels; diff --git a/coderd/database/migrations/000451_chat_labels.up.sql b/coderd/database/migrations/000451_chat_labels.up.sql new file mode 100644 index 0000000000..1d1e238e6b --- /dev/null +++ b/coderd/database/migrations/000451_chat_labels.up.sql @@ -0,0 +1,3 @@ +ALTER TABLE chats ADD COLUMN labels jsonb NOT NULL DEFAULT '{}'; + +CREATE INDEX idx_chats_labels ON chats USING GIN (labels); diff --git a/coderd/database/modelqueries.go b/coderd/database/modelqueries.go index beb7a3de44..e62cb25922 100644 --- a/coderd/database/modelqueries.go +++ b/coderd/database/modelqueries.go @@ -761,6 +761,7 @@ func (q *sqlQuerier) GetAuthorizedChats(ctx context.Context, arg GetChatsParams, arg.OwnerID, arg.Archived, arg.AfterID, + arg.LabelFilter, arg.OffsetOpt, arg.LimitOpt, ) @@ -789,6 +790,7 @@ func (q *sqlQuerier) GetAuthorizedChats(ctx context.Context, arg GetChatsParams, &i.LastError, &i.Mode, pq.Array(&i.MCPServerIDs), + &i.Labels, ); err != nil { return nil, err } diff --git a/coderd/database/models.go b/coderd/database/models.go index c7c558650a..08b06cd811 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -4170,6 +4170,7 @@ type Chat struct { LastError sql.NullString `db:"last_error" json:"last_error"` Mode NullChatMode `db:"mode" json:"mode"` MCPServerIDs []uuid.UUID `db:"mcp_server_ids" json:"mcp_server_ids"` + Labels StringMap `db:"labels" json:"labels"` } type ChatDiffStatus struct { diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 1f8c00bc6a..cf6cc78ce2 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -823,6 +823,7 @@ type sqlcQuerier interface { // Bumps the heartbeat timestamp for a running chat so that other // replicas know the worker is still alive. UpdateChatHeartbeat(ctx context.Context, arg UpdateChatHeartbeatParams) (int64, error) + UpdateChatLabelsByID(ctx context.Context, arg UpdateChatLabelsByIDParams) (Chat, error) UpdateChatMCPServerIDs(ctx context.Context, arg UpdateChatMCPServerIDsParams) (Chat, error) UpdateChatMessageByID(ctx context.Context, arg UpdateChatMessageByIDParams) (ChatMessage, error) UpdateChatModelConfig(ctx context.Context, arg UpdateChatModelConfigParams) (ChatModelConfig, error) diff --git a/coderd/database/querier_test.go b/coderd/database/querier_test.go index 181902f16d..b1f98ffe42 100644 --- a/coderd/database/querier_test.go +++ b/coderd/database/querier_test.go @@ -10486,3 +10486,215 @@ func TestGetPRInsights(t *testing.T) { assert.Equal(t, int64(5_000_000), summary.MergedCostMicros) }) } + +func TestChatLabels(t *testing.T) { + t.Parallel() + if testing.Short() { + t.SkipNow() + } + + sqlDB := testSQLDB(t) + err := migrations.Up(sqlDB) + require.NoError(t, err) + db := database.New(sqlDB) + + ctx := testutil.Context(t, testutil.WaitMedium) + owner := dbgen.User(t, db, database.User{}) + + _, err = db.InsertChatProvider(ctx, database.InsertChatProviderParams{ + Provider: "openai", + DisplayName: "OpenAI", + APIKey: "test-key", + Enabled: true, + }) + require.NoError(t, err) + + modelCfg, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{ + Provider: "openai", + Model: "test-model", + DisplayName: "Test Model", + CreatedBy: uuid.NullUUID{UUID: owner.ID, Valid: true}, + UpdatedBy: uuid.NullUUID{UUID: owner.ID, Valid: true}, + Enabled: true, + IsDefault: true, + ContextLimit: 128000, + CompressionThreshold: 80, + Options: json.RawMessage(`{}`), + }) + require.NoError(t, err) + + t.Run("CreateWithLabels", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitMedium) + + labels := database.StringMap{"github.repo": "coder/coder", "env": "prod"} + labelsJSON, err := json.Marshal(labels) + require.NoError(t, err) + + chat, err := db.InsertChat(ctx, database.InsertChatParams{ + OwnerID: owner.ID, + LastModelConfigID: modelCfg.ID, + Title: "labeled-chat", + Labels: pqtype.NullRawMessage{ + RawMessage: labelsJSON, + Valid: true, + }, + }) + require.NoError(t, err) + require.Equal(t, database.StringMap{"github.repo": "coder/coder", "env": "prod"}, chat.Labels) + + // Read back and verify. + fetched, err := db.GetChatByID(ctx, chat.ID) + require.NoError(t, err) + require.Equal(t, chat.Labels, fetched.Labels) + }) + + t.Run("CreateWithoutLabels", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitMedium) + + chat, err := db.InsertChat(ctx, database.InsertChatParams{ + OwnerID: owner.ID, + LastModelConfigID: modelCfg.ID, + Title: "no-labels-chat", + }) + require.NoError(t, err) + // Default should be an empty map, not nil. + require.NotNil(t, chat.Labels) + require.Empty(t, chat.Labels) + }) + + t.Run("UpdateLabels", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitMedium) + + chat, err := db.InsertChat(ctx, database.InsertChatParams{ + OwnerID: owner.ID, + LastModelConfigID: modelCfg.ID, + Title: "update-labels-chat", + }) + require.NoError(t, err) + require.Empty(t, chat.Labels) + + // Set labels. + newLabels, err := json.Marshal(database.StringMap{"team": "backend"}) + require.NoError(t, err) + updated, err := db.UpdateChatLabelsByID(ctx, database.UpdateChatLabelsByIDParams{ + ID: chat.ID, + Labels: newLabels, + }) + require.NoError(t, err) + require.Equal(t, database.StringMap{"team": "backend"}, updated.Labels) + + // Title should be unchanged. + require.Equal(t, "update-labels-chat", updated.Title) + + // Clear labels by setting empty object. + emptyLabels, err := json.Marshal(database.StringMap{}) + require.NoError(t, err) + cleared, err := db.UpdateChatLabelsByID(ctx, database.UpdateChatLabelsByIDParams{ + ID: chat.ID, + Labels: emptyLabels, + }) + require.NoError(t, err) + require.Empty(t, cleared.Labels) + }) + + t.Run("UpdateTitleDoesNotAffectLabels", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitMedium) + + labels := database.StringMap{"pr": "1234"} + labelsJSON, err := json.Marshal(labels) + require.NoError(t, err) + + chat, err := db.InsertChat(ctx, database.InsertChatParams{ + OwnerID: owner.ID, + LastModelConfigID: modelCfg.ID, + Title: "original-title", + Labels: pqtype.NullRawMessage{ + RawMessage: labelsJSON, + Valid: true, + }, + }) + require.NoError(t, err) + + // Update title only — labels must survive. + updated, err := db.UpdateChatByID(ctx, database.UpdateChatByIDParams{ + ID: chat.ID, + Title: "new-title", + }) + require.NoError(t, err) + require.Equal(t, "new-title", updated.Title) + require.Equal(t, database.StringMap{"pr": "1234"}, updated.Labels) + }) + + t.Run("FilterByLabels", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitMedium) + + // Create three chats with different labels. + for _, tc := range []struct { + title string + labels database.StringMap + }{ + {"filter-a", database.StringMap{"env": "prod", "team": "backend"}}, + {"filter-b", database.StringMap{"env": "prod", "team": "frontend"}}, + {"filter-c", database.StringMap{"env": "staging"}}, + } { + labelsJSON, err := json.Marshal(tc.labels) + require.NoError(t, err) + _, err = db.InsertChat(ctx, database.InsertChatParams{ + OwnerID: owner.ID, + LastModelConfigID: modelCfg.ID, + Title: tc.title, + Labels: pqtype.NullRawMessage{ + RawMessage: labelsJSON, + Valid: true, + }, + }) + require.NoError(t, err) + } + + // Filter by env=prod — should match filter-a and filter-b. + filterJSON, err := json.Marshal(database.StringMap{"env": "prod"}) + require.NoError(t, err) + results, err := db.GetChats(ctx, database.GetChatsParams{ + OwnerID: owner.ID, + LabelFilter: pqtype.NullRawMessage{ + RawMessage: filterJSON, + Valid: true, + }, + }) + require.NoError(t, err) + + titles := make([]string, 0, len(results)) + for _, c := range results { + titles = append(titles, c.Title) + } + require.Contains(t, titles, "filter-a") + require.Contains(t, titles, "filter-b") + require.NotContains(t, titles, "filter-c") + + // Filter by env=prod AND team=backend — should match only filter-a. + filterJSON, err = json.Marshal(database.StringMap{"env": "prod", "team": "backend"}) + require.NoError(t, err) + results, err = db.GetChats(ctx, database.GetChatsParams{ + OwnerID: owner.ID, + LabelFilter: pqtype.NullRawMessage{ + RawMessage: filterJSON, + Valid: true, + }, + }) + require.NoError(t, err) + require.Len(t, results, 1) + require.Equal(t, "filter-a", results[0].Title) + + // No filter — should return all chats for this owner. + allChats, err := db.GetChats(ctx, database.GetChatsParams{ + OwnerID: owner.ID, + }) + require.NoError(t, err) + require.GreaterOrEqual(t, len(allChats), 3) + }) +} diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 47e2c01040..d276c524fe 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -3823,7 +3823,7 @@ WHERE $3::int ) RETURNING - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels ` type AcquireChatsParams struct { @@ -3861,6 +3861,7 @@ func (q *sqlQuerier) AcquireChats(ctx context.Context, arg AcquireChatsParams) ( &i.LastError, &i.Mode, pq.Array(&i.MCPServerIDs), + &i.Labels, ); err != nil { return nil, err } @@ -4094,7 +4095,7 @@ func (q *sqlQuerier) DeleteChatUsageLimitUserOverride(ctx context.Context, userI const getChatByID = `-- name: GetChatByID :one SELECT - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels FROM chats WHERE @@ -4122,12 +4123,13 @@ func (q *sqlQuerier) GetChatByID(ctx context.Context, id uuid.UUID) (Chat, error &i.LastError, &i.Mode, pq.Array(&i.MCPServerIDs), + &i.Labels, ) return i, err } const getChatByIDForUpdate = `-- name: GetChatByIDForUpdate :one -SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids FROM chats WHERE id = $1::uuid FOR UPDATE +SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels FROM chats WHERE id = $1::uuid FOR UPDATE ` func (q *sqlQuerier) GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (Chat, error) { @@ -4151,6 +4153,7 @@ func (q *sqlQuerier) GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (Ch &i.LastError, &i.Mode, pq.Array(&i.MCPServerIDs), + &i.Labels, ) return i, err } @@ -4995,7 +4998,7 @@ func (q *sqlQuerier) GetChatUsageLimitUserOverride(ctx context.Context, userID u const getChats = `-- name: GetChats :many SELECT - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels FROM chats WHERE @@ -5026,24 +5029,29 @@ WHERE ) ELSE true END + AND CASE + WHEN $4::jsonb IS NOT NULL THEN chats.labels @> $4::jsonb + ELSE true + END -- Authorize Filter clause will be injected below in GetAuthorizedChats -- @authorize_filter ORDER BY -- Deterministic and consistent ordering of all rows, even if they share -- a timestamp. This is to ensure consistent pagination. - (updated_at, id) DESC OFFSET $4 + (updated_at, id) DESC OFFSET $5 LIMIT -- The chat list is unbounded and expected to grow large. -- Default to 50 to prevent accidental excessively large queries. - COALESCE(NULLIF($5 :: int, 0), 50) + COALESCE(NULLIF($6 :: int, 0), 50) ` type GetChatsParams struct { - OwnerID uuid.UUID `db:"owner_id" json:"owner_id"` - Archived sql.NullBool `db:"archived" json:"archived"` - AfterID uuid.UUID `db:"after_id" json:"after_id"` - OffsetOpt int32 `db:"offset_opt" json:"offset_opt"` - LimitOpt int32 `db:"limit_opt" json:"limit_opt"` + OwnerID uuid.UUID `db:"owner_id" json:"owner_id"` + Archived sql.NullBool `db:"archived" json:"archived"` + AfterID uuid.UUID `db:"after_id" json:"after_id"` + LabelFilter pqtype.NullRawMessage `db:"label_filter" json:"label_filter"` + OffsetOpt int32 `db:"offset_opt" json:"offset_opt"` + LimitOpt int32 `db:"limit_opt" json:"limit_opt"` } func (q *sqlQuerier) GetChats(ctx context.Context, arg GetChatsParams) ([]Chat, error) { @@ -5051,6 +5059,7 @@ func (q *sqlQuerier) GetChats(ctx context.Context, arg GetChatsParams) ([]Chat, arg.OwnerID, arg.Archived, arg.AfterID, + arg.LabelFilter, arg.OffsetOpt, arg.LimitOpt, ) @@ -5079,6 +5088,7 @@ func (q *sqlQuerier) GetChats(ctx context.Context, arg GetChatsParams) ([]Chat, &i.LastError, &i.Mode, pq.Array(&i.MCPServerIDs), + &i.Labels, ); err != nil { return nil, err } @@ -5144,7 +5154,7 @@ func (q *sqlQuerier) GetLastChatMessageByRole(ctx context.Context, arg GetLastCh const getStaleChats = `-- name: GetStaleChats :many SELECT - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels FROM chats WHERE @@ -5181,6 +5191,7 @@ func (q *sqlQuerier) GetStaleChats(ctx context.Context, staleThreshold time.Time &i.LastError, &i.Mode, pq.Array(&i.MCPServerIDs), + &i.Labels, ); err != nil { return nil, err } @@ -5244,7 +5255,8 @@ INSERT INTO chats ( last_model_config_id, title, mode, - mcp_server_ids + mcp_server_ids, + labels ) VALUES ( $1::uuid, $2::uuid, @@ -5253,21 +5265,23 @@ INSERT INTO chats ( $5::uuid, $6::text, $7::chat_mode, - COALESCE($8::uuid[], '{}'::uuid[]) + COALESCE($8::uuid[], '{}'::uuid[]), + COALESCE($9::jsonb, '{}'::jsonb) ) RETURNING - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels ` type InsertChatParams struct { - OwnerID uuid.UUID `db:"owner_id" json:"owner_id"` - WorkspaceID uuid.NullUUID `db:"workspace_id" json:"workspace_id"` - ParentChatID uuid.NullUUID `db:"parent_chat_id" json:"parent_chat_id"` - RootChatID uuid.NullUUID `db:"root_chat_id" json:"root_chat_id"` - LastModelConfigID uuid.UUID `db:"last_model_config_id" json:"last_model_config_id"` - Title string `db:"title" json:"title"` - Mode NullChatMode `db:"mode" json:"mode"` - MCPServerIDs []uuid.UUID `db:"mcp_server_ids" json:"mcp_server_ids"` + OwnerID uuid.UUID `db:"owner_id" json:"owner_id"` + WorkspaceID uuid.NullUUID `db:"workspace_id" json:"workspace_id"` + ParentChatID uuid.NullUUID `db:"parent_chat_id" json:"parent_chat_id"` + RootChatID uuid.NullUUID `db:"root_chat_id" json:"root_chat_id"` + LastModelConfigID uuid.UUID `db:"last_model_config_id" json:"last_model_config_id"` + Title string `db:"title" json:"title"` + Mode NullChatMode `db:"mode" json:"mode"` + MCPServerIDs []uuid.UUID `db:"mcp_server_ids" json:"mcp_server_ids"` + Labels pqtype.NullRawMessage `db:"labels" json:"labels"` } func (q *sqlQuerier) InsertChat(ctx context.Context, arg InsertChatParams) (Chat, error) { @@ -5280,6 +5294,7 @@ func (q *sqlQuerier) InsertChat(ctx context.Context, arg InsertChatParams) (Chat arg.Title, arg.Mode, pq.Array(arg.MCPServerIDs), + arg.Labels, ) var i Chat err := row.Scan( @@ -5300,6 +5315,7 @@ func (q *sqlQuerier) InsertChat(ctx context.Context, arg InsertChatParams) (Chat &i.LastError, &i.Mode, pq.Array(&i.MCPServerIDs), + &i.Labels, ) return i, err } @@ -5695,7 +5711,7 @@ SET WHERE id = $2::uuid RETURNING - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels ` type UpdateChatByIDParams struct { @@ -5724,6 +5740,7 @@ func (q *sqlQuerier) UpdateChatByID(ctx context.Context, arg UpdateChatByIDParam &i.LastError, &i.Mode, pq.Array(&i.MCPServerIDs), + &i.Labels, ) return i, err } @@ -5754,6 +5771,49 @@ func (q *sqlQuerier) UpdateChatHeartbeat(ctx context.Context, arg UpdateChatHear return result.RowsAffected() } +const updateChatLabelsByID = `-- name: UpdateChatLabelsByID :one +UPDATE + chats +SET + labels = $1::jsonb, + updated_at = NOW() +WHERE + id = $2::uuid +RETURNING + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels +` + +type UpdateChatLabelsByIDParams struct { + Labels json.RawMessage `db:"labels" json:"labels"` + ID uuid.UUID `db:"id" json:"id"` +} + +func (q *sqlQuerier) UpdateChatLabelsByID(ctx context.Context, arg UpdateChatLabelsByIDParams) (Chat, error) { + row := q.db.QueryRowContext(ctx, updateChatLabelsByID, arg.Labels, arg.ID) + var i Chat + err := row.Scan( + &i.ID, + &i.OwnerID, + &i.WorkspaceID, + &i.Title, + &i.Status, + &i.WorkerID, + &i.StartedAt, + &i.HeartbeatAt, + &i.CreatedAt, + &i.UpdatedAt, + &i.ParentChatID, + &i.RootChatID, + &i.LastModelConfigID, + &i.Archived, + &i.LastError, + &i.Mode, + pq.Array(&i.MCPServerIDs), + &i.Labels, + ) + return i, err +} + const updateChatMCPServerIDs = `-- name: UpdateChatMCPServerIDs :one UPDATE chats @@ -5763,7 +5823,7 @@ SET WHERE id = $2::uuid RETURNING - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels ` type UpdateChatMCPServerIDsParams struct { @@ -5792,6 +5852,7 @@ func (q *sqlQuerier) UpdateChatMCPServerIDs(ctx context.Context, arg UpdateChatM &i.LastError, &i.Mode, pq.Array(&i.MCPServerIDs), + &i.Labels, ) return i, err } @@ -5856,7 +5917,7 @@ SET WHERE id = $6::uuid RETURNING - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels ` type UpdateChatStatusParams struct { @@ -5896,6 +5957,7 @@ func (q *sqlQuerier) UpdateChatStatus(ctx context.Context, arg UpdateChatStatusP &i.LastError, &i.Mode, pq.Array(&i.MCPServerIDs), + &i.Labels, ) return i, err } @@ -5909,7 +5971,7 @@ SET WHERE id = $2::uuid RETURNING - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels ` type UpdateChatWorkspaceParams struct { @@ -5938,6 +6000,7 @@ func (q *sqlQuerier) UpdateChatWorkspace(ctx context.Context, arg UpdateChatWork &i.LastError, &i.Mode, pq.Array(&i.MCPServerIDs), + &i.Labels, ) return i, err } diff --git a/coderd/database/queries/chats.sql b/coderd/database/queries/chats.sql index 5788f99858..0130d71774 100644 --- a/coderd/database/queries/chats.sql +++ b/coderd/database/queries/chats.sql @@ -161,6 +161,10 @@ WHERE ) ELSE true END + AND CASE + WHEN sqlc.narg('label_filter')::jsonb IS NOT NULL THEN chats.labels @> sqlc.narg('label_filter')::jsonb + ELSE true + END -- Authorize Filter clause will be injected below in GetAuthorizedChats -- @authorize_filter ORDER BY @@ -181,7 +185,8 @@ INSERT INTO chats ( last_model_config_id, title, mode, - mcp_server_ids + mcp_server_ids, + labels ) VALUES ( @owner_id::uuid, sqlc.narg('workspace_id')::uuid, @@ -190,7 +195,8 @@ INSERT INTO chats ( @last_model_config_id::uuid, @title::text, sqlc.narg('mode')::chat_mode, - COALESCE(@mcp_server_ids::uuid[], '{}'::uuid[]) + COALESCE(@mcp_server_ids::uuid[], '{}'::uuid[]), + COALESCE(sqlc.narg('labels')::jsonb, '{}'::jsonb) ) RETURNING *; @@ -288,6 +294,17 @@ WHERE RETURNING *; +-- name: UpdateChatLabelsByID :one +UPDATE + chats +SET + labels = @labels::jsonb, + updated_at = NOW() +WHERE + id = @id::uuid +RETURNING + *; + -- name: UpdateChatWorkspace :one UPDATE chats diff --git a/coderd/database/sqlc.yaml b/coderd/database/sqlc.yaml index a6d5396b44..e9c933ca08 100644 --- a/coderd/database/sqlc.yaml +++ b/coderd/database/sqlc.yaml @@ -65,6 +65,9 @@ sql: - column: "provisioner_jobs.tags" go_type: type: "StringMap" + - column: "chats.labels" + go_type: + type: "StringMap" - column: "users.rbac_roles" go_type: "github.com/lib/pq.StringArray" - column: "templates.user_acl" diff --git a/coderd/exp_chats.go b/coderd/exp_chats.go index 9d3540b11f..9baa3ef511 100644 --- a/coderd/exp_chats.go +++ b/coderd/exp_chats.go @@ -23,6 +23,7 @@ import ( "github.com/go-chi/chi/v5" "github.com/google/uuid" "github.com/shopspring/decimal" + "github.com/sqlc-dev/pqtype" "golang.org/x/sync/errgroup" "golang.org/x/xerrors" @@ -191,10 +192,38 @@ func (api *API) listChats(rw http.ResponseWriter, r *http.Request) { return } + var labelFilter pqtype.NullRawMessage + if labelParams := r.URL.Query()["label"]; len(labelParams) > 0 { + labelMap := make(map[string]string, len(labelParams)) + for _, lp := range labelParams { + key, value, ok := strings.Cut(lp, ":") + if !ok || key == "" || value == "" { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: fmt.Sprintf("Invalid label filter: %q (expected format key:value, both must be non-empty)", lp), + }) + return + } + labelMap[key] = value + } + labelsJSON, err := json.Marshal(labelMap) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to marshal label filter.", + Detail: err.Error(), + }) + return + } + labelFilter = pqtype.NullRawMessage{ + RawMessage: labelsJSON, + Valid: true, + } + } + params := database.GetChatsParams{ - OwnerID: apiKey.UserID, - Archived: searchParams.Archived, - AfterID: paginationParams.AfterID, + OwnerID: apiKey.UserID, + Archived: searchParams.Archived, + AfterID: paginationParams.AfterID, + LabelFilter: labelFilter, // #nosec G115 - Pagination offsets are small and fit in int32 OffsetOpt: int32(paginationParams.Offset), // #nosec G115 - Pagination limits are small and fit in int32 @@ -320,6 +349,18 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) { mcpServerIDs = []uuid.UUID{} } + labels := req.Labels + if labels == nil { + labels = map[string]string{} + } + if errs := httpapi.ValidateChatLabels(labels); len(errs) > 0 { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Invalid labels.", + Validations: errs, + }) + return + } + chat, err := api.chatDaemon.CreateChat(ctx, chatd.CreateOptions{ OwnerID: apiKey.UserID, WorkspaceID: workspaceSelection.WorkspaceID, @@ -328,6 +369,7 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) { SystemPrompt: api.resolvedChatSystemPrompt(ctx), InitialUserContent: contentBlocks, MCPServerIDs: mcpServerIDs, + Labels: labels, }) if err != nil { if maybeWriteLimitErr(ctx, rw, err) { @@ -1407,8 +1449,8 @@ func (api *API) watchChatDesktop(rw http.ResponseWriter, r *http.Request) { logger.Debug(ctx, "desktop Bicopy finished") } -// patchChat updates a chat resource. Currently supports toggling the -// archived state via the Archived field. +// patchChat updates a chat resource. Supports updating labels and +// toggling the archived state. func (api *API) patchChat(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() chat := httpmw.ChatParam(r) @@ -1418,6 +1460,40 @@ func (api *API) patchChat(rw http.ResponseWriter, r *http.Request) { return } + if req.Labels != nil { + if errs := httpapi.ValidateChatLabels(*req.Labels); len(errs) > 0 { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Invalid labels.", + Validations: errs, + }) + return + } + labelsJSON, err := json.Marshal(*req.Labels) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to marshal labels.", + Detail: err.Error(), + }) + return + } + updatedChat, err := api.Database.UpdateChatLabelsByID(ctx, database.UpdateChatLabelsByIDParams{ + ID: chat.ID, + Labels: labelsJSON, + }) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + httpapi.ResourceNotFound(rw) + return + } + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to update chat labels.", + Detail: err.Error(), + }) + return + } + chat = updatedChat + } + if req.Archived != nil { archived := *req.Archived if archived == chat.Archived { @@ -3500,6 +3576,10 @@ func convertChat(c database.Chat, diffStatus *database.ChatDiffStatus) codersdk. if mcpServerIDs == nil { mcpServerIDs = []uuid.UUID{} } + labels := map[string]string(c.Labels) + if labels == nil { + labels = map[string]string{} + } chat := codersdk.Chat{ ID: c.ID, OwnerID: c.OwnerID, @@ -3510,6 +3590,7 @@ func convertChat(c database.Chat, diffStatus *database.ChatDiffStatus) codersdk. CreatedAt: c.CreatedAt, UpdatedAt: c.UpdatedAt, MCPServerIDs: mcpServerIDs, + Labels: labels, } if c.LastError.Valid { chat.LastError = &c.LastError.String diff --git a/coderd/httpapi/chatlabels.go b/coderd/httpapi/chatlabels.go new file mode 100644 index 0000000000..c4796ee186 --- /dev/null +++ b/coderd/httpapi/chatlabels.go @@ -0,0 +1,78 @@ +package httpapi + +import ( + "fmt" + "regexp" + + "github.com/coder/coder/v2/codersdk" +) + +const ( + // maxLabelsPerChat is the maximum number of labels allowed on a + // single chat. + maxLabelsPerChat = 50 + // maxLabelKeyLength is the maximum length of a label key in bytes. + maxLabelKeyLength = 64 + // maxLabelValueLength is the maximum length of a label value in + // bytes. + maxLabelValueLength = 256 +) + +// labelKeyRegex validates that a label key starts with an alphanumeric +// character and is followed by alphanumeric characters, dots, hyphens, +// underscores, or forward slashes. +var labelKeyRegex = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9._/-]*$`) + +// ValidateChatLabels checks that the provided labels map conforms to the +// labeling constraints for chats. It returns a list of validation +// errors, one per violated constraint. +func ValidateChatLabels(labels map[string]string) []codersdk.ValidationError { + var errs []codersdk.ValidationError + + if len(labels) > maxLabelsPerChat { + errs = append(errs, codersdk.ValidationError{ + Field: "labels", + Detail: fmt.Sprintf("too many labels (%d); maximum is %d", len(labels), maxLabelsPerChat), + }) + } + + for k, v := range labels { + if k == "" { + errs = append(errs, codersdk.ValidationError{ + Field: "labels", + Detail: "label key must not be empty", + }) + continue + } + + if len(k) > maxLabelKeyLength { + errs = append(errs, codersdk.ValidationError{ + Field: "labels", + Detail: fmt.Sprintf("label key %q exceeds maximum length of %d bytes", k, maxLabelKeyLength), + }) + } + + if !labelKeyRegex.MatchString(k) { + errs = append(errs, codersdk.ValidationError{ + Field: "labels", + Detail: fmt.Sprintf("label key %q contains invalid characters; must match %s", k, labelKeyRegex.String()), + }) + } + + if v == "" { + errs = append(errs, codersdk.ValidationError{ + Field: "labels", + Detail: fmt.Sprintf("label value for key %q must not be empty", k), + }) + } + + if len(v) > maxLabelValueLength { + errs = append(errs, codersdk.ValidationError{ + Field: "labels", + Detail: fmt.Sprintf("label value for key %q exceeds maximum length of %d bytes", k, maxLabelValueLength), + }) + } + } + + return errs +} diff --git a/coderd/httpapi/chatlabels_test.go b/coderd/httpapi/chatlabels_test.go new file mode 100644 index 0000000000..86e82dbee1 --- /dev/null +++ b/coderd/httpapi/chatlabels_test.go @@ -0,0 +1,191 @@ +package httpapi_test + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/coderd/httpapi" +) + +func TestValidateChatLabels(t *testing.T) { + t.Parallel() + + t.Run("NilMap", func(t *testing.T) { + t.Parallel() + errs := httpapi.ValidateChatLabels(nil) + require.Empty(t, errs) + }) + + t.Run("EmptyMap", func(t *testing.T) { + t.Parallel() + errs := httpapi.ValidateChatLabels(map[string]string{}) + require.Empty(t, errs) + }) + + t.Run("ValidLabels", func(t *testing.T) { + t.Parallel() + labels := map[string]string{ + "env": "production", + "github.repo": "coder/coder", + "automation/pr": "12345", + "team-backend": "core", + "version_number": "v1.2.3", + "A1.b2/c3-d4_e5": "mixed", + } + errs := httpapi.ValidateChatLabels(labels) + require.Empty(t, errs) + }) + + t.Run("TooManyLabels", func(t *testing.T) { + t.Parallel() + labels := make(map[string]string, 51) + for i := range 51 { + labels[strings.Repeat("k", i+1)] = "v" + } + errs := httpapi.ValidateChatLabels(labels) + require.NotEmpty(t, errs) + + found := false + for _, e := range errs { + if strings.Contains(e.Detail, "too many labels") { + found = true + break + } + } + assert.True(t, found, "expected a 'too many labels' error") + }) + + t.Run("KeyTooLong", func(t *testing.T) { + t.Parallel() + longKey := strings.Repeat("a", 65) + labels := map[string]string{ + longKey: "value", + } + errs := httpapi.ValidateChatLabels(labels) + require.NotEmpty(t, errs) + + found := false + for _, e := range errs { + if strings.Contains(e.Detail, "exceeds maximum length of 64 bytes") { + found = true + break + } + } + assert.True(t, found, "expected a key-too-long error") + }) + + t.Run("ValueTooLong", func(t *testing.T) { + t.Parallel() + longValue := strings.Repeat("v", 257) + labels := map[string]string{ + "key": longValue, + } + errs := httpapi.ValidateChatLabels(labels) + require.NotEmpty(t, errs) + + found := false + for _, e := range errs { + if strings.Contains(e.Detail, "exceeds maximum length of 256 bytes") { + found = true + break + } + } + assert.True(t, found, "expected a value-too-long error") + }) + + t.Run("InvalidKeyWithSpaces", func(t *testing.T) { + t.Parallel() + labels := map[string]string{ + "invalid key": "value", + } + errs := httpapi.ValidateChatLabels(labels) + require.NotEmpty(t, errs) + + found := false + for _, e := range errs { + if strings.Contains(e.Detail, "contains invalid characters") { + found = true + break + } + } + assert.True(t, found, "expected an invalid-characters error for spaces") + }) + + t.Run("InvalidKeyWithSpecialChars", func(t *testing.T) { + t.Parallel() + labels := map[string]string{ + "key@value": "value", + } + errs := httpapi.ValidateChatLabels(labels) + require.NotEmpty(t, errs) + + found := false + for _, e := range errs { + if strings.Contains(e.Detail, "contains invalid characters") { + found = true + break + } + } + assert.True(t, found, "expected an invalid-characters error for special chars") + }) + + t.Run("KeyStartsWithNonAlphanumeric", func(t *testing.T) { + t.Parallel() + labels := map[string]string{ + ".dotfirst": "value", + "-dashfirst": "value", + "_underfirst": "value", + "/slashfirst": "value", + } + errs := httpapi.ValidateChatLabels(labels) + // Each of the four keys should produce an error. + require.Len(t, errs, 4) + for _, e := range errs { + assert.Contains(t, e.Detail, "contains invalid characters") + } + }) + + t.Run("EmptyKey", func(t *testing.T) { + t.Parallel() + labels := map[string]string{ + "": "value", + } + errs := httpapi.ValidateChatLabels(labels) + require.Len(t, errs, 1) + assert.Contains(t, errs[0].Detail, "must not be empty") + }) + + t.Run("EmptyValue", func(t *testing.T) { + t.Parallel() + labels := map[string]string{ + "key": "", + } + errs := httpapi.ValidateChatLabels(labels) + require.Len(t, errs, 1) + assert.Contains(t, errs[0].Detail, "must not be empty") + }) + + t.Run("AllFieldsAreLabels", func(t *testing.T) { + t.Parallel() + labels := map[string]string{ + "bad key": "", + } + errs := httpapi.ValidateChatLabels(labels) + for _, e := range errs { + assert.Equal(t, "labels", e.Field) + } + }) + + t.Run("ExactlyAtLimits", func(t *testing.T) { + t.Parallel() + // Keys and values exactly at their limits should be valid. + labels := map[string]string{ + strings.Repeat("a", 64): strings.Repeat("v", 256), + } + errs := httpapi.ValidateChatLabels(labels) + require.Empty(t, errs) + }) +} diff --git a/coderd/x/chatd/chatd.go b/coderd/x/chatd/chatd.go index c5775c38e0..7046e398a6 100644 --- a/coderd/x/chatd/chatd.go +++ b/coderd/x/chatd/chatd.go @@ -397,6 +397,7 @@ type CreateOptions struct { SystemPrompt string InitialUserContent []codersdk.ChatMessagePart MCPServerIDs []uuid.UUID + Labels database.StringMap } // SendMessageBusyBehavior controls what happens when a chat is already active. @@ -475,6 +476,9 @@ func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.C if opts.MCPServerIDs == nil { opts.MCPServerIDs = []uuid.UUID{} } + if opts.Labels == nil { + opts.Labels = database.StringMap{} + } var chat database.Chat txErr := p.db.InTx(func(tx database.Store) error { @@ -482,6 +486,11 @@ func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.C return limitErr } + labelsJSON, err := json.Marshal(opts.Labels) + if err != nil { + return xerrors.Errorf("marshal labels: %w", err) + } + insertedChat, err := tx.InsertChat(ctx, database.InsertChatParams{ OwnerID: opts.OwnerID, WorkspaceID: opts.WorkspaceID, @@ -491,6 +500,10 @@ func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.C Title: opts.Title, Mode: opts.ChatMode, MCPServerIDs: opts.MCPServerIDs, + Labels: pqtype.NullRawMessage{ + RawMessage: labelsJSON, + Valid: true, + }, }) if err != nil { return xerrors.Errorf("insert chat: %w", err) diff --git a/codersdk/chats.go b/codersdk/chats.go index b4400436e8..355c10ed11 100644 --- a/codersdk/chats.go +++ b/codersdk/chats.go @@ -46,20 +46,21 @@ const ( // Chat represents a chat session with an AI agent. type Chat struct { - ID uuid.UUID `json:"id" format:"uuid"` - OwnerID uuid.UUID `json:"owner_id" format:"uuid"` - WorkspaceID *uuid.UUID `json:"workspace_id,omitempty" format:"uuid"` - ParentChatID *uuid.UUID `json:"parent_chat_id,omitempty" format:"uuid"` - RootChatID *uuid.UUID `json:"root_chat_id,omitempty" format:"uuid"` - LastModelConfigID uuid.UUID `json:"last_model_config_id" format:"uuid"` - Title string `json:"title"` - Status ChatStatus `json:"status"` - LastError *string `json:"last_error"` - DiffStatus *ChatDiffStatus `json:"diff_status,omitempty"` - CreatedAt time.Time `json:"created_at" format:"date-time"` - UpdatedAt time.Time `json:"updated_at" format:"date-time"` - Archived bool `json:"archived"` - MCPServerIDs []uuid.UUID `json:"mcp_server_ids" format:"uuid"` + ID uuid.UUID `json:"id" format:"uuid"` + OwnerID uuid.UUID `json:"owner_id" format:"uuid"` + WorkspaceID *uuid.UUID `json:"workspace_id,omitempty" format:"uuid"` + ParentChatID *uuid.UUID `json:"parent_chat_id,omitempty" format:"uuid"` + RootChatID *uuid.UUID `json:"root_chat_id,omitempty" format:"uuid"` + LastModelConfigID uuid.UUID `json:"last_model_config_id" format:"uuid"` + Title string `json:"title"` + Status ChatStatus `json:"status"` + LastError *string `json:"last_error"` + DiffStatus *ChatDiffStatus `json:"diff_status,omitempty"` + CreatedAt time.Time `json:"created_at" format:"date-time"` + UpdatedAt time.Time `json:"updated_at" format:"date-time"` + Archived bool `json:"archived"` + MCPServerIDs []uuid.UUID `json:"mcp_server_ids" format:"uuid"` + Labels map[string]string `json:"labels"` } // ChatMessage represents a single message in a chat. @@ -311,16 +312,18 @@ type ChatInputPart struct { // CreateChatRequest is the request to create a new chat. type CreateChatRequest struct { - Content []ChatInputPart `json:"content"` - WorkspaceID *uuid.UUID `json:"workspace_id,omitempty" format:"uuid"` - ModelConfigID *uuid.UUID `json:"model_config_id,omitempty" format:"uuid"` - MCPServerIDs []uuid.UUID `json:"mcp_server_ids,omitempty" format:"uuid"` + Content []ChatInputPart `json:"content"` + WorkspaceID *uuid.UUID `json:"workspace_id,omitempty" format:"uuid"` + ModelConfigID *uuid.UUID `json:"model_config_id,omitempty" format:"uuid"` + MCPServerIDs []uuid.UUID `json:"mcp_server_ids,omitempty" format:"uuid"` + Labels map[string]string `json:"labels,omitempty"` } // UpdateChatRequest is the request to update a chat. type UpdateChatRequest struct { - Title *string `json:"title,omitempty"` - Archived *bool `json:"archived,omitempty"` + Title *string `json:"title,omitempty"` + Archived *bool `json:"archived,omitempty"` + Labels *map[string]string `json:"labels,omitempty"` } // CreateChatMessageRequest is the request to add a message to a chat. @@ -1166,7 +1169,8 @@ type ChatUsageLimitConfigResponse struct { // ListChatsOptions are optional parameters for ListChats. type ListChatsOptions struct { - Query string + Query string + Labels map[string]string Pagination } @@ -1182,6 +1186,15 @@ func (c *ExperimentalClient) ListChats(ctx context.Context, opts *ListChatsOptio r.URL.RawQuery = q.Encode() }) } + if len(opts.Labels) > 0 { + reqOpts = append(reqOpts, func(r *http.Request) { + q := r.URL.Query() + for k, v := range opts.Labels { + q.Add("label", k+":"+v) + } + r.URL.RawQuery = q.Encode() + }) + } } res, err := c.Request(ctx, http.MethodGet, "/api/experimental/chats", nil, reqOpts...) if err != nil { diff --git a/site/src/api/queries/chats.test.ts b/site/src/api/queries/chats.test.ts index 8ca63da900..0534ffa83c 100644 --- a/site/src/api/queries/chats.test.ts +++ b/site/src/api/queries/chats.test.ts @@ -78,6 +78,7 @@ const makeChat = ( owner_id: "owner-1", last_model_config_id: "model-1", mcp_server_ids: [], + labels: {}, title: `Chat ${id}`, status: "running", created_at: "2025-01-01T00:00:00.000Z", diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index 7acb7db689..c1fd4cc6c0 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -1100,6 +1100,7 @@ export interface Chat { readonly updated_at: string; readonly archived: boolean; readonly mcp_server_ids: readonly string[]; + readonly labels: Record; } // From codersdk/chats.go @@ -2245,6 +2246,7 @@ export interface CreateChatRequest { readonly workspace_id?: string; readonly model_config_id?: string; readonly mcp_server_ids?: readonly string[]; + readonly labels?: Record; } // From codersdk/users.go @@ -3777,6 +3779,7 @@ export interface LinkConfig { */ export interface ListChatsOptions extends Pagination { readonly Query: string; + readonly Labels: Record; } // From codersdk/inboxnotification.go @@ -7008,6 +7011,7 @@ export interface UpdateChatProviderConfigRequest { export interface UpdateChatRequest { readonly title?: string; readonly archived?: boolean; + readonly labels?: Record; } // From codersdk/chats.go diff --git a/site/src/pages/AgentsPage/AgentDetail.stories.tsx b/site/src/pages/AgentsPage/AgentDetail.stories.tsx index bc31ed3a49..52eb84f617 100644 --- a/site/src/pages/AgentsPage/AgentDetail.stories.tsx +++ b/site/src/pages/AgentsPage/AgentDetail.stories.tsx @@ -112,6 +112,7 @@ const baseChatFields = { workspace_id: mockWorkspace.id, last_model_config_id: "model-config-1", mcp_server_ids: [], + labels: {}, created_at: "2026-02-18T00:00:00.000Z", updated_at: "2026-02-18T00:00:00.000Z", archived: false, diff --git a/site/src/pages/AgentsPage/AgentsPageView.stories.tsx b/site/src/pages/AgentsPage/AgentsPageView.stories.tsx index cbc48ae30f..cc5a3401da 100644 --- a/site/src/pages/AgentsPage/AgentsPageView.stories.tsx +++ b/site/src/pages/AgentsPage/AgentsPageView.stories.tsx @@ -120,6 +120,7 @@ const buildChat = (overrides: Partial = {}): Chat => ({ status: "completed", last_model_config_id: defaultModelConfigs[0].id, mcp_server_ids: [], + labels: {}, created_at: oneWeekAgo, updated_at: oneWeekAgo, archived: false, diff --git a/site/src/pages/AgentsPage/components/AgentDetail/ChatContext.test.tsx b/site/src/pages/AgentsPage/components/AgentDetail/ChatContext.test.tsx index e30e8daef0..5d4b080062 100644 --- a/site/src/pages/AgentsPage/components/AgentDetail/ChatContext.test.tsx +++ b/site/src/pages/AgentsPage/components/AgentDetail/ChatContext.test.tsx @@ -201,6 +201,7 @@ const makeChat = (chatID: string): TypesGen.Chat => ({ owner_id: "owner-1", last_model_config_id: "model-1", mcp_server_ids: [], + labels: {}, title: "test", status: "running", created_at: "2025-01-01T00:00:00.000Z", diff --git a/site/src/pages/AgentsPage/components/AgentDetail/TopBar.stories.tsx b/site/src/pages/AgentsPage/components/AgentDetail/TopBar.stories.tsx index 553f42c629..9146150415 100644 --- a/site/src/pages/AgentsPage/components/AgentDetail/TopBar.stories.tsx +++ b/site/src/pages/AgentsPage/components/AgentDetail/TopBar.stories.tsx @@ -52,6 +52,7 @@ export const WithParentChat: Story = { owner_id: "owner-id", last_model_config_id: "model-config-1", mcp_server_ids: [], + labels: {}, title: "Set up CI/CD pipeline", status: "completed", last_error: null, diff --git a/site/src/pages/AgentsPage/components/AgentDetailView.stories.tsx b/site/src/pages/AgentsPage/components/AgentDetailView.stories.tsx index dcd5a29a1f..452a74bbb4 100644 --- a/site/src/pages/AgentsPage/components/AgentDetailView.stories.tsx +++ b/site/src/pages/AgentsPage/components/AgentDetailView.stories.tsx @@ -39,6 +39,7 @@ const buildChat = (overrides: Partial = {}): TypesGen.Chat => ({ status: "completed", last_model_config_id: "model-config-1", mcp_server_ids: [], + labels: {}, created_at: oneWeekAgo, updated_at: oneWeekAgo, archived: false, diff --git a/site/src/pages/AgentsPage/components/Sidebar/AgentsSidebar.stories.tsx b/site/src/pages/AgentsPage/components/Sidebar/AgentsSidebar.stories.tsx index e2e48f2cd1..86c98fd726 100644 --- a/site/src/pages/AgentsPage/components/Sidebar/AgentsSidebar.stories.tsx +++ b/site/src/pages/AgentsPage/components/Sidebar/AgentsSidebar.stories.tsx @@ -41,6 +41,7 @@ const buildChat = (overrides: Partial = {}): Chat => ({ status: "completed", last_model_config_id: defaultModelConfigs[0].id, mcp_server_ids: [], + labels: {}, created_at: oneWeekAgo, updated_at: oneWeekAgo, archived: false, diff --git a/site/src/pages/AgentsPage/components/Sidebar/AgentsSidebar.test.tsx b/site/src/pages/AgentsPage/components/Sidebar/AgentsSidebar.test.tsx index d761efd2cc..809d72096b 100644 --- a/site/src/pages/AgentsPage/components/Sidebar/AgentsSidebar.test.tsx +++ b/site/src/pages/AgentsPage/components/Sidebar/AgentsSidebar.test.tsx @@ -64,6 +64,7 @@ const buildChat = (overrides: Partial = {}): Chat => ({ archived: false, last_error: null, mcp_server_ids: [], + labels: {}, ...overrides, });