mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
feat: add labels to chats (#23594)
## Summary
Adds a general-purpose `map[string]string` label system to chats, stored
as jsonb with a GIN index for efficient containment queries.
This is a standalone foundational feature that will be used by the
upcoming Automations feature for session identity (matching webhook
events to existing chats), replacing the need for bespoke session-key
tables.
## Changes
### Database
- **Migration 000451**: Adds `labels jsonb NOT NULL DEFAULT '{}'` column
to `chats` table with a GIN index (`idx_chats_labels`)
- **`InsertChat`**: Accepts labels on creation via `COALESCE(@labels,
'{}')`
- **`UpdateChatByID`**: Supports partial update —
`COALESCE(sqlc.narg('labels'), labels)` preserves existing labels when
NULL is passed
- **`GetChats`**: New `has_labels` filter using PostgreSQL `@>`
containment operator
- **`GetAuthorizedChats`**: Synced with generated `GetChats` (new column
scan + query param)
### API
- **Create chat** (`POST /chats`): Accepts optional `labels` field,
validated before creation
- **Update chat** (`PATCH /chats/{chat}`): Supports `labels` field for
atomic label replacement
- **List chats** (`GET /chats`): Supports `?label=key:value` query
parameters (multiple are AND-ed)
### SDK
- `Chat`, `CreateChatRequest`, `UpdateChatRequest`, `ListChatsOptions`
all gain `Labels` fields
- `UpdateChatRequest.Labels` is a pointer (`*map[string]string`) so
`nil` means "don't change" vs empty map means "clear all"
### Validation (`coderd/httpapi/labels.go`)
- Max 50 labels per chat
- Key: 1–64 chars, must match `[a-zA-Z0-9][a-zA-Z0-9._/-]*` (supports
namespaced keys like `github.repo`, `automation/pr-number`)
- Value: 1–256 chars
- 13 test cases covering all edge cases
### Chat runtime
- `chatd.CreateOptions` gains `Labels` field, threaded through to
`InsertChat`
- Existing `UpdateChatByID` callers (e.g., quickgen title updates) are
unaffected — NULL labels preserve existing values via COALESCE
This commit is contained in:
@@ -5641,6 +5641,17 @@ func (q *querier) UpdateChatHeartbeat(ctx context.Context, arg database.UpdateCh
|
||||
return q.db.UpdateChatHeartbeat(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatLabelsByID(ctx context.Context, arg database.UpdateChatLabelsByIDParams) (database.Chat, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ID)
|
||||
if err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
return q.db.UpdateChatLabelsByID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatMCPServerIDs(ctx context.Context, arg database.UpdateChatMCPServerIDsParams) (database.Chat, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ID)
|
||||
if err != nil {
|
||||
|
||||
@@ -749,6 +749,16 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().UpdateChatByID(gomock.Any(), arg).Return(chat, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(chat)
|
||||
}))
|
||||
s.Run("UpdateChatLabelsByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.UpdateChatLabelsByIDParams{
|
||||
ID: chat.ID,
|
||||
Labels: []byte(`{"env":"prod"}`),
|
||||
}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateChatLabelsByID(gomock.Any(), arg).Return(chat, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(chat)
|
||||
}))
|
||||
s.Run("UpdateChatHeartbeat", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.UpdateChatHeartbeatParams{
|
||||
|
||||
@@ -4016,6 +4016,14 @@ func (m queryMetricsStore) UpdateChatHeartbeat(ctx context.Context, arg database
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatLabelsByID(ctx context.Context, arg database.UpdateChatLabelsByIDParams) (database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateChatLabelsByID(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateChatLabelsByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatLabelsByID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatMCPServerIDs(ctx context.Context, arg database.UpdateChatMCPServerIDsParams) (database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateChatMCPServerIDs(ctx, arg)
|
||||
|
||||
@@ -7582,6 +7582,21 @@ func (mr *MockStoreMockRecorder) UpdateChatHeartbeat(ctx, arg any) *gomock.Call
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatHeartbeat", reflect.TypeOf((*MockStore)(nil).UpdateChatHeartbeat), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatLabelsByID mocks base method.
|
||||
func (m *MockStore) UpdateChatLabelsByID(ctx context.Context, arg database.UpdateChatLabelsByIDParams) (database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateChatLabelsByID", ctx, arg)
|
||||
ret0, _ := ret[0].(database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateChatLabelsByID indicates an expected call of UpdateChatLabelsByID.
|
||||
func (mr *MockStoreMockRecorder) UpdateChatLabelsByID(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatLabelsByID", reflect.TypeOf((*MockStore)(nil).UpdateChatLabelsByID), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatMCPServerIDs mocks base method.
|
||||
func (m *MockStore) UpdateChatMCPServerIDs(ctx context.Context, arg database.UpdateChatMCPServerIDsParams) (database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
Generated
+4
-1
@@ -1398,7 +1398,8 @@ CREATE TABLE chats (
|
||||
archived boolean DEFAULT false NOT NULL,
|
||||
last_error text,
|
||||
mode chat_mode,
|
||||
mcp_server_ids uuid[] DEFAULT '{}'::uuid[] NOT NULL
|
||||
mcp_server_ids uuid[] DEFAULT '{}'::uuid[] NOT NULL,
|
||||
labels jsonb DEFAULT '{}'::jsonb NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE connection_logs (
|
||||
@@ -3726,6 +3727,8 @@ CREATE INDEX idx_chat_providers_enabled ON chat_providers USING btree (enabled);
|
||||
|
||||
CREATE INDEX idx_chat_queued_messages_chat_id ON chat_queued_messages USING btree (chat_id);
|
||||
|
||||
CREATE INDEX idx_chats_labels ON chats USING gin (labels);
|
||||
|
||||
CREATE INDEX idx_chats_last_model_config_id ON chats USING btree (last_model_config_id);
|
||||
|
||||
CREATE INDEX idx_chats_owner ON chats USING btree (owner_id);
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
DROP INDEX IF EXISTS idx_chats_labels;
|
||||
|
||||
ALTER TABLE chats DROP COLUMN labels;
|
||||
@@ -0,0 +1,3 @@
|
||||
ALTER TABLE chats ADD COLUMN labels jsonb NOT NULL DEFAULT '{}';
|
||||
|
||||
CREATE INDEX idx_chats_labels ON chats USING GIN (labels);
|
||||
@@ -761,6 +761,7 @@ func (q *sqlQuerier) GetAuthorizedChats(ctx context.Context, arg GetChatsParams,
|
||||
arg.OwnerID,
|
||||
arg.Archived,
|
||||
arg.AfterID,
|
||||
arg.LabelFilter,
|
||||
arg.OffsetOpt,
|
||||
arg.LimitOpt,
|
||||
)
|
||||
@@ -789,6 +790,7 @@ func (q *sqlQuerier) GetAuthorizedChats(ctx context.Context, arg GetChatsParams,
|
||||
&i.LastError,
|
||||
&i.Mode,
|
||||
pq.Array(&i.MCPServerIDs),
|
||||
&i.Labels,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -4170,6 +4170,7 @@ type Chat struct {
|
||||
LastError sql.NullString `db:"last_error" json:"last_error"`
|
||||
Mode NullChatMode `db:"mode" json:"mode"`
|
||||
MCPServerIDs []uuid.UUID `db:"mcp_server_ids" json:"mcp_server_ids"`
|
||||
Labels StringMap `db:"labels" json:"labels"`
|
||||
}
|
||||
|
||||
type ChatDiffStatus struct {
|
||||
|
||||
@@ -823,6 +823,7 @@ type sqlcQuerier interface {
|
||||
// Bumps the heartbeat timestamp for a running chat so that other
|
||||
// replicas know the worker is still alive.
|
||||
UpdateChatHeartbeat(ctx context.Context, arg UpdateChatHeartbeatParams) (int64, error)
|
||||
UpdateChatLabelsByID(ctx context.Context, arg UpdateChatLabelsByIDParams) (Chat, error)
|
||||
UpdateChatMCPServerIDs(ctx context.Context, arg UpdateChatMCPServerIDsParams) (Chat, error)
|
||||
UpdateChatMessageByID(ctx context.Context, arg UpdateChatMessageByIDParams) (ChatMessage, error)
|
||||
UpdateChatModelConfig(ctx context.Context, arg UpdateChatModelConfigParams) (ChatModelConfig, error)
|
||||
|
||||
@@ -10486,3 +10486,215 @@ func TestGetPRInsights(t *testing.T) {
|
||||
assert.Equal(t, int64(5_000_000), summary.MergedCostMicros)
|
||||
})
|
||||
}
|
||||
|
||||
func TestChatLabels(t *testing.T) {
|
||||
t.Parallel()
|
||||
if testing.Short() {
|
||||
t.SkipNow()
|
||||
}
|
||||
|
||||
sqlDB := testSQLDB(t)
|
||||
err := migrations.Up(sqlDB)
|
||||
require.NoError(t, err)
|
||||
db := database.New(sqlDB)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
owner := dbgen.User(t, db, database.User{})
|
||||
|
||||
_, err = db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-key",
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
modelCfg, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
||||
Provider: "openai",
|
||||
Model: "test-model",
|
||||
DisplayName: "Test Model",
|
||||
CreatedBy: uuid.NullUUID{UUID: owner.ID, Valid: true},
|
||||
UpdatedBy: uuid.NullUUID{UUID: owner.ID, Valid: true},
|
||||
Enabled: true,
|
||||
IsDefault: true,
|
||||
ContextLimit: 128000,
|
||||
CompressionThreshold: 80,
|
||||
Options: json.RawMessage(`{}`),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("CreateWithLabels", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
labels := database.StringMap{"github.repo": "coder/coder", "env": "prod"}
|
||||
labelsJSON, err := json.Marshal(labels)
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
OwnerID: owner.ID,
|
||||
LastModelConfigID: modelCfg.ID,
|
||||
Title: "labeled-chat",
|
||||
Labels: pqtype.NullRawMessage{
|
||||
RawMessage: labelsJSON,
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, database.StringMap{"github.repo": "coder/coder", "env": "prod"}, chat.Labels)
|
||||
|
||||
// Read back and verify.
|
||||
fetched, err := db.GetChatByID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, chat.Labels, fetched.Labels)
|
||||
})
|
||||
|
||||
t.Run("CreateWithoutLabels", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
OwnerID: owner.ID,
|
||||
LastModelConfigID: modelCfg.ID,
|
||||
Title: "no-labels-chat",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
// Default should be an empty map, not nil.
|
||||
require.NotNil(t, chat.Labels)
|
||||
require.Empty(t, chat.Labels)
|
||||
})
|
||||
|
||||
t.Run("UpdateLabels", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
OwnerID: owner.ID,
|
||||
LastModelConfigID: modelCfg.ID,
|
||||
Title: "update-labels-chat",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, chat.Labels)
|
||||
|
||||
// Set labels.
|
||||
newLabels, err := json.Marshal(database.StringMap{"team": "backend"})
|
||||
require.NoError(t, err)
|
||||
updated, err := db.UpdateChatLabelsByID(ctx, database.UpdateChatLabelsByIDParams{
|
||||
ID: chat.ID,
|
||||
Labels: newLabels,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, database.StringMap{"team": "backend"}, updated.Labels)
|
||||
|
||||
// Title should be unchanged.
|
||||
require.Equal(t, "update-labels-chat", updated.Title)
|
||||
|
||||
// Clear labels by setting empty object.
|
||||
emptyLabels, err := json.Marshal(database.StringMap{})
|
||||
require.NoError(t, err)
|
||||
cleared, err := db.UpdateChatLabelsByID(ctx, database.UpdateChatLabelsByIDParams{
|
||||
ID: chat.ID,
|
||||
Labels: emptyLabels,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, cleared.Labels)
|
||||
})
|
||||
|
||||
t.Run("UpdateTitleDoesNotAffectLabels", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
labels := database.StringMap{"pr": "1234"}
|
||||
labelsJSON, err := json.Marshal(labels)
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
OwnerID: owner.ID,
|
||||
LastModelConfigID: modelCfg.ID,
|
||||
Title: "original-title",
|
||||
Labels: pqtype.NullRawMessage{
|
||||
RawMessage: labelsJSON,
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update title only — labels must survive.
|
||||
updated, err := db.UpdateChatByID(ctx, database.UpdateChatByIDParams{
|
||||
ID: chat.ID,
|
||||
Title: "new-title",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "new-title", updated.Title)
|
||||
require.Equal(t, database.StringMap{"pr": "1234"}, updated.Labels)
|
||||
})
|
||||
|
||||
t.Run("FilterByLabels", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
// Create three chats with different labels.
|
||||
for _, tc := range []struct {
|
||||
title string
|
||||
labels database.StringMap
|
||||
}{
|
||||
{"filter-a", database.StringMap{"env": "prod", "team": "backend"}},
|
||||
{"filter-b", database.StringMap{"env": "prod", "team": "frontend"}},
|
||||
{"filter-c", database.StringMap{"env": "staging"}},
|
||||
} {
|
||||
labelsJSON, err := json.Marshal(tc.labels)
|
||||
require.NoError(t, err)
|
||||
_, err = db.InsertChat(ctx, database.InsertChatParams{
|
||||
OwnerID: owner.ID,
|
||||
LastModelConfigID: modelCfg.ID,
|
||||
Title: tc.title,
|
||||
Labels: pqtype.NullRawMessage{
|
||||
RawMessage: labelsJSON,
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Filter by env=prod — should match filter-a and filter-b.
|
||||
filterJSON, err := json.Marshal(database.StringMap{"env": "prod"})
|
||||
require.NoError(t, err)
|
||||
results, err := db.GetChats(ctx, database.GetChatsParams{
|
||||
OwnerID: owner.ID,
|
||||
LabelFilter: pqtype.NullRawMessage{
|
||||
RawMessage: filterJSON,
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
titles := make([]string, 0, len(results))
|
||||
for _, c := range results {
|
||||
titles = append(titles, c.Title)
|
||||
}
|
||||
require.Contains(t, titles, "filter-a")
|
||||
require.Contains(t, titles, "filter-b")
|
||||
require.NotContains(t, titles, "filter-c")
|
||||
|
||||
// Filter by env=prod AND team=backend — should match only filter-a.
|
||||
filterJSON, err = json.Marshal(database.StringMap{"env": "prod", "team": "backend"})
|
||||
require.NoError(t, err)
|
||||
results, err = db.GetChats(ctx, database.GetChatsParams{
|
||||
OwnerID: owner.ID,
|
||||
LabelFilter: pqtype.NullRawMessage{
|
||||
RawMessage: filterJSON,
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, results, 1)
|
||||
require.Equal(t, "filter-a", results[0].Title)
|
||||
|
||||
// No filter — should return all chats for this owner.
|
||||
allChats, err := db.GetChats(ctx, database.GetChatsParams{
|
||||
OwnerID: owner.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.GreaterOrEqual(t, len(allChats), 3)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -3823,7 +3823,7 @@ WHERE
|
||||
$3::int
|
||||
)
|
||||
RETURNING
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels
|
||||
`
|
||||
|
||||
type AcquireChatsParams struct {
|
||||
@@ -3861,6 +3861,7 @@ func (q *sqlQuerier) AcquireChats(ctx context.Context, arg AcquireChatsParams) (
|
||||
&i.LastError,
|
||||
&i.Mode,
|
||||
pq.Array(&i.MCPServerIDs),
|
||||
&i.Labels,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -4094,7 +4095,7 @@ func (q *sqlQuerier) DeleteChatUsageLimitUserOverride(ctx context.Context, userI
|
||||
|
||||
const getChatByID = `-- name: GetChatByID :one
|
||||
SELECT
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels
|
||||
FROM
|
||||
chats
|
||||
WHERE
|
||||
@@ -4122,12 +4123,13 @@ func (q *sqlQuerier) GetChatByID(ctx context.Context, id uuid.UUID) (Chat, error
|
||||
&i.LastError,
|
||||
&i.Mode,
|
||||
pq.Array(&i.MCPServerIDs),
|
||||
&i.Labels,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getChatByIDForUpdate = `-- name: GetChatByIDForUpdate :one
|
||||
SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids FROM chats WHERE id = $1::uuid FOR UPDATE
|
||||
SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels FROM chats WHERE id = $1::uuid FOR UPDATE
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (Chat, error) {
|
||||
@@ -4151,6 +4153,7 @@ func (q *sqlQuerier) GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (Ch
|
||||
&i.LastError,
|
||||
&i.Mode,
|
||||
pq.Array(&i.MCPServerIDs),
|
||||
&i.Labels,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@@ -4995,7 +4998,7 @@ func (q *sqlQuerier) GetChatUsageLimitUserOverride(ctx context.Context, userID u
|
||||
|
||||
const getChats = `-- name: GetChats :many
|
||||
SELECT
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels
|
||||
FROM
|
||||
chats
|
||||
WHERE
|
||||
@@ -5026,24 +5029,29 @@ WHERE
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
AND CASE
|
||||
WHEN $4::jsonb IS NOT NULL THEN chats.labels @> $4::jsonb
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in GetAuthorizedChats
|
||||
-- @authorize_filter
|
||||
ORDER BY
|
||||
-- Deterministic and consistent ordering of all rows, even if they share
|
||||
-- a timestamp. This is to ensure consistent pagination.
|
||||
(updated_at, id) DESC OFFSET $4
|
||||
(updated_at, id) DESC OFFSET $5
|
||||
LIMIT
|
||||
-- The chat list is unbounded and expected to grow large.
|
||||
-- Default to 50 to prevent accidental excessively large queries.
|
||||
COALESCE(NULLIF($5 :: int, 0), 50)
|
||||
COALESCE(NULLIF($6 :: int, 0), 50)
|
||||
`
|
||||
|
||||
type GetChatsParams struct {
|
||||
OwnerID uuid.UUID `db:"owner_id" json:"owner_id"`
|
||||
Archived sql.NullBool `db:"archived" json:"archived"`
|
||||
AfterID uuid.UUID `db:"after_id" json:"after_id"`
|
||||
OffsetOpt int32 `db:"offset_opt" json:"offset_opt"`
|
||||
LimitOpt int32 `db:"limit_opt" json:"limit_opt"`
|
||||
OwnerID uuid.UUID `db:"owner_id" json:"owner_id"`
|
||||
Archived sql.NullBool `db:"archived" json:"archived"`
|
||||
AfterID uuid.UUID `db:"after_id" json:"after_id"`
|
||||
LabelFilter pqtype.NullRawMessage `db:"label_filter" json:"label_filter"`
|
||||
OffsetOpt int32 `db:"offset_opt" json:"offset_opt"`
|
||||
LimitOpt int32 `db:"limit_opt" json:"limit_opt"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) GetChats(ctx context.Context, arg GetChatsParams) ([]Chat, error) {
|
||||
@@ -5051,6 +5059,7 @@ func (q *sqlQuerier) GetChats(ctx context.Context, arg GetChatsParams) ([]Chat,
|
||||
arg.OwnerID,
|
||||
arg.Archived,
|
||||
arg.AfterID,
|
||||
arg.LabelFilter,
|
||||
arg.OffsetOpt,
|
||||
arg.LimitOpt,
|
||||
)
|
||||
@@ -5079,6 +5088,7 @@ func (q *sqlQuerier) GetChats(ctx context.Context, arg GetChatsParams) ([]Chat,
|
||||
&i.LastError,
|
||||
&i.Mode,
|
||||
pq.Array(&i.MCPServerIDs),
|
||||
&i.Labels,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -5144,7 +5154,7 @@ func (q *sqlQuerier) GetLastChatMessageByRole(ctx context.Context, arg GetLastCh
|
||||
|
||||
const getStaleChats = `-- name: GetStaleChats :many
|
||||
SELECT
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels
|
||||
FROM
|
||||
chats
|
||||
WHERE
|
||||
@@ -5181,6 +5191,7 @@ func (q *sqlQuerier) GetStaleChats(ctx context.Context, staleThreshold time.Time
|
||||
&i.LastError,
|
||||
&i.Mode,
|
||||
pq.Array(&i.MCPServerIDs),
|
||||
&i.Labels,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -5244,7 +5255,8 @@ INSERT INTO chats (
|
||||
last_model_config_id,
|
||||
title,
|
||||
mode,
|
||||
mcp_server_ids
|
||||
mcp_server_ids,
|
||||
labels
|
||||
) VALUES (
|
||||
$1::uuid,
|
||||
$2::uuid,
|
||||
@@ -5253,21 +5265,23 @@ INSERT INTO chats (
|
||||
$5::uuid,
|
||||
$6::text,
|
||||
$7::chat_mode,
|
||||
COALESCE($8::uuid[], '{}'::uuid[])
|
||||
COALESCE($8::uuid[], '{}'::uuid[]),
|
||||
COALESCE($9::jsonb, '{}'::jsonb)
|
||||
)
|
||||
RETURNING
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels
|
||||
`
|
||||
|
||||
type InsertChatParams struct {
|
||||
OwnerID uuid.UUID `db:"owner_id" json:"owner_id"`
|
||||
WorkspaceID uuid.NullUUID `db:"workspace_id" json:"workspace_id"`
|
||||
ParentChatID uuid.NullUUID `db:"parent_chat_id" json:"parent_chat_id"`
|
||||
RootChatID uuid.NullUUID `db:"root_chat_id" json:"root_chat_id"`
|
||||
LastModelConfigID uuid.UUID `db:"last_model_config_id" json:"last_model_config_id"`
|
||||
Title string `db:"title" json:"title"`
|
||||
Mode NullChatMode `db:"mode" json:"mode"`
|
||||
MCPServerIDs []uuid.UUID `db:"mcp_server_ids" json:"mcp_server_ids"`
|
||||
OwnerID uuid.UUID `db:"owner_id" json:"owner_id"`
|
||||
WorkspaceID uuid.NullUUID `db:"workspace_id" json:"workspace_id"`
|
||||
ParentChatID uuid.NullUUID `db:"parent_chat_id" json:"parent_chat_id"`
|
||||
RootChatID uuid.NullUUID `db:"root_chat_id" json:"root_chat_id"`
|
||||
LastModelConfigID uuid.UUID `db:"last_model_config_id" json:"last_model_config_id"`
|
||||
Title string `db:"title" json:"title"`
|
||||
Mode NullChatMode `db:"mode" json:"mode"`
|
||||
MCPServerIDs []uuid.UUID `db:"mcp_server_ids" json:"mcp_server_ids"`
|
||||
Labels pqtype.NullRawMessage `db:"labels" json:"labels"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) InsertChat(ctx context.Context, arg InsertChatParams) (Chat, error) {
|
||||
@@ -5280,6 +5294,7 @@ func (q *sqlQuerier) InsertChat(ctx context.Context, arg InsertChatParams) (Chat
|
||||
arg.Title,
|
||||
arg.Mode,
|
||||
pq.Array(arg.MCPServerIDs),
|
||||
arg.Labels,
|
||||
)
|
||||
var i Chat
|
||||
err := row.Scan(
|
||||
@@ -5300,6 +5315,7 @@ func (q *sqlQuerier) InsertChat(ctx context.Context, arg InsertChatParams) (Chat
|
||||
&i.LastError,
|
||||
&i.Mode,
|
||||
pq.Array(&i.MCPServerIDs),
|
||||
&i.Labels,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@@ -5695,7 +5711,7 @@ SET
|
||||
WHERE
|
||||
id = $2::uuid
|
||||
RETURNING
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels
|
||||
`
|
||||
|
||||
type UpdateChatByIDParams struct {
|
||||
@@ -5724,6 +5740,7 @@ func (q *sqlQuerier) UpdateChatByID(ctx context.Context, arg UpdateChatByIDParam
|
||||
&i.LastError,
|
||||
&i.Mode,
|
||||
pq.Array(&i.MCPServerIDs),
|
||||
&i.Labels,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@@ -5754,6 +5771,49 @@ func (q *sqlQuerier) UpdateChatHeartbeat(ctx context.Context, arg UpdateChatHear
|
||||
return result.RowsAffected()
|
||||
}
|
||||
|
||||
const updateChatLabelsByID = `-- name: UpdateChatLabelsByID :one
|
||||
UPDATE
|
||||
chats
|
||||
SET
|
||||
labels = $1::jsonb,
|
||||
updated_at = NOW()
|
||||
WHERE
|
||||
id = $2::uuid
|
||||
RETURNING
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels
|
||||
`
|
||||
|
||||
type UpdateChatLabelsByIDParams struct {
|
||||
Labels json.RawMessage `db:"labels" json:"labels"`
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) UpdateChatLabelsByID(ctx context.Context, arg UpdateChatLabelsByIDParams) (Chat, error) {
|
||||
row := q.db.QueryRowContext(ctx, updateChatLabelsByID, arg.Labels, arg.ID)
|
||||
var i Chat
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.OwnerID,
|
||||
&i.WorkspaceID,
|
||||
&i.Title,
|
||||
&i.Status,
|
||||
&i.WorkerID,
|
||||
&i.StartedAt,
|
||||
&i.HeartbeatAt,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.ParentChatID,
|
||||
&i.RootChatID,
|
||||
&i.LastModelConfigID,
|
||||
&i.Archived,
|
||||
&i.LastError,
|
||||
&i.Mode,
|
||||
pq.Array(&i.MCPServerIDs),
|
||||
&i.Labels,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const updateChatMCPServerIDs = `-- name: UpdateChatMCPServerIDs :one
|
||||
UPDATE
|
||||
chats
|
||||
@@ -5763,7 +5823,7 @@ SET
|
||||
WHERE
|
||||
id = $2::uuid
|
||||
RETURNING
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels
|
||||
`
|
||||
|
||||
type UpdateChatMCPServerIDsParams struct {
|
||||
@@ -5792,6 +5852,7 @@ func (q *sqlQuerier) UpdateChatMCPServerIDs(ctx context.Context, arg UpdateChatM
|
||||
&i.LastError,
|
||||
&i.Mode,
|
||||
pq.Array(&i.MCPServerIDs),
|
||||
&i.Labels,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@@ -5856,7 +5917,7 @@ SET
|
||||
WHERE
|
||||
id = $6::uuid
|
||||
RETURNING
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels
|
||||
`
|
||||
|
||||
type UpdateChatStatusParams struct {
|
||||
@@ -5896,6 +5957,7 @@ func (q *sqlQuerier) UpdateChatStatus(ctx context.Context, arg UpdateChatStatusP
|
||||
&i.LastError,
|
||||
&i.Mode,
|
||||
pq.Array(&i.MCPServerIDs),
|
||||
&i.Labels,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@@ -5909,7 +5971,7 @@ SET
|
||||
WHERE
|
||||
id = $2::uuid
|
||||
RETURNING
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels
|
||||
`
|
||||
|
||||
type UpdateChatWorkspaceParams struct {
|
||||
@@ -5938,6 +6000,7 @@ func (q *sqlQuerier) UpdateChatWorkspace(ctx context.Context, arg UpdateChatWork
|
||||
&i.LastError,
|
||||
&i.Mode,
|
||||
pq.Array(&i.MCPServerIDs),
|
||||
&i.Labels,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
@@ -161,6 +161,10 @@ WHERE
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
AND CASE
|
||||
WHEN sqlc.narg('label_filter')::jsonb IS NOT NULL THEN chats.labels @> sqlc.narg('label_filter')::jsonb
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in GetAuthorizedChats
|
||||
-- @authorize_filter
|
||||
ORDER BY
|
||||
@@ -181,7 +185,8 @@ INSERT INTO chats (
|
||||
last_model_config_id,
|
||||
title,
|
||||
mode,
|
||||
mcp_server_ids
|
||||
mcp_server_ids,
|
||||
labels
|
||||
) VALUES (
|
||||
@owner_id::uuid,
|
||||
sqlc.narg('workspace_id')::uuid,
|
||||
@@ -190,7 +195,8 @@ INSERT INTO chats (
|
||||
@last_model_config_id::uuid,
|
||||
@title::text,
|
||||
sqlc.narg('mode')::chat_mode,
|
||||
COALESCE(@mcp_server_ids::uuid[], '{}'::uuid[])
|
||||
COALESCE(@mcp_server_ids::uuid[], '{}'::uuid[]),
|
||||
COALESCE(sqlc.narg('labels')::jsonb, '{}'::jsonb)
|
||||
)
|
||||
RETURNING
|
||||
*;
|
||||
@@ -288,6 +294,17 @@ WHERE
|
||||
RETURNING
|
||||
*;
|
||||
|
||||
-- name: UpdateChatLabelsByID :one
|
||||
UPDATE
|
||||
chats
|
||||
SET
|
||||
labels = @labels::jsonb,
|
||||
updated_at = NOW()
|
||||
WHERE
|
||||
id = @id::uuid
|
||||
RETURNING
|
||||
*;
|
||||
|
||||
-- name: UpdateChatWorkspace :one
|
||||
UPDATE
|
||||
chats
|
||||
|
||||
@@ -65,6 +65,9 @@ sql:
|
||||
- column: "provisioner_jobs.tags"
|
||||
go_type:
|
||||
type: "StringMap"
|
||||
- column: "chats.labels"
|
||||
go_type:
|
||||
type: "StringMap"
|
||||
- column: "users.rbac_roles"
|
||||
go_type: "github.com/lib/pq.StringArray"
|
||||
- column: "templates.user_acl"
|
||||
|
||||
+86
-5
@@ -23,6 +23,7 @@ import (
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/shopspring/decimal"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
@@ -191,10 +192,38 @@ func (api *API) listChats(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
var labelFilter pqtype.NullRawMessage
|
||||
if labelParams := r.URL.Query()["label"]; len(labelParams) > 0 {
|
||||
labelMap := make(map[string]string, len(labelParams))
|
||||
for _, lp := range labelParams {
|
||||
key, value, ok := strings.Cut(lp, ":")
|
||||
if !ok || key == "" || value == "" {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: fmt.Sprintf("Invalid label filter: %q (expected format key:value, both must be non-empty)", lp),
|
||||
})
|
||||
return
|
||||
}
|
||||
labelMap[key] = value
|
||||
}
|
||||
labelsJSON, err := json.Marshal(labelMap)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to marshal label filter.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
labelFilter = pqtype.NullRawMessage{
|
||||
RawMessage: labelsJSON,
|
||||
Valid: true,
|
||||
}
|
||||
}
|
||||
|
||||
params := database.GetChatsParams{
|
||||
OwnerID: apiKey.UserID,
|
||||
Archived: searchParams.Archived,
|
||||
AfterID: paginationParams.AfterID,
|
||||
OwnerID: apiKey.UserID,
|
||||
Archived: searchParams.Archived,
|
||||
AfterID: paginationParams.AfterID,
|
||||
LabelFilter: labelFilter,
|
||||
// #nosec G115 - Pagination offsets are small and fit in int32
|
||||
OffsetOpt: int32(paginationParams.Offset),
|
||||
// #nosec G115 - Pagination limits are small and fit in int32
|
||||
@@ -320,6 +349,18 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) {
|
||||
mcpServerIDs = []uuid.UUID{}
|
||||
}
|
||||
|
||||
labels := req.Labels
|
||||
if labels == nil {
|
||||
labels = map[string]string{}
|
||||
}
|
||||
if errs := httpapi.ValidateChatLabels(labels); len(errs) > 0 {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid labels.",
|
||||
Validations: errs,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
chat, err := api.chatDaemon.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: apiKey.UserID,
|
||||
WorkspaceID: workspaceSelection.WorkspaceID,
|
||||
@@ -328,6 +369,7 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) {
|
||||
SystemPrompt: api.resolvedChatSystemPrompt(ctx),
|
||||
InitialUserContent: contentBlocks,
|
||||
MCPServerIDs: mcpServerIDs,
|
||||
Labels: labels,
|
||||
})
|
||||
if err != nil {
|
||||
if maybeWriteLimitErr(ctx, rw, err) {
|
||||
@@ -1407,8 +1449,8 @@ func (api *API) watchChatDesktop(rw http.ResponseWriter, r *http.Request) {
|
||||
logger.Debug(ctx, "desktop Bicopy finished")
|
||||
}
|
||||
|
||||
// patchChat updates a chat resource. Currently supports toggling the
|
||||
// archived state via the Archived field.
|
||||
// patchChat updates a chat resource. Supports updating labels and
|
||||
// toggling the archived state.
|
||||
func (api *API) patchChat(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
chat := httpmw.ChatParam(r)
|
||||
@@ -1418,6 +1460,40 @@ func (api *API) patchChat(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if req.Labels != nil {
|
||||
if errs := httpapi.ValidateChatLabels(*req.Labels); len(errs) > 0 {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid labels.",
|
||||
Validations: errs,
|
||||
})
|
||||
return
|
||||
}
|
||||
labelsJSON, err := json.Marshal(*req.Labels)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to marshal labels.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
updatedChat, err := api.Database.UpdateChatLabelsByID(ctx, database.UpdateChatLabelsByIDParams{
|
||||
ID: chat.ID,
|
||||
Labels: labelsJSON,
|
||||
})
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to update chat labels.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
chat = updatedChat
|
||||
}
|
||||
|
||||
if req.Archived != nil {
|
||||
archived := *req.Archived
|
||||
if archived == chat.Archived {
|
||||
@@ -3500,6 +3576,10 @@ func convertChat(c database.Chat, diffStatus *database.ChatDiffStatus) codersdk.
|
||||
if mcpServerIDs == nil {
|
||||
mcpServerIDs = []uuid.UUID{}
|
||||
}
|
||||
labels := map[string]string(c.Labels)
|
||||
if labels == nil {
|
||||
labels = map[string]string{}
|
||||
}
|
||||
chat := codersdk.Chat{
|
||||
ID: c.ID,
|
||||
OwnerID: c.OwnerID,
|
||||
@@ -3510,6 +3590,7 @@ func convertChat(c database.Chat, diffStatus *database.ChatDiffStatus) codersdk.
|
||||
CreatedAt: c.CreatedAt,
|
||||
UpdatedAt: c.UpdatedAt,
|
||||
MCPServerIDs: mcpServerIDs,
|
||||
Labels: labels,
|
||||
}
|
||||
if c.LastError.Valid {
|
||||
chat.LastError = &c.LastError.String
|
||||
|
||||
@@ -0,0 +1,78 @@
|
||||
package httpapi
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
const (
|
||||
// maxLabelsPerChat is the maximum number of labels allowed on a
|
||||
// single chat.
|
||||
maxLabelsPerChat = 50
|
||||
// maxLabelKeyLength is the maximum length of a label key in bytes.
|
||||
maxLabelKeyLength = 64
|
||||
// maxLabelValueLength is the maximum length of a label value in
|
||||
// bytes.
|
||||
maxLabelValueLength = 256
|
||||
)
|
||||
|
||||
// labelKeyRegex validates that a label key starts with an alphanumeric
|
||||
// character and is followed by alphanumeric characters, dots, hyphens,
|
||||
// underscores, or forward slashes.
|
||||
var labelKeyRegex = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9._/-]*$`)
|
||||
|
||||
// ValidateChatLabels checks that the provided labels map conforms to the
|
||||
// labeling constraints for chats. It returns a list of validation
|
||||
// errors, one per violated constraint.
|
||||
func ValidateChatLabels(labels map[string]string) []codersdk.ValidationError {
|
||||
var errs []codersdk.ValidationError
|
||||
|
||||
if len(labels) > maxLabelsPerChat {
|
||||
errs = append(errs, codersdk.ValidationError{
|
||||
Field: "labels",
|
||||
Detail: fmt.Sprintf("too many labels (%d); maximum is %d", len(labels), maxLabelsPerChat),
|
||||
})
|
||||
}
|
||||
|
||||
for k, v := range labels {
|
||||
if k == "" {
|
||||
errs = append(errs, codersdk.ValidationError{
|
||||
Field: "labels",
|
||||
Detail: "label key must not be empty",
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
if len(k) > maxLabelKeyLength {
|
||||
errs = append(errs, codersdk.ValidationError{
|
||||
Field: "labels",
|
||||
Detail: fmt.Sprintf("label key %q exceeds maximum length of %d bytes", k, maxLabelKeyLength),
|
||||
})
|
||||
}
|
||||
|
||||
if !labelKeyRegex.MatchString(k) {
|
||||
errs = append(errs, codersdk.ValidationError{
|
||||
Field: "labels",
|
||||
Detail: fmt.Sprintf("label key %q contains invalid characters; must match %s", k, labelKeyRegex.String()),
|
||||
})
|
||||
}
|
||||
|
||||
if v == "" {
|
||||
errs = append(errs, codersdk.ValidationError{
|
||||
Field: "labels",
|
||||
Detail: fmt.Sprintf("label value for key %q must not be empty", k),
|
||||
})
|
||||
}
|
||||
|
||||
if len(v) > maxLabelValueLength {
|
||||
errs = append(errs, codersdk.ValidationError{
|
||||
Field: "labels",
|
||||
Detail: fmt.Sprintf("label value for key %q exceeds maximum length of %d bytes", k, maxLabelValueLength),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return errs
|
||||
}
|
||||
@@ -0,0 +1,191 @@
|
||||
package httpapi_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
)
|
||||
|
||||
func TestValidateChatLabels(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("NilMap", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
errs := httpapi.ValidateChatLabels(nil)
|
||||
require.Empty(t, errs)
|
||||
})
|
||||
|
||||
t.Run("EmptyMap", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
errs := httpapi.ValidateChatLabels(map[string]string{})
|
||||
require.Empty(t, errs)
|
||||
})
|
||||
|
||||
t.Run("ValidLabels", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
labels := map[string]string{
|
||||
"env": "production",
|
||||
"github.repo": "coder/coder",
|
||||
"automation/pr": "12345",
|
||||
"team-backend": "core",
|
||||
"version_number": "v1.2.3",
|
||||
"A1.b2/c3-d4_e5": "mixed",
|
||||
}
|
||||
errs := httpapi.ValidateChatLabels(labels)
|
||||
require.Empty(t, errs)
|
||||
})
|
||||
|
||||
t.Run("TooManyLabels", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
labels := make(map[string]string, 51)
|
||||
for i := range 51 {
|
||||
labels[strings.Repeat("k", i+1)] = "v"
|
||||
}
|
||||
errs := httpapi.ValidateChatLabels(labels)
|
||||
require.NotEmpty(t, errs)
|
||||
|
||||
found := false
|
||||
for _, e := range errs {
|
||||
if strings.Contains(e.Detail, "too many labels") {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "expected a 'too many labels' error")
|
||||
})
|
||||
|
||||
t.Run("KeyTooLong", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
longKey := strings.Repeat("a", 65)
|
||||
labels := map[string]string{
|
||||
longKey: "value",
|
||||
}
|
||||
errs := httpapi.ValidateChatLabels(labels)
|
||||
require.NotEmpty(t, errs)
|
||||
|
||||
found := false
|
||||
for _, e := range errs {
|
||||
if strings.Contains(e.Detail, "exceeds maximum length of 64 bytes") {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "expected a key-too-long error")
|
||||
})
|
||||
|
||||
t.Run("ValueTooLong", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
longValue := strings.Repeat("v", 257)
|
||||
labels := map[string]string{
|
||||
"key": longValue,
|
||||
}
|
||||
errs := httpapi.ValidateChatLabels(labels)
|
||||
require.NotEmpty(t, errs)
|
||||
|
||||
found := false
|
||||
for _, e := range errs {
|
||||
if strings.Contains(e.Detail, "exceeds maximum length of 256 bytes") {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "expected a value-too-long error")
|
||||
})
|
||||
|
||||
t.Run("InvalidKeyWithSpaces", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
labels := map[string]string{
|
||||
"invalid key": "value",
|
||||
}
|
||||
errs := httpapi.ValidateChatLabels(labels)
|
||||
require.NotEmpty(t, errs)
|
||||
|
||||
found := false
|
||||
for _, e := range errs {
|
||||
if strings.Contains(e.Detail, "contains invalid characters") {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "expected an invalid-characters error for spaces")
|
||||
})
|
||||
|
||||
t.Run("InvalidKeyWithSpecialChars", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
labels := map[string]string{
|
||||
"key@value": "value",
|
||||
}
|
||||
errs := httpapi.ValidateChatLabels(labels)
|
||||
require.NotEmpty(t, errs)
|
||||
|
||||
found := false
|
||||
for _, e := range errs {
|
||||
if strings.Contains(e.Detail, "contains invalid characters") {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "expected an invalid-characters error for special chars")
|
||||
})
|
||||
|
||||
t.Run("KeyStartsWithNonAlphanumeric", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
labels := map[string]string{
|
||||
".dotfirst": "value",
|
||||
"-dashfirst": "value",
|
||||
"_underfirst": "value",
|
||||
"/slashfirst": "value",
|
||||
}
|
||||
errs := httpapi.ValidateChatLabels(labels)
|
||||
// Each of the four keys should produce an error.
|
||||
require.Len(t, errs, 4)
|
||||
for _, e := range errs {
|
||||
assert.Contains(t, e.Detail, "contains invalid characters")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("EmptyKey", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
labels := map[string]string{
|
||||
"": "value",
|
||||
}
|
||||
errs := httpapi.ValidateChatLabels(labels)
|
||||
require.Len(t, errs, 1)
|
||||
assert.Contains(t, errs[0].Detail, "must not be empty")
|
||||
})
|
||||
|
||||
t.Run("EmptyValue", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
labels := map[string]string{
|
||||
"key": "",
|
||||
}
|
||||
errs := httpapi.ValidateChatLabels(labels)
|
||||
require.Len(t, errs, 1)
|
||||
assert.Contains(t, errs[0].Detail, "must not be empty")
|
||||
})
|
||||
|
||||
t.Run("AllFieldsAreLabels", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
labels := map[string]string{
|
||||
"bad key": "",
|
||||
}
|
||||
errs := httpapi.ValidateChatLabels(labels)
|
||||
for _, e := range errs {
|
||||
assert.Equal(t, "labels", e.Field)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ExactlyAtLimits", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Keys and values exactly at their limits should be valid.
|
||||
labels := map[string]string{
|
||||
strings.Repeat("a", 64): strings.Repeat("v", 256),
|
||||
}
|
||||
errs := httpapi.ValidateChatLabels(labels)
|
||||
require.Empty(t, errs)
|
||||
})
|
||||
}
|
||||
@@ -397,6 +397,7 @@ type CreateOptions struct {
|
||||
SystemPrompt string
|
||||
InitialUserContent []codersdk.ChatMessagePart
|
||||
MCPServerIDs []uuid.UUID
|
||||
Labels database.StringMap
|
||||
}
|
||||
|
||||
// SendMessageBusyBehavior controls what happens when a chat is already active.
|
||||
@@ -475,6 +476,9 @@ func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.C
|
||||
if opts.MCPServerIDs == nil {
|
||||
opts.MCPServerIDs = []uuid.UUID{}
|
||||
}
|
||||
if opts.Labels == nil {
|
||||
opts.Labels = database.StringMap{}
|
||||
}
|
||||
|
||||
var chat database.Chat
|
||||
txErr := p.db.InTx(func(tx database.Store) error {
|
||||
@@ -482,6 +486,11 @@ func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.C
|
||||
return limitErr
|
||||
}
|
||||
|
||||
labelsJSON, err := json.Marshal(opts.Labels)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("marshal labels: %w", err)
|
||||
}
|
||||
|
||||
insertedChat, err := tx.InsertChat(ctx, database.InsertChatParams{
|
||||
OwnerID: opts.OwnerID,
|
||||
WorkspaceID: opts.WorkspaceID,
|
||||
@@ -491,6 +500,10 @@ func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.C
|
||||
Title: opts.Title,
|
||||
Mode: opts.ChatMode,
|
||||
MCPServerIDs: opts.MCPServerIDs,
|
||||
Labels: pqtype.NullRawMessage{
|
||||
RawMessage: labelsJSON,
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("insert chat: %w", err)
|
||||
|
||||
+34
-21
@@ -46,20 +46,21 @@ const (
|
||||
|
||||
// Chat represents a chat session with an AI agent.
|
||||
type Chat struct {
|
||||
ID uuid.UUID `json:"id" format:"uuid"`
|
||||
OwnerID uuid.UUID `json:"owner_id" format:"uuid"`
|
||||
WorkspaceID *uuid.UUID `json:"workspace_id,omitempty" format:"uuid"`
|
||||
ParentChatID *uuid.UUID `json:"parent_chat_id,omitempty" format:"uuid"`
|
||||
RootChatID *uuid.UUID `json:"root_chat_id,omitempty" format:"uuid"`
|
||||
LastModelConfigID uuid.UUID `json:"last_model_config_id" format:"uuid"`
|
||||
Title string `json:"title"`
|
||||
Status ChatStatus `json:"status"`
|
||||
LastError *string `json:"last_error"`
|
||||
DiffStatus *ChatDiffStatus `json:"diff_status,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at" format:"date-time"`
|
||||
UpdatedAt time.Time `json:"updated_at" format:"date-time"`
|
||||
Archived bool `json:"archived"`
|
||||
MCPServerIDs []uuid.UUID `json:"mcp_server_ids" format:"uuid"`
|
||||
ID uuid.UUID `json:"id" format:"uuid"`
|
||||
OwnerID uuid.UUID `json:"owner_id" format:"uuid"`
|
||||
WorkspaceID *uuid.UUID `json:"workspace_id,omitempty" format:"uuid"`
|
||||
ParentChatID *uuid.UUID `json:"parent_chat_id,omitempty" format:"uuid"`
|
||||
RootChatID *uuid.UUID `json:"root_chat_id,omitempty" format:"uuid"`
|
||||
LastModelConfigID uuid.UUID `json:"last_model_config_id" format:"uuid"`
|
||||
Title string `json:"title"`
|
||||
Status ChatStatus `json:"status"`
|
||||
LastError *string `json:"last_error"`
|
||||
DiffStatus *ChatDiffStatus `json:"diff_status,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at" format:"date-time"`
|
||||
UpdatedAt time.Time `json:"updated_at" format:"date-time"`
|
||||
Archived bool `json:"archived"`
|
||||
MCPServerIDs []uuid.UUID `json:"mcp_server_ids" format:"uuid"`
|
||||
Labels map[string]string `json:"labels"`
|
||||
}
|
||||
|
||||
// ChatMessage represents a single message in a chat.
|
||||
@@ -311,16 +312,18 @@ type ChatInputPart struct {
|
||||
|
||||
// CreateChatRequest is the request to create a new chat.
|
||||
type CreateChatRequest struct {
|
||||
Content []ChatInputPart `json:"content"`
|
||||
WorkspaceID *uuid.UUID `json:"workspace_id,omitempty" format:"uuid"`
|
||||
ModelConfigID *uuid.UUID `json:"model_config_id,omitempty" format:"uuid"`
|
||||
MCPServerIDs []uuid.UUID `json:"mcp_server_ids,omitempty" format:"uuid"`
|
||||
Content []ChatInputPart `json:"content"`
|
||||
WorkspaceID *uuid.UUID `json:"workspace_id,omitempty" format:"uuid"`
|
||||
ModelConfigID *uuid.UUID `json:"model_config_id,omitempty" format:"uuid"`
|
||||
MCPServerIDs []uuid.UUID `json:"mcp_server_ids,omitempty" format:"uuid"`
|
||||
Labels map[string]string `json:"labels,omitempty"`
|
||||
}
|
||||
|
||||
// UpdateChatRequest is the request to update a chat.
|
||||
type UpdateChatRequest struct {
|
||||
Title *string `json:"title,omitempty"`
|
||||
Archived *bool `json:"archived,omitempty"`
|
||||
Title *string `json:"title,omitempty"`
|
||||
Archived *bool `json:"archived,omitempty"`
|
||||
Labels *map[string]string `json:"labels,omitempty"`
|
||||
}
|
||||
|
||||
// CreateChatMessageRequest is the request to add a message to a chat.
|
||||
@@ -1166,7 +1169,8 @@ type ChatUsageLimitConfigResponse struct {
|
||||
|
||||
// ListChatsOptions are optional parameters for ListChats.
|
||||
type ListChatsOptions struct {
|
||||
Query string
|
||||
Query string
|
||||
Labels map[string]string
|
||||
Pagination
|
||||
}
|
||||
|
||||
@@ -1182,6 +1186,15 @@ func (c *ExperimentalClient) ListChats(ctx context.Context, opts *ListChatsOptio
|
||||
r.URL.RawQuery = q.Encode()
|
||||
})
|
||||
}
|
||||
if len(opts.Labels) > 0 {
|
||||
reqOpts = append(reqOpts, func(r *http.Request) {
|
||||
q := r.URL.Query()
|
||||
for k, v := range opts.Labels {
|
||||
q.Add("label", k+":"+v)
|
||||
}
|
||||
r.URL.RawQuery = q.Encode()
|
||||
})
|
||||
}
|
||||
}
|
||||
res, err := c.Request(ctx, http.MethodGet, "/api/experimental/chats", nil, reqOpts...)
|
||||
if err != nil {
|
||||
|
||||
@@ -78,6 +78,7 @@ const makeChat = (
|
||||
owner_id: "owner-1",
|
||||
last_model_config_id: "model-1",
|
||||
mcp_server_ids: [],
|
||||
labels: {},
|
||||
title: `Chat ${id}`,
|
||||
status: "running",
|
||||
created_at: "2025-01-01T00:00:00.000Z",
|
||||
|
||||
Generated
+4
@@ -1100,6 +1100,7 @@ export interface Chat {
|
||||
readonly updated_at: string;
|
||||
readonly archived: boolean;
|
||||
readonly mcp_server_ids: readonly string[];
|
||||
readonly labels: Record<string, string>;
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
@@ -2245,6 +2246,7 @@ export interface CreateChatRequest {
|
||||
readonly workspace_id?: string;
|
||||
readonly model_config_id?: string;
|
||||
readonly mcp_server_ids?: readonly string[];
|
||||
readonly labels?: Record<string, string>;
|
||||
}
|
||||
|
||||
// From codersdk/users.go
|
||||
@@ -3777,6 +3779,7 @@ export interface LinkConfig {
|
||||
*/
|
||||
export interface ListChatsOptions extends Pagination {
|
||||
readonly Query: string;
|
||||
readonly Labels: Record<string, string>;
|
||||
}
|
||||
|
||||
// From codersdk/inboxnotification.go
|
||||
@@ -7008,6 +7011,7 @@ export interface UpdateChatProviderConfigRequest {
|
||||
export interface UpdateChatRequest {
|
||||
readonly title?: string;
|
||||
readonly archived?: boolean;
|
||||
readonly labels?: Record<string, string>;
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
|
||||
@@ -112,6 +112,7 @@ const baseChatFields = {
|
||||
workspace_id: mockWorkspace.id,
|
||||
last_model_config_id: "model-config-1",
|
||||
mcp_server_ids: [],
|
||||
labels: {},
|
||||
created_at: "2026-02-18T00:00:00.000Z",
|
||||
updated_at: "2026-02-18T00:00:00.000Z",
|
||||
archived: false,
|
||||
|
||||
@@ -120,6 +120,7 @@ const buildChat = (overrides: Partial<Chat> = {}): Chat => ({
|
||||
status: "completed",
|
||||
last_model_config_id: defaultModelConfigs[0].id,
|
||||
mcp_server_ids: [],
|
||||
labels: {},
|
||||
created_at: oneWeekAgo,
|
||||
updated_at: oneWeekAgo,
|
||||
archived: false,
|
||||
|
||||
@@ -201,6 +201,7 @@ const makeChat = (chatID: string): TypesGen.Chat => ({
|
||||
owner_id: "owner-1",
|
||||
last_model_config_id: "model-1",
|
||||
mcp_server_ids: [],
|
||||
labels: {},
|
||||
title: "test",
|
||||
status: "running",
|
||||
created_at: "2025-01-01T00:00:00.000Z",
|
||||
|
||||
@@ -52,6 +52,7 @@ export const WithParentChat: Story = {
|
||||
owner_id: "owner-id",
|
||||
last_model_config_id: "model-config-1",
|
||||
mcp_server_ids: [],
|
||||
labels: {},
|
||||
title: "Set up CI/CD pipeline",
|
||||
status: "completed",
|
||||
last_error: null,
|
||||
|
||||
@@ -39,6 +39,7 @@ const buildChat = (overrides: Partial<TypesGen.Chat> = {}): TypesGen.Chat => ({
|
||||
status: "completed",
|
||||
last_model_config_id: "model-config-1",
|
||||
mcp_server_ids: [],
|
||||
labels: {},
|
||||
created_at: oneWeekAgo,
|
||||
updated_at: oneWeekAgo,
|
||||
archived: false,
|
||||
|
||||
@@ -41,6 +41,7 @@ const buildChat = (overrides: Partial<Chat> = {}): Chat => ({
|
||||
status: "completed",
|
||||
last_model_config_id: defaultModelConfigs[0].id,
|
||||
mcp_server_ids: [],
|
||||
labels: {},
|
||||
created_at: oneWeekAgo,
|
||||
updated_at: oneWeekAgo,
|
||||
archived: false,
|
||||
|
||||
@@ -64,6 +64,7 @@ const buildChat = (overrides: Partial<Chat> = {}): Chat => ({
|
||||
archived: false,
|
||||
last_error: null,
|
||||
mcp_server_ids: [],
|
||||
labels: {},
|
||||
...overrides,
|
||||
});
|
||||
|
||||
|
||||
Reference in New Issue
Block a user