From 1a774ab7ce99063a2e01beb94de3fcbccaf84dbe Mon Sep 17 00:00:00 2001 From: Ethan <39577870+ethanndickson@users.noreply.github.com> Date: Sat, 14 Mar 2026 05:26:48 +1100 Subject: [PATCH] fix(tailnet): retry after transport dial timeouts (#22977) (cherry-pick/v2.31) (#22992) Backport of #22977 to 2.31 --- tailnet/controllers.go | 2 +- tailnet/controllers_test.go | 103 ++++++++++++++++++++++++++++++++++++ 2 files changed, 104 insertions(+), 1 deletion(-) diff --git a/tailnet/controllers.go b/tailnet/controllers.go index b51ddcdffe..969d65461c 100644 --- a/tailnet/controllers.go +++ b/tailnet/controllers.go @@ -1429,7 +1429,7 @@ func (c *Controller) Run(ctx context.Context) { tailnetClients, err := c.Dialer.Dial(c.ctx, c.ResumeTokenCtrl) if err != nil { - if xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) { + if c.ctx.Err() != nil { return } diff --git a/tailnet/controllers_test.go b/tailnet/controllers_test.go index 247698d6c0..962b15c34e 100644 --- a/tailnet/controllers_test.go +++ b/tailnet/controllers_test.go @@ -1075,6 +1075,84 @@ func TestController_Disconnects(t *testing.T) { _ = testutil.TryReceive(testCtx, t, uut.Closed()) } +func TestController_RetriesWrappedDeadlineExceeded(t *testing.T) { + t.Parallel() + testCtx := testutil.Context(t, testutil.WaitShort) + ctx, cancel := context.WithCancel(testCtx) + defer cancel() + + logger := testutil.Logger(t) + dialer := &scriptedDialer{ + attempts: make(chan int, 10), + dialFn: func(ctx context.Context, attempt int) (tailnet.ControlProtocolClients, error) { + if attempt == 1 { + return tailnet.ControlProtocolClients{}, &net.OpError{ + Op: "dial", + Net: "tcp", + Err: context.DeadlineExceeded, + } + } + + <-ctx.Done() + return tailnet.ControlProtocolClients{}, ctx.Err() + }, + } + + uut := tailnet.NewController(logger.Named("ctrl"), dialer) + uut.Run(ctx) + + require.Equal(t, 1, testutil.TryReceive(testCtx, t, dialer.attempts)) + require.Equal(t, 2, testutil.TryReceive(testCtx, t, dialer.attempts)) + + select { + case <-uut.Closed(): + t.Fatal("controller exited after wrapped deadline exceeded") + default: + } + + cancel() + _ = testutil.TryReceive(testCtx, t, uut.Closed()) +} + +func TestController_DoesNotRedialAfterCancel(t *testing.T) { + t.Parallel() + testCtx := testutil.Context(t, testutil.WaitShort) + ctx, cancel := context.WithCancel(testCtx) + logger := testutil.Logger(t) + + fClient := newFakeWorkspaceUpdateClient(testCtx, t) + dialer := &scriptedDialer{ + attempts: make(chan int, 10), + dialFn: func(_ context.Context, _ int) (tailnet.ControlProtocolClients, error) { + return tailnet.ControlProtocolClients{ + WorkspaceUpdates: fClient, + Closer: fakeCloser{}, + }, nil + }, + } + fCtrl := newFakeUpdatesController(testCtx, t) + + uut := tailnet.NewController(logger.Named("ctrl"), dialer) + uut.WorkspaceUpdatesCtrl = fCtrl + uut.Run(ctx) + + require.Equal(t, 1, testutil.TryReceive(testCtx, t, dialer.attempts)) + call := testutil.TryReceive(testCtx, t, fCtrl.calls) + require.Equal(t, fClient, call.client) + testutil.RequireSend[tailnet.CloserWaiter](testCtx, t, call.resp, newFakeCloserWaiter()) + + cancel() + closeCall := testutil.TryReceive(testCtx, t, fClient.close) + testutil.RequireSend(testCtx, t, closeCall, nil) + _ = testutil.TryReceive(testCtx, t, uut.Closed()) + + select { + case attempt := <-dialer.attempts: + t.Fatalf("unexpected redial attempt after cancel: %d", attempt) + default: + } +} + func TestController_TelemetrySuccess(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitShort) @@ -2070,6 +2148,31 @@ func newFakeCloserWaiter() *fakeCloserWaiter { } } +type scriptedDialer struct { + attempts chan int + dialFn func(context.Context, int) (tailnet.ControlProtocolClients, error) + + mu sync.Mutex + attemptN int +} + +func (d *scriptedDialer) Dial(ctx context.Context, _ tailnet.ResumeTokenController) (tailnet.ControlProtocolClients, error) { + d.mu.Lock() + d.attemptN++ + attempt := d.attemptN + d.mu.Unlock() + + if d.attempts != nil { + select { + case d.attempts <- attempt: + case <-ctx.Done(): + return tailnet.ControlProtocolClients{}, ctx.Err() + } + } + + return d.dialFn(ctx, attempt) +} + type fakeWorkspaceUpdatesDialer struct { client tailnet.WorkspaceUpdatesClient }