diff --git a/enterprise/tailnet/connio.go b/enterprise/tailnet/connio.go index 45d9c71c3e..6e98dfec4c 100644 --- a/enterprise/tailnet/connio.go +++ b/enterprise/tailnet/connio.go @@ -2,7 +2,6 @@ package tailnet import ( "context" - "io" "sync" "sync/atomic" "time" @@ -104,19 +103,21 @@ func (c *connIO) recvLoop() { }() defer c.Close() for { - req, err := agpl.RecvCtx(c.peerCtx, c.requests) - if err != nil { - if xerrors.Is(err, context.Canceled) || - xerrors.Is(err, context.DeadlineExceeded) || - xerrors.Is(err, io.EOF) { - c.logger.Debug(c.coordCtx, "exiting io recvLoop", slog.Error(err)) - } else { - c.logger.Error(c.coordCtx, "failed to receive request", slog.Error(err)) + select { + case <-c.coordCtx.Done(): + c.logger.Debug(c.coordCtx, "exiting io recvLoop; coordinator exit") + return + case <-c.peerCtx.Done(): + c.logger.Debug(c.peerCtx, "exiting io recvLoop; peer context canceled") + return + case req, ok := <-c.requests: + if !ok { + c.logger.Debug(c.peerCtx, "exiting io recvLoop; requests chan closed") + return + } + if err := c.handleRequest(req); err != nil { + return } - return - } - if err := c.handleRequest(req); err != nil { - return } } } diff --git a/enterprise/tailnet/multiagent_test.go b/enterprise/tailnet/multiagent_test.go index c9f8f73fe9..bbb3c55735 100644 --- a/enterprise/tailnet/multiagent_test.go +++ b/enterprise/tailnet/multiagent_test.go @@ -4,17 +4,14 @@ import ( "context" "testing" - "github.com/google/uuid" "github.com/stretchr/testify/require" - "golang.org/x/exp/slices" - "tailscale.com/types/key" "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/enterprise/tailnet" agpl "github.com/coder/coder/v2/tailnet" - "github.com/coder/coder/v2/tailnet/proto" + "github.com/coder/coder/v2/tailnet/tailnettest" "github.com/coder/coder/v2/testutil" ) @@ -42,25 +39,48 @@ func TestPGCoordinator_MultiAgent(t *testing.T) { defer agent1.close() agent1.sendNode(&agpl.Node{PreferredDERP: 5}) - ma1 := newTestMultiAgent(t, coord1) - defer ma1.close() + ma1 := tailnettest.NewTestMultiAgent(t, coord1) + defer ma1.Close() - ma1.subscribeAgent(agent1.id) - ma1.assertEventuallyHasDERPs(ctx, 5) + ma1.RequireSubscribeAgent(agent1.id) + ma1.RequireEventuallyHasDERPs(ctx, 5) agent1.sendNode(&agpl.Node{PreferredDERP: 1}) - ma1.assertEventuallyHasDERPs(ctx, 1) + ma1.RequireEventuallyHasDERPs(ctx, 1) - ma1.sendNodeWithDERP(3) + ma1.SendNodeWithDERP(3) assertEventuallyHasDERPs(ctx, t, agent1, 3) - ma1.close() + ma1.Close() require.NoError(t, agent1.close()) assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) assertEventuallyLost(ctx, t, store, agent1.id) } +func TestPGCoordinator_MultiAgent_CoordClose(t *testing.T) { + t.Parallel() + if !dbtestutil.WillUsePostgres() { + t.Skip("test only with postgres") + } + + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + store, ps := dbtestutil.NewDB(t) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store) + require.NoError(t, err) + defer coord1.Close() + + ma1 := tailnettest.NewTestMultiAgent(t, coord1) + defer ma1.Close() + + err = coord1.Close() + require.NoError(t, err) + + ma1.RequireEventuallyClosed(ctx) +} + // TestPGCoordinator_MultiAgent_UnsubscribeRace tests a single coordinator with // a MultiAgent connecting to one agent. It tries to race a call to Unsubscribe // with the MultiAgent closing. @@ -86,20 +106,20 @@ func TestPGCoordinator_MultiAgent_UnsubscribeRace(t *testing.T) { defer agent1.close() agent1.sendNode(&agpl.Node{PreferredDERP: 5}) - ma1 := newTestMultiAgent(t, coord1) - defer ma1.close() + ma1 := tailnettest.NewTestMultiAgent(t, coord1) + defer ma1.Close() - ma1.subscribeAgent(agent1.id) - ma1.assertEventuallyHasDERPs(ctx, 5) + ma1.RequireSubscribeAgent(agent1.id) + ma1.RequireEventuallyHasDERPs(ctx, 5) agent1.sendNode(&agpl.Node{PreferredDERP: 1}) - ma1.assertEventuallyHasDERPs(ctx, 1) + ma1.RequireEventuallyHasDERPs(ctx, 1) - ma1.sendNodeWithDERP(3) + ma1.SendNodeWithDERP(3) assertEventuallyHasDERPs(ctx, t, agent1, 3) - ma1.unsubscribeAgent(agent1.id) - ma1.close() + ma1.RequireUnsubscribeAgent(agent1.id) + ma1.Close() require.NoError(t, agent1.close()) assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) @@ -131,35 +151,35 @@ func TestPGCoordinator_MultiAgent_Unsubscribe(t *testing.T) { defer agent1.close() agent1.sendNode(&agpl.Node{PreferredDERP: 5}) - ma1 := newTestMultiAgent(t, coord1) - defer ma1.close() + ma1 := tailnettest.NewTestMultiAgent(t, coord1) + defer ma1.Close() - ma1.subscribeAgent(agent1.id) - ma1.assertEventuallyHasDERPs(ctx, 5) + ma1.RequireSubscribeAgent(agent1.id) + ma1.RequireEventuallyHasDERPs(ctx, 5) agent1.sendNode(&agpl.Node{PreferredDERP: 1}) - ma1.assertEventuallyHasDERPs(ctx, 1) + ma1.RequireEventuallyHasDERPs(ctx, 1) - ma1.sendNodeWithDERP(3) + ma1.SendNodeWithDERP(3) assertEventuallyHasDERPs(ctx, t, agent1, 3) - ma1.unsubscribeAgent(agent1.id) + ma1.RequireUnsubscribeAgent(agent1.id) assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) func() { ctx, cancel := context.WithTimeout(ctx, testutil.IntervalSlow*3) defer cancel() - ma1.sendNodeWithDERP(9) + ma1.SendNodeWithDERP(9) assertNeverHasDERPs(ctx, t, agent1, 9) }() func() { ctx, cancel := context.WithTimeout(ctx, testutil.IntervalSlow*3) defer cancel() agent1.sendNode(&agpl.Node{PreferredDERP: 8}) - ma1.assertNeverHasDERPs(ctx, 8) + ma1.RequireNeverHasDERPs(ctx, 8) }() - ma1.close() + ma1.Close() require.NoError(t, agent1.close()) assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) @@ -196,19 +216,19 @@ func TestPGCoordinator_MultiAgent_MultiCoordinator(t *testing.T) { defer agent1.close() agent1.sendNode(&agpl.Node{PreferredDERP: 5}) - ma1 := newTestMultiAgent(t, coord2) - defer ma1.close() + ma1 := tailnettest.NewTestMultiAgent(t, coord2) + defer ma1.Close() - ma1.subscribeAgent(agent1.id) - ma1.assertEventuallyHasDERPs(ctx, 5) + ma1.RequireSubscribeAgent(agent1.id) + ma1.RequireEventuallyHasDERPs(ctx, 5) agent1.sendNode(&agpl.Node{PreferredDERP: 1}) - ma1.assertEventuallyHasDERPs(ctx, 1) + ma1.RequireEventuallyHasDERPs(ctx, 1) - ma1.sendNodeWithDERP(3) + ma1.SendNodeWithDERP(3) assertEventuallyHasDERPs(ctx, t, agent1, 3) - ma1.close() + ma1.Close() require.NoError(t, agent1.close()) assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) @@ -246,19 +266,19 @@ func TestPGCoordinator_MultiAgent_MultiCoordinator_UpdateBeforeSubscribe(t *test defer agent1.close() agent1.sendNode(&agpl.Node{PreferredDERP: 5}) - ma1 := newTestMultiAgent(t, coord2) - defer ma1.close() + ma1 := tailnettest.NewTestMultiAgent(t, coord2) + defer ma1.Close() - ma1.sendNodeWithDERP(3) + ma1.SendNodeWithDERP(3) - ma1.subscribeAgent(agent1.id) - ma1.assertEventuallyHasDERPs(ctx, 5) + ma1.RequireSubscribeAgent(agent1.id) + ma1.RequireEventuallyHasDERPs(ctx, 5) assertEventuallyHasDERPs(ctx, t, agent1, 3) agent1.sendNode(&agpl.Node{PreferredDERP: 1}) - ma1.assertEventuallyHasDERPs(ctx, 1) + ma1.RequireEventuallyHasDERPs(ctx, 1) - ma1.close() + ma1.Close() require.NoError(t, agent1.close()) assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) @@ -305,129 +325,29 @@ func TestPGCoordinator_MultiAgent_TwoAgents(t *testing.T) { defer agent1.close() agent2.sendNode(&agpl.Node{PreferredDERP: 6}) - ma1 := newTestMultiAgent(t, coord3) - defer ma1.close() + ma1 := tailnettest.NewTestMultiAgent(t, coord3) + defer ma1.Close() - ma1.subscribeAgent(agent1.id) - ma1.assertEventuallyHasDERPs(ctx, 5) + ma1.RequireSubscribeAgent(agent1.id) + ma1.RequireEventuallyHasDERPs(ctx, 5) agent1.sendNode(&agpl.Node{PreferredDERP: 1}) - ma1.assertEventuallyHasDERPs(ctx, 1) + ma1.RequireEventuallyHasDERPs(ctx, 1) - ma1.subscribeAgent(agent2.id) - ma1.assertEventuallyHasDERPs(ctx, 6) + ma1.RequireSubscribeAgent(agent2.id) + ma1.RequireEventuallyHasDERPs(ctx, 6) agent2.sendNode(&agpl.Node{PreferredDERP: 2}) - ma1.assertEventuallyHasDERPs(ctx, 2) + ma1.RequireEventuallyHasDERPs(ctx, 2) - ma1.sendNodeWithDERP(3) + ma1.SendNodeWithDERP(3) assertEventuallyHasDERPs(ctx, t, agent1, 3) assertEventuallyHasDERPs(ctx, t, agent2, 3) - ma1.close() + ma1.Close() require.NoError(t, agent1.close()) require.NoError(t, agent2.close()) assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) assertEventuallyLost(ctx, t, store, agent1.id) } - -type testMultiAgent struct { - t testing.TB - id uuid.UUID - a agpl.MultiAgentConn - nodeKey []byte - discoKey string -} - -func newTestMultiAgent(t testing.TB, coord agpl.Coordinator) *testMultiAgent { - nk, err := key.NewNode().Public().MarshalBinary() - require.NoError(t, err) - dk, err := key.NewDisco().Public().MarshalText() - require.NoError(t, err) - m := &testMultiAgent{t: t, id: uuid.New(), nodeKey: nk, discoKey: string(dk)} - m.a = coord.ServeMultiAgent(m.id) - return m -} - -func (m *testMultiAgent) sendNodeWithDERP(derp int32) { - m.t.Helper() - err := m.a.UpdateSelf(&proto.Node{ - Key: m.nodeKey, - Disco: m.discoKey, - PreferredDerp: derp, - }) - require.NoError(m.t, err) -} - -func (m *testMultiAgent) close() { - m.t.Helper() - err := m.a.Close() - require.NoError(m.t, err) -} - -func (m *testMultiAgent) subscribeAgent(id uuid.UUID) { - m.t.Helper() - err := m.a.SubscribeAgent(id) - require.NoError(m.t, err) -} - -func (m *testMultiAgent) unsubscribeAgent(id uuid.UUID) { - m.t.Helper() - err := m.a.UnsubscribeAgent(id) - require.NoError(m.t, err) -} - -func (m *testMultiAgent) assertEventuallyHasDERPs(ctx context.Context, expected ...int) { - m.t.Helper() - for { - resp, ok := m.a.NextUpdate(ctx) - require.True(m.t, ok) - nodes, err := agpl.OnlyNodeUpdates(resp) - require.NoError(m.t, err) - if len(nodes) != len(expected) { - m.t.Logf("expected %d, got %d nodes", len(expected), len(nodes)) - continue - } - - derps := make([]int, 0, len(nodes)) - for _, n := range nodes { - derps = append(derps, n.PreferredDERP) - } - for _, e := range expected { - if !slices.Contains(derps, e) { - m.t.Logf("expected DERP %d to be in %v", e, derps) - continue - } - return - } - } -} - -func (m *testMultiAgent) assertNeverHasDERPs(ctx context.Context, expected ...int) { - m.t.Helper() - for { - resp, ok := m.a.NextUpdate(ctx) - if !ok { - return - } - nodes, err := agpl.OnlyNodeUpdates(resp) - require.NoError(m.t, err) - if len(nodes) != len(expected) { - m.t.Logf("expected %d, got %d nodes", len(expected), len(nodes)) - continue - } - - derps := make([]int, 0, len(nodes)) - for _, n := range nodes { - derps = append(derps, n.PreferredDERP) - } - for _, e := range expected { - if !slices.Contains(derps, e) { - m.t.Logf("expected DERP %d to be in %v", e, derps) - continue - } - return - } - } -} diff --git a/tailnet/coordinator.go b/tailnet/coordinator.go index a5d9241a85..31e2a9dded 100644 --- a/tailnet/coordinator.go +++ b/tailnet/coordinator.go @@ -1017,7 +1017,13 @@ func v1ReqLoop(ctx context.Context, cancel context.CancelFunc, logger slog.Logge } func v1RespLoop(ctx context.Context, cancel context.CancelFunc, logger slog.Logger, q Queue, resps <-chan *proto.CoordinateResponse) { - defer cancel() + defer func() { + cErr := q.Close() + if cErr != nil { + logger.Info(ctx, "error closing response Queue", slog.Error(cErr)) + } + cancel() + }() for { resp, err := RecvCtx(ctx, resps) if err != nil { diff --git a/tailnet/coordinator_test.go b/tailnet/coordinator_test.go index ab38f91bd0..72a6591051 100644 --- a/tailnet/coordinator_test.go +++ b/tailnet/coordinator_test.go @@ -383,6 +383,24 @@ func TestCoordinator_Lost(t *testing.T) { test.LostTest(ctx, t, coordinator) } +func TestCoordinator_MultiAgent_CoordClose(t *testing.T) { + t.Parallel() + + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + coord1 := tailnet.NewCoordinator(logger.Named("coord1")) + defer coord1.Close() + + ma1 := tailnettest.NewTestMultiAgent(t, coord1) + defer ma1.Close() + + err := coord1.Close() + require.NoError(t, err) + + ma1.RequireEventuallyClosed(ctx) +} + func websocketConn(ctx context.Context, t *testing.T) (client net.Conn, server net.Conn) { t.Helper() sc := make(chan net.Conn, 1) diff --git a/tailnet/tailnettest/tailnettest.go b/tailnet/tailnettest/tailnettest.go index 794aee549c..e3c66a23ab 100644 --- a/tailnet/tailnettest/tailnettest.go +++ b/tailnet/tailnettest/tailnettest.go @@ -1,6 +1,7 @@ package tailnettest import ( + "context" "crypto/tls" "fmt" "html" @@ -8,7 +9,11 @@ import ( "net/http" "net/http/httptest" "testing" + "time" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "golang.org/x/exp/slices" "tailscale.com/derp" "tailscale.com/derp/derphttp" "tailscale.com/net/stun/stuntest" @@ -19,6 +24,8 @@ import ( "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/v2/tailnet" + "github.com/coder/coder/v2/tailnet/proto" + "github.com/coder/coder/v2/testutil" ) //go:generate mockgen -destination ./multiagentmock.go -package tailnettest github.com/coder/coder/v2/tailnet MultiAgentConn @@ -125,3 +132,120 @@ func RunDERPOnlyWebSockets(t *testing.T) *tailcfg.DERPMap { }, } } + +type TestMultiAgent struct { + t testing.TB + id uuid.UUID + a tailnet.MultiAgentConn + nodeKey []byte + discoKey string +} + +func NewTestMultiAgent(t testing.TB, coord tailnet.Coordinator) *TestMultiAgent { + nk, err := key.NewNode().Public().MarshalBinary() + require.NoError(t, err) + dk, err := key.NewDisco().Public().MarshalText() + require.NoError(t, err) + m := &TestMultiAgent{t: t, id: uuid.New(), nodeKey: nk, discoKey: string(dk)} + m.a = coord.ServeMultiAgent(m.id) + return m +} + +func (m *TestMultiAgent) SendNodeWithDERP(d int32) { + m.t.Helper() + err := m.a.UpdateSelf(&proto.Node{ + Key: m.nodeKey, + Disco: m.discoKey, + PreferredDerp: d, + }) + require.NoError(m.t, err) +} + +func (m *TestMultiAgent) Close() { + m.t.Helper() + err := m.a.Close() + require.NoError(m.t, err) +} + +func (m *TestMultiAgent) RequireSubscribeAgent(id uuid.UUID) { + m.t.Helper() + err := m.a.SubscribeAgent(id) + require.NoError(m.t, err) +} + +func (m *TestMultiAgent) RequireUnsubscribeAgent(id uuid.UUID) { + m.t.Helper() + err := m.a.UnsubscribeAgent(id) + require.NoError(m.t, err) +} + +func (m *TestMultiAgent) RequireEventuallyHasDERPs(ctx context.Context, expected ...int) { + m.t.Helper() + for { + resp, ok := m.a.NextUpdate(ctx) + require.True(m.t, ok) + nodes, err := tailnet.OnlyNodeUpdates(resp) + require.NoError(m.t, err) + if len(nodes) != len(expected) { + m.t.Logf("expected %d, got %d nodes", len(expected), len(nodes)) + continue + } + + derps := make([]int, 0, len(nodes)) + for _, n := range nodes { + derps = append(derps, n.PreferredDERP) + } + for _, e := range expected { + if !slices.Contains(derps, e) { + m.t.Logf("expected DERP %d to be in %v", e, derps) + continue + } + return + } + } +} + +func (m *TestMultiAgent) RequireNeverHasDERPs(ctx context.Context, expected ...int) { + m.t.Helper() + for { + resp, ok := m.a.NextUpdate(ctx) + if !ok { + return + } + nodes, err := tailnet.OnlyNodeUpdates(resp) + require.NoError(m.t, err) + if len(nodes) != len(expected) { + m.t.Logf("expected %d, got %d nodes", len(expected), len(nodes)) + continue + } + + derps := make([]int, 0, len(nodes)) + for _, n := range nodes { + derps = append(derps, n.PreferredDERP) + } + for _, e := range expected { + if !slices.Contains(derps, e) { + m.t.Logf("expected DERP %d to be in %v", e, derps) + continue + } + return + } + } +} + +func (m *TestMultiAgent) RequireEventuallyClosed(ctx context.Context) { + m.t.Helper() + tkr := time.NewTicker(testutil.IntervalFast) + defer tkr.Stop() + for { + select { + case <-ctx.Done(): + m.t.Fatal("timeout") + return // unhittable + case <-tkr.C: + if m.a.IsClosed() { + return + } + } + } +}