diff --git a/coderd/chats.go b/coderd/chats.go index b7dc6a8028..a6db4c7a13 100644 --- a/coderd/chats.go +++ b/coderd/chats.go @@ -55,6 +55,7 @@ const ( defaultChatContextCompressionThreshold = int32(70) minChatContextCompressionThreshold = int32(0) maxChatContextCompressionThreshold = int32(100) + maxSystemPromptLenBytes = 131072 // 128 KiB ) // chatDiffRefreshBackoffSchedule defines the delays between successive @@ -284,7 +285,7 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) { WorkspaceID: workspaceSelection.WorkspaceID, Title: title, ModelConfigID: modelConfigID, - SystemPrompt: defaultChatSystemPrompt(), + SystemPrompt: api.resolvedChatSystemPrompt(ctx), InitialUserContent: contentBlocks, ContentFileIDs: contentFileIDs, }) @@ -2262,7 +2263,63 @@ func detectChatFileType(data []byte) string { return http.DetectContentType(data) } -func defaultChatSystemPrompt() string { +//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler. +func (api *API) getChatSystemPrompt(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + prompt, err := api.Database.GetChatSystemPrompt(ctx) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error fetching chat system prompt.", + Detail: err.Error(), + }) + return + } + httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatSystemPromptResponse{ + SystemPrompt: prompt, + }) +} + +func (api *API) putChatSystemPrompt(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + var req codersdk.UpdateChatSystemPromptRequest + if !httpapi.Read(ctx, rw, r, &req) { + return + } + trimmedPrompt := strings.TrimSpace(req.SystemPrompt) + // 128 KiB is generous for a system prompt while still + // preventing abuse or accidental pastes of large content. + if len(trimmedPrompt) > maxSystemPromptLenBytes { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "System prompt exceeds maximum length.", + Detail: fmt.Sprintf("Maximum length is %d bytes, got %d.", maxSystemPromptLenBytes, len(trimmedPrompt)), + }) + return + } + err := api.Database.UpsertChatSystemPrompt(ctx, trimmedPrompt) + if httpapi.Is404Error(err) { // also catches authz error + httpapi.ResourceNotFound(rw) + return + } else if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error updating chat system prompt.", + Detail: err.Error(), + }) + return + } + rw.WriteHeader(http.StatusNoContent) +} + +func (api *API) resolvedChatSystemPrompt(ctx context.Context) string { + custom, err := api.Database.GetChatSystemPrompt(ctx) + if err != nil { + // Log but don't fail chat creation — fall back to the + // built-in default so the user isn't blocked. + api.Logger.Error(ctx, "failed to fetch custom chat system prompt, using default", slog.Error(err)) + return chatd.DefaultSystemPrompt + } + if strings.TrimSpace(custom) != "" { + return custom + } return chatd.DefaultSystemPrompt } diff --git a/coderd/chats_test.go b/coderd/chats_test.go index faaeb491b6..c534556fc1 100644 --- a/coderd/chats_test.go +++ b/coderd/chats_test.go @@ -3118,6 +3118,81 @@ func createChatModelConfig(t *testing.T, client *codersdk.Client) codersdk.ChatM return modelConfig } +//nolint:tparallel,paralleltest // Subtests share a single coderdtest instance. +func TestChatSystemPrompt(t *testing.T) { + t.Parallel() + + adminClient := newChatClient(t) + firstUser := coderdtest.CreateFirstUser(t, adminClient) + memberClient, _ := coderdtest.CreateAnotherUser(t, adminClient, firstUser.OrganizationID) + + t.Run("ReturnsEmptyWhenUnset", func(t *testing.T) { + ctx := testutil.Context(t, testutil.WaitLong) + + resp, err := adminClient.GetChatSystemPrompt(ctx) + require.NoError(t, err) + require.Equal(t, "", resp.SystemPrompt) + }) + + t.Run("AdminCanSet", func(t *testing.T) { + ctx := testutil.Context(t, testutil.WaitLong) + + err := adminClient.UpdateChatSystemPrompt(ctx, codersdk.UpdateChatSystemPromptRequest{ + SystemPrompt: "You are a helpful coding assistant.", + }) + require.NoError(t, err) + + resp, err := adminClient.GetChatSystemPrompt(ctx) + require.NoError(t, err) + require.Equal(t, "You are a helpful coding assistant.", resp.SystemPrompt) + }) + + t.Run("AdminCanUnset", func(t *testing.T) { + ctx := testutil.Context(t, testutil.WaitLong) + + // Unset by sending an empty string. + err := adminClient.UpdateChatSystemPrompt(ctx, codersdk.UpdateChatSystemPromptRequest{ + SystemPrompt: "", + }) + require.NoError(t, err) + + resp, err := adminClient.GetChatSystemPrompt(ctx) + require.NoError(t, err) + require.Equal(t, "", resp.SystemPrompt) + }) + + t.Run("NonAdminFails", func(t *testing.T) { + ctx := testutil.Context(t, testutil.WaitLong) + + err := memberClient.UpdateChatSystemPrompt(ctx, codersdk.UpdateChatSystemPromptRequest{ + SystemPrompt: "This should fail.", + }) + requireSDKError(t, err, http.StatusNotFound) + }) + + t.Run("UnauthenticatedFails", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + + anonClient := codersdk.New(adminClient.URL) + _, err := anonClient.GetChatSystemPrompt(ctx) + var sdkErr *codersdk.Error + require.ErrorAs(t, err, &sdkErr) + require.Equal(t, http.StatusUnauthorized, sdkErr.StatusCode()) + }) + + t.Run("TooLong", func(t *testing.T) { + ctx := testutil.Context(t, testutil.WaitLong) + + tooLong := strings.Repeat("a", 131073) + err := adminClient.UpdateChatSystemPrompt(ctx, codersdk.UpdateChatSystemPromptRequest{ + SystemPrompt: tooLong, + }) + sdkErr := requireSDKError(t, err, http.StatusBadRequest) + require.Equal(t, "System prompt exceeds maximum length.", sdkErr.Message) + }) +} + func requireSDKError(t *testing.T, err error, expectedStatus int) *codersdk.Error { t.Helper() diff --git a/coderd/coderd.go b/coderd/coderd.go index 0e7b113650..9a084ef59f 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -1127,6 +1127,11 @@ func New(options *Options) *API { r.Post("/", api.postChatFile) r.Get("/{file}", api.chatFileByID) }) + r.Route("/config", func(r chi.Router) { + r.Get("/system-prompt", api.getChatSystemPrompt) + r.Put("/system-prompt", api.putChatSystemPrompt) + }) + // TODO(cian): place under /api/experimental/chats/config r.Route("/providers", func(r chi.Router) { r.Get("/", api.listChatProviders) r.Post("/", api.createChatProvider) @@ -1135,6 +1140,7 @@ func New(options *Options) *API { r.Delete("/", api.deleteChatProvider) }) }) + // TODO(cian): place under /api/experimental/chats/config r.Route("/model-configs", func(r chi.Router) { r.Get("/", api.listChatModelConfigs) r.Post("/", api.createChatModelConfig) diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 83710fca9e..9b63649ada 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -2564,6 +2564,18 @@ func (q *querier) GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID) ( return q.db.GetChatQueuedMessages(ctx, chatID) } +func (q *querier) GetChatSystemPrompt(ctx context.Context) (string, error) { + // The system prompt is a deployment-wide setting read during chat + // creation by every authenticated user, so no RBAC policy check + // is needed. We still verify that a valid actor exists in the + // context to ensure this is never callable by an unauthenticated + // or system-internal path without an explicit actor. + if _, ok := ActorFromContext(ctx); !ok { + return "", ErrNoActor + } + return q.db.GetChatSystemPrompt(ctx) +} + func (q *querier) GetChatsByOwnerID(ctx context.Context, ownerID database.GetChatsByOwnerIDParams) ([]database.Chat, error) { return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetChatsByOwnerID)(ctx, ownerID) } @@ -6536,6 +6548,13 @@ func (q *querier) UpsertChatDiffStatusReference(ctx context.Context, arg databas return q.db.UpsertChatDiffStatusReference(ctx, arg) } +func (q *querier) UpsertChatSystemPrompt(ctx context.Context, value string) error { + if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil { + return err + } + return q.db.UpsertChatSystemPrompt(ctx, value) +} + func (q *querier) UpsertConnectionLog(ctx context.Context, arg database.UpsertConnectionLogParams) (database.ConnectionLog, error) { if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceConnectionLog); err != nil { return database.ConnectionLog{}, err diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 668d016db8..3dbb103c4c 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -551,6 +551,10 @@ func (s *MethodTestSuite) TestChats() { dbm.EXPECT().GetChatQueuedMessages(gomock.Any(), chat.ID).Return(qms, nil).AnyTimes() check.Args(chat.ID).Asserts(chat, policy.ActionRead).Returns(qms) })) + s.Run("GetChatSystemPrompt", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + dbm.EXPECT().GetChatSystemPrompt(gomock.Any()).Return("prompt", nil).AnyTimes() + check.Args().Asserts() + })) s.Run("GetEnabledChatModelConfigs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { configA := testutil.Fake(s.T(), faker, database.ChatModelConfig{}) configB := testutil.Fake(s.T(), faker, database.ChatModelConfig{}) @@ -758,6 +762,10 @@ func (s *MethodTestSuite) TestChats() { dbm.EXPECT().UpsertChatDiffStatusReference(gomock.Any(), arg).Return(diffStatus, nil).AnyTimes() check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(diffStatus) })) + s.Run("UpsertChatSystemPrompt", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + dbm.EXPECT().UpsertChatSystemPrompt(gomock.Any(), "").Return(nil).AnyTimes() + check.Args("").Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate) + })) } func (s *MethodTestSuite) TestFile() { diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index 82ad4baf61..876e4535df 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -1103,6 +1103,14 @@ func (m queryMetricsStore) GetChatQueuedMessages(ctx context.Context, chatID uui return r0, r1 } +func (m queryMetricsStore) GetChatSystemPrompt(ctx context.Context) (string, error) { + start := time.Now() + r0, r1 := m.s.GetChatSystemPrompt(ctx) + m.queryLatencies.WithLabelValues("GetChatSystemPrompt").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatSystemPrompt").Inc() + return r0, r1 +} + func (m queryMetricsStore) GetChatsByOwnerID(ctx context.Context, ownerID database.GetChatsByOwnerIDParams) ([]database.Chat, error) { start := time.Now() r0, r1 := m.s.GetChatsByOwnerID(ctx, ownerID) @@ -4526,6 +4534,14 @@ func (m queryMetricsStore) UpsertChatDiffStatusReference(ctx context.Context, ar return r0, r1 } +func (m queryMetricsStore) UpsertChatSystemPrompt(ctx context.Context, value string) error { + start := time.Now() + r0 := m.s.UpsertChatSystemPrompt(ctx, value) + m.queryLatencies.WithLabelValues("UpsertChatSystemPrompt").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatSystemPrompt").Inc() + return r0 +} + func (m queryMetricsStore) UpsertConnectionLog(ctx context.Context, arg database.UpsertConnectionLogParams) (database.ConnectionLog, error) { start := time.Now() r0, r1 := m.s.UpsertConnectionLog(ctx, arg) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index 6a1b286ac5..8d81f43a43 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -2017,6 +2017,21 @@ func (mr *MockStoreMockRecorder) GetChatQueuedMessages(ctx, chatID any) *gomock. return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatQueuedMessages", reflect.TypeOf((*MockStore)(nil).GetChatQueuedMessages), ctx, chatID) } +// GetChatSystemPrompt mocks base method. +func (m *MockStore) GetChatSystemPrompt(ctx context.Context) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetChatSystemPrompt", ctx) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetChatSystemPrompt indicates an expected call of GetChatSystemPrompt. +func (mr *MockStoreMockRecorder) GetChatSystemPrompt(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatSystemPrompt", reflect.TypeOf((*MockStore)(nil).GetChatSystemPrompt), ctx) +} + // GetChatsByOwnerID mocks base method. func (m *MockStore) GetChatsByOwnerID(ctx context.Context, arg database.GetChatsByOwnerIDParams) ([]database.Chat, error) { m.ctrl.T.Helper() @@ -8463,6 +8478,20 @@ func (mr *MockStoreMockRecorder) UpsertChatDiffStatusReference(ctx, arg any) *go return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatDiffStatusReference", reflect.TypeOf((*MockStore)(nil).UpsertChatDiffStatusReference), ctx, arg) } +// UpsertChatSystemPrompt mocks base method. +func (m *MockStore) UpsertChatSystemPrompt(ctx context.Context, value string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpsertChatSystemPrompt", ctx, value) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpsertChatSystemPrompt indicates an expected call of UpsertChatSystemPrompt. +func (mr *MockStoreMockRecorder) UpsertChatSystemPrompt(ctx, value any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatSystemPrompt", reflect.TypeOf((*MockStore)(nil).UpsertChatSystemPrompt), ctx, value) +} + // UpsertConnectionLog mocks base method. func (m *MockStore) UpsertConnectionLog(ctx context.Context, arg database.UpsertConnectionLogParams) (database.ConnectionLog, error) { m.ctrl.T.Helper() diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 014f9fd3b0..5d8138694f 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -230,6 +230,7 @@ type sqlcQuerier interface { GetChatProviderByProvider(ctx context.Context, provider string) (ChatProvider, error) GetChatProviders(ctx context.Context) ([]ChatProvider, error) GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID) ([]ChatQueuedMessage, error) + GetChatSystemPrompt(ctx context.Context) (string, error) GetChatsByOwnerID(ctx context.Context, arg GetChatsByOwnerIDParams) ([]Chat, error) GetConnectionLogsOffset(ctx context.Context, arg GetConnectionLogsOffsetParams) ([]GetConnectionLogsOffsetRow, error) GetCoordinatorResumeTokenSigningKey(ctx context.Context) (string, error) @@ -840,6 +841,7 @@ type sqlcQuerier interface { UpsertBoundaryUsageStats(ctx context.Context, arg UpsertBoundaryUsageStatsParams) (bool, error) UpsertChatDiffStatus(ctx context.Context, arg UpsertChatDiffStatusParams) (ChatDiffStatus, error) UpsertChatDiffStatusReference(ctx context.Context, arg UpsertChatDiffStatusReferenceParams) (ChatDiffStatus, error) + UpsertChatSystemPrompt(ctx context.Context, value string) error UpsertConnectionLog(ctx context.Context, arg UpsertConnectionLogParams) (ConnectionLog, error) UpsertCoordinatorResumeTokenSigningKey(ctx context.Context, value string) error // The default proxy is implied and not actually stored in the database. diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index fe17470c9f..7606e160bd 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -14475,6 +14475,18 @@ func (q *sqlQuerier) GetApplicationName(ctx context.Context) (string, error) { return value, err } +const getChatSystemPrompt = `-- name: GetChatSystemPrompt :one +SELECT + COALESCE((SELECT value FROM site_configs WHERE key = 'agents_chat_system_prompt'), '') :: text AS chat_system_prompt +` + +func (q *sqlQuerier) GetChatSystemPrompt(ctx context.Context) (string, error) { + row := q.db.QueryRowContext(ctx, getChatSystemPrompt) + var chat_system_prompt string + err := row.Scan(&chat_system_prompt) + return chat_system_prompt, err +} + const getCoordinatorResumeTokenSigningKey = `-- name: GetCoordinatorResumeTokenSigningKey :one SELECT value FROM site_configs WHERE key = 'coordinator_resume_token_signing_key' ` @@ -14689,6 +14701,16 @@ func (q *sqlQuerier) UpsertApplicationName(ctx context.Context, value string) er return err } +const upsertChatSystemPrompt = `-- name: UpsertChatSystemPrompt :exec +INSERT INTO site_configs (key, value) VALUES ('agents_chat_system_prompt', $1) +ON CONFLICT (key) DO UPDATE SET value = $1 WHERE site_configs.key = 'agents_chat_system_prompt' +` + +func (q *sqlQuerier) UpsertChatSystemPrompt(ctx context.Context, value string) error { + _, err := q.db.ExecContext(ctx, upsertChatSystemPrompt, value) + return err +} + const upsertCoordinatorResumeTokenSigningKey = `-- name: UpsertCoordinatorResumeTokenSigningKey :exec INSERT INTO site_configs (key, value) VALUES ('coordinator_resume_token_signing_key', $1) ON CONFLICT (key) DO UPDATE set value = $1 WHERE site_configs.key = 'coordinator_resume_token_signing_key' diff --git a/coderd/database/queries/siteconfig.sql b/coderd/database/queries/siteconfig.sql index 4ee19c6bd5..63b8076a46 100644 --- a/coderd/database/queries/siteconfig.sql +++ b/coderd/database/queries/siteconfig.sql @@ -153,3 +153,11 @@ DO UPDATE SET value = EXCLUDED.value WHERE site_configs.key = EXCLUDED.key; SELECT COALESCE((SELECT value FROM site_configs WHERE key = 'webpush_vapid_public_key'), '') :: text AS vapid_public_key, COALESCE((SELECT value FROM site_configs WHERE key = 'webpush_vapid_private_key'), '') :: text AS vapid_private_key; + +-- name: GetChatSystemPrompt :one +SELECT + COALESCE((SELECT value FROM site_configs WHERE key = 'agents_chat_system_prompt'), '') :: text AS chat_system_prompt; + +-- name: UpsertChatSystemPrompt :exec +INSERT INTO site_configs (key, value) VALUES ('agents_chat_system_prompt', $1) +ON CONFLICT (key) DO UPDATE SET value = $1 WHERE site_configs.key = 'agents_chat_system_prompt'; diff --git a/codersdk/chats.go b/codersdk/chats.go index d354495a85..0b4ab414a9 100644 --- a/codersdk/chats.go +++ b/codersdk/chats.go @@ -202,6 +202,16 @@ type ChatModelsResponse struct { Providers []ChatModelProvider `json:"providers"` } +// ChatSystemPromptResponse is the response for getting the chat system prompt. +type ChatSystemPromptResponse struct { + SystemPrompt string `json:"system_prompt"` +} + +// UpdateChatSystemPromptRequest is the request to update the chat system prompt. +type UpdateChatSystemPromptRequest struct { + SystemPrompt string `json:"system_prompt"` +} + // ChatProviderConfigSource describes how a provider entry is sourced. type ChatProviderConfigSource string @@ -681,6 +691,33 @@ func (c *Client) DeleteChatModelConfig(ctx context.Context, modelConfigID uuid.U return nil } +// GetChatSystemPrompt returns the deployment-wide chat system prompt. +func (c *Client) GetChatSystemPrompt(ctx context.Context) (ChatSystemPromptResponse, error) { + res, err := c.Request(ctx, http.MethodGet, "/api/experimental/chats/config/system-prompt", nil) + if err != nil { + return ChatSystemPromptResponse{}, err + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return ChatSystemPromptResponse{}, ReadBodyAsError(res) + } + var resp ChatSystemPromptResponse + return resp, json.NewDecoder(res.Body).Decode(&resp) +} + +// UpdateChatSystemPrompt updates the deployment-wide chat system prompt. +func (c *Client) UpdateChatSystemPrompt(ctx context.Context, req UpdateChatSystemPromptRequest) error { + res, err := c.Request(ctx, http.MethodPut, "/api/experimental/chats/config/system-prompt", req) + if err != nil { + return err + } + defer res.Body.Close() + if res.StatusCode != http.StatusNoContent { + return ReadBodyAsError(res) + } + return nil +} + // CreateChat creates a new chat. func (c *Client) CreateChat(ctx context.Context, req CreateChatRequest) (Chat, error) { res, err := c.Request(ctx, http.MethodPost, "/api/experimental/chats", req) diff --git a/site/src/api/api.ts b/site/src/api/api.ts index 36c7d4ae9e..0d8f406a29 100644 --- a/site/src/api/api.ts +++ b/site/src/api/api.ts @@ -3052,6 +3052,20 @@ class ApiMethods { return response.data; }; + getChatSystemPrompt = + async (): Promise => { + const response = await this.axios.get( + "/api/experimental/chats/config/system-prompt", + ); + return response.data; + }; + + updateChatSystemPrompt = async ( + req: TypesGen.UpdateChatSystemPromptRequest, + ): Promise => { + await this.axios.put("/api/experimental/chats/config/system-prompt", req); + }; + getChatProviderConfigs = async (): Promise => { const response = await this.axios.get( chatProviderConfigsPath, diff --git a/site/src/api/queries/chats.ts b/site/src/api/queries/chats.ts index 1939002544..c254d66e79 100644 --- a/site/src/api/queries/chats.ts +++ b/site/src/api/queries/chats.ts @@ -196,6 +196,22 @@ export const chatDiffContents = (chatId: string) => ({ queryFn: () => API.getChatDiffContents(chatId), }); +const chatSystemPromptKey = ["chat-system-prompt"] as const; + +export const chatSystemPrompt = () => ({ + queryKey: chatSystemPromptKey, + queryFn: () => API.getChatSystemPrompt(), +}); + +export const updateChatSystemPrompt = (queryClient: QueryClient) => ({ + mutationFn: API.updateChatSystemPrompt, + onSuccess: async () => { + await queryClient.invalidateQueries({ + queryKey: chatSystemPromptKey, + }); + }, +}); + export const chatModelsKey = ["chat-models"] as const; export const chatModels = () => ({ diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index 059eaf5a53..6dc2d94c77 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -1611,6 +1611,14 @@ export interface ChatStreamStatus { readonly status: ChatStatus; } +// From codersdk/chats.go +/** + * ChatSystemPromptResponse is the response for getting the chat system prompt. + */ +export interface ChatSystemPromptResponse { + readonly system_prompt: string; +} + // From codersdk/chats.go /** * ChatWithMessages is a chat along with its messages. @@ -6303,6 +6311,14 @@ export interface UpdateChatRequest { readonly title: string; } +// From codersdk/chats.go +/** + * UpdateChatSystemPromptRequest is the request to update the chat system prompt. + */ +export interface UpdateChatSystemPromptRequest { + readonly system_prompt: string; +} + // From codersdk/updatecheck.go /** * UpdateCheckResponse contains information on the latest release of Coder. diff --git a/site/src/pages/AgentsPage/AgentsPage.stories.tsx b/site/src/pages/AgentsPage/AgentsPage.stories.tsx index 7b65c9b35f..fd8254e9a4 100644 --- a/site/src/pages/AgentsPage/AgentsPage.stories.tsx +++ b/site/src/pages/AgentsPage/AgentsPage.stories.tsx @@ -22,8 +22,6 @@ const modelOptions = [ }, ] as const; -const behaviorStorageKey = "agents.system-prompt"; - const meta: Meta = { title: "pages/AgentsPage/AgentsEmptyState", component: AgentsEmptyState, @@ -49,6 +47,10 @@ const meta: Meta = { workspaces: [], count: 0, }); + spyOn(API, "getChatSystemPrompt").mockResolvedValue({ + system_prompt: "", + }); + spyOn(API, "updateChatSystemPrompt").mockResolvedValue(); }, }; @@ -186,9 +188,9 @@ export const SavesBehaviorPromptAndRestores: Story = { await userEvent.click(within(dialog).getByRole("button", { name: "Save" })); await waitFor(() => { - expect(localStorage.getItem(behaviorStorageKey)).toBe( - "You are a focused coding assistant.", - ); + expect(API.updateChatSystemPrompt).toHaveBeenCalledWith({ + system_prompt: "You are a focused coding assistant.", + }); }); }, }; diff --git a/site/src/pages/AgentsPage/AgentsPage.tsx b/site/src/pages/AgentsPage/AgentsPage.tsx index b3fd16b96a..0ead639647 100644 --- a/site/src/pages/AgentsPage/AgentsPage.tsx +++ b/site/src/pages/AgentsPage/AgentsPage.tsx @@ -7,10 +7,12 @@ import { chatKey, chatModelConfigs, chatModels, + chatSystemPrompt, chats, chatsKey, createChat, unarchiveChat, + updateChatSystemPrompt, } from "api/queries/chats"; import { workspaces } from "api/queries/workspaces"; import type * as TypesGen from "api/typesGenerated"; @@ -67,7 +69,6 @@ import { WebPushButton } from "./WebPushButton"; export const emptyInputStorageKey = "agents.empty-input"; const selectedWorkspaceIdStorageKey = "agents.selected-workspace-id"; const lastModelConfigIDStorageKey = "agents.last-model-config-id"; -const systemPromptStorageKey = "agents.system-prompt"; const nilUUID = "00000000-0000-0000-0000-000000000000"; type ChatModelOption = ModelSelectorOption; @@ -704,14 +705,15 @@ export const AgentsEmptyState: FC = ({ onConfigureAgentsDialogOpenChange, }) => { const { organizations } = useDashboard(); + const queryClient = useQueryClient(); const { initialInputValue, handleContentChange, submitDraft, resetDraft } = useEmptyStateDraft(); - const initialSystemPrompt = () => { - if (typeof window === "undefined") { - return ""; - } - return localStorage.getItem(systemPromptStorageKey) ?? ""; - }; + const systemPromptQuery = useQuery(chatSystemPrompt()); + const { + mutate: saveSystemPrompt, + isPending: isSavingSystemPrompt, + isError: isSaveSystemPromptError, + } = useMutation(updateChatSystemPrompt(queryClient)); const [initialLastModelConfigID] = useState(() => { if (typeof window === "undefined") { return ""; @@ -771,10 +773,9 @@ export const AgentsEmptyState: FC = ({ modelOptions.some((modelOption) => modelOption.id === userSelectedModel) ? userSelectedModel : preferredModelID; - const [savedSystemPrompt, setSavedSystemPrompt] = - useState(initialSystemPrompt); - const [systemPromptDraft, setSystemPromptDraft] = - useState(initialSystemPrompt); + const serverPrompt = systemPromptQuery.data?.system_prompt ?? ""; + const [localEdit, setLocalEdit] = useState(null); + const systemPromptDraft = localEdit ?? serverPrompt; const workspacesQuery = useQuery(workspaces({ q: "owner:me", limit: 0 })); const [selectedWorkspaceId, setSelectedWorkspaceId] = useState( () => { @@ -832,7 +833,7 @@ export const AgentsEmptyState: FC = ({ selectedWorkspaceIdRef.current = selectedWorkspaceId; const selectedModelRef = useRef(selectedModel); selectedModelRef.current = selectedModel; - const isSystemPromptDirty = systemPromptDraft !== savedSystemPrompt; + const isSystemPromptDirty = localEdit !== null && localEdit !== serverPrompt; const handleWorkspaceChange = (value: string) => { if (value === autoCreateWorkspaceValue) { @@ -859,17 +860,12 @@ export const AgentsEmptyState: FC = ({ if (!isSystemPromptDirty) { return; } - - setSavedSystemPrompt(systemPromptDraft); - if (typeof window !== "undefined") { - if (systemPromptDraft) { - localStorage.setItem(systemPromptStorageKey, systemPromptDraft); - } else { - localStorage.removeItem(systemPromptStorageKey); - } - } + saveSystemPrompt( + { system_prompt: systemPromptDraft }, + { onSuccess: () => setLocalEdit(null) }, + ); }, - [isSystemPromptDirty, systemPromptDraft], + [isSystemPromptDirty, systemPromptDraft, saveSystemPrompt], ); const handleSend = useCallback( @@ -1013,10 +1009,11 @@ export const AgentsEmptyState: FC = ({ canManageChatModelConfigs={canManageChatModelConfigs} canSetSystemPrompt={canSetSystemPrompt} systemPromptDraft={systemPromptDraft} - onSystemPromptDraftChange={setSystemPromptDraft} + onSystemPromptDraftChange={setLocalEdit} onSaveSystemPrompt={handleSaveSystemPrompt} isSystemPromptDirty={isSystemPromptDirty} - isDisabled={isCreating} + saveSystemPromptError={isSaveSystemPromptError} + isDisabled={isCreating || isSavingSystemPrompt} /> )} diff --git a/site/src/pages/AgentsPage/ConfigureAgentsDialog.stories.tsx b/site/src/pages/AgentsPage/ConfigureAgentsDialog.stories.tsx index fa438716c8..895ee28e87 100644 --- a/site/src/pages/AgentsPage/ConfigureAgentsDialog.stories.tsx +++ b/site/src/pages/AgentsPage/ConfigureAgentsDialog.stories.tsx @@ -78,6 +78,7 @@ const meta: Meta = { onSystemPromptDraftChange: fn(), onSaveSystemPrompt: fn(), isSystemPromptDirty: false, + saveSystemPromptError: false, isDisabled: false, }, }; diff --git a/site/src/pages/AgentsPage/ConfigureAgentsDialog.tsx b/site/src/pages/AgentsPage/ConfigureAgentsDialog.tsx index 2f72f2cced..9487c94473 100644 --- a/site/src/pages/AgentsPage/ConfigureAgentsDialog.tsx +++ b/site/src/pages/AgentsPage/ConfigureAgentsDialog.tsx @@ -32,6 +32,7 @@ interface ConfigureAgentsDialogProps { onSystemPromptDraftChange: (value: string) => void; onSaveSystemPrompt: (event: FormEvent) => void; isSystemPromptDirty: boolean; + saveSystemPromptError: boolean; isDisabled: boolean; } @@ -44,6 +45,7 @@ export const ConfigureAgentsDialog: FC = ({ onSystemPromptDraftChange, onSaveSystemPrompt, isSystemPromptDirty, + saveSystemPromptError, isDisabled, }) => { const configureSectionOptions = useMemo< @@ -152,7 +154,8 @@ export const ConfigureAgentsDialog: FC = ({ System Prompt

- Admin-only instruction applied to all new chats. + Admin-only instruction applied to all new chats. When empty, + the built-in default prompt is used.

= ({ Save + {saveSystemPromptError && ( +

+ Failed to save system prompt. +

+ )}