mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
feat: support graceful disconnect in PGCoordinator (#10937)
Adds support for graceful disconnect to PGCoordinator. When peers gracefully disconnect, they send a disconnect message. This triggers the peer to be disconnected from all tunneled peers. The Multi-Agent Client supports graceful disconnect, since it is in memory and we know that when it is closed, we really mean to disconnect. The v1 agent and client Websocket connections do not support graceful disconnect, since the v1 protocol doesn't have this feature. That means that if a v1 peer connects to a v2 peer, when the v1 peer's coordinator connection is closed, the v2 peer will see it as "lost" since we don't know whether the v1 peer meant to disconnect, or it just lost connectivity to the coordinator.
This commit is contained in:
@@ -24,16 +24,17 @@ type connIO struct {
|
||||
// coordCtx is the parent context, that is, the context of the Coordinator
|
||||
coordCtx context.Context
|
||||
// peerCtx is the context of the connection to our peer
|
||||
peerCtx context.Context
|
||||
cancel context.CancelFunc
|
||||
logger slog.Logger
|
||||
requests <-chan *proto.CoordinateRequest
|
||||
responses chan<- *proto.CoordinateResponse
|
||||
bindings chan<- binding
|
||||
tunnels chan<- tunnel
|
||||
auth agpl.TunnelAuth
|
||||
mu sync.Mutex
|
||||
closed bool
|
||||
peerCtx context.Context
|
||||
cancel context.CancelFunc
|
||||
logger slog.Logger
|
||||
requests <-chan *proto.CoordinateRequest
|
||||
responses chan<- *proto.CoordinateResponse
|
||||
bindings chan<- binding
|
||||
tunnels chan<- tunnel
|
||||
auth agpl.TunnelAuth
|
||||
mu sync.Mutex
|
||||
closed bool
|
||||
disconnected bool
|
||||
|
||||
name string
|
||||
start int64
|
||||
@@ -76,20 +77,29 @@ func newConnIO(coordContext context.Context,
|
||||
|
||||
func (c *connIO) recvLoop() {
|
||||
defer func() {
|
||||
// withdraw bindings & tunnels when we exit. We need to use the parent context here, since
|
||||
// withdraw bindings & tunnels when we exit. We need to use the coordinator context here, since
|
||||
// our own context might be canceled, but we still need to withdraw.
|
||||
b := binding{
|
||||
bKey: bKey(c.UniqueID()),
|
||||
kind: proto.CoordinateResponse_PeerUpdate_LOST,
|
||||
}
|
||||
if c.disconnected {
|
||||
b.kind = proto.CoordinateResponse_PeerUpdate_DISCONNECTED
|
||||
}
|
||||
if err := sendCtx(c.coordCtx, c.bindings, b); err != nil {
|
||||
c.logger.Debug(c.coordCtx, "parent context expired while withdrawing bindings", slog.Error(err))
|
||||
}
|
||||
t := tunnel{
|
||||
tKey: tKey{src: c.UniqueID()},
|
||||
active: false,
|
||||
}
|
||||
if err := sendCtx(c.coordCtx, c.tunnels, t); err != nil {
|
||||
c.logger.Debug(c.coordCtx, "parent context expired while withdrawing tunnels", slog.Error(err))
|
||||
// only remove tunnels on graceful disconnect. If we remove tunnels for lost peers, then
|
||||
// this will look like a disconnect from the peer perspective, since we query for active peers
|
||||
// by using the tunnel as a join in the database
|
||||
if c.disconnected {
|
||||
t := tunnel{
|
||||
tKey: tKey{src: c.UniqueID()},
|
||||
active: false,
|
||||
}
|
||||
if err := sendCtx(c.coordCtx, c.tunnels, t); err != nil {
|
||||
c.logger.Debug(c.coordCtx, "parent context expired while withdrawing tunnels", slog.Error(err))
|
||||
}
|
||||
}
|
||||
}()
|
||||
defer c.Close()
|
||||
@@ -111,6 +121,8 @@ func (c *connIO) recvLoop() {
|
||||
}
|
||||
}
|
||||
|
||||
var errDisconnect = xerrors.New("graceful disconnect")
|
||||
|
||||
func (c *connIO) handleRequest(req *proto.CoordinateRequest) error {
|
||||
c.logger.Debug(c.peerCtx, "got request")
|
||||
if req.UpdateSelf != nil {
|
||||
@@ -118,6 +130,7 @@ func (c *connIO) handleRequest(req *proto.CoordinateRequest) error {
|
||||
b := binding{
|
||||
bKey: bKey(c.UniqueID()),
|
||||
node: req.UpdateSelf.Node,
|
||||
kind: proto.CoordinateResponse_PeerUpdate_NODE,
|
||||
}
|
||||
if err := sendCtx(c.coordCtx, c.bindings, b); err != nil {
|
||||
c.logger.Debug(c.peerCtx, "failed to send binding", slog.Error(err))
|
||||
@@ -169,7 +182,11 @@ func (c *connIO) handleRequest(req *proto.CoordinateRequest) error {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// TODO: (spikecurtis) support Disconnect
|
||||
if req.Disconnect != nil {
|
||||
c.logger.Debug(c.peerCtx, "graceful disconnect")
|
||||
c.disconnected = true
|
||||
return errDisconnect
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ func TestPGCoordinator_MultiAgent(t *testing.T) {
|
||||
require.NoError(t, agent1.close())
|
||||
|
||||
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
|
||||
assertEventuallyNoAgents(ctx, t, store, agent1.id)
|
||||
assertEventuallyLost(ctx, t, store, agent1.id)
|
||||
}
|
||||
|
||||
// TestPGCoordinator_MultiAgent_UnsubscribeRace tests a single coordinator with
|
||||
@@ -106,7 +106,7 @@ func TestPGCoordinator_MultiAgent_UnsubscribeRace(t *testing.T) {
|
||||
require.NoError(t, agent1.close())
|
||||
|
||||
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
|
||||
assertEventuallyNoAgents(ctx, t, store, agent1.id)
|
||||
assertEventuallyLost(ctx, t, store, agent1.id)
|
||||
}
|
||||
|
||||
// TestPGCoordinator_MultiAgent_Unsubscribe tests a single coordinator with a
|
||||
@@ -168,7 +168,7 @@ func TestPGCoordinator_MultiAgent_Unsubscribe(t *testing.T) {
|
||||
require.NoError(t, agent1.close())
|
||||
|
||||
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
|
||||
assertEventuallyNoAgents(ctx, t, store, agent1.id)
|
||||
assertEventuallyLost(ctx, t, store, agent1.id)
|
||||
}
|
||||
|
||||
// TestPGCoordinator_MultiAgent_MultiCoordinator tests two coordinators with a
|
||||
@@ -220,7 +220,7 @@ func TestPGCoordinator_MultiAgent_MultiCoordinator(t *testing.T) {
|
||||
require.NoError(t, agent1.close())
|
||||
|
||||
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
|
||||
assertEventuallyNoAgents(ctx, t, store, agent1.id)
|
||||
assertEventuallyLost(ctx, t, store, agent1.id)
|
||||
}
|
||||
|
||||
// TestPGCoordinator_MultiAgent_MultiCoordinator_UpdateBeforeSubscribe tests two
|
||||
@@ -273,7 +273,7 @@ func TestPGCoordinator_MultiAgent_MultiCoordinator_UpdateBeforeSubscribe(t *test
|
||||
require.NoError(t, agent1.close())
|
||||
|
||||
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
|
||||
assertEventuallyNoAgents(ctx, t, store, agent1.id)
|
||||
assertEventuallyLost(ctx, t, store, agent1.id)
|
||||
}
|
||||
|
||||
// TestPGCoordinator_MultiAgent_TwoAgents tests three coordinators with a
|
||||
@@ -344,5 +344,5 @@ func TestPGCoordinator_MultiAgent_TwoAgents(t *testing.T) {
|
||||
require.NoError(t, agent2.close())
|
||||
|
||||
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
|
||||
assertEventuallyNoAgents(ctx, t, store, agent1.id)
|
||||
assertEventuallyLost(ctx, t, store, agent1.id)
|
||||
}
|
||||
|
||||
+109
-69
@@ -203,6 +203,7 @@ func (c *pgCoord) ServeMultiAgent(id uuid.UUID) agpl.MultiAgentConn {
|
||||
}})
|
||||
},
|
||||
OnRemove: func(_ agpl.Queue) {
|
||||
_ = sendCtx(c.ctx, reqs, &proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}})
|
||||
cancel()
|
||||
},
|
||||
}).Init()
|
||||
@@ -352,9 +353,14 @@ func v1SendLoop(ctx context.Context, cancel context.CancelFunc, logger slog.Logg
|
||||
_ = q.CoordinatorClose()
|
||||
return
|
||||
}
|
||||
// don't send empty updates
|
||||
if len(nodes) == 0 {
|
||||
logger.Debug(ctx, "skipping enqueueing 0-length v1 update")
|
||||
continue
|
||||
}
|
||||
err = q.Enqueue(nodes)
|
||||
if err != nil {
|
||||
logger.Error(ctx, "failed to enqueue multi-agent update", slog.Error(err))
|
||||
logger.Error(ctx, "failed to enqueue v1 update", slog.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -597,6 +603,7 @@ type bKey uuid.UUID
|
||||
type binding struct {
|
||||
bKey
|
||||
node *proto.Node
|
||||
kind proto.CoordinateResponse_PeerUpdate_Kind
|
||||
}
|
||||
|
||||
// binder reads node bindings from the channel and writes them to the database. It handles retries with a backoff.
|
||||
@@ -675,22 +682,7 @@ func (b *binder) worker() {
|
||||
|
||||
func (b *binder) writeOne(bnd binding) error {
|
||||
var err error
|
||||
if bnd.node != nil {
|
||||
var nodeRaw []byte
|
||||
nodeRaw, err = gProto.Marshal(bnd.node)
|
||||
if err != nil {
|
||||
// this is very bad news, but it should never happen because the node was Unmarshalled or converted by this
|
||||
// process earlier.
|
||||
b.logger.Critical(b.ctx, "failed to marshal node", slog.Error(err))
|
||||
return err
|
||||
}
|
||||
_, err = b.store.UpsertTailnetPeer(b.ctx, database.UpsertTailnetPeerParams{
|
||||
ID: uuid.UUID(bnd.bKey),
|
||||
CoordinatorID: b.coordinatorID,
|
||||
Node: nodeRaw,
|
||||
Status: database.TailnetStatusOk,
|
||||
})
|
||||
} else {
|
||||
if bnd.kind == proto.CoordinateResponse_PeerUpdate_DISCONNECTED {
|
||||
_, err = b.store.DeleteTailnetPeer(b.ctx, database.DeleteTailnetPeerParams{
|
||||
ID: uuid.UUID(bnd.bKey),
|
||||
CoordinatorID: b.coordinatorID,
|
||||
@@ -699,6 +691,25 @@ func (b *binder) writeOne(bnd binding) error {
|
||||
if xerrors.Is(err, sql.ErrNoRows) {
|
||||
err = nil
|
||||
}
|
||||
} else {
|
||||
var nodeRaw []byte
|
||||
nodeRaw, err = gProto.Marshal(bnd.node)
|
||||
if err != nil {
|
||||
// this is very bad news, but it should never happen because the node was Unmarshalled or converted by this
|
||||
// process earlier.
|
||||
b.logger.Critical(b.ctx, "failed to marshal node", slog.Error(err))
|
||||
return err
|
||||
}
|
||||
status := database.TailnetStatusOk
|
||||
if bnd.kind == proto.CoordinateResponse_PeerUpdate_LOST {
|
||||
status = database.TailnetStatusLost
|
||||
}
|
||||
_, err = b.store.UpsertTailnetPeer(b.ctx, database.UpsertTailnetPeerParams{
|
||||
ID: uuid.UUID(bnd.bKey),
|
||||
CoordinatorID: b.coordinatorID,
|
||||
Node: nodeRaw,
|
||||
Status: status,
|
||||
})
|
||||
}
|
||||
|
||||
if err != nil && !database.IsQueryCanceledError(err) {
|
||||
@@ -710,16 +721,27 @@ func (b *binder) writeOne(bnd binding) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// storeBinding stores the latest binding, where we interpret node == nil as removing the binding. This keeps the map
|
||||
// storeBinding stores the latest binding, where we interpret kind == DISCONNECTED as removing the binding. This keeps the map
|
||||
// from growing without bound.
|
||||
func (b *binder) storeBinding(bnd binding) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
if bnd.node != nil {
|
||||
|
||||
switch bnd.kind {
|
||||
case proto.CoordinateResponse_PeerUpdate_NODE:
|
||||
b.latest[bnd.bKey] = bnd
|
||||
} else {
|
||||
// nil node is interpreted as removing binding
|
||||
case proto.CoordinateResponse_PeerUpdate_DISCONNECTED:
|
||||
delete(b.latest, bnd.bKey)
|
||||
case proto.CoordinateResponse_PeerUpdate_LOST:
|
||||
// we need to coalesce with the previously stored node, since it must
|
||||
// be non-nil in the database
|
||||
old, ok := b.latest[bnd.bKey]
|
||||
if !ok {
|
||||
// lost before we ever got a node update. No action
|
||||
return
|
||||
}
|
||||
bnd.node = old.node
|
||||
b.latest[bnd.bKey] = bnd
|
||||
}
|
||||
}
|
||||
|
||||
@@ -732,6 +754,7 @@ func (b *binder) retrieveBinding(bk bKey) binding {
|
||||
bnd = binding{
|
||||
bKey: bk,
|
||||
node: nil,
|
||||
kind: proto.CoordinateResponse_PeerUpdate_DISCONNECTED,
|
||||
}
|
||||
}
|
||||
return bnd
|
||||
@@ -752,9 +775,8 @@ type mapper struct {
|
||||
|
||||
// latest is the most recent, unfiltered snapshot of the mappings we know about
|
||||
latest []mapping
|
||||
// sent is the state of mappings we have actually enqueued; used to compute diffs for updates. It is a map from peer
|
||||
// ID to node.
|
||||
sent map[uuid.UUID]*proto.Node
|
||||
// sent is the state of mappings we have actually enqueued; used to compute diffs for updates.
|
||||
sent map[uuid.UUID]mapping
|
||||
|
||||
// called to filter mappings to healthy coordinators
|
||||
heartbeats *heartbeats
|
||||
@@ -771,7 +793,7 @@ func newMapper(c *connIO, logger slog.Logger, h *heartbeats) *mapper {
|
||||
update: make(chan struct{}),
|
||||
mappings: make(chan []mapping),
|
||||
heartbeats: h,
|
||||
sent: make(map[uuid.UUID]*proto.Node),
|
||||
sent: make(map[uuid.UUID]mapping),
|
||||
}
|
||||
go m.run()
|
||||
return m
|
||||
@@ -779,19 +801,19 @@ func newMapper(c *connIO, logger slog.Logger, h *heartbeats) *mapper {
|
||||
|
||||
func (m *mapper) run() {
|
||||
for {
|
||||
var nodes map[uuid.UUID]*proto.Node
|
||||
var best map[uuid.UUID]mapping
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
return
|
||||
case mappings := <-m.mappings:
|
||||
m.logger.Debug(m.ctx, "got new mappings")
|
||||
m.latest = mappings
|
||||
nodes = m.mappingsToNodes(mappings)
|
||||
best = m.bestMappings(mappings)
|
||||
case <-m.update:
|
||||
m.logger.Debug(m.ctx, "triggered update")
|
||||
nodes = m.mappingsToNodes(m.latest)
|
||||
best = m.bestMappings(m.latest)
|
||||
}
|
||||
update := m.nodesToUpdate(nodes)
|
||||
update := m.bestToUpdate(best)
|
||||
if update == nil {
|
||||
m.logger.Debug(m.ctx, "skipping nil node update")
|
||||
continue
|
||||
@@ -802,67 +824,83 @@ func (m *mapper) run() {
|
||||
}
|
||||
}
|
||||
|
||||
// mappingsToNodes takes a set of mappings and resolves the best set of nodes. We may get several mappings for a
|
||||
// bestMappings takes a set of mappings and resolves the best set of nodes. We may get several mappings for a
|
||||
// particular connection, from different coordinators in the distributed system. Furthermore, some coordinators
|
||||
// might be considered invalid on account of missing heartbeats. We take the most recent mapping from a valid
|
||||
// coordinator as the "best" mapping.
|
||||
func (m *mapper) mappingsToNodes(mappings []mapping) map[uuid.UUID]*proto.Node {
|
||||
func (m *mapper) bestMappings(mappings []mapping) map[uuid.UUID]mapping {
|
||||
mappings = m.heartbeats.filter(mappings)
|
||||
best := make(map[uuid.UUID]mapping, len(mappings))
|
||||
for _, m := range mappings {
|
||||
bestM, ok := best[m.peer]
|
||||
if !ok || m.updatedAt.After(bestM.updatedAt) {
|
||||
best[m.peer] = m
|
||||
for _, mpng := range mappings {
|
||||
bestM, ok := best[mpng.peer]
|
||||
switch {
|
||||
case !ok:
|
||||
// no current best
|
||||
best[mpng.peer] = mpng
|
||||
|
||||
// NODE always beats LOST mapping, since the LOST could be from a coordinator that's
|
||||
// slow updating the DB, and the peer has reconnected to a different coordinator and
|
||||
// given a NODE mapping.
|
||||
case bestM.kind == proto.CoordinateResponse_PeerUpdate_LOST && mpng.kind == proto.CoordinateResponse_PeerUpdate_NODE:
|
||||
best[mpng.peer] = mpng
|
||||
case mpng.updatedAt.After(bestM.updatedAt) && mpng.kind == proto.CoordinateResponse_PeerUpdate_NODE:
|
||||
// newer, and it's a NODE update.
|
||||
best[mpng.peer] = mpng
|
||||
}
|
||||
}
|
||||
nodes := make(map[uuid.UUID]*proto.Node, len(best))
|
||||
for k, m := range best {
|
||||
nodes[k] = m.node
|
||||
}
|
||||
return nodes
|
||||
return best
|
||||
}
|
||||
|
||||
func (m *mapper) nodesToUpdate(nodes map[uuid.UUID]*proto.Node) *proto.CoordinateResponse {
|
||||
func (m *mapper) bestToUpdate(best map[uuid.UUID]mapping) *proto.CoordinateResponse {
|
||||
resp := new(proto.CoordinateResponse)
|
||||
|
||||
for k, n := range nodes {
|
||||
sn, ok := m.sent[k]
|
||||
if !ok {
|
||||
resp.PeerUpdates = append(resp.PeerUpdates, &proto.CoordinateResponse_PeerUpdate{
|
||||
Uuid: agpl.UUIDToByteSlice(k),
|
||||
Node: n,
|
||||
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
|
||||
Reason: "new",
|
||||
})
|
||||
for k, mpng := range best {
|
||||
var reason string
|
||||
sm, ok := m.sent[k]
|
||||
switch {
|
||||
case !ok && mpng.kind == proto.CoordinateResponse_PeerUpdate_LOST:
|
||||
// we don't need to send a "lost" update if we've never sent an update about this peer
|
||||
continue
|
||||
}
|
||||
eq, err := sn.Equal(n)
|
||||
if err != nil {
|
||||
m.logger.Critical(m.ctx, "failed to compare nodes", slog.F("old", sn), slog.F("new", n))
|
||||
}
|
||||
if !eq {
|
||||
resp.PeerUpdates = append(resp.PeerUpdates, &proto.CoordinateResponse_PeerUpdate{
|
||||
Uuid: agpl.UUIDToByteSlice(k),
|
||||
Node: n,
|
||||
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
|
||||
Reason: "update",
|
||||
})
|
||||
case !ok && mpng.kind == proto.CoordinateResponse_PeerUpdate_NODE:
|
||||
reason = "new"
|
||||
case ok && sm.kind == proto.CoordinateResponse_PeerUpdate_LOST && mpng.kind == proto.CoordinateResponse_PeerUpdate_LOST:
|
||||
// was lost and remains lost, no update needed
|
||||
continue
|
||||
case ok && sm.kind == proto.CoordinateResponse_PeerUpdate_LOST && mpng.kind == proto.CoordinateResponse_PeerUpdate_NODE:
|
||||
reason = "found"
|
||||
case ok && sm.kind == proto.CoordinateResponse_PeerUpdate_NODE && mpng.kind == proto.CoordinateResponse_PeerUpdate_LOST:
|
||||
reason = "lost"
|
||||
case ok && sm.kind == proto.CoordinateResponse_PeerUpdate_NODE && mpng.kind == proto.CoordinateResponse_PeerUpdate_NODE:
|
||||
eq, err := sm.node.Equal(mpng.node)
|
||||
if err != nil {
|
||||
m.logger.Critical(m.ctx, "failed to compare nodes", slog.F("old", sm.node), slog.F("new", mpng.kind))
|
||||
continue
|
||||
}
|
||||
if eq {
|
||||
continue
|
||||
}
|
||||
reason = "update"
|
||||
}
|
||||
resp.PeerUpdates = append(resp.PeerUpdates, &proto.CoordinateResponse_PeerUpdate{
|
||||
Uuid: agpl.UUIDToByteSlice(k),
|
||||
Node: mpng.node,
|
||||
Kind: mpng.kind,
|
||||
Reason: reason,
|
||||
})
|
||||
m.sent[k] = mpng
|
||||
}
|
||||
|
||||
for k := range m.sent {
|
||||
if _, ok := nodes[k]; !ok {
|
||||
if _, ok := best[k]; !ok {
|
||||
resp.PeerUpdates = append(resp.PeerUpdates, &proto.CoordinateResponse_PeerUpdate{
|
||||
Uuid: agpl.UUIDToByteSlice(k),
|
||||
Kind: proto.CoordinateResponse_PeerUpdate_DISCONNECTED,
|
||||
Reason: "disconnected",
|
||||
})
|
||||
delete(m.sent, k)
|
||||
}
|
||||
}
|
||||
|
||||
m.sent = nodes
|
||||
|
||||
if len(resp.PeerUpdates) == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -1069,10 +1107,6 @@ func (q *querier) mappingQuery(peer mKey) error {
|
||||
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
|
||||
return err
|
||||
}
|
||||
if len(bindings) == 0 {
|
||||
logger.Debug(q.ctx, "no mappings, nothing to do")
|
||||
return nil
|
||||
}
|
||||
mappings, err := q.bindingsToMappings(bindings)
|
||||
if err != nil {
|
||||
logger.Debug(q.ctx, "failed to convert mappings", slog.Error(err))
|
||||
@@ -1100,11 +1134,16 @@ func (q *querier) bindingsToMappings(bindings []database.GetTailnetTunnelPeerBin
|
||||
q.logger.Error(q.ctx, "failed to unmarshal node", slog.Error(err))
|
||||
return nil, backoff.Permanent(err)
|
||||
}
|
||||
kind := proto.CoordinateResponse_PeerUpdate_NODE
|
||||
if binding.Status == database.TailnetStatusLost {
|
||||
kind = proto.CoordinateResponse_PeerUpdate_LOST
|
||||
}
|
||||
mappings = append(mappings, mapping{
|
||||
peer: binding.PeerID,
|
||||
coordinator: binding.CoordinatorID,
|
||||
updatedAt: binding.UpdatedAt,
|
||||
node: node,
|
||||
kind: kind,
|
||||
})
|
||||
}
|
||||
return mappings, nil
|
||||
@@ -1326,6 +1365,7 @@ type mapping struct {
|
||||
coordinator uuid.UUID
|
||||
updatedAt time.Time
|
||||
node *proto.Node
|
||||
kind proto.CoordinateResponse_PeerUpdate_Kind
|
||||
}
|
||||
|
||||
// querierWorkKey describes two kinds of work the querier needs to do. If peerUpdate
|
||||
|
||||
@@ -71,7 +71,7 @@ func TestPGCoordinatorSingle_ClientWithoutAgent(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
<-client.errChan
|
||||
<-client.closeChan
|
||||
assertEventuallyNoClientsForAgent(ctx, t, store, agentID)
|
||||
assertEventuallyLost(ctx, t, store, client.id)
|
||||
}
|
||||
|
||||
func TestPGCoordinatorSingle_AgentWithoutClients(t *testing.T) {
|
||||
@@ -108,7 +108,7 @@ func TestPGCoordinatorSingle_AgentWithoutClients(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
<-agent.errChan
|
||||
<-agent.closeChan
|
||||
assertEventuallyNoAgents(ctx, t, store, agent.id)
|
||||
assertEventuallyLost(ctx, t, store, agent.id)
|
||||
}
|
||||
|
||||
func TestPGCoordinatorSingle_AgentWithClient(t *testing.T) {
|
||||
@@ -184,8 +184,8 @@ func TestPGCoordinatorSingle_AgentWithClient(t *testing.T) {
|
||||
_ = client.recvErr(ctx, t)
|
||||
client.waitForClose(ctx, t)
|
||||
|
||||
assertEventuallyNoAgents(ctx, t, store, agent.id)
|
||||
assertEventuallyNoClientsForAgent(ctx, t, store, agent.id)
|
||||
assertEventuallyLost(ctx, t, store, agent.id)
|
||||
assertEventuallyLost(ctx, t, store, client.id)
|
||||
}
|
||||
|
||||
func TestPGCoordinatorSingle_MissedHeartbeats(t *testing.T) {
|
||||
@@ -272,7 +272,7 @@ func TestPGCoordinatorSingle_MissedHeartbeats(t *testing.T) {
|
||||
_ = client.recvErr(ctx, t)
|
||||
client.waitForClose(ctx, t)
|
||||
|
||||
assertEventuallyNoClientsForAgent(ctx, t, store, agent.id)
|
||||
assertEventuallyLost(ctx, t, store, client.id)
|
||||
}
|
||||
|
||||
func TestPGCoordinatorSingle_SendsHeartbeats(t *testing.T) {
|
||||
@@ -519,8 +519,8 @@ func TestPGCoordinator_MultiCoordinatorAgent(t *testing.T) {
|
||||
require.ErrorIs(t, err, io.ErrClosedPipe)
|
||||
client.waitForClose(ctx, t)
|
||||
|
||||
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
|
||||
assertEventuallyNoAgents(ctx, t, store, agent1.id)
|
||||
assertEventuallyLost(ctx, t, store, client.id)
|
||||
assertEventuallyLost(ctx, t, store, agent1.id)
|
||||
}
|
||||
|
||||
func TestPGCoordinator_Unhealthy(t *testing.T) {
|
||||
@@ -624,6 +624,63 @@ func TestPGCoordinator_BidirectionalTunnels(t *testing.T) {
|
||||
p2.assertEventuallyHasDERP(p1.id, 1)
|
||||
}
|
||||
|
||||
func TestPGCoordinator_GracefulDisconnect(t *testing.T) {
|
||||
t.Parallel()
|
||||
if !dbtestutil.WillUsePostgres() {
|
||||
t.Skip("test only with postgres")
|
||||
}
|
||||
store, ps := dbtestutil.NewDB(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
|
||||
defer cancel()
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
coordinator, err := tailnet.NewPGCoordV2(ctx, logger, ps, store)
|
||||
require.NoError(t, err)
|
||||
defer coordinator.Close()
|
||||
|
||||
p1 := newTestPeer(ctx, t, coordinator, "p1")
|
||||
defer p1.close(ctx)
|
||||
p2 := newTestPeer(ctx, t, coordinator, "p2")
|
||||
defer p2.close(ctx)
|
||||
p1.addTunnel(p2.id)
|
||||
p1.updateDERP(1)
|
||||
p2.updateDERP(2)
|
||||
|
||||
p1.assertEventuallyHasDERP(p2.id, 2)
|
||||
p2.assertEventuallyHasDERP(p1.id, 1)
|
||||
|
||||
p2.disconnect()
|
||||
p1.assertEventuallyDisconnected(p2.id)
|
||||
p2.assertEventuallyResponsesClosed()
|
||||
}
|
||||
|
||||
func TestPGCoordinator_Lost(t *testing.T) {
|
||||
t.Parallel()
|
||||
if !dbtestutil.WillUsePostgres() {
|
||||
t.Skip("test only with postgres")
|
||||
}
|
||||
store, ps := dbtestutil.NewDB(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
|
||||
defer cancel()
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
coordinator, err := tailnet.NewPGCoordV2(ctx, logger, ps, store)
|
||||
require.NoError(t, err)
|
||||
defer coordinator.Close()
|
||||
|
||||
p1 := newTestPeer(ctx, t, coordinator, "p1")
|
||||
defer p1.close(ctx)
|
||||
p2 := newTestPeer(ctx, t, coordinator, "p2")
|
||||
defer p2.close(ctx)
|
||||
p1.addTunnel(p2.id)
|
||||
p1.updateDERP(1)
|
||||
p2.updateDERP(2)
|
||||
|
||||
p1.assertEventuallyHasDERP(p2.id, 2)
|
||||
p2.assertEventuallyHasDERP(p1.id, 1)
|
||||
|
||||
p2.close(ctx)
|
||||
p1.assertEventuallyLost(p2.id)
|
||||
}
|
||||
|
||||
type testConn struct {
|
||||
ws, serverWS net.Conn
|
||||
nodeChan chan []*agpl.Node
|
||||
@@ -813,6 +870,7 @@ func assertMultiAgentNeverHasDERPs(ctx context.Context, t *testing.T, ma agpl.Mu
|
||||
}
|
||||
|
||||
func assertEventuallyNoAgents(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID) {
|
||||
t.Helper()
|
||||
assert.Eventually(t, func() bool {
|
||||
agents, err := store.GetTailnetPeers(ctx, agentID)
|
||||
if xerrors.Is(err, sql.ErrNoRows) {
|
||||
@@ -825,6 +883,25 @@ func assertEventuallyNoAgents(ctx context.Context, t *testing.T, store database.
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
}
|
||||
|
||||
func assertEventuallyLost(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID) {
|
||||
t.Helper()
|
||||
assert.Eventually(t, func() bool {
|
||||
peers, err := store.GetTailnetPeers(ctx, agentID)
|
||||
if xerrors.Is(err, sql.ErrNoRows) {
|
||||
return false
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
for _, peer := range peers {
|
||||
if peer.Status == database.TailnetStatusOk {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
}
|
||||
|
||||
func assertEventuallyNoClientsForAgent(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID) {
|
||||
t.Helper()
|
||||
assert.Eventually(t, func() bool {
|
||||
@@ -839,6 +916,11 @@ func assertEventuallyNoClientsForAgent(ctx context.Context, t *testing.T, store
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
}
|
||||
|
||||
type peerStatus struct {
|
||||
preferredDERP int32
|
||||
status proto.CoordinateResponse_PeerUpdate_Kind
|
||||
}
|
||||
|
||||
type testPeer struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
@@ -847,11 +929,11 @@ type testPeer struct {
|
||||
name string
|
||||
resps <-chan *proto.CoordinateResponse
|
||||
reqs chan<- *proto.CoordinateRequest
|
||||
derps map[uuid.UUID]int32
|
||||
peers map[uuid.UUID]peerStatus
|
||||
}
|
||||
|
||||
func newTestPeer(ctx context.Context, t testing.TB, coord agpl.CoordinatorV2, name string, id ...uuid.UUID) *testPeer {
|
||||
p := &testPeer{t: t, name: name, derps: make(map[uuid.UUID]int32)}
|
||||
p := &testPeer{t: t, name: name, peers: make(map[uuid.UUID]peerStatus)}
|
||||
p.ctx, p.cancel = context.WithCancel(ctx)
|
||||
if len(id) > 1 {
|
||||
t.Fatal("too many")
|
||||
@@ -890,38 +972,102 @@ func (p *testPeer) updateDERP(derp int32) {
|
||||
}
|
||||
}
|
||||
|
||||
func (p *testPeer) disconnect() {
|
||||
p.t.Helper()
|
||||
req := &proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}}
|
||||
select {
|
||||
case <-p.ctx.Done():
|
||||
p.t.Errorf("timeout updating node for %s", p.name)
|
||||
return
|
||||
case p.reqs <- req:
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (p *testPeer) assertEventuallyHasDERP(other uuid.UUID, derp int32) {
|
||||
p.t.Helper()
|
||||
for {
|
||||
d, ok := p.derps[other]
|
||||
if ok && d == derp {
|
||||
o, ok := p.peers[other]
|
||||
if ok && o.preferredDERP == derp {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-p.ctx.Done():
|
||||
p.t.Errorf("timeout waiting for response for %s", p.name)
|
||||
if err := p.handleOneResp(); err != nil {
|
||||
assert.NoError(p.t, err)
|
||||
return
|
||||
case resp, ok := <-p.resps:
|
||||
if !ok {
|
||||
p.t.Errorf("responses closed for %s", p.name)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *testPeer) assertEventuallyDisconnected(other uuid.UUID) {
|
||||
p.t.Helper()
|
||||
for {
|
||||
_, ok := p.peers[other]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if err := p.handleOneResp(); err != nil {
|
||||
assert.NoError(p.t, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *testPeer) assertEventuallyLost(other uuid.UUID) {
|
||||
p.t.Helper()
|
||||
for {
|
||||
o := p.peers[other]
|
||||
if o.status == proto.CoordinateResponse_PeerUpdate_LOST {
|
||||
return
|
||||
}
|
||||
if err := p.handleOneResp(); err != nil {
|
||||
assert.NoError(p.t, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *testPeer) assertEventuallyResponsesClosed() {
|
||||
p.t.Helper()
|
||||
for {
|
||||
err := p.handleOneResp()
|
||||
if xerrors.Is(err, responsesClosed) {
|
||||
return
|
||||
}
|
||||
if !assert.NoError(p.t, err) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var responsesClosed = xerrors.New("responses closed")
|
||||
|
||||
func (p *testPeer) handleOneResp() error {
|
||||
select {
|
||||
case <-p.ctx.Done():
|
||||
return p.ctx.Err()
|
||||
case resp, ok := <-p.resps:
|
||||
if !ok {
|
||||
return responsesClosed
|
||||
}
|
||||
for _, update := range resp.PeerUpdates {
|
||||
id, err := uuid.FromBytes(update.Uuid)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, update := range resp.PeerUpdates {
|
||||
id, err := uuid.FromBytes(update.Uuid)
|
||||
if !assert.NoError(p.t, err) {
|
||||
return
|
||||
}
|
||||
switch update.Kind {
|
||||
case proto.CoordinateResponse_PeerUpdate_NODE:
|
||||
p.derps[id] = update.Node.PreferredDerp
|
||||
case proto.CoordinateResponse_PeerUpdate_DISCONNECTED:
|
||||
delete(p.derps, id)
|
||||
default:
|
||||
p.t.Errorf("unhandled update kind %s", update.Kind)
|
||||
switch update.Kind {
|
||||
case proto.CoordinateResponse_PeerUpdate_NODE, proto.CoordinateResponse_PeerUpdate_LOST:
|
||||
p.peers[id] = peerStatus{
|
||||
preferredDERP: update.GetNode().GetPreferredDerp(),
|
||||
status: update.Kind,
|
||||
}
|
||||
case proto.CoordinateResponse_PeerUpdate_DISCONNECTED:
|
||||
delete(p.peers, id)
|
||||
default:
|
||||
return xerrors.Errorf("unhandled update kind %s", update.Kind)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *testPeer) close(ctx context.Context) {
|
||||
|
||||
Reference in New Issue
Block a user