Files
coder/agent/agent_test.go
T

4173 lines
127 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package agent_test
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"net/netip"
"os"
"os/exec"
"os/user"
"path"
"path/filepath"
"regexp"
"runtime"
"slices"
"strconv"
"strings"
"testing"
"time"
"github.com/bramvdbogaerde/go-scp"
"github.com/google/uuid"
"github.com/ory/dockertest/v3"
"github.com/ory/dockertest/v3/docker"
"github.com/pion/udp"
"github.com/pkg/sftp"
"github.com/prometheus/client_golang/prometheus"
promgo "github.com/prometheus/client_model/go"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"golang.org/x/crypto/ssh"
"golang.org/x/xerrors"
"tailscale.com/net/speedtest"
"tailscale.com/tailcfg"
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/agent"
"github.com/coder/coder/v2/agent/agentcontainers"
"github.com/coder/coder/v2/agent/agentssh"
"github.com/coder/coder/v2/agent/agenttest"
"github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/agent/usershell"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/cryptorand"
"github.com/coder/coder/v2/pty/ptytest"
"github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/tailnettest"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
)
func TestMain(m *testing.M) {
if os.Getenv("CODER_TEST_RUN_SUB_AGENT_MAIN") == "1" {
// If we're running as a subagent, we don't want to run the main tests.
// Instead, we just run the subagent tests.
exit := runSubAgentMain()
os.Exit(exit)
}
goleak.VerifyTestMain(m, testutil.GoleakOptions...)
}
var sshPorts = []uint16{workspacesdk.AgentSSHPort, workspacesdk.AgentStandardSSHPort}
// TestAgent_CloseWhileStarting is a regression test for https://github.com/coder/coder/issues/17328
func TestAgent_ImmediateClose(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
logger := slogtest.Make(t, &slogtest.Options{
// Agent can drop errors when shutting down, and some, like the
// fasthttplistener connection closed error, are unexported.
IgnoreErrors: true,
}).Leveled(slog.LevelDebug)
manifest := agentsdk.Manifest{
AgentID: uuid.New(),
AgentName: "test-agent",
WorkspaceName: "test-workspace",
WorkspaceID: uuid.New(),
}
coordinator := tailnet.NewCoordinator(logger)
t.Cleanup(func() {
_ = coordinator.Close()
})
statsCh := make(chan *proto.Stats, 50)
fs := afero.NewMemMapFs()
client := agenttest.NewClient(t, logger.Named("agenttest"), manifest.AgentID, manifest, statsCh, coordinator)
t.Cleanup(client.Close)
options := agent.Options{
Client: client,
Filesystem: fs,
Logger: logger.Named("agent"),
ReconnectingPTYTimeout: 0,
EnvironmentVariables: map[string]string{},
}
agentUnderTest := agent.New(options)
t.Cleanup(func() {
_ = agentUnderTest.Close()
})
// wait until the agent has connected and is starting to find races in the startup code
_ = testutil.TryReceive(ctx, t, client.GetStartup())
t.Log("Closing Agent")
err := agentUnderTest.Close()
require.NoError(t, err)
}
// NOTE(Cian): I noticed that these tests would fail when my default shell was zsh.
// Writing "exit 0" to stdin before closing fixed the issue for me.
func TestAgent_Stats_SSH(t *testing.T) {
t.Parallel()
for _, port := range sshPorts {
t.Run(fmt.Sprintf("(:%d)", port), func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
//nolint:dogsled
conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
sshClient, err := conn.SSHClientOnPort(ctx, port)
require.NoError(t, err)
defer sshClient.Close()
session, err := sshClient.NewSession()
require.NoError(t, err)
defer session.Close()
stdin, err := session.StdinPipe()
require.NoError(t, err)
err = session.Shell()
require.NoError(t, err)
var s *proto.Stats
// We are looking for four different stats to be reported. They might not all
// arrive at the same time, so we loop until we've seen them all.
var connectionCountSeen, rxBytesSeen, txBytesSeen, sessionCountSSHSeen bool
require.Eventuallyf(t, func() bool {
var ok bool
s, ok = <-stats
if !ok {
return false
}
if s.ConnectionCount > 0 {
connectionCountSeen = true
}
if s.RxBytes > 0 {
rxBytesSeen = true
}
if s.TxBytes > 0 {
txBytesSeen = true
}
if s.SessionCountSsh == 1 {
sessionCountSSHSeen = true
}
return connectionCountSeen && rxBytesSeen && txBytesSeen && sessionCountSSHSeen
}, testutil.WaitLong, testutil.IntervalFast,
"never saw all stats: %+v, saw connectionCount: %t, rxBytes: %t, txBytes: %t, sessionCountSsh: %t",
s, connectionCountSeen, rxBytesSeen, txBytesSeen, sessionCountSSHSeen,
)
_, err = stdin.Write([]byte("exit 0\n"))
require.NoError(t, err, "writing exit to stdin")
_ = stdin.Close()
err = session.Wait()
require.NoError(t, err, "waiting for session to exit")
})
}
}
func TestAgent_Stats_ReconnectingPTY(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
//nolint:dogsled
conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
ptyConn, err := conn.ReconnectingPTY(ctx, uuid.New(), 128, 128, "bash")
require.NoError(t, err)
defer ptyConn.Close()
data, err := json.Marshal(workspacesdk.ReconnectingPTYRequest{
Data: "echo test\r\n",
})
require.NoError(t, err)
_, err = ptyConn.Write(data)
require.NoError(t, err)
var s *proto.Stats
// We are looking for four different stats to be reported. They might not all
// arrive at the same time, so we loop until we've seen them all.
var connectionCountSeen, rxBytesSeen, txBytesSeen, sessionCountReconnectingPTYSeen bool
require.Eventuallyf(t, func() bool {
var ok bool
s, ok = <-stats
if !ok {
return false
}
if s.ConnectionCount > 0 {
connectionCountSeen = true
}
if s.RxBytes > 0 {
rxBytesSeen = true
}
if s.TxBytes > 0 {
txBytesSeen = true
}
if s.SessionCountReconnectingPty == 1 {
sessionCountReconnectingPTYSeen = true
}
return connectionCountSeen && rxBytesSeen && txBytesSeen && sessionCountReconnectingPTYSeen
}, testutil.WaitLong, testutil.IntervalFast,
"never saw all stats: %+v, saw connectionCount: %t, rxBytes: %t, txBytes: %t, sessionCountReconnectingPTY: %t",
s, connectionCountSeen, rxBytesSeen, txBytesSeen, sessionCountReconnectingPTYSeen,
)
}
func TestAgent_Stats_Magic(t *testing.T) {
t.Parallel()
t.Run("StripsEnvironmentVariable", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
//nolint:dogsled
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
sshClient, err := conn.SSHClient(ctx)
require.NoError(t, err)
defer sshClient.Close()
session, err := sshClient.NewSession()
require.NoError(t, err)
session.Setenv(agentssh.MagicSessionTypeEnvironmentVariable, string(agentssh.MagicSessionTypeVSCode))
defer session.Close()
command := "sh -c 'echo $" + agentssh.MagicSessionTypeEnvironmentVariable + "'"
expected := ""
if runtime.GOOS == "windows" {
expected = "%" + agentssh.MagicSessionTypeEnvironmentVariable + "%"
command = "cmd.exe /c echo " + expected
}
output, err := session.Output(command)
require.NoError(t, err)
require.Equal(t, expected, strings.TrimSpace(string(output)))
})
t.Run("TracksVSCode", func(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
t.Skip("Sleeping for infinity doesn't work on Windows")
}
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
//nolint:dogsled
conn, agentClient, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
sshClient, err := conn.SSHClient(ctx)
require.NoError(t, err)
defer sshClient.Close()
session, err := sshClient.NewSession()
require.NoError(t, err)
session.Setenv(agentssh.MagicSessionTypeEnvironmentVariable, string(agentssh.MagicSessionTypeVSCode))
defer session.Close()
stdin, err := session.StdinPipe()
require.NoError(t, err)
err = session.Shell()
require.NoError(t, err)
require.Eventuallyf(t, func() bool {
s, ok := <-stats
t.Logf("got stats: ok=%t, ConnectionCount=%d, RxBytes=%d, TxBytes=%d, SessionCountVSCode=%d, ConnectionMedianLatencyMS=%f",
ok, s.ConnectionCount, s.RxBytes, s.TxBytes, s.SessionCountVscode, s.ConnectionMedianLatencyMs)
return ok &&
// Ensure that the connection didn't count as a "normal" SSH session.
// This was a special one, so it should be labeled specially in the stats!
s.SessionCountVscode == 1 &&
// Ensure that connection latency is being counted!
// If it isn't, it's set to -1.
s.ConnectionMedianLatencyMs >= 0
}, testutil.WaitLong, testutil.IntervalFast,
"never saw stats",
)
_, err = stdin.Write([]byte("exit 0\n"))
require.NoError(t, err, "writing exit to stdin")
_ = stdin.Close()
err = session.Wait()
require.NoError(t, err)
assertConnectionReport(t, agentClient, proto.Connection_VSCODE, 0, "")
})
t.Run("TracksJetBrains", func(t *testing.T) {
t.Parallel()
if runtime.GOOS != "linux" {
t.Skip("JetBrains tracking is only supported on Linux")
}
ctx := testutil.Context(t, testutil.WaitLong)
// JetBrains tracking works by looking at the process name listening on the
// forwarded port. If the process's command line includes the magic string
// we are looking for, then we assume it is a JetBrains editor. So when we
// connect to the port we must ensure the process includes that magic string
// to fool the agent into thinking this is JetBrains. To do this we need to
// spawn an external process (in this case a simple echo server) so we can
// control the process name. The -D here is just to mimic how Java options
// are set but is not necessary as the agent looks only for the magic
// string itself anywhere in the command.
_, b, _, ok := runtime.Caller(0)
require.True(t, ok)
dir := filepath.Join(filepath.Dir(b), "../scripts/echoserver/main.go")
echoServerCmd := exec.Command("go", "run", dir,
"-D", agentssh.MagicProcessCmdlineJetBrains)
stdout, err := echoServerCmd.StdoutPipe()
require.NoError(t, err)
err = echoServerCmd.Start()
require.NoError(t, err)
defer echoServerCmd.Process.Kill()
// The echo server prints its port as the first line.
sc := bufio.NewScanner(stdout)
sc.Scan()
remotePort := sc.Text()
//nolint:dogsled
conn, agentClient, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
sshClient, err := conn.SSHClient(ctx)
require.NoError(t, err)
tunneledConn, err := sshClient.Dial("tcp", fmt.Sprintf("127.0.0.1:%s", remotePort))
require.NoError(t, err)
t.Cleanup(func() {
// always close on failure of test
_ = conn.Close()
_ = tunneledConn.Close()
})
require.Eventuallyf(t, func() bool {
s, ok := <-stats
t.Logf("got stats with conn open: ok=%t, ConnectionCount=%d, SessionCountJetBrains=%d",
ok, s.ConnectionCount, s.SessionCountJetbrains)
return ok && s.SessionCountJetbrains == 1
}, testutil.WaitLong, testutil.IntervalFast,
"never saw stats with conn open",
)
// Kill the server and connection after checking for the echo.
requireEcho(t, tunneledConn)
_ = echoServerCmd.Process.Kill()
_ = tunneledConn.Close()
require.Eventuallyf(t, func() bool {
s, ok := <-stats
t.Logf("got stats after disconnect %t, %d",
ok, s.SessionCountJetbrains)
return ok &&
s.SessionCountJetbrains == 0
}, testutil.WaitLong, testutil.IntervalFast,
"never saw stats after conn closes",
)
assertConnectionReport(t, agentClient, proto.Connection_JETBRAINS, 0, "")
})
}
func TestAgent_SessionExec(t *testing.T) {
t.Parallel()
for _, port := range sshPorts {
t.Run(fmt.Sprintf("(:%d)", port), func(t *testing.T) {
t.Parallel()
session := setupSSHSessionOnPort(t, agentsdk.Manifest{}, codersdk.ServiceBannerConfig{}, nil, port)
command := "echo test"
if runtime.GOOS == "windows" {
command = "cmd.exe /c echo test"
}
output, err := session.Output(command)
require.NoError(t, err)
require.Equal(t, "test", strings.TrimSpace(string(output)))
})
}
}
//nolint:tparallel // Sub tests need to run sequentially.
func TestAgent_Session_EnvironmentVariables(t *testing.T) {
t.Parallel()
tmpdir := t.TempDir()
// Defined by the coder script runner, hardcoded here since we don't
// have a reference to it.
scriptBinDir := filepath.Join(tmpdir, "coder-script-data", "bin")
manifest := agentsdk.Manifest{
EnvironmentVariables: map[string]string{
"MY_MANIFEST": "true",
"MY_OVERRIDE": "false",
"MY_SESSION_MANIFEST": "false",
},
}
banner := codersdk.ServiceBannerConfig{}
session := setupSSHSession(t, manifest, banner, nil, func(_ *agenttest.Client, opts *agent.Options) {
opts.ScriptDataDir = tmpdir
opts.EnvironmentVariables["MY_OVERRIDE"] = "true"
})
err := session.Setenv("MY_SESSION_MANIFEST", "true")
require.NoError(t, err)
err = session.Setenv("MY_SESSION", "true")
require.NoError(t, err)
command := "sh"
echoEnv := func(t *testing.T, w io.Writer, env string) {
if runtime.GOOS == "windows" {
_, err := fmt.Fprintf(w, "echo %%%s%%\r\n", env)
require.NoError(t, err)
} else {
_, err := fmt.Fprintf(w, "echo $%s\n", env)
require.NoError(t, err)
}
}
if runtime.GOOS == "windows" {
command = "cmd.exe"
}
stdin, err := session.StdinPipe()
require.NoError(t, err)
defer stdin.Close()
stdout, err := session.StdoutPipe()
require.NoError(t, err)
err = session.Start(command)
require.NoError(t, err)
// Context is fine here since we're not doing a parallel subtest.
ctx := testutil.Context(t, testutil.WaitLong)
go func() {
<-ctx.Done()
_ = session.Close()
}()
s := bufio.NewScanner(stdout)
//nolint:paralleltest // These tests need to run sequentially.
for k, partialV := range map[string]string{
"CODER": "true", // From the agent.
"MY_MANIFEST": "true", // From the manifest.
"MY_OVERRIDE": "true", // From the agent environment variables option, overrides manifest.
"MY_SESSION_MANIFEST": "false", // From the manifest, overrides session env.
"MY_SESSION": "true", // From the session.
"PATH": scriptBinDir + string(filepath.ListSeparator),
} {
t.Run(k, func(t *testing.T) {
echoEnv(t, stdin, k)
// Windows is unreliable, so keep scanning until we find a match.
for s.Scan() {
got := strings.TrimSpace(s.Text())
t.Logf("%s=%s", k, got)
if strings.Contains(got, partialV) {
break
}
}
if err := s.Err(); !errors.Is(err, io.EOF) {
require.NoError(t, err)
}
})
}
}
func TestAgent_Session_SecretInjection(t *testing.T) {
t.Parallel()
manifest := agentsdk.Manifest{
EnvironmentVariables: map[string]string{
"SHOULD_BE_OVERRIDDEN": "manifest-value",
},
}
secrets := []agentsdk.WorkspaceSecret{
{EnvName: "MY_SECRET_ENV", Value: []byte("env-secret-value")},
{FilePath: "/tmp/secret-file", Value: []byte("file-secret-content")},
{EnvName: "BOTH_ENV", FilePath: "/tmp/both-file", Value: []byte("both-value")},
{EnvName: "SHOULD_BE_OVERRIDDEN", Value: []byte("secret-wins")},
}
ctx := testutil.Context(t, testutil.WaitLong)
//nolint:dogsled
conn, _, _, fs, _ := setupAgentWithSecrets(t, manifest, secrets, 0)
// Verify file injection via the agent's filesystem.
content, err := afero.ReadFile(fs, "/tmp/secret-file")
require.NoError(t, err)
require.Equal(t, "file-secret-content", string(content))
content, err = afero.ReadFile(fs, "/tmp/both-file")
require.NoError(t, err)
require.Equal(t, "both-value", string(content))
// Verify env var injection via an SSH session.
sshClient, err := conn.SSHClient(ctx)
require.NoError(t, err)
t.Cleanup(func() { _ = sshClient.Close() })
session, err := sshClient.NewSession()
require.NoError(t, err)
t.Cleanup(func() { _ = session.Close() })
command := "sh"
if runtime.GOOS == "windows" {
command = "cmd.exe"
}
stdin, err := session.StdinPipe()
require.NoError(t, err)
defer stdin.Close()
stdout, err := session.StdoutPipe()
require.NoError(t, err)
err = session.Start(command)
require.NoError(t, err)
go func() {
<-ctx.Done()
_ = session.Close()
}()
s := bufio.NewScanner(stdout)
echoEnv := func(t *testing.T, w io.Writer, env string) {
t.Helper()
if runtime.GOOS == "windows" {
_, err := fmt.Fprintf(w, "echo %%%s%%\r\n", env)
require.NoError(t, err)
} else {
_, err := fmt.Fprintf(w, "echo $%s\n", env)
require.NoError(t, err)
}
}
for k, partialV := range map[string]string{
"MY_SECRET_ENV": "env-secret-value",
"BOTH_ENV": "both-value",
"SHOULD_BE_OVERRIDDEN": "secret-wins",
} {
echoEnv(t, stdin, k)
found := false
for s.Scan() {
got := strings.TrimSpace(s.Text())
t.Logf("%s=%s", k, got)
if strings.Contains(got, partialV) {
found = true
break
}
}
require.True(t, found, "env %s not found in output", k)
if err := s.Err(); !errors.Is(err, io.EOF) {
require.NoError(t, err)
}
}
}
func TestAgent_StartupScript_SecretInjection(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
t.Skip("startup script test uses sh syntax")
}
tmpDir := t.TempDir()
secretFilePath := filepath.Join(tmpDir, "secret-file")
envProofPath := filepath.Join(tmpDir, "env-proof")
fileProofPath := filepath.Join(tmpDir, "file-proof")
// The startup script reads the secret env var and the secret file,
// writing both to proof files so we can verify they were available
// at script execution time.
script := fmt.Sprintf(
"echo \"$MY_STARTUP_SECRET\" > %s && cat %s > %s",
envProofPath, secretFilePath, fileProofPath,
)
manifest := agentsdk.Manifest{
Scripts: []codersdk.WorkspaceAgentScript{{
Script: script,
Timeout: 30 * time.Second,
RunOnStart: true,
}},
}
secrets := []agentsdk.WorkspaceSecret{
{EnvName: "MY_STARTUP_SECRET", Value: []byte("startup-env-value")},
{FilePath: secretFilePath, Value: []byte("startup-file-content")},
}
// Use the real OS filesystem so that both writeSecretFiles and
// the startup script operate on the same filesystem.
//nolint:dogsled
_, client, _, _, _ := setupAgentWithSecrets(t, manifest, secrets, 0, func(_ *agenttest.Client, opts *agent.Options) {
opts.Filesystem = afero.NewOsFs()
})
// Wait for the startup script to complete.
var got []codersdk.WorkspaceAgentLifecycle
assert.Eventually(t, func() bool {
got = client.GetLifecycleStates()
return len(got) > 0 && got[len(got)-1] == codersdk.WorkspaceAgentLifecycleReady
}, testutil.WaitLong, testutil.IntervalMedium)
require.Contains(t, got, codersdk.WorkspaceAgentLifecycleReady, "agent never reached ready")
// Verify the startup script could read the secret env var.
envProof, err := os.ReadFile(envProofPath)
require.NoError(t, err)
require.Equal(t, "startup-env-value", strings.TrimSpace(string(envProof)))
// Verify the startup script could read the secret file.
fileProof, err := os.ReadFile(fileProofPath)
require.NoError(t, err)
require.Equal(t, "startup-file-content", string(fileProof))
}
func TestAgent_GitSSH(t *testing.T) {
t.Parallel()
session := setupSSHSession(t, agentsdk.Manifest{}, codersdk.ServiceBannerConfig{}, nil)
command := "sh -c 'echo $GIT_SSH_COMMAND'"
if runtime.GOOS == "windows" {
command = "cmd.exe /c echo %GIT_SSH_COMMAND%"
}
output, err := session.Output(command)
require.NoError(t, err)
require.True(t, strings.HasSuffix(strings.TrimSpace(string(output)), "gitssh --"))
}
func TestAgent_SessionTTYShell(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
// This might be our implementation, or ConPTY itself.
// It's difficult to find extensive tests for it, so
// it seems like it could be either.
t.Skip("ConPTY appears to be inconsistent on Windows.")
}
for _, port := range sshPorts {
t.Run(fmt.Sprintf("(%d)", port), func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
session := setupSSHSessionOnPort(t, agentsdk.Manifest{}, codersdk.ServiceBannerConfig{}, nil, port)
command := "sh"
if runtime.GOOS == "windows" {
command = "cmd.exe"
}
err := session.RequestPty("xterm", 128, 128, ssh.TerminalModes{})
require.NoError(t, err)
ptty := ptytest.New(t)
session.Stdout = ptty.Output()
session.Stderr = ptty.Output()
session.Stdin = ptty.Input()
err = session.Start(command)
require.NoError(t, err)
_ = ptty.Peek(ctx, 1) // wait for the prompt
ptty.WriteLine("echo test")
ptty.ExpectMatch("test")
ptty.WriteLine("exit")
err = session.Wait()
require.NoError(t, err)
})
}
}
func TestAgent_SessionTTYExitCode(t *testing.T) {
t.Parallel()
session := setupSSHSession(t, agentsdk.Manifest{}, codersdk.ServiceBannerConfig{}, nil)
command := "areallynotrealcommand"
err := session.RequestPty("xterm", 128, 128, ssh.TerminalModes{})
require.NoError(t, err)
ptty := ptytest.New(t)
session.Stdout = ptty.Output()
session.Stderr = ptty.Output()
session.Stdin = ptty.Input()
err = session.Start(command)
require.NoError(t, err)
err = session.Wait()
exitErr := &ssh.ExitError{}
require.True(t, xerrors.As(err, &exitErr))
if runtime.GOOS == "windows" {
assert.Equal(t, 1, exitErr.ExitStatus())
} else {
assert.Equal(t, 127, exitErr.ExitStatus())
}
}
func TestAgent_Session_TTY_MOTD(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
// This might be our implementation, or ConPTY itself.
// It's difficult to find extensive tests for it, so
// it seems like it could be either.
t.Skip("ConPTY appears to be inconsistent on Windows.")
}
u, err := user.Current()
require.NoError(t, err, "get current user")
name := filepath.Join(u.HomeDir, "motd")
wantMOTD := "Welcome to your Coder workspace!"
wantServiceBanner := "Service banner text goes here"
tests := []struct {
name string
manifest agentsdk.Manifest
banner codersdk.ServiceBannerConfig
expected []string
unexpected []string
expectedRe *regexp.Regexp
}{
{
name: "WithoutServiceBanner",
manifest: agentsdk.Manifest{MOTDFile: name},
banner: codersdk.ServiceBannerConfig{},
expected: []string{wantMOTD},
unexpected: []string{wantServiceBanner},
},
{
name: "WithServiceBanner",
manifest: agentsdk.Manifest{MOTDFile: name},
banner: codersdk.ServiceBannerConfig{
Enabled: true,
Message: wantServiceBanner,
},
expected: []string{wantMOTD, wantServiceBanner},
},
{
name: "ServiceBannerDisabled",
manifest: agentsdk.Manifest{MOTDFile: name},
banner: codersdk.ServiceBannerConfig{
Enabled: false,
Message: wantServiceBanner,
},
expected: []string{wantMOTD},
unexpected: []string{wantServiceBanner},
},
{
name: "ServiceBannerOnly",
manifest: agentsdk.Manifest{},
banner: codersdk.ServiceBannerConfig{
Enabled: true,
Message: wantServiceBanner,
},
expected: []string{wantServiceBanner},
unexpected: []string{wantMOTD},
},
{
name: "None",
manifest: agentsdk.Manifest{},
banner: codersdk.ServiceBannerConfig{},
unexpected: []string{wantServiceBanner, wantMOTD},
},
{
name: "CarriageReturns",
manifest: agentsdk.Manifest{},
banner: codersdk.ServiceBannerConfig{
Enabled: true,
Message: "service\n\nbanner\nhere",
},
expected: []string{"service\r\n\r\nbanner\r\nhere\r\n\r\n"},
unexpected: []string{},
},
{
name: "Trim",
// Enable motd since it will be printed after the banner,
// this ensures that we can test for an exact mount of
// newlines.
manifest: agentsdk.Manifest{
MOTDFile: name,
},
banner: codersdk.ServiceBannerConfig{
Enabled: true,
Message: "\n\n\n\n\n\nbanner\n\n\n\n\n\n",
},
expectedRe: regexp.MustCompile(`([^\n\r]|^)banner\r\n\r\n[^\r\n]`),
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
t.Parallel()
session := setupSSHSession(t, test.manifest, test.banner, func(fs afero.Fs) {
err := fs.MkdirAll(filepath.Dir(name), 0o700)
require.NoError(t, err)
err = afero.WriteFile(fs, name, []byte(wantMOTD), 0o600)
require.NoError(t, err)
})
testSessionOutput(t, session, test.expected, test.unexpected, test.expectedRe)
})
}
}
//nolint:tparallel // Sub tests need to run sequentially.
func TestAgent_Session_TTY_MOTD_Update(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
// This might be our implementation, or ConPTY itself.
// It's difficult to find extensive tests for it, so
// it seems like it could be either.
t.Skip("ConPTY appears to be inconsistent on Windows.")
}
// Only the banner updates dynamically; the MOTD file does not.
wantServiceBanner := "Service banner text goes here"
tests := []struct {
banner codersdk.ServiceBannerConfig
expected []string
unexpected []string
}{
{
banner: codersdk.ServiceBannerConfig{},
expected: []string{},
unexpected: []string{wantServiceBanner},
},
{
banner: codersdk.ServiceBannerConfig{
Enabled: true,
Message: wantServiceBanner,
},
expected: []string{wantServiceBanner},
},
{
banner: codersdk.ServiceBannerConfig{
Enabled: false,
Message: wantServiceBanner,
},
expected: []string{},
unexpected: []string{wantServiceBanner},
},
{
banner: codersdk.ServiceBannerConfig{
Enabled: true,
Message: wantServiceBanner,
},
expected: []string{wantServiceBanner},
unexpected: []string{},
},
{
banner: codersdk.ServiceBannerConfig{},
unexpected: []string{wantServiceBanner},
},
}
setSBInterval := func(_ *agenttest.Client, opts *agent.Options) {
opts.ServiceBannerRefreshInterval = testutil.IntervalFast
}
//nolint:dogsled // Allow the blank identifiers.
conn, client, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, setSBInterval)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
//nolint:paralleltest // These tests need to swap the banner func.
for _, port := range sshPorts {
sshClient, err := conn.SSHClientOnPort(ctx, port)
require.NoError(t, err)
t.Cleanup(func() {
_ = sshClient.Close()
})
for i, test := range tests {
t.Run(fmt.Sprintf("(:%d)/%d", port, i), func(t *testing.T) {
// Set new banner func and wait for the agent to call it to update the
// banner. We wait for two calls to ensure the value has been stored:
// the second call can only begin after the first iteration of
// fetchServiceBannerLoop completes (call + store), so after
// receiving two signals at least one store has happened.
ready := make(chan struct{}, 2)
client.SetAnnouncementBannersFunc(func() ([]codersdk.BannerConfig, error) {
select {
case ready <- struct{}{}:
default:
}
return []codersdk.BannerConfig{test.banner}, nil
})
testutil.TryReceive(ctx, t, ready)
testutil.TryReceive(ctx, t, ready)
session, err := sshClient.NewSession()
require.NoError(t, err)
t.Cleanup(func() {
_ = session.Close()
})
testSessionOutput(t, session, test.expected, test.unexpected, nil)
})
}
}
}
//nolint:paralleltest // This test sets an environment variable.
func TestAgent_Session_TTY_QuietLogin(t *testing.T) {
if runtime.GOOS == "windows" {
// This might be our implementation, or ConPTY itself.
// It's difficult to find extensive tests for it, so
// it seems like it could be either.
t.Skip("ConPTY appears to be inconsistent on Windows.")
}
wantNotMOTD := "Welcome to your Coder workspace!"
wantMaybeServiceBanner := "Service banner text goes here"
u, err := user.Current()
require.NoError(t, err, "get current user")
name := filepath.Join(u.HomeDir, "motd")
// Neither banner nor MOTD should show if not a login shell.
t.Run("NotLogin", func(t *testing.T) {
session := setupSSHSession(t, agentsdk.Manifest{
MOTDFile: name,
}, codersdk.ServiceBannerConfig{
Enabled: true,
Message: wantMaybeServiceBanner,
}, func(fs afero.Fs) {
err := afero.WriteFile(fs, name, []byte(wantNotMOTD), 0o600)
require.NoError(t, err, "write motd file")
})
err = session.RequestPty("xterm", 128, 128, ssh.TerminalModes{})
require.NoError(t, err)
wantEcho := "foobar"
command := "echo " + wantEcho
output, err := session.Output(command)
require.NoError(t, err)
require.Contains(t, string(output), wantEcho, "should show echo")
require.NotContains(t, string(output), wantNotMOTD, "should not show motd")
require.NotContains(t, string(output), wantMaybeServiceBanner, "should not show service banner")
})
// Only the MOTD should be silenced when hushlogin is present.
t.Run("Hushlogin", func(t *testing.T) {
session := setupSSHSession(t, agentsdk.Manifest{
MOTDFile: name,
}, codersdk.ServiceBannerConfig{
Enabled: true,
Message: wantMaybeServiceBanner,
}, func(fs afero.Fs) {
err := afero.WriteFile(fs, name, []byte(wantNotMOTD), 0o600)
require.NoError(t, err, "write motd file")
// Create hushlogin to silence motd.
err = afero.WriteFile(fs, name, []byte{}, 0o600)
require.NoError(t, err, "write hushlogin file")
})
err = session.RequestPty("xterm", 128, 128, ssh.TerminalModes{})
require.NoError(t, err)
ptty := ptytest.New(t)
var stdout bytes.Buffer
session.Stdout = &stdout
session.Stderr = ptty.Output()
session.Stdin = ptty.Input()
err = session.Shell()
require.NoError(t, err)
ptty.WriteLine("exit 0")
err = session.Wait()
require.NoError(t, err)
require.NotContains(t, stdout.String(), wantNotMOTD, "should not show motd")
require.Contains(t, stdout.String(), wantMaybeServiceBanner, "should show service banner")
})
}
func TestAgent_Session_TTY_FastCommandHasOutput(t *testing.T) {
t.Parallel()
// This test is here to prevent regressions where quickly executing
// commands (with TTY) don't sync their output to the SSH session.
//
// See: https://github.com/coder/coder/issues/6656
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
//nolint:dogsled
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
sshClient, err := conn.SSHClient(ctx)
require.NoError(t, err)
defer sshClient.Close()
ptty := ptytest.New(t)
var stdout bytes.Buffer
// NOTE(mafredri): Increase iterations to increase chance of failure,
// assuming bug is present. Limiting GOMAXPROCS further
// increases the chance of failure.
// Using 1000 iterations is basically a guaranteed failure (but let's
// not increase test times needlessly).
// Limit GOMAXPROCS (e.g. `export GOMAXPROCS=1`) to further increase
// chance of failure. Also -race helps.
for i := 0; i < 5; i++ {
func() {
stdout.Reset()
session, err := sshClient.NewSession()
require.NoError(t, err)
defer session.Close()
err = session.RequestPty("xterm", 128, 128, ssh.TerminalModes{})
require.NoError(t, err)
session.Stdout = &stdout
session.Stderr = ptty.Output()
session.Stdin = ptty.Input()
err = session.Start("echo wazzup")
require.NoError(t, err)
err = session.Wait()
require.NoError(t, err)
require.Contains(t, stdout.String(), "wazzup", "should output greeting")
}()
}
}
func TestAgent_Session_TTY_HugeOutputIsNotLost(t *testing.T) {
t.Parallel()
// This test is here to prevent regressions where a command (with or
// without) a large amount of output would not be fully copied to the
// SSH session. On unix systems, this was fixed by duplicating the file
// descriptor of the PTY master and using it for copying the output.
//
// See: https://github.com/coder/coder/issues/6656
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
//nolint:dogsled
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
sshClient, err := conn.SSHClient(ctx)
require.NoError(t, err)
defer sshClient.Close()
ptty := ptytest.New(t)
var stdout bytes.Buffer
// NOTE(mafredri): Increase iterations to increase chance of failure,
// assuming bug is present.
// Using 10 iterations is basically a guaranteed failure (but let's
// not increase test times needlessly). Run with -race and do not
// limit parallelism (`export GOMAXPROCS=10`) to increase the chance
// of failure.
for i := 0; i < 1; i++ {
func() {
stdout.Reset()
session, err := sshClient.NewSession()
require.NoError(t, err)
defer session.Close()
err = session.RequestPty("xterm", 128, 128, ssh.TerminalModes{})
require.NoError(t, err)
session.Stdout = &stdout
session.Stderr = ptty.Output()
session.Stdin = ptty.Input()
want := strings.Repeat("wazzup", 1024+1) // ~6KB, +1 because 1024 is a common buffer size.
err = session.Start("echo " + want)
require.NoError(t, err)
err = session.Wait()
require.NoError(t, err)
require.Contains(t, stdout.String(), want, "should output entire greeting")
}()
}
}
func TestAgent_TCPLocalForwarding(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
rl, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer rl.Close()
tcpAddr, valid := rl.Addr().(*net.TCPAddr)
require.True(t, valid)
remotePort := tcpAddr.Port
go echoOnce(t, rl)
sshClient := setupAgentSSHClient(ctx, t)
conn, err := sshClient.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", remotePort))
require.NoError(t, err)
defer conn.Close()
requireEcho(t, conn)
}
func TestAgent_TCPRemoteForwarding(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
sshClient := setupAgentSSHClient(ctx, t)
localhost := netip.MustParseAddr("127.0.0.1")
var randomPort uint16
var ll net.Listener
var err error
for {
randomPort = testutil.RandomPortNoListen(t)
addr := net.TCPAddrFromAddrPort(netip.AddrPortFrom(localhost, randomPort))
ll, err = sshClient.ListenTCP(addr)
if err != nil {
t.Logf("error remote forwarding: %s", err.Error())
select {
case <-ctx.Done():
t.Fatal("timed out getting random listener")
default:
continue
}
}
break
}
defer ll.Close()
go echoOnce(t, ll)
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", randomPort))
require.NoError(t, err)
defer conn.Close()
requireEcho(t, conn)
}
func TestAgent_TCPLocalForwardingBlocked(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
rl, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer rl.Close()
tcpAddr, valid := rl.Addr().(*net.TCPAddr)
require.True(t, valid)
remotePort := tcpAddr.Port
//nolint:dogsled
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
o.BlockLocalPortForwarding = true
})
sshClient, err := agentConn.SSHClient(ctx)
require.NoError(t, err)
defer sshClient.Close()
_, err = sshClient.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", remotePort))
require.ErrorContains(t, err, "administratively prohibited")
}
func TestAgent_TCPRemoteForwardingBlocked(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
//nolint:dogsled
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
o.BlockReversePortForwarding = true
})
sshClient, err := agentConn.SSHClient(ctx)
require.NoError(t, err)
defer sshClient.Close()
localhost := netip.MustParseAddr("127.0.0.1")
randomPort := testutil.RandomPortNoListen(t)
addr := net.TCPAddrFromAddrPort(netip.AddrPortFrom(localhost, randomPort))
_, err = sshClient.ListenTCP(addr)
require.ErrorContains(t, err, "tcpip-forward request denied by peer")
}
func TestAgent_UnixLocalForwardingBlocked(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
t.Skip("unix domain sockets are not fully supported on Windows")
}
ctx := testutil.Context(t, testutil.WaitLong)
tmpdir := testutil.TempDirUnixSocket(t)
remoteSocketPath := filepath.Join(tmpdir, "remote-socket")
l, err := net.Listen("unix", remoteSocketPath)
require.NoError(t, err)
defer l.Close()
//nolint:dogsled
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
o.BlockLocalPortForwarding = true
})
sshClient, err := agentConn.SSHClient(ctx)
require.NoError(t, err)
defer sshClient.Close()
_, err = sshClient.Dial("unix", remoteSocketPath)
require.ErrorContains(t, err, "administratively prohibited")
}
func TestAgent_UnixRemoteForwardingBlocked(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
t.Skip("unix domain sockets are not fully supported on Windows")
}
ctx := testutil.Context(t, testutil.WaitLong)
tmpdir := testutil.TempDirUnixSocket(t)
remoteSocketPath := filepath.Join(tmpdir, "remote-socket")
//nolint:dogsled
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
o.BlockReversePortForwarding = true
})
sshClient, err := agentConn.SSHClient(ctx)
require.NoError(t, err)
defer sshClient.Close()
_, err = sshClient.ListenUnix(remoteSocketPath)
require.ErrorContains(t, err, "streamlocal-forward@openssh.com request denied by peer")
}
// TestAgent_LocalBlockedDoesNotAffectReverse verifies that blocking
// local port forwarding does not prevent reverse port forwarding from
// working. A field-name transposition at any plumbing hop would cause
// both directions to be blocked when only one flag is set.
func TestAgent_LocalBlockedDoesNotAffectReverse(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
//nolint:dogsled
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
o.BlockLocalPortForwarding = true
})
sshClient, err := agentConn.SSHClient(ctx)
require.NoError(t, err)
defer sshClient.Close()
// Reverse forwarding must still work.
localhost := netip.MustParseAddr("127.0.0.1")
var ll net.Listener
for {
randomPort := testutil.RandomPortNoListen(t)
addr := net.TCPAddrFromAddrPort(netip.AddrPortFrom(localhost, randomPort))
ll, err = sshClient.ListenTCP(addr)
if err != nil {
t.Logf("error remote forwarding: %s", err.Error())
select {
case <-ctx.Done():
t.Fatal("timed out getting random listener")
default:
continue
}
}
break
}
_ = ll.Close()
}
// TestAgent_ReverseBlockedDoesNotAffectLocal verifies that blocking
// reverse port forwarding does not prevent local port forwarding from
// working.
func TestAgent_ReverseBlockedDoesNotAffectLocal(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
rl, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer rl.Close()
tcpAddr, valid := rl.Addr().(*net.TCPAddr)
require.True(t, valid)
remotePort := tcpAddr.Port
go echoOnce(t, rl)
//nolint:dogsled
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
o.BlockReversePortForwarding = true
})
sshClient, err := agentConn.SSHClient(ctx)
require.NoError(t, err)
defer sshClient.Close()
// Local forwarding must still work.
conn, err := sshClient.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", remotePort))
require.NoError(t, err)
defer conn.Close()
requireEcho(t, conn)
}
func TestAgent_UnixLocalForwarding(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
t.Skip("unix domain sockets are not fully supported on Windows")
}
ctx := testutil.Context(t, testutil.WaitLong)
tmpdir := testutil.TempDirUnixSocket(t)
remoteSocketPath := filepath.Join(tmpdir, "remote-socket")
l, err := net.Listen("unix", remoteSocketPath)
require.NoError(t, err)
defer l.Close()
go echoOnce(t, l)
sshClient := setupAgentSSHClient(ctx, t)
conn, err := sshClient.Dial("unix", remoteSocketPath)
require.NoError(t, err)
defer conn.Close()
_, err = conn.Write([]byte("test"))
require.NoError(t, err)
b := make([]byte, 4)
_, err = conn.Read(b)
require.NoError(t, err)
require.Equal(t, "test", string(b))
_ = conn.Close()
}
func TestAgent_UnixRemoteForwarding(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
t.Skip("unix domain sockets are not fully supported on Windows")
}
tmpdir := testutil.TempDirUnixSocket(t)
remoteSocketPath := filepath.Join(tmpdir, "remote-socket")
ctx := testutil.Context(t, testutil.WaitLong)
sshClient := setupAgentSSHClient(ctx, t)
l, err := sshClient.ListenUnix(remoteSocketPath)
require.NoError(t, err)
defer l.Close()
go echoOnce(t, l)
conn, err := net.Dial("unix", remoteSocketPath)
require.NoError(t, err)
defer conn.Close()
requireEcho(t, conn)
}
func TestAgent_SFTP(t *testing.T) {
t.Parallel()
t.Run("DefaultWorkingDirectory", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
u, err := user.Current()
require.NoError(t, err, "get current user")
home := u.HomeDir
if runtime.GOOS == "windows" {
home = "/" + strings.ReplaceAll(home, "\\", "/")
}
//nolint:dogsled
conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
sshClient, err := conn.SSHClient(ctx)
require.NoError(t, err)
defer sshClient.Close()
client, err := sftp.NewClient(sshClient)
require.NoError(t, err)
defer client.Close()
wd, err := client.Getwd()
require.NoError(t, err, "get working directory")
require.Equal(t, home, wd, "working directory should be user home")
tempFile := filepath.Join(t.TempDir(), "sftp")
// SFTP only accepts unix-y paths.
remoteFile := filepath.ToSlash(tempFile)
if !path.IsAbs(remoteFile) {
// On Windows, e.g. "/C:/Users/...".
remoteFile = path.Join("/", remoteFile)
}
file, err := client.Create(remoteFile)
require.NoError(t, err)
err = file.Close()
require.NoError(t, err)
_, err = os.Stat(tempFile)
require.NoError(t, err)
// Close the client to trigger disconnect event.
_ = client.Close()
assertConnectionReport(t, agentClient, proto.Connection_SSH, 0, "")
})
t.Run("CustomWorkingDirectory", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
// Create a custom directory for the agent to use.
customDir := t.TempDir()
expectedDir := customDir
if runtime.GOOS == "windows" {
expectedDir = "/" + strings.ReplaceAll(customDir, "\\", "/")
}
//nolint:dogsled
conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{
Directory: customDir,
}, 0)
sshClient, err := conn.SSHClient(ctx)
require.NoError(t, err)
defer sshClient.Close()
client, err := sftp.NewClient(sshClient)
require.NoError(t, err)
defer client.Close()
wd, err := client.Getwd()
require.NoError(t, err, "get working directory")
require.Equal(t, expectedDir, wd, "working directory should be custom directory")
// Close the client to trigger disconnect event.
_ = client.Close()
assertConnectionReport(t, agentClient, proto.Connection_SSH, 0, "")
})
}
func TestAgent_SCP(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
//nolint:dogsled
conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
sshClient, err := conn.SSHClient(ctx)
require.NoError(t, err)
defer sshClient.Close()
scpClient, err := scp.NewClientBySSH(sshClient)
require.NoError(t, err)
defer scpClient.Close()
tempFile := filepath.Join(t.TempDir(), "scp")
content := "hello world"
err = scpClient.CopyFile(context.Background(), strings.NewReader(content), tempFile, "0755")
require.NoError(t, err)
_, err = os.Stat(tempFile)
require.NoError(t, err)
// Close the client to trigger disconnect event.
scpClient.Close()
assertConnectionReport(t, agentClient, proto.Connection_SSH, 0, "")
}
func TestAgent_FileTransferBlocked(t *testing.T) {
t.Parallel()
assertFileTransferBlocked := func(t *testing.T, errorMessage string) {
// NOTE: Checking content of the error message is flaky. Most likely there is a race condition, which results
// in stopping the client in different phases, and returning different errors:
// - client read the full error message: File transfer has been disabled.
// - client's stream was terminated before reading the error message: EOF
// - client just read the error code (Windows): Process exited with status 65
isErr := strings.Contains(errorMessage, agentssh.BlockedFileTransferErrorMessage) ||
strings.Contains(errorMessage, "EOF") ||
strings.Contains(errorMessage, "Process exited with status 65")
require.True(t, isErr, "Message: "+errorMessage)
}
t.Run("SFTP", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
//nolint:dogsled
conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
o.BlockFileTransfer = true
})
sshClient, err := conn.SSHClient(ctx)
require.NoError(t, err)
defer sshClient.Close()
_, err = sftp.NewClient(sshClient)
require.Error(t, err)
assertFileTransferBlocked(t, err.Error())
assertConnectionReport(t, agentClient, proto.Connection_SSH, agentssh.BlockedFileTransferErrorCode, "")
})
t.Run("SCP with go-scp package", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
//nolint:dogsled
conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
o.BlockFileTransfer = true
})
sshClient, err := conn.SSHClient(ctx)
require.NoError(t, err)
defer sshClient.Close()
scpClient, err := scp.NewClientBySSH(sshClient)
require.NoError(t, err)
defer scpClient.Close()
tempFile := filepath.Join(t.TempDir(), "scp")
err = scpClient.CopyFile(context.Background(), strings.NewReader("hello world"), tempFile, "0755")
require.Error(t, err)
assertFileTransferBlocked(t, err.Error())
assertConnectionReport(t, agentClient, proto.Connection_SSH, agentssh.BlockedFileTransferErrorCode, "")
})
t.Run("Forbidden commands", func(t *testing.T) {
t.Parallel()
for _, c := range agentssh.BlockedFileTransferCommands {
t.Run(c, func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
//nolint:dogsled
conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
o.BlockFileTransfer = true
})
sshClient, err := conn.SSHClient(ctx)
require.NoError(t, err)
defer sshClient.Close()
session, err := sshClient.NewSession()
require.NoError(t, err)
defer session.Close()
stdout, err := session.StdoutPipe()
require.NoError(t, err)
//nolint:govet // we don't need `c := c` in Go 1.22
err = session.Start(c)
require.NoError(t, err)
defer session.Close()
msg, err := io.ReadAll(stdout)
require.NoError(t, err)
assertFileTransferBlocked(t, string(msg))
assertConnectionReport(t, agentClient, proto.Connection_SSH, agentssh.BlockedFileTransferErrorCode, "")
})
}
})
}
func TestAgent_EnvironmentVariables(t *testing.T) {
t.Parallel()
key := "EXAMPLE"
value := "value"
session := setupSSHSession(t, agentsdk.Manifest{
EnvironmentVariables: map[string]string{
key: value,
},
}, codersdk.ServiceBannerConfig{}, nil)
command := "sh -c 'echo $" + key + "'"
if runtime.GOOS == "windows" {
command = "cmd.exe /c echo %" + key + "%"
}
output, err := session.Output(command)
require.NoError(t, err)
require.Equal(t, value, strings.TrimSpace(string(output)))
}
func TestAgent_EnvironmentVariableExpansion(t *testing.T) {
t.Parallel()
key := "EXAMPLE"
session := setupSSHSession(t, agentsdk.Manifest{
EnvironmentVariables: map[string]string{
key: "$SOMETHINGNOTSET",
},
}, codersdk.ServiceBannerConfig{}, nil)
command := "sh -c 'echo $" + key + "'"
if runtime.GOOS == "windows" {
command = "cmd.exe /c echo %" + key + "%"
}
output, err := session.Output(command)
require.NoError(t, err)
expect := ""
if runtime.GOOS == "windows" {
expect = "%EXAMPLE%"
}
// Output should be empty, because the variable is not set!
require.Equal(t, expect, strings.TrimSpace(string(output)))
}
func TestAgent_CoderEnvVars(t *testing.T) {
t.Parallel()
for _, key := range []string{"CODER", "CODER_WORKSPACE_NAME", "CODER_WORKSPACE_OWNER_NAME", "CODER_WORKSPACE_AGENT_NAME"} {
t.Run(key, func(t *testing.T) {
t.Parallel()
session := setupSSHSession(t, agentsdk.Manifest{}, codersdk.ServiceBannerConfig{}, nil)
command := "sh -c 'echo $" + key + "'"
if runtime.GOOS == "windows" {
command = "cmd.exe /c echo %" + key + "%"
}
output, err := session.Output(command)
require.NoError(t, err)
require.NotEmpty(t, strings.TrimSpace(string(output)))
})
}
}
func TestAgent_SSHConnectionEnvVars(t *testing.T) {
t.Parallel()
// Note: the SSH_TTY environment variable should only be set for TTYs.
// For some reason this test produces a TTY locally and a non-TTY in CI
// so we don't test for the absence of SSH_TTY.
for _, key := range []string{"SSH_CONNECTION", "SSH_CLIENT"} {
t.Run(key, func(t *testing.T) {
t.Parallel()
session := setupSSHSession(t, agentsdk.Manifest{}, codersdk.ServiceBannerConfig{}, nil)
command := "sh -c 'echo $" + key + "'"
if runtime.GOOS == "windows" {
command = "cmd.exe /c echo %" + key + "%"
}
output, err := session.Output(command)
require.NoError(t, err)
require.NotEmpty(t, strings.TrimSpace(string(output)))
})
}
}
func TestAgent_SSHConnectionLoginVars(t *testing.T) {
t.Parallel()
envInfo := usershell.SystemEnvInfo{}
u, err := envInfo.User()
require.NoError(t, err, "get current user")
shell, err := envInfo.Shell(u.Username)
require.NoError(t, err, "get current shell")
tests := []struct {
key string
want string
}{
{
key: "USER",
want: u.Username,
},
{
key: "LOGNAME",
want: u.Username,
},
{
key: "SHELL",
want: shell,
},
}
for _, tt := range tests {
t.Run(tt.key, func(t *testing.T) {
t.Parallel()
session := setupSSHSession(t, agentsdk.Manifest{}, codersdk.ServiceBannerConfig{}, nil)
command := "sh -c 'echo $" + tt.key + "'"
if runtime.GOOS == "windows" {
command = "cmd.exe /c echo %" + tt.key + "%"
}
output, err := session.Output(command)
require.NoError(t, err)
require.Equal(t, tt.want, strings.TrimSpace(string(output)))
})
}
}
func TestAgent_Metadata(t *testing.T) {
t.Parallel()
echoHello := "echo 'hello'"
t.Run("Once", func(t *testing.T) {
t.Parallel()
//nolint:dogsled
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
Metadata: []codersdk.WorkspaceAgentMetadataDescription{
{
Key: "greeting1",
Interval: 0,
Script: echoHello,
},
{
Key: "greeting2",
Interval: 1,
Script: echoHello,
},
},
}, 0, func(_ *agenttest.Client, opts *agent.Options) {
opts.ReportMetadataInterval = testutil.IntervalFast
})
var gotMd map[string]agentsdk.Metadata
require.Eventually(t, func() bool {
gotMd = client.GetMetadata()
return len(gotMd) == 2
}, testutil.WaitShort, testutil.IntervalFast/2)
collectedAt1 := gotMd["greeting1"].CollectedAt
collectedAt2 := gotMd["greeting2"].CollectedAt
require.Eventually(t, func() bool {
gotMd = client.GetMetadata()
if len(gotMd) != 2 {
panic("unexpected number of metadata")
}
return !gotMd["greeting2"].CollectedAt.Equal(collectedAt2)
}, testutil.WaitShort, testutil.IntervalFast/2)
require.Equal(t, gotMd["greeting1"].CollectedAt, collectedAt1, "metadata should not be collected again")
})
t.Run("Many", func(t *testing.T) {
t.Parallel()
//nolint:dogsled
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
Metadata: []codersdk.WorkspaceAgentMetadataDescription{
{
Key: "greeting",
Interval: 1,
Timeout: 100,
Script: echoHello,
},
},
}, 0, func(_ *agenttest.Client, opts *agent.Options) {
opts.ReportMetadataInterval = testutil.IntervalFast
})
var gotMd map[string]agentsdk.Metadata
require.Eventually(t, func() bool {
gotMd = client.GetMetadata()
return len(gotMd) == 1
}, testutil.WaitShort, testutil.IntervalFast/2)
collectedAt1 := gotMd["greeting"].CollectedAt
require.Equal(t, "hello", strings.TrimSpace(gotMd["greeting"].Value))
if !assert.Eventually(t, func() bool {
gotMd = client.GetMetadata()
return gotMd["greeting"].CollectedAt.After(collectedAt1)
}, testutil.WaitShort, testutil.IntervalFast/2) {
t.Fatalf("expected metadata to be collected again")
}
})
}
func TestAgentMetadata_Timing(t *testing.T) {
if runtime.GOOS == "windows" {
// Shell scripting in Windows is a pain, and we have already tested
// that the OS logic works in the simpler tests.
t.SkipNow()
}
testutil.SkipIfNotTiming(t)
t.Parallel()
dir := t.TempDir()
const reportInterval = 2
const intervalUnit = 100 * time.Millisecond
var (
greetingPath = filepath.Join(dir, "greeting")
script = "echo hello | tee -a " + greetingPath
)
//nolint:dogsled
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
Metadata: []codersdk.WorkspaceAgentMetadataDescription{
{
Key: "greeting",
Interval: reportInterval,
Script: script,
},
{
Key: "bad",
Interval: reportInterval,
Script: "exit 1",
},
},
}, 0, func(_ *agenttest.Client, opts *agent.Options) {
opts.ReportMetadataInterval = intervalUnit
})
require.Eventually(t, func() bool {
return len(client.GetMetadata()) == 2
}, testutil.WaitShort, testutil.IntervalMedium)
for start := time.Now(); time.Since(start) < testutil.WaitMedium; time.Sleep(testutil.IntervalMedium) {
md := client.GetMetadata()
require.Len(t, md, 2, "got: %+v", md)
require.Equal(t, "hello\n", md["greeting"].Value)
require.Equal(t, "run cmd: exit status 1", md["bad"].Error)
greetingByt, err := os.ReadFile(greetingPath)
require.NoError(t, err)
var (
numGreetings = bytes.Count(greetingByt, []byte("hello"))
idealNumGreetings = time.Since(start) / (reportInterval * intervalUnit)
// We allow a 50% error margin because the report loop may backlog
// in CI and other toasters. In production, there is no hard
// guarantee on timing either, and the frontend gives similar
// wiggle room to the staleness of the value.
upperBound = int(idealNumGreetings) + 1
lowerBound = (int(idealNumGreetings) / 2)
)
if idealNumGreetings < 50 {
// There is an insufficient sample size.
continue
}
t.Logf("numGreetings: %d, idealNumGreetings: %d", numGreetings, idealNumGreetings)
// The report loop may slow down on load, but it should never, ever
// speed up.
if numGreetings > upperBound {
t.Fatalf("too many greetings: %d > %d in %v", numGreetings, upperBound, time.Since(start))
} else if numGreetings < lowerBound {
t.Fatalf("too few greetings: %d < %d", numGreetings, lowerBound)
}
}
}
func TestAgent_Lifecycle(t *testing.T) {
t.Parallel()
t.Run("StartTimeout", func(t *testing.T) {
t.Parallel()
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
Scripts: []codersdk.WorkspaceAgentScript{{
Script: "sleep 3",
Timeout: time.Millisecond,
RunOnStart: true,
}},
}, 0)
want := []codersdk.WorkspaceAgentLifecycle{
codersdk.WorkspaceAgentLifecycleStarting,
codersdk.WorkspaceAgentLifecycleStartTimeout,
}
var got []codersdk.WorkspaceAgentLifecycle
assert.Eventually(t, func() bool {
got = client.GetLifecycleStates()
return slices.Contains(got, want[len(want)-1])
}, testutil.WaitShort, testutil.IntervalMedium)
require.Equal(t, want, got[:len(want)])
})
t.Run("StartError", func(t *testing.T) {
t.Parallel()
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
Scripts: []codersdk.WorkspaceAgentScript{{
Script: "false",
Timeout: 30 * time.Second,
RunOnStart: true,
}},
}, 0)
want := []codersdk.WorkspaceAgentLifecycle{
codersdk.WorkspaceAgentLifecycleStarting,
codersdk.WorkspaceAgentLifecycleStartError,
}
var got []codersdk.WorkspaceAgentLifecycle
assert.Eventually(t, func() bool {
got = client.GetLifecycleStates()
return slices.Contains(got, want[len(want)-1])
}, testutil.WaitShort, testutil.IntervalMedium)
require.Equal(t, want, got[:len(want)])
})
t.Run("Ready", func(t *testing.T) {
t.Parallel()
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
Scripts: []codersdk.WorkspaceAgentScript{{
Script: "echo foo",
Timeout: 30 * time.Second,
RunOnStart: true,
}},
}, 0)
want := []codersdk.WorkspaceAgentLifecycle{
codersdk.WorkspaceAgentLifecycleStarting,
codersdk.WorkspaceAgentLifecycleReady,
}
var got []codersdk.WorkspaceAgentLifecycle
assert.Eventually(t, func() bool {
got = client.GetLifecycleStates()
return len(got) > 0 && got[len(got)-1] == want[len(want)-1]
}, testutil.WaitShort, testutil.IntervalMedium)
require.Equal(t, want, got)
})
t.Run("ShuttingDown", func(t *testing.T) {
t.Parallel()
_, client, _, _, closer := setupAgent(t, agentsdk.Manifest{
Scripts: []codersdk.WorkspaceAgentScript{{
Script: "sleep 3",
Timeout: 30 * time.Second,
RunOnStop: true,
}},
}, 0)
assert.Eventually(t, func() bool {
return slices.Contains(client.GetLifecycleStates(), codersdk.WorkspaceAgentLifecycleReady)
}, testutil.WaitShort, testutil.IntervalMedium)
// Start close asynchronously so that we an inspect the state.
done := make(chan struct{})
go func() {
defer close(done)
err := closer.Close()
assert.NoError(t, err)
}()
t.Cleanup(func() {
<-done
})
want := []codersdk.WorkspaceAgentLifecycle{
codersdk.WorkspaceAgentLifecycleStarting,
codersdk.WorkspaceAgentLifecycleReady,
codersdk.WorkspaceAgentLifecycleShuttingDown,
}
var got []codersdk.WorkspaceAgentLifecycle
assert.Eventually(t, func() bool {
got = client.GetLifecycleStates()
return slices.Contains(got, want[len(want)-1])
}, testutil.WaitShort, testutil.IntervalMedium)
require.Equal(t, want, got[:len(want)])
})
t.Run("ShutdownTimeout", func(t *testing.T) {
t.Parallel()
_, client, _, _, closer := setupAgent(t, agentsdk.Manifest{
Scripts: []codersdk.WorkspaceAgentScript{{
Script: "sleep 3",
Timeout: time.Millisecond,
RunOnStop: true,
}},
}, 0)
assert.Eventually(t, func() bool {
return slices.Contains(client.GetLifecycleStates(), codersdk.WorkspaceAgentLifecycleReady)
}, testutil.WaitShort, testutil.IntervalMedium)
// Start close asynchronously so that we an inspect the state.
done := make(chan struct{})
go func() {
defer close(done)
err := closer.Close()
assert.NoError(t, err)
}()
t.Cleanup(func() {
<-done
})
want := []codersdk.WorkspaceAgentLifecycle{
codersdk.WorkspaceAgentLifecycleStarting,
codersdk.WorkspaceAgentLifecycleReady,
codersdk.WorkspaceAgentLifecycleShuttingDown,
codersdk.WorkspaceAgentLifecycleShutdownTimeout,
}
var got []codersdk.WorkspaceAgentLifecycle
assert.Eventually(t, func() bool {
got = client.GetLifecycleStates()
return slices.Contains(got, want[len(want)-1])
}, testutil.WaitShort, testutil.IntervalMedium)
require.Equal(t, want, got[:len(want)])
})
t.Run("ShutdownError", func(t *testing.T) {
t.Parallel()
_, client, _, _, closer := setupAgent(t, agentsdk.Manifest{
Scripts: []codersdk.WorkspaceAgentScript{{
Script: "false",
Timeout: 30 * time.Second,
RunOnStop: true,
}},
}, 0)
assert.Eventually(t, func() bool {
return slices.Contains(client.GetLifecycleStates(), codersdk.WorkspaceAgentLifecycleReady)
}, testutil.WaitShort, testutil.IntervalMedium)
// Start close asynchronously so that we an inspect the state.
done := make(chan struct{})
go func() {
defer close(done)
err := closer.Close()
assert.NoError(t, err)
}()
t.Cleanup(func() {
<-done
})
want := []codersdk.WorkspaceAgentLifecycle{
codersdk.WorkspaceAgentLifecycleStarting,
codersdk.WorkspaceAgentLifecycleReady,
codersdk.WorkspaceAgentLifecycleShuttingDown,
codersdk.WorkspaceAgentLifecycleShutdownError,
}
var got []codersdk.WorkspaceAgentLifecycle
assert.Eventually(t, func() bool {
got = client.GetLifecycleStates()
return slices.Contains(got, want[len(want)-1])
}, testutil.WaitShort, testutil.IntervalMedium)
require.Equal(t, want, got[:len(want)])
})
t.Run("ShutdownScriptOnce", func(t *testing.T) {
t.Parallel()
logger := testutil.Logger(t)
ctx := testutil.Context(t, testutil.WaitMedium)
expected := "this-is-shutdown"
derpMap, _ := tailnettest.RunDERPAndSTUN(t)
statsCh := make(chan *proto.Stats, 50)
client := agenttest.NewClient(t,
logger,
uuid.New(),
agentsdk.Manifest{
DERPMap: derpMap,
Scripts: []codersdk.WorkspaceAgentScript{{
ID: uuid.New(),
LogPath: "coder-startup-script.log",
Script: "echo 1",
RunOnStart: true,
}, {
ID: uuid.New(),
LogPath: "coder-shutdown-script.log",
Script: "echo " + expected,
RunOnStop: true,
}},
},
statsCh,
tailnet.NewCoordinator(logger),
)
defer client.Close()
fs := afero.NewMemMapFs()
agent := agent.New(agent.Options{
Client: client,
Logger: logger.Named("agent"),
Filesystem: fs,
})
// agent.Close() loads the shutdown script from the agent metadata.
// The metadata is populated just before execution of the startup script, so it's mandatory to wait
// until the startup starts.
require.Eventually(t, func() bool {
outputPath := filepath.Join(os.TempDir(), "coder-startup-script.log")
content, err := afero.ReadFile(fs, outputPath)
if err != nil {
t.Logf("read file %q: %s", outputPath, err)
return false
}
return len(content) > 0 // something is in the startup log file
}, testutil.WaitShort, testutil.IntervalMedium)
// In order to avoid shutting down the agent before it is fully started and triggering
// errors, we'll wait until the agent is fully up. It's a bit hokey, but among the last things the agent starts
// is the stats reporting, so getting a stats report is a good indication the agent is fully up.
_ = testutil.TryReceive(ctx, t, statsCh)
err := agent.Close()
require.NoError(t, err, "agent should be closed successfully")
outputPath := filepath.Join(os.TempDir(), "coder-shutdown-script.log")
logFirstRead, err := afero.ReadFile(fs, outputPath)
require.NoError(t, err, "log file should be present")
require.Equal(t, expected, string(bytes.TrimSpace(logFirstRead)))
// Make sure that script can't be executed twice.
err = agent.Close()
require.NoError(t, err, "don't need to close the agent twice, no effect")
logSecondRead, err := afero.ReadFile(fs, outputPath)
require.NoError(t, err, "log file should be present")
require.Equal(t, string(bytes.TrimSpace(logFirstRead)), string(bytes.TrimSpace(logSecondRead)))
})
}
func TestAgent_Startup(t *testing.T) {
t.Parallel()
t.Run("EmptyDirectory", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
Directory: "",
}, 0)
startup := testutil.TryReceive(ctx, t, client.GetStartup())
require.Equal(t, "", startup.GetExpandedDirectory())
})
t.Run("HomeDirectory", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
Directory: "~",
}, 0)
startup := testutil.TryReceive(ctx, t, client.GetStartup())
homeDir, err := os.UserHomeDir()
require.NoError(t, err)
require.Equal(t, homeDir, startup.GetExpandedDirectory())
})
t.Run("NotAbsoluteDirectory", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
Directory: "coder/coder",
}, 0)
startup := testutil.TryReceive(ctx, t, client.GetStartup())
homeDir, err := os.UserHomeDir()
require.NoError(t, err)
require.Equal(t, filepath.Join(homeDir, "coder/coder"), startup.GetExpandedDirectory())
})
t.Run("HomeEnvironmentVariable", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
Directory: "$HOME",
}, 0)
startup := testutil.TryReceive(ctx, t, client.GetStartup())
homeDir, err := os.UserHomeDir()
require.NoError(t, err)
require.Equal(t, homeDir, startup.GetExpandedDirectory())
})
}
//nolint:paralleltest // This test sets an environment variable.
func TestAgent_ReconnectingPTY(t *testing.T) {
if runtime.GOOS == "windows" {
// This might be our implementation, or ConPTY itself.
// It's difficult to find extensive tests for it, so
// it seems like it could be either.
t.Skip("ConPTY appears to be inconsistent on Windows.")
}
backends := []string{"Buffered", "Screen"}
_, err := exec.LookPath("screen")
hasScreen := err == nil
// Make sure UTF-8 works even with LANG set to something like C.
t.Setenv("LANG", "C")
for _, backendType := range backends {
t.Run(backendType, func(t *testing.T) {
if backendType == "Screen" {
if runtime.GOOS != "linux" {
t.Skipf("`screen` is not supported on %s", runtime.GOOS)
} else if !hasScreen {
t.Skip("`screen` not found")
}
} else if hasScreen && runtime.GOOS == "linux" {
// Set up a PATH that does not have screen in it.
bashPath, err := exec.LookPath("bash")
require.NoError(t, err)
dir, err := os.MkdirTemp("/tmp", "coder-test-reconnecting-pty-PATH")
require.NoError(t, err, "create temp dir for reconnecting pty PATH")
err = os.Symlink(bashPath, filepath.Join(dir, "bash"))
require.NoError(t, err, "symlink bash into reconnecting pty PATH")
t.Setenv("PATH", dir)
}
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
//nolint:dogsled
conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
idConnectionReport := uuid.New()
id := uuid.New()
// Test that the connection is reported. This must be tested in the
// first connection because we care about verifying all of these.
netConn0, err := conn.ReconnectingPTY(ctx, idConnectionReport, 80, 80, "bash --norc")
require.NoError(t, err)
_ = netConn0.Close()
assertConnectionReport(t, agentClient, proto.Connection_RECONNECTING_PTY, 0, "")
// --norc disables executing .bashrc, which is often used to customize the bash prompt
netConn1, err := conn.ReconnectingPTY(ctx, id, 80, 80, "bash --norc")
require.NoError(t, err)
defer netConn1.Close()
tr1 := testutil.NewTerminalReader(t, netConn1)
// A second simultaneous connection.
netConn2, err := conn.ReconnectingPTY(ctx, id, 80, 80, "bash --norc")
require.NoError(t, err)
defer netConn2.Close()
tr2 := testutil.NewTerminalReader(t, netConn2)
matchPrompt := func(line string) bool {
return strings.Contains(line, "$ ") || strings.Contains(line, "# ")
}
matchEchoCommand := func(line string) bool {
return strings.Contains(line, "echo test")
}
matchEchoOutput := func(line string) bool {
return strings.Contains(line, "test") && !strings.Contains(line, "echo")
}
matchExitCommand := func(line string) bool {
return strings.Contains(line, "exit")
}
matchExitOutput := func(line string) bool {
return strings.Contains(line, "exit") || strings.Contains(line, "logout")
}
// Wait for the prompt before writing commands. If the command arrives before the prompt is written, screen
// will sometimes put the command output on the same line as the command and the test will flake
require.NoError(t, tr1.ReadUntil(ctx, matchPrompt), "find prompt")
require.NoError(t, tr2.ReadUntil(ctx, matchPrompt), "find prompt")
data, err := json.Marshal(workspacesdk.ReconnectingPTYRequest{
Data: "echo test\r",
})
require.NoError(t, err)
_, err = netConn1.Write(data)
require.NoError(t, err)
// Once for typing the command...
require.NoError(t, tr1.ReadUntil(ctx, matchEchoCommand), "find echo command")
// And another time for the actual output.
require.NoError(t, tr1.ReadUntil(ctx, matchEchoOutput), "find echo output")
// Same for the other connection.
require.NoError(t, tr2.ReadUntil(ctx, matchEchoCommand), "find echo command")
require.NoError(t, tr2.ReadUntil(ctx, matchEchoOutput), "find echo output")
_ = netConn1.Close()
_ = netConn2.Close()
netConn3, err := conn.ReconnectingPTY(ctx, id, 80, 80, "bash --norc")
require.NoError(t, err)
defer netConn3.Close()
tr3 := testutil.NewTerminalReader(t, netConn3)
// Same output again!
require.NoError(t, tr3.ReadUntil(ctx, matchEchoCommand), "find echo command")
require.NoError(t, tr3.ReadUntil(ctx, matchEchoOutput), "find echo output")
// Exit should cause the connection to close.
data, err = json.Marshal(workspacesdk.ReconnectingPTYRequest{
Data: "exit\r",
})
require.NoError(t, err)
_, err = netConn3.Write(data)
require.NoError(t, err)
// Once for the input and again for the output.
require.NoError(t, tr3.ReadUntil(ctx, matchExitCommand), "find exit command")
require.NoError(t, tr3.ReadUntil(ctx, matchExitOutput), "find exit output")
// Wait for the connection to close.
require.ErrorIs(t, tr3.ReadUntil(ctx, nil), io.EOF)
// Try a non-shell command. It should output then immediately exit.
netConn4, err := conn.ReconnectingPTY(ctx, uuid.New(), 80, 80, "echo test")
require.NoError(t, err)
defer netConn4.Close()
tr4 := testutil.NewTerminalReader(t, netConn4)
require.NoError(t, tr4.ReadUntil(ctx, matchEchoOutput), "find echo output")
require.ErrorIs(t, tr4.ReadUntil(ctx, nil), io.EOF)
// Ensure that UTF-8 is supported. Avoid the terminal emulator because it
// does not appear to support UTF-8, just make sure the bytes that come
// back have the character in it.
netConn5, err := conn.ReconnectingPTY(ctx, uuid.New(), 80, 80, "echo ")
require.NoError(t, err)
defer netConn5.Close()
bytes, err := io.ReadAll(netConn5)
require.NoError(t, err)
require.Contains(t, string(bytes), "")
})
}
}
// This tests end-to-end functionality of connecting to a running container
// and executing a command. It creates a real Docker container and runs a
// command. As such, it does not run by default in CI.
// You can run it manually as follows:
//
// CODER_TEST_USE_DOCKER=1 go test -count=1 ./agent -run TestAgent_ReconnectingPTYContainer
func TestAgent_ReconnectingPTYContainer(t *testing.T) {
t.Parallel()
if os.Getenv("CODER_TEST_USE_DOCKER") != "1" {
t.Skip("Set CODER_TEST_USE_DOCKER=1 to run this test")
}
if _, err := exec.LookPath("devcontainer"); err != nil {
t.Skip("This test requires the devcontainer CLI: npm install -g @devcontainers/cli")
}
pool, err := dockertest.NewPool("")
require.NoError(t, err, "Could not connect to docker")
ct, err := pool.RunWithOptions(&dockertest.RunOptions{
Repository: "busybox",
Tag: "latest",
Cmd: []string{"sleep", "infnity"},
}, func(config *docker.HostConfig) {
config.AutoRemove = true
config.RestartPolicy = docker.RestartPolicy{Name: "no"}
})
require.NoError(t, err, "Could not start container")
defer func() {
err := pool.Purge(ct)
require.NoError(t, err, "Could not stop container")
}()
// Wait for container to start
require.Eventually(t, func() bool {
ct, ok := pool.ContainerByName(ct.Container.Name)
return ok && ct.Container.State.Running
}, testutil.WaitShort, testutil.IntervalSlow, "Container did not start in time")
// nolint: dogsled
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
o.Devcontainers = true
o.DevcontainerAPIOptions = append(o.DevcontainerAPIOptions,
agentcontainers.WithContainerLabelIncludeFilter("this.label.does.not.exist.ignore.devcontainers", "true"),
)
})
ctx := testutil.Context(t, testutil.WaitLong)
ac, err := conn.ReconnectingPTY(ctx, uuid.New(), 80, 80, "/bin/sh", func(arp *workspacesdk.AgentReconnectingPTYInit) {
arp.Container = ct.Container.ID
})
require.NoError(t, err, "failed to create ReconnectingPTY")
defer ac.Close()
tr := testutil.NewTerminalReader(t, ac)
require.NoError(t, tr.ReadUntil(ctx, func(line string) bool {
return strings.Contains(line, "#") || strings.Contains(line, "$")
}), "find prompt")
require.NoError(t, json.NewEncoder(ac).Encode(workspacesdk.ReconnectingPTYRequest{
Data: "hostname\r",
}), "write hostname")
require.NoError(t, tr.ReadUntil(ctx, func(line string) bool {
return strings.Contains(line, "hostname")
}), "find hostname command")
require.NoError(t, tr.ReadUntil(ctx, func(line string) bool {
return strings.Contains(line, ct.Container.Config.Hostname)
}), "find hostname output")
require.NoError(t, json.NewEncoder(ac).Encode(workspacesdk.ReconnectingPTYRequest{
Data: "exit\r",
}), "write exit command")
// Wait for the connection to close.
require.ErrorIs(t, tr.ReadUntil(ctx, nil), io.EOF)
}
type subAgentRequestPayload struct {
Token string `json:"token"`
Directory string `json:"directory"`
}
// runSubAgentMain is the main function for the sub-agent that connects
// to the control plane. It reads the CODER_AGENT_URL and
// CODER_AGENT_TOKEN environment variables, sends the token, and exits
// with a status code based on the response.
func runSubAgentMain() int {
url := os.Getenv("CODER_AGENT_URL")
token := os.Getenv("CODER_AGENT_TOKEN")
if url == "" || token == "" {
_, _ = fmt.Fprintln(os.Stderr, "CODER_AGENT_URL and CODER_AGENT_TOKEN must be set")
return 10
}
dir, err := os.Getwd()
if err != nil {
_, _ = fmt.Fprintf(os.Stderr, "failed to get current working directory: %v\n", err)
return 1
}
payload := subAgentRequestPayload{
Token: token,
Directory: dir,
}
b, err := json.Marshal(payload)
if err != nil {
_, _ = fmt.Fprintf(os.Stderr, "failed to marshal payload: %v\n", err)
return 1
}
req, err := http.NewRequest("POST", url, bytes.NewReader(b))
if err != nil {
_, _ = fmt.Fprintf(os.Stderr, "failed to create request: %v\n", err)
return 1
}
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
req = req.WithContext(ctx)
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
_, _ = fmt.Fprintf(os.Stderr, "agent connection failed: %v\n", err)
return 11
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
_, _ = fmt.Fprintf(os.Stderr, "agent exiting with non-zero exit code %d\n", resp.StatusCode)
return 12
}
_, _ = fmt.Println("sub-agent connected successfully")
return 0
}
// This tests end-to-end functionality of auto-starting a devcontainer.
// It runs "devcontainer up" which creates a real Docker container. As
// such, it does not run by default in CI.
//
// You can run it manually as follows:
//
// CODER_TEST_USE_DOCKER=1 go test -count=1 ./agent -run TestAgent_DevcontainerAutostart
//
//nolint:paralleltest // This test sets an environment variable.
func TestAgent_DevcontainerAutostart(t *testing.T) {
if os.Getenv("CODER_TEST_USE_DOCKER") != "1" {
t.Skip("Set CODER_TEST_USE_DOCKER=1 to run this test")
}
if _, err := exec.LookPath("devcontainer"); err != nil {
t.Skip("This test requires the devcontainer CLI: npm install -g @devcontainers/cli")
}
// This HTTP handler handles requests from runSubAgentMain which
// acts as a fake sub-agent. We want to verify that the sub-agent
// connects and sends its token. We use a channel to signal
// that the sub-agent has connected successfully and then we wait
// until we receive another signal to return from the handler. This
// keeps the agent "alive" for as long as we want.
subAgentConnected := make(chan subAgentRequestPayload, 1)
subAgentReady := make(chan struct{}, 1)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodGet && strings.HasPrefix(r.URL.Path, "/api/v2/workspaceagents/me/") {
return
}
t.Logf("Sub-agent request received: %s %s", r.Method, r.URL.Path)
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// Read the token from the request body.
var payload subAgentRequestPayload
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
http.Error(w, "Failed to read token", http.StatusBadRequest)
t.Logf("Failed to read token: %v", err)
return
}
defer r.Body.Close()
t.Logf("Sub-agent request payload received: %+v", payload)
// Signal that the sub-agent has connected successfully.
select {
case <-t.Context().Done():
t.Logf("Test context done, not processing sub-agent request")
return
case subAgentConnected <- payload:
}
// Wait for the signal to return from the handler.
select {
case <-t.Context().Done():
t.Logf("Test context done, not waiting for sub-agent ready")
return
case <-subAgentReady:
}
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()
pool, err := dockertest.NewPool("")
require.NoError(t, err, "Could not connect to docker")
// Prepare temporary devcontainer for test (mywork).
devcontainerID := uuid.New()
tmpdir := t.TempDir()
t.Setenv("HOME", tmpdir)
tempWorkspaceFolder := filepath.Join(tmpdir, "mywork")
unexpandedWorkspaceFolder := filepath.Join("~", "mywork")
t.Logf("Workspace folder: %s", tempWorkspaceFolder)
t.Logf("Unexpanded workspace folder: %s", unexpandedWorkspaceFolder)
devcontainerPath := filepath.Join(tempWorkspaceFolder, ".devcontainer")
err = os.MkdirAll(devcontainerPath, 0o755)
require.NoError(t, err, "create devcontainer directory")
devcontainerFile := filepath.Join(devcontainerPath, "devcontainer.json")
err = os.WriteFile(devcontainerFile, []byte(`{
"name": "mywork",
"image": "ubuntu:latest",
"cmd": ["sleep", "infinity"],
"runArgs": ["--network=host", "--label=`+agentcontainers.DevcontainerIsTestRunLabel+`=true"]
}`), 0o600)
require.NoError(t, err, "write devcontainer.json")
manifest := agentsdk.Manifest{
// Set up pre-conditions for auto-starting a devcontainer, the script
// is expected to be prepared by the provisioner normally.
Devcontainers: []codersdk.WorkspaceAgentDevcontainer{
{
ID: devcontainerID,
Name: "test",
// Use an unexpanded path to test the expansion.
WorkspaceFolder: unexpandedWorkspaceFolder,
},
},
Scripts: []codersdk.WorkspaceAgentScript{
{
ID: devcontainerID,
LogSourceID: agentsdk.ExternalLogSourceID,
RunOnStart: true,
Script: "echo this-will-be-replaced",
DisplayName: "Dev Container (test)",
},
},
}
mClock := quartz.NewMock(t)
mClock.Set(time.Now())
tickerFuncTrap := mClock.Trap().TickerFunc("agentcontainers")
//nolint:dogsled
_, agentClient, _, _, _ := setupAgent(t, manifest, 0, func(_ *agenttest.Client, o *agent.Options) {
o.Devcontainers = true
o.DevcontainerAPIOptions = append(
o.DevcontainerAPIOptions,
// Only match this specific dev container.
agentcontainers.WithClock(mClock),
agentcontainers.WithContainerLabelIncludeFilter("devcontainer.local_folder", tempWorkspaceFolder),
agentcontainers.WithContainerLabelIncludeFilter(agentcontainers.DevcontainerIsTestRunLabel, "true"),
agentcontainers.WithSubAgentURL(srv.URL),
// The agent will copy "itself", but in the case of this test, the
// agent is actually this test binary. So we'll tell the test binary
// to execute the sub-agent main function via this env.
agentcontainers.WithSubAgentEnv("CODER_TEST_RUN_SUB_AGENT_MAIN=1"),
)
})
t.Logf("Waiting for container with label: devcontainer.local_folder=%s", tempWorkspaceFolder)
var container docker.APIContainers
require.Eventually(t, func() bool {
containers, err := pool.Client.ListContainers(docker.ListContainersOptions{All: true})
if err != nil {
t.Logf("Error listing containers: %v", err)
return false
}
for _, c := range containers {
t.Logf("Found container: %s with labels: %v", c.ID[:12], c.Labels)
if labelValue, ok := c.Labels["devcontainer.local_folder"]; ok {
if labelValue == tempWorkspaceFolder {
t.Logf("Found matching container: %s", c.ID[:12])
container = c
return true
}
}
}
return false
}, testutil.WaitSuperLong, testutil.IntervalMedium, "no container with workspace folder label found")
defer func() {
// We can't rely on pool here because the container is not
// managed by it (it is managed by @devcontainer/cli).
err := pool.Client.RemoveContainer(docker.RemoveContainerOptions{
ID: container.ID,
RemoveVolumes: true,
Force: true,
})
assert.NoError(t, err, "remove container")
}()
containerInfo, err := pool.Client.InspectContainer(container.ID)
require.NoError(t, err, "inspect container")
t.Logf("Container state: status: %v", containerInfo.State.Status)
require.True(t, containerInfo.State.Running, "container should be running")
ctx := testutil.Context(t, testutil.WaitLong)
// Ensure the container update routine runs.
tickerFuncTrap.MustWait(ctx).MustRelease(ctx)
tickerFuncTrap.Close()
// Since the agent does RefreshContainers, and the ticker function
// is set to skip instead of queue, we must advance the clock
// multiple times to ensure that the sub-agent is created.
var subAgents []*proto.SubAgent
for {
_, next := mClock.AdvanceNext()
next.MustWait(ctx)
// Verify that a subagent was created.
subAgents = agentClient.GetSubAgents()
if len(subAgents) > 0 {
t.Logf("Found sub-agents: %d", len(subAgents))
break
}
}
require.Len(t, subAgents, 1, "expected one sub agent")
subAgent := subAgents[0]
subAgentID, err := uuid.FromBytes(subAgent.GetId())
require.NoError(t, err, "failed to parse sub-agent ID")
t.Logf("Connecting to sub-agent: %s (ID: %s)", subAgent.Name, subAgentID)
gotDir, err := agentClient.GetSubAgentDirectory(subAgentID)
require.NoError(t, err, "failed to get sub-agent directory")
require.Equal(t, "/workspaces/mywork", gotDir, "sub-agent directory should match")
subAgentToken, err := uuid.FromBytes(subAgent.GetAuthToken())
require.NoError(t, err, "failed to parse sub-agent token")
payload := testutil.RequireReceive(ctx, t, subAgentConnected)
require.Equal(t, subAgentToken.String(), payload.Token, "sub-agent token should match")
require.Equal(t, "/workspaces/mywork", payload.Directory, "sub-agent directory should match")
// Allow the subagent to exit.
close(subAgentReady)
}
// TestAgent_DevcontainerRecreate tests that RecreateDevcontainer
// recreates a devcontainer and emits logs.
//
// This tests end-to-end functionality of auto-starting a devcontainer.
// It runs "devcontainer up" which creates a real Docker container. As
// such, it does not run by default in CI.
//
// You can run it manually as follows:
//
// CODER_TEST_USE_DOCKER=1 go test -count=1 ./agent -run TestAgent_DevcontainerRecreate
func TestAgent_DevcontainerRecreate(t *testing.T) {
if os.Getenv("CODER_TEST_USE_DOCKER") != "1" {
t.Skip("Set CODER_TEST_USE_DOCKER=1 to run this test")
}
t.Parallel()
pool, err := dockertest.NewPool("")
require.NoError(t, err, "Could not connect to docker")
// Prepare temporary devcontainer for test (mywork).
devcontainerID := uuid.New()
devcontainerLogSourceID := uuid.New()
workspaceFolder := filepath.Join(t.TempDir(), "mywork")
t.Logf("Workspace folder: %s", workspaceFolder)
devcontainerPath := filepath.Join(workspaceFolder, ".devcontainer")
err = os.MkdirAll(devcontainerPath, 0o755)
require.NoError(t, err, "create devcontainer directory")
devcontainerFile := filepath.Join(devcontainerPath, "devcontainer.json")
err = os.WriteFile(devcontainerFile, []byte(`{
"name": "mywork",
"image": "busybox:latest",
"cmd": ["sleep", "infinity"],
"runArgs": ["--label=`+agentcontainers.DevcontainerIsTestRunLabel+`=true"]
}`), 0o600)
require.NoError(t, err, "write devcontainer.json")
manifest := agentsdk.Manifest{
// Set up pre-conditions for auto-starting a devcontainer, the
// script is used to extract the log source ID.
Devcontainers: []codersdk.WorkspaceAgentDevcontainer{
{
ID: devcontainerID,
Name: "test",
WorkspaceFolder: workspaceFolder,
},
},
Scripts: []codersdk.WorkspaceAgentScript{
{
ID: devcontainerID,
LogSourceID: devcontainerLogSourceID,
},
},
}
//nolint:dogsled
conn, client, _, _, _ := setupAgent(t, manifest, 0, func(_ *agenttest.Client, o *agent.Options) {
o.Devcontainers = true
o.DevcontainerAPIOptions = append(o.DevcontainerAPIOptions,
agentcontainers.WithContainerLabelIncludeFilter("devcontainer.local_folder", workspaceFolder),
agentcontainers.WithContainerLabelIncludeFilter(agentcontainers.DevcontainerIsTestRunLabel, "true"),
)
})
ctx := testutil.Context(t, testutil.WaitLong)
// We enabled autostart for the devcontainer, so ready is a good
// indication that the devcontainer is up and running. Importantly,
// this also means that the devcontainer startup is no longer
// producing logs that may interfere with the recreate logs.
testutil.Eventually(ctx, t, func(context.Context) bool {
states := client.GetLifecycleStates()
return slices.Contains(states, codersdk.WorkspaceAgentLifecycleReady)
}, testutil.IntervalMedium, "devcontainer not ready")
t.Logf("Looking for container with label: devcontainer.local_folder=%s", workspaceFolder)
var container codersdk.WorkspaceAgentContainer
testutil.Eventually(ctx, t, func(context.Context) bool {
resp, err := conn.ListContainers(ctx)
if err != nil {
t.Logf("Error listing containers: %v", err)
return false
}
for _, c := range resp.Containers {
t.Logf("Found container: %s with labels: %v", c.ID[:12], c.Labels)
if v, ok := c.Labels["devcontainer.local_folder"]; ok && v == workspaceFolder {
t.Logf("Found matching container: %s", c.ID[:12])
container = c
return true
}
}
return false
}, testutil.IntervalMedium, "no container with workspace folder label found")
defer func(container codersdk.WorkspaceAgentContainer) {
// We can't rely on pool here because the container is not
// managed by it (it is managed by @devcontainer/cli).
err := pool.Client.RemoveContainer(docker.RemoveContainerOptions{
ID: container.ID,
RemoveVolumes: true,
Force: true,
})
assert.Error(t, err, "container should be removed by recreate")
}(container)
ctx = testutil.Context(t, testutil.WaitLong) // Reset context.
// Capture logs via ScriptLogger.
logsCh := make(chan *proto.BatchCreateLogsRequest, 1)
client.SetLogsChannel(logsCh)
// Invoke recreate to trigger the destruction and recreation of the
// devcontainer, we do it in a goroutine so we can process logs
// concurrently.
go func(container codersdk.WorkspaceAgentContainer) {
_, err := conn.RecreateDevcontainer(ctx, devcontainerID.String())
assert.NoError(t, err, "recreate devcontainer should succeed")
}(container)
t.Logf("Checking recreate logs for outcome...")
// Wait for the logs to be emitted, the @devcontainer/cli up command
// will emit a log with the outcome at the end suggesting we did
// receive all the logs.
waitForOutcomeLoop:
for {
batch := testutil.RequireReceive(ctx, t, logsCh)
if bytes.Equal(batch.LogSourceId, devcontainerLogSourceID[:]) {
for _, log := range batch.Logs {
t.Logf("Received log: %s", log.Output)
if strings.Contains(log.Output, "\"outcome\"") {
break waitForOutcomeLoop
}
}
}
}
t.Logf("Checking there's a new container with label: devcontainer.local_folder=%s", workspaceFolder)
// Make sure the container exists and isn't the same as the old one.
testutil.Eventually(ctx, t, func(context.Context) bool {
resp, err := conn.ListContainers(ctx)
if err != nil {
t.Logf("Error listing containers: %v", err)
return false
}
for _, c := range resp.Containers {
t.Logf("Found container: %s with labels: %v", c.ID[:12], c.Labels)
if v, ok := c.Labels["devcontainer.local_folder"]; ok && v == workspaceFolder {
if c.ID == container.ID {
t.Logf("Found same container: %s", c.ID[:12])
return false
}
t.Logf("Found new container: %s", c.ID[:12])
container = c
return true
}
}
return false
}, testutil.IntervalMedium, "new devcontainer not found")
defer func(container codersdk.WorkspaceAgentContainer) {
// We can't rely on pool here because the container is not
// managed by it (it is managed by @devcontainer/cli).
err := pool.Client.RemoveContainer(docker.RemoveContainerOptions{
ID: container.ID,
RemoveVolumes: true,
Force: true,
})
assert.NoError(t, err, "remove container")
}(container)
}
func TestAgent_DevcontainersDisabledForSubAgent(t *testing.T) {
t.Parallel()
// Create a manifest with a ParentID to make this a sub agent.
manifest := agentsdk.Manifest{
AgentID: uuid.New(),
ParentID: uuid.New(),
}
// Setup the agent with devcontainers enabled initially.
//nolint:dogsled
conn, _, _, _, _ := setupAgent(t, manifest, 0, func(_ *agenttest.Client, o *agent.Options) {
o.Devcontainers = true
})
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
var err error
// setupAgent only waits for tailnet reachability, not for the HTTP API
// listener to serve the expected sub-agent rejection response.
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
_, err = conn.ListContainers(ctx)
if err != nil {
t.Logf("Error listing containers: %v", err)
}
return err != nil && strings.Contains(err.Error(), "Dev Container feature not supported.")
}, testutil.IntervalFast, "containers endpoint should reject devcontainers inside sub agents")
require.Error(t, err)
require.Contains(t, err.Error(), "Dev Container feature not supported.")
require.Contains(t, err.Error(), "Dev Container integration inside other Dev Containers is explicitly not supported.")
}
// TestAgent_DevcontainerPrebuildClaim tests that we correctly handle
// the claiming process for running devcontainers.
//
// You can run it manually as follows:
//
// CODER_TEST_USE_DOCKER=1 go test -count=1 ./agent -run TestAgent_DevcontainerPrebuildClaim
//
//nolint:paralleltest // This test sets an environment variable.
func TestAgent_DevcontainerPrebuildClaim(t *testing.T) {
if os.Getenv("CODER_TEST_USE_DOCKER") != "1" {
t.Skip("Set CODER_TEST_USE_DOCKER=1 to run this test")
}
if _, err := exec.LookPath("devcontainer"); err != nil {
t.Skip("This test requires the devcontainer CLI: npm install -g @devcontainers/cli")
}
pool, err := dockertest.NewPool("")
require.NoError(t, err, "Could not connect to docker")
var (
ctx = testutil.Context(t, testutil.WaitShort)
devcontainerID = uuid.New()
devcontainerLogSourceID = uuid.New()
workspaceFolder = filepath.Join(t.TempDir(), "project")
devcontainerPath = filepath.Join(workspaceFolder, ".devcontainer")
devcontainerConfig = filepath.Join(devcontainerPath, "devcontainer.json")
)
// Given: A devcontainer project.
t.Logf("Workspace folder: %s", workspaceFolder)
err = os.MkdirAll(devcontainerPath, 0o755)
require.NoError(t, err, "create dev container directory")
// Given: This devcontainer project specifies an app that uses the owner name and workspace name.
err = os.WriteFile(devcontainerConfig, []byte(`{
"name": "project",
"image": "busybox:latest",
"cmd": ["sleep", "infinity"],
"runArgs": ["--label=`+agentcontainers.DevcontainerIsTestRunLabel+`=true"],
"customizations": {
"coder": {
"apps": [{
"slug": "zed",
"url": "zed://ssh/${localEnv:CODER_WORKSPACE_AGENT_NAME}.${localEnv:CODER_WORKSPACE_NAME}.${localEnv:CODER_WORKSPACE_OWNER_NAME}.coder${containerWorkspaceFolder}"
}]
}
}
}`), 0o600)
require.NoError(t, err, "write devcontainer config")
// Given: A manifest with a prebuild username and workspace name.
manifest := agentsdk.Manifest{
OwnerName: "prebuilds",
WorkspaceName: "prebuilds-xyz-123",
Devcontainers: []codersdk.WorkspaceAgentDevcontainer{
{ID: devcontainerID, Name: "test", WorkspaceFolder: workspaceFolder},
},
Scripts: []codersdk.WorkspaceAgentScript{
{ID: devcontainerID, LogSourceID: devcontainerLogSourceID},
},
}
// When: We create an agent with devcontainers enabled.
//nolint:dogsled
conn, client, _, _, _ := setupAgent(t, manifest, 0, func(_ *agenttest.Client, o *agent.Options) {
o.Devcontainers = true
o.DevcontainerAPIOptions = append(o.DevcontainerAPIOptions,
agentcontainers.WithContainerLabelIncludeFilter(agentcontainers.DevcontainerLocalFolderLabel, workspaceFolder),
agentcontainers.WithContainerLabelIncludeFilter(agentcontainers.DevcontainerIsTestRunLabel, "true"),
)
})
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
return slices.Contains(client.GetLifecycleStates(), codersdk.WorkspaceAgentLifecycleReady)
}, testutil.IntervalMedium, "agent not ready")
var dcPrebuild codersdk.WorkspaceAgentDevcontainer
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
resp, err := conn.ListContainers(ctx)
require.NoError(t, err)
for _, dc := range resp.Devcontainers {
if dc.Container == nil {
continue
}
v, ok := dc.Container.Labels[agentcontainers.DevcontainerLocalFolderLabel]
if ok && v == workspaceFolder {
dcPrebuild = dc
return true
}
}
return false
}, testutil.IntervalMedium, "devcontainer not found")
defer func() {
pool.Client.RemoveContainer(docker.RemoveContainerOptions{
ID: dcPrebuild.Container.ID,
RemoveVolumes: true,
Force: true,
})
}()
// Then: We expect a sub agent to have been created.
subAgents := client.GetSubAgents()
require.Len(t, subAgents, 1)
subAgent := subAgents[0]
subAgentID, err := uuid.FromBytes(subAgent.GetId())
require.NoError(t, err)
// And: We expect there to be 1 app.
subAgentApps, err := client.GetSubAgentApps(subAgentID)
require.NoError(t, err)
require.Len(t, subAgentApps, 1)
// And: This app should contain the prebuild workspace name and owner name.
subAgentApp := subAgentApps[0]
require.Equal(t, "zed://ssh/project.prebuilds-xyz-123.prebuilds.coder/workspaces/project", subAgentApp.GetUrl())
// Given: We close the client and connection
client.Close()
conn.Close()
// Given: A new manifest with a regular user owner name and workspace name.
manifest = agentsdk.Manifest{
OwnerName: "user",
WorkspaceName: "user-workspace",
Devcontainers: []codersdk.WorkspaceAgentDevcontainer{
{ID: devcontainerID, Name: "test", WorkspaceFolder: workspaceFolder},
},
Scripts: []codersdk.WorkspaceAgentScript{
{ID: devcontainerID, LogSourceID: devcontainerLogSourceID},
},
}
// When: We create an agent with devcontainers enabled.
//nolint:dogsled
conn, client, _, _, _ = setupAgent(t, manifest, 0, func(_ *agenttest.Client, o *agent.Options) {
o.Devcontainers = true
o.DevcontainerAPIOptions = append(o.DevcontainerAPIOptions,
agentcontainers.WithContainerLabelIncludeFilter(agentcontainers.DevcontainerLocalFolderLabel, workspaceFolder),
agentcontainers.WithContainerLabelIncludeFilter(agentcontainers.DevcontainerIsTestRunLabel, "true"),
)
})
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
return slices.Contains(client.GetLifecycleStates(), codersdk.WorkspaceAgentLifecycleReady)
}, testutil.IntervalMedium, "agent not ready")
var dcClaimed codersdk.WorkspaceAgentDevcontainer
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
resp, err := conn.ListContainers(ctx)
require.NoError(t, err)
for _, dc := range resp.Devcontainers {
if dc.Container == nil {
continue
}
v, ok := dc.Container.Labels[agentcontainers.DevcontainerLocalFolderLabel]
if ok && v == workspaceFolder {
dcClaimed = dc
return true
}
}
return false
}, testutil.IntervalMedium, "devcontainer not found")
defer func() {
if dcClaimed.Container.ID != dcPrebuild.Container.ID {
pool.Client.RemoveContainer(docker.RemoveContainerOptions{
ID: dcClaimed.Container.ID,
RemoveVolumes: true,
Force: true,
})
}
}()
// Then: We expect the claimed devcontainer and prebuild devcontainer
// to be using the same underlying container.
require.Equal(t, dcPrebuild.Container.ID, dcClaimed.Container.ID)
// And: We expect there to be a sub agent created.
subAgents = client.GetSubAgents()
require.Len(t, subAgents, 1)
subAgent = subAgents[0]
subAgentID, err = uuid.FromBytes(subAgent.GetId())
require.NoError(t, err)
// And: We expect there to be an app.
subAgentApps, err = client.GetSubAgentApps(subAgentID)
require.NoError(t, err)
require.Len(t, subAgentApps, 1)
// And: We expect this app to have the user's owner name and workspace name.
subAgentApp = subAgentApps[0]
require.Equal(t, "zed://ssh/project.user-workspace.user.coder/workspaces/project", subAgentApp.GetUrl())
}
func TestAgent_Dial(t *testing.T) {
t.Parallel()
cases := []struct {
name string
setup func(t testing.TB) net.Listener
}{
{
name: "TCP",
setup: func(t testing.TB) net.Listener {
l, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err, "create TCP listener")
return l
},
},
{
name: "UDP",
setup: func(t testing.TB) net.Listener {
addr := net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 0,
}
l, err := udp.Listen("udp", &addr)
require.NoError(t, err, "create UDP listener")
return l
},
},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
t.Parallel()
// The purpose of this test is to ensure that a client can dial a
// listener in the workspace over tailnet.
//
// The OS sometimes drops packets if the system can't keep up with
// them. For TCP packets, it's typically fine due to
// retransmissions, but for UDP packets, it can fail this test.
//
// The OS gets involved for the Wireguard traffic (either via DERP
// or direct UDP), and also for the traffic between the agent and
// the listener in the "workspace".
//
// To avoid this, we'll retry this test up to 3 times.
//nolint:gocritic // This test is flaky due to uncontrollable OS packet drops under heavy load.
testutil.RunRetry(t, 3, func(t testing.TB) {
ctx := testutil.Context(t, testutil.WaitLong)
l := c.setup(t)
done := make(chan struct{})
defer func() {
l.Close()
<-done
}()
go func() {
defer close(done)
for range 2 {
c, err := l.Accept()
if assert.NoError(t, err, "accept connection") {
testAccept(ctx, t, c)
_ = c.Close()
}
}
}()
agentID := uuid.UUID{0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8}
//nolint:dogsled
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{
AgentID: agentID,
}, 0)
require.True(t, agentConn.AwaitReachable(ctx))
conn, err := agentConn.DialContext(ctx, l.Addr().Network(), l.Addr().String())
require.NoError(t, err)
testDial(ctx, t, conn)
err = conn.Close()
require.NoError(t, err)
// also connect via the CoderServicePrefix, to test that we can reach the agent on this
// IP. This will be required for CoderVPN.
_, rawPort, _ := net.SplitHostPort(l.Addr().String())
port, _ := strconv.ParseUint(rawPort, 10, 16)
ipp := netip.AddrPortFrom(tailnet.CoderServicePrefix.AddrFromUUID(agentID), uint16(port))
switch l.Addr().Network() {
case "tcp":
conn, err = agentConn.TailnetConn().DialContextTCP(ctx, ipp)
case "udp":
conn, err = agentConn.TailnetConn().DialContextUDP(ctx, ipp)
default:
t.Fatalf("unknown network: %s", l.Addr().Network())
}
require.NoError(t, err)
testDial(ctx, t, conn)
err = conn.Close()
require.NoError(t, err)
})
})
}
}
// TestAgent_UpdatedDERP checks that agents can handle their DERP map being
// updated, and that clients can also handle it.
func TestAgent_UpdatedDERP(t *testing.T) {
t.Parallel()
logger := testutil.Logger(t)
originalDerpMap, _ := tailnettest.RunDERPAndSTUN(t)
require.NotNil(t, originalDerpMap)
coordinator := tailnet.NewCoordinator(logger)
// use t.Cleanup so the coordinator closing doesn't deadlock with in-memory
// coordination
t.Cleanup(func() {
_ = coordinator.Close()
})
agentID := uuid.New()
statsCh := make(chan *proto.Stats, 50)
fs := afero.NewMemMapFs()
client := agenttest.NewClient(t,
logger.Named("agent"),
agentID,
agentsdk.Manifest{
DERPMap: originalDerpMap,
// Force DERP.
DisableDirectConnections: true,
},
statsCh,
coordinator,
)
t.Cleanup(func() {
t.Log("closing client")
client.Close()
})
uut := agent.New(agent.Options{
Client: client,
Filesystem: fs,
Logger: logger.Named("agent"),
ReconnectingPTYTimeout: time.Minute,
})
t.Cleanup(func() {
t.Log("closing agent")
_ = uut.Close()
})
// Setup a client connection.
newClientConn := func(derpMap *tailcfg.DERPMap, name string) workspacesdk.AgentConn {
conn, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{tailnet.TailscaleServicePrefix.RandomPrefix()},
DERPMap: derpMap,
Logger: logger.Named(name),
})
require.NoError(t, err)
t.Cleanup(func() {
t.Logf("closing conn %s", name)
_ = conn.Close()
})
testCtx, testCtxCancel := context.WithCancel(context.Background())
t.Cleanup(testCtxCancel)
clientID := uuid.New()
ctrl := tailnet.NewTunnelSrcCoordController(logger, conn)
ctrl.AddDestination(agentID)
auth := tailnet.ClientCoordinateeAuth{AgentID: agentID}
coordination := ctrl.New(tailnet.NewInMemoryCoordinatorClient(logger, clientID, auth, coordinator))
t.Cleanup(func() {
t.Logf("closing coordination %s", name)
cctx, ccancel := context.WithTimeout(testCtx, testutil.WaitShort)
defer ccancel()
err := coordination.Close(cctx)
if err != nil {
t.Logf("error closing in-memory coordination: %s", err.Error())
}
t.Logf("closed coordination %s", name)
})
// Force DERP.
conn.SetBlockEndpoints(true)
sdkConn := workspacesdk.NewAgentConn(conn, workspacesdk.AgentConnOptions{
AgentID: agentID,
CloseFunc: func() error { return workspacesdk.ErrSkipClose },
})
t.Cleanup(func() {
t.Logf("closing sdkConn %s", name)
_ = sdkConn.Close()
})
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
if !sdkConn.AwaitReachable(ctx) {
t.Fatal("agent not reachable")
}
return sdkConn
}
conn1 := newClientConn(originalDerpMap, "client1")
// Change the DERP map.
newDerpMap, _ := tailnettest.RunDERPAndSTUN(t)
require.NotNil(t, newDerpMap)
// Change the region ID.
newDerpMap.Regions[2] = newDerpMap.Regions[1]
delete(newDerpMap.Regions, 1)
newDerpMap.Regions[2].RegionID = 2
for _, node := range newDerpMap.Regions[2].Nodes {
node.RegionID = 2
}
// Push a new DERP map to the agent.
err := client.PushDERPMapUpdate(newDerpMap)
require.NoError(t, err)
t.Log("pushed DERPMap update to agent")
require.Eventually(t, func() bool {
conn := uut.TailnetConn()
if conn == nil {
return false
}
regionIDs := conn.DERPMap().RegionIDs()
preferredDERP := conn.Node().PreferredDERP
t.Logf("agent Conn DERPMap with regionIDs %v, PreferredDERP %d", regionIDs, preferredDERP)
return len(regionIDs) == 1 && regionIDs[0] == 2 && preferredDERP == 2
}, testutil.WaitLong, testutil.IntervalFast)
t.Log("agent got the new DERPMap")
// Connect from a second client and make sure it uses the new DERP map.
conn2 := newClientConn(newDerpMap, "client2")
require.Equal(t, []int{2}, conn2.TailnetConn().DERPMap().RegionIDs())
t.Log("conn2 got the new DERPMap")
// If the first client gets a DERP map update, it should be able to
// reconnect just fine.
conn1.TailnetConn().SetDERPMap(newDerpMap)
require.Equal(t, []int{2}, conn1.TailnetConn().DERPMap().RegionIDs())
t.Log("set the new DERPMap on conn1")
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
require.True(t, conn1.AwaitReachable(ctx))
t.Log("conn1 reached agent with new DERP")
}
func TestAgent_Speedtest(t *testing.T) {
t.Parallel()
t.Skip("This test is relatively flakey because of Tailscale's speedtest code...")
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
derpMap, _ := tailnettest.RunDERPAndSTUN(t)
//nolint:dogsled
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{
DERPMap: derpMap,
}, 0, func(client *agenttest.Client, options *agent.Options) {
options.Logger = logger.Named("agent")
})
defer conn.Close()
res, err := conn.Speedtest(ctx, speedtest.Upload, 250*time.Millisecond)
require.NoError(t, err)
t.Logf("%.2f MBits/s", res[len(res)-1].MBitsPerSecond())
}
func TestAgent_Reconnect(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
logger := testutil.Logger(t)
// After the agent is disconnected from a coordinator, it's supposed
// to reconnect!
fCoordinator := tailnettest.NewFakeCoordinator()
agentID := uuid.New()
statsCh := make(chan *proto.Stats, 50)
derpMap, _ := tailnettest.RunDERPAndSTUN(t)
client := agenttest.NewClient(t,
logger,
agentID,
agentsdk.Manifest{
DERPMap: derpMap,
Directory: "/test/workspace",
},
statsCh,
fCoordinator,
)
defer client.Close()
closer := agent.New(agent.Options{
Client: client,
Logger: logger.Named("agent"),
})
defer closer.Close()
// Each iteration forces the agent to reconnect by closing
// the current coordinate call while the tracked HTTP server
// goroutine (from connection 1's createTailnet) is still
// alive, widening the race window.
const reconnections = 5
for i := range reconnections {
call := testutil.RequireReceive(ctx, t, fCoordinator.CoordinateCalls)
require.Equal(t, i+1, client.GetNumRefreshTokenCalls())
close(call.Resps) // hang up — triggers reconnect
}
// Verify final reconnect succeeds.
testutil.RequireReceive(ctx, t, fCoordinator.CoordinateCalls)
require.Equal(t, reconnections+1, client.GetNumRefreshTokenCalls())
closer.Close()
}
func TestAgent_ReconnectNoLifecycleReemit(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
logger := testutil.Logger(t)
fCoordinator := tailnettest.NewFakeCoordinator()
agentID := uuid.New()
statsCh := make(chan *proto.Stats, 50)
derpMap, _ := tailnettest.RunDERPAndSTUN(t)
client := agenttest.NewClient(t,
logger,
agentID,
agentsdk.Manifest{
DERPMap: derpMap,
Scripts: []codersdk.WorkspaceAgentScript{{
Script: "echo hello",
Timeout: 30 * time.Second,
RunOnStart: true,
}},
},
statsCh,
fCoordinator,
)
defer client.Close()
closer := agent.New(agent.Options{
Client: client,
Logger: logger.Named("agent"),
})
defer closer.Close()
// Wait for the agent to reach Ready state.
require.Eventually(t, func() bool {
return slices.Contains(client.GetLifecycleStates(), codersdk.WorkspaceAgentLifecycleReady)
}, testutil.WaitShort, testutil.IntervalFast)
statesBefore := slices.Clone(client.GetLifecycleStates())
// Disconnect by closing the coordinator response channel.
call1 := testutil.RequireReceive(ctx, t, fCoordinator.CoordinateCalls)
close(call1.Resps)
// Wait for reconnect.
testutil.RequireReceive(ctx, t, fCoordinator.CoordinateCalls)
// Wait for a stats report as a deterministic steady-state proof.
testutil.RequireReceive(ctx, t, statsCh)
statesAfter := client.GetLifecycleStates()
require.Equal(t, statesBefore, statesAfter,
"lifecycle states should not be re-reported after reconnect")
closer.Close()
}
func TestAgent_WriteVSCodeConfigs(t *testing.T) {
t.Parallel()
logger := testutil.Logger(t)
coordinator := tailnet.NewCoordinator(logger)
defer coordinator.Close()
client := agenttest.NewClient(t,
logger,
uuid.New(),
agentsdk.Manifest{
GitAuthConfigs: 1,
DERPMap: &tailcfg.DERPMap{},
},
make(chan *proto.Stats, 50),
coordinator,
)
defer client.Close()
filesystem := afero.NewMemMapFs()
closer := agent.New(agent.Options{
Client: client,
Logger: logger.Named("agent"),
Filesystem: filesystem,
})
defer closer.Close()
home, err := os.UserHomeDir()
require.NoError(t, err)
name := filepath.Join(home, ".vscode-server", "data", "Machine", "settings.json")
require.Eventually(t, func() bool {
_, err := filesystem.Stat(name)
return err == nil
}, testutil.WaitShort, testutil.IntervalFast)
}
func TestAgent_DebugServer(t *testing.T) {
t.Parallel()
logDir := t.TempDir()
logPath := filepath.Join(logDir, "coder-agent.log")
randLogStr, err := cryptorand.String(32)
require.NoError(t, err)
require.NoError(t, os.WriteFile(logPath, []byte(randLogStr), 0o600))
derpMap, _ := tailnettest.RunDERPAndSTUN(t)
//nolint:dogsled
conn, _, _, _, agnt := setupAgentWithSecrets(t, agentsdk.Manifest{
DERPMap: derpMap,
}, []agentsdk.WorkspaceSecret{
{EnvName: "DEBUG_SECRET", Value: []byte("super-secret-value-12345")},
}, 0, func(c *agenttest.Client, o *agent.Options) {
o.LogDir = logDir
})
awaitReachableCtx := testutil.Context(t, testutil.WaitLong)
ok := conn.AwaitReachable(awaitReachableCtx)
require.True(t, ok)
_ = conn.Close()
srv := httptest.NewServer(agnt.HTTPDebug())
t.Cleanup(srv.Close)
t.Run("MagicsockDebug", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, srv.URL+"/debug/magicsock", nil)
require.NoError(t, err)
res, err := srv.Client().Do(req)
require.NoError(t, err)
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode)
resBody, err := io.ReadAll(res.Body)
require.NoError(t, err)
require.Contains(t, string(resBody), "<h1>magicsock</h1>")
})
t.Run("MagicsockDebugLogging", func(t *testing.T) {
t.Parallel()
t.Run("Enable", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, srv.URL+"/debug/magicsock/debug-logging/t", nil)
require.NoError(t, err)
res, err := srv.Client().Do(req)
require.NoError(t, err)
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode)
resBody, err := io.ReadAll(res.Body)
require.NoError(t, err)
require.Contains(t, string(resBody), "updated magicsock debug logging to true")
})
t.Run("Disable", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, srv.URL+"/debug/magicsock/debug-logging/0", nil)
require.NoError(t, err)
res, err := srv.Client().Do(req)
require.NoError(t, err)
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode)
resBody, err := io.ReadAll(res.Body)
require.NoError(t, err)
require.Contains(t, string(resBody), "updated magicsock debug logging to false")
})
t.Run("Invalid", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, srv.URL+"/debug/magicsock/debug-logging/blah", nil)
require.NoError(t, err)
res, err := srv.Client().Do(req)
require.NoError(t, err)
defer res.Body.Close()
require.Equal(t, http.StatusBadRequest, res.StatusCode)
resBody, err := io.ReadAll(res.Body)
require.NoError(t, err)
require.Contains(t, string(resBody), `invalid state "blah", must be a boolean`)
})
})
t.Run("Manifest", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, srv.URL+"/debug/manifest", nil)
require.NoError(t, err)
res, err := srv.Client().Do(req)
require.NoError(t, err)
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode)
var v agentsdk.Manifest
require.NoError(t, json.NewDecoder(res.Body).Decode(&v))
require.NotNil(t, v)
})
t.Run("ManifestSecretsStripped", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, srv.URL+"/debug/manifest", nil)
require.NoError(t, err)
res, err := srv.Client().Do(req)
require.NoError(t, err)
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode)
body, err := io.ReadAll(res.Body)
require.NoError(t, err)
// The response must not contain the secret value.
require.NotContains(t, string(body), "super-secret-value-12345")
// Confirm we can decode as a Manifest. The SDK type
// intentionally has no Secrets field, so there is nothing
// to leak through JSON encoding.
var v agentsdk.Manifest
require.NoError(t, json.Unmarshal(body, &v))
})
t.Run("Logs", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, srv.URL+"/debug/logs", nil)
require.NoError(t, err)
res, err := srv.Client().Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, res.StatusCode)
defer res.Body.Close()
resBody, err := io.ReadAll(res.Body)
require.NoError(t, err)
require.NotEmpty(t, string(resBody))
require.Contains(t, string(resBody), randLogStr)
})
}
func TestAgent_ScriptLogging(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("bash scripts only")
}
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
derpMap, _ := tailnettest.RunDERPAndSTUN(t)
logsCh := make(chan *proto.BatchCreateLogsRequest, 100)
lsStart := uuid.UUID{0x11}
lsStop := uuid.UUID{0x22}
//nolint:dogsled
_, _, _, _, agnt := setupAgent(
t,
agentsdk.Manifest{
DERPMap: derpMap,
Scripts: []codersdk.WorkspaceAgentScript{
{
LogSourceID: lsStart,
RunOnStart: true,
Script: `#!/bin/sh
i=0
while [ $i -ne 5 ]
do
i=$(($i+1))
echo "start $i"
done
`,
},
{
LogSourceID: lsStop,
RunOnStop: true,
Script: `#!/bin/sh
i=0
while [ $i -ne 3000 ]
do
i=$(($i+1))
echo "stop $i"
done
`, // send a lot of stop logs to make sure we don't truncate shutdown logs before closing the API conn
},
},
},
0,
func(cl *agenttest.Client, _ *agent.Options) {
cl.SetLogsChannel(logsCh)
},
)
n := 1
for n <= 5 {
logs := testutil.TryReceive(ctx, t, logsCh)
require.NotNil(t, logs)
for _, l := range logs.GetLogs() {
require.Equal(t, fmt.Sprintf("start %d", n), l.GetOutput())
n++
}
}
err := agnt.Close()
require.NoError(t, err)
n = 1
for n <= 3000 {
logs := testutil.TryReceive(ctx, t, logsCh)
require.NotNil(t, logs)
for _, l := range logs.GetLogs() {
require.Equal(t, fmt.Sprintf("stop %d", n), l.GetOutput())
n++
}
t.Logf("got %d stop logs", n-1)
}
}
// setupAgentSSHClient creates an agent, dials it, and sets up an ssh.Client for it
func setupAgentSSHClient(ctx context.Context, t *testing.T) *ssh.Client {
//nolint: dogsled
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
sshClient, err := agentConn.SSHClient(ctx)
require.NoError(t, err)
t.Cleanup(func() { sshClient.Close() })
return sshClient
}
func setupSSHSession(
t *testing.T,
manifest agentsdk.Manifest,
banner codersdk.BannerConfig,
prepareFS func(fs afero.Fs),
opts ...func(*agenttest.Client, *agent.Options),
) *ssh.Session {
return setupSSHSessionOnPort(t, manifest, banner, prepareFS, workspacesdk.AgentSSHPort, opts...)
}
func setupSSHSessionOnPort(
t *testing.T,
manifest agentsdk.Manifest,
banner codersdk.BannerConfig,
prepareFS func(fs afero.Fs),
port uint16,
opts ...func(*agenttest.Client, *agent.Options),
) *ssh.Session {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
opts = append(opts, func(c *agenttest.Client, o *agent.Options) {
c.SetAnnouncementBannersFunc(func() ([]codersdk.BannerConfig, error) {
return []codersdk.BannerConfig{banner}, nil
})
})
//nolint:dogsled
conn, _, _, fs, _ := setupAgent(t, manifest, 0, opts...)
if prepareFS != nil {
prepareFS(fs)
}
sshClient, err := conn.SSHClientOnPort(ctx, port)
require.NoError(t, err)
t.Cleanup(func() {
_ = sshClient.Close()
})
session, err := sshClient.NewSession()
require.NoError(t, err)
t.Cleanup(func() {
_ = session.Close()
})
return session
}
func setupAgent(t testing.TB, metadata agentsdk.Manifest, ptyTimeout time.Duration, opts ...func(*agenttest.Client, *agent.Options)) (
workspacesdk.AgentConn,
*agenttest.Client,
<-chan *proto.Stats,
afero.Fs,
agent.Agent,
) {
return setupAgentWithSecrets(t, metadata, nil, ptyTimeout, opts...)
}
// setupAgentWithSecrets is like setupAgent but also injects user
// secrets into the agent's proto manifest. Separate from setupAgent
// because agentsdk.Manifest intentionally does not carry secrets; see
// the Manifest doc comment in codersdk/agentsdk.
func setupAgentWithSecrets(t testing.TB, metadata agentsdk.Manifest, secrets []agentsdk.WorkspaceSecret, ptyTimeout time.Duration, opts ...func(*agenttest.Client, *agent.Options)) (
workspacesdk.AgentConn,
*agenttest.Client,
<-chan *proto.Stats,
afero.Fs,
agent.Agent,
) {
logger := slogtest.Make(t, &slogtest.Options{
// Agent can drop errors when shutting down, and some, like the
// fasthttplistener connection closed error, are unexported.
IgnoreErrors: true,
}).Leveled(slog.LevelDebug)
if metadata.DERPMap == nil {
metadata.DERPMap, _ = tailnettest.RunDERPAndSTUN(t)
}
if metadata.AgentID == uuid.Nil {
metadata.AgentID = uuid.New()
}
if metadata.AgentName == "" {
metadata.AgentName = "test-agent"
}
if metadata.WorkspaceName == "" {
metadata.WorkspaceName = "test-workspace"
}
if metadata.OwnerName == "" {
metadata.OwnerName = "test-user"
}
if metadata.WorkspaceID == uuid.Nil {
metadata.WorkspaceID = uuid.New()
}
coordinator := tailnet.NewCoordinator(logger)
t.Cleanup(func() {
_ = coordinator.Close()
})
statsCh := make(chan *proto.Stats, 50)
fs := afero.NewMemMapFs()
c := agenttest.NewClientWithSecrets(t, logger.Named("agenttest"), metadata.AgentID, metadata, secrets, statsCh, coordinator)
t.Cleanup(c.Close)
options := agent.Options{
Client: c,
Filesystem: fs,
Logger: logger.Named("agent"),
ReconnectingPTYTimeout: ptyTimeout,
EnvironmentVariables: map[string]string{},
}
for _, opt := range opts {
opt(c, &options)
}
agnt := agent.New(options)
t.Cleanup(func() {
_ = agnt.Close()
})
conn, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.TailscaleServicePrefix.RandomAddr(), 128)},
DERPMap: metadata.DERPMap,
Logger: logger.Named("client"),
})
require.NoError(t, err)
t.Cleanup(func() {
_ = conn.Close()
})
testCtx, testCtxCancel := context.WithCancel(context.Background())
t.Cleanup(testCtxCancel)
clientID := uuid.New()
ctrl := tailnet.NewTunnelSrcCoordController(logger, conn)
ctrl.AddDestination(metadata.AgentID)
auth := tailnet.ClientCoordinateeAuth{AgentID: metadata.AgentID}
coordination := ctrl.New(tailnet.NewInMemoryCoordinatorClient(
logger, clientID, auth, coordinator))
t.Cleanup(func() {
cctx, ccancel := context.WithTimeout(testCtx, testutil.WaitShort)
defer ccancel()
err := coordination.Close(cctx)
if err != nil {
t.Logf("error closing in-mem coordination: %s", err.Error())
}
})
agentConn := workspacesdk.NewAgentConn(conn, workspacesdk.AgentConnOptions{
AgentID: metadata.AgentID,
})
t.Cleanup(func() {
_ = agentConn.Close()
})
// Ideally we wouldn't wait too long here, but sometimes the the
// networking needs more time to resolve itself.
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
if !agentConn.AwaitReachable(ctx) {
t.Fatal("agent not reachable")
}
return agentConn, c, statsCh, fs, agnt
}
var dialTestPayload = []byte("dean-was-here123")
func testDial(ctx context.Context, t testing.TB, c net.Conn) {
t.Helper()
if deadline, ok := ctx.Deadline(); ok {
err := c.SetDeadline(deadline)
assert.NoError(t, err)
defer func() {
err := c.SetDeadline(time.Time{})
assert.NoError(t, err)
}()
}
assertWritePayload(t, c, dialTestPayload)
assertReadPayload(t, c, dialTestPayload)
}
func testAccept(ctx context.Context, t testing.TB, c net.Conn) {
t.Helper()
defer c.Close()
if deadline, ok := ctx.Deadline(); ok {
err := c.SetDeadline(deadline)
assert.NoError(t, err)
defer func() {
err := c.SetDeadline(time.Time{})
assert.NoError(t, err)
}()
}
assertReadPayload(t, c, dialTestPayload)
assertWritePayload(t, c, dialTestPayload)
}
func assertReadPayload(t testing.TB, r io.Reader, payload []byte) {
t.Helper()
b := make([]byte, len(payload)+16)
n, err := r.Read(b)
assert.NoError(t, err, "read payload")
assert.Equal(t, len(payload), n, "read payload length does not match")
assert.Equal(t, payload, b[:n])
}
func assertWritePayload(t testing.TB, w io.Writer, payload []byte) {
t.Helper()
n, err := w.Write(payload)
assert.NoError(t, err, "write payload")
assert.Equal(t, len(payload), n, "written payload length does not match")
}
func testSessionOutput(t *testing.T, session *ssh.Session, expected, unexpected []string, expectedRe *regexp.Regexp) {
t.Helper()
err := session.RequestPty("xterm", 128, 128, ssh.TerminalModes{})
require.NoError(t, err)
ptty := ptytest.New(t)
var stdout bytes.Buffer
session.Stdout = &stdout
session.Stderr = ptty.Output()
session.Stdin = ptty.Input()
err = session.Shell()
require.NoError(t, err)
ptty.WriteLine("exit 0")
waitErr := make(chan error, 1)
go func() {
waitErr <- session.Wait()
}()
select {
case err = <-waitErr:
require.NoError(t, err)
case <-time.After(testutil.WaitLong):
require.Fail(t, "timed out waiting for session to exit")
}
for _, unexpected := range unexpected {
require.NotContains(t, stdout.String(), unexpected, "should not show output")
}
for _, expect := range expected {
require.Contains(t, stdout.String(), expect, "should show output")
}
if expectedRe != nil {
require.Regexp(t, expectedRe, stdout.String())
}
}
func TestAgent_Metrics_SSH(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
registry := prometheus.NewRegistry()
//nolint:dogsled
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
o.PrometheusRegistry = registry
})
sshClient, err := conn.SSHClient(ctx)
require.NoError(t, err)
defer sshClient.Close()
session, err := sshClient.NewSession()
require.NoError(t, err)
defer session.Close()
stdin, err := session.StdinPipe()
require.NoError(t, err)
err = session.Shell()
require.NoError(t, err)
expected := []struct {
Name string
Type proto.Stats_Metric_Type
CheckFn func(float64) error
Labels []*proto.Stats_Metric_Label
}{
{
Name: "agent_reconnecting_pty_connections_total",
Type: proto.Stats_Metric_COUNTER,
CheckFn: func(v float64) error {
if v == 0 {
return nil
}
return xerrors.Errorf("expected 0, got %f", v)
},
},
{
Name: "agent_sessions_total",
Type: proto.Stats_Metric_COUNTER,
CheckFn: func(v float64) error {
if v == 1 {
return nil
}
return xerrors.Errorf("expected 1, got %f", v)
},
Labels: []*proto.Stats_Metric_Label{
{
Name: "magic_type",
Value: "ssh",
},
{
Name: "pty",
Value: "no",
},
},
},
{
Name: "agent_ssh_server_failed_connections_total",
Type: proto.Stats_Metric_COUNTER,
CheckFn: func(v float64) error {
if v == 0 {
return nil
}
return xerrors.Errorf("expected 0, got %f", v)
},
},
{
Name: "agent_ssh_server_sftp_connections_total",
Type: proto.Stats_Metric_COUNTER,
CheckFn: func(v float64) error {
if v == 0 {
return nil
}
return xerrors.Errorf("expected 0, got %f", v)
},
},
{
Name: "agent_ssh_server_sftp_server_errors_total",
Type: proto.Stats_Metric_COUNTER,
CheckFn: func(v float64) error {
if v == 0 {
return nil
}
return xerrors.Errorf("expected 0, got %f", v)
},
},
{
Name: "coderd_agentstats_currently_reachable_peers",
Type: proto.Stats_Metric_GAUGE,
CheckFn: func(float64) error {
// We can't reliably ping a peer here, and networking is out of
// scope of this test, so we just test that the metric exists
// with the correct labels.
return nil
},
Labels: []*proto.Stats_Metric_Label{
{
Name: "connection_type",
Value: "derp",
},
},
},
{
Name: "coderd_agentstats_currently_reachable_peers",
Type: proto.Stats_Metric_GAUGE,
CheckFn: func(float64) error {
return nil
},
Labels: []*proto.Stats_Metric_Label{
{
Name: "connection_type",
Value: "p2p",
},
},
},
{
Name: "coderd_agentstats_startup_script_seconds",
Type: proto.Stats_Metric_GAUGE,
CheckFn: func(f float64) error {
if f >= 0 {
return nil
}
return xerrors.Errorf("expected >= 0, got %f", f)
},
Labels: []*proto.Stats_Metric_Label{
{
Name: "success",
Value: "true",
},
},
},
}
var actual []*promgo.MetricFamily
assert.Eventually(t, func() bool {
actual, err = registry.Gather()
if err != nil {
return false
}
count := 0
for _, m := range actual {
count += len(m.GetMetric())
}
return count == len(expected)
}, testutil.WaitLong, testutil.IntervalFast)
i := 0
for _, mf := range actual {
for _, m := range mf.GetMetric() {
assert.Equal(t, expected[i].Name, mf.GetName())
assert.Equal(t, expected[i].Type.String(), mf.GetType().String())
if expected[i].Type == proto.Stats_Metric_GAUGE {
assert.NoError(t, expected[i].CheckFn(m.GetGauge().GetValue()), "check fn for %s failed", expected[i].Name)
} else if expected[i].Type == proto.Stats_Metric_COUNTER {
assert.NoError(t, expected[i].CheckFn(m.GetCounter().GetValue()), "check fn for %s failed", expected[i].Name)
}
for j, lbl := range expected[i].Labels {
assert.Equal(t, m.GetLabel()[j], &promgo.LabelPair{
Name: &lbl.Name,
Value: &lbl.Value,
})
}
i++
}
}
_, err = stdin.Write([]byte("exit 0\n"))
require.NoError(t, err, "writing exit to stdin")
_ = stdin.Close()
err = session.Wait()
require.NoError(t, err, "waiting for session to exit")
}
// echoOnce accepts a single connection, reads 4 bytes and echos them back
func echoOnce(t *testing.T, ll net.Listener) {
t.Helper()
conn, err := ll.Accept()
if err != nil {
return
}
defer conn.Close()
b := make([]byte, 4)
_, err = conn.Read(b)
if !assert.NoError(t, err) {
return
}
_, err = conn.Write(b)
if !assert.NoError(t, err) {
return
}
}
// requireEcho sends 4 bytes and requires the read response to match what was sent.
func requireEcho(t *testing.T, conn net.Conn) {
t.Helper()
_, err := conn.Write([]byte("test"))
require.NoError(t, err)
b := make([]byte, 4)
_, err = conn.Read(b)
require.NoError(t, err)
require.Equal(t, "test", string(b))
}
func assertConnectionReport(t testing.TB, agentClient *agenttest.Client, connectionType proto.Connection_Type, status int, reason string) {
t.Helper()
var reports []*proto.ReportConnectionRequest
if !assert.Eventually(t, func() bool {
reports = agentClient.GetConnectionReports()
return len(reports) >= 2
}, testutil.WaitMedium, testutil.IntervalFast, "waiting for 2 connection reports or more; got %d", len(reports)) {
return
}
assert.Len(t, reports, 2, "want 2 connection reports")
assert.Equal(t, proto.Connection_CONNECT, reports[0].GetConnection().GetAction(), "first report should be connect")
assert.Equal(t, proto.Connection_DISCONNECT, reports[1].GetConnection().GetAction(), "second report should be disconnect")
assert.Equal(t, connectionType, reports[0].GetConnection().GetType(), "connect type should be %s", connectionType)
assert.Equal(t, connectionType, reports[1].GetConnection().GetType(), "disconnect type should be %s", connectionType)
t1 := reports[0].GetConnection().GetTimestamp().AsTime()
t2 := reports[1].GetConnection().GetTimestamp().AsTime()
assert.True(t, t1.Before(t2) || t1.Equal(t2), "connect timestamp should be before or equal to disconnect timestamp")
assert.NotEmpty(t, reports[0].GetConnection().GetIp(), "connect ip should not be empty")
assert.NotEmpty(t, reports[1].GetConnection().GetIp(), "disconnect ip should not be empty")
assert.Equal(t, 0, int(reports[0].GetConnection().GetStatusCode()), "connect status code should be 0")
assert.Equal(t, status, int(reports[1].GetConnection().GetStatusCode()), "disconnect status code should be %d", status)
assert.Equal(t, "", reports[0].GetConnection().GetReason(), "connect reason should be empty")
if reason != "" {
assert.Contains(t, reports[1].GetConnection().GetReason(), reason, "disconnect reason should contain %s", reason)
} else {
t.Logf("connection report disconnect reason: %s", reports[1].GetConnection().GetReason())
}
}