mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
fix(coderd): validate webpush subscription endpoints (#24347)
Co-authored-by: Cian Johnston <cian@coder.com>
This commit is contained in:
@@ -4,7 +4,12 @@ import (
|
|||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/netip"
|
||||||
|
"net/url"
|
||||||
"slices"
|
"slices"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"golang.org/x/xerrors"
|
||||||
|
|
||||||
"github.com/coder/coder/v2/coderd/database"
|
"github.com/coder/coder/v2/coderd/database"
|
||||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||||
@@ -33,6 +38,13 @@ func (api *API) postUserWebpushSubscription(rw http.ResponseWriter, r *http.Requ
|
|||||||
if !httpapi.Read(ctx, rw, r, &req) {
|
if !httpapi.Read(ctx, rw, r, &req) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if err := validateWebpushEndpoint(req.Endpoint); err != nil {
|
||||||
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||||
|
Message: "Invalid webpush endpoint.",
|
||||||
|
Detail: err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if err := api.WebpushDispatcher.Test(ctx, req); err != nil {
|
if err := api.WebpushDispatcher.Test(ctx, req); err != nil {
|
||||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||||
@@ -62,6 +74,42 @@ func (api *API) postUserWebpushSubscription(rw http.ResponseWriter, r *http.Requ
|
|||||||
rw.WriteHeader(http.StatusNoContent)
|
rw.WriteHeader(http.StatusNoContent)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func validateWebpushEndpoint(rawEndpoint string) error {
|
||||||
|
endpoint, err := url.Parse(rawEndpoint)
|
||||||
|
if err != nil {
|
||||||
|
return xerrors.Errorf("parse endpoint URL: %w", err)
|
||||||
|
}
|
||||||
|
if !endpoint.IsAbs() {
|
||||||
|
return xerrors.New("endpoint must be an absolute URL")
|
||||||
|
}
|
||||||
|
if endpoint.Scheme != "https" {
|
||||||
|
return xerrors.New("endpoint URL scheme must be https")
|
||||||
|
}
|
||||||
|
if endpoint.Host == "" {
|
||||||
|
return xerrors.New("endpoint host is required")
|
||||||
|
}
|
||||||
|
if endpoint.User != nil {
|
||||||
|
return xerrors.New("endpoint URL must not include userinfo")
|
||||||
|
}
|
||||||
|
|
||||||
|
hostname := strings.ToLower(endpoint.Hostname())
|
||||||
|
if hostname == "" {
|
||||||
|
return xerrors.New("endpoint hostname is required")
|
||||||
|
}
|
||||||
|
if hostname == "localhost" || strings.HasSuffix(hostname, ".localhost") {
|
||||||
|
return xerrors.New("endpoint hostname must not be localhost")
|
||||||
|
}
|
||||||
|
|
||||||
|
if ip, err := netip.ParseAddr(hostname); err == nil &&
|
||||||
|
(ip.IsPrivate() || ip.IsLoopback() || ip.IsLinkLocalUnicast() ||
|
||||||
|
ip.IsLinkLocalMulticast() || ip.IsMulticast() ||
|
||||||
|
ip.IsUnspecified()) {
|
||||||
|
return xerrors.New("endpoint IP must not be private, loopback, link-local, multicast, or unspecified")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// @Summary Delete user webpush subscription
|
// @Summary Delete user webpush subscription
|
||||||
// @ID delete-user-webpush-subscription
|
// @ID delete-user-webpush-subscription
|
||||||
// @Security CoderSessionToken
|
// @Security CoderSessionToken
|
||||||
|
|||||||
@@ -6,9 +6,12 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/netip"
|
||||||
"slices"
|
"slices"
|
||||||
"sync"
|
"sync"
|
||||||
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/SherClockHolmes/webpush-go"
|
"github.com/SherClockHolmes/webpush-go"
|
||||||
@@ -47,6 +50,7 @@ type SubscriptionCacheInvalidator interface {
|
|||||||
type options struct {
|
type options struct {
|
||||||
clock quartz.Clock
|
clock quartz.Clock
|
||||||
subscriptionCacheTTL time.Duration
|
subscriptionCacheTTL time.Duration
|
||||||
|
httpClient *http.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
// Option configures optional behavior for a Webpusher.
|
// Option configures optional behavior for a Webpusher.
|
||||||
@@ -68,6 +72,15 @@ func WithSubscriptionCacheTTL(ttl time.Duration) Option {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithHTTPClient overrides the default SSRF-safe HTTP client used to deliver
|
||||||
|
// push notifications. This is intended for tests that need to deliver to
|
||||||
|
// localhost test servers.
|
||||||
|
func WithHTTPClient(client *http.Client) Option {
|
||||||
|
return func(o *options) {
|
||||||
|
o.httpClient = client
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// New creates a new Dispatcher to dispatch web push notifications.
|
// New creates a new Dispatcher to dispatch web push notifications.
|
||||||
//
|
//
|
||||||
// This is *not* integrated into the enqueue system unfortunately.
|
// This is *not* integrated into the enqueue system unfortunately.
|
||||||
@@ -90,6 +103,9 @@ func New(ctx context.Context, log *slog.Logger, db database.Store, vapidSub stri
|
|||||||
if cfg.subscriptionCacheTTL <= 0 {
|
if cfg.subscriptionCacheTTL <= 0 {
|
||||||
cfg.subscriptionCacheTTL = defaultSubscriptionCacheTTL
|
cfg.subscriptionCacheTTL = defaultSubscriptionCacheTTL
|
||||||
}
|
}
|
||||||
|
if cfg.httpClient == nil {
|
||||||
|
cfg.httpClient = newSSRFSafeHTTPClient()
|
||||||
|
}
|
||||||
|
|
||||||
keys, err := db.GetWebpushVAPIDKeys(ctx)
|
keys, err := db.GetWebpushVAPIDKeys(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -121,6 +137,7 @@ func New(ctx context.Context, log *slog.Logger, db database.Store, vapidSub stri
|
|||||||
subscriptionCacheTTL: cfg.subscriptionCacheTTL,
|
subscriptionCacheTTL: cfg.subscriptionCacheTTL,
|
||||||
subscriptionCache: make(map[uuid.UUID]cachedSubscriptions),
|
subscriptionCache: make(map[uuid.UUID]cachedSubscriptions),
|
||||||
subscriptionGenerations: make(map[uuid.UUID]uint64),
|
subscriptionGenerations: make(map[uuid.UUID]uint64),
|
||||||
|
httpClient: cfg.httpClient,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -142,6 +159,12 @@ type Webpusher struct {
|
|||||||
VAPIDPublicKey string
|
VAPIDPublicKey string
|
||||||
VAPIDPrivateKey string
|
VAPIDPrivateKey string
|
||||||
|
|
||||||
|
// httpClient is an SSRF-safe HTTP client that rejects connections to
|
||||||
|
// private, loopback, and link-local IP addresses at dial time. This
|
||||||
|
// closes the DNS rebinding TOCTOU gap where a hostname passes URL
|
||||||
|
// validation but resolves to a private IP when the connection is made.
|
||||||
|
httpClient *http.Client
|
||||||
|
|
||||||
clock quartz.Clock
|
clock quartz.Clock
|
||||||
|
|
||||||
cacheMu sync.RWMutex
|
cacheMu sync.RWMutex
|
||||||
@@ -338,6 +361,7 @@ func (n *Webpusher) webpushSend(ctx context.Context, msg []byte, endpoint string
|
|||||||
Endpoint: endpoint,
|
Endpoint: endpoint,
|
||||||
Keys: keys,
|
Keys: keys,
|
||||||
}, &webpush.Options{
|
}, &webpush.Options{
|
||||||
|
HTTPClient: n.httpClient,
|
||||||
Subscriber: n.vapidSub,
|
Subscriber: n.vapidSub,
|
||||||
VAPIDPublicKey: n.VAPIDPublicKey,
|
VAPIDPublicKey: n.VAPIDPublicKey,
|
||||||
VAPIDPrivateKey: n.VAPIDPrivateKey,
|
VAPIDPrivateKey: n.VAPIDPrivateKey,
|
||||||
@@ -407,6 +431,37 @@ func (*NoopWebpusher) PublicKey() string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// newSSRFSafeHTTPClient returns an HTTP client that rejects connections to
|
||||||
|
// private, loopback, link-local, multicast, and unspecified IP addresses.
|
||||||
|
// This prevents DNS rebinding attacks where a hostname passes URL-level
|
||||||
|
// validation but resolves to an internal IP at dial time.
|
||||||
|
func newSSRFSafeHTTPClient() *http.Client {
|
||||||
|
return &http.Client{
|
||||||
|
Transport: &http.Transport{
|
||||||
|
DialContext: (&net.Dialer{
|
||||||
|
Control: func(_ string, address string, _ syscall.RawConn) error {
|
||||||
|
host, _, err := net.SplitHostPort(address)
|
||||||
|
if err != nil {
|
||||||
|
return xerrors.Errorf("split host/port: %w", err)
|
||||||
|
}
|
||||||
|
ip, err := netip.ParseAddr(host)
|
||||||
|
if err != nil {
|
||||||
|
return xerrors.Errorf("parse resolved IP: %w", err)
|
||||||
|
}
|
||||||
|
if ip.IsPrivate() || ip.IsLoopback() || ip.IsLinkLocalUnicast() ||
|
||||||
|
ip.IsLinkLocalMulticast() || ip.IsMulticast() ||
|
||||||
|
ip.IsUnspecified() {
|
||||||
|
return xerrors.Errorf(
|
||||||
|
"webpush endpoint resolved to non-public address %s", ip.String(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}).DialContext,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// RegenerateVAPIDKeys regenerates the VAPID keys and deletes all existing
|
// RegenerateVAPIDKeys regenerates the VAPID keys and deletes all existing
|
||||||
// push subscriptions as part of the transaction, as they are no longer valid.
|
// push subscriptions as part of the transaction, as they are no longer valid.
|
||||||
func RegenerateVAPIDKeys(ctx context.Context, db database.Store) (newPrivateKey string, newPublicKey string, err error) {
|
func RegenerateVAPIDKeys(ctx context.Context, db database.Store) (newPrivateKey string, newPublicKey string, err error) {
|
||||||
|
|||||||
@@ -387,6 +387,8 @@ func assertWebpushPayload(t testing.TB, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// setupPushTest creates a common test setup for webpush notification tests.
|
// setupPushTest creates a common test setup for webpush notification tests.
|
||||||
|
// The test HTTP client bypasses SSRF protection so that httptest.Server
|
||||||
|
// (bound to 127.0.0.1) can be reached.
|
||||||
func setupPushTest(ctx context.Context, t *testing.T, handlerFunc func(w http.ResponseWriter, r *http.Request)) (webpush.Dispatcher, database.Store, string) {
|
func setupPushTest(ctx context.Context, t *testing.T, handlerFunc func(w http.ResponseWriter, r *http.Request)) (webpush.Dispatcher, database.Store, string) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
db, _ := dbtestutil.NewDB(t)
|
db, _ := dbtestutil.NewDB(t)
|
||||||
@@ -400,6 +402,9 @@ func setupPushTestWithOptions(ctx context.Context, t *testing.T, db database.Sto
|
|||||||
server := httptest.NewServer(http.HandlerFunc(handlerFunc))
|
server := httptest.NewServer(http.HandlerFunc(handlerFunc))
|
||||||
t.Cleanup(server.Close)
|
t.Cleanup(server.Close)
|
||||||
|
|
||||||
|
// Use an unrestricted HTTP client for tests. The default SSRF-safe
|
||||||
|
// client rejects loopback addresses, which blocks httptest.Server.
|
||||||
|
opts = append(opts, webpush.WithHTTPClient(http.DefaultClient))
|
||||||
manager, err := webpush.New(ctx, &logger, db, "http://example.com", opts...)
|
manager, err := webpush.New(ctx, &logger, db, "http://example.com", opts...)
|
||||||
require.NoError(t, err, "Failed to create webpush manager")
|
require.NoError(t, err, "Failed to create webpush manager")
|
||||||
|
|
||||||
@@ -423,3 +428,56 @@ func TestNoopWebpusher(t *testing.T) {
|
|||||||
|
|
||||||
require.Empty(t, noop.PublicKey())
|
require.Empty(t, noop.PublicKey())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestSSRFPrevention verifies that the default SSRF-safe HTTP client blocks
|
||||||
|
// webpush delivery to loopback (and other non-public) addresses. This
|
||||||
|
// reproduces the attack vector from the original SSRF PoC: an authenticated
|
||||||
|
// user supplies a localhost endpoint in their webpush subscription, and the
|
||||||
|
// server must refuse to connect.
|
||||||
|
func TestSSRFPrevention(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx := testutil.Context(t, testutil.WaitShort)
|
||||||
|
|
||||||
|
// Start a server that records whether it received a request.
|
||||||
|
var received atomic.Bool
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
received.Store(true)
|
||||||
|
w.WriteHeader(http.StatusCreated)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
// Create a dispatcher via New() WITHOUT WithHTTPClient so it
|
||||||
|
// uses the default SSRF-safe client that blocks loopback.
|
||||||
|
db, _ := dbtestutil.NewDB(t)
|
||||||
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||||
|
manager, err := webpush.New(ctx, &logger, db, "http://example.com")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Test() calls webpushSend directly with the supplied endpoint.
|
||||||
|
err = manager.Test(ctx, codersdk.WebpushSubscription{
|
||||||
|
Endpoint: server.URL,
|
||||||
|
AuthKey: validEndpointAuthKey,
|
||||||
|
P256DHKey: validEndpointP256dhKey,
|
||||||
|
})
|
||||||
|
require.Error(t, err, "SSRF-safe client should reject Test() to loopback address")
|
||||||
|
assert.False(t, received.Load(), "Test() request should not reach the localhost server")
|
||||||
|
|
||||||
|
// Dispatch() goes through the subscription cache → webpushSend path.
|
||||||
|
user := dbgen.User(t, db, database.User{})
|
||||||
|
_, err = db.InsertWebpushSubscription(ctx, database.InsertWebpushSubscriptionParams{
|
||||||
|
CreatedAt: dbtime.Now(),
|
||||||
|
UserID: user.ID,
|
||||||
|
Endpoint: server.URL,
|
||||||
|
EndpointAuthKey: validEndpointAuthKey,
|
||||||
|
EndpointP256dhKey: validEndpointP256dhKey,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = manager.Dispatch(ctx, user.ID, codersdk.WebpushMessage{
|
||||||
|
Title: "SSRF test",
|
||||||
|
Body: "This should not arrive.",
|
||||||
|
})
|
||||||
|
require.Error(t, err, "SSRF-safe client should reject Dispatch() to loopback address")
|
||||||
|
assert.False(t, received.Load(), "Dispatch() request should not reach the localhost server")
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,151 @@
|
|||||||
|
package coderd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestValidateWebpushEndpoint(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
endpoint string
|
||||||
|
wantErr bool
|
||||||
|
errSubstr string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid https endpoint",
|
||||||
|
endpoint: "https://fcm.googleapis.com/fcm/send/abc123",
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid https endpoint with port",
|
||||||
|
endpoint: "https://push.example.com:8443/subscription",
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "relative URL",
|
||||||
|
endpoint: "/push/subscription",
|
||||||
|
wantErr: true,
|
||||||
|
errSubstr: "absolute URL",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "http scheme rejected",
|
||||||
|
endpoint: "http://push.example.com/subscription",
|
||||||
|
wantErr: true,
|
||||||
|
errSubstr: "scheme must be https",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "custom scheme rejected",
|
||||||
|
endpoint: "ws://push.example.com/subscription",
|
||||||
|
wantErr: true,
|
||||||
|
errSubstr: "scheme must be https",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty host",
|
||||||
|
endpoint: "https:///path",
|
||||||
|
wantErr: true,
|
||||||
|
errSubstr: "host is required",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "userinfo rejected",
|
||||||
|
endpoint: "https://user:pass@push.example.com/subscription",
|
||||||
|
wantErr: true,
|
||||||
|
errSubstr: "must not include userinfo",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "localhost rejected",
|
||||||
|
endpoint: "https://localhost/subscription",
|
||||||
|
wantErr: true,
|
||||||
|
errSubstr: "must not be localhost",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "subdomain of localhost rejected",
|
||||||
|
endpoint: "https://foo.localhost/subscription",
|
||||||
|
wantErr: true,
|
||||||
|
errSubstr: "must not be localhost",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "loopback IPv4 rejected",
|
||||||
|
endpoint: "https://127.0.0.1/subscription",
|
||||||
|
wantErr: true,
|
||||||
|
errSubstr: "must not be private",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "private 10.x rejected",
|
||||||
|
endpoint: "https://10.0.0.1/subscription",
|
||||||
|
wantErr: true,
|
||||||
|
errSubstr: "must not be private",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "private 192.168.x rejected",
|
||||||
|
endpoint: "https://192.168.1.1/subscription",
|
||||||
|
wantErr: true,
|
||||||
|
errSubstr: "must not be private",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "private 172.16.x rejected",
|
||||||
|
endpoint: "https://172.16.0.1/subscription",
|
||||||
|
wantErr: true,
|
||||||
|
errSubstr: "must not be private",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "link-local IPv4 rejected",
|
||||||
|
endpoint: "https://169.254.1.1/subscription",
|
||||||
|
wantErr: true,
|
||||||
|
errSubstr: "must not be private",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unspecified IPv4 rejected",
|
||||||
|
endpoint: "https://0.0.0.0/subscription",
|
||||||
|
wantErr: true,
|
||||||
|
errSubstr: "must not be private",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "loopback IPv6 rejected",
|
||||||
|
endpoint: "https://[::1]/subscription",
|
||||||
|
wantErr: true,
|
||||||
|
errSubstr: "must not be private",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unspecified IPv6 rejected",
|
||||||
|
endpoint: "https://[::]/subscription",
|
||||||
|
wantErr: true,
|
||||||
|
errSubstr: "must not be private",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "link-local IPv6 rejected",
|
||||||
|
endpoint: "https://[fe80::1]/subscription",
|
||||||
|
wantErr: true,
|
||||||
|
errSubstr: "must not be private",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multicast IPv4 rejected",
|
||||||
|
endpoint: "https://224.0.0.1/subscription",
|
||||||
|
wantErr: true,
|
||||||
|
errSubstr: "must not be private",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "public IPv4 allowed",
|
||||||
|
endpoint: "https://203.0.113.1/subscription",
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
err := validateWebpushEndpoint(tt.endpoint)
|
||||||
|
if tt.wantErr {
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), tt.errSubstr,
|
||||||
|
"error should mention %q", tt.errSubstr)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
+74
-18
@@ -3,7 +3,7 @@ package coderd_test
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@@ -30,49 +30,48 @@ func TestWebpushSubscribeUnsubscribe(t *testing.T) {
|
|||||||
|
|
||||||
ctx := testutil.Context(t, testutil.WaitShort)
|
ctx := testutil.Context(t, testutil.WaitShort)
|
||||||
|
|
||||||
client := coderdtest.New(t, &coderdtest.Options{})
|
dispatcher := &testWebpushDispatcher{}
|
||||||
|
client := coderdtest.New(t, &coderdtest.Options{
|
||||||
|
WebpushDispatcher: dispatcher,
|
||||||
|
})
|
||||||
owner := coderdtest.CreateFirstUser(t, client)
|
owner := coderdtest.CreateFirstUser(t, client)
|
||||||
memberClient, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
|
memberClient, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
|
||||||
_, anotherMember := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
|
_, anotherMember := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
|
||||||
|
endpoint := "https://push.example.com/subscription/abc123"
|
||||||
var handlerCalls atomic.Int32
|
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
||||||
w.WriteHeader(http.StatusCreated)
|
|
||||||
handlerCalls.Add(1)
|
|
||||||
}))
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
// Seed the dispatcher cache with an empty subscription set. Creating the
|
// Seed the dispatcher cache with an empty subscription set. Creating the
|
||||||
// subscription should invalidate that entry so the next dispatch sees the new
|
// subscription should invalidate that entry so the next dispatch sees the new
|
||||||
// subscription immediately.
|
// subscription immediately.
|
||||||
err := memberClient.PostTestWebpushMessage(ctx)
|
err := memberClient.PostTestWebpushMessage(ctx)
|
||||||
require.NoError(t, err, "test webpush message without a subscription")
|
require.NoError(t, err, "test webpush message without a subscription")
|
||||||
require.Zero(t, handlerCalls.Load(), "a user without subscriptions should not receive a push")
|
require.Equal(t, int32(1), dispatcher.dispatchCalls.Load(), "dispatch should be called even with no subscriptions")
|
||||||
|
|
||||||
err = memberClient.PostWebpushSubscription(ctx, "me", codersdk.WebpushSubscription{
|
err = memberClient.PostWebpushSubscription(ctx, "me", codersdk.WebpushSubscription{
|
||||||
Endpoint: server.URL,
|
Endpoint: endpoint,
|
||||||
AuthKey: validEndpointAuthKey,
|
AuthKey: validEndpointAuthKey,
|
||||||
P256DHKey: validEndpointP256dhKey,
|
P256DHKey: validEndpointP256dhKey,
|
||||||
})
|
})
|
||||||
require.NoError(t, err, "create webpush subscription")
|
require.NoError(t, err, "create webpush subscription")
|
||||||
require.Equal(t, int32(1), handlerCalls.Load(), "subscription validation should hit the endpoint once")
|
require.Equal(t, int32(1), dispatcher.testCalls.Load(), "subscription validation should call dispatcher test once")
|
||||||
|
require.Equal(t, 1, dispatcher.invalidateCount(), "subscribing should invalidate the user's cached subscriptions")
|
||||||
|
|
||||||
err = memberClient.PostTestWebpushMessage(ctx)
|
err = memberClient.PostTestWebpushMessage(ctx)
|
||||||
require.NoError(t, err, "test webpush message after subscribing")
|
require.NoError(t, err, "test webpush message after subscribing")
|
||||||
require.Equal(t, int32(2), handlerCalls.Load(), "the dispatcher should invalidate empty cache entries after subscribing")
|
require.Equal(t, int32(2), dispatcher.dispatchCalls.Load(), "dispatch should be called after subscribing")
|
||||||
|
|
||||||
err = memberClient.DeleteWebpushSubscription(ctx, "me", codersdk.DeleteWebpushSubscription{
|
err = memberClient.DeleteWebpushSubscription(ctx, "me", codersdk.DeleteWebpushSubscription{
|
||||||
Endpoint: server.URL,
|
Endpoint: endpoint,
|
||||||
})
|
})
|
||||||
require.NoError(t, err, "delete webpush subscription")
|
require.NoError(t, err, "delete webpush subscription")
|
||||||
|
require.Equal(t, 2, dispatcher.invalidateCount(), "unsubscribing should invalidate the user's cached subscriptions")
|
||||||
|
|
||||||
err = memberClient.PostTestWebpushMessage(ctx)
|
err = memberClient.PostTestWebpushMessage(ctx)
|
||||||
require.NoError(t, err, "test webpush message after unsubscribing")
|
require.NoError(t, err, "test webpush message after unsubscribing")
|
||||||
require.Equal(t, int32(2), handlerCalls.Load(), "the dispatcher should invalidate cached subscriptions after unsubscribing")
|
require.Equal(t, int32(3), dispatcher.dispatchCalls.Load(), "dispatch should be called after unsubscribing")
|
||||||
|
|
||||||
// Deleting the subscription for a non-existent endpoint should return a 404.
|
// Deleting the subscription for a non-existent endpoint should return a 404.
|
||||||
err = memberClient.DeleteWebpushSubscription(ctx, "me", codersdk.DeleteWebpushSubscription{
|
err = memberClient.DeleteWebpushSubscription(ctx, "me", codersdk.DeleteWebpushSubscription{
|
||||||
Endpoint: server.URL,
|
Endpoint: endpoint,
|
||||||
})
|
})
|
||||||
var sdkError *codersdk.Error
|
var sdkError *codersdk.Error
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
@@ -81,7 +80,7 @@ func TestWebpushSubscribeUnsubscribe(t *testing.T) {
|
|||||||
|
|
||||||
// Creating a subscription for another user should not be allowed.
|
// Creating a subscription for another user should not be allowed.
|
||||||
err = memberClient.PostWebpushSubscription(ctx, anotherMember.ID.String(), codersdk.WebpushSubscription{
|
err = memberClient.PostWebpushSubscription(ctx, anotherMember.ID.String(), codersdk.WebpushSubscription{
|
||||||
Endpoint: server.URL,
|
Endpoint: endpoint,
|
||||||
AuthKey: validEndpointAuthKey,
|
AuthKey: validEndpointAuthKey,
|
||||||
P256DHKey: validEndpointP256dhKey,
|
P256DHKey: validEndpointP256dhKey,
|
||||||
})
|
})
|
||||||
@@ -89,11 +88,33 @@ func TestWebpushSubscribeUnsubscribe(t *testing.T) {
|
|||||||
|
|
||||||
// Deleting a subscription for another user should not be allowed.
|
// Deleting a subscription for another user should not be allowed.
|
||||||
err = memberClient.DeleteWebpushSubscription(ctx, anotherMember.ID.String(), codersdk.DeleteWebpushSubscription{
|
err = memberClient.DeleteWebpushSubscription(ctx, anotherMember.ID.String(), codersdk.DeleteWebpushSubscription{
|
||||||
Endpoint: server.URL,
|
Endpoint: endpoint,
|
||||||
})
|
})
|
||||||
require.Error(t, err, "delete webpush subscription for another user")
|
require.Error(t, err, "delete webpush subscription for another user")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWebpushSubscribeRejectsInvalidEndpoint(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx := testutil.Context(t, testutil.WaitShort)
|
||||||
|
client := coderdtest.New(t, &coderdtest.Options{
|
||||||
|
WebpushDispatcher: &testWebpushDispatcher{},
|
||||||
|
})
|
||||||
|
owner := coderdtest.CreateFirstUser(t, client)
|
||||||
|
memberClient, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
|
||||||
|
|
||||||
|
err := memberClient.PostWebpushSubscription(ctx, "me", codersdk.WebpushSubscription{
|
||||||
|
Endpoint: "http://127.0.0.1:8080/subscription",
|
||||||
|
AuthKey: validEndpointAuthKey,
|
||||||
|
P256DHKey: validEndpointP256dhKey,
|
||||||
|
})
|
||||||
|
var sdkError *codersdk.Error
|
||||||
|
require.Error(t, err)
|
||||||
|
require.ErrorAsf(t, err, &sdkError, "error should be of type *codersdk.Error")
|
||||||
|
require.Equal(t, http.StatusBadRequest, sdkError.StatusCode())
|
||||||
|
require.Contains(t, sdkError.Error(), "endpoint URL scheme must be https")
|
||||||
|
}
|
||||||
|
|
||||||
// testWebpushErrorStore wraps a real database.Store and allows injecting
|
// testWebpushErrorStore wraps a real database.Store and allows injecting
|
||||||
// errors into GetWebpushSubscriptionsByUserID.
|
// errors into GetWebpushSubscriptionsByUserID.
|
||||||
type testWebpushErrorStore struct {
|
type testWebpushErrorStore struct {
|
||||||
@@ -101,6 +122,41 @@ type testWebpushErrorStore struct {
|
|||||||
getWebpushSubscriptionsErr atomic.Pointer[error]
|
getWebpushSubscriptionsErr atomic.Pointer[error]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type testWebpushDispatcher struct {
|
||||||
|
testCalls atomic.Int32
|
||||||
|
dispatchCalls atomic.Int32
|
||||||
|
invalidateUserIDs []uuid.UUID
|
||||||
|
invalidateUserLock sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *testWebpushDispatcher) Dispatch(_ context.Context, _ uuid.UUID, _ codersdk.WebpushMessage) error {
|
||||||
|
d.dispatchCalls.Add(1)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *testWebpushDispatcher) Test(_ context.Context, _ codersdk.WebpushSubscription) error {
|
||||||
|
d.testCalls.Add(1)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*testWebpushDispatcher) PublicKey() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvalidateUser implements webpush.SubscriptionCacheInvalidator so the
|
||||||
|
// handler exercises the cache-invalidation path on subscribe/unsubscribe.
|
||||||
|
func (d *testWebpushDispatcher) InvalidateUser(userID uuid.UUID) {
|
||||||
|
d.invalidateUserLock.Lock()
|
||||||
|
defer d.invalidateUserLock.Unlock()
|
||||||
|
d.invalidateUserIDs = append(d.invalidateUserIDs, userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *testWebpushDispatcher) invalidateCount() int {
|
||||||
|
d.invalidateUserLock.Lock()
|
||||||
|
defer d.invalidateUserLock.Unlock()
|
||||||
|
return len(d.invalidateUserIDs)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *testWebpushErrorStore) GetWebpushSubscriptionsByUserID(ctx context.Context, userID uuid.UUID) ([]database.WebpushSubscription, error) {
|
func (s *testWebpushErrorStore) GetWebpushSubscriptionsByUserID(ctx context.Context, userID uuid.UUID) ([]database.WebpushSubscription, error) {
|
||||||
if err := s.getWebpushSubscriptionsErr.Load(); err != nil {
|
if err := s.getWebpushSubscriptionsErr.Load(); err != nil {
|
||||||
return nil, *err
|
return nil, *err
|
||||||
|
|||||||
Reference in New Issue
Block a user