mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
feat(agent): add process execution API and rewrite execute tool (#22416)
## Summary Adds a new agent-side process management HTTP API and rewrites the chat execute tool to use it instead of SSH sessions. ## What changed ### New agent/agentproc/ package - **headtail.go** — Thread-safe io.Writer with bounded memory (16KB head + 16KB tail ring buffer). Provides LLM-ready output with truncation metadata and long-line truncation at 2048 bytes. - **headtail_test.go** — 16 tests including race detector coverage for concurrent writes. - **process.go** — Manager + Process types for lifecycle management using agentexec.Execer for proper OOM/nice scores. - **api.go** — HTTP API following the agentfiles chi router pattern. 4 endpoints: start, list, output, signal. ### Agent wiring (agent/agent.go, agent/api.go) Mounts the process API at /api/v0/processes, mirroring how agentfiles is mounted. ### SDK (codersdk/workspacesdk/agentconn.go) 4 new AgentConn interface methods + 7 request/response types: - StartProcess, ListProcesses, ProcessOutput, SignalProcess ### Execute tool rewrite (coderd/chatd/chattool/execute.go) - SSH to Agent API: conn.StartProcess() + conn.ProcessOutput() polling - New parameters: workdir, run_in_background - Structured response: success, exit_code, wall_duration_ms, error, truncated, note, background_process_id - Non-interactive env vars: GIT_EDITOR=true, TERM=dumb, NO_COLOR=1, PAGER=cat, etc. - Output truncation: HeadTailBuffer caps at 32KB for LLM consumption - File-dump detection with advisory notes suggesting read_file - Default timeout: 60s to 10s - Foreground polling: 200ms intervals until exit or timeout ## Architecture State lives on the agent, surviving coderd failover and instance changes. Any coderd replica can query any agent via HTTP over tailnet.
This commit is contained in:
+8
-1
@@ -41,6 +41,7 @@ import (
|
||||
"github.com/coder/coder/v2/agent/agentcontainers"
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
"github.com/coder/coder/v2/agent/agentfiles"
|
||||
"github.com/coder/coder/v2/agent/agentproc"
|
||||
"github.com/coder/coder/v2/agent/agentscripts"
|
||||
"github.com/coder/coder/v2/agent/agentsocket"
|
||||
"github.com/coder/coder/v2/agent/agentssh"
|
||||
@@ -302,7 +303,8 @@ type agent struct {
|
||||
containerAPIOptions []agentcontainers.Option
|
||||
containerAPI *agentcontainers.API
|
||||
|
||||
filesAPI *agentfiles.API
|
||||
filesAPI *agentfiles.API
|
||||
processAPI *agentproc.API
|
||||
|
||||
socketServerEnabled bool
|
||||
socketPath string
|
||||
@@ -375,6 +377,7 @@ func (a *agent) init() {
|
||||
a.containerAPI = agentcontainers.NewAPI(a.logger.Named("containers"), containerAPIOpts...)
|
||||
|
||||
a.filesAPI = agentfiles.NewAPI(a.logger.Named("files"), a.filesystem)
|
||||
a.processAPI = agentproc.NewAPI(a.logger.Named("processes"), a.execer)
|
||||
|
||||
a.reconnectingPTYServer = reconnectingpty.NewServer(
|
||||
a.logger.Named("reconnecting-pty"),
|
||||
@@ -2030,6 +2033,10 @@ func (a *agent) Close() error {
|
||||
a.logger.Error(a.hardCtx, "container API close", slog.Error(err))
|
||||
}
|
||||
|
||||
if err := a.processAPI.Close(); err != nil {
|
||||
a.logger.Error(a.hardCtx, "process API close", slog.Error(err))
|
||||
}
|
||||
|
||||
if a.boundaryLogProxy != nil {
|
||||
err = a.boundaryLogProxy.Close()
|
||||
if err != nil {
|
||||
|
||||
@@ -0,0 +1,175 @@
|
||||
package agentproc
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
)
|
||||
|
||||
// API exposes process-related operations through the agent.
|
||||
type API struct {
|
||||
logger slog.Logger
|
||||
manager *manager
|
||||
}
|
||||
|
||||
// NewAPI creates a new process API handler.
|
||||
func NewAPI(logger slog.Logger, execer agentexec.Execer) *API {
|
||||
return &API{
|
||||
logger: logger,
|
||||
manager: newManager(logger, execer),
|
||||
}
|
||||
}
|
||||
|
||||
// Close shuts down the process manager, killing all running
|
||||
// processes.
|
||||
func (api *API) Close() error {
|
||||
return api.manager.Close()
|
||||
}
|
||||
|
||||
// Routes returns the HTTP handler for process-related routes.
|
||||
func (api *API) Routes() http.Handler {
|
||||
r := chi.NewRouter()
|
||||
r.Post("/start", api.handleStartProcess)
|
||||
r.Get("/list", api.handleListProcesses)
|
||||
r.Get("/{id}/output", api.handleProcessOutput)
|
||||
r.Post("/{id}/signal", api.handleSignalProcess)
|
||||
return r
|
||||
}
|
||||
|
||||
// handleStartProcess starts a new process.
|
||||
func (api *API) handleStartProcess(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
var req workspacesdk.StartProcessRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Request body must be valid JSON.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if req.Command == "" {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Command is required.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
proc, err := api.manager.start(req)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to start process.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, workspacesdk.StartProcessResponse{
|
||||
ID: proc.id,
|
||||
Started: true,
|
||||
})
|
||||
}
|
||||
|
||||
// handleListProcesses lists all tracked processes.
|
||||
func (api *API) handleListProcesses(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
infos := api.manager.list()
|
||||
httpapi.Write(ctx, rw, http.StatusOK, workspacesdk.ListProcessesResponse{
|
||||
Processes: infos,
|
||||
})
|
||||
}
|
||||
|
||||
// handleProcessOutput returns the output of a process.
|
||||
func (api *API) handleProcessOutput(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
id := chi.URLParam(r, "id")
|
||||
proc, ok := api.manager.get(id)
|
||||
if !ok {
|
||||
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
|
||||
Message: fmt.Sprintf("Process %q not found.", id),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
output, truncated := proc.output()
|
||||
info := proc.info()
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, workspacesdk.ProcessOutputResponse{
|
||||
Output: output,
|
||||
Truncated: truncated,
|
||||
Running: info.Running,
|
||||
ExitCode: info.ExitCode,
|
||||
})
|
||||
}
|
||||
|
||||
// handleSignalProcess sends a signal to a running process.
|
||||
func (api *API) handleSignalProcess(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
id := chi.URLParam(r, "id")
|
||||
|
||||
var req workspacesdk.SignalProcessRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Request body must be valid JSON.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if req.Signal == "" {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Signal is required.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if req.Signal != "kill" && req.Signal != "terminate" {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: fmt.Sprintf(
|
||||
"Unsupported signal %q. Use \"kill\" or \"terminate\".",
|
||||
req.Signal,
|
||||
),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if err := api.manager.signal(id, req.Signal); err != nil {
|
||||
switch {
|
||||
case errors.Is(err, errProcessNotFound):
|
||||
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
|
||||
Message: fmt.Sprintf("Process %q not found.", id),
|
||||
})
|
||||
case errors.Is(err, errProcessNotRunning):
|
||||
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
|
||||
Message: fmt.Sprintf(
|
||||
"Process %q is not running.", id,
|
||||
),
|
||||
})
|
||||
default:
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to signal process.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.Response{
|
||||
Message: fmt.Sprintf(
|
||||
"Signal %q sent to process %q.", req.Signal, id,
|
||||
),
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,636 @@
|
||||
package agentproc_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
"github.com/coder/coder/v2/agent/agentproc"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
// postStart sends a POST /start request and returns the recorder.
|
||||
func postStart(t *testing.T, handler http.Handler, req workspacesdk.StartProcessRequest) *httptest.ResponseRecorder {
|
||||
t.Helper()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
body, err := json.Marshal(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequestWithContext(ctx, http.MethodPost, "/start", bytes.NewReader(body))
|
||||
handler.ServeHTTP(w, r)
|
||||
return w
|
||||
}
|
||||
|
||||
// getList sends a GET /list request and returns the recorder.
|
||||
func getList(t *testing.T, handler http.Handler) *httptest.ResponseRecorder {
|
||||
t.Helper()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequestWithContext(ctx, http.MethodGet, "/list", nil)
|
||||
handler.ServeHTTP(w, r)
|
||||
return w
|
||||
}
|
||||
|
||||
// getOutput sends a GET /{id}/output request and returns the
|
||||
// recorder.
|
||||
func getOutput(t *testing.T, handler http.Handler, id string) *httptest.ResponseRecorder {
|
||||
t.Helper()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("/%s/output", id), nil)
|
||||
handler.ServeHTTP(w, r)
|
||||
return w
|
||||
}
|
||||
|
||||
// postSignal sends a POST /{id}/signal request and returns
|
||||
// the recorder.
|
||||
func postSignal(t *testing.T, handler http.Handler, id string, req workspacesdk.SignalProcessRequest) *httptest.ResponseRecorder {
|
||||
t.Helper()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
body, err := json.Marshal(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("/%s/signal", id), bytes.NewReader(body))
|
||||
handler.ServeHTTP(w, r)
|
||||
return w
|
||||
}
|
||||
|
||||
// newTestAPI creates a new API with a test logger and default
|
||||
// execer, returning the handler and API.
|
||||
func newTestAPI(t *testing.T) http.Handler {
|
||||
t.Helper()
|
||||
|
||||
logger := slogtest.Make(t, &slogtest.Options{
|
||||
IgnoreErrors: true,
|
||||
}).Leveled(slog.LevelDebug)
|
||||
api := agentproc.NewAPI(logger, agentexec.DefaultExecer)
|
||||
t.Cleanup(func() {
|
||||
_ = api.Close()
|
||||
})
|
||||
return api.Routes()
|
||||
}
|
||||
|
||||
// waitForExit polls the output endpoint until the process is
|
||||
// no longer running or the context expires.
|
||||
func waitForExit(t *testing.T, handler http.Handler, id string) workspacesdk.ProcessOutputResponse {
|
||||
t.Helper()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
ticker := time.NewTicker(50 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out waiting for process to exit")
|
||||
case <-ticker.C:
|
||||
w := getOutput(t, handler, id)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var resp workspacesdk.ProcessOutputResponse
|
||||
err := json.NewDecoder(w.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
if !resp.Running {
|
||||
return resp
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// startAndGetID is a helper that starts a process and returns
|
||||
// the process ID.
|
||||
func startAndGetID(t *testing.T, handler http.Handler, req workspacesdk.StartProcessRequest) string {
|
||||
t.Helper()
|
||||
|
||||
w := postStart(t, handler, req)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var resp workspacesdk.StartProcessResponse
|
||||
err := json.NewDecoder(w.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
require.True(t, resp.Started)
|
||||
require.NotEmpty(t, resp.ID)
|
||||
return resp.ID
|
||||
}
|
||||
|
||||
func TestStartProcess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("ForegroundCommand", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newTestAPI(t)
|
||||
w := postStart(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: "echo hello",
|
||||
})
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var resp workspacesdk.StartProcessResponse
|
||||
err := json.NewDecoder(w.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
require.True(t, resp.Started)
|
||||
require.NotEmpty(t, resp.ID)
|
||||
})
|
||||
|
||||
t.Run("BackgroundCommand", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newTestAPI(t)
|
||||
w := postStart(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: "echo background",
|
||||
Background: true,
|
||||
})
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var resp workspacesdk.StartProcessResponse
|
||||
err := json.NewDecoder(w.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
require.True(t, resp.Started)
|
||||
require.NotEmpty(t, resp.ID)
|
||||
})
|
||||
|
||||
t.Run("EmptyCommand", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newTestAPI(t)
|
||||
w := postStart(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: "",
|
||||
})
|
||||
require.Equal(t, http.StatusBadRequest, w.Code)
|
||||
|
||||
var resp codersdk.Response
|
||||
err := json.NewDecoder(w.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, resp.Message, "Command is required")
|
||||
})
|
||||
|
||||
t.Run("MalformedJSON", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newTestAPI(t)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequestWithContext(ctx, http.MethodPost, "/start", strings.NewReader("{invalid json"))
|
||||
handler.ServeHTTP(w, r)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, w.Code)
|
||||
|
||||
var resp codersdk.Response
|
||||
err := json.NewDecoder(w.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, resp.Message, "valid JSON")
|
||||
})
|
||||
|
||||
t.Run("CustomWorkDir", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newTestAPI(t)
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// Write a marker file to verify the command ran in
|
||||
// the correct directory. Comparing pwd output is
|
||||
// unreliable on Windows where Git Bash returns POSIX
|
||||
// paths.
|
||||
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: "touch marker.txt && ls marker.txt",
|
||||
WorkDir: tmpDir,
|
||||
})
|
||||
|
||||
resp := waitForExit(t, handler, id)
|
||||
require.NotNil(t, resp.ExitCode)
|
||||
require.Equal(t, 0, *resp.ExitCode)
|
||||
require.Contains(t, resp.Output, "marker.txt")
|
||||
})
|
||||
|
||||
t.Run("CustomEnv", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newTestAPI(t)
|
||||
|
||||
// Use a unique env var name to avoid collisions in
|
||||
// parallel tests.
|
||||
envKey := fmt.Sprintf("TEST_PROC_ENV_%d", time.Now().UnixNano())
|
||||
envVal := "custom_value_12345"
|
||||
|
||||
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: fmt.Sprintf("printenv %s", envKey),
|
||||
Env: map[string]string{envKey: envVal},
|
||||
})
|
||||
|
||||
resp := waitForExit(t, handler, id)
|
||||
require.NotNil(t, resp.ExitCode)
|
||||
require.Equal(t, 0, *resp.ExitCode)
|
||||
require.Contains(t, strings.TrimSpace(resp.Output), envVal)
|
||||
})
|
||||
}
|
||||
|
||||
func TestListProcesses(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("NoProcesses", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newTestAPI(t)
|
||||
w := getList(t, handler)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var resp workspacesdk.ListProcessesResponse
|
||||
err := json.NewDecoder(w.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.Processes)
|
||||
require.Empty(t, resp.Processes)
|
||||
})
|
||||
|
||||
t.Run("MixedRunningAndExited", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newTestAPI(t)
|
||||
|
||||
// Start a process that exits quickly.
|
||||
exitedID := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: "echo done",
|
||||
})
|
||||
waitForExit(t, handler, exitedID)
|
||||
|
||||
// Start a long-running process.
|
||||
runningID := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: "sleep 300",
|
||||
Background: true,
|
||||
})
|
||||
|
||||
// List should contain both.
|
||||
w := getList(t, handler)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var resp workspacesdk.ListProcessesResponse
|
||||
err := json.NewDecoder(w.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Processes, 2)
|
||||
|
||||
procMap := make(map[string]workspacesdk.ProcessInfo)
|
||||
for _, p := range resp.Processes {
|
||||
procMap[p.ID] = p
|
||||
}
|
||||
|
||||
exited, ok := procMap[exitedID]
|
||||
require.True(t, ok, "exited process should be in list")
|
||||
require.False(t, exited.Running)
|
||||
require.NotNil(t, exited.ExitCode)
|
||||
|
||||
running, ok := procMap[runningID]
|
||||
require.True(t, ok, "running process should be in list")
|
||||
require.True(t, running.Running)
|
||||
|
||||
// Clean up the long-running process.
|
||||
sw := postSignal(t, handler, runningID, workspacesdk.SignalProcessRequest{
|
||||
Signal: "kill",
|
||||
})
|
||||
require.Equal(t, http.StatusOK, sw.Code)
|
||||
})
|
||||
}
|
||||
|
||||
func TestProcessOutput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("ExitedProcess", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newTestAPI(t)
|
||||
|
||||
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: "echo hello-output",
|
||||
})
|
||||
|
||||
resp := waitForExit(t, handler, id)
|
||||
require.False(t, resp.Running)
|
||||
require.NotNil(t, resp.ExitCode)
|
||||
require.Equal(t, 0, *resp.ExitCode)
|
||||
require.Contains(t, resp.Output, "hello-output")
|
||||
})
|
||||
|
||||
t.Run("RunningProcess", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newTestAPI(t)
|
||||
|
||||
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: "sleep 300",
|
||||
Background: true,
|
||||
})
|
||||
|
||||
w := getOutput(t, handler, id)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var resp workspacesdk.ProcessOutputResponse
|
||||
err := json.NewDecoder(w.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
require.True(t, resp.Running)
|
||||
|
||||
// Kill and wait for the process so cleanup does
|
||||
// not hang.
|
||||
postSignal(
|
||||
t, handler, id,
|
||||
workspacesdk.SignalProcessRequest{Signal: "kill"},
|
||||
)
|
||||
waitForExit(t, handler, id)
|
||||
})
|
||||
|
||||
t.Run("NonexistentProcess", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newTestAPI(t)
|
||||
w := getOutput(t, handler, "nonexistent-id-12345")
|
||||
require.Equal(t, http.StatusNotFound, w.Code)
|
||||
|
||||
var resp codersdk.Response
|
||||
err := json.NewDecoder(w.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, resp.Message, "not found")
|
||||
})
|
||||
}
|
||||
|
||||
func TestSignalProcess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("KillRunning", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newTestAPI(t)
|
||||
|
||||
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: "sleep 300",
|
||||
Background: true,
|
||||
})
|
||||
|
||||
w := postSignal(t, handler, id, workspacesdk.SignalProcessRequest{
|
||||
Signal: "kill",
|
||||
})
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
// Verify the process exits.
|
||||
resp := waitForExit(t, handler, id)
|
||||
require.False(t, resp.Running)
|
||||
})
|
||||
|
||||
t.Run("TerminateRunning", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("SIGTERM is not supported on Windows")
|
||||
}
|
||||
|
||||
handler := newTestAPI(t)
|
||||
|
||||
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: "sleep 300",
|
||||
Background: true,
|
||||
})
|
||||
|
||||
w := postSignal(t, handler, id, workspacesdk.SignalProcessRequest{
|
||||
Signal: "terminate",
|
||||
})
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
// Verify the process exits.
|
||||
resp := waitForExit(t, handler, id)
|
||||
require.False(t, resp.Running)
|
||||
})
|
||||
|
||||
t.Run("NonexistentProcess", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newTestAPI(t)
|
||||
w := postSignal(t, handler, "nonexistent-id-12345", workspacesdk.SignalProcessRequest{
|
||||
Signal: "kill",
|
||||
})
|
||||
require.Equal(t, http.StatusNotFound, w.Code)
|
||||
})
|
||||
|
||||
t.Run("AlreadyExitedProcess", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newTestAPI(t)
|
||||
|
||||
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: "echo done",
|
||||
})
|
||||
|
||||
// Wait for exit first.
|
||||
waitForExit(t, handler, id)
|
||||
|
||||
// Signaling an exited process should return 409
|
||||
// Conflict via the errProcessNotRunning sentinel.
|
||||
w := postSignal(t, handler, id, workspacesdk.SignalProcessRequest{
|
||||
Signal: "kill",
|
||||
})
|
||||
assert.Equal(t, http.StatusConflict, w.Code,
|
||||
"expected 409 for signaling exited process, got %d", w.Code)
|
||||
})
|
||||
|
||||
t.Run("EmptySignal", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newTestAPI(t)
|
||||
|
||||
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: "sleep 300",
|
||||
Background: true,
|
||||
})
|
||||
|
||||
w := postSignal(t, handler, id, workspacesdk.SignalProcessRequest{
|
||||
Signal: "",
|
||||
})
|
||||
require.Equal(t, http.StatusBadRequest, w.Code)
|
||||
|
||||
var resp codersdk.Response
|
||||
err := json.NewDecoder(w.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, resp.Message, "Signal is required")
|
||||
|
||||
// Clean up.
|
||||
postSignal(t, handler, id, workspacesdk.SignalProcessRequest{
|
||||
Signal: "kill",
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("InvalidSignal", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newTestAPI(t)
|
||||
|
||||
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: "sleep 300",
|
||||
Background: true,
|
||||
})
|
||||
|
||||
w := postSignal(t, handler, id, workspacesdk.SignalProcessRequest{
|
||||
Signal: "SIGFOO",
|
||||
})
|
||||
require.Equal(t, http.StatusBadRequest, w.Code)
|
||||
|
||||
var resp codersdk.Response
|
||||
err := json.NewDecoder(w.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, resp.Message, "Unsupported signal")
|
||||
|
||||
// Clean up.
|
||||
postSignal(t, handler, id, workspacesdk.SignalProcessRequest{
|
||||
Signal: "kill",
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestProcessLifecycle(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("StartWaitCheckOutput", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newTestAPI(t)
|
||||
|
||||
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: "echo lifecycle-test && echo second-line",
|
||||
})
|
||||
|
||||
resp := waitForExit(t, handler, id)
|
||||
require.False(t, resp.Running)
|
||||
require.NotNil(t, resp.ExitCode)
|
||||
require.Equal(t, 0, *resp.ExitCode)
|
||||
require.Contains(t, resp.Output, "lifecycle-test")
|
||||
require.Contains(t, resp.Output, "second-line")
|
||||
})
|
||||
|
||||
t.Run("NonZeroExitCode", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newTestAPI(t)
|
||||
|
||||
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: "exit 42",
|
||||
})
|
||||
|
||||
resp := waitForExit(t, handler, id)
|
||||
require.False(t, resp.Running)
|
||||
require.NotNil(t, resp.ExitCode)
|
||||
require.Equal(t, 42, *resp.ExitCode)
|
||||
})
|
||||
|
||||
t.Run("StartSignalVerifyExit", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newTestAPI(t)
|
||||
|
||||
// Start a long-running background process.
|
||||
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: "sleep 300",
|
||||
Background: true,
|
||||
})
|
||||
|
||||
// Verify it's running.
|
||||
w := getOutput(t, handler, id)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
var running workspacesdk.ProcessOutputResponse
|
||||
err := json.NewDecoder(w.Body).Decode(&running)
|
||||
require.NoError(t, err)
|
||||
require.True(t, running.Running)
|
||||
|
||||
// Signal it.
|
||||
sw := postSignal(t, handler, id, workspacesdk.SignalProcessRequest{
|
||||
Signal: "kill",
|
||||
})
|
||||
require.Equal(t, http.StatusOK, sw.Code)
|
||||
|
||||
// Verify it exits.
|
||||
resp := waitForExit(t, handler, id)
|
||||
require.False(t, resp.Running)
|
||||
require.NotNil(t, resp.ExitCode)
|
||||
})
|
||||
|
||||
t.Run("OutputExceedsBuffer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newTestAPI(t)
|
||||
|
||||
// Generate output that exceeds MaxHeadBytes +
|
||||
// MaxTailBytes. Each line is ~100 chars, and we
|
||||
// need more than 32KB total (16KB head + 16KB
|
||||
// tail).
|
||||
lineCount := (agentproc.MaxHeadBytes+agentproc.MaxTailBytes)/50 + 500
|
||||
cmd := fmt.Sprintf(
|
||||
"for i in $(seq 1 %d); do echo \"line-$i-padding-to-make-this-longer-than-fifty-characters-total\"; done",
|
||||
lineCount,
|
||||
)
|
||||
|
||||
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: cmd,
|
||||
})
|
||||
|
||||
resp := waitForExit(t, handler, id)
|
||||
require.False(t, resp.Running)
|
||||
require.NotNil(t, resp.ExitCode)
|
||||
require.Equal(t, 0, *resp.ExitCode)
|
||||
|
||||
// The output should be truncated with head/tail
|
||||
// strategy metadata.
|
||||
require.NotNil(t, resp.Truncated, "large output should be truncated")
|
||||
require.Equal(t, "head_tail", resp.Truncated.Strategy)
|
||||
require.Greater(t, resp.Truncated.OmittedBytes, 0)
|
||||
require.Greater(t, resp.Truncated.OriginalBytes, resp.Truncated.RetainedBytes)
|
||||
|
||||
// Verify the output contains the omission marker.
|
||||
require.Contains(t, resp.Output, "... [omitted")
|
||||
})
|
||||
|
||||
t.Run("StderrCaptured", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newTestAPI(t)
|
||||
|
||||
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: "echo stdout-msg && echo stderr-msg >&2",
|
||||
})
|
||||
|
||||
resp := waitForExit(t, handler, id)
|
||||
require.False(t, resp.Running)
|
||||
require.NotNil(t, resp.ExitCode)
|
||||
require.Equal(t, 0, *resp.ExitCode)
|
||||
// Both stdout and stderr should be captured.
|
||||
require.Contains(t, resp.Output, "stdout-msg")
|
||||
require.Contains(t, resp.Output, "stderr-msg")
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,309 @@
|
||||
package agentproc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
)
|
||||
|
||||
const (
|
||||
// MaxHeadBytes is the number of bytes retained from the
|
||||
// beginning of the output for LLM consumption.
|
||||
MaxHeadBytes = 16 << 10 // 16KB
|
||||
|
||||
// MaxTailBytes is the number of bytes retained from the
|
||||
// end of the output for LLM consumption.
|
||||
MaxTailBytes = 16 << 10 // 16KB
|
||||
|
||||
// MaxLineLength is the maximum length of a single line
|
||||
// before it is truncated. This prevents minified files
|
||||
// or other long single-line output from consuming the
|
||||
// entire buffer.
|
||||
MaxLineLength = 2048
|
||||
|
||||
// lineTruncationSuffix is appended to lines that exceed
|
||||
// MaxLineLength.
|
||||
lineTruncationSuffix = " ... [truncated]"
|
||||
)
|
||||
|
||||
// HeadTailBuffer is a thread-safe buffer that captures process
|
||||
// output and provides head+tail truncation for LLM consumption.
|
||||
// It implements io.Writer so it can be used directly as
|
||||
// cmd.Stdout or cmd.Stderr.
|
||||
//
|
||||
// The buffer stores up to MaxHeadBytes from the beginning of
|
||||
// the output and up to MaxTailBytes from the end in a ring
|
||||
// buffer, keeping total memory usage bounded regardless of
|
||||
// how much output is written.
|
||||
type HeadTailBuffer struct {
|
||||
mu sync.Mutex
|
||||
head []byte
|
||||
tail []byte
|
||||
tailPos int
|
||||
tailFull bool
|
||||
headFull bool
|
||||
totalBytes int
|
||||
maxHead int
|
||||
maxTail int
|
||||
}
|
||||
|
||||
// NewHeadTailBuffer creates a new HeadTailBuffer with the
|
||||
// default head and tail sizes.
|
||||
func NewHeadTailBuffer() *HeadTailBuffer {
|
||||
return &HeadTailBuffer{
|
||||
maxHead: MaxHeadBytes,
|
||||
maxTail: MaxTailBytes,
|
||||
}
|
||||
}
|
||||
|
||||
// NewHeadTailBufferSized creates a HeadTailBuffer with custom
|
||||
// head and tail sizes. This is useful for testing truncation
|
||||
// logic with smaller buffers.
|
||||
func NewHeadTailBufferSized(maxHead, maxTail int) *HeadTailBuffer {
|
||||
return &HeadTailBuffer{
|
||||
maxHead: maxHead,
|
||||
maxTail: maxTail,
|
||||
}
|
||||
}
|
||||
|
||||
// Write implements io.Writer. It is safe for concurrent use.
|
||||
// All bytes are accepted; the return value always equals
|
||||
// len(p) with a nil error.
|
||||
func (b *HeadTailBuffer) Write(p []byte) (int, error) {
|
||||
if len(p) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
n := len(p)
|
||||
b.totalBytes += n
|
||||
|
||||
// Fill head buffer if it is not yet full.
|
||||
if !b.headFull {
|
||||
remaining := b.maxHead - len(b.head)
|
||||
if remaining > 0 {
|
||||
take := remaining
|
||||
if take > len(p) {
|
||||
take = len(p)
|
||||
}
|
||||
b.head = append(b.head, p[:take]...)
|
||||
p = p[take:]
|
||||
if len(b.head) >= b.maxHead {
|
||||
b.headFull = true
|
||||
}
|
||||
}
|
||||
if len(p) == 0 {
|
||||
return n, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Write remaining bytes into the tail ring buffer.
|
||||
b.writeTail(p)
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// writeTail appends data to the tail ring buffer. The caller
|
||||
// must hold b.mu.
|
||||
func (b *HeadTailBuffer) writeTail(p []byte) {
|
||||
if b.maxTail <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Lazily allocate the tail buffer on first use.
|
||||
if b.tail == nil {
|
||||
b.tail = make([]byte, b.maxTail)
|
||||
}
|
||||
|
||||
for len(p) > 0 {
|
||||
// Write as many bytes as fit starting at tailPos.
|
||||
space := b.maxTail - b.tailPos
|
||||
take := space
|
||||
if take > len(p) {
|
||||
take = len(p)
|
||||
}
|
||||
copy(b.tail[b.tailPos:b.tailPos+take], p[:take])
|
||||
p = p[take:]
|
||||
b.tailPos += take
|
||||
if b.tailPos >= b.maxTail {
|
||||
b.tailPos = 0
|
||||
b.tailFull = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// tailBytes returns the current tail contents in order. The
|
||||
// caller must hold b.mu.
|
||||
func (b *HeadTailBuffer) tailBytes() []byte {
|
||||
if b.tail == nil {
|
||||
return nil
|
||||
}
|
||||
if !b.tailFull {
|
||||
// Haven't wrapped yet; data is [0, tailPos).
|
||||
return b.tail[:b.tailPos]
|
||||
}
|
||||
// Wrapped: data is [tailPos, maxTail) + [0, tailPos).
|
||||
out := make([]byte, b.maxTail)
|
||||
n := copy(out, b.tail[b.tailPos:])
|
||||
copy(out[n:], b.tail[:b.tailPos])
|
||||
return out
|
||||
}
|
||||
|
||||
// Bytes returns a copy of the raw buffer contents. If no
|
||||
// truncation has occurred the full output is returned;
|
||||
// otherwise the head and tail portions are concatenated.
|
||||
func (b *HeadTailBuffer) Bytes() []byte {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
tail := b.tailBytes()
|
||||
if len(tail) == 0 {
|
||||
out := make([]byte, len(b.head))
|
||||
copy(out, b.head)
|
||||
return out
|
||||
}
|
||||
out := make([]byte, len(b.head)+len(tail))
|
||||
copy(out, b.head)
|
||||
copy(out[len(b.head):], tail)
|
||||
return out
|
||||
}
|
||||
|
||||
// Len returns the number of bytes currently stored in the
|
||||
// buffer.
|
||||
func (b *HeadTailBuffer) Len() int {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
tailLen := 0
|
||||
if b.tailFull {
|
||||
tailLen = b.maxTail
|
||||
} else if b.tail != nil {
|
||||
tailLen = b.tailPos
|
||||
}
|
||||
return len(b.head) + tailLen
|
||||
}
|
||||
|
||||
// TotalWritten returns the total number of bytes written to
|
||||
// the buffer, which may exceed the stored capacity.
|
||||
func (b *HeadTailBuffer) TotalWritten() int {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
return b.totalBytes
|
||||
}
|
||||
|
||||
// Output returns the truncated output suitable for LLM
|
||||
// consumption, along with truncation metadata. If the total
|
||||
// output fits within the head buffer alone, the full output is
|
||||
// returned with nil truncation info. Otherwise the head and
|
||||
// tail are joined with an omission marker and long lines are
|
||||
// truncated.
|
||||
func (b *HeadTailBuffer) Output() (string, *workspacesdk.ProcessTruncation) {
|
||||
b.mu.Lock()
|
||||
head := make([]byte, len(b.head))
|
||||
copy(head, b.head)
|
||||
tail := b.tailBytes()
|
||||
total := b.totalBytes
|
||||
headFull := b.headFull
|
||||
b.mu.Unlock()
|
||||
|
||||
storedLen := len(head) + len(tail)
|
||||
|
||||
// If everything fits, no head/tail split is needed.
|
||||
if !headFull || len(tail) == 0 {
|
||||
out := truncateLines(string(head))
|
||||
if total == 0 {
|
||||
return "", nil
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// We have both head and tail data, meaning the total
|
||||
// output exceeded the head capacity. Build the
|
||||
// combined output with an omission marker.
|
||||
omitted := total - storedLen
|
||||
headStr := truncateLines(string(head))
|
||||
tailStr := truncateLines(string(tail))
|
||||
|
||||
var sb strings.Builder
|
||||
_, _ = sb.WriteString(headStr)
|
||||
if omitted > 0 {
|
||||
_, _ = sb.WriteString(fmt.Sprintf(
|
||||
"\n\n... [omitted %d bytes] ...\n\n",
|
||||
omitted,
|
||||
))
|
||||
} else {
|
||||
// Head and tail are contiguous but were stored
|
||||
// separately because the head filled up.
|
||||
_, _ = sb.WriteString("\n")
|
||||
}
|
||||
_, _ = sb.WriteString(tailStr)
|
||||
result := sb.String()
|
||||
|
||||
return result, &workspacesdk.ProcessTruncation{
|
||||
OriginalBytes: total,
|
||||
RetainedBytes: len(result),
|
||||
OmittedBytes: omitted,
|
||||
Strategy: "head_tail",
|
||||
}
|
||||
}
|
||||
|
||||
// truncateLines scans the input line by line and truncates
|
||||
// any line longer than MaxLineLength.
|
||||
func truncateLines(s string) string {
|
||||
if len(s) <= MaxLineLength {
|
||||
// Fast path: if the entire string is shorter than
|
||||
// the max line length, no line can exceed it.
|
||||
return s
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
b.Grow(len(s))
|
||||
|
||||
for len(s) > 0 {
|
||||
idx := strings.IndexByte(s, '\n')
|
||||
var line string
|
||||
if idx == -1 {
|
||||
line = s
|
||||
s = ""
|
||||
} else {
|
||||
line = s[:idx]
|
||||
s = s[idx+1:]
|
||||
}
|
||||
|
||||
if len(line) > MaxLineLength {
|
||||
// Truncate preserving the suffix length so the
|
||||
// total does not exceed a reasonable size.
|
||||
cut := MaxLineLength - len(lineTruncationSuffix)
|
||||
if cut < 0 {
|
||||
cut = 0
|
||||
}
|
||||
_, _ = b.WriteString(line[:cut])
|
||||
_, _ = b.WriteString(lineTruncationSuffix)
|
||||
} else {
|
||||
_, _ = b.WriteString(line)
|
||||
}
|
||||
|
||||
// Re-add the newline unless this was the final
|
||||
// segment without a trailing newline.
|
||||
if idx != -1 {
|
||||
_ = b.WriteByte('\n')
|
||||
}
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// Reset clears the buffer, discarding all data.
|
||||
func (b *HeadTailBuffer) Reset() {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
b.head = nil
|
||||
b.tail = nil
|
||||
b.tailPos = 0
|
||||
b.tailFull = false
|
||||
b.headFull = false
|
||||
b.totalBytes = 0
|
||||
}
|
||||
@@ -0,0 +1,338 @@
|
||||
package agentproc_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/agent/agentproc"
|
||||
)
|
||||
|
||||
func TestHeadTailBuffer_EmptyBuffer(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
buf := agentproc.NewHeadTailBuffer()
|
||||
out, info := buf.Output()
|
||||
require.Empty(t, out)
|
||||
require.Nil(t, info)
|
||||
require.Equal(t, 0, buf.Len())
|
||||
require.Equal(t, 0, buf.TotalWritten())
|
||||
require.Empty(t, buf.Bytes())
|
||||
}
|
||||
|
||||
func TestHeadTailBuffer_SmallOutput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
buf := agentproc.NewHeadTailBuffer()
|
||||
data := "hello world\n"
|
||||
n, err := buf.Write([]byte(data))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, len(data), n)
|
||||
|
||||
out, info := buf.Output()
|
||||
require.Equal(t, data, out)
|
||||
require.Nil(t, info, "small output should not be truncated")
|
||||
require.Equal(t, len(data), buf.Len())
|
||||
require.Equal(t, len(data), buf.TotalWritten())
|
||||
}
|
||||
|
||||
func TestHeadTailBuffer_ExactlyHeadSize(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
buf := agentproc.NewHeadTailBuffer()
|
||||
|
||||
// Build data that is exactly MaxHeadBytes using short
|
||||
// lines so that line truncation does not apply.
|
||||
line := strings.Repeat("x", 79) + "\n" // 80 bytes per line
|
||||
count := agentproc.MaxHeadBytes / len(line)
|
||||
pad := agentproc.MaxHeadBytes - (count * len(line))
|
||||
data := strings.Repeat(line, count) + strings.Repeat("y", pad)
|
||||
require.Equal(t, agentproc.MaxHeadBytes, len(data),
|
||||
"test data must be exactly MaxHeadBytes")
|
||||
|
||||
n, err := buf.Write([]byte(data))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, agentproc.MaxHeadBytes, n)
|
||||
|
||||
out, info := buf.Output()
|
||||
require.Equal(t, data, out)
|
||||
require.Nil(t, info, "output fitting in head should not be truncated")
|
||||
require.Equal(t, agentproc.MaxHeadBytes, buf.Len())
|
||||
}
|
||||
|
||||
func TestHeadTailBuffer_HeadPlusTailNoOmission(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Use a small buffer so we can test the boundary where
|
||||
// head fills and tail starts but nothing is omitted.
|
||||
// With maxHead=10, maxTail=10, writing exactly 20 bytes
|
||||
// means head gets 10, tail gets 10, omitted = 0.
|
||||
buf := agentproc.NewHeadTailBufferSized(10, 10)
|
||||
|
||||
data := "0123456789abcdefghij" // 20 bytes
|
||||
n, err := buf.Write([]byte(data))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 20, n)
|
||||
|
||||
out, info := buf.Output()
|
||||
require.NotNil(t, info)
|
||||
require.Equal(t, 0, info.OmittedBytes)
|
||||
require.Equal(t, "head_tail", info.Strategy)
|
||||
// The output should contain both head and tail.
|
||||
require.Contains(t, out, "0123456789")
|
||||
require.Contains(t, out, "abcdefghij")
|
||||
}
|
||||
|
||||
func TestHeadTailBuffer_LargeOutputTruncation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Use small head/tail so truncation is easy to verify.
|
||||
buf := agentproc.NewHeadTailBufferSized(10, 10)
|
||||
|
||||
// Write 100 bytes: head=10, tail=10, omitted=80.
|
||||
data := strings.Repeat("A", 50) + strings.Repeat("Z", 50)
|
||||
n, err := buf.Write([]byte(data))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 100, n)
|
||||
|
||||
out, info := buf.Output()
|
||||
require.NotNil(t, info)
|
||||
require.Equal(t, 100, info.OriginalBytes)
|
||||
require.Equal(t, 80, info.OmittedBytes)
|
||||
require.Equal(t, "head_tail", info.Strategy)
|
||||
|
||||
// Head should be first 10 bytes (all A's).
|
||||
require.True(t, strings.HasPrefix(out, "AAAAAAAAAA"))
|
||||
// Tail should be last 10 bytes (all Z's).
|
||||
require.True(t, strings.HasSuffix(out, "ZZZZZZZZZZ"))
|
||||
// Omission marker should be present.
|
||||
require.Contains(t, out, "... [omitted 80 bytes] ...")
|
||||
|
||||
require.Equal(t, 20, buf.Len())
|
||||
require.Equal(t, 100, buf.TotalWritten())
|
||||
}
|
||||
|
||||
func TestHeadTailBuffer_MultiMBStaysBounded(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
buf := agentproc.NewHeadTailBuffer()
|
||||
|
||||
// Write 5MB of data in chunks.
|
||||
chunk := []byte(strings.Repeat("x", 4096) + "\n")
|
||||
totalWritten := 0
|
||||
for totalWritten < 5*1024*1024 {
|
||||
n, err := buf.Write(chunk)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, len(chunk), n)
|
||||
totalWritten += n
|
||||
}
|
||||
|
||||
// Memory should be bounded to head+tail.
|
||||
require.LessOrEqual(t, buf.Len(),
|
||||
agentproc.MaxHeadBytes+agentproc.MaxTailBytes)
|
||||
require.Equal(t, totalWritten, buf.TotalWritten())
|
||||
|
||||
out, info := buf.Output()
|
||||
require.NotNil(t, info)
|
||||
require.Equal(t, totalWritten, info.OriginalBytes)
|
||||
require.Greater(t, info.OmittedBytes, 0)
|
||||
require.NotEmpty(t, out)
|
||||
}
|
||||
|
||||
func TestHeadTailBuffer_LongLineTruncation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
buf := agentproc.NewHeadTailBuffer()
|
||||
|
||||
// Write a line longer than MaxLineLength.
|
||||
longLine := strings.Repeat("m", agentproc.MaxLineLength+500)
|
||||
_, err := buf.Write([]byte(longLine + "\n"))
|
||||
require.NoError(t, err)
|
||||
|
||||
out, _ := buf.Output()
|
||||
lines := strings.Split(strings.TrimRight(out, "\n"), "\n")
|
||||
require.Len(t, lines, 1)
|
||||
require.LessOrEqual(t, len(lines[0]), agentproc.MaxLineLength)
|
||||
require.True(t, strings.HasSuffix(lines[0], "... [truncated]"))
|
||||
}
|
||||
|
||||
func TestHeadTailBuffer_LongLineInTail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Use small buffers so we can force data into the tail.
|
||||
buf := agentproc.NewHeadTailBufferSized(20, 5000)
|
||||
|
||||
// Fill head with short data.
|
||||
_, err := buf.Write([]byte("head data goes here\n"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Now write a very long line into the tail.
|
||||
longLine := strings.Repeat("T", agentproc.MaxLineLength+100)
|
||||
_, err = buf.Write([]byte(longLine + "\n"))
|
||||
require.NoError(t, err)
|
||||
|
||||
out, info := buf.Output()
|
||||
require.NotNil(t, info)
|
||||
// The long line in the tail should be truncated.
|
||||
require.Contains(t, out, "... [truncated]")
|
||||
}
|
||||
|
||||
func TestHeadTailBuffer_ConcurrentWrites(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
buf := agentproc.NewHeadTailBuffer()
|
||||
|
||||
const goroutines = 10
|
||||
const writes = 1000
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(goroutines)
|
||||
|
||||
for g := range goroutines {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
line := fmt.Sprintf("goroutine-%d: data\n", g)
|
||||
for range writes {
|
||||
_, err := buf.Write([]byte(line))
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify totals are consistent.
|
||||
require.Greater(t, buf.TotalWritten(), 0)
|
||||
require.Greater(t, buf.Len(), 0)
|
||||
|
||||
out, _ := buf.Output()
|
||||
require.NotEmpty(t, out)
|
||||
}
|
||||
|
||||
func TestHeadTailBuffer_TruncationInfoFields(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
buf := agentproc.NewHeadTailBufferSized(10, 10)
|
||||
|
||||
// Write enough to cause omission.
|
||||
data := strings.Repeat("D", 50)
|
||||
_, err := buf.Write([]byte(data))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, info := buf.Output()
|
||||
require.NotNil(t, info)
|
||||
require.Equal(t, 50, info.OriginalBytes)
|
||||
require.Equal(t, 30, info.OmittedBytes)
|
||||
require.Equal(t, "head_tail", info.Strategy)
|
||||
// RetainedBytes is the length of the formatted output
|
||||
// string including the omission marker.
|
||||
require.Greater(t, info.RetainedBytes, 0)
|
||||
}
|
||||
|
||||
func TestHeadTailBuffer_MultipleSmallWrites(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
buf := agentproc.NewHeadTailBuffer()
|
||||
|
||||
// Write one byte at a time.
|
||||
expected := "hello world"
|
||||
for i := range len(expected) {
|
||||
n, err := buf.Write([]byte{expected[i]})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, n)
|
||||
}
|
||||
|
||||
out, info := buf.Output()
|
||||
require.Equal(t, expected, out)
|
||||
require.Nil(t, info)
|
||||
}
|
||||
|
||||
func TestHeadTailBuffer_WriteEmptySlice(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
buf := agentproc.NewHeadTailBuffer()
|
||||
n, err := buf.Write([]byte{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, n)
|
||||
require.Equal(t, 0, buf.TotalWritten())
|
||||
}
|
||||
|
||||
func TestHeadTailBuffer_Reset(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
buf := agentproc.NewHeadTailBuffer()
|
||||
_, err := buf.Write([]byte("some data"))
|
||||
require.NoError(t, err)
|
||||
require.Greater(t, buf.Len(), 0)
|
||||
|
||||
buf.Reset()
|
||||
|
||||
require.Equal(t, 0, buf.Len())
|
||||
require.Equal(t, 0, buf.TotalWritten())
|
||||
out, info := buf.Output()
|
||||
require.Empty(t, out)
|
||||
require.Nil(t, info)
|
||||
}
|
||||
|
||||
func TestHeadTailBuffer_BytesReturnsCopy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
buf := agentproc.NewHeadTailBuffer()
|
||||
_, err := buf.Write([]byte("original"))
|
||||
require.NoError(t, err)
|
||||
|
||||
b := buf.Bytes()
|
||||
require.Equal(t, []byte("original"), b)
|
||||
|
||||
// Mutating the returned slice should not affect the
|
||||
// buffer.
|
||||
b[0] = 'X'
|
||||
require.Equal(t, []byte("original"), buf.Bytes())
|
||||
}
|
||||
|
||||
func TestHeadTailBuffer_RingBufferWraparound(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Use a tail of 10 bytes and write enough to wrap
|
||||
// around multiple times.
|
||||
buf := agentproc.NewHeadTailBufferSized(5, 10)
|
||||
|
||||
// Fill head (5 bytes).
|
||||
_, err := buf.Write([]byte("HEADD"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Write 25 bytes into tail, wrapping 2.5 times.
|
||||
_, err = buf.Write([]byte("0123456789"))
|
||||
require.NoError(t, err)
|
||||
_, err = buf.Write([]byte("abcdefghij"))
|
||||
require.NoError(t, err)
|
||||
_, err = buf.Write([]byte("ABCDE"))
|
||||
require.NoError(t, err)
|
||||
|
||||
out, info := buf.Output()
|
||||
require.NotNil(t, info)
|
||||
// Tail should contain the last 10 bytes: "fghijABCDE".
|
||||
require.True(t, strings.HasSuffix(out, "fghijABCDE"),
|
||||
"expected tail to be last 10 bytes, got: %q", out)
|
||||
}
|
||||
|
||||
func TestHeadTailBuffer_MultipleLinesTruncated(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
buf := agentproc.NewHeadTailBuffer()
|
||||
|
||||
short := "short line\n"
|
||||
long := strings.Repeat("L", agentproc.MaxLineLength+100) + "\n"
|
||||
_, err := buf.Write([]byte(short + long + short))
|
||||
require.NoError(t, err)
|
||||
|
||||
out, _ := buf.Output()
|
||||
lines := strings.Split(strings.TrimRight(out, "\n"), "\n")
|
||||
require.Len(t, lines, 3)
|
||||
require.Equal(t, "short line", lines[0])
|
||||
require.True(t, strings.HasSuffix(lines[1], "... [truncated]"))
|
||||
require.Equal(t, "short line", lines[2])
|
||||
}
|
||||
@@ -0,0 +1,274 @@
|
||||
package agentproc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
var (
|
||||
errProcessNotFound = xerrors.New("process not found")
|
||||
errProcessNotRunning = xerrors.New("process is not running")
|
||||
)
|
||||
|
||||
// process represents a running or completed process.
|
||||
type process struct {
|
||||
mu sync.Mutex
|
||||
id string
|
||||
command string
|
||||
workDir string
|
||||
background bool
|
||||
cmd *exec.Cmd
|
||||
cancel context.CancelFunc
|
||||
buf *HeadTailBuffer
|
||||
running bool
|
||||
exitCode *int
|
||||
startedAt int64
|
||||
exitedAt *int64
|
||||
done chan struct{} // closed when process exits
|
||||
}
|
||||
|
||||
// info returns a snapshot of the process state.
|
||||
func (p *process) info() workspacesdk.ProcessInfo {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
return workspacesdk.ProcessInfo{
|
||||
ID: p.id,
|
||||
Command: p.command,
|
||||
WorkDir: p.workDir,
|
||||
Background: p.background,
|
||||
Running: p.running,
|
||||
ExitCode: p.exitCode,
|
||||
StartedAt: p.startedAt,
|
||||
ExitedAt: p.exitedAt,
|
||||
}
|
||||
}
|
||||
|
||||
// output returns the truncated output from the process buffer
|
||||
// along with optional truncation metadata.
|
||||
func (p *process) output() (string, *workspacesdk.ProcessTruncation) {
|
||||
return p.buf.Output()
|
||||
}
|
||||
|
||||
// manager tracks processes spawned by the agent.
|
||||
type manager struct {
|
||||
mu sync.Mutex
|
||||
logger slog.Logger
|
||||
execer agentexec.Execer
|
||||
clock quartz.Clock
|
||||
procs map[string]*process
|
||||
closed bool
|
||||
}
|
||||
|
||||
// newManager creates a new process manager.
|
||||
func newManager(logger slog.Logger, execer agentexec.Execer) *manager {
|
||||
return &manager{
|
||||
logger: logger,
|
||||
execer: execer,
|
||||
clock: quartz.NewReal(),
|
||||
procs: make(map[string]*process),
|
||||
}
|
||||
}
|
||||
|
||||
// start spawns a new process. Both foreground and background
|
||||
// processes use a long-lived context so the process survives
|
||||
// the HTTP request lifecycle. The background flag only affects
|
||||
// client-side polling behavior.
|
||||
func (m *manager) start(req workspacesdk.StartProcessRequest) (*process, error) {
|
||||
m.mu.Lock()
|
||||
if m.closed {
|
||||
m.mu.Unlock()
|
||||
return nil, xerrors.New("manager is closed")
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
id := uuid.New().String()
|
||||
|
||||
// Use a cancellable context so Close() can terminate
|
||||
// all processes. context.Background() is the parent so
|
||||
// the process is not tied to any HTTP request.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cmd := m.execer.CommandContext(ctx, "sh", "-c", req.Command)
|
||||
if req.WorkDir != "" {
|
||||
cmd.Dir = req.WorkDir
|
||||
}
|
||||
cmd.Stdin = nil
|
||||
|
||||
// WaitDelay ensures cmd.Wait returns promptly after
|
||||
// the process is killed, even if child processes are
|
||||
// still holding the stdout/stderr pipes open.
|
||||
cmd.WaitDelay = 5 * time.Second
|
||||
|
||||
buf := NewHeadTailBuffer()
|
||||
cmd.Stdout = buf
|
||||
cmd.Stderr = buf
|
||||
|
||||
if len(req.Env) > 0 {
|
||||
cmd.Env = os.Environ()
|
||||
for k, v := range req.Env {
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", k, v))
|
||||
}
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
cancel()
|
||||
return nil, xerrors.Errorf("start process: %w", err)
|
||||
}
|
||||
|
||||
now := m.clock.Now().Unix()
|
||||
proc := &process{
|
||||
id: id,
|
||||
command: req.Command,
|
||||
workDir: req.WorkDir,
|
||||
background: req.Background,
|
||||
cmd: cmd,
|
||||
cancel: cancel,
|
||||
buf: buf,
|
||||
running: true,
|
||||
startedAt: now,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
if m.closed {
|
||||
m.mu.Unlock()
|
||||
// Manager closed between our check and now. Kill the
|
||||
// process we just started.
|
||||
cancel()
|
||||
_ = cmd.Wait()
|
||||
return nil, xerrors.New("manager is closed")
|
||||
}
|
||||
m.procs[id] = proc
|
||||
m.mu.Unlock()
|
||||
|
||||
go func() {
|
||||
err := cmd.Wait()
|
||||
exitedAt := m.clock.Now().Unix()
|
||||
|
||||
proc.mu.Lock()
|
||||
proc.running = false
|
||||
proc.exitedAt = &exitedAt
|
||||
code := 0
|
||||
if err != nil {
|
||||
// Extract the exit code from the error.
|
||||
var exitErr *exec.ExitError
|
||||
if xerrors.As(err, &exitErr) {
|
||||
code = exitErr.ExitCode()
|
||||
} else {
|
||||
// Unknown error; use -1 as a sentinel.
|
||||
code = -1
|
||||
m.logger.Warn(
|
||||
context.Background(),
|
||||
"process wait returned non-exit error",
|
||||
slog.F("id", id),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
proc.exitCode = &code
|
||||
proc.mu.Unlock()
|
||||
|
||||
close(proc.done)
|
||||
}()
|
||||
|
||||
return proc, nil
|
||||
}
|
||||
|
||||
// get returns a process by ID.
|
||||
func (m *manager) get(id string) (*process, bool) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
proc, ok := m.procs[id]
|
||||
return proc, ok
|
||||
}
|
||||
|
||||
// list returns info about all tracked processes.
|
||||
func (m *manager) list() []workspacesdk.ProcessInfo {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
infos := make([]workspacesdk.ProcessInfo, 0, len(m.procs))
|
||||
for _, proc := range m.procs {
|
||||
infos = append(infos, proc.info())
|
||||
}
|
||||
return infos
|
||||
}
|
||||
|
||||
// signal sends a signal to a running process. It returns
|
||||
// sentinel errors errProcessNotFound and errProcessNotRunning
|
||||
// so callers can distinguish failure modes.
|
||||
func (m *manager) signal(id string, sig string) error {
|
||||
m.mu.Lock()
|
||||
proc, ok := m.procs[id]
|
||||
m.mu.Unlock()
|
||||
|
||||
if !ok {
|
||||
return errProcessNotFound
|
||||
}
|
||||
|
||||
proc.mu.Lock()
|
||||
defer proc.mu.Unlock()
|
||||
|
||||
if !proc.running {
|
||||
return errProcessNotRunning
|
||||
}
|
||||
|
||||
switch sig {
|
||||
case "kill":
|
||||
if err := proc.cmd.Process.Kill(); err != nil {
|
||||
return xerrors.Errorf("kill process: %w", err)
|
||||
}
|
||||
case "terminate":
|
||||
//nolint:revive // syscall.SIGTERM is portable enough
|
||||
// for our supported platforms.
|
||||
if err := proc.cmd.Process.Signal(syscall.SIGTERM); err != nil {
|
||||
return xerrors.Errorf("terminate process: %w", err)
|
||||
}
|
||||
default:
|
||||
return xerrors.Errorf("unsupported signal %q", sig)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close kills all running processes and prevents new ones from
|
||||
// starting. It cancels each process's context, which causes
|
||||
// CommandContext to kill the process and its pipe goroutines to
|
||||
// drain.
|
||||
func (m *manager) Close() error {
|
||||
m.mu.Lock()
|
||||
if m.closed {
|
||||
m.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
m.closed = true
|
||||
procs := make([]*process, 0, len(m.procs))
|
||||
for _, p := range m.procs {
|
||||
procs = append(procs, p)
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
for _, p := range procs {
|
||||
p.cancel()
|
||||
}
|
||||
|
||||
// Wait for all processes to exit.
|
||||
for _, p := range procs {
|
||||
<-p.done
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -28,6 +28,7 @@ func (a *agent) apiHandler() http.Handler {
|
||||
})
|
||||
|
||||
r.Mount("/api/v0", a.filesAPI.Routes())
|
||||
r.Mount("/api/v0/processes", a.processAPI.Routes())
|
||||
|
||||
if a.devcontainers {
|
||||
r.Mount("/api/v0/containers", a.containerAPI.Routes())
|
||||
|
||||
@@ -2071,6 +2071,15 @@ func (p *Server) runChat(
|
||||
chattool.Execute(chattool.ExecuteOptions{
|
||||
GetWorkspaceConn: getWorkspaceConn,
|
||||
}),
|
||||
chattool.ProcessOutput(chattool.ProcessToolOptions{
|
||||
GetWorkspaceConn: getWorkspaceConn,
|
||||
}),
|
||||
chattool.ProcessList(chattool.ProcessToolOptions{
|
||||
GetWorkspaceConn: getWorkspaceConn,
|
||||
}),
|
||||
chattool.ProcessSignal(chattool.ProcessToolOptions{
|
||||
GetWorkspaceConn: getWorkspaceConn,
|
||||
}),
|
||||
}
|
||||
tools = append(tools, p.subagentTools(func() database.Chat {
|
||||
return chat
|
||||
|
||||
@@ -2,32 +2,88 @@ package chattool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultExecuteTimeout = 60 * time.Second
|
||||
chatAgentEnvVar = "CODER_CHAT_AGENT"
|
||||
gitAuthRequiredPrefix = "CODER_GITAUTH_REQUIRED:"
|
||||
authRequiredResultReason = "authentication_required"
|
||||
// defaultTimeout is the default timeout for command
|
||||
// execution.
|
||||
defaultTimeout = 10 * time.Second
|
||||
|
||||
// maxOutputToModel is the maximum output sent to the LLM.
|
||||
maxOutputToModel = 32 << 10 // 32KB
|
||||
|
||||
// pollInterval is how often we check for process completion
|
||||
// in foreground mode.
|
||||
pollInterval = 200 * time.Millisecond
|
||||
)
|
||||
|
||||
// nonInteractiveEnvVars are set on every process to prevent
|
||||
// interactive prompts that would hang a headless execution.
|
||||
var nonInteractiveEnvVars = map[string]string{
|
||||
"GIT_EDITOR": "true",
|
||||
"GIT_SEQUENCE_EDITOR": "true",
|
||||
"EDITOR": "true",
|
||||
"VISUAL": "true",
|
||||
"GIT_TERMINAL_PROMPT": "0",
|
||||
"NO_COLOR": "1",
|
||||
"TERM": "dumb",
|
||||
"PAGER": "cat",
|
||||
"GIT_PAGER": "cat",
|
||||
}
|
||||
|
||||
// fileDumpPatterns detects commands that dump entire files.
|
||||
// When matched, a note is added suggesting read_file instead.
|
||||
var fileDumpPatterns = []*regexp.Regexp{
|
||||
regexp.MustCompile(`^cat\s+`),
|
||||
regexp.MustCompile(`^(rg|grep)\s+.*--include-all`),
|
||||
regexp.MustCompile(`^(rg|grep)\s+-l\s+`),
|
||||
}
|
||||
|
||||
// ExecuteResult is the structured response from the execute
|
||||
// tool.
|
||||
type ExecuteResult struct {
|
||||
Success bool `json:"success"`
|
||||
Output string `json:"output,omitempty"`
|
||||
ExitCode int `json:"exit_code"`
|
||||
WallDurationMs int64 `json:"wall_duration_ms"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Truncated *workspacesdk.ProcessTruncation `json:"truncated,omitempty"`
|
||||
Note string `json:"note,omitempty"`
|
||||
BackgroundProcessID string `json:"background_process_id,omitempty"`
|
||||
}
|
||||
|
||||
// ExecuteOptions configures the execute tool.
|
||||
type ExecuteOptions struct {
|
||||
GetWorkspaceConn func(context.Context) (workspacesdk.AgentConn, error)
|
||||
DefaultTimeout time.Duration
|
||||
}
|
||||
|
||||
type ExecuteArgs struct {
|
||||
Command string `json:"command"`
|
||||
TimeoutSeconds *int `json:"timeout_seconds,omitempty"`
|
||||
// ProcessToolOptions configures a process management tool
|
||||
// (process_output, process_list, or process_signal). Each of
|
||||
// these tools only needs a workspace connection resolver.
|
||||
type ProcessToolOptions struct {
|
||||
GetWorkspaceConn func(context.Context) (workspacesdk.AgentConn, error)
|
||||
}
|
||||
|
||||
// ExecuteArgs are the parameters accepted by the execute tool.
|
||||
type ExecuteArgs struct {
|
||||
Command string `json:"command"`
|
||||
Timeout *string `json:"timeout,omitempty"`
|
||||
WorkDir *string `json:"workdir,omitempty"`
|
||||
RunInBackground *bool `json:"run_in_background,omitempty"`
|
||||
}
|
||||
|
||||
// Execute returns an AgentTool that runs a shell command in the
|
||||
// workspace via the agent HTTP API.
|
||||
func Execute(options ExecuteOptions) fantasy.AgentTool {
|
||||
return fantasy.NewAgentTool(
|
||||
"execute",
|
||||
@@ -49,85 +105,342 @@ func executeTool(
|
||||
ctx context.Context,
|
||||
conn workspacesdk.AgentConn,
|
||||
args ExecuteArgs,
|
||||
defaultTimeout time.Duration,
|
||||
optTimeout time.Duration,
|
||||
) fantasy.ToolResponse {
|
||||
if args.Command == "" {
|
||||
return fantasy.NewTextErrorResponse("command is required")
|
||||
}
|
||||
|
||||
timeout := defaultTimeout
|
||||
if timeout <= 0 {
|
||||
timeout = defaultExecuteTimeout
|
||||
// Build the environment map for the process request.
|
||||
env := make(map[string]string, len(nonInteractiveEnvVars)+1)
|
||||
env["CODER_CHAT_AGENT"] = "true"
|
||||
for k, v := range nonInteractiveEnvVars {
|
||||
env[k] = v
|
||||
}
|
||||
if args.TimeoutSeconds != nil {
|
||||
timeout = time.Duration(*args.TimeoutSeconds) * time.Second
|
||||
}
|
||||
cmdCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
output, exitCode, err := runCommand(cmdCtx, conn, args.Command)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error())
|
||||
background := args.RunInBackground != nil && *args.RunInBackground
|
||||
|
||||
var workDir string
|
||||
if args.WorkDir != nil {
|
||||
workDir = *args.WorkDir
|
||||
}
|
||||
return toolResponse(map[string]any{
|
||||
"output": output,
|
||||
"exit_code": exitCode,
|
||||
})
|
||||
|
||||
if background {
|
||||
return executeBackground(ctx, conn, args.Command, workDir, env)
|
||||
}
|
||||
return executeForeground(ctx, conn, args, optTimeout, workDir, env)
|
||||
}
|
||||
|
||||
func runCommand(
|
||||
// executeBackground starts a process in the background and
|
||||
// returns immediately with the process ID.
|
||||
func executeBackground(
|
||||
ctx context.Context,
|
||||
conn workspacesdk.AgentConn,
|
||||
command string,
|
||||
) (string, int, error) {
|
||||
sshClient, err := conn.SSHClient(ctx)
|
||||
workDir string,
|
||||
env map[string]string,
|
||||
) fantasy.ToolResponse {
|
||||
resp, err := conn.StartProcess(ctx, workspacesdk.StartProcessRequest{
|
||||
Command: command,
|
||||
WorkDir: workDir,
|
||||
Env: env,
|
||||
Background: true,
|
||||
})
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
return errorResult(fmt.Sprintf("start background process: %v", err))
|
||||
}
|
||||
defer sshClient.Close()
|
||||
|
||||
session, err := sshClient.NewSession()
|
||||
result := ExecuteResult{
|
||||
Success: true,
|
||||
BackgroundProcessID: resp.ID,
|
||||
}
|
||||
data, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
return fantasy.NewTextErrorResponse(err.Error())
|
||||
}
|
||||
defer session.Close()
|
||||
if err := session.Setenv(chatAgentEnvVar, "true"); err != nil {
|
||||
return "", 0, xerrors.Errorf("set %s: %w", chatAgentEnvVar, err)
|
||||
return fantasy.NewTextResponse(string(data))
|
||||
}
|
||||
|
||||
// executeForeground starts a process and polls for its
|
||||
// completion, enforcing the configured timeout.
|
||||
func executeForeground(
|
||||
ctx context.Context,
|
||||
conn workspacesdk.AgentConn,
|
||||
args ExecuteArgs,
|
||||
optTimeout time.Duration,
|
||||
workDir string,
|
||||
env map[string]string,
|
||||
) fantasy.ToolResponse {
|
||||
timeout := optTimeout
|
||||
if timeout <= 0 {
|
||||
timeout = defaultTimeout
|
||||
}
|
||||
|
||||
resultCh := make(chan struct {
|
||||
output string
|
||||
exitCode int
|
||||
err error
|
||||
}, 1)
|
||||
|
||||
go func() {
|
||||
output, err := session.CombinedOutput(command)
|
||||
exitCode := 0
|
||||
if args.Timeout != nil {
|
||||
parsed, err := time.ParseDuration(*args.Timeout)
|
||||
if err != nil {
|
||||
var exitErr *ssh.ExitError
|
||||
if xerrors.As(err, &exitErr) {
|
||||
exitCode = exitErr.ExitStatus()
|
||||
} else {
|
||||
exitCode = 1
|
||||
return fantasy.NewTextErrorResponse(
|
||||
fmt.Sprintf("invalid timeout %q: %v", *args.Timeout, err),
|
||||
)
|
||||
}
|
||||
timeout = parsed
|
||||
}
|
||||
|
||||
cmdCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
start := time.Now()
|
||||
|
||||
resp, err := conn.StartProcess(cmdCtx, workspacesdk.StartProcessRequest{
|
||||
Command: args.Command,
|
||||
WorkDir: workDir,
|
||||
Env: env,
|
||||
Background: false,
|
||||
})
|
||||
if err != nil {
|
||||
return errorResult(fmt.Sprintf("start process: %v", err))
|
||||
}
|
||||
|
||||
result := pollProcess(cmdCtx, conn, resp.ID, timeout)
|
||||
result.WallDurationMs = time.Since(start).Milliseconds()
|
||||
|
||||
// Add an advisory note for file-dump commands.
|
||||
if note := detectFileDump(args.Command); note != "" {
|
||||
result.Note = note
|
||||
}
|
||||
|
||||
data, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error())
|
||||
}
|
||||
return fantasy.NewTextResponse(string(data))
|
||||
}
|
||||
|
||||
// truncateOutput safely truncates output to maxOutputToModel,
|
||||
// ensuring the result is valid UTF-8 even if the cut falls in
|
||||
// the middle of a multi-byte character.
|
||||
func truncateOutput(output string) string {
|
||||
if len(output) > maxOutputToModel {
|
||||
output = strings.ToValidUTF8(output[:maxOutputToModel], "")
|
||||
}
|
||||
return output
|
||||
}
|
||||
|
||||
// pollProcess polls for process output until the process exits
|
||||
// or the context times out.
|
||||
func pollProcess(
|
||||
ctx context.Context,
|
||||
conn workspacesdk.AgentConn,
|
||||
processID string,
|
||||
timeout time.Duration,
|
||||
) ExecuteResult {
|
||||
ticker := time.NewTicker(pollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Timeout — get whatever output we have. Use a
|
||||
// fresh context since cmdCtx is already canceled.
|
||||
bgCtx, bgCancel := context.WithTimeout(
|
||||
context.Background(),
|
||||
5*time.Second,
|
||||
)
|
||||
outputResp, _ := conn.ProcessOutput(bgCtx, processID)
|
||||
bgCancel()
|
||||
output := truncateOutput(outputResp.Output)
|
||||
return ExecuteResult{
|
||||
Success: false,
|
||||
Output: output,
|
||||
ExitCode: -1,
|
||||
Error: fmt.Sprintf("command timed out after %s", timeout),
|
||||
Truncated: outputResp.Truncated,
|
||||
}
|
||||
case <-ticker.C:
|
||||
outputResp, err := conn.ProcessOutput(ctx, processID)
|
||||
if err != nil {
|
||||
return ExecuteResult{
|
||||
Success: false,
|
||||
Error: fmt.Sprintf("get process output: %v", err),
|
||||
}
|
||||
}
|
||||
if !outputResp.Running {
|
||||
exitCode := 0
|
||||
if outputResp.ExitCode != nil {
|
||||
exitCode = *outputResp.ExitCode
|
||||
}
|
||||
output := truncateOutput(outputResp.Output)
|
||||
return ExecuteResult{
|
||||
Success: exitCode == 0,
|
||||
Output: output,
|
||||
ExitCode: exitCode,
|
||||
Truncated: outputResp.Truncated,
|
||||
}
|
||||
}
|
||||
}
|
||||
resultCh <- struct {
|
||||
output string
|
||||
exitCode int
|
||||
err error
|
||||
}{
|
||||
output: string(output),
|
||||
exitCode: exitCode,
|
||||
err: err,
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
_ = session.Close()
|
||||
return "", 0, ctx.Err()
|
||||
case result := <-resultCh:
|
||||
return result.output, result.exitCode, result.err
|
||||
}
|
||||
}
|
||||
|
||||
// errorResult builds a ToolResponse from an ExecuteResult with
|
||||
// an error message.
|
||||
func errorResult(msg string) fantasy.ToolResponse {
|
||||
data, err := json.Marshal(ExecuteResult{
|
||||
Success: false,
|
||||
Error: msg,
|
||||
})
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(msg)
|
||||
}
|
||||
return fantasy.NewTextResponse(string(data))
|
||||
}
|
||||
|
||||
// detectFileDump checks whether the command matches a file-dump
|
||||
// pattern and returns an advisory note, or empty string if no
|
||||
// match.
|
||||
func detectFileDump(command string) string {
|
||||
for _, pat := range fileDumpPatterns {
|
||||
if pat.MatchString(command) {
|
||||
return "Consider using read_file instead of " +
|
||||
"dumping file contents with shell commands."
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// ProcessOutputArgs are the parameters accepted by the
|
||||
// process_output tool.
|
||||
type ProcessOutputArgs struct {
|
||||
ProcessID string `json:"process_id"`
|
||||
}
|
||||
|
||||
// ProcessOutput returns an AgentTool that retrieves the output
|
||||
// of a background process by its ID.
|
||||
func ProcessOutput(options ProcessToolOptions) fantasy.AgentTool {
|
||||
return fantasy.NewAgentTool(
|
||||
"process_output",
|
||||
"Retrieve output from a background process. "+
|
||||
"Use the process_id returned by execute with "+
|
||||
"run_in_background=true. Returns the current output, "+
|
||||
"whether the process is still running, and the exit "+
|
||||
"code if it has finished.",
|
||||
func(ctx context.Context, args ProcessOutputArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
if options.GetWorkspaceConn == nil {
|
||||
return fantasy.NewTextErrorResponse("workspace connection resolver is not configured"), nil
|
||||
}
|
||||
if args.ProcessID == "" {
|
||||
return fantasy.NewTextErrorResponse("process_id is required"), nil
|
||||
}
|
||||
conn, err := options.GetWorkspaceConn(ctx)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
resp, err := conn.ProcessOutput(ctx, args.ProcessID)
|
||||
if err != nil {
|
||||
return errorResult(fmt.Sprintf("get process output: %v", err)), nil
|
||||
}
|
||||
output := truncateOutput(resp.Output)
|
||||
exitCode := 0
|
||||
if resp.ExitCode != nil {
|
||||
exitCode = *resp.ExitCode
|
||||
}
|
||||
result := ExecuteResult{
|
||||
Success: !resp.Running && exitCode == 0,
|
||||
Output: output,
|
||||
ExitCode: exitCode,
|
||||
Truncated: resp.Truncated,
|
||||
}
|
||||
if resp.Running {
|
||||
// Process is still running — success is not
|
||||
// yet determined.
|
||||
result.Success = true
|
||||
result.Note = "process is still running"
|
||||
}
|
||||
data, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
return fantasy.NewTextResponse(string(data)), nil
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// ProcessList returns an AgentTool that lists all tracked
|
||||
// processes on the workspace agent.
|
||||
func ProcessList(options ProcessToolOptions) fantasy.AgentTool {
|
||||
return fantasy.NewAgentTool(
|
||||
"process_list",
|
||||
"List all tracked processes in the workspace. "+
|
||||
"Returns process IDs, commands, status (running or "+
|
||||
"exited), and exit codes. Use this to discover "+
|
||||
"background processes or check which processes are "+
|
||||
"still running.",
|
||||
func(ctx context.Context, _ struct{}, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
if options.GetWorkspaceConn == nil {
|
||||
return fantasy.NewTextErrorResponse("workspace connection resolver is not configured"), nil
|
||||
}
|
||||
conn, err := options.GetWorkspaceConn(ctx)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
resp, err := conn.ListProcesses(ctx)
|
||||
if err != nil {
|
||||
return errorResult(fmt.Sprintf("list processes: %v", err)), nil
|
||||
}
|
||||
data, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
return fantasy.NewTextResponse(string(data)), nil
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// ProcessSignalArgs are the parameters accepted by the
|
||||
// process_signal tool.
|
||||
type ProcessSignalArgs struct {
|
||||
ProcessID string `json:"process_id"`
|
||||
Signal string `json:"signal"`
|
||||
}
|
||||
|
||||
// ProcessSignal returns an AgentTool that sends a signal to a
|
||||
// tracked process on the workspace agent.
|
||||
func ProcessSignal(options ProcessToolOptions) fantasy.AgentTool {
|
||||
return fantasy.NewAgentTool(
|
||||
"process_signal",
|
||||
"Send a signal to a background process. "+
|
||||
"Use \"terminate\" (SIGTERM) for graceful shutdown "+
|
||||
"or \"kill\" (SIGKILL) to force stop. Use the "+
|
||||
"process_id returned by execute with "+
|
||||
"run_in_background=true or from process_list.",
|
||||
func(ctx context.Context, args ProcessSignalArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
if options.GetWorkspaceConn == nil {
|
||||
return fantasy.NewTextErrorResponse("workspace connection resolver is not configured"), nil
|
||||
}
|
||||
if args.ProcessID == "" {
|
||||
return fantasy.NewTextErrorResponse("process_id is required"), nil
|
||||
}
|
||||
if args.Signal != "terminate" && args.Signal != "kill" {
|
||||
return fantasy.NewTextErrorResponse(
|
||||
"signal must be \"terminate\" or \"kill\"",
|
||||
), nil
|
||||
}
|
||||
conn, err := options.GetWorkspaceConn(ctx)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
if err := conn.SignalProcess(ctx, args.ProcessID, args.Signal); err != nil {
|
||||
return errorResult(fmt.Sprintf("signal process: %v", err)), nil
|
||||
}
|
||||
data, err := json.Marshal(map[string]any{
|
||||
"success": true,
|
||||
"message": fmt.Sprintf(
|
||||
"signal %q sent to process %s",
|
||||
args.Signal, args.ProcessID,
|
||||
),
|
||||
})
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
return fantasy.NewTextResponse(string(data)), nil
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@@ -54,13 +54,17 @@ type AgentConn interface {
|
||||
DialContext(ctx context.Context, network string, addr string) (net.Conn, error)
|
||||
GetPeerDiagnostics() tailnet.PeerDiagnostics
|
||||
ListContainers(ctx context.Context) (codersdk.WorkspaceAgentListContainersResponse, error)
|
||||
ListProcesses(ctx context.Context) (ListProcessesResponse, error)
|
||||
ListeningPorts(ctx context.Context) (codersdk.WorkspaceAgentListeningPortsResponse, error)
|
||||
Netcheck(ctx context.Context) (healthsdk.AgentNetcheckReport, error)
|
||||
Ping(ctx context.Context) (time.Duration, bool, *ipnstate.PingResult, error)
|
||||
ProcessOutput(ctx context.Context, id string) (ProcessOutputResponse, error)
|
||||
PrometheusMetrics(ctx context.Context) ([]byte, error)
|
||||
ReconnectingPTY(ctx context.Context, id uuid.UUID, height uint16, width uint16, command string, initOpts ...AgentReconnectingPTYInitOption) (net.Conn, error)
|
||||
DeleteDevcontainer(ctx context.Context, devcontainerID string) error
|
||||
RecreateDevcontainer(ctx context.Context, devcontainerID string) (codersdk.Response, error)
|
||||
SignalProcess(ctx context.Context, id string, signal string) error
|
||||
StartProcess(ctx context.Context, req StartProcessRequest) (StartProcessResponse, error)
|
||||
LS(ctx context.Context, path string, req LSRequest) (LSResponse, error)
|
||||
ReadFile(ctx context.Context, path string, offset, limit int64) (io.ReadCloser, string, error)
|
||||
ReadFileLines(ctx context.Context, path string, offset, limit int64, limits ReadFileLinesLimits) (ReadFileLinesResponse, error)
|
||||
@@ -498,6 +502,61 @@ func (c *agentConn) RecreateDevcontainer(ctx context.Context, devcontainerID str
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// StartProcessRequest is the request body for starting a
|
||||
// process on the workspace agent.
|
||||
type StartProcessRequest struct {
|
||||
Command string `json:"command"`
|
||||
WorkDir string `json:"workdir,omitempty"`
|
||||
Env map[string]string `json:"env,omitempty"`
|
||||
Background bool `json:"background,omitempty"`
|
||||
}
|
||||
|
||||
// StartProcessResponse is returned when a process is started.
|
||||
type StartProcessResponse struct {
|
||||
ID string `json:"id"`
|
||||
Started bool `json:"started"`
|
||||
}
|
||||
|
||||
// ListProcessesResponse contains information about tracked
|
||||
// processes on the workspace agent.
|
||||
type ListProcessesResponse struct {
|
||||
Processes []ProcessInfo `json:"processes"`
|
||||
}
|
||||
|
||||
// ProcessInfo describes a tracked process on the agent.
|
||||
type ProcessInfo struct {
|
||||
ID string `json:"id"`
|
||||
Command string `json:"command"`
|
||||
WorkDir string `json:"workdir,omitempty"`
|
||||
Background bool `json:"background"`
|
||||
Running bool `json:"running"`
|
||||
ExitCode *int `json:"exit_code,omitempty"`
|
||||
StartedAt int64 `json:"started_at_unix"`
|
||||
ExitedAt *int64 `json:"exited_at_unix,omitempty"`
|
||||
}
|
||||
|
||||
// ProcessOutputResponse contains the output of a process.
|
||||
type ProcessOutputResponse struct {
|
||||
Output string `json:"output"`
|
||||
Truncated *ProcessTruncation `json:"truncated,omitempty"`
|
||||
Running bool `json:"running"`
|
||||
ExitCode *int `json:"exit_code,omitempty"`
|
||||
}
|
||||
|
||||
// ProcessTruncation describes how process output was truncated.
|
||||
type ProcessTruncation struct {
|
||||
OriginalBytes int `json:"original_bytes"`
|
||||
RetainedBytes int `json:"retained_bytes"`
|
||||
OmittedBytes int `json:"omitted_bytes"`
|
||||
Strategy string `json:"strategy"`
|
||||
}
|
||||
|
||||
// SignalProcessRequest is the request body for signaling a
|
||||
// process on the workspace agent.
|
||||
type SignalProcessRequest struct {
|
||||
Signal string `json:"signal"`
|
||||
}
|
||||
|
||||
type LSRequest struct {
|
||||
// e.g. [], ["repos", "coder"],
|
||||
Path []string `json:"path"`
|
||||
@@ -681,6 +740,73 @@ type FileEditRequest struct {
|
||||
Files []FileEdits `json:"files"`
|
||||
}
|
||||
|
||||
// StartProcess starts a new process on the workspace agent.
|
||||
func (c *agentConn) StartProcess(ctx context.Context, req StartProcessRequest) (StartProcessResponse, error) {
|
||||
ctx, span := tracing.StartSpan(ctx)
|
||||
defer span.End()
|
||||
res, err := c.apiRequest(ctx, http.MethodPost, "/api/v0/processes/start", req)
|
||||
if err != nil {
|
||||
return StartProcessResponse{}, xerrors.Errorf("do request: %w", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return StartProcessResponse{}, codersdk.ReadBodyAsError(res)
|
||||
}
|
||||
var resp StartProcessResponse
|
||||
return resp, json.NewDecoder(res.Body).Decode(&resp)
|
||||
}
|
||||
|
||||
// ListProcesses returns information about tracked processes on the agent.
|
||||
func (c *agentConn) ListProcesses(ctx context.Context) (ListProcessesResponse, error) {
|
||||
ctx, span := tracing.StartSpan(ctx)
|
||||
defer span.End()
|
||||
res, err := c.apiRequest(ctx, http.MethodGet, "/api/v0/processes/list", nil)
|
||||
if err != nil {
|
||||
return ListProcessesResponse{}, xerrors.Errorf("do request: %w", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return ListProcessesResponse{}, codersdk.ReadBodyAsError(res)
|
||||
}
|
||||
var resp ListProcessesResponse
|
||||
return resp, json.NewDecoder(res.Body).Decode(&resp)
|
||||
}
|
||||
|
||||
// ProcessOutput returns the output of a tracked process on the agent.
|
||||
func (c *agentConn) ProcessOutput(ctx context.Context, id string) (ProcessOutputResponse, error) {
|
||||
ctx, span := tracing.StartSpan(ctx)
|
||||
defer span.End()
|
||||
res, err := c.apiRequest(ctx, http.MethodGet, "/api/v0/processes/"+id+"/output", nil)
|
||||
if err != nil {
|
||||
return ProcessOutputResponse{}, xerrors.Errorf("do request: %w", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return ProcessOutputResponse{}, codersdk.ReadBodyAsError(res)
|
||||
}
|
||||
var resp ProcessOutputResponse
|
||||
return resp, json.NewDecoder(res.Body).Decode(&resp)
|
||||
}
|
||||
|
||||
// SignalProcess sends a signal to a tracked process on the agent.
|
||||
func (c *agentConn) SignalProcess(ctx context.Context, id string, signal string) error {
|
||||
ctx, span := tracing.StartSpan(ctx)
|
||||
defer span.End()
|
||||
res, err := c.apiRequest(ctx, http.MethodPost, "/api/v0/processes/"+id+"/signal", SignalProcessRequest{Signal: signal})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("do request: %w", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return codersdk.ReadBodyAsError(res)
|
||||
}
|
||||
var m codersdk.Response
|
||||
if err := json.NewDecoder(res.Body).Decode(&m); err != nil {
|
||||
return xerrors.Errorf("decode response body: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// EditFiles performs search and replace edits on one or more files.
|
||||
func (c *agentConn) EditFiles(ctx context.Context, edits FileEditRequest) error {
|
||||
ctx, span := tracing.StartSpan(ctx)
|
||||
|
||||
@@ -213,6 +213,21 @@ func (mr *MockAgentConnMockRecorder) ListContainers(ctx any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListContainers", reflect.TypeOf((*MockAgentConn)(nil).ListContainers), ctx)
|
||||
}
|
||||
|
||||
// ListProcesses mocks base method.
|
||||
func (m *MockAgentConn) ListProcesses(ctx context.Context) (workspacesdk.ListProcessesResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ListProcesses", ctx)
|
||||
ret0, _ := ret[0].(workspacesdk.ListProcessesResponse)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ListProcesses indicates an expected call of ListProcesses.
|
||||
func (mr *MockAgentConnMockRecorder) ListProcesses(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListProcesses", reflect.TypeOf((*MockAgentConn)(nil).ListProcesses), ctx)
|
||||
}
|
||||
|
||||
// ListeningPorts mocks base method.
|
||||
func (m *MockAgentConn) ListeningPorts(ctx context.Context) (codersdk.WorkspaceAgentListeningPortsResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -260,6 +275,21 @@ func (mr *MockAgentConnMockRecorder) Ping(ctx any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Ping", reflect.TypeOf((*MockAgentConn)(nil).Ping), ctx)
|
||||
}
|
||||
|
||||
// ProcessOutput mocks base method.
|
||||
func (m *MockAgentConn) ProcessOutput(ctx context.Context, id string) (workspacesdk.ProcessOutputResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ProcessOutput", ctx, id)
|
||||
ret0, _ := ret[0].(workspacesdk.ProcessOutputResponse)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ProcessOutput indicates an expected call of ProcessOutput.
|
||||
func (mr *MockAgentConnMockRecorder) ProcessOutput(ctx, id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ProcessOutput", reflect.TypeOf((*MockAgentConn)(nil).ProcessOutput), ctx, id)
|
||||
}
|
||||
|
||||
// PrometheusMetrics mocks base method.
|
||||
func (m *MockAgentConn) PrometheusMetrics(ctx context.Context) ([]byte, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -401,6 +431,20 @@ func (mr *MockAgentConnMockRecorder) SSHOnPort(ctx, port any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SSHOnPort", reflect.TypeOf((*MockAgentConn)(nil).SSHOnPort), ctx, port)
|
||||
}
|
||||
|
||||
// SignalProcess mocks base method.
|
||||
func (m *MockAgentConn) SignalProcess(ctx context.Context, id, signal string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SignalProcess", ctx, id, signal)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SignalProcess indicates an expected call of SignalProcess.
|
||||
func (mr *MockAgentConnMockRecorder) SignalProcess(ctx, id, signal any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignalProcess", reflect.TypeOf((*MockAgentConn)(nil).SignalProcess), ctx, id, signal)
|
||||
}
|
||||
|
||||
// Speedtest mocks base method.
|
||||
func (m *MockAgentConn) Speedtest(ctx context.Context, direction speedtest.Direction, duration time.Duration) ([]speedtest.Result, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -416,6 +460,21 @@ func (mr *MockAgentConnMockRecorder) Speedtest(ctx, direction, duration any) *go
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Speedtest", reflect.TypeOf((*MockAgentConn)(nil).Speedtest), ctx, direction, duration)
|
||||
}
|
||||
|
||||
// StartProcess mocks base method.
|
||||
func (m *MockAgentConn) StartProcess(ctx context.Context, req workspacesdk.StartProcessRequest) (workspacesdk.StartProcessResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "StartProcess", ctx, req)
|
||||
ret0, _ := ret[0].(workspacesdk.StartProcessResponse)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// StartProcess indicates an expected call of StartProcess.
|
||||
func (mr *MockAgentConnMockRecorder) StartProcess(ctx, req any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartProcess", reflect.TypeOf((*MockAgentConn)(nil).StartProcess), ctx, req)
|
||||
}
|
||||
|
||||
// TailnetConn mocks base method.
|
||||
func (m *MockAgentConn) TailnetConn() *tailnet.Conn {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
Reference in New Issue
Block a user