From 4aa94fcd4c20572c6967e21e362d3d34555278f5 Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Fri, 20 Mar 2026 22:00:55 +0200 Subject: [PATCH] fix: StatusWriter Unwrap and process output error recovery (#23383) Add Unwrap() to StatusWriter so http.ResponseController.SetWriteDeadline can reach the underlying net.Conn through the middleware wrapper. Without this, the agent's 20s WriteTimeout killed blocking process output connections. Also add 30s headroom to the write deadline in handleProcessOutput so the response can be written after a full-duration blocking wait. On the tool layer, waitForProcess and the process_output tool now try a non-blocking snapshot on any error, not just context timeout. Transport errors (like the WriteTimeout EOF) previously returned with no process ID and no recovery path. Now if the process finished, the result is returned transparently. If still running, the error includes the process ID and tells the agent to use process_output. --- agent/agentproc/api.go | 4 +- coderd/chatd/chattool/execute.go | 90 ++++++++++++++++++--------- coderd/chatd/chattool/execute_test.go | 88 ++++++++++++++++++++++++++ coderd/tracing/status_writer.go | 6 ++ coderd/tracing/status_writer_test.go | 40 ++++++++++++ 5 files changed, 198 insertions(+), 30 deletions(-) diff --git a/agent/agentproc/api.go b/agent/agentproc/api.go index bf974b6307..c2b8d072c1 100644 --- a/agent/agentproc/api.go +++ b/agent/agentproc/api.go @@ -181,7 +181,9 @@ func (api *API) handleProcessOutput(rw http.ResponseWriter, r *http.Request) { // WriteTimeout does not kill the connection while // we block. rc := http.NewResponseController(rw) - if err := rc.SetWriteDeadline(time.Now().Add(maxWaitDuration)); err != nil { + // Add headroom beyond the wait timeout so there's time to + // write the response after the blocking wait completes. + if err := rc.SetWriteDeadline(time.Now().Add(maxWaitDuration + 30*time.Second)); err != nil { api.logger.Error(ctx, "extend write deadline for blocking process output", slog.Error(err), ) diff --git a/coderd/chatd/chattool/execute.go b/coderd/chatd/chattool/execute.go index 4b64df1e89..08d6e73f12 100644 --- a/coderd/chatd/chattool/execute.go +++ b/coderd/chatd/chattool/execute.go @@ -88,7 +88,7 @@ type ExecuteArgs struct { func Execute(options ExecuteOptions) fantasy.AgentTool { return fantasy.NewAgentTool( "execute", - "Execute a shell command in the workspace. Use run_in_background=true for long-running processes (dev servers, file watchers, builds). Never use shell '&' for backgrounding. If the command times out, the response includes a background_process_id so you can retrieve output later with process_output.", + "Execute a shell command in the workspace. Use run_in_background=true for long-running processes (dev servers, file watchers, builds). Never use shell '&' for backgrounding. If the command fails or times out, the response may include a background_process_id; use process_output with that ID to retrieve the result.", func(ctx context.Context, args ExecuteArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) { if options.GetWorkspaceConn == nil { return fantasy.NewTextErrorResponse("workspace connection resolver is not configured"), nil @@ -211,7 +211,7 @@ func executeForeground( return errorResult(fmt.Sprintf("start process: %v", err)) } - result := waitForProcess(cmdCtx, conn, resp.ID, timeout) + result := waitForProcess(cmdCtx, ctx, conn, resp.ID, timeout) result.WallDurationMs = time.Since(start).Milliseconds() // Add an advisory note for file-dump commands. @@ -238,8 +238,13 @@ func truncateOutput(output string) string { // waitForProcess waits for process completion using the // blocking process output API instead of polling. +// waitForProcess blocks until the process exits or the context +// expires. On any error (timeout or transport), it tries a +// non-blocking snapshot to recover. Total wall time may exceed +// timeout by up to snapshotTimeout if recovery is needed. func waitForProcess( ctx context.Context, + parentCtx context.Context, conn workspacesdk.AgentConn, processID string, timeout time.Duration, @@ -250,37 +255,62 @@ func waitForProcess( Wait: true, }) if err != nil { - if ctx.Err() != nil { - // Timeout: fetch final snapshot with a fresh - // context. The blocking request was canceled - // so the response body was lost. - bgCtx, bgCancel := context.WithTimeout( - context.Background(), - snapshotTimeout, - ) - defer bgCancel() - resp, err = conn.ProcessOutput(bgCtx, processID, nil) - if err != nil { - return ExecuteResult{ - Success: false, - ExitCode: -1, - Error: fmt.Sprintf("command timed out after %s; failed to get output: %v", timeout, err), - BackgroundProcessID: processID, - } + origErr := err + timedOut := ctx.Err() != nil + + // Fetch a snapshot with a fresh context. The blocking + // request may have failed due to a context timeout or + // a transport error (e.g. the server's WriteTimeout + // killed the connection). Either way, the process may + // still have output available. + bgCtx, bgCancel := context.WithTimeout( + parentCtx, + snapshotTimeout, + ) + defer bgCancel() + resp, err = conn.ProcessOutput(bgCtx, processID, nil) + if err != nil { + errMsg := fmt.Sprintf("get process output: %v; use process_output with ID %s to retry", origErr, processID) + if timedOut { + errMsg = fmt.Sprintf("command timed out after %s; failed to get output: %v", timeout, err) } - output := truncateOutput(resp.Output) return ExecuteResult{ Success: false, - Output: output, ExitCode: -1, - Error: fmt.Sprintf("command timed out after %s", timeout), - Truncated: resp.Truncated, + Error: errMsg, BackgroundProcessID: processID, } } + + // Snapshot succeeded. If the process finished, return + // its real result (transparent recovery). + if !resp.Running { + exitCode := 0 + if resp.ExitCode != nil { + exitCode = *resp.ExitCode + } + output := truncateOutput(resp.Output) + return ExecuteResult{ + Success: exitCode == 0, + Output: output, + ExitCode: exitCode, + Truncated: resp.Truncated, + } + } + + // Process still running, return partial output. + output := truncateOutput(resp.Output) + errMsg := fmt.Sprintf("command timed out after %s", timeout) + if !timedOut { + errMsg = fmt.Sprintf("get process output: %v (process still running, use process_output to check later)", origErr) + } return ExecuteResult{ - Success: false, - Error: fmt.Sprintf("get process output: %v", err), + Success: false, + Output: output, + ExitCode: -1, + Error: errMsg, + Truncated: resp.Truncated, + BackgroundProcessID: processID, } } @@ -291,7 +321,7 @@ func waitForProcess( if resp.Running { if ctx.Err() == nil { // Still within the caller's timeout, retry. - return waitForProcess(ctx, conn, processID, timeout) + return waitForProcess(ctx, parentCtx, conn, processID, timeout) } output := truncateOutput(resp.Output) return ExecuteResult{ @@ -407,9 +437,11 @@ func ProcessOutput(options ProcessToolOptions) fantasy.AgentTool { } resp, err := conn.ProcessOutput(ctx, args.ProcessID, opts) if err != nil { - // If our wait timed out but the parent is still alive, - // fetch a non-blocking snapshot. - if ctx.Err() == nil || parentCtx.Err() != nil { + // The blocking request may have failed due to a + // context timeout or a transport error (e.g. + // server WriteTimeout). Try a non-blocking + // snapshot if the parent context is still alive. + if parentCtx.Err() != nil { return errorResult(fmt.Sprintf("get process output: %v", err)), nil } bgCtx, bgCancel := context.WithTimeout(parentCtx, snapshotTimeout) diff --git a/coderd/chatd/chattool/execute_test.go b/coderd/chatd/chattool/execute_test.go index 3d22cee96f..0b6e2c159c 100644 --- a/coderd/chatd/chattool/execute_test.go +++ b/coderd/chatd/chattool/execute_test.go @@ -343,6 +343,11 @@ func TestExecuteTool(t *testing.T) { mockConn.EXPECT(). StartProcess(gomock.Any(), gomock.Any()). Return(workspacesdk.StartProcessResponse{ID: "proc-1"}, nil) + // First call: blocking wait fails. + mockConn.EXPECT(). + ProcessOutput(gomock.Any(), "proc-1", gomock.Any()). + Return(workspacesdk.ProcessOutputResponse{}, xerrors.New("agent disconnected")) + // Second call: snapshot fallback also fails. mockConn.EXPECT(). ProcessOutput(gomock.Any(), "proc-1", gomock.Any()). Return(workspacesdk.ProcessOutputResponse{}, xerrors.New("agent disconnected")) @@ -361,6 +366,89 @@ func TestExecuteTool(t *testing.T) { require.NoError(t, json.Unmarshal([]byte(resp.Content), &result)) assert.False(t, result.Success) assert.Contains(t, result.Error, "agent disconnected") + // Snapshot fallback should provide the process ID + // so the agent can retry manually. + assert.Equal(t, "proc-1", result.BackgroundProcessID) + }) + + t.Run("TransportErrorRecoveryProcessDone", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + mockConn := agentconnmock.NewMockAgentConn(ctrl) + + exitCode := 0 + mockConn.EXPECT(). + StartProcess(gomock.Any(), gomock.Any()). + Return(workspacesdk.StartProcessResponse{ID: "proc-1"}, nil) + // Blocking wait fails with transport error. + mockConn.EXPECT(). + ProcessOutput(gomock.Any(), "proc-1", gomock.Any()). + Return(workspacesdk.ProcessOutputResponse{}, xerrors.New("EOF")) + // Snapshot fallback finds the process completed. + mockConn.EXPECT(). + ProcessOutput(gomock.Any(), "proc-1", gomock.Any()). + Return(workspacesdk.ProcessOutputResponse{ + Output: "hello\n", + Running: false, + ExitCode: &exitCode, + }, nil) + + tool := newExecuteTool(t, mockConn) + ctx := testutil.Context(t, testutil.WaitMedium) + resp, err := tool.Run(ctx, fantasy.ToolCall{ + ID: "call-1", + Name: "execute", + Input: `{"command":"echo hello"}`, + }) + require.NoError(t, err) + assert.False(t, resp.IsError) + + var result chattool.ExecuteResult + require.NoError(t, json.Unmarshal([]byte(resp.Content), &result)) + // Transparent recovery: success with real output. + assert.True(t, result.Success) + assert.Equal(t, 0, result.ExitCode) + assert.Equal(t, "hello\n", result.Output) + assert.Empty(t, result.BackgroundProcessID) + }) + + t.Run("TransportErrorProcessStillRunning", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + mockConn := agentconnmock.NewMockAgentConn(ctrl) + + mockConn.EXPECT(). + StartProcess(gomock.Any(), gomock.Any()). + Return(workspacesdk.StartProcessResponse{ID: "proc-1"}, nil) + // Blocking wait fails with transport error. + mockConn.EXPECT(). + ProcessOutput(gomock.Any(), "proc-1", gomock.Any()). + Return(workspacesdk.ProcessOutputResponse{}, xerrors.New("EOF")) + // Snapshot fallback: process still running. + mockConn.EXPECT(). + ProcessOutput(gomock.Any(), "proc-1", gomock.Any()). + Return(workspacesdk.ProcessOutputResponse{ + Output: "partial output", + Running: true, + }, nil) + + tool := newExecuteTool(t, mockConn) + ctx := testutil.Context(t, testutil.WaitMedium) + resp, err := tool.Run(ctx, fantasy.ToolCall{ + ID: "call-1", + Name: "execute", + Input: `{"command":"sleep 60"}`, + }) + require.NoError(t, err) + assert.False(t, resp.IsError) + + var result chattool.ExecuteResult + require.NoError(t, json.Unmarshal([]byte(resp.Content), &result)) + assert.False(t, result.Success) + assert.Contains(t, result.Error, "process still running") + assert.Contains(t, result.Error, "process_output") + assert.Equal(t, "partial output", result.Output) + assert.Equal(t, "proc-1", result.BackgroundProcessID) }) t.Run("GetWorkspaceConnNil", func(t *testing.T) { diff --git a/coderd/tracing/status_writer.go b/coderd/tracing/status_writer.go index e9337c20e0..2dddd758c5 100644 --- a/coderd/tracing/status_writer.go +++ b/coderd/tracing/status_writer.go @@ -90,6 +90,12 @@ func minInt(a, b int) int { return b } +// Unwrap returns the underlying ResponseWriter, allowing +// http.ResponseController to reach it for SetWriteDeadline, etc. +func (w *StatusWriter) Unwrap() http.ResponseWriter { + return w.ResponseWriter +} + func (w *StatusWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { hijacker, ok := w.ResponseWriter.(http.Hijacker) if !ok { diff --git a/coderd/tracing/status_writer_test.go b/coderd/tracing/status_writer_test.go index 6aff7b915c..98bf37f41e 100644 --- a/coderd/tracing/status_writer_test.go +++ b/coderd/tracing/status_writer_test.go @@ -7,6 +7,7 @@ import ( "net/http" "net/http/httptest" "testing" + "time" "github.com/stretchr/testify/require" "golang.org/x/xerrors" @@ -117,6 +118,45 @@ func TestStatusWriter(t *testing.T) { require.Equal(t, "hijacked", err.Error()) }) + t.Run("Unwrap", func(t *testing.T) { + t.Parallel() + rec := httptest.NewRecorder() + w := &tracing.StatusWriter{ResponseWriter: rec} + + got := w.Unwrap() + require.Equal(t, rec, got, "Unwrap should return the inner ResponseWriter") + }) + + t.Run("SetWriteDeadlineThroughMiddleware", func(t *testing.T) { + t.Parallel() + + // Use a real HTTP server so the ResponseWriter is backed by + // a net.Conn that supports SetWriteDeadline. + // http.ResponseController reaches it by calling Unwrap() on + // each wrapper in the chain. + var setDeadlineErr error + handlerCalled := false + handler := tracing.StatusWriterMiddleware(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + handlerCalled = true + rc := http.NewResponseController(w) + setDeadlineErr = rc.SetWriteDeadline(time.Now().Add(time.Minute)) + w.WriteHeader(http.StatusNoContent) + })) + + srv := httptest.NewServer(handler) + t.Cleanup(srv.Close) + + req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, srv.URL, nil) + require.NoError(t, err) + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + resp.Body.Close() + require.True(t, handlerCalled, "handler must be invoked") + require.Equal(t, http.StatusNoContent, resp.StatusCode) + // Assert in the test goroutine, not the handler goroutine. + require.NoError(t, setDeadlineErr, "SetWriteDeadline should succeed through StatusWriter") + }) + t.Run("Middleware", func(t *testing.T) { t.Parallel()