fix(coderd): validate webpush subscription endpoints (#24347)

Co-authored-by: Cian Johnston <cian@coder.com>
This commit is contained in:
Thomas Kosiewski
2026-04-15 11:31:43 +02:00
committed by GitHub
parent e317f3b239
commit 5812f84e1c
5 changed files with 386 additions and 18 deletions
+74 -18
View File
@@ -3,7 +3,7 @@ package coderd_test
import (
"context"
"net/http"
"net/http/httptest"
"sync"
"sync/atomic"
"testing"
@@ -30,49 +30,48 @@ func TestWebpushSubscribeUnsubscribe(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitShort)
client := coderdtest.New(t, &coderdtest.Options{})
dispatcher := &testWebpushDispatcher{}
client := coderdtest.New(t, &coderdtest.Options{
WebpushDispatcher: dispatcher,
})
owner := coderdtest.CreateFirstUser(t, client)
memberClient, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
_, anotherMember := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
var handlerCalls atomic.Int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusCreated)
handlerCalls.Add(1)
}))
defer server.Close()
endpoint := "https://push.example.com/subscription/abc123"
// Seed the dispatcher cache with an empty subscription set. Creating the
// subscription should invalidate that entry so the next dispatch sees the new
// subscription immediately.
err := memberClient.PostTestWebpushMessage(ctx)
require.NoError(t, err, "test webpush message without a subscription")
require.Zero(t, handlerCalls.Load(), "a user without subscriptions should not receive a push")
require.Equal(t, int32(1), dispatcher.dispatchCalls.Load(), "dispatch should be called even with no subscriptions")
err = memberClient.PostWebpushSubscription(ctx, "me", codersdk.WebpushSubscription{
Endpoint: server.URL,
Endpoint: endpoint,
AuthKey: validEndpointAuthKey,
P256DHKey: validEndpointP256dhKey,
})
require.NoError(t, err, "create webpush subscription")
require.Equal(t, int32(1), handlerCalls.Load(), "subscription validation should hit the endpoint once")
require.Equal(t, int32(1), dispatcher.testCalls.Load(), "subscription validation should call dispatcher test once")
require.Equal(t, 1, dispatcher.invalidateCount(), "subscribing should invalidate the user's cached subscriptions")
err = memberClient.PostTestWebpushMessage(ctx)
require.NoError(t, err, "test webpush message after subscribing")
require.Equal(t, int32(2), handlerCalls.Load(), "the dispatcher should invalidate empty cache entries after subscribing")
require.Equal(t, int32(2), dispatcher.dispatchCalls.Load(), "dispatch should be called after subscribing")
err = memberClient.DeleteWebpushSubscription(ctx, "me", codersdk.DeleteWebpushSubscription{
Endpoint: server.URL,
Endpoint: endpoint,
})
require.NoError(t, err, "delete webpush subscription")
require.Equal(t, 2, dispatcher.invalidateCount(), "unsubscribing should invalidate the user's cached subscriptions")
err = memberClient.PostTestWebpushMessage(ctx)
require.NoError(t, err, "test webpush message after unsubscribing")
require.Equal(t, int32(2), handlerCalls.Load(), "the dispatcher should invalidate cached subscriptions after unsubscribing")
require.Equal(t, int32(3), dispatcher.dispatchCalls.Load(), "dispatch should be called after unsubscribing")
// Deleting the subscription for a non-existent endpoint should return a 404.
err = memberClient.DeleteWebpushSubscription(ctx, "me", codersdk.DeleteWebpushSubscription{
Endpoint: server.URL,
Endpoint: endpoint,
})
var sdkError *codersdk.Error
require.Error(t, err)
@@ -81,7 +80,7 @@ func TestWebpushSubscribeUnsubscribe(t *testing.T) {
// Creating a subscription for another user should not be allowed.
err = memberClient.PostWebpushSubscription(ctx, anotherMember.ID.String(), codersdk.WebpushSubscription{
Endpoint: server.URL,
Endpoint: endpoint,
AuthKey: validEndpointAuthKey,
P256DHKey: validEndpointP256dhKey,
})
@@ -89,11 +88,33 @@ func TestWebpushSubscribeUnsubscribe(t *testing.T) {
// Deleting a subscription for another user should not be allowed.
err = memberClient.DeleteWebpushSubscription(ctx, anotherMember.ID.String(), codersdk.DeleteWebpushSubscription{
Endpoint: server.URL,
Endpoint: endpoint,
})
require.Error(t, err, "delete webpush subscription for another user")
}
func TestWebpushSubscribeRejectsInvalidEndpoint(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
client := coderdtest.New(t, &coderdtest.Options{
WebpushDispatcher: &testWebpushDispatcher{},
})
owner := coderdtest.CreateFirstUser(t, client)
memberClient, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
err := memberClient.PostWebpushSubscription(ctx, "me", codersdk.WebpushSubscription{
Endpoint: "http://127.0.0.1:8080/subscription",
AuthKey: validEndpointAuthKey,
P256DHKey: validEndpointP256dhKey,
})
var sdkError *codersdk.Error
require.Error(t, err)
require.ErrorAsf(t, err, &sdkError, "error should be of type *codersdk.Error")
require.Equal(t, http.StatusBadRequest, sdkError.StatusCode())
require.Contains(t, sdkError.Error(), "endpoint URL scheme must be https")
}
// testWebpushErrorStore wraps a real database.Store and allows injecting
// errors into GetWebpushSubscriptionsByUserID.
type testWebpushErrorStore struct {
@@ -101,6 +122,41 @@ type testWebpushErrorStore struct {
getWebpushSubscriptionsErr atomic.Pointer[error]
}
type testWebpushDispatcher struct {
testCalls atomic.Int32
dispatchCalls atomic.Int32
invalidateUserIDs []uuid.UUID
invalidateUserLock sync.Mutex
}
func (d *testWebpushDispatcher) Dispatch(_ context.Context, _ uuid.UUID, _ codersdk.WebpushMessage) error {
d.dispatchCalls.Add(1)
return nil
}
func (d *testWebpushDispatcher) Test(_ context.Context, _ codersdk.WebpushSubscription) error {
d.testCalls.Add(1)
return nil
}
func (*testWebpushDispatcher) PublicKey() string {
return ""
}
// InvalidateUser implements webpush.SubscriptionCacheInvalidator so the
// handler exercises the cache-invalidation path on subscribe/unsubscribe.
func (d *testWebpushDispatcher) InvalidateUser(userID uuid.UUID) {
d.invalidateUserLock.Lock()
defer d.invalidateUserLock.Unlock()
d.invalidateUserIDs = append(d.invalidateUserIDs, userID)
}
func (d *testWebpushDispatcher) invalidateCount() int {
d.invalidateUserLock.Lock()
defer d.invalidateUserLock.Unlock()
return len(d.invalidateUserIDs)
}
func (s *testWebpushErrorStore) GetWebpushSubscriptionsByUserID(ctx context.Context, userID uuid.UUID) ([]database.WebpushSubscription, error) {
if err := s.getWebpushSubscriptionsErr.Load(); err != nil {
return nil, *err