fix: filter out cross-talk on TestPortForward (#25503)

<!--

If you have used AI to produce some or all of this PR, please ensure you have read our [AI Contribution guidelines](https://coder.com/docs/about/contributing/AI_CONTRIBUTING) before submitting.

-->

Fixes https://github.com/coder/internal/issues/1539  
  
Protects from port cross-talk by adding a short random prefix to our socket communication and instructing the service on the workspace agent side of the test to ignore any connections that don't use the prefix.
This commit is contained in:
Spike Curtis
2026-05-20 13:08:57 -04:00
committed by GitHub
parent b229573c7e
commit 05e47b9c0f
+101 -58
View File
@@ -1,10 +1,13 @@
package cli_test
import (
"bytes"
"context"
"crypto/rand"
"fmt"
"io"
"net"
"slices"
"sync"
"testing"
"time"
@@ -41,6 +44,22 @@ func TestPortForward_None(t *testing.T) {
require.ErrorContains(t, err, "no port-forwards")
}
func listenLocalUDPWithPrefix(t *testing.T, prefix []byte) net.Listener {
addr := net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 0,
}
cfg := udp.ListenConfig{AcceptFilter: func(bytes []byte) bool {
if len(bytes) < len(prefix) {
return false
}
return slices.Equal(prefix, bytes[:len(prefix)])
}}
l, err := cfg.Listen("udp", &addr)
require.NoError(t, err, "create UDP listener")
return l
}
func TestPortForward(t *testing.T) {
t.Parallel()
cases := []struct {
@@ -50,8 +69,9 @@ func TestPortForward(t *testing.T) {
// of connection. Has one format arg (string) for the remote address.
flag []string
// setupRemote creates a "remote" listener to emulate a service in the
// workspace.
setupRemote func(t *testing.T) net.Listener
// workspace. The prefix is generated per test case and can be used to
// filter connections.
setupRemote func(t *testing.T, prefix []byte) net.Listener
// the local address(es) to "dial"
localAddress []string
}{
@@ -59,7 +79,7 @@ func TestPortForward(t *testing.T) {
name: "TCP",
network: "tcp",
flag: []string{"--tcp=5555:%v", "--tcp=6666:%v"},
setupRemote: func(t *testing.T) net.Listener {
setupRemote: func(t *testing.T, _ []byte) net.Listener {
l, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err, "create TCP listener")
return l
@@ -70,7 +90,7 @@ func TestPortForward(t *testing.T) {
name: "TCP-opportunistic-ipv6",
network: "tcp",
flag: []string{"--tcp=5566:%v", "--tcp=6655:%v"},
setupRemote: func(t *testing.T) net.Listener {
setupRemote: func(t *testing.T, _ []byte) net.Listener {
l, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err, "create TCP listener")
return l
@@ -78,39 +98,23 @@ func TestPortForward(t *testing.T) {
localAddress: []string{"[::1]:5566", "[::1]:6655"},
},
{
name: "UDP",
network: "udp",
flag: []string{"--udp=7777:%v", "--udp=8888:%v"},
setupRemote: func(t *testing.T) net.Listener {
addr := net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 0,
}
l, err := udp.Listen("udp", &addr)
require.NoError(t, err, "create UDP listener")
return l
},
name: "UDP",
network: "udp",
flag: []string{"--udp=7777:%v", "--udp=8888:%v"},
setupRemote: listenLocalUDPWithPrefix,
localAddress: []string{"127.0.0.1:7777", "127.0.0.1:8888"},
},
{
name: "UDP-opportunistic-ipv6",
network: "udp",
flag: []string{"--udp=7788:%v", "--udp=8877:%v"},
setupRemote: func(t *testing.T) net.Listener {
addr := net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 0,
}
l, err := udp.Listen("udp", &addr)
require.NoError(t, err, "create UDP listener")
return l
},
name: "UDP-opportunistic-ipv6",
network: "udp",
flag: []string{"--udp=7788:%v", "--udp=8877:%v"},
setupRemote: listenLocalUDPWithPrefix,
localAddress: []string{"[::1]:7788", "[::1]:8877"},
},
{
name: "TCPWithAddress",
network: "tcp", flag: []string{"--tcp=10.10.10.99:9999:%v", "--tcp=10.10.10.10:1010:%v"},
setupRemote: func(t *testing.T) net.Listener {
setupRemote: func(t *testing.T, _ []byte) net.Listener {
l, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err, "create TCP listener")
return l
@@ -120,7 +124,7 @@ func TestPortForward(t *testing.T) {
{
name: "TCP-IPv6",
network: "tcp", flag: []string{"--tcp=[fe80::99]:9999:%v", "--tcp=[fe80::10]:1010:%v"},
setupRemote: func(t *testing.T) net.Listener {
setupRemote: func(t *testing.T, _ []byte) net.Listener {
l, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err, "create TCP listener")
return l
@@ -146,7 +150,8 @@ func TestPortForward(t *testing.T) {
for _, c := range cases {
t.Run(c.name+"_OnePort", func(t *testing.T) {
t.Parallel()
p1 := setupTestListener(t, c.setupRemote(t))
prefix := generateRandomPrefix(t)
p1 := setupTestListener(t, c.setupRemote(t, prefix), prefix)
// Create a flag that forwards from local to listener 1.
flag := fmt.Sprintf(c.flag[0], p1)
@@ -182,8 +187,8 @@ func TestPortForward(t *testing.T) {
c2, err := iNet.Dial(dialCtx, testutil.NewAddr(c.network, c.localAddress[0]))
require.NoError(t, err, "open connection 2 to 'local' listener")
defer c2.Close()
testDial(t, c2)
testDial(t, c1)
testDial(t, c2, prefix)
testDial(t, c1, prefix)
cancel()
err = <-errC
@@ -199,10 +204,9 @@ func TestPortForward(t *testing.T) {
t.Run(c.name+"_TwoPorts", func(t *testing.T) {
t.Parallel()
var (
p1 = setupTestListener(t, c.setupRemote(t))
p2 = setupTestListener(t, c.setupRemote(t))
)
prefix := generateRandomPrefix(t)
p1 := setupTestListener(t, c.setupRemote(t, prefix), prefix)
p2 := setupTestListener(t, c.setupRemote(t, prefix), prefix)
// Create a flags for listener 1 and listener 2.
flag1 := fmt.Sprintf(c.flag[0], p1)
@@ -237,8 +241,8 @@ func TestPortForward(t *testing.T) {
c2, err := iNet.Dial(dialCtx, testutil.NewAddr(c.network, c.localAddress[1]))
require.NoError(t, err, "open connection 2 to 'local' listener 2")
defer c2.Close()
testDial(t, c2)
testDial(t, c1)
testDial(t, c2, prefix)
testDial(t, c1, prefix)
cancel()
err = <-errC
@@ -260,9 +264,10 @@ func TestPortForward(t *testing.T) {
flags = []string{}
)
prefix := generateRandomPrefix(t)
// Start listeners and populate arrays with the cases.
for _, c := range cases {
p := setupTestListener(t, c.setupRemote(t))
p := setupTestListener(t, c.setupRemote(t, prefix), prefix)
dials = append(dials, testutil.NewAddr(c.network, c.localAddress[0]))
flags = append(flags, fmt.Sprintf(c.flag[0], p))
@@ -302,7 +307,7 @@ func TestPortForward(t *testing.T) {
// Test each connection in reverse order.
for i := len(conns) - 1; i >= 0; i-- {
testDial(t, conns[i])
testDial(t, conns[i], prefix)
}
cancel()
@@ -320,9 +325,11 @@ func TestPortForward(t *testing.T) {
t.Run("IPv6Busy", func(t *testing.T) {
t.Parallel()
prefix := generateRandomPrefix(t)
remoteLis, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err, "create TCP listener")
p1 := setupTestListener(t, remoteLis)
p1 := setupTestListener(t, remoteLis, prefix)
// Create a flag that forwards from local 5555 to remote listener port.
flag := fmt.Sprintf("--tcp=5555:%v", p1)
@@ -360,7 +367,7 @@ func TestPortForward(t *testing.T) {
c1, err := iNet.Dial(dialCtx, testutil.NewAddr("tcp", "127.0.0.1:5555"))
require.NoError(t, err, "open connection 1 to 'local' listener")
defer c1.Close()
testDial(t, c1)
testDial(t, c1, prefix)
cancel()
err = <-errC
@@ -375,6 +382,17 @@ func TestPortForward(t *testing.T) {
})
}
// generateRandomPrefix generates a unique prefix per test case to ensure that we can filter out any cross-talk on the
// local network.
func generateRandomPrefix(t *testing.T) []byte {
t.Helper()
prefix := make([]byte, 16)
n, err := rand.Read(prefix)
require.NoError(t, err)
require.Equal(t, 16, n)
return prefix
}
// runAgent creates a fake workspace and starts an agent locally for that
// workspace. The agent will be cleaned up on test completion.
// nolint:unused
@@ -398,8 +416,8 @@ func runAgent(t *testing.T, client *codersdk.Client, owner uuid.UUID, db databas
}
// setupTestListener starts accepting connections and echoing a single packet.
// Returns the listener and the listen port.
func setupTestListener(t *testing.T, l net.Listener) string {
// Returns the listen port.
func setupTestListener(t *testing.T, l net.Listener, prefix []byte) string {
t.Helper()
// Wait for listener to completely exit before releasing.
@@ -423,7 +441,7 @@ func setupTestListener(t *testing.T, l net.Listener) string {
wg.Add(1)
go func() {
testAccept(t, c)
echoIfPrefixed(t, c, prefix)
wg.Done()
}()
}
@@ -432,30 +450,54 @@ func setupTestListener(t *testing.T, l net.Listener) string {
addr := l.Addr().String()
_, port, err := net.SplitHostPort(addr)
require.NoErrorf(t, err, "split non-Unix listen path %q", addr)
addr = port
return addr
return port
}
var dialTestPayload = []byte("dean-was-here123")
const dialTestPayload = "dean-was-here123"
func testDial(t *testing.T, c net.Conn) {
func newPayload(prefix []byte) []byte {
payload := make([]byte, 0, len(dialTestPayload)+len(prefix))
payload = append(payload, prefix...)
payload = append(payload, dialTestPayload...)
return payload
}
func testDial(t *testing.T, c net.Conn, prefix []byte) {
t.Helper()
assertWritePayload(t, c, dialTestPayload)
assertReadPayload(t, c, dialTestPayload)
assertWritePayload(t, c, prefix)
assertReadPayload(t, c, prefix)
}
func testAccept(t *testing.T, c net.Conn) {
func echoIfPrefixed(t *testing.T, c net.Conn, prefix []byte) {
t.Helper()
defer c.Close()
assertReadPayload(t, c, dialTestPayload)
assertWritePayload(t, c, dialTestPayload)
// here we don't want to assert anything, because the listener is exposed to the OS, so who knows what might
// connect. If we get the expected prefix to our message, echo it back.
b := make([]byte, 2048)
n, err := c.Read(b)
if err != nil {
t.Logf("read failed (could be crosstalk): %v", err)
return
}
if n < len(prefix) {
t.Logf("short read (could be crosstalk): read %x", b[:n])
return
}
if !bytes.HasPrefix(b, prefix) {
t.Logf("missing prefix (could be crosstalk), wanted %x got %x", prefix, b[:n])
return
}
_, err = c.Write(b[:n])
if err != nil {
t.Logf("write failed: %v", err)
}
}
func assertReadPayload(t *testing.T, r io.Reader, payload []byte) {
func assertReadPayload(t *testing.T, r io.Reader, prefix []byte) {
t.Helper()
payload := newPayload(prefix)
b := make([]byte, len(payload)+16)
n, err := r.Read(b)
assert.NoError(t, err, "read payload")
@@ -463,8 +505,9 @@ func assertReadPayload(t *testing.T, r io.Reader, payload []byte) {
assert.Equal(t, payload, b[:n])
}
func assertWritePayload(t *testing.T, w io.Writer, payload []byte) {
func assertWritePayload(t *testing.T, w io.Writer, prefix []byte) {
t.Helper()
payload := newPayload(prefix)
n, err := w.Write(payload)
assert.NoError(t, err, "write payload")
assert.Equal(t, len(payload), n, "payload length does not match")