diff --git a/agent/agentfiles/files.go b/agent/agentfiles/files.go index 3f3bea27c2..028ac9697e 100644 --- a/agent/agentfiles/files.go +++ b/agent/agentfiles/files.go @@ -42,6 +42,14 @@ type ReadFileLinesResponse struct { type HTTPResponseCode = int +// pendingEdit holds the computed result of a file edit, ready to +// be written to disk. +type pendingEdit struct { + path string + content string + mode os.FileMode +} + func (api *API) HandleReadFile(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() @@ -368,17 +376,23 @@ func (api *API) HandleEditFiles(rw http.ResponseWriter, r *http.Request) { return } + // Phase 1: compute all edits in memory. If any file fails + // (bad path, search miss, permission error), bail before + // writing anything. + var pending []pendingEdit var combinedErr error status := http.StatusOK for _, edit := range req.Files { - s, err := api.editFile(r.Context(), edit.Path, edit.Edits) - // Keep the highest response status, so 500 will be preferred over 400, etc. + s, p, err := api.prepareFileEdit(edit.Path, edit.Edits) if s > status { status = s } if err != nil { combinedErr = errors.Join(combinedErr, err) } + if p != nil { + pending = append(pending, *p) + } } if combinedErr != nil { @@ -388,6 +402,20 @@ func (api *API) HandleEditFiles(rw http.ResponseWriter, r *http.Request) { return } + // Phase 2: write all files via atomicWrite. A failure here + // (e.g. disk full) can leave earlier files committed. True + // cross-file atomicity would require filesystem transactions. + for _, p := range pending { + mode := p.mode + s, err := api.atomicWrite(ctx, p.path, &mode, strings.NewReader(p.content)) + if err != nil { + httpapi.Write(ctx, rw, s, codersdk.Response{ + Message: err.Error(), + }) + return + } + } + // Track edited paths for git watch. if api.pathStore != nil { if chatID, ancestorIDs, ok := agentgit.ExtractChatContext(r); ok { @@ -404,22 +432,24 @@ func (api *API) HandleEditFiles(rw http.ResponseWriter, r *http.Request) { }) } -func (api *API) editFile(ctx context.Context, path string, edits []workspacesdk.FileEdit) (int, error) { +// prepareFileEdit validates, reads, and computes edits for a single +// file without writing anything to disk. +func (api *API) prepareFileEdit(path string, edits []workspacesdk.FileEdit) (int, *pendingEdit, error) { if path == "" { - return http.StatusBadRequest, xerrors.New("\"path\" is required") + return http.StatusBadRequest, nil, xerrors.New("\"path\" is required") } if !filepath.IsAbs(path) { - return http.StatusBadRequest, xerrors.Errorf("file path must be absolute: %q", path) + return http.StatusBadRequest, nil, xerrors.Errorf("file path must be absolute: %q", path) } if len(edits) == 0 { - return http.StatusBadRequest, xerrors.New("must specify at least one edit") + return http.StatusBadRequest, nil, xerrors.New("must specify at least one edit") } resolved, err := api.resolveSymlink(path) if err != nil { - return http.StatusInternalServerError, xerrors.Errorf("resolve symlink %q: %w", path, err) + return http.StatusInternalServerError, nil, xerrors.Errorf("resolve symlink %q: %w", path, err) } path = resolved @@ -432,22 +462,22 @@ func (api *API) editFile(ctx context.Context, path string, edits []workspacesdk. case errors.Is(err, os.ErrPermission): status = http.StatusForbidden } - return status, err + return status, nil, err } defer f.Close() stat, err := f.Stat() if err != nil { - return http.StatusInternalServerError, err + return http.StatusInternalServerError, nil, err } if stat.IsDir() { - return http.StatusBadRequest, xerrors.Errorf("open %s: not a file", path) + return http.StatusBadRequest, nil, xerrors.Errorf("open %s: not a file", path) } data, err := io.ReadAll(f) if err != nil { - return http.StatusInternalServerError, xerrors.Errorf("read %s: %w", path, err) + return http.StatusInternalServerError, nil, xerrors.Errorf("read %s: %w", path, err) } content := string(data) @@ -455,12 +485,15 @@ func (api *API) editFile(ctx context.Context, path string, edits []workspacesdk. var err error content, err = fuzzyReplace(content, edit) if err != nil { - return http.StatusBadRequest, xerrors.Errorf("edit %s: %w", path, err) + return http.StatusBadRequest, nil, xerrors.Errorf("edit %s: %w", path, err) } } - m := stat.Mode() - return api.atomicWrite(ctx, path, &m, strings.NewReader(content)) + return 0, &pendingEdit{ + path: path, + content: content, + mode: stat.Mode(), + }, nil } // atomicWrite writes content from r to path via a temp file in the diff --git a/agent/agentfiles/files_test.go b/agent/agentfiles/files_test.go index f30bb740f1..10908c17d0 100644 --- a/agent/agentfiles/files_test.go +++ b/agent/agentfiles/files_test.go @@ -969,8 +969,10 @@ func TestEditFiles(t *testing.T) { }, }, }, + // No files should be modified when any edit fails + // (atomic multi-file semantics). expected: map[string]string{ - filepath.Join(tmpdir, "file8"): "edited8 8", + filepath.Join(tmpdir, "file8"): "file 8", }, // Higher status codes will override lower ones, so in this case the 404 // takes priority over the 403. @@ -980,8 +982,44 @@ func TestEditFiles(t *testing.T) { "file9: file does not exist", }, }, + { + // Valid edits on files A and C, but file B has a + // search miss. None should be written. + name: "AtomicMultiFile_OneFailsNoneWritten", + contents: map[string]string{ + filepath.Join(tmpdir, "atomic-a"): "aaa", + filepath.Join(tmpdir, "atomic-b"): "bbb", + filepath.Join(tmpdir, "atomic-c"): "ccc", + }, + edits: []workspacesdk.FileEdits{ + { + Path: filepath.Join(tmpdir, "atomic-a"), + Edits: []workspacesdk.FileEdit{ + {Search: "aaa", Replace: "AAA"}, + }, + }, + { + Path: filepath.Join(tmpdir, "atomic-b"), + Edits: []workspacesdk.FileEdit{ + {Search: "NOTFOUND", Replace: "XXX"}, + }, + }, + { + Path: filepath.Join(tmpdir, "atomic-c"), + Edits: []workspacesdk.FileEdit{ + {Search: "ccc", Replace: "CCC"}, + }, + }, + }, + errCode: http.StatusBadRequest, + errors: []string{"search string not found"}, + expected: map[string]string{ + filepath.Join(tmpdir, "atomic-a"): "aaa", + filepath.Join(tmpdir, "atomic-b"): "bbb", + filepath.Join(tmpdir, "atomic-c"): "ccc", + }, + }, } - for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel()