Files

537 lines
17 KiB
Go

package webpush
import (
"context"
"database/sql"
"encoding/json"
"errors"
"io"
"net"
"net/http"
"net/netip"
"slices"
"sync"
"syscall"
"time"
"github.com/SherClockHolmes/webpush-go"
"github.com/google/uuid"
"golang.org/x/sync/errgroup"
"golang.org/x/xerrors"
"tailscale.com/util/singleflight"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/quartz"
)
const defaultSubscriptionCacheTTL = 3 * time.Minute
// isStaleSubscriptionStatus reports whether a status code from a push
// service indicates that the subscription is permanently invalid and
// should be removed from the database. Other 4xx and 5xx responses
// (rate limits, transient failures) leave the subscription in place
// so it can be retried on the next dispatch.
func isStaleSubscriptionStatus(statusCode int) bool {
switch statusCode {
case http.StatusBadRequest, // 400: malformed subscription per the push service.
http.StatusForbidden, // 403: Apple BadJwtToken / VAPID rejected, key rotation.
http.StatusNotFound, // 404: FCM/Mozilla endpoint no longer valid.
http.StatusGone: // 410: standard "subscription expired" signal.
return true
}
return false
}
// Dispatcher is an interface that can be used to dispatch
// web push notifications to clients such as browsers.
type Dispatcher interface {
// Dispatch sends a web push notification to all subscriptions
// for a user. Any notifications that fail to send are silently dropped.
Dispatch(ctx context.Context, userID uuid.UUID, notification codersdk.WebpushMessage) error
// Test sends a test web push notificatoin to a subscription to ensure it is valid.
Test(ctx context.Context, req codersdk.WebpushSubscription) error
// PublicKey returns the VAPID public key for the webpush dispatcher.
PublicKey() string
}
// SubscriptionCacheInvalidator is an optional interface that lets local
// subscription mutation handlers invalidate cached subscriptions.
type SubscriptionCacheInvalidator interface {
InvalidateUser(userID uuid.UUID)
}
type options struct {
clock quartz.Clock
subscriptionCacheTTL time.Duration
httpClient *http.Client
}
// Option configures optional behavior for a Webpusher.
type Option func(*options)
// WithClock sets the clock used by the subscription cache. Defaults to a real
// clock when not provided.
func WithClock(clock quartz.Clock) Option {
return func(o *options) {
o.clock = clock
}
}
// WithSubscriptionCacheTTL sets the in-memory subscription cache TTL. Defaults
// to three minutes when not provided or when given a non-positive duration.
func WithSubscriptionCacheTTL(ttl time.Duration) Option {
return func(o *options) {
o.subscriptionCacheTTL = ttl
}
}
// WithHTTPClient overrides the default SSRF-safe HTTP client used to deliver
// push notifications. This is intended for tests that need to deliver to
// localhost test servers.
func WithHTTPClient(client *http.Client) Option {
return func(o *options) {
o.httpClient = client
}
}
// New creates a new Dispatcher to dispatch web push notifications.
//
// This is *not* integrated into the enqueue system unfortunately.
// That's because the notifications system has a enqueue system,
// and push notifications at time of implementation are being used
// for updates inside of a workspace, which we want to be immediate.
//
// See: https://github.com/coder/internal/issues/528
func New(ctx context.Context, log *slog.Logger, db database.Store, vapidSub string, opts ...Option) (Dispatcher, error) {
cfg := options{
clock: quartz.NewReal(),
subscriptionCacheTTL: defaultSubscriptionCacheTTL,
}
for _, opt := range opts {
opt(&cfg)
}
if cfg.clock == nil {
cfg.clock = quartz.NewReal()
}
if cfg.subscriptionCacheTTL <= 0 {
cfg.subscriptionCacheTTL = defaultSubscriptionCacheTTL
}
if cfg.httpClient == nil {
cfg.httpClient = newSSRFSafeHTTPClient()
}
keys, err := db.GetWebpushVAPIDKeys(ctx)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
return nil, xerrors.Errorf("get notification vapid keys: %w", err)
}
}
if keys.VapidPublicKey == "" || keys.VapidPrivateKey == "" {
// Generate new VAPID keys. This also deletes all existing push
// subscriptions as part of the transaction, as they are no longer
// valid.
newPrivateKey, newPublicKey, err := RegenerateVAPIDKeys(ctx, db)
if err != nil {
return nil, xerrors.Errorf("regenerate vapid keys: %w", err)
}
keys.VapidPublicKey = newPublicKey
keys.VapidPrivateKey = newPrivateKey
}
return &Webpusher{
vapidSub: vapidSub,
store: db,
log: log,
VAPIDPublicKey: keys.VapidPublicKey,
VAPIDPrivateKey: keys.VapidPrivateKey,
clock: cfg.clock,
subscriptionCacheTTL: cfg.subscriptionCacheTTL,
subscriptionCache: make(map[uuid.UUID]cachedSubscriptions),
subscriptionGenerations: make(map[uuid.UUID]uint64),
httpClient: cfg.httpClient,
}, nil
}
type cachedSubscriptions struct {
subscriptions []database.WebpushSubscription
expiresAt time.Time
}
type Webpusher struct {
store database.Store
log *slog.Logger
// VAPID allows us to identify the sender of the message.
// This must be a https:// URL or an email address.
// Some push services (such as Apple's) require this to be set.
vapidSub string
// public and private keys for VAPID. These are used to sign and encrypt
// the message payload.
VAPIDPublicKey string
VAPIDPrivateKey string
// httpClient is an SSRF-safe HTTP client that rejects connections to
// private, loopback, and link-local IP addresses at dial time. This
// closes the DNS rebinding TOCTOU gap where a hostname passes URL
// validation but resolves to a private IP when the connection is made.
httpClient *http.Client
clock quartz.Clock
cacheMu sync.RWMutex
subscriptionCache map[uuid.UUID]cachedSubscriptions
subscriptionGenerations map[uuid.UUID]uint64
subscriptionCacheTTL time.Duration
subscriptionFetches singleflight.Group[string, []database.WebpushSubscription]
}
func (n *Webpusher) Dispatch(ctx context.Context, userID uuid.UUID, msg codersdk.WebpushMessage) error {
subscriptions, err := n.subscriptionsForUser(ctx, userID)
if err != nil {
return xerrors.Errorf("get web push subscriptions by user ID: %w", err)
}
if len(subscriptions) == 0 {
return nil
}
msgJSON, err := json.Marshal(msg)
if err != nil {
return xerrors.Errorf("marshal webpush notification: %w", err)
}
cleanupSubscriptions := make([]uuid.UUID, 0)
var mu sync.Mutex
var eg errgroup.Group
for _, subscription := range subscriptions {
eg.Go(func() error {
// TODO: Implement some retry logic here. For now, this is just a
// best-effort attempt.
statusCode, body, err := n.webpushSend(ctx, msgJSON, subscription.Endpoint, webpush.Keys{
Auth: subscription.EndpointAuthKey,
P256dh: subscription.EndpointP256dhKey,
})
if err != nil {
return xerrors.Errorf("send webpush notification: %w", err)
}
if isStaleSubscriptionStatus(statusCode) {
// Remove subscriptions that the push service has marked as
// permanently invalid (Apple returns 403 BadJwtToken and 404
// for invalidated subscriptions, FCM returns 404 for
// expired endpoints, all push services return 410 for
// permanently gone subscriptions, and 400 indicates a
// malformed subscription that cannot be retried). Without
// this, stale rows accumulate after PWA reinstalls and the
// in-memory cache keeps trying to deliver to dead
// subscriptions.
mu.Lock()
cleanupSubscriptions = append(cleanupSubscriptions, subscription.ID)
mu.Unlock()
}
if statusCode == http.StatusGone {
// 410 Gone is informational, not a delivery error.
return nil
}
// 200, 201, and 202 are common for successful delivery.
if statusCode > http.StatusAccepted {
// It's likely the subscription failed to deliver for some reason.
return xerrors.Errorf("web push dispatch failed with status code %d: %s", statusCode, string(body))
}
return nil
})
}
dispatchErr := eg.Wait()
// Always remove subscriptions that the push service rejected as
// permanently invalid, even when sibling deliveries returned a
// non-stale error. The cleanup must run before the error return so a
// transient delivery failure on one subscription cannot block the
// deletion of a 410/404/403/400 sibling. Without this ordering,
// stale rows accumulate after PWA reinstalls and silently mask the
// new subscription on every subsequent dispatch.
n.cleanupStaleSubscriptions(ctx, userID, cleanupSubscriptions)
if dispatchErr != nil {
return xerrors.Errorf("send webpush notifications: %w", dispatchErr)
}
return nil
}
// cleanupStaleSubscriptions deletes the rows the push service flagged as
// permanently invalid (see isStaleSubscriptionStatus) and clears the cached
// entries for the affected user. Failures are logged at error level rather
// than returned: the caller is in the middle of returning a delivery error
// and shouldn't have its error shadowed by a cleanup failure. The cache
// prune is gated on a successful database delete so a partial state cannot
// leak into the cache.
func (n *Webpusher) cleanupStaleSubscriptions(ctx context.Context, userID uuid.UUID, ids []uuid.UUID) {
if len(ids) == 0 {
return
}
// nolint:gocritic // These are known to be invalid subscriptions.
if err := n.store.DeleteWebpushSubscriptions(dbauthz.AsNotifier(ctx), ids); err != nil {
n.log.Error(ctx, "failed to delete stale push subscriptions", slog.Error(err))
return
}
n.pruneSubscriptions(userID, ids)
}
func (n *Webpusher) subscriptionsForUser(ctx context.Context, userID uuid.UUID) ([]database.WebpushSubscription, error) {
if subscriptions, ok := n.cachedSubscriptions(userID); ok {
return subscriptions, nil
}
subscriptions, err, _ := n.subscriptionFetches.Do(userID.String(), func() ([]database.WebpushSubscription, error) {
if cached, ok := n.cachedSubscriptions(userID); ok {
return cached, nil
}
generation := n.subscriptionGeneration(userID)
fetched, err := n.store.GetWebpushSubscriptionsByUserID(ctx, userID)
if err != nil {
return nil, err
}
n.storeSubscriptions(userID, generation, fetched)
return slices.Clone(fetched), nil
})
if err != nil {
return nil, err
}
return slices.Clone(subscriptions), nil
}
func (n *Webpusher) cachedSubscriptions(userID uuid.UUID) ([]database.WebpushSubscription, bool) {
n.cacheMu.RLock()
entry, ok := n.subscriptionCache[userID]
n.cacheMu.RUnlock()
if !ok {
return nil, false
}
if n.clock.Now().Before(entry.expiresAt) {
return slices.Clone(entry.subscriptions), true
}
n.cacheMu.Lock()
if current, ok := n.subscriptionCache[userID]; ok && !n.clock.Now().Before(current.expiresAt) {
delete(n.subscriptionCache, userID)
}
n.cacheMu.Unlock()
return nil, false
}
func (n *Webpusher) subscriptionGeneration(userID uuid.UUID) uint64 {
n.cacheMu.RLock()
generation := n.subscriptionGenerations[userID]
n.cacheMu.RUnlock()
return generation
}
func (n *Webpusher) storeSubscriptions(userID uuid.UUID, generation uint64, subscriptions []database.WebpushSubscription) {
n.cacheMu.Lock()
defer n.cacheMu.Unlock()
if n.subscriptionGenerations[userID] != generation {
return
}
n.subscriptionCache[userID] = cachedSubscriptions{
subscriptions: slices.Clone(subscriptions),
expiresAt: n.clock.Now().Add(n.subscriptionCacheTTL),
}
}
func (n *Webpusher) pruneSubscriptions(userID uuid.UUID, staleIDs []uuid.UUID) {
if len(staleIDs) == 0 {
return
}
stale := make(map[uuid.UUID]struct{}, len(staleIDs))
for _, id := range staleIDs {
stale[id] = struct{}{}
}
n.cacheMu.Lock()
defer n.cacheMu.Unlock()
entry, ok := n.subscriptionCache[userID]
if !ok {
return
}
if !n.clock.Now().Before(entry.expiresAt) {
delete(n.subscriptionCache, userID)
return
}
filtered := make([]database.WebpushSubscription, 0, len(entry.subscriptions))
for _, subscription := range entry.subscriptions {
if _, shouldDelete := stale[subscription.ID]; shouldDelete {
continue
}
filtered = append(filtered, subscription)
}
if len(filtered) == 0 {
delete(n.subscriptionCache, userID)
return
}
entry.subscriptions = filtered
n.subscriptionCache[userID] = entry
}
// InvalidateUser clears the cached subscriptions for a user and advances
// its invalidation generation. Local subscribe and unsubscribe handlers call
// this after mutating subscriptions in the same process.
func (n *Webpusher) InvalidateUser(userID uuid.UUID) {
n.cacheMu.Lock()
delete(n.subscriptionCache, userID)
n.subscriptionGenerations[userID]++
n.cacheMu.Unlock()
n.subscriptionFetches.Forget(userID.String())
}
func (n *Webpusher) webpushSend(ctx context.Context, msg []byte, endpoint string, keys webpush.Keys) (int, []byte, error) {
// Copy the message to avoid modifying the original.
cpy := slices.Clone(msg)
resp, err := webpush.SendNotificationWithContext(ctx, cpy, &webpush.Subscription{
Endpoint: endpoint,
Keys: keys,
}, &webpush.Options{
HTTPClient: n.httpClient,
Subscriber: n.vapidSub,
VAPIDPublicKey: n.VAPIDPublicKey,
VAPIDPrivateKey: n.VAPIDPrivateKey,
})
if err != nil {
n.log.Error(ctx, "failed to send webpush notification", slog.Error(err), slog.F("endpoint", endpoint))
return -1, nil, xerrors.Errorf("send webpush notification: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return -1, nil, xerrors.Errorf("read response body: %w", err)
}
return resp.StatusCode, body, nil
}
func (n *Webpusher) Test(ctx context.Context, req codersdk.WebpushSubscription) error {
msgJSON, err := json.Marshal(codersdk.WebpushMessage{
Title: "It's working!",
Body: "You've subscribed to push notifications.",
})
if err != nil {
return xerrors.Errorf("marshal webpush notification: %w", err)
}
statusCode, body, err := n.webpushSend(ctx, msgJSON, req.Endpoint, webpush.Keys{
Auth: req.AuthKey,
P256dh: req.P256DHKey,
})
if err != nil {
return xerrors.Errorf("send test webpush notification: %w", err)
}
// 200, 201, and 202 are common for successful delivery.
if statusCode > http.StatusAccepted {
// It's likely the subscription failed to deliver for some reason.
return xerrors.Errorf("web push dispatch failed with status code %d: %s", statusCode, string(body))
}
return nil
}
// PublicKey returns the VAPID public key for the webpush dispatcher.
// Clients need this, so it's exposed via the BuildInfo endpoint.
func (n *Webpusher) PublicKey() string {
return n.VAPIDPublicKey
}
// NoopWebpusher is a Dispatcher that always fails, returning Msg as
// the error. It is used as a fallback when VAPID key setup fails.
// The underlying error is not included to avoid leaking internal
// details (e.g. database errors) in API responses; it is logged at
// the call site instead.
type NoopWebpusher struct {
Msg string
}
func (n *NoopWebpusher) Dispatch(context.Context, uuid.UUID, codersdk.WebpushMessage) error {
return xerrors.New(n.Msg)
}
func (n *NoopWebpusher) Test(context.Context, codersdk.WebpushSubscription) error {
return xerrors.New(n.Msg)
}
func (*NoopWebpusher) PublicKey() string {
return ""
}
// newSSRFSafeHTTPClient returns an HTTP client that rejects connections to
// private, loopback, link-local, multicast, and unspecified IP addresses.
// This prevents DNS rebinding attacks where a hostname passes URL-level
// validation but resolves to an internal IP at dial time.
func newSSRFSafeHTTPClient() *http.Client {
return &http.Client{
Transport: &http.Transport{
DialContext: (&net.Dialer{
Control: func(_ string, address string, _ syscall.RawConn) error {
host, _, err := net.SplitHostPort(address)
if err != nil {
return xerrors.Errorf("split host/port: %w", err)
}
ip, err := netip.ParseAddr(host)
if err != nil {
return xerrors.Errorf("parse resolved IP: %w", err)
}
if ip.IsPrivate() || ip.IsLoopback() || ip.IsLinkLocalUnicast() ||
ip.IsLinkLocalMulticast() || ip.IsMulticast() ||
ip.IsUnspecified() {
return xerrors.Errorf(
"webpush endpoint resolved to non-public address %s", ip.String(),
)
}
return nil
},
}).DialContext,
},
}
}
// RegenerateVAPIDKeys regenerates the VAPID keys and deletes all existing
// push subscriptions as part of the transaction, as they are no longer valid.
func RegenerateVAPIDKeys(ctx context.Context, db database.Store) (newPrivateKey string, newPublicKey string, err error) {
newPrivateKey, newPublicKey, err = webpush.GenerateVAPIDKeys()
if err != nil {
return "", "", xerrors.Errorf("generate new vapid keypair: %w", err)
}
if txErr := db.InTx(func(tx database.Store) error {
if err := tx.DeleteAllWebpushSubscriptions(ctx); err != nil {
return xerrors.Errorf("delete all webpush subscriptions: %w", err)
}
if err := tx.UpsertWebpushVAPIDKeys(ctx, database.UpsertWebpushVAPIDKeysParams{
VapidPrivateKey: newPrivateKey,
VapidPublicKey: newPublicKey,
}); err != nil {
return xerrors.Errorf("upsert notification vapid key: %w", err)
}
return nil
}, nil); txErr != nil {
return "", "", xerrors.Errorf("regenerate vapid keypair: %w", txErr)
}
return newPrivateKey, newPublicKey, nil
}