test: dial socket when testing coder ssh unix socket forwarding (#19563)

Closes https://github.com/coder/internal/issues/942

The flakey test, `RemoteForwardUnixSocket`, was using `netstat` to check if the unix socket was forwarded properly. In the flake, it looks like netstat was hanging. This PR has `RemoteForwardUnixSocket` be rewritten to match the implementation of `RemoteForwardMultipleUnixSockets`, where we send bytes over the socket in-process instead. More importantly, that test hasn't flaked (yet).

Note: The implementation has been copied directly from the other test, comments and all.
This commit is contained in:
Ethan
2025-08-27 21:30:54 +10:00
committed by GitHub
parent 5c2022e08c
commit fcef2ec3a5
+63 -18
View File
@@ -20,6 +20,7 @@ import (
"regexp"
"runtime"
"strings"
"sync"
"testing"
"time"
@@ -1318,9 +1319,6 @@ func TestSSH(t *testing.T) {
tmpdir := tempDirUnixSocket(t)
localSock := filepath.Join(tmpdir, "local.sock")
l, err := net.Listen("unix", localSock)
require.NoError(t, err)
defer l.Close()
remoteSock := filepath.Join(tmpdir, "remote.sock")
inv, root := clitest.New(t,
@@ -1332,23 +1330,62 @@ func TestSSH(t *testing.T) {
clitest.SetupConfig(t, client, root)
pty := ptytest.New(t).Attach(inv)
inv.Stderr = pty.Output()
cmdDone := tGo(t, func() {
err := inv.WithContext(ctx).Run()
assert.NoError(t, err, "ssh command failed")
})
// Wait for the prompt or any output really to indicate the command has
// started and accepting input on stdin.
w := clitest.StartWithWaiter(t, inv.WithContext(ctx))
defer w.Wait() // We don't care about any exit error (exit code 255: SSH connection ended unexpectedly).
// Since something was output, it should be safe to write input.
// This could show a prompt or "running startup scripts", so it's
// not indicative of the SSH connection being ready.
_ = pty.Peek(ctx, 1)
// This needs to support most shells on Linux or macOS
// We can't include exactly what's expected in the input, as that will always be matched
pty.WriteLine(fmt.Sprintf(`echo "results: $(netstat -an | grep %s | wc -l | tr -d ' ')"`, remoteSock))
pty.ExpectMatchContext(ctx, "results: 1")
// Ensure the SSH connection is ready by testing the shell
// input/output.
pty.WriteLine("echo ping' 'pong")
pty.ExpectMatchContext(ctx, "ping pong")
// Start the listener on the "local machine".
l, err := net.Listen("unix", localSock)
require.NoError(t, err)
defer l.Close()
testutil.Go(t, func() {
var wg sync.WaitGroup
defer wg.Wait()
for {
fd, err := l.Accept()
if err != nil {
if !errors.Is(err, net.ErrClosed) {
assert.NoError(t, err, "listener accept failed")
}
return
}
wg.Add(1)
go func() {
defer wg.Done()
defer fd.Close()
agentssh.Bicopy(ctx, fd, fd)
}()
}
})
// Dial the forwarded socket on the "remote machine".
d := &net.Dialer{}
fd, err := d.DialContext(ctx, "unix", remoteSock)
require.NoError(t, err)
defer fd.Close()
// Ping / pong to ensure the socket is working.
_, err = fd.Write([]byte("hello world"))
require.NoError(t, err)
buf := make([]byte, 11)
_, err = fd.Read(buf)
require.NoError(t, err)
require.Equal(t, "hello world", string(buf))
// And we're done.
pty.WriteLine("exit")
<-cmdDone
})
// Test that we can forward a local unix socket to a remote unix socket and
@@ -1377,6 +1414,8 @@ func TestSSH(t *testing.T) {
require.NoError(t, err)
defer l.Close()
testutil.Go(t, func() {
var wg sync.WaitGroup
defer wg.Wait()
for {
fd, err := l.Accept()
if err != nil {
@@ -1386,10 +1425,12 @@ func TestSSH(t *testing.T) {
return
}
testutil.Go(t, func() {
wg.Add(1)
go func() {
defer wg.Done()
defer fd.Close()
agentssh.Bicopy(ctx, fd, fd)
})
}()
}
})
@@ -1522,6 +1563,8 @@ func TestSSH(t *testing.T) {
require.NoError(t, err)
defer l.Close() //nolint:revive // Defer is fine in this loop, we only run it twice.
testutil.Go(t, func() {
var wg sync.WaitGroup
defer wg.Wait()
for {
fd, err := l.Accept()
if err != nil {
@@ -1531,10 +1574,12 @@ func TestSSH(t *testing.T) {
return
}
testutil.Go(t, func() {
wg.Add(1)
go func() {
defer wg.Done()
defer fd.Close()
agentssh.Bicopy(ctx, fd, fd)
})
}()
}
})