feat: add support for WorkspaceUpdates to WebsocketDialer (#15534)

closes #14730

Adds support for WorkspaceUpdates to the WebsocketDialer. This allows us to dial the new endpoint added in #14847 and connect it up to a `tailnet.Controllers` to connect to all agents over the tailnet.

I refactored the fakeWorkspaceUpdatesProvider to a mock and moved it to `tailnettest` so it could be more easily reused.  The Mock is a little more full-featured.
This commit is contained in:
Spike Curtis
2024-11-18 10:54:11 +04:00
committed by GitHub
parent 16992ee548
commit 747f7ce173
9 changed files with 305 additions and 78 deletions
+10 -5
View File
@@ -482,6 +482,13 @@ DB_GEN_FILES := \
coderd/database/dbauthz/dbauthz.go \
coderd/database/dbmock/dbmock.go
TAILNETTEST_MOCKS := \
tailnet/tailnettest/coordinatormock.go \
tailnet/tailnettest/coordinateemock.go \
tailnet/tailnettest/workspaceupdatesprovidermock.go \
tailnet/tailnettest/subscriptionmock.go
# all gen targets should be added here and to gen/mark-fresh
gen: \
tailnet/proto/tailnet.pb.go \
@@ -506,8 +513,7 @@ gen: \
site/e2e/provisionerGenerated.ts \
site/src/theme/icons.json \
examples/examples.gen.json \
tailnet/tailnettest/coordinatormock.go \
tailnet/tailnettest/coordinateemock.go \
$(TAILNETTEST_MOCKS) \
coderd/database/pubsub/psmock/psmock.go
.PHONY: gen
@@ -536,8 +542,7 @@ gen/mark-fresh:
site/e2e/provisionerGenerated.ts \
site/src/theme/icons.json \
examples/examples.gen.json \
tailnet/tailnettest/coordinatormock.go \
tailnet/tailnettest/coordinateemock.go \
$(TAILNETTEST_MOCKS) \
coderd/database/pubsub/psmock/psmock.go \
"
@@ -570,7 +575,7 @@ coderd/database/dbmock/dbmock.go: coderd/database/db.go coderd/database/querier.
coderd/database/pubsub/psmock/psmock.go: coderd/database/pubsub/pubsub.go
go generate ./coderd/database/pubsub/psmock
tailnet/tailnettest/coordinatormock.go tailnet/tailnettest/coordinateemock.go: tailnet/coordinator.go
$(TAILNETTEST_MOCKS): tailnet/coordinator.go tailnet/service.go
go generate ./tailnet/tailnettest/
tailnet/proto/tailnet.pb.go: tailnet/proto/tailnet.proto
+56 -13
View File
@@ -25,14 +25,26 @@ var permanentErrorStatuses = []int{
}
type WebsocketDialer struct {
logger slog.Logger
dialOptions *websocket.DialOptions
url *url.URL
logger slog.Logger
dialOptions *websocket.DialOptions
url *url.URL
// workspaceUpdatesReq != nil means that the dialer should call the WorkspaceUpdates RPC and
// return the corresponding client
workspaceUpdatesReq *proto.WorkspaceUpdatesRequest
resumeTokenFailed bool
connected chan error
isFirst bool
}
type WebsocketDialerOption func(*WebsocketDialer)
func WithWorkspaceUpdates(req *proto.WorkspaceUpdatesRequest) WebsocketDialerOption {
return func(w *WebsocketDialer) {
w.workspaceUpdatesReq = req
}
}
func (w *WebsocketDialer) Dial(ctx context.Context, r tailnet.ResumeTokenController,
) (
tailnet.ControlProtocolClients, error,
@@ -41,14 +53,27 @@ func (w *WebsocketDialer) Dial(ctx context.Context, r tailnet.ResumeTokenControl
u := new(url.URL)
*u = *w.url
q := u.Query()
if r != nil && !w.resumeTokenFailed {
if token, ok := r.Token(); ok {
q := u.Query()
q.Set("resume_token", token)
u.RawQuery = q.Encode()
w.logger.Debug(ctx, "using resume token on dial")
}
}
// The current version includes additions
//
// 2.1 GetAnnouncementBanners on the Agent API (version locked to Tailnet API)
// 2.2 PostTelemetry on the Tailnet API
// 2.3 RefreshResumeToken, WorkspaceUpdates
//
// Resume tokens and telemetry are optional, and fail gracefully. So we use version 2.0 for
// maximum compatibility if we don't need WorkspaceUpdates. If we do, we use 2.3.
if w.workspaceUpdatesReq != nil {
q.Add("version", "2.3")
} else {
q.Add("version", "2.0")
}
u.RawQuery = q.Encode()
// nolint:bodyclose
ws, res, err := websocket.Dial(ctx, u.String(), w.dialOptions)
@@ -115,12 +140,23 @@ func (w *WebsocketDialer) Dial(ctx context.Context, r tailnet.ResumeTokenControl
return tailnet.ControlProtocolClients{}, err
}
var updates tailnet.WorkspaceUpdatesClient
if w.workspaceUpdatesReq != nil {
updates, err = client.WorkspaceUpdates(context.Background(), w.workspaceUpdatesReq)
if err != nil {
w.logger.Debug(ctx, "failed to create WorkspaceUpdates stream", slog.Error(err))
_ = ws.Close(websocket.StatusInternalError, "")
return tailnet.ControlProtocolClients{}, err
}
}
return tailnet.ControlProtocolClients{
Closer: client.DRPCConn(),
Coordinator: coord,
DERP: derps,
ResumeToken: client,
Telemetry: client,
Closer: client.DRPCConn(),
Coordinator: coord,
DERP: derps,
ResumeToken: client,
Telemetry: client,
WorkspaceUpdates: updates,
}, nil
}
@@ -128,12 +164,19 @@ func (w *WebsocketDialer) Connected() <-chan error {
return w.connected
}
func NewWebsocketDialer(logger slog.Logger, u *url.URL, opts *websocket.DialOptions) *WebsocketDialer {
return &WebsocketDialer{
func NewWebsocketDialer(
logger slog.Logger, u *url.URL, websocketOptions *websocket.DialOptions,
dialerOptions ...WebsocketDialerOption,
) *WebsocketDialer {
w := &WebsocketDialer{
logger: logger,
dialOptions: opts,
dialOptions: websocketOptions,
url: u,
connected: make(chan error, 1),
isFirst: true,
}
for _, o := range dialerOptions {
o(w)
}
return w
}
+87 -3
View File
@@ -9,8 +9,10 @@ import (
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"nhooyr.io/websocket"
"tailscale.com/tailcfg"
@@ -21,7 +23,7 @@ import (
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto"
tailnetproto "github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/coder/v2/tailnet/tailnettest"
"github.com/coder/coder/v2/testutil"
)
@@ -102,6 +104,7 @@ func TestWebsocketDialer_TokenController(t *testing.T) {
require.Equal(t, "", gotToken)
clients = testutil.RequireRecvCtx(ctx, t, clientCh)
require.Nil(t, clients.WorkspaceUpdates)
clients.Closer.Close()
err = testutil.RequireRecvCtx(ctx, t, wsErr)
@@ -273,7 +276,7 @@ func TestWebsocketDialer_UplevelVersion(t *testing.T) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
sVer := apiversion.New(proto.CurrentMajor, proto.CurrentMinor-1)
sVer := apiversion.New(2, 2)
// the following matches what Coderd does;
// c.f. coderd/workspaceagents.go: workspaceAgentClientCoordinate
@@ -291,7 +294,10 @@ func TestWebsocketDialer_UplevelVersion(t *testing.T) {
svrURL, err := url.Parse(svr.URL)
require.NoError(t, err)
uut := workspacesdk.NewWebsocketDialer(logger, svrURL, &websocket.DialOptions{})
uut := workspacesdk.NewWebsocketDialer(
logger, svrURL, &websocket.DialOptions{},
workspacesdk.WithWorkspaceUpdates(&tailnetproto.WorkspaceUpdatesRequest{}),
)
errCh := make(chan error, 1)
go func() {
@@ -307,6 +313,84 @@ func TestWebsocketDialer_UplevelVersion(t *testing.T) {
require.NotEmpty(t, sdkErr.Helper)
}
func TestWebsocketDialer_WorkspaceUpdates(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, &slogtest.Options{
IgnoreErrors: true,
}).Leveled(slog.LevelDebug)
fCoord := tailnettest.NewFakeCoordinator()
var coord tailnet.Coordinator = fCoord
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
coordPtr.Store(&coord)
ctrl := gomock.NewController(t)
mProvider := tailnettest.NewMockWorkspaceUpdatesProvider(ctrl)
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
Logger: logger,
CoordPtr: &coordPtr,
DERPMapUpdateFrequency: time.Hour,
DERPMapFn: func() *tailcfg.DERPMap { return &tailcfg.DERPMap{} },
WorkspaceUpdatesProvider: mProvider,
})
require.NoError(t, err)
wsErr := make(chan error, 1)
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// need 2.3 for WorkspaceUpdates RPC
cVer := r.URL.Query().Get("version")
assert.Equal(t, "2.3", cVer)
sws, err := websocket.Accept(w, r, nil)
if !assert.NoError(t, err) {
return
}
wsCtx, nc := codersdk.WebsocketNetConn(ctx, sws, websocket.MessageBinary)
// streamID can be empty because we don't call RPCs in this test.
wsErr <- svc.ServeConnV2(wsCtx, nc, tailnet.StreamID{})
}))
defer svr.Close()
svrURL, err := url.Parse(svr.URL)
require.NoError(t, err)
userID := uuid.UUID{88}
mSub := tailnettest.NewMockSubscription(ctrl)
updateCh := make(chan *tailnetproto.WorkspaceUpdate, 1)
mProvider.EXPECT().Subscribe(gomock.Any(), userID).Times(1).Return(mSub, nil)
mSub.EXPECT().Updates().MinTimes(1).Return(updateCh)
mSub.EXPECT().Close().Times(1).Return(nil)
uut := workspacesdk.NewWebsocketDialer(
logger, svrURL, &websocket.DialOptions{},
workspacesdk.WithWorkspaceUpdates(&tailnetproto.WorkspaceUpdatesRequest{
WorkspaceOwnerId: userID[:],
}),
)
clients, err := uut.Dial(ctx, nil)
require.NoError(t, err)
require.NotNil(t, clients.WorkspaceUpdates)
wsID := uuid.UUID{99}
expectedUpdate := &tailnetproto.WorkspaceUpdate{
UpsertedWorkspaces: []*tailnetproto.Workspace{
{Id: wsID[:]},
},
}
updateCh <- expectedUpdate
gotUpdate, err := clients.WorkspaceUpdates.Recv()
require.NoError(t, err)
require.Equal(t, wsID[:], gotUpdate.GetUpsertedWorkspaces()[0].GetId())
clients.Closer.Close()
err = testutil.RequireRecvCtx(ctx, t, wsErr)
require.NoError(t, err)
}
type fakeResumeTokenController struct {
ctx context.Context
t testing.TB
-11
View File
@@ -216,17 +216,6 @@ func (c *Client) DialAgent(dialCtx context.Context, agentID uuid.UUID, options *
if err != nil {
return nil, xerrors.Errorf("parse url: %w", err)
}
q := coordinateURL.Query()
// The current version includes additions
//
// 2.1 GetAnnouncementBanners on the Agent API (version locked to Tailnet API)
// 2.2 PostTelemetry on the Tailnet API
// 2.3 RefreshResumeToken, WorkspaceUpdates
//
// Since resume tokens and telemetry are optional, and fail gracefully, and we don't use
// WorkspaceUpdates to talk to a single agent, we ask for version 2.0 for maximum compatibility
q.Add("version", "2.0")
coordinateURL.RawQuery = q.Encode()
dialer := NewWebsocketDialer(options.Logger, coordinateURL, &websocket.DialOptions{
HTTPClient: c.client.HTTPClient,
@@ -21,18 +21,6 @@ import (
agpl "github.com/coder/coder/v2/tailnet"
)
// TailnetAPIVersion is the version of the Tailnet API we use for wsproxy.
//
// # The current version of the Tailnet API includes additions
//
// 2.1 GetAnnouncementBanners on the Agent API (version locked to Tailnet API)
// 2.2 PostTelemetry on the Tailnet API
// 2.3 RefreshResumeToken, WorkspaceUpdates
//
// Since resume tokens and telemetry are optional, and fail gracefully, and we don't use
// WorkspaceUpdates in the wsproxy, we ask for version 2.0 for maximum compatibility
const TailnetAPIVersion = "2.0"
// Client is a HTTP client for a subset of Coder API routes that external
// proxies need.
type Client struct {
@@ -518,9 +506,6 @@ func (c *Client) TailnetDialer() (*workspacesdk.WebsocketDialer, error) {
if err != nil {
return nil, xerrors.Errorf("parse url: %w", err)
}
q := coordinateURL.Query()
q.Add("version", TailnetAPIVersion)
coordinateURL.RawQuery = q.Encode()
coordinateHeaders := make(http.Header)
tokenHeader := codersdk.SessionTokenHeader
if c.SDKClient.SessionTokenHeader != "" {
+11 -31
View File
@@ -11,6 +11,7 @@ import (
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"golang.org/x/xerrors"
"tailscale.com/tailcfg"
@@ -236,8 +237,8 @@ func TestClientUserCoordinateeAuth(t *testing.T) {
agentID2 := uuid.UUID{0x02}
clientID := uuid.UUID{0x03}
updatesCh := make(chan *proto.WorkspaceUpdate, 1)
updatesProvider := &fakeUpdatesProvider{ch: updatesCh}
ctrl := gomock.NewController(t)
updatesProvider := tailnettest.NewMockWorkspaceUpdatesProvider(ctrl)
fCoord, client := createUpdateService(t, ctx, clientID, updatesProvider)
@@ -271,8 +272,10 @@ func TestWorkspaceUpdates(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
ctrl := gomock.NewController(t)
updatesProvider := tailnettest.NewMockWorkspaceUpdatesProvider(ctrl)
mSub := tailnettest.NewMockSubscription(ctrl)
updatesCh := make(chan *proto.WorkspaceUpdate, 1)
updatesProvider := &fakeUpdatesProvider{ch: updatesCh}
clientID := uuid.UUID{0x03}
wsID := uuid.UUID{0x04}
@@ -293,6 +296,11 @@ func TestWorkspaceUpdates(t *testing.T) {
DeletedAgents: []*proto.Agent{},
}
updatesCh <- expected
updatesProvider.EXPECT().Subscribe(gomock.Any(), clientID).
Times(1).
Return(mSub, nil)
mSub.EXPECT().Updates().MinTimes(1).Return(updatesCh)
mSub.EXPECT().Close().Times(1).Return(nil)
updatesStream, err := client.WorkspaceUpdates(ctx, &proto.WorkspaceUpdatesRequest{
WorkspaceOwnerId: tailnet.UUIDToByteSlice(clientID),
@@ -354,34 +362,6 @@ func createUpdateService(t *testing.T, ctx context.Context, clientID uuid.UUID,
return fCoord, client
}
type fakeUpdatesProvider struct {
ch chan *proto.WorkspaceUpdate
}
func (*fakeUpdatesProvider) Close() error {
return nil
}
func (f *fakeUpdatesProvider) Subscribe(context.Context, uuid.UUID) (tailnet.Subscription, error) {
return &fakeSubscription{ch: f.ch}, nil
}
type fakeSubscription struct {
ch chan *proto.WorkspaceUpdate
}
func (*fakeSubscription) Close() error {
return nil
}
func (f *fakeSubscription) Updates() <-chan *proto.WorkspaceUpdate {
return f.ch
}
var _ tailnet.Subscription = (*fakeSubscription)(nil)
var _ tailnet.WorkspaceUpdatesProvider = (*fakeUpdatesProvider)(nil)
type fakeTunnelAuth struct{}
// AuthorizeTunnel implements tailnet.TunnelAuthorizer.
+68
View File
@@ -0,0 +1,68 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/coder/coder/v2/tailnet (interfaces: Subscription)
//
// Generated by this command:
//
// mockgen -destination ./subscriptionmock.go -package tailnettest github.com/coder/coder/v2/tailnet Subscription
//
// Package tailnettest is a generated GoMock package.
package tailnettest
import (
reflect "reflect"
proto "github.com/coder/coder/v2/tailnet/proto"
gomock "go.uber.org/mock/gomock"
)
// MockSubscription is a mock of Subscription interface.
type MockSubscription struct {
ctrl *gomock.Controller
recorder *MockSubscriptionMockRecorder
}
// MockSubscriptionMockRecorder is the mock recorder for MockSubscription.
type MockSubscriptionMockRecorder struct {
mock *MockSubscription
}
// NewMockSubscription creates a new mock instance.
func NewMockSubscription(ctrl *gomock.Controller) *MockSubscription {
mock := &MockSubscription{ctrl: ctrl}
mock.recorder = &MockSubscriptionMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockSubscription) EXPECT() *MockSubscriptionMockRecorder {
return m.recorder
}
// Close mocks base method.
func (m *MockSubscription) Close() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Close")
ret0, _ := ret[0].(error)
return ret0
}
// Close indicates an expected call of Close.
func (mr *MockSubscriptionMockRecorder) Close() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockSubscription)(nil).Close))
}
// Updates mocks base method.
func (m *MockSubscription) Updates() <-chan *proto.WorkspaceUpdate {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Updates")
ret0, _ := ret[0].(<-chan *proto.WorkspaceUpdate)
return ret0
}
// Updates indicates an expected call of Updates.
func (mr *MockSubscriptionMockRecorder) Updates() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Updates", reflect.TypeOf((*MockSubscription)(nil).Updates))
}
+2
View File
@@ -26,6 +26,8 @@ import (
//go:generate mockgen -destination ./coordinatormock.go -package tailnettest github.com/coder/coder/v2/tailnet Coordinator
//go:generate mockgen -destination ./coordinateemock.go -package tailnettest github.com/coder/coder/v2/tailnet Coordinatee
//go:generate mockgen -destination ./workspaceupdatesprovidermock.go -package tailnettest github.com/coder/coder/v2/tailnet WorkspaceUpdatesProvider
//go:generate mockgen -destination ./subscriptionmock.go -package tailnettest github.com/coder/coder/v2/tailnet Subscription
type derpAndSTUNCfg struct {
DisableSTUN bool
@@ -0,0 +1,71 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/coder/coder/v2/tailnet (interfaces: WorkspaceUpdatesProvider)
//
// Generated by this command:
//
// mockgen -destination ./workspaceupdatesprovidermock.go -package tailnettest github.com/coder/coder/v2/tailnet WorkspaceUpdatesProvider
//
// Package tailnettest is a generated GoMock package.
package tailnettest
import (
context "context"
reflect "reflect"
tailnet "github.com/coder/coder/v2/tailnet"
uuid "github.com/google/uuid"
gomock "go.uber.org/mock/gomock"
)
// MockWorkspaceUpdatesProvider is a mock of WorkspaceUpdatesProvider interface.
type MockWorkspaceUpdatesProvider struct {
ctrl *gomock.Controller
recorder *MockWorkspaceUpdatesProviderMockRecorder
}
// MockWorkspaceUpdatesProviderMockRecorder is the mock recorder for MockWorkspaceUpdatesProvider.
type MockWorkspaceUpdatesProviderMockRecorder struct {
mock *MockWorkspaceUpdatesProvider
}
// NewMockWorkspaceUpdatesProvider creates a new mock instance.
func NewMockWorkspaceUpdatesProvider(ctrl *gomock.Controller) *MockWorkspaceUpdatesProvider {
mock := &MockWorkspaceUpdatesProvider{ctrl: ctrl}
mock.recorder = &MockWorkspaceUpdatesProviderMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockWorkspaceUpdatesProvider) EXPECT() *MockWorkspaceUpdatesProviderMockRecorder {
return m.recorder
}
// Close mocks base method.
func (m *MockWorkspaceUpdatesProvider) Close() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Close")
ret0, _ := ret[0].(error)
return ret0
}
// Close indicates an expected call of Close.
func (mr *MockWorkspaceUpdatesProviderMockRecorder) Close() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockWorkspaceUpdatesProvider)(nil).Close))
}
// Subscribe mocks base method.
func (m *MockWorkspaceUpdatesProvider) Subscribe(arg0 context.Context, arg1 uuid.UUID) (tailnet.Subscription, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Subscribe", arg0, arg1)
ret0, _ := ret[0].(tailnet.Subscription)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Subscribe indicates an expected call of Subscribe.
func (mr *MockWorkspaceUpdatesProviderMockRecorder) Subscribe(arg0, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Subscribe", reflect.TypeOf((*MockWorkspaceUpdatesProvider)(nil).Subscribe), arg0, arg1)
}