fix: add agent exec abstraction (#15717)

This commit is contained in:
Jon Ayers
2024-12-04 23:30:25 +02:00
committed by GitHub
parent 6c9ccca687
commit ce573b9faa
16 changed files with 210 additions and 192 deletions
+9 -1
View File
@@ -33,6 +33,7 @@ import (
"tailscale.com/util/clientmetric"
"cdr.dev/slog"
"github.com/coder/coder/v2/agent/agentexec"
"github.com/coder/coder/v2/agent/agentscripts"
"github.com/coder/coder/v2/agent/agentssh"
"github.com/coder/coder/v2/agent/proto"
@@ -80,6 +81,7 @@ type Options struct {
ReportMetadataInterval time.Duration
ServiceBannerRefreshInterval time.Duration
BlockFileTransfer bool
Execer agentexec.Execer
}
type Client interface {
@@ -139,6 +141,10 @@ func New(options Options) Agent {
prometheusRegistry = prometheus.NewRegistry()
}
if options.Execer == nil {
options.Execer = agentexec.DefaultExecer
}
hardCtx, hardCancel := context.WithCancel(context.Background())
gracefulCtx, gracefulCancel := context.WithCancel(hardCtx)
a := &agent{
@@ -171,6 +177,7 @@ func New(options Options) Agent {
prometheusRegistry: prometheusRegistry,
metrics: newAgentMetrics(prometheusRegistry),
execer: options.Execer,
}
// Initially, we have a closed channel, reflecting the fact that we are not initially connected.
// Each time we connect we replace the channel (while holding the closeMutex) with a new one
@@ -239,6 +246,7 @@ type agent struct {
// metrics are prometheus registered metrics that will be collected and
// labeled in Coder with the agent + workspace.
metrics *agentMetrics
execer agentexec.Execer
}
func (a *agent) TailnetConn() *tailnet.Conn {
@@ -247,7 +255,7 @@ func (a *agent) TailnetConn() *tailnet.Conn {
func (a *agent) init() {
// pass the "hard" context because we explicitly close the SSH server as part of graceful shutdown.
sshSrv, err := agentssh.NewServer(a.hardCtx, a.logger.Named("ssh-server"), a.prometheusRegistry, a.filesystem, &agentssh.Config{
sshSrv, err := agentssh.NewServer(a.hardCtx, a.logger.Named("ssh-server"), a.prometheusRegistry, a.filesystem, a.execer, &agentssh.Config{
MaxTimeout: a.sshMaxTimeout,
MOTDFile: func() string { return a.manifest.Load().MOTDFile },
AnnouncementBanners: func() *[]codersdk.BannerConfig { return a.announcementBanners.Load() },
-3
View File
@@ -17,9 +17,6 @@ import (
"golang.org/x/xerrors"
)
// unset is set to an invalid value for nice and oom scores.
const unset = -2000
// CLI runs the agent-exec command. It should only be called by the cli package.
func CLI() error {
// We lock the OS thread here to avoid a race condition where the nice priority
+72 -31
View File
@@ -20,60 +20,101 @@ const (
EnvProcPrioMgmt = "CODER_PROC_PRIO_MGMT"
EnvProcOOMScore = "CODER_PROC_OOM_SCORE"
EnvProcNiceScore = "CODER_PROC_NICE_SCORE"
// unset is set to an invalid value for nice and oom scores.
unset = -2000
)
// CommandContext returns an exec.Cmd that calls "coder agent-exec" prior to exec'ing
// the provided command if CODER_PROC_PRIO_MGMT is set, otherwise a normal exec.Cmd
// is returned. All instances of exec.Cmd should flow through this function to ensure
// proper resource constraints are applied to the child process.
func CommandContext(ctx context.Context, cmd string, args ...string) (*exec.Cmd, error) {
cmd, args, err := agentExecCmd(cmd, args...)
if err != nil {
return nil, xerrors.Errorf("agent exec cmd: %w", err)
}
return exec.CommandContext(ctx, cmd, args...), nil
var DefaultExecer Execer = execer{}
// Execer defines an abstraction for creating exec.Cmd variants. It's unfortunately
// necessary because we need to be able to wrap child processes with "coder agent-exec"
// for templates that expect the agent to manage process priority.
type Execer interface {
// CommandContext returns an exec.Cmd that calls "coder agent-exec" prior to exec'ing
// the provided command if CODER_PROC_PRIO_MGMT is set, otherwise a normal exec.Cmd
// is returned. All instances of exec.Cmd should flow through this function to ensure
// proper resource constraints are applied to the child process.
CommandContext(ctx context.Context, cmd string, args ...string) *exec.Cmd
// PTYCommandContext returns an pty.Cmd that calls "coder agent-exec" prior to exec'ing
// the provided command if CODER_PROC_PRIO_MGMT is set, otherwise a normal pty.Cmd
// is returned. All instances of pty.Cmd should flow through this function to ensure
// proper resource constraints are applied to the child process.
PTYCommandContext(ctx context.Context, cmd string, args ...string) *pty.Cmd
}
// PTYCommandContext returns an pty.Cmd that calls "coder agent-exec" prior to exec'ing
// the provided command if CODER_PROC_PRIO_MGMT is set, otherwise a normal pty.Cmd
// is returned. All instances of pty.Cmd should flow through this function to ensure
// proper resource constraints are applied to the child process.
func PTYCommandContext(ctx context.Context, cmd string, args ...string) (*pty.Cmd, error) {
cmd, args, err := agentExecCmd(cmd, args...)
if err != nil {
return nil, xerrors.Errorf("agent exec cmd: %w", err)
}
return pty.CommandContext(ctx, cmd, args...), nil
}
func agentExecCmd(cmd string, args ...string) (string, []string, error) {
func NewExecer() (Execer, error) {
_, enabled := os.LookupEnv(EnvProcPrioMgmt)
if runtime.GOOS != "linux" || !enabled {
return cmd, args, nil
return DefaultExecer, nil
}
executable, err := os.Executable()
if err != nil {
return "", nil, xerrors.Errorf("get executable: %w", err)
return nil, xerrors.Errorf("get executable: %w", err)
}
bin, err := filepath.EvalSymlinks(executable)
if err != nil {
return "", nil, xerrors.Errorf("eval symlinks: %w", err)
return nil, xerrors.Errorf("eval symlinks: %w", err)
}
oomScore, ok := envValInt(EnvProcOOMScore)
if !ok {
oomScore = unset
}
niceScore, ok := envValInt(EnvProcNiceScore)
if !ok {
niceScore = unset
}
return priorityExecer{
binPath: bin,
oomScore: oomScore,
niceScore: niceScore,
}, nil
}
type execer struct{}
func (execer) CommandContext(ctx context.Context, cmd string, args ...string) *exec.Cmd {
return exec.CommandContext(ctx, cmd, args...)
}
func (execer) PTYCommandContext(ctx context.Context, cmd string, args ...string) *pty.Cmd {
return pty.CommandContext(ctx, cmd, args...)
}
type priorityExecer struct {
binPath string
oomScore int
niceScore int
}
func (e priorityExecer) CommandContext(ctx context.Context, cmd string, args ...string) *exec.Cmd {
cmd, args = e.agentExecCmd(cmd, args...)
return exec.CommandContext(ctx, cmd, args...)
}
func (e priorityExecer) PTYCommandContext(ctx context.Context, cmd string, args ...string) *pty.Cmd {
cmd, args = e.agentExecCmd(cmd, args...)
return pty.CommandContext(ctx, cmd, args...)
}
func (e priorityExecer) agentExecCmd(cmd string, args ...string) (string, []string) {
execArgs := []string{"agent-exec"}
if score, ok := envValInt(EnvProcOOMScore); ok {
execArgs = append(execArgs, oomScoreArg(score))
if e.oomScore != unset {
execArgs = append(execArgs, oomScoreArg(e.oomScore))
}
if score, ok := envValInt(EnvProcNiceScore); ok {
execArgs = append(execArgs, niceScoreArg(score))
if e.niceScore != unset {
execArgs = append(execArgs, niceScoreArg(e.niceScore))
}
execArgs = append(execArgs, "--", cmd)
execArgs = append(execArgs, args...)
return bin, execArgs, nil
return e.binPath, execArgs
}
// envValInt searches for a key in a list of environment variables and parses it to an int.
+84
View File
@@ -0,0 +1,84 @@
package agentexec
import (
"context"
"os/exec"
"testing"
"github.com/stretchr/testify/require"
)
func TestExecer(t *testing.T) {
t.Parallel()
t.Run("Default", func(t *testing.T) {
t.Parallel()
cmd := DefaultExecer.CommandContext(context.Background(), "sh", "-c", "sleep")
path, err := exec.LookPath("sh")
require.NoError(t, err)
require.Equal(t, path, cmd.Path)
require.Equal(t, []string{"sh", "-c", "sleep"}, cmd.Args)
})
t.Run("Priority", func(t *testing.T) {
t.Parallel()
t.Run("OK", func(t *testing.T) {
t.Parallel()
e := priorityExecer{
binPath: "/foo/bar/baz",
oomScore: unset,
niceScore: unset,
}
cmd := e.CommandContext(context.Background(), "sh", "-c", "sleep")
require.Equal(t, e.binPath, cmd.Path)
require.Equal(t, []string{e.binPath, "agent-exec", "--", "sh", "-c", "sleep"}, cmd.Args)
})
t.Run("Nice", func(t *testing.T) {
t.Parallel()
e := priorityExecer{
binPath: "/foo/bar/baz",
oomScore: unset,
niceScore: 10,
}
cmd := e.CommandContext(context.Background(), "sh", "-c", "sleep")
require.Equal(t, e.binPath, cmd.Path)
require.Equal(t, []string{e.binPath, "agent-exec", "--coder-nice=10", "--", "sh", "-c", "sleep"}, cmd.Args)
})
t.Run("OOM", func(t *testing.T) {
t.Parallel()
e := priorityExecer{
binPath: "/foo/bar/baz",
oomScore: 123,
niceScore: unset,
}
cmd := e.CommandContext(context.Background(), "sh", "-c", "sleep")
require.Equal(t, e.binPath, cmd.Path)
require.Equal(t, []string{e.binPath, "agent-exec", "--coder-oom=123", "--", "sh", "-c", "sleep"}, cmd.Args)
})
t.Run("Both", func(t *testing.T) {
t.Parallel()
e := priorityExecer{
binPath: "/foo/bar/baz",
oomScore: 432,
niceScore: 14,
}
cmd := e.CommandContext(context.Background(), "sh", "-c", "sleep")
require.Equal(t, e.binPath, cmd.Path)
require.Equal(t, []string{e.binPath, "agent-exec", "--coder-oom=432", "--coder-nice=14", "--", "sh", "-c", "sleep"}, cmd.Args)
})
})
}
-119
View File
@@ -1,119 +0,0 @@
package agentexec_test
import (
"context"
"os"
"os/exec"
"runtime"
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/agent/agentexec"
)
//nolint:paralleltest // we need to test environment variables
func TestExec(t *testing.T) {
//nolint:paralleltest // we need to test environment variables
t.Run("NonLinux", func(t *testing.T) {
t.Setenv(agentexec.EnvProcPrioMgmt, "true")
if runtime.GOOS == "linux" {
t.Skip("skipping on linux")
}
cmd, err := agentexec.CommandContext(context.Background(), "sh", "-c", "sleep")
require.NoError(t, err)
path, err := exec.LookPath("sh")
require.NoError(t, err)
require.Equal(t, path, cmd.Path)
require.Equal(t, []string{"sh", "-c", "sleep"}, cmd.Args)
})
//nolint:paralleltest // we need to test environment variables
t.Run("Linux", func(t *testing.T) {
//nolint:paralleltest // we need to test environment variables
t.Run("Disabled", func(t *testing.T) {
if runtime.GOOS != "linux" {
t.Skip("skipping on non-linux")
}
cmd, err := agentexec.CommandContext(context.Background(), "sh", "-c", "sleep")
require.NoError(t, err)
path, err := exec.LookPath("sh")
require.NoError(t, err)
require.Equal(t, path, cmd.Path)
require.Equal(t, []string{"sh", "-c", "sleep"}, cmd.Args)
})
//nolint:paralleltest // we need to test environment variables
t.Run("Enabled", func(t *testing.T) {
t.Setenv(agentexec.EnvProcPrioMgmt, "hello")
if runtime.GOOS != "linux" {
t.Skip("skipping on non-linux")
}
executable, err := os.Executable()
require.NoError(t, err)
cmd, err := agentexec.CommandContext(context.Background(), "sh", "-c", "sleep")
require.NoError(t, err)
require.Equal(t, executable, cmd.Path)
require.Equal(t, []string{executable, "agent-exec", "--", "sh", "-c", "sleep"}, cmd.Args)
})
t.Run("Nice", func(t *testing.T) {
t.Setenv(agentexec.EnvProcPrioMgmt, "hello")
t.Setenv(agentexec.EnvProcNiceScore, "10")
if runtime.GOOS != "linux" {
t.Skip("skipping on non-linux")
}
executable, err := os.Executable()
require.NoError(t, err)
cmd, err := agentexec.CommandContext(context.Background(), "sh", "-c", "sleep")
require.NoError(t, err)
require.Equal(t, executable, cmd.Path)
require.Equal(t, []string{executable, "agent-exec", "--coder-nice=10", "--", "sh", "-c", "sleep"}, cmd.Args)
})
t.Run("OOM", func(t *testing.T) {
t.Setenv(agentexec.EnvProcPrioMgmt, "hello")
t.Setenv(agentexec.EnvProcOOMScore, "123")
if runtime.GOOS != "linux" {
t.Skip("skipping on non-linux")
}
executable, err := os.Executable()
require.NoError(t, err)
cmd, err := agentexec.CommandContext(context.Background(), "sh", "-c", "sleep")
require.NoError(t, err)
require.Equal(t, executable, cmd.Path)
require.Equal(t, []string{executable, "agent-exec", "--coder-oom=123", "--", "sh", "-c", "sleep"}, cmd.Args)
})
t.Run("Both", func(t *testing.T) {
t.Setenv(agentexec.EnvProcPrioMgmt, "hello")
t.Setenv(agentexec.EnvProcOOMScore, "432")
t.Setenv(agentexec.EnvProcNiceScore, "14")
if runtime.GOOS != "linux" {
t.Skip("skipping on non-linux")
}
executable, err := os.Executable()
require.NoError(t, err)
cmd, err := agentexec.CommandContext(context.Background(), "sh", "-c", "sleep")
require.NoError(t, err)
require.Equal(t, executable, cmd.Path)
require.Equal(t, []string{executable, "agent-exec", "--coder-oom=432", "--coder-nice=14", "--", "sh", "-c", "sleep"}, cmd.Args)
})
})
}
+2 -1
View File
@@ -14,6 +14,7 @@ import (
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"github.com/coder/coder/v2/agent/agentexec"
"github.com/coder/coder/v2/agent/agentscripts"
"github.com/coder/coder/v2/agent/agentssh"
"github.com/coder/coder/v2/agent/agenttest"
@@ -160,7 +161,7 @@ func setup(t *testing.T, getScriptLogger func(logSourceID uuid.UUID) agentscript
}
fs := afero.NewMemMapFs()
logger := testutil.Logger(t)
s, err := agentssh.NewServer(context.Background(), logger, prometheus.NewRegistry(), fs, nil)
s, err := agentssh.NewServer(context.Background(), logger, prometheus.NewRegistry(), fs, agentexec.DefaultExecer, nil)
require.NoError(t, err)
t.Cleanup(func() {
_ = s.Close()
+4 -5
View File
@@ -98,6 +98,7 @@ type Server struct {
// a lock on mu but protected by closing.
wg sync.WaitGroup
Execer agentexec.Execer
logger slog.Logger
srv *ssh.Server
@@ -110,7 +111,7 @@ type Server struct {
metrics *sshServerMetrics
}
func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prometheus.Registry, fs afero.Fs, config *Config) (*Server, error) {
func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prometheus.Registry, fs afero.Fs, execer agentexec.Execer, config *Config) (*Server, error) {
// Clients' should ignore the host key when connecting.
// The agent needs to authenticate with coderd to SSH,
// so SSH authentication doesn't improve security.
@@ -153,6 +154,7 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
metrics := newSSHServerMetrics(prometheusRegistry)
s := &Server{
Execer: execer,
listeners: make(map[net.Listener]struct{}),
fs: fs,
conns: make(map[net.Conn]struct{}),
@@ -726,10 +728,7 @@ func (s *Server) CreateCommand(ctx context.Context, script string, env []string)
}
}
cmd, err := agentexec.PTYCommandContext(ctx, name, args...)
if err != nil {
return nil, xerrors.Errorf("pty command context: %w", err)
}
cmd := s.Execer.PTYCommandContext(ctx, name, args...)
cmd.Dir = s.config.WorkingDirectory()
// If the metadata directory doesn't exist, we run the command
+2 -1
View File
@@ -15,6 +15,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/agent/agentexec"
"github.com/coder/coder/v2/pty"
"github.com/coder/coder/v2/testutil"
)
@@ -35,7 +36,7 @@ func Test_sessionStart_orphan(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
logger := testutil.Logger(t)
s, err := NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), nil)
s, err := NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil)
require.NoError(t, err)
defer s.Close()
+6 -5
View File
@@ -22,6 +22,7 @@ import (
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/agent/agentexec"
"github.com/coder/coder/v2/agent/agentssh"
"github.com/coder/coder/v2/pty/ptytest"
"github.com/coder/coder/v2/testutil"
@@ -36,7 +37,7 @@ func TestNewServer_ServeClient(t *testing.T) {
ctx := context.Background()
logger := testutil.Logger(t)
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), nil)
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil)
require.NoError(t, err)
defer s.Close()
@@ -77,7 +78,7 @@ func TestNewServer_ExecuteShebang(t *testing.T) {
ctx := context.Background()
logger := testutil.Logger(t)
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), nil)
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil)
require.NoError(t, err)
t.Cleanup(func() {
_ = s.Close()
@@ -108,7 +109,7 @@ func TestNewServer_CloseActiveConnections(t *testing.T) {
ctx := context.Background()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), nil)
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil)
require.NoError(t, err)
defer s.Close()
@@ -159,7 +160,7 @@ func TestNewServer_Signal(t *testing.T) {
ctx := context.Background()
logger := testutil.Logger(t)
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), nil)
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil)
require.NoError(t, err)
defer s.Close()
@@ -224,7 +225,7 @@ func TestNewServer_Signal(t *testing.T) {
ctx := context.Background()
logger := testutil.Logger(t)
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), nil)
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil)
require.NoError(t, err)
defer s.Close()
+2 -1
View File
@@ -21,6 +21,7 @@ import (
"github.com/stretchr/testify/require"
gossh "golang.org/x/crypto/ssh"
"github.com/coder/coder/v2/agent/agentexec"
"github.com/coder/coder/v2/agent/agentssh"
"github.com/coder/coder/v2/testutil"
)
@@ -34,7 +35,7 @@ func TestServer_X11(t *testing.T) {
ctx := context.Background()
logger := testutil.Logger(t)
fs := afero.NewOsFs()
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), fs, &agentssh.Config{})
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), fs, agentexec.DefaultExecer, &agentssh.Config{})
require.NoError(t, err)
defer s.Close()
+2 -6
View File
@@ -40,7 +40,7 @@ type bufferedReconnectingPTY struct {
// newBuffered starts the buffered pty. If the context ends the process will be
// killed.
func newBuffered(ctx context.Context, cmd *pty.Cmd, options *Options, logger slog.Logger) *bufferedReconnectingPTY {
func newBuffered(ctx context.Context, logger slog.Logger, execer agentexec.Execer, cmd *pty.Cmd, options *Options) *bufferedReconnectingPTY {
rpty := &bufferedReconnectingPTY{
activeConns: map[string]net.Conn{},
command: cmd,
@@ -59,11 +59,7 @@ func newBuffered(ctx context.Context, cmd *pty.Cmd, options *Options, logger slo
// Add TERM then start the command with a pty. pty.Cmd duplicates Path as the
// first argument so remove it.
cmdWithEnv, err := agentexec.PTYCommandContext(ctx, cmd.Path, cmd.Args[1:]...)
if err != nil {
rpty.state.setState(StateDone, xerrors.Errorf("pty command context: %w", err))
return rpty
}
cmdWithEnv := execer.PTYCommandContext(ctx, cmd.Path, cmd.Args[1:]...)
cmdWithEnv.Env = append(rpty.command.Env, "TERM=xterm-256color")
cmdWithEnv.Dir = rpty.command.Dir
ptty, process, err := pty.Start(cmdWithEnv)
+4 -3
View File
@@ -14,6 +14,7 @@ import (
"golang.org/x/xerrors"
"cdr.dev/slog"
"github.com/coder/coder/v2/agent/agentexec"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/pty"
)
@@ -55,7 +56,7 @@ type ReconnectingPTY interface {
// close itself (and all connections to it) if nothing is attached for the
// duration of the timeout, if the context ends, or the process exits (buffered
// backend only).
func New(ctx context.Context, cmd *pty.Cmd, options *Options, logger slog.Logger) ReconnectingPTY {
func New(ctx context.Context, logger slog.Logger, execer agentexec.Execer, cmd *pty.Cmd, options *Options) ReconnectingPTY {
if options.Timeout == 0 {
options.Timeout = 5 * time.Minute
}
@@ -75,9 +76,9 @@ func New(ctx context.Context, cmd *pty.Cmd, options *Options, logger slog.Logger
switch backendType {
case "screen":
return newScreen(ctx, cmd, options, logger)
return newScreen(ctx, logger, execer, cmd, options)
default:
return newBuffered(ctx, cmd, options, logger)
return newBuffered(ctx, logger, execer, cmd, options)
}
}
+6 -10
View File
@@ -25,6 +25,7 @@ import (
// screenReconnectingPTY provides a reconnectable PTY via `screen`.
type screenReconnectingPTY struct {
execer agentexec.Execer
command *pty.Cmd
// id holds the id of the session for both creating and attaching. This will
@@ -59,8 +60,9 @@ type screenReconnectingPTY struct {
// spawns the daemon with a hardcoded 24x80 size it is not a very good user
// experience. Instead we will let the attach command spawn the daemon on its
// own which causes it to spawn with the specified size.
func newScreen(ctx context.Context, cmd *pty.Cmd, options *Options, logger slog.Logger) *screenReconnectingPTY {
func newScreen(ctx context.Context, logger slog.Logger, execer agentexec.Execer, cmd *pty.Cmd, options *Options) *screenReconnectingPTY {
rpty := &screenReconnectingPTY{
execer: execer,
command: cmd,
metrics: options.Metrics,
state: newState(),
@@ -210,7 +212,7 @@ func (rpty *screenReconnectingPTY) doAttach(ctx context.Context, conn net.Conn,
logger.Debug(ctx, "spawning screen client", slog.F("screen_id", rpty.id))
// Wrap the command with screen and tie it to the connection's context.
cmd, err := agentexec.PTYCommandContext(ctx, "screen", append([]string{
cmd := rpty.execer.PTYCommandContext(ctx, "screen", append([]string{
// -S is for setting the session's name.
"-S", rpty.id,
// -U tells screen to use UTF-8 encoding.
@@ -223,9 +225,6 @@ func (rpty *screenReconnectingPTY) doAttach(ctx context.Context, conn net.Conn,
rpty.command.Path,
// pty.Cmd duplicates Path as the first argument so remove it.
}, rpty.command.Args[1:]...)...)
if err != nil {
return nil, nil, xerrors.Errorf("pty command context: %w", err)
}
cmd.Env = append(rpty.command.Env, "TERM=xterm-256color")
cmd.Dir = rpty.command.Dir
ptty, process, err := pty.Start(cmd, pty.WithPTYOption(
@@ -333,7 +332,7 @@ func (rpty *screenReconnectingPTY) sendCommand(ctx context.Context, command stri
run := func() (bool, error) {
var stdout bytes.Buffer
//nolint:gosec
cmd, err := agentexec.CommandContext(ctx, "screen",
cmd := rpty.execer.CommandContext(ctx, "screen",
// -x targets an attached session.
"-x", rpty.id,
// -c is the flag for the config file.
@@ -341,13 +340,10 @@ func (rpty *screenReconnectingPTY) sendCommand(ctx context.Context, command stri
// -X runs a command in the matching session.
"-X", command,
)
if err != nil {
return false, xerrors.Errorf("command context: %w", err)
}
cmd.Env = append(rpty.command.Env, "TERM=xterm-256color")
cmd.Dir = rpty.command.Dir
cmd.Stdout = &stdout
err = cmd.Run()
err := cmd.Run()
if err == nil {
return true, nil
}
+9 -4
View File
@@ -165,10 +165,15 @@ func (s *Server) handleConn(ctx context.Context, logger slog.Logger, conn net.Co
return xerrors.Errorf("create command: %w", err)
}
rpty = New(ctx, cmd, &Options{
Timeout: s.timeout,
Metrics: s.errorsTotal,
}, logger.With(slog.F("message_id", msg.ID)))
rpty = New(ctx,
logger.With(slog.F("message_id", msg.ID)),
s.commandCreator.Execer,
cmd,
&Options{
Timeout: s.timeout,
Metrics: s.errorsTotal,
},
)
done := make(chan struct{})
go func() {
+6
View File
@@ -309,6 +309,11 @@ func (r *RootCmd) workspaceAgent() *serpent.Command {
)
}
execer, err := agentexec.NewExecer()
if err != nil {
return xerrors.Errorf("create agent execer: %w", err)
}
agnt := agent.New(agent.Options{
Client: client,
Logger: logger,
@@ -333,6 +338,7 @@ func (r *RootCmd) workspaceAgent() *serpent.Command {
PrometheusRegistry: prometheusRegistry,
BlockFileTransfer: blockFileTransfer,
Execer: execer,
})
promHandler := agent.PrometheusMetricsHandler(prometheusRegistry, logger)
+2 -2
View File
@@ -503,7 +503,7 @@ func noExecInAgent(m dsl.Matcher) {
!m.File().PkgPath.Matches("/agentexec") &&
!m.File().Name.Matches(`_test\.go$`),
).
Report("The agent and its subpackages should not use exec.Command or exec.CommandContext directly. Consider using agentexec.CommandContext instead.")
Report("The agent and its subpackages should not use exec.Command or exec.CommandContext directly. Consider using an agentexec.Execer instead.")
}
// noPTYInAgent ensures that packages under agent/ don't use pty.Command or
@@ -521,5 +521,5 @@ func noPTYInAgent(m dsl.Matcher) {
!m.File().PkgPath.Matches(`/agentexec`) &&
!m.File().Name.Matches(`_test\.go$`),
).
Report("The agent and its subpackages should not use pty.Command or pty.CommandContext directly. Consider using agentexec.PTYCommandContext instead.")
Report("The agent and its subpackages should not use pty.Command or pty.CommandContext directly. Consider using an agentexec.Execer instead.")
}