feat(cli): allow SSH command to connect to running container (#16726)

Fixes https://github.com/coder/coder/issues/16709 and
https://github.com/coder/coder/issues/16420

Adds the capability to`coder ssh` into a running container if `CODER_AGENT_DEVCONTAINERS_ENABLE=true`.

Notes:
* SFTP is currently not supported
* Haven't tested X11 container forwarding
* Haven't tested agent forwarding
This commit is contained in:
Cian Johnston
2025-02-28 09:38:45 +00:00
committed by GitHub
parent 64fec8bf0b
commit ec44f06f5c
8 changed files with 252 additions and 42 deletions
+7 -5
View File
@@ -91,8 +91,8 @@ type Options struct {
Execer agentexec.Execer
ContainerLister agentcontainers.Lister
ExperimentalContainersEnabled bool
ExperimentalConnectionReports bool
ExperimentalConnectionReports bool
ExperimentalDevcontainersEnabled bool
}
type Client interface {
@@ -156,7 +156,7 @@ func New(options Options) Agent {
options.Execer = agentexec.DefaultExecer
}
if options.ContainerLister == nil {
options.ContainerLister = agentcontainers.NewDocker(options.Execer)
options.ContainerLister = agentcontainers.NoopLister{}
}
hardCtx, hardCancel := context.WithCancel(context.Background())
@@ -195,7 +195,7 @@ func New(options Options) Agent {
execer: options.Execer,
lister: options.ContainerLister,
experimentalDevcontainersEnabled: options.ExperimentalContainersEnabled,
experimentalDevcontainersEnabled: options.ExperimentalDevcontainersEnabled,
experimentalConnectionReports: options.ExperimentalConnectionReports,
}
// Initially, we have a closed channel, reflecting the fact that we are not initially connected.
@@ -307,6 +307,8 @@ func (a *agent) init() {
return a.reportConnection(id, connectionType, ip)
},
ExperimentalDevContainersEnabled: a.experimentalDevcontainersEnabled,
})
if err != nil {
panic(err)
@@ -335,7 +337,7 @@ func (a *agent) init() {
a.metrics.connectionsTotal, a.metrics.reconnectingPTYErrors,
a.reconnectingPTYTimeout,
func(s *reconnectingpty.Server) {
s.ExperimentalContainersEnabled = a.experimentalDevcontainersEnabled
s.ExperimentalDevcontainersEnabled = a.experimentalDevcontainersEnabled
},
)
go a.runLoop()
+1 -1
View File
@@ -1841,7 +1841,7 @@ func TestAgent_ReconnectingPTYContainer(t *testing.T) {
// nolint: dogsled
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
o.ExperimentalContainersEnabled = true
o.ExperimentalDevcontainersEnabled = true
})
ac, err := conn.ReconnectingPTY(ctx, uuid.New(), 80, 80, "/bin/sh", func(arp *workspacesdk.AgentReconnectingPTYInit) {
arp.Container = ct.Container.ID
+58 -12
View File
@@ -29,6 +29,7 @@ import (
"cdr.dev/slog"
"github.com/coder/coder/v2/agent/agentcontainers"
"github.com/coder/coder/v2/agent/agentexec"
"github.com/coder/coder/v2/agent/agentrsa"
"github.com/coder/coder/v2/agent/usershell"
@@ -60,6 +61,14 @@ const (
// MagicSessionTypeEnvironmentVariable is used to track the purpose behind an SSH connection.
// This is stripped from any commands being executed, and is counted towards connection stats.
MagicSessionTypeEnvironmentVariable = "CODER_SSH_SESSION_TYPE"
// ContainerEnvironmentVariable is used to specify the target container for an SSH connection.
// This is stripped from any commands being executed.
// Only available if CODER_AGENT_DEVCONTAINERS_ENABLE=true.
ContainerEnvironmentVariable = "CODER_CONTAINER"
// ContainerUserEnvironmentVariable is used to specify the container user for
// an SSH connection.
// Only available if CODER_AGENT_DEVCONTAINERS_ENABLE=true.
ContainerUserEnvironmentVariable = "CODER_CONTAINER_USER"
)
// MagicSessionType enums.
@@ -104,6 +113,9 @@ type Config struct {
BlockFileTransfer bool
// ReportConnection.
ReportConnection reportConnectionFunc
// Experimental: allow connecting to running containers if
// CODER_AGENT_DEVCONTAINERS_ENABLE=true.
ExperimentalDevContainersEnabled bool
}
type Server struct {
@@ -324,6 +336,22 @@ func (s *sessionCloseTracker) Close() error {
return s.Session.Close()
}
func extractContainerInfo(env []string) (container, containerUser string, filteredEnv []string) {
for _, kv := range env {
if strings.HasPrefix(kv, ContainerEnvironmentVariable+"=") {
container = strings.TrimPrefix(kv, ContainerEnvironmentVariable+"=")
}
if strings.HasPrefix(kv, ContainerUserEnvironmentVariable+"=") {
containerUser = strings.TrimPrefix(kv, ContainerUserEnvironmentVariable+"=")
}
}
return container, containerUser, slices.DeleteFunc(env, func(kv string) bool {
return strings.HasPrefix(kv, ContainerEnvironmentVariable+"=") || strings.HasPrefix(kv, ContainerUserEnvironmentVariable+"=")
})
}
func (s *Server) sessionHandler(session ssh.Session) {
ctx := session.Context()
id := uuid.New()
@@ -353,6 +381,7 @@ func (s *Server) sessionHandler(session ssh.Session) {
defer s.trackSession(session, false)
reportSession := true
switch magicType {
case MagicSessionTypeVSCode:
s.connCountVSCode.Add(1)
@@ -395,9 +424,22 @@ func (s *Server) sessionHandler(session ssh.Session) {
return
}
container, containerUser, env := extractContainerInfo(env)
if container != "" {
s.logger.Debug(ctx, "container info",
slog.F("container", container),
slog.F("container_user", containerUser),
)
}
switch ss := session.Subsystem(); ss {
case "":
case "sftp":
if s.config.ExperimentalDevContainersEnabled && container != "" {
closeCause("sftp not yet supported with containers")
_ = session.Exit(1)
return
}
err := s.sftpHandler(logger, session)
if err != nil {
closeCause(err.Error())
@@ -422,7 +464,7 @@ func (s *Server) sessionHandler(session ssh.Session) {
env = append(env, fmt.Sprintf("DISPLAY=localhost:%d.%d", display, x11.ScreenNumber))
}
err := s.sessionStart(logger, session, env, magicType)
err := s.sessionStart(logger, session, env, magicType, container, containerUser)
var exitError *exec.ExitError
if xerrors.As(err, &exitError) {
code := exitError.ExitCode()
@@ -495,18 +537,27 @@ func (s *Server) fileTransferBlocked(session ssh.Session) bool {
return false
}
func (s *Server) sessionStart(logger slog.Logger, session ssh.Session, env []string, magicType MagicSessionType) (retErr error) {
func (s *Server) sessionStart(logger slog.Logger, session ssh.Session, env []string, magicType MagicSessionType, container, containerUser string) (retErr error) {
ctx := session.Context()
magicTypeLabel := magicTypeMetricLabel(magicType)
sshPty, windowSize, isPty := session.Pty()
ptyLabel := "no"
if isPty {
ptyLabel = "yes"
}
cmd, err := s.CreateCommand(ctx, session.RawCommand(), env, nil)
if err != nil {
ptyLabel := "no"
if isPty {
ptyLabel = "yes"
var ei usershell.EnvInfoer
var err error
if s.config.ExperimentalDevContainersEnabled && container != "" {
ei, err = agentcontainers.EnvInfo(ctx, s.Execer, container, containerUser)
if err != nil {
s.metrics.sessionErrors.WithLabelValues(magicTypeLabel, ptyLabel, "container_env_info").Add(1)
return err
}
}
cmd, err := s.CreateCommand(ctx, session.RawCommand(), env, ei)
if err != nil {
s.metrics.sessionErrors.WithLabelValues(magicTypeLabel, ptyLabel, "create_command").Add(1)
return err
}
@@ -514,11 +565,6 @@ func (s *Server) sessionStart(logger slog.Logger, session ssh.Session, env []str
if ssh.AgentRequested(session) {
l, err := ssh.NewAgentListener()
if err != nil {
ptyLabel := "no"
if isPty {
ptyLabel = "yes"
}
s.metrics.sessionErrors.WithLabelValues(magicTypeLabel, ptyLabel, "listener").Add(1)
return xerrors.Errorf("new agent listener: %w", err)
}
+2 -2
View File
@@ -32,7 +32,7 @@ type Server struct {
reconnectingPTYs sync.Map
timeout time.Duration
ExperimentalContainersEnabled bool
ExperimentalDevcontainersEnabled bool
}
// NewServer returns a new ReconnectingPTY server
@@ -187,7 +187,7 @@ func (s *Server) handleConn(ctx context.Context, logger slog.Logger, conn net.Co
}()
var ei usershell.EnvInfoer
if s.ExperimentalContainersEnabled && msg.Container != "" {
if s.ExperimentalDevcontainersEnabled && msg.Container != "" {
dei, err := agentcontainers.EnvInfo(ctx, s.commandCreator.Execer, msg.Container, msg.ContainerUser)
if err != nil {
return xerrors.Errorf("get container env info: %w", err)
+21 -21
View File
@@ -38,24 +38,24 @@ import (
func (r *RootCmd) workspaceAgent() *serpent.Command {
var (
auth string
logDir string
scriptDataDir string
pprofAddress string
noReap bool
sshMaxTimeout time.Duration
tailnetListenPort int64
prometheusAddress string
debugAddress string
slogHumanPath string
slogJSONPath string
slogStackdriverPath string
blockFileTransfer bool
agentHeaderCommand string
agentHeader []string
devcontainersEnabled bool
auth string
logDir string
scriptDataDir string
pprofAddress string
noReap bool
sshMaxTimeout time.Duration
tailnetListenPort int64
prometheusAddress string
debugAddress string
slogHumanPath string
slogJSONPath string
slogStackdriverPath string
blockFileTransfer bool
agentHeaderCommand string
agentHeader []string
experimentalConnectionReports bool
experimentalConnectionReports bool
experimentalDevcontainersEnabled bool
)
cmd := &serpent.Command{
Use: "agent",
@@ -319,7 +319,7 @@ func (r *RootCmd) workspaceAgent() *serpent.Command {
}
var containerLister agentcontainers.Lister
if !devcontainersEnabled {
if !experimentalDevcontainersEnabled {
logger.Info(ctx, "agent devcontainer detection not enabled")
containerLister = &agentcontainers.NoopLister{}
} else {
@@ -358,8 +358,8 @@ func (r *RootCmd) workspaceAgent() *serpent.Command {
Execer: execer,
ContainerLister: containerLister,
ExperimentalContainersEnabled: devcontainersEnabled,
ExperimentalConnectionReports: experimentalConnectionReports,
ExperimentalDevcontainersEnabled: experimentalDevcontainersEnabled,
ExperimentalConnectionReports: experimentalConnectionReports,
})
promHandler := agent.PrometheusMetricsHandler(prometheusRegistry, logger)
@@ -487,7 +487,7 @@ func (r *RootCmd) workspaceAgent() *serpent.Command {
Default: "false",
Env: "CODER_AGENT_DEVCONTAINERS_ENABLE",
Description: "Allow the agent to automatically detect running devcontainers.",
Value: serpent.BoolOf(&devcontainersEnabled),
Value: serpent.BoolOf(&experimentalDevcontainersEnabled),
},
{
Flag: "experimental-connection-reports-enable",
+3 -1
View File
@@ -9,6 +9,7 @@ import (
"github.com/ory/dockertest/v3/docker"
"github.com/coder/coder/v2/agent"
"github.com/coder/coder/v2/agent/agentcontainers"
"github.com/coder/coder/v2/agent/agenttest"
"github.com/coder/coder/v2/cli/clitest"
"github.com/coder/coder/v2/coderd/coderdtest"
@@ -88,7 +89,8 @@ func TestExpRpty(t *testing.T) {
})
_ = agenttest.New(t, client.URL, agentToken, func(o *agent.Options) {
o.ExperimentalContainersEnabled = true
o.ExperimentalDevcontainersEnabled = true
o.ContainerLister = agentcontainers.NewDocker(o.Execer)
})
_ = coderdtest.NewWorkspaceAgentWaiter(t, client, workspace.ID).Wait()
+56
View File
@@ -34,6 +34,7 @@ import (
"cdr.dev/slog"
"cdr.dev/slog/sloggers/sloghuman"
"github.com/coder/coder/v2/agent/agentssh"
"github.com/coder/coder/v2/cli/cliui"
"github.com/coder/coder/v2/cli/cliutil"
"github.com/coder/coder/v2/coderd/autobuild/notify"
@@ -76,6 +77,9 @@ func (r *RootCmd) ssh() *serpent.Command {
appearanceConfig codersdk.AppearanceConfig
networkInfoDir string
networkInfoInterval time.Duration
containerName string
containerUser string
)
client := new(codersdk.Client)
cmd := &serpent.Command{
@@ -282,6 +286,34 @@ func (r *RootCmd) ssh() *serpent.Command {
}
conn.AwaitReachable(ctx)
if containerName != "" {
cts, err := client.WorkspaceAgentListContainers(ctx, workspaceAgent.ID, nil)
if err != nil {
return xerrors.Errorf("list containers: %w", err)
}
if len(cts.Containers) == 0 {
cliui.Info(inv.Stderr, "No containers found!")
cliui.Info(inv.Stderr, "Tip: Agent container integration is experimental and not enabled by default.")
cliui.Info(inv.Stderr, " To enable it, set CODER_AGENT_DEVCONTAINERS_ENABLE=true in your template.")
return nil
}
var found bool
for _, c := range cts.Containers {
if c.FriendlyName == containerName || c.ID == containerName {
found = true
break
}
}
if !found {
availableContainers := make([]string, len(cts.Containers))
for i, c := range cts.Containers {
availableContainers[i] = c.FriendlyName
}
cliui.Errorf(inv.Stderr, "Container not found: %q\nAvailable containers: %v", containerName, availableContainers)
return nil
}
}
stopPolling := tryPollWorkspaceAutostop(ctx, client, workspace)
defer stopPolling()
@@ -454,6 +486,17 @@ func (r *RootCmd) ssh() *serpent.Command {
}
}
if containerName != "" {
for k, v := range map[string]string{
agentssh.ContainerEnvironmentVariable: containerName,
agentssh.ContainerUserEnvironmentVariable: containerUser,
} {
if err := sshSession.Setenv(k, v); err != nil {
return xerrors.Errorf("setenv: %w", err)
}
}
}
err = sshSession.RequestPty("xterm-256color", 128, 128, gossh.TerminalModes{})
if err != nil {
return xerrors.Errorf("request pty: %w", err)
@@ -594,6 +637,19 @@ func (r *RootCmd) ssh() *serpent.Command {
Default: "5s",
Value: serpent.DurationOf(&networkInfoInterval),
},
{
Flag: "container",
FlagShorthand: "c",
Description: "Specifies a container inside the workspace to connect to.",
Value: serpent.StringOf(&containerName),
Hidden: true, // Hidden until this features is at least in beta.
},
{
Flag: "container-user",
Description: "When connecting to a container, specifies the user to connect as.",
Value: serpent.StringOf(&containerUser),
Hidden: true, // Hidden until this features is at least in beta.
},
sshDisableAutostartOption(serpent.BoolOf(&disableAutostart)),
}
return cmd
+104
View File
@@ -24,6 +24,8 @@ import (
"time"
"github.com/google/uuid"
"github.com/ory/dockertest/v3"
"github.com/ory/dockertest/v3/docker"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -33,6 +35,7 @@ import (
"golang.org/x/xerrors"
"github.com/coder/coder/v2/agent"
"github.com/coder/coder/v2/agent/agentcontainers"
"github.com/coder/coder/v2/agent/agentssh"
"github.com/coder/coder/v2/agent/agenttest"
agentproto "github.com/coder/coder/v2/agent/proto"
@@ -1924,6 +1927,107 @@ Expire-Date: 0
<-cmdDone
}
func TestSSH_Container(t *testing.T) {
t.Parallel()
if runtime.GOOS != "linux" {
t.Skip("Skipping test on non-Linux platform")
}
t.Run("OK", func(t *testing.T) {
t.Parallel()
client, workspace, agentToken := setupWorkspaceForAgent(t)
ctx := testutil.Context(t, testutil.WaitLong)
pool, err := dockertest.NewPool("")
require.NoError(t, err, "Could not connect to docker")
ct, err := pool.RunWithOptions(&dockertest.RunOptions{
Repository: "busybox",
Tag: "latest",
Cmd: []string{"sleep", "infnity"},
}, func(config *docker.HostConfig) {
config.AutoRemove = true
config.RestartPolicy = docker.RestartPolicy{Name: "no"}
})
require.NoError(t, err, "Could not start container")
// Wait for container to start
require.Eventually(t, func() bool {
ct, ok := pool.ContainerByName(ct.Container.Name)
return ok && ct.Container.State.Running
}, testutil.WaitShort, testutil.IntervalSlow, "Container did not start in time")
t.Cleanup(func() {
err := pool.Purge(ct)
require.NoError(t, err, "Could not stop container")
})
_ = agenttest.New(t, client.URL, agentToken, func(o *agent.Options) {
o.ExperimentalDevcontainersEnabled = true
o.ContainerLister = agentcontainers.NewDocker(o.Execer)
})
_ = coderdtest.NewWorkspaceAgentWaiter(t, client, workspace.ID).Wait()
inv, root := clitest.New(t, "ssh", workspace.Name, "-c", ct.Container.ID)
clitest.SetupConfig(t, client, root)
ptty := ptytest.New(t).Attach(inv)
cmdDone := tGo(t, func() {
err := inv.WithContext(ctx).Run()
assert.NoError(t, err)
})
ptty.ExpectMatch(" #")
ptty.WriteLine("hostname")
ptty.ExpectMatch(ct.Container.Config.Hostname)
ptty.WriteLine("exit")
<-cmdDone
})
t.Run("NotFound", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
client, workspace, agentToken := setupWorkspaceForAgent(t)
_ = agenttest.New(t, client.URL, agentToken, func(o *agent.Options) {
o.ExperimentalDevcontainersEnabled = true
o.ContainerLister = agentcontainers.NewDocker(o.Execer)
})
_ = coderdtest.NewWorkspaceAgentWaiter(t, client, workspace.ID).Wait()
inv, root := clitest.New(t, "ssh", workspace.Name, "-c", uuid.NewString())
clitest.SetupConfig(t, client, root)
ptty := ptytest.New(t).Attach(inv)
cmdDone := tGo(t, func() {
err := inv.WithContext(ctx).Run()
assert.NoError(t, err)
})
ptty.ExpectMatch("Container not found:")
<-cmdDone
})
t.Run("NotEnabled", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
client, workspace, agentToken := setupWorkspaceForAgent(t)
_ = agenttest.New(t, client.URL, agentToken)
_ = coderdtest.NewWorkspaceAgentWaiter(t, client, workspace.ID).Wait()
inv, root := clitest.New(t, "ssh", workspace.Name, "-c", uuid.NewString())
clitest.SetupConfig(t, client, root)
ptty := ptytest.New(t).Attach(inv)
cmdDone := tGo(t, func() {
err := inv.WithContext(ctx).Run()
assert.NoError(t, err)
})
ptty.ExpectMatch("No containers found!")
ptty.ExpectMatch("Tip: Agent container integration is experimental and not enabled by default.")
<-cmdDone
})
}
// tGoContext runs fn in a goroutine passing a context that will be
// canceled on test completion and wait until fn has finished executing.
// Done and cancel are returned for optionally waiting until completion