diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index 33732eea3d..ce2c0b069b 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -953,7 +953,7 @@ func (api *API) updateEntitlements(ctx context.Context) error { coordinator = haCoordinator } - api.replicaManager.SetCallback(func() { + api.replicaManager.AddCallback(func() { // Only update DERP mesh if the built-in server is enabled. if api.Options.DeploymentValues.DERP.Server.Enable { addresses := make([]string, 0) @@ -973,7 +973,7 @@ func (api *API) updateEntitlements(ctx context.Context) error { if api.Options.DeploymentValues.DERP.Server.Enable { api.derpMesh.SetAddresses([]string{}, false) } - api.replicaManager.SetCallback(func() { + api.replicaManager.AddCallback(func() { // If the amount of replicas change, so should our entitlements. // This is to display a warning in the UI if the user is unlicensed. _ = api.updateEntitlements(api.ctx) diff --git a/enterprise/replicasync/replicasync.go b/enterprise/replicasync/replicasync.go index f69db6ed94..fc37554477 100644 --- a/enterprise/replicasync/replicasync.go +++ b/enterprise/replicasync/replicasync.go @@ -122,10 +122,10 @@ type Manager struct { closed chan (struct{}) closeCancel context.CancelFunc - self database.Replica - mutex sync.Mutex - peers []database.Replica - callback func() + self database.Replica + mutex sync.Mutex + peers []database.Replica + callbacks []func() } func (m *Manager) ID() uuid.UUID { @@ -312,7 +312,6 @@ func (m *Manager) syncReplicas(ctx context.Context) error { } m.mutex.Lock() - defer m.mutex.Unlock() // nolint:gocritic // Updating a replica is a system function. replica, err := m.db.UpdateReplica(dbauthz.AsSystemRestricted(ctx), database.UpdateReplicaParams{ ID: m.self.ID, @@ -330,6 +329,7 @@ func (m *Manager) syncReplicas(ctx context.Context) error { }) if err != nil { if !errors.Is(err, sql.ErrNoRows) { + m.mutex.Unlock() return xerrors.Errorf("update replica: %w", err) } // self replica has been cleaned up, we must reinsert @@ -348,6 +348,7 @@ func (m *Manager) syncReplicas(ctx context.Context) error { Primary: m.self.Primary, }) if err != nil { + m.mutex.Unlock() return xerrors.Errorf("update replica: %w", err) } } @@ -355,12 +356,15 @@ func (m *Manager) syncReplicas(ctx context.Context) error { // Publish an update occurred! err = m.PublishUpdate() if err != nil { + m.mutex.Unlock() return xerrors.Errorf("publish replica update: %w", err) } } m.self = replica - if m.callback != nil { - go m.callback() + callbacks := append([]func(){}, m.callbacks...) + m.mutex.Unlock() + for _, callback := range callbacks { + go callback() } return nil } @@ -439,12 +443,12 @@ func (m *Manager) regionID() int32 { return m.self.RegionID } -// SetCallback sets a function to execute whenever new peers +// AddCallback adds a function to execute whenever new peers // are refreshed or updated. -func (m *Manager) SetCallback(callback func()) { +func (m *Manager) AddCallback(callback func()) { m.mutex.Lock() - defer m.mutex.Unlock() - m.callback = callback + m.callbacks = append(m.callbacks, callback) + m.mutex.Unlock() // Instantly call the callback to inform replicas! go callback() } diff --git a/enterprise/replicasync/replicasync_test.go b/enterprise/replicasync/replicasync_test.go index 0438db8e21..64116e2a30 100644 --- a/enterprise/replicasync/replicasync_test.go +++ b/enterprise/replicasync/replicasync_test.go @@ -207,6 +207,34 @@ func TestReplica(t *testing.T) { return len(server.Regional()) == 0 }, testutil.WaitShort, testutil.IntervalFast) }) + t.Run("MultipleCallbacks", func(t *testing.T) { + t.Parallel() + dh := &derpyHandler{} + defer dh.requireOnlyDERPPaths(t) + srv := httptest.NewServer(dh) + defer srv.Close() + db, pubsub := dbtestutil.NewDB(t) + ctx, cancelCtx := context.WithCancel(context.Background()) + defer cancelCtx() + server, err := replicasync.New(ctx, testutil.Logger(t), db, pubsub, &replicasync.Options{ + RelayAddress: srv.URL, + }) + require.NoError(t, err) + defer server.Close() + + var first atomic.Int64 + var second atomic.Int64 + server.AddCallback(func() { first.Add(1) }) + server.AddCallback(func() { second.Add(1) }) + require.Eventually(t, func() bool { + return first.Load() >= 1 && second.Load() >= 1 + }, testutil.WaitShort, testutil.IntervalFast) + + require.NoError(t, server.UpdateNow(ctx)) + require.Eventually(t, func() bool { + return first.Load() >= 2 && second.Load() >= 2 + }, testutil.WaitShort, testutil.IntervalFast) + }) t.Run("TwentyConcurrent", func(t *testing.T) { // Ensures that twenty concurrent replicas can spawn and all // discover each other in parallel! @@ -233,7 +261,7 @@ func TestReplica(t *testing.T) { done := false var m sync.Mutex - server.SetCallback(func() { + server.AddCallback(func() { m.Lock() defer m.Unlock() if len(server.AllPrimary()) != count {