test(cli): fix TestSSH/RemoteForward_Unix_Signal flake (#16172)

This commit is contained in:
Mathias Fredriksson
2025-01-17 16:53:09 +02:00
committed by GitHub
parent ea8cd55404
commit 7cf62423ec
+90 -87
View File
@@ -819,102 +819,105 @@ func TestSSH(t *testing.T) {
tmpdir := tempDirUnixSocket(t)
localSock := filepath.Join(tmpdir, "local.sock")
l, err := net.Listen("unix", localSock)
require.NoError(t, err)
defer l.Close()
remoteSock := path.Join(tmpdir, "remote.sock")
for i := 0; i < 2; i++ {
t.Logf("connect %d of 2", i+1)
inv, root := clitest.New(t,
"ssh",
workspace.Name,
"--remote-forward",
remoteSock+":"+localSock,
)
fsn := clitest.NewFakeSignalNotifier(t)
inv = inv.WithTestSignalNotifyContext(t, fsn.NotifyContext)
inv.Stdout = io.Discard
inv.Stderr = io.Discard
func() { // Function scope for defer.
t.Logf("Connect %d/2", i+1)
clitest.SetupConfig(t, client, root)
cmdDone := tGo(t, func() {
err := inv.WithContext(ctx).Run()
assert.Error(t, err)
})
// accept a single connection
msgs := make(chan string, 1)
go func() {
conn, err := l.Accept()
if !assert.NoError(t, err) {
return
}
msg, err := io.ReadAll(conn)
if !assert.NoError(t, err) {
return
}
msgs <- string(msg)
}()
// Unfortunately, there is a race in crypto/ssh where it sends the request to forward
// unix sockets before it is prepared to receive the response, meaning that even after
// the socket exists on the file system, the client might not be ready to accept the
// channel.
//
// https://cs.opensource.google/go/x/crypto/+/master:ssh/streamlocal.go;drc=2fc4c88bf43f0ea5ea305eae2b7af24b2cc93287;l=33
//
// To work around this, we attempt to send messages in a loop until one succeeds
success := make(chan struct{})
done := make(chan struct{})
go func() {
defer close(done)
var (
conn net.Conn
err error
inv, root := clitest.New(t,
"ssh",
workspace.Name,
"--remote-forward",
remoteSock+":"+localSock,
)
for {
time.Sleep(testutil.IntervalMedium)
select {
case <-ctx.Done():
t.Error("timeout")
fsn := clitest.NewFakeSignalNotifier(t)
inv = inv.WithTestSignalNotifyContext(t, fsn.NotifyContext)
inv.Stdout = io.Discard
inv.Stderr = io.Discard
clitest.SetupConfig(t, client, root)
cmdDone := tGo(t, func() {
err := inv.WithContext(ctx).Run()
assert.Error(t, err)
})
// accept a single connection
msgs := make(chan string, 1)
l, err := net.Listen("unix", localSock)
require.NoError(t, err)
defer l.Close()
go func() {
conn, err := l.Accept()
if !assert.NoError(t, err) {
return
case <-success:
}
msg, err := io.ReadAll(conn)
if !assert.NoError(t, err) {
return
default:
// Ok
}
conn, err = net.Dial("unix", remoteSock)
if err != nil {
t.Logf("dial error: %s", err)
continue
msgs <- string(msg)
}()
// Unfortunately, there is a race in crypto/ssh where it sends the request to forward
// unix sockets before it is prepared to receive the response, meaning that even after
// the socket exists on the file system, the client might not be ready to accept the
// channel.
//
// https://cs.opensource.google/go/x/crypto/+/master:ssh/streamlocal.go;drc=2fc4c88bf43f0ea5ea305eae2b7af24b2cc93287;l=33
//
// To work around this, we attempt to send messages in a loop until one succeeds
success := make(chan struct{})
done := make(chan struct{})
go func() {
defer close(done)
var (
conn net.Conn
err error
)
for {
time.Sleep(testutil.IntervalMedium)
select {
case <-ctx.Done():
t.Error("timeout")
return
case <-success:
return
default:
// Ok
}
conn, err = net.Dial("unix", remoteSock)
if err != nil {
t.Logf("dial error: %s", err)
continue
}
_, err = conn.Write([]byte("test"))
if err != nil {
t.Logf("write error: %s", err)
}
err = conn.Close()
if err != nil {
t.Logf("close error: %s", err)
}
}
_, err = conn.Write([]byte("test"))
if err != nil {
t.Logf("write error: %s", err)
}
err = conn.Close()
if err != nil {
t.Logf("close error: %s", err)
}
}
}()
msg := testutil.RequireRecvCtx(ctx, t, msgs)
require.Equal(t, "test", msg)
close(success)
fsn.Notify()
<-cmdDone
fsn.AssertStopped()
// wait for dial goroutine to complete
_ = testutil.RequireRecvCtx(ctx, t, done)
// wait for the remote socket to get cleaned up before retrying,
// because cleaning up the socket happens asynchronously, and we
// might connect to an old listener on the agent side.
require.Eventually(t, func() bool {
_, err = os.Stat(remoteSock)
return xerrors.Is(err, os.ErrNotExist)
}, testutil.WaitShort, testutil.IntervalFast)
}()
msg := testutil.RequireRecvCtx(ctx, t, msgs)
require.Equal(t, "test", msg)
close(success)
fsn.Notify()
<-cmdDone
fsn.AssertStopped()
// wait for dial goroutine to complete
_ = testutil.RequireRecvCtx(ctx, t, done)
// wait for the remote socket to get cleaned up before retrying,
// because cleaning up the socket happens asynchronously, and we
// might connect to an old listener on the agent side.
require.Eventually(t, func() bool {
_, err = os.Stat(remoteSock)
return xerrors.Is(err, os.ErrNotExist)
}, testutil.WaitShort, testutil.IntervalFast)
}
})