fix: wait for PGCoordinator to clean up db state (#13351)

c.f. https://github.com/coder/coder/pull/13192#issuecomment-2097657692

We need to wait for PGCoordinator to finish its work before returning on `Close()`, so that we delete database state (best effort -- if this fails others will filter it out based on heartbeats).
This commit is contained in:
Spike Curtis
2024-05-24 12:01:03 +04:00
committed by GitHub
parent e5bb0a7a00
commit a0962ba089
3 changed files with 69 additions and 2 deletions
+21 -2
View File
@@ -161,11 +161,12 @@ func newPGCoordInternal(
closed: make(chan struct{}),
}
go func() {
// when the main context is canceled, or the coordinator closed, the binder and tunneler
// always eventually stop. Once they stop it's safe to cancel the querier context, which
// when the main context is canceled, or the coordinator closed, the binder, tunneler, and
// handshaker always eventually stop. Once they stop it's safe to cancel the querier context, which
// has the effect of deleting the coordinator from the database and ceasing heartbeats.
c.binder.workerWG.Wait()
c.tunneler.workerWG.Wait()
c.handshaker.workerWG.Wait()
querierCancel()
}()
logger.Info(ctx, "starting coordinator")
@@ -231,6 +232,7 @@ func (c *pgCoord) Close() error {
c.logger.Info(c.ctx, "closing coordinator")
c.cancel()
c.closeOnce.Do(func() { close(c.closed) })
c.querier.wait()
return nil
}
@@ -795,6 +797,8 @@ type querier struct {
workQ *workQ[querierWorkKey]
wg sync.WaitGroup
heartbeats *heartbeats
updates <-chan hbUpdate
@@ -831,6 +835,7 @@ func newQuerier(ctx context.Context,
}
q.subscribe()
q.wg.Add(2 + numWorkers)
go func() {
<-firstHeartbeat
go q.handleIncoming()
@@ -842,7 +847,13 @@ func newQuerier(ctx context.Context,
return q
}
func (q *querier) wait() {
q.wg.Wait()
q.heartbeats.wg.Wait()
}
func (q *querier) handleIncoming() {
defer q.wg.Done()
for {
select {
case <-q.ctx.Done():
@@ -919,6 +930,7 @@ func (q *querier) cleanupConn(c *connIO) {
}
func (q *querier) worker() {
defer q.wg.Done()
eb := backoff.NewExponentialBackOff()
eb.MaxElapsedTime = 0 // retry indefinitely
eb.MaxInterval = dbMaxBackoff
@@ -1204,6 +1216,7 @@ func (q *querier) resyncPeerMappings() {
}
func (q *querier) handleUpdates() {
defer q.wg.Done()
for {
select {
case <-q.ctx.Done():
@@ -1451,6 +1464,8 @@ type heartbeats struct {
coordinators map[uuid.UUID]time.Time
timer *time.Timer
wg sync.WaitGroup
// overwritten in tests, but otherwise constant
cleanupPeriod time.Duration
}
@@ -1472,6 +1487,7 @@ func newHeartbeats(
coordinators: make(map[uuid.UUID]time.Time),
cleanupPeriod: cleanupPeriod,
}
h.wg.Add(3)
go h.subscribe()
go h.sendBeats()
go h.cleanupLoop()
@@ -1502,6 +1518,7 @@ func (h *heartbeats) filter(mappings []mapping) []mapping {
}
func (h *heartbeats) subscribe() {
defer h.wg.Done()
eb := backoff.NewExponentialBackOff()
eb.MaxElapsedTime = 0 // retry indefinitely
eb.MaxInterval = dbMaxBackoff
@@ -1611,6 +1628,7 @@ func (h *heartbeats) checkExpiry() {
}
func (h *heartbeats) sendBeats() {
defer h.wg.Done()
// send an initial heartbeat so that other coordinators can start using our bindings right away.
h.sendBeat()
close(h.firstHeartbeat) // signal binder it can start writing
@@ -1662,6 +1680,7 @@ func (h *heartbeats) sendDelete() {
}
func (h *heartbeats) cleanupLoop() {
defer h.wg.Done()
h.cleanup()
tkr := time.NewTicker(h.cleanupPeriod)
defer tkr.Stop()
@@ -66,6 +66,7 @@ func TestHeartbeats_Cleanup(t *testing.T) {
store: mStore,
cleanupPeriod: time.Millisecond,
}
uut.wg.Add(1)
go uut.cleanupLoop()
for i := 0; i < 6; i++ {
+47
View File
@@ -864,6 +864,53 @@ func TestPGCoordinator_Lost(t *testing.T) {
agpltest.LostTest(ctx, t, coordinator)
}
func TestPGCoordinator_DeleteOnClose(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancel()
ctrl := gomock.NewController(t)
mStore := dbmock.NewMockStore(ctrl)
ps := pubsub.NewInMemory()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
upsertDone := make(chan struct{})
deleteCalled := make(chan struct{})
finishDelete := make(chan struct{})
mStore.EXPECT().UpsertTailnetCoordinator(gomock.Any(), gomock.Any()).
MinTimes(1).
Do(func(_ context.Context, _ uuid.UUID) { close(upsertDone) }).
Return(database.TailnetCoordinator{}, nil)
mStore.EXPECT().DeleteCoordinator(gomock.Any(), gomock.Any()).
Times(1).
Do(func(_ context.Context, _ uuid.UUID) {
close(deleteCalled)
<-finishDelete
}).
Return(nil)
// extra calls we don't particularly care about for this test
mStore.EXPECT().CleanTailnetCoordinators(gomock.Any()).AnyTimes().Return(nil)
mStore.EXPECT().CleanTailnetLostPeers(gomock.Any()).AnyTimes().Return(nil)
mStore.EXPECT().CleanTailnetTunnels(gomock.Any()).AnyTimes().Return(nil)
uut, err := tailnet.NewPGCoord(ctx, logger, ps, mStore)
require.NoError(t, err)
testutil.RequireRecvCtx(ctx, t, upsertDone)
closeErr := make(chan error, 1)
go func() {
closeErr <- uut.Close()
}()
select {
case <-closeErr:
t.Fatal("close returned before DeleteCoordinator called")
case <-deleteCalled:
close(finishDelete)
err := testutil.RequireRecvCtx(ctx, t, closeErr)
require.NoError(t, err)
}
}
type testConn struct {
ws, serverWS net.Conn
nodeChan chan []*agpl.Node