fix: filter out cross-talk on TestPortForward (#25503)

<!--

If you have used AI to produce some or all of this PR, please ensure you have read our [AI Contribution guidelines](https://coder.com/docs/about/contributing/AI_CONTRIBUTING) before submitting.

-->

Fixes https://github.com/coder/internal/issues/1539  
  
Protects from port cross-talk by adding a short random prefix to our socket communication and instructing the service on the workspace agent side of the test to ignore any connections that don't use the prefix.
This commit is contained in:
Spike Curtis
2026-05-20 13:08:57 -04:00
committed by GitHub
parent b229573c7e
commit 05e47b9c0f
+101 -58
View File
@@ -1,10 +1,13 @@
package cli_test package cli_test
import ( import (
"bytes"
"context" "context"
"crypto/rand"
"fmt" "fmt"
"io" "io"
"net" "net"
"slices"
"sync" "sync"
"testing" "testing"
"time" "time"
@@ -41,6 +44,22 @@ func TestPortForward_None(t *testing.T) {
require.ErrorContains(t, err, "no port-forwards") require.ErrorContains(t, err, "no port-forwards")
} }
func listenLocalUDPWithPrefix(t *testing.T, prefix []byte) net.Listener {
addr := net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 0,
}
cfg := udp.ListenConfig{AcceptFilter: func(bytes []byte) bool {
if len(bytes) < len(prefix) {
return false
}
return slices.Equal(prefix, bytes[:len(prefix)])
}}
l, err := cfg.Listen("udp", &addr)
require.NoError(t, err, "create UDP listener")
return l
}
func TestPortForward(t *testing.T) { func TestPortForward(t *testing.T) {
t.Parallel() t.Parallel()
cases := []struct { cases := []struct {
@@ -50,8 +69,9 @@ func TestPortForward(t *testing.T) {
// of connection. Has one format arg (string) for the remote address. // of connection. Has one format arg (string) for the remote address.
flag []string flag []string
// setupRemote creates a "remote" listener to emulate a service in the // setupRemote creates a "remote" listener to emulate a service in the
// workspace. // workspace. The prefix is generated per test case and can be used to
setupRemote func(t *testing.T) net.Listener // filter connections.
setupRemote func(t *testing.T, prefix []byte) net.Listener
// the local address(es) to "dial" // the local address(es) to "dial"
localAddress []string localAddress []string
}{ }{
@@ -59,7 +79,7 @@ func TestPortForward(t *testing.T) {
name: "TCP", name: "TCP",
network: "tcp", network: "tcp",
flag: []string{"--tcp=5555:%v", "--tcp=6666:%v"}, flag: []string{"--tcp=5555:%v", "--tcp=6666:%v"},
setupRemote: func(t *testing.T) net.Listener { setupRemote: func(t *testing.T, _ []byte) net.Listener {
l, err := net.Listen("tcp", "127.0.0.1:0") l, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err, "create TCP listener") require.NoError(t, err, "create TCP listener")
return l return l
@@ -70,7 +90,7 @@ func TestPortForward(t *testing.T) {
name: "TCP-opportunistic-ipv6", name: "TCP-opportunistic-ipv6",
network: "tcp", network: "tcp",
flag: []string{"--tcp=5566:%v", "--tcp=6655:%v"}, flag: []string{"--tcp=5566:%v", "--tcp=6655:%v"},
setupRemote: func(t *testing.T) net.Listener { setupRemote: func(t *testing.T, _ []byte) net.Listener {
l, err := net.Listen("tcp", "127.0.0.1:0") l, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err, "create TCP listener") require.NoError(t, err, "create TCP listener")
return l return l
@@ -78,39 +98,23 @@ func TestPortForward(t *testing.T) {
localAddress: []string{"[::1]:5566", "[::1]:6655"}, localAddress: []string{"[::1]:5566", "[::1]:6655"},
}, },
{ {
name: "UDP", name: "UDP",
network: "udp", network: "udp",
flag: []string{"--udp=7777:%v", "--udp=8888:%v"}, flag: []string{"--udp=7777:%v", "--udp=8888:%v"},
setupRemote: func(t *testing.T) net.Listener { setupRemote: listenLocalUDPWithPrefix,
addr := net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 0,
}
l, err := udp.Listen("udp", &addr)
require.NoError(t, err, "create UDP listener")
return l
},
localAddress: []string{"127.0.0.1:7777", "127.0.0.1:8888"}, localAddress: []string{"127.0.0.1:7777", "127.0.0.1:8888"},
}, },
{ {
name: "UDP-opportunistic-ipv6", name: "UDP-opportunistic-ipv6",
network: "udp", network: "udp",
flag: []string{"--udp=7788:%v", "--udp=8877:%v"}, flag: []string{"--udp=7788:%v", "--udp=8877:%v"},
setupRemote: func(t *testing.T) net.Listener { setupRemote: listenLocalUDPWithPrefix,
addr := net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 0,
}
l, err := udp.Listen("udp", &addr)
require.NoError(t, err, "create UDP listener")
return l
},
localAddress: []string{"[::1]:7788", "[::1]:8877"}, localAddress: []string{"[::1]:7788", "[::1]:8877"},
}, },
{ {
name: "TCPWithAddress", name: "TCPWithAddress",
network: "tcp", flag: []string{"--tcp=10.10.10.99:9999:%v", "--tcp=10.10.10.10:1010:%v"}, network: "tcp", flag: []string{"--tcp=10.10.10.99:9999:%v", "--tcp=10.10.10.10:1010:%v"},
setupRemote: func(t *testing.T) net.Listener { setupRemote: func(t *testing.T, _ []byte) net.Listener {
l, err := net.Listen("tcp", "127.0.0.1:0") l, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err, "create TCP listener") require.NoError(t, err, "create TCP listener")
return l return l
@@ -120,7 +124,7 @@ func TestPortForward(t *testing.T) {
{ {
name: "TCP-IPv6", name: "TCP-IPv6",
network: "tcp", flag: []string{"--tcp=[fe80::99]:9999:%v", "--tcp=[fe80::10]:1010:%v"}, network: "tcp", flag: []string{"--tcp=[fe80::99]:9999:%v", "--tcp=[fe80::10]:1010:%v"},
setupRemote: func(t *testing.T) net.Listener { setupRemote: func(t *testing.T, _ []byte) net.Listener {
l, err := net.Listen("tcp", "127.0.0.1:0") l, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err, "create TCP listener") require.NoError(t, err, "create TCP listener")
return l return l
@@ -146,7 +150,8 @@ func TestPortForward(t *testing.T) {
for _, c := range cases { for _, c := range cases {
t.Run(c.name+"_OnePort", func(t *testing.T) { t.Run(c.name+"_OnePort", func(t *testing.T) {
t.Parallel() t.Parallel()
p1 := setupTestListener(t, c.setupRemote(t)) prefix := generateRandomPrefix(t)
p1 := setupTestListener(t, c.setupRemote(t, prefix), prefix)
// Create a flag that forwards from local to listener 1. // Create a flag that forwards from local to listener 1.
flag := fmt.Sprintf(c.flag[0], p1) flag := fmt.Sprintf(c.flag[0], p1)
@@ -182,8 +187,8 @@ func TestPortForward(t *testing.T) {
c2, err := iNet.Dial(dialCtx, testutil.NewAddr(c.network, c.localAddress[0])) c2, err := iNet.Dial(dialCtx, testutil.NewAddr(c.network, c.localAddress[0]))
require.NoError(t, err, "open connection 2 to 'local' listener") require.NoError(t, err, "open connection 2 to 'local' listener")
defer c2.Close() defer c2.Close()
testDial(t, c2) testDial(t, c2, prefix)
testDial(t, c1) testDial(t, c1, prefix)
cancel() cancel()
err = <-errC err = <-errC
@@ -199,10 +204,9 @@ func TestPortForward(t *testing.T) {
t.Run(c.name+"_TwoPorts", func(t *testing.T) { t.Run(c.name+"_TwoPorts", func(t *testing.T) {
t.Parallel() t.Parallel()
var ( prefix := generateRandomPrefix(t)
p1 = setupTestListener(t, c.setupRemote(t)) p1 := setupTestListener(t, c.setupRemote(t, prefix), prefix)
p2 = setupTestListener(t, c.setupRemote(t)) p2 := setupTestListener(t, c.setupRemote(t, prefix), prefix)
)
// Create a flags for listener 1 and listener 2. // Create a flags for listener 1 and listener 2.
flag1 := fmt.Sprintf(c.flag[0], p1) flag1 := fmt.Sprintf(c.flag[0], p1)
@@ -237,8 +241,8 @@ func TestPortForward(t *testing.T) {
c2, err := iNet.Dial(dialCtx, testutil.NewAddr(c.network, c.localAddress[1])) c2, err := iNet.Dial(dialCtx, testutil.NewAddr(c.network, c.localAddress[1]))
require.NoError(t, err, "open connection 2 to 'local' listener 2") require.NoError(t, err, "open connection 2 to 'local' listener 2")
defer c2.Close() defer c2.Close()
testDial(t, c2) testDial(t, c2, prefix)
testDial(t, c1) testDial(t, c1, prefix)
cancel() cancel()
err = <-errC err = <-errC
@@ -260,9 +264,10 @@ func TestPortForward(t *testing.T) {
flags = []string{} flags = []string{}
) )
prefix := generateRandomPrefix(t)
// Start listeners and populate arrays with the cases. // Start listeners and populate arrays with the cases.
for _, c := range cases { for _, c := range cases {
p := setupTestListener(t, c.setupRemote(t)) p := setupTestListener(t, c.setupRemote(t, prefix), prefix)
dials = append(dials, testutil.NewAddr(c.network, c.localAddress[0])) dials = append(dials, testutil.NewAddr(c.network, c.localAddress[0]))
flags = append(flags, fmt.Sprintf(c.flag[0], p)) flags = append(flags, fmt.Sprintf(c.flag[0], p))
@@ -302,7 +307,7 @@ func TestPortForward(t *testing.T) {
// Test each connection in reverse order. // Test each connection in reverse order.
for i := len(conns) - 1; i >= 0; i-- { for i := len(conns) - 1; i >= 0; i-- {
testDial(t, conns[i]) testDial(t, conns[i], prefix)
} }
cancel() cancel()
@@ -320,9 +325,11 @@ func TestPortForward(t *testing.T) {
t.Run("IPv6Busy", func(t *testing.T) { t.Run("IPv6Busy", func(t *testing.T) {
t.Parallel() t.Parallel()
prefix := generateRandomPrefix(t)
remoteLis, err := net.Listen("tcp", "127.0.0.1:0") remoteLis, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err, "create TCP listener") require.NoError(t, err, "create TCP listener")
p1 := setupTestListener(t, remoteLis) p1 := setupTestListener(t, remoteLis, prefix)
// Create a flag that forwards from local 5555 to remote listener port. // Create a flag that forwards from local 5555 to remote listener port.
flag := fmt.Sprintf("--tcp=5555:%v", p1) flag := fmt.Sprintf("--tcp=5555:%v", p1)
@@ -360,7 +367,7 @@ func TestPortForward(t *testing.T) {
c1, err := iNet.Dial(dialCtx, testutil.NewAddr("tcp", "127.0.0.1:5555")) c1, err := iNet.Dial(dialCtx, testutil.NewAddr("tcp", "127.0.0.1:5555"))
require.NoError(t, err, "open connection 1 to 'local' listener") require.NoError(t, err, "open connection 1 to 'local' listener")
defer c1.Close() defer c1.Close()
testDial(t, c1) testDial(t, c1, prefix)
cancel() cancel()
err = <-errC err = <-errC
@@ -375,6 +382,17 @@ func TestPortForward(t *testing.T) {
}) })
} }
// generateRandomPrefix generates a unique prefix per test case to ensure that we can filter out any cross-talk on the
// local network.
func generateRandomPrefix(t *testing.T) []byte {
t.Helper()
prefix := make([]byte, 16)
n, err := rand.Read(prefix)
require.NoError(t, err)
require.Equal(t, 16, n)
return prefix
}
// runAgent creates a fake workspace and starts an agent locally for that // runAgent creates a fake workspace and starts an agent locally for that
// workspace. The agent will be cleaned up on test completion. // workspace. The agent will be cleaned up on test completion.
// nolint:unused // nolint:unused
@@ -398,8 +416,8 @@ func runAgent(t *testing.T, client *codersdk.Client, owner uuid.UUID, db databas
} }
// setupTestListener starts accepting connections and echoing a single packet. // setupTestListener starts accepting connections and echoing a single packet.
// Returns the listener and the listen port. // Returns the listen port.
func setupTestListener(t *testing.T, l net.Listener) string { func setupTestListener(t *testing.T, l net.Listener, prefix []byte) string {
t.Helper() t.Helper()
// Wait for listener to completely exit before releasing. // Wait for listener to completely exit before releasing.
@@ -423,7 +441,7 @@ func setupTestListener(t *testing.T, l net.Listener) string {
wg.Add(1) wg.Add(1)
go func() { go func() {
testAccept(t, c) echoIfPrefixed(t, c, prefix)
wg.Done() wg.Done()
}() }()
} }
@@ -432,30 +450,54 @@ func setupTestListener(t *testing.T, l net.Listener) string {
addr := l.Addr().String() addr := l.Addr().String()
_, port, err := net.SplitHostPort(addr) _, port, err := net.SplitHostPort(addr)
require.NoErrorf(t, err, "split non-Unix listen path %q", addr) require.NoErrorf(t, err, "split non-Unix listen path %q", addr)
addr = port return port
return addr
} }
var dialTestPayload = []byte("dean-was-here123") const dialTestPayload = "dean-was-here123"
func testDial(t *testing.T, c net.Conn) { func newPayload(prefix []byte) []byte {
payload := make([]byte, 0, len(dialTestPayload)+len(prefix))
payload = append(payload, prefix...)
payload = append(payload, dialTestPayload...)
return payload
}
func testDial(t *testing.T, c net.Conn, prefix []byte) {
t.Helper() t.Helper()
assertWritePayload(t, c, dialTestPayload) assertWritePayload(t, c, prefix)
assertReadPayload(t, c, dialTestPayload) assertReadPayload(t, c, prefix)
} }
func testAccept(t *testing.T, c net.Conn) { func echoIfPrefixed(t *testing.T, c net.Conn, prefix []byte) {
t.Helper() t.Helper()
defer c.Close() defer c.Close()
assertReadPayload(t, c, dialTestPayload) // here we don't want to assert anything, because the listener is exposed to the OS, so who knows what might
assertWritePayload(t, c, dialTestPayload) // connect. If we get the expected prefix to our message, echo it back.
b := make([]byte, 2048)
n, err := c.Read(b)
if err != nil {
t.Logf("read failed (could be crosstalk): %v", err)
return
}
if n < len(prefix) {
t.Logf("short read (could be crosstalk): read %x", b[:n])
return
}
if !bytes.HasPrefix(b, prefix) {
t.Logf("missing prefix (could be crosstalk), wanted %x got %x", prefix, b[:n])
return
}
_, err = c.Write(b[:n])
if err != nil {
t.Logf("write failed: %v", err)
}
} }
func assertReadPayload(t *testing.T, r io.Reader, payload []byte) { func assertReadPayload(t *testing.T, r io.Reader, prefix []byte) {
t.Helper() t.Helper()
payload := newPayload(prefix)
b := make([]byte, len(payload)+16) b := make([]byte, len(payload)+16)
n, err := r.Read(b) n, err := r.Read(b)
assert.NoError(t, err, "read payload") assert.NoError(t, err, "read payload")
@@ -463,8 +505,9 @@ func assertReadPayload(t *testing.T, r io.Reader, payload []byte) {
assert.Equal(t, payload, b[:n]) assert.Equal(t, payload, b[:n])
} }
func assertWritePayload(t *testing.T, w io.Writer, payload []byte) { func assertWritePayload(t *testing.T, w io.Writer, prefix []byte) {
t.Helper() t.Helper()
payload := newPayload(prefix)
n, err := w.Write(payload) n, err := w.Write(payload)
assert.NoError(t, err, "write payload") assert.NoError(t, err, "write payload")
assert.Equal(t, len(payload), n, "payload length does not match") assert.Equal(t, len(payload), n, "payload length does not match")