mirror of
https://github.com/coder/coder.git
synced 2026-06-03 13:08:25 +00:00
bddb808b25
Fixes all our Go file imports to match the preferred spec that we've _mostly_ been using. For example: ``` import ( "context" "time" "github.com/prometheus/client_golang/prometheus" "golang.org/x/xerrors" "gopkg.in/natefinch/lumberjack.v2" "cdr.dev/slog/v3" "github.com/coder/coder/v2/codersdk/agentsdk" "github.com/coder/serpent" ) ``` 3 groups: standard library, 3rd partly libs, Coder libs. This PR makes the change across the codebase. The PR in the stack above modifies our formatting to maintain this state of affairs, and is a separate PR so it's possible to review that one in detail.
475 lines
14 KiB
Go
475 lines
14 KiB
Go
package workspacesdk_test
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"go.uber.org/mock/gomock"
|
|
"tailscale.com/tailcfg"
|
|
|
|
"cdr.dev/slog/v3"
|
|
"cdr.dev/slog/v3/sloggers/slogtest"
|
|
"github.com/coder/coder/v2/apiversion"
|
|
"github.com/coder/coder/v2/coderd/httpapi"
|
|
"github.com/coder/coder/v2/codersdk"
|
|
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
|
"github.com/coder/coder/v2/tailnet"
|
|
tailnetproto "github.com/coder/coder/v2/tailnet/proto"
|
|
"github.com/coder/coder/v2/tailnet/tailnettest"
|
|
"github.com/coder/coder/v2/testutil"
|
|
"github.com/coder/websocket"
|
|
)
|
|
|
|
func TestWebsocketDialer_TokenController(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
|
logger := slogtest.Make(t, &slogtest.Options{
|
|
IgnoreErrors: true,
|
|
}).Leveled(slog.LevelDebug)
|
|
|
|
fTokenProv := newFakeTokenController(ctx, t)
|
|
fCoord := tailnettest.NewFakeCoordinator()
|
|
var coord tailnet.Coordinator = fCoord
|
|
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
|
|
coordPtr.Store(&coord)
|
|
|
|
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
|
|
Logger: logger,
|
|
CoordPtr: &coordPtr,
|
|
DERPMapUpdateFrequency: time.Hour,
|
|
DERPMapFn: func() *tailcfg.DERPMap { return &tailcfg.DERPMap{} },
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
dialTokens := make(chan string, 1)
|
|
wsErr := make(chan error, 1)
|
|
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
select {
|
|
case <-ctx.Done():
|
|
t.Error("timed out sending token")
|
|
case dialTokens <- r.URL.Query().Get("resume_token"):
|
|
// OK
|
|
}
|
|
|
|
sws, err := websocket.Accept(w, r, nil)
|
|
if !assert.NoError(t, err) {
|
|
return
|
|
}
|
|
wsCtx, nc := codersdk.WebsocketNetConn(ctx, sws, websocket.MessageBinary)
|
|
// streamID can be empty because we don't call RPCs in this test.
|
|
wsErr <- svc.ServeConnV2(wsCtx, nc, tailnet.StreamID{})
|
|
}))
|
|
defer svr.Close()
|
|
svrURL, err := url.Parse(svr.URL)
|
|
require.NoError(t, err)
|
|
|
|
uut := workspacesdk.NewWebsocketDialer(logger, svrURL, &websocket.DialOptions{})
|
|
|
|
clientCh := make(chan tailnet.ControlProtocolClients, 1)
|
|
go func() {
|
|
clients, err := uut.Dial(ctx, fTokenProv)
|
|
assert.NoError(t, err)
|
|
clientCh <- clients
|
|
}()
|
|
|
|
call := testutil.TryReceive(ctx, t, fTokenProv.tokenCalls)
|
|
call <- tokenResponse{"test token", true}
|
|
gotToken := <-dialTokens
|
|
require.Equal(t, "test token", gotToken)
|
|
|
|
clients := testutil.TryReceive(ctx, t, clientCh)
|
|
clients.Closer.Close()
|
|
|
|
err = testutil.TryReceive(ctx, t, wsErr)
|
|
require.NoError(t, err)
|
|
|
|
clientCh = make(chan tailnet.ControlProtocolClients, 1)
|
|
go func() {
|
|
clients, err := uut.Dial(ctx, fTokenProv)
|
|
assert.NoError(t, err)
|
|
clientCh <- clients
|
|
}()
|
|
|
|
call = testutil.TryReceive(ctx, t, fTokenProv.tokenCalls)
|
|
call <- tokenResponse{"test token", false}
|
|
gotToken = <-dialTokens
|
|
require.Equal(t, "", gotToken)
|
|
|
|
clients = testutil.TryReceive(ctx, t, clientCh)
|
|
require.Nil(t, clients.WorkspaceUpdates)
|
|
clients.Closer.Close()
|
|
|
|
err = testutil.TryReceive(ctx, t, wsErr)
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
func TestWebsocketDialer_NoTokenController(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
logger := slogtest.Make(t, &slogtest.Options{
|
|
IgnoreErrors: true,
|
|
}).Leveled(slog.LevelDebug)
|
|
|
|
fCoord := tailnettest.NewFakeCoordinator()
|
|
var coord tailnet.Coordinator = fCoord
|
|
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
|
|
coordPtr.Store(&coord)
|
|
|
|
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
|
|
Logger: logger,
|
|
CoordPtr: &coordPtr,
|
|
DERPMapUpdateFrequency: time.Hour,
|
|
DERPMapFn: func() *tailcfg.DERPMap { return &tailcfg.DERPMap{} },
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
dialTokens := make(chan string, 1)
|
|
wsErr := make(chan error, 1)
|
|
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
select {
|
|
case <-ctx.Done():
|
|
t.Error("timed out sending token")
|
|
case dialTokens <- r.URL.Query().Get("resume_token"):
|
|
// OK
|
|
}
|
|
|
|
sws, err := websocket.Accept(w, r, nil)
|
|
if !assert.NoError(t, err) {
|
|
return
|
|
}
|
|
wsCtx, nc := codersdk.WebsocketNetConn(ctx, sws, websocket.MessageBinary)
|
|
// streamID can be empty because we don't call RPCs in this test.
|
|
wsErr <- svc.ServeConnV2(wsCtx, nc, tailnet.StreamID{})
|
|
}))
|
|
defer svr.Close()
|
|
svrURL, err := url.Parse(svr.URL)
|
|
require.NoError(t, err)
|
|
|
|
uut := workspacesdk.NewWebsocketDialer(logger, svrURL, &websocket.DialOptions{})
|
|
|
|
clientCh := make(chan tailnet.ControlProtocolClients, 1)
|
|
go func() {
|
|
clients, err := uut.Dial(ctx, nil)
|
|
assert.NoError(t, err)
|
|
clientCh <- clients
|
|
}()
|
|
|
|
gotToken := <-dialTokens
|
|
require.Equal(t, "", gotToken)
|
|
|
|
clients := testutil.TryReceive(ctx, t, clientCh)
|
|
clients.Closer.Close()
|
|
|
|
err = testutil.TryReceive(ctx, t, wsErr)
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
func TestWebsocketDialer_ResumeTokenFailure(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
logger := slogtest.Make(t, &slogtest.Options{
|
|
IgnoreErrors: true,
|
|
}).Leveled(slog.LevelDebug)
|
|
|
|
fTokenProv := newFakeTokenController(ctx, t)
|
|
fCoord := tailnettest.NewFakeCoordinator()
|
|
var coord tailnet.Coordinator = fCoord
|
|
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
|
|
coordPtr.Store(&coord)
|
|
|
|
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
|
|
Logger: logger,
|
|
CoordPtr: &coordPtr,
|
|
DERPMapUpdateFrequency: time.Hour,
|
|
DERPMapFn: func() *tailcfg.DERPMap { return &tailcfg.DERPMap{} },
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
dialTokens := make(chan string, 1)
|
|
wsErr := make(chan error, 1)
|
|
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
resumeToken := r.URL.Query().Get("resume_token")
|
|
select {
|
|
case <-ctx.Done():
|
|
t.Error("timed out sending token")
|
|
case dialTokens <- resumeToken:
|
|
// OK
|
|
}
|
|
|
|
if resumeToken != "" {
|
|
httpapi.Write(ctx, w, http.StatusUnauthorized, codersdk.Response{
|
|
Message: workspacesdk.CoordinateAPIInvalidResumeToken,
|
|
Validations: []codersdk.ValidationError{
|
|
{Field: "resume_token", Detail: workspacesdk.CoordinateAPIInvalidResumeToken},
|
|
},
|
|
})
|
|
return
|
|
}
|
|
sws, err := websocket.Accept(w, r, nil)
|
|
if !assert.NoError(t, err) {
|
|
return
|
|
}
|
|
wsCtx, nc := codersdk.WebsocketNetConn(ctx, sws, websocket.MessageBinary)
|
|
// streamID can be empty because we don't call RPCs in this test.
|
|
wsErr <- svc.ServeConnV2(wsCtx, nc, tailnet.StreamID{})
|
|
}))
|
|
defer svr.Close()
|
|
svrURL, err := url.Parse(svr.URL)
|
|
require.NoError(t, err)
|
|
|
|
uut := workspacesdk.NewWebsocketDialer(logger, svrURL, &websocket.DialOptions{})
|
|
|
|
errCh := make(chan error, 1)
|
|
go func() {
|
|
_, err := uut.Dial(ctx, fTokenProv)
|
|
errCh <- err
|
|
}()
|
|
|
|
call := testutil.TryReceive(ctx, t, fTokenProv.tokenCalls)
|
|
call <- tokenResponse{"test token", true}
|
|
gotToken := <-dialTokens
|
|
require.Equal(t, "test token", gotToken)
|
|
|
|
err = testutil.TryReceive(ctx, t, errCh)
|
|
require.Error(t, err)
|
|
|
|
// redial should not use the token
|
|
clientCh := make(chan tailnet.ControlProtocolClients, 1)
|
|
go func() {
|
|
clients, err := uut.Dial(ctx, fTokenProv)
|
|
assert.NoError(t, err)
|
|
clientCh <- clients
|
|
}()
|
|
gotToken = <-dialTokens
|
|
require.Equal(t, "", gotToken)
|
|
|
|
clients := testutil.TryReceive(ctx, t, clientCh)
|
|
require.Error(t, err)
|
|
clients.Closer.Close()
|
|
err = testutil.TryReceive(ctx, t, wsErr)
|
|
require.NoError(t, err)
|
|
|
|
// Successful dial should reset to using token again
|
|
go func() {
|
|
_, err := uut.Dial(ctx, fTokenProv)
|
|
errCh <- err
|
|
}()
|
|
call = testutil.TryReceive(ctx, t, fTokenProv.tokenCalls)
|
|
call <- tokenResponse{"test token", true}
|
|
gotToken = <-dialTokens
|
|
require.Equal(t, "test token", gotToken)
|
|
err = testutil.TryReceive(ctx, t, errCh)
|
|
require.Error(t, err)
|
|
}
|
|
|
|
func TestWebsocketDialer_UnauthenticatedFailFast(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
logger := slogtest.Make(t, &slogtest.Options{
|
|
IgnoreErrors: true,
|
|
}).Leveled(slog.LevelDebug)
|
|
|
|
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
httpapi.Write(ctx, w, http.StatusUnauthorized, codersdk.Response{})
|
|
}))
|
|
defer svr.Close()
|
|
svrURL, err := url.Parse(svr.URL)
|
|
require.NoError(t, err)
|
|
|
|
uut := workspacesdk.NewWebsocketDialer(logger, svrURL, &websocket.DialOptions{})
|
|
|
|
_, err = uut.Dial(ctx, nil)
|
|
require.Error(t, err)
|
|
}
|
|
|
|
func TestWebsocketDialer_UnauthorizedFailFast(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
logger := slogtest.Make(t, &slogtest.Options{
|
|
IgnoreErrors: true,
|
|
}).Leveled(slog.LevelDebug)
|
|
|
|
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
httpapi.Write(ctx, w, http.StatusUnauthorized, codersdk.Response{})
|
|
}))
|
|
defer svr.Close()
|
|
svrURL, err := url.Parse(svr.URL)
|
|
require.NoError(t, err)
|
|
|
|
uut := workspacesdk.NewWebsocketDialer(logger, svrURL, &websocket.DialOptions{})
|
|
|
|
_, err = uut.Dial(ctx, nil)
|
|
require.Error(t, err)
|
|
}
|
|
|
|
func TestWebsocketDialer_UplevelVersion(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
|
|
|
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
sVer := apiversion.New(2, 2)
|
|
|
|
// the following matches what Coderd does;
|
|
// c.f. coderd/workspaceagents.go: workspaceAgentClientCoordinate
|
|
cVer := r.URL.Query().Get("version")
|
|
if err := sVer.Validate(cVer); err != nil {
|
|
httpapi.Write(ctx, w, http.StatusBadRequest, codersdk.Response{
|
|
Message: workspacesdk.AgentAPIMismatchMessage,
|
|
Validations: []codersdk.ValidationError{
|
|
{Field: "version", Detail: err.Error()},
|
|
},
|
|
})
|
|
return
|
|
}
|
|
}))
|
|
svrURL, err := url.Parse(svr.URL)
|
|
require.NoError(t, err)
|
|
|
|
uut := workspacesdk.NewWebsocketDialer(
|
|
logger, svrURL, &websocket.DialOptions{},
|
|
workspacesdk.WithWorkspaceUpdates(&tailnetproto.WorkspaceUpdatesRequest{}),
|
|
)
|
|
|
|
errCh := make(chan error, 1)
|
|
go func() {
|
|
_, err := uut.Dial(ctx, nil)
|
|
errCh <- err
|
|
}()
|
|
|
|
err = testutil.TryReceive(ctx, t, errCh)
|
|
var sdkErr *codersdk.Error
|
|
require.ErrorAs(t, err, &sdkErr)
|
|
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
|
require.Equal(t, workspacesdk.AgentAPIMismatchMessage, sdkErr.Message)
|
|
require.NotEmpty(t, sdkErr.Helper)
|
|
}
|
|
|
|
func TestWebsocketDialer_WorkspaceUpdates(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
logger := slogtest.Make(t, &slogtest.Options{
|
|
IgnoreErrors: true,
|
|
}).Leveled(slog.LevelDebug)
|
|
|
|
fCoord := tailnettest.NewFakeCoordinator()
|
|
var coord tailnet.Coordinator = fCoord
|
|
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
|
|
coordPtr.Store(&coord)
|
|
ctrl := gomock.NewController(t)
|
|
mProvider := tailnettest.NewMockWorkspaceUpdatesProvider(ctrl)
|
|
|
|
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
|
|
Logger: logger,
|
|
CoordPtr: &coordPtr,
|
|
DERPMapUpdateFrequency: time.Hour,
|
|
DERPMapFn: func() *tailcfg.DERPMap { return &tailcfg.DERPMap{} },
|
|
WorkspaceUpdatesProvider: mProvider,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
wsErr := make(chan error, 1)
|
|
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// need 2.3 for WorkspaceUpdates RPC
|
|
cVer := r.URL.Query().Get("version")
|
|
assert.Equal(t, "2.3", cVer)
|
|
|
|
sws, err := websocket.Accept(w, r, nil)
|
|
if !assert.NoError(t, err) {
|
|
return
|
|
}
|
|
wsCtx, nc := codersdk.WebsocketNetConn(ctx, sws, websocket.MessageBinary)
|
|
// streamID can be empty because we don't call RPCs in this test.
|
|
wsErr <- svc.ServeConnV2(wsCtx, nc, tailnet.StreamID{})
|
|
}))
|
|
defer svr.Close()
|
|
svrURL, err := url.Parse(svr.URL)
|
|
require.NoError(t, err)
|
|
|
|
userID := uuid.UUID{88}
|
|
|
|
mSub := tailnettest.NewMockSubscription(ctrl)
|
|
updateCh := make(chan *tailnetproto.WorkspaceUpdate, 1)
|
|
mProvider.EXPECT().Subscribe(gomock.Any(), userID).Times(1).Return(mSub, nil)
|
|
mSub.EXPECT().Updates().MinTimes(1).Return(updateCh)
|
|
mSub.EXPECT().Close().Times(1).Return(nil)
|
|
|
|
uut := workspacesdk.NewWebsocketDialer(
|
|
logger, svrURL, &websocket.DialOptions{},
|
|
workspacesdk.WithWorkspaceUpdates(&tailnetproto.WorkspaceUpdatesRequest{
|
|
WorkspaceOwnerId: userID[:],
|
|
}),
|
|
)
|
|
|
|
clients, err := uut.Dial(ctx, nil)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, clients.WorkspaceUpdates)
|
|
|
|
wsID := uuid.UUID{99}
|
|
expectedUpdate := &tailnetproto.WorkspaceUpdate{
|
|
UpsertedWorkspaces: []*tailnetproto.Workspace{
|
|
{Id: wsID[:]},
|
|
},
|
|
}
|
|
updateCh <- expectedUpdate
|
|
|
|
gotUpdate, err := clients.WorkspaceUpdates.Recv()
|
|
require.NoError(t, err)
|
|
require.Equal(t, wsID[:], gotUpdate.GetUpsertedWorkspaces()[0].GetId())
|
|
|
|
clients.Closer.Close()
|
|
|
|
err = testutil.TryReceive(ctx, t, wsErr)
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
type fakeResumeTokenController struct {
|
|
ctx context.Context
|
|
t testing.TB
|
|
tokenCalls chan chan tokenResponse
|
|
}
|
|
|
|
func (*fakeResumeTokenController) New(tailnet.ResumeTokenClient) tailnet.CloserWaiter {
|
|
panic("not implemented")
|
|
}
|
|
|
|
func (f *fakeResumeTokenController) Token() (string, bool) {
|
|
call := make(chan tokenResponse)
|
|
select {
|
|
case <-f.ctx.Done():
|
|
f.t.Error("timeout on Token() call")
|
|
case f.tokenCalls <- call:
|
|
// OK
|
|
}
|
|
select {
|
|
case <-f.ctx.Done():
|
|
f.t.Error("timeout on Token() response")
|
|
return "", false
|
|
case r := <-call:
|
|
return r.token, r.ok
|
|
}
|
|
}
|
|
|
|
var _ tailnet.ResumeTokenController = &fakeResumeTokenController{}
|
|
|
|
func newFakeTokenController(ctx context.Context, t testing.TB) *fakeResumeTokenController {
|
|
return &fakeResumeTokenController{
|
|
ctx: ctx,
|
|
t: t,
|
|
tokenCalls: make(chan chan tokenResponse),
|
|
}
|
|
}
|
|
|
|
type tokenResponse struct {
|
|
token string
|
|
ok bool
|
|
}
|