mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
fix: filter out cross-talk on TestPortForward (#25503)
<!-- If you have used AI to produce some or all of this PR, please ensure you have read our [AI Contribution guidelines](https://coder.com/docs/about/contributing/AI_CONTRIBUTING) before submitting. --> Fixes https://github.com/coder/internal/issues/1539 Protects from port cross-talk by adding a short random prefix to our socket communication and instructing the service on the workspace agent side of the test to ignore any connections that don't use the prefix.
This commit is contained in:
+101
-58
@@ -1,10 +1,13 @@
|
|||||||
package cli_test
|
package cli_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"slices"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -41,6 +44,22 @@ func TestPortForward_None(t *testing.T) {
|
|||||||
require.ErrorContains(t, err, "no port-forwards")
|
require.ErrorContains(t, err, "no port-forwards")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func listenLocalUDPWithPrefix(t *testing.T, prefix []byte) net.Listener {
|
||||||
|
addr := net.UDPAddr{
|
||||||
|
IP: net.ParseIP("127.0.0.1"),
|
||||||
|
Port: 0,
|
||||||
|
}
|
||||||
|
cfg := udp.ListenConfig{AcceptFilter: func(bytes []byte) bool {
|
||||||
|
if len(bytes) < len(prefix) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return slices.Equal(prefix, bytes[:len(prefix)])
|
||||||
|
}}
|
||||||
|
l, err := cfg.Listen("udp", &addr)
|
||||||
|
require.NoError(t, err, "create UDP listener")
|
||||||
|
return l
|
||||||
|
}
|
||||||
|
|
||||||
func TestPortForward(t *testing.T) {
|
func TestPortForward(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
cases := []struct {
|
cases := []struct {
|
||||||
@@ -50,8 +69,9 @@ func TestPortForward(t *testing.T) {
|
|||||||
// of connection. Has one format arg (string) for the remote address.
|
// of connection. Has one format arg (string) for the remote address.
|
||||||
flag []string
|
flag []string
|
||||||
// setupRemote creates a "remote" listener to emulate a service in the
|
// setupRemote creates a "remote" listener to emulate a service in the
|
||||||
// workspace.
|
// workspace. The prefix is generated per test case and can be used to
|
||||||
setupRemote func(t *testing.T) net.Listener
|
// filter connections.
|
||||||
|
setupRemote func(t *testing.T, prefix []byte) net.Listener
|
||||||
// the local address(es) to "dial"
|
// the local address(es) to "dial"
|
||||||
localAddress []string
|
localAddress []string
|
||||||
}{
|
}{
|
||||||
@@ -59,7 +79,7 @@ func TestPortForward(t *testing.T) {
|
|||||||
name: "TCP",
|
name: "TCP",
|
||||||
network: "tcp",
|
network: "tcp",
|
||||||
flag: []string{"--tcp=5555:%v", "--tcp=6666:%v"},
|
flag: []string{"--tcp=5555:%v", "--tcp=6666:%v"},
|
||||||
setupRemote: func(t *testing.T) net.Listener {
|
setupRemote: func(t *testing.T, _ []byte) net.Listener {
|
||||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
require.NoError(t, err, "create TCP listener")
|
require.NoError(t, err, "create TCP listener")
|
||||||
return l
|
return l
|
||||||
@@ -70,7 +90,7 @@ func TestPortForward(t *testing.T) {
|
|||||||
name: "TCP-opportunistic-ipv6",
|
name: "TCP-opportunistic-ipv6",
|
||||||
network: "tcp",
|
network: "tcp",
|
||||||
flag: []string{"--tcp=5566:%v", "--tcp=6655:%v"},
|
flag: []string{"--tcp=5566:%v", "--tcp=6655:%v"},
|
||||||
setupRemote: func(t *testing.T) net.Listener {
|
setupRemote: func(t *testing.T, _ []byte) net.Listener {
|
||||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
require.NoError(t, err, "create TCP listener")
|
require.NoError(t, err, "create TCP listener")
|
||||||
return l
|
return l
|
||||||
@@ -78,39 +98,23 @@ func TestPortForward(t *testing.T) {
|
|||||||
localAddress: []string{"[::1]:5566", "[::1]:6655"},
|
localAddress: []string{"[::1]:5566", "[::1]:6655"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "UDP",
|
name: "UDP",
|
||||||
network: "udp",
|
network: "udp",
|
||||||
flag: []string{"--udp=7777:%v", "--udp=8888:%v"},
|
flag: []string{"--udp=7777:%v", "--udp=8888:%v"},
|
||||||
setupRemote: func(t *testing.T) net.Listener {
|
setupRemote: listenLocalUDPWithPrefix,
|
||||||
addr := net.UDPAddr{
|
|
||||||
IP: net.ParseIP("127.0.0.1"),
|
|
||||||
Port: 0,
|
|
||||||
}
|
|
||||||
l, err := udp.Listen("udp", &addr)
|
|
||||||
require.NoError(t, err, "create UDP listener")
|
|
||||||
return l
|
|
||||||
},
|
|
||||||
localAddress: []string{"127.0.0.1:7777", "127.0.0.1:8888"},
|
localAddress: []string{"127.0.0.1:7777", "127.0.0.1:8888"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "UDP-opportunistic-ipv6",
|
name: "UDP-opportunistic-ipv6",
|
||||||
network: "udp",
|
network: "udp",
|
||||||
flag: []string{"--udp=7788:%v", "--udp=8877:%v"},
|
flag: []string{"--udp=7788:%v", "--udp=8877:%v"},
|
||||||
setupRemote: func(t *testing.T) net.Listener {
|
setupRemote: listenLocalUDPWithPrefix,
|
||||||
addr := net.UDPAddr{
|
|
||||||
IP: net.ParseIP("127.0.0.1"),
|
|
||||||
Port: 0,
|
|
||||||
}
|
|
||||||
l, err := udp.Listen("udp", &addr)
|
|
||||||
require.NoError(t, err, "create UDP listener")
|
|
||||||
return l
|
|
||||||
},
|
|
||||||
localAddress: []string{"[::1]:7788", "[::1]:8877"},
|
localAddress: []string{"[::1]:7788", "[::1]:8877"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "TCPWithAddress",
|
name: "TCPWithAddress",
|
||||||
network: "tcp", flag: []string{"--tcp=10.10.10.99:9999:%v", "--tcp=10.10.10.10:1010:%v"},
|
network: "tcp", flag: []string{"--tcp=10.10.10.99:9999:%v", "--tcp=10.10.10.10:1010:%v"},
|
||||||
setupRemote: func(t *testing.T) net.Listener {
|
setupRemote: func(t *testing.T, _ []byte) net.Listener {
|
||||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
require.NoError(t, err, "create TCP listener")
|
require.NoError(t, err, "create TCP listener")
|
||||||
return l
|
return l
|
||||||
@@ -120,7 +124,7 @@ func TestPortForward(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "TCP-IPv6",
|
name: "TCP-IPv6",
|
||||||
network: "tcp", flag: []string{"--tcp=[fe80::99]:9999:%v", "--tcp=[fe80::10]:1010:%v"},
|
network: "tcp", flag: []string{"--tcp=[fe80::99]:9999:%v", "--tcp=[fe80::10]:1010:%v"},
|
||||||
setupRemote: func(t *testing.T) net.Listener {
|
setupRemote: func(t *testing.T, _ []byte) net.Listener {
|
||||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
require.NoError(t, err, "create TCP listener")
|
require.NoError(t, err, "create TCP listener")
|
||||||
return l
|
return l
|
||||||
@@ -146,7 +150,8 @@ func TestPortForward(t *testing.T) {
|
|||||||
for _, c := range cases {
|
for _, c := range cases {
|
||||||
t.Run(c.name+"_OnePort", func(t *testing.T) {
|
t.Run(c.name+"_OnePort", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
p1 := setupTestListener(t, c.setupRemote(t))
|
prefix := generateRandomPrefix(t)
|
||||||
|
p1 := setupTestListener(t, c.setupRemote(t, prefix), prefix)
|
||||||
|
|
||||||
// Create a flag that forwards from local to listener 1.
|
// Create a flag that forwards from local to listener 1.
|
||||||
flag := fmt.Sprintf(c.flag[0], p1)
|
flag := fmt.Sprintf(c.flag[0], p1)
|
||||||
@@ -182,8 +187,8 @@ func TestPortForward(t *testing.T) {
|
|||||||
c2, err := iNet.Dial(dialCtx, testutil.NewAddr(c.network, c.localAddress[0]))
|
c2, err := iNet.Dial(dialCtx, testutil.NewAddr(c.network, c.localAddress[0]))
|
||||||
require.NoError(t, err, "open connection 2 to 'local' listener")
|
require.NoError(t, err, "open connection 2 to 'local' listener")
|
||||||
defer c2.Close()
|
defer c2.Close()
|
||||||
testDial(t, c2)
|
testDial(t, c2, prefix)
|
||||||
testDial(t, c1)
|
testDial(t, c1, prefix)
|
||||||
|
|
||||||
cancel()
|
cancel()
|
||||||
err = <-errC
|
err = <-errC
|
||||||
@@ -199,10 +204,9 @@ func TestPortForward(t *testing.T) {
|
|||||||
|
|
||||||
t.Run(c.name+"_TwoPorts", func(t *testing.T) {
|
t.Run(c.name+"_TwoPorts", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
var (
|
prefix := generateRandomPrefix(t)
|
||||||
p1 = setupTestListener(t, c.setupRemote(t))
|
p1 := setupTestListener(t, c.setupRemote(t, prefix), prefix)
|
||||||
p2 = setupTestListener(t, c.setupRemote(t))
|
p2 := setupTestListener(t, c.setupRemote(t, prefix), prefix)
|
||||||
)
|
|
||||||
|
|
||||||
// Create a flags for listener 1 and listener 2.
|
// Create a flags for listener 1 and listener 2.
|
||||||
flag1 := fmt.Sprintf(c.flag[0], p1)
|
flag1 := fmt.Sprintf(c.flag[0], p1)
|
||||||
@@ -237,8 +241,8 @@ func TestPortForward(t *testing.T) {
|
|||||||
c2, err := iNet.Dial(dialCtx, testutil.NewAddr(c.network, c.localAddress[1]))
|
c2, err := iNet.Dial(dialCtx, testutil.NewAddr(c.network, c.localAddress[1]))
|
||||||
require.NoError(t, err, "open connection 2 to 'local' listener 2")
|
require.NoError(t, err, "open connection 2 to 'local' listener 2")
|
||||||
defer c2.Close()
|
defer c2.Close()
|
||||||
testDial(t, c2)
|
testDial(t, c2, prefix)
|
||||||
testDial(t, c1)
|
testDial(t, c1, prefix)
|
||||||
|
|
||||||
cancel()
|
cancel()
|
||||||
err = <-errC
|
err = <-errC
|
||||||
@@ -260,9 +264,10 @@ func TestPortForward(t *testing.T) {
|
|||||||
flags = []string{}
|
flags = []string{}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
prefix := generateRandomPrefix(t)
|
||||||
// Start listeners and populate arrays with the cases.
|
// Start listeners and populate arrays with the cases.
|
||||||
for _, c := range cases {
|
for _, c := range cases {
|
||||||
p := setupTestListener(t, c.setupRemote(t))
|
p := setupTestListener(t, c.setupRemote(t, prefix), prefix)
|
||||||
|
|
||||||
dials = append(dials, testutil.NewAddr(c.network, c.localAddress[0]))
|
dials = append(dials, testutil.NewAddr(c.network, c.localAddress[0]))
|
||||||
flags = append(flags, fmt.Sprintf(c.flag[0], p))
|
flags = append(flags, fmt.Sprintf(c.flag[0], p))
|
||||||
@@ -302,7 +307,7 @@ func TestPortForward(t *testing.T) {
|
|||||||
|
|
||||||
// Test each connection in reverse order.
|
// Test each connection in reverse order.
|
||||||
for i := len(conns) - 1; i >= 0; i-- {
|
for i := len(conns) - 1; i >= 0; i-- {
|
||||||
testDial(t, conns[i])
|
testDial(t, conns[i], prefix)
|
||||||
}
|
}
|
||||||
|
|
||||||
cancel()
|
cancel()
|
||||||
@@ -320,9 +325,11 @@ func TestPortForward(t *testing.T) {
|
|||||||
t.Run("IPv6Busy", func(t *testing.T) {
|
t.Run("IPv6Busy", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
prefix := generateRandomPrefix(t)
|
||||||
|
|
||||||
remoteLis, err := net.Listen("tcp", "127.0.0.1:0")
|
remoteLis, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
require.NoError(t, err, "create TCP listener")
|
require.NoError(t, err, "create TCP listener")
|
||||||
p1 := setupTestListener(t, remoteLis)
|
p1 := setupTestListener(t, remoteLis, prefix)
|
||||||
|
|
||||||
// Create a flag that forwards from local 5555 to remote listener port.
|
// Create a flag that forwards from local 5555 to remote listener port.
|
||||||
flag := fmt.Sprintf("--tcp=5555:%v", p1)
|
flag := fmt.Sprintf("--tcp=5555:%v", p1)
|
||||||
@@ -360,7 +367,7 @@ func TestPortForward(t *testing.T) {
|
|||||||
c1, err := iNet.Dial(dialCtx, testutil.NewAddr("tcp", "127.0.0.1:5555"))
|
c1, err := iNet.Dial(dialCtx, testutil.NewAddr("tcp", "127.0.0.1:5555"))
|
||||||
require.NoError(t, err, "open connection 1 to 'local' listener")
|
require.NoError(t, err, "open connection 1 to 'local' listener")
|
||||||
defer c1.Close()
|
defer c1.Close()
|
||||||
testDial(t, c1)
|
testDial(t, c1, prefix)
|
||||||
|
|
||||||
cancel()
|
cancel()
|
||||||
err = <-errC
|
err = <-errC
|
||||||
@@ -375,6 +382,17 @@ func TestPortForward(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// generateRandomPrefix generates a unique prefix per test case to ensure that we can filter out any cross-talk on the
|
||||||
|
// local network.
|
||||||
|
func generateRandomPrefix(t *testing.T) []byte {
|
||||||
|
t.Helper()
|
||||||
|
prefix := make([]byte, 16)
|
||||||
|
n, err := rand.Read(prefix)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 16, n)
|
||||||
|
return prefix
|
||||||
|
}
|
||||||
|
|
||||||
// runAgent creates a fake workspace and starts an agent locally for that
|
// runAgent creates a fake workspace and starts an agent locally for that
|
||||||
// workspace. The agent will be cleaned up on test completion.
|
// workspace. The agent will be cleaned up on test completion.
|
||||||
// nolint:unused
|
// nolint:unused
|
||||||
@@ -398,8 +416,8 @@ func runAgent(t *testing.T, client *codersdk.Client, owner uuid.UUID, db databas
|
|||||||
}
|
}
|
||||||
|
|
||||||
// setupTestListener starts accepting connections and echoing a single packet.
|
// setupTestListener starts accepting connections and echoing a single packet.
|
||||||
// Returns the listener and the listen port.
|
// Returns the listen port.
|
||||||
func setupTestListener(t *testing.T, l net.Listener) string {
|
func setupTestListener(t *testing.T, l net.Listener, prefix []byte) string {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
// Wait for listener to completely exit before releasing.
|
// Wait for listener to completely exit before releasing.
|
||||||
@@ -423,7 +441,7 @@ func setupTestListener(t *testing.T, l net.Listener) string {
|
|||||||
|
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
testAccept(t, c)
|
echoIfPrefixed(t, c, prefix)
|
||||||
wg.Done()
|
wg.Done()
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
@@ -432,30 +450,54 @@ func setupTestListener(t *testing.T, l net.Listener) string {
|
|||||||
addr := l.Addr().String()
|
addr := l.Addr().String()
|
||||||
_, port, err := net.SplitHostPort(addr)
|
_, port, err := net.SplitHostPort(addr)
|
||||||
require.NoErrorf(t, err, "split non-Unix listen path %q", addr)
|
require.NoErrorf(t, err, "split non-Unix listen path %q", addr)
|
||||||
addr = port
|
return port
|
||||||
|
|
||||||
return addr
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var dialTestPayload = []byte("dean-was-here123")
|
const dialTestPayload = "dean-was-here123"
|
||||||
|
|
||||||
func testDial(t *testing.T, c net.Conn) {
|
func newPayload(prefix []byte) []byte {
|
||||||
|
payload := make([]byte, 0, len(dialTestPayload)+len(prefix))
|
||||||
|
payload = append(payload, prefix...)
|
||||||
|
payload = append(payload, dialTestPayload...)
|
||||||
|
return payload
|
||||||
|
}
|
||||||
|
|
||||||
|
func testDial(t *testing.T, c net.Conn, prefix []byte) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
assertWritePayload(t, c, dialTestPayload)
|
assertWritePayload(t, c, prefix)
|
||||||
assertReadPayload(t, c, dialTestPayload)
|
assertReadPayload(t, c, prefix)
|
||||||
}
|
}
|
||||||
|
|
||||||
func testAccept(t *testing.T, c net.Conn) {
|
func echoIfPrefixed(t *testing.T, c net.Conn, prefix []byte) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
defer c.Close()
|
defer c.Close()
|
||||||
|
|
||||||
assertReadPayload(t, c, dialTestPayload)
|
// here we don't want to assert anything, because the listener is exposed to the OS, so who knows what might
|
||||||
assertWritePayload(t, c, dialTestPayload)
|
// connect. If we get the expected prefix to our message, echo it back.
|
||||||
|
b := make([]byte, 2048)
|
||||||
|
n, err := c.Read(b)
|
||||||
|
if err != nil {
|
||||||
|
t.Logf("read failed (could be crosstalk): %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if n < len(prefix) {
|
||||||
|
t.Logf("short read (could be crosstalk): read %x", b[:n])
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !bytes.HasPrefix(b, prefix) {
|
||||||
|
t.Logf("missing prefix (could be crosstalk), wanted %x got %x", prefix, b[:n])
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, err = c.Write(b[:n])
|
||||||
|
if err != nil {
|
||||||
|
t.Logf("write failed: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func assertReadPayload(t *testing.T, r io.Reader, payload []byte) {
|
func assertReadPayload(t *testing.T, r io.Reader, prefix []byte) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
payload := newPayload(prefix)
|
||||||
b := make([]byte, len(payload)+16)
|
b := make([]byte, len(payload)+16)
|
||||||
n, err := r.Read(b)
|
n, err := r.Read(b)
|
||||||
assert.NoError(t, err, "read payload")
|
assert.NoError(t, err, "read payload")
|
||||||
@@ -463,8 +505,9 @@ func assertReadPayload(t *testing.T, r io.Reader, payload []byte) {
|
|||||||
assert.Equal(t, payload, b[:n])
|
assert.Equal(t, payload, b[:n])
|
||||||
}
|
}
|
||||||
|
|
||||||
func assertWritePayload(t *testing.T, w io.Writer, payload []byte) {
|
func assertWritePayload(t *testing.T, w io.Writer, prefix []byte) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
payload := newPayload(prefix)
|
||||||
n, err := w.Write(payload)
|
n, err := w.Write(payload)
|
||||||
assert.NoError(t, err, "write payload")
|
assert.NoError(t, err, "write payload")
|
||||||
assert.Equal(t, len(payload), n, "payload length does not match")
|
assert.Equal(t, len(payload), n, "payload length does not match")
|
||||||
|
|||||||
Reference in New Issue
Block a user