mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
feat: add nats cluster peer support (#25632)
This commit is contained in:
@@ -0,0 +1,148 @@
|
||||
package nats
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/url"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
// SetPeerAddresses replaces the configured NATS cluster peer routes.
|
||||
func (p *Pubsub) SetPeerAddresses(addresses []string) error {
|
||||
p.clusterMu.Lock()
|
||||
defer p.clusterMu.Unlock()
|
||||
|
||||
if p.ctx.Err() != nil {
|
||||
return errClosed
|
||||
}
|
||||
if !p.clustered {
|
||||
return xerrors.New("nats pubsub was not started with clustering enabled")
|
||||
}
|
||||
|
||||
routes, err := parsePeerAddresses(addresses)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
self := &url.URL{Scheme: "nats", Host: p.ns.ClusterAddr().String()}
|
||||
routes = filterSelfRoutes(routes, self)
|
||||
routes = sortRouteURLs(routes)
|
||||
|
||||
if sortedURLsEqual(p.currentRoutes, routes) {
|
||||
return nil
|
||||
}
|
||||
|
||||
newOpts := p.serverOpts.Clone()
|
||||
newOpts.Routes = cloneRouteURLs(routes)
|
||||
if err := p.ns.ReloadOptions(newOpts); err != nil {
|
||||
return xerrors.Errorf("reload nats peer addresses: %w", err)
|
||||
}
|
||||
p.serverOpts = newOpts.Clone()
|
||||
p.currentRoutes = cloneRouteURLs(routes)
|
||||
return nil
|
||||
}
|
||||
|
||||
func parsePeerAddresses(addresses []string) ([]*url.URL, error) {
|
||||
routesByAddress := make(map[string]*url.URL, len(addresses))
|
||||
for i, address := range addresses {
|
||||
trimmed := strings.TrimSpace(address)
|
||||
if trimmed == "" {
|
||||
return nil, xerrors.Errorf("peer address %d is empty", i)
|
||||
}
|
||||
|
||||
normalizedHost, err := normalizeHostPort(trimmed)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
routesByAddress[normalizedHost] = &url.URL{
|
||||
Scheme: "nats",
|
||||
Host: normalizedHost,
|
||||
}
|
||||
}
|
||||
|
||||
routes := make([]*url.URL, 0, len(routesByAddress))
|
||||
for _, route := range routesByAddress {
|
||||
routes = append(routes, route)
|
||||
}
|
||||
return routes, nil
|
||||
}
|
||||
|
||||
func filterSelfRoutes(routes []*url.URL, self *url.URL) []*url.URL {
|
||||
filtered := make([]*url.URL, 0, len(routes))
|
||||
for _, route := range routes {
|
||||
if route.String() == self.String() {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, route)
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
func normalizeHostPort(address string) (string, error) {
|
||||
route, err := url.Parse(address)
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("parse peer address %q: %w", address, err)
|
||||
}
|
||||
if route.User != nil {
|
||||
return "", xerrors.Errorf("peer address %q must not include userinfo", address)
|
||||
}
|
||||
if route.Path != "" || route.RawQuery != "" || route.Fragment != "" {
|
||||
return "", xerrors.Errorf("peer address %q must not include path, query, or fragment", address)
|
||||
}
|
||||
|
||||
host, port, err := net.SplitHostPort(route.Host)
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("split %q host port: %w", address, err)
|
||||
}
|
||||
if host == "" || port == "" {
|
||||
return "", xerrors.Errorf("%q must include host and port", address)
|
||||
}
|
||||
|
||||
portNumber, err := strconv.Atoi(port)
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("parse %q port: %w", address, err)
|
||||
}
|
||||
if portNumber <= 0 || portNumber > 65535 {
|
||||
return "", xerrors.Errorf("peer address %q must include a valid port", address)
|
||||
}
|
||||
return net.JoinHostPort(host, strconv.Itoa(portNumber)), nil
|
||||
}
|
||||
|
||||
func sortRouteURLs(routes []*url.URL) []*url.URL {
|
||||
slices.SortFunc(routes, func(a, b *url.URL) int {
|
||||
return strings.Compare(a.String(), b.String())
|
||||
})
|
||||
return routes
|
||||
}
|
||||
|
||||
// sortedURLsEqual assumes sorted slices.
|
||||
func sortedURLsEqual(a, b []*url.URL) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i].String() != b[i].String() {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func cloneRouteURLs(routes []*url.URL) []*url.URL {
|
||||
if routes == nil {
|
||||
return nil
|
||||
}
|
||||
clones := make([]*url.URL, len(routes))
|
||||
for i, route := range routes {
|
||||
if route == nil {
|
||||
continue
|
||||
}
|
||||
clone := *route
|
||||
clones[i] = &clone
|
||||
}
|
||||
return clones
|
||||
}
|
||||
@@ -0,0 +1,134 @@
|
||||
package nats
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_parsePeerAddresses(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("Valid", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
routes, err := parsePeerAddresses([]string{
|
||||
"whatever://127.0.0.1:4222 ",
|
||||
"http://[::1]:7222",
|
||||
"nats://example.com:6222",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.ElementsMatch(t, []string{
|
||||
"nats://127.0.0.1:4222",
|
||||
"nats://[::1]:7222",
|
||||
"nats://example.com:6222",
|
||||
}, routeStrings(routes))
|
||||
})
|
||||
|
||||
t.Run("Empty", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
routes, err := parsePeerAddresses(nil)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, routes)
|
||||
})
|
||||
|
||||
t.Run("Dedupes", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
routes, err := parsePeerAddresses([]string{
|
||||
"nats://b.example:6222",
|
||||
"nats://a.example:6222",
|
||||
"nats://b.example:6222",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.ElementsMatch(t, []string{
|
||||
"nats://a.example:6222",
|
||||
"nats://b.example:6222",
|
||||
}, routeStrings(routes))
|
||||
})
|
||||
|
||||
t.Run("Invalid", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
for _, address := range []string{
|
||||
"",
|
||||
" ",
|
||||
"127.0.0.1:4222",
|
||||
"127.0.0.1",
|
||||
":4222",
|
||||
"127.0.0.1:0",
|
||||
"127.0.0.1:bad",
|
||||
"nats://127.0.0.1",
|
||||
"nats://:4222",
|
||||
"nats://127.0.0.1:0",
|
||||
"nats://127.0.0.1:bad",
|
||||
"nats://user@127.0.0.1:4222",
|
||||
"nats://127.0.0.1:4222/path",
|
||||
"nats://127.0.0.1:4222?x=1",
|
||||
"nats://127.0.0.1:4222#frag",
|
||||
} {
|
||||
t.Run(address, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, err := parsePeerAddresses([]string{address})
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func Test_filterSelfRoutes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
routes, err := parsePeerAddresses([]string{
|
||||
"nats://b.example:6222",
|
||||
"http://self.example:6222",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
routes = filterSelfRoutes(routes, &url.URL{Scheme: "nats", Host: "self.example:6222"})
|
||||
require.Equal(t, []string{"nats://b.example:6222"}, routeStrings(routes))
|
||||
}
|
||||
|
||||
// Cluster tests bind free ports and reload shared route state.
|
||||
func TestPubsub_SetPeerAddresses(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("OK", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
a := newTestPubsub(t, clusterTestOptions(t))
|
||||
b := newTestPubsub(t, clusterTestOptions(t))
|
||||
c := newTestPubsub(t, clusterTestOptions(t))
|
||||
|
||||
addrB := clusterRouteAddress(t, b)
|
||||
addrC := clusterRouteAddress(t, c)
|
||||
require.NoError(t, a.SetPeerAddresses([]string{addrC, addrB}))
|
||||
requireRoutesEqual(t, a.currentRoutes, addrB, addrC)
|
||||
|
||||
require.NoError(t, a.SetPeerAddresses([]string{addrB, addrC}))
|
||||
requireRoutesEqual(t, a.currentRoutes, addrB, addrC)
|
||||
|
||||
require.NoError(t, a.SetPeerAddresses(nil))
|
||||
require.Empty(t, a.currentRoutes)
|
||||
require.Empty(t, a.serverOpts.Routes)
|
||||
})
|
||||
|
||||
t.Run("StandaloneConfigError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ps := newTestPubsub(t, defaultTestOptions())
|
||||
err := ps.SetPeerAddresses(nil)
|
||||
require.ErrorContains(t, err, "not started with clustering enabled")
|
||||
})
|
||||
|
||||
t.Run("Closed", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ps := newTestPubsub(t, clusterTestOptions(t))
|
||||
require.NoError(t, ps.Close())
|
||||
err := ps.SetPeerAddresses(nil)
|
||||
require.True(t, errors.Is(err, errClosed), "got %v", err)
|
||||
})
|
||||
|
||||
t.Run("DropsSelfRoute", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ps := newTestPubsub(t, clusterTestOptions(t))
|
||||
require.NoError(t, ps.SetPeerAddresses([]string{clusterRouteAddress(t, ps)}))
|
||||
require.Empty(t, ps.currentRoutes)
|
||||
})
|
||||
}
|
||||
+42
-1
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -19,6 +20,12 @@ import (
|
||||
// DefaultMaxPending is the per-client outbound pending byte budget.
|
||||
const DefaultMaxPending int64 = 128 << 20
|
||||
|
||||
const (
|
||||
defaultClusterName = "coder"
|
||||
defaultClusterPort = 6222
|
||||
defaultRoutePoolSize = 3
|
||||
)
|
||||
|
||||
var errClosed = xerrors.New("nats pubsub closed")
|
||||
|
||||
// PendingLimits configures per-subscription NATS pending limits set
|
||||
@@ -65,6 +72,23 @@ type Options struct {
|
||||
// shared subscription is pinned to one connection by a stable hash
|
||||
// of its subject. Zero or negative means 1.
|
||||
SubscribeConns int
|
||||
|
||||
// ClusterHost is the embedded NATS route listener host. Empty means
|
||||
// all interfaces when cluster mode is enabled.
|
||||
ClusterHost string
|
||||
|
||||
// ClusterPort is the embedded NATS route listener port. Zero means
|
||||
// 6222 when cluster mode is enabled.
|
||||
ClusterPort int
|
||||
|
||||
// RoutePoolSize is the NATS route pool size. Zero means the package
|
||||
// default when cluster mode is enabled.
|
||||
RoutePoolSize int
|
||||
|
||||
// disableCluster is intended only for testing. Since we cannot reload a server
|
||||
// with a cluster host/port after initialization, we start all production servers
|
||||
// with clustering enabled.
|
||||
disableCluster bool
|
||||
}
|
||||
|
||||
// Pubsub is an embedded NATS-backed implementation of pubsub.Pubsub.
|
||||
@@ -97,6 +121,11 @@ type Pubsub struct {
|
||||
// cleanup observes the canceled context.
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
clusterMu sync.Mutex
|
||||
clustered bool
|
||||
serverOpts *natsserver.Options
|
||||
currentRoutes []*url.URL
|
||||
}
|
||||
|
||||
// natsSub maps to one underlying *natsgo.Subscription. The first
|
||||
@@ -203,13 +232,25 @@ func (p *Pubsub) buildConnHandlers() connHandlers {
|
||||
// embedded server and the publisher and subscriber connection pools.
|
||||
// Close shuts down all owned resources.
|
||||
func New(ctx context.Context, logger slog.Logger, opts Options) (*Pubsub, error) {
|
||||
ns, err := startEmbeddedServer(logger, opts)
|
||||
sopts, err := buildServerOptions(opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ns, err := startEmbeddedServer(sopts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logger.Info(context.Background(), "embedded nats server started",
|
||||
slog.F("client_url", ns.ClientURL()),
|
||||
)
|
||||
|
||||
p := newPubsub(ctx, logger, opts)
|
||||
p.ns = ns
|
||||
p.clustered = !opts.disableCluster
|
||||
p.serverOpts = sopts.Clone()
|
||||
p.currentRoutes = cloneRouteURLs(sopts.Routes)
|
||||
handlers := p.buildConnHandlers()
|
||||
|
||||
publishPool, err := newConnPool(ns, opts, handlers, opts.PublishConns, "coder-pubsub-pub")
|
||||
|
||||
@@ -1,18 +1,20 @@
|
||||
package nats //nolint:testpackage // Exercises internal pubsub state and helpers.
|
||||
package nats
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
natsserver "github.com/nats-io/nats-server/v2/server"
|
||||
natsgo "github.com/nats-io/nats.go"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
@@ -82,7 +84,7 @@ func Test_pickConn(t *testing.T) {
|
||||
func subjectForConn(t *testing.T, pool []*natsgo.Conn, conn *natsgo.Conn, prefix string) string {
|
||||
t.Helper()
|
||||
|
||||
for i := 0; i < 10_000; i++ {
|
||||
for i := range 10_000 {
|
||||
subject := fmt.Sprintf("%s_%d", prefix, i)
|
||||
if pickConn(pool, subject) == conn {
|
||||
return subject
|
||||
@@ -97,17 +99,13 @@ func Test_New(t *testing.T) {
|
||||
|
||||
t.Run("ConnectionCount", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancel()
|
||||
ps, err := New(ctx, logger, Options{})
|
||||
require.NoError(t, err)
|
||||
ps := newTestPubsub(t, defaultTestOptions())
|
||||
t.Cleanup(func() { _ = ps.Close() })
|
||||
|
||||
const n = 50
|
||||
cancels := make([]func(), 0, n)
|
||||
for i := 0; i < n; i++ {
|
||||
c, err := ps.Subscribe(fmt.Sprintf("cc_evt_%d", i), func(context.Context, []byte) {})
|
||||
for i := range n {
|
||||
c, err := ps.Subscribe(fmt.Sprintf("cc_evt_%d", i), func(_ context.Context, _ []byte) {})
|
||||
require.NoError(t, err)
|
||||
cancels = append(cancels, c)
|
||||
}
|
||||
@@ -130,10 +128,9 @@ func Test_SubscribeWithErr(t *testing.T) {
|
||||
|
||||
t.Run("SameSubjectSharesSubscription", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancel()
|
||||
ps, err := New(ctx, logger, Options{})
|
||||
logger := slogtest.Make(t, nil)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
ps, err := New(ctx, logger, defaultTestOptions())
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { _ = ps.Close() })
|
||||
|
||||
@@ -155,10 +152,10 @@ func Test_Pubsub_buildConnHandlers(t *testing.T) {
|
||||
|
||||
t.Run("DisconnectSignalsDropsForMatchingSubscriberConn", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
ps := newPubsub(ctx, logger, Options{})
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
ps := newPubsub(ctx, logger, defaultTestOptions())
|
||||
|
||||
var subConnA, subConnB, pubConn natsgo.Conn
|
||||
ps.subscribePool = []*natsgo.Conn{&subConnA, &subConnB}
|
||||
@@ -205,8 +202,7 @@ func Test_localSub_init(t *testing.T) {
|
||||
|
||||
t.Run("SerializesCallbacks", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
dataStarted := make(chan struct{})
|
||||
dropDelivered := make(chan struct{})
|
||||
@@ -219,7 +215,7 @@ func Test_localSub_init(t *testing.T) {
|
||||
|
||||
s := &localSub{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
cancel: func() {},
|
||||
listener: func(_ context.Context, _ []byte, ferr error) {
|
||||
if active.Add(1) != 1 {
|
||||
concurrent.Store(true)
|
||||
@@ -279,10 +275,9 @@ func Test_localSub_init(t *testing.T) {
|
||||
|
||||
t.Run("CrossSubjectListenerIsolation", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
ps, err := New(ctx, logger, Options{})
|
||||
logger := slogtest.Make(t, nil)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
ps, err := New(ctx, logger, defaultTestOptions())
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { _ = ps.Close() })
|
||||
|
||||
@@ -312,7 +307,7 @@ func Test_localSub_init(t *testing.T) {
|
||||
|
||||
total := defaultListenerQueueSize + 256
|
||||
payload := make([]byte, 4*1024)
|
||||
for i := 0; i < total; i++ {
|
||||
for range total {
|
||||
require.NoError(t, ps.Publish("iso_slow", payload))
|
||||
require.NoError(t, ps.Publish("iso_fast", []byte("ping")))
|
||||
}
|
||||
@@ -337,3 +332,161 @@ func Test_localSub_init(t *testing.T) {
|
||||
require.Equal(t, 2, ps.ns.NumClients(), "slow consumer must not disconnect subConn")
|
||||
})
|
||||
}
|
||||
|
||||
func TestPubsubCluster(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// OK verifies that SetPeerAddresses changes the active cluster topology.
|
||||
// A starts connected to B, then C is added and receives both global and
|
||||
// C-only messages. B is then removed from A's peers, while C continues to
|
||||
// receive global and C-only messages.
|
||||
t.Run("OK", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
a := newTestPubsub(t, clusterTestOptions(t))
|
||||
b := newTestPubsub(t, clusterTestOptions(t))
|
||||
c := newTestPubsub(t, clusterTestOptions(t))
|
||||
|
||||
addrB := clusterRouteAddress(t, b)
|
||||
addrC := clusterRouteAddress(t, c)
|
||||
|
||||
require.NoError(t, a.SetPeerAddresses([]string{addrB}))
|
||||
requireRoutesEqual(t, a.currentRoutes, addrB)
|
||||
|
||||
globalEvent := "global"
|
||||
bGlobal := make(chan []byte, 8)
|
||||
cancelBGlobal, err := b.Subscribe(globalEvent, func(_ context.Context, msg []byte) {
|
||||
bGlobal <- msg
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer cancelBGlobal()
|
||||
|
||||
waitForRouteSubscription(t, a, globalEvent)
|
||||
publishAndFlush(t, a, globalEvent, "from-a-to-b")
|
||||
require.Equal(t, "from-a-to-b", string(receiveMessage(t, bGlobal)))
|
||||
|
||||
// Add C's subscriptions before adding C as an extra peer to A.
|
||||
cGlobal := make(chan []byte, 8)
|
||||
cancelCGlobal, err := c.Subscribe(globalEvent, func(_ context.Context, msg []byte) {
|
||||
cGlobal <- msg
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer cancelCGlobal()
|
||||
|
||||
cSubject := "c-only-subscriber"
|
||||
cUnique := make(chan []byte, 8)
|
||||
cancelCUnique, err := c.Subscribe(cSubject, func(_ context.Context, msg []byte) {
|
||||
cUnique <- msg
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer cancelCUnique()
|
||||
|
||||
// Add C to A's peer list. B and C should both receive global messages,
|
||||
// while the C-only subject should route only to C.
|
||||
require.NoError(t, a.SetPeerAddresses([]string{addrC, addrB}))
|
||||
requireRoutesEqual(t, a.currentRoutes, addrB, addrC)
|
||||
|
||||
waitForRouteSubscription(t, a, globalEvent)
|
||||
waitForRouteSubscription(t, a, cSubject)
|
||||
|
||||
publishAndFlush(t, a, globalEvent, "new-global-msg")
|
||||
require.Equal(t, "new-global-msg", string(receiveMessage(t, bGlobal)))
|
||||
require.Equal(t, "new-global-msg", string(receiveMessage(t, cGlobal)))
|
||||
|
||||
publishAndFlush(t, a, cSubject, "c-unique-msg")
|
||||
require.Equal(t, "c-unique-msg", string(receiveMessage(t, cUnique)))
|
||||
|
||||
// Remove B from A's peer list. Only C should receive the next messages.
|
||||
require.NoError(t, a.SetPeerAddresses([]string{addrC}))
|
||||
requireRoutesEqual(t, a.currentRoutes, addrC)
|
||||
|
||||
publishAndFlush(t, a, globalEvent, "no-b-peer")
|
||||
require.Equal(t, "no-b-peer", string(receiveMessage(t, cGlobal)))
|
||||
|
||||
publishAndFlush(t, a, cSubject, "c-messages-still-work")
|
||||
require.Equal(t, "c-messages-still-work", string(receiveMessage(t, cUnique)))
|
||||
})
|
||||
}
|
||||
|
||||
func defaultTestOptions() Options {
|
||||
return Options{disableCluster: true}
|
||||
}
|
||||
|
||||
func clusterTestOptions(t *testing.T) Options {
|
||||
t.Helper()
|
||||
return Options{
|
||||
ClusterHost: "127.0.0.1",
|
||||
ClusterPort: natsserver.RANDOM_PORT,
|
||||
disableCluster: false,
|
||||
}
|
||||
}
|
||||
|
||||
func newTestPubsub(t *testing.T, opts Options) *Pubsub {
|
||||
t.Helper()
|
||||
logger := slogtest.Make(t, nil)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
ps, err := New(ctx, logger, opts)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = ps.Close()
|
||||
})
|
||||
return ps
|
||||
}
|
||||
|
||||
func clusterRouteAddress(t *testing.T, ps *Pubsub) string {
|
||||
t.Helper()
|
||||
addr := ps.ns.ClusterAddr()
|
||||
require.NotNil(t, addr)
|
||||
return "nats://" + addr.String()
|
||||
}
|
||||
|
||||
func waitForRouteSubscription(t *testing.T, ps *Pubsub, subject string) {
|
||||
t.Helper()
|
||||
require.Eventually(t, func() bool {
|
||||
routes, err := ps.ns.Routez(&natsserver.RoutezOptions{Subscriptions: true})
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
for _, route := range routes.Routes {
|
||||
for _, sub := range route.Subs {
|
||||
if sub == subject {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
}
|
||||
|
||||
func publishAndFlush(t *testing.T, ps *Pubsub, event, message string) {
|
||||
t.Helper()
|
||||
require.NoError(t, ps.Publish(event, []byte(message)))
|
||||
require.NoError(t, ps.Flush())
|
||||
}
|
||||
|
||||
func receiveMessage(t *testing.T, got <-chan []byte) []byte {
|
||||
t.Helper()
|
||||
select {
|
||||
case msg := <-got:
|
||||
return msg
|
||||
case <-time.After(testutil.WaitShort):
|
||||
t.Fatal("timed out waiting for message")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func requireRoutesEqual(t *testing.T, routes []*url.URL, addresses ...string) {
|
||||
t.Helper()
|
||||
want, err := parsePeerAddresses(addresses)
|
||||
require.NoError(t, err)
|
||||
want = sortRouteURLs(want)
|
||||
require.True(t, sortedURLsEqual(want, routes), "want %v, got %v", routeStrings(want), routeStrings(routes))
|
||||
}
|
||||
|
||||
func routeStrings(routes []*url.URL) []string {
|
||||
strings := make([]string, 0, len(routes))
|
||||
for _, route := range routes {
|
||||
strings = append(strings, route.String())
|
||||
}
|
||||
return strings
|
||||
}
|
||||
|
||||
@@ -7,25 +7,30 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
natsserver "github.com/nats-io/nats-server/v2/server"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
xnats "github.com/coder/coder/v2/coderd/x/nats"
|
||||
"github.com/coder/coder/v2/coderd/x/nats"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func newTestPubsub(t *testing.T, opts xnats.Options) *xnats.Pubsub {
|
||||
func newPubsub(t *testing.T, opts nats.Options) *nats.Pubsub {
|
||||
t.Helper()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ps, err := xnats.New(ctx, logger, opts)
|
||||
|
||||
if opts.ClusterPort == 0 {
|
||||
opts.ClusterPort = natsserver.RANDOM_PORT
|
||||
}
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
ps, err := nats.New(ctx, logger, opts)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = ps.Close()
|
||||
cancel()
|
||||
})
|
||||
return ps
|
||||
}
|
||||
@@ -35,7 +40,7 @@ func TestPubsub(t *testing.T) {
|
||||
|
||||
t.Run("RoundTrip", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ps := newTestPubsub(t, xnats.Options{})
|
||||
ps := newPubsub(t, nats.Options{})
|
||||
|
||||
got := make(chan []byte, 1)
|
||||
cancel, err := ps.Subscribe("test_event", func(_ context.Context, msg []byte) {
|
||||
@@ -56,7 +61,7 @@ func TestPubsub(t *testing.T) {
|
||||
|
||||
t.Run("SubscribeWithErrNormalMessage", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ps := newTestPubsub(t, xnats.Options{})
|
||||
ps := newPubsub(t, nats.Options{})
|
||||
|
||||
got := make(chan []byte, 1)
|
||||
cancel, err := ps.SubscribeWithErr("evt", func(_ context.Context, msg []byte, err error) {
|
||||
@@ -78,7 +83,7 @@ func TestPubsub(t *testing.T) {
|
||||
|
||||
t.Run("EchoDefault", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ps := newTestPubsub(t, xnats.Options{})
|
||||
ps := newPubsub(t, nats.Options{})
|
||||
|
||||
got := make(chan []byte, 1)
|
||||
cancel, err := ps.Subscribe("echo_evt", func(_ context.Context, msg []byte) {
|
||||
@@ -99,7 +104,7 @@ func TestPubsub(t *testing.T) {
|
||||
|
||||
t.Run("Ordering", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ps := newTestPubsub(t, xnats.Options{})
|
||||
ps := newPubsub(t, nats.Options{})
|
||||
|
||||
const n = 100
|
||||
got := make(chan []byte, n)
|
||||
@@ -129,7 +134,7 @@ func TestPubsub(t *testing.T) {
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancel()
|
||||
ps, err := xnats.New(ctx, logger, xnats.Options{})
|
||||
ps, err := nats.New(ctx, logger, nats.Options{})
|
||||
require.NoError(t, err)
|
||||
|
||||
var first, second error
|
||||
@@ -147,8 +152,8 @@ func TestPubsub(t *testing.T) {
|
||||
|
||||
t.Run("SubscribeWithErrReceivesDropError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ps := newTestPubsub(t, xnats.Options{
|
||||
PendingLimits: xnats.PendingLimits{Msgs: 1, Bytes: 1024 * 1024},
|
||||
ps := newPubsub(t, nats.Options{
|
||||
PendingLimits: nats.PendingLimits{Msgs: 1, Bytes: 1024 * 1024},
|
||||
})
|
||||
|
||||
const event = "slow_evt_sync"
|
||||
|
||||
+30
-17
@@ -1,20 +1,18 @@
|
||||
package nats
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
natsserver "github.com/nats-io/nats-server/v2/server"
|
||||
natsgo "github.com/nats-io/nats.go"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
)
|
||||
|
||||
const readyTimeout = 10 * time.Second
|
||||
|
||||
// buildServerOptions constructs the embedded NATS server options. The
|
||||
// server runs standalone with a loopback random client listener.
|
||||
// server runs with a loopback random client listener and an optional
|
||||
// cluster route listener.
|
||||
func buildServerOptions(opts Options) (*natsserver.Options, error) {
|
||||
maxPayload := opts.MaxPayload
|
||||
if maxPayload == 0 {
|
||||
@@ -37,16 +35,34 @@ func buildServerOptions(opts Options) (*natsserver.Options, error) {
|
||||
sopts.Host = "127.0.0.1"
|
||||
sopts.Port = natsserver.RANDOM_PORT
|
||||
|
||||
if !opts.disableCluster {
|
||||
clusterHost := opts.ClusterHost
|
||||
if clusterHost == "" {
|
||||
clusterHost = natsserver.DEFAULT_HOST
|
||||
}
|
||||
clusterPort := opts.ClusterPort
|
||||
if clusterPort == 0 {
|
||||
clusterPort = defaultClusterPort
|
||||
}
|
||||
routePoolSize := opts.RoutePoolSize
|
||||
if routePoolSize == 0 {
|
||||
routePoolSize = defaultRoutePoolSize
|
||||
}
|
||||
|
||||
sopts.Cluster = natsserver.ClusterOpts{
|
||||
Name: defaultClusterName,
|
||||
Host: clusterHost,
|
||||
Port: clusterPort,
|
||||
PoolSize: routePoolSize,
|
||||
}
|
||||
}
|
||||
|
||||
return sopts, nil
|
||||
}
|
||||
|
||||
// startEmbeddedServer starts an in-process standalone NATS server.
|
||||
func startEmbeddedServer(logger slog.Logger, opts Options) (*natsserver.Server, error) {
|
||||
sopts, err := buildServerOptions(opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ns, err := natsserver.NewServer(sopts)
|
||||
// startEmbeddedServer starts an in-process NATS server.
|
||||
func startEmbeddedServer(opts *natsserver.Options) (*natsserver.Server, error) {
|
||||
ns, err := natsserver.NewServer(opts)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("new embedded nats server: %w", err)
|
||||
}
|
||||
@@ -56,9 +72,6 @@ func startEmbeddedServer(logger slog.Logger, opts Options) (*natsserver.Server,
|
||||
ns.WaitForShutdown()
|
||||
return nil, xerrors.Errorf("embedded nats server not ready within %s", readyTimeout)
|
||||
}
|
||||
logger.Info(context.Background(), "embedded nats server started",
|
||||
slog.F("client_url", ns.ClientURL()),
|
||||
)
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
@@ -92,13 +105,13 @@ func connectClient(ns *natsserver.Server, opts Options, handlers connHandlers, c
|
||||
if handlers.errH != nil {
|
||||
connOpts = append(connOpts, natsgo.ErrorHandler(handlers.errH))
|
||||
}
|
||||
url := ns.ClientURL()
|
||||
clientURL := ns.ClientURL()
|
||||
if opts.InProcess {
|
||||
// InProcessServer overrides URL dialing with a net.Pipe; the
|
||||
// url argument is ignored but must still be syntactically valid.
|
||||
// URL argument is ignored but must still be syntactically valid.
|
||||
connOpts = append(connOpts, natsgo.InProcessServer(ns))
|
||||
}
|
||||
nc, err := natsgo.Connect(url, connOpts...)
|
||||
nc, err := natsgo.Connect(clientURL, connOpts...)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("connect client: %w", err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user