mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
7ffeac711c
The web terminal was rendering Claude Code and Codex incorrectly because xterm's custom glyph renderer draws block and quadrant characters with its own geometry. The reconnecting PTY screen backend also exposed `screen.xterm-256color` to the user's shell, which made tmux rendering issues harder to reason about. This PR: * Disables xterm custom glyph rendering so the selected terminal font draws block and quadrant glyphs. * Adds a tiny Powerline-only terminal symbol fallback font so common prompt separators still render when custom glyphs are disabled. * Configures the screen backend to keep the inner shell `TERM` aligned with the browser terminal emulator, including background color erase behavior. * Tightens reconnecting PTY tests around prompt synchronization and `TERM` assertions. <!-- linear:table-colwidths:200,200 --> | Before | After | | -- | -- | | <img src="https://uploads.linear.app/e62091d9-44f5-421c-8e5c-df481fc99003/3c45efce-9d7e-43b4-b24f-88d4d23d294a/ba68155e-949e-4961-b0b2-124757cb07bb?signature=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJwYXRoIjoiL2U2MjA5MWQ5LTQ0ZjUtNDIxYy04ZTVjLWRmNDgxZmM5OTAwMy8zYzQ1ZWZjZS05ZDdlLTQzYjQtYjI0Zi04OGQ0ZDIzZDI5NGEvYmE2ODE1NWUtOTQ5ZS00OTYxLWIwYjItMTI0NzU3Y2IwN2JiIiwiaWF0IjoxNzc4MTgxNjUwLCJleHAiOjE4MDk3NTIyMTB9.45f1ZzBpWOF5OCJV0xHfICdpyRQ1UoGMbJjLYPqeAkg " alt="Before: Claude Code logo rendering is distorted in the web terminal outside and inside tmux" width="640" /> | <img src="https://uploads.linear.app/e62091d9-44f5-421c-8e5c-df481fc99003/26b0a109-5e21-4000-b1b5-ddac87c409d4/46a301c2-a815-419a-92d2-c51cecdefe40?signature=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJwYXRoIjoiL2U2MjA5MWQ5LTQ0ZjUtNDIxYy04ZTVjLWRmNDgxZmM5OTAwMy8yNmIwYTEwOS01ZTIxLTQwMDAtYjFiNS1kZGFjODdjNDA5ZDQvNDZhMzAxYzItYTgxNS00MTlhLTkyZDItYzUxY2VjZGVmZTQwIiwiaWF0IjoxNzc4MTgxNjUwLCJleHAiOjE4MDk3NTIyMTB9.SQVwUbtaf2OrpjRJPkRH3uc0nPqad0bNBVvcRyuR6NQ " alt="After: Claude Code logo renders correctly in the web terminal outside and inside tmux" width="640" /> | ## Validation * `go test ./agent -run '^TestAgent_ReconnectingPTY$' -count=1` * `pnpm --dir site test -- src/theme/constants.test.ts` * `pnpm --dir site lint:types` * `pnpm --dir site check` * `pnpm --dir site build` * `git commit` pre-commit hook passed * `git push` pre-push hook ran and printed the repo CI monitoring hint > Mux worked on this PR on Mike's behalf. --------- Co-authored-by: Michael Suchacz <203725896+ibetitsmike@users.noreply.github.com>
4186 lines
128 KiB
Go
4186 lines
128 KiB
Go
package agent_test
|
||
|
||
import (
|
||
"bufio"
|
||
"bytes"
|
||
"context"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"io"
|
||
"net"
|
||
"net/http"
|
||
"net/http/httptest"
|
||
"net/netip"
|
||
"os"
|
||
"os/exec"
|
||
"os/user"
|
||
"path"
|
||
"path/filepath"
|
||
"regexp"
|
||
"runtime"
|
||
"slices"
|
||
"strconv"
|
||
"strings"
|
||
"testing"
|
||
"time"
|
||
|
||
"github.com/bramvdbogaerde/go-scp"
|
||
"github.com/google/uuid"
|
||
"github.com/ory/dockertest/v3"
|
||
"github.com/ory/dockertest/v3/docker"
|
||
"github.com/pion/udp"
|
||
"github.com/pkg/sftp"
|
||
"github.com/prometheus/client_golang/prometheus"
|
||
promgo "github.com/prometheus/client_model/go"
|
||
"github.com/spf13/afero"
|
||
"github.com/stretchr/testify/assert"
|
||
"github.com/stretchr/testify/require"
|
||
"go.uber.org/goleak"
|
||
"golang.org/x/crypto/ssh"
|
||
"golang.org/x/xerrors"
|
||
"tailscale.com/net/speedtest"
|
||
"tailscale.com/tailcfg"
|
||
|
||
"cdr.dev/slog/v3"
|
||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||
"github.com/coder/coder/v2/agent"
|
||
"github.com/coder/coder/v2/agent/agentcontainers"
|
||
"github.com/coder/coder/v2/agent/agentssh"
|
||
"github.com/coder/coder/v2/agent/agenttest"
|
||
"github.com/coder/coder/v2/agent/proto"
|
||
"github.com/coder/coder/v2/agent/usershell"
|
||
"github.com/coder/coder/v2/codersdk"
|
||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||
"github.com/coder/coder/v2/cryptorand"
|
||
"github.com/coder/coder/v2/pty/ptytest"
|
||
"github.com/coder/coder/v2/tailnet"
|
||
"github.com/coder/coder/v2/tailnet/tailnettest"
|
||
"github.com/coder/coder/v2/testutil"
|
||
"github.com/coder/quartz"
|
||
)
|
||
|
||
func TestMain(m *testing.M) {
|
||
if os.Getenv("CODER_TEST_RUN_SUB_AGENT_MAIN") == "1" {
|
||
// If we're running as a subagent, we don't want to run the main tests.
|
||
// Instead, we just run the subagent tests.
|
||
exit := runSubAgentMain()
|
||
os.Exit(exit)
|
||
}
|
||
goleak.VerifyTestMain(m, testutil.GoleakOptions...)
|
||
}
|
||
|
||
var sshPorts = []uint16{workspacesdk.AgentSSHPort, workspacesdk.AgentStandardSSHPort}
|
||
|
||
// TestAgent_CloseWhileStarting is a regression test for https://github.com/coder/coder/issues/17328
|
||
func TestAgent_ImmediateClose(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||
defer cancel()
|
||
|
||
logger := slogtest.Make(t, &slogtest.Options{
|
||
// Agent can drop errors when shutting down, and some, like the
|
||
// fasthttplistener connection closed error, are unexported.
|
||
IgnoreErrors: true,
|
||
}).Leveled(slog.LevelDebug)
|
||
manifest := agentsdk.Manifest{
|
||
AgentID: uuid.New(),
|
||
AgentName: "test-agent",
|
||
WorkspaceName: "test-workspace",
|
||
WorkspaceID: uuid.New(),
|
||
}
|
||
|
||
coordinator := tailnet.NewCoordinator(logger)
|
||
t.Cleanup(func() {
|
||
_ = coordinator.Close()
|
||
})
|
||
statsCh := make(chan *proto.Stats, 50)
|
||
fs := afero.NewMemMapFs()
|
||
client := agenttest.NewClient(t, logger.Named("agenttest"), manifest.AgentID, manifest, statsCh, coordinator)
|
||
t.Cleanup(client.Close)
|
||
|
||
options := agent.Options{
|
||
Client: client,
|
||
Filesystem: fs,
|
||
Logger: logger.Named("agent"),
|
||
ReconnectingPTYTimeout: 0,
|
||
EnvironmentVariables: map[string]string{},
|
||
}
|
||
|
||
agentUnderTest := agent.New(options)
|
||
t.Cleanup(func() {
|
||
_ = agentUnderTest.Close()
|
||
})
|
||
|
||
// wait until the agent has connected and is starting to find races in the startup code
|
||
_ = testutil.TryReceive(ctx, t, client.GetStartup())
|
||
t.Log("Closing Agent")
|
||
err := agentUnderTest.Close()
|
||
require.NoError(t, err)
|
||
}
|
||
|
||
// NOTE(Cian): I noticed that these tests would fail when my default shell was zsh.
|
||
// Writing "exit 0" to stdin before closing fixed the issue for me.
|
||
|
||
func TestAgent_Stats_SSH(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
for _, port := range sshPorts {
|
||
t.Run(fmt.Sprintf("(:%d)", port), func(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||
defer cancel()
|
||
|
||
//nolint:dogsled
|
||
conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
|
||
|
||
sshClient, err := conn.SSHClientOnPort(ctx, port)
|
||
require.NoError(t, err)
|
||
defer sshClient.Close()
|
||
session, err := sshClient.NewSession()
|
||
require.NoError(t, err)
|
||
defer session.Close()
|
||
stdin, err := session.StdinPipe()
|
||
require.NoError(t, err)
|
||
err = session.Shell()
|
||
require.NoError(t, err)
|
||
|
||
var s *proto.Stats
|
||
// We are looking for four different stats to be reported. They might not all
|
||
// arrive at the same time, so we loop until we've seen them all.
|
||
var connectionCountSeen, rxBytesSeen, txBytesSeen, sessionCountSSHSeen bool
|
||
require.Eventuallyf(t, func() bool {
|
||
var ok bool
|
||
s, ok = <-stats
|
||
if !ok {
|
||
return false
|
||
}
|
||
if s.ConnectionCount > 0 {
|
||
connectionCountSeen = true
|
||
}
|
||
if s.RxBytes > 0 {
|
||
rxBytesSeen = true
|
||
}
|
||
if s.TxBytes > 0 {
|
||
txBytesSeen = true
|
||
}
|
||
if s.SessionCountSsh == 1 {
|
||
sessionCountSSHSeen = true
|
||
}
|
||
return connectionCountSeen && rxBytesSeen && txBytesSeen && sessionCountSSHSeen
|
||
}, testutil.WaitLong, testutil.IntervalFast,
|
||
"never saw all stats: %+v, saw connectionCount: %t, rxBytes: %t, txBytes: %t, sessionCountSsh: %t",
|
||
s, connectionCountSeen, rxBytesSeen, txBytesSeen, sessionCountSSHSeen,
|
||
)
|
||
_, err = stdin.Write([]byte("exit 0\n"))
|
||
require.NoError(t, err, "writing exit to stdin")
|
||
_ = stdin.Close()
|
||
err = session.Wait()
|
||
require.NoError(t, err, "waiting for session to exit")
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestAgent_Stats_ReconnectingPTY(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||
defer cancel()
|
||
|
||
//nolint:dogsled
|
||
conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
|
||
|
||
ptyConn, err := conn.ReconnectingPTY(ctx, uuid.New(), 128, 128, "bash")
|
||
require.NoError(t, err)
|
||
defer ptyConn.Close()
|
||
|
||
data, err := json.Marshal(workspacesdk.ReconnectingPTYRequest{
|
||
Data: "echo test\r\n",
|
||
})
|
||
require.NoError(t, err)
|
||
_, err = ptyConn.Write(data)
|
||
require.NoError(t, err)
|
||
|
||
var s *proto.Stats
|
||
// We are looking for four different stats to be reported. They might not all
|
||
// arrive at the same time, so we loop until we've seen them all.
|
||
var connectionCountSeen, rxBytesSeen, txBytesSeen, sessionCountReconnectingPTYSeen bool
|
||
require.Eventuallyf(t, func() bool {
|
||
var ok bool
|
||
s, ok = <-stats
|
||
if !ok {
|
||
return false
|
||
}
|
||
if s.ConnectionCount > 0 {
|
||
connectionCountSeen = true
|
||
}
|
||
if s.RxBytes > 0 {
|
||
rxBytesSeen = true
|
||
}
|
||
if s.TxBytes > 0 {
|
||
txBytesSeen = true
|
||
}
|
||
if s.SessionCountReconnectingPty == 1 {
|
||
sessionCountReconnectingPTYSeen = true
|
||
}
|
||
return connectionCountSeen && rxBytesSeen && txBytesSeen && sessionCountReconnectingPTYSeen
|
||
}, testutil.WaitLong, testutil.IntervalFast,
|
||
"never saw all stats: %+v, saw connectionCount: %t, rxBytes: %t, txBytes: %t, sessionCountReconnectingPTY: %t",
|
||
s, connectionCountSeen, rxBytesSeen, txBytesSeen, sessionCountReconnectingPTYSeen,
|
||
)
|
||
}
|
||
|
||
func TestAgent_Stats_Magic(t *testing.T) {
|
||
t.Parallel()
|
||
t.Run("StripsEnvironmentVariable", func(t *testing.T) {
|
||
t.Parallel()
|
||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||
defer cancel()
|
||
//nolint:dogsled
|
||
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
|
||
sshClient, err := conn.SSHClient(ctx)
|
||
require.NoError(t, err)
|
||
defer sshClient.Close()
|
||
session, err := sshClient.NewSession()
|
||
require.NoError(t, err)
|
||
session.Setenv(agentssh.MagicSessionTypeEnvironmentVariable, string(agentssh.MagicSessionTypeVSCode))
|
||
defer session.Close()
|
||
|
||
command := "sh -c 'echo $" + agentssh.MagicSessionTypeEnvironmentVariable + "'"
|
||
expected := ""
|
||
if runtime.GOOS == "windows" {
|
||
expected = "%" + agentssh.MagicSessionTypeEnvironmentVariable + "%"
|
||
command = "cmd.exe /c echo " + expected
|
||
}
|
||
output, err := session.Output(command)
|
||
require.NoError(t, err)
|
||
require.Equal(t, expected, strings.TrimSpace(string(output)))
|
||
})
|
||
|
||
t.Run("TracksVSCode", func(t *testing.T) {
|
||
t.Parallel()
|
||
if runtime.GOOS == "windows" {
|
||
t.Skip("Sleeping for infinity doesn't work on Windows")
|
||
}
|
||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||
defer cancel()
|
||
//nolint:dogsled
|
||
conn, agentClient, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
|
||
sshClient, err := conn.SSHClient(ctx)
|
||
require.NoError(t, err)
|
||
defer sshClient.Close()
|
||
session, err := sshClient.NewSession()
|
||
require.NoError(t, err)
|
||
session.Setenv(agentssh.MagicSessionTypeEnvironmentVariable, string(agentssh.MagicSessionTypeVSCode))
|
||
defer session.Close()
|
||
stdin, err := session.StdinPipe()
|
||
require.NoError(t, err)
|
||
err = session.Shell()
|
||
require.NoError(t, err)
|
||
require.Eventuallyf(t, func() bool {
|
||
s, ok := <-stats
|
||
t.Logf("got stats: ok=%t, ConnectionCount=%d, RxBytes=%d, TxBytes=%d, SessionCountVSCode=%d, ConnectionMedianLatencyMS=%f",
|
||
ok, s.ConnectionCount, s.RxBytes, s.TxBytes, s.SessionCountVscode, s.ConnectionMedianLatencyMs)
|
||
return ok &&
|
||
// Ensure that the connection didn't count as a "normal" SSH session.
|
||
// This was a special one, so it should be labeled specially in the stats!
|
||
s.SessionCountVscode == 1 &&
|
||
// Ensure that connection latency is being counted!
|
||
// If it isn't, it's set to -1.
|
||
s.ConnectionMedianLatencyMs >= 0
|
||
}, testutil.WaitLong, testutil.IntervalFast,
|
||
"never saw stats",
|
||
)
|
||
|
||
_, err = stdin.Write([]byte("exit 0\n"))
|
||
require.NoError(t, err, "writing exit to stdin")
|
||
_ = stdin.Close()
|
||
err = session.Wait()
|
||
require.NoError(t, err)
|
||
|
||
assertConnectionReport(t, agentClient, proto.Connection_VSCODE, 0, "")
|
||
})
|
||
|
||
t.Run("TracksJetBrains", func(t *testing.T) {
|
||
t.Parallel()
|
||
if runtime.GOOS != "linux" {
|
||
t.Skip("JetBrains tracking is only supported on Linux")
|
||
}
|
||
|
||
ctx := testutil.Context(t, testutil.WaitLong)
|
||
|
||
// JetBrains tracking works by looking at the process name listening on the
|
||
// forwarded port. If the process's command line includes the magic string
|
||
// we are looking for, then we assume it is a JetBrains editor. So when we
|
||
// connect to the port we must ensure the process includes that magic string
|
||
// to fool the agent into thinking this is JetBrains. To do this we need to
|
||
// spawn an external process (in this case a simple echo server) so we can
|
||
// control the process name. The -D here is just to mimic how Java options
|
||
// are set but is not necessary as the agent looks only for the magic
|
||
// string itself anywhere in the command.
|
||
_, b, _, ok := runtime.Caller(0)
|
||
require.True(t, ok)
|
||
dir := filepath.Join(filepath.Dir(b), "../scripts/echoserver/main.go")
|
||
echoServerCmd := exec.Command("go", "run", dir,
|
||
"-D", agentssh.MagicProcessCmdlineJetBrains)
|
||
stdout, err := echoServerCmd.StdoutPipe()
|
||
require.NoError(t, err)
|
||
err = echoServerCmd.Start()
|
||
require.NoError(t, err)
|
||
defer echoServerCmd.Process.Kill()
|
||
|
||
// The echo server prints its port as the first line.
|
||
sc := bufio.NewScanner(stdout)
|
||
sc.Scan()
|
||
remotePort := sc.Text()
|
||
|
||
//nolint:dogsled
|
||
conn, agentClient, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
|
||
sshClient, err := conn.SSHClient(ctx)
|
||
require.NoError(t, err)
|
||
|
||
tunneledConn, err := sshClient.Dial("tcp", fmt.Sprintf("127.0.0.1:%s", remotePort))
|
||
require.NoError(t, err)
|
||
t.Cleanup(func() {
|
||
// always close on failure of test
|
||
_ = conn.Close()
|
||
_ = tunneledConn.Close()
|
||
})
|
||
|
||
require.Eventuallyf(t, func() bool {
|
||
s, ok := <-stats
|
||
t.Logf("got stats with conn open: ok=%t, ConnectionCount=%d, SessionCountJetBrains=%d",
|
||
ok, s.ConnectionCount, s.SessionCountJetbrains)
|
||
return ok && s.SessionCountJetbrains == 1
|
||
}, testutil.WaitLong, testutil.IntervalFast,
|
||
"never saw stats with conn open",
|
||
)
|
||
|
||
// Kill the server and connection after checking for the echo.
|
||
requireEcho(t, tunneledConn)
|
||
_ = echoServerCmd.Process.Kill()
|
||
_ = tunneledConn.Close()
|
||
|
||
require.Eventuallyf(t, func() bool {
|
||
s, ok := <-stats
|
||
t.Logf("got stats after disconnect %t, %d",
|
||
ok, s.SessionCountJetbrains)
|
||
return ok &&
|
||
s.SessionCountJetbrains == 0
|
||
}, testutil.WaitLong, testutil.IntervalFast,
|
||
"never saw stats after conn closes",
|
||
)
|
||
|
||
assertConnectionReport(t, agentClient, proto.Connection_JETBRAINS, 0, "")
|
||
})
|
||
}
|
||
|
||
func TestAgent_SessionExec(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
for _, port := range sshPorts {
|
||
t.Run(fmt.Sprintf("(:%d)", port), func(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
session := setupSSHSessionOnPort(t, agentsdk.Manifest{}, codersdk.ServiceBannerConfig{}, nil, port)
|
||
|
||
command := "echo test"
|
||
if runtime.GOOS == "windows" {
|
||
command = "cmd.exe /c echo test"
|
||
}
|
||
output, err := session.Output(command)
|
||
require.NoError(t, err)
|
||
require.Equal(t, "test", strings.TrimSpace(string(output)))
|
||
})
|
||
}
|
||
}
|
||
|
||
//nolint:tparallel // Sub tests need to run sequentially.
|
||
func TestAgent_Session_EnvironmentVariables(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
tmpdir := t.TempDir()
|
||
|
||
// Defined by the coder script runner, hardcoded here since we don't
|
||
// have a reference to it.
|
||
scriptBinDir := filepath.Join(tmpdir, "coder-script-data", "bin")
|
||
|
||
manifest := agentsdk.Manifest{
|
||
EnvironmentVariables: map[string]string{
|
||
"MY_MANIFEST": "true",
|
||
"MY_OVERRIDE": "false",
|
||
"MY_SESSION_MANIFEST": "false",
|
||
},
|
||
}
|
||
banner := codersdk.ServiceBannerConfig{}
|
||
session := setupSSHSession(t, manifest, banner, nil, func(_ *agenttest.Client, opts *agent.Options) {
|
||
opts.ScriptDataDir = tmpdir
|
||
opts.EnvironmentVariables["MY_OVERRIDE"] = "true"
|
||
})
|
||
|
||
err := session.Setenv("MY_SESSION_MANIFEST", "true")
|
||
require.NoError(t, err)
|
||
err = session.Setenv("MY_SESSION", "true")
|
||
require.NoError(t, err)
|
||
|
||
command := "sh"
|
||
echoEnv := func(t *testing.T, w io.Writer, env string) {
|
||
if runtime.GOOS == "windows" {
|
||
_, err := fmt.Fprintf(w, "echo %%%s%%\r\n", env)
|
||
require.NoError(t, err)
|
||
} else {
|
||
_, err := fmt.Fprintf(w, "echo $%s\n", env)
|
||
require.NoError(t, err)
|
||
}
|
||
}
|
||
if runtime.GOOS == "windows" {
|
||
command = "cmd.exe"
|
||
}
|
||
stdin, err := session.StdinPipe()
|
||
require.NoError(t, err)
|
||
defer stdin.Close()
|
||
stdout, err := session.StdoutPipe()
|
||
require.NoError(t, err)
|
||
|
||
err = session.Start(command)
|
||
require.NoError(t, err)
|
||
|
||
// Context is fine here since we're not doing a parallel subtest.
|
||
ctx := testutil.Context(t, testutil.WaitLong)
|
||
go func() {
|
||
<-ctx.Done()
|
||
_ = session.Close()
|
||
}()
|
||
|
||
s := bufio.NewScanner(stdout)
|
||
|
||
//nolint:paralleltest // These tests need to run sequentially.
|
||
for k, partialV := range map[string]string{
|
||
"CODER": "true", // From the agent.
|
||
"MY_MANIFEST": "true", // From the manifest.
|
||
"MY_OVERRIDE": "true", // From the agent environment variables option, overrides manifest.
|
||
"MY_SESSION_MANIFEST": "false", // From the manifest, overrides session env.
|
||
"MY_SESSION": "true", // From the session.
|
||
"PATH": scriptBinDir + string(filepath.ListSeparator),
|
||
} {
|
||
t.Run(k, func(t *testing.T) {
|
||
echoEnv(t, stdin, k)
|
||
// Windows is unreliable, so keep scanning until we find a match.
|
||
for s.Scan() {
|
||
got := strings.TrimSpace(s.Text())
|
||
t.Logf("%s=%s", k, got)
|
||
if strings.Contains(got, partialV) {
|
||
break
|
||
}
|
||
}
|
||
if err := s.Err(); !errors.Is(err, io.EOF) {
|
||
require.NoError(t, err)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestAgent_Session_SecretInjection(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
manifest := agentsdk.Manifest{
|
||
EnvironmentVariables: map[string]string{
|
||
"SHOULD_BE_OVERRIDDEN": "manifest-value",
|
||
},
|
||
}
|
||
secrets := []agentsdk.WorkspaceSecret{
|
||
{EnvName: "MY_SECRET_ENV", Value: []byte("env-secret-value")},
|
||
{FilePath: "/tmp/secret-file", Value: []byte("file-secret-content")},
|
||
{EnvName: "BOTH_ENV", FilePath: "/tmp/both-file", Value: []byte("both-value")},
|
||
{EnvName: "SHOULD_BE_OVERRIDDEN", Value: []byte("secret-wins")},
|
||
}
|
||
|
||
ctx := testutil.Context(t, testutil.WaitLong)
|
||
//nolint:dogsled
|
||
conn, _, _, fs, _ := setupAgentWithSecrets(t, manifest, secrets, 0)
|
||
|
||
// Verify file injection via the agent's filesystem.
|
||
content, err := afero.ReadFile(fs, "/tmp/secret-file")
|
||
require.NoError(t, err)
|
||
require.Equal(t, "file-secret-content", string(content))
|
||
|
||
content, err = afero.ReadFile(fs, "/tmp/both-file")
|
||
require.NoError(t, err)
|
||
require.Equal(t, "both-value", string(content))
|
||
|
||
// Verify env var injection via an SSH session.
|
||
sshClient, err := conn.SSHClient(ctx)
|
||
require.NoError(t, err)
|
||
t.Cleanup(func() { _ = sshClient.Close() })
|
||
|
||
session, err := sshClient.NewSession()
|
||
require.NoError(t, err)
|
||
t.Cleanup(func() { _ = session.Close() })
|
||
|
||
command := "sh"
|
||
if runtime.GOOS == "windows" {
|
||
command = "cmd.exe"
|
||
}
|
||
|
||
stdin, err := session.StdinPipe()
|
||
require.NoError(t, err)
|
||
defer stdin.Close()
|
||
stdout, err := session.StdoutPipe()
|
||
require.NoError(t, err)
|
||
|
||
err = session.Start(command)
|
||
require.NoError(t, err)
|
||
|
||
go func() {
|
||
<-ctx.Done()
|
||
_ = session.Close()
|
||
}()
|
||
|
||
s := bufio.NewScanner(stdout)
|
||
|
||
echoEnv := func(t *testing.T, w io.Writer, env string) {
|
||
t.Helper()
|
||
if runtime.GOOS == "windows" {
|
||
_, err := fmt.Fprintf(w, "echo %%%s%%\r\n", env)
|
||
require.NoError(t, err)
|
||
} else {
|
||
_, err := fmt.Fprintf(w, "echo $%s\n", env)
|
||
require.NoError(t, err)
|
||
}
|
||
}
|
||
|
||
for k, partialV := range map[string]string{
|
||
"MY_SECRET_ENV": "env-secret-value",
|
||
"BOTH_ENV": "both-value",
|
||
"SHOULD_BE_OVERRIDDEN": "secret-wins",
|
||
} {
|
||
echoEnv(t, stdin, k)
|
||
found := false
|
||
for s.Scan() {
|
||
got := strings.TrimSpace(s.Text())
|
||
t.Logf("%s=%s", k, got)
|
||
if strings.Contains(got, partialV) {
|
||
found = true
|
||
break
|
||
}
|
||
}
|
||
require.True(t, found, "env %s not found in output", k)
|
||
if err := s.Err(); !errors.Is(err, io.EOF) {
|
||
require.NoError(t, err)
|
||
}
|
||
}
|
||
}
|
||
|
||
func TestAgent_StartupScript_SecretInjection(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
if runtime.GOOS == "windows" {
|
||
t.Skip("startup script test uses sh syntax")
|
||
}
|
||
|
||
tmpDir := t.TempDir()
|
||
secretFilePath := filepath.Join(tmpDir, "secret-file")
|
||
envProofPath := filepath.Join(tmpDir, "env-proof")
|
||
fileProofPath := filepath.Join(tmpDir, "file-proof")
|
||
|
||
// The startup script reads the secret env var and the secret file,
|
||
// writing both to proof files so we can verify they were available
|
||
// at script execution time.
|
||
script := fmt.Sprintf(
|
||
"echo \"$MY_STARTUP_SECRET\" > %s && cat %s > %s",
|
||
envProofPath, secretFilePath, fileProofPath,
|
||
)
|
||
|
||
manifest := agentsdk.Manifest{
|
||
Scripts: []codersdk.WorkspaceAgentScript{{
|
||
Script: script,
|
||
Timeout: 30 * time.Second,
|
||
RunOnStart: true,
|
||
}},
|
||
}
|
||
secrets := []agentsdk.WorkspaceSecret{
|
||
{EnvName: "MY_STARTUP_SECRET", Value: []byte("startup-env-value")},
|
||
{FilePath: secretFilePath, Value: []byte("startup-file-content")},
|
||
}
|
||
|
||
// Use the real OS filesystem so that both writeSecretFiles and
|
||
// the startup script operate on the same filesystem.
|
||
//nolint:dogsled
|
||
_, client, _, _, _ := setupAgentWithSecrets(t, manifest, secrets, 0, func(_ *agenttest.Client, opts *agent.Options) {
|
||
opts.Filesystem = afero.NewOsFs()
|
||
})
|
||
|
||
// Wait for the startup script to complete.
|
||
var got []codersdk.WorkspaceAgentLifecycle
|
||
assert.Eventually(t, func() bool {
|
||
got = client.GetLifecycleStates()
|
||
return len(got) > 0 && got[len(got)-1] == codersdk.WorkspaceAgentLifecycleReady
|
||
}, testutil.WaitLong, testutil.IntervalMedium)
|
||
require.Contains(t, got, codersdk.WorkspaceAgentLifecycleReady, "agent never reached ready")
|
||
|
||
// Verify the startup script could read the secret env var.
|
||
envProof, err := os.ReadFile(envProofPath)
|
||
require.NoError(t, err)
|
||
require.Equal(t, "startup-env-value", strings.TrimSpace(string(envProof)))
|
||
|
||
// Verify the startup script could read the secret file.
|
||
fileProof, err := os.ReadFile(fileProofPath)
|
||
require.NoError(t, err)
|
||
require.Equal(t, "startup-file-content", string(fileProof))
|
||
}
|
||
|
||
func TestAgent_GitSSH(t *testing.T) {
|
||
t.Parallel()
|
||
session := setupSSHSession(t, agentsdk.Manifest{}, codersdk.ServiceBannerConfig{}, nil)
|
||
command := "sh -c 'echo $GIT_SSH_COMMAND'"
|
||
if runtime.GOOS == "windows" {
|
||
command = "cmd.exe /c echo %GIT_SSH_COMMAND%"
|
||
}
|
||
output, err := session.Output(command)
|
||
require.NoError(t, err)
|
||
require.True(t, strings.HasSuffix(strings.TrimSpace(string(output)), "gitssh --"))
|
||
}
|
||
|
||
func TestAgent_SessionTTYShell(t *testing.T) {
|
||
t.Parallel()
|
||
if runtime.GOOS == "windows" {
|
||
// This might be our implementation, or ConPTY itself.
|
||
// It's difficult to find extensive tests for it, so
|
||
// it seems like it could be either.
|
||
t.Skip("ConPTY appears to be inconsistent on Windows.")
|
||
}
|
||
|
||
for _, port := range sshPorts {
|
||
t.Run(fmt.Sprintf("(%d)", port), func(t *testing.T) {
|
||
t.Parallel()
|
||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||
|
||
session := setupSSHSessionOnPort(t, agentsdk.Manifest{}, codersdk.ServiceBannerConfig{}, nil, port)
|
||
command := "sh"
|
||
if runtime.GOOS == "windows" {
|
||
command = "cmd.exe"
|
||
}
|
||
err := session.RequestPty("xterm", 128, 128, ssh.TerminalModes{})
|
||
require.NoError(t, err)
|
||
ptty := ptytest.New(t)
|
||
session.Stdout = ptty.Output()
|
||
session.Stderr = ptty.Output()
|
||
session.Stdin = ptty.Input()
|
||
err = session.Start(command)
|
||
require.NoError(t, err)
|
||
_ = ptty.Peek(ctx, 1) // wait for the prompt
|
||
ptty.WriteLine("echo test")
|
||
ptty.ExpectMatch("test")
|
||
ptty.WriteLine("exit")
|
||
err = session.Wait()
|
||
require.NoError(t, err)
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestAgent_SessionTTYExitCode(t *testing.T) {
|
||
t.Parallel()
|
||
session := setupSSHSession(t, agentsdk.Manifest{}, codersdk.ServiceBannerConfig{}, nil)
|
||
command := "areallynotrealcommand"
|
||
err := session.RequestPty("xterm", 128, 128, ssh.TerminalModes{})
|
||
require.NoError(t, err)
|
||
ptty := ptytest.New(t)
|
||
session.Stdout = ptty.Output()
|
||
session.Stderr = ptty.Output()
|
||
session.Stdin = ptty.Input()
|
||
err = session.Start(command)
|
||
require.NoError(t, err)
|
||
err = session.Wait()
|
||
exitErr := &ssh.ExitError{}
|
||
require.True(t, xerrors.As(err, &exitErr))
|
||
if runtime.GOOS == "windows" {
|
||
assert.Equal(t, 1, exitErr.ExitStatus())
|
||
} else {
|
||
assert.Equal(t, 127, exitErr.ExitStatus())
|
||
}
|
||
}
|
||
|
||
func TestAgent_Session_TTY_MOTD(t *testing.T) {
|
||
t.Parallel()
|
||
if runtime.GOOS == "windows" {
|
||
// This might be our implementation, or ConPTY itself.
|
||
// It's difficult to find extensive tests for it, so
|
||
// it seems like it could be either.
|
||
t.Skip("ConPTY appears to be inconsistent on Windows.")
|
||
}
|
||
|
||
u, err := user.Current()
|
||
require.NoError(t, err, "get current user")
|
||
|
||
name := filepath.Join(u.HomeDir, "motd")
|
||
|
||
wantMOTD := "Welcome to your Coder workspace!"
|
||
wantServiceBanner := "Service banner text goes here"
|
||
|
||
tests := []struct {
|
||
name string
|
||
manifest agentsdk.Manifest
|
||
banner codersdk.ServiceBannerConfig
|
||
expected []string
|
||
unexpected []string
|
||
expectedRe *regexp.Regexp
|
||
}{
|
||
{
|
||
name: "WithoutServiceBanner",
|
||
manifest: agentsdk.Manifest{MOTDFile: name},
|
||
banner: codersdk.ServiceBannerConfig{},
|
||
expected: []string{wantMOTD},
|
||
unexpected: []string{wantServiceBanner},
|
||
},
|
||
{
|
||
name: "WithServiceBanner",
|
||
manifest: agentsdk.Manifest{MOTDFile: name},
|
||
banner: codersdk.ServiceBannerConfig{
|
||
Enabled: true,
|
||
Message: wantServiceBanner,
|
||
},
|
||
expected: []string{wantMOTD, wantServiceBanner},
|
||
},
|
||
{
|
||
name: "ServiceBannerDisabled",
|
||
manifest: agentsdk.Manifest{MOTDFile: name},
|
||
banner: codersdk.ServiceBannerConfig{
|
||
Enabled: false,
|
||
Message: wantServiceBanner,
|
||
},
|
||
expected: []string{wantMOTD},
|
||
unexpected: []string{wantServiceBanner},
|
||
},
|
||
{
|
||
name: "ServiceBannerOnly",
|
||
manifest: agentsdk.Manifest{},
|
||
banner: codersdk.ServiceBannerConfig{
|
||
Enabled: true,
|
||
Message: wantServiceBanner,
|
||
},
|
||
expected: []string{wantServiceBanner},
|
||
unexpected: []string{wantMOTD},
|
||
},
|
||
{
|
||
name: "None",
|
||
manifest: agentsdk.Manifest{},
|
||
banner: codersdk.ServiceBannerConfig{},
|
||
unexpected: []string{wantServiceBanner, wantMOTD},
|
||
},
|
||
{
|
||
name: "CarriageReturns",
|
||
manifest: agentsdk.Manifest{},
|
||
banner: codersdk.ServiceBannerConfig{
|
||
Enabled: true,
|
||
Message: "service\n\nbanner\nhere",
|
||
},
|
||
expected: []string{"service\r\n\r\nbanner\r\nhere\r\n\r\n"},
|
||
unexpected: []string{},
|
||
},
|
||
{
|
||
name: "Trim",
|
||
// Enable motd since it will be printed after the banner,
|
||
// this ensures that we can test for an exact mount of
|
||
// newlines.
|
||
manifest: agentsdk.Manifest{
|
||
MOTDFile: name,
|
||
},
|
||
banner: codersdk.ServiceBannerConfig{
|
||
Enabled: true,
|
||
Message: "\n\n\n\n\n\nbanner\n\n\n\n\n\n",
|
||
},
|
||
expectedRe: regexp.MustCompile(`([^\n\r]|^)banner\r\n\r\n[^\r\n]`),
|
||
},
|
||
}
|
||
|
||
for _, test := range tests {
|
||
t.Run(test.name, func(t *testing.T) {
|
||
t.Parallel()
|
||
session := setupSSHSession(t, test.manifest, test.banner, func(fs afero.Fs) {
|
||
err := fs.MkdirAll(filepath.Dir(name), 0o700)
|
||
require.NoError(t, err)
|
||
err = afero.WriteFile(fs, name, []byte(wantMOTD), 0o600)
|
||
require.NoError(t, err)
|
||
})
|
||
testSessionOutput(t, session, test.expected, test.unexpected, test.expectedRe)
|
||
})
|
||
}
|
||
}
|
||
|
||
//nolint:tparallel // Sub tests need to run sequentially.
|
||
func TestAgent_Session_TTY_MOTD_Update(t *testing.T) {
|
||
t.Parallel()
|
||
if runtime.GOOS == "windows" {
|
||
// This might be our implementation, or ConPTY itself.
|
||
// It's difficult to find extensive tests for it, so
|
||
// it seems like it could be either.
|
||
t.Skip("ConPTY appears to be inconsistent on Windows.")
|
||
}
|
||
|
||
// Only the banner updates dynamically; the MOTD file does not.
|
||
wantServiceBanner := "Service banner text goes here"
|
||
|
||
tests := []struct {
|
||
banner codersdk.ServiceBannerConfig
|
||
expected []string
|
||
unexpected []string
|
||
}{
|
||
{
|
||
banner: codersdk.ServiceBannerConfig{},
|
||
expected: []string{},
|
||
unexpected: []string{wantServiceBanner},
|
||
},
|
||
{
|
||
banner: codersdk.ServiceBannerConfig{
|
||
Enabled: true,
|
||
Message: wantServiceBanner,
|
||
},
|
||
expected: []string{wantServiceBanner},
|
||
},
|
||
{
|
||
banner: codersdk.ServiceBannerConfig{
|
||
Enabled: false,
|
||
Message: wantServiceBanner,
|
||
},
|
||
expected: []string{},
|
||
unexpected: []string{wantServiceBanner},
|
||
},
|
||
{
|
||
banner: codersdk.ServiceBannerConfig{
|
||
Enabled: true,
|
||
Message: wantServiceBanner,
|
||
},
|
||
expected: []string{wantServiceBanner},
|
||
unexpected: []string{},
|
||
},
|
||
{
|
||
banner: codersdk.ServiceBannerConfig{},
|
||
unexpected: []string{wantServiceBanner},
|
||
},
|
||
}
|
||
|
||
setSBInterval := func(_ *agenttest.Client, opts *agent.Options) {
|
||
opts.ServiceBannerRefreshInterval = testutil.IntervalFast
|
||
}
|
||
//nolint:dogsled // Allow the blank identifiers.
|
||
conn, client, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, setSBInterval)
|
||
|
||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||
defer cancel()
|
||
|
||
//nolint:paralleltest // These tests need to swap the banner func.
|
||
for _, port := range sshPorts {
|
||
sshClient, err := conn.SSHClientOnPort(ctx, port)
|
||
require.NoError(t, err)
|
||
t.Cleanup(func() {
|
||
_ = sshClient.Close()
|
||
})
|
||
|
||
for i, test := range tests {
|
||
t.Run(fmt.Sprintf("(:%d)/%d", port, i), func(t *testing.T) {
|
||
// Set new banner func and wait for the agent to call it to update the
|
||
// banner. We wait for two calls to ensure the value has been stored:
|
||
// the second call can only begin after the first iteration of
|
||
// fetchServiceBannerLoop completes (call + store), so after
|
||
// receiving two signals at least one store has happened.
|
||
ready := make(chan struct{}, 2)
|
||
client.SetAnnouncementBannersFunc(func() ([]codersdk.BannerConfig, error) {
|
||
select {
|
||
case ready <- struct{}{}:
|
||
default:
|
||
}
|
||
return []codersdk.BannerConfig{test.banner}, nil
|
||
})
|
||
testutil.TryReceive(ctx, t, ready)
|
||
testutil.TryReceive(ctx, t, ready)
|
||
|
||
session, err := sshClient.NewSession()
|
||
require.NoError(t, err)
|
||
t.Cleanup(func() {
|
||
_ = session.Close()
|
||
})
|
||
|
||
testSessionOutput(t, session, test.expected, test.unexpected, nil)
|
||
})
|
||
}
|
||
}
|
||
}
|
||
|
||
//nolint:paralleltest // This test sets an environment variable.
|
||
func TestAgent_Session_TTY_QuietLogin(t *testing.T) {
|
||
if runtime.GOOS == "windows" {
|
||
// This might be our implementation, or ConPTY itself.
|
||
// It's difficult to find extensive tests for it, so
|
||
// it seems like it could be either.
|
||
t.Skip("ConPTY appears to be inconsistent on Windows.")
|
||
}
|
||
|
||
wantNotMOTD := "Welcome to your Coder workspace!"
|
||
wantMaybeServiceBanner := "Service banner text goes here"
|
||
|
||
u, err := user.Current()
|
||
require.NoError(t, err, "get current user")
|
||
|
||
name := filepath.Join(u.HomeDir, "motd")
|
||
|
||
// Neither banner nor MOTD should show if not a login shell.
|
||
t.Run("NotLogin", func(t *testing.T) {
|
||
session := setupSSHSession(t, agentsdk.Manifest{
|
||
MOTDFile: name,
|
||
}, codersdk.ServiceBannerConfig{
|
||
Enabled: true,
|
||
Message: wantMaybeServiceBanner,
|
||
}, func(fs afero.Fs) {
|
||
err := afero.WriteFile(fs, name, []byte(wantNotMOTD), 0o600)
|
||
require.NoError(t, err, "write motd file")
|
||
})
|
||
err = session.RequestPty("xterm", 128, 128, ssh.TerminalModes{})
|
||
require.NoError(t, err)
|
||
|
||
wantEcho := "foobar"
|
||
command := "echo " + wantEcho
|
||
output, err := session.Output(command)
|
||
require.NoError(t, err)
|
||
|
||
require.Contains(t, string(output), wantEcho, "should show echo")
|
||
require.NotContains(t, string(output), wantNotMOTD, "should not show motd")
|
||
require.NotContains(t, string(output), wantMaybeServiceBanner, "should not show service banner")
|
||
})
|
||
|
||
// Only the MOTD should be silenced when hushlogin is present.
|
||
t.Run("Hushlogin", func(t *testing.T) {
|
||
session := setupSSHSession(t, agentsdk.Manifest{
|
||
MOTDFile: name,
|
||
}, codersdk.ServiceBannerConfig{
|
||
Enabled: true,
|
||
Message: wantMaybeServiceBanner,
|
||
}, func(fs afero.Fs) {
|
||
err := afero.WriteFile(fs, name, []byte(wantNotMOTD), 0o600)
|
||
require.NoError(t, err, "write motd file")
|
||
|
||
// Create hushlogin to silence motd.
|
||
err = afero.WriteFile(fs, name, []byte{}, 0o600)
|
||
require.NoError(t, err, "write hushlogin file")
|
||
})
|
||
err = session.RequestPty("xterm", 128, 128, ssh.TerminalModes{})
|
||
require.NoError(t, err)
|
||
|
||
ptty := ptytest.New(t)
|
||
var stdout bytes.Buffer
|
||
session.Stdout = &stdout
|
||
session.Stderr = ptty.Output()
|
||
session.Stdin = ptty.Input()
|
||
err = session.Shell()
|
||
require.NoError(t, err)
|
||
|
||
ptty.WriteLine("exit 0")
|
||
err = session.Wait()
|
||
require.NoError(t, err)
|
||
|
||
require.NotContains(t, stdout.String(), wantNotMOTD, "should not show motd")
|
||
require.Contains(t, stdout.String(), wantMaybeServiceBanner, "should show service banner")
|
||
})
|
||
}
|
||
|
||
func TestAgent_Session_TTY_FastCommandHasOutput(t *testing.T) {
|
||
t.Parallel()
|
||
// This test is here to prevent regressions where quickly executing
|
||
// commands (with TTY) don't sync their output to the SSH session.
|
||
//
|
||
// See: https://github.com/coder/coder/issues/6656
|
||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||
defer cancel()
|
||
//nolint:dogsled
|
||
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
|
||
sshClient, err := conn.SSHClient(ctx)
|
||
require.NoError(t, err)
|
||
defer sshClient.Close()
|
||
|
||
ptty := ptytest.New(t)
|
||
|
||
var stdout bytes.Buffer
|
||
// NOTE(mafredri): Increase iterations to increase chance of failure,
|
||
// assuming bug is present. Limiting GOMAXPROCS further
|
||
// increases the chance of failure.
|
||
// Using 1000 iterations is basically a guaranteed failure (but let's
|
||
// not increase test times needlessly).
|
||
// Limit GOMAXPROCS (e.g. `export GOMAXPROCS=1`) to further increase
|
||
// chance of failure. Also -race helps.
|
||
for i := 0; i < 5; i++ {
|
||
func() {
|
||
stdout.Reset()
|
||
|
||
session, err := sshClient.NewSession()
|
||
require.NoError(t, err)
|
||
defer session.Close()
|
||
err = session.RequestPty("xterm", 128, 128, ssh.TerminalModes{})
|
||
require.NoError(t, err)
|
||
|
||
session.Stdout = &stdout
|
||
session.Stderr = ptty.Output()
|
||
session.Stdin = ptty.Input()
|
||
err = session.Start("echo wazzup")
|
||
require.NoError(t, err)
|
||
|
||
err = session.Wait()
|
||
require.NoError(t, err)
|
||
require.Contains(t, stdout.String(), "wazzup", "should output greeting")
|
||
}()
|
||
}
|
||
}
|
||
|
||
func TestAgent_Session_TTY_HugeOutputIsNotLost(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
// This test is here to prevent regressions where a command (with or
|
||
// without) a large amount of output would not be fully copied to the
|
||
// SSH session. On unix systems, this was fixed by duplicating the file
|
||
// descriptor of the PTY master and using it for copying the output.
|
||
//
|
||
// See: https://github.com/coder/coder/issues/6656
|
||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||
defer cancel()
|
||
//nolint:dogsled
|
||
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
|
||
sshClient, err := conn.SSHClient(ctx)
|
||
require.NoError(t, err)
|
||
defer sshClient.Close()
|
||
|
||
ptty := ptytest.New(t)
|
||
|
||
var stdout bytes.Buffer
|
||
// NOTE(mafredri): Increase iterations to increase chance of failure,
|
||
// assuming bug is present.
|
||
// Using 10 iterations is basically a guaranteed failure (but let's
|
||
// not increase test times needlessly). Run with -race and do not
|
||
// limit parallelism (`export GOMAXPROCS=10`) to increase the chance
|
||
// of failure.
|
||
for i := 0; i < 1; i++ {
|
||
func() {
|
||
stdout.Reset()
|
||
|
||
session, err := sshClient.NewSession()
|
||
require.NoError(t, err)
|
||
defer session.Close()
|
||
err = session.RequestPty("xterm", 128, 128, ssh.TerminalModes{})
|
||
require.NoError(t, err)
|
||
|
||
session.Stdout = &stdout
|
||
session.Stderr = ptty.Output()
|
||
session.Stdin = ptty.Input()
|
||
want := strings.Repeat("wazzup", 1024+1) // ~6KB, +1 because 1024 is a common buffer size.
|
||
err = session.Start("echo " + want)
|
||
require.NoError(t, err)
|
||
|
||
err = session.Wait()
|
||
require.NoError(t, err)
|
||
require.Contains(t, stdout.String(), want, "should output entire greeting")
|
||
}()
|
||
}
|
||
}
|
||
|
||
func TestAgent_TCPLocalForwarding(t *testing.T) {
|
||
t.Parallel()
|
||
ctx := testutil.Context(t, testutil.WaitLong)
|
||
|
||
rl, err := net.Listen("tcp", "127.0.0.1:0")
|
||
require.NoError(t, err)
|
||
defer rl.Close()
|
||
tcpAddr, valid := rl.Addr().(*net.TCPAddr)
|
||
require.True(t, valid)
|
||
remotePort := tcpAddr.Port
|
||
go echoOnce(t, rl)
|
||
|
||
sshClient := setupAgentSSHClient(ctx, t)
|
||
|
||
conn, err := sshClient.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", remotePort))
|
||
require.NoError(t, err)
|
||
defer conn.Close()
|
||
requireEcho(t, conn)
|
||
}
|
||
|
||
func TestAgent_TCPRemoteForwarding(t *testing.T) {
|
||
t.Parallel()
|
||
ctx := testutil.Context(t, testutil.WaitLong)
|
||
sshClient := setupAgentSSHClient(ctx, t)
|
||
|
||
localhost := netip.MustParseAddr("127.0.0.1")
|
||
var randomPort uint16
|
||
var ll net.Listener
|
||
var err error
|
||
for {
|
||
randomPort = testutil.RandomPortNoListen(t)
|
||
addr := net.TCPAddrFromAddrPort(netip.AddrPortFrom(localhost, randomPort))
|
||
ll, err = sshClient.ListenTCP(addr)
|
||
if err != nil {
|
||
t.Logf("error remote forwarding: %s", err.Error())
|
||
select {
|
||
case <-ctx.Done():
|
||
t.Fatal("timed out getting random listener")
|
||
default:
|
||
continue
|
||
}
|
||
}
|
||
break
|
||
}
|
||
defer ll.Close()
|
||
go echoOnce(t, ll)
|
||
|
||
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", randomPort))
|
||
require.NoError(t, err)
|
||
defer conn.Close()
|
||
requireEcho(t, conn)
|
||
}
|
||
|
||
func TestAgent_TCPLocalForwardingBlocked(t *testing.T) {
|
||
t.Parallel()
|
||
ctx := testutil.Context(t, testutil.WaitLong)
|
||
|
||
rl, err := net.Listen("tcp", "127.0.0.1:0")
|
||
require.NoError(t, err)
|
||
defer rl.Close()
|
||
tcpAddr, valid := rl.Addr().(*net.TCPAddr)
|
||
require.True(t, valid)
|
||
remotePort := tcpAddr.Port
|
||
|
||
//nolint:dogsled
|
||
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||
o.BlockLocalPortForwarding = true
|
||
})
|
||
sshClient, err := agentConn.SSHClient(ctx)
|
||
require.NoError(t, err)
|
||
defer sshClient.Close()
|
||
|
||
_, err = sshClient.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", remotePort))
|
||
require.ErrorContains(t, err, "administratively prohibited")
|
||
}
|
||
|
||
func TestAgent_TCPRemoteForwardingBlocked(t *testing.T) {
|
||
t.Parallel()
|
||
ctx := testutil.Context(t, testutil.WaitLong)
|
||
|
||
//nolint:dogsled
|
||
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||
o.BlockReversePortForwarding = true
|
||
})
|
||
sshClient, err := agentConn.SSHClient(ctx)
|
||
require.NoError(t, err)
|
||
defer sshClient.Close()
|
||
|
||
localhost := netip.MustParseAddr("127.0.0.1")
|
||
randomPort := testutil.RandomPortNoListen(t)
|
||
addr := net.TCPAddrFromAddrPort(netip.AddrPortFrom(localhost, randomPort))
|
||
_, err = sshClient.ListenTCP(addr)
|
||
require.ErrorContains(t, err, "tcpip-forward request denied by peer")
|
||
}
|
||
|
||
func TestAgent_UnixLocalForwardingBlocked(t *testing.T) {
|
||
t.Parallel()
|
||
if runtime.GOOS == "windows" {
|
||
t.Skip("unix domain sockets are not fully supported on Windows")
|
||
}
|
||
ctx := testutil.Context(t, testutil.WaitLong)
|
||
tmpdir := testutil.TempDirUnixSocket(t)
|
||
remoteSocketPath := filepath.Join(tmpdir, "remote-socket")
|
||
|
||
l, err := net.Listen("unix", remoteSocketPath)
|
||
require.NoError(t, err)
|
||
defer l.Close()
|
||
|
||
//nolint:dogsled
|
||
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||
o.BlockLocalPortForwarding = true
|
||
})
|
||
sshClient, err := agentConn.SSHClient(ctx)
|
||
require.NoError(t, err)
|
||
defer sshClient.Close()
|
||
|
||
_, err = sshClient.Dial("unix", remoteSocketPath)
|
||
require.ErrorContains(t, err, "administratively prohibited")
|
||
}
|
||
|
||
func TestAgent_UnixRemoteForwardingBlocked(t *testing.T) {
|
||
t.Parallel()
|
||
if runtime.GOOS == "windows" {
|
||
t.Skip("unix domain sockets are not fully supported on Windows")
|
||
}
|
||
ctx := testutil.Context(t, testutil.WaitLong)
|
||
tmpdir := testutil.TempDirUnixSocket(t)
|
||
remoteSocketPath := filepath.Join(tmpdir, "remote-socket")
|
||
|
||
//nolint:dogsled
|
||
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||
o.BlockReversePortForwarding = true
|
||
})
|
||
sshClient, err := agentConn.SSHClient(ctx)
|
||
require.NoError(t, err)
|
||
defer sshClient.Close()
|
||
|
||
_, err = sshClient.ListenUnix(remoteSocketPath)
|
||
require.ErrorContains(t, err, "streamlocal-forward@openssh.com request denied by peer")
|
||
}
|
||
|
||
// TestAgent_LocalBlockedDoesNotAffectReverse verifies that blocking
|
||
// local port forwarding does not prevent reverse port forwarding from
|
||
// working. A field-name transposition at any plumbing hop would cause
|
||
// both directions to be blocked when only one flag is set.
|
||
func TestAgent_LocalBlockedDoesNotAffectReverse(t *testing.T) {
|
||
t.Parallel()
|
||
ctx := testutil.Context(t, testutil.WaitLong)
|
||
|
||
//nolint:dogsled
|
||
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||
o.BlockLocalPortForwarding = true
|
||
})
|
||
sshClient, err := agentConn.SSHClient(ctx)
|
||
require.NoError(t, err)
|
||
defer sshClient.Close()
|
||
|
||
// Reverse forwarding must still work.
|
||
localhost := netip.MustParseAddr("127.0.0.1")
|
||
var ll net.Listener
|
||
for {
|
||
randomPort := testutil.RandomPortNoListen(t)
|
||
addr := net.TCPAddrFromAddrPort(netip.AddrPortFrom(localhost, randomPort))
|
||
ll, err = sshClient.ListenTCP(addr)
|
||
if err != nil {
|
||
t.Logf("error remote forwarding: %s", err.Error())
|
||
select {
|
||
case <-ctx.Done():
|
||
t.Fatal("timed out getting random listener")
|
||
default:
|
||
continue
|
||
}
|
||
}
|
||
break
|
||
}
|
||
_ = ll.Close()
|
||
}
|
||
|
||
// TestAgent_ReverseBlockedDoesNotAffectLocal verifies that blocking
|
||
// reverse port forwarding does not prevent local port forwarding from
|
||
// working.
|
||
func TestAgent_ReverseBlockedDoesNotAffectLocal(t *testing.T) {
|
||
t.Parallel()
|
||
ctx := testutil.Context(t, testutil.WaitLong)
|
||
|
||
rl, err := net.Listen("tcp", "127.0.0.1:0")
|
||
require.NoError(t, err)
|
||
defer rl.Close()
|
||
tcpAddr, valid := rl.Addr().(*net.TCPAddr)
|
||
require.True(t, valid)
|
||
remotePort := tcpAddr.Port
|
||
go echoOnce(t, rl)
|
||
|
||
//nolint:dogsled
|
||
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||
o.BlockReversePortForwarding = true
|
||
})
|
||
sshClient, err := agentConn.SSHClient(ctx)
|
||
require.NoError(t, err)
|
||
defer sshClient.Close()
|
||
|
||
// Local forwarding must still work.
|
||
conn, err := sshClient.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", remotePort))
|
||
require.NoError(t, err)
|
||
defer conn.Close()
|
||
requireEcho(t, conn)
|
||
}
|
||
|
||
func TestAgent_UnixLocalForwarding(t *testing.T) {
|
||
t.Parallel()
|
||
if runtime.GOOS == "windows" {
|
||
t.Skip("unix domain sockets are not fully supported on Windows")
|
||
}
|
||
ctx := testutil.Context(t, testutil.WaitLong)
|
||
tmpdir := testutil.TempDirUnixSocket(t)
|
||
remoteSocketPath := filepath.Join(tmpdir, "remote-socket")
|
||
|
||
l, err := net.Listen("unix", remoteSocketPath)
|
||
require.NoError(t, err)
|
||
defer l.Close()
|
||
go echoOnce(t, l)
|
||
|
||
sshClient := setupAgentSSHClient(ctx, t)
|
||
|
||
conn, err := sshClient.Dial("unix", remoteSocketPath)
|
||
require.NoError(t, err)
|
||
defer conn.Close()
|
||
_, err = conn.Write([]byte("test"))
|
||
require.NoError(t, err)
|
||
b := make([]byte, 4)
|
||
_, err = conn.Read(b)
|
||
require.NoError(t, err)
|
||
require.Equal(t, "test", string(b))
|
||
_ = conn.Close()
|
||
}
|
||
|
||
func TestAgent_UnixRemoteForwarding(t *testing.T) {
|
||
t.Parallel()
|
||
if runtime.GOOS == "windows" {
|
||
t.Skip("unix domain sockets are not fully supported on Windows")
|
||
}
|
||
|
||
tmpdir := testutil.TempDirUnixSocket(t)
|
||
remoteSocketPath := filepath.Join(tmpdir, "remote-socket")
|
||
|
||
ctx := testutil.Context(t, testutil.WaitLong)
|
||
sshClient := setupAgentSSHClient(ctx, t)
|
||
|
||
l, err := sshClient.ListenUnix(remoteSocketPath)
|
||
require.NoError(t, err)
|
||
defer l.Close()
|
||
go echoOnce(t, l)
|
||
|
||
conn, err := net.Dial("unix", remoteSocketPath)
|
||
require.NoError(t, err)
|
||
defer conn.Close()
|
||
requireEcho(t, conn)
|
||
}
|
||
|
||
func TestAgent_SFTP(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
t.Run("DefaultWorkingDirectory", func(t *testing.T) {
|
||
t.Parallel()
|
||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||
defer cancel()
|
||
u, err := user.Current()
|
||
require.NoError(t, err, "get current user")
|
||
home := u.HomeDir
|
||
if runtime.GOOS == "windows" {
|
||
home = "/" + strings.ReplaceAll(home, "\\", "/")
|
||
}
|
||
//nolint:dogsled
|
||
conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
|
||
sshClient, err := conn.SSHClient(ctx)
|
||
require.NoError(t, err)
|
||
defer sshClient.Close()
|
||
client, err := sftp.NewClient(sshClient)
|
||
require.NoError(t, err)
|
||
defer client.Close()
|
||
wd, err := client.Getwd()
|
||
require.NoError(t, err, "get working directory")
|
||
require.Equal(t, home, wd, "working directory should be user home")
|
||
tempFile := filepath.Join(t.TempDir(), "sftp")
|
||
// SFTP only accepts unix-y paths.
|
||
remoteFile := filepath.ToSlash(tempFile)
|
||
if !path.IsAbs(remoteFile) {
|
||
// On Windows, e.g. "/C:/Users/...".
|
||
remoteFile = path.Join("/", remoteFile)
|
||
}
|
||
file, err := client.Create(remoteFile)
|
||
require.NoError(t, err)
|
||
err = file.Close()
|
||
require.NoError(t, err)
|
||
_, err = os.Stat(tempFile)
|
||
require.NoError(t, err)
|
||
|
||
// Close the client to trigger disconnect event.
|
||
_ = client.Close()
|
||
assertConnectionReport(t, agentClient, proto.Connection_SSH, 0, "")
|
||
})
|
||
|
||
t.Run("CustomWorkingDirectory", func(t *testing.T) {
|
||
t.Parallel()
|
||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||
defer cancel()
|
||
|
||
// Create a custom directory for the agent to use.
|
||
customDir := t.TempDir()
|
||
expectedDir := customDir
|
||
if runtime.GOOS == "windows" {
|
||
expectedDir = "/" + strings.ReplaceAll(customDir, "\\", "/")
|
||
}
|
||
|
||
//nolint:dogsled
|
||
conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{
|
||
Directory: customDir,
|
||
}, 0)
|
||
sshClient, err := conn.SSHClient(ctx)
|
||
require.NoError(t, err)
|
||
defer sshClient.Close()
|
||
client, err := sftp.NewClient(sshClient)
|
||
require.NoError(t, err)
|
||
defer client.Close()
|
||
wd, err := client.Getwd()
|
||
require.NoError(t, err, "get working directory")
|
||
require.Equal(t, expectedDir, wd, "working directory should be custom directory")
|
||
|
||
// Close the client to trigger disconnect event.
|
||
_ = client.Close()
|
||
assertConnectionReport(t, agentClient, proto.Connection_SSH, 0, "")
|
||
})
|
||
}
|
||
|
||
func TestAgent_SCP(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||
defer cancel()
|
||
|
||
//nolint:dogsled
|
||
conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
|
||
sshClient, err := conn.SSHClient(ctx)
|
||
require.NoError(t, err)
|
||
defer sshClient.Close()
|
||
scpClient, err := scp.NewClientBySSH(sshClient)
|
||
require.NoError(t, err)
|
||
defer scpClient.Close()
|
||
tempFile := filepath.Join(t.TempDir(), "scp")
|
||
content := "hello world"
|
||
err = scpClient.CopyFile(context.Background(), strings.NewReader(content), tempFile, "0755")
|
||
require.NoError(t, err)
|
||
_, err = os.Stat(tempFile)
|
||
require.NoError(t, err)
|
||
|
||
// Close the client to trigger disconnect event.
|
||
scpClient.Close()
|
||
assertConnectionReport(t, agentClient, proto.Connection_SSH, 0, "")
|
||
}
|
||
|
||
func TestAgent_FileTransferBlocked(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
assertFileTransferBlocked := func(t *testing.T, errorMessage string) {
|
||
// NOTE: Checking content of the error message is flaky. Most likely there is a race condition, which results
|
||
// in stopping the client in different phases, and returning different errors:
|
||
// - client read the full error message: File transfer has been disabled.
|
||
// - client's stream was terminated before reading the error message: EOF
|
||
// - client just read the error code (Windows): Process exited with status 65
|
||
isErr := strings.Contains(errorMessage, agentssh.BlockedFileTransferErrorMessage) ||
|
||
strings.Contains(errorMessage, "EOF") ||
|
||
strings.Contains(errorMessage, "Process exited with status 65")
|
||
require.True(t, isErr, "Message: "+errorMessage)
|
||
}
|
||
|
||
t.Run("SFTP", func(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||
defer cancel()
|
||
|
||
//nolint:dogsled
|
||
conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||
o.BlockFileTransfer = true
|
||
})
|
||
sshClient, err := conn.SSHClient(ctx)
|
||
require.NoError(t, err)
|
||
defer sshClient.Close()
|
||
_, err = sftp.NewClient(sshClient)
|
||
require.Error(t, err)
|
||
assertFileTransferBlocked(t, err.Error())
|
||
|
||
assertConnectionReport(t, agentClient, proto.Connection_SSH, agentssh.BlockedFileTransferErrorCode, "")
|
||
})
|
||
|
||
t.Run("SCP with go-scp package", func(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||
defer cancel()
|
||
|
||
//nolint:dogsled
|
||
conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||
o.BlockFileTransfer = true
|
||
})
|
||
sshClient, err := conn.SSHClient(ctx)
|
||
require.NoError(t, err)
|
||
defer sshClient.Close()
|
||
scpClient, err := scp.NewClientBySSH(sshClient)
|
||
require.NoError(t, err)
|
||
defer scpClient.Close()
|
||
tempFile := filepath.Join(t.TempDir(), "scp")
|
||
err = scpClient.CopyFile(context.Background(), strings.NewReader("hello world"), tempFile, "0755")
|
||
require.Error(t, err)
|
||
assertFileTransferBlocked(t, err.Error())
|
||
|
||
assertConnectionReport(t, agentClient, proto.Connection_SSH, agentssh.BlockedFileTransferErrorCode, "")
|
||
})
|
||
|
||
t.Run("Forbidden commands", func(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
for _, c := range agentssh.BlockedFileTransferCommands {
|
||
t.Run(c, func(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||
defer cancel()
|
||
|
||
//nolint:dogsled
|
||
conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||
o.BlockFileTransfer = true
|
||
})
|
||
sshClient, err := conn.SSHClient(ctx)
|
||
require.NoError(t, err)
|
||
defer sshClient.Close()
|
||
|
||
session, err := sshClient.NewSession()
|
||
require.NoError(t, err)
|
||
defer session.Close()
|
||
|
||
stdout, err := session.StdoutPipe()
|
||
require.NoError(t, err)
|
||
|
||
//nolint:govet // we don't need `c := c` in Go 1.22
|
||
err = session.Start(c)
|
||
require.NoError(t, err)
|
||
defer session.Close()
|
||
|
||
msg, err := io.ReadAll(stdout)
|
||
require.NoError(t, err)
|
||
assertFileTransferBlocked(t, string(msg))
|
||
|
||
assertConnectionReport(t, agentClient, proto.Connection_SSH, agentssh.BlockedFileTransferErrorCode, "")
|
||
})
|
||
}
|
||
})
|
||
}
|
||
|
||
func TestAgent_EnvironmentVariables(t *testing.T) {
|
||
t.Parallel()
|
||
key := "EXAMPLE"
|
||
value := "value"
|
||
session := setupSSHSession(t, agentsdk.Manifest{
|
||
EnvironmentVariables: map[string]string{
|
||
key: value,
|
||
},
|
||
}, codersdk.ServiceBannerConfig{}, nil)
|
||
command := "sh -c 'echo $" + key + "'"
|
||
if runtime.GOOS == "windows" {
|
||
command = "cmd.exe /c echo %" + key + "%"
|
||
}
|
||
output, err := session.Output(command)
|
||
require.NoError(t, err)
|
||
require.Equal(t, value, strings.TrimSpace(string(output)))
|
||
}
|
||
|
||
func TestAgent_EnvironmentVariableExpansion(t *testing.T) {
|
||
t.Parallel()
|
||
key := "EXAMPLE"
|
||
session := setupSSHSession(t, agentsdk.Manifest{
|
||
EnvironmentVariables: map[string]string{
|
||
key: "$SOMETHINGNOTSET",
|
||
},
|
||
}, codersdk.ServiceBannerConfig{}, nil)
|
||
command := "sh -c 'echo $" + key + "'"
|
||
if runtime.GOOS == "windows" {
|
||
command = "cmd.exe /c echo %" + key + "%"
|
||
}
|
||
output, err := session.Output(command)
|
||
require.NoError(t, err)
|
||
expect := ""
|
||
if runtime.GOOS == "windows" {
|
||
expect = "%EXAMPLE%"
|
||
}
|
||
// Output should be empty, because the variable is not set!
|
||
require.Equal(t, expect, strings.TrimSpace(string(output)))
|
||
}
|
||
|
||
func TestAgent_CoderEnvVars(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
for _, key := range []string{"CODER", "CODER_WORKSPACE_NAME", "CODER_WORKSPACE_OWNER_NAME", "CODER_WORKSPACE_AGENT_NAME"} {
|
||
t.Run(key, func(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
session := setupSSHSession(t, agentsdk.Manifest{}, codersdk.ServiceBannerConfig{}, nil)
|
||
command := "sh -c 'echo $" + key + "'"
|
||
if runtime.GOOS == "windows" {
|
||
command = "cmd.exe /c echo %" + key + "%"
|
||
}
|
||
output, err := session.Output(command)
|
||
require.NoError(t, err)
|
||
require.NotEmpty(t, strings.TrimSpace(string(output)))
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestAgent_SSHConnectionEnvVars(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
// Note: the SSH_TTY environment variable should only be set for TTYs.
|
||
// For some reason this test produces a TTY locally and a non-TTY in CI
|
||
// so we don't test for the absence of SSH_TTY.
|
||
for _, key := range []string{"SSH_CONNECTION", "SSH_CLIENT"} {
|
||
t.Run(key, func(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
session := setupSSHSession(t, agentsdk.Manifest{}, codersdk.ServiceBannerConfig{}, nil)
|
||
command := "sh -c 'echo $" + key + "'"
|
||
if runtime.GOOS == "windows" {
|
||
command = "cmd.exe /c echo %" + key + "%"
|
||
}
|
||
output, err := session.Output(command)
|
||
require.NoError(t, err)
|
||
require.NotEmpty(t, strings.TrimSpace(string(output)))
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestAgent_SSHConnectionLoginVars(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
envInfo := usershell.SystemEnvInfo{}
|
||
u, err := envInfo.User()
|
||
require.NoError(t, err, "get current user")
|
||
shell, err := envInfo.Shell(u.Username)
|
||
require.NoError(t, err, "get current shell")
|
||
|
||
tests := []struct {
|
||
key string
|
||
want string
|
||
}{
|
||
{
|
||
key: "USER",
|
||
want: u.Username,
|
||
},
|
||
{
|
||
key: "LOGNAME",
|
||
want: u.Username,
|
||
},
|
||
{
|
||
key: "SHELL",
|
||
want: shell,
|
||
},
|
||
}
|
||
for _, tt := range tests {
|
||
t.Run(tt.key, func(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
session := setupSSHSession(t, agentsdk.Manifest{}, codersdk.ServiceBannerConfig{}, nil)
|
||
command := "sh -c 'echo $" + tt.key + "'"
|
||
if runtime.GOOS == "windows" {
|
||
command = "cmd.exe /c echo %" + tt.key + "%"
|
||
}
|
||
output, err := session.Output(command)
|
||
require.NoError(t, err)
|
||
require.Equal(t, tt.want, strings.TrimSpace(string(output)))
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestAgent_Metadata(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
echoHello := "echo 'hello'"
|
||
|
||
t.Run("Once", func(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
//nolint:dogsled
|
||
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
|
||
Metadata: []codersdk.WorkspaceAgentMetadataDescription{
|
||
{
|
||
Key: "greeting1",
|
||
Interval: 0,
|
||
Script: echoHello,
|
||
},
|
||
{
|
||
Key: "greeting2",
|
||
Interval: 1,
|
||
Script: echoHello,
|
||
},
|
||
},
|
||
}, 0, func(_ *agenttest.Client, opts *agent.Options) {
|
||
opts.ReportMetadataInterval = testutil.IntervalFast
|
||
})
|
||
|
||
var gotMd map[string]agentsdk.Metadata
|
||
require.Eventually(t, func() bool {
|
||
gotMd = client.GetMetadata()
|
||
return len(gotMd) == 2
|
||
}, testutil.WaitShort, testutil.IntervalFast/2)
|
||
|
||
collectedAt1 := gotMd["greeting1"].CollectedAt
|
||
collectedAt2 := gotMd["greeting2"].CollectedAt
|
||
|
||
require.Eventually(t, func() bool {
|
||
gotMd = client.GetMetadata()
|
||
if len(gotMd) != 2 {
|
||
panic("unexpected number of metadata")
|
||
}
|
||
return !gotMd["greeting2"].CollectedAt.Equal(collectedAt2)
|
||
}, testutil.WaitShort, testutil.IntervalFast/2)
|
||
|
||
require.Equal(t, gotMd["greeting1"].CollectedAt, collectedAt1, "metadata should not be collected again")
|
||
})
|
||
|
||
t.Run("Many", func(t *testing.T) {
|
||
t.Parallel()
|
||
//nolint:dogsled
|
||
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
|
||
Metadata: []codersdk.WorkspaceAgentMetadataDescription{
|
||
{
|
||
Key: "greeting",
|
||
Interval: 1,
|
||
Timeout: 100,
|
||
Script: echoHello,
|
||
},
|
||
},
|
||
}, 0, func(_ *agenttest.Client, opts *agent.Options) {
|
||
opts.ReportMetadataInterval = testutil.IntervalFast
|
||
})
|
||
|
||
var gotMd map[string]agentsdk.Metadata
|
||
require.Eventually(t, func() bool {
|
||
gotMd = client.GetMetadata()
|
||
return len(gotMd) == 1
|
||
}, testutil.WaitShort, testutil.IntervalFast/2)
|
||
|
||
collectedAt1 := gotMd["greeting"].CollectedAt
|
||
require.Equal(t, "hello", strings.TrimSpace(gotMd["greeting"].Value))
|
||
|
||
if !assert.Eventually(t, func() bool {
|
||
gotMd = client.GetMetadata()
|
||
return gotMd["greeting"].CollectedAt.After(collectedAt1)
|
||
}, testutil.WaitShort, testutil.IntervalFast/2) {
|
||
t.Fatalf("expected metadata to be collected again")
|
||
}
|
||
})
|
||
}
|
||
|
||
func TestAgentMetadata_Timing(t *testing.T) {
|
||
if runtime.GOOS == "windows" {
|
||
// Shell scripting in Windows is a pain, and we have already tested
|
||
// that the OS logic works in the simpler tests.
|
||
t.SkipNow()
|
||
}
|
||
testutil.SkipIfNotTiming(t)
|
||
t.Parallel()
|
||
|
||
dir := t.TempDir()
|
||
|
||
const reportInterval = 2
|
||
const intervalUnit = 100 * time.Millisecond
|
||
var (
|
||
greetingPath = filepath.Join(dir, "greeting")
|
||
script = "echo hello | tee -a " + greetingPath
|
||
)
|
||
//nolint:dogsled
|
||
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
|
||
Metadata: []codersdk.WorkspaceAgentMetadataDescription{
|
||
{
|
||
Key: "greeting",
|
||
Interval: reportInterval,
|
||
Script: script,
|
||
},
|
||
{
|
||
Key: "bad",
|
||
Interval: reportInterval,
|
||
Script: "exit 1",
|
||
},
|
||
},
|
||
}, 0, func(_ *agenttest.Client, opts *agent.Options) {
|
||
opts.ReportMetadataInterval = intervalUnit
|
||
})
|
||
|
||
require.Eventually(t, func() bool {
|
||
return len(client.GetMetadata()) == 2
|
||
}, testutil.WaitShort, testutil.IntervalMedium)
|
||
|
||
for start := time.Now(); time.Since(start) < testutil.WaitMedium; time.Sleep(testutil.IntervalMedium) {
|
||
md := client.GetMetadata()
|
||
require.Len(t, md, 2, "got: %+v", md)
|
||
|
||
require.Equal(t, "hello\n", md["greeting"].Value)
|
||
require.Equal(t, "run cmd: exit status 1", md["bad"].Error)
|
||
|
||
greetingByt, err := os.ReadFile(greetingPath)
|
||
require.NoError(t, err)
|
||
|
||
var (
|
||
numGreetings = bytes.Count(greetingByt, []byte("hello"))
|
||
idealNumGreetings = time.Since(start) / (reportInterval * intervalUnit)
|
||
// We allow a 50% error margin because the report loop may backlog
|
||
// in CI and other toasters. In production, there is no hard
|
||
// guarantee on timing either, and the frontend gives similar
|
||
// wiggle room to the staleness of the value.
|
||
upperBound = int(idealNumGreetings) + 1
|
||
lowerBound = (int(idealNumGreetings) / 2)
|
||
)
|
||
|
||
if idealNumGreetings < 50 {
|
||
// There is an insufficient sample size.
|
||
continue
|
||
}
|
||
|
||
t.Logf("numGreetings: %d, idealNumGreetings: %d", numGreetings, idealNumGreetings)
|
||
// The report loop may slow down on load, but it should never, ever
|
||
// speed up.
|
||
if numGreetings > upperBound {
|
||
t.Fatalf("too many greetings: %d > %d in %v", numGreetings, upperBound, time.Since(start))
|
||
} else if numGreetings < lowerBound {
|
||
t.Fatalf("too few greetings: %d < %d", numGreetings, lowerBound)
|
||
}
|
||
}
|
||
}
|
||
|
||
func TestAgent_Lifecycle(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
t.Run("StartTimeout", func(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
|
||
Scripts: []codersdk.WorkspaceAgentScript{{
|
||
Script: "sleep 3",
|
||
Timeout: time.Millisecond,
|
||
RunOnStart: true,
|
||
}},
|
||
}, 0)
|
||
|
||
want := []codersdk.WorkspaceAgentLifecycle{
|
||
codersdk.WorkspaceAgentLifecycleStarting,
|
||
codersdk.WorkspaceAgentLifecycleStartTimeout,
|
||
}
|
||
|
||
var got []codersdk.WorkspaceAgentLifecycle
|
||
assert.Eventually(t, func() bool {
|
||
got = client.GetLifecycleStates()
|
||
return slices.Contains(got, want[len(want)-1])
|
||
}, testutil.WaitShort, testutil.IntervalMedium)
|
||
|
||
require.Equal(t, want, got[:len(want)])
|
||
})
|
||
|
||
t.Run("StartError", func(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
|
||
Scripts: []codersdk.WorkspaceAgentScript{{
|
||
Script: "false",
|
||
Timeout: 30 * time.Second,
|
||
RunOnStart: true,
|
||
}},
|
||
}, 0)
|
||
|
||
want := []codersdk.WorkspaceAgentLifecycle{
|
||
codersdk.WorkspaceAgentLifecycleStarting,
|
||
codersdk.WorkspaceAgentLifecycleStartError,
|
||
}
|
||
|
||
var got []codersdk.WorkspaceAgentLifecycle
|
||
assert.Eventually(t, func() bool {
|
||
got = client.GetLifecycleStates()
|
||
return slices.Contains(got, want[len(want)-1])
|
||
}, testutil.WaitShort, testutil.IntervalMedium)
|
||
|
||
require.Equal(t, want, got[:len(want)])
|
||
})
|
||
|
||
t.Run("Ready", func(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
|
||
Scripts: []codersdk.WorkspaceAgentScript{{
|
||
Script: "echo foo",
|
||
Timeout: 30 * time.Second,
|
||
RunOnStart: true,
|
||
}},
|
||
}, 0)
|
||
|
||
want := []codersdk.WorkspaceAgentLifecycle{
|
||
codersdk.WorkspaceAgentLifecycleStarting,
|
||
codersdk.WorkspaceAgentLifecycleReady,
|
||
}
|
||
|
||
var got []codersdk.WorkspaceAgentLifecycle
|
||
assert.Eventually(t, func() bool {
|
||
got = client.GetLifecycleStates()
|
||
return len(got) > 0 && got[len(got)-1] == want[len(want)-1]
|
||
}, testutil.WaitShort, testutil.IntervalMedium)
|
||
|
||
require.Equal(t, want, got)
|
||
})
|
||
|
||
t.Run("ShuttingDown", func(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
_, client, _, _, closer := setupAgent(t, agentsdk.Manifest{
|
||
Scripts: []codersdk.WorkspaceAgentScript{{
|
||
Script: "sleep 3",
|
||
Timeout: 30 * time.Second,
|
||
RunOnStop: true,
|
||
}},
|
||
}, 0)
|
||
|
||
assert.Eventually(t, func() bool {
|
||
return slices.Contains(client.GetLifecycleStates(), codersdk.WorkspaceAgentLifecycleReady)
|
||
}, testutil.WaitShort, testutil.IntervalMedium)
|
||
|
||
// Start close asynchronously so that we an inspect the state.
|
||
done := make(chan struct{})
|
||
go func() {
|
||
defer close(done)
|
||
err := closer.Close()
|
||
assert.NoError(t, err)
|
||
}()
|
||
t.Cleanup(func() {
|
||
<-done
|
||
})
|
||
|
||
want := []codersdk.WorkspaceAgentLifecycle{
|
||
codersdk.WorkspaceAgentLifecycleStarting,
|
||
codersdk.WorkspaceAgentLifecycleReady,
|
||
codersdk.WorkspaceAgentLifecycleShuttingDown,
|
||
}
|
||
|
||
var got []codersdk.WorkspaceAgentLifecycle
|
||
assert.Eventually(t, func() bool {
|
||
got = client.GetLifecycleStates()
|
||
return slices.Contains(got, want[len(want)-1])
|
||
}, testutil.WaitShort, testutil.IntervalMedium)
|
||
|
||
require.Equal(t, want, got[:len(want)])
|
||
})
|
||
|
||
t.Run("ShutdownTimeout", func(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
_, client, _, _, closer := setupAgent(t, agentsdk.Manifest{
|
||
Scripts: []codersdk.WorkspaceAgentScript{{
|
||
Script: "sleep 3",
|
||
Timeout: time.Millisecond,
|
||
RunOnStop: true,
|
||
}},
|
||
}, 0)
|
||
|
||
assert.Eventually(t, func() bool {
|
||
return slices.Contains(client.GetLifecycleStates(), codersdk.WorkspaceAgentLifecycleReady)
|
||
}, testutil.WaitShort, testutil.IntervalMedium)
|
||
|
||
// Start close asynchronously so that we an inspect the state.
|
||
done := make(chan struct{})
|
||
go func() {
|
||
defer close(done)
|
||
err := closer.Close()
|
||
assert.NoError(t, err)
|
||
}()
|
||
t.Cleanup(func() {
|
||
<-done
|
||
})
|
||
|
||
want := []codersdk.WorkspaceAgentLifecycle{
|
||
codersdk.WorkspaceAgentLifecycleStarting,
|
||
codersdk.WorkspaceAgentLifecycleReady,
|
||
codersdk.WorkspaceAgentLifecycleShuttingDown,
|
||
codersdk.WorkspaceAgentLifecycleShutdownTimeout,
|
||
}
|
||
|
||
var got []codersdk.WorkspaceAgentLifecycle
|
||
assert.Eventually(t, func() bool {
|
||
got = client.GetLifecycleStates()
|
||
return slices.Contains(got, want[len(want)-1])
|
||
}, testutil.WaitShort, testutil.IntervalMedium)
|
||
|
||
require.Equal(t, want, got[:len(want)])
|
||
})
|
||
|
||
t.Run("ShutdownError", func(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
_, client, _, _, closer := setupAgent(t, agentsdk.Manifest{
|
||
Scripts: []codersdk.WorkspaceAgentScript{{
|
||
Script: "false",
|
||
Timeout: 30 * time.Second,
|
||
RunOnStop: true,
|
||
}},
|
||
}, 0)
|
||
|
||
assert.Eventually(t, func() bool {
|
||
return slices.Contains(client.GetLifecycleStates(), codersdk.WorkspaceAgentLifecycleReady)
|
||
}, testutil.WaitShort, testutil.IntervalMedium)
|
||
|
||
// Start close asynchronously so that we an inspect the state.
|
||
done := make(chan struct{})
|
||
go func() {
|
||
defer close(done)
|
||
err := closer.Close()
|
||
assert.NoError(t, err)
|
||
}()
|
||
t.Cleanup(func() {
|
||
<-done
|
||
})
|
||
|
||
want := []codersdk.WorkspaceAgentLifecycle{
|
||
codersdk.WorkspaceAgentLifecycleStarting,
|
||
codersdk.WorkspaceAgentLifecycleReady,
|
||
codersdk.WorkspaceAgentLifecycleShuttingDown,
|
||
codersdk.WorkspaceAgentLifecycleShutdownError,
|
||
}
|
||
|
||
var got []codersdk.WorkspaceAgentLifecycle
|
||
assert.Eventually(t, func() bool {
|
||
got = client.GetLifecycleStates()
|
||
return slices.Contains(got, want[len(want)-1])
|
||
}, testutil.WaitShort, testutil.IntervalMedium)
|
||
|
||
require.Equal(t, want, got[:len(want)])
|
||
})
|
||
|
||
t.Run("ShutdownScriptOnce", func(t *testing.T) {
|
||
t.Parallel()
|
||
logger := testutil.Logger(t)
|
||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||
expected := "this-is-shutdown"
|
||
derpMap, _ := tailnettest.RunDERPAndSTUN(t)
|
||
statsCh := make(chan *proto.Stats, 50)
|
||
|
||
client := agenttest.NewClient(t,
|
||
logger,
|
||
uuid.New(),
|
||
agentsdk.Manifest{
|
||
DERPMap: derpMap,
|
||
Scripts: []codersdk.WorkspaceAgentScript{{
|
||
ID: uuid.New(),
|
||
LogPath: "coder-startup-script.log",
|
||
Script: "echo 1",
|
||
RunOnStart: true,
|
||
}, {
|
||
ID: uuid.New(),
|
||
LogPath: "coder-shutdown-script.log",
|
||
Script: "echo " + expected,
|
||
RunOnStop: true,
|
||
}},
|
||
},
|
||
statsCh,
|
||
tailnet.NewCoordinator(logger),
|
||
)
|
||
defer client.Close()
|
||
|
||
fs := afero.NewMemMapFs()
|
||
agent := agent.New(agent.Options{
|
||
Client: client,
|
||
Logger: logger.Named("agent"),
|
||
Filesystem: fs,
|
||
})
|
||
|
||
// agent.Close() loads the shutdown script from the agent metadata.
|
||
// The metadata is populated just before execution of the startup script, so it's mandatory to wait
|
||
// until the startup starts.
|
||
require.Eventually(t, func() bool {
|
||
outputPath := filepath.Join(os.TempDir(), "coder-startup-script.log")
|
||
content, err := afero.ReadFile(fs, outputPath)
|
||
if err != nil {
|
||
t.Logf("read file %q: %s", outputPath, err)
|
||
return false
|
||
}
|
||
return len(content) > 0 // something is in the startup log file
|
||
}, testutil.WaitShort, testutil.IntervalMedium)
|
||
|
||
// In order to avoid shutting down the agent before it is fully started and triggering
|
||
// errors, we'll wait until the agent is fully up. It's a bit hokey, but among the last things the agent starts
|
||
// is the stats reporting, so getting a stats report is a good indication the agent is fully up.
|
||
_ = testutil.TryReceive(ctx, t, statsCh)
|
||
|
||
err := agent.Close()
|
||
require.NoError(t, err, "agent should be closed successfully")
|
||
|
||
outputPath := filepath.Join(os.TempDir(), "coder-shutdown-script.log")
|
||
logFirstRead, err := afero.ReadFile(fs, outputPath)
|
||
require.NoError(t, err, "log file should be present")
|
||
require.Equal(t, expected, string(bytes.TrimSpace(logFirstRead)))
|
||
|
||
// Make sure that script can't be executed twice.
|
||
err = agent.Close()
|
||
require.NoError(t, err, "don't need to close the agent twice, no effect")
|
||
|
||
logSecondRead, err := afero.ReadFile(fs, outputPath)
|
||
require.NoError(t, err, "log file should be present")
|
||
require.Equal(t, string(bytes.TrimSpace(logFirstRead)), string(bytes.TrimSpace(logSecondRead)))
|
||
})
|
||
}
|
||
|
||
func TestAgent_Startup(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
t.Run("EmptyDirectory", func(t *testing.T) {
|
||
t.Parallel()
|
||
ctx := testutil.Context(t, testutil.WaitShort)
|
||
|
||
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
|
||
Directory: "",
|
||
}, 0)
|
||
startup := testutil.TryReceive(ctx, t, client.GetStartup())
|
||
require.Equal(t, "", startup.GetExpandedDirectory())
|
||
})
|
||
|
||
t.Run("HomeDirectory", func(t *testing.T) {
|
||
t.Parallel()
|
||
ctx := testutil.Context(t, testutil.WaitShort)
|
||
|
||
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
|
||
Directory: "~",
|
||
}, 0)
|
||
startup := testutil.TryReceive(ctx, t, client.GetStartup())
|
||
homeDir, err := os.UserHomeDir()
|
||
require.NoError(t, err)
|
||
require.Equal(t, homeDir, startup.GetExpandedDirectory())
|
||
})
|
||
|
||
t.Run("NotAbsoluteDirectory", func(t *testing.T) {
|
||
t.Parallel()
|
||
ctx := testutil.Context(t, testutil.WaitShort)
|
||
|
||
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
|
||
Directory: "coder/coder",
|
||
}, 0)
|
||
startup := testutil.TryReceive(ctx, t, client.GetStartup())
|
||
homeDir, err := os.UserHomeDir()
|
||
require.NoError(t, err)
|
||
require.Equal(t, filepath.Join(homeDir, "coder/coder"), startup.GetExpandedDirectory())
|
||
})
|
||
|
||
t.Run("HomeEnvironmentVariable", func(t *testing.T) {
|
||
t.Parallel()
|
||
ctx := testutil.Context(t, testutil.WaitShort)
|
||
|
||
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
|
||
Directory: "$HOME",
|
||
}, 0)
|
||
startup := testutil.TryReceive(ctx, t, client.GetStartup())
|
||
homeDir, err := os.UserHomeDir()
|
||
require.NoError(t, err)
|
||
require.Equal(t, homeDir, startup.GetExpandedDirectory())
|
||
})
|
||
}
|
||
|
||
//nolint:paralleltest // This test sets an environment variable.
|
||
func TestAgent_ReconnectingPTY(t *testing.T) {
|
||
if runtime.GOOS == "windows" {
|
||
// This might be our implementation, or ConPTY itself.
|
||
// It's difficult to find extensive tests for it, so
|
||
// it seems like it could be either.
|
||
t.Skip("ConPTY appears to be inconsistent on Windows.")
|
||
}
|
||
|
||
backends := []string{"Buffered", "Screen"}
|
||
|
||
_, err := exec.LookPath("screen")
|
||
hasScreen := err == nil
|
||
|
||
// Make sure UTF-8 works even with LANG set to something like C.
|
||
t.Setenv("LANG", "C")
|
||
|
||
for _, backendType := range backends {
|
||
t.Run(backendType, func(t *testing.T) {
|
||
if backendType == "Screen" {
|
||
if runtime.GOOS != "linux" {
|
||
t.Skipf("`screen` is not supported on %s", runtime.GOOS)
|
||
} else if !hasScreen {
|
||
t.Skip("`screen` not found")
|
||
}
|
||
} else if hasScreen && runtime.GOOS == "linux" {
|
||
// Set up a PATH that does not have screen in it.
|
||
bashPath, err := exec.LookPath("bash")
|
||
require.NoError(t, err)
|
||
dir, err := os.MkdirTemp("/tmp", "coder-test-reconnecting-pty-PATH")
|
||
require.NoError(t, err, "create temp dir for reconnecting pty PATH")
|
||
err = os.Symlink(bashPath, filepath.Join(dir, "bash"))
|
||
require.NoError(t, err, "symlink bash into reconnecting pty PATH")
|
||
t.Setenv("PATH", dir)
|
||
}
|
||
|
||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||
defer cancel()
|
||
|
||
//nolint:dogsled
|
||
conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
|
||
idConnectionReport := uuid.New()
|
||
id := uuid.New()
|
||
|
||
// Test that the connection is reported. This must be tested in the
|
||
// first connection because we care about verifying all of these.
|
||
netConn0, err := conn.ReconnectingPTY(ctx, idConnectionReport, 80, 80, "bash --norc")
|
||
require.NoError(t, err)
|
||
_ = netConn0.Close()
|
||
assertConnectionReport(t, agentClient, proto.Connection_RECONNECTING_PTY, 0, "")
|
||
|
||
// --norc disables executing .bashrc, which is often used to customize the bash prompt
|
||
netConn1, err := conn.ReconnectingPTY(ctx, id, 80, 80, "bash --norc")
|
||
require.NoError(t, err)
|
||
defer netConn1.Close()
|
||
tr1 := testutil.NewTerminalReader(t, netConn1)
|
||
|
||
// A second simultaneous connection.
|
||
netConn2, err := conn.ReconnectingPTY(ctx, id, 80, 80, "bash --norc")
|
||
require.NoError(t, err)
|
||
defer netConn2.Close()
|
||
tr2 := testutil.NewTerminalReader(t, netConn2)
|
||
|
||
matchPrompt := func(line string) bool {
|
||
return strings.Contains(line, "$ ") || strings.Contains(line, "# ")
|
||
}
|
||
matchEchoCommand := func(line string) bool {
|
||
return strings.Contains(line, "echo test")
|
||
}
|
||
matchEchoOutput := func(line string) bool {
|
||
return strings.Contains(line, "test") && !strings.Contains(line, "echo")
|
||
}
|
||
matchExitCommand := func(line string) bool {
|
||
return strings.Contains(line, "exit")
|
||
}
|
||
matchExitOutput := func(line string) bool {
|
||
return strings.Contains(line, "exit") || strings.Contains(line, "logout")
|
||
}
|
||
|
||
// Wait for the prompt before writing commands. If the command
|
||
// arrives before the prompt is written, screen will sometimes put
|
||
// the command output on the same line as the command and the test
|
||
// will flake.
|
||
require.NoError(t, tr1.ReadUntil(ctx, matchPrompt), "find prompt")
|
||
require.NoError(t, tr2.ReadUntil(ctx, matchPrompt), "find prompt")
|
||
|
||
data, err := json.Marshal(workspacesdk.ReconnectingPTYRequest{
|
||
Data: "printf '%s\\n' \"$TERM\"\r",
|
||
})
|
||
require.NoError(t, err)
|
||
_, err = netConn1.Write(data)
|
||
require.NoError(t, err)
|
||
require.NoError(t, tr1.ReadUntilString(ctx, "xterm-256color"), "find TERM output")
|
||
require.NoError(t, tr2.ReadUntilString(ctx, "xterm-256color"), "find TERM output")
|
||
require.NoError(t, tr1.ReadUntil(ctx, matchPrompt), "find prompt")
|
||
require.NoError(t, tr2.ReadUntil(ctx, matchPrompt), "find prompt")
|
||
|
||
data, err = json.Marshal(workspacesdk.ReconnectingPTYRequest{
|
||
Data: "echo test\r",
|
||
})
|
||
require.NoError(t, err)
|
||
_, err = netConn1.Write(data)
|
||
require.NoError(t, err)
|
||
|
||
// Once for typing the command...
|
||
require.NoError(t, tr1.ReadUntil(ctx, matchEchoCommand), "find echo command")
|
||
// And another time for the actual output.
|
||
require.NoError(t, tr1.ReadUntil(ctx, matchEchoOutput), "find echo output")
|
||
|
||
// Same for the other connection.
|
||
require.NoError(t, tr2.ReadUntil(ctx, matchEchoCommand), "find echo command")
|
||
require.NoError(t, tr2.ReadUntil(ctx, matchEchoOutput), "find echo output")
|
||
|
||
_ = netConn1.Close()
|
||
_ = netConn2.Close()
|
||
netConn3, err := conn.ReconnectingPTY(ctx, id, 80, 80, "bash --norc")
|
||
require.NoError(t, err)
|
||
defer netConn3.Close()
|
||
tr3 := testutil.NewTerminalReader(t, netConn3)
|
||
|
||
// Same output again!
|
||
require.NoError(t, tr3.ReadUntil(ctx, matchEchoCommand), "find echo command")
|
||
require.NoError(t, tr3.ReadUntil(ctx, matchEchoOutput), "find echo output")
|
||
|
||
// Exit should cause the connection to close.
|
||
data, err = json.Marshal(workspacesdk.ReconnectingPTYRequest{
|
||
Data: "exit\r",
|
||
})
|
||
require.NoError(t, err)
|
||
_, err = netConn3.Write(data)
|
||
require.NoError(t, err)
|
||
|
||
// Once for the input and again for the output.
|
||
require.NoError(t, tr3.ReadUntil(ctx, matchExitCommand), "find exit command")
|
||
require.NoError(t, tr3.ReadUntil(ctx, matchExitOutput), "find exit output")
|
||
|
||
// Wait for the connection to close.
|
||
require.ErrorIs(t, tr3.ReadUntil(ctx, nil), io.EOF)
|
||
|
||
// Try a non-shell command. It should output then immediately exit.
|
||
netConn4, err := conn.ReconnectingPTY(ctx, uuid.New(), 80, 80, "echo test")
|
||
require.NoError(t, err)
|
||
defer netConn4.Close()
|
||
|
||
tr4 := testutil.NewTerminalReader(t, netConn4)
|
||
require.NoError(t, tr4.ReadUntil(ctx, matchEchoOutput), "find echo output")
|
||
require.ErrorIs(t, tr4.ReadUntil(ctx, nil), io.EOF)
|
||
|
||
// Ensure that UTF-8 is supported. Avoid the terminal emulator because it
|
||
// does not appear to support UTF-8, just make sure the bytes that come
|
||
// back have the character in it.
|
||
netConn5, err := conn.ReconnectingPTY(ctx, uuid.New(), 80, 80, "echo ❯")
|
||
require.NoError(t, err)
|
||
defer netConn5.Close()
|
||
|
||
bytes, err := io.ReadAll(netConn5)
|
||
require.NoError(t, err)
|
||
require.Contains(t, string(bytes), "❯")
|
||
})
|
||
}
|
||
}
|
||
|
||
// This tests end-to-end functionality of connecting to a running container
|
||
// and executing a command. It creates a real Docker container and runs a
|
||
// command. As such, it does not run by default in CI.
|
||
// You can run it manually as follows:
|
||
//
|
||
// CODER_TEST_USE_DOCKER=1 go test -count=1 ./agent -run TestAgent_ReconnectingPTYContainer
|
||
func TestAgent_ReconnectingPTYContainer(t *testing.T) {
|
||
t.Parallel()
|
||
if os.Getenv("CODER_TEST_USE_DOCKER") != "1" {
|
||
t.Skip("Set CODER_TEST_USE_DOCKER=1 to run this test")
|
||
}
|
||
if _, err := exec.LookPath("devcontainer"); err != nil {
|
||
t.Skip("This test requires the devcontainer CLI: npm install -g @devcontainers/cli")
|
||
}
|
||
|
||
pool, err := dockertest.NewPool("")
|
||
require.NoError(t, err, "Could not connect to docker")
|
||
ct, err := pool.RunWithOptions(&dockertest.RunOptions{
|
||
Repository: "busybox",
|
||
Tag: "latest",
|
||
Cmd: []string{"sleep", "infnity"},
|
||
}, func(config *docker.HostConfig) {
|
||
config.AutoRemove = true
|
||
config.RestartPolicy = docker.RestartPolicy{Name: "no"}
|
||
})
|
||
require.NoError(t, err, "Could not start container")
|
||
defer func() {
|
||
err := pool.Purge(ct)
|
||
require.NoError(t, err, "Could not stop container")
|
||
}()
|
||
// Wait for container to start
|
||
require.Eventually(t, func() bool {
|
||
ct, ok := pool.ContainerByName(ct.Container.Name)
|
||
return ok && ct.Container.State.Running
|
||
}, testutil.WaitShort, testutil.IntervalSlow, "Container did not start in time")
|
||
|
||
// nolint: dogsled
|
||
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||
o.Devcontainers = true
|
||
o.DevcontainerAPIOptions = append(o.DevcontainerAPIOptions,
|
||
agentcontainers.WithContainerLabelIncludeFilter("this.label.does.not.exist.ignore.devcontainers", "true"),
|
||
)
|
||
})
|
||
ctx := testutil.Context(t, testutil.WaitLong)
|
||
ac, err := conn.ReconnectingPTY(ctx, uuid.New(), 80, 80, "/bin/sh", func(arp *workspacesdk.AgentReconnectingPTYInit) {
|
||
arp.Container = ct.Container.ID
|
||
})
|
||
require.NoError(t, err, "failed to create ReconnectingPTY")
|
||
defer ac.Close()
|
||
tr := testutil.NewTerminalReader(t, ac)
|
||
|
||
require.NoError(t, tr.ReadUntil(ctx, func(line string) bool {
|
||
return strings.Contains(line, "#") || strings.Contains(line, "$")
|
||
}), "find prompt")
|
||
|
||
require.NoError(t, json.NewEncoder(ac).Encode(workspacesdk.ReconnectingPTYRequest{
|
||
Data: "hostname\r",
|
||
}), "write hostname")
|
||
require.NoError(t, tr.ReadUntil(ctx, func(line string) bool {
|
||
return strings.Contains(line, "hostname")
|
||
}), "find hostname command")
|
||
|
||
require.NoError(t, tr.ReadUntil(ctx, func(line string) bool {
|
||
return strings.Contains(line, ct.Container.Config.Hostname)
|
||
}), "find hostname output")
|
||
require.NoError(t, json.NewEncoder(ac).Encode(workspacesdk.ReconnectingPTYRequest{
|
||
Data: "exit\r",
|
||
}), "write exit command")
|
||
|
||
// Wait for the connection to close.
|
||
require.ErrorIs(t, tr.ReadUntil(ctx, nil), io.EOF)
|
||
}
|
||
|
||
type subAgentRequestPayload struct {
|
||
Token string `json:"token"`
|
||
Directory string `json:"directory"`
|
||
}
|
||
|
||
// runSubAgentMain is the main function for the sub-agent that connects
|
||
// to the control plane. It reads the CODER_AGENT_URL and
|
||
// CODER_AGENT_TOKEN environment variables, sends the token, and exits
|
||
// with a status code based on the response.
|
||
func runSubAgentMain() int {
|
||
url := os.Getenv("CODER_AGENT_URL")
|
||
token := os.Getenv("CODER_AGENT_TOKEN")
|
||
if url == "" || token == "" {
|
||
_, _ = fmt.Fprintln(os.Stderr, "CODER_AGENT_URL and CODER_AGENT_TOKEN must be set")
|
||
return 10
|
||
}
|
||
|
||
dir, err := os.Getwd()
|
||
if err != nil {
|
||
_, _ = fmt.Fprintf(os.Stderr, "failed to get current working directory: %v\n", err)
|
||
return 1
|
||
}
|
||
payload := subAgentRequestPayload{
|
||
Token: token,
|
||
Directory: dir,
|
||
}
|
||
b, err := json.Marshal(payload)
|
||
if err != nil {
|
||
_, _ = fmt.Fprintf(os.Stderr, "failed to marshal payload: %v\n", err)
|
||
return 1
|
||
}
|
||
|
||
req, err := http.NewRequest("POST", url, bytes.NewReader(b))
|
||
if err != nil {
|
||
_, _ = fmt.Fprintf(os.Stderr, "failed to create request: %v\n", err)
|
||
return 1
|
||
}
|
||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||
defer cancel()
|
||
req = req.WithContext(ctx)
|
||
client := &http.Client{}
|
||
resp, err := client.Do(req)
|
||
if err != nil {
|
||
_, _ = fmt.Fprintf(os.Stderr, "agent connection failed: %v\n", err)
|
||
return 11
|
||
}
|
||
defer resp.Body.Close()
|
||
if resp.StatusCode != http.StatusOK {
|
||
_, _ = fmt.Fprintf(os.Stderr, "agent exiting with non-zero exit code %d\n", resp.StatusCode)
|
||
return 12
|
||
}
|
||
_, _ = fmt.Println("sub-agent connected successfully")
|
||
return 0
|
||
}
|
||
|
||
// This tests end-to-end functionality of auto-starting a devcontainer.
|
||
// It runs "devcontainer up" which creates a real Docker container. As
|
||
// such, it does not run by default in CI.
|
||
//
|
||
// You can run it manually as follows:
|
||
//
|
||
// CODER_TEST_USE_DOCKER=1 go test -count=1 ./agent -run TestAgent_DevcontainerAutostart
|
||
//
|
||
//nolint:paralleltest // This test sets an environment variable.
|
||
func TestAgent_DevcontainerAutostart(t *testing.T) {
|
||
if os.Getenv("CODER_TEST_USE_DOCKER") != "1" {
|
||
t.Skip("Set CODER_TEST_USE_DOCKER=1 to run this test")
|
||
}
|
||
if _, err := exec.LookPath("devcontainer"); err != nil {
|
||
t.Skip("This test requires the devcontainer CLI: npm install -g @devcontainers/cli")
|
||
}
|
||
|
||
// This HTTP handler handles requests from runSubAgentMain which
|
||
// acts as a fake sub-agent. We want to verify that the sub-agent
|
||
// connects and sends its token. We use a channel to signal
|
||
// that the sub-agent has connected successfully and then we wait
|
||
// until we receive another signal to return from the handler. This
|
||
// keeps the agent "alive" for as long as we want.
|
||
subAgentConnected := make(chan subAgentRequestPayload, 1)
|
||
subAgentReady := make(chan struct{}, 1)
|
||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method == http.MethodGet && strings.HasPrefix(r.URL.Path, "/api/v2/workspaceagents/me/") {
|
||
return
|
||
}
|
||
|
||
t.Logf("Sub-agent request received: %s %s", r.Method, r.URL.Path)
|
||
|
||
if r.Method != http.MethodPost {
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
// Read the token from the request body.
|
||
var payload subAgentRequestPayload
|
||
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
|
||
http.Error(w, "Failed to read token", http.StatusBadRequest)
|
||
t.Logf("Failed to read token: %v", err)
|
||
return
|
||
}
|
||
defer r.Body.Close()
|
||
|
||
t.Logf("Sub-agent request payload received: %+v", payload)
|
||
|
||
// Signal that the sub-agent has connected successfully.
|
||
select {
|
||
case <-t.Context().Done():
|
||
t.Logf("Test context done, not processing sub-agent request")
|
||
return
|
||
case subAgentConnected <- payload:
|
||
}
|
||
|
||
// Wait for the signal to return from the handler.
|
||
select {
|
||
case <-t.Context().Done():
|
||
t.Logf("Test context done, not waiting for sub-agent ready")
|
||
return
|
||
case <-subAgentReady:
|
||
}
|
||
|
||
w.WriteHeader(http.StatusOK)
|
||
}))
|
||
defer srv.Close()
|
||
|
||
pool, err := dockertest.NewPool("")
|
||
require.NoError(t, err, "Could not connect to docker")
|
||
|
||
// Prepare temporary devcontainer for test (mywork).
|
||
devcontainerID := uuid.New()
|
||
tmpdir := t.TempDir()
|
||
t.Setenv("HOME", tmpdir)
|
||
tempWorkspaceFolder := filepath.Join(tmpdir, "mywork")
|
||
unexpandedWorkspaceFolder := filepath.Join("~", "mywork")
|
||
t.Logf("Workspace folder: %s", tempWorkspaceFolder)
|
||
t.Logf("Unexpanded workspace folder: %s", unexpandedWorkspaceFolder)
|
||
devcontainerPath := filepath.Join(tempWorkspaceFolder, ".devcontainer")
|
||
err = os.MkdirAll(devcontainerPath, 0o755)
|
||
require.NoError(t, err, "create devcontainer directory")
|
||
devcontainerFile := filepath.Join(devcontainerPath, "devcontainer.json")
|
||
err = os.WriteFile(devcontainerFile, []byte(`{
|
||
"name": "mywork",
|
||
"image": "ubuntu:latest",
|
||
"cmd": ["sleep", "infinity"],
|
||
"runArgs": ["--network=host", "--label=`+agentcontainers.DevcontainerIsTestRunLabel+`=true"]
|
||
}`), 0o600)
|
||
require.NoError(t, err, "write devcontainer.json")
|
||
|
||
manifest := agentsdk.Manifest{
|
||
// Set up pre-conditions for auto-starting a devcontainer, the script
|
||
// is expected to be prepared by the provisioner normally.
|
||
Devcontainers: []codersdk.WorkspaceAgentDevcontainer{
|
||
{
|
||
ID: devcontainerID,
|
||
Name: "test",
|
||
// Use an unexpanded path to test the expansion.
|
||
WorkspaceFolder: unexpandedWorkspaceFolder,
|
||
},
|
||
},
|
||
Scripts: []codersdk.WorkspaceAgentScript{
|
||
{
|
||
ID: devcontainerID,
|
||
LogSourceID: agentsdk.ExternalLogSourceID,
|
||
RunOnStart: true,
|
||
Script: "echo this-will-be-replaced",
|
||
DisplayName: "Dev Container (test)",
|
||
},
|
||
},
|
||
}
|
||
mClock := quartz.NewMock(t)
|
||
mClock.Set(time.Now())
|
||
tickerFuncTrap := mClock.Trap().TickerFunc("agentcontainers")
|
||
|
||
//nolint:dogsled
|
||
_, agentClient, _, _, _ := setupAgent(t, manifest, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||
o.Devcontainers = true
|
||
o.DevcontainerAPIOptions = append(
|
||
o.DevcontainerAPIOptions,
|
||
// Only match this specific dev container.
|
||
agentcontainers.WithClock(mClock),
|
||
agentcontainers.WithContainerLabelIncludeFilter("devcontainer.local_folder", tempWorkspaceFolder),
|
||
agentcontainers.WithContainerLabelIncludeFilter(agentcontainers.DevcontainerIsTestRunLabel, "true"),
|
||
agentcontainers.WithSubAgentURL(srv.URL),
|
||
// The agent will copy "itself", but in the case of this test, the
|
||
// agent is actually this test binary. So we'll tell the test binary
|
||
// to execute the sub-agent main function via this env.
|
||
agentcontainers.WithSubAgentEnv("CODER_TEST_RUN_SUB_AGENT_MAIN=1"),
|
||
)
|
||
})
|
||
|
||
t.Logf("Waiting for container with label: devcontainer.local_folder=%s", tempWorkspaceFolder)
|
||
|
||
var container docker.APIContainers
|
||
require.Eventually(t, func() bool {
|
||
containers, err := pool.Client.ListContainers(docker.ListContainersOptions{All: true})
|
||
if err != nil {
|
||
t.Logf("Error listing containers: %v", err)
|
||
return false
|
||
}
|
||
|
||
for _, c := range containers {
|
||
t.Logf("Found container: %s with labels: %v", c.ID[:12], c.Labels)
|
||
if labelValue, ok := c.Labels["devcontainer.local_folder"]; ok {
|
||
if labelValue == tempWorkspaceFolder {
|
||
t.Logf("Found matching container: %s", c.ID[:12])
|
||
container = c
|
||
return true
|
||
}
|
||
}
|
||
}
|
||
|
||
return false
|
||
}, testutil.WaitSuperLong, testutil.IntervalMedium, "no container with workspace folder label found")
|
||
defer func() {
|
||
// We can't rely on pool here because the container is not
|
||
// managed by it (it is managed by @devcontainer/cli).
|
||
err := pool.Client.RemoveContainer(docker.RemoveContainerOptions{
|
||
ID: container.ID,
|
||
RemoveVolumes: true,
|
||
Force: true,
|
||
})
|
||
assert.NoError(t, err, "remove container")
|
||
}()
|
||
|
||
containerInfo, err := pool.Client.InspectContainer(container.ID)
|
||
require.NoError(t, err, "inspect container")
|
||
t.Logf("Container state: status: %v", containerInfo.State.Status)
|
||
require.True(t, containerInfo.State.Running, "container should be running")
|
||
|
||
ctx := testutil.Context(t, testutil.WaitLong)
|
||
|
||
// Ensure the container update routine runs.
|
||
tickerFuncTrap.MustWait(ctx).MustRelease(ctx)
|
||
tickerFuncTrap.Close()
|
||
|
||
// Since the agent does RefreshContainers, and the ticker function
|
||
// is set to skip instead of queue, we must advance the clock
|
||
// multiple times to ensure that the sub-agent is created.
|
||
var subAgents []*proto.SubAgent
|
||
for {
|
||
_, next := mClock.AdvanceNext()
|
||
next.MustWait(ctx)
|
||
|
||
// Verify that a subagent was created.
|
||
subAgents = agentClient.GetSubAgents()
|
||
if len(subAgents) > 0 {
|
||
t.Logf("Found sub-agents: %d", len(subAgents))
|
||
break
|
||
}
|
||
}
|
||
require.Len(t, subAgents, 1, "expected one sub agent")
|
||
|
||
subAgent := subAgents[0]
|
||
subAgentID, err := uuid.FromBytes(subAgent.GetId())
|
||
require.NoError(t, err, "failed to parse sub-agent ID")
|
||
t.Logf("Connecting to sub-agent: %s (ID: %s)", subAgent.Name, subAgentID)
|
||
|
||
gotDir, err := agentClient.GetSubAgentDirectory(subAgentID)
|
||
require.NoError(t, err, "failed to get sub-agent directory")
|
||
require.Equal(t, "/workspaces/mywork", gotDir, "sub-agent directory should match")
|
||
|
||
subAgentToken, err := uuid.FromBytes(subAgent.GetAuthToken())
|
||
require.NoError(t, err, "failed to parse sub-agent token")
|
||
|
||
payload := testutil.RequireReceive(ctx, t, subAgentConnected)
|
||
require.Equal(t, subAgentToken.String(), payload.Token, "sub-agent token should match")
|
||
require.Equal(t, "/workspaces/mywork", payload.Directory, "sub-agent directory should match")
|
||
|
||
// Allow the subagent to exit.
|
||
close(subAgentReady)
|
||
}
|
||
|
||
// TestAgent_DevcontainerRecreate tests that RecreateDevcontainer
|
||
// recreates a devcontainer and emits logs.
|
||
//
|
||
// This tests end-to-end functionality of auto-starting a devcontainer.
|
||
// It runs "devcontainer up" which creates a real Docker container. As
|
||
// such, it does not run by default in CI.
|
||
//
|
||
// You can run it manually as follows:
|
||
//
|
||
// CODER_TEST_USE_DOCKER=1 go test -count=1 ./agent -run TestAgent_DevcontainerRecreate
|
||
func TestAgent_DevcontainerRecreate(t *testing.T) {
|
||
if os.Getenv("CODER_TEST_USE_DOCKER") != "1" {
|
||
t.Skip("Set CODER_TEST_USE_DOCKER=1 to run this test")
|
||
}
|
||
t.Parallel()
|
||
|
||
pool, err := dockertest.NewPool("")
|
||
require.NoError(t, err, "Could not connect to docker")
|
||
|
||
// Prepare temporary devcontainer for test (mywork).
|
||
devcontainerID := uuid.New()
|
||
devcontainerLogSourceID := uuid.New()
|
||
workspaceFolder := filepath.Join(t.TempDir(), "mywork")
|
||
t.Logf("Workspace folder: %s", workspaceFolder)
|
||
devcontainerPath := filepath.Join(workspaceFolder, ".devcontainer")
|
||
err = os.MkdirAll(devcontainerPath, 0o755)
|
||
require.NoError(t, err, "create devcontainer directory")
|
||
devcontainerFile := filepath.Join(devcontainerPath, "devcontainer.json")
|
||
err = os.WriteFile(devcontainerFile, []byte(`{
|
||
"name": "mywork",
|
||
"image": "busybox:latest",
|
||
"cmd": ["sleep", "infinity"],
|
||
"runArgs": ["--label=`+agentcontainers.DevcontainerIsTestRunLabel+`=true"]
|
||
}`), 0o600)
|
||
require.NoError(t, err, "write devcontainer.json")
|
||
|
||
manifest := agentsdk.Manifest{
|
||
// Set up pre-conditions for auto-starting a devcontainer, the
|
||
// script is used to extract the log source ID.
|
||
Devcontainers: []codersdk.WorkspaceAgentDevcontainer{
|
||
{
|
||
ID: devcontainerID,
|
||
Name: "test",
|
||
WorkspaceFolder: workspaceFolder,
|
||
},
|
||
},
|
||
Scripts: []codersdk.WorkspaceAgentScript{
|
||
{
|
||
ID: devcontainerID,
|
||
LogSourceID: devcontainerLogSourceID,
|
||
},
|
||
},
|
||
}
|
||
|
||
//nolint:dogsled
|
||
conn, client, _, _, _ := setupAgent(t, manifest, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||
o.Devcontainers = true
|
||
o.DevcontainerAPIOptions = append(o.DevcontainerAPIOptions,
|
||
agentcontainers.WithContainerLabelIncludeFilter("devcontainer.local_folder", workspaceFolder),
|
||
agentcontainers.WithContainerLabelIncludeFilter(agentcontainers.DevcontainerIsTestRunLabel, "true"),
|
||
)
|
||
})
|
||
|
||
ctx := testutil.Context(t, testutil.WaitLong)
|
||
|
||
// We enabled autostart for the devcontainer, so ready is a good
|
||
// indication that the devcontainer is up and running. Importantly,
|
||
// this also means that the devcontainer startup is no longer
|
||
// producing logs that may interfere with the recreate logs.
|
||
testutil.Eventually(ctx, t, func(context.Context) bool {
|
||
states := client.GetLifecycleStates()
|
||
return slices.Contains(states, codersdk.WorkspaceAgentLifecycleReady)
|
||
}, testutil.IntervalMedium, "devcontainer not ready")
|
||
|
||
t.Logf("Looking for container with label: devcontainer.local_folder=%s", workspaceFolder)
|
||
|
||
var container codersdk.WorkspaceAgentContainer
|
||
testutil.Eventually(ctx, t, func(context.Context) bool {
|
||
resp, err := conn.ListContainers(ctx)
|
||
if err != nil {
|
||
t.Logf("Error listing containers: %v", err)
|
||
return false
|
||
}
|
||
for _, c := range resp.Containers {
|
||
t.Logf("Found container: %s with labels: %v", c.ID[:12], c.Labels)
|
||
if v, ok := c.Labels["devcontainer.local_folder"]; ok && v == workspaceFolder {
|
||
t.Logf("Found matching container: %s", c.ID[:12])
|
||
container = c
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}, testutil.IntervalMedium, "no container with workspace folder label found")
|
||
defer func(container codersdk.WorkspaceAgentContainer) {
|
||
// We can't rely on pool here because the container is not
|
||
// managed by it (it is managed by @devcontainer/cli).
|
||
err := pool.Client.RemoveContainer(docker.RemoveContainerOptions{
|
||
ID: container.ID,
|
||
RemoveVolumes: true,
|
||
Force: true,
|
||
})
|
||
assert.Error(t, err, "container should be removed by recreate")
|
||
}(container)
|
||
|
||
ctx = testutil.Context(t, testutil.WaitLong) // Reset context.
|
||
|
||
// Capture logs via ScriptLogger.
|
||
logsCh := make(chan *proto.BatchCreateLogsRequest, 1)
|
||
client.SetLogsChannel(logsCh)
|
||
|
||
// Invoke recreate to trigger the destruction and recreation of the
|
||
// devcontainer, we do it in a goroutine so we can process logs
|
||
// concurrently.
|
||
go func(container codersdk.WorkspaceAgentContainer) {
|
||
_, err := conn.RecreateDevcontainer(ctx, devcontainerID.String())
|
||
assert.NoError(t, err, "recreate devcontainer should succeed")
|
||
}(container)
|
||
|
||
t.Logf("Checking recreate logs for outcome...")
|
||
|
||
// Wait for the logs to be emitted, the @devcontainer/cli up command
|
||
// will emit a log with the outcome at the end suggesting we did
|
||
// receive all the logs.
|
||
waitForOutcomeLoop:
|
||
for {
|
||
batch := testutil.RequireReceive(ctx, t, logsCh)
|
||
|
||
if bytes.Equal(batch.LogSourceId, devcontainerLogSourceID[:]) {
|
||
for _, log := range batch.Logs {
|
||
t.Logf("Received log: %s", log.Output)
|
||
if strings.Contains(log.Output, "\"outcome\"") {
|
||
break waitForOutcomeLoop
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
t.Logf("Checking there's a new container with label: devcontainer.local_folder=%s", workspaceFolder)
|
||
|
||
// Make sure the container exists and isn't the same as the old one.
|
||
testutil.Eventually(ctx, t, func(context.Context) bool {
|
||
resp, err := conn.ListContainers(ctx)
|
||
if err != nil {
|
||
t.Logf("Error listing containers: %v", err)
|
||
return false
|
||
}
|
||
for _, c := range resp.Containers {
|
||
t.Logf("Found container: %s with labels: %v", c.ID[:12], c.Labels)
|
||
if v, ok := c.Labels["devcontainer.local_folder"]; ok && v == workspaceFolder {
|
||
if c.ID == container.ID {
|
||
t.Logf("Found same container: %s", c.ID[:12])
|
||
return false
|
||
}
|
||
t.Logf("Found new container: %s", c.ID[:12])
|
||
container = c
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}, testutil.IntervalMedium, "new devcontainer not found")
|
||
defer func(container codersdk.WorkspaceAgentContainer) {
|
||
// We can't rely on pool here because the container is not
|
||
// managed by it (it is managed by @devcontainer/cli).
|
||
err := pool.Client.RemoveContainer(docker.RemoveContainerOptions{
|
||
ID: container.ID,
|
||
RemoveVolumes: true,
|
||
Force: true,
|
||
})
|
||
assert.NoError(t, err, "remove container")
|
||
}(container)
|
||
}
|
||
|
||
func TestAgent_DevcontainersDisabledForSubAgent(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
// Create a manifest with a ParentID to make this a sub agent.
|
||
manifest := agentsdk.Manifest{
|
||
AgentID: uuid.New(),
|
||
ParentID: uuid.New(),
|
||
}
|
||
|
||
// Setup the agent with devcontainers enabled initially.
|
||
//nolint:dogsled
|
||
conn, _, _, _, _ := setupAgent(t, manifest, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||
o.Devcontainers = true
|
||
})
|
||
|
||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
|
||
defer cancel()
|
||
|
||
var err error
|
||
// setupAgent only waits for tailnet reachability, not for the HTTP API
|
||
// listener to serve the expected sub-agent rejection response.
|
||
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
||
_, err = conn.ListContainers(ctx)
|
||
if err != nil {
|
||
t.Logf("Error listing containers: %v", err)
|
||
}
|
||
return err != nil && strings.Contains(err.Error(), "Dev Container feature not supported.")
|
||
}, testutil.IntervalFast, "containers endpoint should reject devcontainers inside sub agents")
|
||
require.Error(t, err)
|
||
require.Contains(t, err.Error(), "Dev Container feature not supported.")
|
||
require.Contains(t, err.Error(), "Dev Container integration inside other Dev Containers is explicitly not supported.")
|
||
}
|
||
|
||
// TestAgent_DevcontainerPrebuildClaim tests that we correctly handle
|
||
// the claiming process for running devcontainers.
|
||
//
|
||
// You can run it manually as follows:
|
||
//
|
||
// CODER_TEST_USE_DOCKER=1 go test -count=1 ./agent -run TestAgent_DevcontainerPrebuildClaim
|
||
//
|
||
//nolint:paralleltest // This test sets an environment variable.
|
||
func TestAgent_DevcontainerPrebuildClaim(t *testing.T) {
|
||
if os.Getenv("CODER_TEST_USE_DOCKER") != "1" {
|
||
t.Skip("Set CODER_TEST_USE_DOCKER=1 to run this test")
|
||
}
|
||
if _, err := exec.LookPath("devcontainer"); err != nil {
|
||
t.Skip("This test requires the devcontainer CLI: npm install -g @devcontainers/cli")
|
||
}
|
||
|
||
pool, err := dockertest.NewPool("")
|
||
require.NoError(t, err, "Could not connect to docker")
|
||
|
||
var (
|
||
ctx = testutil.Context(t, testutil.WaitShort)
|
||
|
||
devcontainerID = uuid.New()
|
||
devcontainerLogSourceID = uuid.New()
|
||
|
||
workspaceFolder = filepath.Join(t.TempDir(), "project")
|
||
devcontainerPath = filepath.Join(workspaceFolder, ".devcontainer")
|
||
devcontainerConfig = filepath.Join(devcontainerPath, "devcontainer.json")
|
||
)
|
||
|
||
// Given: A devcontainer project.
|
||
t.Logf("Workspace folder: %s", workspaceFolder)
|
||
|
||
err = os.MkdirAll(devcontainerPath, 0o755)
|
||
require.NoError(t, err, "create dev container directory")
|
||
|
||
// Given: This devcontainer project specifies an app that uses the owner name and workspace name.
|
||
err = os.WriteFile(devcontainerConfig, []byte(`{
|
||
"name": "project",
|
||
"image": "busybox:latest",
|
||
"cmd": ["sleep", "infinity"],
|
||
"runArgs": ["--label=`+agentcontainers.DevcontainerIsTestRunLabel+`=true"],
|
||
"customizations": {
|
||
"coder": {
|
||
"apps": [{
|
||
"slug": "zed",
|
||
"url": "zed://ssh/${localEnv:CODER_WORKSPACE_AGENT_NAME}.${localEnv:CODER_WORKSPACE_NAME}.${localEnv:CODER_WORKSPACE_OWNER_NAME}.coder${containerWorkspaceFolder}"
|
||
}]
|
||
}
|
||
}
|
||
}`), 0o600)
|
||
require.NoError(t, err, "write devcontainer config")
|
||
|
||
// Given: A manifest with a prebuild username and workspace name.
|
||
manifest := agentsdk.Manifest{
|
||
OwnerName: "prebuilds",
|
||
WorkspaceName: "prebuilds-xyz-123",
|
||
|
||
Devcontainers: []codersdk.WorkspaceAgentDevcontainer{
|
||
{ID: devcontainerID, Name: "test", WorkspaceFolder: workspaceFolder},
|
||
},
|
||
Scripts: []codersdk.WorkspaceAgentScript{
|
||
{ID: devcontainerID, LogSourceID: devcontainerLogSourceID},
|
||
},
|
||
}
|
||
|
||
// When: We create an agent with devcontainers enabled.
|
||
//nolint:dogsled
|
||
conn, client, _, _, _ := setupAgent(t, manifest, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||
o.Devcontainers = true
|
||
o.DevcontainerAPIOptions = append(o.DevcontainerAPIOptions,
|
||
agentcontainers.WithContainerLabelIncludeFilter(agentcontainers.DevcontainerLocalFolderLabel, workspaceFolder),
|
||
agentcontainers.WithContainerLabelIncludeFilter(agentcontainers.DevcontainerIsTestRunLabel, "true"),
|
||
)
|
||
})
|
||
|
||
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
||
return slices.Contains(client.GetLifecycleStates(), codersdk.WorkspaceAgentLifecycleReady)
|
||
}, testutil.IntervalMedium, "agent not ready")
|
||
|
||
var dcPrebuild codersdk.WorkspaceAgentDevcontainer
|
||
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
||
resp, err := conn.ListContainers(ctx)
|
||
require.NoError(t, err)
|
||
|
||
for _, dc := range resp.Devcontainers {
|
||
if dc.Container == nil {
|
||
continue
|
||
}
|
||
|
||
v, ok := dc.Container.Labels[agentcontainers.DevcontainerLocalFolderLabel]
|
||
if ok && v == workspaceFolder {
|
||
dcPrebuild = dc
|
||
return true
|
||
}
|
||
}
|
||
|
||
return false
|
||
}, testutil.IntervalMedium, "devcontainer not found")
|
||
defer func() {
|
||
pool.Client.RemoveContainer(docker.RemoveContainerOptions{
|
||
ID: dcPrebuild.Container.ID,
|
||
RemoveVolumes: true,
|
||
Force: true,
|
||
})
|
||
}()
|
||
|
||
// Then: We expect a sub agent to have been created.
|
||
subAgents := client.GetSubAgents()
|
||
require.Len(t, subAgents, 1)
|
||
|
||
subAgent := subAgents[0]
|
||
subAgentID, err := uuid.FromBytes(subAgent.GetId())
|
||
require.NoError(t, err)
|
||
|
||
// And: We expect there to be 1 app.
|
||
subAgentApps, err := client.GetSubAgentApps(subAgentID)
|
||
require.NoError(t, err)
|
||
require.Len(t, subAgentApps, 1)
|
||
|
||
// And: This app should contain the prebuild workspace name and owner name.
|
||
subAgentApp := subAgentApps[0]
|
||
require.Equal(t, "zed://ssh/project.prebuilds-xyz-123.prebuilds.coder/workspaces/project", subAgentApp.GetUrl())
|
||
|
||
// Given: We close the client and connection
|
||
client.Close()
|
||
conn.Close()
|
||
|
||
// Given: A new manifest with a regular user owner name and workspace name.
|
||
manifest = agentsdk.Manifest{
|
||
OwnerName: "user",
|
||
WorkspaceName: "user-workspace",
|
||
|
||
Devcontainers: []codersdk.WorkspaceAgentDevcontainer{
|
||
{ID: devcontainerID, Name: "test", WorkspaceFolder: workspaceFolder},
|
||
},
|
||
Scripts: []codersdk.WorkspaceAgentScript{
|
||
{ID: devcontainerID, LogSourceID: devcontainerLogSourceID},
|
||
},
|
||
}
|
||
|
||
// When: We create an agent with devcontainers enabled.
|
||
//nolint:dogsled
|
||
conn, client, _, _, _ = setupAgent(t, manifest, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||
o.Devcontainers = true
|
||
o.DevcontainerAPIOptions = append(o.DevcontainerAPIOptions,
|
||
agentcontainers.WithContainerLabelIncludeFilter(agentcontainers.DevcontainerLocalFolderLabel, workspaceFolder),
|
||
agentcontainers.WithContainerLabelIncludeFilter(agentcontainers.DevcontainerIsTestRunLabel, "true"),
|
||
)
|
||
})
|
||
|
||
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
||
return slices.Contains(client.GetLifecycleStates(), codersdk.WorkspaceAgentLifecycleReady)
|
||
}, testutil.IntervalMedium, "agent not ready")
|
||
|
||
var dcClaimed codersdk.WorkspaceAgentDevcontainer
|
||
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
||
resp, err := conn.ListContainers(ctx)
|
||
require.NoError(t, err)
|
||
|
||
for _, dc := range resp.Devcontainers {
|
||
if dc.Container == nil {
|
||
continue
|
||
}
|
||
|
||
v, ok := dc.Container.Labels[agentcontainers.DevcontainerLocalFolderLabel]
|
||
if ok && v == workspaceFolder {
|
||
dcClaimed = dc
|
||
return true
|
||
}
|
||
}
|
||
|
||
return false
|
||
}, testutil.IntervalMedium, "devcontainer not found")
|
||
defer func() {
|
||
if dcClaimed.Container.ID != dcPrebuild.Container.ID {
|
||
pool.Client.RemoveContainer(docker.RemoveContainerOptions{
|
||
ID: dcClaimed.Container.ID,
|
||
RemoveVolumes: true,
|
||
Force: true,
|
||
})
|
||
}
|
||
}()
|
||
|
||
// Then: We expect the claimed devcontainer and prebuild devcontainer
|
||
// to be using the same underlying container.
|
||
require.Equal(t, dcPrebuild.Container.ID, dcClaimed.Container.ID)
|
||
|
||
// And: We expect there to be a sub agent created.
|
||
subAgents = client.GetSubAgents()
|
||
require.Len(t, subAgents, 1)
|
||
|
||
subAgent = subAgents[0]
|
||
subAgentID, err = uuid.FromBytes(subAgent.GetId())
|
||
require.NoError(t, err)
|
||
|
||
// And: We expect there to be an app.
|
||
subAgentApps, err = client.GetSubAgentApps(subAgentID)
|
||
require.NoError(t, err)
|
||
require.Len(t, subAgentApps, 1)
|
||
|
||
// And: We expect this app to have the user's owner name and workspace name.
|
||
subAgentApp = subAgentApps[0]
|
||
require.Equal(t, "zed://ssh/project.user-workspace.user.coder/workspaces/project", subAgentApp.GetUrl())
|
||
}
|
||
|
||
func TestAgent_Dial(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
cases := []struct {
|
||
name string
|
||
setup func(t testing.TB) net.Listener
|
||
}{
|
||
{
|
||
name: "TCP",
|
||
setup: func(t testing.TB) net.Listener {
|
||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||
require.NoError(t, err, "create TCP listener")
|
||
return l
|
||
},
|
||
},
|
||
{
|
||
name: "UDP",
|
||
setup: func(t testing.TB) net.Listener {
|
||
addr := net.UDPAddr{
|
||
IP: net.ParseIP("127.0.0.1"),
|
||
Port: 0,
|
||
}
|
||
l, err := udp.Listen("udp", &addr)
|
||
require.NoError(t, err, "create UDP listener")
|
||
return l
|
||
},
|
||
},
|
||
}
|
||
|
||
for _, c := range cases {
|
||
t.Run(c.name, func(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
// The purpose of this test is to ensure that a client can dial a
|
||
// listener in the workspace over tailnet.
|
||
//
|
||
// The OS sometimes drops packets if the system can't keep up with
|
||
// them. For TCP packets, it's typically fine due to
|
||
// retransmissions, but for UDP packets, it can fail this test.
|
||
//
|
||
// The OS gets involved for the Wireguard traffic (either via DERP
|
||
// or direct UDP), and also for the traffic between the agent and
|
||
// the listener in the "workspace".
|
||
//
|
||
// To avoid this, we'll retry this test up to 3 times.
|
||
//nolint:gocritic // This test is flaky due to uncontrollable OS packet drops under heavy load.
|
||
testutil.RunRetry(t, 3, func(t testing.TB) {
|
||
ctx := testutil.Context(t, testutil.WaitLong)
|
||
|
||
l := c.setup(t)
|
||
done := make(chan struct{})
|
||
defer func() {
|
||
l.Close()
|
||
<-done
|
||
}()
|
||
|
||
go func() {
|
||
defer close(done)
|
||
for range 2 {
|
||
c, err := l.Accept()
|
||
if assert.NoError(t, err, "accept connection") {
|
||
testAccept(ctx, t, c)
|
||
_ = c.Close()
|
||
}
|
||
}
|
||
}()
|
||
|
||
agentID := uuid.UUID{0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8}
|
||
//nolint:dogsled
|
||
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{
|
||
AgentID: agentID,
|
||
}, 0)
|
||
require.True(t, agentConn.AwaitReachable(ctx))
|
||
conn, err := agentConn.DialContext(ctx, l.Addr().Network(), l.Addr().String())
|
||
require.NoError(t, err)
|
||
testDial(ctx, t, conn)
|
||
err = conn.Close()
|
||
require.NoError(t, err)
|
||
|
||
// also connect via the CoderServicePrefix, to test that we can reach the agent on this
|
||
// IP. This will be required for CoderVPN.
|
||
_, rawPort, _ := net.SplitHostPort(l.Addr().String())
|
||
port, _ := strconv.ParseUint(rawPort, 10, 16)
|
||
ipp := netip.AddrPortFrom(tailnet.CoderServicePrefix.AddrFromUUID(agentID), uint16(port))
|
||
|
||
switch l.Addr().Network() {
|
||
case "tcp":
|
||
conn, err = agentConn.TailnetConn().DialContextTCP(ctx, ipp)
|
||
case "udp":
|
||
conn, err = agentConn.TailnetConn().DialContextUDP(ctx, ipp)
|
||
default:
|
||
t.Fatalf("unknown network: %s", l.Addr().Network())
|
||
}
|
||
require.NoError(t, err)
|
||
testDial(ctx, t, conn)
|
||
err = conn.Close()
|
||
require.NoError(t, err)
|
||
})
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestAgent_UpdatedDERP checks that agents can handle their DERP map being
|
||
// updated, and that clients can also handle it.
|
||
func TestAgent_UpdatedDERP(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
logger := testutil.Logger(t)
|
||
|
||
originalDerpMap, _ := tailnettest.RunDERPAndSTUN(t)
|
||
require.NotNil(t, originalDerpMap)
|
||
|
||
coordinator := tailnet.NewCoordinator(logger)
|
||
// use t.Cleanup so the coordinator closing doesn't deadlock with in-memory
|
||
// coordination
|
||
t.Cleanup(func() {
|
||
_ = coordinator.Close()
|
||
})
|
||
agentID := uuid.New()
|
||
statsCh := make(chan *proto.Stats, 50)
|
||
fs := afero.NewMemMapFs()
|
||
client := agenttest.NewClient(t,
|
||
logger.Named("agent"),
|
||
agentID,
|
||
agentsdk.Manifest{
|
||
DERPMap: originalDerpMap,
|
||
// Force DERP.
|
||
DisableDirectConnections: true,
|
||
},
|
||
statsCh,
|
||
coordinator,
|
||
)
|
||
t.Cleanup(func() {
|
||
t.Log("closing client")
|
||
client.Close()
|
||
})
|
||
uut := agent.New(agent.Options{
|
||
Client: client,
|
||
Filesystem: fs,
|
||
Logger: logger.Named("agent"),
|
||
ReconnectingPTYTimeout: time.Minute,
|
||
})
|
||
t.Cleanup(func() {
|
||
t.Log("closing agent")
|
||
_ = uut.Close()
|
||
})
|
||
|
||
// Setup a client connection.
|
||
newClientConn := func(derpMap *tailcfg.DERPMap, name string) workspacesdk.AgentConn {
|
||
conn, err := tailnet.NewConn(&tailnet.Options{
|
||
Addresses: []netip.Prefix{tailnet.TailscaleServicePrefix.RandomPrefix()},
|
||
DERPMap: derpMap,
|
||
Logger: logger.Named(name),
|
||
})
|
||
require.NoError(t, err)
|
||
t.Cleanup(func() {
|
||
t.Logf("closing conn %s", name)
|
||
_ = conn.Close()
|
||
})
|
||
testCtx, testCtxCancel := context.WithCancel(context.Background())
|
||
t.Cleanup(testCtxCancel)
|
||
clientID := uuid.New()
|
||
ctrl := tailnet.NewTunnelSrcCoordController(logger, conn)
|
||
ctrl.AddDestination(agentID)
|
||
auth := tailnet.ClientCoordinateeAuth{AgentID: agentID}
|
||
coordination := ctrl.New(tailnet.NewInMemoryCoordinatorClient(logger, clientID, auth, coordinator))
|
||
t.Cleanup(func() {
|
||
t.Logf("closing coordination %s", name)
|
||
cctx, ccancel := context.WithTimeout(testCtx, testutil.WaitShort)
|
||
defer ccancel()
|
||
err := coordination.Close(cctx)
|
||
if err != nil {
|
||
t.Logf("error closing in-memory coordination: %s", err.Error())
|
||
}
|
||
t.Logf("closed coordination %s", name)
|
||
})
|
||
// Force DERP.
|
||
conn.SetBlockEndpoints(true)
|
||
|
||
sdkConn := workspacesdk.NewAgentConn(conn, workspacesdk.AgentConnOptions{
|
||
AgentID: agentID,
|
||
CloseFunc: func() error { return workspacesdk.ErrSkipClose },
|
||
})
|
||
t.Cleanup(func() {
|
||
t.Logf("closing sdkConn %s", name)
|
||
_ = sdkConn.Close()
|
||
})
|
||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||
defer cancel()
|
||
if !sdkConn.AwaitReachable(ctx) {
|
||
t.Fatal("agent not reachable")
|
||
}
|
||
|
||
return sdkConn
|
||
}
|
||
conn1 := newClientConn(originalDerpMap, "client1")
|
||
|
||
// Change the DERP map.
|
||
newDerpMap, _ := tailnettest.RunDERPAndSTUN(t)
|
||
require.NotNil(t, newDerpMap)
|
||
|
||
// Change the region ID.
|
||
newDerpMap.Regions[2] = newDerpMap.Regions[1]
|
||
delete(newDerpMap.Regions, 1)
|
||
newDerpMap.Regions[2].RegionID = 2
|
||
for _, node := range newDerpMap.Regions[2].Nodes {
|
||
node.RegionID = 2
|
||
}
|
||
|
||
// Push a new DERP map to the agent.
|
||
err := client.PushDERPMapUpdate(newDerpMap)
|
||
require.NoError(t, err)
|
||
t.Log("pushed DERPMap update to agent")
|
||
|
||
require.Eventually(t, func() bool {
|
||
conn := uut.TailnetConn()
|
||
if conn == nil {
|
||
return false
|
||
}
|
||
regionIDs := conn.DERPMap().RegionIDs()
|
||
preferredDERP := conn.Node().PreferredDERP
|
||
t.Logf("agent Conn DERPMap with regionIDs %v, PreferredDERP %d", regionIDs, preferredDERP)
|
||
return len(regionIDs) == 1 && regionIDs[0] == 2 && preferredDERP == 2
|
||
}, testutil.WaitLong, testutil.IntervalFast)
|
||
t.Log("agent got the new DERPMap")
|
||
|
||
// Connect from a second client and make sure it uses the new DERP map.
|
||
conn2 := newClientConn(newDerpMap, "client2")
|
||
require.Equal(t, []int{2}, conn2.TailnetConn().DERPMap().RegionIDs())
|
||
t.Log("conn2 got the new DERPMap")
|
||
|
||
// If the first client gets a DERP map update, it should be able to
|
||
// reconnect just fine.
|
||
conn1.TailnetConn().SetDERPMap(newDerpMap)
|
||
require.Equal(t, []int{2}, conn1.TailnetConn().DERPMap().RegionIDs())
|
||
t.Log("set the new DERPMap on conn1")
|
||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||
defer cancel()
|
||
require.True(t, conn1.AwaitReachable(ctx))
|
||
t.Log("conn1 reached agent with new DERP")
|
||
}
|
||
|
||
func TestAgent_Speedtest(t *testing.T) {
|
||
t.Parallel()
|
||
t.Skip("This test is relatively flakey because of Tailscale's speedtest code...")
|
||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||
defer cancel()
|
||
derpMap, _ := tailnettest.RunDERPAndSTUN(t)
|
||
//nolint:dogsled
|
||
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{
|
||
DERPMap: derpMap,
|
||
}, 0, func(client *agenttest.Client, options *agent.Options) {
|
||
options.Logger = logger.Named("agent")
|
||
})
|
||
defer conn.Close()
|
||
res, err := conn.Speedtest(ctx, speedtest.Upload, 250*time.Millisecond)
|
||
require.NoError(t, err)
|
||
t.Logf("%.2f MBits/s", res[len(res)-1].MBitsPerSecond())
|
||
}
|
||
|
||
func TestAgent_Reconnect(t *testing.T) {
|
||
t.Parallel()
|
||
ctx := testutil.Context(t, testutil.WaitLong)
|
||
logger := testutil.Logger(t)
|
||
// After the agent is disconnected from a coordinator, it's supposed
|
||
// to reconnect!
|
||
fCoordinator := tailnettest.NewFakeCoordinator()
|
||
|
||
agentID := uuid.New()
|
||
statsCh := make(chan *proto.Stats, 50)
|
||
derpMap, _ := tailnettest.RunDERPAndSTUN(t)
|
||
client := agenttest.NewClient(t,
|
||
logger,
|
||
agentID,
|
||
agentsdk.Manifest{
|
||
DERPMap: derpMap,
|
||
Directory: "/test/workspace",
|
||
},
|
||
statsCh,
|
||
fCoordinator,
|
||
)
|
||
defer client.Close()
|
||
|
||
closer := agent.New(agent.Options{
|
||
Client: client,
|
||
Logger: logger.Named("agent"),
|
||
})
|
||
defer closer.Close()
|
||
|
||
// Each iteration forces the agent to reconnect by closing
|
||
// the current coordinate call while the tracked HTTP server
|
||
// goroutine (from connection 1's createTailnet) is still
|
||
// alive, widening the race window.
|
||
const reconnections = 5
|
||
for i := range reconnections {
|
||
call := testutil.RequireReceive(ctx, t, fCoordinator.CoordinateCalls)
|
||
require.Equal(t, i+1, client.GetNumRefreshTokenCalls())
|
||
close(call.Resps) // hang up — triggers reconnect
|
||
}
|
||
// Verify final reconnect succeeds.
|
||
testutil.RequireReceive(ctx, t, fCoordinator.CoordinateCalls)
|
||
require.Equal(t, reconnections+1, client.GetNumRefreshTokenCalls())
|
||
closer.Close()
|
||
}
|
||
|
||
func TestAgent_ReconnectNoLifecycleReemit(t *testing.T) {
|
||
t.Parallel()
|
||
ctx := testutil.Context(t, testutil.WaitLong)
|
||
logger := testutil.Logger(t)
|
||
|
||
fCoordinator := tailnettest.NewFakeCoordinator()
|
||
agentID := uuid.New()
|
||
statsCh := make(chan *proto.Stats, 50)
|
||
derpMap, _ := tailnettest.RunDERPAndSTUN(t)
|
||
|
||
client := agenttest.NewClient(t,
|
||
logger,
|
||
agentID,
|
||
agentsdk.Manifest{
|
||
DERPMap: derpMap,
|
||
Scripts: []codersdk.WorkspaceAgentScript{{
|
||
Script: "echo hello",
|
||
Timeout: 30 * time.Second,
|
||
RunOnStart: true,
|
||
}},
|
||
},
|
||
statsCh,
|
||
fCoordinator,
|
||
)
|
||
defer client.Close()
|
||
|
||
closer := agent.New(agent.Options{
|
||
Client: client,
|
||
Logger: logger.Named("agent"),
|
||
})
|
||
defer closer.Close()
|
||
|
||
// Wait for the agent to reach Ready state.
|
||
require.Eventually(t, func() bool {
|
||
return slices.Contains(client.GetLifecycleStates(), codersdk.WorkspaceAgentLifecycleReady)
|
||
}, testutil.WaitShort, testutil.IntervalFast)
|
||
|
||
statesBefore := slices.Clone(client.GetLifecycleStates())
|
||
|
||
// Disconnect by closing the coordinator response channel.
|
||
call1 := testutil.RequireReceive(ctx, t, fCoordinator.CoordinateCalls)
|
||
close(call1.Resps)
|
||
|
||
// Wait for reconnect.
|
||
testutil.RequireReceive(ctx, t, fCoordinator.CoordinateCalls)
|
||
|
||
// Wait for a stats report as a deterministic steady-state proof.
|
||
testutil.RequireReceive(ctx, t, statsCh)
|
||
|
||
statesAfter := client.GetLifecycleStates()
|
||
require.Equal(t, statesBefore, statesAfter,
|
||
"lifecycle states should not be re-reported after reconnect")
|
||
|
||
closer.Close()
|
||
}
|
||
|
||
func TestAgent_WriteVSCodeConfigs(t *testing.T) {
|
||
t.Parallel()
|
||
logger := testutil.Logger(t)
|
||
coordinator := tailnet.NewCoordinator(logger)
|
||
defer coordinator.Close()
|
||
|
||
client := agenttest.NewClient(t,
|
||
logger,
|
||
uuid.New(),
|
||
agentsdk.Manifest{
|
||
GitAuthConfigs: 1,
|
||
DERPMap: &tailcfg.DERPMap{},
|
||
},
|
||
make(chan *proto.Stats, 50),
|
||
coordinator,
|
||
)
|
||
defer client.Close()
|
||
filesystem := afero.NewMemMapFs()
|
||
closer := agent.New(agent.Options{
|
||
Client: client,
|
||
Logger: logger.Named("agent"),
|
||
Filesystem: filesystem,
|
||
})
|
||
defer closer.Close()
|
||
|
||
home, err := os.UserHomeDir()
|
||
require.NoError(t, err)
|
||
name := filepath.Join(home, ".vscode-server", "data", "Machine", "settings.json")
|
||
require.Eventually(t, func() bool {
|
||
_, err := filesystem.Stat(name)
|
||
return err == nil
|
||
}, testutil.WaitShort, testutil.IntervalFast)
|
||
}
|
||
|
||
func TestAgent_DebugServer(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
logDir := t.TempDir()
|
||
logPath := filepath.Join(logDir, "coder-agent.log")
|
||
randLogStr, err := cryptorand.String(32)
|
||
require.NoError(t, err)
|
||
require.NoError(t, os.WriteFile(logPath, []byte(randLogStr), 0o600))
|
||
derpMap, _ := tailnettest.RunDERPAndSTUN(t)
|
||
//nolint:dogsled
|
||
conn, _, _, _, agnt := setupAgentWithSecrets(t, agentsdk.Manifest{
|
||
DERPMap: derpMap,
|
||
}, []agentsdk.WorkspaceSecret{
|
||
{EnvName: "DEBUG_SECRET", Value: []byte("super-secret-value-12345")},
|
||
}, 0, func(c *agenttest.Client, o *agent.Options) {
|
||
o.LogDir = logDir
|
||
})
|
||
|
||
awaitReachableCtx := testutil.Context(t, testutil.WaitLong)
|
||
ok := conn.AwaitReachable(awaitReachableCtx)
|
||
require.True(t, ok)
|
||
_ = conn.Close()
|
||
|
||
srv := httptest.NewServer(agnt.HTTPDebug())
|
||
t.Cleanup(srv.Close)
|
||
|
||
t.Run("MagicsockDebug", func(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
ctx := testutil.Context(t, testutil.WaitLong)
|
||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, srv.URL+"/debug/magicsock", nil)
|
||
require.NoError(t, err)
|
||
|
||
res, err := srv.Client().Do(req)
|
||
require.NoError(t, err)
|
||
defer res.Body.Close()
|
||
require.Equal(t, http.StatusOK, res.StatusCode)
|
||
|
||
resBody, err := io.ReadAll(res.Body)
|
||
require.NoError(t, err)
|
||
require.Contains(t, string(resBody), "<h1>magicsock</h1>")
|
||
})
|
||
|
||
t.Run("MagicsockDebugLogging", func(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
t.Run("Enable", func(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
ctx := testutil.Context(t, testutil.WaitLong)
|
||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, srv.URL+"/debug/magicsock/debug-logging/t", nil)
|
||
require.NoError(t, err)
|
||
|
||
res, err := srv.Client().Do(req)
|
||
require.NoError(t, err)
|
||
defer res.Body.Close()
|
||
require.Equal(t, http.StatusOK, res.StatusCode)
|
||
|
||
resBody, err := io.ReadAll(res.Body)
|
||
require.NoError(t, err)
|
||
require.Contains(t, string(resBody), "updated magicsock debug logging to true")
|
||
})
|
||
|
||
t.Run("Disable", func(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
ctx := testutil.Context(t, testutil.WaitLong)
|
||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, srv.URL+"/debug/magicsock/debug-logging/0", nil)
|
||
require.NoError(t, err)
|
||
|
||
res, err := srv.Client().Do(req)
|
||
require.NoError(t, err)
|
||
defer res.Body.Close()
|
||
require.Equal(t, http.StatusOK, res.StatusCode)
|
||
|
||
resBody, err := io.ReadAll(res.Body)
|
||
require.NoError(t, err)
|
||
require.Contains(t, string(resBody), "updated magicsock debug logging to false")
|
||
})
|
||
|
||
t.Run("Invalid", func(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
ctx := testutil.Context(t, testutil.WaitLong)
|
||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, srv.URL+"/debug/magicsock/debug-logging/blah", nil)
|
||
require.NoError(t, err)
|
||
|
||
res, err := srv.Client().Do(req)
|
||
require.NoError(t, err)
|
||
defer res.Body.Close()
|
||
require.Equal(t, http.StatusBadRequest, res.StatusCode)
|
||
|
||
resBody, err := io.ReadAll(res.Body)
|
||
require.NoError(t, err)
|
||
require.Contains(t, string(resBody), `invalid state "blah", must be a boolean`)
|
||
})
|
||
})
|
||
|
||
t.Run("Manifest", func(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
ctx := testutil.Context(t, testutil.WaitLong)
|
||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, srv.URL+"/debug/manifest", nil)
|
||
require.NoError(t, err)
|
||
|
||
res, err := srv.Client().Do(req)
|
||
require.NoError(t, err)
|
||
defer res.Body.Close()
|
||
require.Equal(t, http.StatusOK, res.StatusCode)
|
||
|
||
var v agentsdk.Manifest
|
||
require.NoError(t, json.NewDecoder(res.Body).Decode(&v))
|
||
require.NotNil(t, v)
|
||
})
|
||
|
||
t.Run("ManifestSecretsStripped", func(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
ctx := testutil.Context(t, testutil.WaitLong)
|
||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, srv.URL+"/debug/manifest", nil)
|
||
require.NoError(t, err)
|
||
|
||
res, err := srv.Client().Do(req)
|
||
require.NoError(t, err)
|
||
defer res.Body.Close()
|
||
require.Equal(t, http.StatusOK, res.StatusCode)
|
||
|
||
body, err := io.ReadAll(res.Body)
|
||
require.NoError(t, err)
|
||
|
||
// The response must not contain the secret value.
|
||
require.NotContains(t, string(body), "super-secret-value-12345")
|
||
|
||
// Confirm we can decode as a Manifest. The SDK type
|
||
// intentionally has no Secrets field, so there is nothing
|
||
// to leak through JSON encoding.
|
||
var v agentsdk.Manifest
|
||
require.NoError(t, json.Unmarshal(body, &v))
|
||
})
|
||
|
||
t.Run("Logs", func(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
ctx := testutil.Context(t, testutil.WaitLong)
|
||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, srv.URL+"/debug/logs", nil)
|
||
require.NoError(t, err)
|
||
|
||
res, err := srv.Client().Do(req)
|
||
require.NoError(t, err)
|
||
require.Equal(t, http.StatusOK, res.StatusCode)
|
||
defer res.Body.Close()
|
||
resBody, err := io.ReadAll(res.Body)
|
||
require.NoError(t, err)
|
||
require.NotEmpty(t, string(resBody))
|
||
require.Contains(t, string(resBody), randLogStr)
|
||
})
|
||
}
|
||
|
||
func TestAgent_ScriptLogging(t *testing.T) {
|
||
if runtime.GOOS == "windows" {
|
||
t.Skip("bash scripts only")
|
||
}
|
||
t.Parallel()
|
||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||
|
||
derpMap, _ := tailnettest.RunDERPAndSTUN(t)
|
||
logsCh := make(chan *proto.BatchCreateLogsRequest, 100)
|
||
lsStart := uuid.UUID{0x11}
|
||
lsStop := uuid.UUID{0x22}
|
||
//nolint:dogsled
|
||
_, _, _, _, agnt := setupAgent(
|
||
t,
|
||
agentsdk.Manifest{
|
||
DERPMap: derpMap,
|
||
Scripts: []codersdk.WorkspaceAgentScript{
|
||
{
|
||
LogSourceID: lsStart,
|
||
RunOnStart: true,
|
||
Script: `#!/bin/sh
|
||
i=0
|
||
while [ $i -ne 5 ]
|
||
do
|
||
i=$(($i+1))
|
||
echo "start $i"
|
||
done
|
||
`,
|
||
},
|
||
{
|
||
LogSourceID: lsStop,
|
||
RunOnStop: true,
|
||
Script: `#!/bin/sh
|
||
i=0
|
||
while [ $i -ne 3000 ]
|
||
do
|
||
i=$(($i+1))
|
||
echo "stop $i"
|
||
done
|
||
`, // send a lot of stop logs to make sure we don't truncate shutdown logs before closing the API conn
|
||
},
|
||
},
|
||
},
|
||
0,
|
||
func(cl *agenttest.Client, _ *agent.Options) {
|
||
cl.SetLogsChannel(logsCh)
|
||
},
|
||
)
|
||
|
||
n := 1
|
||
for n <= 5 {
|
||
logs := testutil.TryReceive(ctx, t, logsCh)
|
||
require.NotNil(t, logs)
|
||
for _, l := range logs.GetLogs() {
|
||
require.Equal(t, fmt.Sprintf("start %d", n), l.GetOutput())
|
||
n++
|
||
}
|
||
}
|
||
|
||
err := agnt.Close()
|
||
require.NoError(t, err)
|
||
|
||
n = 1
|
||
for n <= 3000 {
|
||
logs := testutil.TryReceive(ctx, t, logsCh)
|
||
require.NotNil(t, logs)
|
||
for _, l := range logs.GetLogs() {
|
||
require.Equal(t, fmt.Sprintf("stop %d", n), l.GetOutput())
|
||
n++
|
||
}
|
||
t.Logf("got %d stop logs", n-1)
|
||
}
|
||
}
|
||
|
||
// setupAgentSSHClient creates an agent, dials it, and sets up an ssh.Client for it
|
||
func setupAgentSSHClient(ctx context.Context, t *testing.T) *ssh.Client {
|
||
//nolint: dogsled
|
||
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
|
||
sshClient, err := agentConn.SSHClient(ctx)
|
||
require.NoError(t, err)
|
||
t.Cleanup(func() { sshClient.Close() })
|
||
return sshClient
|
||
}
|
||
|
||
func setupSSHSession(
|
||
t *testing.T,
|
||
manifest agentsdk.Manifest,
|
||
banner codersdk.BannerConfig,
|
||
prepareFS func(fs afero.Fs),
|
||
opts ...func(*agenttest.Client, *agent.Options),
|
||
) *ssh.Session {
|
||
return setupSSHSessionOnPort(t, manifest, banner, prepareFS, workspacesdk.AgentSSHPort, opts...)
|
||
}
|
||
|
||
func setupSSHSessionOnPort(
|
||
t *testing.T,
|
||
manifest agentsdk.Manifest,
|
||
banner codersdk.BannerConfig,
|
||
prepareFS func(fs afero.Fs),
|
||
port uint16,
|
||
opts ...func(*agenttest.Client, *agent.Options),
|
||
) *ssh.Session {
|
||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||
defer cancel()
|
||
opts = append(opts, func(c *agenttest.Client, o *agent.Options) {
|
||
c.SetAnnouncementBannersFunc(func() ([]codersdk.BannerConfig, error) {
|
||
return []codersdk.BannerConfig{banner}, nil
|
||
})
|
||
})
|
||
//nolint:dogsled
|
||
conn, _, _, fs, _ := setupAgent(t, manifest, 0, opts...)
|
||
if prepareFS != nil {
|
||
prepareFS(fs)
|
||
}
|
||
sshClient, err := conn.SSHClientOnPort(ctx, port)
|
||
require.NoError(t, err)
|
||
t.Cleanup(func() {
|
||
_ = sshClient.Close()
|
||
})
|
||
session, err := sshClient.NewSession()
|
||
require.NoError(t, err)
|
||
t.Cleanup(func() {
|
||
_ = session.Close()
|
||
})
|
||
return session
|
||
}
|
||
|
||
func setupAgent(t testing.TB, metadata agentsdk.Manifest, ptyTimeout time.Duration, opts ...func(*agenttest.Client, *agent.Options)) (
|
||
workspacesdk.AgentConn,
|
||
*agenttest.Client,
|
||
<-chan *proto.Stats,
|
||
afero.Fs,
|
||
agent.Agent,
|
||
) {
|
||
return setupAgentWithSecrets(t, metadata, nil, ptyTimeout, opts...)
|
||
}
|
||
|
||
// setupAgentWithSecrets is like setupAgent but also injects user
|
||
// secrets into the agent's proto manifest. Separate from setupAgent
|
||
// because agentsdk.Manifest intentionally does not carry secrets; see
|
||
// the Manifest doc comment in codersdk/agentsdk.
|
||
func setupAgentWithSecrets(t testing.TB, metadata agentsdk.Manifest, secrets []agentsdk.WorkspaceSecret, ptyTimeout time.Duration, opts ...func(*agenttest.Client, *agent.Options)) (
|
||
workspacesdk.AgentConn,
|
||
*agenttest.Client,
|
||
<-chan *proto.Stats,
|
||
afero.Fs,
|
||
agent.Agent,
|
||
) {
|
||
logger := slogtest.Make(t, &slogtest.Options{
|
||
// Agent can drop errors when shutting down, and some, like the
|
||
// fasthttplistener connection closed error, are unexported.
|
||
IgnoreErrors: true,
|
||
}).Leveled(slog.LevelDebug)
|
||
if metadata.DERPMap == nil {
|
||
metadata.DERPMap, _ = tailnettest.RunDERPAndSTUN(t)
|
||
}
|
||
if metadata.AgentID == uuid.Nil {
|
||
metadata.AgentID = uuid.New()
|
||
}
|
||
if metadata.AgentName == "" {
|
||
metadata.AgentName = "test-agent"
|
||
}
|
||
if metadata.WorkspaceName == "" {
|
||
metadata.WorkspaceName = "test-workspace"
|
||
}
|
||
if metadata.OwnerName == "" {
|
||
metadata.OwnerName = "test-user"
|
||
}
|
||
if metadata.WorkspaceID == uuid.Nil {
|
||
metadata.WorkspaceID = uuid.New()
|
||
}
|
||
coordinator := tailnet.NewCoordinator(logger)
|
||
t.Cleanup(func() {
|
||
_ = coordinator.Close()
|
||
})
|
||
statsCh := make(chan *proto.Stats, 50)
|
||
fs := afero.NewMemMapFs()
|
||
c := agenttest.NewClientWithSecrets(t, logger.Named("agenttest"), metadata.AgentID, metadata, secrets, statsCh, coordinator)
|
||
t.Cleanup(c.Close)
|
||
|
||
options := agent.Options{
|
||
Client: c,
|
||
Filesystem: fs,
|
||
Logger: logger.Named("agent"),
|
||
ReconnectingPTYTimeout: ptyTimeout,
|
||
EnvironmentVariables: map[string]string{},
|
||
}
|
||
|
||
for _, opt := range opts {
|
||
opt(c, &options)
|
||
}
|
||
|
||
agnt := agent.New(options)
|
||
t.Cleanup(func() {
|
||
_ = agnt.Close()
|
||
})
|
||
conn, err := tailnet.NewConn(&tailnet.Options{
|
||
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.TailscaleServicePrefix.RandomAddr(), 128)},
|
||
DERPMap: metadata.DERPMap,
|
||
Logger: logger.Named("client"),
|
||
})
|
||
require.NoError(t, err)
|
||
t.Cleanup(func() {
|
||
_ = conn.Close()
|
||
})
|
||
testCtx, testCtxCancel := context.WithCancel(context.Background())
|
||
t.Cleanup(testCtxCancel)
|
||
clientID := uuid.New()
|
||
ctrl := tailnet.NewTunnelSrcCoordController(logger, conn)
|
||
ctrl.AddDestination(metadata.AgentID)
|
||
auth := tailnet.ClientCoordinateeAuth{AgentID: metadata.AgentID}
|
||
coordination := ctrl.New(tailnet.NewInMemoryCoordinatorClient(
|
||
logger, clientID, auth, coordinator))
|
||
t.Cleanup(func() {
|
||
cctx, ccancel := context.WithTimeout(testCtx, testutil.WaitShort)
|
||
defer ccancel()
|
||
err := coordination.Close(cctx)
|
||
if err != nil {
|
||
t.Logf("error closing in-mem coordination: %s", err.Error())
|
||
}
|
||
})
|
||
agentConn := workspacesdk.NewAgentConn(conn, workspacesdk.AgentConnOptions{
|
||
AgentID: metadata.AgentID,
|
||
})
|
||
t.Cleanup(func() {
|
||
_ = agentConn.Close()
|
||
})
|
||
// Ideally we wouldn't wait too long here, but sometimes the the
|
||
// networking needs more time to resolve itself.
|
||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||
defer cancel()
|
||
if !agentConn.AwaitReachable(ctx) {
|
||
t.Fatal("agent not reachable")
|
||
}
|
||
return agentConn, c, statsCh, fs, agnt
|
||
}
|
||
|
||
var dialTestPayload = []byte("dean-was-here123")
|
||
|
||
func testDial(ctx context.Context, t testing.TB, c net.Conn) {
|
||
t.Helper()
|
||
|
||
if deadline, ok := ctx.Deadline(); ok {
|
||
err := c.SetDeadline(deadline)
|
||
assert.NoError(t, err)
|
||
defer func() {
|
||
err := c.SetDeadline(time.Time{})
|
||
assert.NoError(t, err)
|
||
}()
|
||
}
|
||
|
||
assertWritePayload(t, c, dialTestPayload)
|
||
assertReadPayload(t, c, dialTestPayload)
|
||
}
|
||
|
||
func testAccept(ctx context.Context, t testing.TB, c net.Conn) {
|
||
t.Helper()
|
||
defer c.Close()
|
||
|
||
if deadline, ok := ctx.Deadline(); ok {
|
||
err := c.SetDeadline(deadline)
|
||
assert.NoError(t, err)
|
||
defer func() {
|
||
err := c.SetDeadline(time.Time{})
|
||
assert.NoError(t, err)
|
||
}()
|
||
}
|
||
|
||
assertReadPayload(t, c, dialTestPayload)
|
||
assertWritePayload(t, c, dialTestPayload)
|
||
}
|
||
|
||
func assertReadPayload(t testing.TB, r io.Reader, payload []byte) {
|
||
t.Helper()
|
||
b := make([]byte, len(payload)+16)
|
||
n, err := r.Read(b)
|
||
assert.NoError(t, err, "read payload")
|
||
assert.Equal(t, len(payload), n, "read payload length does not match")
|
||
assert.Equal(t, payload, b[:n])
|
||
}
|
||
|
||
func assertWritePayload(t testing.TB, w io.Writer, payload []byte) {
|
||
t.Helper()
|
||
n, err := w.Write(payload)
|
||
assert.NoError(t, err, "write payload")
|
||
assert.Equal(t, len(payload), n, "written payload length does not match")
|
||
}
|
||
|
||
func testSessionOutput(t *testing.T, session *ssh.Session, expected, unexpected []string, expectedRe *regexp.Regexp) {
|
||
t.Helper()
|
||
|
||
err := session.RequestPty("xterm", 128, 128, ssh.TerminalModes{})
|
||
require.NoError(t, err)
|
||
|
||
ptty := ptytest.New(t)
|
||
var stdout bytes.Buffer
|
||
session.Stdout = &stdout
|
||
session.Stderr = ptty.Output()
|
||
session.Stdin = ptty.Input()
|
||
err = session.Shell()
|
||
require.NoError(t, err)
|
||
|
||
ptty.WriteLine("exit 0")
|
||
|
||
waitErr := make(chan error, 1)
|
||
go func() {
|
||
waitErr <- session.Wait()
|
||
}()
|
||
select {
|
||
case err = <-waitErr:
|
||
require.NoError(t, err)
|
||
case <-time.After(testutil.WaitLong):
|
||
require.Fail(t, "timed out waiting for session to exit")
|
||
}
|
||
|
||
for _, unexpected := range unexpected {
|
||
require.NotContains(t, stdout.String(), unexpected, "should not show output")
|
||
}
|
||
for _, expect := range expected {
|
||
require.Contains(t, stdout.String(), expect, "should show output")
|
||
}
|
||
if expectedRe != nil {
|
||
require.Regexp(t, expectedRe, stdout.String())
|
||
}
|
||
}
|
||
|
||
func TestAgent_Metrics_SSH(t *testing.T) {
|
||
t.Parallel()
|
||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||
defer cancel()
|
||
|
||
registry := prometheus.NewRegistry()
|
||
|
||
//nolint:dogsled
|
||
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||
o.PrometheusRegistry = registry
|
||
})
|
||
|
||
sshClient, err := conn.SSHClient(ctx)
|
||
require.NoError(t, err)
|
||
defer sshClient.Close()
|
||
session, err := sshClient.NewSession()
|
||
require.NoError(t, err)
|
||
defer session.Close()
|
||
stdin, err := session.StdinPipe()
|
||
require.NoError(t, err)
|
||
err = session.Shell()
|
||
require.NoError(t, err)
|
||
|
||
expected := []struct {
|
||
Name string
|
||
Type proto.Stats_Metric_Type
|
||
CheckFn func(float64) error
|
||
Labels []*proto.Stats_Metric_Label
|
||
}{
|
||
{
|
||
Name: "agent_reconnecting_pty_connections_total",
|
||
Type: proto.Stats_Metric_COUNTER,
|
||
CheckFn: func(v float64) error {
|
||
if v == 0 {
|
||
return nil
|
||
}
|
||
return xerrors.Errorf("expected 0, got %f", v)
|
||
},
|
||
},
|
||
{
|
||
Name: "agent_sessions_total",
|
||
Type: proto.Stats_Metric_COUNTER,
|
||
CheckFn: func(v float64) error {
|
||
if v == 1 {
|
||
return nil
|
||
}
|
||
return xerrors.Errorf("expected 1, got %f", v)
|
||
},
|
||
Labels: []*proto.Stats_Metric_Label{
|
||
{
|
||
Name: "magic_type",
|
||
Value: "ssh",
|
||
},
|
||
{
|
||
Name: "pty",
|
||
Value: "no",
|
||
},
|
||
},
|
||
},
|
||
{
|
||
Name: "agent_ssh_server_failed_connections_total",
|
||
Type: proto.Stats_Metric_COUNTER,
|
||
CheckFn: func(v float64) error {
|
||
if v == 0 {
|
||
return nil
|
||
}
|
||
return xerrors.Errorf("expected 0, got %f", v)
|
||
},
|
||
},
|
||
{
|
||
Name: "agent_ssh_server_sftp_connections_total",
|
||
Type: proto.Stats_Metric_COUNTER,
|
||
CheckFn: func(v float64) error {
|
||
if v == 0 {
|
||
return nil
|
||
}
|
||
return xerrors.Errorf("expected 0, got %f", v)
|
||
},
|
||
},
|
||
{
|
||
Name: "agent_ssh_server_sftp_server_errors_total",
|
||
Type: proto.Stats_Metric_COUNTER,
|
||
CheckFn: func(v float64) error {
|
||
if v == 0 {
|
||
return nil
|
||
}
|
||
return xerrors.Errorf("expected 0, got %f", v)
|
||
},
|
||
},
|
||
{
|
||
Name: "coderd_agentstats_currently_reachable_peers",
|
||
Type: proto.Stats_Metric_GAUGE,
|
||
CheckFn: func(float64) error {
|
||
// We can't reliably ping a peer here, and networking is out of
|
||
// scope of this test, so we just test that the metric exists
|
||
// with the correct labels.
|
||
return nil
|
||
},
|
||
Labels: []*proto.Stats_Metric_Label{
|
||
{
|
||
Name: "connection_type",
|
||
Value: "derp",
|
||
},
|
||
},
|
||
},
|
||
{
|
||
Name: "coderd_agentstats_currently_reachable_peers",
|
||
Type: proto.Stats_Metric_GAUGE,
|
||
CheckFn: func(float64) error {
|
||
return nil
|
||
},
|
||
Labels: []*proto.Stats_Metric_Label{
|
||
{
|
||
Name: "connection_type",
|
||
Value: "p2p",
|
||
},
|
||
},
|
||
},
|
||
{
|
||
Name: "coderd_agentstats_startup_script_seconds",
|
||
Type: proto.Stats_Metric_GAUGE,
|
||
CheckFn: func(f float64) error {
|
||
if f >= 0 {
|
||
return nil
|
||
}
|
||
return xerrors.Errorf("expected >= 0, got %f", f)
|
||
},
|
||
Labels: []*proto.Stats_Metric_Label{
|
||
{
|
||
Name: "success",
|
||
Value: "true",
|
||
},
|
||
},
|
||
},
|
||
}
|
||
|
||
var actual []*promgo.MetricFamily
|
||
assert.Eventually(t, func() bool {
|
||
actual, err = registry.Gather()
|
||
if err != nil {
|
||
return false
|
||
}
|
||
count := 0
|
||
for _, m := range actual {
|
||
count += len(m.GetMetric())
|
||
}
|
||
return count == len(expected)
|
||
}, testutil.WaitLong, testutil.IntervalFast)
|
||
|
||
i := 0
|
||
for _, mf := range actual {
|
||
for _, m := range mf.GetMetric() {
|
||
assert.Equal(t, expected[i].Name, mf.GetName())
|
||
assert.Equal(t, expected[i].Type.String(), mf.GetType().String())
|
||
if expected[i].Type == proto.Stats_Metric_GAUGE {
|
||
assert.NoError(t, expected[i].CheckFn(m.GetGauge().GetValue()), "check fn for %s failed", expected[i].Name)
|
||
} else if expected[i].Type == proto.Stats_Metric_COUNTER {
|
||
assert.NoError(t, expected[i].CheckFn(m.GetCounter().GetValue()), "check fn for %s failed", expected[i].Name)
|
||
}
|
||
for j, lbl := range expected[i].Labels {
|
||
assert.Equal(t, m.GetLabel()[j], &promgo.LabelPair{
|
||
Name: &lbl.Name,
|
||
Value: &lbl.Value,
|
||
})
|
||
}
|
||
i++
|
||
}
|
||
}
|
||
|
||
_, err = stdin.Write([]byte("exit 0\n"))
|
||
require.NoError(t, err, "writing exit to stdin")
|
||
_ = stdin.Close()
|
||
err = session.Wait()
|
||
require.NoError(t, err, "waiting for session to exit")
|
||
}
|
||
|
||
// echoOnce accepts a single connection, reads 4 bytes and echos them back
|
||
func echoOnce(t *testing.T, ll net.Listener) {
|
||
t.Helper()
|
||
conn, err := ll.Accept()
|
||
if err != nil {
|
||
return
|
||
}
|
||
defer conn.Close()
|
||
b := make([]byte, 4)
|
||
_, err = conn.Read(b)
|
||
if !assert.NoError(t, err) {
|
||
return
|
||
}
|
||
_, err = conn.Write(b)
|
||
if !assert.NoError(t, err) {
|
||
return
|
||
}
|
||
}
|
||
|
||
// requireEcho sends 4 bytes and requires the read response to match what was sent.
|
||
func requireEcho(t *testing.T, conn net.Conn) {
|
||
t.Helper()
|
||
_, err := conn.Write([]byte("test"))
|
||
require.NoError(t, err)
|
||
b := make([]byte, 4)
|
||
_, err = conn.Read(b)
|
||
require.NoError(t, err)
|
||
require.Equal(t, "test", string(b))
|
||
}
|
||
|
||
func assertConnectionReport(t testing.TB, agentClient *agenttest.Client, connectionType proto.Connection_Type, status int, reason string) {
|
||
t.Helper()
|
||
|
||
var reports []*proto.ReportConnectionRequest
|
||
if !assert.Eventually(t, func() bool {
|
||
reports = agentClient.GetConnectionReports()
|
||
return len(reports) >= 2
|
||
}, testutil.WaitMedium, testutil.IntervalFast, "waiting for 2 connection reports or more; got %d", len(reports)) {
|
||
return
|
||
}
|
||
|
||
assert.Len(t, reports, 2, "want 2 connection reports")
|
||
|
||
assert.Equal(t, proto.Connection_CONNECT, reports[0].GetConnection().GetAction(), "first report should be connect")
|
||
assert.Equal(t, proto.Connection_DISCONNECT, reports[1].GetConnection().GetAction(), "second report should be disconnect")
|
||
assert.Equal(t, connectionType, reports[0].GetConnection().GetType(), "connect type should be %s", connectionType)
|
||
assert.Equal(t, connectionType, reports[1].GetConnection().GetType(), "disconnect type should be %s", connectionType)
|
||
t1 := reports[0].GetConnection().GetTimestamp().AsTime()
|
||
t2 := reports[1].GetConnection().GetTimestamp().AsTime()
|
||
assert.True(t, t1.Before(t2) || t1.Equal(t2), "connect timestamp should be before or equal to disconnect timestamp")
|
||
assert.NotEmpty(t, reports[0].GetConnection().GetIp(), "connect ip should not be empty")
|
||
assert.NotEmpty(t, reports[1].GetConnection().GetIp(), "disconnect ip should not be empty")
|
||
assert.Equal(t, 0, int(reports[0].GetConnection().GetStatusCode()), "connect status code should be 0")
|
||
assert.Equal(t, status, int(reports[1].GetConnection().GetStatusCode()), "disconnect status code should be %d", status)
|
||
assert.Equal(t, "", reports[0].GetConnection().GetReason(), "connect reason should be empty")
|
||
if reason != "" {
|
||
assert.Contains(t, reports[1].GetConnection().GetReason(), reason, "disconnect reason should contain %s", reason)
|
||
} else {
|
||
t.Logf("connection report disconnect reason: %s", reports[1].GetConnection().GetReason())
|
||
}
|
||
}
|