Files
coder/codersdk/toolsdk/bash.go
T
Thomas Kosiewski 398e80f003 feat: add timeout support to workspace bash tool (#19035)
# Add timeout support to workspace bash tool

This PR adds a timeout feature to the workspace bash tool, allowing
users to specify a maximum execution time for commands. Key changes
include:

- Added a `timeout_ms` parameter to control command execution time
(defaults to 60 seconds, with a maximum of 5 minutes)
- Implemented a new `executeCommandWithTimeout` function that properly
handles command timeouts
- Added proper output capturing during timeout scenarios, returning all
output collected before the timeout
- Updated documentation to explain the timeout feature and provide usage
examples
- Added comprehensive tests for the timeout functionality, including
integration tests

When a command times out, the tool now returns all captured output up to
that point along with a cancellation message, making it clear to users
what happened.

Signed-off-by: Thomas Kosiewski <tk@coder.com>
2025-07-28 11:25:43 +02:00

426 lines
13 KiB
Go

package toolsdk
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"strings"
"sync"
"time"
gossh "golang.org/x/crypto/ssh"
"golang.org/x/xerrors"
"github.com/coder/aisdk-go"
"github.com/coder/coder/v2/cli/cliui"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
type WorkspaceBashArgs struct {
Workspace string `json:"workspace"`
Command string `json:"command"`
TimeoutMs int `json:"timeout_ms,omitempty"`
}
type WorkspaceBashResult struct {
Output string `json:"output"`
ExitCode int `json:"exit_code"`
}
var WorkspaceBash = Tool[WorkspaceBashArgs, WorkspaceBashResult]{
Tool: aisdk.Tool{
Name: ToolNameWorkspaceBash,
Description: `Execute a bash command in a Coder workspace.
This tool provides the same functionality as the 'coder ssh <workspace> <command>' CLI command.
It automatically starts the workspace if it's stopped and waits for the agent to be ready.
The output is trimmed of leading and trailing whitespace.
The workspace parameter supports various formats:
- workspace (uses current user)
- owner/workspace
- owner--workspace
- workspace.agent (specific agent)
- owner/workspace.agent
The timeout_ms parameter specifies the command timeout in milliseconds (defaults to 60000ms, maximum of 300000ms).
If the command times out, all output captured up to that point is returned with a cancellation message.
Examples:
- workspace: "my-workspace", command: "ls -la"
- workspace: "john/dev-env", command: "git status", timeout_ms: 30000
- workspace: "my-workspace.main", command: "docker ps"`,
Schema: aisdk.Schema{
Properties: map[string]any{
"workspace": map[string]any{
"type": "string",
"description": "The workspace name in format [owner/]workspace[.agent]. If owner is not specified, the authenticated user is used.",
},
"command": map[string]any{
"type": "string",
"description": "The bash command to execute in the workspace.",
},
"timeout_ms": map[string]any{
"type": "integer",
"description": "Command timeout in milliseconds. Defaults to 60000ms (60 seconds) if not specified.",
"default": 60000,
"minimum": 1,
},
},
Required: []string{"workspace", "command"},
},
},
Handler: func(ctx context.Context, deps Deps, args WorkspaceBashArgs) (res WorkspaceBashResult, err error) {
if args.Workspace == "" {
return WorkspaceBashResult{}, xerrors.New("workspace name cannot be empty")
}
if args.Command == "" {
return WorkspaceBashResult{}, xerrors.New("command cannot be empty")
}
ctx, cancel := context.WithTimeoutCause(ctx, 5*time.Minute, xerrors.New("MCP handler timeout after 5 min"))
defer cancel()
// Normalize workspace input to handle various formats
workspaceName := NormalizeWorkspaceInput(args.Workspace)
// Find workspace and agent
_, workspaceAgent, err := findWorkspaceAndAgent(ctx, deps.coderClient, workspaceName)
if err != nil {
return WorkspaceBashResult{}, xerrors.Errorf("failed to find workspace: %w", err)
}
// Wait for agent to be ready
if err := cliui.Agent(ctx, io.Discard, workspaceAgent.ID, cliui.AgentOptions{
FetchInterval: 0,
Fetch: deps.coderClient.WorkspaceAgent,
FetchLogs: deps.coderClient.WorkspaceAgentLogsAfter,
Wait: true, // Always wait for startup scripts
}); err != nil {
return WorkspaceBashResult{}, xerrors.Errorf("agent not ready: %w", err)
}
// Create workspace SDK client for agent connection
wsClient := workspacesdk.New(deps.coderClient)
// Dial agent
conn, err := wsClient.DialAgent(ctx, workspaceAgent.ID, &workspacesdk.DialAgentOptions{
BlockEndpoints: false,
})
if err != nil {
return WorkspaceBashResult{}, xerrors.Errorf("failed to dial agent: %w", err)
}
defer conn.Close()
// Wait for connection to be reachable
if !conn.AwaitReachable(ctx) {
return WorkspaceBashResult{}, xerrors.New("agent connection not reachable")
}
// Create SSH client
sshClient, err := conn.SSHClient(ctx)
if err != nil {
return WorkspaceBashResult{}, xerrors.Errorf("failed to create SSH client: %w", err)
}
defer sshClient.Close()
// Create SSH session
session, err := sshClient.NewSession()
if err != nil {
return WorkspaceBashResult{}, xerrors.Errorf("failed to create SSH session: %w", err)
}
defer session.Close()
// Set default timeout if not specified (60 seconds)
timeoutMs := args.TimeoutMs
if timeoutMs <= 0 {
timeoutMs = 60000
}
// Create context with timeout
ctx, cancel = context.WithTimeout(ctx, time.Duration(timeoutMs)*time.Millisecond)
defer cancel()
// Execute command with timeout handling
output, err := executeCommandWithTimeout(ctx, session, args.Command)
outputStr := strings.TrimSpace(string(output))
// Handle command execution results
if err != nil {
// Check if the command timed out
if errors.Is(context.Cause(ctx), context.DeadlineExceeded) {
outputStr += "\nCommand canceled due to timeout"
return WorkspaceBashResult{
Output: outputStr,
ExitCode: 124,
}, nil
}
// Extract exit code from SSH error if available
exitCode := 1
var exitErr *gossh.ExitError
if errors.As(err, &exitErr) {
exitCode = exitErr.ExitStatus()
}
// For other errors, use standard timeout or generic error code
return WorkspaceBashResult{
Output: outputStr,
ExitCode: exitCode,
}, nil
}
return WorkspaceBashResult{
Output: outputStr,
ExitCode: 0,
}, nil
},
}
// findWorkspaceAndAgent finds workspace and agent by name with auto-start support
func findWorkspaceAndAgent(ctx context.Context, client *codersdk.Client, workspaceName string) (codersdk.Workspace, codersdk.WorkspaceAgent, error) {
// Parse workspace name to extract workspace and agent parts
parts := strings.Split(workspaceName, ".")
var agentName string
if len(parts) >= 2 {
agentName = parts[1]
workspaceName = parts[0]
}
// Get workspace
workspace, err := namedWorkspace(ctx, client, workspaceName)
if err != nil {
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, err
}
// Auto-start workspace if needed
if workspace.LatestBuild.Transition != codersdk.WorkspaceTransitionStart {
if workspace.LatestBuild.Transition == codersdk.WorkspaceTransitionDelete {
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.Errorf("workspace %q is deleted", workspace.Name)
}
if workspace.LatestBuild.Job.Status == codersdk.ProvisionerJobFailed {
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.Errorf("workspace %q is in failed state", workspace.Name)
}
if workspace.LatestBuild.Status != codersdk.WorkspaceStatusStopped {
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.Errorf("workspace must be started; was unable to autostart as the last build job is %q, expected %q",
workspace.LatestBuild.Status, codersdk.WorkspaceStatusStopped)
}
// Start workspace
build, err := client.CreateWorkspaceBuild(ctx, workspace.ID, codersdk.CreateWorkspaceBuildRequest{
Transition: codersdk.WorkspaceTransitionStart,
})
if err != nil {
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.Errorf("failed to start workspace: %w", err)
}
// Wait for build to complete
if build.Job.CompletedAt == nil {
err := cliui.WorkspaceBuild(ctx, io.Discard, client, build.ID)
if err != nil {
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.Errorf("failed to wait for build completion: %w", err)
}
}
// Refresh workspace after build completes
workspace, err = client.Workspace(ctx, workspace.ID)
if err != nil {
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, err
}
}
// Find agent
workspaceAgent, err := getWorkspaceAgent(workspace, agentName)
if err != nil {
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, err
}
return workspace, workspaceAgent, nil
}
// getWorkspaceAgent finds the specified agent in the workspace
func getWorkspaceAgent(workspace codersdk.Workspace, agentName string) (codersdk.WorkspaceAgent, error) {
resources := workspace.LatestBuild.Resources
var agents []codersdk.WorkspaceAgent
var availableNames []string
for _, resource := range resources {
for _, agent := range resource.Agents {
availableNames = append(availableNames, agent.Name)
agents = append(agents, agent)
}
}
if len(agents) == 0 {
return codersdk.WorkspaceAgent{}, xerrors.Errorf("workspace %q has no agents", workspace.Name)
}
if agentName != "" {
for _, agent := range agents {
if agent.Name == agentName || agent.ID.String() == agentName {
return agent, nil
}
}
return codersdk.WorkspaceAgent{}, xerrors.Errorf("agent not found by name %q, available agents: %v", agentName, availableNames)
}
if len(agents) == 1 {
return agents[0], nil
}
return codersdk.WorkspaceAgent{}, xerrors.Errorf("multiple agents found, please specify the agent name, available agents: %v", availableNames)
}
// namedWorkspace gets a workspace by owner/name or just name
func namedWorkspace(ctx context.Context, client *codersdk.Client, identifier string) (codersdk.Workspace, error) {
// Parse owner and workspace name
parts := strings.SplitN(identifier, "/", 2)
var owner, workspaceName string
if len(parts) == 2 {
owner = parts[0]
workspaceName = parts[1]
} else {
owner = "me"
workspaceName = identifier
}
// Handle -- separator format (convert to / format)
if strings.Contains(identifier, "--") && !strings.Contains(identifier, "/") {
dashParts := strings.SplitN(identifier, "--", 2)
if len(dashParts) == 2 {
owner = dashParts[0]
workspaceName = dashParts[1]
}
}
return client.WorkspaceByOwnerAndName(ctx, owner, workspaceName, codersdk.WorkspaceOptions{})
}
// NormalizeWorkspaceInput converts workspace name input to standard format.
// Handles the following input formats:
// - workspace → workspace
// - workspace.agent → workspace.agent
// - owner/workspace → owner/workspace
// - owner--workspace → owner/workspace
// - owner/workspace.agent → owner/workspace.agent
// - owner--workspace.agent → owner/workspace.agent
// - agent.workspace.owner → owner/workspace.agent (Coder Connect format)
func NormalizeWorkspaceInput(input string) string {
// Handle the special Coder Connect format: agent.workspace.owner
// This format uses only dots and has exactly 3 parts
if strings.Count(input, ".") == 2 && !strings.Contains(input, "/") && !strings.Contains(input, "--") {
parts := strings.Split(input, ".")
if len(parts) == 3 {
// Convert agent.workspace.owner → owner/workspace.agent
return fmt.Sprintf("%s/%s.%s", parts[2], parts[1], parts[0])
}
}
// Convert -- separator to / separator for consistency
normalized := strings.ReplaceAll(input, "--", "/")
return normalized
}
// executeCommandWithTimeout executes a command with timeout support
func executeCommandWithTimeout(ctx context.Context, session *gossh.Session, command string) ([]byte, error) {
// Set up pipes to capture output
stdoutPipe, err := session.StdoutPipe()
if err != nil {
return nil, xerrors.Errorf("failed to create stdout pipe: %w", err)
}
stderrPipe, err := session.StderrPipe()
if err != nil {
return nil, xerrors.Errorf("failed to create stderr pipe: %w", err)
}
// Start the command
if err := session.Start(command); err != nil {
return nil, xerrors.Errorf("failed to start command: %w", err)
}
// Create a thread-safe buffer for combined output
var output bytes.Buffer
var mu sync.Mutex
safeWriter := &syncWriter{w: &output, mu: &mu}
// Use io.MultiWriter to combine stdout and stderr
multiWriter := io.MultiWriter(safeWriter)
// Channel to signal when command completes
done := make(chan error, 1)
// Start goroutine to copy output and wait for completion
go func() {
// Copy stdout and stderr concurrently
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
_, _ = io.Copy(multiWriter, stdoutPipe)
}()
go func() {
defer wg.Done()
_, _ = io.Copy(multiWriter, stderrPipe)
}()
// Wait for all output to be copied
wg.Wait()
// Wait for the command to complete
done <- session.Wait()
}()
// Wait for either completion or context cancellation
select {
case err := <-done:
// Command completed normally
return safeWriter.Bytes(), err
case <-ctx.Done():
// Context was canceled (timeout or other cancellation)
// Close the session to stop the command
_ = session.Close()
// Give a brief moment to collect any remaining output
timer := time.NewTimer(50 * time.Millisecond)
defer timer.Stop()
select {
case <-timer.C:
// Timer expired, return what we have
case err := <-done:
// Command finished during grace period
return safeWriter.Bytes(), err
}
return safeWriter.Bytes(), context.Cause(ctx)
}
}
// syncWriter is a thread-safe writer
type syncWriter struct {
w *bytes.Buffer
mu *sync.Mutex
}
func (sw *syncWriter) Write(p []byte) (n int, err error) {
sw.mu.Lock()
defer sw.mu.Unlock()
return sw.w.Write(p)
}
func (sw *syncWriter) Bytes() []byte {
sw.mu.Lock()
defer sw.mu.Unlock()
return sw.w.Bytes()
}