mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
fix: StatusWriter Unwrap and process output error recovery (#23383)
Add Unwrap() to StatusWriter so http.ResponseController.SetWriteDeadline can reach the underlying net.Conn through the middleware wrapper. Without this, the agent's 20s WriteTimeout killed blocking process output connections. Also add 30s headroom to the write deadline in handleProcessOutput so the response can be written after a full-duration blocking wait. On the tool layer, waitForProcess and the process_output tool now try a non-blocking snapshot on any error, not just context timeout. Transport errors (like the WriteTimeout EOF) previously returned with no process ID and no recovery path. Now if the process finished, the result is returned transparently. If still running, the error includes the process ID and tells the agent to use process_output.
This commit is contained in:
committed by
GitHub
parent
599f21afa3
commit
4aa94fcd4c
@@ -181,7 +181,9 @@ func (api *API) handleProcessOutput(rw http.ResponseWriter, r *http.Request) {
|
||||
// WriteTimeout does not kill the connection while
|
||||
// we block.
|
||||
rc := http.NewResponseController(rw)
|
||||
if err := rc.SetWriteDeadline(time.Now().Add(maxWaitDuration)); err != nil {
|
||||
// Add headroom beyond the wait timeout so there's time to
|
||||
// write the response after the blocking wait completes.
|
||||
if err := rc.SetWriteDeadline(time.Now().Add(maxWaitDuration + 30*time.Second)); err != nil {
|
||||
api.logger.Error(ctx, "extend write deadline for blocking process output",
|
||||
slog.Error(err),
|
||||
)
|
||||
|
||||
@@ -88,7 +88,7 @@ type ExecuteArgs struct {
|
||||
func Execute(options ExecuteOptions) fantasy.AgentTool {
|
||||
return fantasy.NewAgentTool(
|
||||
"execute",
|
||||
"Execute a shell command in the workspace. Use run_in_background=true for long-running processes (dev servers, file watchers, builds). Never use shell '&' for backgrounding. If the command times out, the response includes a background_process_id so you can retrieve output later with process_output.",
|
||||
"Execute a shell command in the workspace. Use run_in_background=true for long-running processes (dev servers, file watchers, builds). Never use shell '&' for backgrounding. If the command fails or times out, the response may include a background_process_id; use process_output with that ID to retrieve the result.",
|
||||
func(ctx context.Context, args ExecuteArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
if options.GetWorkspaceConn == nil {
|
||||
return fantasy.NewTextErrorResponse("workspace connection resolver is not configured"), nil
|
||||
@@ -211,7 +211,7 @@ func executeForeground(
|
||||
return errorResult(fmt.Sprintf("start process: %v", err))
|
||||
}
|
||||
|
||||
result := waitForProcess(cmdCtx, conn, resp.ID, timeout)
|
||||
result := waitForProcess(cmdCtx, ctx, conn, resp.ID, timeout)
|
||||
result.WallDurationMs = time.Since(start).Milliseconds()
|
||||
|
||||
// Add an advisory note for file-dump commands.
|
||||
@@ -238,8 +238,13 @@ func truncateOutput(output string) string {
|
||||
|
||||
// waitForProcess waits for process completion using the
|
||||
// blocking process output API instead of polling.
|
||||
// waitForProcess blocks until the process exits or the context
|
||||
// expires. On any error (timeout or transport), it tries a
|
||||
// non-blocking snapshot to recover. Total wall time may exceed
|
||||
// timeout by up to snapshotTimeout if recovery is needed.
|
||||
func waitForProcess(
|
||||
ctx context.Context,
|
||||
parentCtx context.Context,
|
||||
conn workspacesdk.AgentConn,
|
||||
processID string,
|
||||
timeout time.Duration,
|
||||
@@ -250,39 +255,64 @@ func waitForProcess(
|
||||
Wait: true,
|
||||
})
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
// Timeout: fetch final snapshot with a fresh
|
||||
// context. The blocking request was canceled
|
||||
// so the response body was lost.
|
||||
origErr := err
|
||||
timedOut := ctx.Err() != nil
|
||||
|
||||
// Fetch a snapshot with a fresh context. The blocking
|
||||
// request may have failed due to a context timeout or
|
||||
// a transport error (e.g. the server's WriteTimeout
|
||||
// killed the connection). Either way, the process may
|
||||
// still have output available.
|
||||
bgCtx, bgCancel := context.WithTimeout(
|
||||
context.Background(),
|
||||
parentCtx,
|
||||
snapshotTimeout,
|
||||
)
|
||||
defer bgCancel()
|
||||
resp, err = conn.ProcessOutput(bgCtx, processID, nil)
|
||||
if err != nil {
|
||||
errMsg := fmt.Sprintf("get process output: %v; use process_output with ID %s to retry", origErr, processID)
|
||||
if timedOut {
|
||||
errMsg = fmt.Sprintf("command timed out after %s; failed to get output: %v", timeout, err)
|
||||
}
|
||||
return ExecuteResult{
|
||||
Success: false,
|
||||
ExitCode: -1,
|
||||
Error: fmt.Sprintf("command timed out after %s; failed to get output: %v", timeout, err),
|
||||
Error: errMsg,
|
||||
BackgroundProcessID: processID,
|
||||
}
|
||||
}
|
||||
|
||||
// Snapshot succeeded. If the process finished, return
|
||||
// its real result (transparent recovery).
|
||||
if !resp.Running {
|
||||
exitCode := 0
|
||||
if resp.ExitCode != nil {
|
||||
exitCode = *resp.ExitCode
|
||||
}
|
||||
output := truncateOutput(resp.Output)
|
||||
return ExecuteResult{
|
||||
Success: exitCode == 0,
|
||||
Output: output,
|
||||
ExitCode: exitCode,
|
||||
Truncated: resp.Truncated,
|
||||
}
|
||||
}
|
||||
|
||||
// Process still running, return partial output.
|
||||
output := truncateOutput(resp.Output)
|
||||
errMsg := fmt.Sprintf("command timed out after %s", timeout)
|
||||
if !timedOut {
|
||||
errMsg = fmt.Sprintf("get process output: %v (process still running, use process_output to check later)", origErr)
|
||||
}
|
||||
return ExecuteResult{
|
||||
Success: false,
|
||||
Output: output,
|
||||
ExitCode: -1,
|
||||
Error: fmt.Sprintf("command timed out after %s", timeout),
|
||||
Error: errMsg,
|
||||
Truncated: resp.Truncated,
|
||||
BackgroundProcessID: processID,
|
||||
}
|
||||
}
|
||||
return ExecuteResult{
|
||||
Success: false,
|
||||
Error: fmt.Sprintf("get process output: %v", err),
|
||||
}
|
||||
}
|
||||
|
||||
// The server-side wait may return before the
|
||||
// process exits if maxWaitDuration is shorter than
|
||||
@@ -291,7 +321,7 @@ func waitForProcess(
|
||||
if resp.Running {
|
||||
if ctx.Err() == nil {
|
||||
// Still within the caller's timeout, retry.
|
||||
return waitForProcess(ctx, conn, processID, timeout)
|
||||
return waitForProcess(ctx, parentCtx, conn, processID, timeout)
|
||||
}
|
||||
output := truncateOutput(resp.Output)
|
||||
return ExecuteResult{
|
||||
@@ -407,9 +437,11 @@ func ProcessOutput(options ProcessToolOptions) fantasy.AgentTool {
|
||||
}
|
||||
resp, err := conn.ProcessOutput(ctx, args.ProcessID, opts)
|
||||
if err != nil {
|
||||
// If our wait timed out but the parent is still alive,
|
||||
// fetch a non-blocking snapshot.
|
||||
if ctx.Err() == nil || parentCtx.Err() != nil {
|
||||
// The blocking request may have failed due to a
|
||||
// context timeout or a transport error (e.g.
|
||||
// server WriteTimeout). Try a non-blocking
|
||||
// snapshot if the parent context is still alive.
|
||||
if parentCtx.Err() != nil {
|
||||
return errorResult(fmt.Sprintf("get process output: %v", err)), nil
|
||||
}
|
||||
bgCtx, bgCancel := context.WithTimeout(parentCtx, snapshotTimeout)
|
||||
|
||||
@@ -343,6 +343,11 @@ func TestExecuteTool(t *testing.T) {
|
||||
mockConn.EXPECT().
|
||||
StartProcess(gomock.Any(), gomock.Any()).
|
||||
Return(workspacesdk.StartProcessResponse{ID: "proc-1"}, nil)
|
||||
// First call: blocking wait fails.
|
||||
mockConn.EXPECT().
|
||||
ProcessOutput(gomock.Any(), "proc-1", gomock.Any()).
|
||||
Return(workspacesdk.ProcessOutputResponse{}, xerrors.New("agent disconnected"))
|
||||
// Second call: snapshot fallback also fails.
|
||||
mockConn.EXPECT().
|
||||
ProcessOutput(gomock.Any(), "proc-1", gomock.Any()).
|
||||
Return(workspacesdk.ProcessOutputResponse{}, xerrors.New("agent disconnected"))
|
||||
@@ -361,6 +366,89 @@ func TestExecuteTool(t *testing.T) {
|
||||
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
|
||||
assert.False(t, result.Success)
|
||||
assert.Contains(t, result.Error, "agent disconnected")
|
||||
// Snapshot fallback should provide the process ID
|
||||
// so the agent can retry manually.
|
||||
assert.Equal(t, "proc-1", result.BackgroundProcessID)
|
||||
})
|
||||
|
||||
t.Run("TransportErrorRecoveryProcessDone", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctrl := gomock.NewController(t)
|
||||
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
|
||||
exitCode := 0
|
||||
mockConn.EXPECT().
|
||||
StartProcess(gomock.Any(), gomock.Any()).
|
||||
Return(workspacesdk.StartProcessResponse{ID: "proc-1"}, nil)
|
||||
// Blocking wait fails with transport error.
|
||||
mockConn.EXPECT().
|
||||
ProcessOutput(gomock.Any(), "proc-1", gomock.Any()).
|
||||
Return(workspacesdk.ProcessOutputResponse{}, xerrors.New("EOF"))
|
||||
// Snapshot fallback finds the process completed.
|
||||
mockConn.EXPECT().
|
||||
ProcessOutput(gomock.Any(), "proc-1", gomock.Any()).
|
||||
Return(workspacesdk.ProcessOutputResponse{
|
||||
Output: "hello\n",
|
||||
Running: false,
|
||||
ExitCode: &exitCode,
|
||||
}, nil)
|
||||
|
||||
tool := newExecuteTool(t, mockConn)
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
resp, err := tool.Run(ctx, fantasy.ToolCall{
|
||||
ID: "call-1",
|
||||
Name: "execute",
|
||||
Input: `{"command":"echo hello"}`,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.False(t, resp.IsError)
|
||||
|
||||
var result chattool.ExecuteResult
|
||||
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
|
||||
// Transparent recovery: success with real output.
|
||||
assert.True(t, result.Success)
|
||||
assert.Equal(t, 0, result.ExitCode)
|
||||
assert.Equal(t, "hello\n", result.Output)
|
||||
assert.Empty(t, result.BackgroundProcessID)
|
||||
})
|
||||
|
||||
t.Run("TransportErrorProcessStillRunning", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctrl := gomock.NewController(t)
|
||||
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
|
||||
mockConn.EXPECT().
|
||||
StartProcess(gomock.Any(), gomock.Any()).
|
||||
Return(workspacesdk.StartProcessResponse{ID: "proc-1"}, nil)
|
||||
// Blocking wait fails with transport error.
|
||||
mockConn.EXPECT().
|
||||
ProcessOutput(gomock.Any(), "proc-1", gomock.Any()).
|
||||
Return(workspacesdk.ProcessOutputResponse{}, xerrors.New("EOF"))
|
||||
// Snapshot fallback: process still running.
|
||||
mockConn.EXPECT().
|
||||
ProcessOutput(gomock.Any(), "proc-1", gomock.Any()).
|
||||
Return(workspacesdk.ProcessOutputResponse{
|
||||
Output: "partial output",
|
||||
Running: true,
|
||||
}, nil)
|
||||
|
||||
tool := newExecuteTool(t, mockConn)
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
resp, err := tool.Run(ctx, fantasy.ToolCall{
|
||||
ID: "call-1",
|
||||
Name: "execute",
|
||||
Input: `{"command":"sleep 60"}`,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.False(t, resp.IsError)
|
||||
|
||||
var result chattool.ExecuteResult
|
||||
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
|
||||
assert.False(t, result.Success)
|
||||
assert.Contains(t, result.Error, "process still running")
|
||||
assert.Contains(t, result.Error, "process_output")
|
||||
assert.Equal(t, "partial output", result.Output)
|
||||
assert.Equal(t, "proc-1", result.BackgroundProcessID)
|
||||
})
|
||||
|
||||
t.Run("GetWorkspaceConnNil", func(t *testing.T) {
|
||||
|
||||
@@ -90,6 +90,12 @@ func minInt(a, b int) int {
|
||||
return b
|
||||
}
|
||||
|
||||
// Unwrap returns the underlying ResponseWriter, allowing
|
||||
// http.ResponseController to reach it for SetWriteDeadline, etc.
|
||||
func (w *StatusWriter) Unwrap() http.ResponseWriter {
|
||||
return w.ResponseWriter
|
||||
}
|
||||
|
||||
func (w *StatusWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
hijacker, ok := w.ResponseWriter.(http.Hijacker)
|
||||
if !ok {
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
@@ -117,6 +118,45 @@ func TestStatusWriter(t *testing.T) {
|
||||
require.Equal(t, "hijacked", err.Error())
|
||||
})
|
||||
|
||||
t.Run("Unwrap", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
rec := httptest.NewRecorder()
|
||||
w := &tracing.StatusWriter{ResponseWriter: rec}
|
||||
|
||||
got := w.Unwrap()
|
||||
require.Equal(t, rec, got, "Unwrap should return the inner ResponseWriter")
|
||||
})
|
||||
|
||||
t.Run("SetWriteDeadlineThroughMiddleware", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Use a real HTTP server so the ResponseWriter is backed by
|
||||
// a net.Conn that supports SetWriteDeadline.
|
||||
// http.ResponseController reaches it by calling Unwrap() on
|
||||
// each wrapper in the chain.
|
||||
var setDeadlineErr error
|
||||
handlerCalled := false
|
||||
handler := tracing.StatusWriterMiddleware(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
handlerCalled = true
|
||||
rc := http.NewResponseController(w)
|
||||
setDeadlineErr = rc.SetWriteDeadline(time.Now().Add(time.Minute))
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
|
||||
srv := httptest.NewServer(handler)
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, srv.URL, nil)
|
||||
require.NoError(t, err)
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
resp.Body.Close()
|
||||
require.True(t, handlerCalled, "handler must be invoked")
|
||||
require.Equal(t, http.StatusNoContent, resp.StatusCode)
|
||||
// Assert in the test goroutine, not the handler goroutine.
|
||||
require.NoError(t, setDeadlineErr, "SetWriteDeadline should succeed through StatusWriter")
|
||||
})
|
||||
|
||||
t.Run("Middleware", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user