fix(coderd): validate webpush subscription endpoints (#24347)

Co-authored-by: Cian Johnston <cian@coder.com>
This commit is contained in:
Thomas Kosiewski
2026-04-15 11:31:43 +02:00
committed by GitHub
parent e317f3b239
commit 5812f84e1c
5 changed files with 386 additions and 18 deletions
+48
View File
@@ -4,7 +4,12 @@ import (
"database/sql" "database/sql"
"errors" "errors"
"net/http" "net/http"
"net/netip"
"net/url"
"slices" "slices"
"strings"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/database/dbtime"
@@ -33,6 +38,13 @@ func (api *API) postUserWebpushSubscription(rw http.ResponseWriter, r *http.Requ
if !httpapi.Read(ctx, rw, r, &req) { if !httpapi.Read(ctx, rw, r, &req) {
return return
} }
if err := validateWebpushEndpoint(req.Endpoint); err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid webpush endpoint.",
Detail: err.Error(),
})
return
}
if err := api.WebpushDispatcher.Test(ctx, req); err != nil { if err := api.WebpushDispatcher.Test(ctx, req); err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
@@ -62,6 +74,42 @@ func (api *API) postUserWebpushSubscription(rw http.ResponseWriter, r *http.Requ
rw.WriteHeader(http.StatusNoContent) rw.WriteHeader(http.StatusNoContent)
} }
func validateWebpushEndpoint(rawEndpoint string) error {
endpoint, err := url.Parse(rawEndpoint)
if err != nil {
return xerrors.Errorf("parse endpoint URL: %w", err)
}
if !endpoint.IsAbs() {
return xerrors.New("endpoint must be an absolute URL")
}
if endpoint.Scheme != "https" {
return xerrors.New("endpoint URL scheme must be https")
}
if endpoint.Host == "" {
return xerrors.New("endpoint host is required")
}
if endpoint.User != nil {
return xerrors.New("endpoint URL must not include userinfo")
}
hostname := strings.ToLower(endpoint.Hostname())
if hostname == "" {
return xerrors.New("endpoint hostname is required")
}
if hostname == "localhost" || strings.HasSuffix(hostname, ".localhost") {
return xerrors.New("endpoint hostname must not be localhost")
}
if ip, err := netip.ParseAddr(hostname); err == nil &&
(ip.IsPrivate() || ip.IsLoopback() || ip.IsLinkLocalUnicast() ||
ip.IsLinkLocalMulticast() || ip.IsMulticast() ||
ip.IsUnspecified()) {
return xerrors.New("endpoint IP must not be private, loopback, link-local, multicast, or unspecified")
}
return nil
}
// @Summary Delete user webpush subscription // @Summary Delete user webpush subscription
// @ID delete-user-webpush-subscription // @ID delete-user-webpush-subscription
// @Security CoderSessionToken // @Security CoderSessionToken
+55
View File
@@ -6,9 +6,12 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"io" "io"
"net"
"net/http" "net/http"
"net/netip"
"slices" "slices"
"sync" "sync"
"syscall"
"time" "time"
"github.com/SherClockHolmes/webpush-go" "github.com/SherClockHolmes/webpush-go"
@@ -47,6 +50,7 @@ type SubscriptionCacheInvalidator interface {
type options struct { type options struct {
clock quartz.Clock clock quartz.Clock
subscriptionCacheTTL time.Duration subscriptionCacheTTL time.Duration
httpClient *http.Client
} }
// Option configures optional behavior for a Webpusher. // Option configures optional behavior for a Webpusher.
@@ -68,6 +72,15 @@ func WithSubscriptionCacheTTL(ttl time.Duration) Option {
} }
} }
// WithHTTPClient overrides the default SSRF-safe HTTP client used to deliver
// push notifications. This is intended for tests that need to deliver to
// localhost test servers.
func WithHTTPClient(client *http.Client) Option {
return func(o *options) {
o.httpClient = client
}
}
// New creates a new Dispatcher to dispatch web push notifications. // New creates a new Dispatcher to dispatch web push notifications.
// //
// This is *not* integrated into the enqueue system unfortunately. // This is *not* integrated into the enqueue system unfortunately.
@@ -90,6 +103,9 @@ func New(ctx context.Context, log *slog.Logger, db database.Store, vapidSub stri
if cfg.subscriptionCacheTTL <= 0 { if cfg.subscriptionCacheTTL <= 0 {
cfg.subscriptionCacheTTL = defaultSubscriptionCacheTTL cfg.subscriptionCacheTTL = defaultSubscriptionCacheTTL
} }
if cfg.httpClient == nil {
cfg.httpClient = newSSRFSafeHTTPClient()
}
keys, err := db.GetWebpushVAPIDKeys(ctx) keys, err := db.GetWebpushVAPIDKeys(ctx)
if err != nil { if err != nil {
@@ -121,6 +137,7 @@ func New(ctx context.Context, log *slog.Logger, db database.Store, vapidSub stri
subscriptionCacheTTL: cfg.subscriptionCacheTTL, subscriptionCacheTTL: cfg.subscriptionCacheTTL,
subscriptionCache: make(map[uuid.UUID]cachedSubscriptions), subscriptionCache: make(map[uuid.UUID]cachedSubscriptions),
subscriptionGenerations: make(map[uuid.UUID]uint64), subscriptionGenerations: make(map[uuid.UUID]uint64),
httpClient: cfg.httpClient,
}, nil }, nil
} }
@@ -142,6 +159,12 @@ type Webpusher struct {
VAPIDPublicKey string VAPIDPublicKey string
VAPIDPrivateKey string VAPIDPrivateKey string
// httpClient is an SSRF-safe HTTP client that rejects connections to
// private, loopback, and link-local IP addresses at dial time. This
// closes the DNS rebinding TOCTOU gap where a hostname passes URL
// validation but resolves to a private IP when the connection is made.
httpClient *http.Client
clock quartz.Clock clock quartz.Clock
cacheMu sync.RWMutex cacheMu sync.RWMutex
@@ -338,6 +361,7 @@ func (n *Webpusher) webpushSend(ctx context.Context, msg []byte, endpoint string
Endpoint: endpoint, Endpoint: endpoint,
Keys: keys, Keys: keys,
}, &webpush.Options{ }, &webpush.Options{
HTTPClient: n.httpClient,
Subscriber: n.vapidSub, Subscriber: n.vapidSub,
VAPIDPublicKey: n.VAPIDPublicKey, VAPIDPublicKey: n.VAPIDPublicKey,
VAPIDPrivateKey: n.VAPIDPrivateKey, VAPIDPrivateKey: n.VAPIDPrivateKey,
@@ -407,6 +431,37 @@ func (*NoopWebpusher) PublicKey() string {
return "" return ""
} }
// newSSRFSafeHTTPClient returns an HTTP client that rejects connections to
// private, loopback, link-local, multicast, and unspecified IP addresses.
// This prevents DNS rebinding attacks where a hostname passes URL-level
// validation but resolves to an internal IP at dial time.
func newSSRFSafeHTTPClient() *http.Client {
return &http.Client{
Transport: &http.Transport{
DialContext: (&net.Dialer{
Control: func(_ string, address string, _ syscall.RawConn) error {
host, _, err := net.SplitHostPort(address)
if err != nil {
return xerrors.Errorf("split host/port: %w", err)
}
ip, err := netip.ParseAddr(host)
if err != nil {
return xerrors.Errorf("parse resolved IP: %w", err)
}
if ip.IsPrivate() || ip.IsLoopback() || ip.IsLinkLocalUnicast() ||
ip.IsLinkLocalMulticast() || ip.IsMulticast() ||
ip.IsUnspecified() {
return xerrors.Errorf(
"webpush endpoint resolved to non-public address %s", ip.String(),
)
}
return nil
},
}).DialContext,
},
}
}
// RegenerateVAPIDKeys regenerates the VAPID keys and deletes all existing // RegenerateVAPIDKeys regenerates the VAPID keys and deletes all existing
// push subscriptions as part of the transaction, as they are no longer valid. // push subscriptions as part of the transaction, as they are no longer valid.
func RegenerateVAPIDKeys(ctx context.Context, db database.Store) (newPrivateKey string, newPublicKey string, err error) { func RegenerateVAPIDKeys(ctx context.Context, db database.Store) (newPrivateKey string, newPublicKey string, err error) {
+58
View File
@@ -387,6 +387,8 @@ func assertWebpushPayload(t testing.TB, r *http.Request) {
} }
// setupPushTest creates a common test setup for webpush notification tests. // setupPushTest creates a common test setup for webpush notification tests.
// The test HTTP client bypasses SSRF protection so that httptest.Server
// (bound to 127.0.0.1) can be reached.
func setupPushTest(ctx context.Context, t *testing.T, handlerFunc func(w http.ResponseWriter, r *http.Request)) (webpush.Dispatcher, database.Store, string) { func setupPushTest(ctx context.Context, t *testing.T, handlerFunc func(w http.ResponseWriter, r *http.Request)) (webpush.Dispatcher, database.Store, string) {
t.Helper() t.Helper()
db, _ := dbtestutil.NewDB(t) db, _ := dbtestutil.NewDB(t)
@@ -400,6 +402,9 @@ func setupPushTestWithOptions(ctx context.Context, t *testing.T, db database.Sto
server := httptest.NewServer(http.HandlerFunc(handlerFunc)) server := httptest.NewServer(http.HandlerFunc(handlerFunc))
t.Cleanup(server.Close) t.Cleanup(server.Close)
// Use an unrestricted HTTP client for tests. The default SSRF-safe
// client rejects loopback addresses, which blocks httptest.Server.
opts = append(opts, webpush.WithHTTPClient(http.DefaultClient))
manager, err := webpush.New(ctx, &logger, db, "http://example.com", opts...) manager, err := webpush.New(ctx, &logger, db, "http://example.com", opts...)
require.NoError(t, err, "Failed to create webpush manager") require.NoError(t, err, "Failed to create webpush manager")
@@ -423,3 +428,56 @@ func TestNoopWebpusher(t *testing.T) {
require.Empty(t, noop.PublicKey()) require.Empty(t, noop.PublicKey())
} }
// TestSSRFPrevention verifies that the default SSRF-safe HTTP client blocks
// webpush delivery to loopback (and other non-public) addresses. This
// reproduces the attack vector from the original SSRF PoC: an authenticated
// user supplies a localhost endpoint in their webpush subscription, and the
// server must refuse to connect.
func TestSSRFPrevention(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
// Start a server that records whether it received a request.
var received atomic.Bool
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
received.Store(true)
w.WriteHeader(http.StatusCreated)
}))
defer server.Close()
// Create a dispatcher via New() WITHOUT WithHTTPClient so it
// uses the default SSRF-safe client that blocks loopback.
db, _ := dbtestutil.NewDB(t)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
manager, err := webpush.New(ctx, &logger, db, "http://example.com")
require.NoError(t, err)
// Test() calls webpushSend directly with the supplied endpoint.
err = manager.Test(ctx, codersdk.WebpushSubscription{
Endpoint: server.URL,
AuthKey: validEndpointAuthKey,
P256DHKey: validEndpointP256dhKey,
})
require.Error(t, err, "SSRF-safe client should reject Test() to loopback address")
assert.False(t, received.Load(), "Test() request should not reach the localhost server")
// Dispatch() goes through the subscription cache → webpushSend path.
user := dbgen.User(t, db, database.User{})
_, err = db.InsertWebpushSubscription(ctx, database.InsertWebpushSubscriptionParams{
CreatedAt: dbtime.Now(),
UserID: user.ID,
Endpoint: server.URL,
EndpointAuthKey: validEndpointAuthKey,
EndpointP256dhKey: validEndpointP256dhKey,
})
require.NoError(t, err)
err = manager.Dispatch(ctx, user.ID, codersdk.WebpushMessage{
Title: "SSRF test",
Body: "This should not arrive.",
})
require.Error(t, err, "SSRF-safe client should reject Dispatch() to loopback address")
assert.False(t, received.Load(), "Dispatch() request should not reach the localhost server")
}
+151
View File
@@ -0,0 +1,151 @@
package coderd
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestValidateWebpushEndpoint(t *testing.T) {
t.Parallel()
tests := []struct {
name string
endpoint string
wantErr bool
errSubstr string
}{
{
name: "valid https endpoint",
endpoint: "https://fcm.googleapis.com/fcm/send/abc123",
wantErr: false,
},
{
name: "valid https endpoint with port",
endpoint: "https://push.example.com:8443/subscription",
wantErr: false,
},
{
name: "relative URL",
endpoint: "/push/subscription",
wantErr: true,
errSubstr: "absolute URL",
},
{
name: "http scheme rejected",
endpoint: "http://push.example.com/subscription",
wantErr: true,
errSubstr: "scheme must be https",
},
{
name: "custom scheme rejected",
endpoint: "ws://push.example.com/subscription",
wantErr: true,
errSubstr: "scheme must be https",
},
{
name: "empty host",
endpoint: "https:///path",
wantErr: true,
errSubstr: "host is required",
},
{
name: "userinfo rejected",
endpoint: "https://user:pass@push.example.com/subscription",
wantErr: true,
errSubstr: "must not include userinfo",
},
{
name: "localhost rejected",
endpoint: "https://localhost/subscription",
wantErr: true,
errSubstr: "must not be localhost",
},
{
name: "subdomain of localhost rejected",
endpoint: "https://foo.localhost/subscription",
wantErr: true,
errSubstr: "must not be localhost",
},
{
name: "loopback IPv4 rejected",
endpoint: "https://127.0.0.1/subscription",
wantErr: true,
errSubstr: "must not be private",
},
{
name: "private 10.x rejected",
endpoint: "https://10.0.0.1/subscription",
wantErr: true,
errSubstr: "must not be private",
},
{
name: "private 192.168.x rejected",
endpoint: "https://192.168.1.1/subscription",
wantErr: true,
errSubstr: "must not be private",
},
{
name: "private 172.16.x rejected",
endpoint: "https://172.16.0.1/subscription",
wantErr: true,
errSubstr: "must not be private",
},
{
name: "link-local IPv4 rejected",
endpoint: "https://169.254.1.1/subscription",
wantErr: true,
errSubstr: "must not be private",
},
{
name: "unspecified IPv4 rejected",
endpoint: "https://0.0.0.0/subscription",
wantErr: true,
errSubstr: "must not be private",
},
{
name: "loopback IPv6 rejected",
endpoint: "https://[::1]/subscription",
wantErr: true,
errSubstr: "must not be private",
},
{
name: "unspecified IPv6 rejected",
endpoint: "https://[::]/subscription",
wantErr: true,
errSubstr: "must not be private",
},
{
name: "link-local IPv6 rejected",
endpoint: "https://[fe80::1]/subscription",
wantErr: true,
errSubstr: "must not be private",
},
{
name: "multicast IPv4 rejected",
endpoint: "https://224.0.0.1/subscription",
wantErr: true,
errSubstr: "must not be private",
},
{
name: "public IPv4 allowed",
endpoint: "https://203.0.113.1/subscription",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
err := validateWebpushEndpoint(tt.endpoint)
if tt.wantErr {
require.Error(t, err)
assert.Contains(t, err.Error(), tt.errSubstr,
"error should mention %q", tt.errSubstr)
} else {
require.NoError(t, err)
}
})
}
}
+74 -18
View File
@@ -3,7 +3,7 @@ package coderd_test
import ( import (
"context" "context"
"net/http" "net/http"
"net/http/httptest" "sync"
"sync/atomic" "sync/atomic"
"testing" "testing"
@@ -30,49 +30,48 @@ func TestWebpushSubscribeUnsubscribe(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitShort) ctx := testutil.Context(t, testutil.WaitShort)
client := coderdtest.New(t, &coderdtest.Options{}) dispatcher := &testWebpushDispatcher{}
client := coderdtest.New(t, &coderdtest.Options{
WebpushDispatcher: dispatcher,
})
owner := coderdtest.CreateFirstUser(t, client) owner := coderdtest.CreateFirstUser(t, client)
memberClient, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID) memberClient, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
_, anotherMember := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID) _, anotherMember := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
endpoint := "https://push.example.com/subscription/abc123"
var handlerCalls atomic.Int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusCreated)
handlerCalls.Add(1)
}))
defer server.Close()
// Seed the dispatcher cache with an empty subscription set. Creating the // Seed the dispatcher cache with an empty subscription set. Creating the
// subscription should invalidate that entry so the next dispatch sees the new // subscription should invalidate that entry so the next dispatch sees the new
// subscription immediately. // subscription immediately.
err := memberClient.PostTestWebpushMessage(ctx) err := memberClient.PostTestWebpushMessage(ctx)
require.NoError(t, err, "test webpush message without a subscription") require.NoError(t, err, "test webpush message without a subscription")
require.Zero(t, handlerCalls.Load(), "a user without subscriptions should not receive a push") require.Equal(t, int32(1), dispatcher.dispatchCalls.Load(), "dispatch should be called even with no subscriptions")
err = memberClient.PostWebpushSubscription(ctx, "me", codersdk.WebpushSubscription{ err = memberClient.PostWebpushSubscription(ctx, "me", codersdk.WebpushSubscription{
Endpoint: server.URL, Endpoint: endpoint,
AuthKey: validEndpointAuthKey, AuthKey: validEndpointAuthKey,
P256DHKey: validEndpointP256dhKey, P256DHKey: validEndpointP256dhKey,
}) })
require.NoError(t, err, "create webpush subscription") require.NoError(t, err, "create webpush subscription")
require.Equal(t, int32(1), handlerCalls.Load(), "subscription validation should hit the endpoint once") require.Equal(t, int32(1), dispatcher.testCalls.Load(), "subscription validation should call dispatcher test once")
require.Equal(t, 1, dispatcher.invalidateCount(), "subscribing should invalidate the user's cached subscriptions")
err = memberClient.PostTestWebpushMessage(ctx) err = memberClient.PostTestWebpushMessage(ctx)
require.NoError(t, err, "test webpush message after subscribing") require.NoError(t, err, "test webpush message after subscribing")
require.Equal(t, int32(2), handlerCalls.Load(), "the dispatcher should invalidate empty cache entries after subscribing") require.Equal(t, int32(2), dispatcher.dispatchCalls.Load(), "dispatch should be called after subscribing")
err = memberClient.DeleteWebpushSubscription(ctx, "me", codersdk.DeleteWebpushSubscription{ err = memberClient.DeleteWebpushSubscription(ctx, "me", codersdk.DeleteWebpushSubscription{
Endpoint: server.URL, Endpoint: endpoint,
}) })
require.NoError(t, err, "delete webpush subscription") require.NoError(t, err, "delete webpush subscription")
require.Equal(t, 2, dispatcher.invalidateCount(), "unsubscribing should invalidate the user's cached subscriptions")
err = memberClient.PostTestWebpushMessage(ctx) err = memberClient.PostTestWebpushMessage(ctx)
require.NoError(t, err, "test webpush message after unsubscribing") require.NoError(t, err, "test webpush message after unsubscribing")
require.Equal(t, int32(2), handlerCalls.Load(), "the dispatcher should invalidate cached subscriptions after unsubscribing") require.Equal(t, int32(3), dispatcher.dispatchCalls.Load(), "dispatch should be called after unsubscribing")
// Deleting the subscription for a non-existent endpoint should return a 404. // Deleting the subscription for a non-existent endpoint should return a 404.
err = memberClient.DeleteWebpushSubscription(ctx, "me", codersdk.DeleteWebpushSubscription{ err = memberClient.DeleteWebpushSubscription(ctx, "me", codersdk.DeleteWebpushSubscription{
Endpoint: server.URL, Endpoint: endpoint,
}) })
var sdkError *codersdk.Error var sdkError *codersdk.Error
require.Error(t, err) require.Error(t, err)
@@ -81,7 +80,7 @@ func TestWebpushSubscribeUnsubscribe(t *testing.T) {
// Creating a subscription for another user should not be allowed. // Creating a subscription for another user should not be allowed.
err = memberClient.PostWebpushSubscription(ctx, anotherMember.ID.String(), codersdk.WebpushSubscription{ err = memberClient.PostWebpushSubscription(ctx, anotherMember.ID.String(), codersdk.WebpushSubscription{
Endpoint: server.URL, Endpoint: endpoint,
AuthKey: validEndpointAuthKey, AuthKey: validEndpointAuthKey,
P256DHKey: validEndpointP256dhKey, P256DHKey: validEndpointP256dhKey,
}) })
@@ -89,11 +88,33 @@ func TestWebpushSubscribeUnsubscribe(t *testing.T) {
// Deleting a subscription for another user should not be allowed. // Deleting a subscription for another user should not be allowed.
err = memberClient.DeleteWebpushSubscription(ctx, anotherMember.ID.String(), codersdk.DeleteWebpushSubscription{ err = memberClient.DeleteWebpushSubscription(ctx, anotherMember.ID.String(), codersdk.DeleteWebpushSubscription{
Endpoint: server.URL, Endpoint: endpoint,
}) })
require.Error(t, err, "delete webpush subscription for another user") require.Error(t, err, "delete webpush subscription for another user")
} }
func TestWebpushSubscribeRejectsInvalidEndpoint(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
client := coderdtest.New(t, &coderdtest.Options{
WebpushDispatcher: &testWebpushDispatcher{},
})
owner := coderdtest.CreateFirstUser(t, client)
memberClient, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
err := memberClient.PostWebpushSubscription(ctx, "me", codersdk.WebpushSubscription{
Endpoint: "http://127.0.0.1:8080/subscription",
AuthKey: validEndpointAuthKey,
P256DHKey: validEndpointP256dhKey,
})
var sdkError *codersdk.Error
require.Error(t, err)
require.ErrorAsf(t, err, &sdkError, "error should be of type *codersdk.Error")
require.Equal(t, http.StatusBadRequest, sdkError.StatusCode())
require.Contains(t, sdkError.Error(), "endpoint URL scheme must be https")
}
// testWebpushErrorStore wraps a real database.Store and allows injecting // testWebpushErrorStore wraps a real database.Store and allows injecting
// errors into GetWebpushSubscriptionsByUserID. // errors into GetWebpushSubscriptionsByUserID.
type testWebpushErrorStore struct { type testWebpushErrorStore struct {
@@ -101,6 +122,41 @@ type testWebpushErrorStore struct {
getWebpushSubscriptionsErr atomic.Pointer[error] getWebpushSubscriptionsErr atomic.Pointer[error]
} }
type testWebpushDispatcher struct {
testCalls atomic.Int32
dispatchCalls atomic.Int32
invalidateUserIDs []uuid.UUID
invalidateUserLock sync.Mutex
}
func (d *testWebpushDispatcher) Dispatch(_ context.Context, _ uuid.UUID, _ codersdk.WebpushMessage) error {
d.dispatchCalls.Add(1)
return nil
}
func (d *testWebpushDispatcher) Test(_ context.Context, _ codersdk.WebpushSubscription) error {
d.testCalls.Add(1)
return nil
}
func (*testWebpushDispatcher) PublicKey() string {
return ""
}
// InvalidateUser implements webpush.SubscriptionCacheInvalidator so the
// handler exercises the cache-invalidation path on subscribe/unsubscribe.
func (d *testWebpushDispatcher) InvalidateUser(userID uuid.UUID) {
d.invalidateUserLock.Lock()
defer d.invalidateUserLock.Unlock()
d.invalidateUserIDs = append(d.invalidateUserIDs, userID)
}
func (d *testWebpushDispatcher) invalidateCount() int {
d.invalidateUserLock.Lock()
defer d.invalidateUserLock.Unlock()
return len(d.invalidateUserIDs)
}
func (s *testWebpushErrorStore) GetWebpushSubscriptionsByUserID(ctx context.Context, userID uuid.UUID) ([]database.WebpushSubscription, error) { func (s *testWebpushErrorStore) GetWebpushSubscriptionsByUserID(ctx context.Context, userID uuid.UUID) ([]database.WebpushSubscription, error) {
if err := s.getWebpushSubscriptionsErr.Load(); err != nil { if err := s.getWebpushSubscriptionsErr.Load(); err != nil {
return nil, *err return nil, *err