From 5cdc9e28a984bb2cf6521de15b1d34d4df970810 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Fri, 29 May 2026 11:35:59 -0500 Subject: [PATCH] feat: add nats cluster peer support (#25632) --- coderd/x/nats/cluster.go | 148 ++++++++++++++++++ coderd/x/nats/cluster_internal_test.go | 134 ++++++++++++++++ coderd/x/nats/pubsub.go | 43 +++++- coderd/x/nats/pubsub_internal_test.go | 205 +++++++++++++++++++++---- coderd/x/nats/pubsub_test.go | 31 ++-- coderd/x/nats/server.go | 47 ++++-- 6 files changed, 551 insertions(+), 57 deletions(-) create mode 100644 coderd/x/nats/cluster.go create mode 100644 coderd/x/nats/cluster_internal_test.go diff --git a/coderd/x/nats/cluster.go b/coderd/x/nats/cluster.go new file mode 100644 index 0000000000..7b0fd1ab80 --- /dev/null +++ b/coderd/x/nats/cluster.go @@ -0,0 +1,148 @@ +package nats + +import ( + "net" + "net/url" + "slices" + "strconv" + "strings" + + "golang.org/x/xerrors" +) + +// SetPeerAddresses replaces the configured NATS cluster peer routes. +func (p *Pubsub) SetPeerAddresses(addresses []string) error { + p.clusterMu.Lock() + defer p.clusterMu.Unlock() + + if p.ctx.Err() != nil { + return errClosed + } + if !p.clustered { + return xerrors.New("nats pubsub was not started with clustering enabled") + } + + routes, err := parsePeerAddresses(addresses) + if err != nil { + return err + } + + self := &url.URL{Scheme: "nats", Host: p.ns.ClusterAddr().String()} + routes = filterSelfRoutes(routes, self) + routes = sortRouteURLs(routes) + + if sortedURLsEqual(p.currentRoutes, routes) { + return nil + } + + newOpts := p.serverOpts.Clone() + newOpts.Routes = cloneRouteURLs(routes) + if err := p.ns.ReloadOptions(newOpts); err != nil { + return xerrors.Errorf("reload nats peer addresses: %w", err) + } + p.serverOpts = newOpts.Clone() + p.currentRoutes = cloneRouteURLs(routes) + return nil +} + +func parsePeerAddresses(addresses []string) ([]*url.URL, error) { + routesByAddress := make(map[string]*url.URL, len(addresses)) + for i, address := range addresses { + trimmed := strings.TrimSpace(address) + if trimmed == "" { + return nil, xerrors.Errorf("peer address %d is empty", i) + } + + normalizedHost, err := normalizeHostPort(trimmed) + if err != nil { + return nil, err + } + + routesByAddress[normalizedHost] = &url.URL{ + Scheme: "nats", + Host: normalizedHost, + } + } + + routes := make([]*url.URL, 0, len(routesByAddress)) + for _, route := range routesByAddress { + routes = append(routes, route) + } + return routes, nil +} + +func filterSelfRoutes(routes []*url.URL, self *url.URL) []*url.URL { + filtered := make([]*url.URL, 0, len(routes)) + for _, route := range routes { + if route.String() == self.String() { + continue + } + filtered = append(filtered, route) + } + return filtered +} + +func normalizeHostPort(address string) (string, error) { + route, err := url.Parse(address) + if err != nil { + return "", xerrors.Errorf("parse peer address %q: %w", address, err) + } + if route.User != nil { + return "", xerrors.Errorf("peer address %q must not include userinfo", address) + } + if route.Path != "" || route.RawQuery != "" || route.Fragment != "" { + return "", xerrors.Errorf("peer address %q must not include path, query, or fragment", address) + } + + host, port, err := net.SplitHostPort(route.Host) + if err != nil { + return "", xerrors.Errorf("split %q host port: %w", address, err) + } + if host == "" || port == "" { + return "", xerrors.Errorf("%q must include host and port", address) + } + + portNumber, err := strconv.Atoi(port) + if err != nil { + return "", xerrors.Errorf("parse %q port: %w", address, err) + } + if portNumber <= 0 || portNumber > 65535 { + return "", xerrors.Errorf("peer address %q must include a valid port", address) + } + return net.JoinHostPort(host, strconv.Itoa(portNumber)), nil +} + +func sortRouteURLs(routes []*url.URL) []*url.URL { + slices.SortFunc(routes, func(a, b *url.URL) int { + return strings.Compare(a.String(), b.String()) + }) + return routes +} + +// sortedURLsEqual assumes sorted slices. +func sortedURLsEqual(a, b []*url.URL) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i].String() != b[i].String() { + return false + } + } + return true +} + +func cloneRouteURLs(routes []*url.URL) []*url.URL { + if routes == nil { + return nil + } + clones := make([]*url.URL, len(routes)) + for i, route := range routes { + if route == nil { + continue + } + clone := *route + clones[i] = &clone + } + return clones +} diff --git a/coderd/x/nats/cluster_internal_test.go b/coderd/x/nats/cluster_internal_test.go new file mode 100644 index 0000000000..eadf2e561f --- /dev/null +++ b/coderd/x/nats/cluster_internal_test.go @@ -0,0 +1,134 @@ +package nats + +import ( + "errors" + "net/url" + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_parsePeerAddresses(t *testing.T) { + t.Parallel() + + t.Run("Valid", func(t *testing.T) { + t.Parallel() + routes, err := parsePeerAddresses([]string{ + "whatever://127.0.0.1:4222 ", + "http://[::1]:7222", + "nats://example.com:6222", + }) + require.NoError(t, err) + require.ElementsMatch(t, []string{ + "nats://127.0.0.1:4222", + "nats://[::1]:7222", + "nats://example.com:6222", + }, routeStrings(routes)) + }) + + t.Run("Empty", func(t *testing.T) { + t.Parallel() + routes, err := parsePeerAddresses(nil) + require.NoError(t, err) + require.Empty(t, routes) + }) + + t.Run("Dedupes", func(t *testing.T) { + t.Parallel() + routes, err := parsePeerAddresses([]string{ + "nats://b.example:6222", + "nats://a.example:6222", + "nats://b.example:6222", + }) + require.NoError(t, err) + require.ElementsMatch(t, []string{ + "nats://a.example:6222", + "nats://b.example:6222", + }, routeStrings(routes)) + }) + + t.Run("Invalid", func(t *testing.T) { + t.Parallel() + for _, address := range []string{ + "", + " ", + "127.0.0.1:4222", + "127.0.0.1", + ":4222", + "127.0.0.1:0", + "127.0.0.1:bad", + "nats://127.0.0.1", + "nats://:4222", + "nats://127.0.0.1:0", + "nats://127.0.0.1:bad", + "nats://user@127.0.0.1:4222", + "nats://127.0.0.1:4222/path", + "nats://127.0.0.1:4222?x=1", + "nats://127.0.0.1:4222#frag", + } { + t.Run(address, func(t *testing.T) { + t.Parallel() + _, err := parsePeerAddresses([]string{address}) + require.Error(t, err) + }) + } + }) +} + +func Test_filterSelfRoutes(t *testing.T) { + t.Parallel() + + routes, err := parsePeerAddresses([]string{ + "nats://b.example:6222", + "http://self.example:6222", + }) + require.NoError(t, err) + + routes = filterSelfRoutes(routes, &url.URL{Scheme: "nats", Host: "self.example:6222"}) + require.Equal(t, []string{"nats://b.example:6222"}, routeStrings(routes)) +} + +// Cluster tests bind free ports and reload shared route state. +func TestPubsub_SetPeerAddresses(t *testing.T) { + t.Parallel() + t.Run("OK", func(t *testing.T) { + t.Parallel() + a := newTestPubsub(t, clusterTestOptions(t)) + b := newTestPubsub(t, clusterTestOptions(t)) + c := newTestPubsub(t, clusterTestOptions(t)) + + addrB := clusterRouteAddress(t, b) + addrC := clusterRouteAddress(t, c) + require.NoError(t, a.SetPeerAddresses([]string{addrC, addrB})) + requireRoutesEqual(t, a.currentRoutes, addrB, addrC) + + require.NoError(t, a.SetPeerAddresses([]string{addrB, addrC})) + requireRoutesEqual(t, a.currentRoutes, addrB, addrC) + + require.NoError(t, a.SetPeerAddresses(nil)) + require.Empty(t, a.currentRoutes) + require.Empty(t, a.serverOpts.Routes) + }) + + t.Run("StandaloneConfigError", func(t *testing.T) { + t.Parallel() + ps := newTestPubsub(t, defaultTestOptions()) + err := ps.SetPeerAddresses(nil) + require.ErrorContains(t, err, "not started with clustering enabled") + }) + + t.Run("Closed", func(t *testing.T) { + t.Parallel() + ps := newTestPubsub(t, clusterTestOptions(t)) + require.NoError(t, ps.Close()) + err := ps.SetPeerAddresses(nil) + require.True(t, errors.Is(err, errClosed), "got %v", err) + }) + + t.Run("DropsSelfRoute", func(t *testing.T) { + t.Parallel() + ps := newTestPubsub(t, clusterTestOptions(t)) + require.NoError(t, ps.SetPeerAddresses([]string{clusterRouteAddress(t, ps)})) + require.Empty(t, ps.currentRoutes) + }) +} diff --git a/coderd/x/nats/pubsub.go b/coderd/x/nats/pubsub.go index afb6776a7f..a41247ed09 100644 --- a/coderd/x/nats/pubsub.go +++ b/coderd/x/nats/pubsub.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "hash/fnv" + "net/url" "sync" "time" @@ -19,6 +20,12 @@ import ( // DefaultMaxPending is the per-client outbound pending byte budget. const DefaultMaxPending int64 = 128 << 20 +const ( + defaultClusterName = "coder" + defaultClusterPort = 6222 + defaultRoutePoolSize = 3 +) + var errClosed = xerrors.New("nats pubsub closed") // PendingLimits configures per-subscription NATS pending limits set @@ -65,6 +72,23 @@ type Options struct { // shared subscription is pinned to one connection by a stable hash // of its subject. Zero or negative means 1. SubscribeConns int + + // ClusterHost is the embedded NATS route listener host. Empty means + // all interfaces when cluster mode is enabled. + ClusterHost string + + // ClusterPort is the embedded NATS route listener port. Zero means + // 6222 when cluster mode is enabled. + ClusterPort int + + // RoutePoolSize is the NATS route pool size. Zero means the package + // default when cluster mode is enabled. + RoutePoolSize int + + // disableCluster is intended only for testing. Since we cannot reload a server + // with a cluster host/port after initialization, we start all production servers + // with clustering enabled. + disableCluster bool } // Pubsub is an embedded NATS-backed implementation of pubsub.Pubsub. @@ -97,6 +121,11 @@ type Pubsub struct { // cleanup observes the canceled context. ctx context.Context cancel context.CancelFunc + + clusterMu sync.Mutex + clustered bool + serverOpts *natsserver.Options + currentRoutes []*url.URL } // natsSub maps to one underlying *natsgo.Subscription. The first @@ -203,13 +232,25 @@ func (p *Pubsub) buildConnHandlers() connHandlers { // embedded server and the publisher and subscriber connection pools. // Close shuts down all owned resources. func New(ctx context.Context, logger slog.Logger, opts Options) (*Pubsub, error) { - ns, err := startEmbeddedServer(logger, opts) + sopts, err := buildServerOptions(opts) if err != nil { return nil, err } + ns, err := startEmbeddedServer(sopts) + if err != nil { + return nil, err + } + + logger.Info(context.Background(), "embedded nats server started", + slog.F("client_url", ns.ClientURL()), + ) + p := newPubsub(ctx, logger, opts) p.ns = ns + p.clustered = !opts.disableCluster + p.serverOpts = sopts.Clone() + p.currentRoutes = cloneRouteURLs(sopts.Routes) handlers := p.buildConnHandlers() publishPool, err := newConnPool(ns, opts, handlers, opts.PublishConns, "coder-pubsub-pub") diff --git a/coderd/x/nats/pubsub_internal_test.go b/coderd/x/nats/pubsub_internal_test.go index 35abb84ed6..3b5263654e 100644 --- a/coderd/x/nats/pubsub_internal_test.go +++ b/coderd/x/nats/pubsub_internal_test.go @@ -1,18 +1,20 @@ -package nats //nolint:testpackage // Exercises internal pubsub state and helpers. +package nats import ( "context" "errors" "fmt" + "net/url" "sync" "sync/atomic" "testing" + "time" + natsserver "github.com/nats-io/nats-server/v2/server" natsgo "github.com/nats-io/nats.go" "github.com/stretchr/testify/require" "golang.org/x/xerrors" - "cdr.dev/slog/v3" "cdr.dev/slog/v3/sloggers/slogtest" "github.com/coder/coder/v2/coderd/database/pubsub" "github.com/coder/coder/v2/testutil" @@ -82,7 +84,7 @@ func Test_pickConn(t *testing.T) { func subjectForConn(t *testing.T, pool []*natsgo.Conn, conn *natsgo.Conn, prefix string) string { t.Helper() - for i := 0; i < 10_000; i++ { + for i := range 10_000 { subject := fmt.Sprintf("%s_%d", prefix, i) if pickConn(pool, subject) == conn { return subject @@ -97,17 +99,13 @@ func Test_New(t *testing.T) { t.Run("ConnectionCount", func(t *testing.T) { t.Parallel() - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) - defer cancel() - ps, err := New(ctx, logger, Options{}) - require.NoError(t, err) + ps := newTestPubsub(t, defaultTestOptions()) t.Cleanup(func() { _ = ps.Close() }) const n = 50 cancels := make([]func(), 0, n) - for i := 0; i < n; i++ { - c, err := ps.Subscribe(fmt.Sprintf("cc_evt_%d", i), func(context.Context, []byte) {}) + for i := range n { + c, err := ps.Subscribe(fmt.Sprintf("cc_evt_%d", i), func(_ context.Context, _ []byte) {}) require.NoError(t, err) cancels = append(cancels, c) } @@ -130,10 +128,9 @@ func Test_SubscribeWithErr(t *testing.T) { t.Run("SameSubjectSharesSubscription", func(t *testing.T) { t.Parallel() - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) - defer cancel() - ps, err := New(ctx, logger, Options{}) + logger := slogtest.Make(t, nil) + ctx := testutil.Context(t, testutil.WaitShort) + ps, err := New(ctx, logger, defaultTestOptions()) require.NoError(t, err) t.Cleanup(func() { _ = ps.Close() }) @@ -155,10 +152,10 @@ func Test_Pubsub_buildConnHandlers(t *testing.T) { t.Run("DisconnectSignalsDropsForMatchingSubscriberConn", func(t *testing.T) { t.Parallel() - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - ps := newPubsub(ctx, logger, Options{}) + + logger := slogtest.Make(t, nil) + ctx := testutil.Context(t, testutil.WaitShort) + ps := newPubsub(ctx, logger, defaultTestOptions()) var subConnA, subConnB, pubConn natsgo.Conn ps.subscribePool = []*natsgo.Conn{&subConnA, &subConnB} @@ -205,8 +202,7 @@ func Test_localSub_init(t *testing.T) { t.Run("SerializesCallbacks", func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := testutil.Context(t, testutil.WaitShort) dataStarted := make(chan struct{}) dropDelivered := make(chan struct{}) @@ -219,7 +215,7 @@ func Test_localSub_init(t *testing.T) { s := &localSub{ ctx: ctx, - cancel: cancel, + cancel: func() {}, listener: func(_ context.Context, _ []byte, ferr error) { if active.Add(1) != 1 { concurrent.Store(true) @@ -279,10 +275,9 @@ func Test_localSub_init(t *testing.T) { t.Run("CrossSubjectListenerIsolation", func(t *testing.T) { t.Parallel() - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() - ps, err := New(ctx, logger, Options{}) + logger := slogtest.Make(t, nil) + ctx := testutil.Context(t, testutil.WaitLong) + ps, err := New(ctx, logger, defaultTestOptions()) require.NoError(t, err) t.Cleanup(func() { _ = ps.Close() }) @@ -312,7 +307,7 @@ func Test_localSub_init(t *testing.T) { total := defaultListenerQueueSize + 256 payload := make([]byte, 4*1024) - for i := 0; i < total; i++ { + for range total { require.NoError(t, ps.Publish("iso_slow", payload)) require.NoError(t, ps.Publish("iso_fast", []byte("ping"))) } @@ -337,3 +332,161 @@ func Test_localSub_init(t *testing.T) { require.Equal(t, 2, ps.ns.NumClients(), "slow consumer must not disconnect subConn") }) } + +func TestPubsubCluster(t *testing.T) { + t.Parallel() + + // OK verifies that SetPeerAddresses changes the active cluster topology. + // A starts connected to B, then C is added and receives both global and + // C-only messages. B is then removed from A's peers, while C continues to + // receive global and C-only messages. + t.Run("OK", func(t *testing.T) { + t.Parallel() + + a := newTestPubsub(t, clusterTestOptions(t)) + b := newTestPubsub(t, clusterTestOptions(t)) + c := newTestPubsub(t, clusterTestOptions(t)) + + addrB := clusterRouteAddress(t, b) + addrC := clusterRouteAddress(t, c) + + require.NoError(t, a.SetPeerAddresses([]string{addrB})) + requireRoutesEqual(t, a.currentRoutes, addrB) + + globalEvent := "global" + bGlobal := make(chan []byte, 8) + cancelBGlobal, err := b.Subscribe(globalEvent, func(_ context.Context, msg []byte) { + bGlobal <- msg + }) + require.NoError(t, err) + defer cancelBGlobal() + + waitForRouteSubscription(t, a, globalEvent) + publishAndFlush(t, a, globalEvent, "from-a-to-b") + require.Equal(t, "from-a-to-b", string(receiveMessage(t, bGlobal))) + + // Add C's subscriptions before adding C as an extra peer to A. + cGlobal := make(chan []byte, 8) + cancelCGlobal, err := c.Subscribe(globalEvent, func(_ context.Context, msg []byte) { + cGlobal <- msg + }) + require.NoError(t, err) + defer cancelCGlobal() + + cSubject := "c-only-subscriber" + cUnique := make(chan []byte, 8) + cancelCUnique, err := c.Subscribe(cSubject, func(_ context.Context, msg []byte) { + cUnique <- msg + }) + require.NoError(t, err) + defer cancelCUnique() + + // Add C to A's peer list. B and C should both receive global messages, + // while the C-only subject should route only to C. + require.NoError(t, a.SetPeerAddresses([]string{addrC, addrB})) + requireRoutesEqual(t, a.currentRoutes, addrB, addrC) + + waitForRouteSubscription(t, a, globalEvent) + waitForRouteSubscription(t, a, cSubject) + + publishAndFlush(t, a, globalEvent, "new-global-msg") + require.Equal(t, "new-global-msg", string(receiveMessage(t, bGlobal))) + require.Equal(t, "new-global-msg", string(receiveMessage(t, cGlobal))) + + publishAndFlush(t, a, cSubject, "c-unique-msg") + require.Equal(t, "c-unique-msg", string(receiveMessage(t, cUnique))) + + // Remove B from A's peer list. Only C should receive the next messages. + require.NoError(t, a.SetPeerAddresses([]string{addrC})) + requireRoutesEqual(t, a.currentRoutes, addrC) + + publishAndFlush(t, a, globalEvent, "no-b-peer") + require.Equal(t, "no-b-peer", string(receiveMessage(t, cGlobal))) + + publishAndFlush(t, a, cSubject, "c-messages-still-work") + require.Equal(t, "c-messages-still-work", string(receiveMessage(t, cUnique))) + }) +} + +func defaultTestOptions() Options { + return Options{disableCluster: true} +} + +func clusterTestOptions(t *testing.T) Options { + t.Helper() + return Options{ + ClusterHost: "127.0.0.1", + ClusterPort: natsserver.RANDOM_PORT, + disableCluster: false, + } +} + +func newTestPubsub(t *testing.T, opts Options) *Pubsub { + t.Helper() + logger := slogtest.Make(t, nil) + ctx := testutil.Context(t, testutil.WaitLong) + ps, err := New(ctx, logger, opts) + require.NoError(t, err) + t.Cleanup(func() { + _ = ps.Close() + }) + return ps +} + +func clusterRouteAddress(t *testing.T, ps *Pubsub) string { + t.Helper() + addr := ps.ns.ClusterAddr() + require.NotNil(t, addr) + return "nats://" + addr.String() +} + +func waitForRouteSubscription(t *testing.T, ps *Pubsub, subject string) { + t.Helper() + require.Eventually(t, func() bool { + routes, err := ps.ns.Routez(&natsserver.RoutezOptions{Subscriptions: true}) + if err != nil { + return false + } + for _, route := range routes.Routes { + for _, sub := range route.Subs { + if sub == subject { + return true + } + } + } + return false + }, testutil.WaitShort, testutil.IntervalFast) +} + +func publishAndFlush(t *testing.T, ps *Pubsub, event, message string) { + t.Helper() + require.NoError(t, ps.Publish(event, []byte(message))) + require.NoError(t, ps.Flush()) +} + +func receiveMessage(t *testing.T, got <-chan []byte) []byte { + t.Helper() + select { + case msg := <-got: + return msg + case <-time.After(testutil.WaitShort): + t.Fatal("timed out waiting for message") + return nil + } +} + +func requireRoutesEqual(t *testing.T, routes []*url.URL, addresses ...string) { + t.Helper() + want, err := parsePeerAddresses(addresses) + require.NoError(t, err) + want = sortRouteURLs(want) + require.True(t, sortedURLsEqual(want, routes), "want %v, got %v", routeStrings(want), routeStrings(routes)) +} + +func routeStrings(routes []*url.URL) []string { + strings := make([]string, 0, len(routes)) + for _, route := range routes { + strings = append(strings, route.String()) + } + return strings +} diff --git a/coderd/x/nats/pubsub_test.go b/coderd/x/nats/pubsub_test.go index 5be9f766f4..7b65228b7a 100644 --- a/coderd/x/nats/pubsub_test.go +++ b/coderd/x/nats/pubsub_test.go @@ -7,25 +7,30 @@ import ( "testing" "time" + natsserver "github.com/nats-io/nats-server/v2/server" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "cdr.dev/slog/v3" "cdr.dev/slog/v3/sloggers/slogtest" "github.com/coder/coder/v2/coderd/database/pubsub" - xnats "github.com/coder/coder/v2/coderd/x/nats" + "github.com/coder/coder/v2/coderd/x/nats" "github.com/coder/coder/v2/testutil" ) -func newTestPubsub(t *testing.T, opts xnats.Options) *xnats.Pubsub { +func newPubsub(t *testing.T, opts nats.Options) *nats.Pubsub { t.Helper() - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) - ctx, cancel := context.WithCancel(context.Background()) - ps, err := xnats.New(ctx, logger, opts) + + if opts.ClusterPort == 0 { + opts.ClusterPort = natsserver.RANDOM_PORT + } + + logger := slogtest.Make(t, nil) + ctx := testutil.Context(t, testutil.WaitLong) + ps, err := nats.New(ctx, logger, opts) require.NoError(t, err) t.Cleanup(func() { _ = ps.Close() - cancel() }) return ps } @@ -35,7 +40,7 @@ func TestPubsub(t *testing.T) { t.Run("RoundTrip", func(t *testing.T) { t.Parallel() - ps := newTestPubsub(t, xnats.Options{}) + ps := newPubsub(t, nats.Options{}) got := make(chan []byte, 1) cancel, err := ps.Subscribe("test_event", func(_ context.Context, msg []byte) { @@ -56,7 +61,7 @@ func TestPubsub(t *testing.T) { t.Run("SubscribeWithErrNormalMessage", func(t *testing.T) { t.Parallel() - ps := newTestPubsub(t, xnats.Options{}) + ps := newPubsub(t, nats.Options{}) got := make(chan []byte, 1) cancel, err := ps.SubscribeWithErr("evt", func(_ context.Context, msg []byte, err error) { @@ -78,7 +83,7 @@ func TestPubsub(t *testing.T) { t.Run("EchoDefault", func(t *testing.T) { t.Parallel() - ps := newTestPubsub(t, xnats.Options{}) + ps := newPubsub(t, nats.Options{}) got := make(chan []byte, 1) cancel, err := ps.Subscribe("echo_evt", func(_ context.Context, msg []byte) { @@ -99,7 +104,7 @@ func TestPubsub(t *testing.T) { t.Run("Ordering", func(t *testing.T) { t.Parallel() - ps := newTestPubsub(t, xnats.Options{}) + ps := newPubsub(t, nats.Options{}) const n = 100 got := make(chan []byte, n) @@ -129,7 +134,7 @@ func TestPubsub(t *testing.T) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) defer cancel() - ps, err := xnats.New(ctx, logger, xnats.Options{}) + ps, err := nats.New(ctx, logger, nats.Options{}) require.NoError(t, err) var first, second error @@ -147,8 +152,8 @@ func TestPubsub(t *testing.T) { t.Run("SubscribeWithErrReceivesDropError", func(t *testing.T) { t.Parallel() - ps := newTestPubsub(t, xnats.Options{ - PendingLimits: xnats.PendingLimits{Msgs: 1, Bytes: 1024 * 1024}, + ps := newPubsub(t, nats.Options{ + PendingLimits: nats.PendingLimits{Msgs: 1, Bytes: 1024 * 1024}, }) const event = "slow_evt_sync" diff --git a/coderd/x/nats/server.go b/coderd/x/nats/server.go index c7e038a06d..6013c44feb 100644 --- a/coderd/x/nats/server.go +++ b/coderd/x/nats/server.go @@ -1,20 +1,18 @@ package nats import ( - "context" "time" natsserver "github.com/nats-io/nats-server/v2/server" natsgo "github.com/nats-io/nats.go" "golang.org/x/xerrors" - - "cdr.dev/slog/v3" ) const readyTimeout = 10 * time.Second // buildServerOptions constructs the embedded NATS server options. The -// server runs standalone with a loopback random client listener. +// server runs with a loopback random client listener and an optional +// cluster route listener. func buildServerOptions(opts Options) (*natsserver.Options, error) { maxPayload := opts.MaxPayload if maxPayload == 0 { @@ -37,16 +35,34 @@ func buildServerOptions(opts Options) (*natsserver.Options, error) { sopts.Host = "127.0.0.1" sopts.Port = natsserver.RANDOM_PORT + if !opts.disableCluster { + clusterHost := opts.ClusterHost + if clusterHost == "" { + clusterHost = natsserver.DEFAULT_HOST + } + clusterPort := opts.ClusterPort + if clusterPort == 0 { + clusterPort = defaultClusterPort + } + routePoolSize := opts.RoutePoolSize + if routePoolSize == 0 { + routePoolSize = defaultRoutePoolSize + } + + sopts.Cluster = natsserver.ClusterOpts{ + Name: defaultClusterName, + Host: clusterHost, + Port: clusterPort, + PoolSize: routePoolSize, + } + } + return sopts, nil } -// startEmbeddedServer starts an in-process standalone NATS server. -func startEmbeddedServer(logger slog.Logger, opts Options) (*natsserver.Server, error) { - sopts, err := buildServerOptions(opts) - if err != nil { - return nil, err - } - ns, err := natsserver.NewServer(sopts) +// startEmbeddedServer starts an in-process NATS server. +func startEmbeddedServer(opts *natsserver.Options) (*natsserver.Server, error) { + ns, err := natsserver.NewServer(opts) if err != nil { return nil, xerrors.Errorf("new embedded nats server: %w", err) } @@ -56,9 +72,6 @@ func startEmbeddedServer(logger slog.Logger, opts Options) (*natsserver.Server, ns.WaitForShutdown() return nil, xerrors.Errorf("embedded nats server not ready within %s", readyTimeout) } - logger.Info(context.Background(), "embedded nats server started", - slog.F("client_url", ns.ClientURL()), - ) return ns, nil } @@ -92,13 +105,13 @@ func connectClient(ns *natsserver.Server, opts Options, handlers connHandlers, c if handlers.errH != nil { connOpts = append(connOpts, natsgo.ErrorHandler(handlers.errH)) } - url := ns.ClientURL() + clientURL := ns.ClientURL() if opts.InProcess { // InProcessServer overrides URL dialing with a net.Pipe; the - // url argument is ignored but must still be syntactically valid. + // URL argument is ignored but must still be syntactically valid. connOpts = append(connOpts, natsgo.InProcessServer(ns)) } - nc, err := natsgo.Connect(url, connOpts...) + nc, err := natsgo.Connect(clientURL, connOpts...) if err != nil { return nil, xerrors.Errorf("connect client: %w", err) }