chore: export MsgQueue from pubsub package (#25707)

<!--

If you have used AI to produce some or all of this PR, please ensure you have read our [AI Contribution guidelines](https://coder.com/docs/about/contributing/AI_CONTRIBUTING) before submitting.

-->

Makes `MsgQueue` exported, so it can be used in pubsub implementations outside PGPubsub.
This commit is contained in:
Spike Curtis
2026-05-27 10:11:51 -04:00
committed by GitHub
parent d1e27889eb
commit 6f06ace949
3 changed files with 148 additions and 147 deletions
+18 -18
View File
@@ -56,14 +56,14 @@ type msgOrErr struct {
err error err error
} }
// msgQueue implements a fixed length queue with the ability to replace elements // MsgQueue implements a fixed length queue with the ability to replace elements
// after they are queued (but before they are dequeued). // after they are queued (but before they are dequeued).
// //
// The purpose of this data structure is to build something that works a bit // The purpose of this data structure is to build something that works a bit
// like a golang channel, but if the queue is full, then we can replace the // like a golang channel, but if the queue is full, then we can replace the
// last element with an error so that the subscriber can get notified that some // last element with an error so that the subscriber can get notified that some
// messages were dropped, all without blocking. // messages were dropped, all without blocking.
type msgQueue struct { type MsgQueue struct {
ctx context.Context ctx context.Context
cond *sync.Cond cond *sync.Cond
q [BufferSize]msgOrErr q [BufferSize]msgOrErr
@@ -74,11 +74,11 @@ type msgQueue struct {
le ListenerWithErr le ListenerWithErr
} }
func newMsgQueue(ctx context.Context, l Listener, le ListenerWithErr) *msgQueue { func NewMsgQueue(ctx context.Context, l Listener, le ListenerWithErr) *MsgQueue {
if l == nil && le == nil { if l == nil && le == nil {
panic("l or le must be non-nil") panic("l or le must be non-nil")
} }
q := &msgQueue{ q := &MsgQueue{
ctx: ctx, ctx: ctx,
cond: sync.NewCond(&sync.Mutex{}), cond: sync.NewCond(&sync.Mutex{}),
l: l, l: l,
@@ -88,7 +88,7 @@ func newMsgQueue(ctx context.Context, l Listener, le ListenerWithErr) *msgQueue
return q return q
} }
func (q *msgQueue) run() { func (q *MsgQueue) run() {
for { for {
// wait until there is something on the queue or we are closed // wait until there is something on the queue or we are closed
q.cond.L.Lock() q.cond.L.Lock()
@@ -125,7 +125,7 @@ func (q *msgQueue) run() {
} }
} }
func (q *msgQueue) enqueue(msg []byte) { func (q *MsgQueue) Enqueue(msg []byte) {
q.cond.L.Lock() q.cond.L.Lock()
defer q.cond.L.Unlock() defer q.cond.L.Unlock()
@@ -149,15 +149,15 @@ func (q *msgQueue) enqueue(msg []byte) {
q.cond.Broadcast() q.cond.Broadcast()
} }
func (q *msgQueue) close() { func (q *MsgQueue) Close() {
q.cond.L.Lock() q.cond.L.Lock()
defer q.cond.L.Unlock() defer q.cond.L.Unlock()
defer q.cond.Broadcast() defer q.cond.Broadcast()
q.closed = true q.closed = true
} }
// dropped records an error in the queue that messages might have been dropped // Dropped records an error in the queue that messages might have been Dropped
func (q *msgQueue) dropped() { func (q *MsgQueue) Dropped() {
q.cond.L.Lock() q.cond.L.Lock()
defer q.cond.L.Unlock() defer q.cond.L.Unlock()
@@ -195,7 +195,7 @@ func (l pqListenerShim) NotifyChan() <-chan *pq.Notification {
} }
type queueSet struct { type queueSet struct {
m map[*msgQueue]struct{} m map[*MsgQueue]struct{}
// unlistenInProgress will be non-nil if another goroutine is unlistening for the event this // unlistenInProgress will be non-nil if another goroutine is unlistening for the event this
// queueSet corresponds to. If non-nil, that goroutine will close the channel when it is done. // queueSet corresponds to. If non-nil, that goroutine will close the channel when it is done.
unlistenInProgress chan struct{} unlistenInProgress chan struct{}
@@ -203,7 +203,7 @@ type queueSet struct {
func newQueueSet() *queueSet { func newQueueSet() *queueSet {
return &queueSet{ return &queueSet{
m: make(map[*msgQueue]struct{}), m: make(map[*MsgQueue]struct{}),
} }
} }
@@ -243,19 +243,19 @@ const BufferSize = 2048
// Subscribe calls the listener when an event matching the name is received. // Subscribe calls the listener when an event matching the name is received.
func (p *PGPubsub) Subscribe(event string, listener Listener) (cancel func(), err error) { func (p *PGPubsub) Subscribe(event string, listener Listener) (cancel func(), err error) {
return p.subscribeQueue(event, newMsgQueue(context.Background(), listener, nil)) return p.subscribeQueue(event, NewMsgQueue(context.Background(), listener, nil))
} }
func (p *PGPubsub) SubscribeWithErr(event string, listener ListenerWithErr) (cancel func(), err error) { func (p *PGPubsub) SubscribeWithErr(event string, listener ListenerWithErr) (cancel func(), err error) {
return p.subscribeQueue(event, newMsgQueue(context.Background(), nil, listener)) return p.subscribeQueue(event, NewMsgQueue(context.Background(), nil, listener))
} }
func (p *PGPubsub) subscribeQueue(event string, newQ *msgQueue) (cancel func(), err error) { func (p *PGPubsub) subscribeQueue(event string, newQ *MsgQueue) (cancel func(), err error) {
defer func() { defer func() {
if err != nil { if err != nil {
// if we hit an error, we need to close the queue so we don't // if we hit an error, we need to close the queue so we don't
// leak its goroutine. // leak its goroutine.
newQ.close() newQ.Close()
p.subscribesTotal.WithLabelValues("false").Inc() p.subscribesTotal.WithLabelValues("false").Inc()
} else { } else {
p.subscribesTotal.WithLabelValues("true").Inc() p.subscribesTotal.WithLabelValues("true").Inc()
@@ -325,7 +325,7 @@ func (p *PGPubsub) subscribeQueue(event string, newQ *msgQueue) (cancel func(),
func() { func() {
p.qMu.Lock() p.qMu.Lock()
defer p.qMu.Unlock() defer p.qMu.Unlock()
newQ.close() newQ.Close()
qSet, ok := p.queues[event] qSet, ok := p.queues[event]
if !ok { if !ok {
p.logger.Critical(context.Background(), "event was removed before cancel", slog.F("event", event)) p.logger.Critical(context.Background(), "event was removed before cancel", slog.F("event", event))
@@ -436,7 +436,7 @@ func (p *PGPubsub) listenReceive(notif *pq.Notification) {
} }
extra := []byte(notif.Extra) extra := []byte(notif.Extra)
for q := range qSet.m { for q := range qSet.m {
q.enqueue(extra) q.Enqueue(extra)
} }
} }
@@ -445,7 +445,7 @@ func (p *PGPubsub) recordReconnect() {
defer p.qMu.Unlock() defer p.qMu.Unlock()
for _, qSet := range p.queues { for _, qSet := range p.queues {
for q := range qSet.m { for q := range qSet.m {
q.dropped() q.Dropped()
} }
} }
} }
@@ -13,135 +13,6 @@ import (
"github.com/coder/coder/v2/testutil" "github.com/coder/coder/v2/testutil"
) )
func Test_msgQueue_ListenerWithError(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
m := make(chan string)
e := make(chan error)
uut := newMsgQueue(ctx, nil, func(ctx context.Context, msg []byte, err error) {
m <- string(msg)
e <- err
})
defer uut.close()
// We're going to enqueue 4 messages and an error in a loop -- that is, a cycle of 5.
// PubsubBufferSize is 2048, which is a power of 2, so a pattern of 5 will not be aligned
// when we wrap around the end of the circular buffer. This tests that we correctly handle
// the wrapping and aren't dequeueing misaligned data.
cycles := (BufferSize / 5) * 2 // almost twice around the ring
for j := 0; j < cycles; j++ {
for i := 0; i < 4; i++ {
uut.enqueue([]byte(fmt.Sprintf("%d%d", j, i)))
}
uut.dropped()
for i := 0; i < 4; i++ {
select {
case <-ctx.Done():
t.Fatal("timed out")
case msg := <-m:
require.Equal(t, fmt.Sprintf("%d%d", j, i), msg)
}
select {
case <-ctx.Done():
t.Fatal("timed out")
case err := <-e:
require.NoError(t, err)
}
}
select {
case <-ctx.Done():
t.Fatal("timed out")
case msg := <-m:
require.Equal(t, "", msg)
}
select {
case <-ctx.Done():
t.Fatal("timed out")
case err := <-e:
require.ErrorIs(t, err, ErrDroppedMessages)
}
}
}
func Test_msgQueue_Listener(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
m := make(chan string)
uut := newMsgQueue(ctx, func(ctx context.Context, msg []byte) {
m <- string(msg)
}, nil)
defer uut.close()
// We're going to enqueue 4 messages and an error in a loop -- that is, a cycle of 5.
// PubsubBufferSize is 2048, which is a power of 2, so a pattern of 5 will not be aligned
// when we wrap around the end of the circular buffer. This tests that we correctly handle
// the wrapping and aren't dequeueing misaligned data.
cycles := (BufferSize / 5) * 2 // almost twice around the ring
for j := 0; j < cycles; j++ {
for i := 0; i < 4; i++ {
uut.enqueue([]byte(fmt.Sprintf("%d%d", j, i)))
}
uut.dropped()
for i := 0; i < 4; i++ {
select {
case <-ctx.Done():
t.Fatal("timed out")
case msg := <-m:
require.Equal(t, fmt.Sprintf("%d%d", j, i), msg)
}
}
// Listener skips over errors, so we only read out the 4 real messages.
}
}
func Test_msgQueue_Full(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
firstDequeue := make(chan struct{})
allowRead := make(chan struct{})
n := 0
errors := make(chan error)
uut := newMsgQueue(ctx, nil, func(ctx context.Context, msg []byte, err error) {
if n == 0 {
close(firstDequeue)
}
<-allowRead
if err == nil {
require.Equal(t, fmt.Sprintf("%d", n), string(msg))
n++
return
}
errors <- err
})
defer uut.close()
// we send 2 more than the capacity. One extra because the call to the ListenerFunc blocks
// but only after we've dequeued a message, and then another extra because we want to exceed
// the capacity, not just reach it.
for i := 0; i < BufferSize+2; i++ {
uut.enqueue([]byte(fmt.Sprintf("%d", i)))
// ensure the first dequeue has happened before proceeding, so that this function isn't racing
// against the goroutine that dequeues items.
<-firstDequeue
}
close(allowRead)
select {
case <-ctx.Done():
t.Fatal("timed out")
case err := <-errors:
require.ErrorIs(t, err, ErrDroppedMessages)
}
// Ok, so we sent 2 more than capacity, but we only read the capacity, that's because the last
// message we send doesn't get queued, AND, it bumps a message out of the queue to make room
// for the error, so we read 2 less than we sent.
require.Equal(t, BufferSize, n)
}
func TestPubSub_DoesntBlockNotify(t *testing.T) { func TestPubSub_DoesntBlockNotify(t *testing.T) {
t.Parallel() t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort) ctx := testutil.Context(t, testutil.WaitShort)
+130
View File
@@ -3,6 +3,7 @@ package pubsub_test
import ( import (
"context" "context"
"database/sql" "database/sql"
"fmt"
"testing" "testing"
"time" "time"
@@ -201,3 +202,132 @@ func TestPGPubsubDriver(t *testing.T) {
} }
}, testutil.IntervalMedium, "subscriber did not receive message after reconnect") }, testutil.IntervalMedium, "subscriber did not receive message after reconnect")
} }
func Test_MsgQueue_ListenerWithError(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
m := make(chan string)
e := make(chan error)
uut := pubsub.NewMsgQueue(ctx, nil, func(ctx context.Context, msg []byte, err error) {
m <- string(msg)
e <- err
})
defer uut.Close()
// We're going to enqueue 4 messages and an error in a loop -- that is, a cycle of 5.
// PubsubBufferSize is 2048, which is a power of 2, so a pattern of 5 will not be aligned
// when we wrap around the end of the circular buffer. This tests that we correctly handle
// the wrapping and aren't dequeueing misaligned data.
cycles := (pubsub.BufferSize / 5) * 2 // almost twice around the ring
for j := 0; j < cycles; j++ {
for i := 0; i < 4; i++ {
uut.Enqueue([]byte(fmt.Sprintf("%d%d", j, i)))
}
uut.Dropped()
for i := 0; i < 4; i++ {
select {
case <-ctx.Done():
t.Fatal("timed out")
case msg := <-m:
require.Equal(t, fmt.Sprintf("%d%d", j, i), msg)
}
select {
case <-ctx.Done():
t.Fatal("timed out")
case err := <-e:
require.NoError(t, err)
}
}
select {
case <-ctx.Done():
t.Fatal("timed out")
case msg := <-m:
require.Equal(t, "", msg)
}
select {
case <-ctx.Done():
t.Fatal("timed out")
case err := <-e:
require.ErrorIs(t, err, pubsub.ErrDroppedMessages)
}
}
}
func Test_MsgQueue_Listener(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
m := make(chan string)
uut := pubsub.NewMsgQueue(ctx, func(ctx context.Context, msg []byte) {
m <- string(msg)
}, nil)
defer uut.Close()
// We're going to enqueue 4 messages and an error in a loop -- that is, a cycle of 5.
// PubsubBufferSize is 2048, which is a power of 2, so a pattern of 5 will not be aligned
// when we wrap around the end of the circular buffer. This tests that we correctly handle
// the wrapping and aren't dequeueing misaligned data.
cycles := (pubsub.BufferSize / 5) * 2 // almost twice around the ring
for j := 0; j < cycles; j++ {
for i := 0; i < 4; i++ {
uut.Enqueue([]byte(fmt.Sprintf("%d%d", j, i)))
}
uut.Dropped()
for i := 0; i < 4; i++ {
select {
case <-ctx.Done():
t.Fatal("timed out")
case msg := <-m:
require.Equal(t, fmt.Sprintf("%d%d", j, i), msg)
}
}
// Listener skips over errors, so we only read out the 4 real messages.
}
}
func Test_MsgQueue_Full(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
firstDequeue := make(chan struct{})
allowRead := make(chan struct{})
n := 0
errors := make(chan error)
uut := pubsub.NewMsgQueue(ctx, nil, func(ctx context.Context, msg []byte, err error) {
if n == 0 {
close(firstDequeue)
}
<-allowRead
if err == nil {
require.Equal(t, fmt.Sprintf("%d", n), string(msg))
n++
return
}
errors <- err
})
defer uut.Close()
// we send 2 more than the capacity. One extra because the call to the ListenerFunc blocks
// but only after we've dequeued a message, and then another extra because we want to exceed
// the capacity, not just reach it.
for i := 0; i < pubsub.BufferSize+2; i++ {
uut.Enqueue([]byte(fmt.Sprintf("%d", i)))
// ensure the first dequeue has happened before proceeding, so that this function isn't racing
// against the goroutine that dequeues items.
<-firstDequeue
}
close(allowRead)
select {
case <-ctx.Done():
t.Fatal("timed out")
case err := <-errors:
require.ErrorIs(t, err, pubsub.ErrDroppedMessages)
}
// Ok, so we sent 2 more than capacity, but we only read the capacity, that's because the last
// message we send doesn't get queued, AND, it bumps a message out of the queue to make room
// for the error, so we read 2 less than we sent.
require.Equal(t, pubsub.BufferSize, n)
}