diff --git a/cli/exp_task_status_test.go b/cli/exp_task_status_test.go index 6631980ac1..b520d27288 100644 --- a/cli/exp_task_status_test.go +++ b/cli/exp_task_status_test.go @@ -243,13 +243,12 @@ STATE CHANGED STATUS STATE MESSAGE ctx = testutil.Context(t, testutil.WaitShort) now = time.Now().UTC() // TODO: replace with quartz srv = httptest.NewServer(http.HandlerFunc(tc.hf(ctx, now))) - client = new(codersdk.Client) + client = codersdk.New(testutil.MustURL(t, srv.URL)) sb = strings.Builder{} args = []string{"exp", "task", "status", "--watch-interval", testutil.IntervalFast.String()} ) t.Cleanup(srv.Close) - client.URL = testutil.MustURL(t, srv.URL) args = append(args, tc.args...) inv, root := clitest.New(t, args...) inv.Stdout = &sb diff --git a/cli/exp_taskcreate_test.go b/cli/exp_taskcreate_test.go index 121f22eb52..f49c2fee11 100644 --- a/cli/exp_taskcreate_test.go +++ b/cli/exp_taskcreate_test.go @@ -5,14 +5,12 @@ import ( "fmt" "net/http" "net/http/httptest" - "net/url" "strings" "testing" "time" "github.com/google/uuid" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "github.com/coder/coder/v2/cli/clitest" "github.com/coder/coder/v2/cli/cliui" @@ -236,7 +234,7 @@ func TestTaskCreate(t *testing.T) { var ( ctx = testutil.Context(t, testutil.WaitShort) srv = httptest.NewServer(tt.handler(t, ctx)) - client = new(codersdk.Client) + client = codersdk.New(testutil.MustURL(t, srv.URL)) args = []string{"exp", "task", "create"} sb strings.Builder err error @@ -244,9 +242,6 @@ func TestTaskCreate(t *testing.T) { t.Cleanup(srv.Close) - client.URL, err = url.Parse(srv.URL) - require.NoError(t, err) - inv, root := clitest.New(t, append(args, tt.args...)...) inv.Environ = serpent.ParseEnviron(tt.env, "") inv.Stdout = &sb diff --git a/cli/root.go b/cli/root.go index b3e67a46ad..ed6869b6a1 100644 --- a/cli/root.go +++ b/cli/root.go @@ -635,6 +635,9 @@ func (r *RootCmd) HeaderTransport(ctx context.Context, serverURL *url.URL) (*cod } func (r *RootCmd) configureClient(ctx context.Context, client *codersdk.Client, serverURL *url.URL, inv *serpent.Invocation) error { + if client.SessionTokenProvider == nil { + client.SessionTokenProvider = codersdk.FixedSessionTokenProvider{} + } transport := http.DefaultTransport transport = wrapTransportWithTelemetryHeader(transport, inv) if !r.noVersionCheck { diff --git a/coderd/coderdtest/oidctest/idp.go b/coderd/coderdtest/oidctest/idp.go index c7f7d35937..a76f6447dc 100644 --- a/coderd/coderdtest/oidctest/idp.go +++ b/coderd/coderdtest/oidctest/idp.go @@ -641,7 +641,7 @@ func (f *FakeIDP) LoginWithClient(t testing.TB, client *codersdk.Client, idToken // ExternalLogin does the oauth2 flow for external auth providers. This requires // an authenticated coder client. -func (f *FakeIDP) ExternalLogin(t testing.TB, client *codersdk.Client, opts ...func(r *http.Request)) { +func (f *FakeIDP) ExternalLogin(t testing.TB, client *codersdk.Client, opts ...codersdk.RequestOption) { coderOauthURL, err := client.URL.Parse(fmt.Sprintf("/external-auth/%s/callback", f.externalProviderID)) require.NoError(t, err) f.SetRedirect(t, coderOauthURL.String()) @@ -660,11 +660,7 @@ func (f *FakeIDP) ExternalLogin(t testing.TB, client *codersdk.Client, opts ...f req, err := http.NewRequestWithContext(ctx, "GET", coderOauthURL.String(), nil) require.NoError(t, err) // External auth flow requires the user be authenticated. - headerName := client.SessionTokenHeader - if headerName == "" { - headerName = codersdk.SessionTokenHeader - } - req.Header.Set(headerName, client.SessionToken()) + opts = append([]codersdk.RequestOption{client.SessionTokenProvider.AsRequestOption()}, opts...) if cli.Jar == nil { cli.Jar, err = cookiejar.New(nil) require.NoError(t, err, "failed to create cookie jar") diff --git a/coderd/mcp/mcp_test.go b/coderd/mcp/mcp_test.go index 0c53c899b9..b7b5a71478 100644 --- a/coderd/mcp/mcp_test.go +++ b/coderd/mcp/mcp_test.go @@ -115,7 +115,7 @@ func TestMCPHTTP_ToolRegistration(t *testing.T) { require.Contains(t, err.Error(), "client cannot be nil", "Should reject nil client with appropriate error message") // Test registering tools with valid client should succeed - client := &codersdk.Client{} + client := codersdk.New(testutil.MustURL(t, "http://not-used")) err = server.RegisterTools(client) require.NoError(t, err) diff --git a/codersdk/client.go b/codersdk/client.go index 105c8437f8..b6f10465e3 100644 --- a/codersdk/client.go +++ b/codersdk/client.go @@ -108,8 +108,9 @@ var loggableMimeTypes = map[string]struct{}{ // New creates a Coder client for the provided URL. func New(serverURL *url.URL) *Client { return &Client{ - URL: serverURL, - HTTPClient: &http.Client{}, + URL: serverURL, + HTTPClient: &http.Client{}, + SessionTokenProvider: FixedSessionTokenProvider{}, } } @@ -118,18 +119,14 @@ func New(serverURL *url.URL) *Client { type Client struct { // mu protects the fields sessionToken, logger, and logBodies. These // need to be safe for concurrent access. - mu sync.RWMutex - sessionToken string - logger slog.Logger - logBodies bool + mu sync.RWMutex + SessionTokenProvider SessionTokenProvider + logger slog.Logger + logBodies bool HTTPClient *http.Client URL *url.URL - // SessionTokenHeader is an optional custom header to use for setting tokens. By - // default 'Coder-Session-Token' is used. - SessionTokenHeader string - // PlainLogger may be set to log HTTP traffic in a human-readable form. // It uses the LogBodies option. PlainLogger io.Writer @@ -176,14 +173,20 @@ func (c *Client) SetLogBodies(logBodies bool) { func (c *Client) SessionToken() string { c.mu.RLock() defer c.mu.RUnlock() - return c.sessionToken + return c.SessionTokenProvider.GetSessionToken() } -// SetSessionToken returns the currently set token for the client. +// SetSessionToken sets a fixed token for the client. +// Deprecated: Create a new client instead of changing the token after creation. func (c *Client) SetSessionToken(token string) { + c.SetSessionTokenProvider(FixedSessionTokenProvider{SessionToken: token}) +} + +// SetSessionTokenProvider sets the session token provider for the client. +func (c *Client) SetSessionTokenProvider(provider SessionTokenProvider) { c.mu.Lock() defer c.mu.Unlock() - c.sessionToken = token + c.SessionTokenProvider = provider } func prefixLines(prefix, s []byte) []byte { @@ -199,6 +202,14 @@ func prefixLines(prefix, s []byte) []byte { // Request performs a HTTP request with the body provided. The caller is // responsible for closing the response body. func (c *Client) Request(ctx context.Context, method, path string, body interface{}, opts ...RequestOption) (*http.Response, error) { + opts = append([]RequestOption{c.SessionTokenProvider.AsRequestOption()}, opts...) + return c.RequestWithoutSessionToken(ctx, method, path, body, opts...) +} + +// RequestWithoutSessionToken performs a HTTP request. It is similar to Request, but does not set +// the session token in the request header, nor does it make a call to the SessionTokenProvider. +// This allows session token providers to call this method without causing reentrancy issues. +func (c *Client) RequestWithoutSessionToken(ctx context.Context, method, path string, body interface{}, opts ...RequestOption) (*http.Response, error) { if ctx == nil { return nil, xerrors.Errorf("context should not be nil") } @@ -248,12 +259,6 @@ func (c *Client) Request(ctx context.Context, method, path string, body interfac return nil, xerrors.Errorf("create request: %w", err) } - tokenHeader := c.SessionTokenHeader - if tokenHeader == "" { - tokenHeader = SessionTokenHeader - } - req.Header.Set(tokenHeader, c.SessionToken()) - if r != nil { req.Header.Set("Content-Type", "application/json") } @@ -345,20 +350,10 @@ func (c *Client) Dial(ctx context.Context, path string, opts *websocket.DialOpti return nil, err } - tokenHeader := c.SessionTokenHeader - if tokenHeader == "" { - tokenHeader = SessionTokenHeader - } - if opts == nil { opts = &websocket.DialOptions{} } - if opts.HTTPHeader == nil { - opts.HTTPHeader = http.Header{} - } - if opts.HTTPHeader.Get(tokenHeader) == "" { - opts.HTTPHeader.Set(tokenHeader, c.SessionToken()) - } + c.SessionTokenProvider.SetDialOption(opts) conn, resp, err := websocket.Dial(ctx, u.String(), opts) if resp != nil && resp.Body != nil { diff --git a/codersdk/credentials.go b/codersdk/credentials.go new file mode 100644 index 0000000000..06dc8cc22a --- /dev/null +++ b/codersdk/credentials.go @@ -0,0 +1,55 @@ +package codersdk + +import ( + "net/http" + + "github.com/coder/websocket" +) + +// SessionTokenProvider provides the session token to access the Coder service (coderd). +// @typescript-ignore SessionTokenProvider +type SessionTokenProvider interface { + // AsRequestOption returns a request option that attaches the session token to an HTTP request. + AsRequestOption() RequestOption + // SetDialOption sets the session token on a websocket request via DialOptions + SetDialOption(options *websocket.DialOptions) + // GetSessionToken returns the session token as a string. + GetSessionToken() string +} + +// FixedSessionTokenProvider provides a given, fixed, session token. E.g. one read from file or environment variable +// at the program start. +// @typescript-ignore FixedSessionTokenProvider +type FixedSessionTokenProvider struct { + SessionToken string + // SessionTokenHeader is an optional custom header to use for setting tokens. By + // default, 'Coder-Session-Token' is used. + SessionTokenHeader string +} + +func (f FixedSessionTokenProvider) AsRequestOption() RequestOption { + return func(req *http.Request) { + tokenHeader := f.SessionTokenHeader + if tokenHeader == "" { + tokenHeader = SessionTokenHeader + } + req.Header.Set(tokenHeader, f.SessionToken) + } +} + +func (f FixedSessionTokenProvider) GetSessionToken() string { + return f.SessionToken +} + +func (f FixedSessionTokenProvider) SetDialOption(opts *websocket.DialOptions) { + tokenHeader := f.SessionTokenHeader + if tokenHeader == "" { + tokenHeader = SessionTokenHeader + } + if opts.HTTPHeader == nil { + opts.HTTPHeader = http.Header{} + } + if opts.HTTPHeader.Get(tokenHeader) == "" { + opts.HTTPHeader.Set(tokenHeader, f.SessionToken) + } +} diff --git a/codersdk/workspacesdk/workspacesdk.go b/codersdk/workspacesdk/workspacesdk.go index ddaec06388..29ddbd1f53 100644 --- a/codersdk/workspacesdk/workspacesdk.go +++ b/codersdk/workspacesdk/workspacesdk.go @@ -215,12 +215,12 @@ func (c *Client) DialAgent(dialCtx context.Context, agentID uuid.UUID, options * options.BlockEndpoints = true } - headers := make(http.Header) - tokenHeader := codersdk.SessionTokenHeader - if c.client.SessionTokenHeader != "" { - tokenHeader = c.client.SessionTokenHeader + wsOptions := &websocket.DialOptions{ + HTTPClient: c.client.HTTPClient, + // Need to disable compression to avoid a data-race. + CompressionMode: websocket.CompressionDisabled, } - headers.Set(tokenHeader, c.client.SessionToken()) + c.client.SessionTokenProvider.SetDialOption(wsOptions) // New context, separate from dialCtx. We don't want to cancel the // connection if dialCtx is canceled. @@ -236,12 +236,7 @@ func (c *Client) DialAgent(dialCtx context.Context, agentID uuid.UUID, options * return nil, xerrors.Errorf("parse url: %w", err) } - dialer := NewWebsocketDialer(options.Logger, coordinateURL, &websocket.DialOptions{ - HTTPClient: c.client.HTTPClient, - HTTPHeader: headers, - // Need to disable compression to avoid a data-race. - CompressionMode: websocket.CompressionDisabled, - }) + dialer := NewWebsocketDialer(options.Logger, coordinateURL, wsOptions) clk := quartz.NewReal() controller := tailnet.NewController(options.Logger, dialer) controller.ResumeTokenCtrl = tailnet.NewBasicResumeTokenController(options.Logger, clk) diff --git a/enterprise/coderd/workspaceproxy_test.go b/enterprise/coderd/workspaceproxy_test.go index 7024ad2366..28d46c0137 100644 --- a/enterprise/coderd/workspaceproxy_test.go +++ b/enterprise/coderd/workspaceproxy_test.go @@ -312,8 +312,7 @@ func TestProxyRegisterDeregister(t *testing.T) { }) require.NoError(t, err) - proxyClient := wsproxysdk.New(client.URL) - proxyClient.SetSessionToken(createRes.ProxyToken) + proxyClient := wsproxysdk.New(client.URL, createRes.ProxyToken) // Register req := wsproxysdk.RegisterWorkspaceProxyRequest{ @@ -427,8 +426,7 @@ func TestProxyRegisterDeregister(t *testing.T) { }) require.NoError(t, err) - proxyClient := wsproxysdk.New(client.URL) - proxyClient.SetSessionToken(createRes.ProxyToken) + proxyClient := wsproxysdk.New(client.URL, createRes.ProxyToken) req := wsproxysdk.RegisterWorkspaceProxyRequest{ AccessURL: "https://proxy.coder.test", @@ -472,8 +470,7 @@ func TestProxyRegisterDeregister(t *testing.T) { }) require.NoError(t, err) - proxyClient := wsproxysdk.New(client.URL) - proxyClient.SetSessionToken(createRes.ProxyToken) + proxyClient := wsproxysdk.New(client.URL, createRes.ProxyToken) err = proxyClient.DeregisterWorkspaceProxy(ctx, wsproxysdk.DeregisterWorkspaceProxyRequest{ ReplicaID: uuid.New(), @@ -501,8 +498,7 @@ func TestProxyRegisterDeregister(t *testing.T) { // Register a replica on proxy 2. This shouldn't be returned by replicas // for proxy 1. - proxyClient2 := wsproxysdk.New(client.URL) - proxyClient2.SetSessionToken(createRes2.ProxyToken) + proxyClient2 := wsproxysdk.New(client.URL, createRes2.ProxyToken) _, err = proxyClient2.RegisterWorkspaceProxy(ctx, wsproxysdk.RegisterWorkspaceProxyRequest{ AccessURL: "https://other.proxy.coder.test", WildcardHostname: "*.other.proxy.coder.test", @@ -516,8 +512,7 @@ func TestProxyRegisterDeregister(t *testing.T) { require.NoError(t, err) // Register replica 1. - proxyClient1 := wsproxysdk.New(client.URL) - proxyClient1.SetSessionToken(createRes1.ProxyToken) + proxyClient1 := wsproxysdk.New(client.URL, createRes1.ProxyToken) req1 := wsproxysdk.RegisterWorkspaceProxyRequest{ AccessURL: "https://one.proxy.coder.test", WildcardHostname: "*.one.proxy.coder.test", @@ -574,8 +569,7 @@ func TestProxyRegisterDeregister(t *testing.T) { }) require.NoError(t, err) - proxyClient := wsproxysdk.New(client.URL) - proxyClient.SetSessionToken(createRes.ProxyToken) + proxyClient := wsproxysdk.New(client.URL, createRes.ProxyToken) for i := 0; i < 100; i++ { ok := false @@ -652,8 +646,7 @@ func TestIssueSignedAppToken(t *testing.T) { t.Run("BadAppRequest", func(t *testing.T) { t.Parallel() - proxyClient := wsproxysdk.New(client.URL) - proxyClient.SetSessionToken(proxyRes.ProxyToken) + proxyClient := wsproxysdk.New(client.URL, proxyRes.ProxyToken) ctx := testutil.Context(t, testutil.WaitLong) _, err := proxyClient.IssueSignedAppToken(ctx, workspaceapps.IssueTokenRequest{ @@ -674,8 +667,7 @@ func TestIssueSignedAppToken(t *testing.T) { } t.Run("OK", func(t *testing.T) { t.Parallel() - proxyClient := wsproxysdk.New(client.URL) - proxyClient.SetSessionToken(proxyRes.ProxyToken) + proxyClient := wsproxysdk.New(client.URL, proxyRes.ProxyToken) ctx := testutil.Context(t, testutil.WaitLong) _, err := proxyClient.IssueSignedAppToken(ctx, goodRequest) @@ -684,8 +676,7 @@ func TestIssueSignedAppToken(t *testing.T) { t.Run("OKHTML", func(t *testing.T) { t.Parallel() - proxyClient := wsproxysdk.New(client.URL) - proxyClient.SetSessionToken(proxyRes.ProxyToken) + proxyClient := wsproxysdk.New(client.URL, proxyRes.ProxyToken) rw := httptest.NewRecorder() ctx := testutil.Context(t, testutil.WaitLong) @@ -1032,8 +1023,7 @@ func TestGetCryptoKeys(t *testing.T) { Name: testutil.GetRandomName(t), }) - client := wsproxysdk.New(cclient.URL) - client.SetSessionToken(cclient.SessionToken()) + client := wsproxysdk.New(cclient.URL, cclient.SessionToken()) _, err := client.CryptoKeys(ctx, codersdk.CryptoKeyFeatureWorkspaceAppsAPIKey) require.Error(t, err) diff --git a/enterprise/wsproxy/wsproxy.go b/enterprise/wsproxy/wsproxy.go index c2ac1baf2d..6e1da2f258 100644 --- a/enterprise/wsproxy/wsproxy.go +++ b/enterprise/wsproxy/wsproxy.go @@ -163,11 +163,7 @@ func New(ctx context.Context, opts *Options) (*Server, error) { return nil, err } - client := wsproxysdk.New(opts.DashboardURL) - err := client.SetSessionToken(opts.ProxySessionToken) - if err != nil { - return nil, xerrors.Errorf("set client token: %w", err) - } + client := wsproxysdk.New(opts.DashboardURL, opts.ProxySessionToken) // Use the configured client if provided. if opts.HTTPClient != nil { diff --git a/enterprise/wsproxy/wsproxy_test.go b/enterprise/wsproxy/wsproxy_test.go index 523d429476..0e8e61af88 100644 --- a/enterprise/wsproxy/wsproxy_test.go +++ b/enterprise/wsproxy/wsproxy_test.go @@ -577,8 +577,7 @@ func TestWorkspaceProxyDERPMeshProbe(t *testing.T) { t.Cleanup(srv.Close) // Register a proxy. - wsproxyClient := wsproxysdk.New(primaryAccessURL) - wsproxyClient.SetSessionToken(token) + wsproxyClient := wsproxysdk.New(primaryAccessURL, token) hostname, err := cryptorand.String(6) require.NoError(t, err) replicaID := uuid.New() @@ -879,8 +878,7 @@ func TestWorkspaceProxyDERPMeshProbe(t *testing.T) { require.Contains(t, respJSON.Warnings[0], "High availability networking") // Deregister the other replica. - wsproxyClient := wsproxysdk.New(api.AccessURL) - wsproxyClient.SetSessionToken(proxy.Options.ProxySessionToken) + wsproxyClient := wsproxysdk.New(api.AccessURL, proxy.Options.ProxySessionToken) err = wsproxyClient.DeregisterWorkspaceProxy(ctx, wsproxysdk.DeregisterWorkspaceProxyRequest{ ReplicaID: otherReplicaID, }) diff --git a/enterprise/wsproxy/wsproxysdk/wsproxysdk.go b/enterprise/wsproxy/wsproxysdk/wsproxysdk.go index 72f5a4291c..15400a2d33 100644 --- a/enterprise/wsproxy/wsproxysdk/wsproxysdk.go +++ b/enterprise/wsproxy/wsproxysdk/wsproxysdk.go @@ -33,15 +33,20 @@ type Client struct { // New creates a external proxy client for the provided primary coder server // URL. -func New(serverURL *url.URL) *Client { +func New(serverURL *url.URL, sessionToken string) *Client { sdkClient := codersdk.New(serverURL) - sdkClient.SessionTokenHeader = httpmw.WorkspaceProxyAuthTokenHeader - + sdkClient.SessionTokenProvider = codersdk.FixedSessionTokenProvider{ + SessionToken: sessionToken, + SessionTokenHeader: httpmw.WorkspaceProxyAuthTokenHeader, + } sdkClientIgnoreRedirects := codersdk.New(serverURL) sdkClientIgnoreRedirects.HTTPClient.CheckRedirect = func(_ *http.Request, _ []*http.Request) error { return http.ErrUseLastResponse } - sdkClientIgnoreRedirects.SessionTokenHeader = httpmw.WorkspaceProxyAuthTokenHeader + sdkClientIgnoreRedirects.SessionTokenProvider = codersdk.FixedSessionTokenProvider{ + SessionToken: sessionToken, + SessionTokenHeader: httpmw.WorkspaceProxyAuthTokenHeader, + } return &Client{ SDKClient: sdkClient, @@ -49,14 +54,6 @@ func New(serverURL *url.URL) *Client { } } -// SetSessionToken sets the session token for the client. An error is returned -// if the session token is not in the correct format for external proxies. -func (c *Client) SetSessionToken(token string) error { - c.SDKClient.SetSessionToken(token) - c.sdkClientIgnoreRedirects.SetSessionToken(token) - return nil -} - // SessionToken returns the currently set token for the client. func (c *Client) SessionToken() string { return c.SDKClient.SessionToken() @@ -506,17 +503,12 @@ func (c *Client) TailnetDialer() (*workspacesdk.WebsocketDialer, error) { if err != nil { return nil, xerrors.Errorf("parse url: %w", err) } - coordinateHeaders := make(http.Header) - tokenHeader := codersdk.SessionTokenHeader - if c.SDKClient.SessionTokenHeader != "" { - tokenHeader = c.SDKClient.SessionTokenHeader - } - coordinateHeaders.Set(tokenHeader, c.SessionToken()) - - return workspacesdk.NewWebsocketDialer(logger, coordinateURL, &websocket.DialOptions{ + wsOptions := &websocket.DialOptions{ HTTPClient: c.SDKClient.HTTPClient, - HTTPHeader: coordinateHeaders, - }), nil + } + c.SDKClient.SessionTokenProvider.SetDialOption(wsOptions) + + return workspacesdk.NewWebsocketDialer(logger, coordinateURL, wsOptions), nil } type CryptoKeysResponse struct { diff --git a/enterprise/wsproxy/wsproxysdk/wsproxysdk_test.go b/enterprise/wsproxy/wsproxysdk/wsproxysdk_test.go index aada23da9d..6b4da6831c 100644 --- a/enterprise/wsproxy/wsproxysdk/wsproxysdk_test.go +++ b/enterprise/wsproxy/wsproxysdk/wsproxysdk_test.go @@ -60,8 +60,7 @@ func Test_IssueSignedAppTokenHTML(t *testing.T) { u, err := url.Parse(srv.URL) require.NoError(t, err) - client := wsproxysdk.New(u) - client.SetSessionToken(expectedProxyToken) + client := wsproxysdk.New(u, expectedProxyToken) ctx := testutil.Context(t, testutil.WaitLong) @@ -111,8 +110,7 @@ func Test_IssueSignedAppTokenHTML(t *testing.T) { u, err := url.Parse(srv.URL) require.NoError(t, err) - client := wsproxysdk.New(u) - _ = client.SetSessionToken(expectedProxyToken) + client := wsproxysdk.New(u, expectedProxyToken) ctx := testutil.Context(t, testutil.WaitLong) diff --git a/scaletest/workspacetraffic/conn.go b/scaletest/workspacetraffic/conn.go index 17cbc7c501..3b516c6347 100644 --- a/scaletest/workspacetraffic/conn.go +++ b/scaletest/workspacetraffic/conn.go @@ -6,7 +6,6 @@ import ( "errors" "io" "net" - "net/http" "sync" "time" @@ -269,18 +268,13 @@ func (w *wrappedSSHConn) Write(p []byte) (n int, err error) { } func appClientConn(ctx context.Context, client *codersdk.Client, url string) (*countReadWriteCloser, error) { - headers := http.Header{} - tokenHeader := codersdk.SessionTokenHeader - if client.SessionTokenHeader != "" { - tokenHeader = client.SessionTokenHeader + wsOptions := &websocket.DialOptions{ + HTTPClient: client.HTTPClient, } - headers.Set(tokenHeader, client.SessionToken()) + client.SessionTokenProvider.SetDialOption(wsOptions) //nolint:bodyclose // The websocket conn manages the body. - conn, _, err := websocket.Dial(ctx, url, &websocket.DialOptions{ - HTTPClient: client.HTTPClient, - HTTPHeader: headers, - }) + conn, _, err := websocket.Dial(ctx, url, wsOptions) if err != nil { return nil, xerrors.Errorf("websocket dial: %w", err) } diff --git a/scaletest/workspacetraffic/run_test.go b/scaletest/workspacetraffic/run_test.go index 59801e68d8..dd84747886 100644 --- a/scaletest/workspacetraffic/run_test.go +++ b/scaletest/workspacetraffic/run_test.go @@ -6,6 +6,7 @@ import ( "io" "net/http" "net/http/httptest" + "net/url" "runtime" "slices" "strings" @@ -313,9 +314,7 @@ func TestRun(t *testing.T) { readMetrics = &testMetrics{} writeMetrics = &testMetrics{} ) - client := &codersdk.Client{ - HTTPClient: &http.Client{}, - } + client := codersdk.New(&url.URL{}) runner := workspacetraffic.NewRunner(client, workspacetraffic.Config{ BytesPerTick: int64(bytesPerTick), TickInterval: tickInterval,