From 5812f84e1cc9c312436d3ef0e44074fcc4f8219f Mon Sep 17 00:00:00 2001 From: Thomas Kosiewski Date: Wed, 15 Apr 2026 11:31:43 +0200 Subject: [PATCH] fix(coderd): validate webpush subscription endpoints (#24347) Co-authored-by: Cian Johnston --- coderd/webpush.go | 48 ++++++++++ coderd/webpush/webpush.go | 55 ++++++++++++ coderd/webpush/webpush_test.go | 58 ++++++++++++ coderd/webpush_internal_test.go | 151 ++++++++++++++++++++++++++++++++ coderd/webpush_test.go | 92 +++++++++++++++---- 5 files changed, 386 insertions(+), 18 deletions(-) create mode 100644 coderd/webpush_internal_test.go diff --git a/coderd/webpush.go b/coderd/webpush.go index e275873400..adb9b93107 100644 --- a/coderd/webpush.go +++ b/coderd/webpush.go @@ -4,7 +4,12 @@ import ( "database/sql" "errors" "net/http" + "net/netip" + "net/url" "slices" + "strings" + + "golang.org/x/xerrors" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbtime" @@ -33,6 +38,13 @@ func (api *API) postUserWebpushSubscription(rw http.ResponseWriter, r *http.Requ if !httpapi.Read(ctx, rw, r, &req) { return } + if err := validateWebpushEndpoint(req.Endpoint); err != nil { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Invalid webpush endpoint.", + Detail: err.Error(), + }) + return + } if err := api.WebpushDispatcher.Test(ctx, req); err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ @@ -62,6 +74,42 @@ func (api *API) postUserWebpushSubscription(rw http.ResponseWriter, r *http.Requ rw.WriteHeader(http.StatusNoContent) } +func validateWebpushEndpoint(rawEndpoint string) error { + endpoint, err := url.Parse(rawEndpoint) + if err != nil { + return xerrors.Errorf("parse endpoint URL: %w", err) + } + if !endpoint.IsAbs() { + return xerrors.New("endpoint must be an absolute URL") + } + if endpoint.Scheme != "https" { + return xerrors.New("endpoint URL scheme must be https") + } + if endpoint.Host == "" { + return xerrors.New("endpoint host is required") + } + if endpoint.User != nil { + return xerrors.New("endpoint URL must not include userinfo") + } + + hostname := strings.ToLower(endpoint.Hostname()) + if hostname == "" { + return xerrors.New("endpoint hostname is required") + } + if hostname == "localhost" || strings.HasSuffix(hostname, ".localhost") { + return xerrors.New("endpoint hostname must not be localhost") + } + + if ip, err := netip.ParseAddr(hostname); err == nil && + (ip.IsPrivate() || ip.IsLoopback() || ip.IsLinkLocalUnicast() || + ip.IsLinkLocalMulticast() || ip.IsMulticast() || + ip.IsUnspecified()) { + return xerrors.New("endpoint IP must not be private, loopback, link-local, multicast, or unspecified") + } + + return nil +} + // @Summary Delete user webpush subscription // @ID delete-user-webpush-subscription // @Security CoderSessionToken diff --git a/coderd/webpush/webpush.go b/coderd/webpush/webpush.go index ee7ad47075..dc91d1eb9f 100644 --- a/coderd/webpush/webpush.go +++ b/coderd/webpush/webpush.go @@ -6,9 +6,12 @@ import ( "encoding/json" "errors" "io" + "net" "net/http" + "net/netip" "slices" "sync" + "syscall" "time" "github.com/SherClockHolmes/webpush-go" @@ -47,6 +50,7 @@ type SubscriptionCacheInvalidator interface { type options struct { clock quartz.Clock subscriptionCacheTTL time.Duration + httpClient *http.Client } // Option configures optional behavior for a Webpusher. @@ -68,6 +72,15 @@ func WithSubscriptionCacheTTL(ttl time.Duration) Option { } } +// WithHTTPClient overrides the default SSRF-safe HTTP client used to deliver +// push notifications. This is intended for tests that need to deliver to +// localhost test servers. +func WithHTTPClient(client *http.Client) Option { + return func(o *options) { + o.httpClient = client + } +} + // New creates a new Dispatcher to dispatch web push notifications. // // This is *not* integrated into the enqueue system unfortunately. @@ -90,6 +103,9 @@ func New(ctx context.Context, log *slog.Logger, db database.Store, vapidSub stri if cfg.subscriptionCacheTTL <= 0 { cfg.subscriptionCacheTTL = defaultSubscriptionCacheTTL } + if cfg.httpClient == nil { + cfg.httpClient = newSSRFSafeHTTPClient() + } keys, err := db.GetWebpushVAPIDKeys(ctx) if err != nil { @@ -121,6 +137,7 @@ func New(ctx context.Context, log *slog.Logger, db database.Store, vapidSub stri subscriptionCacheTTL: cfg.subscriptionCacheTTL, subscriptionCache: make(map[uuid.UUID]cachedSubscriptions), subscriptionGenerations: make(map[uuid.UUID]uint64), + httpClient: cfg.httpClient, }, nil } @@ -142,6 +159,12 @@ type Webpusher struct { VAPIDPublicKey string VAPIDPrivateKey string + // httpClient is an SSRF-safe HTTP client that rejects connections to + // private, loopback, and link-local IP addresses at dial time. This + // closes the DNS rebinding TOCTOU gap where a hostname passes URL + // validation but resolves to a private IP when the connection is made. + httpClient *http.Client + clock quartz.Clock cacheMu sync.RWMutex @@ -338,6 +361,7 @@ func (n *Webpusher) webpushSend(ctx context.Context, msg []byte, endpoint string Endpoint: endpoint, Keys: keys, }, &webpush.Options{ + HTTPClient: n.httpClient, Subscriber: n.vapidSub, VAPIDPublicKey: n.VAPIDPublicKey, VAPIDPrivateKey: n.VAPIDPrivateKey, @@ -407,6 +431,37 @@ func (*NoopWebpusher) PublicKey() string { return "" } +// newSSRFSafeHTTPClient returns an HTTP client that rejects connections to +// private, loopback, link-local, multicast, and unspecified IP addresses. +// This prevents DNS rebinding attacks where a hostname passes URL-level +// validation but resolves to an internal IP at dial time. +func newSSRFSafeHTTPClient() *http.Client { + return &http.Client{ + Transport: &http.Transport{ + DialContext: (&net.Dialer{ + Control: func(_ string, address string, _ syscall.RawConn) error { + host, _, err := net.SplitHostPort(address) + if err != nil { + return xerrors.Errorf("split host/port: %w", err) + } + ip, err := netip.ParseAddr(host) + if err != nil { + return xerrors.Errorf("parse resolved IP: %w", err) + } + if ip.IsPrivate() || ip.IsLoopback() || ip.IsLinkLocalUnicast() || + ip.IsLinkLocalMulticast() || ip.IsMulticast() || + ip.IsUnspecified() { + return xerrors.Errorf( + "webpush endpoint resolved to non-public address %s", ip.String(), + ) + } + return nil + }, + }).DialContext, + }, + } +} + // RegenerateVAPIDKeys regenerates the VAPID keys and deletes all existing // push subscriptions as part of the transaction, as they are no longer valid. func RegenerateVAPIDKeys(ctx context.Context, db database.Store) (newPrivateKey string, newPublicKey string, err error) { diff --git a/coderd/webpush/webpush_test.go b/coderd/webpush/webpush_test.go index 1da6fcdd54..61c83d4bd4 100644 --- a/coderd/webpush/webpush_test.go +++ b/coderd/webpush/webpush_test.go @@ -387,6 +387,8 @@ func assertWebpushPayload(t testing.TB, r *http.Request) { } // setupPushTest creates a common test setup for webpush notification tests. +// The test HTTP client bypasses SSRF protection so that httptest.Server +// (bound to 127.0.0.1) can be reached. func setupPushTest(ctx context.Context, t *testing.T, handlerFunc func(w http.ResponseWriter, r *http.Request)) (webpush.Dispatcher, database.Store, string) { t.Helper() db, _ := dbtestutil.NewDB(t) @@ -400,6 +402,9 @@ func setupPushTestWithOptions(ctx context.Context, t *testing.T, db database.Sto server := httptest.NewServer(http.HandlerFunc(handlerFunc)) t.Cleanup(server.Close) + // Use an unrestricted HTTP client for tests. The default SSRF-safe + // client rejects loopback addresses, which blocks httptest.Server. + opts = append(opts, webpush.WithHTTPClient(http.DefaultClient)) manager, err := webpush.New(ctx, &logger, db, "http://example.com", opts...) require.NoError(t, err, "Failed to create webpush manager") @@ -423,3 +428,56 @@ func TestNoopWebpusher(t *testing.T) { require.Empty(t, noop.PublicKey()) } + +// TestSSRFPrevention verifies that the default SSRF-safe HTTP client blocks +// webpush delivery to loopback (and other non-public) addresses. This +// reproduces the attack vector from the original SSRF PoC: an authenticated +// user supplies a localhost endpoint in their webpush subscription, and the +// server must refuse to connect. +func TestSSRFPrevention(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + + // Start a server that records whether it received a request. + var received atomic.Bool + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + received.Store(true) + w.WriteHeader(http.StatusCreated) + })) + defer server.Close() + + // Create a dispatcher via New() WITHOUT WithHTTPClient so it + // uses the default SSRF-safe client that blocks loopback. + db, _ := dbtestutil.NewDB(t) + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + manager, err := webpush.New(ctx, &logger, db, "http://example.com") + require.NoError(t, err) + + // Test() calls webpushSend directly with the supplied endpoint. + err = manager.Test(ctx, codersdk.WebpushSubscription{ + Endpoint: server.URL, + AuthKey: validEndpointAuthKey, + P256DHKey: validEndpointP256dhKey, + }) + require.Error(t, err, "SSRF-safe client should reject Test() to loopback address") + assert.False(t, received.Load(), "Test() request should not reach the localhost server") + + // Dispatch() goes through the subscription cache → webpushSend path. + user := dbgen.User(t, db, database.User{}) + _, err = db.InsertWebpushSubscription(ctx, database.InsertWebpushSubscriptionParams{ + CreatedAt: dbtime.Now(), + UserID: user.ID, + Endpoint: server.URL, + EndpointAuthKey: validEndpointAuthKey, + EndpointP256dhKey: validEndpointP256dhKey, + }) + require.NoError(t, err) + + err = manager.Dispatch(ctx, user.ID, codersdk.WebpushMessage{ + Title: "SSRF test", + Body: "This should not arrive.", + }) + require.Error(t, err, "SSRF-safe client should reject Dispatch() to loopback address") + assert.False(t, received.Load(), "Dispatch() request should not reach the localhost server") +} diff --git a/coderd/webpush_internal_test.go b/coderd/webpush_internal_test.go new file mode 100644 index 0000000000..6f6d45987d --- /dev/null +++ b/coderd/webpush_internal_test.go @@ -0,0 +1,151 @@ +package coderd + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestValidateWebpushEndpoint(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + endpoint string + wantErr bool + errSubstr string + }{ + { + name: "valid https endpoint", + endpoint: "https://fcm.googleapis.com/fcm/send/abc123", + wantErr: false, + }, + { + name: "valid https endpoint with port", + endpoint: "https://push.example.com:8443/subscription", + wantErr: false, + }, + { + name: "relative URL", + endpoint: "/push/subscription", + wantErr: true, + errSubstr: "absolute URL", + }, + { + name: "http scheme rejected", + endpoint: "http://push.example.com/subscription", + wantErr: true, + errSubstr: "scheme must be https", + }, + { + name: "custom scheme rejected", + endpoint: "ws://push.example.com/subscription", + wantErr: true, + errSubstr: "scheme must be https", + }, + { + name: "empty host", + endpoint: "https:///path", + wantErr: true, + errSubstr: "host is required", + }, + { + name: "userinfo rejected", + endpoint: "https://user:pass@push.example.com/subscription", + wantErr: true, + errSubstr: "must not include userinfo", + }, + { + name: "localhost rejected", + endpoint: "https://localhost/subscription", + wantErr: true, + errSubstr: "must not be localhost", + }, + { + name: "subdomain of localhost rejected", + endpoint: "https://foo.localhost/subscription", + wantErr: true, + errSubstr: "must not be localhost", + }, + { + name: "loopback IPv4 rejected", + endpoint: "https://127.0.0.1/subscription", + wantErr: true, + errSubstr: "must not be private", + }, + { + name: "private 10.x rejected", + endpoint: "https://10.0.0.1/subscription", + wantErr: true, + errSubstr: "must not be private", + }, + { + name: "private 192.168.x rejected", + endpoint: "https://192.168.1.1/subscription", + wantErr: true, + errSubstr: "must not be private", + }, + { + name: "private 172.16.x rejected", + endpoint: "https://172.16.0.1/subscription", + wantErr: true, + errSubstr: "must not be private", + }, + { + name: "link-local IPv4 rejected", + endpoint: "https://169.254.1.1/subscription", + wantErr: true, + errSubstr: "must not be private", + }, + { + name: "unspecified IPv4 rejected", + endpoint: "https://0.0.0.0/subscription", + wantErr: true, + errSubstr: "must not be private", + }, + { + name: "loopback IPv6 rejected", + endpoint: "https://[::1]/subscription", + wantErr: true, + errSubstr: "must not be private", + }, + { + name: "unspecified IPv6 rejected", + endpoint: "https://[::]/subscription", + wantErr: true, + errSubstr: "must not be private", + }, + { + name: "link-local IPv6 rejected", + endpoint: "https://[fe80::1]/subscription", + wantErr: true, + errSubstr: "must not be private", + }, + { + name: "multicast IPv4 rejected", + endpoint: "https://224.0.0.1/subscription", + wantErr: true, + errSubstr: "must not be private", + }, + { + name: "public IPv4 allowed", + endpoint: "https://203.0.113.1/subscription", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + err := validateWebpushEndpoint(tt.endpoint) + if tt.wantErr { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errSubstr, + "error should mention %q", tt.errSubstr) + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/coderd/webpush_test.go b/coderd/webpush_test.go index 353cc676b4..696a12052a 100644 --- a/coderd/webpush_test.go +++ b/coderd/webpush_test.go @@ -3,7 +3,7 @@ package coderd_test import ( "context" "net/http" - "net/http/httptest" + "sync" "sync/atomic" "testing" @@ -30,49 +30,48 @@ func TestWebpushSubscribeUnsubscribe(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) - client := coderdtest.New(t, &coderdtest.Options{}) + dispatcher := &testWebpushDispatcher{} + client := coderdtest.New(t, &coderdtest.Options{ + WebpushDispatcher: dispatcher, + }) owner := coderdtest.CreateFirstUser(t, client) memberClient, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID) _, anotherMember := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID) - - var handlerCalls atomic.Int32 - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusCreated) - handlerCalls.Add(1) - })) - defer server.Close() + endpoint := "https://push.example.com/subscription/abc123" // Seed the dispatcher cache with an empty subscription set. Creating the // subscription should invalidate that entry so the next dispatch sees the new // subscription immediately. err := memberClient.PostTestWebpushMessage(ctx) require.NoError(t, err, "test webpush message without a subscription") - require.Zero(t, handlerCalls.Load(), "a user without subscriptions should not receive a push") + require.Equal(t, int32(1), dispatcher.dispatchCalls.Load(), "dispatch should be called even with no subscriptions") err = memberClient.PostWebpushSubscription(ctx, "me", codersdk.WebpushSubscription{ - Endpoint: server.URL, + Endpoint: endpoint, AuthKey: validEndpointAuthKey, P256DHKey: validEndpointP256dhKey, }) require.NoError(t, err, "create webpush subscription") - require.Equal(t, int32(1), handlerCalls.Load(), "subscription validation should hit the endpoint once") + require.Equal(t, int32(1), dispatcher.testCalls.Load(), "subscription validation should call dispatcher test once") + require.Equal(t, 1, dispatcher.invalidateCount(), "subscribing should invalidate the user's cached subscriptions") err = memberClient.PostTestWebpushMessage(ctx) require.NoError(t, err, "test webpush message after subscribing") - require.Equal(t, int32(2), handlerCalls.Load(), "the dispatcher should invalidate empty cache entries after subscribing") + require.Equal(t, int32(2), dispatcher.dispatchCalls.Load(), "dispatch should be called after subscribing") err = memberClient.DeleteWebpushSubscription(ctx, "me", codersdk.DeleteWebpushSubscription{ - Endpoint: server.URL, + Endpoint: endpoint, }) require.NoError(t, err, "delete webpush subscription") + require.Equal(t, 2, dispatcher.invalidateCount(), "unsubscribing should invalidate the user's cached subscriptions") err = memberClient.PostTestWebpushMessage(ctx) require.NoError(t, err, "test webpush message after unsubscribing") - require.Equal(t, int32(2), handlerCalls.Load(), "the dispatcher should invalidate cached subscriptions after unsubscribing") + require.Equal(t, int32(3), dispatcher.dispatchCalls.Load(), "dispatch should be called after unsubscribing") // Deleting the subscription for a non-existent endpoint should return a 404. err = memberClient.DeleteWebpushSubscription(ctx, "me", codersdk.DeleteWebpushSubscription{ - Endpoint: server.URL, + Endpoint: endpoint, }) var sdkError *codersdk.Error require.Error(t, err) @@ -81,7 +80,7 @@ func TestWebpushSubscribeUnsubscribe(t *testing.T) { // Creating a subscription for another user should not be allowed. err = memberClient.PostWebpushSubscription(ctx, anotherMember.ID.String(), codersdk.WebpushSubscription{ - Endpoint: server.URL, + Endpoint: endpoint, AuthKey: validEndpointAuthKey, P256DHKey: validEndpointP256dhKey, }) @@ -89,11 +88,33 @@ func TestWebpushSubscribeUnsubscribe(t *testing.T) { // Deleting a subscription for another user should not be allowed. err = memberClient.DeleteWebpushSubscription(ctx, anotherMember.ID.String(), codersdk.DeleteWebpushSubscription{ - Endpoint: server.URL, + Endpoint: endpoint, }) require.Error(t, err, "delete webpush subscription for another user") } +func TestWebpushSubscribeRejectsInvalidEndpoint(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + client := coderdtest.New(t, &coderdtest.Options{ + WebpushDispatcher: &testWebpushDispatcher{}, + }) + owner := coderdtest.CreateFirstUser(t, client) + memberClient, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID) + + err := memberClient.PostWebpushSubscription(ctx, "me", codersdk.WebpushSubscription{ + Endpoint: "http://127.0.0.1:8080/subscription", + AuthKey: validEndpointAuthKey, + P256DHKey: validEndpointP256dhKey, + }) + var sdkError *codersdk.Error + require.Error(t, err) + require.ErrorAsf(t, err, &sdkError, "error should be of type *codersdk.Error") + require.Equal(t, http.StatusBadRequest, sdkError.StatusCode()) + require.Contains(t, sdkError.Error(), "endpoint URL scheme must be https") +} + // testWebpushErrorStore wraps a real database.Store and allows injecting // errors into GetWebpushSubscriptionsByUserID. type testWebpushErrorStore struct { @@ -101,6 +122,41 @@ type testWebpushErrorStore struct { getWebpushSubscriptionsErr atomic.Pointer[error] } +type testWebpushDispatcher struct { + testCalls atomic.Int32 + dispatchCalls atomic.Int32 + invalidateUserIDs []uuid.UUID + invalidateUserLock sync.Mutex +} + +func (d *testWebpushDispatcher) Dispatch(_ context.Context, _ uuid.UUID, _ codersdk.WebpushMessage) error { + d.dispatchCalls.Add(1) + return nil +} + +func (d *testWebpushDispatcher) Test(_ context.Context, _ codersdk.WebpushSubscription) error { + d.testCalls.Add(1) + return nil +} + +func (*testWebpushDispatcher) PublicKey() string { + return "" +} + +// InvalidateUser implements webpush.SubscriptionCacheInvalidator so the +// handler exercises the cache-invalidation path on subscribe/unsubscribe. +func (d *testWebpushDispatcher) InvalidateUser(userID uuid.UUID) { + d.invalidateUserLock.Lock() + defer d.invalidateUserLock.Unlock() + d.invalidateUserIDs = append(d.invalidateUserIDs, userID) +} + +func (d *testWebpushDispatcher) invalidateCount() int { + d.invalidateUserLock.Lock() + defer d.invalidateUserLock.Unlock() + return len(d.invalidateUserIDs) +} + func (s *testWebpushErrorStore) GetWebpushSubscriptionsByUserID(ctx context.Context, userID uuid.UUID) ([]database.WebpushSubscription, error) { if err := s.getWebpushSubscriptionsErr.Load(); err != nil { return nil, *err