feat: support graceful disconnect in PGCoordinator (#10937)

Adds support for graceful disconnect to PGCoordinator.  When peers gracefully disconnect, they send a disconnect message.  This triggers the peer to be disconnected from all tunneled peers.

The Multi-Agent Client supports graceful disconnect, since it is in memory and we know that when it is closed, we really mean to disconnect.

The v1 agent and client Websocket connections do not support graceful disconnect, since the v1 protocol doesn't have this feature.  That means that if a v1 peer connects to a v2 peer, when the v1 peer's coordinator connection is closed, the v2 peer will
see it as "lost" since we don't know whether the v1 peer meant to disconnect, or it just lost connectivity to the coordinator.
This commit is contained in:
Spike Curtis
2023-12-01 09:55:25 +04:00
committed by GitHub
parent 967db2801b
commit 0cab6e7763
4 changed files with 326 additions and 123 deletions
+35 -18
View File
@@ -24,16 +24,17 @@ type connIO struct {
// coordCtx is the parent context, that is, the context of the Coordinator
coordCtx context.Context
// peerCtx is the context of the connection to our peer
peerCtx context.Context
cancel context.CancelFunc
logger slog.Logger
requests <-chan *proto.CoordinateRequest
responses chan<- *proto.CoordinateResponse
bindings chan<- binding
tunnels chan<- tunnel
auth agpl.TunnelAuth
mu sync.Mutex
closed bool
peerCtx context.Context
cancel context.CancelFunc
logger slog.Logger
requests <-chan *proto.CoordinateRequest
responses chan<- *proto.CoordinateResponse
bindings chan<- binding
tunnels chan<- tunnel
auth agpl.TunnelAuth
mu sync.Mutex
closed bool
disconnected bool
name string
start int64
@@ -76,20 +77,29 @@ func newConnIO(coordContext context.Context,
func (c *connIO) recvLoop() {
defer func() {
// withdraw bindings & tunnels when we exit. We need to use the parent context here, since
// withdraw bindings & tunnels when we exit. We need to use the coordinator context here, since
// our own context might be canceled, but we still need to withdraw.
b := binding{
bKey: bKey(c.UniqueID()),
kind: proto.CoordinateResponse_PeerUpdate_LOST,
}
if c.disconnected {
b.kind = proto.CoordinateResponse_PeerUpdate_DISCONNECTED
}
if err := sendCtx(c.coordCtx, c.bindings, b); err != nil {
c.logger.Debug(c.coordCtx, "parent context expired while withdrawing bindings", slog.Error(err))
}
t := tunnel{
tKey: tKey{src: c.UniqueID()},
active: false,
}
if err := sendCtx(c.coordCtx, c.tunnels, t); err != nil {
c.logger.Debug(c.coordCtx, "parent context expired while withdrawing tunnels", slog.Error(err))
// only remove tunnels on graceful disconnect. If we remove tunnels for lost peers, then
// this will look like a disconnect from the peer perspective, since we query for active peers
// by using the tunnel as a join in the database
if c.disconnected {
t := tunnel{
tKey: tKey{src: c.UniqueID()},
active: false,
}
if err := sendCtx(c.coordCtx, c.tunnels, t); err != nil {
c.logger.Debug(c.coordCtx, "parent context expired while withdrawing tunnels", slog.Error(err))
}
}
}()
defer c.Close()
@@ -111,6 +121,8 @@ func (c *connIO) recvLoop() {
}
}
var errDisconnect = xerrors.New("graceful disconnect")
func (c *connIO) handleRequest(req *proto.CoordinateRequest) error {
c.logger.Debug(c.peerCtx, "got request")
if req.UpdateSelf != nil {
@@ -118,6 +130,7 @@ func (c *connIO) handleRequest(req *proto.CoordinateRequest) error {
b := binding{
bKey: bKey(c.UniqueID()),
node: req.UpdateSelf.Node,
kind: proto.CoordinateResponse_PeerUpdate_NODE,
}
if err := sendCtx(c.coordCtx, c.bindings, b); err != nil {
c.logger.Debug(c.peerCtx, "failed to send binding", slog.Error(err))
@@ -169,7 +182,11 @@ func (c *connIO) handleRequest(req *proto.CoordinateRequest) error {
return err
}
}
// TODO: (spikecurtis) support Disconnect
if req.Disconnect != nil {
c.logger.Debug(c.peerCtx, "graceful disconnect")
c.disconnected = true
return errDisconnect
}
return nil
}
+6 -6
View File
@@ -58,7 +58,7 @@ func TestPGCoordinator_MultiAgent(t *testing.T) {
require.NoError(t, agent1.close())
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
assertEventuallyNoAgents(ctx, t, store, agent1.id)
assertEventuallyLost(ctx, t, store, agent1.id)
}
// TestPGCoordinator_MultiAgent_UnsubscribeRace tests a single coordinator with
@@ -106,7 +106,7 @@ func TestPGCoordinator_MultiAgent_UnsubscribeRace(t *testing.T) {
require.NoError(t, agent1.close())
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
assertEventuallyNoAgents(ctx, t, store, agent1.id)
assertEventuallyLost(ctx, t, store, agent1.id)
}
// TestPGCoordinator_MultiAgent_Unsubscribe tests a single coordinator with a
@@ -168,7 +168,7 @@ func TestPGCoordinator_MultiAgent_Unsubscribe(t *testing.T) {
require.NoError(t, agent1.close())
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
assertEventuallyNoAgents(ctx, t, store, agent1.id)
assertEventuallyLost(ctx, t, store, agent1.id)
}
// TestPGCoordinator_MultiAgent_MultiCoordinator tests two coordinators with a
@@ -220,7 +220,7 @@ func TestPGCoordinator_MultiAgent_MultiCoordinator(t *testing.T) {
require.NoError(t, agent1.close())
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
assertEventuallyNoAgents(ctx, t, store, agent1.id)
assertEventuallyLost(ctx, t, store, agent1.id)
}
// TestPGCoordinator_MultiAgent_MultiCoordinator_UpdateBeforeSubscribe tests two
@@ -273,7 +273,7 @@ func TestPGCoordinator_MultiAgent_MultiCoordinator_UpdateBeforeSubscribe(t *test
require.NoError(t, agent1.close())
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
assertEventuallyNoAgents(ctx, t, store, agent1.id)
assertEventuallyLost(ctx, t, store, agent1.id)
}
// TestPGCoordinator_MultiAgent_TwoAgents tests three coordinators with a
@@ -344,5 +344,5 @@ func TestPGCoordinator_MultiAgent_TwoAgents(t *testing.T) {
require.NoError(t, agent2.close())
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
assertEventuallyNoAgents(ctx, t, store, agent1.id)
assertEventuallyLost(ctx, t, store, agent1.id)
}
+109 -69
View File
@@ -203,6 +203,7 @@ func (c *pgCoord) ServeMultiAgent(id uuid.UUID) agpl.MultiAgentConn {
}})
},
OnRemove: func(_ agpl.Queue) {
_ = sendCtx(c.ctx, reqs, &proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}})
cancel()
},
}).Init()
@@ -352,9 +353,14 @@ func v1SendLoop(ctx context.Context, cancel context.CancelFunc, logger slog.Logg
_ = q.CoordinatorClose()
return
}
// don't send empty updates
if len(nodes) == 0 {
logger.Debug(ctx, "skipping enqueueing 0-length v1 update")
continue
}
err = q.Enqueue(nodes)
if err != nil {
logger.Error(ctx, "failed to enqueue multi-agent update", slog.Error(err))
logger.Error(ctx, "failed to enqueue v1 update", slog.Error(err))
}
}
}
@@ -597,6 +603,7 @@ type bKey uuid.UUID
type binding struct {
bKey
node *proto.Node
kind proto.CoordinateResponse_PeerUpdate_Kind
}
// binder reads node bindings from the channel and writes them to the database. It handles retries with a backoff.
@@ -675,22 +682,7 @@ func (b *binder) worker() {
func (b *binder) writeOne(bnd binding) error {
var err error
if bnd.node != nil {
var nodeRaw []byte
nodeRaw, err = gProto.Marshal(bnd.node)
if err != nil {
// this is very bad news, but it should never happen because the node was Unmarshalled or converted by this
// process earlier.
b.logger.Critical(b.ctx, "failed to marshal node", slog.Error(err))
return err
}
_, err = b.store.UpsertTailnetPeer(b.ctx, database.UpsertTailnetPeerParams{
ID: uuid.UUID(bnd.bKey),
CoordinatorID: b.coordinatorID,
Node: nodeRaw,
Status: database.TailnetStatusOk,
})
} else {
if bnd.kind == proto.CoordinateResponse_PeerUpdate_DISCONNECTED {
_, err = b.store.DeleteTailnetPeer(b.ctx, database.DeleteTailnetPeerParams{
ID: uuid.UUID(bnd.bKey),
CoordinatorID: b.coordinatorID,
@@ -699,6 +691,25 @@ func (b *binder) writeOne(bnd binding) error {
if xerrors.Is(err, sql.ErrNoRows) {
err = nil
}
} else {
var nodeRaw []byte
nodeRaw, err = gProto.Marshal(bnd.node)
if err != nil {
// this is very bad news, but it should never happen because the node was Unmarshalled or converted by this
// process earlier.
b.logger.Critical(b.ctx, "failed to marshal node", slog.Error(err))
return err
}
status := database.TailnetStatusOk
if bnd.kind == proto.CoordinateResponse_PeerUpdate_LOST {
status = database.TailnetStatusLost
}
_, err = b.store.UpsertTailnetPeer(b.ctx, database.UpsertTailnetPeerParams{
ID: uuid.UUID(bnd.bKey),
CoordinatorID: b.coordinatorID,
Node: nodeRaw,
Status: status,
})
}
if err != nil && !database.IsQueryCanceledError(err) {
@@ -710,16 +721,27 @@ func (b *binder) writeOne(bnd binding) error {
return err
}
// storeBinding stores the latest binding, where we interpret node == nil as removing the binding. This keeps the map
// storeBinding stores the latest binding, where we interpret kind == DISCONNECTED as removing the binding. This keeps the map
// from growing without bound.
func (b *binder) storeBinding(bnd binding) {
b.mu.Lock()
defer b.mu.Unlock()
if bnd.node != nil {
switch bnd.kind {
case proto.CoordinateResponse_PeerUpdate_NODE:
b.latest[bnd.bKey] = bnd
} else {
// nil node is interpreted as removing binding
case proto.CoordinateResponse_PeerUpdate_DISCONNECTED:
delete(b.latest, bnd.bKey)
case proto.CoordinateResponse_PeerUpdate_LOST:
// we need to coalesce with the previously stored node, since it must
// be non-nil in the database
old, ok := b.latest[bnd.bKey]
if !ok {
// lost before we ever got a node update. No action
return
}
bnd.node = old.node
b.latest[bnd.bKey] = bnd
}
}
@@ -732,6 +754,7 @@ func (b *binder) retrieveBinding(bk bKey) binding {
bnd = binding{
bKey: bk,
node: nil,
kind: proto.CoordinateResponse_PeerUpdate_DISCONNECTED,
}
}
return bnd
@@ -752,9 +775,8 @@ type mapper struct {
// latest is the most recent, unfiltered snapshot of the mappings we know about
latest []mapping
// sent is the state of mappings we have actually enqueued; used to compute diffs for updates. It is a map from peer
// ID to node.
sent map[uuid.UUID]*proto.Node
// sent is the state of mappings we have actually enqueued; used to compute diffs for updates.
sent map[uuid.UUID]mapping
// called to filter mappings to healthy coordinators
heartbeats *heartbeats
@@ -771,7 +793,7 @@ func newMapper(c *connIO, logger slog.Logger, h *heartbeats) *mapper {
update: make(chan struct{}),
mappings: make(chan []mapping),
heartbeats: h,
sent: make(map[uuid.UUID]*proto.Node),
sent: make(map[uuid.UUID]mapping),
}
go m.run()
return m
@@ -779,19 +801,19 @@ func newMapper(c *connIO, logger slog.Logger, h *heartbeats) *mapper {
func (m *mapper) run() {
for {
var nodes map[uuid.UUID]*proto.Node
var best map[uuid.UUID]mapping
select {
case <-m.ctx.Done():
return
case mappings := <-m.mappings:
m.logger.Debug(m.ctx, "got new mappings")
m.latest = mappings
nodes = m.mappingsToNodes(mappings)
best = m.bestMappings(mappings)
case <-m.update:
m.logger.Debug(m.ctx, "triggered update")
nodes = m.mappingsToNodes(m.latest)
best = m.bestMappings(m.latest)
}
update := m.nodesToUpdate(nodes)
update := m.bestToUpdate(best)
if update == nil {
m.logger.Debug(m.ctx, "skipping nil node update")
continue
@@ -802,67 +824,83 @@ func (m *mapper) run() {
}
}
// mappingsToNodes takes a set of mappings and resolves the best set of nodes. We may get several mappings for a
// bestMappings takes a set of mappings and resolves the best set of nodes. We may get several mappings for a
// particular connection, from different coordinators in the distributed system. Furthermore, some coordinators
// might be considered invalid on account of missing heartbeats. We take the most recent mapping from a valid
// coordinator as the "best" mapping.
func (m *mapper) mappingsToNodes(mappings []mapping) map[uuid.UUID]*proto.Node {
func (m *mapper) bestMappings(mappings []mapping) map[uuid.UUID]mapping {
mappings = m.heartbeats.filter(mappings)
best := make(map[uuid.UUID]mapping, len(mappings))
for _, m := range mappings {
bestM, ok := best[m.peer]
if !ok || m.updatedAt.After(bestM.updatedAt) {
best[m.peer] = m
for _, mpng := range mappings {
bestM, ok := best[mpng.peer]
switch {
case !ok:
// no current best
best[mpng.peer] = mpng
// NODE always beats LOST mapping, since the LOST could be from a coordinator that's
// slow updating the DB, and the peer has reconnected to a different coordinator and
// given a NODE mapping.
case bestM.kind == proto.CoordinateResponse_PeerUpdate_LOST && mpng.kind == proto.CoordinateResponse_PeerUpdate_NODE:
best[mpng.peer] = mpng
case mpng.updatedAt.After(bestM.updatedAt) && mpng.kind == proto.CoordinateResponse_PeerUpdate_NODE:
// newer, and it's a NODE update.
best[mpng.peer] = mpng
}
}
nodes := make(map[uuid.UUID]*proto.Node, len(best))
for k, m := range best {
nodes[k] = m.node
}
return nodes
return best
}
func (m *mapper) nodesToUpdate(nodes map[uuid.UUID]*proto.Node) *proto.CoordinateResponse {
func (m *mapper) bestToUpdate(best map[uuid.UUID]mapping) *proto.CoordinateResponse {
resp := new(proto.CoordinateResponse)
for k, n := range nodes {
sn, ok := m.sent[k]
if !ok {
resp.PeerUpdates = append(resp.PeerUpdates, &proto.CoordinateResponse_PeerUpdate{
Uuid: agpl.UUIDToByteSlice(k),
Node: n,
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
Reason: "new",
})
for k, mpng := range best {
var reason string
sm, ok := m.sent[k]
switch {
case !ok && mpng.kind == proto.CoordinateResponse_PeerUpdate_LOST:
// we don't need to send a "lost" update if we've never sent an update about this peer
continue
}
eq, err := sn.Equal(n)
if err != nil {
m.logger.Critical(m.ctx, "failed to compare nodes", slog.F("old", sn), slog.F("new", n))
}
if !eq {
resp.PeerUpdates = append(resp.PeerUpdates, &proto.CoordinateResponse_PeerUpdate{
Uuid: agpl.UUIDToByteSlice(k),
Node: n,
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
Reason: "update",
})
case !ok && mpng.kind == proto.CoordinateResponse_PeerUpdate_NODE:
reason = "new"
case ok && sm.kind == proto.CoordinateResponse_PeerUpdate_LOST && mpng.kind == proto.CoordinateResponse_PeerUpdate_LOST:
// was lost and remains lost, no update needed
continue
case ok && sm.kind == proto.CoordinateResponse_PeerUpdate_LOST && mpng.kind == proto.CoordinateResponse_PeerUpdate_NODE:
reason = "found"
case ok && sm.kind == proto.CoordinateResponse_PeerUpdate_NODE && mpng.kind == proto.CoordinateResponse_PeerUpdate_LOST:
reason = "lost"
case ok && sm.kind == proto.CoordinateResponse_PeerUpdate_NODE && mpng.kind == proto.CoordinateResponse_PeerUpdate_NODE:
eq, err := sm.node.Equal(mpng.node)
if err != nil {
m.logger.Critical(m.ctx, "failed to compare nodes", slog.F("old", sm.node), slog.F("new", mpng.kind))
continue
}
if eq {
continue
}
reason = "update"
}
resp.PeerUpdates = append(resp.PeerUpdates, &proto.CoordinateResponse_PeerUpdate{
Uuid: agpl.UUIDToByteSlice(k),
Node: mpng.node,
Kind: mpng.kind,
Reason: reason,
})
m.sent[k] = mpng
}
for k := range m.sent {
if _, ok := nodes[k]; !ok {
if _, ok := best[k]; !ok {
resp.PeerUpdates = append(resp.PeerUpdates, &proto.CoordinateResponse_PeerUpdate{
Uuid: agpl.UUIDToByteSlice(k),
Kind: proto.CoordinateResponse_PeerUpdate_DISCONNECTED,
Reason: "disconnected",
})
delete(m.sent, k)
}
}
m.sent = nodes
if len(resp.PeerUpdates) == 0 {
return nil
}
@@ -1069,10 +1107,6 @@ func (q *querier) mappingQuery(peer mKey) error {
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
return err
}
if len(bindings) == 0 {
logger.Debug(q.ctx, "no mappings, nothing to do")
return nil
}
mappings, err := q.bindingsToMappings(bindings)
if err != nil {
logger.Debug(q.ctx, "failed to convert mappings", slog.Error(err))
@@ -1100,11 +1134,16 @@ func (q *querier) bindingsToMappings(bindings []database.GetTailnetTunnelPeerBin
q.logger.Error(q.ctx, "failed to unmarshal node", slog.Error(err))
return nil, backoff.Permanent(err)
}
kind := proto.CoordinateResponse_PeerUpdate_NODE
if binding.Status == database.TailnetStatusLost {
kind = proto.CoordinateResponse_PeerUpdate_LOST
}
mappings = append(mappings, mapping{
peer: binding.PeerID,
coordinator: binding.CoordinatorID,
updatedAt: binding.UpdatedAt,
node: node,
kind: kind,
})
}
return mappings, nil
@@ -1326,6 +1365,7 @@ type mapping struct {
coordinator uuid.UUID
updatedAt time.Time
node *proto.Node
kind proto.CoordinateResponse_PeerUpdate_Kind
}
// querierWorkKey describes two kinds of work the querier needs to do. If peerUpdate
+176 -30
View File
@@ -71,7 +71,7 @@ func TestPGCoordinatorSingle_ClientWithoutAgent(t *testing.T) {
require.NoError(t, err)
<-client.errChan
<-client.closeChan
assertEventuallyNoClientsForAgent(ctx, t, store, agentID)
assertEventuallyLost(ctx, t, store, client.id)
}
func TestPGCoordinatorSingle_AgentWithoutClients(t *testing.T) {
@@ -108,7 +108,7 @@ func TestPGCoordinatorSingle_AgentWithoutClients(t *testing.T) {
require.NoError(t, err)
<-agent.errChan
<-agent.closeChan
assertEventuallyNoAgents(ctx, t, store, agent.id)
assertEventuallyLost(ctx, t, store, agent.id)
}
func TestPGCoordinatorSingle_AgentWithClient(t *testing.T) {
@@ -184,8 +184,8 @@ func TestPGCoordinatorSingle_AgentWithClient(t *testing.T) {
_ = client.recvErr(ctx, t)
client.waitForClose(ctx, t)
assertEventuallyNoAgents(ctx, t, store, agent.id)
assertEventuallyNoClientsForAgent(ctx, t, store, agent.id)
assertEventuallyLost(ctx, t, store, agent.id)
assertEventuallyLost(ctx, t, store, client.id)
}
func TestPGCoordinatorSingle_MissedHeartbeats(t *testing.T) {
@@ -272,7 +272,7 @@ func TestPGCoordinatorSingle_MissedHeartbeats(t *testing.T) {
_ = client.recvErr(ctx, t)
client.waitForClose(ctx, t)
assertEventuallyNoClientsForAgent(ctx, t, store, agent.id)
assertEventuallyLost(ctx, t, store, client.id)
}
func TestPGCoordinatorSingle_SendsHeartbeats(t *testing.T) {
@@ -519,8 +519,8 @@ func TestPGCoordinator_MultiCoordinatorAgent(t *testing.T) {
require.ErrorIs(t, err, io.ErrClosedPipe)
client.waitForClose(ctx, t)
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
assertEventuallyNoAgents(ctx, t, store, agent1.id)
assertEventuallyLost(ctx, t, store, client.id)
assertEventuallyLost(ctx, t, store, agent1.id)
}
func TestPGCoordinator_Unhealthy(t *testing.T) {
@@ -624,6 +624,63 @@ func TestPGCoordinator_BidirectionalTunnels(t *testing.T) {
p2.assertEventuallyHasDERP(p1.id, 1)
}
func TestPGCoordinator_GracefulDisconnect(t *testing.T) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("test only with postgres")
}
store, ps := dbtestutil.NewDB(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coordinator, err := tailnet.NewPGCoordV2(ctx, logger, ps, store)
require.NoError(t, err)
defer coordinator.Close()
p1 := newTestPeer(ctx, t, coordinator, "p1")
defer p1.close(ctx)
p2 := newTestPeer(ctx, t, coordinator, "p2")
defer p2.close(ctx)
p1.addTunnel(p2.id)
p1.updateDERP(1)
p2.updateDERP(2)
p1.assertEventuallyHasDERP(p2.id, 2)
p2.assertEventuallyHasDERP(p1.id, 1)
p2.disconnect()
p1.assertEventuallyDisconnected(p2.id)
p2.assertEventuallyResponsesClosed()
}
func TestPGCoordinator_Lost(t *testing.T) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("test only with postgres")
}
store, ps := dbtestutil.NewDB(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coordinator, err := tailnet.NewPGCoordV2(ctx, logger, ps, store)
require.NoError(t, err)
defer coordinator.Close()
p1 := newTestPeer(ctx, t, coordinator, "p1")
defer p1.close(ctx)
p2 := newTestPeer(ctx, t, coordinator, "p2")
defer p2.close(ctx)
p1.addTunnel(p2.id)
p1.updateDERP(1)
p2.updateDERP(2)
p1.assertEventuallyHasDERP(p2.id, 2)
p2.assertEventuallyHasDERP(p1.id, 1)
p2.close(ctx)
p1.assertEventuallyLost(p2.id)
}
type testConn struct {
ws, serverWS net.Conn
nodeChan chan []*agpl.Node
@@ -813,6 +870,7 @@ func assertMultiAgentNeverHasDERPs(ctx context.Context, t *testing.T, ma agpl.Mu
}
func assertEventuallyNoAgents(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID) {
t.Helper()
assert.Eventually(t, func() bool {
agents, err := store.GetTailnetPeers(ctx, agentID)
if xerrors.Is(err, sql.ErrNoRows) {
@@ -825,6 +883,25 @@ func assertEventuallyNoAgents(ctx context.Context, t *testing.T, store database.
}, testutil.WaitShort, testutil.IntervalFast)
}
func assertEventuallyLost(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID) {
t.Helper()
assert.Eventually(t, func() bool {
peers, err := store.GetTailnetPeers(ctx, agentID)
if xerrors.Is(err, sql.ErrNoRows) {
return false
}
if err != nil {
t.Fatal(err)
}
for _, peer := range peers {
if peer.Status == database.TailnetStatusOk {
return false
}
}
return true
}, testutil.WaitShort, testutil.IntervalFast)
}
func assertEventuallyNoClientsForAgent(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID) {
t.Helper()
assert.Eventually(t, func() bool {
@@ -839,6 +916,11 @@ func assertEventuallyNoClientsForAgent(ctx context.Context, t *testing.T, store
}, testutil.WaitShort, testutil.IntervalFast)
}
type peerStatus struct {
preferredDERP int32
status proto.CoordinateResponse_PeerUpdate_Kind
}
type testPeer struct {
ctx context.Context
cancel context.CancelFunc
@@ -847,11 +929,11 @@ type testPeer struct {
name string
resps <-chan *proto.CoordinateResponse
reqs chan<- *proto.CoordinateRequest
derps map[uuid.UUID]int32
peers map[uuid.UUID]peerStatus
}
func newTestPeer(ctx context.Context, t testing.TB, coord agpl.CoordinatorV2, name string, id ...uuid.UUID) *testPeer {
p := &testPeer{t: t, name: name, derps: make(map[uuid.UUID]int32)}
p := &testPeer{t: t, name: name, peers: make(map[uuid.UUID]peerStatus)}
p.ctx, p.cancel = context.WithCancel(ctx)
if len(id) > 1 {
t.Fatal("too many")
@@ -890,38 +972,102 @@ func (p *testPeer) updateDERP(derp int32) {
}
}
func (p *testPeer) disconnect() {
p.t.Helper()
req := &proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}}
select {
case <-p.ctx.Done():
p.t.Errorf("timeout updating node for %s", p.name)
return
case p.reqs <- req:
return
}
}
func (p *testPeer) assertEventuallyHasDERP(other uuid.UUID, derp int32) {
p.t.Helper()
for {
d, ok := p.derps[other]
if ok && d == derp {
o, ok := p.peers[other]
if ok && o.preferredDERP == derp {
return
}
select {
case <-p.ctx.Done():
p.t.Errorf("timeout waiting for response for %s", p.name)
if err := p.handleOneResp(); err != nil {
assert.NoError(p.t, err)
return
case resp, ok := <-p.resps:
if !ok {
p.t.Errorf("responses closed for %s", p.name)
return
}
}
}
func (p *testPeer) assertEventuallyDisconnected(other uuid.UUID) {
p.t.Helper()
for {
_, ok := p.peers[other]
if !ok {
return
}
if err := p.handleOneResp(); err != nil {
assert.NoError(p.t, err)
return
}
}
}
func (p *testPeer) assertEventuallyLost(other uuid.UUID) {
p.t.Helper()
for {
o := p.peers[other]
if o.status == proto.CoordinateResponse_PeerUpdate_LOST {
return
}
if err := p.handleOneResp(); err != nil {
assert.NoError(p.t, err)
return
}
}
}
func (p *testPeer) assertEventuallyResponsesClosed() {
p.t.Helper()
for {
err := p.handleOneResp()
if xerrors.Is(err, responsesClosed) {
return
}
if !assert.NoError(p.t, err) {
return
}
}
}
var responsesClosed = xerrors.New("responses closed")
func (p *testPeer) handleOneResp() error {
select {
case <-p.ctx.Done():
return p.ctx.Err()
case resp, ok := <-p.resps:
if !ok {
return responsesClosed
}
for _, update := range resp.PeerUpdates {
id, err := uuid.FromBytes(update.Uuid)
if err != nil {
return err
}
for _, update := range resp.PeerUpdates {
id, err := uuid.FromBytes(update.Uuid)
if !assert.NoError(p.t, err) {
return
}
switch update.Kind {
case proto.CoordinateResponse_PeerUpdate_NODE:
p.derps[id] = update.Node.PreferredDerp
case proto.CoordinateResponse_PeerUpdate_DISCONNECTED:
delete(p.derps, id)
default:
p.t.Errorf("unhandled update kind %s", update.Kind)
switch update.Kind {
case proto.CoordinateResponse_PeerUpdate_NODE, proto.CoordinateResponse_PeerUpdate_LOST:
p.peers[id] = peerStatus{
preferredDERP: update.GetNode().GetPreferredDerp(),
status: update.Kind,
}
case proto.CoordinateResponse_PeerUpdate_DISCONNECTED:
delete(p.peers, id)
default:
return xerrors.Errorf("unhandled update kind %s", update.Kind)
}
}
}
return nil
}
func (p *testPeer) close(ctx context.Context) {