mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
feat: add plan mode with restricted tool boundary (#24236)
> This PR was authored by Mux on behalf of Mike. ## Summary - add persistent plan mode for chats and the chat-specific plan file flow - add structured planning tools such as `ask_user_question` and `propose_plan` - keep `write_file` and `edit_files` constrained to the chat-specific plan file during plan turns - allow shell exploration in plan mode, including subagents, via `execute` and `process_output` - block implementation-oriented, provider-native, MCP, dynamic, and computer-use tools during plan turns - update the chat UI, tests, and docs for the new planning flow
This commit is contained in:
@@ -31,6 +31,7 @@ func (api *API) Routes() http.Handler {
|
||||
r := chi.NewRouter()
|
||||
|
||||
r.Post("/list-directory", api.HandleLS)
|
||||
r.Get("/resolve-path", api.HandleResolvePath)
|
||||
r.Get("/read-file", api.HandleReadFile)
|
||||
r.Get("/read-file-lines", api.HandleReadFileLines)
|
||||
r.Post("/write-file", api.HandleWriteFile)
|
||||
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
"syscall"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/spf13/afero"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
@@ -328,7 +327,7 @@ func (api *API) writeFile(ctx context.Context, r *http.Request, path string) (HT
|
||||
return http.StatusBadRequest, xerrors.Errorf("file path must be absolute: %q", path)
|
||||
}
|
||||
|
||||
resolved, err := api.resolveSymlink(path)
|
||||
resolved, err := api.resolvePath(path)
|
||||
if err != nil {
|
||||
return http.StatusInternalServerError, xerrors.Errorf("resolve symlink %q: %w", path, err)
|
||||
}
|
||||
@@ -447,7 +446,7 @@ func (api *API) prepareFileEdit(path string, edits []workspacesdk.FileEdit) (int
|
||||
return http.StatusBadRequest, nil, xerrors.New("must specify at least one edit")
|
||||
}
|
||||
|
||||
resolved, err := api.resolveSymlink(path)
|
||||
resolved, err := api.resolvePath(path)
|
||||
if err != nil {
|
||||
return http.StatusInternalServerError, nil, xerrors.Errorf("resolve symlink %q: %w", path, err)
|
||||
}
|
||||
@@ -556,52 +555,6 @@ func (api *API) atomicWrite(ctx context.Context, path string, mode *os.FileMode,
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// resolveSymlink resolves a path through any symlinks so that
|
||||
// subsequent operations (such as atomic rename) target the real
|
||||
// file instead of replacing the symlink itself.
|
||||
//
|
||||
// The filesystem must implement afero.Lstater and afero.LinkReader
|
||||
// for resolution to occur; if it does not (e.g. MemMapFs), the
|
||||
// path is returned unchanged.
|
||||
func (api *API) resolveSymlink(path string) (string, error) {
|
||||
const maxDepth = 10
|
||||
|
||||
lstater, hasLstat := api.filesystem.(afero.Lstater)
|
||||
if !hasLstat {
|
||||
return path, nil
|
||||
}
|
||||
reader, hasReadlink := api.filesystem.(afero.LinkReader)
|
||||
if !hasReadlink {
|
||||
return path, nil
|
||||
}
|
||||
|
||||
for range maxDepth {
|
||||
info, _, err := lstater.LstatIfPossible(path)
|
||||
if err != nil {
|
||||
// If the file does not exist yet (new file write),
|
||||
// there is nothing to resolve.
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return path, nil
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
if info.Mode()&os.ModeSymlink == 0 {
|
||||
return path, nil
|
||||
}
|
||||
|
||||
target, err := reader.ReadlinkIfPossible(path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if !filepath.IsAbs(target) {
|
||||
target = filepath.Join(filepath.Dir(path), target)
|
||||
}
|
||||
path = target
|
||||
}
|
||||
|
||||
return "", xerrors.Errorf("too many levels of symlinks resolving %q", path)
|
||||
}
|
||||
|
||||
// fuzzyReplace attempts to find `search` inside `content` and replace it
|
||||
// with `replace`. It uses a cascading match strategy inspired by
|
||||
// openai/codex's apply_patch:
|
||||
|
||||
@@ -0,0 +1,119 @@
|
||||
package agentfiles
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/spf13/afero"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
)
|
||||
|
||||
// HandleResolvePath resolves the existing portion of an absolute path through
|
||||
// any symlinks and returns the resulting path. Missing trailing components are
|
||||
// preserved so callers can validate future writes against the real target.
|
||||
func (api *API) HandleResolvePath(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
query := r.URL.Query()
|
||||
parser := httpapi.NewQueryParamParser().RequiredNotEmpty("path")
|
||||
path := parser.String(query, "", "path")
|
||||
parser.ErrorExcessParams(query)
|
||||
if len(parser.Errors) > 0 {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Query parameters have invalid values.",
|
||||
Validations: parser.Errors,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
resolved, err := api.resolvePath(path)
|
||||
if err != nil {
|
||||
status := http.StatusInternalServerError
|
||||
switch {
|
||||
case !filepath.IsAbs(path):
|
||||
status = http.StatusBadRequest
|
||||
case errors.Is(err, os.ErrPermission):
|
||||
status = http.StatusForbidden
|
||||
}
|
||||
httpapi.Write(ctx, rw, status, codersdk.Response{Message: err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, workspacesdk.ResolvePathResponse{
|
||||
ResolvedPath: resolved,
|
||||
})
|
||||
}
|
||||
|
||||
// resolvePath resolves any symlinks in the existing portion of path while
|
||||
// preserving missing trailing components.
|
||||
func (api *API) resolvePath(path string) (string, error) {
|
||||
if !filepath.IsAbs(path) {
|
||||
return "", xerrors.Errorf("file path must be absolute: %q", path)
|
||||
}
|
||||
|
||||
path = filepath.Clean(path)
|
||||
|
||||
lstater, hasLstat := api.filesystem.(afero.Lstater)
|
||||
if !hasLstat {
|
||||
return path, nil
|
||||
}
|
||||
targetReader, hasReadlink := api.filesystem.(afero.LinkReader)
|
||||
if !hasReadlink {
|
||||
return path, nil
|
||||
}
|
||||
|
||||
const maxDepth = 40
|
||||
var resolve func(string, int) (string, error)
|
||||
resolve = func(path string, depth int) (string, error) {
|
||||
if depth > maxDepth {
|
||||
return "", xerrors.Errorf("too many levels of symlinks resolving %q", path)
|
||||
}
|
||||
|
||||
info, _, err := lstater.LstatIfPossible(path)
|
||||
switch {
|
||||
case err == nil:
|
||||
if info.Mode()&os.ModeSymlink == 0 {
|
||||
dir := filepath.Dir(path)
|
||||
if dir == path {
|
||||
return path, nil
|
||||
}
|
||||
|
||||
resolvedDir, err := resolve(dir, depth)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Join(resolvedDir, filepath.Base(path)), nil
|
||||
}
|
||||
|
||||
target, err := targetReader.ReadlinkIfPossible(path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if !filepath.IsAbs(target) {
|
||||
target = filepath.Join(filepath.Dir(path), target)
|
||||
}
|
||||
return resolve(filepath.Clean(target), depth+1)
|
||||
case errors.Is(err, os.ErrNotExist):
|
||||
dir := filepath.Dir(path)
|
||||
if dir == path {
|
||||
return path, nil
|
||||
}
|
||||
|
||||
resolvedDir, err := resolve(dir, depth)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Join(resolvedDir, filepath.Base(path)), nil
|
||||
default:
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
return resolve(path, 0)
|
||||
}
|
||||
@@ -0,0 +1,137 @@
|
||||
package agentfiles_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/afero"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/agent/agentfiles"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestResolvePath_FollowsFileSymlink(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("symlinks are not reliably supported on Windows")
|
||||
}
|
||||
|
||||
dir := t.TempDir()
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
osFs := afero.NewOsFs()
|
||||
api := agentfiles.NewAPI(logger, osFs, nil)
|
||||
|
||||
realPath := filepath.Join(dir, "real.txt")
|
||||
err := afero.WriteFile(osFs, realPath, []byte("hello"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
linkPath := filepath.Join(dir, "link.txt")
|
||||
err = os.Symlink(realPath, linkPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancel()
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("/resolve-path?path=%s", linkPath), nil)
|
||||
api.Routes().ServeHTTP(w, r)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var resp workspacesdk.ResolvePathResponse
|
||||
require.NoError(t, json.NewDecoder(w.Body).Decode(&resp))
|
||||
require.Equal(t, mustEvalSymlinks(t, realPath), resp.ResolvedPath)
|
||||
}
|
||||
|
||||
func TestResolvePath_FollowsSymlinkedParentForMissingFile(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("symlinks are not reliably supported on Windows")
|
||||
}
|
||||
|
||||
dir := t.TempDir()
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
osFs := afero.NewOsFs()
|
||||
api := agentfiles.NewAPI(logger, osFs, nil)
|
||||
|
||||
realPlansDir := filepath.Join(dir, "real-plans")
|
||||
err := os.MkdirAll(realPlansDir, 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
linkPlansDir := filepath.Join(dir, "link-plans")
|
||||
err = os.Symlink(realPlansDir, linkPlansDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
requestedPath := filepath.Join(linkPlansDir, "PLAN.md")
|
||||
resolvedPath := filepath.Join(mustEvalSymlinks(t, realPlansDir), "PLAN.md")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancel()
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("/resolve-path?path=%s", requestedPath), nil)
|
||||
api.Routes().ServeHTTP(w, r)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var resp workspacesdk.ResolvePathResponse
|
||||
require.NoError(t, json.NewDecoder(w.Body).Decode(&resp))
|
||||
require.Equal(t, resolvedPath, resp.ResolvedPath)
|
||||
}
|
||||
|
||||
func TestResolvePath_FollowsSymlinkedParentForExistingFile(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("symlinks are not reliably supported on Windows")
|
||||
}
|
||||
|
||||
dir := t.TempDir()
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
osFs := afero.NewOsFs()
|
||||
api := agentfiles.NewAPI(logger, osFs, nil)
|
||||
|
||||
realPlansDir := filepath.Join(dir, "real-plans")
|
||||
err := os.MkdirAll(realPlansDir, 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
resolvedPath := filepath.Join(realPlansDir, "PLAN.md")
|
||||
err = afero.WriteFile(osFs, resolvedPath, []byte("plan"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
linkPlansDir := filepath.Join(dir, "link-plans")
|
||||
err = os.Symlink(realPlansDir, linkPlansDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
requestedPath := filepath.Join(linkPlansDir, "PLAN.md")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancel()
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("/resolve-path?path=%s", requestedPath), nil)
|
||||
api.Routes().ServeHTTP(w, r)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var resp workspacesdk.ResolvePathResponse
|
||||
require.NoError(t, json.NewDecoder(w.Body).Decode(&resp))
|
||||
require.Equal(t, mustEvalSymlinks(t, resolvedPath), resp.ResolvedPath)
|
||||
}
|
||||
|
||||
func mustEvalSymlinks(t *testing.T, path string) string {
|
||||
t.Helper()
|
||||
resolvedPath, err := filepath.EvalSymlinks(path)
|
||||
require.NoError(t, err)
|
||||
return resolvedPath
|
||||
}
|
||||
Reference in New Issue
Block a user