refactor(enterprise/replicasync): support multiple callbacks

This commit is contained in:
Jon Ayers
2026-05-26 19:44:39 +00:00
parent 14cb14a3b4
commit 903836763c
3 changed files with 46 additions and 14 deletions
+2 -2
View File
@@ -953,7 +953,7 @@ func (api *API) updateEntitlements(ctx context.Context) error {
coordinator = haCoordinator
}
api.replicaManager.SetCallback(func() {
api.replicaManager.AddCallback(func() {
// Only update DERP mesh if the built-in server is enabled.
if api.Options.DeploymentValues.DERP.Server.Enable {
addresses := make([]string, 0)
@@ -973,7 +973,7 @@ func (api *API) updateEntitlements(ctx context.Context) error {
if api.Options.DeploymentValues.DERP.Server.Enable {
api.derpMesh.SetAddresses([]string{}, false)
}
api.replicaManager.SetCallback(func() {
api.replicaManager.AddCallback(func() {
// If the amount of replicas change, so should our entitlements.
// This is to display a warning in the UI if the user is unlicensed.
_ = api.updateEntitlements(api.ctx)
+15 -11
View File
@@ -122,10 +122,10 @@ type Manager struct {
closed chan (struct{})
closeCancel context.CancelFunc
self database.Replica
mutex sync.Mutex
peers []database.Replica
callback func()
self database.Replica
mutex sync.Mutex
peers []database.Replica
callbacks []func()
}
func (m *Manager) ID() uuid.UUID {
@@ -312,7 +312,6 @@ func (m *Manager) syncReplicas(ctx context.Context) error {
}
m.mutex.Lock()
defer m.mutex.Unlock()
// nolint:gocritic // Updating a replica is a system function.
replica, err := m.db.UpdateReplica(dbauthz.AsSystemRestricted(ctx), database.UpdateReplicaParams{
ID: m.self.ID,
@@ -330,6 +329,7 @@ func (m *Manager) syncReplicas(ctx context.Context) error {
})
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
m.mutex.Unlock()
return xerrors.Errorf("update replica: %w", err)
}
// self replica has been cleaned up, we must reinsert
@@ -348,6 +348,7 @@ func (m *Manager) syncReplicas(ctx context.Context) error {
Primary: m.self.Primary,
})
if err != nil {
m.mutex.Unlock()
return xerrors.Errorf("update replica: %w", err)
}
}
@@ -355,12 +356,15 @@ func (m *Manager) syncReplicas(ctx context.Context) error {
// Publish an update occurred!
err = m.PublishUpdate()
if err != nil {
m.mutex.Unlock()
return xerrors.Errorf("publish replica update: %w", err)
}
}
m.self = replica
if m.callback != nil {
go m.callback()
callbacks := append([]func(){}, m.callbacks...)
m.mutex.Unlock()
for _, callback := range callbacks {
go callback()
}
return nil
}
@@ -439,12 +443,12 @@ func (m *Manager) regionID() int32 {
return m.self.RegionID
}
// SetCallback sets a function to execute whenever new peers
// AddCallback adds a function to execute whenever new peers
// are refreshed or updated.
func (m *Manager) SetCallback(callback func()) {
func (m *Manager) AddCallback(callback func()) {
m.mutex.Lock()
defer m.mutex.Unlock()
m.callback = callback
m.callbacks = append(m.callbacks, callback)
m.mutex.Unlock()
// Instantly call the callback to inform replicas!
go callback()
}
+29 -1
View File
@@ -207,6 +207,34 @@ func TestReplica(t *testing.T) {
return len(server.Regional()) == 0
}, testutil.WaitShort, testutil.IntervalFast)
})
t.Run("MultipleCallbacks", func(t *testing.T) {
t.Parallel()
dh := &derpyHandler{}
defer dh.requireOnlyDERPPaths(t)
srv := httptest.NewServer(dh)
defer srv.Close()
db, pubsub := dbtestutil.NewDB(t)
ctx, cancelCtx := context.WithCancel(context.Background())
defer cancelCtx()
server, err := replicasync.New(ctx, testutil.Logger(t), db, pubsub, &replicasync.Options{
RelayAddress: srv.URL,
})
require.NoError(t, err)
defer server.Close()
var first atomic.Int64
var second atomic.Int64
server.AddCallback(func() { first.Add(1) })
server.AddCallback(func() { second.Add(1) })
require.Eventually(t, func() bool {
return first.Load() >= 1 && second.Load() >= 1
}, testutil.WaitShort, testutil.IntervalFast)
require.NoError(t, server.UpdateNow(ctx))
require.Eventually(t, func() bool {
return first.Load() >= 2 && second.Load() >= 2
}, testutil.WaitShort, testutil.IntervalFast)
})
t.Run("TwentyConcurrent", func(t *testing.T) {
// Ensures that twenty concurrent replicas can spawn and all
// discover each other in parallel!
@@ -233,7 +261,7 @@ func TestReplica(t *testing.T) {
done := false
var m sync.Mutex
server.SetCallback(func() {
server.AddCallback(func() {
m.Lock()
defer m.Unlock()
if len(server.AllPrimary()) != count {