diff --git a/agent/agent_test.go b/agent/agent_test.go index 8593123ef7..af8364a30d 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -2668,11 +2668,11 @@ func TestAgent_Dial(t *testing.T) { cases := []struct { name string - setup func(t *testing.T) net.Listener + setup func(t testing.TB) net.Listener }{ { name: "TCP", - setup: func(t *testing.T) net.Listener { + setup: func(t testing.TB) net.Listener { l, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err, "create TCP listener") return l @@ -2680,7 +2680,7 @@ func TestAgent_Dial(t *testing.T) { }, { name: "UDP", - setup: func(t *testing.T) net.Listener { + setup: func(t testing.TB) net.Listener { addr := net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), Port: 0, @@ -2698,57 +2698,68 @@ func TestAgent_Dial(t *testing.T) { // The purpose of this test is to ensure that a client can dial a // listener in the workspace over tailnet. - l := c.setup(t) - done := make(chan struct{}) - defer func() { - l.Close() - <-done - }() + // + // The OS sometimes drops packets if the system can't keep up with + // them. For TCP packets, it's typically fine due to + // retransmissions, but for UDP packets, it can fail this test. + // + // The OS gets involved for the Wireguard traffic (either via DERP + // or direct UDP), and also for the traffic between the agent and + // the listener in the "workspace". + // + // To avoid this, we'll retry this test up to 3 times. + testutil.RunRetry(t, 3, func(t testing.TB) { + ctx := testutil.Context(t, testutil.WaitLong) - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() + l := c.setup(t) + done := make(chan struct{}) + defer func() { + l.Close() + <-done + }() - go func() { - defer close(done) - for range 2 { - c, err := l.Accept() - if assert.NoError(t, err, "accept connection") { - testAccept(ctx, t, c) - _ = c.Close() + go func() { + defer close(done) + for range 2 { + c, err := l.Accept() + if assert.NoError(t, err, "accept connection") { + testAccept(ctx, t, c) + _ = c.Close() + } } + }() + + agentID := uuid.UUID{0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8} + //nolint:dogsled + agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{ + AgentID: agentID, + }, 0) + require.True(t, agentConn.AwaitReachable(ctx)) + conn, err := agentConn.DialContext(ctx, l.Addr().Network(), l.Addr().String()) + require.NoError(t, err) + testDial(ctx, t, conn) + err = conn.Close() + require.NoError(t, err) + + // also connect via the CoderServicePrefix, to test that we can reach the agent on this + // IP. This will be required for CoderVPN. + _, rawPort, _ := net.SplitHostPort(l.Addr().String()) + port, _ := strconv.ParseUint(rawPort, 10, 16) + ipp := netip.AddrPortFrom(tailnet.CoderServicePrefix.AddrFromUUID(agentID), uint16(port)) + + switch l.Addr().Network() { + case "tcp": + conn, err = agentConn.Conn.DialContextTCP(ctx, ipp) + case "udp": + conn, err = agentConn.Conn.DialContextUDP(ctx, ipp) + default: + t.Fatalf("unknown network: %s", l.Addr().Network()) } - }() - - agentID := uuid.UUID{0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8} - //nolint:dogsled - agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{ - AgentID: agentID, - }, 0) - require.True(t, agentConn.AwaitReachable(ctx)) - conn, err := agentConn.DialContext(ctx, l.Addr().Network(), l.Addr().String()) - require.NoError(t, err) - testDial(ctx, t, conn) - err = conn.Close() - require.NoError(t, err) - - // also connect via the CoderServicePrefix, to test that we can reach the agent on this - // IP. This will be required for CoderVPN. - _, rawPort, _ := net.SplitHostPort(l.Addr().String()) - port, _ := strconv.ParseUint(rawPort, 10, 16) - ipp := netip.AddrPortFrom(tailnet.CoderServicePrefix.AddrFromUUID(agentID), uint16(port)) - - switch l.Addr().Network() { - case "tcp": - conn, err = agentConn.Conn.DialContextTCP(ctx, ipp) - case "udp": - conn, err = agentConn.Conn.DialContextUDP(ctx, ipp) - default: - t.Fatalf("unknown network: %s", l.Addr().Network()) - } - require.NoError(t, err) - testDial(ctx, t, conn) - err = conn.Close() - require.NoError(t, err) + require.NoError(t, err) + testDial(ctx, t, conn) + err = conn.Close() + require.NoError(t, err) + }) }) } } @@ -3251,7 +3262,7 @@ func setupSSHSessionOnPort( return session } -func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Duration, opts ...func(*agenttest.Client, *agent.Options)) ( +func setupAgent(t testing.TB, metadata agentsdk.Manifest, ptyTimeout time.Duration, opts ...func(*agenttest.Client, *agent.Options)) ( *workspacesdk.AgentConn, *agenttest.Client, <-chan *proto.Stats, @@ -3349,7 +3360,7 @@ func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Durati var dialTestPayload = []byte("dean-was-here123") -func testDial(ctx context.Context, t *testing.T, c net.Conn) { +func testDial(ctx context.Context, t testing.TB, c net.Conn) { t.Helper() if deadline, ok := ctx.Deadline(); ok { @@ -3365,7 +3376,7 @@ func testDial(ctx context.Context, t *testing.T, c net.Conn) { assertReadPayload(t, c, dialTestPayload) } -func testAccept(ctx context.Context, t *testing.T, c net.Conn) { +func testAccept(ctx context.Context, t testing.TB, c net.Conn) { t.Helper() defer c.Close() @@ -3382,7 +3393,7 @@ func testAccept(ctx context.Context, t *testing.T, c net.Conn) { assertWritePayload(t, c, dialTestPayload) } -func assertReadPayload(t *testing.T, r io.Reader, payload []byte) { +func assertReadPayload(t testing.TB, r io.Reader, payload []byte) { t.Helper() b := make([]byte, len(payload)+16) n, err := r.Read(b) @@ -3391,11 +3402,11 @@ func assertReadPayload(t *testing.T, r io.Reader, payload []byte) { assert.Equal(t, payload, b[:n]) } -func assertWritePayload(t *testing.T, w io.Writer, payload []byte) { +func assertWritePayload(t testing.TB, w io.Writer, payload []byte) { t.Helper() n, err := w.Write(payload) assert.NoError(t, err, "write payload") - assert.Equal(t, len(payload), n, "payload length does not match") + assert.Equal(t, len(payload), n, "written payload length does not match") } func testSessionOutput(t *testing.T, session *ssh.Session, expected, unexpected []string, expectedRe *regexp.Regexp) { diff --git a/tailnet/tailnettest/tailnettest.go b/tailnet/tailnettest/tailnettest.go index 89327cddd8..50b83aaf4f 100644 --- a/tailnet/tailnettest/tailnettest.go +++ b/tailnet/tailnettest/tailnettest.go @@ -45,7 +45,7 @@ func DERPIsEmbedded(cfg *derpAndSTUNCfg) { } // RunDERPAndSTUN creates a DERP mapping for tests. -func RunDERPAndSTUN(t *testing.T, opts ...DERPAndStunOption) (*tailcfg.DERPMap, *derp.Server) { +func RunDERPAndSTUN(t testing.TB, opts ...DERPAndStunOption) (*tailcfg.DERPMap, *derp.Server) { cfg := new(derpAndSTUNCfg) for _, o := range opts { o(cfg) diff --git a/testutil/ctx.go b/testutil/ctx.go index e23c48da85..acbf14e5bb 100644 --- a/testutil/ctx.go +++ b/testutil/ctx.go @@ -6,7 +6,7 @@ import ( "time" ) -func Context(t *testing.T, dur time.Duration) context.Context { +func Context(t testing.TB, dur time.Duration) context.Context { ctx, cancel := context.WithTimeout(context.Background(), dur) t.Cleanup(cancel) return ctx diff --git a/testutil/retry.go b/testutil/retry.go new file mode 100644 index 0000000000..0314ae6580 --- /dev/null +++ b/testutil/retry.go @@ -0,0 +1,238 @@ +package testutil + +import ( + "context" + "fmt" + "runtime" + "slices" + "sync" + "testing" + "time" +) + +// RunRetry runs a test function up to `count` times, retrying if it fails. If +// all attempts fail or the context is canceled, the test will fail. It is safe +// to use the parent context in the test function, but do note that the context +// deadline will apply to all attempts. +// +// DO NOT USE THIS FUNCTION IN TESTS UNLESS YOU HAVE A GOOD REASON. It should +// only be used in tests that can flake under high load. It is not a replacement +// for writing a good test. +// +// Note that the `testing.TB` supplied to the function is a fake implementation +// for all runs. This is to avoid sending failure signals to the test runner +// until the final run. Unrecovered panics will still always be bubbled up to +// the test runner. +// +// Some functions are not implemented and will panic when using the fake +// implementation: +// - Chdir +// - Setenv +// - Skip, SkipNow, Skipf, Skipped +// - TempDir +// +// Cleanup functions will be executed after each attempt. +func RunRetry(t *testing.T, count int, fn func(t testing.TB)) { + t.Helper() + + for i := 1; i <= count; i++ { + // Canceled in the attempt goroutine before running cleanup functions. + attemptCtx, attemptCancel := context.WithCancel(t.Context()) + attemptT := &fakeT{ + T: t, + ctx: attemptCtx, + name: fmt.Sprintf("%s (attempt %d/%d)", t.Name(), i, count), + } + + // Run the test in a goroutine so we can capture runtime.Goexit() + // and run cleanup functions. + done := make(chan struct{}, 1) + go func() { + defer close(done) + defer func() { + // As per t.Context(), the context is canceled right before + // cleanup functions are executed. + attemptCancel() + attemptT.runCleanupFns() + }() + + t.Logf("testutil.RunRetry: running test: attempt %d/%d", i, count) + fn(attemptT) + }() + + // We don't wait on the context here, because we want to be sure that + // the test function and cleanup functions have finished before + // returning from the test. + <-done + if !attemptT.Failed() { + t.Logf("testutil.RunRetry: test passed on attempt %d/%d", i, count) + return + } + t.Logf("testutil.RunRetry: test failed on attempt %d/%d", i, count) + + // Wait a few seconds in case the test failure was due to system load. + // There's not really a good way to check for this, so we just do it + // every time. + // No point waiting on t.Context() here because it doesn't factor in + // the test deadline, and only gets canceled when the test function + // completes. + time.Sleep(2 * time.Second) + } + t.Fatalf("testutil.RunRetry: all %d attempts failed", count) +} + +// fakeT is a fake implementation of testing.TB that never fails and only logs +// errors. Fatal errors will cause the goroutine to exit without failing the +// test. +// +// The behavior of the fake implementation should be as close as possible to +// the real implementation from the test function's perspective (minus +// intentionally unimplemented methods). +type fakeT struct { + *testing.T + ctx context.Context + name string + + mu sync.Mutex + failed bool + cleanupFns []func() +} + +var _ testing.TB = &fakeT{} + +func (t *fakeT) runCleanupFns() { + t.mu.Lock() + cleanupFns := slices.Clone(t.cleanupFns) + t.mu.Unlock() + + // Execute in LIFO order to match the behavior of *testing.T. + slices.Reverse(cleanupFns) + for _, fn := range cleanupFns { + fn() + } +} + +// Chdir implements testing.TB. +func (*fakeT) Chdir(_ string) { + panic("t.Chdir is not implemented in testutil.RunRetry closures") +} + +// Cleanup implements testing.TB. Cleanup registers a function to be called when +// the test completes. Cleanup functions will be called in last added, first +// called order. +func (t *fakeT) Cleanup(fn func()) { + t.mu.Lock() + defer t.mu.Unlock() + + t.cleanupFns = append(t.cleanupFns, fn) +} + +// Context implements testing.TB. Context returns a context that is canceled +// just before Cleanup-registered functions are called. +func (t *fakeT) Context() context.Context { + return t.ctx +} + +// Error implements testing.TB. Error is equivalent to Log followed by Fail. +func (t *fakeT) Error(args ...any) { + t.T.Helper() + t.T.Log(args...) + t.Fail() +} + +// Errorf implements testing.TB. Errorf is equivalent to Logf followed by Fail. +func (t *fakeT) Errorf(format string, args ...any) { + t.T.Helper() + t.T.Logf(format, args...) + t.Fail() +} + +// Fail implements testing.TB. Fail marks the function as having failed but +// continues execution. +func (t *fakeT) Fail() { + t.T.Helper() + t.mu.Lock() + defer t.mu.Unlock() + t.failed = true + t.T.Log("testutil.RunRetry: t.Fail called in testutil.RunRetry closure") +} + +// FailNow implements testing.TB. FailNow marks the function as having failed +// and stops its execution by calling runtime.Goexit (which then runs all the +// deferred calls in the current goroutine). +func (t *fakeT) FailNow() { + t.T.Helper() + t.mu.Lock() + defer t.mu.Unlock() + t.failed = true + t.T.Log("testutil.RunRetry: t.FailNow called in testutil.RunRetry closure") + runtime.Goexit() +} + +// Failed implements testing.TB. Failed reports whether the function has failed. +func (t *fakeT) Failed() bool { + t.T.Helper() + t.mu.Lock() + defer t.mu.Unlock() + return t.failed +} + +// Fatal implements testing.TB. Fatal is equivalent to Log followed by FailNow. +func (t *fakeT) Fatal(args ...any) { + t.T.Helper() + t.T.Log(args...) + t.FailNow() +} + +// Fatalf implements testing.TB. Fatalf is equivalent to Logf followed by +// FailNow. +func (t *fakeT) Fatalf(format string, args ...any) { + t.T.Helper() + t.T.Logf(format, args...) + t.FailNow() +} + +// Helper is proxied to the original *testing.T. This is to avoid the fake +// method appearing in the call stack. + +// Log is proxied to the original *testing.T. + +// Logf is proxied to the original *testing.T. + +// Name implements testing.TB. +func (t *fakeT) Name() string { + return t.name +} + +// Setenv implements testing.TB. +func (*fakeT) Setenv(_ string, _ string) { + panic("t.Setenv is not implemented in testutil.RunRetry closures") +} + +// Skip implements testing.TB. +func (*fakeT) Skip(_ ...any) { + panic("t.Skip is not implemented in testutil.RunRetry closures") +} + +// SkipNow implements testing.TB. +func (*fakeT) SkipNow() { + panic("t.SkipNow is not implemented in testutil.RunRetry closures") +} + +// Skipf implements testing.TB. +func (*fakeT) Skipf(_ string, _ ...any) { + panic("t.Skipf is not implemented in testutil.RunRetry closures") +} + +// Skipped implements testing.TB. +func (*fakeT) Skipped() bool { + panic("t.Skipped is not implemented in testutil.RunRetry closures") +} + +// TempDir implements testing.TB. +func (*fakeT) TempDir() string { + panic("t.TempDir is not implemented in testutil.RunRetry closures") +} + +// private is proxied to the original *testing.T. It cannot be implemented by +// our fake implementation since it's a private method.