Files
coder/coderd/webpush/webpush_test.go
T
Ethan 04fca84872 perf(coderd): reduce duplicated reads in push and webpush paths (#23115)
## Background

A 5000-chat scaletest (~50k turns, ~2m45s wall time) completed
successfully,
but the main bottleneck was **DB pool starvation from repeated reads**,
not
individually expensive SQL. The push/webpush path showed a few
especially noisy
reads:

- `GetLastChatMessageByRole` for push body generation
- `GetEnabledChatProviders` + `GetChatModelConfigByID` for push summary
model
  resolution
- `GetWebpushSubscriptionsByUserID` for every webpush dispatch

This PR keeps the optimizations that remove those duplicate reads while
leaving
stream behavior unchanged.

## What changes in this PR

### 1. Reuse resolved chat state for push notifications

`maybeSendPushNotification` used to re-read the last assistant message
and
re-resolve the chat model/provider after `runChat` had already done that
work.

Now `runChat` returns the final assistant text plus the already-resolved
model
and provider keys, and the push goroutine uses that state directly.

That removes the extra push-path reads for:

- `GetLastChatMessageByRole`
- the second `resolveChatModel` path
- the provider/model lookups that came with that second resolution

### 2. Cache webpush subscriptions during dispatch

`Dispatch()` previously hit `GetWebpushSubscriptionsByUserID` on every
push. A
small per-user in-memory cache now avoids those repeated reads.

The follow-up fix keeps that optimization correct: `InvalidateUser()`
bumps a
per-user generation so an older in-flight fetch cannot repopulate the
cache with
pre-mutation data after subscribe/unsubscribe.

That preserves the cache win without letting local subscription changes
be
silently overwritten by stale fetch results.

## Why this is safe

- The push change only reuses data already produced during the same chat
run. It
does not change notification semantics; if there is no assistant text to
  summarize, the existing fallback body still applies.
- The webpush change keeps the existing TTL and `410 Gone` cleanup
behavior. The
generation guard only prevents stale in-flight fetches from poisoning
the
  shared cache after invalidation.
- The final PR does **not** change stream setup, pubsub/relay behavior,
or chat
  status snapshot timing.

## Deliberately not included

- No stream-path optimization in `Subscribe`.
- No inline pubsub message payloads.
- No distributed cross-replica webpush cache invalidation.
2026-03-17 13:50:47 +11:00

408 lines
14 KiB
Go

package webpush_test
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/webpush"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
)
const (
validEndpointAuthKey = "zqbxT6JKstKSY9JKibZLSQ=="
validEndpointP256dhKey = "BNNL5ZaTfK81qhXOx23+wewhigUeFb632jN6LvRWCFH1ubQr77FE/9qV1FuojuRmHP42zmf34rXgW80OvUVDgTk="
)
type countingWebpushStore struct {
database.Store
getSubscriptionsCalls atomic.Int32
}
func (s *countingWebpushStore) GetWebpushSubscriptionsByUserID(ctx context.Context, userID uuid.UUID) ([]database.WebpushSubscription, error) {
s.getSubscriptionsCalls.Add(1)
return s.Store.GetWebpushSubscriptionsByUserID(ctx, userID)
}
func (s *countingWebpushStore) getCallCount() int32 {
return s.getSubscriptionsCalls.Load()
}
func TestPush(t *testing.T) {
t.Parallel()
t.Run("SuccessfulDelivery", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
msg := randomWebpushMessage(t)
manager, store, serverURL := setupPushTest(ctx, t, func(w http.ResponseWriter, r *http.Request) {
assertWebpushPayload(t, r)
w.WriteHeader(http.StatusOK)
})
user := dbgen.User(t, store, database.User{})
sub, err := store.InsertWebpushSubscription(ctx, database.InsertWebpushSubscriptionParams{
UserID: user.ID,
Endpoint: serverURL,
EndpointAuthKey: validEndpointAuthKey,
EndpointP256dhKey: validEndpointP256dhKey,
CreatedAt: dbtime.Now(),
})
require.NoError(t, err)
err = manager.Dispatch(ctx, user.ID, msg)
require.NoError(t, err)
subscriptions, err := store.GetWebpushSubscriptionsByUserID(ctx, user.ID)
require.NoError(t, err)
assert.Len(t, subscriptions, 1, "One subscription should be returned")
assert.Equal(t, subscriptions[0].ID, sub.ID, "The subscription should not be deleted")
})
t.Run("ExpiredSubscription", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
manager, store, serverURL := setupPushTest(ctx, t, func(w http.ResponseWriter, r *http.Request) {
assertWebpushPayload(t, r)
w.WriteHeader(http.StatusGone)
})
user := dbgen.User(t, store, database.User{})
_, err := store.InsertWebpushSubscription(ctx, database.InsertWebpushSubscriptionParams{
UserID: user.ID,
Endpoint: serverURL,
EndpointAuthKey: validEndpointAuthKey,
EndpointP256dhKey: validEndpointP256dhKey,
CreatedAt: dbtime.Now(),
})
require.NoError(t, err)
msg := randomWebpushMessage(t)
err = manager.Dispatch(ctx, user.ID, msg)
require.NoError(t, err)
subscriptions, err := store.GetWebpushSubscriptionsByUserID(ctx, user.ID)
require.NoError(t, err)
assert.Len(t, subscriptions, 0, "No subscriptions should be returned")
})
t.Run("FailedDelivery", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
manager, store, serverURL := setupPushTest(ctx, t, func(w http.ResponseWriter, r *http.Request) {
assertWebpushPayload(t, r)
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("Invalid request"))
})
user := dbgen.User(t, store, database.User{})
sub, err := store.InsertWebpushSubscription(ctx, database.InsertWebpushSubscriptionParams{
UserID: user.ID,
Endpoint: serverURL,
EndpointAuthKey: validEndpointAuthKey,
EndpointP256dhKey: validEndpointP256dhKey,
CreatedAt: dbtime.Now(),
})
require.NoError(t, err)
msg := randomWebpushMessage(t)
err = manager.Dispatch(ctx, user.ID, msg)
require.Error(t, err)
assert.Contains(t, err.Error(), "Invalid request")
subscriptions, err := store.GetWebpushSubscriptionsByUserID(ctx, user.ID)
require.NoError(t, err)
assert.Len(t, subscriptions, 1, "One subscription should be returned")
assert.Equal(t, subscriptions[0].ID, sub.ID, "The subscription should not be deleted")
})
t.Run("MultipleSubscriptions", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
var okEndpointCalled bool
var goneEndpointCalled bool
manager, store, serverOKURL := setupPushTest(ctx, t, func(w http.ResponseWriter, r *http.Request) {
okEndpointCalled = true
assertWebpushPayload(t, r)
w.WriteHeader(http.StatusOK)
})
serverGone := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
goneEndpointCalled = true
assertWebpushPayload(t, r)
w.WriteHeader(http.StatusGone)
}))
defer serverGone.Close()
serverGoneURL := serverGone.URL
// Setup subscriptions pointing to our test servers
user := dbgen.User(t, store, database.User{})
sub1, err := store.InsertWebpushSubscription(ctx, database.InsertWebpushSubscriptionParams{
UserID: user.ID,
Endpoint: serverOKURL,
EndpointAuthKey: validEndpointAuthKey,
EndpointP256dhKey: validEndpointP256dhKey,
CreatedAt: dbtime.Now(),
})
require.NoError(t, err)
_, err = store.InsertWebpushSubscription(ctx, database.InsertWebpushSubscriptionParams{
UserID: user.ID,
Endpoint: serverGoneURL,
EndpointAuthKey: validEndpointAuthKey,
EndpointP256dhKey: validEndpointP256dhKey,
CreatedAt: dbtime.Now(),
})
require.NoError(t, err)
msg := randomWebpushMessage(t)
err = manager.Dispatch(ctx, user.ID, msg)
require.NoError(t, err)
assert.True(t, okEndpointCalled, "The valid endpoint should be called")
assert.True(t, goneEndpointCalled, "The expired endpoint should be called")
// Assert that sub1 was not deleted.
subscriptions, err := store.GetWebpushSubscriptionsByUserID(ctx, user.ID)
require.NoError(t, err)
if assert.Len(t, subscriptions, 1, "One subscription should be returned") {
assert.Equal(t, subscriptions[0].ID, sub1.ID, "The valid subscription should not be deleted")
}
})
t.Run("NotificationPayload", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
var requestReceived bool
manager, store, serverURL := setupPushTest(ctx, t, func(w http.ResponseWriter, r *http.Request) {
requestReceived = true
assertWebpushPayload(t, r)
w.WriteHeader(http.StatusOK)
})
user := dbgen.User(t, store, database.User{})
_, err := store.InsertWebpushSubscription(ctx, database.InsertWebpushSubscriptionParams{
CreatedAt: dbtime.Now(),
UserID: user.ID,
Endpoint: serverURL,
EndpointAuthKey: validEndpointAuthKey,
EndpointP256dhKey: validEndpointP256dhKey,
})
require.NoError(t, err, "Failed to insert push subscription")
msg := randomWebpushMessage(t)
err = manager.Dispatch(ctx, user.ID, msg)
require.NoError(t, err, "The push notification should be dispatched successfully")
require.True(t, requestReceived, "The push notification request should have been received by the server")
})
t.Run("NoSubscriptions", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
manager, store, _ := setupPushTest(ctx, t, func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})
userID := uuid.New()
notification := codersdk.WebpushMessage{
Title: "Test Title",
Body: "Test Body",
}
err := manager.Dispatch(ctx, userID, notification)
require.NoError(t, err)
subscriptions, err := store.GetWebpushSubscriptionsByUserID(ctx, userID)
require.NoError(t, err)
assert.Empty(t, subscriptions, "No subscriptions should be returned")
})
t.Run("CachesSubscriptionsWithinTTL", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
clock := quartz.NewMock(t)
rawStore, _ := dbtestutil.NewDB(t)
store := &countingWebpushStore{Store: rawStore}
var delivered atomic.Int32
manager, _, serverURL := setupPushTestWithOptions(ctx, t, store, func(w http.ResponseWriter, r *http.Request) {
delivered.Add(1)
assertWebpushPayload(t, r)
w.WriteHeader(http.StatusOK)
}, webpush.WithClock(clock), webpush.WithSubscriptionCacheTTL(time.Minute))
user := dbgen.User(t, rawStore, database.User{})
_, err := rawStore.InsertWebpushSubscription(ctx, database.InsertWebpushSubscriptionParams{
CreatedAt: dbtime.Now(),
UserID: user.ID,
Endpoint: serverURL,
EndpointAuthKey: validEndpointAuthKey,
EndpointP256dhKey: validEndpointP256dhKey,
})
require.NoError(t, err)
msg := randomWebpushMessage(t)
err = manager.Dispatch(ctx, user.ID, msg)
require.NoError(t, err)
err = manager.Dispatch(ctx, user.ID, msg)
require.NoError(t, err)
require.Equal(t, int32(1), store.getCallCount(), "subscriptions should be read once within the TTL")
require.Equal(t, int32(2), delivered.Load(), "both dispatches should send a notification")
})
t.Run("RefreshesSubscriptionsAfterTTLExpires", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
clock := quartz.NewMock(t)
rawStore, _ := dbtestutil.NewDB(t)
store := &countingWebpushStore{Store: rawStore}
var delivered atomic.Int32
manager, _, serverURL := setupPushTestWithOptions(ctx, t, store, func(w http.ResponseWriter, r *http.Request) {
delivered.Add(1)
assertWebpushPayload(t, r)
w.WriteHeader(http.StatusOK)
}, webpush.WithClock(clock), webpush.WithSubscriptionCacheTTL(time.Minute))
user := dbgen.User(t, rawStore, database.User{})
_, err := rawStore.InsertWebpushSubscription(ctx, database.InsertWebpushSubscriptionParams{
CreatedAt: dbtime.Now(),
UserID: user.ID,
Endpoint: serverURL,
EndpointAuthKey: validEndpointAuthKey,
EndpointP256dhKey: validEndpointP256dhKey,
})
require.NoError(t, err)
msg := randomWebpushMessage(t)
err = manager.Dispatch(ctx, user.ID, msg)
require.NoError(t, err)
clock.Advance(time.Minute)
err = manager.Dispatch(ctx, user.ID, msg)
require.NoError(t, err)
require.Equal(t, int32(2), store.getCallCount(), "dispatch should refresh subscriptions after the TTL expires")
require.Equal(t, int32(2), delivered.Load(), "both dispatches should send a notification")
})
t.Run("PrunesStaleSubscriptionsFromCache", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
clock := quartz.NewMock(t)
rawStore, _ := dbtestutil.NewDB(t)
store := &countingWebpushStore{Store: rawStore}
var okCalls atomic.Int32
var goneCalls atomic.Int32
manager, _, okServerURL := setupPushTestWithOptions(ctx, t, store, func(w http.ResponseWriter, r *http.Request) {
okCalls.Add(1)
assertWebpushPayload(t, r)
w.WriteHeader(http.StatusOK)
}, webpush.WithClock(clock), webpush.WithSubscriptionCacheTTL(time.Minute))
goneServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
goneCalls.Add(1)
assertWebpushPayload(t, r)
w.WriteHeader(http.StatusGone)
}))
defer goneServer.Close()
user := dbgen.User(t, rawStore, database.User{})
okSubscription, err := rawStore.InsertWebpushSubscription(ctx, database.InsertWebpushSubscriptionParams{
CreatedAt: dbtime.Now(),
UserID: user.ID,
Endpoint: okServerURL,
EndpointAuthKey: validEndpointAuthKey,
EndpointP256dhKey: validEndpointP256dhKey,
})
require.NoError(t, err)
_, err = rawStore.InsertWebpushSubscription(ctx, database.InsertWebpushSubscriptionParams{
CreatedAt: dbtime.Now(),
UserID: user.ID,
Endpoint: goneServer.URL,
EndpointAuthKey: validEndpointAuthKey,
EndpointP256dhKey: validEndpointP256dhKey,
})
require.NoError(t, err)
msg := randomWebpushMessage(t)
err = manager.Dispatch(ctx, user.ID, msg)
require.NoError(t, err)
err = manager.Dispatch(ctx, user.ID, msg)
require.NoError(t, err)
require.Equal(t, int32(1), store.getCallCount(), "stale subscription cleanup should not force a second DB read within the TTL")
require.Equal(t, int32(2), okCalls.Load(), "the healthy endpoint should receive both dispatches")
require.Equal(t, int32(1), goneCalls.Load(), "the stale endpoint should be pruned from the cache after the first dispatch")
subscriptions, err := rawStore.GetWebpushSubscriptionsByUserID(ctx, user.ID)
require.NoError(t, err)
require.Len(t, subscriptions, 1, "only the healthy subscription should remain")
require.Equal(t, okSubscription.ID, subscriptions[0].ID)
})
}
func randomWebpushMessage(t testing.TB) codersdk.WebpushMessage {
t.Helper()
return codersdk.WebpushMessage{
Title: testutil.GetRandomName(t),
Body: testutil.GetRandomName(t),
Actions: []codersdk.WebpushMessageAction{
{Label: "A", URL: "https://example.com/a"},
{Label: "B", URL: "https://example.com/b"},
},
Icon: "https://example.com/icon.png",
}
}
func assertWebpushPayload(t testing.TB, r *http.Request) {
t.Helper()
assert.Equal(t, http.MethodPost, r.Method)
assert.Equal(t, "application/octet-stream", r.Header.Get("Content-Type"))
assert.Equal(t, r.Header.Get("content-encoding"), "aes128gcm")
assert.Contains(t, r.Header.Get("Authorization"), "vapid")
// Attempting to decode the request body as JSON should fail as it is
// encrypted.
assert.Error(t, json.NewDecoder(r.Body).Decode(io.Discard))
}
// setupPushTest creates a common test setup for webpush notification tests.
func setupPushTest(ctx context.Context, t *testing.T, handlerFunc func(w http.ResponseWriter, r *http.Request)) (webpush.Dispatcher, database.Store, string) {
t.Helper()
db, _ := dbtestutil.NewDB(t)
return setupPushTestWithOptions(ctx, t, db, handlerFunc)
}
func setupPushTestWithOptions(ctx context.Context, t *testing.T, db database.Store, handlerFunc func(w http.ResponseWriter, r *http.Request), opts ...webpush.Option) (webpush.Dispatcher, database.Store, string) {
t.Helper()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
server := httptest.NewServer(http.HandlerFunc(handlerFunc))
t.Cleanup(server.Close)
manager, err := webpush.New(ctx, &logger, db, "http://example.com", opts...)
require.NoError(t, err, "Failed to create webpush manager")
return manager, db, server.URL
}