mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
fix(coderd): validate webpush subscription endpoints (#24347)
Co-authored-by: Cian Johnston <cian@coder.com>
This commit is contained in:
@@ -4,7 +4,12 @@ import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
@@ -33,6 +38,13 @@ func (api *API) postUserWebpushSubscription(rw http.ResponseWriter, r *http.Requ
|
||||
if !httpapi.Read(ctx, rw, r, &req) {
|
||||
return
|
||||
}
|
||||
if err := validateWebpushEndpoint(req.Endpoint); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid webpush endpoint.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if err := api.WebpushDispatcher.Test(ctx, req); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
@@ -62,6 +74,42 @@ func (api *API) postUserWebpushSubscription(rw http.ResponseWriter, r *http.Requ
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func validateWebpushEndpoint(rawEndpoint string) error {
|
||||
endpoint, err := url.Parse(rawEndpoint)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("parse endpoint URL: %w", err)
|
||||
}
|
||||
if !endpoint.IsAbs() {
|
||||
return xerrors.New("endpoint must be an absolute URL")
|
||||
}
|
||||
if endpoint.Scheme != "https" {
|
||||
return xerrors.New("endpoint URL scheme must be https")
|
||||
}
|
||||
if endpoint.Host == "" {
|
||||
return xerrors.New("endpoint host is required")
|
||||
}
|
||||
if endpoint.User != nil {
|
||||
return xerrors.New("endpoint URL must not include userinfo")
|
||||
}
|
||||
|
||||
hostname := strings.ToLower(endpoint.Hostname())
|
||||
if hostname == "" {
|
||||
return xerrors.New("endpoint hostname is required")
|
||||
}
|
||||
if hostname == "localhost" || strings.HasSuffix(hostname, ".localhost") {
|
||||
return xerrors.New("endpoint hostname must not be localhost")
|
||||
}
|
||||
|
||||
if ip, err := netip.ParseAddr(hostname); err == nil &&
|
||||
(ip.IsPrivate() || ip.IsLoopback() || ip.IsLinkLocalUnicast() ||
|
||||
ip.IsLinkLocalMulticast() || ip.IsMulticast() ||
|
||||
ip.IsUnspecified()) {
|
||||
return xerrors.New("endpoint IP must not be private, loopback, link-local, multicast, or unspecified")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// @Summary Delete user webpush subscription
|
||||
// @ID delete-user-webpush-subscription
|
||||
// @Security CoderSessionToken
|
||||
|
||||
@@ -6,9 +6,12 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/SherClockHolmes/webpush-go"
|
||||
@@ -47,6 +50,7 @@ type SubscriptionCacheInvalidator interface {
|
||||
type options struct {
|
||||
clock quartz.Clock
|
||||
subscriptionCacheTTL time.Duration
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// Option configures optional behavior for a Webpusher.
|
||||
@@ -68,6 +72,15 @@ func WithSubscriptionCacheTTL(ttl time.Duration) Option {
|
||||
}
|
||||
}
|
||||
|
||||
// WithHTTPClient overrides the default SSRF-safe HTTP client used to deliver
|
||||
// push notifications. This is intended for tests that need to deliver to
|
||||
// localhost test servers.
|
||||
func WithHTTPClient(client *http.Client) Option {
|
||||
return func(o *options) {
|
||||
o.httpClient = client
|
||||
}
|
||||
}
|
||||
|
||||
// New creates a new Dispatcher to dispatch web push notifications.
|
||||
//
|
||||
// This is *not* integrated into the enqueue system unfortunately.
|
||||
@@ -90,6 +103,9 @@ func New(ctx context.Context, log *slog.Logger, db database.Store, vapidSub stri
|
||||
if cfg.subscriptionCacheTTL <= 0 {
|
||||
cfg.subscriptionCacheTTL = defaultSubscriptionCacheTTL
|
||||
}
|
||||
if cfg.httpClient == nil {
|
||||
cfg.httpClient = newSSRFSafeHTTPClient()
|
||||
}
|
||||
|
||||
keys, err := db.GetWebpushVAPIDKeys(ctx)
|
||||
if err != nil {
|
||||
@@ -121,6 +137,7 @@ func New(ctx context.Context, log *slog.Logger, db database.Store, vapidSub stri
|
||||
subscriptionCacheTTL: cfg.subscriptionCacheTTL,
|
||||
subscriptionCache: make(map[uuid.UUID]cachedSubscriptions),
|
||||
subscriptionGenerations: make(map[uuid.UUID]uint64),
|
||||
httpClient: cfg.httpClient,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -142,6 +159,12 @@ type Webpusher struct {
|
||||
VAPIDPublicKey string
|
||||
VAPIDPrivateKey string
|
||||
|
||||
// httpClient is an SSRF-safe HTTP client that rejects connections to
|
||||
// private, loopback, and link-local IP addresses at dial time. This
|
||||
// closes the DNS rebinding TOCTOU gap where a hostname passes URL
|
||||
// validation but resolves to a private IP when the connection is made.
|
||||
httpClient *http.Client
|
||||
|
||||
clock quartz.Clock
|
||||
|
||||
cacheMu sync.RWMutex
|
||||
@@ -338,6 +361,7 @@ func (n *Webpusher) webpushSend(ctx context.Context, msg []byte, endpoint string
|
||||
Endpoint: endpoint,
|
||||
Keys: keys,
|
||||
}, &webpush.Options{
|
||||
HTTPClient: n.httpClient,
|
||||
Subscriber: n.vapidSub,
|
||||
VAPIDPublicKey: n.VAPIDPublicKey,
|
||||
VAPIDPrivateKey: n.VAPIDPrivateKey,
|
||||
@@ -407,6 +431,37 @@ func (*NoopWebpusher) PublicKey() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// newSSRFSafeHTTPClient returns an HTTP client that rejects connections to
|
||||
// private, loopback, link-local, multicast, and unspecified IP addresses.
|
||||
// This prevents DNS rebinding attacks where a hostname passes URL-level
|
||||
// validation but resolves to an internal IP at dial time.
|
||||
func newSSRFSafeHTTPClient() *http.Client {
|
||||
return &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialContext: (&net.Dialer{
|
||||
Control: func(_ string, address string, _ syscall.RawConn) error {
|
||||
host, _, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("split host/port: %w", err)
|
||||
}
|
||||
ip, err := netip.ParseAddr(host)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("parse resolved IP: %w", err)
|
||||
}
|
||||
if ip.IsPrivate() || ip.IsLoopback() || ip.IsLinkLocalUnicast() ||
|
||||
ip.IsLinkLocalMulticast() || ip.IsMulticast() ||
|
||||
ip.IsUnspecified() {
|
||||
return xerrors.Errorf(
|
||||
"webpush endpoint resolved to non-public address %s", ip.String(),
|
||||
)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}).DialContext,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// RegenerateVAPIDKeys regenerates the VAPID keys and deletes all existing
|
||||
// push subscriptions as part of the transaction, as they are no longer valid.
|
||||
func RegenerateVAPIDKeys(ctx context.Context, db database.Store) (newPrivateKey string, newPublicKey string, err error) {
|
||||
|
||||
@@ -387,6 +387,8 @@ func assertWebpushPayload(t testing.TB, r *http.Request) {
|
||||
}
|
||||
|
||||
// setupPushTest creates a common test setup for webpush notification tests.
|
||||
// The test HTTP client bypasses SSRF protection so that httptest.Server
|
||||
// (bound to 127.0.0.1) can be reached.
|
||||
func setupPushTest(ctx context.Context, t *testing.T, handlerFunc func(w http.ResponseWriter, r *http.Request)) (webpush.Dispatcher, database.Store, string) {
|
||||
t.Helper()
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
@@ -400,6 +402,9 @@ func setupPushTestWithOptions(ctx context.Context, t *testing.T, db database.Sto
|
||||
server := httptest.NewServer(http.HandlerFunc(handlerFunc))
|
||||
t.Cleanup(server.Close)
|
||||
|
||||
// Use an unrestricted HTTP client for tests. The default SSRF-safe
|
||||
// client rejects loopback addresses, which blocks httptest.Server.
|
||||
opts = append(opts, webpush.WithHTTPClient(http.DefaultClient))
|
||||
manager, err := webpush.New(ctx, &logger, db, "http://example.com", opts...)
|
||||
require.NoError(t, err, "Failed to create webpush manager")
|
||||
|
||||
@@ -423,3 +428,56 @@ func TestNoopWebpusher(t *testing.T) {
|
||||
|
||||
require.Empty(t, noop.PublicKey())
|
||||
}
|
||||
|
||||
// TestSSRFPrevention verifies that the default SSRF-safe HTTP client blocks
|
||||
// webpush delivery to loopback (and other non-public) addresses. This
|
||||
// reproduces the attack vector from the original SSRF PoC: an authenticated
|
||||
// user supplies a localhost endpoint in their webpush subscription, and the
|
||||
// server must refuse to connect.
|
||||
func TestSSRFPrevention(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
// Start a server that records whether it received a request.
|
||||
var received atomic.Bool
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
received.Store(true)
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Create a dispatcher via New() WITHOUT WithHTTPClient so it
|
||||
// uses the default SSRF-safe client that blocks loopback.
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
manager, err := webpush.New(ctx, &logger, db, "http://example.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test() calls webpushSend directly with the supplied endpoint.
|
||||
err = manager.Test(ctx, codersdk.WebpushSubscription{
|
||||
Endpoint: server.URL,
|
||||
AuthKey: validEndpointAuthKey,
|
||||
P256DHKey: validEndpointP256dhKey,
|
||||
})
|
||||
require.Error(t, err, "SSRF-safe client should reject Test() to loopback address")
|
||||
assert.False(t, received.Load(), "Test() request should not reach the localhost server")
|
||||
|
||||
// Dispatch() goes through the subscription cache → webpushSend path.
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
_, err = db.InsertWebpushSubscription(ctx, database.InsertWebpushSubscriptionParams{
|
||||
CreatedAt: dbtime.Now(),
|
||||
UserID: user.ID,
|
||||
Endpoint: server.URL,
|
||||
EndpointAuthKey: validEndpointAuthKey,
|
||||
EndpointP256dhKey: validEndpointP256dhKey,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = manager.Dispatch(ctx, user.ID, codersdk.WebpushMessage{
|
||||
Title: "SSRF test",
|
||||
Body: "This should not arrive.",
|
||||
})
|
||||
require.Error(t, err, "SSRF-safe client should reject Dispatch() to loopback address")
|
||||
assert.False(t, received.Load(), "Dispatch() request should not reach the localhost server")
|
||||
}
|
||||
|
||||
@@ -0,0 +1,151 @@
|
||||
package coderd
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestValidateWebpushEndpoint(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
endpoint string
|
||||
wantErr bool
|
||||
errSubstr string
|
||||
}{
|
||||
{
|
||||
name: "valid https endpoint",
|
||||
endpoint: "https://fcm.googleapis.com/fcm/send/abc123",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid https endpoint with port",
|
||||
endpoint: "https://push.example.com:8443/subscription",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "relative URL",
|
||||
endpoint: "/push/subscription",
|
||||
wantErr: true,
|
||||
errSubstr: "absolute URL",
|
||||
},
|
||||
{
|
||||
name: "http scheme rejected",
|
||||
endpoint: "http://push.example.com/subscription",
|
||||
wantErr: true,
|
||||
errSubstr: "scheme must be https",
|
||||
},
|
||||
{
|
||||
name: "custom scheme rejected",
|
||||
endpoint: "ws://push.example.com/subscription",
|
||||
wantErr: true,
|
||||
errSubstr: "scheme must be https",
|
||||
},
|
||||
{
|
||||
name: "empty host",
|
||||
endpoint: "https:///path",
|
||||
wantErr: true,
|
||||
errSubstr: "host is required",
|
||||
},
|
||||
{
|
||||
name: "userinfo rejected",
|
||||
endpoint: "https://user:pass@push.example.com/subscription",
|
||||
wantErr: true,
|
||||
errSubstr: "must not include userinfo",
|
||||
},
|
||||
{
|
||||
name: "localhost rejected",
|
||||
endpoint: "https://localhost/subscription",
|
||||
wantErr: true,
|
||||
errSubstr: "must not be localhost",
|
||||
},
|
||||
{
|
||||
name: "subdomain of localhost rejected",
|
||||
endpoint: "https://foo.localhost/subscription",
|
||||
wantErr: true,
|
||||
errSubstr: "must not be localhost",
|
||||
},
|
||||
{
|
||||
name: "loopback IPv4 rejected",
|
||||
endpoint: "https://127.0.0.1/subscription",
|
||||
wantErr: true,
|
||||
errSubstr: "must not be private",
|
||||
},
|
||||
{
|
||||
name: "private 10.x rejected",
|
||||
endpoint: "https://10.0.0.1/subscription",
|
||||
wantErr: true,
|
||||
errSubstr: "must not be private",
|
||||
},
|
||||
{
|
||||
name: "private 192.168.x rejected",
|
||||
endpoint: "https://192.168.1.1/subscription",
|
||||
wantErr: true,
|
||||
errSubstr: "must not be private",
|
||||
},
|
||||
{
|
||||
name: "private 172.16.x rejected",
|
||||
endpoint: "https://172.16.0.1/subscription",
|
||||
wantErr: true,
|
||||
errSubstr: "must not be private",
|
||||
},
|
||||
{
|
||||
name: "link-local IPv4 rejected",
|
||||
endpoint: "https://169.254.1.1/subscription",
|
||||
wantErr: true,
|
||||
errSubstr: "must not be private",
|
||||
},
|
||||
{
|
||||
name: "unspecified IPv4 rejected",
|
||||
endpoint: "https://0.0.0.0/subscription",
|
||||
wantErr: true,
|
||||
errSubstr: "must not be private",
|
||||
},
|
||||
{
|
||||
name: "loopback IPv6 rejected",
|
||||
endpoint: "https://[::1]/subscription",
|
||||
wantErr: true,
|
||||
errSubstr: "must not be private",
|
||||
},
|
||||
{
|
||||
name: "unspecified IPv6 rejected",
|
||||
endpoint: "https://[::]/subscription",
|
||||
wantErr: true,
|
||||
errSubstr: "must not be private",
|
||||
},
|
||||
{
|
||||
name: "link-local IPv6 rejected",
|
||||
endpoint: "https://[fe80::1]/subscription",
|
||||
wantErr: true,
|
||||
errSubstr: "must not be private",
|
||||
},
|
||||
{
|
||||
name: "multicast IPv4 rejected",
|
||||
endpoint: "https://224.0.0.1/subscription",
|
||||
wantErr: true,
|
||||
errSubstr: "must not be private",
|
||||
},
|
||||
{
|
||||
name: "public IPv4 allowed",
|
||||
endpoint: "https://203.0.113.1/subscription",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
err := validateWebpushEndpoint(tt.endpoint)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errSubstr,
|
||||
"error should mention %q", tt.errSubstr)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
+74
-18
@@ -3,7 +3,7 @@ package coderd_test
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
@@ -30,49 +30,48 @@ func TestWebpushSubscribeUnsubscribe(t *testing.T) {
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
client := coderdtest.New(t, &coderdtest.Options{})
|
||||
dispatcher := &testWebpushDispatcher{}
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
WebpushDispatcher: dispatcher,
|
||||
})
|
||||
owner := coderdtest.CreateFirstUser(t, client)
|
||||
memberClient, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
|
||||
_, anotherMember := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
|
||||
|
||||
var handlerCalls atomic.Int32
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
handlerCalls.Add(1)
|
||||
}))
|
||||
defer server.Close()
|
||||
endpoint := "https://push.example.com/subscription/abc123"
|
||||
|
||||
// Seed the dispatcher cache with an empty subscription set. Creating the
|
||||
// subscription should invalidate that entry so the next dispatch sees the new
|
||||
// subscription immediately.
|
||||
err := memberClient.PostTestWebpushMessage(ctx)
|
||||
require.NoError(t, err, "test webpush message without a subscription")
|
||||
require.Zero(t, handlerCalls.Load(), "a user without subscriptions should not receive a push")
|
||||
require.Equal(t, int32(1), dispatcher.dispatchCalls.Load(), "dispatch should be called even with no subscriptions")
|
||||
|
||||
err = memberClient.PostWebpushSubscription(ctx, "me", codersdk.WebpushSubscription{
|
||||
Endpoint: server.URL,
|
||||
Endpoint: endpoint,
|
||||
AuthKey: validEndpointAuthKey,
|
||||
P256DHKey: validEndpointP256dhKey,
|
||||
})
|
||||
require.NoError(t, err, "create webpush subscription")
|
||||
require.Equal(t, int32(1), handlerCalls.Load(), "subscription validation should hit the endpoint once")
|
||||
require.Equal(t, int32(1), dispatcher.testCalls.Load(), "subscription validation should call dispatcher test once")
|
||||
require.Equal(t, 1, dispatcher.invalidateCount(), "subscribing should invalidate the user's cached subscriptions")
|
||||
|
||||
err = memberClient.PostTestWebpushMessage(ctx)
|
||||
require.NoError(t, err, "test webpush message after subscribing")
|
||||
require.Equal(t, int32(2), handlerCalls.Load(), "the dispatcher should invalidate empty cache entries after subscribing")
|
||||
require.Equal(t, int32(2), dispatcher.dispatchCalls.Load(), "dispatch should be called after subscribing")
|
||||
|
||||
err = memberClient.DeleteWebpushSubscription(ctx, "me", codersdk.DeleteWebpushSubscription{
|
||||
Endpoint: server.URL,
|
||||
Endpoint: endpoint,
|
||||
})
|
||||
require.NoError(t, err, "delete webpush subscription")
|
||||
require.Equal(t, 2, dispatcher.invalidateCount(), "unsubscribing should invalidate the user's cached subscriptions")
|
||||
|
||||
err = memberClient.PostTestWebpushMessage(ctx)
|
||||
require.NoError(t, err, "test webpush message after unsubscribing")
|
||||
require.Equal(t, int32(2), handlerCalls.Load(), "the dispatcher should invalidate cached subscriptions after unsubscribing")
|
||||
require.Equal(t, int32(3), dispatcher.dispatchCalls.Load(), "dispatch should be called after unsubscribing")
|
||||
|
||||
// Deleting the subscription for a non-existent endpoint should return a 404.
|
||||
err = memberClient.DeleteWebpushSubscription(ctx, "me", codersdk.DeleteWebpushSubscription{
|
||||
Endpoint: server.URL,
|
||||
Endpoint: endpoint,
|
||||
})
|
||||
var sdkError *codersdk.Error
|
||||
require.Error(t, err)
|
||||
@@ -81,7 +80,7 @@ func TestWebpushSubscribeUnsubscribe(t *testing.T) {
|
||||
|
||||
// Creating a subscription for another user should not be allowed.
|
||||
err = memberClient.PostWebpushSubscription(ctx, anotherMember.ID.String(), codersdk.WebpushSubscription{
|
||||
Endpoint: server.URL,
|
||||
Endpoint: endpoint,
|
||||
AuthKey: validEndpointAuthKey,
|
||||
P256DHKey: validEndpointP256dhKey,
|
||||
})
|
||||
@@ -89,11 +88,33 @@ func TestWebpushSubscribeUnsubscribe(t *testing.T) {
|
||||
|
||||
// Deleting a subscription for another user should not be allowed.
|
||||
err = memberClient.DeleteWebpushSubscription(ctx, anotherMember.ID.String(), codersdk.DeleteWebpushSubscription{
|
||||
Endpoint: server.URL,
|
||||
Endpoint: endpoint,
|
||||
})
|
||||
require.Error(t, err, "delete webpush subscription for another user")
|
||||
}
|
||||
|
||||
func TestWebpushSubscribeRejectsInvalidEndpoint(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
WebpushDispatcher: &testWebpushDispatcher{},
|
||||
})
|
||||
owner := coderdtest.CreateFirstUser(t, client)
|
||||
memberClient, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
|
||||
|
||||
err := memberClient.PostWebpushSubscription(ctx, "me", codersdk.WebpushSubscription{
|
||||
Endpoint: "http://127.0.0.1:8080/subscription",
|
||||
AuthKey: validEndpointAuthKey,
|
||||
P256DHKey: validEndpointP256dhKey,
|
||||
})
|
||||
var sdkError *codersdk.Error
|
||||
require.Error(t, err)
|
||||
require.ErrorAsf(t, err, &sdkError, "error should be of type *codersdk.Error")
|
||||
require.Equal(t, http.StatusBadRequest, sdkError.StatusCode())
|
||||
require.Contains(t, sdkError.Error(), "endpoint URL scheme must be https")
|
||||
}
|
||||
|
||||
// testWebpushErrorStore wraps a real database.Store and allows injecting
|
||||
// errors into GetWebpushSubscriptionsByUserID.
|
||||
type testWebpushErrorStore struct {
|
||||
@@ -101,6 +122,41 @@ type testWebpushErrorStore struct {
|
||||
getWebpushSubscriptionsErr atomic.Pointer[error]
|
||||
}
|
||||
|
||||
type testWebpushDispatcher struct {
|
||||
testCalls atomic.Int32
|
||||
dispatchCalls atomic.Int32
|
||||
invalidateUserIDs []uuid.UUID
|
||||
invalidateUserLock sync.Mutex
|
||||
}
|
||||
|
||||
func (d *testWebpushDispatcher) Dispatch(_ context.Context, _ uuid.UUID, _ codersdk.WebpushMessage) error {
|
||||
d.dispatchCalls.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *testWebpushDispatcher) Test(_ context.Context, _ codersdk.WebpushSubscription) error {
|
||||
d.testCalls.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*testWebpushDispatcher) PublicKey() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// InvalidateUser implements webpush.SubscriptionCacheInvalidator so the
|
||||
// handler exercises the cache-invalidation path on subscribe/unsubscribe.
|
||||
func (d *testWebpushDispatcher) InvalidateUser(userID uuid.UUID) {
|
||||
d.invalidateUserLock.Lock()
|
||||
defer d.invalidateUserLock.Unlock()
|
||||
d.invalidateUserIDs = append(d.invalidateUserIDs, userID)
|
||||
}
|
||||
|
||||
func (d *testWebpushDispatcher) invalidateCount() int {
|
||||
d.invalidateUserLock.Lock()
|
||||
defer d.invalidateUserLock.Unlock()
|
||||
return len(d.invalidateUserIDs)
|
||||
}
|
||||
|
||||
func (s *testWebpushErrorStore) GetWebpushSubscriptionsByUserID(ctx context.Context, userID uuid.UUID) ([]database.WebpushSubscription, error) {
|
||||
if err := s.getWebpushSubscriptionsErr.Load(); err != nil {
|
||||
return nil, *err
|
||||
|
||||
Reference in New Issue
Block a user