fix(coderd): validate webpush subscription endpoints (#24347)

Co-authored-by: Cian Johnston <cian@coder.com>
This commit is contained in:
Thomas Kosiewski
2026-04-15 11:31:43 +02:00
committed by GitHub
parent e317f3b239
commit 5812f84e1c
5 changed files with 386 additions and 18 deletions
+48
View File
@@ -4,7 +4,12 @@ import (
"database/sql"
"errors"
"net/http"
"net/netip"
"net/url"
"slices"
"strings"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbtime"
@@ -33,6 +38,13 @@ func (api *API) postUserWebpushSubscription(rw http.ResponseWriter, r *http.Requ
if !httpapi.Read(ctx, rw, r, &req) {
return
}
if err := validateWebpushEndpoint(req.Endpoint); err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid webpush endpoint.",
Detail: err.Error(),
})
return
}
if err := api.WebpushDispatcher.Test(ctx, req); err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
@@ -62,6 +74,42 @@ func (api *API) postUserWebpushSubscription(rw http.ResponseWriter, r *http.Requ
rw.WriteHeader(http.StatusNoContent)
}
func validateWebpushEndpoint(rawEndpoint string) error {
endpoint, err := url.Parse(rawEndpoint)
if err != nil {
return xerrors.Errorf("parse endpoint URL: %w", err)
}
if !endpoint.IsAbs() {
return xerrors.New("endpoint must be an absolute URL")
}
if endpoint.Scheme != "https" {
return xerrors.New("endpoint URL scheme must be https")
}
if endpoint.Host == "" {
return xerrors.New("endpoint host is required")
}
if endpoint.User != nil {
return xerrors.New("endpoint URL must not include userinfo")
}
hostname := strings.ToLower(endpoint.Hostname())
if hostname == "" {
return xerrors.New("endpoint hostname is required")
}
if hostname == "localhost" || strings.HasSuffix(hostname, ".localhost") {
return xerrors.New("endpoint hostname must not be localhost")
}
if ip, err := netip.ParseAddr(hostname); err == nil &&
(ip.IsPrivate() || ip.IsLoopback() || ip.IsLinkLocalUnicast() ||
ip.IsLinkLocalMulticast() || ip.IsMulticast() ||
ip.IsUnspecified()) {
return xerrors.New("endpoint IP must not be private, loopback, link-local, multicast, or unspecified")
}
return nil
}
// @Summary Delete user webpush subscription
// @ID delete-user-webpush-subscription
// @Security CoderSessionToken
+55
View File
@@ -6,9 +6,12 @@ import (
"encoding/json"
"errors"
"io"
"net"
"net/http"
"net/netip"
"slices"
"sync"
"syscall"
"time"
"github.com/SherClockHolmes/webpush-go"
@@ -47,6 +50,7 @@ type SubscriptionCacheInvalidator interface {
type options struct {
clock quartz.Clock
subscriptionCacheTTL time.Duration
httpClient *http.Client
}
// Option configures optional behavior for a Webpusher.
@@ -68,6 +72,15 @@ func WithSubscriptionCacheTTL(ttl time.Duration) Option {
}
}
// WithHTTPClient overrides the default SSRF-safe HTTP client used to deliver
// push notifications. This is intended for tests that need to deliver to
// localhost test servers.
func WithHTTPClient(client *http.Client) Option {
return func(o *options) {
o.httpClient = client
}
}
// New creates a new Dispatcher to dispatch web push notifications.
//
// This is *not* integrated into the enqueue system unfortunately.
@@ -90,6 +103,9 @@ func New(ctx context.Context, log *slog.Logger, db database.Store, vapidSub stri
if cfg.subscriptionCacheTTL <= 0 {
cfg.subscriptionCacheTTL = defaultSubscriptionCacheTTL
}
if cfg.httpClient == nil {
cfg.httpClient = newSSRFSafeHTTPClient()
}
keys, err := db.GetWebpushVAPIDKeys(ctx)
if err != nil {
@@ -121,6 +137,7 @@ func New(ctx context.Context, log *slog.Logger, db database.Store, vapidSub stri
subscriptionCacheTTL: cfg.subscriptionCacheTTL,
subscriptionCache: make(map[uuid.UUID]cachedSubscriptions),
subscriptionGenerations: make(map[uuid.UUID]uint64),
httpClient: cfg.httpClient,
}, nil
}
@@ -142,6 +159,12 @@ type Webpusher struct {
VAPIDPublicKey string
VAPIDPrivateKey string
// httpClient is an SSRF-safe HTTP client that rejects connections to
// private, loopback, and link-local IP addresses at dial time. This
// closes the DNS rebinding TOCTOU gap where a hostname passes URL
// validation but resolves to a private IP when the connection is made.
httpClient *http.Client
clock quartz.Clock
cacheMu sync.RWMutex
@@ -338,6 +361,7 @@ func (n *Webpusher) webpushSend(ctx context.Context, msg []byte, endpoint string
Endpoint: endpoint,
Keys: keys,
}, &webpush.Options{
HTTPClient: n.httpClient,
Subscriber: n.vapidSub,
VAPIDPublicKey: n.VAPIDPublicKey,
VAPIDPrivateKey: n.VAPIDPrivateKey,
@@ -407,6 +431,37 @@ func (*NoopWebpusher) PublicKey() string {
return ""
}
// newSSRFSafeHTTPClient returns an HTTP client that rejects connections to
// private, loopback, link-local, multicast, and unspecified IP addresses.
// This prevents DNS rebinding attacks where a hostname passes URL-level
// validation but resolves to an internal IP at dial time.
func newSSRFSafeHTTPClient() *http.Client {
return &http.Client{
Transport: &http.Transport{
DialContext: (&net.Dialer{
Control: func(_ string, address string, _ syscall.RawConn) error {
host, _, err := net.SplitHostPort(address)
if err != nil {
return xerrors.Errorf("split host/port: %w", err)
}
ip, err := netip.ParseAddr(host)
if err != nil {
return xerrors.Errorf("parse resolved IP: %w", err)
}
if ip.IsPrivate() || ip.IsLoopback() || ip.IsLinkLocalUnicast() ||
ip.IsLinkLocalMulticast() || ip.IsMulticast() ||
ip.IsUnspecified() {
return xerrors.Errorf(
"webpush endpoint resolved to non-public address %s", ip.String(),
)
}
return nil
},
}).DialContext,
},
}
}
// RegenerateVAPIDKeys regenerates the VAPID keys and deletes all existing
// push subscriptions as part of the transaction, as they are no longer valid.
func RegenerateVAPIDKeys(ctx context.Context, db database.Store) (newPrivateKey string, newPublicKey string, err error) {
+58
View File
@@ -387,6 +387,8 @@ func assertWebpushPayload(t testing.TB, r *http.Request) {
}
// setupPushTest creates a common test setup for webpush notification tests.
// The test HTTP client bypasses SSRF protection so that httptest.Server
// (bound to 127.0.0.1) can be reached.
func setupPushTest(ctx context.Context, t *testing.T, handlerFunc func(w http.ResponseWriter, r *http.Request)) (webpush.Dispatcher, database.Store, string) {
t.Helper()
db, _ := dbtestutil.NewDB(t)
@@ -400,6 +402,9 @@ func setupPushTestWithOptions(ctx context.Context, t *testing.T, db database.Sto
server := httptest.NewServer(http.HandlerFunc(handlerFunc))
t.Cleanup(server.Close)
// Use an unrestricted HTTP client for tests. The default SSRF-safe
// client rejects loopback addresses, which blocks httptest.Server.
opts = append(opts, webpush.WithHTTPClient(http.DefaultClient))
manager, err := webpush.New(ctx, &logger, db, "http://example.com", opts...)
require.NoError(t, err, "Failed to create webpush manager")
@@ -423,3 +428,56 @@ func TestNoopWebpusher(t *testing.T) {
require.Empty(t, noop.PublicKey())
}
// TestSSRFPrevention verifies that the default SSRF-safe HTTP client blocks
// webpush delivery to loopback (and other non-public) addresses. This
// reproduces the attack vector from the original SSRF PoC: an authenticated
// user supplies a localhost endpoint in their webpush subscription, and the
// server must refuse to connect.
func TestSSRFPrevention(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
// Start a server that records whether it received a request.
var received atomic.Bool
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
received.Store(true)
w.WriteHeader(http.StatusCreated)
}))
defer server.Close()
// Create a dispatcher via New() WITHOUT WithHTTPClient so it
// uses the default SSRF-safe client that blocks loopback.
db, _ := dbtestutil.NewDB(t)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
manager, err := webpush.New(ctx, &logger, db, "http://example.com")
require.NoError(t, err)
// Test() calls webpushSend directly with the supplied endpoint.
err = manager.Test(ctx, codersdk.WebpushSubscription{
Endpoint: server.URL,
AuthKey: validEndpointAuthKey,
P256DHKey: validEndpointP256dhKey,
})
require.Error(t, err, "SSRF-safe client should reject Test() to loopback address")
assert.False(t, received.Load(), "Test() request should not reach the localhost server")
// Dispatch() goes through the subscription cache → webpushSend path.
user := dbgen.User(t, db, database.User{})
_, err = db.InsertWebpushSubscription(ctx, database.InsertWebpushSubscriptionParams{
CreatedAt: dbtime.Now(),
UserID: user.ID,
Endpoint: server.URL,
EndpointAuthKey: validEndpointAuthKey,
EndpointP256dhKey: validEndpointP256dhKey,
})
require.NoError(t, err)
err = manager.Dispatch(ctx, user.ID, codersdk.WebpushMessage{
Title: "SSRF test",
Body: "This should not arrive.",
})
require.Error(t, err, "SSRF-safe client should reject Dispatch() to loopback address")
assert.False(t, received.Load(), "Dispatch() request should not reach the localhost server")
}
+151
View File
@@ -0,0 +1,151 @@
package coderd
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestValidateWebpushEndpoint(t *testing.T) {
t.Parallel()
tests := []struct {
name string
endpoint string
wantErr bool
errSubstr string
}{
{
name: "valid https endpoint",
endpoint: "https://fcm.googleapis.com/fcm/send/abc123",
wantErr: false,
},
{
name: "valid https endpoint with port",
endpoint: "https://push.example.com:8443/subscription",
wantErr: false,
},
{
name: "relative URL",
endpoint: "/push/subscription",
wantErr: true,
errSubstr: "absolute URL",
},
{
name: "http scheme rejected",
endpoint: "http://push.example.com/subscription",
wantErr: true,
errSubstr: "scheme must be https",
},
{
name: "custom scheme rejected",
endpoint: "ws://push.example.com/subscription",
wantErr: true,
errSubstr: "scheme must be https",
},
{
name: "empty host",
endpoint: "https:///path",
wantErr: true,
errSubstr: "host is required",
},
{
name: "userinfo rejected",
endpoint: "https://user:pass@push.example.com/subscription",
wantErr: true,
errSubstr: "must not include userinfo",
},
{
name: "localhost rejected",
endpoint: "https://localhost/subscription",
wantErr: true,
errSubstr: "must not be localhost",
},
{
name: "subdomain of localhost rejected",
endpoint: "https://foo.localhost/subscription",
wantErr: true,
errSubstr: "must not be localhost",
},
{
name: "loopback IPv4 rejected",
endpoint: "https://127.0.0.1/subscription",
wantErr: true,
errSubstr: "must not be private",
},
{
name: "private 10.x rejected",
endpoint: "https://10.0.0.1/subscription",
wantErr: true,
errSubstr: "must not be private",
},
{
name: "private 192.168.x rejected",
endpoint: "https://192.168.1.1/subscription",
wantErr: true,
errSubstr: "must not be private",
},
{
name: "private 172.16.x rejected",
endpoint: "https://172.16.0.1/subscription",
wantErr: true,
errSubstr: "must not be private",
},
{
name: "link-local IPv4 rejected",
endpoint: "https://169.254.1.1/subscription",
wantErr: true,
errSubstr: "must not be private",
},
{
name: "unspecified IPv4 rejected",
endpoint: "https://0.0.0.0/subscription",
wantErr: true,
errSubstr: "must not be private",
},
{
name: "loopback IPv6 rejected",
endpoint: "https://[::1]/subscription",
wantErr: true,
errSubstr: "must not be private",
},
{
name: "unspecified IPv6 rejected",
endpoint: "https://[::]/subscription",
wantErr: true,
errSubstr: "must not be private",
},
{
name: "link-local IPv6 rejected",
endpoint: "https://[fe80::1]/subscription",
wantErr: true,
errSubstr: "must not be private",
},
{
name: "multicast IPv4 rejected",
endpoint: "https://224.0.0.1/subscription",
wantErr: true,
errSubstr: "must not be private",
},
{
name: "public IPv4 allowed",
endpoint: "https://203.0.113.1/subscription",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
err := validateWebpushEndpoint(tt.endpoint)
if tt.wantErr {
require.Error(t, err)
assert.Contains(t, err.Error(), tt.errSubstr,
"error should mention %q", tt.errSubstr)
} else {
require.NoError(t, err)
}
})
}
}
+74 -18
View File
@@ -3,7 +3,7 @@ package coderd_test
import (
"context"
"net/http"
"net/http/httptest"
"sync"
"sync/atomic"
"testing"
@@ -30,49 +30,48 @@ func TestWebpushSubscribeUnsubscribe(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitShort)
client := coderdtest.New(t, &coderdtest.Options{})
dispatcher := &testWebpushDispatcher{}
client := coderdtest.New(t, &coderdtest.Options{
WebpushDispatcher: dispatcher,
})
owner := coderdtest.CreateFirstUser(t, client)
memberClient, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
_, anotherMember := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
var handlerCalls atomic.Int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusCreated)
handlerCalls.Add(1)
}))
defer server.Close()
endpoint := "https://push.example.com/subscription/abc123"
// Seed the dispatcher cache with an empty subscription set. Creating the
// subscription should invalidate that entry so the next dispatch sees the new
// subscription immediately.
err := memberClient.PostTestWebpushMessage(ctx)
require.NoError(t, err, "test webpush message without a subscription")
require.Zero(t, handlerCalls.Load(), "a user without subscriptions should not receive a push")
require.Equal(t, int32(1), dispatcher.dispatchCalls.Load(), "dispatch should be called even with no subscriptions")
err = memberClient.PostWebpushSubscription(ctx, "me", codersdk.WebpushSubscription{
Endpoint: server.URL,
Endpoint: endpoint,
AuthKey: validEndpointAuthKey,
P256DHKey: validEndpointP256dhKey,
})
require.NoError(t, err, "create webpush subscription")
require.Equal(t, int32(1), handlerCalls.Load(), "subscription validation should hit the endpoint once")
require.Equal(t, int32(1), dispatcher.testCalls.Load(), "subscription validation should call dispatcher test once")
require.Equal(t, 1, dispatcher.invalidateCount(), "subscribing should invalidate the user's cached subscriptions")
err = memberClient.PostTestWebpushMessage(ctx)
require.NoError(t, err, "test webpush message after subscribing")
require.Equal(t, int32(2), handlerCalls.Load(), "the dispatcher should invalidate empty cache entries after subscribing")
require.Equal(t, int32(2), dispatcher.dispatchCalls.Load(), "dispatch should be called after subscribing")
err = memberClient.DeleteWebpushSubscription(ctx, "me", codersdk.DeleteWebpushSubscription{
Endpoint: server.URL,
Endpoint: endpoint,
})
require.NoError(t, err, "delete webpush subscription")
require.Equal(t, 2, dispatcher.invalidateCount(), "unsubscribing should invalidate the user's cached subscriptions")
err = memberClient.PostTestWebpushMessage(ctx)
require.NoError(t, err, "test webpush message after unsubscribing")
require.Equal(t, int32(2), handlerCalls.Load(), "the dispatcher should invalidate cached subscriptions after unsubscribing")
require.Equal(t, int32(3), dispatcher.dispatchCalls.Load(), "dispatch should be called after unsubscribing")
// Deleting the subscription for a non-existent endpoint should return a 404.
err = memberClient.DeleteWebpushSubscription(ctx, "me", codersdk.DeleteWebpushSubscription{
Endpoint: server.URL,
Endpoint: endpoint,
})
var sdkError *codersdk.Error
require.Error(t, err)
@@ -81,7 +80,7 @@ func TestWebpushSubscribeUnsubscribe(t *testing.T) {
// Creating a subscription for another user should not be allowed.
err = memberClient.PostWebpushSubscription(ctx, anotherMember.ID.String(), codersdk.WebpushSubscription{
Endpoint: server.URL,
Endpoint: endpoint,
AuthKey: validEndpointAuthKey,
P256DHKey: validEndpointP256dhKey,
})
@@ -89,11 +88,33 @@ func TestWebpushSubscribeUnsubscribe(t *testing.T) {
// Deleting a subscription for another user should not be allowed.
err = memberClient.DeleteWebpushSubscription(ctx, anotherMember.ID.String(), codersdk.DeleteWebpushSubscription{
Endpoint: server.URL,
Endpoint: endpoint,
})
require.Error(t, err, "delete webpush subscription for another user")
}
func TestWebpushSubscribeRejectsInvalidEndpoint(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
client := coderdtest.New(t, &coderdtest.Options{
WebpushDispatcher: &testWebpushDispatcher{},
})
owner := coderdtest.CreateFirstUser(t, client)
memberClient, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
err := memberClient.PostWebpushSubscription(ctx, "me", codersdk.WebpushSubscription{
Endpoint: "http://127.0.0.1:8080/subscription",
AuthKey: validEndpointAuthKey,
P256DHKey: validEndpointP256dhKey,
})
var sdkError *codersdk.Error
require.Error(t, err)
require.ErrorAsf(t, err, &sdkError, "error should be of type *codersdk.Error")
require.Equal(t, http.StatusBadRequest, sdkError.StatusCode())
require.Contains(t, sdkError.Error(), "endpoint URL scheme must be https")
}
// testWebpushErrorStore wraps a real database.Store and allows injecting
// errors into GetWebpushSubscriptionsByUserID.
type testWebpushErrorStore struct {
@@ -101,6 +122,41 @@ type testWebpushErrorStore struct {
getWebpushSubscriptionsErr atomic.Pointer[error]
}
type testWebpushDispatcher struct {
testCalls atomic.Int32
dispatchCalls atomic.Int32
invalidateUserIDs []uuid.UUID
invalidateUserLock sync.Mutex
}
func (d *testWebpushDispatcher) Dispatch(_ context.Context, _ uuid.UUID, _ codersdk.WebpushMessage) error {
d.dispatchCalls.Add(1)
return nil
}
func (d *testWebpushDispatcher) Test(_ context.Context, _ codersdk.WebpushSubscription) error {
d.testCalls.Add(1)
return nil
}
func (*testWebpushDispatcher) PublicKey() string {
return ""
}
// InvalidateUser implements webpush.SubscriptionCacheInvalidator so the
// handler exercises the cache-invalidation path on subscribe/unsubscribe.
func (d *testWebpushDispatcher) InvalidateUser(userID uuid.UUID) {
d.invalidateUserLock.Lock()
defer d.invalidateUserLock.Unlock()
d.invalidateUserIDs = append(d.invalidateUserIDs, userID)
}
func (d *testWebpushDispatcher) invalidateCount() int {
d.invalidateUserLock.Lock()
defer d.invalidateUserLock.Unlock()
return len(d.invalidateUserIDs)
}
func (s *testWebpushErrorStore) GetWebpushSubscriptionsByUserID(ctx context.Context, userID uuid.UUID) ([]database.WebpushSubscription, error) {
if err := s.getWebpushSubscriptionsErr.Load(); err != nil {
return nil, *err