diff --git a/coderd/database/pubsub/pubsub.go b/coderd/database/pubsub/pubsub.go index 86f7217b16..97d289e223 100644 --- a/coderd/database/pubsub/pubsub.go +++ b/coderd/database/pubsub/pubsub.go @@ -56,14 +56,14 @@ type msgOrErr struct { err error } -// msgQueue implements a fixed length queue with the ability to replace elements +// MsgQueue implements a fixed length queue with the ability to replace elements // after they are queued (but before they are dequeued). // // The purpose of this data structure is to build something that works a bit // like a golang channel, but if the queue is full, then we can replace the // last element with an error so that the subscriber can get notified that some // messages were dropped, all without blocking. -type msgQueue struct { +type MsgQueue struct { ctx context.Context cond *sync.Cond q [BufferSize]msgOrErr @@ -74,11 +74,11 @@ type msgQueue struct { le ListenerWithErr } -func newMsgQueue(ctx context.Context, l Listener, le ListenerWithErr) *msgQueue { +func NewMsgQueue(ctx context.Context, l Listener, le ListenerWithErr) *MsgQueue { if l == nil && le == nil { panic("l or le must be non-nil") } - q := &msgQueue{ + q := &MsgQueue{ ctx: ctx, cond: sync.NewCond(&sync.Mutex{}), l: l, @@ -88,7 +88,7 @@ func newMsgQueue(ctx context.Context, l Listener, le ListenerWithErr) *msgQueue return q } -func (q *msgQueue) run() { +func (q *MsgQueue) run() { for { // wait until there is something on the queue or we are closed q.cond.L.Lock() @@ -125,7 +125,7 @@ func (q *msgQueue) run() { } } -func (q *msgQueue) enqueue(msg []byte) { +func (q *MsgQueue) Enqueue(msg []byte) { q.cond.L.Lock() defer q.cond.L.Unlock() @@ -149,15 +149,15 @@ func (q *msgQueue) enqueue(msg []byte) { q.cond.Broadcast() } -func (q *msgQueue) close() { +func (q *MsgQueue) Close() { q.cond.L.Lock() defer q.cond.L.Unlock() defer q.cond.Broadcast() q.closed = true } -// dropped records an error in the queue that messages might have been dropped -func (q *msgQueue) dropped() { +// Dropped records an error in the queue that messages might have been Dropped +func (q *MsgQueue) Dropped() { q.cond.L.Lock() defer q.cond.L.Unlock() @@ -195,7 +195,7 @@ func (l pqListenerShim) NotifyChan() <-chan *pq.Notification { } type queueSet struct { - m map[*msgQueue]struct{} + m map[*MsgQueue]struct{} // unlistenInProgress will be non-nil if another goroutine is unlistening for the event this // queueSet corresponds to. If non-nil, that goroutine will close the channel when it is done. unlistenInProgress chan struct{} @@ -203,7 +203,7 @@ type queueSet struct { func newQueueSet() *queueSet { return &queueSet{ - m: make(map[*msgQueue]struct{}), + m: make(map[*MsgQueue]struct{}), } } @@ -243,19 +243,19 @@ const BufferSize = 2048 // Subscribe calls the listener when an event matching the name is received. func (p *PGPubsub) Subscribe(event string, listener Listener) (cancel func(), err error) { - return p.subscribeQueue(event, newMsgQueue(context.Background(), listener, nil)) + return p.subscribeQueue(event, NewMsgQueue(context.Background(), listener, nil)) } func (p *PGPubsub) SubscribeWithErr(event string, listener ListenerWithErr) (cancel func(), err error) { - return p.subscribeQueue(event, newMsgQueue(context.Background(), nil, listener)) + return p.subscribeQueue(event, NewMsgQueue(context.Background(), nil, listener)) } -func (p *PGPubsub) subscribeQueue(event string, newQ *msgQueue) (cancel func(), err error) { +func (p *PGPubsub) subscribeQueue(event string, newQ *MsgQueue) (cancel func(), err error) { defer func() { if err != nil { // if we hit an error, we need to close the queue so we don't // leak its goroutine. - newQ.close() + newQ.Close() p.subscribesTotal.WithLabelValues("false").Inc() } else { p.subscribesTotal.WithLabelValues("true").Inc() @@ -325,7 +325,7 @@ func (p *PGPubsub) subscribeQueue(event string, newQ *msgQueue) (cancel func(), func() { p.qMu.Lock() defer p.qMu.Unlock() - newQ.close() + newQ.Close() qSet, ok := p.queues[event] if !ok { p.logger.Critical(context.Background(), "event was removed before cancel", slog.F("event", event)) @@ -436,7 +436,7 @@ func (p *PGPubsub) listenReceive(notif *pq.Notification) { } extra := []byte(notif.Extra) for q := range qSet.m { - q.enqueue(extra) + q.Enqueue(extra) } } @@ -445,7 +445,7 @@ func (p *PGPubsub) recordReconnect() { defer p.qMu.Unlock() for _, qSet := range p.queues { for q := range qSet.m { - q.dropped() + q.Dropped() } } } diff --git a/coderd/database/pubsub/pubsub_internal_test.go b/coderd/database/pubsub/pubsub_internal_test.go index 0f699b4e4d..0c51d7a8e8 100644 --- a/coderd/database/pubsub/pubsub_internal_test.go +++ b/coderd/database/pubsub/pubsub_internal_test.go @@ -13,135 +13,6 @@ import ( "github.com/coder/coder/v2/testutil" ) -func Test_msgQueue_ListenerWithError(t *testing.T) { - t.Parallel() - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) - defer cancel() - m := make(chan string) - e := make(chan error) - uut := newMsgQueue(ctx, nil, func(ctx context.Context, msg []byte, err error) { - m <- string(msg) - e <- err - }) - defer uut.close() - - // We're going to enqueue 4 messages and an error in a loop -- that is, a cycle of 5. - // PubsubBufferSize is 2048, which is a power of 2, so a pattern of 5 will not be aligned - // when we wrap around the end of the circular buffer. This tests that we correctly handle - // the wrapping and aren't dequeueing misaligned data. - cycles := (BufferSize / 5) * 2 // almost twice around the ring - for j := 0; j < cycles; j++ { - for i := 0; i < 4; i++ { - uut.enqueue([]byte(fmt.Sprintf("%d%d", j, i))) - } - uut.dropped() - for i := 0; i < 4; i++ { - select { - case <-ctx.Done(): - t.Fatal("timed out") - case msg := <-m: - require.Equal(t, fmt.Sprintf("%d%d", j, i), msg) - } - select { - case <-ctx.Done(): - t.Fatal("timed out") - case err := <-e: - require.NoError(t, err) - } - } - select { - case <-ctx.Done(): - t.Fatal("timed out") - case msg := <-m: - require.Equal(t, "", msg) - } - select { - case <-ctx.Done(): - t.Fatal("timed out") - case err := <-e: - require.ErrorIs(t, err, ErrDroppedMessages) - } - } -} - -func Test_msgQueue_Listener(t *testing.T) { - t.Parallel() - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) - defer cancel() - m := make(chan string) - uut := newMsgQueue(ctx, func(ctx context.Context, msg []byte) { - m <- string(msg) - }, nil) - defer uut.close() - - // We're going to enqueue 4 messages and an error in a loop -- that is, a cycle of 5. - // PubsubBufferSize is 2048, which is a power of 2, so a pattern of 5 will not be aligned - // when we wrap around the end of the circular buffer. This tests that we correctly handle - // the wrapping and aren't dequeueing misaligned data. - cycles := (BufferSize / 5) * 2 // almost twice around the ring - for j := 0; j < cycles; j++ { - for i := 0; i < 4; i++ { - uut.enqueue([]byte(fmt.Sprintf("%d%d", j, i))) - } - uut.dropped() - for i := 0; i < 4; i++ { - select { - case <-ctx.Done(): - t.Fatal("timed out") - case msg := <-m: - require.Equal(t, fmt.Sprintf("%d%d", j, i), msg) - } - } - // Listener skips over errors, so we only read out the 4 real messages. - } -} - -func Test_msgQueue_Full(t *testing.T) { - t.Parallel() - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) - defer cancel() - - firstDequeue := make(chan struct{}) - allowRead := make(chan struct{}) - n := 0 - errors := make(chan error) - uut := newMsgQueue(ctx, nil, func(ctx context.Context, msg []byte, err error) { - if n == 0 { - close(firstDequeue) - } - <-allowRead - if err == nil { - require.Equal(t, fmt.Sprintf("%d", n), string(msg)) - n++ - return - } - errors <- err - }) - defer uut.close() - - // we send 2 more than the capacity. One extra because the call to the ListenerFunc blocks - // but only after we've dequeued a message, and then another extra because we want to exceed - // the capacity, not just reach it. - for i := 0; i < BufferSize+2; i++ { - uut.enqueue([]byte(fmt.Sprintf("%d", i))) - // ensure the first dequeue has happened before proceeding, so that this function isn't racing - // against the goroutine that dequeues items. - <-firstDequeue - } - close(allowRead) - - select { - case <-ctx.Done(): - t.Fatal("timed out") - case err := <-errors: - require.ErrorIs(t, err, ErrDroppedMessages) - } - // Ok, so we sent 2 more than capacity, but we only read the capacity, that's because the last - // message we send doesn't get queued, AND, it bumps a message out of the queue to make room - // for the error, so we read 2 less than we sent. - require.Equal(t, BufferSize, n) -} - func TestPubSub_DoesntBlockNotify(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitShort) diff --git a/coderd/database/pubsub/pubsub_test.go b/coderd/database/pubsub/pubsub_test.go index 066b9ce59a..3dbfa92f52 100644 --- a/coderd/database/pubsub/pubsub_test.go +++ b/coderd/database/pubsub/pubsub_test.go @@ -3,6 +3,7 @@ package pubsub_test import ( "context" "database/sql" + "fmt" "testing" "time" @@ -201,3 +202,132 @@ func TestPGPubsubDriver(t *testing.T) { } }, testutil.IntervalMedium, "subscriber did not receive message after reconnect") } + +func Test_MsgQueue_ListenerWithError(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + m := make(chan string) + e := make(chan error) + uut := pubsub.NewMsgQueue(ctx, nil, func(ctx context.Context, msg []byte, err error) { + m <- string(msg) + e <- err + }) + defer uut.Close() + + // We're going to enqueue 4 messages and an error in a loop -- that is, a cycle of 5. + // PubsubBufferSize is 2048, which is a power of 2, so a pattern of 5 will not be aligned + // when we wrap around the end of the circular buffer. This tests that we correctly handle + // the wrapping and aren't dequeueing misaligned data. + cycles := (pubsub.BufferSize / 5) * 2 // almost twice around the ring + for j := 0; j < cycles; j++ { + for i := 0; i < 4; i++ { + uut.Enqueue([]byte(fmt.Sprintf("%d%d", j, i))) + } + uut.Dropped() + for i := 0; i < 4; i++ { + select { + case <-ctx.Done(): + t.Fatal("timed out") + case msg := <-m: + require.Equal(t, fmt.Sprintf("%d%d", j, i), msg) + } + select { + case <-ctx.Done(): + t.Fatal("timed out") + case err := <-e: + require.NoError(t, err) + } + } + select { + case <-ctx.Done(): + t.Fatal("timed out") + case msg := <-m: + require.Equal(t, "", msg) + } + select { + case <-ctx.Done(): + t.Fatal("timed out") + case err := <-e: + require.ErrorIs(t, err, pubsub.ErrDroppedMessages) + } + } +} + +func Test_MsgQueue_Listener(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + m := make(chan string) + uut := pubsub.NewMsgQueue(ctx, func(ctx context.Context, msg []byte) { + m <- string(msg) + }, nil) + defer uut.Close() + + // We're going to enqueue 4 messages and an error in a loop -- that is, a cycle of 5. + // PubsubBufferSize is 2048, which is a power of 2, so a pattern of 5 will not be aligned + // when we wrap around the end of the circular buffer. This tests that we correctly handle + // the wrapping and aren't dequeueing misaligned data. + cycles := (pubsub.BufferSize / 5) * 2 // almost twice around the ring + for j := 0; j < cycles; j++ { + for i := 0; i < 4; i++ { + uut.Enqueue([]byte(fmt.Sprintf("%d%d", j, i))) + } + uut.Dropped() + for i := 0; i < 4; i++ { + select { + case <-ctx.Done(): + t.Fatal("timed out") + case msg := <-m: + require.Equal(t, fmt.Sprintf("%d%d", j, i), msg) + } + } + // Listener skips over errors, so we only read out the 4 real messages. + } +} + +func Test_MsgQueue_Full(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + + firstDequeue := make(chan struct{}) + allowRead := make(chan struct{}) + n := 0 + errors := make(chan error) + uut := pubsub.NewMsgQueue(ctx, nil, func(ctx context.Context, msg []byte, err error) { + if n == 0 { + close(firstDequeue) + } + <-allowRead + if err == nil { + require.Equal(t, fmt.Sprintf("%d", n), string(msg)) + n++ + return + } + errors <- err + }) + defer uut.Close() + + // we send 2 more than the capacity. One extra because the call to the ListenerFunc blocks + // but only after we've dequeued a message, and then another extra because we want to exceed + // the capacity, not just reach it. + for i := 0; i < pubsub.BufferSize+2; i++ { + uut.Enqueue([]byte(fmt.Sprintf("%d", i))) + // ensure the first dequeue has happened before proceeding, so that this function isn't racing + // against the goroutine that dequeues items. + <-firstDequeue + } + close(allowRead) + + select { + case <-ctx.Done(): + t.Fatal("timed out") + case err := <-errors: + require.ErrorIs(t, err, pubsub.ErrDroppedMessages) + } + // Ok, so we sent 2 more than capacity, but we only read the capacity, that's because the last + // message we send doesn't get queued, AND, it bumps a message out of the queue to make room + // for the error, so we read 2 less than we sent. + require.Equal(t, pubsub.BufferSize, n) +}