diff --git a/agent/agentfiles/api.go b/agent/agentfiles/api.go index 8cfe10c65a..e7667b1f81 100644 --- a/agent/agentfiles/api.go +++ b/agent/agentfiles/api.go @@ -31,6 +31,7 @@ func (api *API) Routes() http.Handler { r := chi.NewRouter() r.Post("/list-directory", api.HandleLS) + r.Get("/resolve-path", api.HandleResolvePath) r.Get("/read-file", api.HandleReadFile) r.Get("/read-file-lines", api.HandleReadFileLines) r.Post("/write-file", api.HandleWriteFile) diff --git a/agent/agentfiles/files.go b/agent/agentfiles/files.go index 028ac9697e..0af5d6e5f7 100644 --- a/agent/agentfiles/files.go +++ b/agent/agentfiles/files.go @@ -14,7 +14,6 @@ import ( "syscall" "github.com/google/uuid" - "github.com/spf13/afero" "golang.org/x/xerrors" "cdr.dev/slog/v3" @@ -328,7 +327,7 @@ func (api *API) writeFile(ctx context.Context, r *http.Request, path string) (HT return http.StatusBadRequest, xerrors.Errorf("file path must be absolute: %q", path) } - resolved, err := api.resolveSymlink(path) + resolved, err := api.resolvePath(path) if err != nil { return http.StatusInternalServerError, xerrors.Errorf("resolve symlink %q: %w", path, err) } @@ -447,7 +446,7 @@ func (api *API) prepareFileEdit(path string, edits []workspacesdk.FileEdit) (int return http.StatusBadRequest, nil, xerrors.New("must specify at least one edit") } - resolved, err := api.resolveSymlink(path) + resolved, err := api.resolvePath(path) if err != nil { return http.StatusInternalServerError, nil, xerrors.Errorf("resolve symlink %q: %w", path, err) } @@ -556,52 +555,6 @@ func (api *API) atomicWrite(ctx context.Context, path string, mode *os.FileMode, return 0, nil } -// resolveSymlink resolves a path through any symlinks so that -// subsequent operations (such as atomic rename) target the real -// file instead of replacing the symlink itself. -// -// The filesystem must implement afero.Lstater and afero.LinkReader -// for resolution to occur; if it does not (e.g. MemMapFs), the -// path is returned unchanged. -func (api *API) resolveSymlink(path string) (string, error) { - const maxDepth = 10 - - lstater, hasLstat := api.filesystem.(afero.Lstater) - if !hasLstat { - return path, nil - } - reader, hasReadlink := api.filesystem.(afero.LinkReader) - if !hasReadlink { - return path, nil - } - - for range maxDepth { - info, _, err := lstater.LstatIfPossible(path) - if err != nil { - // If the file does not exist yet (new file write), - // there is nothing to resolve. - if errors.Is(err, os.ErrNotExist) { - return path, nil - } - return "", err - } - if info.Mode()&os.ModeSymlink == 0 { - return path, nil - } - - target, err := reader.ReadlinkIfPossible(path) - if err != nil { - return "", err - } - if !filepath.IsAbs(target) { - target = filepath.Join(filepath.Dir(path), target) - } - path = target - } - - return "", xerrors.Errorf("too many levels of symlinks resolving %q", path) -} - // fuzzyReplace attempts to find `search` inside `content` and replace it // with `replace`. It uses a cascading match strategy inspired by // openai/codex's apply_patch: diff --git a/agent/agentfiles/resolvepath.go b/agent/agentfiles/resolvepath.go new file mode 100644 index 0000000000..3589d505b5 --- /dev/null +++ b/agent/agentfiles/resolvepath.go @@ -0,0 +1,119 @@ +package agentfiles + +import ( + "errors" + "net/http" + "os" + "path/filepath" + + "github.com/spf13/afero" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/httpapi" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/codersdk/workspacesdk" +) + +// HandleResolvePath resolves the existing portion of an absolute path through +// any symlinks and returns the resulting path. Missing trailing components are +// preserved so callers can validate future writes against the real target. +func (api *API) HandleResolvePath(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + query := r.URL.Query() + parser := httpapi.NewQueryParamParser().RequiredNotEmpty("path") + path := parser.String(query, "", "path") + parser.ErrorExcessParams(query) + if len(parser.Errors) > 0 { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Query parameters have invalid values.", + Validations: parser.Errors, + }) + return + } + + resolved, err := api.resolvePath(path) + if err != nil { + status := http.StatusInternalServerError + switch { + case !filepath.IsAbs(path): + status = http.StatusBadRequest + case errors.Is(err, os.ErrPermission): + status = http.StatusForbidden + } + httpapi.Write(ctx, rw, status, codersdk.Response{Message: err.Error()}) + return + } + + httpapi.Write(ctx, rw, http.StatusOK, workspacesdk.ResolvePathResponse{ + ResolvedPath: resolved, + }) +} + +// resolvePath resolves any symlinks in the existing portion of path while +// preserving missing trailing components. +func (api *API) resolvePath(path string) (string, error) { + if !filepath.IsAbs(path) { + return "", xerrors.Errorf("file path must be absolute: %q", path) + } + + path = filepath.Clean(path) + + lstater, hasLstat := api.filesystem.(afero.Lstater) + if !hasLstat { + return path, nil + } + targetReader, hasReadlink := api.filesystem.(afero.LinkReader) + if !hasReadlink { + return path, nil + } + + const maxDepth = 40 + var resolve func(string, int) (string, error) + resolve = func(path string, depth int) (string, error) { + if depth > maxDepth { + return "", xerrors.Errorf("too many levels of symlinks resolving %q", path) + } + + info, _, err := lstater.LstatIfPossible(path) + switch { + case err == nil: + if info.Mode()&os.ModeSymlink == 0 { + dir := filepath.Dir(path) + if dir == path { + return path, nil + } + + resolvedDir, err := resolve(dir, depth) + if err != nil { + return "", err + } + return filepath.Join(resolvedDir, filepath.Base(path)), nil + } + + target, err := targetReader.ReadlinkIfPossible(path) + if err != nil { + return "", err + } + if !filepath.IsAbs(target) { + target = filepath.Join(filepath.Dir(path), target) + } + return resolve(filepath.Clean(target), depth+1) + case errors.Is(err, os.ErrNotExist): + dir := filepath.Dir(path) + if dir == path { + return path, nil + } + + resolvedDir, err := resolve(dir, depth) + if err != nil { + return "", err + } + return filepath.Join(resolvedDir, filepath.Base(path)), nil + default: + return "", err + } + } + + return resolve(path, 0) +} diff --git a/agent/agentfiles/resolvepath_test.go b/agent/agentfiles/resolvepath_test.go new file mode 100644 index 0000000000..6b8160e296 --- /dev/null +++ b/agent/agentfiles/resolvepath_test.go @@ -0,0 +1,137 @@ +package agentfiles_test + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "runtime" + "testing" + + "github.com/spf13/afero" + "github.com/stretchr/testify/require" + + "cdr.dev/slog/v3" + "cdr.dev/slog/v3/sloggers/slogtest" + "github.com/coder/coder/v2/agent/agentfiles" + "github.com/coder/coder/v2/codersdk/workspacesdk" + "github.com/coder/coder/v2/testutil" +) + +func TestResolvePath_FollowsFileSymlink(t *testing.T) { + t.Parallel() + + if runtime.GOOS == "windows" { + t.Skip("symlinks are not reliably supported on Windows") + } + + dir := t.TempDir() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + osFs := afero.NewOsFs() + api := agentfiles.NewAPI(logger, osFs, nil) + + realPath := filepath.Join(dir, "real.txt") + err := afero.WriteFile(osFs, realPath, []byte("hello"), 0o644) + require.NoError(t, err) + + linkPath := filepath.Join(dir, "link.txt") + err = os.Symlink(realPath, linkPath) + require.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + + w := httptest.NewRecorder() + r := httptest.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("/resolve-path?path=%s", linkPath), nil) + api.Routes().ServeHTTP(w, r) + require.Equal(t, http.StatusOK, w.Code) + + var resp workspacesdk.ResolvePathResponse + require.NoError(t, json.NewDecoder(w.Body).Decode(&resp)) + require.Equal(t, mustEvalSymlinks(t, realPath), resp.ResolvedPath) +} + +func TestResolvePath_FollowsSymlinkedParentForMissingFile(t *testing.T) { + t.Parallel() + + if runtime.GOOS == "windows" { + t.Skip("symlinks are not reliably supported on Windows") + } + + dir := t.TempDir() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + osFs := afero.NewOsFs() + api := agentfiles.NewAPI(logger, osFs, nil) + + realPlansDir := filepath.Join(dir, "real-plans") + err := os.MkdirAll(realPlansDir, 0o755) + require.NoError(t, err) + + linkPlansDir := filepath.Join(dir, "link-plans") + err = os.Symlink(realPlansDir, linkPlansDir) + require.NoError(t, err) + + requestedPath := filepath.Join(linkPlansDir, "PLAN.md") + resolvedPath := filepath.Join(mustEvalSymlinks(t, realPlansDir), "PLAN.md") + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + + w := httptest.NewRecorder() + r := httptest.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("/resolve-path?path=%s", requestedPath), nil) + api.Routes().ServeHTTP(w, r) + require.Equal(t, http.StatusOK, w.Code) + + var resp workspacesdk.ResolvePathResponse + require.NoError(t, json.NewDecoder(w.Body).Decode(&resp)) + require.Equal(t, resolvedPath, resp.ResolvedPath) +} + +func TestResolvePath_FollowsSymlinkedParentForExistingFile(t *testing.T) { + t.Parallel() + + if runtime.GOOS == "windows" { + t.Skip("symlinks are not reliably supported on Windows") + } + + dir := t.TempDir() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + osFs := afero.NewOsFs() + api := agentfiles.NewAPI(logger, osFs, nil) + + realPlansDir := filepath.Join(dir, "real-plans") + err := os.MkdirAll(realPlansDir, 0o755) + require.NoError(t, err) + + resolvedPath := filepath.Join(realPlansDir, "PLAN.md") + err = afero.WriteFile(osFs, resolvedPath, []byte("plan"), 0o644) + require.NoError(t, err) + + linkPlansDir := filepath.Join(dir, "link-plans") + err = os.Symlink(realPlansDir, linkPlansDir) + require.NoError(t, err) + + requestedPath := filepath.Join(linkPlansDir, "PLAN.md") + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + + w := httptest.NewRecorder() + r := httptest.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("/resolve-path?path=%s", requestedPath), nil) + api.Routes().ServeHTTP(w, r) + require.Equal(t, http.StatusOK, w.Code) + + var resp workspacesdk.ResolvePathResponse + require.NoError(t, json.NewDecoder(w.Body).Decode(&resp)) + require.Equal(t, mustEvalSymlinks(t, resolvedPath), resp.ResolvedPath) +} + +func mustEvalSymlinks(t *testing.T, path string) string { + t.Helper() + resolvedPath, err := filepath.EvalSymlinks(path) + require.NoError(t, err) + return resolvedPath +} diff --git a/coderd/coderd.go b/coderd/coderd.go index 076825dd2f..0bec2abcd2 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -1181,6 +1181,8 @@ func New(options *Options) *API { r.Route("/config", func(r chi.Router) { r.Get("/system-prompt", api.getChatSystemPrompt) r.Put("/system-prompt", api.putChatSystemPrompt) + r.Get("/plan-mode-instructions", api.getChatPlanModeInstructions) + r.Put("/plan-mode-instructions", api.putChatPlanModeInstructions) r.Get("/desktop-enabled", api.getChatDesktopEnabled) r.Put("/desktop-enabled", api.putChatDesktopEnabled) r.Get("/user-prompt", api.getUserChatCustomPrompt) diff --git a/coderd/database/db2sdk/db2sdk.go b/coderd/database/db2sdk/db2sdk.go index 7d17f0c854..4715c6330a 100644 --- a/coderd/database/db2sdk/db2sdk.go +++ b/coderd/database/db2sdk/db2sdk.go @@ -1601,6 +1601,9 @@ func Chat(c database.Chat, diffStatus *database.ChatDiffStatus, files []database if c.LastError.Valid { chat.LastError = &c.LastError.String } + if c.PlanMode.Valid { + chat.PlanMode = codersdk.ChatPlanMode(c.PlanMode.ChatPlanMode) + } if c.ParentChatID.Valid { parentChatID := c.ParentChatID.UUID chat.ParentChatID = &parentChatID diff --git a/coderd/database/db2sdk/db2sdk_test.go b/coderd/database/db2sdk/db2sdk_test.go index f41b14921e..2bb75afb34 100644 --- a/coderd/database/db2sdk/db2sdk_test.go +++ b/coderd/database/db2sdk/db2sdk_test.go @@ -817,6 +817,7 @@ func TestChat_AllFieldsPopulated(t *testing.T) { UpdatedAt: now, Archived: true, PinOrder: 1, + PlanMode: database.NullChatPlanMode{ChatPlanMode: database.ChatPlanModePlan, Valid: true}, MCPServerIDs: []uuid.UUID{uuid.New()}, Labels: database.StringMap{"env": "prod"}, LastInjectedContext: pqtype.NullRawMessage{ diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 628495e7c9..7afbea6df8 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -2798,6 +2798,13 @@ func (q *querier) GetChatModelConfigsForTelemetry(ctx context.Context) ([]databa return q.db.GetChatModelConfigsForTelemetry(ctx) } +func (q *querier) GetChatPlanModeInstructions(ctx context.Context) (string, error) { + if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil { + return "", err + } + return q.db.GetChatPlanModeInstructions(ctx) +} + func (q *querier) GetChatProviderByID(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) { if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil { return database.ChatProvider{}, err @@ -6090,6 +6097,17 @@ func (q *querier) UpdateChatPinOrder(ctx context.Context, arg database.UpdateCha return q.db.UpdateChatPinOrder(ctx, arg) } +func (q *querier) UpdateChatPlanModeByID(ctx context.Context, arg database.UpdateChatPlanModeByIDParams) (database.Chat, error) { + chat, err := q.db.GetChatByID(ctx, arg.ID) + if err != nil { + return database.Chat{}, err + } + if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil { + return database.Chat{}, err + } + return q.db.UpdateChatPlanModeByID(ctx, arg) +} + func (q *querier) UpdateChatProvider(ctx context.Context, arg database.UpdateChatProviderParams) (database.ChatProvider, error) { if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil { return database.ChatProvider{}, err @@ -7267,6 +7285,13 @@ func (q *querier) UpsertChatIncludeDefaultSystemPrompt(ctx context.Context, incl return q.db.UpsertChatIncludeDefaultSystemPrompt(ctx, includeDefaultSystemPrompt) } +func (q *querier) UpsertChatPlanModeInstructions(ctx context.Context, value string) error { + if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil { + return err + } + return q.db.UpsertChatPlanModeInstructions(ctx, value) +} + func (q *querier) UpsertChatRetentionDays(ctx context.Context, retentionDays int32) error { if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil { return err diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index c9b74458b6..645a6ce643 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -834,6 +834,10 @@ func (s *MethodTestSuite) TestChats() { dbm.EXPECT().GetChatDesktopEnabled(gomock.Any()).Return(false, nil).AnyTimes() check.Args().Asserts() })) + s.Run("GetChatPlanModeInstructions", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + dbm.EXPECT().GetChatPlanModeInstructions(gomock.Any()).Return("", nil).AnyTimes() + check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate) + })) s.Run("GetChatTemplateAllowlist", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { dbm.EXPECT().GetChatTemplateAllowlist(gomock.Any()).Return("", nil).AnyTimes() check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead) @@ -949,6 +953,16 @@ func (s *MethodTestSuite) TestChats() { dbm.EXPECT().UpdateChatLastModelConfigByID(gomock.Any(), arg).Return(chat, nil).AnyTimes() check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(chat) })) + s.Run("UpdateChatPlanModeByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + chat := testutil.Fake(s.T(), faker, database.Chat{}) + arg := database.UpdateChatPlanModeByIDParams{ + ID: chat.ID, + PlanMode: database.NullChatPlanMode{ChatPlanMode: database.ChatPlanModePlan, Valid: true}, + } + dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes() + dbm.EXPECT().UpdateChatPlanModeByID(gomock.Any(), arg).Return(chat, nil).AnyTimes() + check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(chat) + })) s.Run("UpdateChatStatusPreserveUpdatedAt", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { chat := testutil.Fake(s.T(), faker, database.Chat{}) arg := database.UpdateChatStatusPreserveUpdatedAtParams{ @@ -1115,6 +1129,10 @@ func (s *MethodTestSuite) TestChats() { dbm.EXPECT().UpsertChatDesktopEnabled(gomock.Any(), false).Return(nil).AnyTimes() check.Args(false).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate) })) + s.Run("UpsertChatPlanModeInstructions", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + dbm.EXPECT().UpsertChatPlanModeInstructions(gomock.Any(), "").Return(nil).AnyTimes() + check.Args("").Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate) + })) s.Run("UpsertChatTemplateAllowlist", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { dbm.EXPECT().UpsertChatTemplateAllowlist(gomock.Any(), "").Return(nil).AnyTimes() check.Args("").Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate) diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index 4150bdb129..6c312614b7 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -1312,6 +1312,14 @@ func (m queryMetricsStore) GetChatModelConfigsForTelemetry(ctx context.Context) return r0, r1 } +func (m queryMetricsStore) GetChatPlanModeInstructions(ctx context.Context) (string, error) { + start := time.Now() + r0, r1 := m.s.GetChatPlanModeInstructions(ctx) + m.queryLatencies.WithLabelValues("GetChatPlanModeInstructions").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatPlanModeInstructions").Inc() + return r0, r1 +} + func (m queryMetricsStore) GetChatProviderByID(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) { start := time.Now() r0, r1 := m.s.GetChatProviderByID(ctx, id) @@ -4376,6 +4384,14 @@ func (m queryMetricsStore) UpdateChatPinOrder(ctx context.Context, arg database. return r0 } +func (m queryMetricsStore) UpdateChatPlanModeByID(ctx context.Context, arg database.UpdateChatPlanModeByIDParams) (database.Chat, error) { + start := time.Now() + r0, r1 := m.s.UpdateChatPlanModeByID(ctx, arg) + m.queryLatencies.WithLabelValues("UpdateChatPlanModeByID").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatPlanModeByID").Inc() + return r0, r1 +} + func (m queryMetricsStore) UpdateChatProvider(ctx context.Context, arg database.UpdateChatProviderParams) (database.ChatProvider, error) { start := time.Now() r0, r1 := m.s.UpdateChatProvider(ctx, arg) @@ -5184,6 +5200,14 @@ func (m queryMetricsStore) UpsertChatIncludeDefaultSystemPrompt(ctx context.Cont return r0 } +func (m queryMetricsStore) UpsertChatPlanModeInstructions(ctx context.Context, value string) error { + start := time.Now() + r0 := m.s.UpsertChatPlanModeInstructions(ctx, value) + m.queryLatencies.WithLabelValues("UpsertChatPlanModeInstructions").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatPlanModeInstructions").Inc() + return r0 +} + func (m queryMetricsStore) UpsertChatRetentionDays(ctx context.Context, retentionDays int32) error { start := time.Now() r0 := m.s.UpsertChatRetentionDays(ctx, retentionDays) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index 4e6bd174bf..76c7fc1c77 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -2417,6 +2417,21 @@ func (mr *MockStoreMockRecorder) GetChatModelConfigsForTelemetry(ctx any) *gomoc return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatModelConfigsForTelemetry", reflect.TypeOf((*MockStore)(nil).GetChatModelConfigsForTelemetry), ctx) } +// GetChatPlanModeInstructions mocks base method. +func (m *MockStore) GetChatPlanModeInstructions(ctx context.Context) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetChatPlanModeInstructions", ctx) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetChatPlanModeInstructions indicates an expected call of GetChatPlanModeInstructions. +func (mr *MockStoreMockRecorder) GetChatPlanModeInstructions(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatPlanModeInstructions", reflect.TypeOf((*MockStore)(nil).GetChatPlanModeInstructions), ctx) +} + // GetChatProviderByID mocks base method. func (m *MockStore) GetChatProviderByID(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) { m.ctrl.T.Helper() @@ -8282,6 +8297,21 @@ func (mr *MockStoreMockRecorder) UpdateChatPinOrder(ctx, arg any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatPinOrder", reflect.TypeOf((*MockStore)(nil).UpdateChatPinOrder), ctx, arg) } +// UpdateChatPlanModeByID mocks base method. +func (m *MockStore) UpdateChatPlanModeByID(ctx context.Context, arg database.UpdateChatPlanModeByIDParams) (database.Chat, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateChatPlanModeByID", ctx, arg) + ret0, _ := ret[0].(database.Chat) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpdateChatPlanModeByID indicates an expected call of UpdateChatPlanModeByID. +func (mr *MockStoreMockRecorder) UpdateChatPlanModeByID(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatPlanModeByID", reflect.TypeOf((*MockStore)(nil).UpdateChatPlanModeByID), ctx, arg) +} + // UpdateChatProvider mocks base method. func (m *MockStore) UpdateChatProvider(ctx context.Context, arg database.UpdateChatProviderParams) (database.ChatProvider, error) { m.ctrl.T.Helper() @@ -9741,6 +9771,20 @@ func (mr *MockStoreMockRecorder) UpsertChatIncludeDefaultSystemPrompt(ctx, inclu return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatIncludeDefaultSystemPrompt", reflect.TypeOf((*MockStore)(nil).UpsertChatIncludeDefaultSystemPrompt), ctx, includeDefaultSystemPrompt) } +// UpsertChatPlanModeInstructions mocks base method. +func (m *MockStore) UpsertChatPlanModeInstructions(ctx context.Context, value string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpsertChatPlanModeInstructions", ctx, value) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpsertChatPlanModeInstructions indicates an expected call of UpsertChatPlanModeInstructions. +func (mr *MockStoreMockRecorder) UpsertChatPlanModeInstructions(ctx, value any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatPlanModeInstructions", reflect.TypeOf((*MockStore)(nil).UpsertChatPlanModeInstructions), ctx, value) +} + // UpsertChatRetentionDays mocks base method. func (m *MockStore) UpsertChatRetentionDays(ctx context.Context, retentionDays int32) error { m.ctrl.T.Helper() diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index 7f6f0f6d63..aa3485e42d 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -287,6 +287,10 @@ CREATE TYPE chat_mode AS ENUM ( 'computer_use' ); +CREATE TYPE chat_plan_mode AS ENUM ( + 'plan' +); + CREATE TYPE chat_status AS ENUM ( 'waiting', 'pending', @@ -1470,7 +1474,8 @@ CREATE TABLE chats ( last_read_message_id bigint, last_injected_context jsonb, dynamic_tools jsonb, - organization_id uuid NOT NULL + organization_id uuid NOT NULL, + plan_mode chat_plan_mode ); CREATE TABLE connection_logs ( diff --git a/coderd/database/migrations/000469_chat_turn_mode.down.sql b/coderd/database/migrations/000469_chat_turn_mode.down.sql new file mode 100644 index 0000000000..71c1a750c1 --- /dev/null +++ b/coderd/database/migrations/000469_chat_turn_mode.down.sql @@ -0,0 +1,2 @@ +ALTER TABLE chats DROP COLUMN plan_mode; +DROP TYPE chat_plan_mode; diff --git a/coderd/database/migrations/000469_chat_turn_mode.up.sql b/coderd/database/migrations/000469_chat_turn_mode.up.sql new file mode 100644 index 0000000000..94ce9b810f --- /dev/null +++ b/coderd/database/migrations/000469_chat_turn_mode.up.sql @@ -0,0 +1,2 @@ +CREATE TYPE chat_plan_mode AS ENUM ('plan'); +ALTER TABLE chats ADD COLUMN plan_mode chat_plan_mode; diff --git a/coderd/database/modelqueries.go b/coderd/database/modelqueries.go index 77ddca13a1..6e392e0df7 100644 --- a/coderd/database/modelqueries.go +++ b/coderd/database/modelqueries.go @@ -800,6 +800,7 @@ func (q *sqlQuerier) GetAuthorizedChats(ctx context.Context, arg GetChatsParams, &i.Chat.LastInjectedContext, &i.Chat.DynamicTools, &i.Chat.OrganizationID, + &i.Chat.PlanMode, &i.HasUnread); err != nil { return nil, err } diff --git a/coderd/database/models.go b/coderd/database/models.go index 398b67d187..fe2664042e 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -1287,6 +1287,61 @@ func AllChatModeValues() []ChatMode { } } +type ChatPlanMode string + +const ( + ChatPlanModePlan ChatPlanMode = "plan" +) + +func (e *ChatPlanMode) Scan(src interface{}) error { + switch s := src.(type) { + case []byte: + *e = ChatPlanMode(s) + case string: + *e = ChatPlanMode(s) + default: + return fmt.Errorf("unsupported scan type for ChatPlanMode: %T", src) + } + return nil +} + +type NullChatPlanMode struct { + ChatPlanMode ChatPlanMode `json:"chat_plan_mode"` + Valid bool `json:"valid"` // Valid is true if ChatPlanMode is not NULL +} + +// Scan implements the Scanner interface. +func (ns *NullChatPlanMode) Scan(value interface{}) error { + if value == nil { + ns.ChatPlanMode, ns.Valid = "", false + return nil + } + ns.Valid = true + return ns.ChatPlanMode.Scan(value) +} + +// Value implements the driver Valuer interface. +func (ns NullChatPlanMode) Value() (driver.Value, error) { + if !ns.Valid { + return nil, nil + } + return string(ns.ChatPlanMode), nil +} + +func (e ChatPlanMode) Valid() bool { + switch e { + case ChatPlanModePlan: + return true + } + return false +} + +func AllChatPlanModeValues() []ChatPlanMode { + return []ChatPlanMode{ + ChatPlanModePlan, + } +} + type ChatStatus string const ( @@ -4247,6 +4302,7 @@ type Chat struct { LastInjectedContext pqtype.NullRawMessage `db:"last_injected_context" json:"last_injected_context"` DynamicTools pqtype.NullRawMessage `db:"dynamic_tools" json:"dynamic_tools"` OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"` + PlanMode NullChatPlanMode `db:"plan_mode" json:"plan_mode"` } type ChatDebugRun struct { diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 3d00618488..19451b0852 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -306,6 +306,7 @@ type sqlcQuerier interface { GetChatModelConfigs(ctx context.Context) ([]ChatModelConfig, error) // Returns all model configurations for telemetry snapshot collection. GetChatModelConfigsForTelemetry(ctx context.Context) ([]GetChatModelConfigsForTelemetryRow, error) + GetChatPlanModeInstructions(ctx context.Context) (string, error) GetChatProviderByID(ctx context.Context, id uuid.UUID) (ChatProvider, error) GetChatProviderByProvider(ctx context.Context, provider string) (ChatProvider, error) GetChatProviders(ctx context.Context) ([]ChatProvider, error) @@ -982,6 +983,7 @@ type sqlcQuerier interface { UpdateChatMessageByID(ctx context.Context, arg UpdateChatMessageByIDParams) (ChatMessage, error) UpdateChatModelConfig(ctx context.Context, arg UpdateChatModelConfigParams) (ChatModelConfig, error) UpdateChatPinOrder(ctx context.Context, arg UpdateChatPinOrderParams) error + UpdateChatPlanModeByID(ctx context.Context, arg UpdateChatPlanModeByIDParams) (Chat, error) UpdateChatProvider(ctx context.Context, arg UpdateChatProviderParams) (ChatProvider, error) UpdateChatStatus(ctx context.Context, arg UpdateChatStatusParams) (Chat, error) UpdateChatStatusPreserveUpdatedAt(ctx context.Context, arg UpdateChatStatusPreserveUpdatedAtParams) (Chat, error) @@ -1098,6 +1100,7 @@ type sqlcQuerier interface { UpsertChatDiffStatus(ctx context.Context, arg UpsertChatDiffStatusParams) (ChatDiffStatus, error) UpsertChatDiffStatusReference(ctx context.Context, arg UpsertChatDiffStatusReferenceParams) (ChatDiffStatus, error) UpsertChatIncludeDefaultSystemPrompt(ctx context.Context, includeDefaultSystemPrompt bool) error + UpsertChatPlanModeInstructions(ctx context.Context, value string) error UpsertChatRetentionDays(ctx context.Context, retentionDays int32) error UpsertChatSystemPrompt(ctx context.Context, value string) error UpsertChatTemplateAllowlist(ctx context.Context, templateAllowlist string) error diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 01b80b65ea..6447b37dfe 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -4816,7 +4816,7 @@ WHERE $3::int ) RETURNING - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools, organization_id + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools, organization_id, plan_mode ` type AcquireChatsParams struct { @@ -4862,6 +4862,7 @@ func (q *sqlQuerier) AcquireChats(ctx context.Context, arg AcquireChatsParams) ( &i.LastInjectedContext, &i.DynamicTools, &i.OrganizationID, + &i.PlanMode, ); err != nil { return nil, err } @@ -5000,9 +5001,9 @@ WITH chats AS ( UPDATE chats SET archived = true, pin_order = 0, updated_at = NOW() WHERE id = $1::uuid OR root_chat_id = $1::uuid - RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools, organization_id + RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools, organization_id, plan_mode ) -SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools, organization_id +SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools, organization_id, plan_mode FROM chats ORDER BY (id = $1::uuid) DESC, created_at ASC, id ASC ` @@ -5042,6 +5043,7 @@ func (q *sqlQuerier) ArchiveChatByID(ctx context.Context, id uuid.UUID) ([]Chat, &i.LastInjectedContext, &i.DynamicTools, &i.OrganizationID, + &i.PlanMode, ); err != nil { return nil, err } @@ -5191,7 +5193,7 @@ func (q *sqlQuerier) DeleteOldChats(ctx context.Context, arg DeleteOldChatsParam } const getActiveChatsByAgentID = `-- name: GetActiveChatsByAgentID :many -SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools, organization_id +SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools, organization_id, plan_mode FROM chats WHERE agent_id = $1::uuid AND archived = false @@ -5237,6 +5239,7 @@ func (q *sqlQuerier) GetActiveChatsByAgentID(ctx context.Context, agentID uuid.U &i.LastInjectedContext, &i.DynamicTools, &i.OrganizationID, + &i.PlanMode, ); err != nil { return nil, err } @@ -5253,7 +5256,7 @@ func (q *sqlQuerier) GetActiveChatsByAgentID(ctx context.Context, agentID uuid.U const getChatByID = `-- name: GetChatByID :one SELECT - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools, organization_id + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools, organization_id, plan_mode FROM chats WHERE @@ -5289,12 +5292,13 @@ func (q *sqlQuerier) GetChatByID(ctx context.Context, id uuid.UUID) (Chat, error &i.LastInjectedContext, &i.DynamicTools, &i.OrganizationID, + &i.PlanMode, ) return i, err } const getChatByIDForUpdate = `-- name: GetChatByIDForUpdate :one -SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools, organization_id FROM chats WHERE id = $1::uuid FOR UPDATE +SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools, organization_id, plan_mode FROM chats WHERE id = $1::uuid FOR UPDATE ` func (q *sqlQuerier) GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (Chat, error) { @@ -5326,6 +5330,7 @@ func (q *sqlQuerier) GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (Ch &i.LastInjectedContext, &i.DynamicTools, &i.OrganizationID, + &i.PlanMode, ) return i, err } @@ -6380,7 +6385,7 @@ func (q *sqlQuerier) GetChatUsageLimitUserOverride(ctx context.Context, userID u const getChats = `-- name: GetChats :many SELECT - chats.id, chats.owner_id, chats.workspace_id, chats.title, chats.status, chats.worker_id, chats.started_at, chats.heartbeat_at, chats.created_at, chats.updated_at, chats.parent_chat_id, chats.root_chat_id, chats.last_model_config_id, chats.archived, chats.last_error, chats.mode, chats.mcp_server_ids, chats.labels, chats.build_id, chats.agent_id, chats.pin_order, chats.last_read_message_id, chats.last_injected_context, chats.dynamic_tools, chats.organization_id, + chats.id, chats.owner_id, chats.workspace_id, chats.title, chats.status, chats.worker_id, chats.started_at, chats.heartbeat_at, chats.created_at, chats.updated_at, chats.parent_chat_id, chats.root_chat_id, chats.last_model_config_id, chats.archived, chats.last_error, chats.mode, chats.mcp_server_ids, chats.labels, chats.build_id, chats.agent_id, chats.pin_order, chats.last_read_message_id, chats.last_injected_context, chats.dynamic_tools, chats.organization_id, chats.plan_mode, EXISTS ( SELECT 1 FROM chat_messages cm WHERE cm.chat_id = chats.id @@ -6494,6 +6499,7 @@ func (q *sqlQuerier) GetChats(ctx context.Context, arg GetChatsParams) ([]GetCha &i.Chat.LastInjectedContext, &i.Chat.DynamicTools, &i.Chat.OrganizationID, + &i.Chat.PlanMode, &i.HasUnread, ); err != nil { return nil, err @@ -6510,7 +6516,7 @@ func (q *sqlQuerier) GetChats(ctx context.Context, arg GetChatsParams) ([]GetCha } const getChatsByWorkspaceIDs = `-- name: GetChatsByWorkspaceIDs :many -SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools, organization_id +SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools, organization_id, plan_mode FROM chats WHERE archived = false AND workspace_id = ANY($1::uuid[]) @@ -6552,6 +6558,7 @@ func (q *sqlQuerier) GetChatsByWorkspaceIDs(ctx context.Context, ids []uuid.UUID &i.LastInjectedContext, &i.DynamicTools, &i.OrganizationID, + &i.PlanMode, ); err != nil { return nil, err } @@ -6679,7 +6686,7 @@ func (q *sqlQuerier) GetLastChatMessageByRole(ctx context.Context, arg GetLastCh const getStaleChats = `-- name: GetStaleChats :many SELECT - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools, organization_id + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools, organization_id, plan_mode FROM chats WHERE @@ -6728,6 +6735,7 @@ func (q *sqlQuerier) GetStaleChats(ctx context.Context, staleThreshold time.Time &i.LastInjectedContext, &i.DynamicTools, &i.OrganizationID, + &i.PlanMode, ); err != nil { return nil, err } @@ -6816,6 +6824,7 @@ INSERT INTO chats ( last_model_config_id, title, mode, + plan_mode, status, mcp_server_ids, labels, @@ -6831,13 +6840,14 @@ INSERT INTO chats ( $8::uuid, $9::text, $10::chat_mode, - $11::chat_status, - COALESCE($12::uuid[], '{}'::uuid[]), - COALESCE($13::jsonb, '{}'::jsonb), - $14::jsonb + $11::chat_plan_mode, + $12::chat_status, + COALESCE($13::uuid[], '{}'::uuid[]), + COALESCE($14::jsonb, '{}'::jsonb), + $15::jsonb ) RETURNING - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools, organization_id + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools, organization_id, plan_mode ` type InsertChatParams struct { @@ -6851,6 +6861,7 @@ type InsertChatParams struct { LastModelConfigID uuid.UUID `db:"last_model_config_id" json:"last_model_config_id"` Title string `db:"title" json:"title"` Mode NullChatMode `db:"mode" json:"mode"` + PlanMode NullChatPlanMode `db:"plan_mode" json:"plan_mode"` Status ChatStatus `db:"status" json:"status"` MCPServerIDs []uuid.UUID `db:"mcp_server_ids" json:"mcp_server_ids"` Labels pqtype.NullRawMessage `db:"labels" json:"labels"` @@ -6869,6 +6880,7 @@ func (q *sqlQuerier) InsertChat(ctx context.Context, arg InsertChatParams) (Chat arg.LastModelConfigID, arg.Title, arg.Mode, + arg.PlanMode, arg.Status, pq.Array(arg.MCPServerIDs), arg.Labels, @@ -6901,6 +6913,7 @@ func (q *sqlQuerier) InsertChat(ctx context.Context, arg InsertChatParams) (Chat &i.LastInjectedContext, &i.DynamicTools, &i.OrganizationID, + &i.PlanMode, ) return i, err } @@ -7429,9 +7442,9 @@ WITH chats AS ( archived = false, updated_at = NOW() WHERE id = $1::uuid OR root_chat_id = $1::uuid - RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools, organization_id + RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools, organization_id, plan_mode ) -SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools, organization_id +SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools, organization_id, plan_mode FROM chats ORDER BY (id = $1::uuid) DESC, created_at ASC, id ASC ` @@ -7475,6 +7488,7 @@ func (q *sqlQuerier) UnarchiveChatByID(ctx context.Context, id uuid.UUID) ([]Cha &i.LastInjectedContext, &i.DynamicTools, &i.OrganizationID, + &i.PlanMode, ); err != nil { return nil, err } @@ -7555,7 +7569,7 @@ UPDATE chats SET updated_at = NOW() WHERE id = $3::uuid -RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools, organization_id +RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools, organization_id, plan_mode ` type UpdateChatBuildAgentBindingParams struct { @@ -7593,6 +7607,7 @@ func (q *sqlQuerier) UpdateChatBuildAgentBinding(ctx context.Context, arg Update &i.LastInjectedContext, &i.DynamicTools, &i.OrganizationID, + &i.PlanMode, ) return i, err } @@ -7606,7 +7621,7 @@ SET WHERE id = $2::uuid RETURNING - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools, organization_id + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools, organization_id, plan_mode ` type UpdateChatByIDParams struct { @@ -7643,6 +7658,7 @@ func (q *sqlQuerier) UpdateChatByID(ctx context.Context, arg UpdateChatByIDParam &i.LastInjectedContext, &i.DynamicTools, &i.OrganizationID, + &i.PlanMode, ) return i, err } @@ -7701,7 +7717,7 @@ SET WHERE id = $2::uuid RETURNING - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools, organization_id + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools, organization_id, plan_mode ` type UpdateChatLabelsByIDParams struct { @@ -7738,6 +7754,7 @@ func (q *sqlQuerier) UpdateChatLabelsByID(ctx context.Context, arg UpdateChatLab &i.LastInjectedContext, &i.DynamicTools, &i.OrganizationID, + &i.PlanMode, ) return i, err } @@ -7747,7 +7764,7 @@ UPDATE chats SET last_injected_context = $1::jsonb WHERE id = $2::uuid -RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools, organization_id +RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools, organization_id, plan_mode ` type UpdateChatLastInjectedContextParams struct { @@ -7788,6 +7805,7 @@ func (q *sqlQuerier) UpdateChatLastInjectedContext(ctx context.Context, arg Upda &i.LastInjectedContext, &i.DynamicTools, &i.OrganizationID, + &i.PlanMode, ) return i, err } @@ -7801,7 +7819,7 @@ SET WHERE id = $2::uuid RETURNING - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools, organization_id + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools, organization_id, plan_mode ` type UpdateChatLastModelConfigByIDParams struct { @@ -7838,6 +7856,7 @@ func (q *sqlQuerier) UpdateChatLastModelConfigByID(ctx context.Context, arg Upda &i.LastInjectedContext, &i.DynamicTools, &i.OrganizationID, + &i.PlanMode, ) return i, err } @@ -7869,7 +7888,7 @@ SET WHERE id = $2::uuid RETURNING - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools, organization_id + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools, organization_id, plan_mode ` type UpdateChatMCPServerIDsParams struct { @@ -7906,6 +7925,7 @@ func (q *sqlQuerier) UpdateChatMCPServerIDs(ctx context.Context, arg UpdateChatM &i.LastInjectedContext, &i.DynamicTools, &i.OrganizationID, + &i.PlanMode, ) return i, err } @@ -8028,6 +8048,57 @@ func (q *sqlQuerier) UpdateChatPinOrder(ctx context.Context, arg UpdateChatPinOr return err } +const updateChatPlanModeByID = `-- name: UpdateChatPlanModeByID :one +UPDATE + chats +SET + -- NOTE: updated_at is intentionally NOT touched here to avoid changing list ordering. + plan_mode = $1::chat_plan_mode +WHERE + id = $2::uuid +RETURNING + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools, organization_id, plan_mode +` + +type UpdateChatPlanModeByIDParams struct { + PlanMode NullChatPlanMode `db:"plan_mode" json:"plan_mode"` + ID uuid.UUID `db:"id" json:"id"` +} + +func (q *sqlQuerier) UpdateChatPlanModeByID(ctx context.Context, arg UpdateChatPlanModeByIDParams) (Chat, error) { + row := q.db.QueryRowContext(ctx, updateChatPlanModeByID, arg.PlanMode, arg.ID) + var i Chat + err := row.Scan( + &i.ID, + &i.OwnerID, + &i.WorkspaceID, + &i.Title, + &i.Status, + &i.WorkerID, + &i.StartedAt, + &i.HeartbeatAt, + &i.CreatedAt, + &i.UpdatedAt, + &i.ParentChatID, + &i.RootChatID, + &i.LastModelConfigID, + &i.Archived, + &i.LastError, + &i.Mode, + pq.Array(&i.MCPServerIDs), + &i.Labels, + &i.BuildID, + &i.AgentID, + &i.PinOrder, + &i.LastReadMessageID, + &i.LastInjectedContext, + &i.DynamicTools, + &i.OrganizationID, + &i.PlanMode, + ) + return i, err +} + const updateChatStatus = `-- name: UpdateChatStatus :one UPDATE chats @@ -8041,7 +8112,7 @@ SET WHERE id = $6::uuid RETURNING - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools, organization_id + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools, organization_id, plan_mode ` type UpdateChatStatusParams struct { @@ -8089,6 +8160,7 @@ func (q *sqlQuerier) UpdateChatStatus(ctx context.Context, arg UpdateChatStatusP &i.LastInjectedContext, &i.DynamicTools, &i.OrganizationID, + &i.PlanMode, ) return i, err } @@ -8106,7 +8178,7 @@ SET WHERE id = $7::uuid RETURNING - id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools, organization_id + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools, organization_id, plan_mode ` type UpdateChatStatusPreserveUpdatedAtParams struct { @@ -8156,6 +8228,7 @@ func (q *sqlQuerier) UpdateChatStatusPreserveUpdatedAt(ctx context.Context, arg &i.LastInjectedContext, &i.DynamicTools, &i.OrganizationID, + &i.PlanMode, ) return i, err } @@ -8167,7 +8240,7 @@ UPDATE chats SET agent_id = $3::uuid, updated_at = NOW() WHERE id = $4::uuid -RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools, organization_id +RETURNING id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools, organization_id, plan_mode ` type UpdateChatWorkspaceBindingParams struct { @@ -8211,6 +8284,7 @@ func (q *sqlQuerier) UpdateChatWorkspaceBinding(ctx context.Context, arg UpdateC &i.LastInjectedContext, &i.DynamicTools, &i.OrganizationID, + &i.PlanMode, ) return i, err } @@ -19839,6 +19913,18 @@ func (q *sqlQuerier) GetChatIncludeDefaultSystemPrompt(ctx context.Context) (boo return include_default_system_prompt, err } +const getChatPlanModeInstructions = `-- name: GetChatPlanModeInstructions :one +SELECT + COALESCE((SELECT value FROM site_configs WHERE key = 'agents_chat_plan_mode_instructions'), '') :: text AS plan_mode_instructions +` + +func (q *sqlQuerier) GetChatPlanModeInstructions(ctx context.Context) (string, error) { + row := q.db.QueryRowContext(ctx, getChatPlanModeInstructions) + var plan_mode_instructions string + err := row.Scan(&plan_mode_instructions) + return plan_mode_instructions, err +} + const getChatRetentionDays = `-- name: GetChatRetentionDays :one SELECT COALESCE( (SELECT value::integer FROM site_configs @@ -20182,6 +20268,16 @@ func (q *sqlQuerier) UpsertChatIncludeDefaultSystemPrompt(ctx context.Context, i return err } +const upsertChatPlanModeInstructions = `-- name: UpsertChatPlanModeInstructions :exec +INSERT INTO site_configs (key, value) VALUES ('agents_chat_plan_mode_instructions', $1) +ON CONFLICT (key) DO UPDATE SET value = $1 WHERE site_configs.key = 'agents_chat_plan_mode_instructions' +` + +func (q *sqlQuerier) UpsertChatPlanModeInstructions(ctx context.Context, value string) error { + _, err := q.db.ExecContext(ctx, upsertChatPlanModeInstructions, value) + return err +} + const upsertChatRetentionDays = `-- name: UpsertChatRetentionDays :exec INSERT INTO site_configs (key, value) VALUES ('agents_chat_retention_days', CAST($1 AS integer)::text) diff --git a/coderd/database/queries/chats.sql b/coderd/database/queries/chats.sql index 04887c91e0..018c058704 100644 --- a/coderd/database/queries/chats.sql +++ b/coderd/database/queries/chats.sql @@ -402,6 +402,7 @@ INSERT INTO chats ( last_model_config_id, title, mode, + plan_mode, status, mcp_server_ids, labels, @@ -417,6 +418,7 @@ INSERT INTO chats ( @last_model_config_id::uuid, @title::text, sqlc.narg('mode')::chat_mode, + sqlc.narg('plan_mode')::chat_plan_mode, @status::chat_status, COALESCE(@mcp_server_ids::uuid[], '{}'::uuid[]), COALESCE(sqlc.narg('labels')::jsonb, '{}'::jsonb), @@ -518,6 +520,17 @@ WHERE RETURNING *; +-- name: UpdateChatPlanModeByID :one +UPDATE + chats +SET + -- NOTE: updated_at is intentionally NOT touched here to avoid changing list ordering. + plan_mode = sqlc.narg('plan_mode')::chat_plan_mode +WHERE + id = @id::uuid +RETURNING + *; + -- name: UpdateChatLastModelConfigByID :one UPDATE chats diff --git a/coderd/database/queries/siteconfig.sql b/coderd/database/queries/siteconfig.sql index ac7df002b2..984dbf1aa9 100644 --- a/coderd/database/queries/siteconfig.sql +++ b/coderd/database/queries/siteconfig.sql @@ -159,6 +159,14 @@ SELECT INSERT INTO site_configs (key, value) VALUES ('agents_chat_system_prompt', $1) ON CONFLICT (key) DO UPDATE SET value = $1 WHERE site_configs.key = 'agents_chat_system_prompt'; +-- name: GetChatPlanModeInstructions :one +SELECT + COALESCE((SELECT value FROM site_configs WHERE key = 'agents_chat_plan_mode_instructions'), '') :: text AS plan_mode_instructions; + +-- name: UpsertChatPlanModeInstructions :exec +INSERT INTO site_configs (key, value) VALUES ('agents_chat_plan_mode_instructions', $1) +ON CONFLICT (key) DO UPDATE SET value = $1 WHERE site_configs.key = 'agents_chat_plan_mode_instructions'; + -- name: GetChatDesktopEnabled :one SELECT COALESCE((SELECT value = 'true' FROM site_configs WHERE key = 'agents_desktop_enabled'), false) :: boolean AS enable_desktop; diff --git a/coderd/exp_chats.go b/coderd/exp_chats.go index bd310cc175..5e9b7c567f 100644 --- a/coderd/exp_chats.go +++ b/coderd/exp_chats.go @@ -379,6 +379,16 @@ func (api *API) getChatDiffStatusesByChatID( return statusesByChatID, nil } +func planModeToNullChatPlanMode(mode codersdk.ChatPlanMode) database.NullChatPlanMode { + if mode == "" { + return database.NullChatPlanMode{} + } + return database.NullChatPlanMode{ + ChatPlanMode: database.ChatPlanMode(mode), + Valid: true, + } +} + // EXPERIMENTAL: this endpoint is experimental and is subject to change. func (api *API) postChats(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() @@ -457,6 +467,16 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) { return } + switch req.PlanMode { + case codersdk.ChatPlanModePlan, "": + // Valid. + default: + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Invalid plan_mode value.", + }) + return + } + // Validate MCP server IDs exist. if len(req.MCPServerIDs) > 0 { //nolint:gocritic // Need to validate MCP server IDs exist. @@ -554,6 +574,7 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) { WorkspaceID: workspaceSelection.WorkspaceID, Title: title, ModelConfigID: modelConfigID, + PlanMode: planModeToNullChatPlanMode(req.PlanMode), SystemPrompt: req.SystemPrompt, InitialUserContent: contentBlocks, MCPServerIDs: mcpServerIDs, @@ -1728,7 +1749,7 @@ func (api *API) watchChatDesktop(rw http.ResponseWriter, r *http.Request) { } // patchChat updates a chat resource. Supports updating labels, -// archiving, pinning, and pinned-chat ordering. +// workspace binding, archiving, pinning, and pinned-chat ordering. func (api *API) patchChat(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() chat := httpmw.ChatParam(r) @@ -1738,6 +1759,21 @@ func (api *API) patchChat(rw http.ResponseWriter, r *http.Request) { return } + var planModeUpdate *database.NullChatPlanMode + if req.PlanMode != nil { + switch *req.PlanMode { + case codersdk.ChatPlanModePlan, "": + // Valid. + default: + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Invalid plan_mode value.", + }) + return + } + resolvedPlanMode := planModeToNullChatPlanMode(*req.PlanMode) + planModeUpdate = &resolvedPlanMode + } + if req.Labels != nil { if errs := httpapi.ValidateChatLabels(*req.Labels); len(errs) > 0 { httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ @@ -1863,6 +1899,64 @@ func (api *API) patchChat(rw http.ResponseWriter, r *http.Request) { } } + if req.WorkspaceID != nil { + workspaceID := uuid.NullUUID{} + workspace := database.Workspace{} + if *req.WorkspaceID != uuid.Nil { + var status int + var resp *codersdk.Response + workspaceID, workspace, status, resp = api.validateChatWorkspaceSelection(ctx, r, req.WorkspaceID) + if resp != nil { + httpapi.Write(ctx, rw, status, *resp) + return + } + if workspace.OrganizationID != chat.OrganizationID { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Workspace does not belong to this chat's organization.", + }) + return + } + } + + updatedChat, err := api.Database.UpdateChatWorkspaceBinding(ctx, database.UpdateChatWorkspaceBindingParams{ + ID: chat.ID, + WorkspaceID: workspaceID, + BuildID: uuid.NullUUID{}, + AgentID: uuid.NullUUID{}, + }) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + httpapi.ResourceNotFound(rw) + return + } + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to update chat workspace binding.", + Detail: err.Error(), + }) + return + } + chat = updatedChat + } + + if planModeUpdate != nil { + updatedChat, err := api.Database.UpdateChatPlanModeByID(ctx, database.UpdateChatPlanModeByIDParams{ + PlanMode: *planModeUpdate, + ID: chat.ID, + }) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + httpapi.ResourceNotFound(rw) + return + } + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to update chat plan mode.", + Detail: err.Error(), + }) + return + } + chat = updatedChat + } + rw.WriteHeader(http.StatusNoContent) } @@ -1937,6 +2031,24 @@ func (api *API) postChatMessages(rw http.ResponseWriter, r *http.Request) { } } + if req.PlanMode != nil { + switch *req.PlanMode { + case codersdk.ChatPlanModePlan, "": + // Valid. + default: + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Invalid plan_mode value.", + }) + return + } + } + + var sendPlanMode *database.NullChatPlanMode + if req.PlanMode != nil { + resolvedPlanMode := planModeToNullChatPlanMode(*req.PlanMode) + sendPlanMode = &resolvedPlanMode + } + busyBehavior := chatd.SendMessageBusyBehaviorQueue switch req.BusyBehavior { case codersdk.ChatBusyBehaviorInterrupt: @@ -1959,6 +2071,7 @@ func (api *API) postChatMessages(rw http.ResponseWriter, r *http.Request) { Content: contentBlocks, ModelConfigID: req.ModelConfigID, BusyBehavior: busyBehavior, + PlanMode: sendPlanMode, MCPServerIDs: req.MCPServerIDs, }, ) @@ -2929,6 +3042,46 @@ type createChatWorkspaceSelection struct { WorkspaceID uuid.NullUUID } +func (api *API) validateChatWorkspaceSelection( + ctx context.Context, + r *http.Request, + workspaceID *uuid.UUID, +) ( + uuid.NullUUID, + database.Workspace, + int, + *codersdk.Response, +) { + if workspaceID == nil { + return uuid.NullUUID{}, database.Workspace{}, 0, nil + } + + workspace, err := api.Database.GetWorkspaceByID(ctx, *workspaceID) + if err != nil { + if httpapi.Is404Error(err) { + return uuid.NullUUID{}, database.Workspace{}, http.StatusBadRequest, &codersdk.Response{ + Message: "Workspace not found or you do not have access to this resource", + } + } + return uuid.NullUUID{}, database.Workspace{}, http.StatusInternalServerError, &codersdk.Response{ + Message: "Failed to get workspace.", + Detail: err.Error(), + } + } + + selection := uuid.NullUUID{ + UUID: workspace.ID, + Valid: true, + } + if !api.Authorize(r, policy.ActionSSH, workspace) { + return uuid.NullUUID{}, database.Workspace{}, http.StatusBadRequest, &codersdk.Response{ + Message: "Workspace not found or you do not have access to this resource", + } + } + + return selection, workspace, 0, nil +} + func (api *API) validateCreateChatWorkspaceSelection( ctx context.Context, r *http.Request, @@ -2939,39 +3092,20 @@ func (api *API) validateCreateChatWorkspaceSelection( *codersdk.Response, ) { selection := createChatWorkspaceSelection{} - if req.WorkspaceID == nil { + workspaceID, workspace, status, resp := api.validateChatWorkspaceSelection(ctx, r, req.WorkspaceID) + if resp != nil { + return selection, status, resp + } + selection.WorkspaceID = workspaceID + if !workspaceID.Valid { return selection, 0, nil } - - workspace, err := api.Database.GetWorkspaceByID(ctx, *req.WorkspaceID) - if err != nil { - if httpapi.Is404Error(err) { - return selection, http.StatusBadRequest, &codersdk.Response{ - Message: "Workspace not found or you do not have access to this resource", - } - } - return selection, http.StatusInternalServerError, &codersdk.Response{ - Message: "Failed to get workspace.", - Detail: err.Error(), - } - } - selection.WorkspaceID = uuid.NullUUID{ - UUID: workspace.ID, - Valid: true, - } - if workspace.OrganizationID != req.OrganizationID { return selection, http.StatusBadRequest, &codersdk.Response{ Message: "Workspace does not belong to the specified organization.", } } - if !api.Authorize(r, policy.ActionSSH, workspace) { - return selection, http.StatusBadRequest, &codersdk.Response{ - Message: "Workspace not found or you do not have access to this resource", - } - } - return selection, 0, nil } @@ -3151,6 +3285,67 @@ func (api *API) putChatSystemPrompt(rw http.ResponseWriter, r *http.Request) { rw.WriteHeader(http.StatusNoContent) } +// EXPERIMENTAL: this endpoint is experimental and is subject to change. +// +//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler. +func (api *API) getChatPlanModeInstructions(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) { + httpapi.ResourceNotFound(rw) + return + } + + instructions, err := api.Database.GetChatPlanModeInstructions(ctx) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error fetching plan mode instructions.", + Detail: err.Error(), + }) + return + } + + httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatPlanModeInstructionsResponse{ + PlanModeInstructions: instructions, + }) +} + +// EXPERIMENTAL: this endpoint is experimental and is subject to change. +func (api *API) putChatPlanModeInstructions(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) { + httpapi.Forbidden(rw) + return + } + + // Cap the raw request body to prevent excessive memory use from + // payloads padded with invisible characters that sanitize away. + r.Body = http.MaxBytesReader(rw, r.Body, int64(2*maxSystemPromptLenBytes)) + + var req codersdk.UpdateChatPlanModeInstructionsRequest + if !httpapi.Read(ctx, rw, r, &req) { + return + } + + sanitizedInstructions := chatd.SanitizePromptText(req.PlanModeInstructions) + if len(sanitizedInstructions) > maxSystemPromptLenBytes { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Plan mode instructions exceed maximum length.", + Detail: fmt.Sprintf("Maximum length is %d bytes, got %d.", maxSystemPromptLenBytes, len(sanitizedInstructions)), + }) + return + } + + if err := api.Database.UpsertChatPlanModeInstructions(ctx, sanitizedInstructions); err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error updating plan mode instructions.", + Detail: err.Error(), + }) + return + } + + rw.WriteHeader(http.StatusNoContent) +} + // EXPERIMENTAL: this endpoint is experimental and is subject to change. // //nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler. diff --git a/coderd/exp_chats_test.go b/coderd/exp_chats_test.go index ffbae26413..5188914117 100644 --- a/coderd/exp_chats_test.go +++ b/coderd/exp_chats_test.go @@ -664,12 +664,10 @@ func TestPostChats(t *testing.T) { _, err := memberClient.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: uuid.Nil, - Content: []codersdk.ChatInputPart{ - { - Type: codersdk.ChatInputPartTypeText, - Text: "hello", - }, - }, + Content: []codersdk.ChatInputPart{{ + Type: codersdk.ChatInputPartTypeText, + Text: "hello", + }}, }) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Equal(t, "organization_id is required.", sdkErr.Message) @@ -692,12 +690,10 @@ func TestPostChats(t *testing.T) { _, err := memberClient.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: secondOrg.ID, - Content: []codersdk.ChatInputPart{ - { - Type: codersdk.ChatInputPartTypeText, - Text: "hello", - }, - }, + Content: []codersdk.ChatInputPart{{ + Type: codersdk.ChatInputPartTypeText, + Text: "hello", + }}, }) sdkErr := requireSDKError(t, err, http.StatusForbidden) require.Equal(t, "You are not a member of the specified organization.", sdkErr.Message) @@ -727,12 +723,10 @@ func TestPostChats(t *testing.T) { _, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: secondOrg.ID, - Content: []codersdk.ChatInputPart{ - { - Type: codersdk.ChatInputPartTypeText, - Text: "hello", - }, - }, + Content: []codersdk.ChatInputPart{{ + Type: codersdk.ChatInputPartTypeText, + Text: "hello", + }}, WorkspaceID: &workspaceBuild.Workspace.ID, }) sdkErr := requireSDKError(t, err, http.StatusBadRequest) @@ -3615,6 +3609,251 @@ func TestGetChat(t *testing.T) { }) } +func TestPatchChat(t *testing.T) { + t.Parallel() + + createChat := func(ctx context.Context, t *testing.T, client *codersdk.ExperimentalClient, orgID uuid.UUID, text string) codersdk.Chat { + t.Helper() + + chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ + OrganizationID: orgID, + Content: []codersdk.ChatInputPart{ + { + Type: codersdk.ChatInputPartTypeText, + Text: text, + }, + }, + }) + require.NoError(t, err) + return chat + } + + getChat := func(ctx context.Context, t *testing.T, client *codersdk.ExperimentalClient, chatID uuid.UUID) codersdk.Chat { + t.Helper() + + chat, err := client.GetChat(ctx, chatID) + require.NoError(t, err) + return chat + } + + createStoredChat := func( + ctx context.Context, + t *testing.T, + db database.Store, + ownerID uuid.UUID, + orgID uuid.UUID, + modelConfigID uuid.UUID, + title string, + ) codersdk.Chat { + t.Helper() + + dbChat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{ + OrganizationID: orgID, + Status: database.ChatStatusWaiting, + OwnerID: ownerID, + LastModelConfigID: modelConfigID, + Title: title, + }) + require.NoError(t, err) + return db2sdk.Chat(dbChat, nil, nil) + } + t.Run("PlanMode", func(t *testing.T) { + t.Parallel() + + t.Run("SetToPlan", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + firstUser := coderdtest.CreateFirstUser(t, client.Client) + _ = createChatModelConfig(t, client) + + chat := createChat(ctx, t, client, firstUser.OrganizationID, "set plan mode") + err := client.UpdateChat(ctx, chat.ID, codersdk.UpdateChatRequest{ + PlanMode: ptr.Ref(codersdk.ChatPlanModePlan), + }) + require.NoError(t, err) + + updated := getChat(ctx, t, client, chat.ID) + require.Equal(t, codersdk.ChatPlanModePlan, updated.PlanMode) + }) + + t.Run("Clear", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + firstUser := coderdtest.CreateFirstUser(t, client.Client) + _ = createChatModelConfig(t, client) + + chat := createChat(ctx, t, client, firstUser.OrganizationID, "clear plan mode") + err := client.UpdateChat(ctx, chat.ID, codersdk.UpdateChatRequest{ + PlanMode: ptr.Ref(codersdk.ChatPlanModePlan), + }) + require.NoError(t, err) + + err = client.UpdateChat(ctx, chat.ID, codersdk.UpdateChatRequest{ + PlanMode: ptr.Ref(codersdk.ChatPlanMode("")), + }) + require.NoError(t, err) + + updated := getChat(ctx, t, client, chat.ID) + require.Empty(t, updated.PlanMode) + }) + + t.Run("RejectsInvalidValue", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + firstUser := coderdtest.CreateFirstUser(t, client.Client) + _ = createChatModelConfig(t, client) + + chat := createChat(ctx, t, client, firstUser.OrganizationID, "invalid plan mode") + invalidPlanMode := codersdk.ChatPlanMode("invalid") + err := client.UpdateChat(ctx, chat.ID, codersdk.UpdateChatRequest{ + PlanMode: &invalidPlanMode, + }) + sdkErr := requireSDKError(t, err, http.StatusBadRequest) + require.Equal(t, "Invalid plan_mode value.", sdkErr.Message) + }) + }) + + t.Run("WorkspaceBinding", func(t *testing.T) { + t.Parallel() + + t.Run("BindValidWorkspace", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client, db := newChatClientWithDatabase(t) + firstUser := coderdtest.CreateFirstUser(t, client.Client) + modelConfig := createChatModelConfig(t, client) + + workspaceBuild := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ + OrganizationID: firstUser.OrganizationID, + OwnerID: firstUser.UserID, + }).WithAgent().Do() + chat := createStoredChat( + ctx, + t, + db, + firstUser.UserID, + firstUser.OrganizationID, + modelConfig.ID, + "bind workspace", + ) + + err := client.UpdateChat(ctx, chat.ID, codersdk.UpdateChatRequest{ + WorkspaceID: &workspaceBuild.Workspace.ID, + }) + require.NoError(t, err) + + updated := getChat(ctx, t, client, chat.ID) + require.NotNil(t, updated.WorkspaceID) + require.Equal(t, workspaceBuild.Workspace.ID, *updated.WorkspaceID) + }) + + t.Run("WorkspaceNotFound", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client, db := newChatClientWithDatabase(t) + firstUser := coderdtest.CreateFirstUser(t, client.Client) + modelConfig := createChatModelConfig(t, client) + + chat := createStoredChat( + ctx, + t, + db, + firstUser.UserID, + firstUser.OrganizationID, + modelConfig.ID, + "missing workspace", + ) + workspaceID := uuid.New() + err := client.UpdateChat(ctx, chat.ID, codersdk.UpdateChatRequest{ + WorkspaceID: &workspaceID, + }) + sdkErr := requireSDKError(t, err, http.StatusBadRequest) + require.Equal(t, "Workspace not found or you do not have access to this resource", sdkErr.Message) + }) + + t.Run("RejectsCrossOrgWorkspaceBinding", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client, db := newChatClientWithDatabase(t) + firstUser := coderdtest.CreateFirstUser(t, client.Client) + modelConfig := createChatModelConfig(t, client) + + secondOrg := dbgen.Organization(t, db, database.Organization{}) + dbgen.OrganizationMember(t, db, database.OrganizationMember{ + OrganizationID: secondOrg.ID, + UserID: firstUser.UserID, + }) + workspaceBuild := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ + OrganizationID: secondOrg.ID, + OwnerID: firstUser.UserID, + }).WithAgent().Do() + chat := createStoredChat( + ctx, + t, + db, + firstUser.UserID, + firstUser.OrganizationID, + modelConfig.ID, + "cross org workspace binding", + ) + + err := client.UpdateChat(ctx, chat.ID, codersdk.UpdateChatRequest{ + WorkspaceID: &workspaceBuild.Workspace.ID, + }) + sdkErr := requireSDKError(t, err, http.StatusBadRequest) + require.Equal(t, "Workspace does not belong to this chat's organization.", sdkErr.Message) + }) + + t.Run("ClearWorkspaceBinding", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client, db := newChatClientWithDatabase(t) + firstUser := coderdtest.CreateFirstUser(t, client.Client) + modelConfig := createChatModelConfig(t, client) + + workspaceBuild := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ + OrganizationID: firstUser.OrganizationID, + OwnerID: firstUser.UserID, + }).WithAgent().Do() + chat := createStoredChat( + ctx, + t, + db, + firstUser.UserID, + firstUser.OrganizationID, + modelConfig.ID, + "clear workspace binding", + ) + + err := client.UpdateChat(ctx, chat.ID, codersdk.UpdateChatRequest{ + WorkspaceID: &workspaceBuild.Workspace.ID, + }) + require.NoError(t, err) + + workspaceID := uuid.Nil + err = client.UpdateChat(ctx, chat.ID, codersdk.UpdateChatRequest{ + WorkspaceID: &workspaceID, + }) + require.NoError(t, err) + + updated := getChat(ctx, t, client, chat.ID) + require.Nil(t, updated.WorkspaceID) + require.Nil(t, updated.BuildID) + require.Nil(t, updated.AgentID) + }) + }) +} + func TestArchiveChat(t *testing.T) { t.Parallel() @@ -7930,6 +8169,80 @@ func TestChatSystemPrompt(t *testing.T) { }) } +//nolint:tparallel,paralleltest // Subtests share a single coderdtest instance. +func TestChatPlanModeInstructions(t *testing.T) { + t.Parallel() + + adminClient, _ := newChatClientWithDatabase(t) + firstUser := coderdtest.CreateFirstUser(t, adminClient.Client) + _ = createChatModelConfig(t, adminClient) + memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID) + memberClient := codersdk.NewExperimentalClient(memberClientRaw) + + updateChatPlanModeInstructions := func(t *testing.T, ctx context.Context, req codersdk.UpdateChatPlanModeInstructionsRequest) { + t.Helper() + + err := adminClient.UpdateChatPlanModeInstructions(ctx, req) + require.NoError(t, err) + } + + getChatPlanModeInstructions := func(t *testing.T, ctx context.Context) codersdk.ChatPlanModeInstructionsResponse { + t.Helper() + + resp, err := adminClient.GetChatPlanModeInstructions(ctx) + require.NoError(t, err) + return resp + } + + roundTripTests := []struct { + name string + updates []string + want string + }{ + { + name: "DefaultGETReturnsEmpty", + want: "", + }, + { + name: "PUTThenGETRoundTrips", + updates: []string{"Use plan mode for multi-step changes."}, + want: "Use plan mode for multi-step changes.", + }, + } + for _, tt := range roundTripTests { + t.Run(tt.name, func(t *testing.T) { + ctx := testutil.Context(t, testutil.WaitLong) + + for _, instructions := range tt.updates { + updateChatPlanModeInstructions(t, ctx, codersdk.UpdateChatPlanModeInstructionsRequest{ + PlanModeInstructions: instructions, + }) + } + + resp := getChatPlanModeInstructions(t, ctx) + require.Equal(t, tt.want, resp.PlanModeInstructions) + }) + } + + t.Run("OversizedPayloadReturns400", func(t *testing.T) { + ctx := testutil.Context(t, testutil.WaitLong) + tooLong := strings.Repeat("a", 131073) + + err := adminClient.UpdateChatPlanModeInstructions(ctx, codersdk.UpdateChatPlanModeInstructionsRequest{ + PlanModeInstructions: tooLong, + }) + sdkErr := requireSDKError(t, err, http.StatusBadRequest) + require.Equal(t, "Plan mode instructions exceed maximum length.", sdkErr.Message) + }) + + t.Run("NonAdminGETReturns404", func(t *testing.T) { + ctx := testutil.Context(t, testutil.WaitLong) + + _, err := memberClient.GetChatPlanModeInstructions(ctx) + requireSDKError(t, err, http.StatusNotFound) + }) +} + func TestChatDesktopEnabled(t *testing.T) { t.Parallel() diff --git a/coderd/x/chatd/chatd.go b/coderd/x/chatd/chatd.go index d920e66174..04927110d8 100644 --- a/coderd/x/chatd/chatd.go +++ b/coderd/x/chatd/chatd.go @@ -788,6 +788,7 @@ type CreateOptions struct { Title string ModelConfigID uuid.UUID ChatMode database.NullChatMode + PlanMode database.NullChatPlanMode SystemPrompt string InitialUserContent []codersdk.ChatMessagePart MCPServerIDs []uuid.UUID @@ -815,6 +816,7 @@ type SendMessageOptions struct { Content []codersdk.ChatMessagePart ModelConfigID *uuid.UUID BusyBehavior SendMessageBusyBehavior + PlanMode *database.NullChatPlanMode MCPServerIDs *[]uuid.UUID } @@ -882,6 +884,8 @@ func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.C // another pool checkout. deploymentPrompt := p.resolveDeploymentSystemPrompt(ctx) + effectivePlanMode := opts.PlanMode + var chat database.Chat txErr := p.db.InTx(func(tx database.Store) error { if limitErr := p.checkUsageLimit(ctx, tx, opts.OwnerID, uuid.NullUUID{UUID: opts.OrganizationID, Valid: true}); limitErr != nil { @@ -904,6 +908,7 @@ func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.C LastModelConfigID: opts.ModelConfigID, Title: opts.Title, Mode: opts.ChatMode, + PlanMode: effectivePlanMode, // Chats created with an initial user message start pending. // Waiting is reserved for idle chats with no pending work. Status: database.ChatStatusPending, @@ -1040,6 +1045,8 @@ func (p *Server) SendMessage( return SendMessageResult{}, xerrors.Errorf("marshal message content: %w", err) } + requestedPlanMode := opts.PlanMode + var ( result SendMessageResult queuedMessagesSDK []codersdk.ChatQueuedMessage @@ -1056,6 +1063,16 @@ func (p *Server) SendMessage( return limitErr } + if requestedPlanMode != nil { + lockedChat, err = tx.UpdateChatPlanModeByID(ctx, database.UpdateChatPlanModeByIDParams{ + PlanMode: *requestedPlanMode, + ID: opts.ChatID, + }) + if err != nil { + return xerrors.Errorf("update chat plan mode: %w", err) + } + } + modelConfigID := lockedChat.LastModelConfigID if opts.ModelConfigID != nil { modelConfigID = *opts.ModelConfigID @@ -4365,6 +4382,336 @@ type runChatResult struct { PendingDynamicToolCalls []chatloop.PendingToolCall } +func allowedPlanToolNames( + allTools []fantasy.AgentTool, + mode database.NullChatPlanMode, + parentChatID uuid.NullUUID, +) []string { + isPlanModeTurn := mode.Valid && mode.ChatPlanMode == database.ChatPlanModePlan + isRootChat := !parentChatID.Valid + builtinPlanPolicy := map[string]bool{ + "read_file": true, + "write_file": isRootChat, + "edit_files": isRootChat, + "execute": true, + "process_output": true, + "process_list": false, + "process_signal": false, + "list_templates": isRootChat, + "read_template": isRootChat, + "create_workspace": isRootChat, + "start_workspace": isRootChat, + "propose_plan": isRootChat, + "spawn_agent": isRootChat, + "wait_agent": isRootChat, + "message_agent": false, + "close_agent": false, + "spawn_computer_use_agent": false, + "read_skill": true, + "read_skill_file": true, + "ask_user_question": isRootChat, + } + if !isPlanModeTurn { + toolNames := make([]string, 0, len(allTools)) + for _, tool := range allTools { + toolNames = append(toolNames, tool.Info().Name) + } + return toolNames + } + + toolNames := make([]string, 0, len(allTools)) + for _, tool := range allTools { + name := tool.Info().Name + if builtinPlanPolicy[name] { + toolNames = append(toolNames, name) + } + } + return toolNames +} + +func stopAfterPlanTools(mode database.NullChatPlanMode, parentChatID uuid.NullUUID) map[string]struct{} { + if !mode.Valid || mode.ChatPlanMode != database.ChatPlanModePlan { + return nil + } + stopTools := map[string]struct{}{ + "propose_plan": {}, + } + if !parentChatID.Valid { + stopTools["ask_user_question"] = struct{}{} + } + return stopTools +} + +type systemPromptPlanContext struct { + mode database.NullChatPlanMode + planModeInstructions string + isRootChat bool +} + +// buildSystemPrompt applies system-level prompt injections in the +// canonical order. It is used by both the initial prompt assembly +// and the ReloadMessages callback to keep them in sync. +func buildSystemPrompt( + prompt []fantasy.Message, + subagentInstruction string, + instruction string, + skills []chattool.SkillMeta, + userPrompt string, + planContext systemPromptPlanContext, +) []fantasy.Message { + if subagentInstruction != "" { + prompt = chatprompt.InsertSystem(prompt, subagentInstruction) + } + if instruction != "" { + prompt = chatprompt.InsertSystem(prompt, instruction) + } + if skillIndex := chattool.FormatSkillIndex(skills); skillIndex != "" { + prompt = chatprompt.InsertSystem(prompt, skillIndex) + } + if userPrompt != "" { + prompt = chatprompt.InsertSystem(prompt, userPrompt) + } + isPlanModeTurn := planContext.mode.Valid && planContext.mode.ChatPlanMode == database.ChatPlanModePlan + if isPlanModeTurn { + if planContext.isRootChat { + prompt = chatprompt.InsertSystem(prompt, PlanningOverlayPrompt) + if planContext.planModeInstructions != "" { + prompt = chatprompt.InsertSystem(prompt, planContext.planModeInstructions) + } + } else { + prompt = chatprompt.InsertSystem(prompt, PlanningSubagentOverlayPrompt) + } + } + return prompt +} + +type rootChatToolsOptions struct { + chat database.Chat + modelConfigID uuid.UUID + workspaceCtx *turnWorkspaceContext + workspaceMu *sync.Mutex + instruction *string + skills *[]chattool.SkillMeta + resolvePlanPath func(context.Context) (string, string, error) + isPlanModeTurn bool +} + +func (p *Server) loadPlanModeInstructions( + ctx context.Context, + mode database.NullChatPlanMode, + logger slog.Logger, +) string { + if !mode.Valid || mode.ChatPlanMode != database.ChatPlanModePlan { + return "" + } + + // Plan-mode instructions live in deployment config, but chat workers do + // not carry a deployment-config actor during background execution. + //nolint:gocritic // Required to read deployment config during background chat processing. + systemCtx := dbauthz.AsSystemRestricted(ctx) + fetched, err := p.db.GetChatPlanModeInstructions(systemCtx) + if err != nil { + logger.Warn(ctx, + "failed to fetch plan mode instructions", + slog.Error(err), + ) + return "" + } + + return fetched +} + +func (p *Server) appendRootChatTools( + ctx context.Context, + tools []fantasy.AgentTool, + opts rootChatToolsOptions, +) []fantasy.AgentTool { + onChatUpdated := func(updatedChat database.Chat) { + opts.workspaceCtx.selectWorkspace(updatedChat) + // Notify the frontend immediately so it can start streaming + // build logs before the tool completes. + p.publishChatPubsubEvent(updatedChat, codersdk.ChatWatchEventKindStatusChange, nil) + + // When a workspace is first attached mid-turn (e.g. via + // create_workspace), fetch and persist instruction files + // immediately so the LLM has AGENTS.md context for the remainder + // of this turn. The persisted marker prevents redundant fetches on + // subsequent turns. + if *opts.instruction == "" && updatedChat.WorkspaceID.Valid { + newInstruction, discoveredSkills, persistErr := p.persistInstructionFiles( + ctx, + updatedChat, + opts.modelConfigID, + opts.workspaceCtx.getWorkspaceAgent, + opts.workspaceCtx.getWorkspaceConn, + ) + if persistErr != nil { + p.logger.Warn(ctx, "failed to persist instruction files on workspace attach", + slog.F("chat_id", updatedChat.ID), + slog.Error(persistErr), + ) + } else { + *opts.instruction = newInstruction + if len(discoveredSkills) > 0 { + *opts.skills = discoveredSkills + } + } + } + } + + tools = append(tools, + chattool.ListTemplates(opts.chat.OrganizationID, p.db, chattool.ListTemplatesOptions{ + OwnerID: opts.chat.OwnerID, + AllowedTemplateIDs: p.chatTemplateAllowlist, + }), + chattool.ReadTemplate(opts.chat.OrganizationID, p.db, chattool.ReadTemplateOptions{ + OwnerID: opts.chat.OwnerID, + AllowedTemplateIDs: p.chatTemplateAllowlist, + }), + chattool.CreateWorkspace(opts.chat.OrganizationID, p.db, chattool.CreateWorkspaceOptions{ + OwnerID: opts.chat.OwnerID, + ChatID: opts.chat.ID, + CreateFn: p.createWorkspaceFn, + AgentConnFn: chattool.AgentConnFunc(p.agentConnFn), + AgentInactiveDisconnectTimeout: p.agentInactiveDisconnectTimeout, + WorkspaceMu: opts.workspaceMu, + OnChatUpdated: onChatUpdated, + Logger: p.logger, + AllowedTemplateIDs: p.chatTemplateAllowlist, + }), + chattool.StartWorkspace(chattool.StartWorkspaceOptions{ + DB: p.db, + OwnerID: opts.chat.OwnerID, + ChatID: opts.chat.ID, + StartFn: p.startWorkspaceFn, + AgentConnFn: chattool.AgentConnFunc(p.agentConnFn), + WorkspaceMu: opts.workspaceMu, + OnChatUpdated: onChatUpdated, + Logger: p.logger, + }), + ) + if opts.isPlanModeTurn { + tools = append(tools, chattool.ProposePlan(chattool.ProposePlanOptions{ + GetWorkspaceConn: opts.workspaceCtx.getWorkspaceConn, + ResolvePlanPath: opts.resolvePlanPath, + IsPlanTurn: opts.isPlanModeTurn, + StoreFile: func(ctx context.Context, name string, mediaType string, data []byte) (uuid.UUID, error) { + return p.storePlanSnapshotFile(ctx, opts.workspaceCtx, name, mediaType, data) + }, + })) + } + + return append(tools, p.subagentTools(ctx, func() database.Chat { + return opts.chat + })...) +} + +func (p *Server) storePlanSnapshotFile( + ctx context.Context, + workspaceCtx *turnWorkspaceContext, + name string, + mediaType string, + data []byte, +) (uuid.UUID, error) { + chatSnapshot := workspaceCtx.currentChatSnapshot() + if !chatSnapshot.WorkspaceID.Valid { + return uuid.Nil, xerrors.New("no workspace is associated with this chat. Use the create_workspace tool to create one") + } + + ws, err := p.db.GetWorkspaceByID(ctx, chatSnapshot.WorkspaceID.UUID) + if err != nil { + return uuid.Nil, xerrors.Errorf("resolve workspace: %w", err) + } + + row, err := p.db.InsertChatFile(ctx, database.InsertChatFileParams{ + OwnerID: chatSnapshot.OwnerID, + OrganizationID: ws.OrganizationID, + Name: name, + Mimetype: mediaType, + Data: data, + }) + if err != nil { + return uuid.Nil, xerrors.Errorf("insert chat file: %w", err) + } + + // Cap enforcement and dedup are handled atomically in SQL. + // rejected > 0 means the cap was exceeded. + rejected, err := p.db.LinkChatFiles(ctx, database.LinkChatFilesParams{ + ChatID: chatSnapshot.ID, + MaxFileLinks: int32(codersdk.MaxChatFileIDs), + FileIds: []uuid.UUID{row.ID}, + }) + switch { + case err != nil: + p.logger.Error(ctx, "failed to link file to chat", + slog.F("chat_id", chatSnapshot.ID), + slog.F("file_id", row.ID), + slog.Error(err), + ) + case rejected > 0: + p.logger.Warn(ctx, "file cap reached, file not linked to chat", + slog.F("chat_id", chatSnapshot.ID), + slog.F("file_id", row.ID), + slog.F("max_file_links", codersdk.MaxChatFileIDs), + ) + } + + return row.ID, nil +} + +func appendDynamicTools( + ctx context.Context, + logger slog.Logger, + tools []fantasy.AgentTool, + raw pqtype.NullRawMessage, + mode database.NullChatPlanMode, + parentChatID uuid.NullUUID, +) ([]fantasy.AgentTool, map[string]bool, error) { + if mode.Valid && mode.ChatPlanMode == database.ChatPlanModePlan { + return tools, nil, nil + } + + dynamicToolNames, err := parseDynamicToolNames(raw) + if err != nil { + return nil, nil, xerrors.Errorf("parse dynamic tool names: %w", err) + } + if len(dynamicToolNames) == 0 { + return tools, dynamicToolNames, nil + } + + var dynamicToolDefs []codersdk.DynamicTool + if raw.Valid { + if err := json.Unmarshal(raw.RawMessage, &dynamicToolDefs); err != nil { + return nil, nil, xerrors.Errorf("unmarshal dynamic tools: %w", err) + } + } + + activeToolNames := make(map[string]struct{}, len(tools)) + for _, name := range allowedPlanToolNames(tools, mode, parentChatID) { + activeToolNames[name] = struct{}{} + } + for _, t := range tools { + info := t.Info() + if _, active := activeToolNames[info.Name]; !active { + continue + } + if dynamicToolNames[info.Name] { + logger.Warn(ctx, "dynamic tool name collides with built-in tool, built-in takes precedence", + slog.F("tool_name", info.Name)) + delete(dynamicToolNames, info.Name) + } + } + + var filteredDefs []codersdk.DynamicTool + for _, dt := range dynamicToolDefs { + if dynamicToolNames[dt.Name] { + filteredDefs = append(filteredDefs, dt) + } + } + + return append(tools, dynamicToolsFromSDK(logger, filteredDefs)...), dynamicToolNames, nil +} + func (p *Server) runChat( ctx context.Context, chat database.Chat, @@ -4378,6 +4725,7 @@ func (p *Server) runChat( providerKeys chatprovider.ProviderAPIKeys callConfig codersdk.ChatModelCallConfig messages []database.ChatMessage + err error ) // Load MCP server configs and user tokens in parallel with @@ -4445,6 +4793,14 @@ func (p *Server) runChat( if err := g.Wait(); err != nil { return result, err } + + // Capture the current turn's mode from the chat plan mode so prompt + // and tool behavior can be resolved consistently for the rest of the + // turn. + currentPlanMode := chat.PlanMode + isPlanModeTurn := currentPlanMode.Valid && currentPlanMode.ChatPlanMode == database.ChatPlanModePlan + planModeInstructions := p.loadPlanModeInstructions(ctx, currentPlanMode, logger) + chainInfo := resolveChainMode(messages) result.PushSummaryModel = model result.ProviderKeys = providerKeys @@ -4713,9 +5069,23 @@ func (p *Server) runChat( if err := g2.Wait(); err != nil { return result, err } - if chat.ParentChatID.Valid { - prompt = chatprompt.InsertSystem(prompt, defaultSubagentInstruction) + isRootChat := !chat.ParentChatID.Valid + subagentInstruction := "" + if !isRootChat { + subagentInstruction = defaultSubagentInstruction } + prompt = buildSystemPrompt( + prompt, + subagentInstruction, + instruction, + skills, + resolvedUserPrompt, + systemPromptPlanContext{ + mode: currentPlanMode, + planModeInstructions: planModeInstructions, + isRootChat: isRootChat, + }, + ) if mcpCleanup != nil { defer mcpCleanup() } @@ -4730,19 +5100,8 @@ func (p *Server) runChat( } } - var instructionInjected bool - if instruction != "" { - prompt = chatprompt.InsertSystem(prompt, instruction) - instructionInjected = true - } + instructionInjected := instruction != "" prompt = renderPlanPathPrompt(prompt, resolvePlanPathBlock(ctx)) - if skillIndex := chattool.FormatSkillIndex(skills); skillIndex != "" { - prompt = chatprompt.InsertSystem(prompt, skillIndex) - } - if resolvedUserPrompt != "" { - prompt = chatprompt.InsertSystem(prompt, resolvedUserPrompt) - } - // Use the model config's context_limit as a fallback when the LLM // provider doesn't include context_limit in its response metadata // (which is the common case). @@ -5042,6 +5401,7 @@ func (p *Server) runChat( model = cuModel } + allowAskUserQuestion := isPlanModeTurn && isRootChat tools := []fantasy.AgentTool{ chattool.ReadFile(chattool.ReadFileOptions{ GetWorkspaceConn: workspaceCtx.getWorkspaceConn, @@ -5049,10 +5409,12 @@ func (p *Server) runChat( chattool.WriteFile(chattool.WriteFileOptions{ GetWorkspaceConn: workspaceCtx.getWorkspaceConn, ResolvePlanPath: resolvePlanPathForTools, + IsPlanTurn: isPlanModeTurn, }), chattool.EditFiles(chattool.EditFilesOptions{ GetWorkspaceConn: workspaceCtx.getWorkspaceConn, ResolvePlanPath: resolvePlanPathForTools, + IsPlanTurn: isPlanModeTurn, }), chattool.Execute(chattool.ExecuteOptions{ GetWorkspaceConn: workspaceCtx.getWorkspaceConn, @@ -5067,134 +5429,24 @@ func (p *Server) runChat( GetWorkspaceConn: workspaceCtx.getWorkspaceConn, }), } + if allowAskUserQuestion { + tools = append(tools, chattool.NewAskUserQuestionTool()) + } // Only root chats (not delegated subagents) get workspace // provisioning and subagent tools. Child agents must not - // create workspaces or spawn further subagents — they should + // create workspaces or spawn further subagents. They should // focus on completing their delegated task. - if !chat.ParentChatID.Valid { - // Workspace provisioning tools. - onChatUpdated := func(updatedChat database.Chat) { - workspaceCtx.selectWorkspace(updatedChat) - // Notify the frontend immediately so it can - // start streaming build logs before the tool - // completes. - p.publishChatPubsubEvent(updatedChat, codersdk.ChatWatchEventKindStatusChange, nil) - - // When a workspace is first attached mid-turn - // (e.g. via create_workspace), fetch and persist - // instruction files immediately so the LLM has - // AGENTS.md context for the remainder of this - // turn. The persisted marker prevents redundant - // fetches on subsequent turns. - if instruction == "" && updatedChat.WorkspaceID.Valid { - newInstruction, discoveredSkills, persistErr := p.persistInstructionFiles( - ctx, - updatedChat, - modelConfig.ID, - workspaceCtx.getWorkspaceAgent, - workspaceCtx.getWorkspaceConn, - ) - if persistErr != nil { - p.logger.Warn(ctx, "failed to persist instruction files on workspace attach", - slog.F("chat_id", updatedChat.ID), - slog.Error(persistErr), - ) - } else { - instruction = newInstruction - if len(discoveredSkills) > 0 { - skills = discoveredSkills - } - } - } - } - tools = append(tools, - chattool.ListTemplates(chat.OrganizationID, p.db, chattool.ListTemplatesOptions{ - OwnerID: chat.OwnerID, - AllowedTemplateIDs: p.chatTemplateAllowlist, - }), - chattool.ReadTemplate(chat.OrganizationID, p.db, chattool.ReadTemplateOptions{ - OwnerID: chat.OwnerID, - AllowedTemplateIDs: p.chatTemplateAllowlist, - }), - chattool.CreateWorkspace(chat.OrganizationID, p.db, chattool.CreateWorkspaceOptions{ - OwnerID: chat.OwnerID, - ChatID: chat.ID, - CreateFn: p.createWorkspaceFn, - AgentConnFn: chattool.AgentConnFunc(p.agentConnFn), - AgentInactiveDisconnectTimeout: p.agentInactiveDisconnectTimeout, - WorkspaceMu: &workspaceMu, - OnChatUpdated: onChatUpdated, - Logger: p.logger, - AllowedTemplateIDs: p.chatTemplateAllowlist, - }), - - chattool.StartWorkspace(chattool.StartWorkspaceOptions{ - DB: p.db, - OwnerID: chat.OwnerID, - ChatID: chat.ID, - StartFn: p.startWorkspaceFn, - AgentConnFn: chattool.AgentConnFunc(p.agentConnFn), - WorkspaceMu: &workspaceMu, - OnChatUpdated: onChatUpdated, - Logger: p.logger, - }), - ) - // Plan presentation tool. - tools = append(tools, chattool.ProposePlan(chattool.ProposePlanOptions{ - GetWorkspaceConn: workspaceCtx.getWorkspaceConn, - ResolvePlanPath: resolvePlanPathForTools, - StoreFile: func(ctx context.Context, name string, mediaType string, data []byte) (uuid.UUID, error) { - workspaceCtx.chatStateMu.Lock() - chatSnapshot := *workspaceCtx.currentChat - workspaceCtx.chatStateMu.Unlock() - - if !chatSnapshot.WorkspaceID.Valid { - return uuid.Nil, xerrors.New("no workspace is associated with this chat. Use the create_workspace tool to create one") - } - - ws, err := p.db.GetWorkspaceByID(ctx, chatSnapshot.WorkspaceID.UUID) - if err != nil { - return uuid.Nil, xerrors.Errorf("resolve workspace: %w", err) - } - - row, err := p.db.InsertChatFile(ctx, database.InsertChatFileParams{ - OwnerID: chatSnapshot.OwnerID, - OrganizationID: ws.OrganizationID, - Name: name, - Mimetype: mediaType, - Data: data, - }) - if err != nil { - return uuid.Nil, xerrors.Errorf("insert chat file: %w", err) - } - - // Cap enforcement and dedup are handled atomically - // in SQL. rejected > 0 = cap exceeded. - rejected, err := p.db.LinkChatFiles(ctx, database.LinkChatFilesParams{ - ChatID: chatSnapshot.ID, - MaxFileLinks: int32(codersdk.MaxChatFileIDs), - FileIds: []uuid.UUID{row.ID}, - }) - switch { - case err != nil: - p.logger.Error(ctx, "failed to link file to chat", - slog.F("chat_id", chatSnapshot.ID), - slog.F("file_id", row.ID), - slog.Error(err), - ) - case rejected > 0: - p.logger.Warn(ctx, "file cap reached, file not linked to chat", - slog.F("chat_id", chatSnapshot.ID), - slog.F("file_id", row.ID), - slog.F("max_file_links", codersdk.MaxChatFileIDs), - ) - } - return row.ID, nil - }, - })) - tools = append(tools, p.subagentTools(ctx, func() database.Chat { - return chat - })...) + if isRootChat { + tools = p.appendRootChatTools(ctx, tools, rootChatToolsOptions{ + chat: chat, + modelConfigID: modelConfig.ID, + workspaceCtx: &workspaceCtx, + workspaceMu: &workspaceMu, + instruction: &instruction, + skills: &skills, + resolvePlanPath: resolvePlanPathForTools, + isPlanModeTurn: isPlanModeTurn, + }) } // Append skill tools when the workspace has skills. @@ -5221,49 +5473,35 @@ func (p *Server) runChat( // Append tools from external MCP servers. These appear // after the built-in tools so the LLM sees them as // additional capabilities. - tools = append(tools, mcpTools...) - tools = append(tools, workspaceMCPTools...) + if !isPlanModeTurn { + tools = append(tools, mcpTools...) + tools = append(tools, workspaceMCPTools...) + } // Append dynamic tools declared by the client at chat // creation time. These appear in the LLM's tool list but - // are never executed by the chatloop — the client handles + // are never executed by the chatloop. The client handles // execution via POST /tool-results. - dynamicToolNames, err := parseDynamicToolNames(chat.DynamicTools) + var dynamicToolNames map[string]bool + tools, dynamicToolNames, err = appendDynamicTools( + ctx, + logger, + tools, + chat.DynamicTools, + currentPlanMode, + chat.ParentChatID, + ) if err != nil { - return result, xerrors.Errorf("parse dynamic tool names: %w", err) - } - // Unmarshal the full definitions separately so we can - // build the filtered list below. parseDynamicToolNames - // already validated the JSON, so this cannot fail. - var dynamicToolDefs []codersdk.DynamicTool - if chat.DynamicTools.Valid { - if err := json.Unmarshal(chat.DynamicTools.RawMessage, &dynamicToolDefs); err != nil { - return result, xerrors.Errorf("unmarshal dynamic tools: %w", err) - } - } - for _, t := range tools { - info := t.Info() - if dynamicToolNames[info.Name] { - logger.Warn(ctx, "dynamic tool name collides with built-in tool, built-in takes precedence", - slog.F("tool_name", info.Name)) - delete(dynamicToolNames, info.Name) - } + return result, err } - var filteredDefs []codersdk.DynamicTool - for _, dt := range dynamicToolDefs { - if dynamicToolNames[dt.Name] { - filteredDefs = append(filteredDefs, dt) - } - } - tools = append(tools, dynamicToolsFromSDK(p.logger, filteredDefs)...) // Build provider-native tools (e.g., web search) based on // the model configuration. var providerTools []chatloop.ProviderTool - if callConfig.ProviderOptions != nil { + if !isPlanModeTurn && callConfig.ProviderOptions != nil { providerTools = buildProviderTools(model.Provider(), callConfig.ProviderOptions) } - if isComputerUse { + if !isPlanModeTurn && isComputerUse { desktopGeometry := workspacesdk.DefaultDesktopGeometry() providerTools = append(providerTools, chatloop.ProviderTool{ Definition: chattool.ComputerUseProviderTool( @@ -5291,7 +5529,8 @@ func (p *Server) runChat( chainModeActive := chatprovider.IsResponsesStoreEnabled(providerOptions) && chainInfo.previousResponseID != "" && chainInfo.contributingTrailingUserCount > 0 && - chainInfo.modelConfigID == modelConfig.ID + chainInfo.modelConfigID == modelConfig.ID && + !isPlanModeTurn if chainModeActive { providerOptions = chatprovider.CloneWithPreviousResponseID( providerOptions, @@ -5300,9 +5539,12 @@ func (p *Server) runChat( prompt = filterPromptForChainMode(prompt, chainInfo) } err = chatloop.Run(ctx, chatloop.RunOptions{ - Model: model, - Messages: prompt, - Tools: tools, MaxSteps: maxChatSteps, + Model: model, + Messages: prompt, + Tools: tools, + ActiveTools: allowedPlanToolNames(tools, currentPlanMode, chat.ParentChatID), + StopAfterTools: stopAfterPlanTools(currentPlanMode, chat.ParentChatID), + MaxSteps: maxChatSteps, Metrics: p.metrics, BuiltinToolNames: builtinToolNames, @@ -5337,9 +5579,6 @@ func (p *Server) runChat( if err != nil { return nil, xerrors.Errorf("convert reloaded messages: %w", err) } - if chat.ParentChatID.Valid { - reloadedPrompt = chatprompt.InsertSystem(reloadedPrompt, defaultSubagentInstruction) - } // Re-derive instruction and skills from the reloaded // messages so that any context added during the // chatloop (e.g. via persistInstructionFiles when @@ -5351,22 +5590,26 @@ func (p *Server) runChat( reloadedInstruction = instructionFromContextFiles(reloadedMsgs) } if reloadedInstruction != "" { - reloadedPrompt = chatprompt.InsertSystem(reloadedPrompt, reloadedInstruction) instructionInjected = true } - reloadedPrompt = renderPlanPathPrompt(reloadedPrompt, resolvePlanPathBlock(reloadCtx)) reloadedSkills := skillsFromParts(reloadedMsgs) if len(reloadedSkills) == 0 { reloadedSkills = skills } - - if skillIndex := chattool.FormatSkillIndex(reloadedSkills); skillIndex != "" { - reloadedPrompt = chatprompt.InsertSystem(reloadedPrompt, skillIndex) - } reloadUserPrompt := p.resolveUserPrompt(reloadCtx, chat.OwnerID) - if reloadUserPrompt != "" { - reloadedPrompt = chatprompt.InsertSystem(reloadedPrompt, reloadUserPrompt) - } + reloadedPrompt = buildSystemPrompt( + reloadedPrompt, + subagentInstruction, + reloadedInstruction, + reloadedSkills, + reloadUserPrompt, + systemPromptPlanContext{ + mode: currentPlanMode, + planModeInstructions: planModeInstructions, + isRootChat: isRootChat, + }, + ) + reloadedPrompt = renderPlanPathPrompt(reloadedPrompt, resolvePlanPathBlock(reloadCtx)) if chainModeActive { reloadedPrompt = filterPromptForChainMode( reloadedPrompt, @@ -5416,6 +5659,9 @@ func (p *Server) runChat( p.logger.Warn(ctx, "failed to persist interrupted chat step", slog.Error(err)) }, }) + if errors.Is(err, chatloop.ErrStopAfterTool) { + err = nil + } if errors.Is(err, chatloop.ErrDynamicToolCall) { // The stream event is published in processChat's // defer after the DB status transitions to diff --git a/coderd/x/chatd/chatd_internal_test.go b/coderd/x/chatd/chatd_internal_test.go index 0a675311d3..35856b5bc4 100644 --- a/coderd/x/chatd/chatd_internal_test.go +++ b/coderd/x/chatd/chatd_internal_test.go @@ -33,6 +33,209 @@ import ( "github.com/coder/quartz" ) +type testAgentTool struct { + info fantasy.ToolInfo + providerOptions fantasy.ProviderOptions +} + +func newTestAgentTool(name string) fantasy.AgentTool { + return &testAgentTool{info: fantasy.ToolInfo{Name: name}} +} + +func (t *testAgentTool) Info() fantasy.ToolInfo { + return t.info +} + +func (t *testAgentTool) Run(context.Context, fantasy.ToolCall) (fantasy.ToolResponse, error) { + _ = t + return fantasy.ToolResponse{}, nil +} + +func (t *testAgentTool) ProviderOptions() fantasy.ProviderOptions { + return t.providerOptions +} + +func (t *testAgentTool) SetProviderOptions(opts fantasy.ProviderOptions) { + t.providerOptions = opts +} + +func TestAllowedPlanToolNames(t *testing.T) { + t.Parallel() + + makeTools := func(names ...string) []fantasy.AgentTool { + tools := make([]fantasy.AgentTool, 0, len(names)) + for _, name := range names { + tools = append(tools, newTestAgentTool(name)) + } + return tools + } + + planMode := database.NullChatPlanMode{ + ChatPlanMode: database.ChatPlanModePlan, + Valid: true, + } + + t.Run("NormalModeReturnsAllRegisteredTools", func(t *testing.T) { + t.Parallel() + + got := allowedPlanToolNames(makeTools( + "read_file", + "propose_plan", + "custom_tool", + "execute", + ), database.NullChatPlanMode{}, uuid.NullUUID{}) + + require.Equal(t, []string{ + "read_file", + "propose_plan", + "custom_tool", + "execute", + }, got) + }) + + t.Run("PlanModeIncludesOnlyAllowlistedBuiltIns", func(t *testing.T) { + t.Parallel() + + got := allowedPlanToolNames(makeTools( + "read_file", + "write_file", + "edit_files", + "execute", + "process_output", + "process_list", + "process_signal", + "list_templates", + "read_template", + "create_workspace", + "start_workspace", + "propose_plan", + "spawn_agent", + "wait_agent", + "message_agent", + "close_agent", + "spawn_computer_use_agent", + "read_skill", + "read_skill_file", + "ask_user_question", + ), planMode, uuid.NullUUID{}) + + require.Equal(t, []string{ + "read_file", + "write_file", + "edit_files", + "execute", + "process_output", + "list_templates", + "read_template", + "create_workspace", + "start_workspace", + "propose_plan", + "spawn_agent", + "wait_agent", + "read_skill", + "read_skill_file", + "ask_user_question", + }, got) + }) + + t.Run("PlanModeChildChatsAllowExplorationOnly", func(t *testing.T) { + t.Parallel() + + got := allowedPlanToolNames(makeTools( + "read_file", + "write_file", + "edit_files", + "execute", + "process_output", + "list_templates", + "read_template", + "create_workspace", + "start_workspace", + "propose_plan", + "spawn_agent", + "wait_agent", + "read_skill", + "read_skill_file", + "ask_user_question", + ), planMode, uuid.NullUUID{UUID: uuid.New(), Valid: true}) + + require.Equal(t, []string{ + "read_file", + "execute", + "process_output", + "read_skill", + "read_skill_file", + }, got) + require.NotContains(t, got, "write_file") + require.NotContains(t, got, "edit_files") + require.NotContains(t, got, "ask_user_question") + require.NotContains(t, got, "propose_plan") + }) + + t.Run("PlanModeStillExcludesDangerousTools", func(t *testing.T) { + t.Parallel() + + got := allowedPlanToolNames(makeTools( + "execute", + "process_output", + "message_agent", + "spawn_computer_use_agent", + "propose_plan", + ), planMode, uuid.NullUUID{}) + + require.Equal(t, []string{"execute", "process_output", "propose_plan"}, got) + require.NotContains(t, got, "message_agent") + require.NotContains(t, got, "spawn_computer_use_agent") + }) + + t.Run("PlanModeExcludesUnknownTools", func(t *testing.T) { + t.Parallel() + + got := allowedPlanToolNames(makeTools( + "read_file", + "custom_tool", + "another_custom_tool", + "propose_plan", + ), planMode, uuid.NullUUID{}) + + require.Equal(t, []string{ + "read_file", + "propose_plan", + }, got) + require.NotContains(t, got, "custom_tool") + require.NotContains(t, got, "another_custom_tool") + }) +} + +func TestStopAfterPlanTools(t *testing.T) { + t.Parallel() + + planMode := database.NullChatPlanMode{ + ChatPlanMode: database.ChatPlanModePlan, + Valid: true, + } + + t.Run("NormalModeReturnsNil", func(t *testing.T) { + t.Parallel() + require.Nil(t, stopAfterPlanTools(database.NullChatPlanMode{}, uuid.NullUUID{})) + }) + + t.Run("RootPlanModeIncludesClarificationTool", func(t *testing.T) { + t.Parallel() + require.Equal(t, map[string]struct{}{ + "propose_plan": {}, + "ask_user_question": {}, + }, stopAfterPlanTools(planMode, uuid.NullUUID{})) + }) + + t.Run("ChildPlanModeSkipsClarificationTool", func(t *testing.T) { + t.Parallel() + require.Equal(t, map[string]struct{}{ + "propose_plan": {}, + }, stopAfterPlanTools(planMode, uuid.NullUUID{UUID: uuid.New(), Valid: true})) + }) +} + func TestRegenerateChatTitle_PersistsAndBroadcasts(t *testing.T) { t.Parallel() diff --git a/coderd/x/chatd/chatd_test.go b/coderd/x/chatd/chatd_test.go index 483d4667d5..ee5841a3be 100644 --- a/coderd/x/chatd/chatd_test.go +++ b/coderd/x/chatd/chatd_test.go @@ -54,6 +54,90 @@ import ( "github.com/coder/quartz" ) +type recordedOpenAIRequest struct { + Messages []chattest.OpenAIMessage + Tools []string + Store *bool + PreviousResponseID *string + ContentLength int64 +} + +func recordOpenAIRequest(req *chattest.OpenAIRequest) recordedOpenAIRequest { + messages := append([]chattest.OpenAIMessage(nil), req.Messages...) + tools := make([]string, 0, len(req.Tools)) + for _, tool := range req.Tools { + tools = append(tools, tool.Function.Name) + } + + var store *bool + if req.Store != nil { + value := *req.Store + store = &value + } + + var previousResponseID *string + if req.PreviousResponseID != nil { + value := *req.PreviousResponseID + previousResponseID = &value + } + + var contentLength int64 + if req.Request != nil { + contentLength = req.Request.ContentLength + } + + return recordedOpenAIRequest{ + Messages: messages, + Tools: tools, + Store: store, + PreviousResponseID: previousResponseID, + ContentLength: contentLength, + } +} + +func requestHasSystemSubstring(req recordedOpenAIRequest, want string) bool { + for _, msg := range req.Messages { + if msg.Role == "system" && strings.Contains(msg.Content, want) { + return true + } + } + return false +} + +func newWorkspaceToolTestServer( + t *testing.T, + db database.Store, + ps dbpubsub.Pubsub, + agentID uuid.UUID, + planContent string, +) *chatd.Server { + t.Helper() + + ctrl := gomock.NewController(t) + mockConn := agentconnmock.NewMockAgentConn(ctrl) + mockConn.EXPECT().SetExtraHeaders(gomock.Any()).AnyTimes() + mockConn.EXPECT().ContextConfig(gomock.Any()). + Return(workspacesdk.ContextConfigResponse{}, xerrors.New("not supported")).AnyTimes() + mockConn.EXPECT().ListMCPTools(gomock.Any()). + Return(workspacesdk.ListMCPToolsResponse{}, nil).AnyTimes() + mockConn.EXPECT().LS(gomock.Any(), gomock.Any(), gomock.Any()). + Return(workspacesdk.LSResponse{}, nil).AnyTimes() + mockConn.EXPECT().ReadFile(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, path string, _, _ int64) (io.ReadCloser, string, error) { + if path == "/home/coder/PLAN.md" { + return io.NopCloser(strings.NewReader(planContent)), "", nil + } + return io.NopCloser(strings.NewReader("")), "", nil + }).AnyTimes() + + return newActiveTestServer(t, db, ps, func(cfg *chatd.Config) { + cfg.AgentConn = func(_ context.Context, gotAgentID uuid.UUID) (workspacesdk.AgentConn, func(), error) { + require.Equal(t, agentID, gotAgentID) + return mockConn, func() {}, nil + } + }) +} + func TestInterruptChatBroadcastsStatusAcrossInstances(t *testing.T) { t.Parallel() @@ -224,7 +308,7 @@ func TestSubagentChatExcludesWorkspaceProvisioningTools(t *testing.T) { require.GreaterOrEqual(t, len(recorded), 2, "expected at least 2 streamed LLM calls (root + subagent)") - workspaceTools := []string{"propose_plan", "list_templates", "read_template", "create_workspace"} + workspaceTools := []string{"list_templates", "read_template", "create_workspace"} subagentTools := []string{"spawn_agent", "wait_agent", "message_agent", "close_agent"} // Identify root and subagent calls. Root chat calls include @@ -255,6 +339,10 @@ func TestSubagentChatExcludesWorkspaceProvisioningTools(t *testing.T) { "root chat should have subagent tool %q", tool) } + // Standard turns (no turn mode) should hide propose_plan. + require.NotContains(t, rootCalls[0], "propose_plan", + "standard-turn root chat should NOT have propose_plan") + // Subagent calls must NOT include workspace or subagent tools. for _, tool := range workspaceTools { require.NotContains(t, childCalls[0], tool, @@ -266,6 +354,153 @@ func TestSubagentChatExcludesWorkspaceProvisioningTools(t *testing.T) { } } +func TestPlanModeSubagentChatExcludesAskUserQuestion(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + deploymentValues := coderdtest.DeploymentValues(t) + deploymentValues.Experiments = []string{string(codersdk.ExperimentAgents)} + client := coderdtest.New(t, &coderdtest.Options{ + DeploymentValues: deploymentValues, + IncludeProvisionerDaemon: true, + }) + user := coderdtest.CreateFirstUser(t, client) + expClient := codersdk.NewExperimentalClient(client) + + agentToken := uuid.NewString() + version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{ + Parse: echo.ParseComplete, + ProvisionPlan: echo.PlanComplete, + ProvisionApply: echo.ApplyComplete, + ProvisionGraph: echo.ProvisionGraphWithAgent(agentToken), + }) + coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID) + coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) + + _ = agenttest.New(t, client.URL, agentToken) + + var toolsMu sync.Mutex + toolsByCall := make([][]string, 0, 2) + requestsByCall := make([]recordedOpenAIRequest, 0, 2) + + var callCount atomic.Int32 + openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { + if !req.Stream { + return chattest.OpenAINonStreamingResponse("ok") + } + + names := make([]string, 0, len(req.Tools)) + for _, tool := range req.Tools { + names = append(names, tool.Function.Name) + } + toolsMu.Lock() + toolsByCall = append(toolsByCall, names) + requestsByCall = append(requestsByCall, recordOpenAIRequest(req)) + toolsMu.Unlock() + + if callCount.Add(1) == 1 { + return chattest.OpenAIStreamingResponse( + chattest.OpenAIToolCallChunk("spawn_agent", `{"prompt":"inspect the codebase","title":"sub"}`), + ) + } + return chattest.OpenAIStreamingResponse( + chattest.OpenAITextChunks("done")..., + ) + }) + + _, err := expClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ + Provider: "openai-compat", + APIKey: "test-api-key", + BaseURL: openAIURL, + }) + require.NoError(t, err) + + contextLimit := int64(4096) + isDefault := true + _, err = expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ + Provider: "openai-compat", + Model: "gpt-4o-mini", + ContextLimit: &contextLimit, + IsDefault: &isDefault, + }) + require.NoError(t, err) + + chat, err := expClient.CreateChat(ctx, codersdk.CreateChatRequest{ + OrganizationID: user.OrganizationID, + PlanMode: codersdk.ChatPlanModePlan, + Content: []codersdk.ChatInputPart{ + { + Type: codersdk.ChatInputPartTypeText, + Text: "Spawn a subagent to inspect the codebase.", + }, + }, + }) + require.NoError(t, err) + + require.Eventually(t, func() bool { + got, getErr := expClient.GetChat(ctx, chat.ID) + if getErr != nil { + return false + } + if got.Status != codersdk.ChatStatusWaiting && got.Status != codersdk.ChatStatusError { + return false + } + toolsMu.Lock() + n := len(toolsByCall) + toolsMu.Unlock() + return n >= 3 + }, testutil.WaitLong, testutil.IntervalFast) + + toolsMu.Lock() + recorded := append([][]string(nil), toolsByCall...) + recordedRequests := append([]recordedOpenAIRequest(nil), requestsByCall...) + toolsMu.Unlock() + + require.GreaterOrEqual(t, len(recorded), 2, + "expected at least 2 streamed LLM calls (root + subagent)") + require.Len(t, recordedRequests, len(recorded)) + + var rootCalls, childCalls [][]string + var rootRequests, childRequests []recordedOpenAIRequest + for i, tools := range recorded { + if slice.Contains(tools, "spawn_agent") { + rootCalls = append(rootCalls, tools) + rootRequests = append(rootRequests, recordedRequests[i]) + continue + } + childCalls = append(childCalls, tools) + childRequests = append(childRequests, recordedRequests[i]) + } + + require.NotEmpty(t, rootCalls, "expected at least one root chat LLM call") + require.NotEmpty(t, childCalls, "expected at least one subagent LLM call") + require.NotEmpty(t, rootRequests, "expected at least one root prompt") + require.NotEmpty(t, childRequests, "expected at least one subagent prompt") + require.Contains(t, rootCalls[0], "ask_user_question", + "root plan-mode chat should have ask_user_question") + require.Contains(t, rootCalls[0], "write_file", + "root plan-mode chat should have write_file") + require.Contains(t, rootCalls[0], "edit_files", + "root plan-mode chat should have edit_files") + require.Contains(t, rootCalls[0], "execute", + "root plan-mode chat should have execute") + require.Contains(t, rootCalls[0], "process_output", + "root plan-mode chat should have process_output") + require.NotContains(t, childCalls[0], "ask_user_question", + "plan-mode subagent should NOT have ask_user_question") + require.NotContains(t, childCalls[0], "write_file", + "plan-mode subagent should NOT have write_file") + require.NotContains(t, childCalls[0], "edit_files", + "plan-mode subagent should NOT have edit_files") + require.Contains(t, childCalls[0], "execute", + "plan-mode subagent should have execute") + require.Contains(t, childCalls[0], "process_output", + "plan-mode subagent should have process_output") + require.True(t, requestHasSystemSubstring(rootRequests[0], "You are in Plan Mode.")) + require.True(t, requestHasSystemSubstring(childRequests[0], "You are in Plan Mode as a delegated sub-agent.")) + require.False(t, requestHasSystemSubstring(childRequests[0], "When the plan is ready, call propose_plan")) +} + func TestInterruptChatClearsWorkerInDatabase(t *testing.T) { t.Parallel() @@ -579,6 +814,77 @@ func TestSendMessageQueueBehaviorQueuesWhenBusy(t *testing.T) { require.Len(t, messages, 1) } +func TestPlanTurnPromptContract(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + db, ps := dbtestutil.NewDB(t) + + var ( + requests []recordedOpenAIRequest + requestsMu sync.Mutex + ) + openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { + if !req.Stream { + return chattest.OpenAINonStreamingResponse("title") + } + + requestsMu.Lock() + requests = append(requests, recordOpenAIRequest(req)) + requestsMu.Unlock() + + return chattest.OpenAIStreamingResponse( + chattest.OpenAITextChunks("plan acknowledged")..., + ) + }) + + user, org, model := seedChatDependenciesWithProvider(ctx, t, db, "openai-compat", openAIURL) + planModeInstructions := "Ask about deployment sequencing before finalizing the plan." + err := db.UpsertChatPlanModeInstructions(dbauthz.AsSystemRestricted(ctx), planModeInstructions) + require.NoError(t, err) + ws, dbAgent := seedWorkspaceWithAgent(t, db, user.ID) + server := newWorkspaceToolTestServer(t, db, ps, dbAgent.ID, "# Plan\n") + + chat, err := server.CreateChat(ctx, chatd.CreateOptions{ + OwnerID: user.ID, + OrganizationID: org.ID, + Title: "plan-turn-prompt-contract", + ModelConfigID: model.ID, + PlanMode: database.NullChatPlanMode{ChatPlanMode: database.ChatPlanModePlan, Valid: true}, + WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true}, + InitialUserContent: []codersdk.ChatMessagePart{ + codersdk.ChatMessageText("Plan the rollout."), + }, + }) + require.NoError(t, err) + + waitForChatProcessed(ctx, t, db, chat.ID, server) + + requestsMu.Lock() + recorded := append([]recordedOpenAIRequest(nil), requests...) + requestsMu.Unlock() + + require.Len(t, recorded, 1, "expected exactly 1 streamed model call") + require.True(t, requestHasSystemSubstring(recorded[0], "You are in Plan Mode.")) + require.True(t, requestHasSystemSubstring(recorded[0], "The only intentional authored workspace artifact is the plan file")) + require.True(t, requestHasSystemSubstring(recorded[0], "You may use execute and process_output for exploration")) + require.True(t, requestHasSystemSubstring(recorded[0], "After a successful propose_plan call, stop immediately")) + require.True(t, requestHasSystemSubstring(recorded[0], planModeInstructions)) + for _, msg := range recorded[0].Messages { + if msg.Role != "system" { + continue + } + // The overlay constant includes a placeholder that is replaced at + // runtime, so strip only the stable body text before checking. + overlayBody := strings.TrimSuffix( + chatd.PlanningOverlayPrompt, + "{{CODER_CHAT_PLAN_FILE_PATH_BLOCK}}", + ) + sanitized := strings.ReplaceAll(msg.Content, overlayBody, "") + require.NotContains(t, sanitized, "propose_plan") + } +} + func TestSendMessageQueuesWhenWaitingWithQueuedBacklog(t *testing.T) { t.Parallel() @@ -2162,6 +2468,91 @@ func TestDynamicToolCallPausesAndResumes(t *testing.T) { "expected second LLM call to include the submitted dynamic tool result") } +func TestDynamicToolNamedProposePlanRemainsAvailableOutsidePlanMode(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + ctx := testutil.Context(t, testutil.WaitLong) + + var streamedCallsMu sync.Mutex + streamedCalls := make([]chattest.OpenAIRequest, 0, 1) + + openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { + if !req.Stream { + return chattest.OpenAINonStreamingResponse("Dynamic tool collision test") + } + + streamedCallsMu.Lock() + streamedCalls = append(streamedCalls, chattest.OpenAIRequest{ + Messages: append([]chattest.OpenAIMessage(nil), req.Messages...), + Tools: append([]chattest.OpenAITool(nil), req.Tools...), + Stream: req.Stream, + }) + streamedCallsMu.Unlock() + + return chattest.OpenAIStreamingResponse( + chattest.OpenAITextChunks("Dynamic tool list captured.")..., + ) + }) + + user, org, model := seedChatDependenciesWithProvider(ctx, t, db, "openai-compat", openAIURL) + server := newActiveTestServer(t, db, ps) + + dynamicToolsJSON, err := json.Marshal([]mcpgo.Tool{{ + Name: "propose_plan", + Description: "A dynamic tool whose name collides with the hidden built-in.", + InputSchema: mcpgo.ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "input": map[string]any{"type": "string"}, + }, + Required: []string{"input"}, + }, + }}) + require.NoError(t, err) + + chat, err := server.CreateChat(ctx, chatd.CreateOptions{ + OrganizationID: org.ID, + OwnerID: user.ID, + Title: "dynamic-propose-plan-collision", + ModelConfigID: model.ID, + InitialUserContent: []codersdk.ChatMessagePart{ + codersdk.ChatMessageText("List the available tools."), + }, + DynamicTools: dynamicToolsJSON, + }) + require.NoError(t, err) + + var chatResult database.Chat + require.Eventually(t, func() bool { + got, getErr := db.GetChatByID(ctx, chat.ID) + if getErr != nil { + return false + } + chatResult = got + return got.Status == database.ChatStatusWaiting || got.Status == database.ChatStatusError + }, testutil.WaitLong, testutil.IntervalFast) + + if chatResult.Status == database.ChatStatusError { + require.FailNowf(t, "chat run failed", "last_error=%q", chatResult.LastError.String) + } + + streamedCallsMu.Lock() + recordedCalls := append([]chattest.OpenAIRequest(nil), streamedCalls...) + streamedCallsMu.Unlock() + require.NotEmpty(t, recordedCalls) + + var foundDynamicTool bool + for _, tool := range recordedCalls[0].Tools { + if tool.Function.Name == "propose_plan" { + foundDynamicTool = true + break + } + } + require.True(t, foundDynamicTool, + "expected the dynamic propose_plan tool to remain visible outside plan mode") +} + func TestDynamicToolCallMixedWithBuiltIn(t *testing.T) { t.Parallel() diff --git a/coderd/x/chatd/chatloop/chatloop.go b/coderd/x/chatd/chatloop/chatloop.go index 951554bbb6..bca8461b9b 100644 --- a/coderd/x/chatd/chatloop/chatloop.go +++ b/coderd/x/chatd/chatloop/chatloop.go @@ -43,6 +43,10 @@ const ( var ( ErrInterrupted = xerrors.New("chat interrupted") ErrDynamicToolCall = xerrors.New("dynamic tool call") + // ErrStopAfterTool is returned when a tool listed in + // StopAfterTools produces a successful result, indicating + // the run should terminate cleanly after persistence. + ErrStopAfterTool = xerrors.New("stop after tool") errStartupTimeout = xerrors.New( "chat response did not start before the startup timeout", @@ -114,6 +118,11 @@ type RunOptions struct { // the chatloop persists partial results and exits with // ErrDynamicToolCall instead of executing the tool. DynamicToolNames map[string]bool + // StopAfterTools lists tool names that, when they produce a + // successful result, cause the run to stop after persisting + // the current step. This is used for plan turns where + // propose_plan should terminate the run on success. + StopAfterTools map[string]struct{} // ModelConfig holds per-call LLM parameters (temperature, // max tokens, etc.) read from the chat model configuration. @@ -472,7 +481,7 @@ func Run(ctx context.Context, opts RunOptions) error { } // Execute only built-in tools. - toolResults = executeTools(ctx, opts.Tools, opts.ProviderTools, builtinCalls, opts.Metrics, provider, opts.BuiltinToolNames, func(tr fantasy.ToolResultContent, completedAt time.Time) { + toolResults = executeTools(ctx, opts.Tools, opts.ActiveTools, opts.ProviderTools, builtinCalls, opts.Metrics, provider, opts.BuiltinToolNames, func(tr fantasy.ToolResultContent, completedAt time.Time) { recordToolResultTimestamp(&result, tr.ToolCallID, completedAt) ssePart := chatprompt.PartFromContent(tr) ssePart.CreatedAt = &completedAt @@ -566,6 +575,12 @@ func Run(ctx context.Context, opts RunOptions) error { lastUsage = result.usage lastProviderMetadata = result.providerMetadata + // Check if any executed tool triggers an early stop. + if shouldStopAfterTools(opts.StopAfterTools, toolResults) { + tryCompactOnExit(ctx, opts, result.usage, result.providerMetadata) + return ErrStopAfterTool + } + // When chain mode is active (PreviousResponseID set), exit // it after persisting the first chained step. Continuation // steps include tool-result messages, which fantasy rejects @@ -1022,6 +1037,7 @@ func processStepStream( func executeTools( ctx context.Context, allTools []fantasy.AgentTool, + activeTools []string, providerTools []ProviderTool, toolCalls []fantasy.ToolCallContent, metrics *Metrics, @@ -1051,11 +1067,14 @@ func executeTools( for _, t := range allTools { toolMap[t.Info().Name] = t } + providerRunnerNames := make(map[string]struct{}, len(providerTools)) // Include runners from provider tools so locally-executed // provider tools (e.g. computer use) can be dispatched. for _, pt := range providerTools { if pt.Runner != nil { - toolMap[pt.Runner.Info().Name] = pt.Runner + name := pt.Runner.Info().Name + toolMap[name] = pt.Runner + providerRunnerNames[name] = struct{}{} } } @@ -1081,7 +1100,7 @@ func executeTools( // accurate individual completion times. completedAt[i] = dbtime.Now() }() - results[i] = executeSingleTool(ctx, toolMap, tc, metrics, provider, builtinToolNames) + results[i] = executeSingleTool(ctx, toolMap, tc, metrics, provider, builtinToolNames, activeTools, providerRunnerNames) }() } wg.Wait() @@ -1105,6 +1124,8 @@ func executeSingleTool( metrics *Metrics, provider string, builtinToolNames map[string]bool, + activeTools []string, + providerRunnerNames map[string]struct{}, ) fantasy.ToolResultContent { result := fantasy.ToolResultContent{ ToolCallID: tc.ToolCallID, @@ -1121,6 +1142,13 @@ func executeSingleTool( ) }() + if _, isProviderRunner := providerRunnerNames[tc.ToolName]; !isProviderRunner && !isToolActive(tc.ToolName, activeTools) { + result.Result = fantasy.ToolResultOutputContentError{ + Error: xerrors.New("Tool not active in this turn: " + tc.ToolName), + } + return result + } + tool, exists := toolMap[tc.ToolName] if !exists { result.Result = fantasy.ToolResultOutputContentError{ @@ -1325,6 +1353,10 @@ func tryCompactOnExit( } } +func isToolActive(name string, activeTools []string) bool { + return len(activeTools) == 0 || slices.Contains(activeTools, name) +} + // buildToolDefinitions converts AgentTool definitions into the // fantasy.Tool slice expected by fantasy.Call. When activeTools // is non-empty, only function tools whose name appears in the @@ -1334,7 +1366,7 @@ func buildToolDefinitions(tools []fantasy.AgentTool, activeTools []string, provi prepared := make([]fantasy.Tool, 0, len(tools)+len(providerTools)) for _, tool := range tools { info := tool.Info() - if len(activeTools) > 0 && !slices.Contains(activeTools, info.Name) { + if !isToolActive(info.Name, activeTools) { continue } @@ -1361,6 +1393,24 @@ func buildToolDefinitions(tools []fantasy.AgentTool, activeTools []string, provi return prepared } +// shouldStopAfterTools returns true if any tool result in the +// slice matches a name in stopTools and produced a successful +// (non-error) result. +func shouldStopAfterTools(stopTools map[string]struct{}, results []fantasy.ToolResultContent) bool { + if len(stopTools) == 0 { + return false + } + for _, tr := range results { + if _, ok := stopTools[tr.ToolName]; !ok { + continue + } + if _, isErr := tr.Result.(fantasy.ToolResultOutputContentError); !isErr { + return true + } + } + return false +} + func shouldApplyAnthropicPromptCaching(model fantasy.LanguageModel) bool { if model == nil { return false diff --git a/coderd/x/chatd/chatloop/chatloop_test.go b/coderd/x/chatd/chatloop/chatloop_test.go index b33f8ac796..a24bcee6e3 100644 --- a/coderd/x/chatd/chatloop/chatloop_test.go +++ b/coderd/x/chatd/chatloop/chatloop_test.go @@ -101,6 +101,150 @@ func TestRun_ActiveToolsPrepareBehavior(t *testing.T) { require.True(t, hasAnthropicEphemeralCacheControl(capturedCall.Prompt[4])) } +func TestRun_ActiveToolsRejectsDisallowedExecution(t *testing.T) { + t.Parallel() + + var blockedCalls atomic.Int32 + blockedToolName := "write_file" + model := &chattest.FakeModel{ + ProviderName: "fake", + StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { + return streamFromParts([]fantasy.StreamPart{ + {Type: fantasy.StreamPartTypeToolInputStart, ID: "tc-blocked", ToolCallName: blockedToolName}, + {Type: fantasy.StreamPartTypeToolInputDelta, ID: "tc-blocked", Delta: `{"path":"/tmp/nope"}`}, + {Type: fantasy.StreamPartTypeToolInputEnd, ID: "tc-blocked"}, + { + Type: fantasy.StreamPartTypeToolCall, + ID: "tc-blocked", + ToolCallName: blockedToolName, + ToolCallInput: `{"path":"/tmp/nope"}`, + }, + {Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonToolCalls}, + }), nil + }, + } + + blockedTool := fantasy.NewAgentTool( + blockedToolName, + "blocked tool", + func(context.Context, struct{}, fantasy.ToolCall) (fantasy.ToolResponse, error) { + blockedCalls.Add(1) + return fantasy.NewTextResponse("should not run"), nil + }, + ) + + var persistedStep PersistedStep + err := Run(context.Background(), RunOptions{ + Model: model, + Messages: []fantasy.Message{ + textMessage(fantasy.MessageRoleUser, "try the blocked tool"), + }, + Tools: []fantasy.AgentTool{ + newNoopTool(activeToolName), + blockedTool, + }, + ActiveTools: []string{activeToolName}, + MaxSteps: 1, + PersistStep: func(_ context.Context, step PersistedStep) error { + persistedStep = step + return nil + }, + }) + require.NoError(t, err) + require.Zero(t, blockedCalls.Load(), "disallowed tool must not execute") + + var foundToolError bool + for _, block := range persistedStep.Content { + toolResult, ok := fantasy.AsContentType[fantasy.ToolResultContent](block) + if !ok || toolResult.ToolName != blockedToolName { + continue + } + errResult, ok := toolResult.Result.(fantasy.ToolResultOutputContentError) + require.True(t, ok) + assert.EqualError(t, errResult.Error, "Tool not active in this turn: "+blockedToolName) + foundToolError = true + } + require.True(t, foundToolError, "persisted step should include the rejected tool result") +} + +func TestRun_ActiveToolsAllowsProviderRunnerExecution(t *testing.T) { + t.Parallel() + + providerRunnerName := "computer" + var runnerCalls atomic.Int32 + model := &chattest.FakeModel{ + ProviderName: "fake", + StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { + return streamFromParts([]fantasy.StreamPart{ + {Type: fantasy.StreamPartTypeToolInputStart, ID: "tc-provider-runner", ToolCallName: providerRunnerName}, + {Type: fantasy.StreamPartTypeToolInputDelta, ID: "tc-provider-runner", Delta: `{}`}, + {Type: fantasy.StreamPartTypeToolInputEnd, ID: "tc-provider-runner"}, + { + Type: fantasy.StreamPartTypeToolCall, + ID: "tc-provider-runner", + ToolCallName: providerRunnerName, + ToolCallInput: `{}`, + }, + {Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonToolCalls}, + }), nil + }, + } + + runnerTool := fantasy.NewAgentTool( + providerRunnerName, + "provider runner", + func(context.Context, struct{}, fantasy.ToolCall) (fantasy.ToolResponse, error) { + runnerCalls.Add(1) + return fantasy.NewTextResponse("ran provider runner"), nil + }, + ) + + var persistedStep PersistedStep + err := Run(context.Background(), RunOptions{ + Model: model, + Messages: []fantasy.Message{ + textMessage(fantasy.MessageRoleUser, "use the computer"), + }, + Tools: []fantasy.AgentTool{newNoopTool(activeToolName)}, + ActiveTools: []string{activeToolName}, + ProviderTools: []ProviderTool{ + { + Definition: fantasy.FunctionTool{ + Name: providerRunnerName, + Description: "provider runner", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{}, + }, + }, + Runner: runnerTool, + }, + }, + MaxSteps: 1, + PersistStep: func(_ context.Context, step PersistedStep) error { + persistedStep = step + return nil + }, + }) + require.NoError(t, err) + require.Equal(t, int32(1), runnerCalls.Load(), + "provider runner should execute even when omitted from active tools") + + var foundToolResult bool + for _, block := range persistedStep.Content { + toolResult, ok := fantasy.AsContentType[fantasy.ToolResultContent](block) + if !ok || toolResult.ToolName != providerRunnerName { + continue + } + textResult, ok := toolResult.Result.(fantasy.ToolResultOutputContentText) + require.True(t, ok) + assert.Equal(t, "ran provider runner", textResult.Text) + foundToolResult = true + } + require.True(t, foundToolResult, + "persisted step should include the provider runner result") +} + func TestProcessStepStream_AnthropicUsageMatchesFinalDelta(t *testing.T) { t.Parallel() @@ -921,6 +1065,144 @@ func TestRun_MultiStepToolExecution(t *testing.T) { "tool-result timestamp must be >= tool-call timestamp") } +func TestStopAfterTool_Success(t *testing.T) { + t.Parallel() + + streamCalls := 0 + model := &chattest.FakeModel{ + ProviderName: "fake", + StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { + streamCalls++ + return streamFromParts([]fantasy.StreamPart{ + {Type: fantasy.StreamPartTypeToolInputStart, ID: "tc-plan", ToolCallName: "propose_plan"}, + {Type: fantasy.StreamPartTypeToolInputDelta, ID: "tc-plan", Delta: `{"path":"/tmp/plan.md"}`}, + {Type: fantasy.StreamPartTypeToolInputEnd, ID: "tc-plan"}, + { + Type: fantasy.StreamPartTypeToolCall, + ID: "tc-plan", + ToolCallName: "propose_plan", + ToolCallInput: `{"path":"/tmp/plan.md"}`, + }, + {Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonToolCalls}, + }), nil + }, + } + + proposePlanTool := fantasy.NewAgentTool( + "propose_plan", + "writes a plan", + func(context.Context, struct{}, fantasy.ToolCall) (fantasy.ToolResponse, error) { + return fantasy.NewTextResponse("plan saved"), nil + }, + ) + + var persistedSteps []PersistedStep + persistStepCalls := 0 + + err := Run(context.Background(), RunOptions{ + Model: model, + Messages: []fantasy.Message{ + textMessage(fantasy.MessageRoleUser, "propose a plan"), + }, + Tools: []fantasy.AgentTool{proposePlanTool}, + MaxSteps: 5, + StopAfterTools: map[string]struct{}{ + "propose_plan": {}, + }, + PersistStep: func(_ context.Context, step PersistedStep) error { + persistStepCalls++ + persistedSteps = append(persistedSteps, step) + return nil + }, + }) + require.ErrorIs(t, err, ErrStopAfterTool) + require.Equal(t, 1, streamCalls) + require.Equal(t, 1, persistStepCalls) + require.Len(t, persistedSteps, 1) + + var foundToolResult bool + for _, block := range persistedSteps[0].Content { + toolResult, ok := fantasy.AsContentType[fantasy.ToolResultContent](block) + if !ok || toolResult.ToolName != "propose_plan" { + continue + } + foundToolResult = true + _, isErr := toolResult.Result.(fantasy.ToolResultOutputContentError) + require.False(t, isErr, "stop-after-tool should only trigger on successful tool results") + } + require.True(t, foundToolResult, "persisted step should include the successful tool result before stopping") +} + +func TestStopAfterTool_IgnoresErrorResults(t *testing.T) { + t.Parallel() + + streamCalls := 0 + model := &chattest.FakeModel{ + ProviderName: "fake", + StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { + streamCalls++ + if streamCalls == 1 { + return streamFromParts([]fantasy.StreamPart{ + {Type: fantasy.StreamPartTypeToolInputStart, ID: "tc-plan", ToolCallName: "propose_plan"}, + {Type: fantasy.StreamPartTypeToolInputDelta, ID: "tc-plan", Delta: `{"path":"/tmp/plan.md"}`}, + {Type: fantasy.StreamPartTypeToolInputEnd, ID: "tc-plan"}, + { + Type: fantasy.StreamPartTypeToolCall, + ID: "tc-plan", + ToolCallName: "propose_plan", + ToolCallInput: `{"path":"/tmp/plan.md"}`, + }, + {Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonToolCalls}, + }), nil + } + return streamFromParts([]fantasy.StreamPart{ + {Type: fantasy.StreamPartTypeTextStart, ID: "text-1"}, + {Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "tool failed, continue"}, + {Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"}, + {Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop}, + }), nil + }, + } + + proposePlanTool := fantasy.NewAgentTool( + "propose_plan", + "writes a plan", + func(context.Context, struct{}, fantasy.ToolCall) (fantasy.ToolResponse, error) { + return fantasy.NewTextErrorResponse("plan failed"), nil + }, + ) + + var persistedSteps []PersistedStep + err := Run(context.Background(), RunOptions{ + Model: model, + Messages: []fantasy.Message{ + textMessage(fantasy.MessageRoleUser, "propose a plan"), + }, + Tools: []fantasy.AgentTool{proposePlanTool}, + MaxSteps: 5, + StopAfterTools: map[string]struct{}{ + "propose_plan": {}, + }, + PersistStep: func(_ context.Context, step PersistedStep) error { + persistedSteps = append(persistedSteps, step) + return nil + }, + }) + require.NoError(t, err) + require.Equal(t, 2, streamCalls) + require.Len(t, persistedSteps, 2) + + var foundToolError bool + for _, block := range persistedSteps[0].Content { + toolResult, ok := fantasy.AsContentType[fantasy.ToolResultContent](block) + if !ok || toolResult.ToolName != "propose_plan" { + continue + } + _, foundToolError = toolResult.Result.(fantasy.ToolResultOutputContentError) + } + require.True(t, foundToolError, "first step should persist the failed tool result") +} + func TestRun_ParallelToolExecutionTimestamps(t *testing.T) { t.Parallel() diff --git a/coderd/x/chatd/chattool/askuserquestion.go b/coderd/x/chatd/chattool/askuserquestion.go new file mode 100644 index 0000000000..a4f106d3f7 --- /dev/null +++ b/coderd/x/chatd/chattool/askuserquestion.go @@ -0,0 +1,153 @@ +package chattool + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "charm.land/fantasy" + "golang.org/x/xerrors" +) + +const ( + askUserQuestionToolName = "ask_user_question" + askUserQuestionToolDesc = "Ask the user one or more structured clarification questions during plan mode. Use this instead of listing open questions in prose. Each question should have a short label, a detailed question, and 2-4 answer options." +) + +var ( + _ fantasy.AgentTool = (*askUserQuestionTool)(nil) + _ fantasy.Tool = (*askUserQuestionTool)(nil) +) + +type askUserQuestionOption struct { + Label string `json:"label"` + Description string `json:"description"` +} + +type askUserQuestion struct { + Header string `json:"header"` + Question string `json:"question"` + Options []askUserQuestionOption `json:"options"` +} + +type askUserQuestionArgs struct { + Questions []askUserQuestion `json:"questions"` +} + +// NewAskUserQuestionTool creates the ask_user_question tool. +func NewAskUserQuestionTool() fantasy.AgentTool { + return &askUserQuestionTool{} +} + +type askUserQuestionTool struct { + providerOptions fantasy.ProviderOptions +} + +func (*askUserQuestionTool) GetType() fantasy.ToolType { + return fantasy.ToolTypeFunction +} + +func (*askUserQuestionTool) GetName() string { + return askUserQuestionToolName +} + +func (*askUserQuestionTool) Info() fantasy.ToolInfo { + return fantasy.ToolInfo{ + Name: askUserQuestionToolName, + Description: askUserQuestionToolDesc, + Parameters: map[string]any{ + "questions": map[string]any{ + "type": "array", + "description": "The structured clarification questions to present to the user.", + "minItems": 1, + "items": map[string]any{ + "type": "object", + "properties": map[string]any{ + "header": map[string]any{ + "type": "string", + "description": "A short label for the question.", + }, + "question": map[string]any{ + "type": "string", + "description": "The detailed question text.", + }, + "options": map[string]any{ + "type": "array", + "description": "The answer options the user can choose from. Do not include an 'Other' or freeform option; one is provided automatically by the UI.", + "minItems": 2, + "maxItems": 4, + "items": map[string]any{ + "type": "object", + "properties": map[string]any{ + "label": map[string]any{ + "type": "string", + "description": "A short answer label.", + }, + "description": map[string]any{ + "type": "string", + "description": "More detail about what this option means.", + }, + }, + "required": []string{"label", "description"}, + }, + }, + }, + "required": []string{"header", "question", "options"}, + }, + }, + }, + Required: []string{"questions"}, + } +} + +func (*askUserQuestionTool) Run(_ context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) { + var args askUserQuestionArgs + if err := json.Unmarshal([]byte(call.Input), &args); err != nil { + return fantasy.NewTextErrorResponse(fmt.Sprintf("invalid parameters: %s", err)), nil + } + + if err := validateAskUserQuestionArgs(args); err != nil { + return fantasy.NewTextErrorResponse(err.Error()), nil + } + + data, err := json.Marshal(map[string]any{"questions": args.Questions}) + if err != nil { + return fantasy.NewTextErrorResponse("failed to marshal questions: " + err.Error()), nil + } + return fantasy.NewTextResponse(string(data)), nil +} + +func (t *askUserQuestionTool) ProviderOptions() fantasy.ProviderOptions { + return t.providerOptions +} + +func (t *askUserQuestionTool) SetProviderOptions(opts fantasy.ProviderOptions) { + t.providerOptions = opts +} + +func validateAskUserQuestionArgs(args askUserQuestionArgs) error { + if len(args.Questions) == 0 { + return xerrors.New("questions is required") + } + for i, question := range args.Questions { + if strings.TrimSpace(question.Header) == "" { + return xerrors.Errorf("questions[%d].header is required", i) + } + if strings.TrimSpace(question.Question) == "" { + return xerrors.Errorf("questions[%d].question is required", i) + } + if len(question.Options) < 2 || len(question.Options) > 4 { + return xerrors.Errorf("questions[%d].options must contain 2-4 items", i) + } + for j, option := range question.Options { + if strings.TrimSpace(option.Label) == "" { + return xerrors.Errorf("questions[%d].options[%d].label is required", i, j) + } + if strings.TrimSpace(option.Description) == "" { + return xerrors.Errorf("questions[%d].options[%d].description is required", i, j) + } + } + } + return nil +} diff --git a/coderd/x/chatd/chattool/askuserquestion_test.go b/coderd/x/chatd/chattool/askuserquestion_test.go new file mode 100644 index 0000000000..a5d270c1d9 --- /dev/null +++ b/coderd/x/chatd/chattool/askuserquestion_test.go @@ -0,0 +1,141 @@ +package chattool //nolint:testpackage // Uses internal symbols. + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestValidateAskUserQuestionArgs(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + args askUserQuestionArgs + wantErr string + }{ + { + name: "QuestionsRequired", + args: askUserQuestionArgs{}, + wantErr: "questions is required", + }, + { + name: "HeaderRequired", + args: askUserQuestionArgs{Questions: []askUserQuestion{{ + Header: " \t ", + Question: "What should we build?", + Options: validAskUserQuestionOptions(2), + }}}, + wantErr: "questions[0].header is required", + }, + { + name: "QuestionRequired", + args: askUserQuestionArgs{Questions: []askUserQuestion{{ + Header: "Scope", + Question: "\n\t ", + Options: validAskUserQuestionOptions(2), + }}}, + wantErr: "questions[0].question is required", + }, + { + name: "TooFewOptions", + args: askUserQuestionArgs{Questions: []askUserQuestion{{ + Header: "Scope", + Question: "What should we build?", + Options: validAskUserQuestionOptions(1), + }}}, + wantErr: "questions[0].options must contain 2-4 items", + }, + { + name: "TooManyOptions", + args: askUserQuestionArgs{Questions: []askUserQuestion{{ + Header: "Scope", + Question: "What should we build?", + Options: validAskUserQuestionOptions(5), + }}}, + wantErr: "questions[0].options must contain 2-4 items", + }, + { + name: "OptionLabelRequired", + args: askUserQuestionArgs{Questions: []askUserQuestion{{ + Header: "Scope", + Question: "What should we build?", + Options: []askUserQuestionOption{ + {Label: " ", Description: "Build the API first."}, + {Label: "Frontend", Description: "Build the UI first."}, + }, + }}}, + wantErr: "questions[0].options[0].label is required", + }, + { + name: "OptionDescriptionRequired", + args: askUserQuestionArgs{Questions: []askUserQuestion{{ + Header: "Scope", + Question: "What should we build?", + Options: []askUserQuestionOption{ + {Label: "Backend", Description: "\t"}, + {Label: "Frontend", Description: "Build the UI first."}, + }, + }}}, + wantErr: "questions[0].options[0].description is required", + }, + { + name: "ValidTwoOptions", + args: askUserQuestionArgs{Questions: []askUserQuestion{{ + Header: "Scope", + Question: "What should we build?", + Options: validAskUserQuestionOptions(2), + }}}, + }, + { + name: "ValidFourOptions", + args: askUserQuestionArgs{Questions: []askUserQuestion{{ + Header: "Scope", + Question: "What should we build?", + Options: validAskUserQuestionOptions(4), + }}}, + }, + { + name: "SecondQuestionInvalid", + args: askUserQuestionArgs{Questions: []askUserQuestion{ + { + Header: "Scope", + Question: "What should we build?", + Options: validAskUserQuestionOptions(2), + }, + { + Header: "Timeline", + Question: "\t ", + Options: validAskUserQuestionOptions(2), + }, + }}, + wantErr: "questions[1].question is required", + }, + } + + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + err := validateAskUserQuestionArgs(testCase.args) + if testCase.wantErr == "" { + require.NoError(t, err) + return + } + + require.EqualError(t, err, testCase.wantErr) + }) + } +} + +func validAskUserQuestionOptions(count int) []askUserQuestionOption { + options := []askUserQuestionOption{ + {Label: "Backend", Description: "Build the API first."}, + {Label: "Frontend", Description: "Build the UI first."}, + {Label: "Docs", Description: "Write the docs first."}, + {Label: "Tests", Description: "Start with tests first."}, + {Label: "Research", Description: "Investigate the problem first."}, + } + + return append([]askUserQuestionOption(nil), options[:count]...) +} diff --git a/coderd/x/chatd/chattool/editfiles.go b/coderd/x/chatd/chattool/editfiles.go index 100598e13d..d669362fbd 100644 --- a/coderd/x/chatd/chattool/editfiles.go +++ b/coderd/x/chatd/chattool/editfiles.go @@ -12,6 +12,7 @@ import ( type EditFilesOptions struct { GetWorkspaceConn func(context.Context) (workspacesdk.AgentConn, error) ResolvePlanPath func(context.Context) (chatPath string, home string, err error) + IsPlanTurn bool } type EditFilesArgs struct { @@ -24,6 +25,20 @@ func EditFiles(options EditFilesOptions) fantasy.AgentTool { "Perform search-and-replace edits on one or more files in the workspace."+ " Each file can have multiple edits applied atomically.", func(ctx context.Context, args EditFilesArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) { + var planPath string + if options.IsPlanTurn && len(args.Files) > 0 { + resolvedPlanPath, err := resolvePlanTurnPath(ctx, options.ResolvePlanPath) + if err != nil { + return fantasy.NewTextErrorResponse(err.Error()), nil + } + for i := range args.Files { + args.Files[i].Path = strings.TrimSpace(args.Files[i].Path) + if args.Files[i].Path != resolvedPlanPath { + return fantasy.NewTextErrorResponse("during plan turns, edit_files is restricted to " + resolvedPlanPath), nil + } + } + planPath = resolvedPlanPath + } if options.GetWorkspaceConn == nil { return fantasy.NewTextErrorResponse("workspace connection resolver is not configured"), nil } @@ -31,6 +46,11 @@ func EditFiles(options EditFilesOptions) fantasy.AgentTool { if err != nil { return fantasy.NewTextErrorResponse(err.Error()), nil } + if planPath != "" { + if err := ensurePlanPathResolvesToItself(ctx, conn, planPath); err != nil { + return fantasy.NewTextErrorResponse(err.Error()), nil + } + } return executeEditFilesTool(ctx, conn, args, options.ResolvePlanPath) }, ) diff --git a/coderd/x/chatd/chattool/editfiles_test.go b/coderd/x/chatd/chattool/editfiles_test.go index d16be9588a..4438a43174 100644 --- a/coderd/x/chatd/chattool/editfiles_test.go +++ b/coderd/x/chatd/chattool/editfiles_test.go @@ -2,6 +2,7 @@ package chattool_test import ( "context" + "net/http" "testing" "charm.land/fantasy" @@ -18,6 +19,164 @@ import ( func TestEditFiles(t *testing.T) { t.Parallel() + t.Run("PlanTurnRejectsNonPlanPath", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + mockConn := agentconnmock.NewMockAgentConn(ctrl) + planPath := "/home/coder/.coder/plans/PLAN-test-uuid.md" + getWorkspaceConnCalled := false + tool := chattool.EditFiles(chattool.EditFilesOptions{ + GetWorkspaceConn: func(context.Context) (workspacesdk.AgentConn, error) { + getWorkspaceConnCalled = true + return mockConn, nil + }, + ResolvePlanPath: func(context.Context) (string, string, error) { + return planPath, "/home/coder", nil + }, + IsPlanTurn: true, + }) + + resp, err := tool.Run(context.Background(), fantasy.ToolCall{ + ID: "call-1", + Name: "edit_files", + Input: `{"files":[{"path":"/home/coder/README.md","edits":[{"search":"old","replace":"new"}]}]}`, + }) + require.NoError(t, err) + assert.True(t, resp.IsError) + assert.Equal(t, "during plan turns, edit_files is restricted to "+planPath, resp.Content) + assert.False(t, getWorkspaceConnCalled) + }) + + t.Run("PlanTurnRejectsMixedPaths", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + mockConn := agentconnmock.NewMockAgentConn(ctrl) + planPath := "/home/coder/.coder/plans/PLAN-test-uuid.md" + getWorkspaceConnCalled := false + tool := chattool.EditFiles(chattool.EditFilesOptions{ + GetWorkspaceConn: func(context.Context) (workspacesdk.AgentConn, error) { + getWorkspaceConnCalled = true + return mockConn, nil + }, + ResolvePlanPath: func(context.Context) (string, string, error) { + return planPath, "/home/coder", nil + }, + IsPlanTurn: true, + }) + + resp, err := tool.Run(context.Background(), fantasy.ToolCall{ + ID: "call-1", + Name: "edit_files", + Input: `{"files":[` + + `{"path":"` + planPath + `","edits":[{"search":"old","replace":"new"}]},` + + `{"path":"/home/coder/README.md","edits":[{"search":"old","replace":"new"}]}` + + `]}`, + }) + require.NoError(t, err) + assert.True(t, resp.IsError) + assert.Equal(t, "during plan turns, edit_files is restricted to "+planPath, resp.Content) + assert.False(t, getWorkspaceConnCalled) + }) + + t.Run("PlanTurnAllowsResolvedPlanPath", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + mockConn := agentconnmock.NewMockAgentConn(ctrl) + planPath := "/home/coder/.coder/plans/PLAN-test-uuid.md" + resolvePlanPathCalls := 0 + mockConn.EXPECT().ResolvePath(gomock.Any(), planPath).Return(planPath, nil) + request := workspacesdk.FileEditRequest{Files: []workspacesdk.FileEdits{{ + Path: planPath, + Edits: []workspacesdk.FileEdit{{ + Search: "old", + Replace: "new", + }}, + }}} + mockConn.EXPECT().EditFiles(gomock.Any(), request).Return(nil) + + tool := chattool.EditFiles(chattool.EditFilesOptions{ + GetWorkspaceConn: func(context.Context) (workspacesdk.AgentConn, error) { + return mockConn, nil + }, + ResolvePlanPath: func(context.Context) (string, string, error) { + resolvePlanPathCalls++ + return planPath, "/home/coder", nil + }, + IsPlanTurn: true, + }) + + resp, err := tool.Run(context.Background(), fantasy.ToolCall{ + ID: "call-1", + Name: "edit_files", + Input: `{"files":[{"path":"` + planPath + `","edits":[{"search":"old","replace":"new"}]}]}`, + }) + require.NoError(t, err) + assert.False(t, resp.IsError) + assert.Equal(t, 1, resolvePlanPathCalls) + }) + + t.Run("PlanTurnAllowsLegacyAgentWithoutResolvePath", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + mockConn := agentconnmock.NewMockAgentConn(ctrl) + planPath := "/home/coder/.coder/plans/PLAN-test-uuid.md" + mockConn.EXPECT(). + ResolvePath(gomock.Any(), planPath). + Return("", statusError{statusCode: http.StatusNotFound, message: "missing resolve-path endpoint"}) + request := workspacesdk.FileEditRequest{Files: []workspacesdk.FileEdits{{ + Path: planPath, + Edits: []workspacesdk.FileEdit{{ + Search: "old", + Replace: "new", + }}, + }}} + mockConn.EXPECT().EditFiles(gomock.Any(), request).Return(nil) + + tool := chattool.EditFiles(chattool.EditFilesOptions{ + GetWorkspaceConn: func(context.Context) (workspacesdk.AgentConn, error) { + return mockConn, nil + }, + ResolvePlanPath: func(context.Context) (string, string, error) { + return planPath, "/home/coder", nil + }, + IsPlanTurn: true, + }) + + resp, err := tool.Run(context.Background(), fantasy.ToolCall{ + ID: "call-1", + Name: "edit_files", + Input: `{"files":[{"path":"` + planPath + `","edits":[{"search":"old","replace":"new"}]}]}`, + }) + require.NoError(t, err) + assert.False(t, resp.IsError) + }) + + t.Run("PlanTurnRejectsSymlinkedPlanPath", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + mockConn := agentconnmock.NewMockAgentConn(ctrl) + planPath := "/home/coder/.coder/plans/PLAN-test-uuid.md" + mockConn.EXPECT().ResolvePath(gomock.Any(), planPath).Return("/home/coder/README.md", nil) + tool := chattool.EditFiles(chattool.EditFilesOptions{ + GetWorkspaceConn: func(context.Context) (workspacesdk.AgentConn, error) { + return mockConn, nil + }, + ResolvePlanPath: func(context.Context) (string, string, error) { + return planPath, "/home/coder", nil + }, + IsPlanTurn: true, + }) + + resp, err := tool.Run(context.Background(), fantasy.ToolCall{ + ID: "call-1", + Name: "edit_files", + Input: `{"files":[{"path":"` + planPath + `","edits":[{"search":"old","replace":"new"}]}]}`, + }) + require.NoError(t, err) + assert.True(t, resp.IsError) + assert.Equal(t, "the chat-specific plan path /home/coder/.coder/plans/PLAN-test-uuid.md resolves to /home/coder/README.md; symlinked plan paths are not allowed during plan turns", resp.Content) + }) + t.Run("RejectsPlanPathsWhenResolvePlanPathIsConfigured", func(t *testing.T) { t.Parallel() diff --git a/coderd/x/chatd/chattool/planpath.go b/coderd/x/chatd/chattool/planpath.go index 964dde0d03..f1c4e4852c 100644 --- a/coderd/x/chatd/chattool/planpath.go +++ b/coderd/x/chatd/chattool/planpath.go @@ -53,6 +53,26 @@ func PlanPathForChat(home string, chatID uuid.UUID) string { ) } +func resolvePlanTurnPath( + ctx context.Context, + resolvePlanPath func(context.Context) (chatPath string, home string, err error), +) (string, error) { + if resolvePlanPath == nil { + return "", xerrors.New("chat-specific plan path resolver is not configured") + } + + planPath, _, err := resolvePlanPath(ctx) + if err != nil { + return "", xerrors.Errorf("resolve chat-specific plan path: %w", err) + } + planPath = strings.TrimSpace(planPath) + if planPath == "" { + return "", xerrors.New("chat-specific plan path is empty") + } + + return planPath, nil +} + // chatd consumes agent-normalized POSIX paths. Workspace agents are // expected to convert separators to forward slashes before these // helpers run. diff --git a/coderd/x/chatd/chattool/planpath_helpers_test.go b/coderd/x/chatd/chattool/planpath_helpers_test.go new file mode 100644 index 0000000000..f223773d82 --- /dev/null +++ b/coderd/x/chatd/chattool/planpath_helpers_test.go @@ -0,0 +1,19 @@ +package chattool_test + +func sharedPlanPathResolvedMessage(requestedPath, planPath string) string { + return "the plan path " + requestedPath + + " is no longer supported at the home root; use the chat-specific plan path: " + planPath +} + +func planPathVerificationMessage(requestedPath string) string { + return "the plan path " + requestedPath + + " could not be verified because the workspace is currently unavailable to resolve the chat-specific plan path, try again shortly" +} + +func editFilesBatchRejectedMessage(message string) string { + return message + "; no files in this batch were applied" +} + +func relativePlanPathMessage() string { + return "plan files must use absolute paths; use the chat-specific absolute plan path" +} diff --git a/coderd/x/chatd/chattool/planpathmessage.go b/coderd/x/chatd/chattool/planpathmessage.go index a6845a82d2..d757653228 100644 --- a/coderd/x/chatd/chattool/planpathmessage.go +++ b/coderd/x/chatd/chattool/planpathmessage.go @@ -46,6 +46,14 @@ func sharedPlanPathMessage(requestedPath, chatPath string) string { ) } +func symlinkedPlanPathMessage(planPath, resolvedPath string) string { + return fmt.Sprintf( + "the chat-specific plan path %s resolves to %s; symlinked plan paths are not allowed during plan turns", + planPath, + resolvedPath, + ) +} + func planPathVerificationMessage(requestedPath string) string { return fmt.Sprintf( "the plan path %s could not be verified because the workspace is currently unavailable to resolve the chat-specific plan path, try again shortly", diff --git a/coderd/x/chatd/chattool/planpathresolve.go b/coderd/x/chatd/chattool/planpathresolve.go new file mode 100644 index 0000000000..e506e6d579 --- /dev/null +++ b/coderd/x/chatd/chattool/planpathresolve.go @@ -0,0 +1,54 @@ +package chattool + +import ( + "context" + "net/http" + "path" + "path/filepath" + "strings" + + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/codersdk/workspacesdk" +) + +func ensurePlanPathResolvesToItself( + ctx context.Context, + conn workspacesdk.AgentConn, + planPath string, +) error { + if conn == nil { + return xerrors.New("workspace connection is required") + } + + normalizedPlanPath := normalizeWorkspacePath(planPath) + resolvedPath, err := conn.ResolvePath(ctx, planPath) + if err != nil { + if resolvePathUnsupported(err) { + // Older workspace agents do not expose /resolve-path yet. Keep + // plan turns working during rolling upgrades, even though they + // cannot enforce the symlink guard until the agent is upgraded. + return nil + } + return xerrors.Errorf("resolve plan path: %w", err) + } + resolvedPath = normalizeWorkspacePath(resolvedPath) + if resolvedPath != normalizedPlanPath { + return xerrors.New(symlinkedPlanPathMessage(normalizedPlanPath, resolvedPath)) + } + + return nil +} + +func resolvePathUnsupported(err error) bool { + var statusErr interface{ StatusCode() int } + return xerrors.As(err, &statusErr) && statusErr.StatusCode() == http.StatusNotFound +} + +func normalizeWorkspacePath(pathString string) string { + pathString = strings.TrimSpace(pathString) + if pathString == "" { + return "" + } + return path.Clean(filepath.ToSlash(pathString)) +} diff --git a/coderd/x/chatd/chattool/proposeplan.go b/coderd/x/chatd/chattool/proposeplan.go index bc46e7392c..2b18ee6db5 100644 --- a/coderd/x/chatd/chattool/proposeplan.go +++ b/coderd/x/chatd/chattool/proposeplan.go @@ -19,6 +19,7 @@ type ProposePlanOptions struct { GetWorkspaceConn func(context.Context) (workspacesdk.AgentConn, error) ResolvePlanPath func(context.Context) (chatPath string, home string, err error) StoreFile func(ctx context.Context, name string, mediaType string, data []byte) (uuid.UUID, error) + IsPlanTurn bool } // ProposePlanArgs are the arguments for the propose_plan tool. @@ -36,6 +37,21 @@ func ProposePlan(options ProposePlanOptions) fantasy.AgentTool { "Pass the absolute file path to the plan. Important: use the chat-specific absolute plan path, not a generic path like PLAN.md in the home directory. "+ "The tool reads the content from the workspace.", func(ctx context.Context, args ProposePlanArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) { + if options.IsPlanTurn { + planPath, err := resolvePlanTurnPath(ctx, options.ResolvePlanPath) + if err != nil { + return fantasy.NewTextErrorResponse(err.Error()), nil + } + path := strings.TrimSpace(args.Path) + switch { + case path == "": + args.Path = planPath + case path != planPath: + return fantasy.NewTextErrorResponse("during plan turns, propose_plan path must be " + planPath), nil + default: + args.Path = path + } + } if options.GetWorkspaceConn == nil { return fantasy.NewTextErrorResponse("workspace connection resolver is not configured"), nil } @@ -90,6 +106,9 @@ func executeProposePlanTool( if err != nil { return fantasy.NewTextErrorResponse(err.Error()), nil } + if len(data) == 0 || strings.TrimSpace(string(data)) == "" { + return fantasy.NewTextErrorResponse("plan file is empty; write your plan to " + requestedPath + " before proposing"), nil + } if int64(len(data)) > maxProposePlanSize { return fantasy.NewTextErrorResponse("plan file exceeds 32 KiB size limit"), nil } diff --git a/coderd/x/chatd/chattool/proposeplan_test.go b/coderd/x/chatd/chattool/proposeplan_test.go index 92b645818c..c6d87b5857 100644 --- a/coderd/x/chatd/chattool/proposeplan_test.go +++ b/coderd/x/chatd/chattool/proposeplan_test.go @@ -6,7 +6,6 @@ import ( "io" "strings" "testing" - "testing/iotest" "charm.land/fantasy" "github.com/google/uuid" @@ -31,13 +30,13 @@ type proposePlanResponse struct { func TestProposePlan(t *testing.T) { t.Parallel() - t.Run("EmptyPathReturnsError", func(t *testing.T) { + t.Run("RejectsEmptyPath", func(t *testing.T) { t.Parallel() ctrl := gomock.NewController(t) mockConn := agentconnmock.NewMockAgentConn(ctrl) storeFile, _ := fakeStoreFile(t) - tool := newProposePlanTool(t, mockConn, storeFile) + tool := newProposePlanToolWithPlanPath(t, mockConn, storeFile, nil, false) resp, err := tool.Run(context.Background(), fantasy.ToolCall{ ID: "call-1", Name: "propose_plan", @@ -45,326 +44,133 @@ func TestProposePlan(t *testing.T) { }) require.NoError(t, err) assert.True(t, resp.IsError) - assert.Contains(t, resp.Content, "path is required") + assert.Equal(t, "path is required (use the chat-specific absolute plan path)", resp.Content) }) - t.Run("WhitespaceOnlyPathReturnsError", func(t *testing.T) { + t.Run("RejectsNonMarkdownPath", func(t *testing.T) { t.Parallel() ctrl := gomock.NewController(t) mockConn := agentconnmock.NewMockAgentConn(ctrl) storeFile, _ := fakeStoreFile(t) - tool := newProposePlanTool(t, mockConn, storeFile) + tool := newProposePlanToolWithPlanPath(t, mockConn, storeFile, nil, false) resp, err := tool.Run(context.Background(), fantasy.ToolCall{ ID: "call-1", Name: "propose_plan", - Input: `{"path":" "}`, + Input: `{"path":"/home/coder/.coder/plans/PLAN-chat.txt"}`, }) require.NoError(t, err) assert.True(t, resp.IsError) - assert.Contains(t, resp.Content, "path is required") + assert.Equal(t, "path must end with .md", resp.Content) }) - t.Run("NonMdPathReturnsError", func(t *testing.T) { + t.Run("PlanTurnDefaultsEmptyPathToResolvedPath", func(t *testing.T) { t.Parallel() ctrl := gomock.NewController(t) mockConn := agentconnmock.NewMockAgentConn(ctrl) - - storeFile, _ := fakeStoreFile(t) - tool := newProposePlanTool(t, mockConn, storeFile) - resp, err := tool.Run(context.Background(), fantasy.ToolCall{ - ID: "call-1", - Name: "propose_plan", - Input: `{"path":"/home/coder/plan.txt"}`, - }) - require.NoError(t, err) - assert.True(t, resp.IsError) - assert.Contains(t, resp.Content, "path must end with .md") - }) - - t.Run("RelativePlanPathReturnsError", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - mockConn := agentconnmock.NewMockAgentConn(ctrl) - - storeFile, _ := fakeStoreFile(t) - resolvePlanPathCalled := false - tool := newProposePlanToolWithPlanPath( - t, - mockConn, - storeFile, - func(context.Context) (string, string, error) { - resolvePlanPathCalled = true - return "/home/coder/.coder/plans/PLAN-chat.md", "/home/coder", nil - }, - ) - - resp, err := tool.Run(context.Background(), fantasy.ToolCall{ - ID: "call-1", - Name: "propose_plan", - Input: `{"path":"plan.md"}`, - }) - require.NoError(t, err) - assert.True(t, resp.IsError) - assert.False(t, resolvePlanPathCalled) - assert.Equal(t, relativePlanPathMessage(), resp.Content) - }) - - t.Run("OversizedFileRejected", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - mockConn := agentconnmock.NewMockAgentConn(ctrl) - largeContent := strings.Repeat("x", 32*1024+1) - - mockConn.EXPECT(). - ReadFile(gomock.Any(), "/home/coder/PLAN.md", int64(0), int64(32*1024+1)). - Return(io.NopCloser(strings.NewReader(largeContent)), "text/markdown", nil) - - storeFile, _ := fakeStoreFile(t) - tool := newProposePlanTool(t, mockConn, storeFile) - resp, err := tool.Run(context.Background(), fantasy.ToolCall{ - ID: "call-1", - Name: "propose_plan", - Input: `{"path":"/home/coder/PLAN.md"}`, - }) - require.NoError(t, err) - assert.True(t, resp.IsError) - assert.Contains(t, resp.Content, "plan file exceeds 32 KiB size limit") - }) - - t.Run("ExactBoundaryFileSucceeds", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - mockConn := agentconnmock.NewMockAgentConn(ctrl) - content := strings.Repeat("x", 32*1024) - - mockConn.EXPECT(). - ReadFile(gomock.Any(), "/home/coder/PLAN.md", int64(0), int64(32*1024+1)). - Return(io.NopCloser(strings.NewReader(content)), "text/markdown", nil) - - storeFile, _ := fakeStoreFile(t) - tool := newProposePlanTool(t, mockConn, storeFile) - resp, err := tool.Run(context.Background(), fantasy.ToolCall{ - ID: "call-1", - Name: "propose_plan", - Input: `{"path":"/home/coder/PLAN.md"}`, - }) - require.NoError(t, err) - assert.False(t, resp.IsError) - }) - - t.Run("ValidPlanReadsFile", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - mockConn := agentconnmock.NewMockAgentConn(ctrl) - - mockConn.EXPECT(). - ReadFile(gomock.Any(), "/home/coder/docs/PLAN.md", int64(0), int64(32*1024+1)). - Return(io.NopCloser(strings.NewReader("# Plan\n\nContent")), "text/markdown", nil) - - storeFile, stored := fakeStoreFile(t) - planPathCalled := false - tool := newProposePlanToolWithPlanPath( - t, - mockConn, - storeFile, - func(context.Context) (string, string, error) { - planPathCalled = true - return "/home/coder/.coder/plans/PLAN-xxx.md", "/home/coder", nil - }, - ) - resp, err := tool.Run(context.Background(), fantasy.ToolCall{ - ID: "call-1", - Name: "propose_plan", - Input: `{"path":"/home/coder/docs/PLAN.md"}`, - }) - require.NoError(t, err) - assert.False(t, resp.IsError) - assert.True(t, planPathCalled) - - result := decodeProposePlanResponse(t, resp) - assert.True(t, result.OK) - assert.Equal(t, "/home/coder/docs/PLAN.md", result.Path) - assert.Equal(t, "plan", result.Kind) - assert.Equal(t, "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee", result.FileID) - assert.Equal(t, "text/markdown", result.MediaType) - assert.Equal(t, []byte("# Plan\n\nContent"), *stored) - assert.NotContains(t, resp.Content, "content") - }) - - t.Run("NestedPlanPathUnderHomeIsAllowed", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - mockConn := agentconnmock.NewMockAgentConn(ctrl) - - mockConn.EXPECT(). - ReadFile(gomock.Any(), "/home/coder/myproject/plan.md", int64(0), int64(32*1024+1)). - Return(io.NopCloser(strings.NewReader("# Nested Plan")), "text/markdown", nil) - - storeFile, stored := fakeStoreFile(t) - planPathCalled := false - tool := newProposePlanToolWithPlanPath( - t, - mockConn, - storeFile, - func(context.Context) (string, string, error) { - planPathCalled = true - return "/home/coder/.coder/plans/PLAN-chat.md", "/home/coder", nil - }, - ) - resp, err := tool.Run(context.Background(), fantasy.ToolCall{ - ID: "call-1", - Name: "propose_plan", - Input: `{"path":"/home/coder/myproject/plan.md"}`, - }) - require.NoError(t, err) - assert.False(t, resp.IsError) - assert.True(t, planPathCalled) - - result := decodeProposePlanResponse(t, resp) - assert.True(t, result.OK) - assert.Equal(t, "/home/coder/myproject/plan.md", result.Path) - assert.Equal(t, []byte("# Nested Plan"), *stored) - }) - - t.Run("FileNotFound", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - mockConn := agentconnmock.NewMockAgentConn(ctrl) - - mockConn.EXPECT(). - ReadFile(gomock.Any(), "/home/coder/PLAN.md", int64(0), int64(32*1024+1)). - Return(nil, "", xerrors.New("file not found")) - - storeFile, _ := fakeStoreFile(t) - tool := newProposePlanTool(t, mockConn, storeFile) - resp, err := tool.Run(context.Background(), fantasy.ToolCall{ - ID: "call-1", - Name: "propose_plan", - Input: `{"path":"/home/coder/PLAN.md"}`, - }) - require.NoError(t, err) - assert.True(t, resp.IsError) - assert.Contains(t, resp.Content, "file not found") - }) - - t.Run("ReadAllError", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - mockConn := agentconnmock.NewMockAgentConn(ctrl) - - mockConn.EXPECT(). - ReadFile(gomock.Any(), "/home/coder/PLAN.md", int64(0), int64(32*1024+1)). - Return(io.NopCloser(iotest.ErrReader(xerrors.New("connection reset"))), "text/markdown", nil) - - storeFile, _ := fakeStoreFile(t) - tool := newProposePlanTool(t, mockConn, storeFile) - resp, err := tool.Run(context.Background(), fantasy.ToolCall{ - ID: "call-1", - Name: "propose_plan", - Input: `{"path":"/home/coder/PLAN.md"}`, - }) - require.NoError(t, err) - assert.True(t, resp.IsError) - assert.Contains(t, resp.Content, "connection reset") - }) - - t.Run("StoreFileError", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - mockConn := agentconnmock.NewMockAgentConn(ctrl) - - mockConn.EXPECT(). - ReadFile(gomock.Any(), "/home/coder/PLAN.md", int64(0), int64(32*1024+1)). - Return(io.NopCloser(strings.NewReader("# Plan")), "text/markdown", nil) - - tool := newProposePlanTool(t, mockConn, func(_ context.Context, _ string, _ string, _ []byte) (uuid.UUID, error) { - return uuid.Nil, xerrors.New("storage unavailable") - }) - resp, err := tool.Run(context.Background(), fantasy.ToolCall{ - ID: "call-1", - Name: "propose_plan", - Input: `{"path":"/home/coder/PLAN.md"}`, - }) - require.NoError(t, err) - assert.True(t, resp.IsError) - assert.Contains(t, resp.Content, "storage unavailable") - }) - - t.Run("RejectsSharedPlanPathWithResolvedPath", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - mockConn := agentconnmock.NewMockAgentConn(ctrl) - - storeFile, _ := fakeStoreFile(t) - tool := newProposePlanToolWithPlanPath( - t, - mockConn, - storeFile, - func(context.Context) (string, string, error) { - return "/home/coder/.coder/plans/PLAN-chat.md", "/home/coder", nil - }, - ) - - resp, err := tool.Run(context.Background(), fantasy.ToolCall{ - ID: "call-1", - Name: "propose_plan", - Input: `{"path":"` + chattool.LegacySharedPlanPath + `"}`, - }) - require.NoError(t, err) - assert.True(t, resp.IsError) - assert.Equal( - t, - sharedPlanPathResolvedMessage(chattool.LegacySharedPlanPath, "/home/coder/.coder/plans/PLAN-chat.md"), - resp.Content, - ) - }) - - t.Run("RejectsSharedPlanPathWhenResolverFails", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - mockConn := agentconnmock.NewMockAgentConn(ctrl) - - storeFile, _ := fakeStoreFile(t) - tool := newProposePlanToolWithPlanPath( - t, - mockConn, - storeFile, - func(context.Context) (string, string, error) { - return "", "", xerrors.New("workspace unavailable") - }, - ) - - resp, err := tool.Run(context.Background(), fantasy.ToolCall{ - ID: "call-1", - Name: "propose_plan", - Input: `{"path":"` + chattool.LegacySharedPlanPath + `"}`, - }) - require.NoError(t, err) - assert.True(t, resp.IsError) - assert.Equal(t, planPathVerificationMessage(chattool.LegacySharedPlanPath), resp.Content) - }) - - t.Run("PerChatPlanPathIsAllowed", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - mockConn := agentconnmock.NewMockAgentConn(ctrl) - chatPlanPath := "/home/coder/.coder/plans/PLAN-123e4567-e89b-12d3-a456-426614174000.md" + chatPlanPath := "/home/coder/.coder/plans/PLAN-chat.md" mockConn.EXPECT(). ReadFile(gomock.Any(), chatPlanPath, int64(0), int64(32*1024+1)). - Return(io.NopCloser(strings.NewReader("# Per-Chat Plan")), "text/markdown", nil) + Return(io.NopCloser(strings.NewReader("# Plan")), "text/markdown", nil) storeFile, stored := fakeStoreFile(t) - resolvePlanPathCalled := false tool := newProposePlanToolWithPlanPath( t, mockConn, storeFile, func(context.Context) (string, string, error) { - resolvePlanPathCalled = true return chatPlanPath, "/home/coder", nil }, + true, + ) + resp, err := tool.Run(context.Background(), fantasy.ToolCall{ + ID: "call-1", + Name: "propose_plan", + Input: `{"path":""}`, + }) + require.NoError(t, err) + assert.False(t, resp.IsError) + + result := decodeProposePlanResponse(t, resp) + assert.True(t, result.OK) + assert.Equal(t, chatPlanPath, result.Path) + assert.Equal(t, "plan", result.Kind) + assert.Equal(t, "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee", result.FileID) + assert.Equal(t, "text/markdown", result.MediaType) + assert.Equal(t, "# Plan", string(*stored)) + }) + + t.Run("PlanTurnRejectsWrongPath", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + mockConn := agentconnmock.NewMockAgentConn(ctrl) + chatPlanPath := "/home/coder/.coder/plans/PLAN-chat.md" + + storeFile, _ := fakeStoreFile(t) + tool := newProposePlanToolWithPlanPath( + t, + mockConn, + storeFile, + func(context.Context) (string, string, error) { + return chatPlanPath, "/home/coder", nil + }, + true, + ) + resp, err := tool.Run(context.Background(), fantasy.ToolCall{ + ID: "call-1", + Name: "propose_plan", + Input: `{"path":"/home/coder/README.md"}`, + }) + require.NoError(t, err) + assert.True(t, resp.IsError) + assert.Equal(t, "during plan turns, propose_plan path must be "+chatPlanPath, resp.Content) + }) + t.Run("RejectsReadFileErrors", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + mockConn := agentconnmock.NewMockAgentConn(ctrl) + chatPlanPath := "/home/coder/.coder/plans/PLAN-chat.md" + + mockConn.EXPECT(). + ReadFile(gomock.Any(), chatPlanPath, int64(0), int64(32*1024+1)). + Return(nil, "", xerrors.New("read failed")) + + storeFile, _ := fakeStoreFile(t) + tool := newProposePlanToolWithPlanPath(t, mockConn, storeFile, nil, false) + resp, err := tool.Run(context.Background(), fantasy.ToolCall{ + ID: "call-1", + Name: "propose_plan", + Input: `{"path":"` + chatPlanPath + `"}`, + }) + require.NoError(t, err) + assert.True(t, resp.IsError) + assert.Equal(t, "read failed", resp.Content) + }) + + t.Run("PlanTurnRejectsEmptyPlan", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + mockConn := agentconnmock.NewMockAgentConn(ctrl) + chatPlanPath := "/home/coder/.coder/plans/PLAN-chat.md" + + mockConn.EXPECT(). + ReadFile(gomock.Any(), chatPlanPath, int64(0), int64(32*1024+1)). + Return(io.NopCloser(strings.NewReader("")), "text/markdown", nil) + + storeFile, stored := fakeStoreFile(t) + storeCalled := false + tool := newProposePlanToolWithPlanPath( + t, + mockConn, + func(ctx context.Context, name string, mediaType string, data []byte) (uuid.UUID, error) { + storeCalled = true + return storeFile(ctx, name, mediaType, data) + }, + func(context.Context) (string, string, error) { + return chatPlanPath, "/home/coder", nil + }, + true, ) resp, err := tool.Run(context.Background(), fantasy.ToolCall{ ID: "call-1", @@ -372,117 +178,83 @@ func TestProposePlan(t *testing.T) { Input: `{"path":"` + chatPlanPath + `"}`, }) require.NoError(t, err) - assert.False(t, resp.IsError) - assert.False(t, resolvePlanPathCalled) - - result := decodeProposePlanResponse(t, resp) - assert.True(t, result.OK) - assert.Equal(t, chatPlanPath, result.Path) - assert.Equal(t, []byte("# Per-Chat Plan"), *stored) + assert.True(t, resp.IsError) + assert.Contains(t, resp.Content, "plan file is empty") + assert.Contains(t, resp.Content, chatPlanPath) + assert.False(t, storeCalled) + assert.Nil(t, *stored) }) - t.Run("NestedPlanPathAllowedWhenResolverFails", func(t *testing.T) { + t.Run("RejectsOversizedPlan", func(t *testing.T) { t.Parallel() ctrl := gomock.NewController(t) mockConn := agentconnmock.NewMockAgentConn(ctrl) + chatPlanPath := "/home/coder/.coder/plans/PLAN-chat.md" mockConn.EXPECT(). - ReadFile(gomock.Any(), "/home/coder/myproject/plan.md", int64(0), int64(32*1024+1)). - Return(io.NopCloser(strings.NewReader("# Nested Plan")), "text/markdown", nil) + ReadFile(gomock.Any(), chatPlanPath, int64(0), int64(32*1024+1)). + Return(io.NopCloser(strings.NewReader(strings.Repeat("x", 32*1024+1))), "text/markdown", nil) storeFile, stored := fakeStoreFile(t) + storeCalled := false tool := newProposePlanToolWithPlanPath( t, mockConn, - storeFile, - func(context.Context) (string, string, error) { - return "", "", xerrors.New("workspace unavailable") + func(ctx context.Context, name string, mediaType string, data []byte) (uuid.UUID, error) { + storeCalled = true + return storeFile(ctx, name, mediaType, data) }, + nil, + false, ) resp, err := tool.Run(context.Background(), fantasy.ToolCall{ ID: "call-1", Name: "propose_plan", - Input: `{"path":"/home/coder/myproject/plan.md"}`, - }) - require.NoError(t, err) - assert.False(t, resp.IsError) - - result := decodeProposePlanResponse(t, resp) - assert.True(t, result.OK) - assert.Equal(t, "/home/coder/myproject/plan.md", result.Path) - assert.Equal(t, []byte("# Nested Plan"), *stored) - }) - - t.Run("WorkspaceConnectionError", func(t *testing.T) { - t.Parallel() - storeFile, _ := fakeStoreFile(t) - tool := chattool.ProposePlan(chattool.ProposePlanOptions{ - GetWorkspaceConn: func(context.Context) (workspacesdk.AgentConn, error) { - return nil, xerrors.New("connection failed") - }, - StoreFile: storeFile, - }) - - resp, err := tool.Run(context.Background(), fantasy.ToolCall{ - ID: "call-1", - Name: "propose_plan", - Input: `{"path":"/home/coder/PLAN.md"}`, + Input: `{"path":"` + chatPlanPath + `"}`, }) require.NoError(t, err) assert.True(t, resp.IsError) - assert.Contains(t, resp.Content, "connection failed") + assert.Equal(t, "plan file exceeds 32 KiB size limit", resp.Content) + assert.False(t, storeCalled) + assert.Nil(t, *stored) }) - t.Run("NilWorkspaceResolver", func(t *testing.T) { - t.Parallel() - tool := chattool.ProposePlan(chattool.ProposePlanOptions{}) - - resp, err := tool.Run(context.Background(), fantasy.ToolCall{ - ID: "call-1", - Name: "propose_plan", - Input: `{"path":"/home/coder/PLAN.md"}`, - }) - require.NoError(t, err) - assert.True(t, resp.IsError) - assert.Contains(t, resp.Content, "workspace connection resolver is not configured") - }) - - t.Run("NilStoreFile", func(t *testing.T) { + t.Run("PropagatesStoreFileErrors", func(t *testing.T) { t.Parallel() ctrl := gomock.NewController(t) mockConn := agentconnmock.NewMockAgentConn(ctrl) + chatPlanPath := "/home/coder/.coder/plans/PLAN-chat.md" - tool := chattool.ProposePlan(chattool.ProposePlanOptions{ - GetWorkspaceConn: func(_ context.Context) (workspacesdk.AgentConn, error) { - return mockConn, nil + mockConn.EXPECT(). + ReadFile(gomock.Any(), chatPlanPath, int64(0), int64(32*1024+1)). + Return(io.NopCloser(strings.NewReader("# Plan")), "text/markdown", nil) + + tool := newProposePlanToolWithPlanPath( + t, + mockConn, + func(context.Context, string, string, []byte) (uuid.UUID, error) { + return uuid.Nil, xerrors.New("store failed") }, - }) - + nil, + false, + ) resp, err := tool.Run(context.Background(), fantasy.ToolCall{ ID: "call-1", Name: "propose_plan", - Input: `{"path":"/home/coder/PLAN.md"}`, + Input: `{"path":"` + chatPlanPath + `"}`, }) require.NoError(t, err) assert.True(t, resp.IsError) - assert.Contains(t, resp.Content, "file storage is not configured") + assert.Equal(t, "failed to store plan file: store failed", resp.Content) }) } -func newProposePlanTool( - t *testing.T, - mockConn *agentconnmock.MockAgentConn, - storeFile func(ctx context.Context, name string, mediaType string, data []byte) (uuid.UUID, error), -) fantasy.AgentTool { - t.Helper() - return newProposePlanToolWithPlanPath(t, mockConn, storeFile, nil) -} - func newProposePlanToolWithPlanPath( t *testing.T, mockConn *agentconnmock.MockAgentConn, storeFile func(ctx context.Context, name string, mediaType string, data []byte) (uuid.UUID, error), resolvePlanPath func(context.Context) (string, string, error), + isPlanTurn bool, ) fantasy.AgentTool { t.Helper() return chattool.ProposePlan(chattool.ProposePlanOptions{ @@ -491,27 +263,10 @@ func newProposePlanToolWithPlanPath( }, ResolvePlanPath: resolvePlanPath, StoreFile: storeFile, + IsPlanTurn: isPlanTurn, }) } -func sharedPlanPathResolvedMessage(requestedPath, planPath string) string { - return "the plan path " + requestedPath + - " is no longer supported at the home root; use the chat-specific plan path: " + planPath -} - -func planPathVerificationMessage(requestedPath string) string { - return "the plan path " + requestedPath + - " could not be verified because the workspace is currently unavailable to resolve the chat-specific plan path, try again shortly" -} - -func editFilesBatchRejectedMessage(message string) string { - return message + "; no files in this batch were applied" -} - -func relativePlanPathMessage() string { - return "plan files must use absolute paths; use the chat-specific absolute plan path" -} - func fakeStoreFile(t *testing.T) (func(ctx context.Context, name string, mediaType string, data []byte) (uuid.UUID, error), *[]byte) { t.Helper() diff --git a/coderd/x/chatd/chattool/teststatuserror_test.go b/coderd/x/chatd/chattool/teststatuserror_test.go new file mode 100644 index 0000000000..8b5510dfb6 --- /dev/null +++ b/coderd/x/chatd/chattool/teststatuserror_test.go @@ -0,0 +1,19 @@ +package chattool_test + +import "fmt" + +type statusError struct { + statusCode int + message string +} + +func (e statusError) Error() string { + if e.message != "" { + return e.message + } + return fmt.Sprintf("status %d", e.statusCode) +} + +func (e statusError) StatusCode() int { + return e.statusCode +} diff --git a/coderd/x/chatd/chattool/writefile.go b/coderd/x/chatd/chattool/writefile.go index 983b387209..0999f18a97 100644 --- a/coderd/x/chatd/chattool/writefile.go +++ b/coderd/x/chatd/chattool/writefile.go @@ -12,6 +12,7 @@ import ( type WriteFileOptions struct { GetWorkspaceConn func(context.Context) (workspacesdk.AgentConn, error) ResolvePlanPath func(context.Context) (chatPath string, home string, err error) + IsPlanTurn bool } type WriteFileArgs struct { @@ -24,6 +25,18 @@ func WriteFile(options WriteFileOptions) fantasy.AgentTool { "write_file", "Write a file to the workspace.", func(ctx context.Context, args WriteFileArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) { + var planPath string + if options.IsPlanTurn { + args.Path = strings.TrimSpace(args.Path) + resolvedPlanPath, err := resolvePlanTurnPath(ctx, options.ResolvePlanPath) + if err != nil { + return fantasy.NewTextErrorResponse(err.Error()), nil + } + if args.Path != resolvedPlanPath { + return fantasy.NewTextErrorResponse("during plan turns, write_file is restricted to " + resolvedPlanPath), nil + } + planPath = resolvedPlanPath + } if options.GetWorkspaceConn == nil { return fantasy.NewTextErrorResponse("workspace connection resolver is not configured"), nil } @@ -31,6 +44,11 @@ func WriteFile(options WriteFileOptions) fantasy.AgentTool { if err != nil { return fantasy.NewTextErrorResponse(err.Error()), nil } + if planPath != "" { + if err := ensurePlanPathResolvesToItself(ctx, conn, planPath); err != nil { + return fantasy.NewTextErrorResponse(err.Error()), nil + } + } return executeWriteFileTool(ctx, conn, args, options.ResolvePlanPath) }, ) diff --git a/coderd/x/chatd/chattool/writefile_test.go b/coderd/x/chatd/chattool/writefile_test.go index 2fa563156e..c006c911db 100644 --- a/coderd/x/chatd/chattool/writefile_test.go +++ b/coderd/x/chatd/chattool/writefile_test.go @@ -3,6 +3,7 @@ package chattool_test import ( "context" "io" + "net/http" "strings" "testing" @@ -20,6 +21,136 @@ import ( func TestWriteFile(t *testing.T) { t.Parallel() + t.Run("PlanTurnRejectsNonPlanPath", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + mockConn := agentconnmock.NewMockAgentConn(ctrl) + planPath := "/home/coder/.coder/plans/PLAN-test-uuid.md" + getWorkspaceConnCalled := false + tool := chattool.WriteFile(chattool.WriteFileOptions{ + GetWorkspaceConn: func(context.Context) (workspacesdk.AgentConn, error) { + getWorkspaceConnCalled = true + return mockConn, nil + }, + ResolvePlanPath: func(context.Context) (string, string, error) { + return planPath, "/home/coder", nil + }, + IsPlanTurn: true, + }) + + resp, err := tool.Run(context.Background(), fantasy.ToolCall{ + ID: "call-1", + Name: "write_file", + Input: `{"path":"/home/coder/README.md","content":"# Plan"}`, + }) + require.NoError(t, err) + assert.True(t, resp.IsError) + assert.Equal(t, "during plan turns, write_file is restricted to "+planPath, resp.Content) + assert.False(t, getWorkspaceConnCalled) + }) + + t.Run("PlanTurnAllowsResolvedPlanPath", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + mockConn := agentconnmock.NewMockAgentConn(ctrl) + planPath := "/home/coder/.coder/plans/PLAN-test-uuid.md" + resolvePlanPathCalls := 0 + mockConn.EXPECT().ResolvePath(gomock.Any(), planPath).Return(planPath, nil) + mockConn.EXPECT(). + WriteFile(gomock.Any(), planPath, gomock.Any()). + DoAndReturn(func(_ context.Context, path string, reader io.Reader) error { + data, err := io.ReadAll(reader) + require.NoError(t, err) + require.Equal(t, planPath, path) + require.Equal(t, "# Plan", string(data)) + return nil + }) + + tool := chattool.WriteFile(chattool.WriteFileOptions{ + GetWorkspaceConn: func(context.Context) (workspacesdk.AgentConn, error) { + return mockConn, nil + }, + ResolvePlanPath: func(context.Context) (string, string, error) { + resolvePlanPathCalls++ + return planPath, "/home/coder", nil + }, + IsPlanTurn: true, + }) + + resp, err := tool.Run(context.Background(), fantasy.ToolCall{ + ID: "call-1", + Name: "write_file", + Input: `{"path":"` + planPath + `","content":"# Plan"}`, + }) + require.NoError(t, err) + assert.False(t, resp.IsError) + assert.Equal(t, 1, resolvePlanPathCalls) + assert.Equal(t, `{"ok":true}`, strings.TrimSpace(resp.Content)) + }) + + t.Run("PlanTurnAllowsLegacyAgentWithoutResolvePath", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + mockConn := agentconnmock.NewMockAgentConn(ctrl) + planPath := "/home/coder/.coder/plans/PLAN-test-uuid.md" + mockConn.EXPECT(). + ResolvePath(gomock.Any(), planPath). + Return("", statusError{statusCode: http.StatusNotFound, message: "missing resolve-path endpoint"}) + mockConn.EXPECT(). + WriteFile(gomock.Any(), planPath, gomock.Any()). + DoAndReturn(func(_ context.Context, path string, reader io.Reader) error { + data, err := io.ReadAll(reader) + require.NoError(t, err) + require.Equal(t, planPath, path) + require.Equal(t, "# Plan", string(data)) + return nil + }) + tool := chattool.WriteFile(chattool.WriteFileOptions{ + GetWorkspaceConn: func(context.Context) (workspacesdk.AgentConn, error) { + return mockConn, nil + }, + ResolvePlanPath: func(context.Context) (string, string, error) { + return planPath, "/home/coder", nil + }, + IsPlanTurn: true, + }) + + resp, err := tool.Run(context.Background(), fantasy.ToolCall{ + ID: "call-1", + Name: "write_file", + Input: `{"path":"` + planPath + `","content":"# Plan"}`, + }) + require.NoError(t, err) + assert.False(t, resp.IsError) + assert.Equal(t, `{"ok":true}`, strings.TrimSpace(resp.Content)) + }) + + t.Run("PlanTurnRejectsSymlinkedPlanPath", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + mockConn := agentconnmock.NewMockAgentConn(ctrl) + planPath := "/home/coder/.coder/plans/PLAN-test-uuid.md" + mockConn.EXPECT().ResolvePath(gomock.Any(), planPath).Return("/home/coder/README.md", nil) + tool := chattool.WriteFile(chattool.WriteFileOptions{ + GetWorkspaceConn: func(context.Context) (workspacesdk.AgentConn, error) { + return mockConn, nil + }, + ResolvePlanPath: func(context.Context) (string, string, error) { + return planPath, "/home/coder", nil + }, + IsPlanTurn: true, + }) + + resp, err := tool.Run(context.Background(), fantasy.ToolCall{ + ID: "call-1", + Name: "write_file", + Input: `{"path":"` + planPath + `","content":"# Plan"}`, + }) + require.NoError(t, err) + assert.True(t, resp.IsError) + assert.Equal(t, "the chat-specific plan path /home/coder/.coder/plans/PLAN-test-uuid.md resolves to /home/coder/README.md; symlinked plan paths are not allowed during plan turns", resp.Content) + }) + t.Run("RejectsHomeRootPlanVariantsWhenResolvePlanPathIsConfigured", func(t *testing.T) { t.Parallel() diff --git a/coderd/x/chatd/prompt.go b/coderd/x/chatd/prompt.go index d6764f8be7..635af7e86f 100644 --- a/coderd/x/chatd/prompt.go +++ b/coderd/x/chatd/prompt.go @@ -21,7 +21,7 @@ Ask concise clarifying questions only when: - architecture, tooling, or style preferences would change the implementation; - the action is destructive, irreversible, or expensive; or - you cannot make progress with confidence. -If a task is too ambiguous to implement with confidence, or the user asks for a plan, write a plan before implementing. Use propose_plan to present it for review. +If a task is too ambiguous to implement with confidence, ask for clarification before proceeding. @@ -94,12 +94,35 @@ Once a workspace is available: chat-specific path from the block below when it is available. 3. Iterate on the plan with edit_files if needed. -4. Call propose_plan with the same absolute plan file path from the - block below. -5. Wait for the user to review and approve the plan before starting implementation. +4. Present the plan to the user and wait for review before starting implementation. -The propose_plan tool reads the file from the workspace. Do not pass content directly. Write the file first, then present it. All file paths must be absolute. When the block below is present, use that exact path. ` + defaultSystemPromptPlanPathBlockPlaceholder + ` ` + +// PlanningOverlayPrompt contains plan-mode-only instructions appended +// when the chat is in plan mode. +const PlanningOverlayPrompt = `You are in Plan Mode. +Every response must work toward producing a plan. +The only intentional authored workspace artifact is the plan file at the path specified in the block below. +You may use execute and process_output for exploration, including cloning repositories, searching code, and running inspection commands needed to build the plan. +Do not use Plan Mode to implement the requested changes or intentionally modify project files outside the plan file. +If no workspace is attached to this chat yet, create and start one with create_workspace and start_workspace before investigating. +If the plan file already exists, read it first with read_file before replacing or refining it. +Use read_file, execute, process_output, list_templates, read_template, and spawn_agent to gather context. In Plan Mode, spawn_agent delegation is for investigation and planning support, not code writing or implementation. +Use write_file to create the plan file and edit_files to refine it. +Use ask_user_question for structured clarification instead of freeform questions. +When the plan is ready, call propose_plan with the plan file path. +After a successful propose_plan call, stop immediately. Do not produce follow-up output. +` + defaultSystemPromptPlanPathBlockPlaceholder + +// PlanningSubagentOverlayPrompt contains plan-mode instructions for +// delegated child chats. Child chats may investigate with shell tools +// but should return findings to the parent instead of authoring the +// final plan. +const PlanningSubagentOverlayPrompt = `You are in Plan Mode as a delegated sub-agent. +Every response must help the parent agent produce a plan. +You may use read_file, execute, process_output, read_skill, and read_skill_file for exploration, including cloning repositories, searching code, and running inspection commands. +Do not implement changes or intentionally modify workspace files. +Return concise findings and recommendations to the parent agent.` diff --git a/coderd/x/chatd/subagent.go b/coderd/x/chatd/subagent.go index 25620aa5ef..692b4dbda5 100644 --- a/coderd/x/chatd/subagent.go +++ b/coderd/x/chatd/subagent.go @@ -97,26 +97,36 @@ func (p *Server) isDesktopEnabled(ctx context.Context) bool { } func (p *Server) subagentTools(ctx context.Context, currentChat func() database.Chat) []fantasy.AgentTool { + var planMode database.NullChatPlanMode + if currentChat != nil { + planMode = currentChat().PlanMode + } + + spawnAgentDescription := "Spawn a delegated child agent to work on a clearly scoped, " + + "independent task in parallel. Use this when the task is " + + "self-contained and would benefit from a separate agent " + + "(e.g. fixing a specific bug, writing a single module, " + + "running a migration). Do NOT use for simple or quick " + + "operations you can handle directly with execute, " + + "read_file, or write_file - for example, reading a group " + + "of files and outputting them verbatim does not need a " + + "subagent. Reserve subagents for tasks that require " + + "intellectual work such as code analysis, writing new " + + "code, or complex refactoring. Be careful when running " + + "parallel subagents: if two subagents modify the same " + + "files they will conflict with each other, so ensure " + + "parallel subagent tasks are independent. " + + "The child agent receives the same workspace tools but " + + "cannot spawn its own subagents. After spawning, use " + + "wait_agent to collect the result." + if planMode.Valid && planMode.ChatPlanMode == database.ChatPlanModePlan { + spawnAgentDescription += " During plan mode, spawned agents may use shell commands for exploration, such as cloning repositories, searching code, and running inspection commands, but they must not implement changes or intentionally modify workspace files." + } + tools := []fantasy.AgentTool{ fantasy.NewAgentTool( "spawn_agent", - "Spawn a delegated child agent to work on a clearly scoped, "+ - "independent task in parallel. Use this when the task is "+ - "self-contained and would benefit from a separate agent "+ - "(e.g. fixing a specific bug, writing a single module, "+ - "running a migration). Do NOT use for simple or quick "+ - "operations you can handle directly with execute, "+ - "read_file, or write_file - for example, reading a group "+ - "of files and outputting them verbatim does not need a "+ - "subagent. Reserve subagents for tasks that require "+ - "intellectual work such as code analysis, writing new "+ - "code, or complex refactoring. Be careful when running "+ - "parallel subagents: if two subagents modify the same "+ - "files they will conflict with each other, so ensure "+ - "parallel subagent tasks are independent. "+ - "The child agent receives the same workspace tools but "+ - "cannot spawn its own subagents. After spawning, use "+ - "wait_agent to collect the result.", + spawnAgentDescription, func(ctx context.Context, args spawnAgentArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) { if currentChat == nil { return fantasy.NewTextErrorResponse("subagent callbacks are not configured"), nil @@ -131,11 +141,12 @@ func (p *Server) subagentTools(ctx context.Context, currentChat func() database. if err != nil { return fantasy.NewTextErrorResponse(err.Error()), nil } - childChat, err := p.createChildSubagentChat( + childChat, err := p.createChildSubagentChatWithOptions( ctx, parent, args.Prompt, args.Title, + childSubagentChatOptions{}, ) if err != nil { return fantasy.NewTextErrorResponse(err.Error()), nil @@ -474,6 +485,7 @@ func (p *Server) createChildSubagentChatWithOptions( LastModelConfigID: parent.LastModelConfigID, Title: title, Mode: opts.chatMode, + PlanMode: parent.PlanMode, Status: database.ChatStatusPending, MCPServerIDs: mcpServerIDs, Labels: pqtype.NullRawMessage{ diff --git a/coderd/x/chatd/subagent_internal_test.go b/coderd/x/chatd/subagent_internal_test.go index 999c552be8..06ddce0226 100644 --- a/coderd/x/chatd/subagent_internal_test.go +++ b/coderd/x/chatd/subagent_internal_test.go @@ -256,6 +256,43 @@ func TestCreateChildSubagentChatInheritsWorkspaceBinding(t *testing.T) { require.Equal(t, parentChat.AgentID, childChat.AgentID) } +func TestCreateChildSubagentChatCopiesPlanMode(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{}) + + ctx := chatdTestContext(t) + user, org, model := seedInternalChatDeps(ctx, t, db) + planMode := database.NullChatPlanMode{ + ChatPlanMode: database.ChatPlanModePlan, + Valid: true, + } + + parent, err := server.CreateChat(ctx, CreateOptions{ + OrganizationID: org.ID, + OwnerID: user.ID, + Title: "plan-parent", + ModelConfigID: model.ID, + PlanMode: planMode, + InitialUserContent: []codersdk.ChatMessagePart{ + codersdk.ChatMessageText("plan this change"), + }, + }) + require.NoError(t, err) + + parentChat, err := db.GetChatByID(ctx, parent.ID) + require.NoError(t, err) + require.Equal(t, planMode, parentChat.PlanMode) + + child, err := server.createChildSubagentChat(ctx, parentChat, "inspect bindings", "") + require.NoError(t, err) + + childChat, err := db.GetChatByID(ctx, child.ID) + require.NoError(t, err) + require.Equal(t, planMode, childChat.PlanMode) +} + func TestSpawnComputerUseAgent_NoAnthropicProvider(t *testing.T) { t.Parallel() diff --git a/codersdk/chats.go b/codersdk/chats.go index c13fd222af..8baccef57b 100644 --- a/codersdk/chats.go +++ b/codersdk/chats.go @@ -65,6 +65,7 @@ type Chat struct { LastModelConfigID uuid.UUID `json:"last_model_config_id" format:"uuid"` Title string `json:"title"` Status ChatStatus `json:"status"` + PlanMode ChatPlanMode `json:"plan_mode,omitempty"` LastError *string `json:"last_error"` DiffStatus *ChatDiffStatus `json:"diff_status,omitempty"` CreatedAt time.Time `json:"created_at" format:"date-time"` @@ -393,12 +394,14 @@ type CreateChatRequest struct { // LLM can invoke. This API is highly experimental and highly // subject to change. UnsafeDynamicTools []DynamicTool `json:"unsafe_dynamic_tools,omitempty"` + PlanMode ChatPlanMode `json:"plan_mode,omitempty"` } // UpdateChatRequest is the request to update a chat. type UpdateChatRequest struct { - Title *string `json:"title,omitempty"` - Archived *bool `json:"archived,omitempty"` + Title *string `json:"title,omitempty"` + Archived *bool `json:"archived,omitempty"` + WorkspaceID *uuid.UUID `json:"workspace_id,omitempty" format:"uuid"` // PinOrder controls the chat's pinned state and position. // - nil: no change to pin state. // - 0: unpin the chat. @@ -410,6 +413,9 @@ type UpdateChatRequest struct { // value is clamped to [1, pinned_count]. PinOrder *int32 `json:"pin_order,omitempty"` Labels *map[string]string `json:"labels,omitempty"` + // PlanMode switches the chat's persistent plan mode. + // nil: no change, ptr to "plan": enable, ptr to "": clear. + PlanMode *ChatPlanMode `json:"plan_mode,omitempty"` } // ChatBusyBehavior controls what happens when a user sends a message @@ -427,12 +433,23 @@ const ( ChatBusyBehaviorInterrupt ChatBusyBehavior = "interrupt" ) +// ChatPlanMode represents the persistent plan mode state of a chat. +type ChatPlanMode string + +const ( + // ChatPlanModePlan activates plan mode for the chat. + ChatPlanModePlan ChatPlanMode = "plan" +) + // CreateChatMessageRequest is the request to add a message to a chat. type CreateChatMessageRequest struct { Content []ChatInputPart `json:"content"` ModelConfigID *uuid.UUID `json:"model_config_id,omitempty" format:"uuid"` MCPServerIDs *[]uuid.UUID `json:"mcp_server_ids,omitempty" format:"uuid"` BusyBehavior ChatBusyBehavior `json:"busy_behavior,omitempty" enums:"queue,interrupt"` + // PlanMode switches the chat's persistent plan mode. + // nil: no change, ptr to "plan": enable, ptr to "": clear. + PlanMode *ChatPlanMode `json:"plan_mode,omitempty"` } // EditChatMessageRequest is the request to edit a user message in a chat. @@ -514,6 +531,18 @@ type UpdateChatSystemPromptRequest struct { IncludeDefaultSystemPrompt *bool `json:"include_default_system_prompt,omitempty"` } +// ChatPlanModeInstructionsResponse is the response body for the +// plan mode instructions configuration endpoint. +type ChatPlanModeInstructionsResponse struct { + PlanModeInstructions string `json:"plan_mode_instructions"` +} + +// UpdateChatPlanModeInstructionsRequest is the request body for +// updating the plan mode instructions configuration. +type UpdateChatPlanModeInstructionsRequest struct { + PlanModeInstructions string `json:"plan_mode_instructions"` +} + // UserChatCustomPrompt is the request and response body for the // user chat custom prompt configuration endpoint. type UserChatCustomPrompt struct { @@ -1891,6 +1920,33 @@ func (c *ExperimentalClient) UpdateChatSystemPrompt(ctx context.Context, req Upd return nil } +// GetChatPlanModeInstructions returns the deployment-wide plan mode instructions. +func (c *ExperimentalClient) GetChatPlanModeInstructions(ctx context.Context) (ChatPlanModeInstructionsResponse, error) { + res, err := c.Request(ctx, http.MethodGet, "/api/experimental/chats/config/plan-mode-instructions", nil) + if err != nil { + return ChatPlanModeInstructionsResponse{}, err + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return ChatPlanModeInstructionsResponse{}, ReadBodyAsError(res) + } + var resp ChatPlanModeInstructionsResponse + return resp, json.NewDecoder(res.Body).Decode(&resp) +} + +// UpdateChatPlanModeInstructions updates the deployment-wide plan mode instructions. +func (c *ExperimentalClient) UpdateChatPlanModeInstructions(ctx context.Context, req UpdateChatPlanModeInstructionsRequest) error { + res, err := c.Request(ctx, http.MethodPut, "/api/experimental/chats/config/plan-mode-instructions", req) + if err != nil { + return err + } + defer res.Body.Close() + if res.StatusCode != http.StatusNoContent { + return ReadBodyAsError(res) + } + return nil +} + // GetUserChatCustomPrompt fetches the user's custom chat prompt. func (c *ExperimentalClient) GetUserChatCustomPrompt(ctx context.Context) (UserChatCustomPrompt, error) { res, err := c.Request(ctx, http.MethodGet, "/api/experimental/chats/config/user-prompt", nil) diff --git a/codersdk/workspacesdk/agentconn.go b/codersdk/workspacesdk/agentconn.go index 9968a2f27b..cbdc73c1b6 100644 --- a/codersdk/workspacesdk/agentconn.go +++ b/codersdk/workspacesdk/agentconn.go @@ -10,6 +10,7 @@ import ( "net" "net/http" "net/netip" + neturl "net/url" "strconv" "sync" "time" @@ -81,6 +82,7 @@ type AgentConn interface { SignalProcess(ctx context.Context, id string, signal string) error StartProcess(ctx context.Context, req StartProcessRequest) (StartProcessResponse, error) LS(ctx context.Context, path string, req LSRequest) (LSResponse, error) + ResolvePath(ctx context.Context, path string) (string, error) ReadFile(ctx context.Context, path string, offset, limit int64) (io.ReadCloser, string, error) ReadFileLines(ctx context.Context, path string, offset, limit int64, limits ReadFileLinesLimits) (ReadFileLinesResponse, error) WriteFile(ctx context.Context, path string, reader io.Reader) error @@ -855,7 +857,9 @@ func (c *agentConn) LS(ctx context.Context, path string, req LSRequest) (LSRespo ctx, span := tracing.StartSpan(ctx) defer span.End() - res, err := c.apiRequest(ctx, http.MethodPost, fmt.Sprintf("/api/v0/list-directory?path=%s", path), req) + res, err := c.apiRequest(ctx, http.MethodPost, agentAPIPath("/api/v0/list-directory", neturl.Values{ + "path": []string{path}, + }), req) if err != nil { return LSResponse{}, xerrors.Errorf("do request: %w", err) } @@ -871,16 +875,50 @@ func (c *agentConn) LS(ctx context.Context, path string, req LSRequest) (LSRespo return m, nil } +// ResolvePathResponse is the response from the agent's path-resolution endpoint. +type ResolvePathResponse struct { + ResolvedPath string `json:"resolved_path"` +} + +// ResolvePath resolves the existing portion of an absolute path through any +// symlinks and preserves missing trailing components. +func (c *agentConn) ResolvePath(ctx context.Context, path string) (string, error) { + ctx, span := tracing.StartSpan(ctx) + defer span.End() + + res, err := c.apiRequest(ctx, http.MethodGet, agentAPIPath("/api/v0/resolve-path", neturl.Values{ + "path": []string{path}, + }), nil) + if err != nil { + return "", xerrors.Errorf("do request: %w", err) + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return "", codersdk.ReadBodyAsError(res) + } + + var m ResolvePathResponse + if err := json.NewDecoder(res.Body).Decode(&m); err != nil { + return "", xerrors.Errorf("decode response body: %w", err) + } + return m.ResolvedPath, nil +} + // ReadFileLines reads a file with line-based offset and limit, returning // line-numbered content with safety limits. func (c *agentConn) ReadFileLines(ctx context.Context, path string, offset, limit int64, limits ReadFileLinesLimits) (ReadFileLinesResponse, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() - res, err := c.apiRequest(ctx, http.MethodGet, fmt.Sprintf( - "/api/v0/read-file-lines?path=%s&offset=%d&limit=%d&max_file_size=%d&max_line_bytes=%d&max_response_lines=%d&max_response_bytes=%d", - path, offset, limit, limits.MaxFileSize, limits.MaxLineBytes, limits.MaxResponseLines, limits.MaxResponseBytes, - ), nil) + res, err := c.apiRequest(ctx, http.MethodGet, agentAPIPath("/api/v0/read-file-lines", neturl.Values{ + "path": []string{path}, + "offset": []string{strconv.FormatInt(offset, 10)}, + "limit": []string{strconv.FormatInt(limit, 10)}, + "max_file_size": []string{strconv.FormatInt(limits.MaxFileSize, 10)}, + "max_line_bytes": []string{strconv.Itoa(limits.MaxLineBytes)}, + "max_response_lines": []string{strconv.Itoa(limits.MaxResponseLines)}, + "max_response_bytes": []string{strconv.Itoa(limits.MaxResponseBytes)}, + }), nil) if err != nil { return ReadFileLinesResponse{}, xerrors.Errorf("do request: %w", err) } @@ -903,7 +941,11 @@ func (c *agentConn) ReadFile(ctx context.Context, path string, offset, limit int defer span.End() //nolint:bodyclose // we want to return the body so the caller can stream. - res, err := c.apiRequest(ctx, http.MethodGet, fmt.Sprintf("/api/v0/read-file?path=%s&offset=%d&limit=%d", path, offset, limit), nil) + res, err := c.apiRequest(ctx, http.MethodGet, agentAPIPath("/api/v0/read-file", neturl.Values{ + "path": []string{path}, + "offset": []string{strconv.FormatInt(offset, 10)}, + "limit": []string{strconv.FormatInt(limit, 10)}, + }), nil) if err != nil { return nil, "", xerrors.Errorf("do request: %w", err) } @@ -925,7 +967,9 @@ func (c *agentConn) WriteFile(ctx context.Context, path string, reader io.Reader ctx, span := tracing.StartSpan(ctx) defer span.End() - res, err := c.apiRequest(ctx, http.MethodPost, fmt.Sprintf("/api/v0/write-file?path=%s", path), reader) + res, err := c.apiRequest(ctx, http.MethodPost, agentAPIPath("/api/v0/write-file", neturl.Values{ + "path": []string{path}, + }), reader) if err != nil { return xerrors.Errorf("do request: %w", err) } @@ -1195,6 +1239,14 @@ func (c *agentConn) EditFiles(ctx context.Context, edits FileEditRequest) error return nil } +func agentAPIPath(path string, query neturl.Values) string { + if len(query) == 0 { + return path + } + + return path + "?" + query.Encode() +} + // apiRequest makes a request to the workspace agent's HTTP API server. func (c *agentConn) apiRequest(ctx context.Context, method, path string, body interface{}) (*http.Response, error) { ctx, span := tracing.StartSpan(ctx) diff --git a/codersdk/workspacesdk/agentconn_test.go b/codersdk/workspacesdk/agentconn_test.go new file mode 100644 index 0000000000..617c3d7b79 --- /dev/null +++ b/codersdk/workspacesdk/agentconn_test.go @@ -0,0 +1,52 @@ +//nolint:testpackage // This test exercises the internal query builder directly because agent requests need a live tailnet connection. +package workspacesdk + +import ( + neturl "net/url" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestAgentAPIPath(t *testing.T) { + t.Parallel() + + t.Run("encodes reserved query characters", func(t *testing.T) { + t.Parallel() + + path := "/tmp/a&b ?#%c.md" + got := agentAPIPath("/api/v0/resolve-path", neturl.Values{ + "path": []string{path}, + }) + + parsed, err := neturl.Parse(got) + require.NoError(t, err) + require.Equal(t, "/api/v0/resolve-path", parsed.Path) + require.Equal(t, path, parsed.Query().Get("path")) + }) + + t.Run("preserves all query values", func(t *testing.T) { + t.Parallel() + + got := agentAPIPath("/api/v0/read-file-lines", neturl.Values{ + "path": []string{"/tmp/plan v1#.md"}, + "offset": []string{"10"}, + "limit": []string{"20"}, + "max_file_size": []string{"30"}, + "max_line_bytes": []string{"40"}, + "max_response_lines": []string{"50"}, + "max_response_bytes": []string{"60"}, + }) + + parsed, err := neturl.Parse(got) + require.NoError(t, err) + require.Equal(t, "/api/v0/read-file-lines", parsed.Path) + require.Equal(t, "/tmp/plan v1#.md", parsed.Query().Get("path")) + require.Equal(t, "10", parsed.Query().Get("offset")) + require.Equal(t, "20", parsed.Query().Get("limit")) + require.Equal(t, "30", parsed.Query().Get("max_file_size")) + require.Equal(t, "40", parsed.Query().Get("max_line_bytes")) + require.Equal(t, "50", parsed.Query().Get("max_response_lines")) + require.Equal(t, "60", parsed.Query().Get("max_response_bytes")) + }) +} diff --git a/codersdk/workspacesdk/agentconnmock/agentconnmock.go b/codersdk/workspacesdk/agentconnmock/agentconnmock.go index f895a4c9ad..b782038ea8 100644 --- a/codersdk/workspacesdk/agentconnmock/agentconnmock.go +++ b/codersdk/workspacesdk/agentconnmock/agentconnmock.go @@ -449,6 +449,21 @@ func (mr *MockAgentConnMockRecorder) RecreateDevcontainer(ctx, devcontainerID an return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RecreateDevcontainer", reflect.TypeOf((*MockAgentConn)(nil).RecreateDevcontainer), ctx, devcontainerID) } +// ResolvePath mocks base method. +func (m *MockAgentConn) ResolvePath(ctx context.Context, path string) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ResolvePath", ctx, path) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ResolvePath indicates an expected call of ResolvePath. +func (mr *MockAgentConnMockRecorder) ResolvePath(ctx, path any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResolvePath", reflect.TypeOf((*MockAgentConn)(nil).ResolvePath), ctx, path) +} + // SSH mocks base method. func (m *MockAgentConn) SSH(ctx context.Context) (*gonet.TCPConn, error) { m.ctrl.T.Helper() diff --git a/docs/ai-coder/agents/architecture.md b/docs/ai-coder/agents/architecture.md index ec4f0917c5..7528ced69a 100644 --- a/docs/ai-coder/agents/architecture.md +++ b/docs/ai-coder/agents/architecture.md @@ -149,13 +149,21 @@ workspace connection. Platform and orchestration tools are only available to root chats — sub-agents spawned by `spawn_agent` do not have access to them and cannot create workspaces or spawn further sub-agents. -| Tool | What it does | -|--------------------|-----------------------------------------------------------------------------------------| -| `list_templates` | Browses available workspace templates, sorted by popularity. | -| `read_template` | Gets template details and configurable parameters. | -| `create_workspace` | Creates a workspace from a template and waits for it to be ready. | -| `start_workspace` | Starts the chat's workspace if it is currently stopped. Idempotent if already running. | -| `propose_plan` | Presents a Markdown plan file from the workspace for user review before implementation. | +| Tool | What it does | +|---------------------|-----------------------------------------------------------------------------------------| +| `list_templates` | Browses available workspace templates, sorted by popularity. | +| `read_template` | Gets template details and configurable parameters. | +| `create_workspace` | Creates a workspace from a template and waits for it to be ready. | +| `start_workspace` | Starts the chat's workspace if it is currently stopped. Idempotent if already running. | +| `propose_plan` | Presents a Markdown plan file from the workspace for user review before implementation. | +| `ask_user_question` | Asks the user structured clarification questions during plan mode. | + +`propose_plan` and `ask_user_question` are only exposed while plan mode is +active. In that mode, `write_file` and `edit_files` are restricted to the +chat-specific plan file, while `execute` and `process_output` remain available +for exploration such as cloning repositories, searching code, and running +inspection commands. MCP, dynamic, provider-native, and computer-use tools are +not available. ### Orchestration tools diff --git a/docs/ai-coder/agents/index.md b/docs/ai-coder/agents/index.md index 8e1835ac6d..153953cf57 100644 --- a/docs/ai-coder/agents/index.md +++ b/docs/ai-coder/agents/index.md @@ -236,6 +236,7 @@ tasks: | `create_workspace` | Create a workspace from a template | | `start_workspace` | Start a stopped workspace for the current chat | | `propose_plan` | Present a Markdown plan file for user review | +| `ask_user_question` | Ask the user structured clarification questions during plan mode | | `read_file` | Read file contents from the workspace | | `write_file` | Write a file to the workspace | | `edit_files` | Perform search-and-replace edits across files | @@ -257,7 +258,7 @@ web terminals and IDE access. No additional ports or services are required in the workspace. Platform tools (`list_templates`, `read_template`, `create_workspace`, -`start_workspace`, `propose_plan`) and orchestration tools (`spawn_agent`, +`start_workspace`, `propose_plan`, `ask_user_question`) and orchestration tools (`spawn_agent`, `wait_agent`, `message_agent`, `close_agent`, `spawn_computer_use_agent`) are only available to root chats. Sub-agents do not have access to these tools and cannot create workspaces or spawn further sub-agents. @@ -267,6 +268,41 @@ the virtual desktop feature to be enabled by an administrator. `read_skill` and `read_skill_file` are available when the workspace contains skills in its `.agents/skills/` directory. +`propose_plan` and `ask_user_question` are only available while plan mode is +active. In plan mode, the agent can still inspect the workspace and template +metadata, execute shell commands for exploration, and read process output. +`write_file` and `edit_files` remain available only for the chat-specific plan +file under `.coder/plans/`. MCP, dynamic, provider-native, and computer-use +tools are blocked. + +## Plan mode + +Plan mode lets you ask the agent to investigate first and present a plan before +implementation. Open the chat input menu and choose **Plan first** to enable it +for the current chat. After you enable it, later turns in that chat stay in +plan mode until you turn it off or click **Implement plan** after a proposed +plan. Because the mode is stored on the chat, reloading the page preserves the +current setting. + +While plan mode is active: + +- the agent can inspect repository files, workspace state, and available + templates +- `write_file` and `edit_files` can only modify the chat-specific plan file + under `.coder/plans/` +- `ask_user_question` can gather structured clarification from the user before + a plan is proposed +- `propose_plan` snapshots the current plan file into the transcript so you can + review it before implementation starts +- `execute` and `process_output` remain available for exploration, such as + cloning repositories, searching code, and running inspection commands +- MCP tools, dynamic tools, provider-native tools, and computer-use tools are + not available + +This keeps planning turns focused on analysis and plan authoring rather than +implementation. Once you click **Implement plan**, the next turn runs in normal +mode again. + ## Comparison to Coder Tasks Coder Agents is a new approach that differs from diff --git a/docs/ai-coder/agents/platform-controls/index.md b/docs/ai-coder/agents/platform-controls/index.md index a4685878a7..292c08f132 100644 --- a/docs/ai-coder/agents/platform-controls/index.md +++ b/docs/ai-coder/agents/platform-controls/index.md @@ -55,6 +55,21 @@ commit message formats, preferred libraries, or repository-specific context. The system prompt configuration is only accessible to administrators in the dashboard. Developers do not see or interact with it. +### Plan mode instructions + +Administrators can add deployment-wide instructions that apply only when a chat +enters plan mode. These instructions supplement the built-in planning behavior +and are useful for organization-specific planning requirements such as required +plan sections, approval checkpoints, or review workflows. + +This setting is available under **Agents** > **Settings** > **Behavior**. +Developers do not edit it directly. + +The same value is exposed over the experimental chat configuration API: + +- `GET /api/experimental/chats/config/plan-mode-instructions` +- `PUT /api/experimental/chats/config/plan-mode-instructions` + ### Template routing Platform teams control which templates are available to agents and how the agent diff --git a/site/src/api/api.ts b/site/src/api/api.ts index 3cb622454b..3563557bdc 100644 --- a/site/src/api/api.ts +++ b/site/src/api/api.ts @@ -3050,6 +3050,22 @@ export type CreateTaskFeedbackRequest = { comment?: string; }; +export type ChatPlanModeOrClear = TypesGen.ChatPlanMode | ""; + +export type CreateChatMessageRequestWithClearablePlanMode = Omit< + TypesGen.CreateChatMessageRequest, + "plan_mode" +> & { + readonly plan_mode?: ChatPlanModeOrClear; +}; + +type UpdateChatRequestWithClearablePlanMode = Omit< + TypesGen.UpdateChatRequest, + "plan_mode" +> & { + readonly plan_mode?: ChatPlanModeOrClear; +}; + // Experimental API methods call endpoints under the /api/experimental/ prefix. // These endpoints are not stable and may change or be removed at any time. // @@ -3143,7 +3159,7 @@ class ExperimentalApiMethods { updateChat = async ( chatId: string, - req: TypesGen.UpdateChatRequest, + req: UpdateChatRequestWithClearablePlanMode, ): Promise => { await this.axios.patch(`/api/experimental/chats/${chatId}`, req); }; @@ -3157,7 +3173,7 @@ class ExperimentalApiMethods { createChatMessage = async ( chatId: string, - req: TypesGen.CreateChatMessageRequest, + req: CreateChatMessageRequestWithClearablePlanMode, ): Promise => { const response = await this.axios.post( `/api/experimental/chats/${chatId}/messages`, @@ -3242,6 +3258,24 @@ class ExperimentalApiMethods { await this.axios.put("/api/experimental/chats/config/system-prompt", req); }; + getChatPlanModeInstructions = + async (): Promise => { + const response = + await this.axios.get( + "/api/experimental/chats/config/plan-mode-instructions", + ); + return response.data; + }; + + updateChatPlanModeInstructions = async ( + req: TypesGen.UpdateChatPlanModeInstructionsRequest, + ): Promise => { + await this.axios.put( + "/api/experimental/chats/config/plan-mode-instructions", + req, + ); + }; + getChatDesktopEnabled = async (): Promise => { const response = diff --git a/site/src/api/queries/chats.test.ts b/site/src/api/queries/chats.test.ts index fccde9fedc..c6299d5a14 100644 --- a/site/src/api/queries/chats.test.ts +++ b/site/src/api/queries/chats.test.ts @@ -26,6 +26,7 @@ import { reorderPinnedChat, unarchiveChat, unpinChat, + updateChatPlanMode, updateInfiniteChatsCache, } from "./chats"; @@ -199,6 +200,34 @@ describe("invalidateChatListQueries", () => { }); }); +describe("updateChatPlanMode optimistic update", () => { + it("invalidates the chat list on error without a detail cache", async () => { + const queryClient = createTestQueryClient(); + const chatId = "chat-1"; + seedInfiniteChats(queryClient, [makeChat(chatId)]); + + const mutation = updateChatPlanMode(queryClient); + const context = await mutation.onMutate({ + chatId, + planMode: "plan", + }); + + expect(context?.previousChat).toBeUndefined(); + expect(readInfiniteChats(queryClient)?.[0].plan_mode).toBe("plan"); + + mutation.onError( + new Error("server error"), + { chatId, planMode: "plan" }, + context, + ); + + expect( + queryClient.getQueryState(infiniteChatsTestKey)?.isInvalidated, + "chat list should be invalidated when rollback lacks detail cache", + ).toBe(true); + }); +}); + describe("archiveChat optimistic update", () => { it("optimistically sets archived to true in the chats list", async () => { const queryClient = createTestQueryClient(); diff --git a/site/src/api/queries/chats.ts b/site/src/api/queries/chats.ts index 6e9677aa7e..eefd2d6ad2 100644 --- a/site/src/api/queries/chats.ts +++ b/site/src/api/queries/chats.ts @@ -3,7 +3,11 @@ import type { QueryClient, UseInfiniteQueryOptions, } from "react-query"; -import { API } from "#/api/api"; +import { + API, + type ChatPlanModeOrClear, + type CreateChatMessageRequestWithClearablePlanMode, +} from "#/api/api"; import type * as TypesGen from "#/api/typesGenerated"; import type { UsePaginatedQueryOptions } from "#/hooks/usePaginatedQuery"; import { @@ -207,6 +211,19 @@ export const cancelChatListRefetches = (queryClient: QueryClient) => { }; const DEFAULT_CHAT_PAGE_LIMIT = 50; +const nilUUID = "00000000-0000-0000-0000-000000000000"; + +type UpdateChatWorkspaceVariables = { + chatId: string; + workspaceId: string | null; +}; + +type UpdateChatPlanModeVariables = { + chatId: string; + planMode?: ChatPlanModeOrClear; +}; + +const clearPlanMode = "" satisfies ChatPlanModeOrClear; export const infiniteChats = (opts?: { q?: string; archived?: boolean }) => { const limit = DEFAULT_CHAT_PAGE_LIMIT; @@ -387,6 +404,138 @@ export const unarchiveChat = (queryClient: QueryClient) => ({ }, }); +export const updateChatPlanMode = (queryClient: QueryClient) => ({ + mutationFn: ({ chatId, planMode }: UpdateChatPlanModeVariables) => + API.experimental.updateChat(chatId, { + plan_mode: planMode ?? clearPlanMode, + }), + onMutate: async ({ chatId, planMode }: UpdateChatPlanModeVariables) => { + await queryClient.cancelQueries({ + queryKey: chatsKey, + predicate: isChatListQuery, + }); + await queryClient.cancelQueries({ + queryKey: chatKey(chatId), + exact: true, + }); + const previousChat = queryClient.getQueryData( + chatKey(chatId), + ); + const nextPlanMode = planMode === clearPlanMode ? undefined : planMode; + updateInfiniteChatsCache(queryClient, (chats) => + chats.map((chat) => + chat.id === chatId ? { ...chat, plan_mode: nextPlanMode } : chat, + ), + ); + if (previousChat) { + queryClient.setQueryData(chatKey(chatId), { + ...previousChat, + plan_mode: nextPlanMode, + }); + } + return { previousChat }; + }, + onError: ( + _error: unknown, + { chatId }: UpdateChatPlanModeVariables, + context: + | { + previousChat?: TypesGen.Chat; + } + | undefined, + ) => { + void invalidateChatListQueries(queryClient); + const previousChat = context?.previousChat; + if (!previousChat) { + return; + } + updateInfiniteChatsCache(queryClient, (chats) => + chats.map((chat) => + chat.id === chatId + ? { + ...chat, + plan_mode: previousChat.plan_mode, + } + : chat, + ), + ); + queryClient.setQueryData(chatKey(chatId), previousChat); + }, +}); + +export const updateChatWorkspace = (queryClient: QueryClient) => ({ + mutationFn: ({ chatId, workspaceId }: UpdateChatWorkspaceVariables) => + API.experimental.updateChat(chatId, { + workspace_id: workspaceId ?? nilUUID, + }), + onMutate: async ({ chatId, workspaceId }: UpdateChatWorkspaceVariables) => { + await queryClient.cancelQueries({ + queryKey: chatsKey, + predicate: isChatListQuery, + }); + await queryClient.cancelQueries({ + queryKey: chatKey(chatId), + exact: true, + }); + const previousChat = queryClient.getQueryData( + chatKey(chatId), + ); + updateInfiniteChatsCache(queryClient, (chats) => + chats.map((chat) => + chat.id === chatId + ? { ...chat, workspace_id: workspaceId ?? undefined } + : chat, + ), + ); + if (previousChat) { + queryClient.setQueryData(chatKey(chatId), { + ...previousChat, + workspace_id: workspaceId ?? undefined, + }); + } + return { previousChat }; + }, + onError: ( + _error: unknown, + { chatId }: UpdateChatWorkspaceVariables, + context: + | { + previousChat?: TypesGen.Chat; + } + | undefined, + ) => { + void invalidateChatListQueries(queryClient); + const previousChat = context?.previousChat; + if (previousChat) { + updateInfiniteChatsCache(queryClient, (chats) => + chats.map((chat) => + chat.id === chatId + ? { + ...chat, + workspace_id: previousChat.workspace_id, + } + : chat, + ), + ); + queryClient.setQueryData(chatKey(chatId), previousChat); + } + }, + onSettled: async ( + _data: unknown, + _error: unknown, + { chatId }: UpdateChatWorkspaceVariables, + ) => { + await invalidateChatListQueries(queryClient); + await queryClient.invalidateQueries({ + queryKey: chatKey(chatId), + exact: true, + }); + await queryClient.invalidateQueries({ + queryKey: chatsByWorkspaceKeyPrefix, + }); + }, +}); + export const pinChat = (queryClient: QueryClient) => ({ mutationFn: (chatId: string) => API.experimental.updateChat(chatId, { pin_order: 1 }), @@ -596,7 +745,7 @@ export const createChatMessage = ( _queryClient: QueryClient, chatId: string, ) => ({ - mutationFn: (req: TypesGen.CreateChatMessageRequest) => + mutationFn: (req: CreateChatMessageRequestWithClearablePlanMode) => API.experimental.createChatMessage(chatId, req), // No onSuccess invalidation needed: the per-chat WebSocket delivers // the response message via upsertDurableMessage, and the global @@ -748,6 +897,23 @@ export const updateChatSystemPrompt = (queryClient: QueryClient) => ({ }, }); +const chatPlanModeInstructionsKey = ["chat-plan-mode-instructions"] as const; + +export const chatPlanModeInstructions = () => ({ + queryKey: chatPlanModeInstructionsKey, + queryFn: () => API.experimental.getChatPlanModeInstructions(), +}); + +export const updateChatPlanModeInstructions = (queryClient: QueryClient) => ({ + mutationFn: (req: TypesGen.UpdateChatPlanModeInstructionsRequest) => + API.experimental.updateChatPlanModeInstructions(req), + onSuccess: async () => { + await queryClient.invalidateQueries({ + queryKey: chatPlanModeInstructionsKey, + }); + }, +}); + const chatDesktopEnabledKey = ["chat-desktop-enabled"] as const; export const chatDesktopEnabled = () => ({ diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index b85ec07ef2..c9371353e0 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -1236,6 +1236,7 @@ export interface Chat { readonly last_model_config_id: string; readonly title: string; readonly status: ChatStatus; + readonly plan_mode?: ChatPlanMode; readonly last_error: string | null; readonly diff_status?: ChatDiffStatus; readonly created_at: string; @@ -2023,6 +2024,20 @@ export interface ChatModelsResponse { readonly providers: readonly ChatModelProvider[]; } +// From codersdk/chats.go +export type ChatPlanMode = "plan"; + +// From codersdk/chats.go +/** + * ChatPlanModeInstructionsResponse is the response body for the + * plan mode instructions configuration endpoint. + */ +export interface ChatPlanModeInstructionsResponse { + readonly plan_mode_instructions: string; +} + +export const ChatPlanModes: ChatPlanMode[] = ["plan"]; + // From codersdk/chats.go /** * ChatProviderConfig is an admin-managed provider configuration. @@ -2603,6 +2618,11 @@ export interface CreateChatMessageRequest { readonly model_config_id?: string; readonly mcp_server_ids?: string[]; readonly busy_behavior?: ChatBusyBehavior; + /** + * PlanMode switches the chat's persistent plan mode. + * nil: no change, ptr to "plan": enable, ptr to "": clear. + */ + readonly plan_mode?: ChatPlanMode; } // From codersdk/chats.go @@ -2664,6 +2684,7 @@ export interface CreateChatRequest { * subject to change. */ readonly unsafe_dynamic_tools?: readonly DynamicTool[]; + readonly plan_mode?: ChatPlanMode; } // From codersdk/users.go @@ -7568,6 +7589,15 @@ export interface UpdateChatModelConfigRequest { readonly model_config?: ChatModelCallConfig; } +// From codersdk/chats.go +/** + * UpdateChatPlanModeInstructionsRequest is the request body for + * updating the plan mode instructions configuration. + */ +export interface UpdateChatPlanModeInstructionsRequest { + readonly plan_mode_instructions: string; +} + // From codersdk/chats.go /** * UpdateChatProviderConfigRequest updates a chat provider config. @@ -7589,6 +7619,7 @@ export interface UpdateChatProviderConfigRequest { export interface UpdateChatRequest { readonly title?: string; readonly archived?: boolean; + readonly workspace_id?: string; /** * PinOrder controls the chat's pinned state and position. * - nil: no change to pin state. @@ -7602,6 +7633,11 @@ export interface UpdateChatRequest { */ readonly pin_order?: number; readonly labels?: Record; + /** + * PlanMode switches the chat's persistent plan mode. + * nil: no change, ptr to "plan": enable, ptr to "": clear. + */ + readonly plan_mode?: ChatPlanMode; } // From codersdk/chats.go diff --git a/site/src/pages/AgentsPage/AgentChatPage.stories.tsx b/site/src/pages/AgentsPage/AgentChatPage.stories.tsx index 9ce69ff735..81df5f6a44 100644 --- a/site/src/pages/AgentsPage/AgentChatPage.stories.tsx +++ b/site/src/pages/AgentsPage/AgentChatPage.stories.tsx @@ -229,6 +229,7 @@ const meta: Meta = { beforeEach: () => { localStorage.removeItem(RIGHT_PANEL_OPEN_KEY); spyOn(API, "getApiKey").mockRejectedValue(new Error("missing API key")); + spyOn(API.experimental, "updateChat").mockResolvedValue(); spyOn(API.experimental, "getMCPServerConfigs").mockResolvedValue([]); return () => localStorage.removeItem(RIGHT_PANEL_OPEN_KEY); }, @@ -609,6 +610,45 @@ export const Loading: Story = { }, }; +export const PlanModeFromChatState: Story = { + parameters: { + queries: buildQueries( + { + id: CHAT_ID, + ...baseChatFields, + title: "Plan mode persists", + status: "completed", + plan_mode: "plan", + }, + { messages: [], queued_messages: [], has_more: false }, + { diffUrl: undefined }, + ), + }, + play: async ({ canvasElement }) => { + const canvas = within(canvasElement); + const body = within(canvasElement.ownerDocument.body); + const user = userEvent.setup(); + + expect(await canvas.findByText("Planning")).toBeVisible(); + + await user.click(canvas.getByRole("button", { name: "More options" })); + await body.findByRole("dialog"); + const toggles = await body.findAllByRole("menuitemcheckbox", { + name: "Plan first", + }); + const toggle = toggles.at(-1); + if (!toggle) { + throw new Error("Plan mode toggle did not render."); + } + expect(toggle).toHaveAttribute("aria-checked", "true"); + await user.click(toggle); + + await waitFor(() => { + expect(canvas.queryByText("Planning")).not.toBeInTheDocument(); + }); + }, +}; + /** Full layout with actions menu and diff panel portaled to the right slot. */ export const CompletedWithDiffPanel: Story = { beforeEach: () => { diff --git a/site/src/pages/AgentsPage/AgentChatPage.test.ts b/site/src/pages/AgentsPage/AgentChatPage.test.ts index 2c1287e36a..99055dec52 100644 --- a/site/src/pages/AgentsPage/AgentChatPage.test.ts +++ b/site/src/pages/AgentsPage/AgentChatPage.test.ts @@ -1,11 +1,14 @@ import { act, renderHook } from "@testing-library/react"; import { createRef } from "react"; import { beforeEach, describe, expect, it, vi } from "vitest"; +import type * as TypesGen from "#/api/typesGenerated"; import { draftInputStorageKeyPrefix, + filterWorkspaceOptionsByOrganization, getPersistedDraftInputValue, restoreOptimisticRequestSnapshot, useConversationEditingState, + waitForPendingChatSettingsSyncs, } from "./AgentChatPage"; import type { ChatMessageInputRef } from "./components/AgentChatInput"; import { createChatStore } from "./components/ChatConversation/chatStore"; @@ -67,6 +70,85 @@ const setMobileViewport = (isMobile: boolean) => { }); }; +type Deferred = { + promise: Promise; + resolve: (value: T | PromiseLike) => void; + reject: (reason?: unknown) => void; +}; + +const createDeferred = (): Deferred => { + let resolve!: (value: T | PromiseLike) => void; + let reject!: (reason?: unknown) => void; + const promise = new Promise((res, rej) => { + resolve = res; + reject = rej; + }); + return { promise, resolve, reject }; +}; + +describe("waitForPendingChatSettingsSyncs", () => { + it("waits for plan-mode and workspace updates before resolving", async () => { + const planModeUpdate = createDeferred(); + const workspaceUpdate = createDeferred(); + let settled = false; + + const waitPromise = waitForPendingChatSettingsSyncs([ + planModeUpdate.promise, + workspaceUpdate.promise, + ]).then((result) => { + settled = true; + return result; + }); + + await Promise.resolve(); + expect(settled).toBe(false); + + planModeUpdate.resolve(undefined); + await Promise.resolve(); + expect(settled).toBe(false); + + workspaceUpdate.resolve(undefined); + await expect(waitPromise).resolves.toBeUndefined(); + expect(settled).toBe(true); + }); + + it("rejects when a chat-setting update fails", async () => { + const workspaceUpdate = createDeferred(); + const waitPromise = waitForPendingChatSettingsSyncs([ + workspaceUpdate.promise, + ]); + + workspaceUpdate.reject(new Error("boom")); + await expect(waitPromise).rejects.toThrow("boom"); + }); +}); + +describe("filterWorkspaceOptionsByOrganization", () => { + const makeWorkspace = (id: string, organizationID: string) => + ({ id, organization_id: organizationID }) as TypesGen.Workspace; + + it("returns only workspaces from the active chat organization", () => { + const workspaces = [ + makeWorkspace("workspace-1", "org-a"), + makeWorkspace("workspace-2", "org-b"), + makeWorkspace("workspace-3", "org-a"), + ]; + + expect(filterWorkspaceOptionsByOrganization(workspaces, "org-a")).toEqual([ + workspaces[0], + workspaces[2], + ]); + }); + + it("returns an empty list until the chat organization is known", () => { + const workspaces = [makeWorkspace("workspace-1", "org-a")]; + + expect(filterWorkspaceOptionsByOrganization(workspaces, undefined)).toEqual( + [], + ); + }); +}); + describe("getPersistedDraftInputValue", () => { const chatID = "chat-abc-123"; const expectedKey = `${draftInputStorageKeyPrefix}${chatID}`; @@ -423,6 +505,29 @@ describe("useConversationEditingState", () => { unmount(); }); + it("preserves the composer and draft when send fails", async () => { + const { result, onSend, unmount } = renderEditing(); + const mockInput = createMockChatInputHandle("hello"); + result.current.chatInputRef.current = mockInput.handle; + onSend.mockRejectedValueOnce(new Error("boom")); + + act(() => { + result.current.handleContentChange("hello", "hello", false); + }); + + await act(async () => { + await expect(result.current.handleSendFromInput("hello")).rejects.toThrow( + "boom", + ); + }); + + expect(mockInput.clear).not.toHaveBeenCalled(); + expect(mockInput.focus).not.toHaveBeenCalled(); + expect(result.current.inputValueRef.current).toBe("hello"); + expect(localStorage.getItem(expectedKey)).toBe("hello"); + unmount(); + }); + it("clears the composer and persisted draft after a successful send", async () => { localStorage.setItem(expectedKey, "draft to clear"); const { result, onSend, unmount } = renderEditing(); diff --git a/site/src/pages/AgentsPage/AgentChatPage.tsx b/site/src/pages/AgentsPage/AgentChatPage.tsx index 071b523046..1b72fcf220 100644 --- a/site/src/pages/AgentsPage/AgentChatPage.tsx +++ b/site/src/pages/AgentsPage/AgentChatPage.tsx @@ -7,13 +7,19 @@ import { useQueryClient, } from "react-query"; import { useOutletContext, useParams } from "react-router"; +import { toast } from "sonner"; import type { UrlTransform } from "streamdown"; -import { watchWorkspace } from "#/api/api"; -import { isApiError } from "#/api/errors"; +import { + type ChatPlanModeOrClear, + type CreateChatMessageRequestWithClearablePlanMode, + watchWorkspace, +} from "#/api/api"; +import { getErrorMessage, isApiError } from "#/api/errors"; import { buildOptimisticEditedMessage } from "#/api/queries/chatMessageEdits"; import { chat, chatDesktopEnabled, + chatKey, chatMessagesForInfiniteScroll, chatModelConfigs, chatModels, @@ -23,10 +29,17 @@ import { interruptChat, mcpServerConfigs, promoteChatQueuedMessage, + updateChatPlanMode, + updateChatWorkspace, + updateInfiniteChatsCache, userCompactionThresholds, } from "#/api/queries/chats"; import { deploymentSSHConfig } from "#/api/queries/deployment"; -import { workspaceById, workspaceByIdKey } from "#/api/queries/workspaces"; +import { + workspaceById, + workspaceByIdKey, + workspaces, +} from "#/api/queries/workspaces"; import type * as TypesGen from "#/api/typesGenerated"; import type { ChatMessagePart } from "#/api/typesGenerated"; import { useProxy } from "#/contexts/ProxyContext"; @@ -83,6 +96,10 @@ const lastModelConfigIDStorageKey = "agents.last-model-config-id"; /** @internal Exported for testing. */ export const draftInputStorageKeyPrefix = "agents.draft-input."; +const clearChatPlanMode = "" satisfies ChatPlanModeOrClear; + +type PlanModeSwitch = TypesGen.ChatPlanMode | "clear"; + /** * Read the persisted plain-text draft for a given chat ID. * Returns the text portion of the draft (stripping Lexical JSON @@ -122,6 +139,33 @@ export const restoreOptimisticRequestSnapshot = ( }); }; +/** @internal Exported for testing. */ +export const waitForPendingChatSettingsSyncs = async ( + pendingSyncs: readonly (Promise | null | undefined)[], +): Promise => { + const activeSyncs = pendingSyncs.filter( + (pendingSync): pendingSync is Promise => + pendingSync !== null && pendingSync !== undefined, + ); + if (activeSyncs.length === 0) { + return; + } + await Promise.all(activeSyncs); +}; + +/** @internal Exported for testing. */ +export const filterWorkspaceOptionsByOrganization = ( + workspaceOptions: readonly TypesGen.Workspace[], + organizationID: string | undefined, +): readonly TypesGen.Workspace[] => { + if (!organizationID) { + return []; + } + return workspaceOptions.filter( + (workspace) => workspace.organization_id === organizationID, + ); +}; + const buildAttachmentMediaTypes = ( attachments?: readonly PendingAttachment[], ): ReadonlyMap | undefined => { @@ -576,6 +620,11 @@ const AgentChatPage: FC = () => { const userThresholdsQuery = useQuery(userCompactionThresholds()); const desktopEnabledQuery = useQuery(chatDesktopEnabled()); const mcpServersQuery = useQuery(mcpServerConfigs()); + const workspacesQuery = useQuery(workspaces({ q: "owner:me", limit: 0 })); + const workspaceOptions = filterWorkspaceOptionsByOrganization( + workspacesQuery.data?.workspaces ?? [], + chatQuery.data?.organization_id, + ); const desktopEnabled = desktopEnabledQuery.data?.enable_desktop ?? false; // MCP server selection state. @@ -664,6 +713,7 @@ const AgentChatPage: FC = () => { const { proxy } = useProxy(); const chatRecord = chatQuery.data; + const planModeEnabled = chatRecord?.plan_mode === "plan"; // Initialize MCP selection from chat record or defaults. const effectiveMCPServerIds = (() => { @@ -740,6 +790,58 @@ const AgentChatPage: FC = () => { const { mutateAsync: promoteQueuedMessage } = useMutation( promoteChatQueuedMessage(queryClient, agentId ?? ""), ); + const updateChatWorkspaceBase = updateChatWorkspace(queryClient); + const { + isPending: isUpdateChatWorkspacePending, + mutateAsync: updateChatWorkspaceAsync, + } = useMutation({ + ...updateChatWorkspaceBase, + onError: (error, variables, context) => { + updateChatWorkspaceBase.onError(error, variables, context); + toast.error(getErrorMessage(error, "Failed to update workspace.")); + }, + }); + + const updateChatPlanModeBase = updateChatPlanMode(queryClient); + const { + isPending: isUpdateChatPlanModePending, + mutateAsync: updateChatPlanModeAsync, + } = useMutation({ + ...updateChatPlanModeBase, + onError: (error, variables, context) => { + updateChatPlanModeBase.onError(error, variables, context); + toast.error(getErrorMessage(error, "Failed to update plan mode.")); + }, + }); + const setCachedChatPlanMode = ( + chatId: string, + planMode?: TypesGen.ChatPlanMode, + ) => { + updateInfiniteChatsCache(queryClient, (chats) => + chats.map((chat) => + chat.id === chatId ? { ...chat, plan_mode: planMode } : chat, + ), + ); + queryClient.setQueryData(chatKey(chatId), (previousChat) => + previousChat ? { ...previousChat, plan_mode: planMode } : previousChat, + ); + }; + + const pendingPlanModeSyncRef = useRef | null>(null); + const pendingWorkspaceSyncRef = useRef | null>(null); + const trackPendingChatSettingSync = ( + syncPromise: Promise, + syncRef: { current: Promise | null }, + ) => { + let trackedSync: Promise; + trackedSync = syncPromise.finally(() => { + if (syncRef.current === trackedSync) { + syncRef.current = null; + } + }); + syncRef.current = trackedSync; + void trackedSync.catch(() => undefined); + }; const { store, clearStreamError, upsertCacheMessages } = useChatStore({ chatID: agentId, @@ -836,7 +938,26 @@ const AgentChatPage: FC = () => { }); const isSubmissionPending = isSendPending || isEditPending || isInterruptPending; - const isInputDisabled = !hasModelOptions || isArchived; + const isChatSettingsPending = + isUpdateChatPlanModePending || isUpdateChatWorkspacePending; + const isInputDisabled = + !hasModelOptions || isArchived || isChatSettingsPending; + const selectedWorkspaceId = chatQuery.data?.workspace_id ?? null; + + const isWorkspaceLoading = + workspacesQuery.isLoading || isUpdateChatWorkspacePending; + const handlePlanModeToggle = (enabled: boolean) => { + if (!agentId || enabled === planModeEnabled) { + return; + } + trackPendingChatSettingSync( + updateChatPlanModeAsync({ + chatId: agentId, + planMode: enabled ? "plan" : undefined, + }), + pendingPlanModeSyncRef, + ); + }; const handleUsageLimitError = (error: unknown): void => { if (!agentId) { @@ -863,144 +984,6 @@ const AgentChatPage: FC = () => { } }; - const handleSend = async ( - message: string, - attachments?: readonly PendingAttachment[], - editedMessageID?: number, - ) => { - const chatInputHandle = ( - editing.chatInputRef as React.RefObject - )?.current; - - // Walk the Lexical tree in document order so file-reference - // parts appear at the correct position relative to the - // surrounding text the user typed. - const editorParts = chatInputHandle?.getContentParts() ?? []; - const hasFileReferences = editorParts.some( - (p) => p.type === "file-reference", - ); - const hasContent = - message.trim() || - (attachments && attachments.length > 0) || - hasFileReferences; - if (!hasContent || isSubmissionPending || !agentId || !hasModelOptions) { - return; - } - - const content: TypesGen.ChatInputPart[] = []; - - // Emit parts in document order — text segments and - // file-reference chips are interleaved as they appear in - // the editor. - for (const part of editorParts) { - if (part.type === "text") { - const trimmed = part.text.trim(); - if (trimmed) { - content.push({ type: "text", text: part.text }); - } - } else { - const r = part.reference; - content.push({ - type: "file-reference", - file_name: r.fileName, - start_line: r.startLine, - end_line: r.endLine, - content: r.content, - }); - } - } - - // Add pre-uploaded file attachments. - if (attachments && attachments.length > 0) { - for (const { fileId } of attachments) { - content.push({ type: "file", file_id: fileId }); - } - } - if (editedMessageID !== undefined) { - const request: TypesGen.EditChatMessageRequest = { content }; - const originalEditedMessage = chatMessagesList?.find( - (existingMessage) => existingMessage.id === editedMessageID, - ); - const optimisticMessage = originalEditedMessage - ? buildOptimisticEditedMessage({ - requestContent: request.content, - originalMessage: originalEditedMessage, - attachmentMediaTypes: buildAttachmentMediaTypes(attachments), - }) - : undefined; - const previousSnapshot = store.getSnapshot(); - clearChatErrorReason(agentId); - clearStreamError(); - store.batch(() => { - store.setQueuedMessages([]); - store.setChatStatus("running"); - store.clearStreamState(); - }); - scrollToBottomRef.current?.(); - try { - await editMessage({ - messageId: editedMessageID, - optimisticMessage, - req: request, - }); - } catch (error) { - restoreOptimisticRequestSnapshot(store, previousSnapshot); - handleUsageLimitError(error); - throw error; - } - return; - } - const selectedModelConfigID = effectiveSelectedModel || undefined; - const request: TypesGen.CreateChatMessageRequest = { - content, - model_config_id: selectedModelConfigID, - mcp_server_ids: - effectiveMCPServerIds.length > 0 - ? [...effectiveMCPServerIds] - : undefined, - }; - clearChatErrorReason(agentId); - clearStreamError(); - scrollToBottomRef.current?.(); - - // Don't clear stream state before the POST completes. - // For queued sends the WebSocket status events handle - // clearing; for non-queued sends we clear explicitly - // below. Clearing eagerly causes a visible cutoff. - let response: Awaited>; - try { - response = await sendMessage(request); - } catch (error) { - handleUsageLimitError(error); - throw error; - } - // When the server accepts the message immediately (not - // queued), clear the stream and insert the user's message - // so it appears in the timeline without waiting for the - // WebSocket stream. - if (!response.queued) { - store.clearStreamState(); - // Optimistically set status to "running" so the - // "Thinking..." indicator appears immediately. - // The server accepted the message (not queued), - // so it will start processing. The WebSocket - // status:running event no-ops via the - // setChatStatus guard. If the server transitions - // to error/pending instead, the WebSocket event - // overrides this optimistic value. - store.setChatStatus("running"); - if (response.message) { - store.upsertDurableMessage(response.message); - upsertCacheMessages([response.message]); - } - } - if (selectedModelConfigID) { - localStorage.setItem(lastModelConfigIDStorageKey, selectedModelConfigID); - } else { - localStorage.removeItem(lastModelConfigIDStorageKey); - } - }; - const handleInterrupt = () => { if (!agentId || isInterruptPending) { return; @@ -1008,6 +991,19 @@ const AgentChatPage: FC = () => { void interrupt(); }; + const handleWorkspaceChange = (nextWorkspaceId: string | null) => { + if (!agentId || nextWorkspaceId === selectedWorkspaceId) { + return; + } + trackPendingChatSettingSync( + updateChatWorkspaceAsync({ + chatId: agentId, + workspaceId: nextWorkspaceId, + }), + pendingWorkspaceSyncRef, + ); + }; + const handleDeleteQueuedMessage = async (id: number) => { const previousQueuedMessages = store.getSnapshot().queuedMessages; store.setQueuedMessages( @@ -1125,6 +1121,197 @@ const AgentChatPage: FC = () => { return rewriteLocalhostURL(url, proxyHost, agentName, wsName, wsOwner); }; + function buildChatInputContent({ + message, + attachments, + useComposerContent = true, + }: { + message: string; + attachments?: readonly PendingAttachment[]; + useComposerContent?: boolean; + }): { content: TypesGen.ChatInputPart[]; hasContent: boolean } { + const content: TypesGen.ChatInputPart[] = []; + + if (useComposerContent) { + const chatInputHandle = ( + editing.chatInputRef as React.RefObject + )?.current; + const editorParts = chatInputHandle?.getContentParts() ?? []; + + // Walk the Lexical tree in document order so file-reference + // parts appear at the correct position relative to the + // surrounding text the user typed. + for (const part of editorParts) { + if (part.type === "text") { + if (part.text.trim()) { + content.push({ type: "text", text: part.text }); + } + } else { + const reference = part.reference; + content.push({ + type: "file-reference", + file_name: reference.fileName, + start_line: reference.startLine, + end_line: reference.endLine, + content: reference.content, + }); + } + } + + if (content.length === 0 && message.trim()) { + content.push({ type: "text", text: message }); + } + } else if (message.trim()) { + content.push({ type: "text", text: message }); + } + + if (attachments && attachments.length > 0) { + for (const { fileId } of attachments) { + content.push({ type: "file", file_id: fileId }); + } + } + + return { content, hasContent: content.length > 0 }; + } + + async function submitChatTurn({ + message, + attachments, + editedMessageID, + useComposerContent = true, + planModeSwitch, + }: { + message: string; + attachments?: readonly PendingAttachment[]; + editedMessageID?: number; + useComposerContent?: boolean; + planModeSwitch?: PlanModeSwitch; + }) { + const { content, hasContent } = buildChatInputContent({ + message, + attachments, + useComposerContent, + }); + if (!hasContent || isSubmissionPending || !agentId || !hasModelOptions) { + return; + } + // Wait for chat-setting mutations to settle before sending so the + // message observes the workspace and plan-mode choices the user just made. + await waitForPendingChatSettingsSyncs([ + pendingPlanModeSyncRef.current, + pendingWorkspaceSyncRef.current, + ]); + + if (editedMessageID !== undefined) { + const request: TypesGen.EditChatMessageRequest = { content }; + const originalEditedMessage = chatMessagesList?.find( + (existingMessage) => existingMessage.id === editedMessageID, + ); + const optimisticMessage = originalEditedMessage + ? buildOptimisticEditedMessage({ + requestContent: request.content, + originalMessage: originalEditedMessage, + attachmentMediaTypes: buildAttachmentMediaTypes(attachments), + }) + : undefined; + const previousSnapshot = store.getSnapshot(); + clearChatErrorReason(agentId); + clearStreamError(); + store.batch(() => { + store.setQueuedMessages([]); + store.setChatStatus("running"); + store.clearStreamState(); + }); + scrollToBottomRef.current?.(); + try { + await editMessage({ + messageId: editedMessageID, + optimisticMessage, + req: request, + }); + } catch (error) { + restoreOptimisticRequestSnapshot(store, previousSnapshot); + handleUsageLimitError(error); + throw error; + } + return; + } + + const selectedModelConfigID = effectiveSelectedModel || undefined; + const request: CreateChatMessageRequestWithClearablePlanMode = { + content, + model_config_id: selectedModelConfigID, + mcp_server_ids: + effectiveMCPServerIds.length > 0 + ? [...effectiveMCPServerIds] + : undefined, + ...(planModeSwitch !== undefined + ? { + plan_mode: + planModeSwitch === "clear" ? clearChatPlanMode : planModeSwitch, + } + : {}), + }; + clearChatErrorReason(agentId); + clearStreamError(); + scrollToBottomRef.current?.(); + + // Don't clear stream state before the POST completes. + // For queued sends the WebSocket status events handle + // clearing; for non-queued sends we clear explicitly + // below. Clearing eagerly causes a visible cutoff. + let response: Awaited>; + try { + response = await sendMessage(request); + } catch (error) { + handleUsageLimitError(error); + throw error; + } + // When the server accepts the message immediately (not + // queued), clear the stream and insert the user's message + // so it appears in the timeline without waiting for the + // WebSocket stream. + if (!response.queued) { + store.clearStreamState(); + // Optimistically set status to "running" so the + // "Thinking..." indicator appears immediately. + // The server accepted the message (not queued), + // so it will start processing. The WebSocket + // status:running event no-ops via the + // setChatStatus guard. If the server transitions + // to error/pending instead, the WebSocket event + // overrides this optimistic value. + store.setChatStatus("running"); + if (response.message) { + store.upsertDurableMessage(response.message); + upsertCacheMessages([response.message]); + } + } + if (selectedModelConfigID) { + localStorage.setItem(lastModelConfigIDStorageKey, selectedModelConfigID); + } else { + localStorage.removeItem(lastModelConfigIDStorageKey); + } + if (planModeSwitch !== undefined) { + setCachedChatPlanMode( + agentId, + planModeSwitch === "clear" ? undefined : planModeSwitch, + ); + } + } + + async function handleSend( + message: string, + attachments?: readonly PendingAttachment[], + editedMessageID?: number, + ) { + await submitChatTurn({ + message, + attachments, + editedMessageID, + }); + } + const handleRegenerateTitle = () => { if (!agentId || isRegenerateTitleDisabled || !onRegenerateTitle) { return; @@ -1132,6 +1319,21 @@ const AgentChatPage: FC = () => { onRegenerateTitle(agentId); }; + const handleSendAskUserQuestionResponse = async (message: string) => { + await submitChatTurn({ + message, + useComposerContent: false, + }); + }; + + const handleImplementPlan = async () => { + await submitChatTurn({ + message: "Implement the plan.", + planModeSwitch: "clear", + useComposerContent: false, + }); + }; + if (chatQuery.isLoading || chatMessagesQuery.isLoading) { return ( { modelSelectorPlaceholder={modelSelectorPlaceholder} hasModelOptions={hasModelOptions} isModelCatalogLoading={isModelCatalogLoading} + planModeEnabled={planModeEnabled} + onPlanModeToggle={handlePlanModeToggle} isSidebarCollapsed={isSidebarCollapsed} onToggleSidebarCollapsed={onToggleSidebarCollapsed} showRightPanel={showSidebarPanel} @@ -1180,10 +1384,16 @@ const AgentChatPage: FC = () => { modelSelectorHelp={modelSelectorHelp} hasModelOptions={hasModelOptions} isModelCatalogLoading={isModelCatalogLoading} + planModeEnabled={planModeEnabled} + onPlanModeToggle={handlePlanModeToggle} compressionThreshold={compressionThreshold} isInputDisabled={isInputDisabled} isSubmissionPending={isSubmissionPending} isInterruptPending={isInterruptPending} + workspaceOptions={workspaceOptions} + selectedWorkspaceId={selectedWorkspaceId} + onWorkspaceChange={handleWorkspaceChange} + isWorkspaceLoading={isWorkspaceLoading} isSidebarCollapsed={isSidebarCollapsed} onToggleSidebarCollapsed={onToggleSidebarCollapsed} showSidebarPanel={showSidebarPanel} @@ -1196,6 +1406,8 @@ const AgentChatPage: FC = () => { handleInterrupt={handleInterrupt} handleDeleteQueuedMessage={handleDeleteQueuedMessage} handlePromoteQueuedMessage={handlePromoteQueuedMessage} + onImplementPlan={handleImplementPlan} + onSendAskUserQuestionResponse={handleSendAskUserQuestionResponse} handleArchiveAgentAction={handleArchiveAgentAction} handleUnarchiveAgentAction={handleUnarchiveAgentAction} handleArchiveAndDeleteWorkspaceAction={ diff --git a/site/src/pages/AgentsPage/AgentChatPageView.tsx b/site/src/pages/AgentsPage/AgentChatPageView.tsx index e1ace66aba..6bb4be2dc3 100644 --- a/site/src/pages/AgentsPage/AgentChatPageView.tsx +++ b/site/src/pages/AgentsPage/AgentChatPageView.tsx @@ -99,10 +99,16 @@ interface AgentChatPageViewProps { modelSelectorHelp?: ReactNode; hasModelOptions: boolean; isModelCatalogLoading?: boolean; + planModeEnabled?: boolean; + onPlanModeToggle?: (enabled: boolean) => void; compressionThreshold: number | undefined; isInputDisabled: boolean; isSubmissionPending: boolean; isInterruptPending: boolean; + workspaceOptions?: readonly TypesGen.Workspace[]; + selectedWorkspaceId?: string | null; + onWorkspaceChange?: (workspaceId: string | null) => void; + isWorkspaceLoading?: boolean; // Sidebar / panel state. isSidebarCollapsed: boolean; @@ -130,6 +136,9 @@ interface AgentChatPageViewProps { handleDeleteQueuedMessage: (id: number) => Promise; handlePromoteQueuedMessage: (id: number) => Promise; + onImplementPlan?: () => Promise | void; + onSendAskUserQuestionResponse?: (message: string) => Promise | void; + // Archive actions. handleArchiveAgentAction: () => void; handleUnarchiveAgentAction: () => void; @@ -180,10 +189,16 @@ export const AgentChatPageView: FC = ({ modelSelectorHelp, hasModelOptions, isModelCatalogLoading = false, + planModeEnabled, + onPlanModeToggle, compressionThreshold, isInputDisabled, isSubmissionPending, isInterruptPending, + workspaceOptions = [], + selectedWorkspaceId = null, + onWorkspaceChange = () => {}, + isWorkspaceLoading = false, isSidebarCollapsed, onToggleSidebarCollapsed, showSidebarPanel, @@ -196,6 +211,8 @@ export const AgentChatPageView: FC = ({ handleInterrupt, handleDeleteQueuedMessage, handlePromoteQueuedMessage, + onImplementPlan, + onSendAskUserQuestionResponse, handleArchiveAgentAction, handleUnarchiveAgentAction, handleArchiveAndDeleteWorkspaceAction, @@ -219,6 +236,11 @@ export const AgentChatPageView: FC = ({ // Wrap the git watcher refresh to also invalidate the cached // remote/PR diff contents so the panel re-fetches from GitHub. + const canSendAskUserQuestionResponse = + !isInputDisabled && !isSubmissionPending + ? onSendAskUserQuestionResponse + : undefined; + const handleRefresh = () => { const sent = gitWatcher.refresh(); if (sent && agentId) { @@ -275,6 +297,7 @@ export const AgentChatPageView: FC = ({ ); const statusIcon = ; return { + id: workspace.id, name: workspace.name, route: workspaceRoute, statusIcon, @@ -368,6 +391,8 @@ export const AgentChatPageView: FC = ({ editingMessageId={editing.editingMessageId} urlTransform={urlTransform} mcpServers={mcpServers} + onImplementPlan={onImplementPlan} + onSendAskUserQuestionResponse={canSendAskUserQuestionResponse} /> @@ -389,7 +414,13 @@ export const AgentChatPageView: FC = ({ modelOptions={modelOptions} modelSelectorPlaceholder={modelSelectorPlaceholder} modelSelectorHelp={modelSelectorHelp} + planModeEnabled={planModeEnabled} + onPlanModeToggle={onPlanModeToggle} isModelCatalogLoading={isModelCatalogLoading} + workspaceOptions={workspaceOptions} + selectedWorkspaceId={selectedWorkspaceId} + onWorkspaceChange={onWorkspaceChange} + isWorkspaceLoading={isWorkspaceLoading} inputRef={editing.chatInputRef} initialValue={editing.editorInitialValue} initialEditorState={editing.initialEditorState} @@ -493,6 +524,8 @@ interface AgentChatPageLoadingViewProps { modelSelectorPlaceholder: string; hasModelOptions: boolean; isModelCatalogLoading?: boolean; + planModeEnabled?: boolean; + onPlanModeToggle?: (enabled: boolean) => void; isSidebarCollapsed: boolean; onToggleSidebarCollapsed: () => void; showRightPanel: boolean; @@ -507,6 +540,8 @@ export const AgentChatPageLoadingView: FC = ({ modelSelectorPlaceholder, hasModelOptions, isModelCatalogLoading = false, + planModeEnabled, + onPlanModeToggle, isSidebarCollapsed, onToggleSidebarCollapsed, showRightPanel, @@ -556,6 +591,8 @@ export const AgentChatPageLoadingView: FC = ({ onModelChange={setSelectedModel} modelOptions={modelOptions} modelSelectorPlaceholder={modelSelectorPlaceholder} + planModeEnabled={planModeEnabled} + onPlanModeToggle={onPlanModeToggle} isModelCatalogLoading={isModelCatalogLoading} hasModelOptions={hasModelOptions} /> diff --git a/site/src/pages/AgentsPage/AgentCreatePage.tsx b/site/src/pages/AgentsPage/AgentCreatePage.tsx index b0f332e973..2e6486294a 100644 --- a/site/src/pages/AgentsPage/AgentCreatePage.tsx +++ b/site/src/pages/AgentsPage/AgentCreatePage.tsx @@ -18,6 +18,7 @@ import { AgentPageHeader } from "./components/AgentPageHeader"; import { ChimeButton } from "./components/ChimeButton"; import { WebPushButton } from "./components/WebPushButton"; import { getModelOptionsFromConfigs } from "./utils/modelOptions"; +import { buildAgentChatPath } from "./utils/navigation"; const lastModelConfigIDStorageKey = "agents.last-model-config-id"; const nilUUID = "00000000-0000-0000-0000-000000000000"; @@ -45,6 +46,7 @@ const AgentCreatePage: FC = () => { model, mcpServerIds, organizationId, + planMode, }: CreateChatOptions) => { const modelConfigID = model || nilUUID; const content: TypesGen.ChatInputPart[] = []; @@ -63,6 +65,7 @@ const AgentCreatePage: FC = () => { model_config_id: modelConfigID, mcp_server_ids: mcpServerIds && mcpServerIds.length > 0 ? mcpServerIds : undefined, + plan_mode: planMode === "plan" ? "plan" : undefined, }); if (modelConfigID !== nilUUID) { @@ -70,7 +73,7 @@ const AgentCreatePage: FC = () => { } else { localStorage.removeItem(lastModelConfigIDStorageKey); } - navigate(`/agents/${createdChat.id}`); + navigate(buildAgentChatPath({ chatId: createdChat.id })); }; return ( diff --git a/site/src/pages/AgentsPage/AgentSettingsBehaviorPage.tsx b/site/src/pages/AgentsPage/AgentSettingsBehaviorPage.tsx index cb6d3634b5..49ca5609aa 100644 --- a/site/src/pages/AgentsPage/AgentSettingsBehaviorPage.tsx +++ b/site/src/pages/AgentsPage/AgentSettingsBehaviorPage.tsx @@ -3,12 +3,14 @@ import { useMutation, useQuery, useQueryClient } from "react-query"; import { chatDesktopEnabled, chatModelConfigs, + chatPlanModeInstructions, chatRetentionDays, chatSystemPrompt, chatUserCustomPrompt, chatWorkspaceTTL, deleteUserCompactionThreshold, updateChatDesktopEnabled, + updateChatPlanModeInstructions, updateChatRetentionDays, updateChatSystemPrompt, updateChatWorkspaceTTL, @@ -30,6 +32,13 @@ const AgentSettingsBehaviorPage: FC = () => { const saveSystemPromptMutation = useMutation( updateChatSystemPrompt(queryClient), ); + const planModeInstructionsQuery = useQuery({ + ...chatPlanModeInstructions(), + enabled: permissions.editDeploymentConfig, + }); + const savePlanModeInstructionsMutation = useMutation( + updateChatPlanModeInstructions(queryClient), + ); const userPromptQuery = useQuery(chatUserCustomPrompt()); const saveUserPromptMutation = useMutation( @@ -77,6 +86,7 @@ const AgentSettingsBehaviorPage: FC = () => { { onSaveSystemPrompt={saveSystemPromptMutation.mutate} isSavingSystemPrompt={saveSystemPromptMutation.isPending} isSaveSystemPromptError={saveSystemPromptMutation.isError} + onSavePlanModeInstructions={savePlanModeInstructionsMutation.mutate} + isSavingPlanModeInstructions={savePlanModeInstructionsMutation.isPending} + isSavePlanModeInstructionsError={savePlanModeInstructionsMutation.isError} onSaveUserPrompt={saveUserPromptMutation.mutate} isSavingUserPrompt={saveUserPromptMutation.isPending} isSaveUserPromptError={saveUserPromptMutation.isError} diff --git a/site/src/pages/AgentsPage/AgentSettingsBehaviorPageView.stories.tsx b/site/src/pages/AgentsPage/AgentSettingsBehaviorPageView.stories.tsx index c6ebfdf913..6d2f289fae 100644 --- a/site/src/pages/AgentsPage/AgentSettingsBehaviorPageView.stories.tsx +++ b/site/src/pages/AgentsPage/AgentSettingsBehaviorPageView.stories.tsx @@ -15,6 +15,9 @@ const baseProps = { include_default_system_prompt: true, default_system_prompt: mockDefaultSystemPrompt, } as TypesGen.ChatSystemPromptResponse, + planModeInstructionsData: { + plan_mode_instructions: "", + } as TypesGen.ChatPlanModeInstructionsResponse, userPromptData: { custom_prompt: "" } as TypesGen.UserChatCustomPrompt, desktopEnabledData: { enable_desktop: false, @@ -32,6 +35,8 @@ const baseProps = { thresholdsError: undefined as unknown, isSavingSystemPrompt: false, isSaveSystemPromptError: false, + isSavingPlanModeInstructions: false, + isSavePlanModeInstructionsError: false, isSavingUserPrompt: false, isSaveUserPromptError: false, isSavingDesktopEnabled: false, @@ -46,6 +51,7 @@ const meta = { args: { ...baseProps, onSaveSystemPrompt: fn(), + onSavePlanModeInstructions: fn(), onSaveUserPrompt: fn(), onSaveDesktopEnabled: fn(), onSaveWorkspaceTTL: fn(), diff --git a/site/src/pages/AgentsPage/AgentSettingsBehaviorPageView.tsx b/site/src/pages/AgentsPage/AgentSettingsBehaviorPageView.tsx index 4793f61cd2..550f8b725e 100644 --- a/site/src/pages/AgentsPage/AgentSettingsBehaviorPageView.tsx +++ b/site/src/pages/AgentsPage/AgentSettingsBehaviorPageView.tsx @@ -2,6 +2,7 @@ import type { FC } from "react"; import type * as TypesGen from "#/api/typesGenerated"; import { ChatFullWidthSettings } from "./components/ChatFullWidthSettings"; import { PersonalInstructionsSettings } from "./components/PersonalInstructionsSettings"; +import { PlanModeInstructionsSettings } from "./components/PlanModeInstructionsSettings"; import { RetentionPeriodSettings } from "./components/RetentionPeriodSettings"; import { SectionHeader } from "./components/SectionHeader"; import { SystemInstructionsSettings } from "./components/SystemInstructionsSettings"; @@ -19,6 +20,9 @@ interface AgentSettingsBehaviorPageViewProps { // Raw query data systemPromptData: TypesGen.ChatSystemPromptResponse | undefined; + planModeInstructionsData: + | TypesGen.ChatPlanModeInstructionsResponse + | undefined; userPromptData: TypesGen.UserChatCustomPrompt | undefined; desktopEnabledData: TypesGen.ChatDesktopEnabledResponse | undefined; workspaceTTLData: TypesGen.ChatWorkspaceTTLResponse | undefined; @@ -49,6 +53,13 @@ interface AgentSettingsBehaviorPageViewProps { isSavingSystemPrompt: boolean; isSaveSystemPromptError: boolean; + onSavePlanModeInstructions: ( + req: TypesGen.UpdateChatPlanModeInstructionsRequest, + options?: MutationCallbacks, + ) => void; + isSavingPlanModeInstructions: boolean; + isSavePlanModeInstructionsError: boolean; + onSaveUserPrompt: ( req: TypesGen.UserChatCustomPrompt, options?: MutationCallbacks, @@ -83,6 +94,7 @@ export const AgentSettingsBehaviorPageView: FC< > = ({ canSetSystemPrompt, systemPromptData, + planModeInstructionsData, userPromptData, desktopEnabledData, workspaceTTLData, @@ -102,6 +114,9 @@ export const AgentSettingsBehaviorPageView: FC< onSaveSystemPrompt, isSavingSystemPrompt, isSaveSystemPromptError, + onSavePlanModeInstructions, + isSavingPlanModeInstructions, + isSavePlanModeInstructionsError, onSaveUserPrompt, isSavingUserPrompt, isSaveUserPromptError, @@ -115,7 +130,8 @@ export const AgentSettingsBehaviorPageView: FC< isSavingRetentionDays, isSaveRetentionDaysError, }) => { - const isAnyPromptSaving = isSavingSystemPrompt || isSavingUserPrompt; + const isAnyPromptSaving = + isSavingSystemPrompt || isSavingUserPrompt || isSavingPlanModeInstructions; return ( <> @@ -158,6 +174,13 @@ export const AgentSettingsBehaviorPageView: FC< isAnyPromptSaving={isAnyPromptSaving} />
+ +
{ include_default_system_prompt: true, default_system_prompt: "You are Coder, an AI coding assistant...", }} + planModeInstructionsData={{ + plan_mode_instructions: "", + }} userPromptData={{ custom_prompt: "" }} desktopEnabledData={{ enable_desktop: false }} workspaceTTLData={{ workspace_ttl_ms: 0 }} @@ -174,6 +177,9 @@ const BehaviorRouteElement = () => { onSaveSystemPrompt={fn()} isSavingSystemPrompt={false} isSaveSystemPromptError={false} + onSavePlanModeInstructions={fn()} + isSavingPlanModeInstructions={false} + isSavePlanModeInstructionsError={false} onSaveUserPrompt={fn()} isSavingUserPrompt={false} isSaveUserPromptError={false} diff --git a/site/src/pages/AgentsPage/components/AgentChatInput.stories.tsx b/site/src/pages/AgentsPage/components/AgentChatInput.stories.tsx index 8992d7573d..f96eb754bb 100644 --- a/site/src/pages/AgentsPage/components/AgentChatInput.stories.tsx +++ b/site/src/pages/AgentsPage/components/AgentChatInput.stories.tsx @@ -1,4 +1,5 @@ import type { Meta, StoryObj } from "@storybook/react-vite"; +import { MonitorDotIcon } from "lucide-react"; import { useEffect, useRef } from "react"; import { expect, fn, userEvent, waitFor, within } from "storybook/test"; import type * as TypesGen from "#/api/typesGenerated"; @@ -585,6 +586,101 @@ export const PlusMenuOpen: Story = { }, }; +export const PlanFirstMenuItem: Story = { + args: { + onPlanModeToggle: fn(), + }, + play: async ({ canvasElement }) => { + const canvas = within(canvasElement); + const body = within(canvasElement.ownerDocument.body); + await userEvent.click(canvas.getByRole("button", { name: "More options" })); + await body.findByRole("dialog"); + const toggles = await body.findAllByRole("menuitemcheckbox", { + name: "Plan first", + }); + const toggle = toggles.at(-1)!; + expect(toggle).toBeInTheDocument(); + }, +}; + +export const PlanningIndicator: Story = { + args: { + planModeEnabled: true, + onPlanModeToggle: fn(), + }, + play: async ({ canvasElement }) => { + const canvas = within(canvasElement); + expect(canvas.getByText("Planning")).toBeVisible(); + }, +}; + +export const PlanFirstCheckedState: Story = { + args: { + planModeEnabled: true, + onPlanModeToggle: fn(), + }, + play: async ({ canvasElement }) => { + const canvas = within(canvasElement); + const body = within(canvasElement.ownerDocument.body); + await userEvent.click(canvas.getByRole("button", { name: "More options" })); + await body.findByRole("dialog"); + const toggles = await body.findAllByRole("menuitemcheckbox", { + name: "Plan first", + }); + const toggle = toggles.at(-1)!; + expect(toggle).toHaveAttribute("aria-checked", "true"); + }, +}; + +export const DetailPageWorkspacePicker: Story = { + args: { + workspaceOptions: [ + { + id: "ws-detail", + name: "agents-workspace", + owner_name: "mike", + }, + ], + selectedWorkspaceId: "ws-detail", + onWorkspaceChange: fn(), + attachedWorkspace: { + id: "ws-detail", + name: "agents-workspace", + route: "/@mike/agents-workspace", + statusIcon: , + statusLabel: "Workspace running", + }, + }, + play: async ({ canvasElement }) => { + const canvas = within(canvasElement); + + expect(canvas.getAllByText("agents-workspace")).toHaveLength(1); + expect( + canvas.queryByRole("button", { + name: "Remove workspace agents-workspace", + }), + ).not.toBeInTheDocument(); + + const moreOptionsButton = canvas.getByRole("button", { + name: "More options", + }); + await userEvent.click(moreOptionsButton); + await waitFor(() => { + const plusMenuId = moreOptionsButton.getAttribute("aria-controls"); + if (!plusMenuId) { + throw new Error("Expected More options to control a menu dialog."); + } + + const plusMenu = canvasElement.ownerDocument.getElementById(plusMenuId); + if (!(plusMenu instanceof HTMLElement)) { + throw new Error("Expected More options menu dialog to render."); + } + + expect(within(plusMenu).getByText("Attach workspace")).toBeVisible(); + }); + }, +}; + const confluenceMCP = makeMCPServer({ id: "mcp-confluence", display_name: "Confluence Cloud", diff --git a/site/src/pages/AgentsPage/components/AgentChatInput.tsx b/site/src/pages/AgentsPage/components/AgentChatInput.tsx index 65b42a462d..be99684ce6 100644 --- a/site/src/pages/AgentsPage/components/AgentChatInput.tsx +++ b/site/src/pages/AgentsPage/components/AgentChatInput.tsx @@ -102,6 +102,8 @@ interface AgentChatInputProps { modelOptions: readonly ModelSelectorOption[]; modelSelectorPlaceholder: string; hasModelOptions: boolean; + planModeEnabled?: boolean; + onPlanModeToggle?: (enabled: boolean) => void; isModelCatalogLoading?: boolean; // Streaming controls (optional, for the detail page). isStreaming?: boolean; @@ -158,6 +160,7 @@ interface AgentChatInputProps { } export interface AttachedWorkspaceInfo { + id: string; name: string; route: string; statusIcon: React.ReactNode; @@ -168,16 +171,16 @@ type ToolBadgeData = | ({ kind: "attached-workspace" } & AttachedWorkspaceInfo) | { kind: "mcp"; server: TypesGen.MCPServerConfig }; +const toolBadgeClassName = + "inline-flex shrink-0 items-center gap-1 rounded-full bg-surface-secondary px-2 py-0.5 text-xs font-medium text-content-secondary"; + const ToolBadge: FC<{ badge: ToolBadgeData; onRemoveWorkspace?: () => void; onRemoveMcp?: (serverId: string) => void; className?: string; }> = ({ badge, onRemoveWorkspace, onRemoveMcp, className }) => { - const badgeCls = cn( - "inline-flex shrink-0 items-center gap-1 rounded-full bg-surface-secondary px-2 py-0.5 text-xs font-medium text-content-secondary", - className, - ); + const badgeCls = cn(toolBadgeClassName, className); if (badge.kind === "attached-workspace") { return ( @@ -262,6 +265,8 @@ export const AgentChatInput: FC = ({ modelOptions, modelSelectorPlaceholder, hasModelOptions, + planModeEnabled = false, + onPlanModeToggle, isModelCatalogLoading = false, isStreaming = false, onInterrupt, @@ -392,6 +397,11 @@ export const AgentChatInput: FC = ({ (ws) => ws.id === selectedWorkspaceId, ); + const shouldShowSelectedWorkspaceBadge = selectedWorkspace + ? Boolean(onWorkspaceChange) && + selectedWorkspace.id !== attachedWorkspace?.id + : false; + const enabledMcpServers = mcpServers?.filter((s) => s.enabled) ?? []; const activeMcpServers = enabledMcpServers.filter( (s) => @@ -412,7 +422,7 @@ export const AgentChatInput: FC = ({ if (!(workspace && workspaceAgent && chatId) && attachedWorkspace) { allBadges.push({ kind: "attached-workspace", ...attachedWorkspace }); } - if (selectedWorkspace && onWorkspaceChange) { + if (shouldShowSelectedWorkspaceBadge && selectedWorkspace) { allBadges.push({ kind: "workspace", name: selectedWorkspace.name }); } for (const s of activeMcpServers) { @@ -427,6 +437,11 @@ export const AgentChatInput: FC = ({ const handleRemoveMcp = (serverId: string) => handleMcpToggle(serverId, false); + const handlePlanModeToggle = () => { + onPlanModeToggle?.(!planModeEnabled); + setPlusMenuOpen(false); + }; + const fileInputRef = useRef(null); const handleFileSelect = (e: React.ChangeEvent) => { if (e.target.files && onAttach) { @@ -793,6 +808,22 @@ export const AgentChatInput: FC = ({ Attach image )} + {onPlanModeToggle && ( + + )} {workspaceOptions && onWorkspaceChange && ( = ({ dropdownAlign="center" /> )} + {planModeEnabled && ( + + + Planning + + )} {/* Badge row — all badges and the pill always * render so the DOM structure never changes. * Overflow badges use invisible + order-1 to diff --git a/site/src/pages/AgentsPage/components/AgentCreateForm.tsx b/site/src/pages/AgentsPage/components/AgentCreateForm.tsx index 153ca00383..a284e9323c 100644 --- a/site/src/pages/AgentsPage/components/AgentCreateForm.tsx +++ b/site/src/pages/AgentsPage/components/AgentCreateForm.tsx @@ -47,6 +47,7 @@ export type CreateChatOptions = { model?: string; mcpServerIds?: string[]; organizationId: string; + planMode?: TypesGen.ChatPlanMode; }; /** @@ -225,6 +226,7 @@ export const AgentCreateForm: FC = ({ const [pendingOrgChange, setPendingOrgChange] = useState(null); const organizationId = selectedOrg?.id ?? ""; + const [planModeEnabled, setPlanModeEnabled] = useState(false); const hasModelOptions = modelOptions.length > 0; const hasConfiguredModels = hasConfiguredModelsInCatalog(modelCatalog); const hasUserFixableModelProviders = hasUserFixableProviders(modelCatalog); @@ -321,6 +323,7 @@ export const AgentCreateForm: FC = ({ effectiveMCPServerIds.length > 0 ? [...effectiveMCPServerIds] : undefined, + planMode: planModeEnabled ? "plan" : undefined, }).catch((err) => { resetDraft(); throw err; @@ -470,6 +473,8 @@ export const AgentCreateForm: FC = ({ modelSelectorPlaceholder={modelSelectorPlaceholder} isModelCatalogLoading={isModelCatalogLoading} hasModelOptions={hasModelOptions} + planModeEnabled={planModeEnabled} + onPlanModeToggle={setPlanModeEnabled} attachments={attachments} onAttach={handleAttach} onRemoveAttachment={handleRemoveAttachment} diff --git a/site/src/pages/AgentsPage/components/ChatConversation/ConversationTimeline.stories.tsx b/site/src/pages/AgentsPage/components/ChatConversation/ConversationTimeline.stories.tsx index 66b2e9cee9..553be086a9 100644 --- a/site/src/pages/AgentsPage/components/ChatConversation/ConversationTimeline.stories.tsx +++ b/site/src/pages/AgentsPage/components/ChatConversation/ConversationTimeline.stories.tsx @@ -16,6 +16,48 @@ const baseMessage = { created_at: "2026-03-10T00:00:00.000Z", } as const; +const askUserQuestionPayload = { + questions: [ + { + header: "Implementation Approach", + question: "How should we structure the database migration?", + options: [ + { + label: "Single migration", + description: + "One migration file with all changes. Simpler but harder to roll back.", + }, + { + label: "Incremental migrations", + description: + "Split into multiple sequential migrations. More flexible rollback.", + }, + ], + }, + { + header: "Release Plan", + question: "Which rollout path should we use for the new agent workflow?", + options: [ + { + label: "Internal dry run", + description: + "Ship to the team first and confirm the migration flow before broader rollout.", + }, + { + label: "Small beta", + description: + "Start with a limited set of workspaces so we can gather feedback quickly.", + }, + ], + }, + ], +}; + +const askUserQuestionSubmittedResponse = [ + "1. Implementation Approach: Incremental migrations", + "2. Release Plan: Small beta", +].join("\n"); + const TEXT_ATTACHMENT_RESPONSES = new Map([ [ "storybook-test-text", @@ -682,6 +724,79 @@ export const AssistantMessageCopyButton: Story = { }, }; +/** Persisted ask-user-question answers survive reloads. */ +export const AskUserQuestionSubmittedAnswer: Story = { + args: { + ...defaultArgs, + isChatCompleted: true, + parsedMessages: buildMessages([ + { + ...baseMessage, + id: 1, + role: "user", + content: [{ type: "text", text: "Help me pick a rollout plan." }], + }, + { + ...baseMessage, + id: 2, + role: "assistant", + content: [ + { + type: "tool-call", + tool_call_id: "ask-tool-1", + tool_name: "ask_user_question", + }, + ], + }, + { + ...baseMessage, + id: 3, + role: "tool", + content: [ + { + type: "tool-result", + tool_call_id: "ask-tool-1", + result: { + output: JSON.stringify(askUserQuestionPayload), + }, + }, + ], + }, + { + ...baseMessage, + id: 4, + role: "user", + content: [{ type: "text", text: askUserQuestionSubmittedResponse }], + }, + ]), + }, + play: async ({ canvasElement }) => { + const canvas = within(canvasElement); + + // The submitted-answer summary is hidden after the follow-up user message. + expect( + canvas.getByText("How should we structure the database migration?"), + ).toBeInTheDocument(); + expect(canvas.queryAllByRole("radio")).toHaveLength(0); + expect( + canvas.queryByRole("button", { name: "Submit" }), + ).not.toBeInTheDocument(); + const userMessages = canvasElement.querySelectorAll('[data-role="user"]'); + const latestUserMessage = userMessages[userMessages.length - 1]; + if (!(latestUserMessage instanceof HTMLElement)) { + throw new Error("Expected a submitted user message bubble."); + } + expect( + within(latestUserMessage).getByText( + /Implementation Approach: Incremental migrations/, + ), + ).toBeInTheDocument(); + expect( + within(latestUserMessage).getByText(/Release Plan: Small beta/), + ).toBeInTheDocument(); + }, +}; + /** No copy button when assistant message has no markdown content. */ export const AssistantMessageNoCopyWhenToolOnly: Story = { args: { diff --git a/site/src/pages/AgentsPage/components/ChatConversation/ConversationTimeline.tsx b/site/src/pages/AgentsPage/components/ChatConversation/ConversationTimeline.tsx index 689262f02e..ca687c90fe 100644 --- a/site/src/pages/AgentsPage/components/ChatConversation/ConversationTimeline.tsx +++ b/site/src/pages/AgentsPage/components/ChatConversation/ConversationTimeline.tsx @@ -45,6 +45,23 @@ import type { RenderBlock, } from "./types"; +const getChatMessageTextContent = ( + content: readonly TypesGen.ChatMessagePart[] | undefined, +): string | undefined => { + if (!content) { + return undefined; + } + + let textContent = ""; + for (const part of content) { + if (part.type === "text") { + textContent += part.text; + } + } + + return textContent.length > 0 ? textContent : undefined; +}; + const ReasoningDisclosure = memo<{ id: string; text: string; @@ -253,6 +270,12 @@ export const BlockList: FC<{ mcpServers?: readonly TypesGen.MCPServerConfig[]; onImageClick?: (src: string) => void; onTextFileClick?: (content: string) => void; + onImplementPlan?: () => Promise | void; + onSendAskUserQuestionResponse?: (message: string) => Promise | void; + isChatCompleted?: boolean; + latestAskUserQuestionToolId?: string; + askUserQuestionResponseTextByToolId?: ReadonlyMap; + hasUserResponseAfterAskQuestion?: boolean; urlTransform?: UrlTransform; }> = ({ blocks, @@ -266,6 +289,12 @@ export const BlockList: FC<{ mcpServers, onImageClick, onTextFileClick, + onImplementPlan, + onSendAskUserQuestionResponse, + isChatCompleted, + latestAskUserQuestionToolId, + askUserQuestionResponseTextByToolId, + hasUserResponseAfterAskQuestion = false, urlTransform, }) => { const toolByID = new Map(tools.map((tool) => [tool.id, tool])); @@ -369,6 +398,18 @@ export const BlockList: FC<{ } mcpServerConfigId={tool.mcpServerConfigId} mcpServers={mcpServers} + onImplementPlan={onImplementPlan} + onSendAskUserQuestionResponse={onSendAskUserQuestionResponse} + isChatCompleted={isChatCompleted} + isLatestAskUserQuestion={ + tool.id === latestAskUserQuestionToolId && + !hasUserResponseAfterAskQuestion + } + previousResponseText={ + tool.name === "ask_user_question" + ? askUserQuestionResponseTextByToolId?.get(tool.id) + : undefined + } modelIntent={tool.modelIntent} /> ); @@ -410,6 +451,18 @@ export const BlockList: FC<{ } mcpServerConfigId={tool.mcpServerConfigId} mcpServers={mcpServers} + onImplementPlan={onImplementPlan} + onSendAskUserQuestionResponse={onSendAskUserQuestionResponse} + isChatCompleted={isChatCompleted} + isLatestAskUserQuestion={ + tool.id === latestAskUserQuestionToolId && + !hasUserResponseAfterAskQuestion + } + previousResponseText={ + tool.name === "ask_user_question" + ? askUserQuestionResponseTextByToolId?.get(tool.id) + : undefined + } modelIntent={tool.modelIntent} /> ))} @@ -433,11 +486,17 @@ const ChatMessageItem = memo<{ // that fades text out toward the bottom. Used by the sticky // overlay to indicate truncated content. fadeFromBottom?: boolean; + onImplementPlan?: () => Promise | void; urlTransform?: UrlTransform; mcpServers?: readonly TypesGen.MCPServerConfig[]; subagentTitles?: Map; computerUseSubagentIds?: Set; showDesktopPreviews?: boolean; + onSendAskUserQuestionResponse?: (message: string) => Promise | void; + isChatCompleted?: boolean; + latestAskUserQuestionToolId?: string; + askUserQuestionResponseTextByToolId?: ReadonlyMap; + hasUserResponseAfterAskQuestion?: boolean; }>( ({ message, @@ -447,6 +506,12 @@ const ChatMessageItem = memo<{ isAfterEditingMessage = false, hideActions = false, fadeFromBottom = false, + onImplementPlan, + onSendAskUserQuestionResponse, + isChatCompleted, + latestAskUserQuestionToolId, + askUserQuestionResponseTextByToolId, + hasUserResponseAfterAskQuestion = false, urlTransform, mcpServers, @@ -610,6 +675,18 @@ const ChatMessageItem = memo<{ subagentTitles={subagentTitles} computerUseSubagentIds={computerUseSubagentIds} showDesktopPreviews={showDesktopPreviews} + onImplementPlan={onImplementPlan} + onSendAskUserQuestionResponse={ + onSendAskUserQuestionResponse + } + isChatCompleted={isChatCompleted} + latestAskUserQuestionToolId={latestAskUserQuestionToolId} + askUserQuestionResponseTextByToolId={ + askUserQuestionResponseTextByToolId + } + hasUserResponseAfterAskQuestion={ + hasUserResponseAfterAskQuestion + } onImageClick={setPreviewImage} onTextFileClick={setPreviewText} urlTransform={urlTransform} @@ -986,6 +1063,9 @@ interface ConversationTimelineProps { fileBlocks?: readonly TypesGen.ChatMessagePart[], ) => void; editingMessageId?: number | null; + onImplementPlan?: () => Promise | void; + onSendAskUserQuestionResponse?: (message: string) => Promise | void; + isChatCompleted?: boolean; urlTransform?: UrlTransform; mcpServers?: readonly TypesGen.MCPServerConfig[]; computerUseSubagentIds?: Set; @@ -999,6 +1079,9 @@ export const ConversationTimeline = memo( subagentTitles, onEditUserMessage, editingMessageId, + onImplementPlan, + onSendAskUserQuestionResponse, + isChatCompleted, urlTransform, mcpServers, computerUseSubagentIds, @@ -1024,6 +1107,42 @@ export const ConversationTimeline = memo( } } + let latestAskUserQuestionToolId: string | undefined; + let hasUserResponseAfterAskQuestion = false; + const askUserQuestionResponseTextByToolId = new Map(); + let pendingAskUserQuestionToolId: string | undefined; + for (const { message, parsed } of parsedMessages) { + let askUserQuestionToolIdInMessage: string | undefined; + for (const tool of parsed.tools) { + if (tool.name === "ask_user_question") { + askUserQuestionToolIdInMessage = tool.id; + latestAskUserQuestionToolId = tool.id; + hasUserResponseAfterAskQuestion = false; + } + } + + if (askUserQuestionToolIdInMessage) { + pendingAskUserQuestionToolId = askUserQuestionToolIdInMessage; + } + + if (pendingAskUserQuestionToolId && message.role === "user") { + hasUserResponseAfterAskQuestion = + pendingAskUserQuestionToolId === latestAskUserQuestionToolId; + const responseText = getChatMessageTextContent(message.content); + if (responseText !== undefined) { + askUserQuestionResponseTextByToolId.set( + pendingAskUserQuestionToolId, + responseText, + ); + } + pendingAskUserQuestionToolId = undefined; + } + } + const historicalAskUserQuestionResponseTextByToolId = + askUserQuestionResponseTextByToolId.size > 0 + ? askUserQuestionResponseTextByToolId + : undefined; + return (
{parsedMessages.map(({ message, parsed }, msgIdx) => { @@ -1048,6 +1167,14 @@ export const ConversationTimeline = memo( key={message.id} message={message} parsed={parsed} + onImplementPlan={onImplementPlan} + onSendAskUserQuestionResponse={onSendAskUserQuestionResponse} + isChatCompleted={isChatCompleted} + latestAskUserQuestionToolId={latestAskUserQuestionToolId} + askUserQuestionResponseTextByToolId={ + historicalAskUserQuestionResponseTextByToolId + } + hasUserResponseAfterAskQuestion={hasUserResponseAfterAskQuestion} urlTransform={urlTransform} isAfterEditingMessage={afterEditingMessageIds.has(message.id)} hideActions={!isLastInChain} diff --git a/site/src/pages/AgentsPage/components/ChatElements/tools/AskUserQuestionTool.stories.tsx b/site/src/pages/AgentsPage/components/ChatElements/tools/AskUserQuestionTool.stories.tsx new file mode 100644 index 0000000000..9181c7dd13 --- /dev/null +++ b/site/src/pages/AgentsPage/components/ChatElements/tools/AskUserQuestionTool.stories.tsx @@ -0,0 +1,457 @@ +import type { Meta, StoryObj } from "@storybook/react-vite"; +import { expect, fn, userEvent, within } from "storybook/test"; +import { Tool } from "./Tool"; + +const runningPayload = { + questions: [ + { + header: "Implementation Approach", + question: "How should we structure the database migration?", + options: [ + { + label: "Single migration", + description: + "One migration file with all changes. Simpler but harder to roll back.", + }, + { + label: "Incremental migrations", + description: + "Split into multiple sequential migrations. More flexible rollback.", + }, + ], + }, + ], +}; + +const singleQuestionPayload = { + questions: [ + { + header: "Implementation Approach", + question: "How should we structure the database migration?", + options: [ + { + label: "Single migration", + description: + "One migration file with all changes. Simpler but harder to roll back.", + }, + { + label: "Incremental migrations", + description: + "Split into multiple sequential migrations. More flexible rollback.", + }, + ], + }, + ], +}; + +const multipleQuestionsPayload = { + questions: [ + { + header: "Implementation Approach", + question: "How should we structure the database migration?", + options: [ + { + label: "Single migration", + description: + "One migration file with all changes. Simpler but harder to roll back.", + }, + { + label: "Incremental migrations", + description: + "Split into multiple sequential migrations. More flexible rollback.", + }, + ], + }, + { + header: "Release Plan", + question: "Which rollout path should we use for the new agent workflow?", + options: [ + { + label: "Internal dry run", + description: + "Ship to the team first and confirm the migration flow before broader rollout.", + }, + { + label: "Small beta", + description: + "Start with a limited set of workspaces so we can gather feedback quickly.", + }, + { + label: "General rollout", + description: + "Release to every workspace after validation is complete.", + }, + ], + }, + ], +}; + +const submittedWizardResponse = [ + "1. Implementation Approach: Incremental migrations", + "2. Release Plan: Small beta", +].join("\n"); + +const meta: Meta = { + title: "pages/AgentsPage/ChatElements/tools/AskUserQuestion", + component: Tool, + decorators: [ + (Story) => ( +
+ +
+ ), + ], + args: { name: "ask_user_question" }, +}; + +export default meta; + +type Story = StoryObj; + +export const Running: Story = { + args: { + status: "running", + args: runningPayload, + }, + play: async ({ canvasElement }) => { + const canvas = within(canvasElement); + + expect(canvas.getByText("Asking for clarification...")).toBeInTheDocument(); + expect( + canvas.getByTestId("ask-user-question-loading-icon"), + ).toBeInTheDocument(); + expect(canvas.getAllByRole("radio")).toHaveLength(3); + }, +}; + +export const InteractiveSingleQuestion: Story = { + args: { + status: "completed", + result: JSON.stringify(singleQuestionPayload), + isChatCompleted: true, + isLatestAskUserQuestion: true, + onSendAskUserQuestionResponse: fn(), + }, + play: async ({ canvasElement, args }) => { + const canvas = within(canvasElement); + const submitButton = canvas.getByRole("button", { name: "Submit" }); + + expect(submitButton).toBeEnabled(); + expect(canvas.getAllByRole("radio")).toHaveLength(3); + + await userEvent.click( + canvas.getByRole("radio", { name: /single migration/i }), + ); + expect(submitButton).toBeEnabled(); + + await userEvent.click(submitButton); + if (!args.onSendAskUserQuestionResponse) { + throw new Error("Missing ask-user-question response callback."); + } + expect(args.onSendAskUserQuestionResponse).toHaveBeenCalledWith( + "Single migration", + ); + expect(canvas.getByText("Submitted answer")).toBeInTheDocument(); + expect(canvas.getByText("Single migration")).toBeInTheDocument(); + expect( + canvas.queryByRole("button", { name: "Submit" }), + ).not.toBeInTheDocument(); + }, +}; + +export const InteractiveSingleQuestionOther: Story = { + args: { + status: "completed", + result: JSON.stringify(singleQuestionPayload), + isChatCompleted: true, + isLatestAskUserQuestion: true, + onSendAskUserQuestionResponse: fn(), + }, + play: async ({ canvasElement, args }) => { + const canvas = within(canvasElement); + const submitButton = canvas.getByRole("button", { name: "Submit" }); + + await userEvent.click(canvas.getByRole("radio", { name: /^other/i })); + const otherInput = canvas.getByRole("textbox", { name: /other response/i }); + expect(otherInput).toHaveFocus(); + expect(submitButton).toBeDisabled(); + + await userEvent.type(otherInput, "Use a canary rollout"); + expect(submitButton).toBeEnabled(); + + await userEvent.click(submitButton); + if (!args.onSendAskUserQuestionResponse) { + throw new Error("Missing ask-user-question response callback."); + } + expect(args.onSendAskUserQuestionResponse).toHaveBeenCalledWith( + "Other: Use a canary rollout", + ); + expect(canvas.getByText("Other: Use a canary rollout")).toBeInTheDocument(); + }, +}; + +export const KeyboardNavigation: Story = { + args: { + status: "completed", + result: JSON.stringify(singleQuestionPayload), + isChatCompleted: true, + isLatestAskUserQuestion: true, + onSendAskUserQuestionResponse: fn(), + }, + play: async ({ canvasElement, args }) => { + const canvas = within(canvasElement); + const firstRadio = canvas.getByRole("radio", { + name: /single migration/i, + }); + const secondRadio = canvas.getByRole("radio", { + name: /incremental migrations/i, + }); + const submitButton = canvas.getByRole("button", { name: "Submit" }); + + expect(firstRadio).toBeChecked(); + + await userEvent.tab(); + expect(firstRadio).toHaveFocus(); + + await userEvent.keyboard("{ArrowDown}"); + expect(secondRadio).toHaveFocus(); + + await userEvent.keyboard(" "); + expect(secondRadio).toBeChecked(); + + await userEvent.tab(); + expect(submitButton).toHaveFocus(); + + await userEvent.keyboard("{Enter}"); + + if (!args.onSendAskUserQuestionResponse) { + throw new Error("Missing ask-user-question response callback."); + } + expect(args.onSendAskUserQuestionResponse).toHaveBeenCalledWith( + "Incremental migrations", + ); + expect(canvas.getByText("Submitted answer")).toBeInTheDocument(); + }, +}; + +export const KeyboardOtherSubmit: Story = { + args: { + status: "completed", + result: JSON.stringify(singleQuestionPayload), + isChatCompleted: true, + isLatestAskUserQuestion: true, + onSendAskUserQuestionResponse: fn(), + }, + play: async ({ canvasElement, args }) => { + const canvas = within(canvasElement); + const firstRadio = canvas.getByRole("radio", { + name: /single migration/i, + }); + const submitButton = canvas.getByRole("button", { name: "Submit" }); + + expect(firstRadio).toBeChecked(); + + await userEvent.tab(); + expect(firstRadio).toHaveFocus(); + + const secondRadio = canvas.getByRole("radio", { + name: /incremental migrations/i, + }); + const otherRadio = canvas.getByRole("radio", { name: /other/i }); + + await userEvent.keyboard("{ArrowDown}"); + expect(secondRadio).toHaveFocus(); + + await userEvent.keyboard("{ArrowDown}"); + expect(otherRadio).toHaveFocus(); + + await userEvent.keyboard(" "); + expect(otherRadio).toBeChecked(); + expect(submitButton).toBeDisabled(); + + const otherInput = canvas.getByPlaceholderText("Describe another answer"); + expect(otherInput).toHaveFocus(); + + await userEvent.type(otherInput, "Custom approach"); + expect(submitButton).toBeEnabled(); + + await userEvent.keyboard("{Enter}"); + + if (!args.onSendAskUserQuestionResponse) { + throw new Error("Missing ask-user-question response callback."); + } + expect(args.onSendAskUserQuestionResponse).toHaveBeenCalledWith( + "Other: Custom approach", + ); + expect(canvas.getByText("Submitted answer")).toBeInTheDocument(); + }, +}; + +export const InteractiveWizardStep: Story = { + args: { + status: "completed", + result: JSON.stringify(multipleQuestionsPayload), + isChatCompleted: true, + isLatestAskUserQuestion: true, + onSendAskUserQuestionResponse: fn(), + }, + play: async ({ canvasElement }) => { + const canvas = within(canvasElement); + const nextButton = canvas.getByRole("button", { name: "Next" }); + + expect(canvas.getByText("Question 1 of 2")).toBeInTheDocument(); + expect(nextButton).toBeEnabled(); + expect( + canvas.queryByText(/Which rollout path should we use/i), + ).not.toBeInTheDocument(); + + await userEvent.click( + canvas.getByRole("radio", { name: /incremental migrations/i }), + ); + expect(nextButton).toBeEnabled(); + + await userEvent.click(nextButton); + expect(canvas.getByText("Question 2 of 2")).toBeInTheDocument(); + expect( + canvas.getByText(/Which rollout path should we use/i), + ).toBeInTheDocument(); + expect(canvas.getByRole("button", { name: "Submit" })).toBeEnabled(); + }, +}; + +export const SubmittedWizard: Story = { + args: { + status: "completed", + result: JSON.stringify(multipleQuestionsPayload), + isChatCompleted: true, + isLatestAskUserQuestion: true, + onSendAskUserQuestionResponse: fn(), + }, + play: async ({ canvasElement, args }) => { + const canvas = within(canvasElement); + + await userEvent.click( + canvas.getByRole("radio", { name: /incremental migrations/i }), + ); + await userEvent.click(canvas.getByRole("button", { name: "Next" })); + await userEvent.click(canvas.getByRole("radio", { name: /small beta/i })); + await userEvent.click(canvas.getByRole("button", { name: "Submit" })); + + if (!args.onSendAskUserQuestionResponse) { + throw new Error("Missing ask-user-question response callback."); + } + expect(args.onSendAskUserQuestionResponse).toHaveBeenCalledWith( + submittedWizardResponse, + ); + expect(canvas.queryAllByRole("radio")).toHaveLength(0); + expect( + canvas.queryByRole("button", { name: "Next" }), + ).not.toBeInTheDocument(); + expect( + canvas.queryByRole("button", { name: "Submit" }), + ).not.toBeInTheDocument(); + const submittedAnswer = canvas.getByText("Submitted answer"); + expect(submittedAnswer).toBeInTheDocument(); + expect(submittedAnswer.nextElementSibling?.textContent).toBe( + submittedWizardResponse, + ); + }, +}; + +export const PreviouslyAnsweredSingleQuestion: Story = { + args: { + status: "completed", + result: JSON.stringify(singleQuestionPayload), + isChatCompleted: true, + isLatestAskUserQuestion: false, + previousResponseText: "Single migration", + }, + play: async ({ canvasElement }) => { + const canvas = within(canvasElement); + + expect( + canvas.getByText("How should we structure the database migration?"), + ).toBeInTheDocument(); + expect(canvas.queryAllByRole("radio")).toHaveLength(0); + expect(canvas.queryByText("Submitted answer")).not.toBeInTheDocument(); + expect( + canvas.queryByRole("button", { name: "Submit" }), + ).not.toBeInTheDocument(); + }, +}; + +export const PreviouslyAnsweredWizard: Story = { + args: { + status: "completed", + result: JSON.stringify(multipleQuestionsPayload), + isChatCompleted: true, + isLatestAskUserQuestion: false, + previousResponseText: submittedWizardResponse, + }, + play: async ({ canvasElement }) => { + const canvas = within(canvasElement); + + expect( + canvas.getByText(/How should we structure the database migration/), + ).toBeInTheDocument(); + expect( + canvas.getByText(/Which rollout path should we use/), + ).toBeInTheDocument(); + expect(canvas.queryAllByRole("radio")).toHaveLength(0); + expect(canvas.queryByText("Submitted answer")).not.toBeInTheDocument(); + expect( + canvas.queryByRole("button", { name: "Next" }), + ).not.toBeInTheDocument(); + expect( + canvas.queryByRole("button", { name: "Submit" }), + ).not.toBeInTheDocument(); + }, +}; + +export const ReadOnlyPreviousCall: Story = { + args: { + status: "completed", + result: JSON.stringify(multipleQuestionsPayload), + isChatCompleted: true, + isLatestAskUserQuestion: false, + onSendAskUserQuestionResponse: fn(), + }, + play: async ({ canvasElement }) => { + const canvas = within(canvasElement); + const radios = canvas.getAllByRole("radio"); + + expect(radios).toHaveLength(7); + expect(radios[0]).toBeDisabled(); + expect( + canvas.getByText(/How should we structure the database migration/), + ).toBeInTheDocument(); + expect( + canvas.getByText(/Which rollout path should we use/), + ).toBeInTheDocument(); + expect( + canvas.queryByRole("button", { name: "Next" }), + ).not.toBeInTheDocument(); + expect( + canvas.queryByRole("button", { name: "Submit" }), + ).not.toBeInTheDocument(); + }, +}; + +export const ErrorState: Story = { + args: { + status: "completed", + isError: true, + result: "The planning agent could not deliver follow-up questions.", + }, + play: async ({ canvasElement }) => { + const canvas = within(canvasElement); + + expect(canvas.getByRole("alert")).toBeInTheDocument(); + expect( + canvas.getByText( + "The planning agent could not deliver follow-up questions.", + ), + ).toBeInTheDocument(); + expect(canvas.getByLabelText("Error")).toBeInTheDocument(); + }, +}; diff --git a/site/src/pages/AgentsPage/components/ChatElements/tools/AskUserQuestionTool.tsx b/site/src/pages/AgentsPage/components/ChatElements/tools/AskUserQuestionTool.tsx new file mode 100644 index 0000000000..7c1ae7f8c0 --- /dev/null +++ b/site/src/pages/AgentsPage/components/ChatElements/tools/AskUserQuestionTool.tsx @@ -0,0 +1,573 @@ +import { LoaderIcon, TriangleAlertIcon } from "lucide-react"; +import { type FC, type FormEvent, useId, useState } from "react"; +import { Button } from "#/components/Button/Button"; +import { Input } from "#/components/Input/Input"; +import { RadioGroup, RadioGroupItem } from "#/components/RadioGroup/RadioGroup"; +import { cn } from "#/utils/cn"; +import type { ToolStatus } from "./utils"; + +export type AskUserQuestion = { + header: string; + question: string; + options: Array<{ label: string; description: string }>; +}; + +type QuestionAnswer = + | { + kind: "option"; + label: string; + optionIndex: number; + } + | { + kind: "other"; + text: string; + }; + +type AskUserQuestionToolProps = { + questions: AskUserQuestion[]; + status: ToolStatus; + isError: boolean; + errorMessage?: string; + isChatCompleted?: boolean; + isLatestAskUserQuestion?: boolean; + previousResponseText?: string; + onSubmitAnswer?: (message: string) => Promise | void; +}; + +const OTHER_OPTION_VALUE = "other"; + +const getQuestionHeader = ( + question: AskUserQuestion, + questionIndex: number, +): string => question.header || `Question ${questionIndex + 1}`; + +const getQuestionText = (question: AskUserQuestion): string => + question.question || "No question provided."; + +const filterQuestionOptions = (question: AskUserQuestion): AskUserQuestion => ({ + ...question, + options: question.options.filter( + (option) => option.label.trim().toLowerCase() !== "other", + ), +}); + +const formatAnswer = (answer: QuestionAnswer): string => + answer.kind === "other" + ? `Other: ${answer.text.trim()}` + : answer.label || `Option ${answer.optionIndex + 1}`; + +const cloneAnswer = ( + answer: QuestionAnswer | undefined, +): QuestionAnswer | undefined => { + if (!answer) { + return undefined; + } + + if (answer.kind === "other") { + return { kind: "other", text: answer.text }; + } + + return { + kind: "option", + label: answer.label, + optionIndex: answer.optionIndex, + }; +}; + +const isAnswerValid = ( + answer: QuestionAnswer | undefined, +): answer is QuestionAnswer => { + if (!answer) { + return false; + } + + if (answer.kind === "other") { + return answer.text.trim().length > 0; + } + + return answer.label.trim().length > 0; +}; + +const getSelectedValue = ( + answer: QuestionAnswer | undefined, +): string | undefined => { + if (!answer) { + return undefined; + } + + if (answer.kind === "other") { + return OTHER_OPTION_VALUE; + } + + return `option-${answer.optionIndex}`; +}; + +const formatOutgoingMessage = ( + questions: AskUserQuestion[], + answers: readonly QuestionAnswer[], +): string => { + if (questions.length === 1) { + return formatAnswer(answers[0]); + } + + return questions + .map((question, questionIndex) => { + return `${questionIndex + 1}. ${getQuestionHeader(question, questionIndex)}: ${formatAnswer(answers[questionIndex])}`; + }) + .join("\n"); +}; + +export const AskUserQuestionTool: FC = ({ + questions, + status, + isError, + errorMessage, + isChatCompleted = false, + isLatestAskUserQuestion = false, + previousResponseText, + onSubmitAnswer, +}) => { + const idPrefix = useId(); + const filteredQuestions = questions.map(filterQuestionOptions); + const [answers, setAnswers] = useState>( + () => + filteredQuestions.map((question) => { + const firstOption = question.options[0]; + if (!firstOption) { + return undefined; + } + return { + kind: "option" as const, + label: firstOption.label || "Option 1", + optionIndex: 0, + }; + }), + ); + const [currentQuestionIndex, setCurrentQuestionIndex] = useState(0); + const [isSubmitting, setIsSubmitting] = useState(false); + const [submitError, setSubmitError] = useState(); + const [submittedResponseText, setSubmittedResponseText] = useState< + string | null + >(null); + const isRunning = status === "running"; + const displayedSubmittedResponseText = + previousResponseText ?? submittedResponseText; + const hasSubmittedResponse = displayedSubmittedResponseText != null; + const showAnsweredState = status === "completed" && hasSubmittedResponse; + const showSubmittedResponse = showAnsweredState && isLatestAskUserQuestion; + const activeQuestionIndex = Math.min( + currentQuestionIndex, + Math.max(filteredQuestions.length - 1, 0), + ); + const currentAnswer = answers[activeQuestionIndex]; + const isInteractive = + isChatCompleted && + status === "completed" && + isLatestAskUserQuestion && + !hasSubmittedResponse && + Boolean(onSubmitAnswer); + const canAdvanceToNextQuestion = isAnswerValid(currentAnswer); + const canSubmitAllAnswers = filteredQuestions.every((_, questionIndex) => + isAnswerValid(answers[questionIndex]), + ); + + const setAnswerAtIndex = ( + questionIndex: number, + nextAnswer: QuestionAnswer | undefined, + ) => { + setAnswers((currentAnswers) => { + const nextAnswers = [...currentAnswers]; + nextAnswers[questionIndex] = nextAnswer; + return nextAnswers; + }); + setSubmitError(undefined); + }; + + const handleOptionChange = ( + questionIndex: number, + question: AskUserQuestion, + value: string, + ) => { + if (value === OTHER_OPTION_VALUE) { + const previousAnswer = answers[questionIndex]; + setAnswerAtIndex( + questionIndex, + previousAnswer?.kind === "other" + ? previousAnswer + : { kind: "other", text: "" }, + ); + return; + } + + const optionIndex = Number.parseInt(value.replace("option-", ""), 10); + const option = question.options[optionIndex]; + if (!option) { + return; + } + + setAnswerAtIndex(questionIndex, { + kind: "option", + label: option.label || `Option ${optionIndex + 1}`, + optionIndex, + }); + }; + + const handleBack = () => { + setCurrentQuestionIndex((currentIndex) => { + return Math.max(currentIndex - 1, 0); + }); + setSubmitError(undefined); + }; + + const handleNext = () => { + if (!canAdvanceToNextQuestion) { + return; + } + + setCurrentQuestionIndex((currentIndex) => { + return Math.min(currentIndex + 1, filteredQuestions.length - 1); + }); + setSubmitError(undefined); + }; + + const handleSubmit = async () => { + if (!onSubmitAnswer || !isInteractive || !canSubmitAllAnswers) { + return; + } + + const finalizedAnswers = filteredQuestions.map((_, questionIndex) => { + return cloneAnswer(answers[questionIndex]); + }); + if (!finalizedAnswers.every(isAnswerValid)) { + return; + } + + const outgoingMessage = formatOutgoingMessage( + filteredQuestions, + finalizedAnswers, + ); + setIsSubmitting(true); + setSubmitError(undefined); + try { + await onSubmitAnswer(outgoingMessage); + } catch (error) { + setSubmitError( + error instanceof Error + ? error.message + : "Failed to submit your answer.", + ); + setIsSubmitting(false); + return; + } + + setSubmittedResponseText(outgoingMessage); + setIsSubmitting(false); + }; + + const handleFormSubmit = (event: FormEvent) => { + event.preventDefault(); + if (!isInteractive) { + return; + } + + if ( + filteredQuestions.length > 1 && + activeQuestionIndex < filteredQuestions.length - 1 + ) { + handleNext(); + return; + } + + void handleSubmit(); + }; + + if (isError) { + return ( +
+
+ + {errorMessage || "Failed to ask questions"} +
+
+ ); + } + + if (questions.length === 0) { + return ( +
+ {isRunning ? ( +
+ + Asking for clarification... + + +
+ ) : ( +

+ No questions available. +

+ )} +
+ ); + } + + const visibleQuestions = + isInteractive && filteredQuestions.length > 1 + ? [ + { + question: filteredQuestions[activeQuestionIndex], + questionIndex: activeQuestionIndex, + }, + ] + : filteredQuestions.map((question, questionIndex) => ({ + question, + questionIndex, + })); + + const content = ( + <> +
+ {visibleQuestions.map(({ question, questionIndex }) => { + const questionHeader = getQuestionHeader(question, questionIndex); + const questionText = getQuestionText(question); + const questionIdBase = `${idPrefix}-question-${questionIndex}`; + const questionHeaderId = `${questionIdBase}-header`; + const questionTextId = `${questionIdBase}-text`; + const answer = answers[questionIndex]; + const isOtherSelected = answer?.kind === "other"; + const optionCount = question.options.length; + const showProgress = isInteractive && filteredQuestions.length > 1; + + if (showAnsweredState) { + return ( +

+ {questionText} +

+ ); + } + + return ( +
+ {showProgress && ( +

+ Question {questionIndex + 1} of {filteredQuestions.length} +

+ )} +
+

+ {questionHeader} +

+

+ {questionText} +

+
+
+ { + handleOptionChange(questionIndex, question, value); + }} + > + {question.options.map((option, optionIndex) => { + const optionId = `${questionIdBase}-option-${optionIndex}`; + + return ( + + ); + })} + {(() => { + const otherOptionId = `${questionIdBase}-option-${optionCount}`; + return ( +
+ + {isOtherSelected && ( +
+ { + setAnswerAtIndex(questionIndex, { + kind: "other", + text: event.currentTarget.value, + }); + }} + /> +
+ )} +
+ ); + })()} +
+
+
+ ); + })} +
+ + {showSubmittedResponse && ( +
+

+ Submitted answer +

+

+ {displayedSubmittedResponseText || "No answer recorded."} +

+
+ )} + + {submitError && ( +
+ + {submitError} +
+ )} + + {isInteractive && ( +
+ {filteredQuestions.length > 1 && ( + + )} + {filteredQuestions.length > 1 && + activeQuestionIndex < filteredQuestions.length - 1 ? ( + + ) : ( + + )} +
+ )} + + ); + + return ( +
+ {isRunning && ( +
+ + Asking for clarification... + + +
+ )} + + {isInteractive ? ( +
{content}
+ ) : ( + content + )} +
+ ); +}; diff --git a/site/src/pages/AgentsPage/components/ChatElements/tools/ProposePlanTool.stories.tsx b/site/src/pages/AgentsPage/components/ChatElements/tools/ProposePlanTool.stories.tsx index 35bfd83f45..35e74f123c 100644 --- a/site/src/pages/AgentsPage/components/ChatElements/tools/ProposePlanTool.stories.tsx +++ b/site/src/pages/AgentsPage/components/ChatElements/tools/ProposePlanTool.stories.tsx @@ -1,5 +1,5 @@ import type { Meta, StoryObj } from "@storybook/react-vite"; -import { expect, spyOn, userEvent, within } from "storybook/test"; +import { expect, fn, spyOn, userEvent, within } from "storybook/test"; import { reactRouterParameters } from "storybook-addon-remix-react-router"; import { API } from "#/api/api"; import { Tool } from "./Tool"; @@ -64,6 +64,9 @@ export const Running: Story = { expect( canvas.getByText(`Proposing ${defaultPlanFilename}…`), ).toBeInTheDocument(); + expect( + canvas.queryByRole("button", { name: "Implement plan" }), + ).not.toBeInTheDocument(); }, }; @@ -78,16 +81,23 @@ export const Completed: Story = { file_id: "test-file-id-completed", media_type: "text/markdown", }, + onImplementPlan: fn(), }, beforeEach: () => { spyOn(API.experimental, "getChatFileText").mockResolvedValue(samplePlan); }, - play: async ({ canvasElement }) => { + play: async ({ canvasElement, args }) => { const canvas = within(canvasElement); expect(await canvas.findByText("Implementation Plan")).toBeInTheDocument(); expect( canvas.getByRole("button", { name: "Copy plan" }), ).toBeInTheDocument(); + const implementButton = canvas.getByRole("button", { + name: "Implement plan", + }); + expect(implementButton).toHaveTextContent("Implement"); + await userEvent.click(implementButton); + expect(args.onImplementPlan).toHaveBeenCalledTimes(1); }, }; @@ -153,6 +163,9 @@ export const ErrorState: Story = { canvas.getByText(`Proposed ${defaultPlanFilename}`), ).toBeInTheDocument(); expect(canvas.getByLabelText("Error")).toBeInTheDocument(); + expect( + canvas.queryByRole("button", { name: "Implement plan" }), + ).not.toBeInTheDocument(); }, }; diff --git a/site/src/pages/AgentsPage/components/ChatElements/tools/ProposePlanTool.tsx b/site/src/pages/AgentsPage/components/ChatElements/tools/ProposePlanTool.tsx index 6c1be22845..25317b1306 100644 --- a/site/src/pages/AgentsPage/components/ChatElements/tools/ProposePlanTool.tsx +++ b/site/src/pages/AgentsPage/components/ChatElements/tools/ProposePlanTool.tsx @@ -1,7 +1,9 @@ -import { LoaderIcon, TriangleAlertIcon } from "lucide-react"; +import { LoaderIcon, PlayIcon, TriangleAlertIcon } from "lucide-react"; import type React from "react"; +import { useState } from "react"; import { useQuery } from "react-query"; import { API } from "#/api/api"; +import { Button } from "#/components/Button/Button"; import { CopyButton } from "#/components/CopyButton/CopyButton"; import { Tooltip, @@ -19,6 +21,7 @@ export const ProposePlanTool: React.FC<{ status: ToolStatus; isError: boolean; errorMessage?: string; + onImplementPlan?: () => Promise | void; }> = ({ content: inlineContent, fileID, @@ -26,6 +29,7 @@ export const ProposePlanTool: React.FC<{ status, isError, errorMessage, + onImplementPlan, }) => { const hasInlineContent = (inlineContent?.trim().length ?? 0) > 0; const fileQuery = useQuery({ @@ -54,6 +58,27 @@ export const ProposePlanTool: React.FC<{ const filename = (path || "PLAN.md").split("/").pop() || "PLAN.md"; const effectiveError = isError || Boolean(fetchError); const effectiveErrorMessage = errorMessage || fetchError; + const hasDisplayContent = displayContent.trim().length > 0; + const [isSubmitting, setIsSubmitting] = useState(false); + const canImplementPlan = + status === "completed" && + !effectiveError && + !fetchLoading && + hasDisplayContent && + Boolean(onImplementPlan); + + const handleImplementPlanClick = async () => { + if (!onImplementPlan || isSubmitting) { + return; + } + setIsSubmitting(true); + try { + await onImplementPlan(); + setIsSubmitting(false); + } catch { + setIsSubmitting(false); + } + }; return (
@@ -78,11 +103,35 @@ export const ProposePlanTool: React.FC<{ )}
- {displayContent ? ( + {hasDisplayContent ? ( <> {displayContent} -
+
+ {canImplementPlan && ( + + + + + Implement plan + + )}
) : ( diff --git a/site/src/pages/AgentsPage/components/ChatElements/tools/Tool.tsx b/site/src/pages/AgentsPage/components/ChatElements/tools/Tool.tsx index e611c1132c..94b5227aae 100644 --- a/site/src/pages/AgentsPage/components/ChatElements/tools/Tool.tsx +++ b/site/src/pages/AgentsPage/components/ChatElements/tools/Tool.tsx @@ -10,6 +10,10 @@ import { TooltipTrigger, } from "#/components/Tooltip/Tooltip"; import { cn } from "#/utils/cn"; +import { + type AskUserQuestion, + AskUserQuestionTool, +} from "./AskUserQuestionTool"; import { ChatSummarizedTool } from "./ChatSummarizedTool"; import { ComputerTool } from "./ComputerTool"; import { CreateWorkspaceTool } from "./CreateWorkspaceTool"; @@ -72,6 +76,11 @@ interface ToolProps extends Omit, "children"> { mcpServerConfigId?: string; /** Available MCP server configs for icon/name lookup. */ mcpServers?: readonly TypesGen.MCPServerConfig[]; + onImplementPlan?: () => Promise | void; + onSendAskUserQuestionResponse?: (message: string) => Promise | void; + isChatCompleted?: boolean; + isLatestAskUserQuestion?: boolean; + previousResponseText?: string; /** Human-readable intent extracted from the model's tool-call args. */ modelIntent?: string; } @@ -89,6 +98,11 @@ type ToolRendererProps = { computerUseSubagentIds?: Set; showDesktopPreviews?: boolean; subagentStatusOverrides?: Map; + onImplementPlan?: () => Promise | void; + onSendAskUserQuestionResponse?: (message: string) => Promise | void; + isChatCompleted?: boolean; + isLatestAskUserQuestion?: boolean; + previousResponseText?: string; mcpServerConfigId?: string; mcpServers?: readonly TypesGen.MCPServerConfig[]; modelIntent?: string; @@ -98,6 +112,84 @@ type ToolRendererProps = { // Tool-specific renderer functions // --------------------------------------------------------------------------- +const parseAskUserQuestionOptions = ( + value: unknown, +): AskUserQuestion["options"] | null => { + if (!Array.isArray(value)) { + return null; + } + + const options: AskUserQuestion["options"] = []; + for (const option of value) { + const optionRecord = asRecord(option); + if (!optionRecord) { + continue; + } + + options.push({ + label: asString(optionRecord.label).trim(), + description: asString(optionRecord.description).trim(), + }); + } + + return options; +}; + +const parseAskUserQuestions = (value: unknown): AskUserQuestion[] | null => { + if (!Array.isArray(value)) { + return null; + } + + const questions: AskUserQuestion[] = []; + for (const question of value) { + const questionRecord = asRecord(question); + if (!questionRecord) { + continue; + } + + questions.push({ + header: asString(questionRecord.header).trim(), + question: asString(questionRecord.question).trim(), + options: parseAskUserQuestionOptions(questionRecord.options) ?? [], + }); + } + + return questions; +}; + +const parseAskUserQuestionResult = ( + result: unknown, +): AskUserQuestion[] | null => { + const parsedResult = parseArgs(result); + const directQuestions = parsedResult + ? parseAskUserQuestions(parsedResult.questions) + : null; + if (directQuestions) { + return directQuestions; + } + + const resultRecord = asRecord(result); + if (!resultRecord) { + return null; + } + + for (const value of [ + resultRecord.output, + resultRecord.content, + resultRecord.text, + ]) { + const parsedValue = parseArgs(value); + const questions = parsedValue + ? parseAskUserQuestions(parsedValue.questions) + : null; + if (questions) { + return questions; + } + } + + return null; +}; + const ExecuteRenderer: FC = ({ status, args, @@ -493,11 +585,53 @@ const ChatSummarizedRenderer: FC = ({ ); }; +const AskUserQuestionRenderer: FC = ({ + args, + status, + result, + isError, + onSendAskUserQuestionResponse, + isChatCompleted, + isLatestAskUserQuestion, + previousResponseText, +}) => { + const parsedArgs = parseArgs(args); + const questionsFromArgs = parsedArgs + ? parseAskUserQuestions(parsedArgs.questions) + : null; + const questionsFromResult = parseAskUserQuestionResult(result); + const questions = + questionsFromArgs && questionsFromArgs.length > 0 + ? questionsFromArgs + : questionsFromResult && questionsFromResult.length > 0 + ? questionsFromResult + : (questionsFromArgs ?? questionsFromResult ?? []); + const resultRecord = asRecord(result); + const errorMessage = + (resultRecord + ? asString(resultRecord.error || resultRecord.message) + : "") || (typeof result === "string" && isError ? result : ""); + + return ( + + ); +}; + const ProposePlanRenderer: FC = ({ args, status, result, isError, + onImplementPlan, }) => { const parsedArgs = parseArgs(args); const path = parsedArgs ? asString(parsedArgs.path) || "PLAN.md" : "PLAN.md"; @@ -517,6 +651,7 @@ const ProposePlanRenderer: FC = ({ status={status} isError={isError} errorMessage={errorMessage} + onImplementPlan={onImplementPlan} /> ); }; @@ -757,6 +892,7 @@ const toolRenderers: Record> = { close_agent: SubagentRenderer, spawn_computer_use_agent: SubagentRenderer, chat_summarized: ChatSummarizedRenderer, + ask_user_question: AskUserQuestionRenderer, propose_plan: ProposePlanRenderer, computer: ComputerRenderer, }; @@ -780,6 +916,11 @@ export const Tool = memo( subagentStatusOverrides, mcpServerConfigId, mcpServers, + onImplementPlan, + onSendAskUserQuestionResponse, + isChatCompleted, + isLatestAskUserQuestion, + previousResponseText, modelIntent, ref, ...props @@ -812,6 +953,11 @@ export const Tool = memo( subagentStatusOverrides={subagentStatusOverrides} mcpServerConfigId={mcpServerConfigId} mcpServers={mcpServers} + onImplementPlan={onImplementPlan} + onSendAskUserQuestionResponse={onSendAskUserQuestionResponse} + isChatCompleted={isChatCompleted} + isLatestAskUserQuestion={isLatestAskUserQuestion} + previousResponseText={previousResponseText} modelIntent={modelIntent} />
diff --git a/site/src/pages/AgentsPage/components/ChatPageContent.tsx b/site/src/pages/AgentsPage/components/ChatPageContent.tsx index f171826fc7..ae57033de3 100644 --- a/site/src/pages/AgentsPage/components/ChatPageContent.tsx +++ b/site/src/pages/AgentsPage/components/ChatPageContent.tsx @@ -49,6 +49,8 @@ interface ChatPageTimelineProps { fileBlocks?: readonly TypesGen.ChatMessagePart[], ) => void; editingMessageId?: number | null; + onImplementPlan?: () => Promise | void; + onSendAskUserQuestionResponse?: (message: string) => Promise | void; urlTransform?: UrlTransform; mcpServers?: readonly TypesGen.MCPServerConfig[]; } @@ -59,12 +61,17 @@ export const ChatPageTimeline: FC = ({ persistedError, onEditUserMessage, editingMessageId, + onImplementPlan, + onSendAskUserQuestionResponse, urlTransform, mcpServers, }) => { const [chatFullWidth] = useChatFullWidth(); const messagesByID = useChatSelector(store, selectMessagesByID); const orderedMessageIDs = useChatSelector(store, selectOrderedMessageIDs); + const chatStatus = useChatSelector(store, selectChatStatus); + const hasStream = useChatSelector(store, selectHasStreamState); + const isChatCompleted = !hasStream && chatStatus !== "pending"; const messages = orderedMessageIDs .map((messageID) => { @@ -103,6 +110,9 @@ export const ChatPageTimeline: FC = ({ subagentTitles={subagentTitles} onEditUserMessage={onEditUserMessage} editingMessageId={editingMessageId} + onImplementPlan={onImplementPlan} + onSendAskUserQuestionResponse={onSendAskUserQuestionResponse} + isChatCompleted={isChatCompleted} urlTransform={urlTransform} mcpServers={mcpServers} computerUseSubagentIds={computerUseSubagentIds} @@ -149,6 +159,8 @@ interface ChatPageInputProps { modelOptions: readonly ModelSelectorOption[]; modelSelectorPlaceholder: string; modelSelectorHelp?: ReactNode; + planModeEnabled?: boolean; + onPlanModeToggle?: (enabled: boolean) => void; isModelCatalogLoading?: boolean; // Imperative editor handle plus the one-time initial draft, // owned by the conversation component. @@ -184,6 +196,10 @@ interface ChatPageInputProps { onMCPSelectionChange?: (ids: string[]) => void; onMCPAuthComplete?: (serverId: string) => void; lastInjectedContext?: readonly TypesGen.ChatMessagePart[]; + workspaceOptions: readonly TypesGen.Workspace[]; + selectedWorkspaceId: string | null; + onWorkspaceChange: (workspaceId: string | null) => void; + isWorkspaceLoading: boolean; workspace?: TypesGen.Workspace; workspaceAgent?: TypesGen.WorkspaceAgent; chatId?: string; @@ -209,6 +225,8 @@ export const ChatPageInput: FC = ({ modelOptions, modelSelectorPlaceholder, modelSelectorHelp, + planModeEnabled, + onPlanModeToggle, isModelCatalogLoading = false, inputRef, initialValue, @@ -227,6 +245,10 @@ export const ChatPageInput: FC = ({ onMCPSelectionChange, onMCPAuthComplete, lastInjectedContext, + workspaceOptions, + selectedWorkspaceId, + onWorkspaceChange, + isWorkspaceLoading, workspace, workspaceAgent, chatId, @@ -401,7 +423,13 @@ export const ChatPageInput: FC = ({ onModelChange={onModelChange} modelOptions={modelOptions} modelSelectorPlaceholder={modelSelectorPlaceholder} + planModeEnabled={planModeEnabled} + onPlanModeToggle={onPlanModeToggle} isModelCatalogLoading={isModelCatalogLoading} + workspaceOptions={workspaceOptions} + selectedWorkspaceId={selectedWorkspaceId} + onWorkspaceChange={onWorkspaceChange} + isWorkspaceLoading={isWorkspaceLoading} mcpServers={mcpServers} selectedMCPServerIds={selectedMCPServerIds} onMCPSelectionChange={onMCPSelectionChange} diff --git a/site/src/pages/AgentsPage/components/PlanModeInstructionsSettings.tsx b/site/src/pages/AgentsPage/components/PlanModeInstructionsSettings.tsx new file mode 100644 index 0000000000..25d33c10fa --- /dev/null +++ b/site/src/pages/AgentsPage/components/PlanModeInstructionsSettings.tsx @@ -0,0 +1,131 @@ +import { useFormik } from "formik"; +import type { FC } from "react"; +import { useState } from "react"; +import TextareaAutosize from "react-textarea-autosize"; +import type * as TypesGen from "#/api/typesGenerated"; +import { Alert, AlertDescription } from "#/components/Alert/Alert"; +import { Button } from "#/components/Button/Button"; +import { cn } from "#/utils/cn"; +import { countInvisibleCharacters } from "#/utils/invisibleUnicode"; +import { AdminBadge } from "./AdminBadge"; + +interface MutationCallbacks { + onSuccess?: () => void; + onError?: () => void; +} + +interface PlanModeInstructionsSettingsProps { + planModeInstructionsData: + | TypesGen.ChatPlanModeInstructionsResponse + | undefined; + onSavePlanModeInstructions: ( + req: TypesGen.UpdateChatPlanModeInstructionsRequest, + options?: MutationCallbacks, + ) => void; + isSavePlanModeInstructionsError: boolean; + isAnyPromptSaving: boolean; +} + +export const PlanModeInstructionsSettings: FC< + PlanModeInstructionsSettingsProps +> = ({ + planModeInstructionsData, + onSavePlanModeInstructions, + isSavePlanModeInstructionsError, + isAnyPromptSaving, +}) => { + const [ + isPlanModeInstructionsOverflowing, + setIsPlanModeInstructionsOverflowing, + ] = useState(false); + + const hasLoadedPlanModeInstructions = planModeInstructionsData !== undefined; + + const form = useFormik({ + enableReinitialize: true, + initialValues: { + plan_mode_instructions: + planModeInstructionsData?.plan_mode_instructions ?? "", + }, + onSubmit: (values, { resetForm }) => { + onSavePlanModeInstructions(values, { + onSuccess: () => { + resetForm(); + }, + }); + }, + }); + + const planModeInvisibleCharCount = countInvisibleCharacters( + form.values.plan_mode_instructions, + ); + const planModeInvisibleCharWarning = `This text contains ${planModeInvisibleCharCount} invisible Unicode ${planModeInvisibleCharCount !== 1 ? "characters" : "character"} that could hide content. These will be stripped on save.`; + const isPlanModeInstructionsDisabled = + isAnyPromptSaving || !hasLoadedPlanModeInstructions; + + return ( +
+
+

+ Plan mode instructions +

+ +
+

+ Custom instructions applied when the agent enters planning mode. These + supplement the built-in planning behavior. +

+ + setIsPlanModeInstructionsOverflowing(height >= 240) + } + disabled={isPlanModeInstructionsDisabled} + minRows={4} + maxRows={12} + /> + {planModeInvisibleCharCount > 0 && ( + + {planModeInvisibleCharWarning} + + )} +
+ + +
+ {isSavePlanModeInstructionsError && ( +

+ Failed to save plan mode instructions. +

+ )} + + ); +}; diff --git a/site/src/pages/AgentsPage/utils/navigation.ts b/site/src/pages/AgentsPage/utils/navigation.ts new file mode 100644 index 0000000000..fd16ff5b8a --- /dev/null +++ b/site/src/pages/AgentsPage/utils/navigation.ts @@ -0,0 +1,7 @@ +export const buildAgentChatPath = ({ + chatId, +}: Readonly<{ + chatId: string; +}>): string => { + return `/agents/${chatId}`; +};