chore: retry TestAgent_Dial subtests (#19387)

Closes https://github.com/coder/internal/issues/595
This commit is contained in:
Dean Sheather
2025-08-18 23:51:19 +10:00
committed by GitHub
parent a8c89a120f
commit e2ba9e7d62
4 changed files with 307 additions and 58 deletions
+67 -56
View File
@@ -2668,11 +2668,11 @@ func TestAgent_Dial(t *testing.T) {
cases := []struct { cases := []struct {
name string name string
setup func(t *testing.T) net.Listener setup func(t testing.TB) net.Listener
}{ }{
{ {
name: "TCP", name: "TCP",
setup: func(t *testing.T) net.Listener { setup: func(t testing.TB) net.Listener {
l, err := net.Listen("tcp", "127.0.0.1:0") l, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err, "create TCP listener") require.NoError(t, err, "create TCP listener")
return l return l
@@ -2680,7 +2680,7 @@ func TestAgent_Dial(t *testing.T) {
}, },
{ {
name: "UDP", name: "UDP",
setup: func(t *testing.T) net.Listener { setup: func(t testing.TB) net.Listener {
addr := net.UDPAddr{ addr := net.UDPAddr{
IP: net.ParseIP("127.0.0.1"), IP: net.ParseIP("127.0.0.1"),
Port: 0, Port: 0,
@@ -2698,57 +2698,68 @@ func TestAgent_Dial(t *testing.T) {
// The purpose of this test is to ensure that a client can dial a // The purpose of this test is to ensure that a client can dial a
// listener in the workspace over tailnet. // listener in the workspace over tailnet.
l := c.setup(t) //
done := make(chan struct{}) // The OS sometimes drops packets if the system can't keep up with
defer func() { // them. For TCP packets, it's typically fine due to
l.Close() // retransmissions, but for UDP packets, it can fail this test.
<-done //
}() // The OS gets involved for the Wireguard traffic (either via DERP
// or direct UDP), and also for the traffic between the agent and
// the listener in the "workspace".
//
// To avoid this, we'll retry this test up to 3 times.
testutil.RunRetry(t, 3, func(t testing.TB) {
ctx := testutil.Context(t, testutil.WaitLong)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) l := c.setup(t)
defer cancel() done := make(chan struct{})
defer func() {
l.Close()
<-done
}()
go func() { go func() {
defer close(done) defer close(done)
for range 2 { for range 2 {
c, err := l.Accept() c, err := l.Accept()
if assert.NoError(t, err, "accept connection") { if assert.NoError(t, err, "accept connection") {
testAccept(ctx, t, c) testAccept(ctx, t, c)
_ = c.Close() _ = c.Close()
}
} }
}()
agentID := uuid.UUID{0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8}
//nolint:dogsled
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{
AgentID: agentID,
}, 0)
require.True(t, agentConn.AwaitReachable(ctx))
conn, err := agentConn.DialContext(ctx, l.Addr().Network(), l.Addr().String())
require.NoError(t, err)
testDial(ctx, t, conn)
err = conn.Close()
require.NoError(t, err)
// also connect via the CoderServicePrefix, to test that we can reach the agent on this
// IP. This will be required for CoderVPN.
_, rawPort, _ := net.SplitHostPort(l.Addr().String())
port, _ := strconv.ParseUint(rawPort, 10, 16)
ipp := netip.AddrPortFrom(tailnet.CoderServicePrefix.AddrFromUUID(agentID), uint16(port))
switch l.Addr().Network() {
case "tcp":
conn, err = agentConn.Conn.DialContextTCP(ctx, ipp)
case "udp":
conn, err = agentConn.Conn.DialContextUDP(ctx, ipp)
default:
t.Fatalf("unknown network: %s", l.Addr().Network())
} }
}() require.NoError(t, err)
testDial(ctx, t, conn)
agentID := uuid.UUID{0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8} err = conn.Close()
//nolint:dogsled require.NoError(t, err)
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{ })
AgentID: agentID,
}, 0)
require.True(t, agentConn.AwaitReachable(ctx))
conn, err := agentConn.DialContext(ctx, l.Addr().Network(), l.Addr().String())
require.NoError(t, err)
testDial(ctx, t, conn)
err = conn.Close()
require.NoError(t, err)
// also connect via the CoderServicePrefix, to test that we can reach the agent on this
// IP. This will be required for CoderVPN.
_, rawPort, _ := net.SplitHostPort(l.Addr().String())
port, _ := strconv.ParseUint(rawPort, 10, 16)
ipp := netip.AddrPortFrom(tailnet.CoderServicePrefix.AddrFromUUID(agentID), uint16(port))
switch l.Addr().Network() {
case "tcp":
conn, err = agentConn.Conn.DialContextTCP(ctx, ipp)
case "udp":
conn, err = agentConn.Conn.DialContextUDP(ctx, ipp)
default:
t.Fatalf("unknown network: %s", l.Addr().Network())
}
require.NoError(t, err)
testDial(ctx, t, conn)
err = conn.Close()
require.NoError(t, err)
}) })
} }
} }
@@ -3251,7 +3262,7 @@ func setupSSHSessionOnPort(
return session return session
} }
func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Duration, opts ...func(*agenttest.Client, *agent.Options)) ( func setupAgent(t testing.TB, metadata agentsdk.Manifest, ptyTimeout time.Duration, opts ...func(*agenttest.Client, *agent.Options)) (
*workspacesdk.AgentConn, *workspacesdk.AgentConn,
*agenttest.Client, *agenttest.Client,
<-chan *proto.Stats, <-chan *proto.Stats,
@@ -3349,7 +3360,7 @@ func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Durati
var dialTestPayload = []byte("dean-was-here123") var dialTestPayload = []byte("dean-was-here123")
func testDial(ctx context.Context, t *testing.T, c net.Conn) { func testDial(ctx context.Context, t testing.TB, c net.Conn) {
t.Helper() t.Helper()
if deadline, ok := ctx.Deadline(); ok { if deadline, ok := ctx.Deadline(); ok {
@@ -3365,7 +3376,7 @@ func testDial(ctx context.Context, t *testing.T, c net.Conn) {
assertReadPayload(t, c, dialTestPayload) assertReadPayload(t, c, dialTestPayload)
} }
func testAccept(ctx context.Context, t *testing.T, c net.Conn) { func testAccept(ctx context.Context, t testing.TB, c net.Conn) {
t.Helper() t.Helper()
defer c.Close() defer c.Close()
@@ -3382,7 +3393,7 @@ func testAccept(ctx context.Context, t *testing.T, c net.Conn) {
assertWritePayload(t, c, dialTestPayload) assertWritePayload(t, c, dialTestPayload)
} }
func assertReadPayload(t *testing.T, r io.Reader, payload []byte) { func assertReadPayload(t testing.TB, r io.Reader, payload []byte) {
t.Helper() t.Helper()
b := make([]byte, len(payload)+16) b := make([]byte, len(payload)+16)
n, err := r.Read(b) n, err := r.Read(b)
@@ -3391,11 +3402,11 @@ func assertReadPayload(t *testing.T, r io.Reader, payload []byte) {
assert.Equal(t, payload, b[:n]) assert.Equal(t, payload, b[:n])
} }
func assertWritePayload(t *testing.T, w io.Writer, payload []byte) { func assertWritePayload(t testing.TB, w io.Writer, payload []byte) {
t.Helper() t.Helper()
n, err := w.Write(payload) n, err := w.Write(payload)
assert.NoError(t, err, "write payload") assert.NoError(t, err, "write payload")
assert.Equal(t, len(payload), n, "payload length does not match") assert.Equal(t, len(payload), n, "written payload length does not match")
} }
func testSessionOutput(t *testing.T, session *ssh.Session, expected, unexpected []string, expectedRe *regexp.Regexp) { func testSessionOutput(t *testing.T, session *ssh.Session, expected, unexpected []string, expectedRe *regexp.Regexp) {
+1 -1
View File
@@ -45,7 +45,7 @@ func DERPIsEmbedded(cfg *derpAndSTUNCfg) {
} }
// RunDERPAndSTUN creates a DERP mapping for tests. // RunDERPAndSTUN creates a DERP mapping for tests.
func RunDERPAndSTUN(t *testing.T, opts ...DERPAndStunOption) (*tailcfg.DERPMap, *derp.Server) { func RunDERPAndSTUN(t testing.TB, opts ...DERPAndStunOption) (*tailcfg.DERPMap, *derp.Server) {
cfg := new(derpAndSTUNCfg) cfg := new(derpAndSTUNCfg)
for _, o := range opts { for _, o := range opts {
o(cfg) o(cfg)
+1 -1
View File
@@ -6,7 +6,7 @@ import (
"time" "time"
) )
func Context(t *testing.T, dur time.Duration) context.Context { func Context(t testing.TB, dur time.Duration) context.Context {
ctx, cancel := context.WithTimeout(context.Background(), dur) ctx, cancel := context.WithTimeout(context.Background(), dur)
t.Cleanup(cancel) t.Cleanup(cancel)
return ctx return ctx
+238
View File
@@ -0,0 +1,238 @@
package testutil
import (
"context"
"fmt"
"runtime"
"slices"
"sync"
"testing"
"time"
)
// RunRetry runs a test function up to `count` times, retrying if it fails. If
// all attempts fail or the context is canceled, the test will fail. It is safe
// to use the parent context in the test function, but do note that the context
// deadline will apply to all attempts.
//
// DO NOT USE THIS FUNCTION IN TESTS UNLESS YOU HAVE A GOOD REASON. It should
// only be used in tests that can flake under high load. It is not a replacement
// for writing a good test.
//
// Note that the `testing.TB` supplied to the function is a fake implementation
// for all runs. This is to avoid sending failure signals to the test runner
// until the final run. Unrecovered panics will still always be bubbled up to
// the test runner.
//
// Some functions are not implemented and will panic when using the fake
// implementation:
// - Chdir
// - Setenv
// - Skip, SkipNow, Skipf, Skipped
// - TempDir
//
// Cleanup functions will be executed after each attempt.
func RunRetry(t *testing.T, count int, fn func(t testing.TB)) {
t.Helper()
for i := 1; i <= count; i++ {
// Canceled in the attempt goroutine before running cleanup functions.
attemptCtx, attemptCancel := context.WithCancel(t.Context())
attemptT := &fakeT{
T: t,
ctx: attemptCtx,
name: fmt.Sprintf("%s (attempt %d/%d)", t.Name(), i, count),
}
// Run the test in a goroutine so we can capture runtime.Goexit()
// and run cleanup functions.
done := make(chan struct{}, 1)
go func() {
defer close(done)
defer func() {
// As per t.Context(), the context is canceled right before
// cleanup functions are executed.
attemptCancel()
attemptT.runCleanupFns()
}()
t.Logf("testutil.RunRetry: running test: attempt %d/%d", i, count)
fn(attemptT)
}()
// We don't wait on the context here, because we want to be sure that
// the test function and cleanup functions have finished before
// returning from the test.
<-done
if !attemptT.Failed() {
t.Logf("testutil.RunRetry: test passed on attempt %d/%d", i, count)
return
}
t.Logf("testutil.RunRetry: test failed on attempt %d/%d", i, count)
// Wait a few seconds in case the test failure was due to system load.
// There's not really a good way to check for this, so we just do it
// every time.
// No point waiting on t.Context() here because it doesn't factor in
// the test deadline, and only gets canceled when the test function
// completes.
time.Sleep(2 * time.Second)
}
t.Fatalf("testutil.RunRetry: all %d attempts failed", count)
}
// fakeT is a fake implementation of testing.TB that never fails and only logs
// errors. Fatal errors will cause the goroutine to exit without failing the
// test.
//
// The behavior of the fake implementation should be as close as possible to
// the real implementation from the test function's perspective (minus
// intentionally unimplemented methods).
type fakeT struct {
*testing.T
ctx context.Context
name string
mu sync.Mutex
failed bool
cleanupFns []func()
}
var _ testing.TB = &fakeT{}
func (t *fakeT) runCleanupFns() {
t.mu.Lock()
cleanupFns := slices.Clone(t.cleanupFns)
t.mu.Unlock()
// Execute in LIFO order to match the behavior of *testing.T.
slices.Reverse(cleanupFns)
for _, fn := range cleanupFns {
fn()
}
}
// Chdir implements testing.TB.
func (*fakeT) Chdir(_ string) {
panic("t.Chdir is not implemented in testutil.RunRetry closures")
}
// Cleanup implements testing.TB. Cleanup registers a function to be called when
// the test completes. Cleanup functions will be called in last added, first
// called order.
func (t *fakeT) Cleanup(fn func()) {
t.mu.Lock()
defer t.mu.Unlock()
t.cleanupFns = append(t.cleanupFns, fn)
}
// Context implements testing.TB. Context returns a context that is canceled
// just before Cleanup-registered functions are called.
func (t *fakeT) Context() context.Context {
return t.ctx
}
// Error implements testing.TB. Error is equivalent to Log followed by Fail.
func (t *fakeT) Error(args ...any) {
t.T.Helper()
t.T.Log(args...)
t.Fail()
}
// Errorf implements testing.TB. Errorf is equivalent to Logf followed by Fail.
func (t *fakeT) Errorf(format string, args ...any) {
t.T.Helper()
t.T.Logf(format, args...)
t.Fail()
}
// Fail implements testing.TB. Fail marks the function as having failed but
// continues execution.
func (t *fakeT) Fail() {
t.T.Helper()
t.mu.Lock()
defer t.mu.Unlock()
t.failed = true
t.T.Log("testutil.RunRetry: t.Fail called in testutil.RunRetry closure")
}
// FailNow implements testing.TB. FailNow marks the function as having failed
// and stops its execution by calling runtime.Goexit (which then runs all the
// deferred calls in the current goroutine).
func (t *fakeT) FailNow() {
t.T.Helper()
t.mu.Lock()
defer t.mu.Unlock()
t.failed = true
t.T.Log("testutil.RunRetry: t.FailNow called in testutil.RunRetry closure")
runtime.Goexit()
}
// Failed implements testing.TB. Failed reports whether the function has failed.
func (t *fakeT) Failed() bool {
t.T.Helper()
t.mu.Lock()
defer t.mu.Unlock()
return t.failed
}
// Fatal implements testing.TB. Fatal is equivalent to Log followed by FailNow.
func (t *fakeT) Fatal(args ...any) {
t.T.Helper()
t.T.Log(args...)
t.FailNow()
}
// Fatalf implements testing.TB. Fatalf is equivalent to Logf followed by
// FailNow.
func (t *fakeT) Fatalf(format string, args ...any) {
t.T.Helper()
t.T.Logf(format, args...)
t.FailNow()
}
// Helper is proxied to the original *testing.T. This is to avoid the fake
// method appearing in the call stack.
// Log is proxied to the original *testing.T.
// Logf is proxied to the original *testing.T.
// Name implements testing.TB.
func (t *fakeT) Name() string {
return t.name
}
// Setenv implements testing.TB.
func (*fakeT) Setenv(_ string, _ string) {
panic("t.Setenv is not implemented in testutil.RunRetry closures")
}
// Skip implements testing.TB.
func (*fakeT) Skip(_ ...any) {
panic("t.Skip is not implemented in testutil.RunRetry closures")
}
// SkipNow implements testing.TB.
func (*fakeT) SkipNow() {
panic("t.SkipNow is not implemented in testutil.RunRetry closures")
}
// Skipf implements testing.TB.
func (*fakeT) Skipf(_ string, _ ...any) {
panic("t.Skipf is not implemented in testutil.RunRetry closures")
}
// Skipped implements testing.TB.
func (*fakeT) Skipped() bool {
panic("t.Skipped is not implemented in testutil.RunRetry closures")
}
// TempDir implements testing.TB.
func (*fakeT) TempDir() string {
panic("t.TempDir is not implemented in testutil.RunRetry closures")
}
// private is proxied to the original *testing.T. It cannot be implemented by
// our fake implementation since it's a private method.