mirror of
https://github.com/coder/coder.git
synced 2026-06-03 13:08:25 +00:00
8b66a5ad59
Closes #18599. The linked issue was created due to me assuming the dialer didn't fail fast at all. In reality, it does fail fast, but only for a select few status codes. Auth[n|z] errors aren't any of those status codes, despite being 'permanent' in the same way a `400` is. This PR makes 401* and 403 'permanent' errors, meaning the dialer will give up immediately after receiving them from coderd. *One reason to receive a 401 is when the supplied resume_token is invalid. These are not permanent errors, and when we encounter those the dialer will retain the existing behaviour of unsetting the resume token and retrying.
475 lines
14 KiB
Go
475 lines
14 KiB
Go
package workspacesdk_test
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"go.uber.org/mock/gomock"
|
|
"tailscale.com/tailcfg"
|
|
|
|
"cdr.dev/slog"
|
|
"cdr.dev/slog/sloggers/slogtest"
|
|
"github.com/coder/coder/v2/apiversion"
|
|
"github.com/coder/coder/v2/coderd/httpapi"
|
|
"github.com/coder/coder/v2/codersdk"
|
|
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
|
"github.com/coder/coder/v2/tailnet"
|
|
tailnetproto "github.com/coder/coder/v2/tailnet/proto"
|
|
"github.com/coder/coder/v2/tailnet/tailnettest"
|
|
"github.com/coder/coder/v2/testutil"
|
|
"github.com/coder/websocket"
|
|
)
|
|
|
|
func TestWebsocketDialer_TokenController(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
logger := slogtest.Make(t, &slogtest.Options{
|
|
IgnoreErrors: true,
|
|
}).Leveled(slog.LevelDebug)
|
|
|
|
fTokenProv := newFakeTokenController(ctx, t)
|
|
fCoord := tailnettest.NewFakeCoordinator()
|
|
var coord tailnet.Coordinator = fCoord
|
|
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
|
|
coordPtr.Store(&coord)
|
|
|
|
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
|
|
Logger: logger,
|
|
CoordPtr: &coordPtr,
|
|
DERPMapUpdateFrequency: time.Hour,
|
|
DERPMapFn: func() *tailcfg.DERPMap { return &tailcfg.DERPMap{} },
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
dialTokens := make(chan string, 1)
|
|
wsErr := make(chan error, 1)
|
|
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
select {
|
|
case <-ctx.Done():
|
|
t.Error("timed out sending token")
|
|
case dialTokens <- r.URL.Query().Get("resume_token"):
|
|
// OK
|
|
}
|
|
|
|
sws, err := websocket.Accept(w, r, nil)
|
|
if !assert.NoError(t, err) {
|
|
return
|
|
}
|
|
wsCtx, nc := codersdk.WebsocketNetConn(ctx, sws, websocket.MessageBinary)
|
|
// streamID can be empty because we don't call RPCs in this test.
|
|
wsErr <- svc.ServeConnV2(wsCtx, nc, tailnet.StreamID{})
|
|
}))
|
|
defer svr.Close()
|
|
svrURL, err := url.Parse(svr.URL)
|
|
require.NoError(t, err)
|
|
|
|
uut := workspacesdk.NewWebsocketDialer(logger, svrURL, &websocket.DialOptions{})
|
|
|
|
clientCh := make(chan tailnet.ControlProtocolClients, 1)
|
|
go func() {
|
|
clients, err := uut.Dial(ctx, fTokenProv)
|
|
assert.NoError(t, err)
|
|
clientCh <- clients
|
|
}()
|
|
|
|
call := testutil.TryReceive(ctx, t, fTokenProv.tokenCalls)
|
|
call <- tokenResponse{"test token", true}
|
|
gotToken := <-dialTokens
|
|
require.Equal(t, "test token", gotToken)
|
|
|
|
clients := testutil.TryReceive(ctx, t, clientCh)
|
|
clients.Closer.Close()
|
|
|
|
err = testutil.TryReceive(ctx, t, wsErr)
|
|
require.NoError(t, err)
|
|
|
|
clientCh = make(chan tailnet.ControlProtocolClients, 1)
|
|
go func() {
|
|
clients, err := uut.Dial(ctx, fTokenProv)
|
|
assert.NoError(t, err)
|
|
clientCh <- clients
|
|
}()
|
|
|
|
call = testutil.TryReceive(ctx, t, fTokenProv.tokenCalls)
|
|
call <- tokenResponse{"test token", false}
|
|
gotToken = <-dialTokens
|
|
require.Equal(t, "", gotToken)
|
|
|
|
clients = testutil.TryReceive(ctx, t, clientCh)
|
|
require.Nil(t, clients.WorkspaceUpdates)
|
|
clients.Closer.Close()
|
|
|
|
err = testutil.TryReceive(ctx, t, wsErr)
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
func TestWebsocketDialer_NoTokenController(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
logger := slogtest.Make(t, &slogtest.Options{
|
|
IgnoreErrors: true,
|
|
}).Leveled(slog.LevelDebug)
|
|
|
|
fCoord := tailnettest.NewFakeCoordinator()
|
|
var coord tailnet.Coordinator = fCoord
|
|
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
|
|
coordPtr.Store(&coord)
|
|
|
|
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
|
|
Logger: logger,
|
|
CoordPtr: &coordPtr,
|
|
DERPMapUpdateFrequency: time.Hour,
|
|
DERPMapFn: func() *tailcfg.DERPMap { return &tailcfg.DERPMap{} },
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
dialTokens := make(chan string, 1)
|
|
wsErr := make(chan error, 1)
|
|
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
select {
|
|
case <-ctx.Done():
|
|
t.Error("timed out sending token")
|
|
case dialTokens <- r.URL.Query().Get("resume_token"):
|
|
// OK
|
|
}
|
|
|
|
sws, err := websocket.Accept(w, r, nil)
|
|
if !assert.NoError(t, err) {
|
|
return
|
|
}
|
|
wsCtx, nc := codersdk.WebsocketNetConn(ctx, sws, websocket.MessageBinary)
|
|
// streamID can be empty because we don't call RPCs in this test.
|
|
wsErr <- svc.ServeConnV2(wsCtx, nc, tailnet.StreamID{})
|
|
}))
|
|
defer svr.Close()
|
|
svrURL, err := url.Parse(svr.URL)
|
|
require.NoError(t, err)
|
|
|
|
uut := workspacesdk.NewWebsocketDialer(logger, svrURL, &websocket.DialOptions{})
|
|
|
|
clientCh := make(chan tailnet.ControlProtocolClients, 1)
|
|
go func() {
|
|
clients, err := uut.Dial(ctx, nil)
|
|
assert.NoError(t, err)
|
|
clientCh <- clients
|
|
}()
|
|
|
|
gotToken := <-dialTokens
|
|
require.Equal(t, "", gotToken)
|
|
|
|
clients := testutil.TryReceive(ctx, t, clientCh)
|
|
clients.Closer.Close()
|
|
|
|
err = testutil.TryReceive(ctx, t, wsErr)
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
func TestWebsocketDialer_ResumeTokenFailure(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
logger := slogtest.Make(t, &slogtest.Options{
|
|
IgnoreErrors: true,
|
|
}).Leveled(slog.LevelDebug)
|
|
|
|
fTokenProv := newFakeTokenController(ctx, t)
|
|
fCoord := tailnettest.NewFakeCoordinator()
|
|
var coord tailnet.Coordinator = fCoord
|
|
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
|
|
coordPtr.Store(&coord)
|
|
|
|
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
|
|
Logger: logger,
|
|
CoordPtr: &coordPtr,
|
|
DERPMapUpdateFrequency: time.Hour,
|
|
DERPMapFn: func() *tailcfg.DERPMap { return &tailcfg.DERPMap{} },
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
dialTokens := make(chan string, 1)
|
|
wsErr := make(chan error, 1)
|
|
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
resumeToken := r.URL.Query().Get("resume_token")
|
|
select {
|
|
case <-ctx.Done():
|
|
t.Error("timed out sending token")
|
|
case dialTokens <- resumeToken:
|
|
// OK
|
|
}
|
|
|
|
if resumeToken != "" {
|
|
httpapi.Write(ctx, w, http.StatusUnauthorized, codersdk.Response{
|
|
Message: workspacesdk.CoordinateAPIInvalidResumeToken,
|
|
Validations: []codersdk.ValidationError{
|
|
{Field: "resume_token", Detail: workspacesdk.CoordinateAPIInvalidResumeToken},
|
|
},
|
|
})
|
|
return
|
|
}
|
|
sws, err := websocket.Accept(w, r, nil)
|
|
if !assert.NoError(t, err) {
|
|
return
|
|
}
|
|
wsCtx, nc := codersdk.WebsocketNetConn(ctx, sws, websocket.MessageBinary)
|
|
// streamID can be empty because we don't call RPCs in this test.
|
|
wsErr <- svc.ServeConnV2(wsCtx, nc, tailnet.StreamID{})
|
|
}))
|
|
defer svr.Close()
|
|
svrURL, err := url.Parse(svr.URL)
|
|
require.NoError(t, err)
|
|
|
|
uut := workspacesdk.NewWebsocketDialer(logger, svrURL, &websocket.DialOptions{})
|
|
|
|
errCh := make(chan error, 1)
|
|
go func() {
|
|
_, err := uut.Dial(ctx, fTokenProv)
|
|
errCh <- err
|
|
}()
|
|
|
|
call := testutil.TryReceive(ctx, t, fTokenProv.tokenCalls)
|
|
call <- tokenResponse{"test token", true}
|
|
gotToken := <-dialTokens
|
|
require.Equal(t, "test token", gotToken)
|
|
|
|
err = testutil.TryReceive(ctx, t, errCh)
|
|
require.Error(t, err)
|
|
|
|
// redial should not use the token
|
|
clientCh := make(chan tailnet.ControlProtocolClients, 1)
|
|
go func() {
|
|
clients, err := uut.Dial(ctx, fTokenProv)
|
|
assert.NoError(t, err)
|
|
clientCh <- clients
|
|
}()
|
|
gotToken = <-dialTokens
|
|
require.Equal(t, "", gotToken)
|
|
|
|
clients := testutil.TryReceive(ctx, t, clientCh)
|
|
require.Error(t, err)
|
|
clients.Closer.Close()
|
|
err = testutil.TryReceive(ctx, t, wsErr)
|
|
require.NoError(t, err)
|
|
|
|
// Successful dial should reset to using token again
|
|
go func() {
|
|
_, err := uut.Dial(ctx, fTokenProv)
|
|
errCh <- err
|
|
}()
|
|
call = testutil.TryReceive(ctx, t, fTokenProv.tokenCalls)
|
|
call <- tokenResponse{"test token", true}
|
|
gotToken = <-dialTokens
|
|
require.Equal(t, "test token", gotToken)
|
|
err = testutil.TryReceive(ctx, t, errCh)
|
|
require.Error(t, err)
|
|
}
|
|
|
|
func TestWebsocketDialer_UnauthenticatedFailFast(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
logger := slogtest.Make(t, &slogtest.Options{
|
|
IgnoreErrors: true,
|
|
}).Leveled(slog.LevelDebug)
|
|
|
|
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
httpapi.Write(ctx, w, http.StatusUnauthorized, codersdk.Response{})
|
|
}))
|
|
defer svr.Close()
|
|
svrURL, err := url.Parse(svr.URL)
|
|
require.NoError(t, err)
|
|
|
|
uut := workspacesdk.NewWebsocketDialer(logger, svrURL, &websocket.DialOptions{})
|
|
|
|
_, err = uut.Dial(ctx, nil)
|
|
require.Error(t, err)
|
|
}
|
|
|
|
func TestWebsocketDialer_UnauthorizedFailFast(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
logger := slogtest.Make(t, &slogtest.Options{
|
|
IgnoreErrors: true,
|
|
}).Leveled(slog.LevelDebug)
|
|
|
|
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
httpapi.Write(ctx, w, http.StatusUnauthorized, codersdk.Response{})
|
|
}))
|
|
defer svr.Close()
|
|
svrURL, err := url.Parse(svr.URL)
|
|
require.NoError(t, err)
|
|
|
|
uut := workspacesdk.NewWebsocketDialer(logger, svrURL, &websocket.DialOptions{})
|
|
|
|
_, err = uut.Dial(ctx, nil)
|
|
require.Error(t, err)
|
|
}
|
|
|
|
func TestWebsocketDialer_UplevelVersion(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
|
|
|
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
sVer := apiversion.New(2, 2)
|
|
|
|
// the following matches what Coderd does;
|
|
// c.f. coderd/workspaceagents.go: workspaceAgentClientCoordinate
|
|
cVer := r.URL.Query().Get("version")
|
|
if err := sVer.Validate(cVer); err != nil {
|
|
httpapi.Write(ctx, w, http.StatusBadRequest, codersdk.Response{
|
|
Message: workspacesdk.AgentAPIMismatchMessage,
|
|
Validations: []codersdk.ValidationError{
|
|
{Field: "version", Detail: err.Error()},
|
|
},
|
|
})
|
|
return
|
|
}
|
|
}))
|
|
svrURL, err := url.Parse(svr.URL)
|
|
require.NoError(t, err)
|
|
|
|
uut := workspacesdk.NewWebsocketDialer(
|
|
logger, svrURL, &websocket.DialOptions{},
|
|
workspacesdk.WithWorkspaceUpdates(&tailnetproto.WorkspaceUpdatesRequest{}),
|
|
)
|
|
|
|
errCh := make(chan error, 1)
|
|
go func() {
|
|
_, err := uut.Dial(ctx, nil)
|
|
errCh <- err
|
|
}()
|
|
|
|
err = testutil.TryReceive(ctx, t, errCh)
|
|
var sdkErr *codersdk.Error
|
|
require.ErrorAs(t, err, &sdkErr)
|
|
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
|
require.Equal(t, workspacesdk.AgentAPIMismatchMessage, sdkErr.Message)
|
|
require.NotEmpty(t, sdkErr.Helper)
|
|
}
|
|
|
|
func TestWebsocketDialer_WorkspaceUpdates(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
logger := slogtest.Make(t, &slogtest.Options{
|
|
IgnoreErrors: true,
|
|
}).Leveled(slog.LevelDebug)
|
|
|
|
fCoord := tailnettest.NewFakeCoordinator()
|
|
var coord tailnet.Coordinator = fCoord
|
|
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
|
|
coordPtr.Store(&coord)
|
|
ctrl := gomock.NewController(t)
|
|
mProvider := tailnettest.NewMockWorkspaceUpdatesProvider(ctrl)
|
|
|
|
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
|
|
Logger: logger,
|
|
CoordPtr: &coordPtr,
|
|
DERPMapUpdateFrequency: time.Hour,
|
|
DERPMapFn: func() *tailcfg.DERPMap { return &tailcfg.DERPMap{} },
|
|
WorkspaceUpdatesProvider: mProvider,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
wsErr := make(chan error, 1)
|
|
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// need 2.3 for WorkspaceUpdates RPC
|
|
cVer := r.URL.Query().Get("version")
|
|
assert.Equal(t, "2.3", cVer)
|
|
|
|
sws, err := websocket.Accept(w, r, nil)
|
|
if !assert.NoError(t, err) {
|
|
return
|
|
}
|
|
wsCtx, nc := codersdk.WebsocketNetConn(ctx, sws, websocket.MessageBinary)
|
|
// streamID can be empty because we don't call RPCs in this test.
|
|
wsErr <- svc.ServeConnV2(wsCtx, nc, tailnet.StreamID{})
|
|
}))
|
|
defer svr.Close()
|
|
svrURL, err := url.Parse(svr.URL)
|
|
require.NoError(t, err)
|
|
|
|
userID := uuid.UUID{88}
|
|
|
|
mSub := tailnettest.NewMockSubscription(ctrl)
|
|
updateCh := make(chan *tailnetproto.WorkspaceUpdate, 1)
|
|
mProvider.EXPECT().Subscribe(gomock.Any(), userID).Times(1).Return(mSub, nil)
|
|
mSub.EXPECT().Updates().MinTimes(1).Return(updateCh)
|
|
mSub.EXPECT().Close().Times(1).Return(nil)
|
|
|
|
uut := workspacesdk.NewWebsocketDialer(
|
|
logger, svrURL, &websocket.DialOptions{},
|
|
workspacesdk.WithWorkspaceUpdates(&tailnetproto.WorkspaceUpdatesRequest{
|
|
WorkspaceOwnerId: userID[:],
|
|
}),
|
|
)
|
|
|
|
clients, err := uut.Dial(ctx, nil)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, clients.WorkspaceUpdates)
|
|
|
|
wsID := uuid.UUID{99}
|
|
expectedUpdate := &tailnetproto.WorkspaceUpdate{
|
|
UpsertedWorkspaces: []*tailnetproto.Workspace{
|
|
{Id: wsID[:]},
|
|
},
|
|
}
|
|
updateCh <- expectedUpdate
|
|
|
|
gotUpdate, err := clients.WorkspaceUpdates.Recv()
|
|
require.NoError(t, err)
|
|
require.Equal(t, wsID[:], gotUpdate.GetUpsertedWorkspaces()[0].GetId())
|
|
|
|
clients.Closer.Close()
|
|
|
|
err = testutil.TryReceive(ctx, t, wsErr)
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
type fakeResumeTokenController struct {
|
|
ctx context.Context
|
|
t testing.TB
|
|
tokenCalls chan chan tokenResponse
|
|
}
|
|
|
|
func (*fakeResumeTokenController) New(tailnet.ResumeTokenClient) tailnet.CloserWaiter {
|
|
panic("not implemented")
|
|
}
|
|
|
|
func (f *fakeResumeTokenController) Token() (string, bool) {
|
|
call := make(chan tokenResponse)
|
|
select {
|
|
case <-f.ctx.Done():
|
|
f.t.Error("timeout on Token() call")
|
|
case f.tokenCalls <- call:
|
|
// OK
|
|
}
|
|
select {
|
|
case <-f.ctx.Done():
|
|
f.t.Error("timeout on Token() response")
|
|
return "", false
|
|
case r := <-call:
|
|
return r.token, r.ok
|
|
}
|
|
}
|
|
|
|
var _ tailnet.ResumeTokenController = &fakeResumeTokenController{}
|
|
|
|
func newFakeTokenController(ctx context.Context, t testing.TB) *fakeResumeTokenController {
|
|
return &fakeResumeTokenController{
|
|
ctx: ctx,
|
|
t: t,
|
|
tokenCalls: make(chan chan tokenResponse),
|
|
}
|
|
}
|
|
|
|
type tokenResponse struct {
|
|
token string
|
|
ok bool
|
|
}
|