diff --git a/cli/portforward_test.go b/cli/portforward_test.go index 9899bd28cc..91c13efabe 100644 --- a/cli/portforward_test.go +++ b/cli/portforward_test.go @@ -1,10 +1,13 @@ package cli_test import ( + "bytes" "context" + "crypto/rand" "fmt" "io" "net" + "slices" "sync" "testing" "time" @@ -41,6 +44,22 @@ func TestPortForward_None(t *testing.T) { require.ErrorContains(t, err, "no port-forwards") } +func listenLocalUDPWithPrefix(t *testing.T, prefix []byte) net.Listener { + addr := net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 0, + } + cfg := udp.ListenConfig{AcceptFilter: func(bytes []byte) bool { + if len(bytes) < len(prefix) { + return false + } + return slices.Equal(prefix, bytes[:len(prefix)]) + }} + l, err := cfg.Listen("udp", &addr) + require.NoError(t, err, "create UDP listener") + return l +} + func TestPortForward(t *testing.T) { t.Parallel() cases := []struct { @@ -50,8 +69,9 @@ func TestPortForward(t *testing.T) { // of connection. Has one format arg (string) for the remote address. flag []string // setupRemote creates a "remote" listener to emulate a service in the - // workspace. - setupRemote func(t *testing.T) net.Listener + // workspace. The prefix is generated per test case and can be used to + // filter connections. + setupRemote func(t *testing.T, prefix []byte) net.Listener // the local address(es) to "dial" localAddress []string }{ @@ -59,7 +79,7 @@ func TestPortForward(t *testing.T) { name: "TCP", network: "tcp", flag: []string{"--tcp=5555:%v", "--tcp=6666:%v"}, - setupRemote: func(t *testing.T) net.Listener { + setupRemote: func(t *testing.T, _ []byte) net.Listener { l, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err, "create TCP listener") return l @@ -70,7 +90,7 @@ func TestPortForward(t *testing.T) { name: "TCP-opportunistic-ipv6", network: "tcp", flag: []string{"--tcp=5566:%v", "--tcp=6655:%v"}, - setupRemote: func(t *testing.T) net.Listener { + setupRemote: func(t *testing.T, _ []byte) net.Listener { l, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err, "create TCP listener") return l @@ -78,39 +98,23 @@ func TestPortForward(t *testing.T) { localAddress: []string{"[::1]:5566", "[::1]:6655"}, }, { - name: "UDP", - network: "udp", - flag: []string{"--udp=7777:%v", "--udp=8888:%v"}, - setupRemote: func(t *testing.T) net.Listener { - addr := net.UDPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 0, - } - l, err := udp.Listen("udp", &addr) - require.NoError(t, err, "create UDP listener") - return l - }, + name: "UDP", + network: "udp", + flag: []string{"--udp=7777:%v", "--udp=8888:%v"}, + setupRemote: listenLocalUDPWithPrefix, localAddress: []string{"127.0.0.1:7777", "127.0.0.1:8888"}, }, { - name: "UDP-opportunistic-ipv6", - network: "udp", - flag: []string{"--udp=7788:%v", "--udp=8877:%v"}, - setupRemote: func(t *testing.T) net.Listener { - addr := net.UDPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 0, - } - l, err := udp.Listen("udp", &addr) - require.NoError(t, err, "create UDP listener") - return l - }, + name: "UDP-opportunistic-ipv6", + network: "udp", + flag: []string{"--udp=7788:%v", "--udp=8877:%v"}, + setupRemote: listenLocalUDPWithPrefix, localAddress: []string{"[::1]:7788", "[::1]:8877"}, }, { name: "TCPWithAddress", network: "tcp", flag: []string{"--tcp=10.10.10.99:9999:%v", "--tcp=10.10.10.10:1010:%v"}, - setupRemote: func(t *testing.T) net.Listener { + setupRemote: func(t *testing.T, _ []byte) net.Listener { l, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err, "create TCP listener") return l @@ -120,7 +124,7 @@ func TestPortForward(t *testing.T) { { name: "TCP-IPv6", network: "tcp", flag: []string{"--tcp=[fe80::99]:9999:%v", "--tcp=[fe80::10]:1010:%v"}, - setupRemote: func(t *testing.T) net.Listener { + setupRemote: func(t *testing.T, _ []byte) net.Listener { l, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err, "create TCP listener") return l @@ -146,7 +150,8 @@ func TestPortForward(t *testing.T) { for _, c := range cases { t.Run(c.name+"_OnePort", func(t *testing.T) { t.Parallel() - p1 := setupTestListener(t, c.setupRemote(t)) + prefix := generateRandomPrefix(t) + p1 := setupTestListener(t, c.setupRemote(t, prefix), prefix) // Create a flag that forwards from local to listener 1. flag := fmt.Sprintf(c.flag[0], p1) @@ -182,8 +187,8 @@ func TestPortForward(t *testing.T) { c2, err := iNet.Dial(dialCtx, testutil.NewAddr(c.network, c.localAddress[0])) require.NoError(t, err, "open connection 2 to 'local' listener") defer c2.Close() - testDial(t, c2) - testDial(t, c1) + testDial(t, c2, prefix) + testDial(t, c1, prefix) cancel() err = <-errC @@ -199,10 +204,9 @@ func TestPortForward(t *testing.T) { t.Run(c.name+"_TwoPorts", func(t *testing.T) { t.Parallel() - var ( - p1 = setupTestListener(t, c.setupRemote(t)) - p2 = setupTestListener(t, c.setupRemote(t)) - ) + prefix := generateRandomPrefix(t) + p1 := setupTestListener(t, c.setupRemote(t, prefix), prefix) + p2 := setupTestListener(t, c.setupRemote(t, prefix), prefix) // Create a flags for listener 1 and listener 2. flag1 := fmt.Sprintf(c.flag[0], p1) @@ -237,8 +241,8 @@ func TestPortForward(t *testing.T) { c2, err := iNet.Dial(dialCtx, testutil.NewAddr(c.network, c.localAddress[1])) require.NoError(t, err, "open connection 2 to 'local' listener 2") defer c2.Close() - testDial(t, c2) - testDial(t, c1) + testDial(t, c2, prefix) + testDial(t, c1, prefix) cancel() err = <-errC @@ -260,9 +264,10 @@ func TestPortForward(t *testing.T) { flags = []string{} ) + prefix := generateRandomPrefix(t) // Start listeners and populate arrays with the cases. for _, c := range cases { - p := setupTestListener(t, c.setupRemote(t)) + p := setupTestListener(t, c.setupRemote(t, prefix), prefix) dials = append(dials, testutil.NewAddr(c.network, c.localAddress[0])) flags = append(flags, fmt.Sprintf(c.flag[0], p)) @@ -302,7 +307,7 @@ func TestPortForward(t *testing.T) { // Test each connection in reverse order. for i := len(conns) - 1; i >= 0; i-- { - testDial(t, conns[i]) + testDial(t, conns[i], prefix) } cancel() @@ -320,9 +325,11 @@ func TestPortForward(t *testing.T) { t.Run("IPv6Busy", func(t *testing.T) { t.Parallel() + prefix := generateRandomPrefix(t) + remoteLis, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err, "create TCP listener") - p1 := setupTestListener(t, remoteLis) + p1 := setupTestListener(t, remoteLis, prefix) // Create a flag that forwards from local 5555 to remote listener port. flag := fmt.Sprintf("--tcp=5555:%v", p1) @@ -360,7 +367,7 @@ func TestPortForward(t *testing.T) { c1, err := iNet.Dial(dialCtx, testutil.NewAddr("tcp", "127.0.0.1:5555")) require.NoError(t, err, "open connection 1 to 'local' listener") defer c1.Close() - testDial(t, c1) + testDial(t, c1, prefix) cancel() err = <-errC @@ -375,6 +382,17 @@ func TestPortForward(t *testing.T) { }) } +// generateRandomPrefix generates a unique prefix per test case to ensure that we can filter out any cross-talk on the +// local network. +func generateRandomPrefix(t *testing.T) []byte { + t.Helper() + prefix := make([]byte, 16) + n, err := rand.Read(prefix) + require.NoError(t, err) + require.Equal(t, 16, n) + return prefix +} + // runAgent creates a fake workspace and starts an agent locally for that // workspace. The agent will be cleaned up on test completion. // nolint:unused @@ -398,8 +416,8 @@ func runAgent(t *testing.T, client *codersdk.Client, owner uuid.UUID, db databas } // setupTestListener starts accepting connections and echoing a single packet. -// Returns the listener and the listen port. -func setupTestListener(t *testing.T, l net.Listener) string { +// Returns the listen port. +func setupTestListener(t *testing.T, l net.Listener, prefix []byte) string { t.Helper() // Wait for listener to completely exit before releasing. @@ -423,7 +441,7 @@ func setupTestListener(t *testing.T, l net.Listener) string { wg.Add(1) go func() { - testAccept(t, c) + echoIfPrefixed(t, c, prefix) wg.Done() }() } @@ -432,30 +450,54 @@ func setupTestListener(t *testing.T, l net.Listener) string { addr := l.Addr().String() _, port, err := net.SplitHostPort(addr) require.NoErrorf(t, err, "split non-Unix listen path %q", addr) - addr = port - - return addr + return port } -var dialTestPayload = []byte("dean-was-here123") +const dialTestPayload = "dean-was-here123" -func testDial(t *testing.T, c net.Conn) { +func newPayload(prefix []byte) []byte { + payload := make([]byte, 0, len(dialTestPayload)+len(prefix)) + payload = append(payload, prefix...) + payload = append(payload, dialTestPayload...) + return payload +} + +func testDial(t *testing.T, c net.Conn, prefix []byte) { t.Helper() - assertWritePayload(t, c, dialTestPayload) - assertReadPayload(t, c, dialTestPayload) + assertWritePayload(t, c, prefix) + assertReadPayload(t, c, prefix) } -func testAccept(t *testing.T, c net.Conn) { +func echoIfPrefixed(t *testing.T, c net.Conn, prefix []byte) { t.Helper() defer c.Close() - assertReadPayload(t, c, dialTestPayload) - assertWritePayload(t, c, dialTestPayload) + // here we don't want to assert anything, because the listener is exposed to the OS, so who knows what might + // connect. If we get the expected prefix to our message, echo it back. + b := make([]byte, 2048) + n, err := c.Read(b) + if err != nil { + t.Logf("read failed (could be crosstalk): %v", err) + return + } + if n < len(prefix) { + t.Logf("short read (could be crosstalk): read %x", b[:n]) + return + } + if !bytes.HasPrefix(b, prefix) { + t.Logf("missing prefix (could be crosstalk), wanted %x got %x", prefix, b[:n]) + return + } + _, err = c.Write(b[:n]) + if err != nil { + t.Logf("write failed: %v", err) + } } -func assertReadPayload(t *testing.T, r io.Reader, payload []byte) { +func assertReadPayload(t *testing.T, r io.Reader, prefix []byte) { t.Helper() + payload := newPayload(prefix) b := make([]byte, len(payload)+16) n, err := r.Read(b) assert.NoError(t, err, "read payload") @@ -463,8 +505,9 @@ func assertReadPayload(t *testing.T, r io.Reader, payload []byte) { assert.Equal(t, payload, b[:n]) } -func assertWritePayload(t *testing.T, w io.Writer, payload []byte) { +func assertWritePayload(t *testing.T, w io.Writer, prefix []byte) { t.Helper() + payload := newPayload(prefix) n, err := w.Write(payload) assert.NoError(t, err, "write payload") assert.Equal(t, len(payload), n, "payload length does not match")