mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
test(cli): fix TestSSH/RemoteForward_Unix_Signal flake (#16172)
This commit is contained in:
committed by
GitHub
parent
ea8cd55404
commit
7cf62423ec
+90
-87
@@ -819,102 +819,105 @@ func TestSSH(t *testing.T) {
|
||||
|
||||
tmpdir := tempDirUnixSocket(t)
|
||||
localSock := filepath.Join(tmpdir, "local.sock")
|
||||
l, err := net.Listen("unix", localSock)
|
||||
require.NoError(t, err)
|
||||
defer l.Close()
|
||||
remoteSock := path.Join(tmpdir, "remote.sock")
|
||||
for i := 0; i < 2; i++ {
|
||||
t.Logf("connect %d of 2", i+1)
|
||||
inv, root := clitest.New(t,
|
||||
"ssh",
|
||||
workspace.Name,
|
||||
"--remote-forward",
|
||||
remoteSock+":"+localSock,
|
||||
)
|
||||
fsn := clitest.NewFakeSignalNotifier(t)
|
||||
inv = inv.WithTestSignalNotifyContext(t, fsn.NotifyContext)
|
||||
inv.Stdout = io.Discard
|
||||
inv.Stderr = io.Discard
|
||||
func() { // Function scope for defer.
|
||||
t.Logf("Connect %d/2", i+1)
|
||||
|
||||
clitest.SetupConfig(t, client, root)
|
||||
cmdDone := tGo(t, func() {
|
||||
err := inv.WithContext(ctx).Run()
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
// accept a single connection
|
||||
msgs := make(chan string, 1)
|
||||
go func() {
|
||||
conn, err := l.Accept()
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
msg, err := io.ReadAll(conn)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
msgs <- string(msg)
|
||||
}()
|
||||
|
||||
// Unfortunately, there is a race in crypto/ssh where it sends the request to forward
|
||||
// unix sockets before it is prepared to receive the response, meaning that even after
|
||||
// the socket exists on the file system, the client might not be ready to accept the
|
||||
// channel.
|
||||
//
|
||||
// https://cs.opensource.google/go/x/crypto/+/master:ssh/streamlocal.go;drc=2fc4c88bf43f0ea5ea305eae2b7af24b2cc93287;l=33
|
||||
//
|
||||
// To work around this, we attempt to send messages in a loop until one succeeds
|
||||
success := make(chan struct{})
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
var (
|
||||
conn net.Conn
|
||||
err error
|
||||
inv, root := clitest.New(t,
|
||||
"ssh",
|
||||
workspace.Name,
|
||||
"--remote-forward",
|
||||
remoteSock+":"+localSock,
|
||||
)
|
||||
for {
|
||||
time.Sleep(testutil.IntervalMedium)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Error("timeout")
|
||||
fsn := clitest.NewFakeSignalNotifier(t)
|
||||
inv = inv.WithTestSignalNotifyContext(t, fsn.NotifyContext)
|
||||
inv.Stdout = io.Discard
|
||||
inv.Stderr = io.Discard
|
||||
|
||||
clitest.SetupConfig(t, client, root)
|
||||
cmdDone := tGo(t, func() {
|
||||
err := inv.WithContext(ctx).Run()
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
// accept a single connection
|
||||
msgs := make(chan string, 1)
|
||||
l, err := net.Listen("unix", localSock)
|
||||
require.NoError(t, err)
|
||||
defer l.Close()
|
||||
go func() {
|
||||
conn, err := l.Accept()
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
case <-success:
|
||||
}
|
||||
msg, err := io.ReadAll(conn)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
default:
|
||||
// Ok
|
||||
}
|
||||
conn, err = net.Dial("unix", remoteSock)
|
||||
if err != nil {
|
||||
t.Logf("dial error: %s", err)
|
||||
continue
|
||||
msgs <- string(msg)
|
||||
}()
|
||||
|
||||
// Unfortunately, there is a race in crypto/ssh where it sends the request to forward
|
||||
// unix sockets before it is prepared to receive the response, meaning that even after
|
||||
// the socket exists on the file system, the client might not be ready to accept the
|
||||
// channel.
|
||||
//
|
||||
// https://cs.opensource.google/go/x/crypto/+/master:ssh/streamlocal.go;drc=2fc4c88bf43f0ea5ea305eae2b7af24b2cc93287;l=33
|
||||
//
|
||||
// To work around this, we attempt to send messages in a loop until one succeeds
|
||||
success := make(chan struct{})
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
var (
|
||||
conn net.Conn
|
||||
err error
|
||||
)
|
||||
for {
|
||||
time.Sleep(testutil.IntervalMedium)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Error("timeout")
|
||||
return
|
||||
case <-success:
|
||||
return
|
||||
default:
|
||||
// Ok
|
||||
}
|
||||
conn, err = net.Dial("unix", remoteSock)
|
||||
if err != nil {
|
||||
t.Logf("dial error: %s", err)
|
||||
continue
|
||||
}
|
||||
_, err = conn.Write([]byte("test"))
|
||||
if err != nil {
|
||||
t.Logf("write error: %s", err)
|
||||
}
|
||||
err = conn.Close()
|
||||
if err != nil {
|
||||
t.Logf("close error: %s", err)
|
||||
}
|
||||
}
|
||||
_, err = conn.Write([]byte("test"))
|
||||
if err != nil {
|
||||
t.Logf("write error: %s", err)
|
||||
}
|
||||
err = conn.Close()
|
||||
if err != nil {
|
||||
t.Logf("close error: %s", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
msg := testutil.RequireRecvCtx(ctx, t, msgs)
|
||||
require.Equal(t, "test", msg)
|
||||
close(success)
|
||||
fsn.Notify()
|
||||
<-cmdDone
|
||||
fsn.AssertStopped()
|
||||
// wait for dial goroutine to complete
|
||||
_ = testutil.RequireRecvCtx(ctx, t, done)
|
||||
|
||||
// wait for the remote socket to get cleaned up before retrying,
|
||||
// because cleaning up the socket happens asynchronously, and we
|
||||
// might connect to an old listener on the agent side.
|
||||
require.Eventually(t, func() bool {
|
||||
_, err = os.Stat(remoteSock)
|
||||
return xerrors.Is(err, os.ErrNotExist)
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
}()
|
||||
|
||||
msg := testutil.RequireRecvCtx(ctx, t, msgs)
|
||||
require.Equal(t, "test", msg)
|
||||
close(success)
|
||||
fsn.Notify()
|
||||
<-cmdDone
|
||||
fsn.AssertStopped()
|
||||
// wait for dial goroutine to complete
|
||||
_ = testutil.RequireRecvCtx(ctx, t, done)
|
||||
|
||||
// wait for the remote socket to get cleaned up before retrying,
|
||||
// because cleaning up the socket happens asynchronously, and we
|
||||
// might connect to an old listener on the agent side.
|
||||
require.Eventually(t, func() bool {
|
||||
_, err = os.Stat(remoteSock)
|
||||
return xerrors.Is(err, os.ErrNotExist)
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
Reference in New Issue
Block a user