feat: add workspace SSH execution tool for AI SDK (#18924)

# Add SSH Command Execution Tool for Coder Workspaces

This PR adds a new AI tool `coder_workspace_ssh_exec` that allows executing commands in Coder workspaces via SSH. The tool provides functionality similar to the `coder ssh <workspace> <command>` CLI command.

Key features:
- Executes commands in workspaces via SSH and returns the output and exit code
- Automatically starts workspaces if they're stopped
- Waits for the agent to be ready before executing commands
- Trims leading and trailing whitespace from command output
- Supports various workspace identifier formats:
  - `workspace` (uses current user)
  - `owner/workspace`
  - `owner--workspace`
  - `workspace.agent` (specific agent)
  - `owner/workspace.agent`

The implementation includes:
- A new tool definition with schema and handler
- Helper functions for workspace and agent discovery
- Workspace name normalization to handle different input formats
- Comprehensive test coverage including integration tests

This tool enables AI assistants to execute commands in user workspaces, making it possible to automate tasks and provide more interactive assistance.

<!-- This is an auto-generated comment: release notes by coderabbit.ai -->
## Summary by CodeRabbit

* **New Features**
  * Introduced the ability to execute bash commands inside a Coder workspace via SSH, supporting multiple workspace identification formats.
* **Tests**
  * Added comprehensive unit and integration tests for executing bash commands in workspaces, including input validation, output handling, and error scenarios.
* **Chores**
  * Registered the new bash execution tool in the global tools list.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
Thomas Kosiewski
2025-07-21 21:24:00 +02:00
committed by GitHub
parent 75c124013f
commit 326c02459f
4 changed files with 533 additions and 2 deletions
+295
View File
@@ -0,0 +1,295 @@
package toolsdk
import (
"context"
"errors"
"fmt"
"io"
"strings"
gossh "golang.org/x/crypto/ssh"
"golang.org/x/xerrors"
"github.com/coder/aisdk-go"
"github.com/coder/coder/v2/cli/cliui"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
type WorkspaceBashArgs struct {
Workspace string `json:"workspace"`
Command string `json:"command"`
}
type WorkspaceBashResult struct {
Output string `json:"output"`
ExitCode int `json:"exit_code"`
}
var WorkspaceBash = Tool[WorkspaceBashArgs, WorkspaceBashResult]{
Tool: aisdk.Tool{
Name: ToolNameWorkspaceBash,
Description: `Execute a bash command in a Coder workspace.
This tool provides the same functionality as the 'coder ssh <workspace> <command>' CLI command.
It automatically starts the workspace if it's stopped and waits for the agent to be ready.
The output is trimmed of leading and trailing whitespace.
The workspace parameter supports various formats:
- workspace (uses current user)
- owner/workspace
- owner--workspace
- workspace.agent (specific agent)
- owner/workspace.agent
Examples:
- workspace: "my-workspace", command: "ls -la"
- workspace: "john/dev-env", command: "git status"
- workspace: "my-workspace.main", command: "docker ps"`,
Schema: aisdk.Schema{
Properties: map[string]any{
"workspace": map[string]any{
"type": "string",
"description": "The workspace name in format [owner/]workspace[.agent]. If owner is not specified, the authenticated user is used.",
},
"command": map[string]any{
"type": "string",
"description": "The bash command to execute in the workspace.",
},
},
Required: []string{"workspace", "command"},
},
},
Handler: func(ctx context.Context, deps Deps, args WorkspaceBashArgs) (WorkspaceBashResult, error) {
if args.Workspace == "" {
return WorkspaceBashResult{}, xerrors.New("workspace name cannot be empty")
}
if args.Command == "" {
return WorkspaceBashResult{}, xerrors.New("command cannot be empty")
}
// Normalize workspace input to handle various formats
workspaceName := NormalizeWorkspaceInput(args.Workspace)
// Find workspace and agent
_, workspaceAgent, err := findWorkspaceAndAgent(ctx, deps.coderClient, workspaceName)
if err != nil {
return WorkspaceBashResult{}, xerrors.Errorf("failed to find workspace: %w", err)
}
// Wait for agent to be ready
err = cliui.Agent(ctx, nil, workspaceAgent.ID, cliui.AgentOptions{
FetchInterval: 0,
Fetch: deps.coderClient.WorkspaceAgent,
FetchLogs: deps.coderClient.WorkspaceAgentLogsAfter,
Wait: true, // Always wait for startup scripts
})
if err != nil {
return WorkspaceBashResult{}, xerrors.Errorf("agent not ready: %w", err)
}
// Create workspace SDK client for agent connection
wsClient := workspacesdk.New(deps.coderClient)
// Dial agent
conn, err := wsClient.DialAgent(ctx, workspaceAgent.ID, &workspacesdk.DialAgentOptions{
BlockEndpoints: false,
})
if err != nil {
return WorkspaceBashResult{}, xerrors.Errorf("failed to dial agent: %w", err)
}
defer conn.Close()
// Wait for connection to be reachable
if !conn.AwaitReachable(ctx) {
return WorkspaceBashResult{}, xerrors.New("agent connection not reachable")
}
// Create SSH client
sshClient, err := conn.SSHClient(ctx)
if err != nil {
return WorkspaceBashResult{}, xerrors.Errorf("failed to create SSH client: %w", err)
}
defer sshClient.Close()
// Create SSH session
session, err := sshClient.NewSession()
if err != nil {
return WorkspaceBashResult{}, xerrors.Errorf("failed to create SSH session: %w", err)
}
defer session.Close()
// Execute command and capture output
output, err := session.CombinedOutput(args.Command)
outputStr := strings.TrimSpace(string(output))
if err != nil {
// Check if it's an SSH exit error to get the exit code
var exitErr *gossh.ExitError
if errors.As(err, &exitErr) {
return WorkspaceBashResult{
Output: outputStr,
ExitCode: exitErr.ExitStatus(),
}, nil
}
// For other errors, return exit code 1
return WorkspaceBashResult{
Output: outputStr,
ExitCode: 1,
}, nil
}
return WorkspaceBashResult{
Output: outputStr,
ExitCode: 0,
}, nil
},
}
// findWorkspaceAndAgent finds workspace and agent by name with auto-start support
func findWorkspaceAndAgent(ctx context.Context, client *codersdk.Client, workspaceName string) (codersdk.Workspace, codersdk.WorkspaceAgent, error) {
// Parse workspace name to extract workspace and agent parts
parts := strings.Split(workspaceName, ".")
var agentName string
if len(parts) >= 2 {
agentName = parts[1]
workspaceName = parts[0]
}
// Get workspace
workspace, err := namedWorkspace(ctx, client, workspaceName)
if err != nil {
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, err
}
// Auto-start workspace if needed
if workspace.LatestBuild.Transition != codersdk.WorkspaceTransitionStart {
if workspace.LatestBuild.Transition == codersdk.WorkspaceTransitionDelete {
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.Errorf("workspace %q is deleted", workspace.Name)
}
if workspace.LatestBuild.Job.Status == codersdk.ProvisionerJobFailed {
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.Errorf("workspace %q is in failed state", workspace.Name)
}
if workspace.LatestBuild.Status != codersdk.WorkspaceStatusStopped {
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.Errorf("workspace must be started; was unable to autostart as the last build job is %q, expected %q",
workspace.LatestBuild.Status, codersdk.WorkspaceStatusStopped)
}
// Start workspace
build, err := client.CreateWorkspaceBuild(ctx, workspace.ID, codersdk.CreateWorkspaceBuildRequest{
Transition: codersdk.WorkspaceTransitionStart,
})
if err != nil {
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.Errorf("failed to start workspace: %w", err)
}
// Wait for build to complete
if build.Job.CompletedAt == nil {
err := cliui.WorkspaceBuild(ctx, io.Discard, client, build.ID)
if err != nil {
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.Errorf("failed to wait for build completion: %w", err)
}
}
// Refresh workspace after build completes
workspace, err = client.Workspace(ctx, workspace.ID)
if err != nil {
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, err
}
}
// Find agent
workspaceAgent, err := getWorkspaceAgent(workspace, agentName)
if err != nil {
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, err
}
return workspace, workspaceAgent, nil
}
// getWorkspaceAgent finds the specified agent in the workspace
func getWorkspaceAgent(workspace codersdk.Workspace, agentName string) (codersdk.WorkspaceAgent, error) {
resources := workspace.LatestBuild.Resources
var agents []codersdk.WorkspaceAgent
var availableNames []string
for _, resource := range resources {
for _, agent := range resource.Agents {
availableNames = append(availableNames, agent.Name)
agents = append(agents, agent)
}
}
if len(agents) == 0 {
return codersdk.WorkspaceAgent{}, xerrors.Errorf("workspace %q has no agents", workspace.Name)
}
if agentName != "" {
for _, agent := range agents {
if agent.Name == agentName || agent.ID.String() == agentName {
return agent, nil
}
}
return codersdk.WorkspaceAgent{}, xerrors.Errorf("agent not found by name %q, available agents: %v", agentName, availableNames)
}
if len(agents) == 1 {
return agents[0], nil
}
return codersdk.WorkspaceAgent{}, xerrors.Errorf("multiple agents found, please specify the agent name, available agents: %v", availableNames)
}
// namedWorkspace gets a workspace by owner/name or just name
func namedWorkspace(ctx context.Context, client *codersdk.Client, identifier string) (codersdk.Workspace, error) {
// Parse owner and workspace name
parts := strings.SplitN(identifier, "/", 2)
var owner, workspaceName string
if len(parts) == 2 {
owner = parts[0]
workspaceName = parts[1]
} else {
owner = "me"
workspaceName = identifier
}
// Handle -- separator format (convert to / format)
if strings.Contains(identifier, "--") && !strings.Contains(identifier, "/") {
dashParts := strings.SplitN(identifier, "--", 2)
if len(dashParts) == 2 {
owner = dashParts[0]
workspaceName = dashParts[1]
}
}
return client.WorkspaceByOwnerAndName(ctx, owner, workspaceName, codersdk.WorkspaceOptions{})
}
// NormalizeWorkspaceInput converts workspace name input to standard format.
// Handles the following input formats:
// - workspace → workspace
// - workspace.agent → workspace.agent
// - owner/workspace → owner/workspace
// - owner--workspace → owner/workspace
// - owner/workspace.agent → owner/workspace.agent
// - owner--workspace.agent → owner/workspace.agent
// - agent.workspace.owner → owner/workspace.agent (Coder Connect format)
func NormalizeWorkspaceInput(input string) string {
// Handle the special Coder Connect format: agent.workspace.owner
// This format uses only dots and has exactly 3 parts
if strings.Count(input, ".") == 2 && !strings.Contains(input, "/") && !strings.Contains(input, "--") {
parts := strings.Split(input, ".")
if len(parts) == 3 {
// Convert agent.workspace.owner → owner/workspace.agent
return fmt.Sprintf("%s/%s.%s", parts[2], parts[1], parts[0])
}
}
// Convert -- separator to / separator for consistency
normalized := strings.ReplaceAll(input, "--", "/")
return normalized
}
+161
View File
@@ -0,0 +1,161 @@
package toolsdk_test
import (
"context"
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/codersdk/toolsdk"
)
func TestWorkspaceBash(t *testing.T) {
t.Parallel()
t.Run("ValidateArgs", func(t *testing.T) {
t.Parallel()
deps := toolsdk.Deps{}
ctx := context.Background()
// Test empty workspace name
args := toolsdk.WorkspaceBashArgs{
Workspace: "",
Command: "echo test",
}
_, err := toolsdk.WorkspaceBash.Handler(ctx, deps, args)
require.Error(t, err)
require.Contains(t, err.Error(), "workspace name cannot be empty")
// Test empty command
args = toolsdk.WorkspaceBashArgs{
Workspace: "test-workspace",
Command: "",
}
_, err = toolsdk.WorkspaceBash.Handler(ctx, deps, args)
require.Error(t, err)
require.Contains(t, err.Error(), "command cannot be empty")
})
t.Run("ErrorScenarios", func(t *testing.T) {
t.Parallel()
deps := toolsdk.Deps{} // Empty deps will cause client access to fail
ctx := context.Background()
// Test input validation errors (these should fail before client access)
t.Run("EmptyWorkspace", func(t *testing.T) {
args := toolsdk.WorkspaceBashArgs{
Workspace: "", // Empty workspace should be caught by validation
Command: "echo test",
}
_, err := toolsdk.WorkspaceBash.Handler(ctx, deps, args)
require.Error(t, err)
require.Contains(t, err.Error(), "workspace name cannot be empty")
})
t.Run("EmptyCommand", func(t *testing.T) {
args := toolsdk.WorkspaceBashArgs{
Workspace: "test-workspace",
Command: "", // Empty command should be caught by validation
}
_, err := toolsdk.WorkspaceBash.Handler(ctx, deps, args)
require.Error(t, err)
require.Contains(t, err.Error(), "command cannot be empty")
})
})
t.Run("ToolMetadata", func(t *testing.T) {
t.Parallel()
tool := toolsdk.WorkspaceBash
require.Equal(t, toolsdk.ToolNameWorkspaceBash, tool.Name)
require.NotEmpty(t, tool.Description)
require.Contains(t, tool.Description, "Execute a bash command in a Coder workspace")
require.Contains(t, tool.Description, "output is trimmed of leading and trailing whitespace")
require.Contains(t, tool.Schema.Required, "workspace")
require.Contains(t, tool.Schema.Required, "command")
// Check that schema has the required properties
require.Contains(t, tool.Schema.Properties, "workspace")
require.Contains(t, tool.Schema.Properties, "command")
})
t.Run("GenericTool", func(t *testing.T) {
t.Parallel()
genericTool := toolsdk.WorkspaceBash.Generic()
require.Equal(t, toolsdk.ToolNameWorkspaceBash, genericTool.Name)
require.NotEmpty(t, genericTool.Description)
require.NotNil(t, genericTool.Handler)
require.False(t, genericTool.UserClientOptional)
})
}
func TestNormalizeWorkspaceInput(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
input string
expected string
}{
{
name: "SimpleWorkspace",
input: "workspace",
expected: "workspace",
},
{
name: "WorkspaceWithAgent",
input: "workspace.agent",
expected: "workspace.agent",
},
{
name: "OwnerAndWorkspace",
input: "owner/workspace",
expected: "owner/workspace",
},
{
name: "OwnerDashWorkspace",
input: "owner--workspace",
expected: "owner/workspace",
},
{
name: "OwnerWorkspaceAgent",
input: "owner/workspace.agent",
expected: "owner/workspace.agent",
},
{
name: "OwnerDashWorkspaceAgent",
input: "owner--workspace.agent",
expected: "owner/workspace.agent",
},
{
name: "CoderConnectFormat",
input: "agent.workspace.owner", // Special Coder Connect reverse format
expected: "owner/workspace.agent",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
result := toolsdk.NormalizeWorkspaceInput(tc.input)
require.Equal(t, tc.expected, result, "Input %q should normalize to %q but got %q", tc.input, tc.expected, result)
})
}
}
func TestAllToolsIncludesBash(t *testing.T) {
t.Parallel()
// Verify that WorkspaceBash is included in the All slice
found := false
for _, tool := range toolsdk.All {
if tool.Name == toolsdk.ToolNameWorkspaceBash {
found = true
break
}
}
require.True(t, found, "WorkspaceBash tool should be included in toolsdk.All")
}
+2
View File
@@ -33,6 +33,7 @@ const (
ToolNameUploadTarFile = "coder_upload_tar_file"
ToolNameCreateTemplate = "coder_create_template"
ToolNameDeleteTemplate = "coder_delete_template"
ToolNameWorkspaceBash = "coder_workspace_bash"
)
func NewDeps(client *codersdk.Client, opts ...func(*Deps)) (Deps, error) {
@@ -183,6 +184,7 @@ var All = []GenericTool{
ReportTask.Generic(),
UploadTarFile.Generic(),
UpdateTemplateActiveVersion.Generic(),
WorkspaceBash.Generic(),
}
type ReportTaskArgs struct {
+75 -2
View File
@@ -16,6 +16,7 @@ import (
"github.com/coder/aisdk-go"
"github.com/coder/coder/v2/agent/agenttest"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbfake"
@@ -27,11 +28,32 @@ import (
"github.com/coder/coder/v2/testutil"
)
// setupWorkspaceForAgent creates a workspace setup exactly like main SSH tests
// nolint:gocritic // This is in a test package and does not end up in the build
func setupWorkspaceForAgent(t *testing.T) (*codersdk.Client, database.WorkspaceTable, string) {
t.Helper()
client, store := coderdtest.NewWithDatabase(t, nil)
client.SetLogger(testutil.Logger(t).Named("client"))
first := coderdtest.CreateFirstUser(t, client)
userClient, user := coderdtest.CreateAnotherUserMutators(t, client, first.OrganizationID, nil, func(r *codersdk.CreateUserRequestWithOrgs) {
r.Username = "myuser"
})
// nolint:gocritic // This is in a test package and does not end up in the build
r := dbfake.WorkspaceBuild(t, store, database.WorkspaceTable{
Name: "myworkspace",
OrganizationID: first.OrganizationID,
OwnerID: user.ID,
}).WithAgent().Do()
return userClient, r.Workspace, r.AgentToken
}
// These tests are dependent on the state of the coder server.
// Running them in parallel is prone to racy behavior.
// nolint:tparallel,paralleltest
func TestTools(t *testing.T) {
// Given: a running coderd instance
// Given: a running coderd instance using SSH test setup pattern
setupCtx := testutil.Context(t, testutil.WaitShort)
client, store := coderdtest.NewWithDatabase(t, nil)
owner := coderdtest.CreateFirstUser(t, client)
@@ -373,6 +395,57 @@ func TestTools(t *testing.T) {
require.NoError(t, err)
require.NotEmpty(t, res.ID, "expected a workspace ID")
})
t.Run("WorkspaceSSHExec", func(t *testing.T) {
// Setup workspace exactly like main SSH tests
client, workspace, agentToken := setupWorkspaceForAgent(t)
// Start agent and wait for it to be ready (following main SSH test pattern)
_ = agenttest.New(t, client.URL, agentToken)
// Wait for workspace agents to be ready like main SSH tests do
coderdtest.NewWorkspaceAgentWaiter(t, client, workspace.ID).Wait()
// Create tool dependencies using client
tb, err := toolsdk.NewDeps(client)
require.NoError(t, err)
// Test basic command execution
result, err := testTool(t, toolsdk.WorkspaceBash, tb, toolsdk.WorkspaceBashArgs{
Workspace: workspace.Name,
Command: "echo 'hello world'",
})
require.NoError(t, err)
require.Equal(t, 0, result.ExitCode)
require.Equal(t, "hello world", result.Output)
// Test output trimming
result, err = testTool(t, toolsdk.WorkspaceBash, tb, toolsdk.WorkspaceBashArgs{
Workspace: workspace.Name,
Command: "echo -e '\\n test with whitespace \\n'",
})
require.NoError(t, err)
require.Equal(t, 0, result.ExitCode)
require.Equal(t, "test with whitespace", result.Output) // Should be trimmed
// Test non-zero exit code
result, err = testTool(t, toolsdk.WorkspaceBash, tb, toolsdk.WorkspaceBashArgs{
Workspace: workspace.Name,
Command: "exit 42",
})
require.NoError(t, err)
require.Equal(t, 42, result.ExitCode)
require.Empty(t, result.Output)
// Test with workspace owner format - using the myuser from setup
result, err = testTool(t, toolsdk.WorkspaceBash, tb, toolsdk.WorkspaceBashArgs{
Workspace: "myuser/" + workspace.Name,
Command: "echo 'owner format works'",
})
require.NoError(t, err)
require.Equal(t, 0, result.ExitCode)
require.Equal(t, "owner format works", result.Output)
})
}
// TestedTools keeps track of which tools have been tested.
@@ -386,7 +459,7 @@ func testTool[Arg, Ret any](t *testing.T, tool toolsdk.Tool[Arg, Ret], tb toolsd
defer func() { testedTools.Store(tool.Tool.Name, true) }()
toolArgs, err := json.Marshal(args)
require.NoError(t, err, "failed to marshal args")
result, err := tool.Generic().Handler(context.Background(), tb, toolArgs)
result, err := tool.Generic().Handler(t.Context(), tb, toolArgs)
var ret Ret
require.NoError(t, json.Unmarshal(result, &ret), "failed to unmarshal result %q", string(result))
return ret, err