From 04fca84872d9f2b28463b1d524dd738feb5fa26a Mon Sep 17 00:00:00 2001 From: Ethan <39577870+ethanndickson@users.noreply.github.com> Date: Tue, 17 Mar 2026 13:50:47 +1100 Subject: [PATCH] perf(coderd): reduce duplicated reads in push and webpush paths (#23115) ## Background A 5000-chat scaletest (~50k turns, ~2m45s wall time) completed successfully, but the main bottleneck was **DB pool starvation from repeated reads**, not individually expensive SQL. The push/webpush path showed a few especially noisy reads: - `GetLastChatMessageByRole` for push body generation - `GetEnabledChatProviders` + `GetChatModelConfigByID` for push summary model resolution - `GetWebpushSubscriptionsByUserID` for every webpush dispatch This PR keeps the optimizations that remove those duplicate reads while leaving stream behavior unchanged. ## What changes in this PR ### 1. Reuse resolved chat state for push notifications `maybeSendPushNotification` used to re-read the last assistant message and re-resolve the chat model/provider after `runChat` had already done that work. Now `runChat` returns the final assistant text plus the already-resolved model and provider keys, and the push goroutine uses that state directly. That removes the extra push-path reads for: - `GetLastChatMessageByRole` - the second `resolveChatModel` path - the provider/model lookups that came with that second resolution ### 2. Cache webpush subscriptions during dispatch `Dispatch()` previously hit `GetWebpushSubscriptionsByUserID` on every push. A small per-user in-memory cache now avoids those repeated reads. The follow-up fix keeps that optimization correct: `InvalidateUser()` bumps a per-user generation so an older in-flight fetch cannot repopulate the cache with pre-mutation data after subscribe/unsubscribe. That preserves the cache win without letting local subscription changes be silently overwritten by stale fetch results. ## Why this is safe - The push change only reuses data already produced during the same chat run. It does not change notification semantics; if there is no assistant text to summarize, the existing fallback body still applies. - The webpush change keeps the existing TTL and `410 Gone` cleanup behavior. The generation guard only prevents stale in-flight fetches from poisoning the shared cache after invalidation. - The final PR does **not** change stream setup, pubsub/relay behavior, or chat status snapshot timing. ## Deliberately not included - No stream-path optimization in `Subscribe`. - No inline pubsub message payloads. - No distributed cross-replica webpush cache invalidation. --- coderd/chatd/chatd.go | 69 +++++++----- coderd/chatd/chatd_test.go | 63 ++++++++++- coderd/webpush.go | 7 ++ coderd/webpush/webpush.go | 197 +++++++++++++++++++++++++++++++-- coderd/webpush/webpush_test.go | 153 ++++++++++++++++++++++++- coderd/webpush_test.go | 25 +++-- 6 files changed, 465 insertions(+), 49 deletions(-) diff --git a/coderd/chatd/chatd.go b/coderd/chatd/chatd.go index ef8c9aafdb..c820e41ec7 100644 --- a/coderd/chatd/chatd.go +++ b/coderd/chatd/chatd.go @@ -2038,6 +2038,7 @@ func (p *Server) processChat(ctx context.Context, chat database.Chat) { status := database.ChatStatusWaiting wasInterrupted := false lastError := "" + runResult := runChatResult{} remainingQueuedMessages := []database.ChatQueuedMessage{} shouldPublishQueueUpdate := false var promotedMessage *database.ChatMessage @@ -2144,11 +2145,12 @@ func (p *Server) processChat(ctx context.Context, chat database.Chat) { p.publishChatPubsubEvent(chat, coderdpubsub.ChatEventKindStatusChange, nil) if !wasInterrupted { - p.maybeSendPushNotification(cleanupCtx, chat, status, lastError, logger) + p.maybeSendPushNotification(cleanupCtx, chat, status, lastError, runResult, logger) } }() - if err := p.runChat(chatCtx, chat, logger); err != nil { + runResult, err := p.runChat(chatCtx, chat, logger) + if err != nil { if errors.Is(err, chatloop.ErrInterrupted) || errors.Is(context.Cause(chatCtx), chatloop.ErrInterrupted) { logger.Info(ctx, "chat interrupted") status = database.ChatStatusWaiting @@ -2205,11 +2207,18 @@ func isShutdownCancellation( return errors.Is(context.Cause(chatCtx), context.Canceled) } +type runChatResult struct { + FinalAssistantText string + PushSummaryModel fantasy.LanguageModel + ProviderKeys chatprovider.ProviderAPIKeys +} + func (p *Server) runChat( ctx context.Context, chat database.Chat, logger slog.Logger, -) error { +) (runChatResult, error) { + result := runChatResult{} var ( model fantasy.LanguageModel modelConfig database.ChatModelConfig @@ -2241,14 +2250,16 @@ func (p *Server) runChat( return nil }) if err := g.Wait(); err != nil { - return err + return result, err } + result.PushSummaryModel = model + result.ProviderKeys = providerKeys // Fire title generation asynchronously so it doesn't block the // chat response. It uses a detached context so it can finish // even after the chat processing context is canceled. - // Snapshot model so the goroutine doesn't race with the - // model = cuModel reassignment below. - titleModel := model + // Snapshot the original chat model so the goroutine doesn't + // race with the model = cuModel reassignment below. + titleModel := result.PushSummaryModel p.inflight.Add(1) go func() { defer p.inflight.Done() @@ -2257,7 +2268,7 @@ func (p *Server) runChat( prompt, err := chatprompt.ConvertMessagesWithFiles(ctx, messages, p.chatFileResolver(), logger) if err != nil { - return xerrors.Errorf("build chat prompt: %w", err) + return result, xerrors.Errorf("build chat prompt: %w", err) } if chat.ParentChatID.Valid { prompt = chatprompt.InsertSystem(prompt, defaultSubagentInstruction) @@ -2389,9 +2400,11 @@ func (p *Server) runChat( prompt = chatprompt.InsertSystem(prompt, resolvedUserPrompt) } - // Use the model config's context_limit as a fallback when the LLM // provider doesn't include context_limit in its response metadata + // Use the model config's context_limit as a fallback when the LLM + // provider doesn't include context_limit in its response metadata // (which is the common case). modelConfigContextLimit := modelConfig.ContextLimit + var finalAssistantText string persistStep := func(persistCtx context.Context, step chatloop.PersistedStep) error { // If the chat context has been canceled, bail out before @@ -2455,6 +2468,7 @@ func (p *Server) runChat( for _, block := range assistantBlocks { sdkParts = append(sdkParts, chatprompt.PartFromContent(block)) } + finalAssistantText = strings.TrimSpace(contentBlocksToText(sdkParts)) assistantContent, marshalErr := chatprompt.MarshalParts(sdkParts) if marshalErr != nil { return marshalErr @@ -2630,7 +2644,7 @@ func (p *Server) runChat( chatprovider.UserAgent(), ) if cuErr != nil { - return xerrors.Errorf("resolve computer use model: %w", cuErr) + return result, xerrors.Errorf("resolve computer use model: %w", cuErr) } model = cuModel } @@ -2796,7 +2810,11 @@ func (p *Server) runChat( p.logger.Warn(ctx, "failed to persist interrupted chat step", slog.Error(err)) }, }) - return err + if err != nil { + return result, err + } + result.FinalAssistantText = finalAssistantText + return result, nil } // buildProviderTools creates provider-native tool definitions @@ -3301,6 +3319,7 @@ func (p *Server) maybeSendPushNotification( chat database.Chat, status database.ChatStatus, lastError string, + runResult runChatResult, logger slog.Logger, ) { if p.webpushDispatcher == nil || p.webpushDispatcher.PublicKey() == "" { @@ -3328,23 +3347,17 @@ func (p *Server) maybeSendPushNotification( defer p.inflight.Done() pushCtx := context.WithoutCancel(ctx) pushBody := "Agent has finished running." - - msg, err := p.db.GetLastChatMessageByRole(pushCtx, database.GetLastChatMessageByRoleParams{ - ChatID: chat.ID, - Role: database.ChatMessageRoleAssistant, - }) - if err == nil { - content, parseErr := chatprompt.ParseContent(msg) - if parseErr == nil { - assistantText := strings.TrimSpace(contentBlocksToText(content)) - if assistantText != "" { - model, _, keys, resolveErr := p.resolveChatModel(pushCtx, chat) - if resolveErr == nil { - if summary := generatePushSummary(pushCtx, chat.Title, assistantText, model, keys, logger); summary != "" { - pushBody = summary - } - } - } + assistantText := strings.TrimSpace(runResult.FinalAssistantText) + if assistantText != "" && runResult.PushSummaryModel != nil { + if summary := generatePushSummary( + pushCtx, + chat.Title, + assistantText, + runResult.PushSummaryModel, + runResult.ProviderKeys, + logger, + ); summary != "" { + pushBody = summary } } diff --git a/coderd/chatd/chatd_test.go b/coderd/chatd/chatd_test.go index 11920a042c..ced3eb00e7 100644 --- a/coderd/chatd/chatd_test.go +++ b/coderd/chatd/chatd_test.go @@ -2234,12 +2234,10 @@ func TestSuccessfulChatSendsWebPushWithSummary(t *testing.T) { const assistantText = "I have completed the task successfully and all tests are passing now." const summaryText = "Completed task and verified all tests pass." + var nonStreamingRequests atomic.Int32 openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { if !req.Stream { - // Non-streaming calls are used for title - // generation and push summary generation. - // Return the summary text for both — the title - // result is irrelevant to this test. + nonStreamingRequests.Add(1) return chattest.OpenAINonStreamingResponse(summaryText) } return chattest.OpenAIStreamingResponse( @@ -2286,6 +2284,63 @@ func TestSuccessfulChatSendsWebPushWithSummary(t *testing.T) { "push body should be the LLM-generated summary") require.NotEqual(t, "Agent has finished running.", msg.Body, "push body should not use the default fallback text") + require.Equal(t, int32(1), nonStreamingRequests.Load(), + "expected exactly one non-streaming request for push summary generation") +} + +func TestSuccessfulChatSendsWebPushFallbackWithoutSummaryForEmptyAssistantText(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + ctx := testutil.Context(t, testutil.WaitLong) + + var nonStreamingRequests atomic.Int32 + openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { + if !req.Stream { + nonStreamingRequests.Add(1) + return chattest.OpenAINonStreamingResponse("unexpected summary request") + } + return chattest.OpenAIStreamingResponse( + chattest.OpenAITextChunks(" ")..., + ) + }) + + mockPush := &mockWebpushDispatcher{} + + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + server := chatd.New(chatd.Config{ + Logger: logger, + Database: db, + ReplicaID: uuid.New(), + Pubsub: ps, + PendingChatAcquireInterval: 10 * time.Millisecond, + InFlightChatStaleAfter: testutil.WaitSuperLong, + WebpushDispatcher: mockPush, + }) + t.Cleanup(func() { + require.NoError(t, server.Close()) + }) + + user, model := seedChatDependencies(ctx, t, db) + setOpenAIProviderBaseURL(ctx, t, db, openAIURL) + + _, err := server.CreateChat(ctx, chatd.CreateOptions{ + OwnerID: user.ID, + Title: "empty-summary-push-test", + ModelConfigID: model.ID, + InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("do the thing")}, + }) + require.NoError(t, err) + + testutil.Eventually(ctx, t, func(ctx context.Context) bool { + return mockPush.dispatchCount.Load() >= 1 + }, testutil.IntervalFast) + + msg := mockPush.getLastMessage() + require.Equal(t, "Agent has finished running.", msg.Body, + "push body should fall back when the final assistant text is empty") + require.Equal(t, int32(0), nonStreamingRequests.Load(), + "push summary should not be requested when final assistant text has no usable text") } func TestComputerUseSubagentToolsAndModel(t *testing.T) { diff --git a/coderd/webpush.go b/coderd/webpush.go index 4d7687f8e9..e275873400 100644 --- a/coderd/webpush.go +++ b/coderd/webpush.go @@ -12,6 +12,7 @@ import ( "github.com/coder/coder/v2/coderd/httpmw" "github.com/coder/coder/v2/coderd/rbac" "github.com/coder/coder/v2/coderd/rbac/policy" + "github.com/coder/coder/v2/coderd/webpush" "github.com/coder/coder/v2/codersdk" ) @@ -54,6 +55,9 @@ func (api *API) postUserWebpushSubscription(rw http.ResponseWriter, r *http.Requ }) return } + if invalidator, ok := api.WebpushDispatcher.(webpush.SubscriptionCacheInvalidator); ok { + invalidator.InvalidateUser(user.ID) + } rw.WriteHeader(http.StatusNoContent) } @@ -111,6 +115,9 @@ func (api *API) deleteUserWebpushSubscription(rw http.ResponseWriter, r *http.Re }) return } + if invalidator, ok := api.WebpushDispatcher.(webpush.SubscriptionCacheInvalidator); ok { + invalidator.InvalidateUser(user.ID) + } rw.WriteHeader(http.StatusNoContent) } diff --git a/coderd/webpush/webpush.go b/coderd/webpush/webpush.go index 177730db77..94f7d8da24 100644 --- a/coderd/webpush/webpush.go +++ b/coderd/webpush/webpush.go @@ -9,18 +9,23 @@ import ( "net/http" "slices" "sync" + "time" "github.com/SherClockHolmes/webpush-go" "github.com/google/uuid" "golang.org/x/sync/errgroup" "golang.org/x/xerrors" + "tailscale.com/util/singleflight" "cdr.dev/slog/v3" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/codersdk" + "github.com/coder/quartz" ) +const defaultSubscriptionCacheTTL = 3 * time.Minute + // Dispatcher is an interface that can be used to dispatch // web push notifications to clients such as browsers. type Dispatcher interface { @@ -33,6 +38,36 @@ type Dispatcher interface { PublicKey() string } +// SubscriptionCacheInvalidator is an optional interface that lets local +// subscription mutation handlers invalidate cached subscriptions. +type SubscriptionCacheInvalidator interface { + InvalidateUser(userID uuid.UUID) +} + +type options struct { + clock quartz.Clock + subscriptionCacheTTL time.Duration +} + +// Option configures optional behavior for a Webpusher. +type Option func(*options) + +// WithClock sets the clock used by the subscription cache. Defaults to a real +// clock when not provided. +func WithClock(clock quartz.Clock) Option { + return func(o *options) { + o.clock = clock + } +} + +// WithSubscriptionCacheTTL sets the in-memory subscription cache TTL. Defaults +// to three minutes when not provided or when given a non-positive duration. +func WithSubscriptionCacheTTL(ttl time.Duration) Option { + return func(o *options) { + o.subscriptionCacheTTL = ttl + } +} + // New creates a new Dispatcher to dispatch web push notifications. // // This is *not* integrated into the enqueue system unfortunately. @@ -41,7 +76,21 @@ type Dispatcher interface { // for updates inside of a workspace, which we want to be immediate. // // See: https://github.com/coder/internal/issues/528 -func New(ctx context.Context, log *slog.Logger, db database.Store, vapidSub string) (Dispatcher, error) { +func New(ctx context.Context, log *slog.Logger, db database.Store, vapidSub string, opts ...Option) (Dispatcher, error) { + cfg := options{ + clock: quartz.NewReal(), + subscriptionCacheTTL: defaultSubscriptionCacheTTL, + } + for _, opt := range opts { + opt(&cfg) + } + if cfg.clock == nil { + cfg.clock = quartz.NewReal() + } + if cfg.subscriptionCacheTTL <= 0 { + cfg.subscriptionCacheTTL = defaultSubscriptionCacheTTL + } + keys, err := db.GetWebpushVAPIDKeys(ctx) if err != nil { if !errors.Is(err, sql.ErrNoRows) { @@ -63,14 +112,23 @@ func New(ctx context.Context, log *slog.Logger, db database.Store, vapidSub stri } return &Webpusher{ - vapidSub: vapidSub, - store: db, - log: log, - VAPIDPublicKey: keys.VapidPublicKey, - VAPIDPrivateKey: keys.VapidPrivateKey, + vapidSub: vapidSub, + store: db, + log: log, + VAPIDPublicKey: keys.VapidPublicKey, + VAPIDPrivateKey: keys.VapidPrivateKey, + clock: cfg.clock, + subscriptionCacheTTL: cfg.subscriptionCacheTTL, + subscriptionCache: make(map[uuid.UUID]cachedSubscriptions), + subscriptionGenerations: make(map[uuid.UUID]uint64), }, nil } +type cachedSubscriptions struct { + subscriptions []database.WebpushSubscription + expiresAt time.Time +} + type Webpusher struct { store database.Store log *slog.Logger @@ -83,10 +141,18 @@ type Webpusher struct { // the message payload. VAPIDPublicKey string VAPIDPrivateKey string + + clock quartz.Clock + + cacheMu sync.RWMutex + subscriptionCache map[uuid.UUID]cachedSubscriptions + subscriptionGenerations map[uuid.UUID]uint64 + subscriptionCacheTTL time.Duration + subscriptionFetches singleflight.Group[string, []database.WebpushSubscription] } func (n *Webpusher) Dispatch(ctx context.Context, userID uuid.UUID, msg codersdk.WebpushMessage) error { - subscriptions, err := n.store.GetWebpushSubscriptionsByUserID(ctx, userID) + subscriptions, err := n.subscriptionsForUser(ctx, userID) if err != nil { return xerrors.Errorf("get web push subscriptions by user ID: %w", err) } @@ -142,12 +208,129 @@ func (n *Webpusher) Dispatch(ctx context.Context, userID uuid.UUID, msg codersdk err = n.store.DeleteWebpushSubscriptions(dbauthz.AsNotifier(ctx), cleanupSubscriptions) if err != nil { n.log.Error(ctx, "failed to delete stale push subscriptions", slog.Error(err)) + } else { + n.pruneSubscriptions(userID, cleanupSubscriptions) } } return nil } +func (n *Webpusher) subscriptionsForUser(ctx context.Context, userID uuid.UUID) ([]database.WebpushSubscription, error) { + if subscriptions, ok := n.cachedSubscriptions(userID); ok { + return subscriptions, nil + } + + subscriptions, err, _ := n.subscriptionFetches.Do(userID.String(), func() ([]database.WebpushSubscription, error) { + if cached, ok := n.cachedSubscriptions(userID); ok { + return cached, nil + } + + generation := n.subscriptionGeneration(userID) + fetched, err := n.store.GetWebpushSubscriptionsByUserID(ctx, userID) + if err != nil { + return nil, err + } + n.storeSubscriptions(userID, generation, fetched) + return slices.Clone(fetched), nil + }) + if err != nil { + return nil, err + } + + return slices.Clone(subscriptions), nil +} + +func (n *Webpusher) cachedSubscriptions(userID uuid.UUID) ([]database.WebpushSubscription, bool) { + n.cacheMu.RLock() + entry, ok := n.subscriptionCache[userID] + n.cacheMu.RUnlock() + if !ok { + return nil, false + } + if n.clock.Now().Before(entry.expiresAt) { + return slices.Clone(entry.subscriptions), true + } + + n.cacheMu.Lock() + if current, ok := n.subscriptionCache[userID]; ok && !n.clock.Now().Before(current.expiresAt) { + delete(n.subscriptionCache, userID) + } + n.cacheMu.Unlock() + + return nil, false +} + +func (n *Webpusher) subscriptionGeneration(userID uuid.UUID) uint64 { + n.cacheMu.RLock() + generation := n.subscriptionGenerations[userID] + n.cacheMu.RUnlock() + return generation +} + +func (n *Webpusher) storeSubscriptions(userID uuid.UUID, generation uint64, subscriptions []database.WebpushSubscription) { + n.cacheMu.Lock() + defer n.cacheMu.Unlock() + + if n.subscriptionGenerations[userID] != generation { + return + } + + n.subscriptionCache[userID] = cachedSubscriptions{ + subscriptions: slices.Clone(subscriptions), + expiresAt: n.clock.Now().Add(n.subscriptionCacheTTL), + } +} + +func (n *Webpusher) pruneSubscriptions(userID uuid.UUID, staleIDs []uuid.UUID) { + if len(staleIDs) == 0 { + return + } + + stale := make(map[uuid.UUID]struct{}, len(staleIDs)) + for _, id := range staleIDs { + stale[id] = struct{}{} + } + + n.cacheMu.Lock() + defer n.cacheMu.Unlock() + + entry, ok := n.subscriptionCache[userID] + if !ok { + return + } + if !n.clock.Now().Before(entry.expiresAt) { + delete(n.subscriptionCache, userID) + return + } + + filtered := make([]database.WebpushSubscription, 0, len(entry.subscriptions)) + for _, subscription := range entry.subscriptions { + if _, shouldDelete := stale[subscription.ID]; shouldDelete { + continue + } + filtered = append(filtered, subscription) + } + if len(filtered) == 0 { + delete(n.subscriptionCache, userID) + return + } + + entry.subscriptions = filtered + n.subscriptionCache[userID] = entry +} + +// InvalidateUser clears the cached subscriptions for a user and advances +// its invalidation generation. Local subscribe and unsubscribe handlers call +// this after mutating subscriptions in the same process. +func (n *Webpusher) InvalidateUser(userID uuid.UUID) { + n.cacheMu.Lock() + delete(n.subscriptionCache, userID) + n.subscriptionGenerations[userID]++ + n.cacheMu.Unlock() + n.subscriptionFetches.Forget(userID.String()) +} + func (n *Webpusher) webpushSend(ctx context.Context, msg []byte, endpoint string, keys webpush.Keys) (int, []byte, error) { // Copy the message to avoid modifying the original. cpy := slices.Clone(msg) diff --git a/coderd/webpush/webpush_test.go b/coderd/webpush/webpush_test.go index bfb2b39c20..fdd394b286 100644 --- a/coderd/webpush/webpush_test.go +++ b/coderd/webpush/webpush_test.go @@ -6,7 +6,9 @@ import ( "io" "net/http" "net/http/httptest" + "sync/atomic" "testing" + "time" "github.com/google/uuid" "github.com/stretchr/testify/assert" @@ -21,6 +23,7 @@ import ( "github.com/coder/coder/v2/coderd/webpush" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/testutil" + "github.com/coder/quartz" ) const ( @@ -28,6 +31,20 @@ const ( validEndpointP256dhKey = "BNNL5ZaTfK81qhXOx23+wewhigUeFb632jN6LvRWCFH1ubQr77FE/9qV1FuojuRmHP42zmf34rXgW80OvUVDgTk=" ) +type countingWebpushStore struct { + database.Store + getSubscriptionsCalls atomic.Int32 +} + +func (s *countingWebpushStore) GetWebpushSubscriptionsByUserID(ctx context.Context, userID uuid.UUID) ([]database.WebpushSubscription, error) { + s.getSubscriptionsCalls.Add(1) + return s.Store.GetWebpushSubscriptionsByUserID(ctx, userID) +} + +func (s *countingWebpushStore) getCallCount() int32 { + return s.getSubscriptionsCalls.Load() +} + func TestPush(t *testing.T) { t.Parallel() @@ -216,6 +233,131 @@ func TestPush(t *testing.T) { require.NoError(t, err) assert.Empty(t, subscriptions, "No subscriptions should be returned") }) + + t.Run("CachesSubscriptionsWithinTTL", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + clock := quartz.NewMock(t) + rawStore, _ := dbtestutil.NewDB(t) + store := &countingWebpushStore{Store: rawStore} + var delivered atomic.Int32 + manager, _, serverURL := setupPushTestWithOptions(ctx, t, store, func(w http.ResponseWriter, r *http.Request) { + delivered.Add(1) + assertWebpushPayload(t, r) + w.WriteHeader(http.StatusOK) + }, webpush.WithClock(clock), webpush.WithSubscriptionCacheTTL(time.Minute)) + + user := dbgen.User(t, rawStore, database.User{}) + _, err := rawStore.InsertWebpushSubscription(ctx, database.InsertWebpushSubscriptionParams{ + CreatedAt: dbtime.Now(), + UserID: user.ID, + Endpoint: serverURL, + EndpointAuthKey: validEndpointAuthKey, + EndpointP256dhKey: validEndpointP256dhKey, + }) + require.NoError(t, err) + + msg := randomWebpushMessage(t) + err = manager.Dispatch(ctx, user.ID, msg) + require.NoError(t, err) + err = manager.Dispatch(ctx, user.ID, msg) + require.NoError(t, err) + + require.Equal(t, int32(1), store.getCallCount(), "subscriptions should be read once within the TTL") + require.Equal(t, int32(2), delivered.Load(), "both dispatches should send a notification") + }) + + t.Run("RefreshesSubscriptionsAfterTTLExpires", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + clock := quartz.NewMock(t) + rawStore, _ := dbtestutil.NewDB(t) + store := &countingWebpushStore{Store: rawStore} + var delivered atomic.Int32 + manager, _, serverURL := setupPushTestWithOptions(ctx, t, store, func(w http.ResponseWriter, r *http.Request) { + delivered.Add(1) + assertWebpushPayload(t, r) + w.WriteHeader(http.StatusOK) + }, webpush.WithClock(clock), webpush.WithSubscriptionCacheTTL(time.Minute)) + + user := dbgen.User(t, rawStore, database.User{}) + _, err := rawStore.InsertWebpushSubscription(ctx, database.InsertWebpushSubscriptionParams{ + CreatedAt: dbtime.Now(), + UserID: user.ID, + Endpoint: serverURL, + EndpointAuthKey: validEndpointAuthKey, + EndpointP256dhKey: validEndpointP256dhKey, + }) + require.NoError(t, err) + + msg := randomWebpushMessage(t) + err = manager.Dispatch(ctx, user.ID, msg) + require.NoError(t, err) + clock.Advance(time.Minute) + err = manager.Dispatch(ctx, user.ID, msg) + require.NoError(t, err) + + require.Equal(t, int32(2), store.getCallCount(), "dispatch should refresh subscriptions after the TTL expires") + require.Equal(t, int32(2), delivered.Load(), "both dispatches should send a notification") + }) + + t.Run("PrunesStaleSubscriptionsFromCache", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + clock := quartz.NewMock(t) + rawStore, _ := dbtestutil.NewDB(t) + store := &countingWebpushStore{Store: rawStore} + var okCalls atomic.Int32 + var goneCalls atomic.Int32 + manager, _, okServerURL := setupPushTestWithOptions(ctx, t, store, func(w http.ResponseWriter, r *http.Request) { + okCalls.Add(1) + assertWebpushPayload(t, r) + w.WriteHeader(http.StatusOK) + }, webpush.WithClock(clock), webpush.WithSubscriptionCacheTTL(time.Minute)) + + goneServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + goneCalls.Add(1) + assertWebpushPayload(t, r) + w.WriteHeader(http.StatusGone) + })) + defer goneServer.Close() + + user := dbgen.User(t, rawStore, database.User{}) + okSubscription, err := rawStore.InsertWebpushSubscription(ctx, database.InsertWebpushSubscriptionParams{ + CreatedAt: dbtime.Now(), + UserID: user.ID, + Endpoint: okServerURL, + EndpointAuthKey: validEndpointAuthKey, + EndpointP256dhKey: validEndpointP256dhKey, + }) + require.NoError(t, err) + _, err = rawStore.InsertWebpushSubscription(ctx, database.InsertWebpushSubscriptionParams{ + CreatedAt: dbtime.Now(), + UserID: user.ID, + Endpoint: goneServer.URL, + EndpointAuthKey: validEndpointAuthKey, + EndpointP256dhKey: validEndpointP256dhKey, + }) + require.NoError(t, err) + + msg := randomWebpushMessage(t) + err = manager.Dispatch(ctx, user.ID, msg) + require.NoError(t, err) + err = manager.Dispatch(ctx, user.ID, msg) + require.NoError(t, err) + + require.Equal(t, int32(1), store.getCallCount(), "stale subscription cleanup should not force a second DB read within the TTL") + require.Equal(t, int32(2), okCalls.Load(), "the healthy endpoint should receive both dispatches") + require.Equal(t, int32(1), goneCalls.Load(), "the stale endpoint should be pruned from the cache after the first dispatch") + + subscriptions, err := rawStore.GetWebpushSubscriptionsByUserID(ctx, user.ID) + require.NoError(t, err) + require.Len(t, subscriptions, 1, "only the healthy subscription should remain") + require.Equal(t, okSubscription.ID, subscriptions[0].ID) + }) } func randomWebpushMessage(t testing.TB) codersdk.WebpushMessage { @@ -244,16 +386,21 @@ func assertWebpushPayload(t testing.TB, r *http.Request) { assert.Error(t, json.NewDecoder(r.Body).Decode(io.Discard)) } -// setupPushTest creates a common test setup for webpush notification tests +// setupPushTest creates a common test setup for webpush notification tests. func setupPushTest(ctx context.Context, t *testing.T, handlerFunc func(w http.ResponseWriter, r *http.Request)) (webpush.Dispatcher, database.Store, string) { t.Helper() - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) db, _ := dbtestutil.NewDB(t) + return setupPushTestWithOptions(ctx, t, db, handlerFunc) +} + +func setupPushTestWithOptions(ctx context.Context, t *testing.T, db database.Store, handlerFunc func(w http.ResponseWriter, r *http.Request), opts ...webpush.Option) (webpush.Dispatcher, database.Store, string) { + t.Helper() + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) server := httptest.NewServer(http.HandlerFunc(handlerFunc)) t.Cleanup(server.Close) - manager, err := webpush.New(ctx, &logger, db, "http://example.com") + manager, err := webpush.New(ctx, &logger, db, "http://example.com", opts...) require.NoError(t, err, "Failed to create webpush manager") return manager, db, server.URL diff --git a/coderd/webpush_test.go b/coderd/webpush_test.go index b36af9e236..353cc676b4 100644 --- a/coderd/webpush_test.go +++ b/coderd/webpush_test.go @@ -35,31 +35,42 @@ func TestWebpushSubscribeUnsubscribe(t *testing.T) { memberClient, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID) _, anotherMember := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID) - handlerCalled := make(chan bool, 1) + var handlerCalls atomic.Int32 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusCreated) - handlerCalled <- true + handlerCalls.Add(1) })) defer server.Close() - err := memberClient.PostWebpushSubscription(ctx, "me", codersdk.WebpushSubscription{ + // Seed the dispatcher cache with an empty subscription set. Creating the + // subscription should invalidate that entry so the next dispatch sees the new + // subscription immediately. + err := memberClient.PostTestWebpushMessage(ctx) + require.NoError(t, err, "test webpush message without a subscription") + require.Zero(t, handlerCalls.Load(), "a user without subscriptions should not receive a push") + + err = memberClient.PostWebpushSubscription(ctx, "me", codersdk.WebpushSubscription{ Endpoint: server.URL, AuthKey: validEndpointAuthKey, P256DHKey: validEndpointP256dhKey, }) require.NoError(t, err, "create webpush subscription") - require.True(t, <-handlerCalled, "handler should have been called") + require.Equal(t, int32(1), handlerCalls.Load(), "subscription validation should hit the endpoint once") err = memberClient.PostTestWebpushMessage(ctx) - require.NoError(t, err, "test webpush message") - require.True(t, <-handlerCalled, "handler should have been called again") + require.NoError(t, err, "test webpush message after subscribing") + require.Equal(t, int32(2), handlerCalls.Load(), "the dispatcher should invalidate empty cache entries after subscribing") err = memberClient.DeleteWebpushSubscription(ctx, "me", codersdk.DeleteWebpushSubscription{ Endpoint: server.URL, }) require.NoError(t, err, "delete webpush subscription") - // Deleting the subscription for a non-existent endpoint should return a 404 + err = memberClient.PostTestWebpushMessage(ctx) + require.NoError(t, err, "test webpush message after unsubscribing") + require.Equal(t, int32(2), handlerCalls.Load(), "the dispatcher should invalidate cached subscriptions after unsubscribing") + + // Deleting the subscription for a non-existent endpoint should return a 404. err = memberClient.DeleteWebpushSubscription(ctx, "me", codersdk.DeleteWebpushSubscription{ Endpoint: server.URL, })