package tailnet_test import ( "context" "database/sql" "net/netip" "sync" "testing" "time" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/goleak" "go.uber.org/mock/gomock" "golang.org/x/xerrors" gProto "google.golang.org/protobuf/proto" "cdr.dev/slog/v3" "cdr.dev/slog/v3/sloggers/slogtest" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbmock" "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/database/pubsub" "github.com/coder/coder/v2/enterprise/tailnet" agpl "github.com/coder/coder/v2/tailnet" "github.com/coder/coder/v2/tailnet/proto" agpltest "github.com/coder/coder/v2/tailnet/test" "github.com/coder/coder/v2/testutil" "github.com/coder/quartz" ) func TestMain(m *testing.M) { goleak.VerifyTestMain(m, testutil.GoleakOptions...) } func TestPGCoordinatorSingle_ClientWithoutAgent(t *testing.T) { t.Parallel() store, ps := dbtestutil.NewDB(t) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) defer cancel() logger := testutil.Logger(t) coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store) require.NoError(t, err) defer coordinator.Close() agentID := uuid.New() client := agpltest.NewClient(ctx, t, coordinator, "client", agentID) defer client.Close(ctx) client.UpdateDERP(10) require.Eventually(t, func() bool { clients, err := store.GetTailnetTunnelPeerBindings(ctx, agentID) if err != nil && !xerrors.Is(err, sql.ErrNoRows) { t.Fatalf("database error: %v", err) } if len(clients) == 0 { return false } node := new(proto.Node) err = gProto.Unmarshal(clients[0].Node, node) assert.NoError(t, err) assert.EqualValues(t, 10, node.PreferredDerp) return true }, testutil.WaitShort, testutil.IntervalFast) client.UngracefulDisconnect(ctx) assertEventuallyLost(ctx, t, store, client.ID) } func TestPGCoordinatorSingle_AgentWithoutClients(t *testing.T) { t.Parallel() store, ps := dbtestutil.NewDB(t) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) defer cancel() logger := testutil.Logger(t) coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store) require.NoError(t, err) defer coordinator.Close() agent := agpltest.NewAgent(ctx, t, coordinator, "agent") defer agent.Close(ctx) agent.UpdateDERP(10) require.Eventually(t, func() bool { agents, err := store.GetTailnetPeers(ctx, agent.ID) if err != nil && !xerrors.Is(err, sql.ErrNoRows) { t.Fatalf("database error: %v", err) } if len(agents) == 0 { return false } node := new(proto.Node) err = gProto.Unmarshal(agents[0].Node, node) assert.NoError(t, err) assert.EqualValues(t, 10, node.PreferredDerp) return true }, testutil.WaitShort, testutil.IntervalFast) agent.UngracefulDisconnect(ctx) assertEventuallyLost(ctx, t, store, agent.ID) } func TestPGCoordinatorSingle_AgentInvalidIP(t *testing.T) { t.Parallel() store, ps := dbtestutil.NewDB(t) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) defer cancel() logger := testutil.Logger(t) coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store) require.NoError(t, err) defer coordinator.Close() agent := agpltest.NewAgent(ctx, t, coordinator, "agent") defer agent.Close(ctx) prefix := agpl.TailscaleServicePrefix.RandomPrefix() agent.UpdateNode(&proto.Node{ Addresses: []string{prefix.String()}, PreferredDerp: 10, }) // The agent connection should be closed immediately after sending an invalid addr agent.AssertEventuallyResponsesClosed( agpl.AuthorizationError{Wrapped: agpl.InvalidNodeAddressError{Addr: prefix.Addr().String()}}.Error()) assertEventuallyLost(ctx, t, store, agent.ID) } func TestPGCoordinatorSingle_AgentInvalidIPBits(t *testing.T) { t.Parallel() store, ps := dbtestutil.NewDB(t) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) defer cancel() logger := testutil.Logger(t) coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store) require.NoError(t, err) defer coordinator.Close() agent := agpltest.NewAgent(ctx, t, coordinator, "agent") defer agent.Close(ctx) agent.UpdateNode(&proto.Node{ Addresses: []string{ netip.PrefixFrom(agpl.TailscaleServicePrefix.AddrFromUUID(agent.ID), 64).String(), }, PreferredDerp: 10, }) // The agent connection should be closed immediately after sending an invalid addr agent.AssertEventuallyResponsesClosed( agpl.AuthorizationError{Wrapped: agpl.InvalidAddressBitsError{Bits: 64}}.Error()) assertEventuallyLost(ctx, t, store, agent.ID) } func TestPGCoordinatorSingle_AgentValidIP(t *testing.T) { t.Parallel() store, ps := dbtestutil.NewDB(t) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) defer cancel() logger := testutil.Logger(t) coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store) require.NoError(t, err) defer coordinator.Close() agent := agpltest.NewAgent(ctx, t, coordinator, "agent") defer agent.Close(ctx) agent.UpdateNode(&proto.Node{ Addresses: []string{ agpl.TailscaleServicePrefix.PrefixFromUUID(agent.ID).String(), }, PreferredDerp: 10, }) require.Eventually(t, func() bool { agents, err := store.GetTailnetPeers(ctx, agent.ID) if err != nil && !xerrors.Is(err, sql.ErrNoRows) { t.Fatalf("database error: %v", err) } if len(agents) == 0 { return false } node := new(proto.Node) err = gProto.Unmarshal(agents[0].Node, node) assert.NoError(t, err) assert.EqualValues(t, 10, node.PreferredDerp) return true }, testutil.WaitShort, testutil.IntervalFast) agent.UngracefulDisconnect(ctx) assertEventuallyLost(ctx, t, store, agent.ID) } func TestPGCoordinatorSingle_AgentWithClient(t *testing.T) { t.Parallel() store, ps := dbtestutil.NewDB(t) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) defer cancel() logger := testutil.Logger(t) coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store) require.NoError(t, err) defer coordinator.Close() agent := agpltest.NewAgent(ctx, t, coordinator, "original") defer agent.Close(ctx) agent.UpdateDERP(10) client := agpltest.NewClient(ctx, t, coordinator, "client", agent.ID) defer client.Close(ctx) client.AssertEventuallyHasDERP(agent.ID, 10) client.UpdateDERP(11) agent.AssertEventuallyHasDERP(client.ID, 11) // Ensure an update to the agent node reaches the connIO! agent.UpdateDERP(12) client.AssertEventuallyHasDERP(agent.ID, 12) // Close the agent channel so a new one can connect. agent.Close(ctx) // Create a new agent connection. This is to simulate a reconnect! agent = agpltest.NewPeer(ctx, t, coordinator, "reconnection", agpltest.WithID(agent.ID)) // Ensure the coordinator sends its client node immediately! agent.AssertEventuallyHasDERP(client.ID, 11) // Send a bunch of updates in rapid succession, and test that we eventually get the latest. We don't want the // coordinator accidentally reordering things. for d := int32(13); d < 36; d++ { agent.UpdateDERP(d) } client.AssertEventuallyHasDERP(agent.ID, 35) agent.UngracefulDisconnect(ctx) client.UngracefulDisconnect(ctx) assertEventuallyLost(ctx, t, store, agent.ID) assertEventuallyLost(ctx, t, store, client.ID) } func TestPGCoordinatorSingle_MissedHeartbeats(t *testing.T) { t.Parallel() store, ps := dbtestutil.NewDB(t) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) defer cancel() logger := testutil.Logger(t) mClock := quartz.NewMock(t) afTrap := mClock.Trap().AfterFunc("heartbeats", "recvBeat") defer afTrap.Close() rstTrap := mClock.Trap().TimerReset("heartbeats", "resetExpiryTimerWithLock") defer rstTrap.Close() coordinator, err := tailnet.NewTestPGCoord(ctx, logger, ps, store, mClock) require.NoError(t, err) defer coordinator.Close() agent := agpltest.NewAgent(ctx, t, coordinator, "agent") defer agent.Close(ctx) agent.UpdateDERP(10) client := agpltest.NewClient(ctx, t, coordinator, "client", agent.ID) defer client.Close(ctx) client.AssertEventuallyHasDERP(agent.ID, 10) client.UpdateDERP(11) agent.AssertEventuallyHasDERP(client.ID, 11) // simulate a second coordinator via DB calls only --- our goal is to test broken heart-beating, so we can't use a // real coordinator fCoord2 := &fakeCoordinator{ ctx: ctx, t: t, store: store, id: uuid.New(), } fCoord2.heartbeat() afTrap.MustWait(ctx).MustRelease(ctx) // heartbeat timeout started fCoord2.agentNode(agent.ID, &agpl.Node{PreferredDERP: 12}) client.AssertEventuallyHasDERP(agent.ID, 12) fCoord3 := &fakeCoordinator{ ctx: ctx, t: t, store: store, id: uuid.New(), } fCoord3.heartbeat() rstTrap.MustWait(ctx).MustRelease(ctx) // timeout gets reset fCoord3.agentNode(agent.ID, &agpl.Node{PreferredDERP: 13}) client.AssertEventuallyHasDERP(agent.ID, 13) // fCoord2 sends in a second heartbeat, one period later (on time) mClock.Advance(tailnet.HeartbeatPeriod).MustWait(ctx) fCoord2.heartbeat() rstTrap.MustWait(ctx).MustRelease(ctx) // timeout gets reset // when the fCoord3 misses enough heartbeats, the real coordinator should send an update with the // node from fCoord2 for the agent. mClock.Advance(tailnet.HeartbeatPeriod).MustWait(ctx) w := mClock.Advance(tailnet.HeartbeatPeriod) rstTrap.MustWait(ctx).MustRelease(ctx) w.MustWait(ctx) client.AssertEventuallyHasDERP(agent.ID, 12) // one more heartbeat period will result in fCoord2 being expired, which should cause us to // revert to the original agent mapping mClock.Advance(tailnet.HeartbeatPeriod).MustWait(ctx) // note that the timeout doesn't get reset because both fCoord2 and fCoord3 are expired client.AssertEventuallyHasDERP(agent.ID, 10) // send fCoord3 heartbeat, which should trigger us to consider that mapping valid again. fCoord3.heartbeat() rstTrap.MustWait(ctx).MustRelease(ctx) // timeout gets reset client.AssertEventuallyHasDERP(agent.ID, 13) agent.UngracefulDisconnect(ctx) client.UngracefulDisconnect(ctx) assertEventuallyLost(ctx, t, store, client.ID) } func TestPGCoordinatorSingle_MissedHeartbeats_NoDrop(t *testing.T) { t.Parallel() store, ps := dbtestutil.NewDB(t) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) defer cancel() logger := testutil.Logger(t) coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store) require.NoError(t, err) defer coordinator.Close() agentID := uuid.New() client := agpltest.NewPeer(ctx, t, coordinator, "client") defer client.Close(ctx) client.AddTunnel(agentID) client.UpdateDERP(11) // simulate a second coordinator via DB calls only --- our goal is to test // broken heart-beating, so we can't use a real coordinator fCoord2 := &fakeCoordinator{ ctx: ctx, t: t, store: store, id: uuid.New(), } // simulate a single heartbeat, the coordinator is healthy fCoord2.heartbeat() fCoord2.agentNode(agentID, &agpl.Node{PreferredDERP: 12}) // since it's healthy the client should get the new node. client.AssertEventuallyHasDERP(agentID, 12) // the heartbeat should then timeout and we'll get sent a LOST update, NOT a // disconnect. client.AssertEventuallyLost(agentID) client.UngracefulDisconnect(ctx) assertEventuallyLost(ctx, t, store, client.ID) } func TestPGCoordinatorSingle_SendsHeartbeats(t *testing.T) { t.Parallel() store, ps := dbtestutil.NewDB(t) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) defer cancel() logger := testutil.Logger(t) mu := sync.Mutex{} heartbeats := []time.Time{} unsub, err := ps.SubscribeWithErr(tailnet.EventHeartbeats, func(_ context.Context, _ []byte, err error) { assert.NoError(t, err) mu.Lock() defer mu.Unlock() heartbeats = append(heartbeats, time.Now()) }) require.NoError(t, err) defer unsub() start := time.Now() coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store) require.NoError(t, err) defer coordinator.Close() require.Eventually(t, func() bool { mu.Lock() defer mu.Unlock() if len(heartbeats) < 2 { return false } assert.Greater(t, heartbeats[0].Sub(start), time.Duration(0)) assert.Greater(t, heartbeats[1].Sub(start), time.Duration(0)) return assert.Greater(t, heartbeats[1].Sub(heartbeats[0]), tailnet.HeartbeatPeriod*3/4) }, testutil.WaitMedium, testutil.IntervalMedium) } // TestPGCoordinatorDual_Mainline tests with 2 coordinators, one agent connected to each, and 2 clients per agent. // // +---------+ // agent1 ---> | coord1 | <--- client11 (coord 1, agent 1) // | | // | | <--- client12 (coord 1, agent 2) // +---------+ // +---------+ // agent2 ---> | coord2 | <--- client21 (coord 2, agent 1) // | | // | | <--- client22 (coord2, agent 2) // +---------+ func TestPGCoordinatorDual_Mainline(t *testing.T) { t.Parallel() store, ps := dbtestutil.NewDB(t) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) defer cancel() logger := testutil.Logger(t) coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store) require.NoError(t, err) defer coord1.Close() coord2, err := tailnet.NewPGCoord(ctx, logger.Named("coord2"), ps, store) require.NoError(t, err) defer coord2.Close() agent1 := agpltest.NewAgent(ctx, t, coord1, "agent1") defer agent1.Close(ctx) t.Logf("agent1=%s", agent1.ID) agent2 := agpltest.NewAgent(ctx, t, coord2, "agent2") defer agent2.Close(ctx) t.Logf("agent2=%s", agent2.ID) client11 := agpltest.NewClient(ctx, t, coord1, "client11", agent1.ID) defer client11.Close(ctx) t.Logf("client11=%s", client11.ID) client12 := agpltest.NewClient(ctx, t, coord1, "client12", agent2.ID) defer client12.Close(ctx) t.Logf("client12=%s", client12.ID) client21 := agpltest.NewClient(ctx, t, coord2, "client21", agent1.ID) defer client21.Close(ctx) t.Logf("client21=%s", client21.ID) client22 := agpltest.NewClient(ctx, t, coord2, "client22", agent2.ID) defer client22.Close(ctx) t.Logf("client22=%s", client22.ID) t.Log("client11 -> Node 11") client11.UpdateDERP(11) agent1.AssertEventuallyHasDERP(client11.ID, 11) t.Log("client21 -> Node 21") client21.UpdateDERP(21) agent1.AssertEventuallyHasDERP(client21.ID, 21) t.Log("client22 -> Node 22") client22.UpdateDERP(22) agent2.AssertEventuallyHasDERP(client22.ID, 22) t.Log("agent2 -> Node 2") agent2.UpdateDERP(2) client22.AssertEventuallyHasDERP(agent2.ID, 2) client12.AssertEventuallyHasDERP(agent2.ID, 2) t.Log("client12 -> Node 12") client12.UpdateDERP(12) agent2.AssertEventuallyHasDERP(client12.ID, 12) t.Log("agent1 -> Node 1") agent1.UpdateDERP(1) client21.AssertEventuallyHasDERP(agent1.ID, 1) client11.AssertEventuallyHasDERP(agent1.ID, 1) t.Log("close coord2") err = coord2.Close() require.NoError(t, err) // this closes agent2, client22, client21 agent2.AssertEventuallyResponsesClosed(agpl.CloseErrCoordinatorClose) client22.AssertEventuallyResponsesClosed(agpl.CloseErrCoordinatorClose) client21.AssertEventuallyResponsesClosed(agpl.CloseErrCoordinatorClose) assertEventuallyLost(ctx, t, store, agent2.ID) assertEventuallyLost(ctx, t, store, client21.ID) assertEventuallyLost(ctx, t, store, client22.ID) err = coord1.Close() require.NoError(t, err) // this closes agent1, client12, client11 agent1.AssertEventuallyResponsesClosed(agpl.CloseErrCoordinatorClose) client12.AssertEventuallyResponsesClosed(agpl.CloseErrCoordinatorClose) client11.AssertEventuallyResponsesClosed(agpl.CloseErrCoordinatorClose) assertEventuallyLost(ctx, t, store, agent1.ID) assertEventuallyLost(ctx, t, store, client11.ID) assertEventuallyLost(ctx, t, store, client12.ID) } // TestPGCoordinator_MultiCoordinatorAgent tests when a single agent connects to multiple coordinators. // We use two agent connections, but they share the same AgentID. This could happen due to a reconnection, // or an infrastructure problem where an old workspace is not fully cleaned up before a new one started. // // +---------+ // agent1 ---> | coord1 | // +---------+ // +---------+ // agent2 ---> | coord2 | // +---------+ // +---------+ // | coord3 | <--- client // +---------+ func TestPGCoordinator_MultiCoordinatorAgent(t *testing.T) { t.Parallel() store, ps := dbtestutil.NewDB(t) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) defer cancel() logger := testutil.Logger(t) coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store) require.NoError(t, err) defer coord1.Close() coord2, err := tailnet.NewPGCoord(ctx, logger.Named("coord2"), ps, store) require.NoError(t, err) defer coord2.Close() coord3, err := tailnet.NewPGCoord(ctx, logger.Named("coord3"), ps, store) require.NoError(t, err) defer coord3.Close() agent1 := agpltest.NewAgent(ctx, t, coord1, "agent1") defer agent1.Close(ctx) agent2 := agpltest.NewPeer(ctx, t, coord2, "agent2", agpltest.WithID(agent1.ID), agpltest.WithAuth(agpl.AgentCoordinateeAuth{ID: agent1.ID}), ) defer agent2.Close(ctx) client := agpltest.NewClient(ctx, t, coord3, "client", agent1.ID) defer client.Close(ctx) client.UpdateDERP(3) agent1.AssertEventuallyHasDERP(client.ID, 3) agent2.AssertEventuallyHasDERP(client.ID, 3) agent1.UpdateDERP(1) client.AssertEventuallyHasDERP(agent1.ID, 1) // agent2's update overrides agent1 because it is newer agent2.UpdateDERP(2) client.AssertEventuallyHasDERP(agent1.ID, 2) // agent2 disconnects, and we should revert back to agent1 agent2.Close(ctx) client.AssertEventuallyHasDERP(agent1.ID, 1) agent1.UpdateDERP(11) client.AssertEventuallyHasDERP(agent1.ID, 11) client.UpdateDERP(31) agent1.AssertEventuallyHasDERP(client.ID, 31) agent1.UngracefulDisconnect(ctx) client.UngracefulDisconnect(ctx) assertEventuallyLost(ctx, t, store, client.ID) assertEventuallyLost(ctx, t, store, agent1.ID) } func TestPGCoordinator_Unhealthy(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) defer cancel() ctrl := gomock.NewController(t) mStore := dbmock.NewMockStore(ctrl) ps := pubsub.NewInMemory() logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) calls := make(chan struct{}) // first call succeeds, so that our Agent will successfully connect. firstSucceeds := mStore.EXPECT().UpsertTailnetCoordinator(gomock.Any(), gomock.Any()). Times(1). Return(database.TailnetCoordinator{}, nil) // next 3 fail, so the Coordinator becomes unhealthy, and we test that it disconnects the agent threeMissed := mStore.EXPECT().UpsertTailnetCoordinator(gomock.Any(), gomock.Any()). After(firstSucceeds). Times(3). Do(func(_ context.Context, _ uuid.UUID) { <-calls }). Return(database.TailnetCoordinator{}, xerrors.New("test disconnect")) mStore.EXPECT().UpsertTailnetCoordinator(gomock.Any(), gomock.Any()). MinTimes(1). After(threeMissed). Do(func(_ context.Context, _ uuid.UUID) { <-calls }). Return(database.TailnetCoordinator{}, nil) // extra calls we don't particularly care about for this test mStore.EXPECT().CleanTailnetCoordinators(gomock.Any()).AnyTimes().Return(nil) mStore.EXPECT().CleanTailnetLostPeers(gomock.Any()).AnyTimes().Return(nil) mStore.EXPECT().CleanTailnetTunnels(gomock.Any()).AnyTimes().Return(nil) mStore.EXPECT().GetTailnetTunnelPeerIDs(gomock.Any(), gomock.Any()).AnyTimes().Return(nil, nil) mStore.EXPECT().GetTailnetTunnelPeerBindings(gomock.Any(), gomock.Any()). AnyTimes().Return(nil, nil) mStore.EXPECT().DeleteTailnetPeer(gomock.Any(), gomock.Any()). AnyTimes().Return(database.DeleteTailnetPeerRow{}, nil) mStore.EXPECT().DeleteAllTailnetTunnels(gomock.Any(), gomock.Any()).AnyTimes().Return(nil) mStore.EXPECT().UpdateTailnetPeerStatusByCoordinator(gomock.Any(), gomock.Any()) uut, err := tailnet.NewPGCoord(ctx, logger, ps, mStore) require.NoError(t, err) defer func() { err := uut.Close() require.NoError(t, err) }() agent1 := agpltest.NewAgent(ctx, t, uut, "agent1") defer agent1.Close(ctx) for i := 0; i < 3; i++ { select { case <-ctx.Done(): t.Fatalf("timeout waiting for call %d", i+1) case calls <- struct{}{}: // OK } } // connected agent should be disconnected agent1.AssertEventuallyResponsesClosed(tailnet.CloseErrUnhealthy) // new agent should immediately disconnect agent2 := agpltest.NewAgent(ctx, t, uut, "agent2") defer agent2.Close(ctx) agent2.AssertEventuallyResponsesClosed(tailnet.CloseErrUnhealthy) // next heartbeats succeed, so we are healthy for i := 0; i < 2; i++ { select { case <-ctx.Done(): t.Fatal("timeout") case calls <- struct{}{}: // OK } } agent3 := agpltest.NewAgent(ctx, t, uut, "agent3") defer agent3.Close(ctx) agent3.AssertNotClosed(time.Second) } func TestPGCoordinator_Node_Empty(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) defer cancel() ctrl := gomock.NewController(t) mStore := dbmock.NewMockStore(ctrl) ps := pubsub.NewInMemory() logger := testutil.Logger(t) id := uuid.New() mStore.EXPECT().GetTailnetPeers(gomock.Any(), id).Times(1).Return(nil, nil) // extra calls we don't particularly care about for this test mStore.EXPECT().UpsertTailnetCoordinator(gomock.Any(), gomock.Any()). AnyTimes(). Return(database.TailnetCoordinator{}, nil) mStore.EXPECT().CleanTailnetCoordinators(gomock.Any()).AnyTimes().Return(nil) mStore.EXPECT().CleanTailnetLostPeers(gomock.Any()).AnyTimes().Return(nil) mStore.EXPECT().CleanTailnetTunnels(gomock.Any()).AnyTimes().Return(nil) mStore.EXPECT().UpdateTailnetPeerStatusByCoordinator(gomock.Any(), gomock.Any()).Times(1) uut, err := tailnet.NewPGCoord(ctx, logger, ps, mStore) require.NoError(t, err) defer func() { err := uut.Close() require.NoError(t, err) }() node := uut.Node(id) require.Nil(t, node) } // TestPGCoordinator_BidirectionalTunnels tests when peers create tunnels to each other. We don't // do this now, but it's schematically possible, so we should make sure it doesn't break anything. func TestPGCoordinator_BidirectionalTunnels(t *testing.T) { t.Parallel() store, ps := dbtestutil.NewDB(t) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) defer cancel() logger := testutil.Logger(t) coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store) require.NoError(t, err) defer coordinator.Close() agpltest.BidirectionalTunnels(ctx, t, coordinator) } func TestPGCoordinator_GracefulDisconnect(t *testing.T) { t.Parallel() store, ps := dbtestutil.NewDB(t) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) defer cancel() logger := testutil.Logger(t) coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store) require.NoError(t, err) defer coordinator.Close() agpltest.GracefulDisconnectTest(ctx, t, coordinator) } func TestPGCoordinator_Lost(t *testing.T) { t.Parallel() store, ps := dbtestutil.NewDB(t) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) defer cancel() logger := testutil.Logger(t) coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store) require.NoError(t, err) defer coordinator.Close() agpltest.LostTest(ctx, t, coordinator) } func TestPGCoordinator_NoDeleteOnClose(t *testing.T) { t.Parallel() store, ps := dbtestutil.NewDB(t) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) defer cancel() logger := testutil.Logger(t) coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store) require.NoError(t, err) defer coordinator.Close() agent := agpltest.NewAgent(ctx, t, coordinator, "original") defer agent.Close(ctx) agent.UpdateDERP(10) client := agpltest.NewClient(ctx, t, coordinator, "client", agent.ID) defer client.Close(ctx) // Simulate some traffic to generate // a peer. client.AssertEventuallyHasDERP(agent.ID, 10) client.UpdateDERP(11) agent.AssertEventuallyHasDERP(client.ID, 11) anode := coordinator.Node(agent.ID) require.NotNil(t, anode) cnode := coordinator.Node(client.ID) require.NotNil(t, cnode) err = coordinator.Close() require.NoError(t, err) assertEventuallyLost(ctx, t, store, agent.ID) assertEventuallyLost(ctx, t, store, client.ID) coordinator2, err := tailnet.NewPGCoord(ctx, logger, ps, store) require.NoError(t, err) defer coordinator2.Close() anode = coordinator2.Node(agent.ID) require.NotNil(t, anode) assert.Equal(t, 10, anode.PreferredDERP) cnode = coordinator2.Node(client.ID) require.NotNil(t, cnode) assert.Equal(t, 11, cnode.PreferredDERP) } // TestPGCoordinatorDual_FailedHeartbeat tests that peers // disconnect from a coordinator when they are unhealthy, // are marked as LOST (not DISCONNECTED), and can reconnect to // a new coordinator and reestablish their tunnels. func TestPGCoordinatorDual_FailedHeartbeat(t *testing.T) { t.Parallel() dburl, err := dbtestutil.Open(t) require.NoError(t, err) store1, ps1, sdb1 := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithURL(dburl)) defer sdb1.Close() store2, ps2, sdb2 := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithURL(dburl)) defer sdb2.Close() ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) t.Cleanup(cancel) // We do this to avoid failing due errors related to the // database connection being close. logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) // Create two coordinators, 1 for each peer. c1, err := tailnet.NewPGCoord(ctx, logger, ps1, store1) require.NoError(t, err) c2, err := tailnet.NewPGCoord(ctx, logger, ps2, store2) require.NoError(t, err) p1 := agpltest.NewPeer(ctx, t, c1, "peer1") p2 := agpltest.NewPeer(ctx, t, c2, "peer2") // Create a binding between the two. p1.AddTunnel(p2.ID) // Ensure that messages pass through. p1.UpdateDERP(1) p2.UpdateDERP(2) p1.AssertEventuallyHasDERP(p2.ID, 2) p2.AssertEventuallyHasDERP(p1.ID, 1) // Close the underlying database connection to induce // a heartbeat failure scenario and assert that // we eventually disconnect from the coordinator. err = sdb1.Close() require.NoError(t, err) p1.AssertEventuallyResponsesClosed(tailnet.CloseErrUnhealthy) p2.AssertEventuallyLost(p1.ID) // This basically checks that peer2 had no update // performed on their status since we are connected // to coordinator2. assertEventuallyStatus(ctx, t, store2, p2.ID, database.TailnetStatusOk) // Connect peer1 to coordinator2. p1.ConnectToCoordinator(ctx, c2) // Reestablish binding. p1.AddTunnel(p2.ID) // Ensure messages still flow back and forth. p1.AssertEventuallyHasDERP(p2.ID, 2) p1.UpdateDERP(3) p2.UpdateDERP(4) p2.AssertEventuallyHasDERP(p1.ID, 3) p1.AssertEventuallyHasDERP(p2.ID, 4) // Make sure peer2 never got an update about peer1 disconnecting. p2.AssertNeverUpdateKind(p1.ID, proto.CoordinateResponse_PeerUpdate_DISCONNECTED) } func TestPGCoordinatorDual_PeerReconnect(t *testing.T) { t.Parallel() store, ps := dbtestutil.NewDB(t) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) defer cancel() logger := testutil.Logger(t) // Create two coordinators, 1 for each peer. c1, err := tailnet.NewPGCoord(ctx, logger, ps, store) require.NoError(t, err) c2, err := tailnet.NewPGCoord(ctx, logger, ps, store) require.NoError(t, err) p1 := agpltest.NewPeer(ctx, t, c1, "peer1") p2 := agpltest.NewPeer(ctx, t, c2, "peer2") // Create a binding between the two. p1.AddTunnel(p2.ID) // Ensure that messages pass through. p1.UpdateDERP(1) p2.UpdateDERP(2) p1.AssertEventuallyHasDERP(p2.ID, 2) p2.AssertEventuallyHasDERP(p1.ID, 1) // Close coordinator1. Now we will check that we // never send a DISCONNECTED update. err = c1.Close() require.NoError(t, err) p1.AssertEventuallyResponsesClosed(agpl.CloseErrCoordinatorClose) p2.AssertEventuallyLost(p1.ID) // This basically checks that peer2 had no update // performed on their status since we are connected // to coordinator2. assertEventuallyStatus(ctx, t, store, p2.ID, database.TailnetStatusOk) // Connect peer1 to coordinator2. p1.ConnectToCoordinator(ctx, c2) // Reestablish binding. p1.AddTunnel(p2.ID) // Ensure messages still flow back and forth. p1.AssertEventuallyHasDERP(p2.ID, 2) p1.UpdateDERP(3) p2.UpdateDERP(4) p2.AssertEventuallyHasDERP(p1.ID, 3) p1.AssertEventuallyHasDERP(p2.ID, 4) // Make sure peer2 never got an update about peer1 disconnecting. p2.AssertNeverUpdateKind(p1.ID, proto.CoordinateResponse_PeerUpdate_DISCONNECTED) } // TestPGCoordinatorPropogatedPeerContext tests that the context for a specific peer // is propogated through to the `Authorize` method of the coordinatee auth func TestPGCoordinatorPropogatedPeerContext(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitMedium) store, ps := dbtestutil.NewDB(t) logger := testutil.Logger(t) peerCtx := context.WithValue(ctx, agpltest.FakeSubjectKey{}, struct{}{}) peerID := uuid.UUID{0x01} agentID := uuid.UUID{0x02} c1, err := tailnet.NewPGCoord(ctx, logger, ps, store) require.NoError(t, err) defer func() { err := c1.Close() require.NoError(t, err) }() ch := make(chan struct{}) auth := agpltest.FakeCoordinateeAuth{ Chan: ch, } reqs, _ := c1.Coordinate(peerCtx, peerID, "peer1", auth) testutil.RequireSend(ctx, t, reqs, &proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Id: agpl.UUIDToByteSlice(agentID)}}) _ = testutil.TryReceive(ctx, t, ch) } func assertEventuallyStatus(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID, status database.TailnetStatus) { t.Helper() assert.Eventually(t, func() bool { peers, err := store.GetTailnetPeers(ctx, agentID) if xerrors.Is(err, sql.ErrNoRows) { return false } if err != nil { t.Fatal(err) } for _, peer := range peers { if peer.Status != status { return false } } return true }, testutil.WaitShort, testutil.IntervalFast) } func assertEventuallyLost(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID) { t.Helper() assertEventuallyStatus(ctx, t, store, agentID, database.TailnetStatusLost) } func assertEventuallyNoClientsForAgent(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID) { t.Helper() assert.Eventually(t, func() bool { clients, err := store.GetTailnetTunnelPeerIDs(ctx, agentID) if xerrors.Is(err, sql.ErrNoRows) { return true } if err != nil { t.Fatal(err) } return len(clients) == 0 }, testutil.WaitShort, testutil.IntervalFast) } type fakeCoordinator struct { ctx context.Context t *testing.T store database.Store id uuid.UUID } func (c *fakeCoordinator) heartbeat() { c.t.Helper() _, err := c.store.UpsertTailnetCoordinator(c.ctx, c.id) require.NoError(c.t, err) } func (c *fakeCoordinator) agentNode(agentID uuid.UUID, node *agpl.Node) { c.t.Helper() pNode, err := agpl.NodeToProto(node) require.NoError(c.t, err) nodeRaw, err := gProto.Marshal(pNode) require.NoError(c.t, err) _, err = c.store.UpsertTailnetPeer(c.ctx, database.UpsertTailnetPeerParams{ ID: agentID, CoordinatorID: c.id, Node: nodeRaw, Status: database.TailnetStatusOk, }) require.NoError(c.t, err) }