mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
feat: add initial NATS implementation (#25602)
This commit is contained in:
@@ -0,0 +1,652 @@
|
||||
package nats
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
natsserver "github.com/nats-io/nats-server/v2/server"
|
||||
natsgo "github.com/nats-io/nats.go"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
)
|
||||
|
||||
// DefaultMaxPending is the per-client outbound pending byte budget.
|
||||
const DefaultMaxPending int64 = 128 << 20
|
||||
|
||||
var errClosed = xerrors.New("nats pubsub closed")
|
||||
|
||||
// PendingLimits configures per-subscription NATS pending limits set
|
||||
// via SetPendingLimits on each *natsgo.Subscription.
|
||||
type PendingLimits struct {
|
||||
// Msgs is the per-subscription pending message limit. Positive
|
||||
// values also set each local listener queue capacity.
|
||||
// Zero uses the package default. Negative disables this limit.
|
||||
Msgs int
|
||||
|
||||
// Bytes is the per-subscription pending byte limit.
|
||||
// Zero uses the package default. Negative disables this limit.
|
||||
Bytes int
|
||||
}
|
||||
|
||||
// Options configures the embedded NATS Pubsub.
|
||||
type Options struct {
|
||||
// MaxPayload is the NATS max payload. Zero means server default.
|
||||
MaxPayload int32
|
||||
|
||||
// MaxPending is the per-client outbound pending byte budget on the
|
||||
// embedded server. Zero or negative means the package default,
|
||||
// 128 MiB.
|
||||
MaxPending int64
|
||||
|
||||
// PendingLimits configures per-subscription NATS pending limits.
|
||||
// Positive Msgs also sets local listener queue capacity.
|
||||
// Zero fields use package defaults: Msgs -1 and Bytes 512 MiB.
|
||||
PendingLimits PendingLimits
|
||||
|
||||
// ReconnectWait controls client reconnect delay. Zero keeps the
|
||||
// NATS default.
|
||||
ReconnectWait time.Duration
|
||||
|
||||
// InProcess, when true, uses nats.InProcessServer instead of TCP
|
||||
// loopback. Intended for benchmarks and tests.
|
||||
InProcess bool
|
||||
|
||||
// PublishConns is the number of publisher connections. Each Publish
|
||||
// is routed by a stable hash of the subject. Zero or negative means 1.
|
||||
PublishConns int
|
||||
|
||||
// SubscribeConns is the number of subscriber connections. Each
|
||||
// shared subscription is pinned to one connection by a stable hash
|
||||
// of its subject. Zero or negative means 1.
|
||||
SubscribeConns int
|
||||
}
|
||||
|
||||
// Pubsub is an embedded NATS-backed implementation of pubsub.Pubsub.
|
||||
//
|
||||
// Each Pubsub owns one embedded server, a pool of publisher
|
||||
// *natsgo.Conns (Options.PublishConns) and a pool of subscriber
|
||||
// *natsgo.Conns (Options.SubscribeConns). Publishes and shared
|
||||
// subscriptions are pinned to a connection by a stable hash of the
|
||||
// subject, so same-subject traffic preserves per-subject ordering and
|
||||
// every local subscriber for a subject coalesces onto one underlying
|
||||
// *natsgo.Subscription.
|
||||
type Pubsub struct {
|
||||
mu sync.Mutex
|
||||
|
||||
logger slog.Logger
|
||||
opts Options
|
||||
|
||||
ns *natsserver.Server
|
||||
// publishPool and subscribePool are immutable after construction so
|
||||
// the hot path can index without holding p.mu.
|
||||
publishPool []*natsgo.Conn
|
||||
subscribePool []*natsgo.Conn
|
||||
|
||||
// subscriptions coalesces concurrent local subscribers on the
|
||||
// same subject onto a single underlying *natsgo.Subscription.
|
||||
subscriptions map[string]*natsSub
|
||||
closeOnce sync.Once
|
||||
|
||||
// ctx is canceled by Close while holding p.mu so subscriber state
|
||||
// cleanup observes the canceled context.
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// natsSub maps to one underlying *natsgo.Subscription. The first
|
||||
// local subscriber creates it; later local subscribers attach to it.
|
||||
// When the last local subscriber detaches, the NATS subscription is
|
||||
// unsubscribed.
|
||||
type natsSub struct {
|
||||
// sub is set before this natsSub is published in Pubsub.subscriptions
|
||||
// and is immutable after that.
|
||||
sub *natsgo.Subscription
|
||||
|
||||
// mu guards localSubs.
|
||||
mu sync.Mutex
|
||||
// localSubs are the local subscribers attached to this NATS subscription.
|
||||
localSubs map[*localSub]struct{}
|
||||
|
||||
// dropMu keeps async error accounting independent from listener fan-out.
|
||||
dropMu sync.Mutex
|
||||
// lastDropped is the cumulative NATS dropped count last reported locally.
|
||||
lastDropped uint64
|
||||
}
|
||||
|
||||
// localSub is the local handle returned by Subscribe /
|
||||
// SubscribeWithErr. Each local subscriber gets its own bounded inbox
|
||||
// and dispatcher goroutine so one slow listener cannot block peers on
|
||||
// the same subject.
|
||||
type localSub struct {
|
||||
cancelOnce sync.Once
|
||||
|
||||
ctx context.Context
|
||||
|
||||
event string
|
||||
listener pubsub.ListenerWithErr
|
||||
|
||||
// queue is the per-listener data fan-out inbox. The shared NATS
|
||||
// callback enqueues non-blockingly; on overflow the message is
|
||||
// dropped and a drop signal is raised.
|
||||
queue chan []byte
|
||||
// dropSignal is a size-1 buffered channel that coalesces drop
|
||||
// notifications from local overflow and NATS slow-consumer
|
||||
// broadcasts onto a single pending wake.
|
||||
dropSignal chan struct{}
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// Compile-time assertion that *Pubsub satisfies the pubsub.Pubsub interface.
|
||||
var _ pubsub.Pubsub = (*Pubsub)(nil)
|
||||
|
||||
// newPubsub allocates a *Pubsub with initialized maps and cancel ctx.
|
||||
func newPubsub(ctx context.Context, logger slog.Logger, opts Options) *Pubsub {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
return &Pubsub{
|
||||
logger: logger,
|
||||
opts: opts,
|
||||
subscriptions: make(map[string]*natsSub),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
// defaultPendingLimits returns the effective per-subscription pending
|
||||
// limits applied at Subscribe time.
|
||||
func defaultPendingLimits(in PendingLimits) PendingLimits {
|
||||
out := in
|
||||
if out.Msgs == 0 {
|
||||
out.Msgs = -1
|
||||
}
|
||||
if out.Bytes == 0 {
|
||||
out.Bytes = 512 * 1024 * 1024
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// buildConnHandlers returns the connHandlers stack installed on every
|
||||
// owned connection. Handlers close over p so slow-consumer routing
|
||||
// keeps working.
|
||||
func (p *Pubsub) buildConnHandlers() connHandlers {
|
||||
return connHandlers{
|
||||
disconnectErr: func(conn *natsgo.Conn, err error) {
|
||||
if err != nil {
|
||||
p.logger.Warn(p.ctx, "nats client disconnected", slog.Error(err))
|
||||
}
|
||||
p.signalSubscribersDroppedForConn(conn)
|
||||
},
|
||||
reconnect: func(_ *natsgo.Conn) {
|
||||
p.logger.Info(p.ctx, "nats client reconnected")
|
||||
},
|
||||
closed: func(_ *natsgo.Conn) {
|
||||
p.logger.Debug(p.ctx, "nats client closed")
|
||||
},
|
||||
errH: func(_ *natsgo.Conn, sub *natsgo.Subscription, err error) {
|
||||
if err != nil && errors.Is(err, natsgo.ErrSlowConsumer) {
|
||||
p.handleAsyncError(sub, err)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
p.logger.Warn(p.ctx, "nats async error", slog.Error(err))
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// New creates an embedded NATS Pubsub. The returned *Pubsub owns the
|
||||
// embedded server and the publisher and subscriber connection pools.
|
||||
// Close shuts down all owned resources.
|
||||
func New(ctx context.Context, logger slog.Logger, opts Options) (*Pubsub, error) {
|
||||
ns, err := startEmbeddedServer(logger, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p := newPubsub(ctx, logger, opts)
|
||||
p.ns = ns
|
||||
handlers := p.buildConnHandlers()
|
||||
|
||||
publishPool, err := newConnPool(ns, opts, handlers, opts.PublishConns, "coder-pubsub-pub")
|
||||
if err != nil {
|
||||
p.cancel()
|
||||
ns.Shutdown()
|
||||
ns.WaitForShutdown()
|
||||
return nil, err
|
||||
}
|
||||
subscribePool, err := newConnPool(ns, opts, handlers, opts.SubscribeConns, "coder-pubsub-sub")
|
||||
if err != nil {
|
||||
p.cancel()
|
||||
for _, c := range publishPool {
|
||||
c.Close()
|
||||
}
|
||||
ns.Shutdown()
|
||||
ns.WaitForShutdown()
|
||||
return nil, err
|
||||
}
|
||||
p.publishPool = publishPool
|
||||
p.subscribePool = subscribePool
|
||||
go func() {
|
||||
<-p.ctx.Done()
|
||||
_ = p.Close()
|
||||
}()
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func newConnPool(ns *natsserver.Server, opts Options, handlers connHandlers, count int, clientName string) ([]*natsgo.Conn, error) {
|
||||
if count <= 0 {
|
||||
count = 1
|
||||
}
|
||||
pool := make([]*natsgo.Conn, 0, count)
|
||||
for i := 0; i < count; i++ {
|
||||
// Suffix names when the pool has more than one entry so server
|
||||
// logs can distinguish connections.
|
||||
name := clientName
|
||||
if count > 1 {
|
||||
name = fmt.Sprintf("%s-%d", clientName, i)
|
||||
}
|
||||
nc, err := connectClient(ns, opts, handlers, name)
|
||||
if err != nil {
|
||||
for _, c := range pool {
|
||||
c.Close()
|
||||
}
|
||||
return nil, xerrors.Errorf("dial conn: %w", err)
|
||||
}
|
||||
pool = append(pool, nc)
|
||||
}
|
||||
return pool, nil
|
||||
}
|
||||
|
||||
// Publish publishes a message under the given event name. The
|
||||
// publisher connection is selected by a stable hash of the subject so
|
||||
// same-subject publishes preserve per-subject ordering.
|
||||
func (p *Pubsub) Publish(event string, message []byte) error {
|
||||
if p.ctx.Err() != nil {
|
||||
return errClosed
|
||||
}
|
||||
|
||||
if err := pickConn(p.publishPool, event).Publish(event, message); err != nil {
|
||||
return xerrors.Errorf("publish: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Flush blocks until every publisher connection has flushed buffered
|
||||
// publishes to the embedded server. Returns the first error
|
||||
// encountered; remaining connections are still flushed.
|
||||
func (p *Pubsub) Flush() error {
|
||||
if p.ctx.Err() != nil {
|
||||
return errClosed
|
||||
}
|
||||
|
||||
var firstErr error
|
||||
for i, nc := range p.publishPool {
|
||||
if err := nc.Flush(); err != nil && firstErr == nil {
|
||||
firstErr = xerrors.Errorf("flush pub conn %d: %w", i, err)
|
||||
}
|
||||
}
|
||||
return firstErr
|
||||
}
|
||||
|
||||
// Subscribe subscribes a Listener to the given event name. Errors
|
||||
// such as ErrDroppedMessages are silently ignored, mirroring the
|
||||
// legacy pubsub Listener semantics.
|
||||
func (p *Pubsub) Subscribe(event string, listener pubsub.Listener) (cancel func(), err error) {
|
||||
return p.SubscribeWithErr(event, func(ctx context.Context, msg []byte, err error) {
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
listener(ctx, msg)
|
||||
})
|
||||
}
|
||||
|
||||
// SubscribeWithErr subscribes a ListenerWithErr to the given event
|
||||
// name. The listener also receives error deliveries such as
|
||||
// pubsub.ErrDroppedMessages. Multiple local subscribers on the same
|
||||
// event share a single underlying *natsgo.Subscription with
|
||||
// per-listener bounded inboxes so a slow listener cannot block its
|
||||
// peers.
|
||||
func (p *Pubsub) SubscribeWithErr(event string, listener pubsub.ListenerWithErr) (cancel func(), err error) {
|
||||
s, err := p.addSubscriber(event, listener)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cancelFn := func() {
|
||||
s.close()
|
||||
p.unsubscribeLocal(s)
|
||||
}
|
||||
return cancelFn, nil
|
||||
}
|
||||
|
||||
// listenerQueueSize returns the per-listener inbox capacity. A
|
||||
// positive PendingLimits.Msgs sets the cap (giving callers a knob to
|
||||
// trigger local-overflow drops since coalescing makes NATS-level
|
||||
// slow-consumer signals rare). Otherwise the default is used.
|
||||
func listenerQueueSize(in PendingLimits) int {
|
||||
if in.Msgs > 0 {
|
||||
return in.Msgs
|
||||
}
|
||||
return defaultListenerQueueSize
|
||||
}
|
||||
|
||||
const defaultListenerQueueSize = 1024
|
||||
|
||||
// addSubscriber creates a local subscriber and attaches it to the natsSub
|
||||
// for event. New natsSub entries are published only after NATS setup succeeds.
|
||||
func (p *Pubsub) addSubscriber(event string, listener pubsub.ListenerWithErr) (*localSub, error) {
|
||||
ctx, cancel := context.WithCancel(p.ctx)
|
||||
s := &localSub{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
event: event,
|
||||
listener: listener,
|
||||
queue: make(chan []byte, listenerQueueSize(p.opts.PendingLimits)),
|
||||
dropSignal: make(chan struct{}, 1),
|
||||
}
|
||||
s.init()
|
||||
|
||||
cleanupSub, err := func() (*natsgo.Subscription, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if p.ctx.Err() != nil {
|
||||
return nil, errClosed
|
||||
}
|
||||
|
||||
nsub, ok := p.subscriptions[event]
|
||||
if ok {
|
||||
nsub.mu.Lock()
|
||||
nsub.localSubs[s] = struct{}{}
|
||||
nsub.mu.Unlock()
|
||||
return nsub.sub, nil
|
||||
}
|
||||
|
||||
nsub = &natsSub{
|
||||
localSubs: map[*localSub]struct{}{
|
||||
s: {},
|
||||
},
|
||||
}
|
||||
|
||||
subConn := pickConn(p.subscribePool, event)
|
||||
natsSubscription, err := subConn.Subscribe(event, nsub.handleMessage)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("subscribe: %w", err)
|
||||
}
|
||||
nsub.sub = natsSubscription
|
||||
|
||||
// Flush the SUB to the server so a publish issued immediately
|
||||
// after Subscribe returns cannot race ahead of registration.
|
||||
if err := subConn.Flush(); err != nil {
|
||||
return natsSubscription, xerrors.Errorf("flush subscribe: %w", err)
|
||||
}
|
||||
limits := defaultPendingLimits(p.opts.PendingLimits)
|
||||
if err := natsSubscription.SetPendingLimits(limits.Msgs, limits.Bytes); err != nil {
|
||||
return natsSubscription, xerrors.Errorf("set pending limits: %w", err)
|
||||
}
|
||||
|
||||
p.subscriptions[event] = nsub
|
||||
return natsSubscription, nil
|
||||
}()
|
||||
if err != nil {
|
||||
s.close()
|
||||
if cleanupSub != nil {
|
||||
if unsubscribeErr := cleanupSub.Unsubscribe(); unsubscribeErr != nil {
|
||||
err = errors.Join(err, xerrors.Errorf("unsubscribe: %w", unsubscribeErr))
|
||||
}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// unsubscribeLocal removes s from its natsSub. If s was the last
|
||||
// listener, it also removes and unsubscribes the underlying NATS
|
||||
// subscription.
|
||||
func (p *Pubsub) unsubscribeLocal(s *localSub) {
|
||||
natsSub := func() *natsgo.Subscription {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
nsub := p.subscriptions[s.event]
|
||||
if nsub == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
nsub.mu.Lock()
|
||||
defer nsub.mu.Unlock()
|
||||
if _, tracked := nsub.localSubs[s]; !tracked {
|
||||
return nil
|
||||
}
|
||||
delete(nsub.localSubs, s)
|
||||
if len(nsub.localSubs) > 0 {
|
||||
return nil
|
||||
}
|
||||
// Last listener: remove the nsub entry so a new Subscribe to this
|
||||
// subject creates a fresh underlying subscription.
|
||||
delete(p.subscriptions, s.event)
|
||||
return nsub.sub
|
||||
}()
|
||||
if natsSub != nil {
|
||||
_ = natsSub.Unsubscribe()
|
||||
}
|
||||
}
|
||||
|
||||
// handleMessage handles messages for the shared subscription. Each
|
||||
// enqueue is non-blocking and does not call user code, so one slow
|
||||
// listener cannot stall the NATS delivery goroutine.
|
||||
//
|
||||
// Zero-copy fan-out: the same msg.Data slice is delivered to every
|
||||
// local listener without cloning. Listeners on a coalesced subject MUST
|
||||
// treat the delivered bytes as immutable.
|
||||
func (nsub *natsSub) handleMessage(msg *natsgo.Msg) {
|
||||
nsub.mu.Lock()
|
||||
defer nsub.mu.Unlock()
|
||||
|
||||
for s := range nsub.localSubs {
|
||||
s.enqueue(msg.Data)
|
||||
}
|
||||
}
|
||||
|
||||
// init starts the per-listener delivery goroutine.
|
||||
func (s *localSub) init() {
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
case data := <-s.queue:
|
||||
s.listener(s.ctx, data, nil)
|
||||
case <-s.dropSignal:
|
||||
s.listener(s.ctx, nil, pubsub.ErrDroppedMessages)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// close cancels local delivery without waiting for callbacks.
|
||||
func (s *localSub) close() {
|
||||
s.cancelOnce.Do(func() {
|
||||
if s.cancel != nil {
|
||||
s.cancel()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// enqueue non-blockingly sends data onto s.queue. On overflow it drops the
|
||||
// message and raises a drop signal so pubsub.ErrDroppedMessages is surfaced.
|
||||
// If s is canceled the message is silently dropped.
|
||||
func (s *localSub) enqueue(data []byte) {
|
||||
select {
|
||||
case s.queue <- data:
|
||||
default:
|
||||
s.signalDrop()
|
||||
}
|
||||
}
|
||||
|
||||
// signalDrop pushes onto dropSignal without blocking. Multiple drops
|
||||
// between dispatcher dequeues coalesce into a single pending signal, so
|
||||
// the listener observes one ErrDroppedMessages per drop wave.
|
||||
func (s *localSub) signalDrop() {
|
||||
select {
|
||||
case s.dropSignal <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// signalSubscribersDroppedForConn signals local subscribers assigned to conn.
|
||||
func (p *Pubsub) signalSubscribersDroppedForConn(conn *natsgo.Conn) {
|
||||
if conn == nil || len(p.subscribePool) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
subs := make([]*localSub, 0)
|
||||
for event, nsub := range p.subscriptions {
|
||||
if pickConn(p.subscribePool, event) != conn {
|
||||
continue
|
||||
}
|
||||
nsub.mu.Lock()
|
||||
for s := range nsub.localSubs {
|
||||
subs = append(subs, s)
|
||||
}
|
||||
nsub.mu.Unlock()
|
||||
}
|
||||
p.mu.Unlock()
|
||||
|
||||
for _, s := range subs {
|
||||
s.signalDrop()
|
||||
}
|
||||
}
|
||||
|
||||
// handleAsyncError routes async error callbacks. Only slow-consumer
|
||||
// errors trigger drop accounting.
|
||||
func (p *Pubsub) handleAsyncError(sub *natsgo.Subscription, err error) {
|
||||
if sub == nil || !errors.Is(err, natsgo.ErrSlowConsumer) {
|
||||
return
|
||||
}
|
||||
p.mu.Lock()
|
||||
var nsub *natsSub
|
||||
for _, candidate := range p.subscriptions {
|
||||
if candidate.sub == sub {
|
||||
nsub = candidate
|
||||
break
|
||||
}
|
||||
}
|
||||
p.mu.Unlock()
|
||||
if nsub == nil {
|
||||
return
|
||||
}
|
||||
p.handleSlowSubscriber(nsub)
|
||||
}
|
||||
|
||||
// handleSlowSubscriber broadcasts pubsub.ErrDroppedMessages to every
|
||||
// local listener on nsub when NATS reports a new drop delta. The
|
||||
// slow-consumer signal is per-subscription and cannot be narrowed to a
|
||||
// single local listener.
|
||||
func (p *Pubsub) handleSlowSubscriber(nsub *natsSub) {
|
||||
nsub.dropMu.Lock()
|
||||
dropped, err := nsub.sub.Dropped()
|
||||
if err != nil {
|
||||
nsub.dropMu.Unlock()
|
||||
p.logger.Warn(p.ctx, "nats: query dropped count", slog.Error(err))
|
||||
return
|
||||
}
|
||||
if dropped < 0 {
|
||||
nsub.dropMu.Unlock()
|
||||
p.logger.Warn(p.ctx, "nats: negative dropped count")
|
||||
return
|
||||
}
|
||||
// Dropped is cumulative per subscription; signal only new drops.
|
||||
droppedCount := uint64(dropped)
|
||||
if droppedCount < nsub.lastDropped {
|
||||
nsub.lastDropped = droppedCount
|
||||
nsub.dropMu.Unlock()
|
||||
return
|
||||
}
|
||||
if droppedCount == nsub.lastDropped {
|
||||
nsub.dropMu.Unlock()
|
||||
return
|
||||
}
|
||||
nsub.lastDropped = droppedCount
|
||||
nsub.dropMu.Unlock()
|
||||
|
||||
nsub.mu.Lock()
|
||||
defer nsub.mu.Unlock()
|
||||
|
||||
for s := range nsub.localSubs {
|
||||
s.signalDrop()
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops local delivery and shuts down the Pubsub. It is idempotent.
|
||||
// Close does not drain queued listener messages.
|
||||
func (p *Pubsub) Close() error {
|
||||
p.closeOnce.Do(func() {
|
||||
p.mu.Lock()
|
||||
// Cancel while holding p.mu so subscriber state cleanup below
|
||||
// observes the canceled context.
|
||||
p.cancel()
|
||||
var subs []*localSub
|
||||
shareds := make([]*natsSub, 0, len(p.subscriptions))
|
||||
for _, ss := range p.subscriptions {
|
||||
shareds = append(shareds, ss)
|
||||
ss.mu.Lock()
|
||||
for s := range ss.localSubs {
|
||||
subs = append(subs, s)
|
||||
delete(ss.localSubs, s)
|
||||
}
|
||||
ss.mu.Unlock()
|
||||
}
|
||||
clear(p.subscriptions)
|
||||
p.mu.Unlock()
|
||||
|
||||
// Unsubscribe shared subscriptions before closing connections.
|
||||
for _, ss := range shareds {
|
||||
if ss.sub != nil {
|
||||
_ = ss.sub.Unsubscribe()
|
||||
}
|
||||
}
|
||||
|
||||
// Signal per-listener goroutines without waiting for callbacks.
|
||||
for _, s := range subs {
|
||||
s.close()
|
||||
}
|
||||
|
||||
for _, nc := range p.subscribePool {
|
||||
if nc != nil {
|
||||
nc.Close()
|
||||
}
|
||||
}
|
||||
for _, nc := range p.publishPool {
|
||||
if nc != nil {
|
||||
nc.Close()
|
||||
}
|
||||
}
|
||||
|
||||
if p.ns != nil {
|
||||
p.ns.Shutdown()
|
||||
p.ns.WaitForShutdown()
|
||||
}
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// pickConn returns the connection assigned to subject. Selection uses
|
||||
// a stable FNV-1a hash so same-subject traffic always targets the same
|
||||
// connection within a process; pools are immutable after construction
|
||||
// so the lookup is lock-free.
|
||||
func pickConn(pool []*natsgo.Conn, subject string) *natsgo.Conn {
|
||||
if len(pool) == 1 {
|
||||
return pool[0]
|
||||
}
|
||||
h := fnv.New32a()
|
||||
_, _ = h.Write([]byte(subject))
|
||||
n := uint32(len(pool)) //nolint:gosec // pool size bounded by Options.{Publish,Subscribe}Conns
|
||||
return pool[h.Sum32()%n]
|
||||
}
|
||||
@@ -0,0 +1,339 @@
|
||||
package nats //nolint:testpackage // Exercises internal pubsub state and helpers.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
natsgo "github.com/nats-io/nats.go"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func Test_defaultPendingLimits(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const defaultBytes = 512 * 1024 * 1024
|
||||
testCases := []struct {
|
||||
name string
|
||||
in PendingLimits
|
||||
want PendingLimits
|
||||
}{
|
||||
{
|
||||
name: "AllZero",
|
||||
in: PendingLimits{},
|
||||
want: PendingLimits{Msgs: -1, Bytes: defaultBytes},
|
||||
},
|
||||
{
|
||||
name: "MsgsOnly",
|
||||
in: PendingLimits{Msgs: 8},
|
||||
want: PendingLimits{Msgs: 8, Bytes: defaultBytes},
|
||||
},
|
||||
{
|
||||
name: "BytesOnly",
|
||||
in: PendingLimits{Bytes: 1024},
|
||||
want: PendingLimits{Msgs: -1, Bytes: 1024},
|
||||
},
|
||||
{
|
||||
name: "NegativeMsgs",
|
||||
in: PendingLimits{Msgs: -2},
|
||||
want: PendingLimits{Msgs: -2, Bytes: defaultBytes},
|
||||
},
|
||||
{
|
||||
name: "NegativeBytes",
|
||||
in: PendingLimits{Bytes: -2},
|
||||
want: PendingLimits{Msgs: -1, Bytes: -2},
|
||||
},
|
||||
{
|
||||
name: "NegativeBoth",
|
||||
in: PendingLimits{Msgs: -2, Bytes: -3},
|
||||
want: PendingLimits{Msgs: -2, Bytes: -3},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Equal(t, tc.want, defaultPendingLimits(tc.in))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_pickConn(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("DifferentSubjects", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var a, b natsgo.Conn
|
||||
pool := []*natsgo.Conn{&a, &b}
|
||||
|
||||
require.NotSame(t, pickConn(pool, "a"), pickConn(pool, "b"))
|
||||
})
|
||||
}
|
||||
|
||||
func subjectForConn(t *testing.T, pool []*natsgo.Conn, conn *natsgo.Conn, prefix string) string {
|
||||
t.Helper()
|
||||
|
||||
for i := 0; i < 10_000; i++ {
|
||||
subject := fmt.Sprintf("%s_%d", prefix, i)
|
||||
if pickConn(pool, subject) == conn {
|
||||
return subject
|
||||
}
|
||||
}
|
||||
require.FailNow(t, "no subject matched requested connection")
|
||||
return ""
|
||||
}
|
||||
|
||||
func Test_New(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("ConnectionCount", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancel()
|
||||
ps, err := New(ctx, logger, Options{})
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { _ = ps.Close() })
|
||||
|
||||
const n = 50
|
||||
cancels := make([]func(), 0, n)
|
||||
for i := 0; i < n; i++ {
|
||||
c, err := ps.Subscribe(fmt.Sprintf("cc_evt_%d", i), func(context.Context, []byte) {})
|
||||
require.NoError(t, err)
|
||||
cancels = append(cancels, c)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
for _, c := range cancels {
|
||||
c()
|
||||
}
|
||||
})
|
||||
|
||||
require.Equal(t, 2, ps.ns.NumClients(),
|
||||
"expected exactly 2 client connections (pubConn + subConn), got %d", ps.ns.NumClients())
|
||||
require.Len(t, ps.publishPool, 1, "default PublishConns must be 1")
|
||||
require.Len(t, ps.subscribePool, 1, "default SubscribeConns must be 1")
|
||||
require.NotSame(t, ps.publishPool[0], ps.subscribePool[0], "pubConn and subConn must be distinct")
|
||||
})
|
||||
}
|
||||
|
||||
func Test_SubscribeWithErr(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("SameSubjectSharesSubscription", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancel()
|
||||
ps, err := New(ctx, logger, Options{})
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { _ = ps.Close() })
|
||||
|
||||
cancelA, err := ps.Subscribe("coalesce_evt", func(context.Context, []byte) {})
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(cancelA)
|
||||
cancelB, err := ps.Subscribe("coalesce_evt", func(context.Context, []byte) {})
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(cancelB)
|
||||
|
||||
ps.mu.Lock()
|
||||
defer ps.mu.Unlock()
|
||||
require.Len(t, ps.subscriptions, 1)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_Pubsub_buildConnHandlers(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("DisconnectSignalsDropsForMatchingSubscriberConn", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
ps := newPubsub(ctx, logger, Options{})
|
||||
|
||||
var subConnA, subConnB, pubConn natsgo.Conn
|
||||
ps.subscribePool = []*natsgo.Conn{&subConnA, &subConnB}
|
||||
matchingEvent := subjectForConn(t, ps.subscribePool, &subConnA, "disconnect_match")
|
||||
otherEvent := subjectForConn(t, ps.subscribePool, &subConnB, "disconnect_other")
|
||||
|
||||
newLocal := func(event string) *localSub {
|
||||
return &localSub{
|
||||
event: event,
|
||||
dropSignal: make(chan struct{}, 1),
|
||||
}
|
||||
}
|
||||
|
||||
matchingSub := newLocal(matchingEvent)
|
||||
otherSub := newLocal(otherEvent)
|
||||
ps.subscriptions[matchingSub.event] = &natsSub{localSubs: map[*localSub]struct{}{matchingSub: {}}}
|
||||
ps.subscriptions[otherSub.event] = &natsSub{localSubs: map[*localSub]struct{}{otherSub: {}}}
|
||||
|
||||
handlers := ps.buildConnHandlers()
|
||||
handlers.disconnectErr(&subConnA, xerrors.New("disconnect"))
|
||||
|
||||
select {
|
||||
case <-matchingSub.dropSignal:
|
||||
default:
|
||||
require.Fail(t, "matching subscriber did not receive drop signal")
|
||||
}
|
||||
select {
|
||||
case <-otherSub.dropSignal:
|
||||
require.Fail(t, "non-matching subscriber received drop signal")
|
||||
default:
|
||||
}
|
||||
|
||||
handlers.disconnectErr(&pubConn, xerrors.New("publisher disconnect"))
|
||||
select {
|
||||
case <-otherSub.dropSignal:
|
||||
require.Fail(t, "publisher connection disconnect signaled subscriber")
|
||||
default:
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func Test_localSub_init(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("SerializesCallbacks", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
dataStarted := make(chan struct{})
|
||||
dropDelivered := make(chan struct{})
|
||||
release := make(chan struct{})
|
||||
var dataOnce sync.Once
|
||||
var dropOnce sync.Once
|
||||
var releaseOnce sync.Once
|
||||
var active atomic.Int64
|
||||
var concurrent atomic.Bool
|
||||
|
||||
s := &localSub{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
listener: func(_ context.Context, _ []byte, ferr error) {
|
||||
if active.Add(1) != 1 {
|
||||
concurrent.Store(true)
|
||||
}
|
||||
defer active.Add(-1)
|
||||
|
||||
if errors.Is(ferr, pubsub.ErrDroppedMessages) {
|
||||
dropOnce.Do(func() { close(dropDelivered) })
|
||||
return
|
||||
}
|
||||
|
||||
dataOnce.Do(func() { close(dataStarted) })
|
||||
<-release
|
||||
},
|
||||
queue: make(chan []byte, 1),
|
||||
dropSignal: make(chan struct{}, 1),
|
||||
}
|
||||
s.init()
|
||||
t.Cleanup(func() {
|
||||
releaseOnce.Do(func() { close(release) })
|
||||
s.close()
|
||||
})
|
||||
|
||||
s.enqueue([]byte("data"))
|
||||
require.Eventually(t, func() bool {
|
||||
select {
|
||||
case <-dataStarted:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
|
||||
s.signalDrop()
|
||||
require.Never(t, func() bool {
|
||||
select {
|
||||
case <-dropDelivered:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, testutil.IntervalMedium, testutil.IntervalFast,
|
||||
"drop callback must wait for the blocked data callback")
|
||||
require.False(t, concurrent.Load(), "listener callback ran concurrently")
|
||||
|
||||
releaseOnce.Do(func() { close(release) })
|
||||
require.Eventually(t, func() bool {
|
||||
select {
|
||||
case <-dropDelivered:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
require.False(t, concurrent.Load(), "listener callback ran concurrently")
|
||||
})
|
||||
|
||||
t.Run("CrossSubjectListenerIsolation", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
ps, err := New(ctx, logger, Options{})
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { _ = ps.Close() })
|
||||
|
||||
release := make(chan struct{})
|
||||
var releaseOnce sync.Once
|
||||
var slowDrops atomic.Int64
|
||||
var slowBlocked atomic.Bool
|
||||
slowCancel, err := ps.SubscribeWithErr("iso_slow", func(_ context.Context, _ []byte, ferr error) {
|
||||
if ferr != nil && errors.Is(ferr, pubsub.ErrDroppedMessages) {
|
||||
slowDrops.Add(1)
|
||||
return
|
||||
}
|
||||
if slowBlocked.CompareAndSwap(false, true) {
|
||||
<-release
|
||||
}
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer slowCancel()
|
||||
|
||||
var fastCount atomic.Int64
|
||||
fastCancel, err := ps.Subscribe("iso_fast", func(_ context.Context, _ []byte) {
|
||||
fastCount.Add(1)
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer fastCancel()
|
||||
defer releaseOnce.Do(func() { close(release) })
|
||||
|
||||
total := defaultListenerQueueSize + 256
|
||||
payload := make([]byte, 4*1024)
|
||||
for i := 0; i < total; i++ {
|
||||
require.NoError(t, ps.Publish("iso_slow", payload))
|
||||
require.NoError(t, ps.Publish("iso_fast", []byte("ping")))
|
||||
}
|
||||
require.NoError(t, ps.Flush())
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
return fastCount.Load() >= int64(total)
|
||||
}, testutil.WaitLong, testutil.IntervalFast)
|
||||
require.Zero(t, slowDrops.Load(),
|
||||
"drop callback must wait for the blocked data callback")
|
||||
releaseOnce.Do(func() { close(release) })
|
||||
require.Eventually(t, func() bool {
|
||||
return slowDrops.Load() >= 1
|
||||
}, testutil.WaitLong, testutil.IntervalFast,
|
||||
"slow subscriber must receive at least one ErrDroppedMessages signal")
|
||||
|
||||
require.GreaterOrEqual(t, fastCount.Load(), int64(total),
|
||||
"fast subscriber must keep receiving despite slow peer on shared subConn")
|
||||
require.Len(t, ps.subscribePool, 1)
|
||||
require.False(t, ps.subscribePool[0].IsClosed(), "subConn must not be closed by slow consumer")
|
||||
require.True(t, ps.subscribePool[0].IsConnected(), "subConn must stay connected")
|
||||
require.Equal(t, 2, ps.ns.NumClients(), "slow consumer must not disconnect subConn")
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,199 @@
|
||||
package nats_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
xnats "github.com/coder/coder/v2/coderd/x/nats"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func newTestPubsub(t *testing.T, opts xnats.Options) *xnats.Pubsub {
|
||||
t.Helper()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ps, err := xnats.New(ctx, logger, opts)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = ps.Close()
|
||||
cancel()
|
||||
})
|
||||
return ps
|
||||
}
|
||||
|
||||
func TestPubsub(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("RoundTrip", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ps := newTestPubsub(t, xnats.Options{})
|
||||
|
||||
got := make(chan []byte, 1)
|
||||
cancel, err := ps.Subscribe("test_event", func(_ context.Context, msg []byte) {
|
||||
got <- msg
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer cancel()
|
||||
|
||||
require.NoError(t, ps.Publish("test_event", []byte("hello")))
|
||||
|
||||
select {
|
||||
case msg := <-got:
|
||||
assert.Equal(t, "hello", string(msg))
|
||||
case <-time.After(testutil.WaitShort):
|
||||
t.Fatal("timed out waiting for message")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SubscribeWithErrNormalMessage", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ps := newTestPubsub(t, xnats.Options{})
|
||||
|
||||
got := make(chan []byte, 1)
|
||||
cancel, err := ps.SubscribeWithErr("evt", func(_ context.Context, msg []byte, err error) {
|
||||
assert.NoError(t, err)
|
||||
got <- msg
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer cancel()
|
||||
|
||||
require.NoError(t, ps.Publish("evt", []byte("payload")))
|
||||
|
||||
select {
|
||||
case msg := <-got:
|
||||
assert.Equal(t, "payload", string(msg))
|
||||
case <-time.After(testutil.WaitShort):
|
||||
t.Fatal("timed out waiting for message")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("EchoDefault", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ps := newTestPubsub(t, xnats.Options{})
|
||||
|
||||
got := make(chan []byte, 1)
|
||||
cancel, err := ps.Subscribe("echo_evt", func(_ context.Context, msg []byte) {
|
||||
got <- msg
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer cancel()
|
||||
|
||||
require.NoError(t, ps.Publish("echo_evt", []byte("data")))
|
||||
|
||||
select {
|
||||
case msg := <-got:
|
||||
assert.Equal(t, "data", string(msg))
|
||||
case <-time.After(testutil.WaitShort):
|
||||
t.Fatal("default should echo own messages")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Ordering", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ps := newTestPubsub(t, xnats.Options{})
|
||||
|
||||
const n = 100
|
||||
got := make(chan []byte, n)
|
||||
cancel, err := ps.Subscribe("ord_evt", func(_ context.Context, msg []byte) {
|
||||
got <- msg
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer cancel()
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
require.NoError(t, ps.Publish("ord_evt", []byte(fmt.Sprintf("%d", i))))
|
||||
}
|
||||
|
||||
deadline := time.After(testutil.WaitLong)
|
||||
for i := 0; i < n; i++ {
|
||||
select {
|
||||
case msg := <-got:
|
||||
assert.Equal(t, fmt.Sprintf("%d", i), string(msg))
|
||||
case <-deadline:
|
||||
t.Fatalf("timed out at message %d/%d", i, n)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CloseIdempotent", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancel()
|
||||
ps, err := xnats.New(ctx, logger, xnats.Options{})
|
||||
require.NoError(t, err)
|
||||
|
||||
var first, second error
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
first = ps.Close()
|
||||
}()
|
||||
wg.Wait()
|
||||
second = ps.Close()
|
||||
assert.NoError(t, first)
|
||||
assert.NoError(t, second)
|
||||
})
|
||||
|
||||
t.Run("SubscribeWithErrReceivesDropError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ps := newTestPubsub(t, xnats.Options{
|
||||
PendingLimits: xnats.PendingLimits{Msgs: 1, Bytes: 1024 * 1024},
|
||||
})
|
||||
|
||||
const event = "slow_evt_sync"
|
||||
started := make(chan struct{})
|
||||
release := make(chan struct{})
|
||||
dropped := make(chan error, 1)
|
||||
var startedOnce sync.Once
|
||||
var releaseOnce sync.Once
|
||||
defer releaseOnce.Do(func() { close(release) })
|
||||
|
||||
cancel, err := ps.SubscribeWithErr(event, func(_ context.Context, _ []byte, err error) {
|
||||
if err != nil {
|
||||
select {
|
||||
case dropped <- err:
|
||||
default:
|
||||
}
|
||||
return
|
||||
}
|
||||
startedOnce.Do(func() {
|
||||
close(started)
|
||||
<-release
|
||||
})
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer cancel()
|
||||
|
||||
require.NoError(t, ps.Publish(event, []byte("first")))
|
||||
require.NoError(t, ps.Flush())
|
||||
select {
|
||||
case <-started:
|
||||
case <-time.After(testutil.WaitShort):
|
||||
t.Fatal("timed out waiting for first callback")
|
||||
}
|
||||
|
||||
for i := 0; i < 8; i++ {
|
||||
require.NoError(t, ps.Publish(event, []byte("burst")))
|
||||
}
|
||||
require.NoError(t, ps.Flush())
|
||||
releaseOnce.Do(func() { close(release) })
|
||||
|
||||
select {
|
||||
case err := <-dropped:
|
||||
assert.ErrorIs(t, err, pubsub.ErrDroppedMessages)
|
||||
case <-time.After(testutil.WaitLong):
|
||||
t.Fatal("timed out waiting for drop error")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,106 @@
|
||||
package nats
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
natsserver "github.com/nats-io/nats-server/v2/server"
|
||||
natsgo "github.com/nats-io/nats.go"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
)
|
||||
|
||||
const readyTimeout = 10 * time.Second
|
||||
|
||||
// buildServerOptions constructs the embedded NATS server options. The
|
||||
// server runs standalone with a loopback random client listener.
|
||||
func buildServerOptions(opts Options) (*natsserver.Options, error) {
|
||||
maxPayload := opts.MaxPayload
|
||||
if maxPayload == 0 {
|
||||
maxPayload = natsserver.MAX_PAYLOAD_SIZE
|
||||
}
|
||||
maxPending := opts.MaxPending
|
||||
if maxPending <= 0 {
|
||||
maxPending = DefaultMaxPending
|
||||
}
|
||||
|
||||
sopts := &natsserver.Options{
|
||||
JetStream: false,
|
||||
MaxPayload: maxPayload,
|
||||
MaxPending: maxPending,
|
||||
NoLog: true,
|
||||
NoSigs: true,
|
||||
}
|
||||
|
||||
sopts.DontListen = false
|
||||
sopts.Host = "127.0.0.1"
|
||||
sopts.Port = natsserver.RANDOM_PORT
|
||||
|
||||
return sopts, nil
|
||||
}
|
||||
|
||||
// startEmbeddedServer starts an in-process standalone NATS server.
|
||||
func startEmbeddedServer(logger slog.Logger, opts Options) (*natsserver.Server, error) {
|
||||
sopts, err := buildServerOptions(opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ns, err := natsserver.NewServer(sopts)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("new embedded nats server: %w", err)
|
||||
}
|
||||
go ns.Start()
|
||||
if !ns.ReadyForConnections(readyTimeout) {
|
||||
ns.Shutdown()
|
||||
ns.WaitForShutdown()
|
||||
return nil, xerrors.Errorf("embedded nats server not ready within %s", readyTimeout)
|
||||
}
|
||||
logger.Info(context.Background(), "embedded nats server started",
|
||||
slog.F("client_url", ns.ClientURL()),
|
||||
)
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
type connHandlers struct {
|
||||
disconnectErr natsgo.ConnErrHandler
|
||||
reconnect natsgo.ConnHandler
|
||||
closed natsgo.ConnHandler
|
||||
errH natsgo.ErrHandler
|
||||
}
|
||||
|
||||
// connectClient dials the embedded server's client listener over TCP
|
||||
// loopback (or net.Pipe when opts.InProcess is true) and returns the
|
||||
// resulting *natsgo.Conn. connName identifies the connection in server
|
||||
// logs.
|
||||
func connectClient(ns *natsserver.Server, opts Options, handlers connHandlers, connName string) (*natsgo.Conn, error) {
|
||||
connOpts := []natsgo.Option{
|
||||
natsgo.Name(connName),
|
||||
}
|
||||
if opts.ReconnectWait > 0 {
|
||||
connOpts = append(connOpts, natsgo.ReconnectWait(opts.ReconnectWait))
|
||||
}
|
||||
if handlers.disconnectErr != nil {
|
||||
connOpts = append(connOpts, natsgo.DisconnectErrHandler(handlers.disconnectErr))
|
||||
}
|
||||
if handlers.reconnect != nil {
|
||||
connOpts = append(connOpts, natsgo.ReconnectHandler(handlers.reconnect))
|
||||
}
|
||||
if handlers.closed != nil {
|
||||
connOpts = append(connOpts, natsgo.ClosedHandler(handlers.closed))
|
||||
}
|
||||
if handlers.errH != nil {
|
||||
connOpts = append(connOpts, natsgo.ErrorHandler(handlers.errH))
|
||||
}
|
||||
url := ns.ClientURL()
|
||||
if opts.InProcess {
|
||||
// InProcessServer overrides URL dialing with a net.Pipe; the
|
||||
// url argument is ignored but must still be syntactically valid.
|
||||
connOpts = append(connOpts, natsgo.InProcessServer(ns))
|
||||
}
|
||||
nc, err := natsgo.Connect(url, connOpts...)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("connect client: %w", err)
|
||||
}
|
||||
return nc, nil
|
||||
}
|
||||
Reference in New Issue
Block a user