mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
chore: refactor codersdk to use SessionTokenProvider (#19565)
Refactors `codersdk.Client`'s use of session tokens to use a `SessionTokenProvider`, which abstracts the obtaining and storing of the session token. The main motiviation is to unify Agent authentication an an upstack PR, which can use cloud instance identity via token exchange, rather than a fixed session token. However, the abstraction could also allow functionality like obtaining the session token from other external sources like the OS credential manager, or an external secret/key management system like Vault.
This commit is contained in:
@@ -243,13 +243,12 @@ STATE CHANGED STATUS STATE MESSAGE
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
now = time.Now().UTC() // TODO: replace with quartz
|
||||
srv = httptest.NewServer(http.HandlerFunc(tc.hf(ctx, now)))
|
||||
client = new(codersdk.Client)
|
||||
client = codersdk.New(testutil.MustURL(t, srv.URL))
|
||||
sb = strings.Builder{}
|
||||
args = []string{"exp", "task", "status", "--watch-interval", testutil.IntervalFast.String()}
|
||||
)
|
||||
|
||||
t.Cleanup(srv.Close)
|
||||
client.URL = testutil.MustURL(t, srv.URL)
|
||||
args = append(args, tc.args...)
|
||||
inv, root := clitest.New(t, args...)
|
||||
inv.Stdout = &sb
|
||||
|
||||
@@ -5,14 +5,12 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/cli/clitest"
|
||||
"github.com/coder/coder/v2/cli/cliui"
|
||||
@@ -236,7 +234,7 @@ func TestTaskCreate(t *testing.T) {
|
||||
var (
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
srv = httptest.NewServer(tt.handler(t, ctx))
|
||||
client = new(codersdk.Client)
|
||||
client = codersdk.New(testutil.MustURL(t, srv.URL))
|
||||
args = []string{"exp", "task", "create"}
|
||||
sb strings.Builder
|
||||
err error
|
||||
@@ -244,9 +242,6 @@ func TestTaskCreate(t *testing.T) {
|
||||
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
client.URL, err = url.Parse(srv.URL)
|
||||
require.NoError(t, err)
|
||||
|
||||
inv, root := clitest.New(t, append(args, tt.args...)...)
|
||||
inv.Environ = serpent.ParseEnviron(tt.env, "")
|
||||
inv.Stdout = &sb
|
||||
|
||||
@@ -635,6 +635,9 @@ func (r *RootCmd) HeaderTransport(ctx context.Context, serverURL *url.URL) (*cod
|
||||
}
|
||||
|
||||
func (r *RootCmd) configureClient(ctx context.Context, client *codersdk.Client, serverURL *url.URL, inv *serpent.Invocation) error {
|
||||
if client.SessionTokenProvider == nil {
|
||||
client.SessionTokenProvider = codersdk.FixedSessionTokenProvider{}
|
||||
}
|
||||
transport := http.DefaultTransport
|
||||
transport = wrapTransportWithTelemetryHeader(transport, inv)
|
||||
if !r.noVersionCheck {
|
||||
|
||||
@@ -641,7 +641,7 @@ func (f *FakeIDP) LoginWithClient(t testing.TB, client *codersdk.Client, idToken
|
||||
|
||||
// ExternalLogin does the oauth2 flow for external auth providers. This requires
|
||||
// an authenticated coder client.
|
||||
func (f *FakeIDP) ExternalLogin(t testing.TB, client *codersdk.Client, opts ...func(r *http.Request)) {
|
||||
func (f *FakeIDP) ExternalLogin(t testing.TB, client *codersdk.Client, opts ...codersdk.RequestOption) {
|
||||
coderOauthURL, err := client.URL.Parse(fmt.Sprintf("/external-auth/%s/callback", f.externalProviderID))
|
||||
require.NoError(t, err)
|
||||
f.SetRedirect(t, coderOauthURL.String())
|
||||
@@ -660,11 +660,7 @@ func (f *FakeIDP) ExternalLogin(t testing.TB, client *codersdk.Client, opts ...f
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", coderOauthURL.String(), nil)
|
||||
require.NoError(t, err)
|
||||
// External auth flow requires the user be authenticated.
|
||||
headerName := client.SessionTokenHeader
|
||||
if headerName == "" {
|
||||
headerName = codersdk.SessionTokenHeader
|
||||
}
|
||||
req.Header.Set(headerName, client.SessionToken())
|
||||
opts = append([]codersdk.RequestOption{client.SessionTokenProvider.AsRequestOption()}, opts...)
|
||||
if cli.Jar == nil {
|
||||
cli.Jar, err = cookiejar.New(nil)
|
||||
require.NoError(t, err, "failed to create cookie jar")
|
||||
|
||||
@@ -115,7 +115,7 @@ func TestMCPHTTP_ToolRegistration(t *testing.T) {
|
||||
require.Contains(t, err.Error(), "client cannot be nil", "Should reject nil client with appropriate error message")
|
||||
|
||||
// Test registering tools with valid client should succeed
|
||||
client := &codersdk.Client{}
|
||||
client := codersdk.New(testutil.MustURL(t, "http://not-used"))
|
||||
err = server.RegisterTools(client)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
+20
-25
@@ -110,6 +110,7 @@ func New(serverURL *url.URL) *Client {
|
||||
return &Client{
|
||||
URL: serverURL,
|
||||
HTTPClient: &http.Client{},
|
||||
SessionTokenProvider: FixedSessionTokenProvider{},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -119,17 +120,13 @@ type Client struct {
|
||||
// mu protects the fields sessionToken, logger, and logBodies. These
|
||||
// need to be safe for concurrent access.
|
||||
mu sync.RWMutex
|
||||
sessionToken string
|
||||
SessionTokenProvider SessionTokenProvider
|
||||
logger slog.Logger
|
||||
logBodies bool
|
||||
|
||||
HTTPClient *http.Client
|
||||
URL *url.URL
|
||||
|
||||
// SessionTokenHeader is an optional custom header to use for setting tokens. By
|
||||
// default 'Coder-Session-Token' is used.
|
||||
SessionTokenHeader string
|
||||
|
||||
// PlainLogger may be set to log HTTP traffic in a human-readable form.
|
||||
// It uses the LogBodies option.
|
||||
PlainLogger io.Writer
|
||||
@@ -176,14 +173,20 @@ func (c *Client) SetLogBodies(logBodies bool) {
|
||||
func (c *Client) SessionToken() string {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.sessionToken
|
||||
return c.SessionTokenProvider.GetSessionToken()
|
||||
}
|
||||
|
||||
// SetSessionToken returns the currently set token for the client.
|
||||
// SetSessionToken sets a fixed token for the client.
|
||||
// Deprecated: Create a new client instead of changing the token after creation.
|
||||
func (c *Client) SetSessionToken(token string) {
|
||||
c.SetSessionTokenProvider(FixedSessionTokenProvider{SessionToken: token})
|
||||
}
|
||||
|
||||
// SetSessionTokenProvider sets the session token provider for the client.
|
||||
func (c *Client) SetSessionTokenProvider(provider SessionTokenProvider) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.sessionToken = token
|
||||
c.SessionTokenProvider = provider
|
||||
}
|
||||
|
||||
func prefixLines(prefix, s []byte) []byte {
|
||||
@@ -199,6 +202,14 @@ func prefixLines(prefix, s []byte) []byte {
|
||||
// Request performs a HTTP request with the body provided. The caller is
|
||||
// responsible for closing the response body.
|
||||
func (c *Client) Request(ctx context.Context, method, path string, body interface{}, opts ...RequestOption) (*http.Response, error) {
|
||||
opts = append([]RequestOption{c.SessionTokenProvider.AsRequestOption()}, opts...)
|
||||
return c.RequestWithoutSessionToken(ctx, method, path, body, opts...)
|
||||
}
|
||||
|
||||
// RequestWithoutSessionToken performs a HTTP request. It is similar to Request, but does not set
|
||||
// the session token in the request header, nor does it make a call to the SessionTokenProvider.
|
||||
// This allows session token providers to call this method without causing reentrancy issues.
|
||||
func (c *Client) RequestWithoutSessionToken(ctx context.Context, method, path string, body interface{}, opts ...RequestOption) (*http.Response, error) {
|
||||
if ctx == nil {
|
||||
return nil, xerrors.Errorf("context should not be nil")
|
||||
}
|
||||
@@ -248,12 +259,6 @@ func (c *Client) Request(ctx context.Context, method, path string, body interfac
|
||||
return nil, xerrors.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
tokenHeader := c.SessionTokenHeader
|
||||
if tokenHeader == "" {
|
||||
tokenHeader = SessionTokenHeader
|
||||
}
|
||||
req.Header.Set(tokenHeader, c.SessionToken())
|
||||
|
||||
if r != nil {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
@@ -345,20 +350,10 @@ func (c *Client) Dial(ctx context.Context, path string, opts *websocket.DialOpti
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tokenHeader := c.SessionTokenHeader
|
||||
if tokenHeader == "" {
|
||||
tokenHeader = SessionTokenHeader
|
||||
}
|
||||
|
||||
if opts == nil {
|
||||
opts = &websocket.DialOptions{}
|
||||
}
|
||||
if opts.HTTPHeader == nil {
|
||||
opts.HTTPHeader = http.Header{}
|
||||
}
|
||||
if opts.HTTPHeader.Get(tokenHeader) == "" {
|
||||
opts.HTTPHeader.Set(tokenHeader, c.SessionToken())
|
||||
}
|
||||
c.SessionTokenProvider.SetDialOption(opts)
|
||||
|
||||
conn, resp, err := websocket.Dial(ctx, u.String(), opts)
|
||||
if resp != nil && resp.Body != nil {
|
||||
|
||||
@@ -0,0 +1,55 @@
|
||||
package codersdk
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/coder/websocket"
|
||||
)
|
||||
|
||||
// SessionTokenProvider provides the session token to access the Coder service (coderd).
|
||||
// @typescript-ignore SessionTokenProvider
|
||||
type SessionTokenProvider interface {
|
||||
// AsRequestOption returns a request option that attaches the session token to an HTTP request.
|
||||
AsRequestOption() RequestOption
|
||||
// SetDialOption sets the session token on a websocket request via DialOptions
|
||||
SetDialOption(options *websocket.DialOptions)
|
||||
// GetSessionToken returns the session token as a string.
|
||||
GetSessionToken() string
|
||||
}
|
||||
|
||||
// FixedSessionTokenProvider provides a given, fixed, session token. E.g. one read from file or environment variable
|
||||
// at the program start.
|
||||
// @typescript-ignore FixedSessionTokenProvider
|
||||
type FixedSessionTokenProvider struct {
|
||||
SessionToken string
|
||||
// SessionTokenHeader is an optional custom header to use for setting tokens. By
|
||||
// default, 'Coder-Session-Token' is used.
|
||||
SessionTokenHeader string
|
||||
}
|
||||
|
||||
func (f FixedSessionTokenProvider) AsRequestOption() RequestOption {
|
||||
return func(req *http.Request) {
|
||||
tokenHeader := f.SessionTokenHeader
|
||||
if tokenHeader == "" {
|
||||
tokenHeader = SessionTokenHeader
|
||||
}
|
||||
req.Header.Set(tokenHeader, f.SessionToken)
|
||||
}
|
||||
}
|
||||
|
||||
func (f FixedSessionTokenProvider) GetSessionToken() string {
|
||||
return f.SessionToken
|
||||
}
|
||||
|
||||
func (f FixedSessionTokenProvider) SetDialOption(opts *websocket.DialOptions) {
|
||||
tokenHeader := f.SessionTokenHeader
|
||||
if tokenHeader == "" {
|
||||
tokenHeader = SessionTokenHeader
|
||||
}
|
||||
if opts.HTTPHeader == nil {
|
||||
opts.HTTPHeader = http.Header{}
|
||||
}
|
||||
if opts.HTTPHeader.Get(tokenHeader) == "" {
|
||||
opts.HTTPHeader.Set(tokenHeader, f.SessionToken)
|
||||
}
|
||||
}
|
||||
@@ -215,12 +215,12 @@ func (c *Client) DialAgent(dialCtx context.Context, agentID uuid.UUID, options *
|
||||
options.BlockEndpoints = true
|
||||
}
|
||||
|
||||
headers := make(http.Header)
|
||||
tokenHeader := codersdk.SessionTokenHeader
|
||||
if c.client.SessionTokenHeader != "" {
|
||||
tokenHeader = c.client.SessionTokenHeader
|
||||
wsOptions := &websocket.DialOptions{
|
||||
HTTPClient: c.client.HTTPClient,
|
||||
// Need to disable compression to avoid a data-race.
|
||||
CompressionMode: websocket.CompressionDisabled,
|
||||
}
|
||||
headers.Set(tokenHeader, c.client.SessionToken())
|
||||
c.client.SessionTokenProvider.SetDialOption(wsOptions)
|
||||
|
||||
// New context, separate from dialCtx. We don't want to cancel the
|
||||
// connection if dialCtx is canceled.
|
||||
@@ -236,12 +236,7 @@ func (c *Client) DialAgent(dialCtx context.Context, agentID uuid.UUID, options *
|
||||
return nil, xerrors.Errorf("parse url: %w", err)
|
||||
}
|
||||
|
||||
dialer := NewWebsocketDialer(options.Logger, coordinateURL, &websocket.DialOptions{
|
||||
HTTPClient: c.client.HTTPClient,
|
||||
HTTPHeader: headers,
|
||||
// Need to disable compression to avoid a data-race.
|
||||
CompressionMode: websocket.CompressionDisabled,
|
||||
})
|
||||
dialer := NewWebsocketDialer(options.Logger, coordinateURL, wsOptions)
|
||||
clk := quartz.NewReal()
|
||||
controller := tailnet.NewController(options.Logger, dialer)
|
||||
controller.ResumeTokenCtrl = tailnet.NewBasicResumeTokenController(options.Logger, clk)
|
||||
|
||||
@@ -312,8 +312,7 @@ func TestProxyRegisterDeregister(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
proxyClient := wsproxysdk.New(client.URL)
|
||||
proxyClient.SetSessionToken(createRes.ProxyToken)
|
||||
proxyClient := wsproxysdk.New(client.URL, createRes.ProxyToken)
|
||||
|
||||
// Register
|
||||
req := wsproxysdk.RegisterWorkspaceProxyRequest{
|
||||
@@ -427,8 +426,7 @@ func TestProxyRegisterDeregister(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
proxyClient := wsproxysdk.New(client.URL)
|
||||
proxyClient.SetSessionToken(createRes.ProxyToken)
|
||||
proxyClient := wsproxysdk.New(client.URL, createRes.ProxyToken)
|
||||
|
||||
req := wsproxysdk.RegisterWorkspaceProxyRequest{
|
||||
AccessURL: "https://proxy.coder.test",
|
||||
@@ -472,8 +470,7 @@ func TestProxyRegisterDeregister(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
proxyClient := wsproxysdk.New(client.URL)
|
||||
proxyClient.SetSessionToken(createRes.ProxyToken)
|
||||
proxyClient := wsproxysdk.New(client.URL, createRes.ProxyToken)
|
||||
|
||||
err = proxyClient.DeregisterWorkspaceProxy(ctx, wsproxysdk.DeregisterWorkspaceProxyRequest{
|
||||
ReplicaID: uuid.New(),
|
||||
@@ -501,8 +498,7 @@ func TestProxyRegisterDeregister(t *testing.T) {
|
||||
|
||||
// Register a replica on proxy 2. This shouldn't be returned by replicas
|
||||
// for proxy 1.
|
||||
proxyClient2 := wsproxysdk.New(client.URL)
|
||||
proxyClient2.SetSessionToken(createRes2.ProxyToken)
|
||||
proxyClient2 := wsproxysdk.New(client.URL, createRes2.ProxyToken)
|
||||
_, err = proxyClient2.RegisterWorkspaceProxy(ctx, wsproxysdk.RegisterWorkspaceProxyRequest{
|
||||
AccessURL: "https://other.proxy.coder.test",
|
||||
WildcardHostname: "*.other.proxy.coder.test",
|
||||
@@ -516,8 +512,7 @@ func TestProxyRegisterDeregister(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Register replica 1.
|
||||
proxyClient1 := wsproxysdk.New(client.URL)
|
||||
proxyClient1.SetSessionToken(createRes1.ProxyToken)
|
||||
proxyClient1 := wsproxysdk.New(client.URL, createRes1.ProxyToken)
|
||||
req1 := wsproxysdk.RegisterWorkspaceProxyRequest{
|
||||
AccessURL: "https://one.proxy.coder.test",
|
||||
WildcardHostname: "*.one.proxy.coder.test",
|
||||
@@ -574,8 +569,7 @@ func TestProxyRegisterDeregister(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
proxyClient := wsproxysdk.New(client.URL)
|
||||
proxyClient.SetSessionToken(createRes.ProxyToken)
|
||||
proxyClient := wsproxysdk.New(client.URL, createRes.ProxyToken)
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
ok := false
|
||||
@@ -652,8 +646,7 @@ func TestIssueSignedAppToken(t *testing.T) {
|
||||
|
||||
t.Run("BadAppRequest", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
proxyClient := wsproxysdk.New(client.URL)
|
||||
proxyClient.SetSessionToken(proxyRes.ProxyToken)
|
||||
proxyClient := wsproxysdk.New(client.URL, proxyRes.ProxyToken)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
_, err := proxyClient.IssueSignedAppToken(ctx, workspaceapps.IssueTokenRequest{
|
||||
@@ -674,8 +667,7 @@ func TestIssueSignedAppToken(t *testing.T) {
|
||||
}
|
||||
t.Run("OK", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
proxyClient := wsproxysdk.New(client.URL)
|
||||
proxyClient.SetSessionToken(proxyRes.ProxyToken)
|
||||
proxyClient := wsproxysdk.New(client.URL, proxyRes.ProxyToken)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
_, err := proxyClient.IssueSignedAppToken(ctx, goodRequest)
|
||||
@@ -684,8 +676,7 @@ func TestIssueSignedAppToken(t *testing.T) {
|
||||
|
||||
t.Run("OKHTML", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
proxyClient := wsproxysdk.New(client.URL)
|
||||
proxyClient.SetSessionToken(proxyRes.ProxyToken)
|
||||
proxyClient := wsproxysdk.New(client.URL, proxyRes.ProxyToken)
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
@@ -1032,8 +1023,7 @@ func TestGetCryptoKeys(t *testing.T) {
|
||||
Name: testutil.GetRandomName(t),
|
||||
})
|
||||
|
||||
client := wsproxysdk.New(cclient.URL)
|
||||
client.SetSessionToken(cclient.SessionToken())
|
||||
client := wsproxysdk.New(cclient.URL, cclient.SessionToken())
|
||||
|
||||
_, err := client.CryptoKeys(ctx, codersdk.CryptoKeyFeatureWorkspaceAppsAPIKey)
|
||||
require.Error(t, err)
|
||||
|
||||
@@ -163,11 +163,7 @@ func New(ctx context.Context, opts *Options) (*Server, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client := wsproxysdk.New(opts.DashboardURL)
|
||||
err := client.SetSessionToken(opts.ProxySessionToken)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("set client token: %w", err)
|
||||
}
|
||||
client := wsproxysdk.New(opts.DashboardURL, opts.ProxySessionToken)
|
||||
|
||||
// Use the configured client if provided.
|
||||
if opts.HTTPClient != nil {
|
||||
|
||||
@@ -577,8 +577,7 @@ func TestWorkspaceProxyDERPMeshProbe(t *testing.T) {
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
// Register a proxy.
|
||||
wsproxyClient := wsproxysdk.New(primaryAccessURL)
|
||||
wsproxyClient.SetSessionToken(token)
|
||||
wsproxyClient := wsproxysdk.New(primaryAccessURL, token)
|
||||
hostname, err := cryptorand.String(6)
|
||||
require.NoError(t, err)
|
||||
replicaID := uuid.New()
|
||||
@@ -879,8 +878,7 @@ func TestWorkspaceProxyDERPMeshProbe(t *testing.T) {
|
||||
require.Contains(t, respJSON.Warnings[0], "High availability networking")
|
||||
|
||||
// Deregister the other replica.
|
||||
wsproxyClient := wsproxysdk.New(api.AccessURL)
|
||||
wsproxyClient.SetSessionToken(proxy.Options.ProxySessionToken)
|
||||
wsproxyClient := wsproxysdk.New(api.AccessURL, proxy.Options.ProxySessionToken)
|
||||
err = wsproxyClient.DeregisterWorkspaceProxy(ctx, wsproxysdk.DeregisterWorkspaceProxyRequest{
|
||||
ReplicaID: otherReplicaID,
|
||||
})
|
||||
|
||||
@@ -33,15 +33,20 @@ type Client struct {
|
||||
|
||||
// New creates a external proxy client for the provided primary coder server
|
||||
// URL.
|
||||
func New(serverURL *url.URL) *Client {
|
||||
func New(serverURL *url.URL, sessionToken string) *Client {
|
||||
sdkClient := codersdk.New(serverURL)
|
||||
sdkClient.SessionTokenHeader = httpmw.WorkspaceProxyAuthTokenHeader
|
||||
|
||||
sdkClient.SessionTokenProvider = codersdk.FixedSessionTokenProvider{
|
||||
SessionToken: sessionToken,
|
||||
SessionTokenHeader: httpmw.WorkspaceProxyAuthTokenHeader,
|
||||
}
|
||||
sdkClientIgnoreRedirects := codersdk.New(serverURL)
|
||||
sdkClientIgnoreRedirects.HTTPClient.CheckRedirect = func(_ *http.Request, _ []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
}
|
||||
sdkClientIgnoreRedirects.SessionTokenHeader = httpmw.WorkspaceProxyAuthTokenHeader
|
||||
sdkClientIgnoreRedirects.SessionTokenProvider = codersdk.FixedSessionTokenProvider{
|
||||
SessionToken: sessionToken,
|
||||
SessionTokenHeader: httpmw.WorkspaceProxyAuthTokenHeader,
|
||||
}
|
||||
|
||||
return &Client{
|
||||
SDKClient: sdkClient,
|
||||
@@ -49,14 +54,6 @@ func New(serverURL *url.URL) *Client {
|
||||
}
|
||||
}
|
||||
|
||||
// SetSessionToken sets the session token for the client. An error is returned
|
||||
// if the session token is not in the correct format for external proxies.
|
||||
func (c *Client) SetSessionToken(token string) error {
|
||||
c.SDKClient.SetSessionToken(token)
|
||||
c.sdkClientIgnoreRedirects.SetSessionToken(token)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SessionToken returns the currently set token for the client.
|
||||
func (c *Client) SessionToken() string {
|
||||
return c.SDKClient.SessionToken()
|
||||
@@ -506,17 +503,12 @@ func (c *Client) TailnetDialer() (*workspacesdk.WebsocketDialer, error) {
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse url: %w", err)
|
||||
}
|
||||
coordinateHeaders := make(http.Header)
|
||||
tokenHeader := codersdk.SessionTokenHeader
|
||||
if c.SDKClient.SessionTokenHeader != "" {
|
||||
tokenHeader = c.SDKClient.SessionTokenHeader
|
||||
}
|
||||
coordinateHeaders.Set(tokenHeader, c.SessionToken())
|
||||
|
||||
return workspacesdk.NewWebsocketDialer(logger, coordinateURL, &websocket.DialOptions{
|
||||
wsOptions := &websocket.DialOptions{
|
||||
HTTPClient: c.SDKClient.HTTPClient,
|
||||
HTTPHeader: coordinateHeaders,
|
||||
}), nil
|
||||
}
|
||||
c.SDKClient.SessionTokenProvider.SetDialOption(wsOptions)
|
||||
|
||||
return workspacesdk.NewWebsocketDialer(logger, coordinateURL, wsOptions), nil
|
||||
}
|
||||
|
||||
type CryptoKeysResponse struct {
|
||||
|
||||
@@ -60,8 +60,7 @@ func Test_IssueSignedAppTokenHTML(t *testing.T) {
|
||||
|
||||
u, err := url.Parse(srv.URL)
|
||||
require.NoError(t, err)
|
||||
client := wsproxysdk.New(u)
|
||||
client.SetSessionToken(expectedProxyToken)
|
||||
client := wsproxysdk.New(u, expectedProxyToken)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
@@ -111,8 +110,7 @@ func Test_IssueSignedAppTokenHTML(t *testing.T) {
|
||||
|
||||
u, err := url.Parse(srv.URL)
|
||||
require.NoError(t, err)
|
||||
client := wsproxysdk.New(u)
|
||||
_ = client.SetSessionToken(expectedProxyToken)
|
||||
client := wsproxysdk.New(u, expectedProxyToken)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -269,18 +268,13 @@ func (w *wrappedSSHConn) Write(p []byte) (n int, err error) {
|
||||
}
|
||||
|
||||
func appClientConn(ctx context.Context, client *codersdk.Client, url string) (*countReadWriteCloser, error) {
|
||||
headers := http.Header{}
|
||||
tokenHeader := codersdk.SessionTokenHeader
|
||||
if client.SessionTokenHeader != "" {
|
||||
tokenHeader = client.SessionTokenHeader
|
||||
wsOptions := &websocket.DialOptions{
|
||||
HTTPClient: client.HTTPClient,
|
||||
}
|
||||
headers.Set(tokenHeader, client.SessionToken())
|
||||
client.SessionTokenProvider.SetDialOption(wsOptions)
|
||||
|
||||
//nolint:bodyclose // The websocket conn manages the body.
|
||||
conn, _, err := websocket.Dial(ctx, url, &websocket.DialOptions{
|
||||
HTTPClient: client.HTTPClient,
|
||||
HTTPHeader: headers,
|
||||
})
|
||||
conn, _, err := websocket.Dial(ctx, url, wsOptions)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("websocket dial: %w", err)
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
@@ -313,9 +314,7 @@ func TestRun(t *testing.T) {
|
||||
readMetrics = &testMetrics{}
|
||||
writeMetrics = &testMetrics{}
|
||||
)
|
||||
client := &codersdk.Client{
|
||||
HTTPClient: &http.Client{},
|
||||
}
|
||||
client := codersdk.New(&url.URL{})
|
||||
runner := workspacetraffic.NewRunner(client, workspacetraffic.Config{
|
||||
BytesPerTick: int64(bytesPerTick),
|
||||
TickInterval: tickInterval,
|
||||
|
||||
Reference in New Issue
Block a user