perf(coderd): reduce duplicated reads in push and webpush paths (#23115)

## Background

A 5000-chat scaletest (~50k turns, ~2m45s wall time) completed
successfully,
but the main bottleneck was **DB pool starvation from repeated reads**,
not
individually expensive SQL. The push/webpush path showed a few
especially noisy
reads:

- `GetLastChatMessageByRole` for push body generation
- `GetEnabledChatProviders` + `GetChatModelConfigByID` for push summary
model
  resolution
- `GetWebpushSubscriptionsByUserID` for every webpush dispatch

This PR keeps the optimizations that remove those duplicate reads while
leaving
stream behavior unchanged.

## What changes in this PR

### 1. Reuse resolved chat state for push notifications

`maybeSendPushNotification` used to re-read the last assistant message
and
re-resolve the chat model/provider after `runChat` had already done that
work.

Now `runChat` returns the final assistant text plus the already-resolved
model
and provider keys, and the push goroutine uses that state directly.

That removes the extra push-path reads for:

- `GetLastChatMessageByRole`
- the second `resolveChatModel` path
- the provider/model lookups that came with that second resolution

### 2. Cache webpush subscriptions during dispatch

`Dispatch()` previously hit `GetWebpushSubscriptionsByUserID` on every
push. A
small per-user in-memory cache now avoids those repeated reads.

The follow-up fix keeps that optimization correct: `InvalidateUser()`
bumps a
per-user generation so an older in-flight fetch cannot repopulate the
cache with
pre-mutation data after subscribe/unsubscribe.

That preserves the cache win without letting local subscription changes
be
silently overwritten by stale fetch results.

## Why this is safe

- The push change only reuses data already produced during the same chat
run. It
does not change notification semantics; if there is no assistant text to
  summarize, the existing fallback body still applies.
- The webpush change keeps the existing TTL and `410 Gone` cleanup
behavior. The
generation guard only prevents stale in-flight fetches from poisoning
the
  shared cache after invalidation.
- The final PR does **not** change stream setup, pubsub/relay behavior,
or chat
  status snapshot timing.

## Deliberately not included

- No stream-path optimization in `Subscribe`.
- No inline pubsub message payloads.
- No distributed cross-replica webpush cache invalidation.
This commit is contained in:
Ethan
2026-03-17 13:50:47 +11:00
committed by GitHub
parent 7cca2b6176
commit 04fca84872
6 changed files with 465 additions and 49 deletions
+190 -7
View File
@@ -9,18 +9,23 @@ import (
"net/http"
"slices"
"sync"
"time"
"github.com/SherClockHolmes/webpush-go"
"github.com/google/uuid"
"golang.org/x/sync/errgroup"
"golang.org/x/xerrors"
"tailscale.com/util/singleflight"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/quartz"
)
const defaultSubscriptionCacheTTL = 3 * time.Minute
// Dispatcher is an interface that can be used to dispatch
// web push notifications to clients such as browsers.
type Dispatcher interface {
@@ -33,6 +38,36 @@ type Dispatcher interface {
PublicKey() string
}
// SubscriptionCacheInvalidator is an optional interface that lets local
// subscription mutation handlers invalidate cached subscriptions.
type SubscriptionCacheInvalidator interface {
InvalidateUser(userID uuid.UUID)
}
type options struct {
clock quartz.Clock
subscriptionCacheTTL time.Duration
}
// Option configures optional behavior for a Webpusher.
type Option func(*options)
// WithClock sets the clock used by the subscription cache. Defaults to a real
// clock when not provided.
func WithClock(clock quartz.Clock) Option {
return func(o *options) {
o.clock = clock
}
}
// WithSubscriptionCacheTTL sets the in-memory subscription cache TTL. Defaults
// to three minutes when not provided or when given a non-positive duration.
func WithSubscriptionCacheTTL(ttl time.Duration) Option {
return func(o *options) {
o.subscriptionCacheTTL = ttl
}
}
// New creates a new Dispatcher to dispatch web push notifications.
//
// This is *not* integrated into the enqueue system unfortunately.
@@ -41,7 +76,21 @@ type Dispatcher interface {
// for updates inside of a workspace, which we want to be immediate.
//
// See: https://github.com/coder/internal/issues/528
func New(ctx context.Context, log *slog.Logger, db database.Store, vapidSub string) (Dispatcher, error) {
func New(ctx context.Context, log *slog.Logger, db database.Store, vapidSub string, opts ...Option) (Dispatcher, error) {
cfg := options{
clock: quartz.NewReal(),
subscriptionCacheTTL: defaultSubscriptionCacheTTL,
}
for _, opt := range opts {
opt(&cfg)
}
if cfg.clock == nil {
cfg.clock = quartz.NewReal()
}
if cfg.subscriptionCacheTTL <= 0 {
cfg.subscriptionCacheTTL = defaultSubscriptionCacheTTL
}
keys, err := db.GetWebpushVAPIDKeys(ctx)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
@@ -63,14 +112,23 @@ func New(ctx context.Context, log *slog.Logger, db database.Store, vapidSub stri
}
return &Webpusher{
vapidSub: vapidSub,
store: db,
log: log,
VAPIDPublicKey: keys.VapidPublicKey,
VAPIDPrivateKey: keys.VapidPrivateKey,
vapidSub: vapidSub,
store: db,
log: log,
VAPIDPublicKey: keys.VapidPublicKey,
VAPIDPrivateKey: keys.VapidPrivateKey,
clock: cfg.clock,
subscriptionCacheTTL: cfg.subscriptionCacheTTL,
subscriptionCache: make(map[uuid.UUID]cachedSubscriptions),
subscriptionGenerations: make(map[uuid.UUID]uint64),
}, nil
}
type cachedSubscriptions struct {
subscriptions []database.WebpushSubscription
expiresAt time.Time
}
type Webpusher struct {
store database.Store
log *slog.Logger
@@ -83,10 +141,18 @@ type Webpusher struct {
// the message payload.
VAPIDPublicKey string
VAPIDPrivateKey string
clock quartz.Clock
cacheMu sync.RWMutex
subscriptionCache map[uuid.UUID]cachedSubscriptions
subscriptionGenerations map[uuid.UUID]uint64
subscriptionCacheTTL time.Duration
subscriptionFetches singleflight.Group[string, []database.WebpushSubscription]
}
func (n *Webpusher) Dispatch(ctx context.Context, userID uuid.UUID, msg codersdk.WebpushMessage) error {
subscriptions, err := n.store.GetWebpushSubscriptionsByUserID(ctx, userID)
subscriptions, err := n.subscriptionsForUser(ctx, userID)
if err != nil {
return xerrors.Errorf("get web push subscriptions by user ID: %w", err)
}
@@ -142,12 +208,129 @@ func (n *Webpusher) Dispatch(ctx context.Context, userID uuid.UUID, msg codersdk
err = n.store.DeleteWebpushSubscriptions(dbauthz.AsNotifier(ctx), cleanupSubscriptions)
if err != nil {
n.log.Error(ctx, "failed to delete stale push subscriptions", slog.Error(err))
} else {
n.pruneSubscriptions(userID, cleanupSubscriptions)
}
}
return nil
}
func (n *Webpusher) subscriptionsForUser(ctx context.Context, userID uuid.UUID) ([]database.WebpushSubscription, error) {
if subscriptions, ok := n.cachedSubscriptions(userID); ok {
return subscriptions, nil
}
subscriptions, err, _ := n.subscriptionFetches.Do(userID.String(), func() ([]database.WebpushSubscription, error) {
if cached, ok := n.cachedSubscriptions(userID); ok {
return cached, nil
}
generation := n.subscriptionGeneration(userID)
fetched, err := n.store.GetWebpushSubscriptionsByUserID(ctx, userID)
if err != nil {
return nil, err
}
n.storeSubscriptions(userID, generation, fetched)
return slices.Clone(fetched), nil
})
if err != nil {
return nil, err
}
return slices.Clone(subscriptions), nil
}
func (n *Webpusher) cachedSubscriptions(userID uuid.UUID) ([]database.WebpushSubscription, bool) {
n.cacheMu.RLock()
entry, ok := n.subscriptionCache[userID]
n.cacheMu.RUnlock()
if !ok {
return nil, false
}
if n.clock.Now().Before(entry.expiresAt) {
return slices.Clone(entry.subscriptions), true
}
n.cacheMu.Lock()
if current, ok := n.subscriptionCache[userID]; ok && !n.clock.Now().Before(current.expiresAt) {
delete(n.subscriptionCache, userID)
}
n.cacheMu.Unlock()
return nil, false
}
func (n *Webpusher) subscriptionGeneration(userID uuid.UUID) uint64 {
n.cacheMu.RLock()
generation := n.subscriptionGenerations[userID]
n.cacheMu.RUnlock()
return generation
}
func (n *Webpusher) storeSubscriptions(userID uuid.UUID, generation uint64, subscriptions []database.WebpushSubscription) {
n.cacheMu.Lock()
defer n.cacheMu.Unlock()
if n.subscriptionGenerations[userID] != generation {
return
}
n.subscriptionCache[userID] = cachedSubscriptions{
subscriptions: slices.Clone(subscriptions),
expiresAt: n.clock.Now().Add(n.subscriptionCacheTTL),
}
}
func (n *Webpusher) pruneSubscriptions(userID uuid.UUID, staleIDs []uuid.UUID) {
if len(staleIDs) == 0 {
return
}
stale := make(map[uuid.UUID]struct{}, len(staleIDs))
for _, id := range staleIDs {
stale[id] = struct{}{}
}
n.cacheMu.Lock()
defer n.cacheMu.Unlock()
entry, ok := n.subscriptionCache[userID]
if !ok {
return
}
if !n.clock.Now().Before(entry.expiresAt) {
delete(n.subscriptionCache, userID)
return
}
filtered := make([]database.WebpushSubscription, 0, len(entry.subscriptions))
for _, subscription := range entry.subscriptions {
if _, shouldDelete := stale[subscription.ID]; shouldDelete {
continue
}
filtered = append(filtered, subscription)
}
if len(filtered) == 0 {
delete(n.subscriptionCache, userID)
return
}
entry.subscriptions = filtered
n.subscriptionCache[userID] = entry
}
// InvalidateUser clears the cached subscriptions for a user and advances
// its invalidation generation. Local subscribe and unsubscribe handlers call
// this after mutating subscriptions in the same process.
func (n *Webpusher) InvalidateUser(userID uuid.UUID) {
n.cacheMu.Lock()
delete(n.subscriptionCache, userID)
n.subscriptionGenerations[userID]++
n.cacheMu.Unlock()
n.subscriptionFetches.Forget(userID.String())
}
func (n *Webpusher) webpushSend(ctx context.Context, msg []byte, endpoint string, keys webpush.Keys) (int, []byte, error) {
// Copy the message to avoid modifying the original.
cpy := slices.Clone(msg)
+150 -3
View File
@@ -6,7 +6,9 @@ import (
"io"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
@@ -21,6 +23,7 @@ import (
"github.com/coder/coder/v2/coderd/webpush"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
)
const (
@@ -28,6 +31,20 @@ const (
validEndpointP256dhKey = "BNNL5ZaTfK81qhXOx23+wewhigUeFb632jN6LvRWCFH1ubQr77FE/9qV1FuojuRmHP42zmf34rXgW80OvUVDgTk="
)
type countingWebpushStore struct {
database.Store
getSubscriptionsCalls atomic.Int32
}
func (s *countingWebpushStore) GetWebpushSubscriptionsByUserID(ctx context.Context, userID uuid.UUID) ([]database.WebpushSubscription, error) {
s.getSubscriptionsCalls.Add(1)
return s.Store.GetWebpushSubscriptionsByUserID(ctx, userID)
}
func (s *countingWebpushStore) getCallCount() int32 {
return s.getSubscriptionsCalls.Load()
}
func TestPush(t *testing.T) {
t.Parallel()
@@ -216,6 +233,131 @@ func TestPush(t *testing.T) {
require.NoError(t, err)
assert.Empty(t, subscriptions, "No subscriptions should be returned")
})
t.Run("CachesSubscriptionsWithinTTL", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
clock := quartz.NewMock(t)
rawStore, _ := dbtestutil.NewDB(t)
store := &countingWebpushStore{Store: rawStore}
var delivered atomic.Int32
manager, _, serverURL := setupPushTestWithOptions(ctx, t, store, func(w http.ResponseWriter, r *http.Request) {
delivered.Add(1)
assertWebpushPayload(t, r)
w.WriteHeader(http.StatusOK)
}, webpush.WithClock(clock), webpush.WithSubscriptionCacheTTL(time.Minute))
user := dbgen.User(t, rawStore, database.User{})
_, err := rawStore.InsertWebpushSubscription(ctx, database.InsertWebpushSubscriptionParams{
CreatedAt: dbtime.Now(),
UserID: user.ID,
Endpoint: serverURL,
EndpointAuthKey: validEndpointAuthKey,
EndpointP256dhKey: validEndpointP256dhKey,
})
require.NoError(t, err)
msg := randomWebpushMessage(t)
err = manager.Dispatch(ctx, user.ID, msg)
require.NoError(t, err)
err = manager.Dispatch(ctx, user.ID, msg)
require.NoError(t, err)
require.Equal(t, int32(1), store.getCallCount(), "subscriptions should be read once within the TTL")
require.Equal(t, int32(2), delivered.Load(), "both dispatches should send a notification")
})
t.Run("RefreshesSubscriptionsAfterTTLExpires", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
clock := quartz.NewMock(t)
rawStore, _ := dbtestutil.NewDB(t)
store := &countingWebpushStore{Store: rawStore}
var delivered atomic.Int32
manager, _, serverURL := setupPushTestWithOptions(ctx, t, store, func(w http.ResponseWriter, r *http.Request) {
delivered.Add(1)
assertWebpushPayload(t, r)
w.WriteHeader(http.StatusOK)
}, webpush.WithClock(clock), webpush.WithSubscriptionCacheTTL(time.Minute))
user := dbgen.User(t, rawStore, database.User{})
_, err := rawStore.InsertWebpushSubscription(ctx, database.InsertWebpushSubscriptionParams{
CreatedAt: dbtime.Now(),
UserID: user.ID,
Endpoint: serverURL,
EndpointAuthKey: validEndpointAuthKey,
EndpointP256dhKey: validEndpointP256dhKey,
})
require.NoError(t, err)
msg := randomWebpushMessage(t)
err = manager.Dispatch(ctx, user.ID, msg)
require.NoError(t, err)
clock.Advance(time.Minute)
err = manager.Dispatch(ctx, user.ID, msg)
require.NoError(t, err)
require.Equal(t, int32(2), store.getCallCount(), "dispatch should refresh subscriptions after the TTL expires")
require.Equal(t, int32(2), delivered.Load(), "both dispatches should send a notification")
})
t.Run("PrunesStaleSubscriptionsFromCache", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
clock := quartz.NewMock(t)
rawStore, _ := dbtestutil.NewDB(t)
store := &countingWebpushStore{Store: rawStore}
var okCalls atomic.Int32
var goneCalls atomic.Int32
manager, _, okServerURL := setupPushTestWithOptions(ctx, t, store, func(w http.ResponseWriter, r *http.Request) {
okCalls.Add(1)
assertWebpushPayload(t, r)
w.WriteHeader(http.StatusOK)
}, webpush.WithClock(clock), webpush.WithSubscriptionCacheTTL(time.Minute))
goneServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
goneCalls.Add(1)
assertWebpushPayload(t, r)
w.WriteHeader(http.StatusGone)
}))
defer goneServer.Close()
user := dbgen.User(t, rawStore, database.User{})
okSubscription, err := rawStore.InsertWebpushSubscription(ctx, database.InsertWebpushSubscriptionParams{
CreatedAt: dbtime.Now(),
UserID: user.ID,
Endpoint: okServerURL,
EndpointAuthKey: validEndpointAuthKey,
EndpointP256dhKey: validEndpointP256dhKey,
})
require.NoError(t, err)
_, err = rawStore.InsertWebpushSubscription(ctx, database.InsertWebpushSubscriptionParams{
CreatedAt: dbtime.Now(),
UserID: user.ID,
Endpoint: goneServer.URL,
EndpointAuthKey: validEndpointAuthKey,
EndpointP256dhKey: validEndpointP256dhKey,
})
require.NoError(t, err)
msg := randomWebpushMessage(t)
err = manager.Dispatch(ctx, user.ID, msg)
require.NoError(t, err)
err = manager.Dispatch(ctx, user.ID, msg)
require.NoError(t, err)
require.Equal(t, int32(1), store.getCallCount(), "stale subscription cleanup should not force a second DB read within the TTL")
require.Equal(t, int32(2), okCalls.Load(), "the healthy endpoint should receive both dispatches")
require.Equal(t, int32(1), goneCalls.Load(), "the stale endpoint should be pruned from the cache after the first dispatch")
subscriptions, err := rawStore.GetWebpushSubscriptionsByUserID(ctx, user.ID)
require.NoError(t, err)
require.Len(t, subscriptions, 1, "only the healthy subscription should remain")
require.Equal(t, okSubscription.ID, subscriptions[0].ID)
})
}
func randomWebpushMessage(t testing.TB) codersdk.WebpushMessage {
@@ -244,16 +386,21 @@ func assertWebpushPayload(t testing.TB, r *http.Request) {
assert.Error(t, json.NewDecoder(r.Body).Decode(io.Discard))
}
// setupPushTest creates a common test setup for webpush notification tests
// setupPushTest creates a common test setup for webpush notification tests.
func setupPushTest(ctx context.Context, t *testing.T, handlerFunc func(w http.ResponseWriter, r *http.Request)) (webpush.Dispatcher, database.Store, string) {
t.Helper()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
db, _ := dbtestutil.NewDB(t)
return setupPushTestWithOptions(ctx, t, db, handlerFunc)
}
func setupPushTestWithOptions(ctx context.Context, t *testing.T, db database.Store, handlerFunc func(w http.ResponseWriter, r *http.Request), opts ...webpush.Option) (webpush.Dispatcher, database.Store, string) {
t.Helper()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
server := httptest.NewServer(http.HandlerFunc(handlerFunc))
t.Cleanup(server.Close)
manager, err := webpush.New(ctx, &logger, db, "http://example.com")
manager, err := webpush.New(ctx, &logger, db, "http://example.com", opts...)
require.NoError(t, err, "Failed to create webpush manager")
return manager, db, server.URL