feat: add initial NATS implementation (#25602)

This commit is contained in:
Jon Ayers
2026-05-27 12:57:20 -05:00
committed by GitHub
parent 3cf867f84a
commit f6f284ea51
6 changed files with 1321 additions and 0 deletions
+652
View File
@@ -0,0 +1,652 @@
package nats
import (
"context"
"errors"
"fmt"
"hash/fnv"
"sync"
"time"
natsserver "github.com/nats-io/nats-server/v2/server"
natsgo "github.com/nats-io/nats.go"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/database/pubsub"
)
// DefaultMaxPending is the per-client outbound pending byte budget.
const DefaultMaxPending int64 = 128 << 20
var errClosed = xerrors.New("nats pubsub closed")
// PendingLimits configures per-subscription NATS pending limits set
// via SetPendingLimits on each *natsgo.Subscription.
type PendingLimits struct {
// Msgs is the per-subscription pending message limit. Positive
// values also set each local listener queue capacity.
// Zero uses the package default. Negative disables this limit.
Msgs int
// Bytes is the per-subscription pending byte limit.
// Zero uses the package default. Negative disables this limit.
Bytes int
}
// Options configures the embedded NATS Pubsub.
type Options struct {
// MaxPayload is the NATS max payload. Zero means server default.
MaxPayload int32
// MaxPending is the per-client outbound pending byte budget on the
// embedded server. Zero or negative means the package default,
// 128 MiB.
MaxPending int64
// PendingLimits configures per-subscription NATS pending limits.
// Positive Msgs also sets local listener queue capacity.
// Zero fields use package defaults: Msgs -1 and Bytes 512 MiB.
PendingLimits PendingLimits
// ReconnectWait controls client reconnect delay. Zero keeps the
// NATS default.
ReconnectWait time.Duration
// InProcess, when true, uses nats.InProcessServer instead of TCP
// loopback. Intended for benchmarks and tests.
InProcess bool
// PublishConns is the number of publisher connections. Each Publish
// is routed by a stable hash of the subject. Zero or negative means 1.
PublishConns int
// SubscribeConns is the number of subscriber connections. Each
// shared subscription is pinned to one connection by a stable hash
// of its subject. Zero or negative means 1.
SubscribeConns int
}
// Pubsub is an embedded NATS-backed implementation of pubsub.Pubsub.
//
// Each Pubsub owns one embedded server, a pool of publisher
// *natsgo.Conns (Options.PublishConns) and a pool of subscriber
// *natsgo.Conns (Options.SubscribeConns). Publishes and shared
// subscriptions are pinned to a connection by a stable hash of the
// subject, so same-subject traffic preserves per-subject ordering and
// every local subscriber for a subject coalesces onto one underlying
// *natsgo.Subscription.
type Pubsub struct {
mu sync.Mutex
logger slog.Logger
opts Options
ns *natsserver.Server
// publishPool and subscribePool are immutable after construction so
// the hot path can index without holding p.mu.
publishPool []*natsgo.Conn
subscribePool []*natsgo.Conn
// subscriptions coalesces concurrent local subscribers on the
// same subject onto a single underlying *natsgo.Subscription.
subscriptions map[string]*natsSub
closeOnce sync.Once
// ctx is canceled by Close while holding p.mu so subscriber state
// cleanup observes the canceled context.
ctx context.Context
cancel context.CancelFunc
}
// natsSub maps to one underlying *natsgo.Subscription. The first
// local subscriber creates it; later local subscribers attach to it.
// When the last local subscriber detaches, the NATS subscription is
// unsubscribed.
type natsSub struct {
// sub is set before this natsSub is published in Pubsub.subscriptions
// and is immutable after that.
sub *natsgo.Subscription
// mu guards localSubs.
mu sync.Mutex
// localSubs are the local subscribers attached to this NATS subscription.
localSubs map[*localSub]struct{}
// dropMu keeps async error accounting independent from listener fan-out.
dropMu sync.Mutex
// lastDropped is the cumulative NATS dropped count last reported locally.
lastDropped uint64
}
// localSub is the local handle returned by Subscribe /
// SubscribeWithErr. Each local subscriber gets its own bounded inbox
// and dispatcher goroutine so one slow listener cannot block peers on
// the same subject.
type localSub struct {
cancelOnce sync.Once
ctx context.Context
event string
listener pubsub.ListenerWithErr
// queue is the per-listener data fan-out inbox. The shared NATS
// callback enqueues non-blockingly; on overflow the message is
// dropped and a drop signal is raised.
queue chan []byte
// dropSignal is a size-1 buffered channel that coalesces drop
// notifications from local overflow and NATS slow-consumer
// broadcasts onto a single pending wake.
dropSignal chan struct{}
cancel context.CancelFunc
}
// Compile-time assertion that *Pubsub satisfies the pubsub.Pubsub interface.
var _ pubsub.Pubsub = (*Pubsub)(nil)
// newPubsub allocates a *Pubsub with initialized maps and cancel ctx.
func newPubsub(ctx context.Context, logger slog.Logger, opts Options) *Pubsub {
ctx, cancel := context.WithCancel(ctx)
return &Pubsub{
logger: logger,
opts: opts,
subscriptions: make(map[string]*natsSub),
ctx: ctx,
cancel: cancel,
}
}
// defaultPendingLimits returns the effective per-subscription pending
// limits applied at Subscribe time.
func defaultPendingLimits(in PendingLimits) PendingLimits {
out := in
if out.Msgs == 0 {
out.Msgs = -1
}
if out.Bytes == 0 {
out.Bytes = 512 * 1024 * 1024
}
return out
}
// buildConnHandlers returns the connHandlers stack installed on every
// owned connection. Handlers close over p so slow-consumer routing
// keeps working.
func (p *Pubsub) buildConnHandlers() connHandlers {
return connHandlers{
disconnectErr: func(conn *natsgo.Conn, err error) {
if err != nil {
p.logger.Warn(p.ctx, "nats client disconnected", slog.Error(err))
}
p.signalSubscribersDroppedForConn(conn)
},
reconnect: func(_ *natsgo.Conn) {
p.logger.Info(p.ctx, "nats client reconnected")
},
closed: func(_ *natsgo.Conn) {
p.logger.Debug(p.ctx, "nats client closed")
},
errH: func(_ *natsgo.Conn, sub *natsgo.Subscription, err error) {
if err != nil && errors.Is(err, natsgo.ErrSlowConsumer) {
p.handleAsyncError(sub, err)
return
}
if err != nil {
p.logger.Warn(p.ctx, "nats async error", slog.Error(err))
}
},
}
}
// New creates an embedded NATS Pubsub. The returned *Pubsub owns the
// embedded server and the publisher and subscriber connection pools.
// Close shuts down all owned resources.
func New(ctx context.Context, logger slog.Logger, opts Options) (*Pubsub, error) {
ns, err := startEmbeddedServer(logger, opts)
if err != nil {
return nil, err
}
p := newPubsub(ctx, logger, opts)
p.ns = ns
handlers := p.buildConnHandlers()
publishPool, err := newConnPool(ns, opts, handlers, opts.PublishConns, "coder-pubsub-pub")
if err != nil {
p.cancel()
ns.Shutdown()
ns.WaitForShutdown()
return nil, err
}
subscribePool, err := newConnPool(ns, opts, handlers, opts.SubscribeConns, "coder-pubsub-sub")
if err != nil {
p.cancel()
for _, c := range publishPool {
c.Close()
}
ns.Shutdown()
ns.WaitForShutdown()
return nil, err
}
p.publishPool = publishPool
p.subscribePool = subscribePool
go func() {
<-p.ctx.Done()
_ = p.Close()
}()
return p, nil
}
func newConnPool(ns *natsserver.Server, opts Options, handlers connHandlers, count int, clientName string) ([]*natsgo.Conn, error) {
if count <= 0 {
count = 1
}
pool := make([]*natsgo.Conn, 0, count)
for i := 0; i < count; i++ {
// Suffix names when the pool has more than one entry so server
// logs can distinguish connections.
name := clientName
if count > 1 {
name = fmt.Sprintf("%s-%d", clientName, i)
}
nc, err := connectClient(ns, opts, handlers, name)
if err != nil {
for _, c := range pool {
c.Close()
}
return nil, xerrors.Errorf("dial conn: %w", err)
}
pool = append(pool, nc)
}
return pool, nil
}
// Publish publishes a message under the given event name. The
// publisher connection is selected by a stable hash of the subject so
// same-subject publishes preserve per-subject ordering.
func (p *Pubsub) Publish(event string, message []byte) error {
if p.ctx.Err() != nil {
return errClosed
}
if err := pickConn(p.publishPool, event).Publish(event, message); err != nil {
return xerrors.Errorf("publish: %w", err)
}
return nil
}
// Flush blocks until every publisher connection has flushed buffered
// publishes to the embedded server. Returns the first error
// encountered; remaining connections are still flushed.
func (p *Pubsub) Flush() error {
if p.ctx.Err() != nil {
return errClosed
}
var firstErr error
for i, nc := range p.publishPool {
if err := nc.Flush(); err != nil && firstErr == nil {
firstErr = xerrors.Errorf("flush pub conn %d: %w", i, err)
}
}
return firstErr
}
// Subscribe subscribes a Listener to the given event name. Errors
// such as ErrDroppedMessages are silently ignored, mirroring the
// legacy pubsub Listener semantics.
func (p *Pubsub) Subscribe(event string, listener pubsub.Listener) (cancel func(), err error) {
return p.SubscribeWithErr(event, func(ctx context.Context, msg []byte, err error) {
if err != nil {
return
}
listener(ctx, msg)
})
}
// SubscribeWithErr subscribes a ListenerWithErr to the given event
// name. The listener also receives error deliveries such as
// pubsub.ErrDroppedMessages. Multiple local subscribers on the same
// event share a single underlying *natsgo.Subscription with
// per-listener bounded inboxes so a slow listener cannot block its
// peers.
func (p *Pubsub) SubscribeWithErr(event string, listener pubsub.ListenerWithErr) (cancel func(), err error) {
s, err := p.addSubscriber(event, listener)
if err != nil {
return nil, err
}
cancelFn := func() {
s.close()
p.unsubscribeLocal(s)
}
return cancelFn, nil
}
// listenerQueueSize returns the per-listener inbox capacity. A
// positive PendingLimits.Msgs sets the cap (giving callers a knob to
// trigger local-overflow drops since coalescing makes NATS-level
// slow-consumer signals rare). Otherwise the default is used.
func listenerQueueSize(in PendingLimits) int {
if in.Msgs > 0 {
return in.Msgs
}
return defaultListenerQueueSize
}
const defaultListenerQueueSize = 1024
// addSubscriber creates a local subscriber and attaches it to the natsSub
// for event. New natsSub entries are published only after NATS setup succeeds.
func (p *Pubsub) addSubscriber(event string, listener pubsub.ListenerWithErr) (*localSub, error) {
ctx, cancel := context.WithCancel(p.ctx)
s := &localSub{
ctx: ctx,
cancel: cancel,
event: event,
listener: listener,
queue: make(chan []byte, listenerQueueSize(p.opts.PendingLimits)),
dropSignal: make(chan struct{}, 1),
}
s.init()
cleanupSub, err := func() (*natsgo.Subscription, error) {
p.mu.Lock()
defer p.mu.Unlock()
if p.ctx.Err() != nil {
return nil, errClosed
}
nsub, ok := p.subscriptions[event]
if ok {
nsub.mu.Lock()
nsub.localSubs[s] = struct{}{}
nsub.mu.Unlock()
return nsub.sub, nil
}
nsub = &natsSub{
localSubs: map[*localSub]struct{}{
s: {},
},
}
subConn := pickConn(p.subscribePool, event)
natsSubscription, err := subConn.Subscribe(event, nsub.handleMessage)
if err != nil {
return nil, xerrors.Errorf("subscribe: %w", err)
}
nsub.sub = natsSubscription
// Flush the SUB to the server so a publish issued immediately
// after Subscribe returns cannot race ahead of registration.
if err := subConn.Flush(); err != nil {
return natsSubscription, xerrors.Errorf("flush subscribe: %w", err)
}
limits := defaultPendingLimits(p.opts.PendingLimits)
if err := natsSubscription.SetPendingLimits(limits.Msgs, limits.Bytes); err != nil {
return natsSubscription, xerrors.Errorf("set pending limits: %w", err)
}
p.subscriptions[event] = nsub
return natsSubscription, nil
}()
if err != nil {
s.close()
if cleanupSub != nil {
if unsubscribeErr := cleanupSub.Unsubscribe(); unsubscribeErr != nil {
err = errors.Join(err, xerrors.Errorf("unsubscribe: %w", unsubscribeErr))
}
}
return nil, err
}
return s, nil
}
// unsubscribeLocal removes s from its natsSub. If s was the last
// listener, it also removes and unsubscribes the underlying NATS
// subscription.
func (p *Pubsub) unsubscribeLocal(s *localSub) {
natsSub := func() *natsgo.Subscription {
p.mu.Lock()
defer p.mu.Unlock()
nsub := p.subscriptions[s.event]
if nsub == nil {
return nil
}
nsub.mu.Lock()
defer nsub.mu.Unlock()
if _, tracked := nsub.localSubs[s]; !tracked {
return nil
}
delete(nsub.localSubs, s)
if len(nsub.localSubs) > 0 {
return nil
}
// Last listener: remove the nsub entry so a new Subscribe to this
// subject creates a fresh underlying subscription.
delete(p.subscriptions, s.event)
return nsub.sub
}()
if natsSub != nil {
_ = natsSub.Unsubscribe()
}
}
// handleMessage handles messages for the shared subscription. Each
// enqueue is non-blocking and does not call user code, so one slow
// listener cannot stall the NATS delivery goroutine.
//
// Zero-copy fan-out: the same msg.Data slice is delivered to every
// local listener without cloning. Listeners on a coalesced subject MUST
// treat the delivered bytes as immutable.
func (nsub *natsSub) handleMessage(msg *natsgo.Msg) {
nsub.mu.Lock()
defer nsub.mu.Unlock()
for s := range nsub.localSubs {
s.enqueue(msg.Data)
}
}
// init starts the per-listener delivery goroutine.
func (s *localSub) init() {
go func() {
for {
select {
case <-s.ctx.Done():
return
case data := <-s.queue:
s.listener(s.ctx, data, nil)
case <-s.dropSignal:
s.listener(s.ctx, nil, pubsub.ErrDroppedMessages)
}
}
}()
}
// close cancels local delivery without waiting for callbacks.
func (s *localSub) close() {
s.cancelOnce.Do(func() {
if s.cancel != nil {
s.cancel()
}
})
}
// enqueue non-blockingly sends data onto s.queue. On overflow it drops the
// message and raises a drop signal so pubsub.ErrDroppedMessages is surfaced.
// If s is canceled the message is silently dropped.
func (s *localSub) enqueue(data []byte) {
select {
case s.queue <- data:
default:
s.signalDrop()
}
}
// signalDrop pushes onto dropSignal without blocking. Multiple drops
// between dispatcher dequeues coalesce into a single pending signal, so
// the listener observes one ErrDroppedMessages per drop wave.
func (s *localSub) signalDrop() {
select {
case s.dropSignal <- struct{}{}:
default:
}
}
// signalSubscribersDroppedForConn signals local subscribers assigned to conn.
func (p *Pubsub) signalSubscribersDroppedForConn(conn *natsgo.Conn) {
if conn == nil || len(p.subscribePool) == 0 {
return
}
p.mu.Lock()
subs := make([]*localSub, 0)
for event, nsub := range p.subscriptions {
if pickConn(p.subscribePool, event) != conn {
continue
}
nsub.mu.Lock()
for s := range nsub.localSubs {
subs = append(subs, s)
}
nsub.mu.Unlock()
}
p.mu.Unlock()
for _, s := range subs {
s.signalDrop()
}
}
// handleAsyncError routes async error callbacks. Only slow-consumer
// errors trigger drop accounting.
func (p *Pubsub) handleAsyncError(sub *natsgo.Subscription, err error) {
if sub == nil || !errors.Is(err, natsgo.ErrSlowConsumer) {
return
}
p.mu.Lock()
var nsub *natsSub
for _, candidate := range p.subscriptions {
if candidate.sub == sub {
nsub = candidate
break
}
}
p.mu.Unlock()
if nsub == nil {
return
}
p.handleSlowSubscriber(nsub)
}
// handleSlowSubscriber broadcasts pubsub.ErrDroppedMessages to every
// local listener on nsub when NATS reports a new drop delta. The
// slow-consumer signal is per-subscription and cannot be narrowed to a
// single local listener.
func (p *Pubsub) handleSlowSubscriber(nsub *natsSub) {
nsub.dropMu.Lock()
dropped, err := nsub.sub.Dropped()
if err != nil {
nsub.dropMu.Unlock()
p.logger.Warn(p.ctx, "nats: query dropped count", slog.Error(err))
return
}
if dropped < 0 {
nsub.dropMu.Unlock()
p.logger.Warn(p.ctx, "nats: negative dropped count")
return
}
// Dropped is cumulative per subscription; signal only new drops.
droppedCount := uint64(dropped)
if droppedCount < nsub.lastDropped {
nsub.lastDropped = droppedCount
nsub.dropMu.Unlock()
return
}
if droppedCount == nsub.lastDropped {
nsub.dropMu.Unlock()
return
}
nsub.lastDropped = droppedCount
nsub.dropMu.Unlock()
nsub.mu.Lock()
defer nsub.mu.Unlock()
for s := range nsub.localSubs {
s.signalDrop()
}
}
// Close stops local delivery and shuts down the Pubsub. It is idempotent.
// Close does not drain queued listener messages.
func (p *Pubsub) Close() error {
p.closeOnce.Do(func() {
p.mu.Lock()
// Cancel while holding p.mu so subscriber state cleanup below
// observes the canceled context.
p.cancel()
var subs []*localSub
shareds := make([]*natsSub, 0, len(p.subscriptions))
for _, ss := range p.subscriptions {
shareds = append(shareds, ss)
ss.mu.Lock()
for s := range ss.localSubs {
subs = append(subs, s)
delete(ss.localSubs, s)
}
ss.mu.Unlock()
}
clear(p.subscriptions)
p.mu.Unlock()
// Unsubscribe shared subscriptions before closing connections.
for _, ss := range shareds {
if ss.sub != nil {
_ = ss.sub.Unsubscribe()
}
}
// Signal per-listener goroutines without waiting for callbacks.
for _, s := range subs {
s.close()
}
for _, nc := range p.subscribePool {
if nc != nil {
nc.Close()
}
}
for _, nc := range p.publishPool {
if nc != nil {
nc.Close()
}
}
if p.ns != nil {
p.ns.Shutdown()
p.ns.WaitForShutdown()
}
})
return nil
}
// pickConn returns the connection assigned to subject. Selection uses
// a stable FNV-1a hash so same-subject traffic always targets the same
// connection within a process; pools are immutable after construction
// so the lookup is lock-free.
func pickConn(pool []*natsgo.Conn, subject string) *natsgo.Conn {
if len(pool) == 1 {
return pool[0]
}
h := fnv.New32a()
_, _ = h.Write([]byte(subject))
n := uint32(len(pool)) //nolint:gosec // pool size bounded by Options.{Publish,Subscribe}Conns
return pool[h.Sum32()%n]
}
+339
View File
@@ -0,0 +1,339 @@
package nats //nolint:testpackage // Exercises internal pubsub state and helpers.
import (
"context"
"errors"
"fmt"
"sync"
"sync/atomic"
"testing"
natsgo "github.com/nats-io/nats.go"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/database/pubsub"
"github.com/coder/coder/v2/testutil"
)
func Test_defaultPendingLimits(t *testing.T) {
t.Parallel()
const defaultBytes = 512 * 1024 * 1024
testCases := []struct {
name string
in PendingLimits
want PendingLimits
}{
{
name: "AllZero",
in: PendingLimits{},
want: PendingLimits{Msgs: -1, Bytes: defaultBytes},
},
{
name: "MsgsOnly",
in: PendingLimits{Msgs: 8},
want: PendingLimits{Msgs: 8, Bytes: defaultBytes},
},
{
name: "BytesOnly",
in: PendingLimits{Bytes: 1024},
want: PendingLimits{Msgs: -1, Bytes: 1024},
},
{
name: "NegativeMsgs",
in: PendingLimits{Msgs: -2},
want: PendingLimits{Msgs: -2, Bytes: defaultBytes},
},
{
name: "NegativeBytes",
in: PendingLimits{Bytes: -2},
want: PendingLimits{Msgs: -1, Bytes: -2},
},
{
name: "NegativeBoth",
in: PendingLimits{Msgs: -2, Bytes: -3},
want: PendingLimits{Msgs: -2, Bytes: -3},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
require.Equal(t, tc.want, defaultPendingLimits(tc.in))
})
}
}
func Test_pickConn(t *testing.T) {
t.Parallel()
t.Run("DifferentSubjects", func(t *testing.T) {
t.Parallel()
var a, b natsgo.Conn
pool := []*natsgo.Conn{&a, &b}
require.NotSame(t, pickConn(pool, "a"), pickConn(pool, "b"))
})
}
func subjectForConn(t *testing.T, pool []*natsgo.Conn, conn *natsgo.Conn, prefix string) string {
t.Helper()
for i := 0; i < 10_000; i++ {
subject := fmt.Sprintf("%s_%d", prefix, i)
if pickConn(pool, subject) == conn {
return subject
}
}
require.FailNow(t, "no subject matched requested connection")
return ""
}
func Test_New(t *testing.T) {
t.Parallel()
t.Run("ConnectionCount", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
ps, err := New(ctx, logger, Options{})
require.NoError(t, err)
t.Cleanup(func() { _ = ps.Close() })
const n = 50
cancels := make([]func(), 0, n)
for i := 0; i < n; i++ {
c, err := ps.Subscribe(fmt.Sprintf("cc_evt_%d", i), func(context.Context, []byte) {})
require.NoError(t, err)
cancels = append(cancels, c)
}
t.Cleanup(func() {
for _, c := range cancels {
c()
}
})
require.Equal(t, 2, ps.ns.NumClients(),
"expected exactly 2 client connections (pubConn + subConn), got %d", ps.ns.NumClients())
require.Len(t, ps.publishPool, 1, "default PublishConns must be 1")
require.Len(t, ps.subscribePool, 1, "default SubscribeConns must be 1")
require.NotSame(t, ps.publishPool[0], ps.subscribePool[0], "pubConn and subConn must be distinct")
})
}
func Test_SubscribeWithErr(t *testing.T) {
t.Parallel()
t.Run("SameSubjectSharesSubscription", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
ps, err := New(ctx, logger, Options{})
require.NoError(t, err)
t.Cleanup(func() { _ = ps.Close() })
cancelA, err := ps.Subscribe("coalesce_evt", func(context.Context, []byte) {})
require.NoError(t, err)
t.Cleanup(cancelA)
cancelB, err := ps.Subscribe("coalesce_evt", func(context.Context, []byte) {})
require.NoError(t, err)
t.Cleanup(cancelB)
ps.mu.Lock()
defer ps.mu.Unlock()
require.Len(t, ps.subscriptions, 1)
})
}
func Test_Pubsub_buildConnHandlers(t *testing.T) {
t.Parallel()
t.Run("DisconnectSignalsDropsForMatchingSubscriberConn", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ps := newPubsub(ctx, logger, Options{})
var subConnA, subConnB, pubConn natsgo.Conn
ps.subscribePool = []*natsgo.Conn{&subConnA, &subConnB}
matchingEvent := subjectForConn(t, ps.subscribePool, &subConnA, "disconnect_match")
otherEvent := subjectForConn(t, ps.subscribePool, &subConnB, "disconnect_other")
newLocal := func(event string) *localSub {
return &localSub{
event: event,
dropSignal: make(chan struct{}, 1),
}
}
matchingSub := newLocal(matchingEvent)
otherSub := newLocal(otherEvent)
ps.subscriptions[matchingSub.event] = &natsSub{localSubs: map[*localSub]struct{}{matchingSub: {}}}
ps.subscriptions[otherSub.event] = &natsSub{localSubs: map[*localSub]struct{}{otherSub: {}}}
handlers := ps.buildConnHandlers()
handlers.disconnectErr(&subConnA, xerrors.New("disconnect"))
select {
case <-matchingSub.dropSignal:
default:
require.Fail(t, "matching subscriber did not receive drop signal")
}
select {
case <-otherSub.dropSignal:
require.Fail(t, "non-matching subscriber received drop signal")
default:
}
handlers.disconnectErr(&pubConn, xerrors.New("publisher disconnect"))
select {
case <-otherSub.dropSignal:
require.Fail(t, "publisher connection disconnect signaled subscriber")
default:
}
})
}
func Test_localSub_init(t *testing.T) {
t.Parallel()
t.Run("SerializesCallbacks", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
dataStarted := make(chan struct{})
dropDelivered := make(chan struct{})
release := make(chan struct{})
var dataOnce sync.Once
var dropOnce sync.Once
var releaseOnce sync.Once
var active atomic.Int64
var concurrent atomic.Bool
s := &localSub{
ctx: ctx,
cancel: cancel,
listener: func(_ context.Context, _ []byte, ferr error) {
if active.Add(1) != 1 {
concurrent.Store(true)
}
defer active.Add(-1)
if errors.Is(ferr, pubsub.ErrDroppedMessages) {
dropOnce.Do(func() { close(dropDelivered) })
return
}
dataOnce.Do(func() { close(dataStarted) })
<-release
},
queue: make(chan []byte, 1),
dropSignal: make(chan struct{}, 1),
}
s.init()
t.Cleanup(func() {
releaseOnce.Do(func() { close(release) })
s.close()
})
s.enqueue([]byte("data"))
require.Eventually(t, func() bool {
select {
case <-dataStarted:
return true
default:
return false
}
}, testutil.WaitShort, testutil.IntervalFast)
s.signalDrop()
require.Never(t, func() bool {
select {
case <-dropDelivered:
return true
default:
return false
}
}, testutil.IntervalMedium, testutil.IntervalFast,
"drop callback must wait for the blocked data callback")
require.False(t, concurrent.Load(), "listener callback ran concurrently")
releaseOnce.Do(func() { close(release) })
require.Eventually(t, func() bool {
select {
case <-dropDelivered:
return true
default:
return false
}
}, testutil.WaitShort, testutil.IntervalFast)
require.False(t, concurrent.Load(), "listener callback ran concurrently")
})
t.Run("CrossSubjectListenerIsolation", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
ps, err := New(ctx, logger, Options{})
require.NoError(t, err)
t.Cleanup(func() { _ = ps.Close() })
release := make(chan struct{})
var releaseOnce sync.Once
var slowDrops atomic.Int64
var slowBlocked atomic.Bool
slowCancel, err := ps.SubscribeWithErr("iso_slow", func(_ context.Context, _ []byte, ferr error) {
if ferr != nil && errors.Is(ferr, pubsub.ErrDroppedMessages) {
slowDrops.Add(1)
return
}
if slowBlocked.CompareAndSwap(false, true) {
<-release
}
})
require.NoError(t, err)
defer slowCancel()
var fastCount atomic.Int64
fastCancel, err := ps.Subscribe("iso_fast", func(_ context.Context, _ []byte) {
fastCount.Add(1)
})
require.NoError(t, err)
defer fastCancel()
defer releaseOnce.Do(func() { close(release) })
total := defaultListenerQueueSize + 256
payload := make([]byte, 4*1024)
for i := 0; i < total; i++ {
require.NoError(t, ps.Publish("iso_slow", payload))
require.NoError(t, ps.Publish("iso_fast", []byte("ping")))
}
require.NoError(t, ps.Flush())
require.Eventually(t, func() bool {
return fastCount.Load() >= int64(total)
}, testutil.WaitLong, testutil.IntervalFast)
require.Zero(t, slowDrops.Load(),
"drop callback must wait for the blocked data callback")
releaseOnce.Do(func() { close(release) })
require.Eventually(t, func() bool {
return slowDrops.Load() >= 1
}, testutil.WaitLong, testutil.IntervalFast,
"slow subscriber must receive at least one ErrDroppedMessages signal")
require.GreaterOrEqual(t, fastCount.Load(), int64(total),
"fast subscriber must keep receiving despite slow peer on shared subConn")
require.Len(t, ps.subscribePool, 1)
require.False(t, ps.subscribePool[0].IsClosed(), "subConn must not be closed by slow consumer")
require.True(t, ps.subscribePool[0].IsConnected(), "subConn must stay connected")
require.Equal(t, 2, ps.ns.NumClients(), "slow consumer must not disconnect subConn")
})
}
+199
View File
@@ -0,0 +1,199 @@
package nats_test
import (
"context"
"fmt"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/database/pubsub"
xnats "github.com/coder/coder/v2/coderd/x/nats"
"github.com/coder/coder/v2/testutil"
)
func newTestPubsub(t *testing.T, opts xnats.Options) *xnats.Pubsub {
t.Helper()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
ctx, cancel := context.WithCancel(context.Background())
ps, err := xnats.New(ctx, logger, opts)
require.NoError(t, err)
t.Cleanup(func() {
_ = ps.Close()
cancel()
})
return ps
}
func TestPubsub(t *testing.T) {
t.Parallel()
t.Run("RoundTrip", func(t *testing.T) {
t.Parallel()
ps := newTestPubsub(t, xnats.Options{})
got := make(chan []byte, 1)
cancel, err := ps.Subscribe("test_event", func(_ context.Context, msg []byte) {
got <- msg
})
require.NoError(t, err)
defer cancel()
require.NoError(t, ps.Publish("test_event", []byte("hello")))
select {
case msg := <-got:
assert.Equal(t, "hello", string(msg))
case <-time.After(testutil.WaitShort):
t.Fatal("timed out waiting for message")
}
})
t.Run("SubscribeWithErrNormalMessage", func(t *testing.T) {
t.Parallel()
ps := newTestPubsub(t, xnats.Options{})
got := make(chan []byte, 1)
cancel, err := ps.SubscribeWithErr("evt", func(_ context.Context, msg []byte, err error) {
assert.NoError(t, err)
got <- msg
})
require.NoError(t, err)
defer cancel()
require.NoError(t, ps.Publish("evt", []byte("payload")))
select {
case msg := <-got:
assert.Equal(t, "payload", string(msg))
case <-time.After(testutil.WaitShort):
t.Fatal("timed out waiting for message")
}
})
t.Run("EchoDefault", func(t *testing.T) {
t.Parallel()
ps := newTestPubsub(t, xnats.Options{})
got := make(chan []byte, 1)
cancel, err := ps.Subscribe("echo_evt", func(_ context.Context, msg []byte) {
got <- msg
})
require.NoError(t, err)
defer cancel()
require.NoError(t, ps.Publish("echo_evt", []byte("data")))
select {
case msg := <-got:
assert.Equal(t, "data", string(msg))
case <-time.After(testutil.WaitShort):
t.Fatal("default should echo own messages")
}
})
t.Run("Ordering", func(t *testing.T) {
t.Parallel()
ps := newTestPubsub(t, xnats.Options{})
const n = 100
got := make(chan []byte, n)
cancel, err := ps.Subscribe("ord_evt", func(_ context.Context, msg []byte) {
got <- msg
})
require.NoError(t, err)
defer cancel()
for i := 0; i < n; i++ {
require.NoError(t, ps.Publish("ord_evt", []byte(fmt.Sprintf("%d", i))))
}
deadline := time.After(testutil.WaitLong)
for i := 0; i < n; i++ {
select {
case msg := <-got:
assert.Equal(t, fmt.Sprintf("%d", i), string(msg))
case <-deadline:
t.Fatalf("timed out at message %d/%d", i, n)
}
}
})
t.Run("CloseIdempotent", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
ps, err := xnats.New(ctx, logger, xnats.Options{})
require.NoError(t, err)
var first, second error
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
first = ps.Close()
}()
wg.Wait()
second = ps.Close()
assert.NoError(t, first)
assert.NoError(t, second)
})
t.Run("SubscribeWithErrReceivesDropError", func(t *testing.T) {
t.Parallel()
ps := newTestPubsub(t, xnats.Options{
PendingLimits: xnats.PendingLimits{Msgs: 1, Bytes: 1024 * 1024},
})
const event = "slow_evt_sync"
started := make(chan struct{})
release := make(chan struct{})
dropped := make(chan error, 1)
var startedOnce sync.Once
var releaseOnce sync.Once
defer releaseOnce.Do(func() { close(release) })
cancel, err := ps.SubscribeWithErr(event, func(_ context.Context, _ []byte, err error) {
if err != nil {
select {
case dropped <- err:
default:
}
return
}
startedOnce.Do(func() {
close(started)
<-release
})
})
require.NoError(t, err)
defer cancel()
require.NoError(t, ps.Publish(event, []byte("first")))
require.NoError(t, ps.Flush())
select {
case <-started:
case <-time.After(testutil.WaitShort):
t.Fatal("timed out waiting for first callback")
}
for i := 0; i < 8; i++ {
require.NoError(t, ps.Publish(event, []byte("burst")))
}
require.NoError(t, ps.Flush())
releaseOnce.Do(func() { close(release) })
select {
case err := <-dropped:
assert.ErrorIs(t, err, pubsub.ErrDroppedMessages)
case <-time.After(testutil.WaitLong):
t.Fatal("timed out waiting for drop error")
}
})
}
+106
View File
@@ -0,0 +1,106 @@
package nats
import (
"context"
"time"
natsserver "github.com/nats-io/nats-server/v2/server"
natsgo "github.com/nats-io/nats.go"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
)
const readyTimeout = 10 * time.Second
// buildServerOptions constructs the embedded NATS server options. The
// server runs standalone with a loopback random client listener.
func buildServerOptions(opts Options) (*natsserver.Options, error) {
maxPayload := opts.MaxPayload
if maxPayload == 0 {
maxPayload = natsserver.MAX_PAYLOAD_SIZE
}
maxPending := opts.MaxPending
if maxPending <= 0 {
maxPending = DefaultMaxPending
}
sopts := &natsserver.Options{
JetStream: false,
MaxPayload: maxPayload,
MaxPending: maxPending,
NoLog: true,
NoSigs: true,
}
sopts.DontListen = false
sopts.Host = "127.0.0.1"
sopts.Port = natsserver.RANDOM_PORT
return sopts, nil
}
// startEmbeddedServer starts an in-process standalone NATS server.
func startEmbeddedServer(logger slog.Logger, opts Options) (*natsserver.Server, error) {
sopts, err := buildServerOptions(opts)
if err != nil {
return nil, err
}
ns, err := natsserver.NewServer(sopts)
if err != nil {
return nil, xerrors.Errorf("new embedded nats server: %w", err)
}
go ns.Start()
if !ns.ReadyForConnections(readyTimeout) {
ns.Shutdown()
ns.WaitForShutdown()
return nil, xerrors.Errorf("embedded nats server not ready within %s", readyTimeout)
}
logger.Info(context.Background(), "embedded nats server started",
slog.F("client_url", ns.ClientURL()),
)
return ns, nil
}
type connHandlers struct {
disconnectErr natsgo.ConnErrHandler
reconnect natsgo.ConnHandler
closed natsgo.ConnHandler
errH natsgo.ErrHandler
}
// connectClient dials the embedded server's client listener over TCP
// loopback (or net.Pipe when opts.InProcess is true) and returns the
// resulting *natsgo.Conn. connName identifies the connection in server
// logs.
func connectClient(ns *natsserver.Server, opts Options, handlers connHandlers, connName string) (*natsgo.Conn, error) {
connOpts := []natsgo.Option{
natsgo.Name(connName),
}
if opts.ReconnectWait > 0 {
connOpts = append(connOpts, natsgo.ReconnectWait(opts.ReconnectWait))
}
if handlers.disconnectErr != nil {
connOpts = append(connOpts, natsgo.DisconnectErrHandler(handlers.disconnectErr))
}
if handlers.reconnect != nil {
connOpts = append(connOpts, natsgo.ReconnectHandler(handlers.reconnect))
}
if handlers.closed != nil {
connOpts = append(connOpts, natsgo.ClosedHandler(handlers.closed))
}
if handlers.errH != nil {
connOpts = append(connOpts, natsgo.ErrorHandler(handlers.errH))
}
url := ns.ClientURL()
if opts.InProcess {
// InProcessServer overrides URL dialing with a net.Pipe; the
// url argument is ignored but must still be syntactically valid.
connOpts = append(connOpts, natsgo.InProcessServer(ns))
}
nc, err := natsgo.Connect(url, connOpts...)
if err != nil {
return nil, xerrors.Errorf("connect client: %w", err)
}
return nc, nil
}