mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
fix: add agent exec abstraction (#15717)
This commit is contained in:
+9
-1
@@ -33,6 +33,7 @@ import (
|
||||
"tailscale.com/util/clientmetric"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
"github.com/coder/coder/v2/agent/agentscripts"
|
||||
"github.com/coder/coder/v2/agent/agentssh"
|
||||
"github.com/coder/coder/v2/agent/proto"
|
||||
@@ -80,6 +81,7 @@ type Options struct {
|
||||
ReportMetadataInterval time.Duration
|
||||
ServiceBannerRefreshInterval time.Duration
|
||||
BlockFileTransfer bool
|
||||
Execer agentexec.Execer
|
||||
}
|
||||
|
||||
type Client interface {
|
||||
@@ -139,6 +141,10 @@ func New(options Options) Agent {
|
||||
prometheusRegistry = prometheus.NewRegistry()
|
||||
}
|
||||
|
||||
if options.Execer == nil {
|
||||
options.Execer = agentexec.DefaultExecer
|
||||
}
|
||||
|
||||
hardCtx, hardCancel := context.WithCancel(context.Background())
|
||||
gracefulCtx, gracefulCancel := context.WithCancel(hardCtx)
|
||||
a := &agent{
|
||||
@@ -171,6 +177,7 @@ func New(options Options) Agent {
|
||||
|
||||
prometheusRegistry: prometheusRegistry,
|
||||
metrics: newAgentMetrics(prometheusRegistry),
|
||||
execer: options.Execer,
|
||||
}
|
||||
// Initially, we have a closed channel, reflecting the fact that we are not initially connected.
|
||||
// Each time we connect we replace the channel (while holding the closeMutex) with a new one
|
||||
@@ -239,6 +246,7 @@ type agent struct {
|
||||
// metrics are prometheus registered metrics that will be collected and
|
||||
// labeled in Coder with the agent + workspace.
|
||||
metrics *agentMetrics
|
||||
execer agentexec.Execer
|
||||
}
|
||||
|
||||
func (a *agent) TailnetConn() *tailnet.Conn {
|
||||
@@ -247,7 +255,7 @@ func (a *agent) TailnetConn() *tailnet.Conn {
|
||||
|
||||
func (a *agent) init() {
|
||||
// pass the "hard" context because we explicitly close the SSH server as part of graceful shutdown.
|
||||
sshSrv, err := agentssh.NewServer(a.hardCtx, a.logger.Named("ssh-server"), a.prometheusRegistry, a.filesystem, &agentssh.Config{
|
||||
sshSrv, err := agentssh.NewServer(a.hardCtx, a.logger.Named("ssh-server"), a.prometheusRegistry, a.filesystem, a.execer, &agentssh.Config{
|
||||
MaxTimeout: a.sshMaxTimeout,
|
||||
MOTDFile: func() string { return a.manifest.Load().MOTDFile },
|
||||
AnnouncementBanners: func() *[]codersdk.BannerConfig { return a.announcementBanners.Load() },
|
||||
|
||||
@@ -17,9 +17,6 @@ import (
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
// unset is set to an invalid value for nice and oom scores.
|
||||
const unset = -2000
|
||||
|
||||
// CLI runs the agent-exec command. It should only be called by the cli package.
|
||||
func CLI() error {
|
||||
// We lock the OS thread here to avoid a race condition where the nice priority
|
||||
|
||||
+72
-31
@@ -20,60 +20,101 @@ const (
|
||||
EnvProcPrioMgmt = "CODER_PROC_PRIO_MGMT"
|
||||
EnvProcOOMScore = "CODER_PROC_OOM_SCORE"
|
||||
EnvProcNiceScore = "CODER_PROC_NICE_SCORE"
|
||||
|
||||
// unset is set to an invalid value for nice and oom scores.
|
||||
unset = -2000
|
||||
)
|
||||
|
||||
// CommandContext returns an exec.Cmd that calls "coder agent-exec" prior to exec'ing
|
||||
// the provided command if CODER_PROC_PRIO_MGMT is set, otherwise a normal exec.Cmd
|
||||
// is returned. All instances of exec.Cmd should flow through this function to ensure
|
||||
// proper resource constraints are applied to the child process.
|
||||
func CommandContext(ctx context.Context, cmd string, args ...string) (*exec.Cmd, error) {
|
||||
cmd, args, err := agentExecCmd(cmd, args...)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("agent exec cmd: %w", err)
|
||||
}
|
||||
return exec.CommandContext(ctx, cmd, args...), nil
|
||||
var DefaultExecer Execer = execer{}
|
||||
|
||||
// Execer defines an abstraction for creating exec.Cmd variants. It's unfortunately
|
||||
// necessary because we need to be able to wrap child processes with "coder agent-exec"
|
||||
// for templates that expect the agent to manage process priority.
|
||||
type Execer interface {
|
||||
// CommandContext returns an exec.Cmd that calls "coder agent-exec" prior to exec'ing
|
||||
// the provided command if CODER_PROC_PRIO_MGMT is set, otherwise a normal exec.Cmd
|
||||
// is returned. All instances of exec.Cmd should flow through this function to ensure
|
||||
// proper resource constraints are applied to the child process.
|
||||
CommandContext(ctx context.Context, cmd string, args ...string) *exec.Cmd
|
||||
// PTYCommandContext returns an pty.Cmd that calls "coder agent-exec" prior to exec'ing
|
||||
// the provided command if CODER_PROC_PRIO_MGMT is set, otherwise a normal pty.Cmd
|
||||
// is returned. All instances of pty.Cmd should flow through this function to ensure
|
||||
// proper resource constraints are applied to the child process.
|
||||
PTYCommandContext(ctx context.Context, cmd string, args ...string) *pty.Cmd
|
||||
}
|
||||
|
||||
// PTYCommandContext returns an pty.Cmd that calls "coder agent-exec" prior to exec'ing
|
||||
// the provided command if CODER_PROC_PRIO_MGMT is set, otherwise a normal pty.Cmd
|
||||
// is returned. All instances of pty.Cmd should flow through this function to ensure
|
||||
// proper resource constraints are applied to the child process.
|
||||
func PTYCommandContext(ctx context.Context, cmd string, args ...string) (*pty.Cmd, error) {
|
||||
cmd, args, err := agentExecCmd(cmd, args...)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("agent exec cmd: %w", err)
|
||||
}
|
||||
return pty.CommandContext(ctx, cmd, args...), nil
|
||||
}
|
||||
|
||||
func agentExecCmd(cmd string, args ...string) (string, []string, error) {
|
||||
func NewExecer() (Execer, error) {
|
||||
_, enabled := os.LookupEnv(EnvProcPrioMgmt)
|
||||
if runtime.GOOS != "linux" || !enabled {
|
||||
return cmd, args, nil
|
||||
return DefaultExecer, nil
|
||||
}
|
||||
|
||||
executable, err := os.Executable()
|
||||
if err != nil {
|
||||
return "", nil, xerrors.Errorf("get executable: %w", err)
|
||||
return nil, xerrors.Errorf("get executable: %w", err)
|
||||
}
|
||||
|
||||
bin, err := filepath.EvalSymlinks(executable)
|
||||
if err != nil {
|
||||
return "", nil, xerrors.Errorf("eval symlinks: %w", err)
|
||||
return nil, xerrors.Errorf("eval symlinks: %w", err)
|
||||
}
|
||||
|
||||
oomScore, ok := envValInt(EnvProcOOMScore)
|
||||
if !ok {
|
||||
oomScore = unset
|
||||
}
|
||||
|
||||
niceScore, ok := envValInt(EnvProcNiceScore)
|
||||
if !ok {
|
||||
niceScore = unset
|
||||
}
|
||||
|
||||
return priorityExecer{
|
||||
binPath: bin,
|
||||
oomScore: oomScore,
|
||||
niceScore: niceScore,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type execer struct{}
|
||||
|
||||
func (execer) CommandContext(ctx context.Context, cmd string, args ...string) *exec.Cmd {
|
||||
return exec.CommandContext(ctx, cmd, args...)
|
||||
}
|
||||
|
||||
func (execer) PTYCommandContext(ctx context.Context, cmd string, args ...string) *pty.Cmd {
|
||||
return pty.CommandContext(ctx, cmd, args...)
|
||||
}
|
||||
|
||||
type priorityExecer struct {
|
||||
binPath string
|
||||
oomScore int
|
||||
niceScore int
|
||||
}
|
||||
|
||||
func (e priorityExecer) CommandContext(ctx context.Context, cmd string, args ...string) *exec.Cmd {
|
||||
cmd, args = e.agentExecCmd(cmd, args...)
|
||||
return exec.CommandContext(ctx, cmd, args...)
|
||||
}
|
||||
|
||||
func (e priorityExecer) PTYCommandContext(ctx context.Context, cmd string, args ...string) *pty.Cmd {
|
||||
cmd, args = e.agentExecCmd(cmd, args...)
|
||||
return pty.CommandContext(ctx, cmd, args...)
|
||||
}
|
||||
|
||||
func (e priorityExecer) agentExecCmd(cmd string, args ...string) (string, []string) {
|
||||
execArgs := []string{"agent-exec"}
|
||||
if score, ok := envValInt(EnvProcOOMScore); ok {
|
||||
execArgs = append(execArgs, oomScoreArg(score))
|
||||
if e.oomScore != unset {
|
||||
execArgs = append(execArgs, oomScoreArg(e.oomScore))
|
||||
}
|
||||
|
||||
if score, ok := envValInt(EnvProcNiceScore); ok {
|
||||
execArgs = append(execArgs, niceScoreArg(score))
|
||||
if e.niceScore != unset {
|
||||
execArgs = append(execArgs, niceScoreArg(e.niceScore))
|
||||
}
|
||||
execArgs = append(execArgs, "--", cmd)
|
||||
execArgs = append(execArgs, args...)
|
||||
|
||||
return bin, execArgs, nil
|
||||
return e.binPath, execArgs
|
||||
}
|
||||
|
||||
// envValInt searches for a key in a list of environment variables and parses it to an int.
|
||||
|
||||
@@ -0,0 +1,84 @@
|
||||
package agentexec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os/exec"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestExecer(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("Default", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cmd := DefaultExecer.CommandContext(context.Background(), "sh", "-c", "sleep")
|
||||
|
||||
path, err := exec.LookPath("sh")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, path, cmd.Path)
|
||||
require.Equal(t, []string{"sh", "-c", "sleep"}, cmd.Args)
|
||||
})
|
||||
|
||||
t.Run("Priority", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("OK", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
e := priorityExecer{
|
||||
binPath: "/foo/bar/baz",
|
||||
oomScore: unset,
|
||||
niceScore: unset,
|
||||
}
|
||||
|
||||
cmd := e.CommandContext(context.Background(), "sh", "-c", "sleep")
|
||||
require.Equal(t, e.binPath, cmd.Path)
|
||||
require.Equal(t, []string{e.binPath, "agent-exec", "--", "sh", "-c", "sleep"}, cmd.Args)
|
||||
})
|
||||
|
||||
t.Run("Nice", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
e := priorityExecer{
|
||||
binPath: "/foo/bar/baz",
|
||||
oomScore: unset,
|
||||
niceScore: 10,
|
||||
}
|
||||
|
||||
cmd := e.CommandContext(context.Background(), "sh", "-c", "sleep")
|
||||
require.Equal(t, e.binPath, cmd.Path)
|
||||
require.Equal(t, []string{e.binPath, "agent-exec", "--coder-nice=10", "--", "sh", "-c", "sleep"}, cmd.Args)
|
||||
})
|
||||
|
||||
t.Run("OOM", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
e := priorityExecer{
|
||||
binPath: "/foo/bar/baz",
|
||||
oomScore: 123,
|
||||
niceScore: unset,
|
||||
}
|
||||
|
||||
cmd := e.CommandContext(context.Background(), "sh", "-c", "sleep")
|
||||
require.Equal(t, e.binPath, cmd.Path)
|
||||
require.Equal(t, []string{e.binPath, "agent-exec", "--coder-oom=123", "--", "sh", "-c", "sleep"}, cmd.Args)
|
||||
})
|
||||
|
||||
t.Run("Both", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
e := priorityExecer{
|
||||
binPath: "/foo/bar/baz",
|
||||
oomScore: 432,
|
||||
niceScore: 14,
|
||||
}
|
||||
|
||||
cmd := e.CommandContext(context.Background(), "sh", "-c", "sleep")
|
||||
require.Equal(t, e.binPath, cmd.Path)
|
||||
require.Equal(t, []string{e.binPath, "agent-exec", "--coder-oom=432", "--coder-nice=14", "--", "sh", "-c", "sleep"}, cmd.Args)
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -1,119 +0,0 @@
|
||||
package agentexec_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
)
|
||||
|
||||
//nolint:paralleltest // we need to test environment variables
|
||||
func TestExec(t *testing.T) {
|
||||
//nolint:paralleltest // we need to test environment variables
|
||||
t.Run("NonLinux", func(t *testing.T) {
|
||||
t.Setenv(agentexec.EnvProcPrioMgmt, "true")
|
||||
|
||||
if runtime.GOOS == "linux" {
|
||||
t.Skip("skipping on linux")
|
||||
}
|
||||
|
||||
cmd, err := agentexec.CommandContext(context.Background(), "sh", "-c", "sleep")
|
||||
require.NoError(t, err)
|
||||
|
||||
path, err := exec.LookPath("sh")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, path, cmd.Path)
|
||||
require.Equal(t, []string{"sh", "-c", "sleep"}, cmd.Args)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // we need to test environment variables
|
||||
t.Run("Linux", func(t *testing.T) {
|
||||
//nolint:paralleltest // we need to test environment variables
|
||||
t.Run("Disabled", func(t *testing.T) {
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skip("skipping on non-linux")
|
||||
}
|
||||
|
||||
cmd, err := agentexec.CommandContext(context.Background(), "sh", "-c", "sleep")
|
||||
require.NoError(t, err)
|
||||
path, err := exec.LookPath("sh")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, path, cmd.Path)
|
||||
require.Equal(t, []string{"sh", "-c", "sleep"}, cmd.Args)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // we need to test environment variables
|
||||
t.Run("Enabled", func(t *testing.T) {
|
||||
t.Setenv(agentexec.EnvProcPrioMgmt, "hello")
|
||||
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skip("skipping on non-linux")
|
||||
}
|
||||
|
||||
executable, err := os.Executable()
|
||||
require.NoError(t, err)
|
||||
|
||||
cmd, err := agentexec.CommandContext(context.Background(), "sh", "-c", "sleep")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, executable, cmd.Path)
|
||||
require.Equal(t, []string{executable, "agent-exec", "--", "sh", "-c", "sleep"}, cmd.Args)
|
||||
})
|
||||
|
||||
t.Run("Nice", func(t *testing.T) {
|
||||
t.Setenv(agentexec.EnvProcPrioMgmt, "hello")
|
||||
t.Setenv(agentexec.EnvProcNiceScore, "10")
|
||||
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skip("skipping on non-linux")
|
||||
}
|
||||
|
||||
executable, err := os.Executable()
|
||||
require.NoError(t, err)
|
||||
|
||||
cmd, err := agentexec.CommandContext(context.Background(), "sh", "-c", "sleep")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, executable, cmd.Path)
|
||||
require.Equal(t, []string{executable, "agent-exec", "--coder-nice=10", "--", "sh", "-c", "sleep"}, cmd.Args)
|
||||
})
|
||||
|
||||
t.Run("OOM", func(t *testing.T) {
|
||||
t.Setenv(agentexec.EnvProcPrioMgmt, "hello")
|
||||
t.Setenv(agentexec.EnvProcOOMScore, "123")
|
||||
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skip("skipping on non-linux")
|
||||
}
|
||||
|
||||
executable, err := os.Executable()
|
||||
require.NoError(t, err)
|
||||
|
||||
cmd, err := agentexec.CommandContext(context.Background(), "sh", "-c", "sleep")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, executable, cmd.Path)
|
||||
require.Equal(t, []string{executable, "agent-exec", "--coder-oom=123", "--", "sh", "-c", "sleep"}, cmd.Args)
|
||||
})
|
||||
|
||||
t.Run("Both", func(t *testing.T) {
|
||||
t.Setenv(agentexec.EnvProcPrioMgmt, "hello")
|
||||
t.Setenv(agentexec.EnvProcOOMScore, "432")
|
||||
t.Setenv(agentexec.EnvProcNiceScore, "14")
|
||||
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skip("skipping on non-linux")
|
||||
}
|
||||
|
||||
executable, err := os.Executable()
|
||||
require.NoError(t, err)
|
||||
|
||||
cmd, err := agentexec.CommandContext(context.Background(), "sh", "-c", "sleep")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, executable, cmd.Path)
|
||||
require.Equal(t, []string{executable, "agent-exec", "--coder-oom=432", "--coder-nice=14", "--", "sh", "-c", "sleep"}, cmd.Args)
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/goleak"
|
||||
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
"github.com/coder/coder/v2/agent/agentscripts"
|
||||
"github.com/coder/coder/v2/agent/agentssh"
|
||||
"github.com/coder/coder/v2/agent/agenttest"
|
||||
@@ -160,7 +161,7 @@ func setup(t *testing.T, getScriptLogger func(logSourceID uuid.UUID) agentscript
|
||||
}
|
||||
fs := afero.NewMemMapFs()
|
||||
logger := testutil.Logger(t)
|
||||
s, err := agentssh.NewServer(context.Background(), logger, prometheus.NewRegistry(), fs, nil)
|
||||
s, err := agentssh.NewServer(context.Background(), logger, prometheus.NewRegistry(), fs, agentexec.DefaultExecer, nil)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = s.Close()
|
||||
|
||||
@@ -98,6 +98,7 @@ type Server struct {
|
||||
// a lock on mu but protected by closing.
|
||||
wg sync.WaitGroup
|
||||
|
||||
Execer agentexec.Execer
|
||||
logger slog.Logger
|
||||
srv *ssh.Server
|
||||
|
||||
@@ -110,7 +111,7 @@ type Server struct {
|
||||
metrics *sshServerMetrics
|
||||
}
|
||||
|
||||
func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prometheus.Registry, fs afero.Fs, config *Config) (*Server, error) {
|
||||
func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prometheus.Registry, fs afero.Fs, execer agentexec.Execer, config *Config) (*Server, error) {
|
||||
// Clients' should ignore the host key when connecting.
|
||||
// The agent needs to authenticate with coderd to SSH,
|
||||
// so SSH authentication doesn't improve security.
|
||||
@@ -153,6 +154,7 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
|
||||
|
||||
metrics := newSSHServerMetrics(prometheusRegistry)
|
||||
s := &Server{
|
||||
Execer: execer,
|
||||
listeners: make(map[net.Listener]struct{}),
|
||||
fs: fs,
|
||||
conns: make(map[net.Conn]struct{}),
|
||||
@@ -726,10 +728,7 @@ func (s *Server) CreateCommand(ctx context.Context, script string, env []string)
|
||||
}
|
||||
}
|
||||
|
||||
cmd, err := agentexec.PTYCommandContext(ctx, name, args...)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("pty command context: %w", err)
|
||||
}
|
||||
cmd := s.Execer.PTYCommandContext(ctx, name, args...)
|
||||
cmd.Dir = s.config.WorkingDirectory()
|
||||
|
||||
// If the metadata directory doesn't exist, we run the command
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
"github.com/coder/coder/v2/pty"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
@@ -35,7 +36,7 @@ func Test_sessionStart_orphan(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
|
||||
defer cancel()
|
||||
logger := testutil.Logger(t)
|
||||
s, err := NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), nil)
|
||||
s, err := NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil)
|
||||
require.NoError(t, err)
|
||||
defer s.Close()
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
"github.com/coder/coder/v2/agent/agentssh"
|
||||
"github.com/coder/coder/v2/pty/ptytest"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
@@ -36,7 +37,7 @@ func TestNewServer_ServeClient(t *testing.T) {
|
||||
|
||||
ctx := context.Background()
|
||||
logger := testutil.Logger(t)
|
||||
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), nil)
|
||||
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil)
|
||||
require.NoError(t, err)
|
||||
defer s.Close()
|
||||
|
||||
@@ -77,7 +78,7 @@ func TestNewServer_ExecuteShebang(t *testing.T) {
|
||||
|
||||
ctx := context.Background()
|
||||
logger := testutil.Logger(t)
|
||||
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), nil)
|
||||
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = s.Close()
|
||||
@@ -108,7 +109,7 @@ func TestNewServer_CloseActiveConnections(t *testing.T) {
|
||||
|
||||
ctx := context.Background()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), nil)
|
||||
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil)
|
||||
require.NoError(t, err)
|
||||
defer s.Close()
|
||||
|
||||
@@ -159,7 +160,7 @@ func TestNewServer_Signal(t *testing.T) {
|
||||
|
||||
ctx := context.Background()
|
||||
logger := testutil.Logger(t)
|
||||
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), nil)
|
||||
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil)
|
||||
require.NoError(t, err)
|
||||
defer s.Close()
|
||||
|
||||
@@ -224,7 +225,7 @@ func TestNewServer_Signal(t *testing.T) {
|
||||
|
||||
ctx := context.Background()
|
||||
logger := testutil.Logger(t)
|
||||
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), nil)
|
||||
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil)
|
||||
require.NoError(t, err)
|
||||
defer s.Close()
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
gossh "golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
"github.com/coder/coder/v2/agent/agentssh"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
@@ -34,7 +35,7 @@ func TestServer_X11(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
logger := testutil.Logger(t)
|
||||
fs := afero.NewOsFs()
|
||||
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), fs, &agentssh.Config{})
|
||||
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), fs, agentexec.DefaultExecer, &agentssh.Config{})
|
||||
require.NoError(t, err)
|
||||
defer s.Close()
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ type bufferedReconnectingPTY struct {
|
||||
|
||||
// newBuffered starts the buffered pty. If the context ends the process will be
|
||||
// killed.
|
||||
func newBuffered(ctx context.Context, cmd *pty.Cmd, options *Options, logger slog.Logger) *bufferedReconnectingPTY {
|
||||
func newBuffered(ctx context.Context, logger slog.Logger, execer agentexec.Execer, cmd *pty.Cmd, options *Options) *bufferedReconnectingPTY {
|
||||
rpty := &bufferedReconnectingPTY{
|
||||
activeConns: map[string]net.Conn{},
|
||||
command: cmd,
|
||||
@@ -59,11 +59,7 @@ func newBuffered(ctx context.Context, cmd *pty.Cmd, options *Options, logger slo
|
||||
|
||||
// Add TERM then start the command with a pty. pty.Cmd duplicates Path as the
|
||||
// first argument so remove it.
|
||||
cmdWithEnv, err := agentexec.PTYCommandContext(ctx, cmd.Path, cmd.Args[1:]...)
|
||||
if err != nil {
|
||||
rpty.state.setState(StateDone, xerrors.Errorf("pty command context: %w", err))
|
||||
return rpty
|
||||
}
|
||||
cmdWithEnv := execer.PTYCommandContext(ctx, cmd.Path, cmd.Args[1:]...)
|
||||
cmdWithEnv.Env = append(rpty.command.Env, "TERM=xterm-256color")
|
||||
cmdWithEnv.Dir = rpty.command.Dir
|
||||
ptty, process, err := pty.Start(cmdWithEnv)
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/coder/v2/pty"
|
||||
)
|
||||
@@ -55,7 +56,7 @@ type ReconnectingPTY interface {
|
||||
// close itself (and all connections to it) if nothing is attached for the
|
||||
// duration of the timeout, if the context ends, or the process exits (buffered
|
||||
// backend only).
|
||||
func New(ctx context.Context, cmd *pty.Cmd, options *Options, logger slog.Logger) ReconnectingPTY {
|
||||
func New(ctx context.Context, logger slog.Logger, execer agentexec.Execer, cmd *pty.Cmd, options *Options) ReconnectingPTY {
|
||||
if options.Timeout == 0 {
|
||||
options.Timeout = 5 * time.Minute
|
||||
}
|
||||
@@ -75,9 +76,9 @@ func New(ctx context.Context, cmd *pty.Cmd, options *Options, logger slog.Logger
|
||||
|
||||
switch backendType {
|
||||
case "screen":
|
||||
return newScreen(ctx, cmd, options, logger)
|
||||
return newScreen(ctx, logger, execer, cmd, options)
|
||||
default:
|
||||
return newBuffered(ctx, cmd, options, logger)
|
||||
return newBuffered(ctx, logger, execer, cmd, options)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ import (
|
||||
|
||||
// screenReconnectingPTY provides a reconnectable PTY via `screen`.
|
||||
type screenReconnectingPTY struct {
|
||||
execer agentexec.Execer
|
||||
command *pty.Cmd
|
||||
|
||||
// id holds the id of the session for both creating and attaching. This will
|
||||
@@ -59,8 +60,9 @@ type screenReconnectingPTY struct {
|
||||
// spawns the daemon with a hardcoded 24x80 size it is not a very good user
|
||||
// experience. Instead we will let the attach command spawn the daemon on its
|
||||
// own which causes it to spawn with the specified size.
|
||||
func newScreen(ctx context.Context, cmd *pty.Cmd, options *Options, logger slog.Logger) *screenReconnectingPTY {
|
||||
func newScreen(ctx context.Context, logger slog.Logger, execer agentexec.Execer, cmd *pty.Cmd, options *Options) *screenReconnectingPTY {
|
||||
rpty := &screenReconnectingPTY{
|
||||
execer: execer,
|
||||
command: cmd,
|
||||
metrics: options.Metrics,
|
||||
state: newState(),
|
||||
@@ -210,7 +212,7 @@ func (rpty *screenReconnectingPTY) doAttach(ctx context.Context, conn net.Conn,
|
||||
logger.Debug(ctx, "spawning screen client", slog.F("screen_id", rpty.id))
|
||||
|
||||
// Wrap the command with screen and tie it to the connection's context.
|
||||
cmd, err := agentexec.PTYCommandContext(ctx, "screen", append([]string{
|
||||
cmd := rpty.execer.PTYCommandContext(ctx, "screen", append([]string{
|
||||
// -S is for setting the session's name.
|
||||
"-S", rpty.id,
|
||||
// -U tells screen to use UTF-8 encoding.
|
||||
@@ -223,9 +225,6 @@ func (rpty *screenReconnectingPTY) doAttach(ctx context.Context, conn net.Conn,
|
||||
rpty.command.Path,
|
||||
// pty.Cmd duplicates Path as the first argument so remove it.
|
||||
}, rpty.command.Args[1:]...)...)
|
||||
if err != nil {
|
||||
return nil, nil, xerrors.Errorf("pty command context: %w", err)
|
||||
}
|
||||
cmd.Env = append(rpty.command.Env, "TERM=xterm-256color")
|
||||
cmd.Dir = rpty.command.Dir
|
||||
ptty, process, err := pty.Start(cmd, pty.WithPTYOption(
|
||||
@@ -333,7 +332,7 @@ func (rpty *screenReconnectingPTY) sendCommand(ctx context.Context, command stri
|
||||
run := func() (bool, error) {
|
||||
var stdout bytes.Buffer
|
||||
//nolint:gosec
|
||||
cmd, err := agentexec.CommandContext(ctx, "screen",
|
||||
cmd := rpty.execer.CommandContext(ctx, "screen",
|
||||
// -x targets an attached session.
|
||||
"-x", rpty.id,
|
||||
// -c is the flag for the config file.
|
||||
@@ -341,13 +340,10 @@ func (rpty *screenReconnectingPTY) sendCommand(ctx context.Context, command stri
|
||||
// -X runs a command in the matching session.
|
||||
"-X", command,
|
||||
)
|
||||
if err != nil {
|
||||
return false, xerrors.Errorf("command context: %w", err)
|
||||
}
|
||||
cmd.Env = append(rpty.command.Env, "TERM=xterm-256color")
|
||||
cmd.Dir = rpty.command.Dir
|
||||
cmd.Stdout = &stdout
|
||||
err = cmd.Run()
|
||||
err := cmd.Run()
|
||||
if err == nil {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
@@ -165,10 +165,15 @@ func (s *Server) handleConn(ctx context.Context, logger slog.Logger, conn net.Co
|
||||
return xerrors.Errorf("create command: %w", err)
|
||||
}
|
||||
|
||||
rpty = New(ctx, cmd, &Options{
|
||||
Timeout: s.timeout,
|
||||
Metrics: s.errorsTotal,
|
||||
}, logger.With(slog.F("message_id", msg.ID)))
|
||||
rpty = New(ctx,
|
||||
logger.With(slog.F("message_id", msg.ID)),
|
||||
s.commandCreator.Execer,
|
||||
cmd,
|
||||
&Options{
|
||||
Timeout: s.timeout,
|
||||
Metrics: s.errorsTotal,
|
||||
},
|
||||
)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
|
||||
@@ -309,6 +309,11 @@ func (r *RootCmd) workspaceAgent() *serpent.Command {
|
||||
)
|
||||
}
|
||||
|
||||
execer, err := agentexec.NewExecer()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create agent execer: %w", err)
|
||||
}
|
||||
|
||||
agnt := agent.New(agent.Options{
|
||||
Client: client,
|
||||
Logger: logger,
|
||||
@@ -333,6 +338,7 @@ func (r *RootCmd) workspaceAgent() *serpent.Command {
|
||||
|
||||
PrometheusRegistry: prometheusRegistry,
|
||||
BlockFileTransfer: blockFileTransfer,
|
||||
Execer: execer,
|
||||
})
|
||||
|
||||
promHandler := agent.PrometheusMetricsHandler(prometheusRegistry, logger)
|
||||
|
||||
+2
-2
@@ -503,7 +503,7 @@ func noExecInAgent(m dsl.Matcher) {
|
||||
!m.File().PkgPath.Matches("/agentexec") &&
|
||||
!m.File().Name.Matches(`_test\.go$`),
|
||||
).
|
||||
Report("The agent and its subpackages should not use exec.Command or exec.CommandContext directly. Consider using agentexec.CommandContext instead.")
|
||||
Report("The agent and its subpackages should not use exec.Command or exec.CommandContext directly. Consider using an agentexec.Execer instead.")
|
||||
}
|
||||
|
||||
// noPTYInAgent ensures that packages under agent/ don't use pty.Command or
|
||||
@@ -521,5 +521,5 @@ func noPTYInAgent(m dsl.Matcher) {
|
||||
!m.File().PkgPath.Matches(`/agentexec`) &&
|
||||
!m.File().Name.Matches(`_test\.go$`),
|
||||
).
|
||||
Report("The agent and its subpackages should not use pty.Command or pty.CommandContext directly. Consider using agentexec.PTYCommandContext instead.")
|
||||
Report("The agent and its subpackages should not use pty.Command or pty.CommandContext directly. Consider using an agentexec.Execer instead.")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user