mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
fix: wait for PGCoordinator to clean up db state (#13351)
c.f. https://github.com/coder/coder/pull/13192#issuecomment-2097657692 We need to wait for PGCoordinator to finish its work before returning on `Close()`, so that we delete database state (best effort -- if this fails others will filter it out based on heartbeats).
This commit is contained in:
@@ -161,11 +161,12 @@ func newPGCoordInternal(
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
go func() {
|
||||
// when the main context is canceled, or the coordinator closed, the binder and tunneler
|
||||
// always eventually stop. Once they stop it's safe to cancel the querier context, which
|
||||
// when the main context is canceled, or the coordinator closed, the binder, tunneler, and
|
||||
// handshaker always eventually stop. Once they stop it's safe to cancel the querier context, which
|
||||
// has the effect of deleting the coordinator from the database and ceasing heartbeats.
|
||||
c.binder.workerWG.Wait()
|
||||
c.tunneler.workerWG.Wait()
|
||||
c.handshaker.workerWG.Wait()
|
||||
querierCancel()
|
||||
}()
|
||||
logger.Info(ctx, "starting coordinator")
|
||||
@@ -231,6 +232,7 @@ func (c *pgCoord) Close() error {
|
||||
c.logger.Info(c.ctx, "closing coordinator")
|
||||
c.cancel()
|
||||
c.closeOnce.Do(func() { close(c.closed) })
|
||||
c.querier.wait()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -795,6 +797,8 @@ type querier struct {
|
||||
|
||||
workQ *workQ[querierWorkKey]
|
||||
|
||||
wg sync.WaitGroup
|
||||
|
||||
heartbeats *heartbeats
|
||||
updates <-chan hbUpdate
|
||||
|
||||
@@ -831,6 +835,7 @@ func newQuerier(ctx context.Context,
|
||||
}
|
||||
q.subscribe()
|
||||
|
||||
q.wg.Add(2 + numWorkers)
|
||||
go func() {
|
||||
<-firstHeartbeat
|
||||
go q.handleIncoming()
|
||||
@@ -842,7 +847,13 @@ func newQuerier(ctx context.Context,
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *querier) wait() {
|
||||
q.wg.Wait()
|
||||
q.heartbeats.wg.Wait()
|
||||
}
|
||||
|
||||
func (q *querier) handleIncoming() {
|
||||
defer q.wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-q.ctx.Done():
|
||||
@@ -919,6 +930,7 @@ func (q *querier) cleanupConn(c *connIO) {
|
||||
}
|
||||
|
||||
func (q *querier) worker() {
|
||||
defer q.wg.Done()
|
||||
eb := backoff.NewExponentialBackOff()
|
||||
eb.MaxElapsedTime = 0 // retry indefinitely
|
||||
eb.MaxInterval = dbMaxBackoff
|
||||
@@ -1204,6 +1216,7 @@ func (q *querier) resyncPeerMappings() {
|
||||
}
|
||||
|
||||
func (q *querier) handleUpdates() {
|
||||
defer q.wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-q.ctx.Done():
|
||||
@@ -1451,6 +1464,8 @@ type heartbeats struct {
|
||||
coordinators map[uuid.UUID]time.Time
|
||||
timer *time.Timer
|
||||
|
||||
wg sync.WaitGroup
|
||||
|
||||
// overwritten in tests, but otherwise constant
|
||||
cleanupPeriod time.Duration
|
||||
}
|
||||
@@ -1472,6 +1487,7 @@ func newHeartbeats(
|
||||
coordinators: make(map[uuid.UUID]time.Time),
|
||||
cleanupPeriod: cleanupPeriod,
|
||||
}
|
||||
h.wg.Add(3)
|
||||
go h.subscribe()
|
||||
go h.sendBeats()
|
||||
go h.cleanupLoop()
|
||||
@@ -1502,6 +1518,7 @@ func (h *heartbeats) filter(mappings []mapping) []mapping {
|
||||
}
|
||||
|
||||
func (h *heartbeats) subscribe() {
|
||||
defer h.wg.Done()
|
||||
eb := backoff.NewExponentialBackOff()
|
||||
eb.MaxElapsedTime = 0 // retry indefinitely
|
||||
eb.MaxInterval = dbMaxBackoff
|
||||
@@ -1611,6 +1628,7 @@ func (h *heartbeats) checkExpiry() {
|
||||
}
|
||||
|
||||
func (h *heartbeats) sendBeats() {
|
||||
defer h.wg.Done()
|
||||
// send an initial heartbeat so that other coordinators can start using our bindings right away.
|
||||
h.sendBeat()
|
||||
close(h.firstHeartbeat) // signal binder it can start writing
|
||||
@@ -1662,6 +1680,7 @@ func (h *heartbeats) sendDelete() {
|
||||
}
|
||||
|
||||
func (h *heartbeats) cleanupLoop() {
|
||||
defer h.wg.Done()
|
||||
h.cleanup()
|
||||
tkr := time.NewTicker(h.cleanupPeriod)
|
||||
defer tkr.Stop()
|
||||
|
||||
@@ -66,6 +66,7 @@ func TestHeartbeats_Cleanup(t *testing.T) {
|
||||
store: mStore,
|
||||
cleanupPeriod: time.Millisecond,
|
||||
}
|
||||
uut.wg.Add(1)
|
||||
go uut.cleanupLoop()
|
||||
|
||||
for i := 0; i < 6; i++ {
|
||||
|
||||
@@ -864,6 +864,53 @@ func TestPGCoordinator_Lost(t *testing.T) {
|
||||
agpltest.LostTest(ctx, t, coordinator)
|
||||
}
|
||||
|
||||
func TestPGCoordinator_DeleteOnClose(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
|
||||
defer cancel()
|
||||
ctrl := gomock.NewController(t)
|
||||
mStore := dbmock.NewMockStore(ctrl)
|
||||
ps := pubsub.NewInMemory()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
|
||||
upsertDone := make(chan struct{})
|
||||
deleteCalled := make(chan struct{})
|
||||
finishDelete := make(chan struct{})
|
||||
mStore.EXPECT().UpsertTailnetCoordinator(gomock.Any(), gomock.Any()).
|
||||
MinTimes(1).
|
||||
Do(func(_ context.Context, _ uuid.UUID) { close(upsertDone) }).
|
||||
Return(database.TailnetCoordinator{}, nil)
|
||||
mStore.EXPECT().DeleteCoordinator(gomock.Any(), gomock.Any()).
|
||||
Times(1).
|
||||
Do(func(_ context.Context, _ uuid.UUID) {
|
||||
close(deleteCalled)
|
||||
<-finishDelete
|
||||
}).
|
||||
Return(nil)
|
||||
|
||||
// extra calls we don't particularly care about for this test
|
||||
mStore.EXPECT().CleanTailnetCoordinators(gomock.Any()).AnyTimes().Return(nil)
|
||||
mStore.EXPECT().CleanTailnetLostPeers(gomock.Any()).AnyTimes().Return(nil)
|
||||
mStore.EXPECT().CleanTailnetTunnels(gomock.Any()).AnyTimes().Return(nil)
|
||||
|
||||
uut, err := tailnet.NewPGCoord(ctx, logger, ps, mStore)
|
||||
require.NoError(t, err)
|
||||
testutil.RequireRecvCtx(ctx, t, upsertDone)
|
||||
closeErr := make(chan error, 1)
|
||||
go func() {
|
||||
closeErr <- uut.Close()
|
||||
}()
|
||||
select {
|
||||
case <-closeErr:
|
||||
t.Fatal("close returned before DeleteCoordinator called")
|
||||
case <-deleteCalled:
|
||||
close(finishDelete)
|
||||
err := testutil.RequireRecvCtx(ctx, t, closeErr)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
type testConn struct {
|
||||
ws, serverWS net.Conn
|
||||
nodeChan chan []*agpl.Node
|
||||
|
||||
Reference in New Issue
Block a user