mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
perf(coderd): reduce duplicated reads in push and webpush paths (#23115)
## Background A 5000-chat scaletest (~50k turns, ~2m45s wall time) completed successfully, but the main bottleneck was **DB pool starvation from repeated reads**, not individually expensive SQL. The push/webpush path showed a few especially noisy reads: - `GetLastChatMessageByRole` for push body generation - `GetEnabledChatProviders` + `GetChatModelConfigByID` for push summary model resolution - `GetWebpushSubscriptionsByUserID` for every webpush dispatch This PR keeps the optimizations that remove those duplicate reads while leaving stream behavior unchanged. ## What changes in this PR ### 1. Reuse resolved chat state for push notifications `maybeSendPushNotification` used to re-read the last assistant message and re-resolve the chat model/provider after `runChat` had already done that work. Now `runChat` returns the final assistant text plus the already-resolved model and provider keys, and the push goroutine uses that state directly. That removes the extra push-path reads for: - `GetLastChatMessageByRole` - the second `resolveChatModel` path - the provider/model lookups that came with that second resolution ### 2. Cache webpush subscriptions during dispatch `Dispatch()` previously hit `GetWebpushSubscriptionsByUserID` on every push. A small per-user in-memory cache now avoids those repeated reads. The follow-up fix keeps that optimization correct: `InvalidateUser()` bumps a per-user generation so an older in-flight fetch cannot repopulate the cache with pre-mutation data after subscribe/unsubscribe. That preserves the cache win without letting local subscription changes be silently overwritten by stale fetch results. ## Why this is safe - The push change only reuses data already produced during the same chat run. It does not change notification semantics; if there is no assistant text to summarize, the existing fallback body still applies. - The webpush change keeps the existing TTL and `410 Gone` cleanup behavior. The generation guard only prevents stale in-flight fetches from poisoning the shared cache after invalidation. - The final PR does **not** change stream setup, pubsub/relay behavior, or chat status snapshot timing. ## Deliberately not included - No stream-path optimization in `Subscribe`. - No inline pubsub message payloads. - No distributed cross-replica webpush cache invalidation.
This commit is contained in:
+190
-7
@@ -9,18 +9,23 @@ import (
|
||||
"net/http"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/SherClockHolmes/webpush-go"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"golang.org/x/xerrors"
|
||||
"tailscale.com/util/singleflight"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
const defaultSubscriptionCacheTTL = 3 * time.Minute
|
||||
|
||||
// Dispatcher is an interface that can be used to dispatch
|
||||
// web push notifications to clients such as browsers.
|
||||
type Dispatcher interface {
|
||||
@@ -33,6 +38,36 @@ type Dispatcher interface {
|
||||
PublicKey() string
|
||||
}
|
||||
|
||||
// SubscriptionCacheInvalidator is an optional interface that lets local
|
||||
// subscription mutation handlers invalidate cached subscriptions.
|
||||
type SubscriptionCacheInvalidator interface {
|
||||
InvalidateUser(userID uuid.UUID)
|
||||
}
|
||||
|
||||
type options struct {
|
||||
clock quartz.Clock
|
||||
subscriptionCacheTTL time.Duration
|
||||
}
|
||||
|
||||
// Option configures optional behavior for a Webpusher.
|
||||
type Option func(*options)
|
||||
|
||||
// WithClock sets the clock used by the subscription cache. Defaults to a real
|
||||
// clock when not provided.
|
||||
func WithClock(clock quartz.Clock) Option {
|
||||
return func(o *options) {
|
||||
o.clock = clock
|
||||
}
|
||||
}
|
||||
|
||||
// WithSubscriptionCacheTTL sets the in-memory subscription cache TTL. Defaults
|
||||
// to three minutes when not provided or when given a non-positive duration.
|
||||
func WithSubscriptionCacheTTL(ttl time.Duration) Option {
|
||||
return func(o *options) {
|
||||
o.subscriptionCacheTTL = ttl
|
||||
}
|
||||
}
|
||||
|
||||
// New creates a new Dispatcher to dispatch web push notifications.
|
||||
//
|
||||
// This is *not* integrated into the enqueue system unfortunately.
|
||||
@@ -41,7 +76,21 @@ type Dispatcher interface {
|
||||
// for updates inside of a workspace, which we want to be immediate.
|
||||
//
|
||||
// See: https://github.com/coder/internal/issues/528
|
||||
func New(ctx context.Context, log *slog.Logger, db database.Store, vapidSub string) (Dispatcher, error) {
|
||||
func New(ctx context.Context, log *slog.Logger, db database.Store, vapidSub string, opts ...Option) (Dispatcher, error) {
|
||||
cfg := options{
|
||||
clock: quartz.NewReal(),
|
||||
subscriptionCacheTTL: defaultSubscriptionCacheTTL,
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(&cfg)
|
||||
}
|
||||
if cfg.clock == nil {
|
||||
cfg.clock = quartz.NewReal()
|
||||
}
|
||||
if cfg.subscriptionCacheTTL <= 0 {
|
||||
cfg.subscriptionCacheTTL = defaultSubscriptionCacheTTL
|
||||
}
|
||||
|
||||
keys, err := db.GetWebpushVAPIDKeys(ctx)
|
||||
if err != nil {
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
@@ -63,14 +112,23 @@ func New(ctx context.Context, log *slog.Logger, db database.Store, vapidSub stri
|
||||
}
|
||||
|
||||
return &Webpusher{
|
||||
vapidSub: vapidSub,
|
||||
store: db,
|
||||
log: log,
|
||||
VAPIDPublicKey: keys.VapidPublicKey,
|
||||
VAPIDPrivateKey: keys.VapidPrivateKey,
|
||||
vapidSub: vapidSub,
|
||||
store: db,
|
||||
log: log,
|
||||
VAPIDPublicKey: keys.VapidPublicKey,
|
||||
VAPIDPrivateKey: keys.VapidPrivateKey,
|
||||
clock: cfg.clock,
|
||||
subscriptionCacheTTL: cfg.subscriptionCacheTTL,
|
||||
subscriptionCache: make(map[uuid.UUID]cachedSubscriptions),
|
||||
subscriptionGenerations: make(map[uuid.UUID]uint64),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type cachedSubscriptions struct {
|
||||
subscriptions []database.WebpushSubscription
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
type Webpusher struct {
|
||||
store database.Store
|
||||
log *slog.Logger
|
||||
@@ -83,10 +141,18 @@ type Webpusher struct {
|
||||
// the message payload.
|
||||
VAPIDPublicKey string
|
||||
VAPIDPrivateKey string
|
||||
|
||||
clock quartz.Clock
|
||||
|
||||
cacheMu sync.RWMutex
|
||||
subscriptionCache map[uuid.UUID]cachedSubscriptions
|
||||
subscriptionGenerations map[uuid.UUID]uint64
|
||||
subscriptionCacheTTL time.Duration
|
||||
subscriptionFetches singleflight.Group[string, []database.WebpushSubscription]
|
||||
}
|
||||
|
||||
func (n *Webpusher) Dispatch(ctx context.Context, userID uuid.UUID, msg codersdk.WebpushMessage) error {
|
||||
subscriptions, err := n.store.GetWebpushSubscriptionsByUserID(ctx, userID)
|
||||
subscriptions, err := n.subscriptionsForUser(ctx, userID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get web push subscriptions by user ID: %w", err)
|
||||
}
|
||||
@@ -142,12 +208,129 @@ func (n *Webpusher) Dispatch(ctx context.Context, userID uuid.UUID, msg codersdk
|
||||
err = n.store.DeleteWebpushSubscriptions(dbauthz.AsNotifier(ctx), cleanupSubscriptions)
|
||||
if err != nil {
|
||||
n.log.Error(ctx, "failed to delete stale push subscriptions", slog.Error(err))
|
||||
} else {
|
||||
n.pruneSubscriptions(userID, cleanupSubscriptions)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *Webpusher) subscriptionsForUser(ctx context.Context, userID uuid.UUID) ([]database.WebpushSubscription, error) {
|
||||
if subscriptions, ok := n.cachedSubscriptions(userID); ok {
|
||||
return subscriptions, nil
|
||||
}
|
||||
|
||||
subscriptions, err, _ := n.subscriptionFetches.Do(userID.String(), func() ([]database.WebpushSubscription, error) {
|
||||
if cached, ok := n.cachedSubscriptions(userID); ok {
|
||||
return cached, nil
|
||||
}
|
||||
|
||||
generation := n.subscriptionGeneration(userID)
|
||||
fetched, err := n.store.GetWebpushSubscriptionsByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
n.storeSubscriptions(userID, generation, fetched)
|
||||
return slices.Clone(fetched), nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return slices.Clone(subscriptions), nil
|
||||
}
|
||||
|
||||
func (n *Webpusher) cachedSubscriptions(userID uuid.UUID) ([]database.WebpushSubscription, bool) {
|
||||
n.cacheMu.RLock()
|
||||
entry, ok := n.subscriptionCache[userID]
|
||||
n.cacheMu.RUnlock()
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
if n.clock.Now().Before(entry.expiresAt) {
|
||||
return slices.Clone(entry.subscriptions), true
|
||||
}
|
||||
|
||||
n.cacheMu.Lock()
|
||||
if current, ok := n.subscriptionCache[userID]; ok && !n.clock.Now().Before(current.expiresAt) {
|
||||
delete(n.subscriptionCache, userID)
|
||||
}
|
||||
n.cacheMu.Unlock()
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (n *Webpusher) subscriptionGeneration(userID uuid.UUID) uint64 {
|
||||
n.cacheMu.RLock()
|
||||
generation := n.subscriptionGenerations[userID]
|
||||
n.cacheMu.RUnlock()
|
||||
return generation
|
||||
}
|
||||
|
||||
func (n *Webpusher) storeSubscriptions(userID uuid.UUID, generation uint64, subscriptions []database.WebpushSubscription) {
|
||||
n.cacheMu.Lock()
|
||||
defer n.cacheMu.Unlock()
|
||||
|
||||
if n.subscriptionGenerations[userID] != generation {
|
||||
return
|
||||
}
|
||||
|
||||
n.subscriptionCache[userID] = cachedSubscriptions{
|
||||
subscriptions: slices.Clone(subscriptions),
|
||||
expiresAt: n.clock.Now().Add(n.subscriptionCacheTTL),
|
||||
}
|
||||
}
|
||||
|
||||
func (n *Webpusher) pruneSubscriptions(userID uuid.UUID, staleIDs []uuid.UUID) {
|
||||
if len(staleIDs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
stale := make(map[uuid.UUID]struct{}, len(staleIDs))
|
||||
for _, id := range staleIDs {
|
||||
stale[id] = struct{}{}
|
||||
}
|
||||
|
||||
n.cacheMu.Lock()
|
||||
defer n.cacheMu.Unlock()
|
||||
|
||||
entry, ok := n.subscriptionCache[userID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if !n.clock.Now().Before(entry.expiresAt) {
|
||||
delete(n.subscriptionCache, userID)
|
||||
return
|
||||
}
|
||||
|
||||
filtered := make([]database.WebpushSubscription, 0, len(entry.subscriptions))
|
||||
for _, subscription := range entry.subscriptions {
|
||||
if _, shouldDelete := stale[subscription.ID]; shouldDelete {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, subscription)
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
delete(n.subscriptionCache, userID)
|
||||
return
|
||||
}
|
||||
|
||||
entry.subscriptions = filtered
|
||||
n.subscriptionCache[userID] = entry
|
||||
}
|
||||
|
||||
// InvalidateUser clears the cached subscriptions for a user and advances
|
||||
// its invalidation generation. Local subscribe and unsubscribe handlers call
|
||||
// this after mutating subscriptions in the same process.
|
||||
func (n *Webpusher) InvalidateUser(userID uuid.UUID) {
|
||||
n.cacheMu.Lock()
|
||||
delete(n.subscriptionCache, userID)
|
||||
n.subscriptionGenerations[userID]++
|
||||
n.cacheMu.Unlock()
|
||||
n.subscriptionFetches.Forget(userID.String())
|
||||
}
|
||||
|
||||
func (n *Webpusher) webpushSend(ctx context.Context, msg []byte, endpoint string, keys webpush.Keys) (int, []byte, error) {
|
||||
// Copy the message to avoid modifying the original.
|
||||
cpy := slices.Clone(msg)
|
||||
|
||||
@@ -6,7 +6,9 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -21,6 +23,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/webpush"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -28,6 +31,20 @@ const (
|
||||
validEndpointP256dhKey = "BNNL5ZaTfK81qhXOx23+wewhigUeFb632jN6LvRWCFH1ubQr77FE/9qV1FuojuRmHP42zmf34rXgW80OvUVDgTk="
|
||||
)
|
||||
|
||||
type countingWebpushStore struct {
|
||||
database.Store
|
||||
getSubscriptionsCalls atomic.Int32
|
||||
}
|
||||
|
||||
func (s *countingWebpushStore) GetWebpushSubscriptionsByUserID(ctx context.Context, userID uuid.UUID) ([]database.WebpushSubscription, error) {
|
||||
s.getSubscriptionsCalls.Add(1)
|
||||
return s.Store.GetWebpushSubscriptionsByUserID(ctx, userID)
|
||||
}
|
||||
|
||||
func (s *countingWebpushStore) getCallCount() int32 {
|
||||
return s.getSubscriptionsCalls.Load()
|
||||
}
|
||||
|
||||
func TestPush(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -216,6 +233,131 @@ func TestPush(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, subscriptions, "No subscriptions should be returned")
|
||||
})
|
||||
|
||||
t.Run("CachesSubscriptionsWithinTTL", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
clock := quartz.NewMock(t)
|
||||
rawStore, _ := dbtestutil.NewDB(t)
|
||||
store := &countingWebpushStore{Store: rawStore}
|
||||
var delivered atomic.Int32
|
||||
manager, _, serverURL := setupPushTestWithOptions(ctx, t, store, func(w http.ResponseWriter, r *http.Request) {
|
||||
delivered.Add(1)
|
||||
assertWebpushPayload(t, r)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}, webpush.WithClock(clock), webpush.WithSubscriptionCacheTTL(time.Minute))
|
||||
|
||||
user := dbgen.User(t, rawStore, database.User{})
|
||||
_, err := rawStore.InsertWebpushSubscription(ctx, database.InsertWebpushSubscriptionParams{
|
||||
CreatedAt: dbtime.Now(),
|
||||
UserID: user.ID,
|
||||
Endpoint: serverURL,
|
||||
EndpointAuthKey: validEndpointAuthKey,
|
||||
EndpointP256dhKey: validEndpointP256dhKey,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
msg := randomWebpushMessage(t)
|
||||
err = manager.Dispatch(ctx, user.ID, msg)
|
||||
require.NoError(t, err)
|
||||
err = manager.Dispatch(ctx, user.ID, msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, int32(1), store.getCallCount(), "subscriptions should be read once within the TTL")
|
||||
require.Equal(t, int32(2), delivered.Load(), "both dispatches should send a notification")
|
||||
})
|
||||
|
||||
t.Run("RefreshesSubscriptionsAfterTTLExpires", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
clock := quartz.NewMock(t)
|
||||
rawStore, _ := dbtestutil.NewDB(t)
|
||||
store := &countingWebpushStore{Store: rawStore}
|
||||
var delivered atomic.Int32
|
||||
manager, _, serverURL := setupPushTestWithOptions(ctx, t, store, func(w http.ResponseWriter, r *http.Request) {
|
||||
delivered.Add(1)
|
||||
assertWebpushPayload(t, r)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}, webpush.WithClock(clock), webpush.WithSubscriptionCacheTTL(time.Minute))
|
||||
|
||||
user := dbgen.User(t, rawStore, database.User{})
|
||||
_, err := rawStore.InsertWebpushSubscription(ctx, database.InsertWebpushSubscriptionParams{
|
||||
CreatedAt: dbtime.Now(),
|
||||
UserID: user.ID,
|
||||
Endpoint: serverURL,
|
||||
EndpointAuthKey: validEndpointAuthKey,
|
||||
EndpointP256dhKey: validEndpointP256dhKey,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
msg := randomWebpushMessage(t)
|
||||
err = manager.Dispatch(ctx, user.ID, msg)
|
||||
require.NoError(t, err)
|
||||
clock.Advance(time.Minute)
|
||||
err = manager.Dispatch(ctx, user.ID, msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, int32(2), store.getCallCount(), "dispatch should refresh subscriptions after the TTL expires")
|
||||
require.Equal(t, int32(2), delivered.Load(), "both dispatches should send a notification")
|
||||
})
|
||||
|
||||
t.Run("PrunesStaleSubscriptionsFromCache", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
clock := quartz.NewMock(t)
|
||||
rawStore, _ := dbtestutil.NewDB(t)
|
||||
store := &countingWebpushStore{Store: rawStore}
|
||||
var okCalls atomic.Int32
|
||||
var goneCalls atomic.Int32
|
||||
manager, _, okServerURL := setupPushTestWithOptions(ctx, t, store, func(w http.ResponseWriter, r *http.Request) {
|
||||
okCalls.Add(1)
|
||||
assertWebpushPayload(t, r)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}, webpush.WithClock(clock), webpush.WithSubscriptionCacheTTL(time.Minute))
|
||||
|
||||
goneServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
goneCalls.Add(1)
|
||||
assertWebpushPayload(t, r)
|
||||
w.WriteHeader(http.StatusGone)
|
||||
}))
|
||||
defer goneServer.Close()
|
||||
|
||||
user := dbgen.User(t, rawStore, database.User{})
|
||||
okSubscription, err := rawStore.InsertWebpushSubscription(ctx, database.InsertWebpushSubscriptionParams{
|
||||
CreatedAt: dbtime.Now(),
|
||||
UserID: user.ID,
|
||||
Endpoint: okServerURL,
|
||||
EndpointAuthKey: validEndpointAuthKey,
|
||||
EndpointP256dhKey: validEndpointP256dhKey,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = rawStore.InsertWebpushSubscription(ctx, database.InsertWebpushSubscriptionParams{
|
||||
CreatedAt: dbtime.Now(),
|
||||
UserID: user.ID,
|
||||
Endpoint: goneServer.URL,
|
||||
EndpointAuthKey: validEndpointAuthKey,
|
||||
EndpointP256dhKey: validEndpointP256dhKey,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
msg := randomWebpushMessage(t)
|
||||
err = manager.Dispatch(ctx, user.ID, msg)
|
||||
require.NoError(t, err)
|
||||
err = manager.Dispatch(ctx, user.ID, msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, int32(1), store.getCallCount(), "stale subscription cleanup should not force a second DB read within the TTL")
|
||||
require.Equal(t, int32(2), okCalls.Load(), "the healthy endpoint should receive both dispatches")
|
||||
require.Equal(t, int32(1), goneCalls.Load(), "the stale endpoint should be pruned from the cache after the first dispatch")
|
||||
|
||||
subscriptions, err := rawStore.GetWebpushSubscriptionsByUserID(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, subscriptions, 1, "only the healthy subscription should remain")
|
||||
require.Equal(t, okSubscription.ID, subscriptions[0].ID)
|
||||
})
|
||||
}
|
||||
|
||||
func randomWebpushMessage(t testing.TB) codersdk.WebpushMessage {
|
||||
@@ -244,16 +386,21 @@ func assertWebpushPayload(t testing.TB, r *http.Request) {
|
||||
assert.Error(t, json.NewDecoder(r.Body).Decode(io.Discard))
|
||||
}
|
||||
|
||||
// setupPushTest creates a common test setup for webpush notification tests
|
||||
// setupPushTest creates a common test setup for webpush notification tests.
|
||||
func setupPushTest(ctx context.Context, t *testing.T, handlerFunc func(w http.ResponseWriter, r *http.Request)) (webpush.Dispatcher, database.Store, string) {
|
||||
t.Helper()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
return setupPushTestWithOptions(ctx, t, db, handlerFunc)
|
||||
}
|
||||
|
||||
func setupPushTestWithOptions(ctx context.Context, t *testing.T, db database.Store, handlerFunc func(w http.ResponseWriter, r *http.Request), opts ...webpush.Option) (webpush.Dispatcher, database.Store, string) {
|
||||
t.Helper()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(handlerFunc))
|
||||
t.Cleanup(server.Close)
|
||||
|
||||
manager, err := webpush.New(ctx, &logger, db, "http://example.com")
|
||||
manager, err := webpush.New(ctx, &logger, db, "http://example.com", opts...)
|
||||
require.NoError(t, err, "Failed to create webpush manager")
|
||||
|
||||
return manager, db, server.URL
|
||||
|
||||
Reference in New Issue
Block a user