diff --git a/tailnet/controllers.go b/tailnet/controllers.go index 82909fbcf6..b99016e80b 100644 --- a/tailnet/controllers.go +++ b/tailnet/controllers.go @@ -986,6 +986,30 @@ func (a *Agent) Clone() Agent { func (t *TunnelAllWorkspaceUpdatesController) New(client WorkspaceUpdatesClient) CloserWaiter { t.mu.Lock() defer t.mu.Unlock() + + // Preserve workspace state from the previous updater so that DNS + // hosts remain programmed while we wait for the new server + // snapshot. Without this, a control-plane reconnection would + // leave the internal DNS resolver empty until the first + // workspace update arrives, breaking .coder resolution. + var previousWorkspaces map[uuid.UUID]*Workspace + if t.updater != nil { + t.updater.Lock() + if len(t.updater.workspaces) > 0 { + previousWorkspaces = make(map[uuid.UUID]*Workspace, len(t.updater.workspaces)) + for id, w := range t.updater.workspaces { + clone := w.Clone() + previousWorkspaces[id] = &clone + } + } + t.updater.Unlock() + } + + workspaces := make(map[uuid.UUID]*Workspace) + if previousWorkspaces != nil { + workspaces = previousWorkspaces + } + updater := &tunnelUpdater{ client: client, errChan: make(chan error, 1), @@ -996,8 +1020,26 @@ func (t *TunnelAllWorkspaceUpdatesController) New(client WorkspaceUpdatesClient) updateHandler: t.updateHandler, ownerUsername: t.ownerUsername, recvLoopDone: make(chan struct{}), - workspaces: make(map[uuid.UUID]*Workspace), + workspaces: workspaces, } + + // If we inherited workspace state, immediately re-program DNS + // hosts so the resolver stays populated during the reconnection + // window. + if len(previousWorkspaces) > 0 { + dnsNames := updater.updateDNSNamesLocked() + if updater.dnsHostsSetter != nil { + t.logger.Debug(context.Background(), "re-applying DNS hosts from previous session", + slog.F("num_hosts", len(dnsNames)), + ) + if err := updater.dnsHostsSetter.SetDNSHosts(dnsNames); err != nil { + t.logger.Warn(context.Background(), "failed to re-apply DNS hosts from previous session", + slog.Error(err), + ) + } + } + } + t.updater = updater go t.updater.recvLoop() return t.updater @@ -1157,6 +1199,14 @@ func (t *tunnelUpdater) handleUpdate(update *proto.WorkspaceUpdate, updateKind U t.Lock() defer t.Unlock() + // Snapshots represent the complete state, not a diff. Clear any + // inherited or previously accumulated workspace data so that + // workspaces absent from the snapshot (e.g. stopped while we + // were disconnected) do not linger. + if updateKind == Snapshot { + t.workspaces = make(map[uuid.UUID]*Workspace) + } + currentUpdate := WorkspaceUpdate{ UpsertedWorkspaces: []*Workspace{}, UpsertedAgents: []*Agent{}, diff --git a/tailnet/controllers_test.go b/tailnet/controllers_test.go index 962b15c34e..899ffac7eb 100644 --- a/tailnet/controllers_test.go +++ b/tailnet/controllers_test.go @@ -2023,6 +2023,216 @@ func TestTunnelAllWorkspaceUpdatesController_HandleErrors(t *testing.T) { } } +func TestTunnelAllWorkspaceUpdatesController_ReconnectPreservesDNS(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + logger := testutil.Logger(t) + + fConn := &fakeCoordinatee{} + tsc := tailnet.NewTunnelSrcCoordController(logger, fConn) + fDNS := newFakeDNSSetter(ctx, t) + uut := tailnet.NewTunnelAllWorkspaceUpdatesController(logger, tsc, + tailnet.WithDNS(fDNS, "testy", tailnet.DNSNameOptions{Suffix: "mctest"}), + ) + + // Connect coordinator + coordC := newFakeCoordinatorClient(ctx, t) + coordCW := tsc.New(coordC) + + // Connect first workspace update client and send a snapshot + w1ID := testUUID(1) + w1a1ID := testUUID(1, 1) + updateC1 := newFakeWorkspaceUpdateClient(ctx, t) + updateCW1 := uut.New(updateC1) + + initUp := &proto.WorkspaceUpdate{ + UpsertedWorkspaces: []*proto.Workspace{ + {Id: w1ID[:], Name: "w1"}, + }, + UpsertedAgents: []*proto.Agent{ + {Id: w1a1ID[:], Name: "w1a1", WorkspaceId: w1ID[:]}, + }, + } + + upRecvCall := testutil.TryReceive(ctx, t, updateC1.recv) + testutil.RequireSend(ctx, t, upRecvCall.resp, initUp) + + // Consume the AddTunnel coordination call + coordCall := testutil.TryReceive(ctx, t, coordC.reqs) + testutil.RequireSend(ctx, t, coordCall.err, nil) + + // Consume the initial DNS set call + dnsCall := testutil.TryReceive(ctx, t, fDNS.calls) + require.NotEmpty(t, dnsCall.hosts) + initialHosts := dnsCall.hosts + testutil.RequireSend(ctx, t, dnsCall.err, nil) + + // Simulate disconnect: hang up the first workspace update client + upRecvCall = testutil.TryReceive(ctx, t, updateC1.recv) + testutil.RequireSend(ctx, t, upRecvCall.err, io.EOF) + err := testutil.TryReceive(ctx, t, updateCW1.Wait()) + require.ErrorIs(t, err, io.EOF) + + // Reconnect with a new workspace update client. The controller + // should carry over workspace state and immediately re-apply DNS. + updateC2 := newFakeWorkspaceUpdateClient(ctx, t) + + // New() will call SetDNSHosts synchronously with the inherited + // hosts. We need to consume that call concurrently because the + // fake DNS setter blocks until the test reads from the channel. + reconnDNSHosts := make(chan map[dnsname.FQDN][]netip.Addr, 1) + go func() { + call := testutil.TryReceive(ctx, t, fDNS.calls) + reconnDNSHosts <- call.hosts + testutil.RequireSend(ctx, t, call.err, nil) + }() + + updateCW2 := uut.New(updateC2) + + // Verify DNS was re-applied with the previously known hosts + // before any server snapshot arrives on the new client. + gotHosts := testutil.TryReceive(ctx, t, reconnDNSHosts) + require.Equal(t, initialHosts, gotHosts) + + // Cleanup: tear down the second workspace update client + upRecvCall = testutil.TryReceive(ctx, t, updateC2.recv) + testutil.RequireSend(ctx, t, upRecvCall.err, io.EOF) + err = testutil.TryReceive(ctx, t, updateCW2.Wait()) + require.ErrorIs(t, err, io.EOF) + + // Cleanup: tear down coordinator + coordRecv := testutil.TryReceive(ctx, t, coordC.resps) + testutil.RequireSend(ctx, t, coordRecv.err, io.EOF) + cCall := testutil.TryReceive(ctx, t, coordC.close) + testutil.RequireSend(ctx, t, cCall, nil) + err = testutil.TryReceive(ctx, t, coordCW.Wait()) + require.ErrorIs(t, err, io.EOF) +} + +func TestTunnelAllWorkspaceUpdatesController_SnapshotClearsInheritedState(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + logger := testutil.Logger(t) + + fConn := &fakeCoordinatee{} + tsc := tailnet.NewTunnelSrcCoordController(logger, fConn) + fDNS := newFakeDNSSetter(ctx, t) + uut := tailnet.NewTunnelAllWorkspaceUpdatesController(logger, tsc, + tailnet.WithDNS(fDNS, "testy", tailnet.DNSNameOptions{Suffix: "mctest"}), + ) + + coordC := newFakeCoordinatorClient(ctx, t) + coordCW := tsc.New(coordC) + + // First connection: send snapshot with two workspaces. + w1ID := testUUID(1) + w1a1ID := testUUID(1, 1) + w2ID := testUUID(2) + w2a1ID := testUUID(2, 1) + updateC1 := newFakeWorkspaceUpdateClient(ctx, t) + updateCW1 := uut.New(updateC1) + + initUp := &proto.WorkspaceUpdate{ + UpsertedWorkspaces: []*proto.Workspace{ + {Id: w1ID[:], Name: "w1"}, + {Id: w2ID[:], Name: "w2"}, + }, + UpsertedAgents: []*proto.Agent{ + {Id: w1a1ID[:], Name: "w1a1", WorkspaceId: w1ID[:]}, + {Id: w2a1ID[:], Name: "w2a1", WorkspaceId: w2ID[:]}, + }, + } + + upRecvCall := testutil.TryReceive(ctx, t, updateC1.recv) + testutil.RequireSend(ctx, t, upRecvCall.resp, initUp) + + // Consume AddTunnel coordination calls for both agents. + for range 2 { + coordCall := testutil.TryReceive(ctx, t, coordC.reqs) + testutil.RequireSend(ctx, t, coordCall.err, nil) + } + + // Consume the initial DNS set call. + dnsCall := testutil.TryReceive(ctx, t, fDNS.calls) + require.NotEmpty(t, dnsCall.hosts) + initialHosts := dnsCall.hosts + testutil.RequireSend(ctx, t, dnsCall.err, nil) + + // Verify both workspaces are in the DNS hosts. + var initialHostCount int + for range initialHosts { + initialHostCount++ + } + require.Greater(t, initialHostCount, 0) + + // Simulate disconnect. + upRecvCall = testutil.TryReceive(ctx, t, updateC1.recv) + testutil.RequireSend(ctx, t, upRecvCall.err, io.EOF) + err := testutil.TryReceive(ctx, t, updateCW1.Wait()) + require.ErrorIs(t, err, io.EOF) + + // Reconnect. New() re-applies inherited DNS synchronously. + updateC2 := newFakeWorkspaceUpdateClient(ctx, t) + reconnDNSHosts := make(chan map[dnsname.FQDN][]netip.Addr, 1) + go func() { + call := testutil.TryReceive(ctx, t, fDNS.calls) + reconnDNSHosts <- call.hosts + testutil.RequireSend(ctx, t, call.err, nil) + }() + + updateCW2 := uut.New(updateC2) + + // The inherited DNS should have both workspaces. + gotHosts := testutil.TryReceive(ctx, t, reconnDNSHosts) + require.Equal(t, initialHosts, gotHosts) + + // Server sends a snapshot that only contains w2 (w1 stopped + // while we were disconnected). Because it is a snapshot, w1 + // must be removed from the updater state. + snapshotUp := &proto.WorkspaceUpdate{ + UpsertedWorkspaces: []*proto.Workspace{ + {Id: w2ID[:], Name: "w2"}, + }, + UpsertedAgents: []*proto.Agent{ + {Id: w2a1ID[:], Name: "w2a1", WorkspaceId: w2ID[:]}, + }, + } + + upRecvCall = testutil.TryReceive(ctx, t, updateC2.recv) + testutil.RequireSend(ctx, t, upRecvCall.resp, snapshotUp) + + // The snapshot should trigger a RemoveTunnel for w1's agent and + // an AddTunnel for w2's agent (since state was cleared and rebuilt). + // Consume all coordination calls. + coordCall := testutil.TryReceive(ctx, t, coordC.reqs) + testutil.RequireSend(ctx, t, coordCall.err, nil) + + // Consume the DNS update after the snapshot. + dnsCall = testutil.TryReceive(ctx, t, fDNS.calls) + require.NotEmpty(t, dnsCall.hosts) + testutil.RequireSend(ctx, t, dnsCall.err, nil) + + // After the snapshot, DNS should only contain w2 entries. The + // stale w1 entries must be gone. + for fqdn := range dnsCall.hosts { + require.NotContains(t, string(fqdn), "w1", + "stale workspace w1 should not be in DNS after snapshot") + } + + // Cleanup. + upRecvCall = testutil.TryReceive(ctx, t, updateC2.recv) + testutil.RequireSend(ctx, t, upRecvCall.err, io.EOF) + err = testutil.TryReceive(ctx, t, updateCW2.Wait()) + require.ErrorIs(t, err, io.EOF) + + coordRecv := testutil.TryReceive(ctx, t, coordC.resps) + testutil.RequireSend(ctx, t, coordRecv.err, io.EOF) + cCall := testutil.TryReceive(ctx, t, coordC.close) + testutil.RequireSend(ctx, t, cCall, nil) + err = testutil.TryReceive(ctx, t, coordCW.Wait()) + require.ErrorIs(t, err, io.EOF) +} + func TestBasicDERPController_RewriteDERPMap(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitShort)