chore: refactor codersdk to use SessionTokenProvider (#19565)

Refactors `codersdk.Client`'s use of session tokens to use a `SessionTokenProvider`, which abstracts the obtaining and storing of the session token.

The main motiviation is to unify Agent authentication an an upstack PR, which can use cloud instance identity via token exchange, rather than a fixed session token.

However, the abstraction could also allow functionality like obtaining the session token from other external sources like the OS credential manager, or an external secret/key management system like Vault.
This commit is contained in:
Spike Curtis
2025-08-29 10:41:32 +02:00
committed by GitHub
parent f721f3d9d7
commit 192c81e8f9
15 changed files with 129 additions and 124 deletions
+1 -2
View File
@@ -243,13 +243,12 @@ STATE CHANGED STATUS STATE MESSAGE
ctx = testutil.Context(t, testutil.WaitShort)
now = time.Now().UTC() // TODO: replace with quartz
srv = httptest.NewServer(http.HandlerFunc(tc.hf(ctx, now)))
client = new(codersdk.Client)
client = codersdk.New(testutil.MustURL(t, srv.URL))
sb = strings.Builder{}
args = []string{"exp", "task", "status", "--watch-interval", testutil.IntervalFast.String()}
)
t.Cleanup(srv.Close)
client.URL = testutil.MustURL(t, srv.URL)
args = append(args, tc.args...)
inv, root := clitest.New(t, args...)
inv.Stdout = &sb
+1 -6
View File
@@ -5,14 +5,12 @@ import (
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/cli/clitest"
"github.com/coder/coder/v2/cli/cliui"
@@ -236,7 +234,7 @@ func TestTaskCreate(t *testing.T) {
var (
ctx = testutil.Context(t, testutil.WaitShort)
srv = httptest.NewServer(tt.handler(t, ctx))
client = new(codersdk.Client)
client = codersdk.New(testutil.MustURL(t, srv.URL))
args = []string{"exp", "task", "create"}
sb strings.Builder
err error
@@ -244,9 +242,6 @@ func TestTaskCreate(t *testing.T) {
t.Cleanup(srv.Close)
client.URL, err = url.Parse(srv.URL)
require.NoError(t, err)
inv, root := clitest.New(t, append(args, tt.args...)...)
inv.Environ = serpent.ParseEnviron(tt.env, "")
inv.Stdout = &sb
+3
View File
@@ -635,6 +635,9 @@ func (r *RootCmd) HeaderTransport(ctx context.Context, serverURL *url.URL) (*cod
}
func (r *RootCmd) configureClient(ctx context.Context, client *codersdk.Client, serverURL *url.URL, inv *serpent.Invocation) error {
if client.SessionTokenProvider == nil {
client.SessionTokenProvider = codersdk.FixedSessionTokenProvider{}
}
transport := http.DefaultTransport
transport = wrapTransportWithTelemetryHeader(transport, inv)
if !r.noVersionCheck {
+2 -6
View File
@@ -641,7 +641,7 @@ func (f *FakeIDP) LoginWithClient(t testing.TB, client *codersdk.Client, idToken
// ExternalLogin does the oauth2 flow for external auth providers. This requires
// an authenticated coder client.
func (f *FakeIDP) ExternalLogin(t testing.TB, client *codersdk.Client, opts ...func(r *http.Request)) {
func (f *FakeIDP) ExternalLogin(t testing.TB, client *codersdk.Client, opts ...codersdk.RequestOption) {
coderOauthURL, err := client.URL.Parse(fmt.Sprintf("/external-auth/%s/callback", f.externalProviderID))
require.NoError(t, err)
f.SetRedirect(t, coderOauthURL.String())
@@ -660,11 +660,7 @@ func (f *FakeIDP) ExternalLogin(t testing.TB, client *codersdk.Client, opts ...f
req, err := http.NewRequestWithContext(ctx, "GET", coderOauthURL.String(), nil)
require.NoError(t, err)
// External auth flow requires the user be authenticated.
headerName := client.SessionTokenHeader
if headerName == "" {
headerName = codersdk.SessionTokenHeader
}
req.Header.Set(headerName, client.SessionToken())
opts = append([]codersdk.RequestOption{client.SessionTokenProvider.AsRequestOption()}, opts...)
if cli.Jar == nil {
cli.Jar, err = cookiejar.New(nil)
require.NoError(t, err, "failed to create cookie jar")
+1 -1
View File
@@ -115,7 +115,7 @@ func TestMCPHTTP_ToolRegistration(t *testing.T) {
require.Contains(t, err.Error(), "client cannot be nil", "Should reject nil client with appropriate error message")
// Test registering tools with valid client should succeed
client := &codersdk.Client{}
client := codersdk.New(testutil.MustURL(t, "http://not-used"))
err = server.RegisterTools(client)
require.NoError(t, err)
+20 -25
View File
@@ -110,6 +110,7 @@ func New(serverURL *url.URL) *Client {
return &Client{
URL: serverURL,
HTTPClient: &http.Client{},
SessionTokenProvider: FixedSessionTokenProvider{},
}
}
@@ -119,17 +120,13 @@ type Client struct {
// mu protects the fields sessionToken, logger, and logBodies. These
// need to be safe for concurrent access.
mu sync.RWMutex
sessionToken string
SessionTokenProvider SessionTokenProvider
logger slog.Logger
logBodies bool
HTTPClient *http.Client
URL *url.URL
// SessionTokenHeader is an optional custom header to use for setting tokens. By
// default 'Coder-Session-Token' is used.
SessionTokenHeader string
// PlainLogger may be set to log HTTP traffic in a human-readable form.
// It uses the LogBodies option.
PlainLogger io.Writer
@@ -176,14 +173,20 @@ func (c *Client) SetLogBodies(logBodies bool) {
func (c *Client) SessionToken() string {
c.mu.RLock()
defer c.mu.RUnlock()
return c.sessionToken
return c.SessionTokenProvider.GetSessionToken()
}
// SetSessionToken returns the currently set token for the client.
// SetSessionToken sets a fixed token for the client.
// Deprecated: Create a new client instead of changing the token after creation.
func (c *Client) SetSessionToken(token string) {
c.SetSessionTokenProvider(FixedSessionTokenProvider{SessionToken: token})
}
// SetSessionTokenProvider sets the session token provider for the client.
func (c *Client) SetSessionTokenProvider(provider SessionTokenProvider) {
c.mu.Lock()
defer c.mu.Unlock()
c.sessionToken = token
c.SessionTokenProvider = provider
}
func prefixLines(prefix, s []byte) []byte {
@@ -199,6 +202,14 @@ func prefixLines(prefix, s []byte) []byte {
// Request performs a HTTP request with the body provided. The caller is
// responsible for closing the response body.
func (c *Client) Request(ctx context.Context, method, path string, body interface{}, opts ...RequestOption) (*http.Response, error) {
opts = append([]RequestOption{c.SessionTokenProvider.AsRequestOption()}, opts...)
return c.RequestWithoutSessionToken(ctx, method, path, body, opts...)
}
// RequestWithoutSessionToken performs a HTTP request. It is similar to Request, but does not set
// the session token in the request header, nor does it make a call to the SessionTokenProvider.
// This allows session token providers to call this method without causing reentrancy issues.
func (c *Client) RequestWithoutSessionToken(ctx context.Context, method, path string, body interface{}, opts ...RequestOption) (*http.Response, error) {
if ctx == nil {
return nil, xerrors.Errorf("context should not be nil")
}
@@ -248,12 +259,6 @@ func (c *Client) Request(ctx context.Context, method, path string, body interfac
return nil, xerrors.Errorf("create request: %w", err)
}
tokenHeader := c.SessionTokenHeader
if tokenHeader == "" {
tokenHeader = SessionTokenHeader
}
req.Header.Set(tokenHeader, c.SessionToken())
if r != nil {
req.Header.Set("Content-Type", "application/json")
}
@@ -345,20 +350,10 @@ func (c *Client) Dial(ctx context.Context, path string, opts *websocket.DialOpti
return nil, err
}
tokenHeader := c.SessionTokenHeader
if tokenHeader == "" {
tokenHeader = SessionTokenHeader
}
if opts == nil {
opts = &websocket.DialOptions{}
}
if opts.HTTPHeader == nil {
opts.HTTPHeader = http.Header{}
}
if opts.HTTPHeader.Get(tokenHeader) == "" {
opts.HTTPHeader.Set(tokenHeader, c.SessionToken())
}
c.SessionTokenProvider.SetDialOption(opts)
conn, resp, err := websocket.Dial(ctx, u.String(), opts)
if resp != nil && resp.Body != nil {
+55
View File
@@ -0,0 +1,55 @@
package codersdk
import (
"net/http"
"github.com/coder/websocket"
)
// SessionTokenProvider provides the session token to access the Coder service (coderd).
// @typescript-ignore SessionTokenProvider
type SessionTokenProvider interface {
// AsRequestOption returns a request option that attaches the session token to an HTTP request.
AsRequestOption() RequestOption
// SetDialOption sets the session token on a websocket request via DialOptions
SetDialOption(options *websocket.DialOptions)
// GetSessionToken returns the session token as a string.
GetSessionToken() string
}
// FixedSessionTokenProvider provides a given, fixed, session token. E.g. one read from file or environment variable
// at the program start.
// @typescript-ignore FixedSessionTokenProvider
type FixedSessionTokenProvider struct {
SessionToken string
// SessionTokenHeader is an optional custom header to use for setting tokens. By
// default, 'Coder-Session-Token' is used.
SessionTokenHeader string
}
func (f FixedSessionTokenProvider) AsRequestOption() RequestOption {
return func(req *http.Request) {
tokenHeader := f.SessionTokenHeader
if tokenHeader == "" {
tokenHeader = SessionTokenHeader
}
req.Header.Set(tokenHeader, f.SessionToken)
}
}
func (f FixedSessionTokenProvider) GetSessionToken() string {
return f.SessionToken
}
func (f FixedSessionTokenProvider) SetDialOption(opts *websocket.DialOptions) {
tokenHeader := f.SessionTokenHeader
if tokenHeader == "" {
tokenHeader = SessionTokenHeader
}
if opts.HTTPHeader == nil {
opts.HTTPHeader = http.Header{}
}
if opts.HTTPHeader.Get(tokenHeader) == "" {
opts.HTTPHeader.Set(tokenHeader, f.SessionToken)
}
}
+6 -11
View File
@@ -215,12 +215,12 @@ func (c *Client) DialAgent(dialCtx context.Context, agentID uuid.UUID, options *
options.BlockEndpoints = true
}
headers := make(http.Header)
tokenHeader := codersdk.SessionTokenHeader
if c.client.SessionTokenHeader != "" {
tokenHeader = c.client.SessionTokenHeader
wsOptions := &websocket.DialOptions{
HTTPClient: c.client.HTTPClient,
// Need to disable compression to avoid a data-race.
CompressionMode: websocket.CompressionDisabled,
}
headers.Set(tokenHeader, c.client.SessionToken())
c.client.SessionTokenProvider.SetDialOption(wsOptions)
// New context, separate from dialCtx. We don't want to cancel the
// connection if dialCtx is canceled.
@@ -236,12 +236,7 @@ func (c *Client) DialAgent(dialCtx context.Context, agentID uuid.UUID, options *
return nil, xerrors.Errorf("parse url: %w", err)
}
dialer := NewWebsocketDialer(options.Logger, coordinateURL, &websocket.DialOptions{
HTTPClient: c.client.HTTPClient,
HTTPHeader: headers,
// Need to disable compression to avoid a data-race.
CompressionMode: websocket.CompressionDisabled,
})
dialer := NewWebsocketDialer(options.Logger, coordinateURL, wsOptions)
clk := quartz.NewReal()
controller := tailnet.NewController(options.Logger, dialer)
controller.ResumeTokenCtrl = tailnet.NewBasicResumeTokenController(options.Logger, clk)
+10 -20
View File
@@ -312,8 +312,7 @@ func TestProxyRegisterDeregister(t *testing.T) {
})
require.NoError(t, err)
proxyClient := wsproxysdk.New(client.URL)
proxyClient.SetSessionToken(createRes.ProxyToken)
proxyClient := wsproxysdk.New(client.URL, createRes.ProxyToken)
// Register
req := wsproxysdk.RegisterWorkspaceProxyRequest{
@@ -427,8 +426,7 @@ func TestProxyRegisterDeregister(t *testing.T) {
})
require.NoError(t, err)
proxyClient := wsproxysdk.New(client.URL)
proxyClient.SetSessionToken(createRes.ProxyToken)
proxyClient := wsproxysdk.New(client.URL, createRes.ProxyToken)
req := wsproxysdk.RegisterWorkspaceProxyRequest{
AccessURL: "https://proxy.coder.test",
@@ -472,8 +470,7 @@ func TestProxyRegisterDeregister(t *testing.T) {
})
require.NoError(t, err)
proxyClient := wsproxysdk.New(client.URL)
proxyClient.SetSessionToken(createRes.ProxyToken)
proxyClient := wsproxysdk.New(client.URL, createRes.ProxyToken)
err = proxyClient.DeregisterWorkspaceProxy(ctx, wsproxysdk.DeregisterWorkspaceProxyRequest{
ReplicaID: uuid.New(),
@@ -501,8 +498,7 @@ func TestProxyRegisterDeregister(t *testing.T) {
// Register a replica on proxy 2. This shouldn't be returned by replicas
// for proxy 1.
proxyClient2 := wsproxysdk.New(client.URL)
proxyClient2.SetSessionToken(createRes2.ProxyToken)
proxyClient2 := wsproxysdk.New(client.URL, createRes2.ProxyToken)
_, err = proxyClient2.RegisterWorkspaceProxy(ctx, wsproxysdk.RegisterWorkspaceProxyRequest{
AccessURL: "https://other.proxy.coder.test",
WildcardHostname: "*.other.proxy.coder.test",
@@ -516,8 +512,7 @@ func TestProxyRegisterDeregister(t *testing.T) {
require.NoError(t, err)
// Register replica 1.
proxyClient1 := wsproxysdk.New(client.URL)
proxyClient1.SetSessionToken(createRes1.ProxyToken)
proxyClient1 := wsproxysdk.New(client.URL, createRes1.ProxyToken)
req1 := wsproxysdk.RegisterWorkspaceProxyRequest{
AccessURL: "https://one.proxy.coder.test",
WildcardHostname: "*.one.proxy.coder.test",
@@ -574,8 +569,7 @@ func TestProxyRegisterDeregister(t *testing.T) {
})
require.NoError(t, err)
proxyClient := wsproxysdk.New(client.URL)
proxyClient.SetSessionToken(createRes.ProxyToken)
proxyClient := wsproxysdk.New(client.URL, createRes.ProxyToken)
for i := 0; i < 100; i++ {
ok := false
@@ -652,8 +646,7 @@ func TestIssueSignedAppToken(t *testing.T) {
t.Run("BadAppRequest", func(t *testing.T) {
t.Parallel()
proxyClient := wsproxysdk.New(client.URL)
proxyClient.SetSessionToken(proxyRes.ProxyToken)
proxyClient := wsproxysdk.New(client.URL, proxyRes.ProxyToken)
ctx := testutil.Context(t, testutil.WaitLong)
_, err := proxyClient.IssueSignedAppToken(ctx, workspaceapps.IssueTokenRequest{
@@ -674,8 +667,7 @@ func TestIssueSignedAppToken(t *testing.T) {
}
t.Run("OK", func(t *testing.T) {
t.Parallel()
proxyClient := wsproxysdk.New(client.URL)
proxyClient.SetSessionToken(proxyRes.ProxyToken)
proxyClient := wsproxysdk.New(client.URL, proxyRes.ProxyToken)
ctx := testutil.Context(t, testutil.WaitLong)
_, err := proxyClient.IssueSignedAppToken(ctx, goodRequest)
@@ -684,8 +676,7 @@ func TestIssueSignedAppToken(t *testing.T) {
t.Run("OKHTML", func(t *testing.T) {
t.Parallel()
proxyClient := wsproxysdk.New(client.URL)
proxyClient.SetSessionToken(proxyRes.ProxyToken)
proxyClient := wsproxysdk.New(client.URL, proxyRes.ProxyToken)
rw := httptest.NewRecorder()
ctx := testutil.Context(t, testutil.WaitLong)
@@ -1032,8 +1023,7 @@ func TestGetCryptoKeys(t *testing.T) {
Name: testutil.GetRandomName(t),
})
client := wsproxysdk.New(cclient.URL)
client.SetSessionToken(cclient.SessionToken())
client := wsproxysdk.New(cclient.URL, cclient.SessionToken())
_, err := client.CryptoKeys(ctx, codersdk.CryptoKeyFeatureWorkspaceAppsAPIKey)
require.Error(t, err)
+1 -5
View File
@@ -163,11 +163,7 @@ func New(ctx context.Context, opts *Options) (*Server, error) {
return nil, err
}
client := wsproxysdk.New(opts.DashboardURL)
err := client.SetSessionToken(opts.ProxySessionToken)
if err != nil {
return nil, xerrors.Errorf("set client token: %w", err)
}
client := wsproxysdk.New(opts.DashboardURL, opts.ProxySessionToken)
// Use the configured client if provided.
if opts.HTTPClient != nil {
+2 -4
View File
@@ -577,8 +577,7 @@ func TestWorkspaceProxyDERPMeshProbe(t *testing.T) {
t.Cleanup(srv.Close)
// Register a proxy.
wsproxyClient := wsproxysdk.New(primaryAccessURL)
wsproxyClient.SetSessionToken(token)
wsproxyClient := wsproxysdk.New(primaryAccessURL, token)
hostname, err := cryptorand.String(6)
require.NoError(t, err)
replicaID := uuid.New()
@@ -879,8 +878,7 @@ func TestWorkspaceProxyDERPMeshProbe(t *testing.T) {
require.Contains(t, respJSON.Warnings[0], "High availability networking")
// Deregister the other replica.
wsproxyClient := wsproxysdk.New(api.AccessURL)
wsproxyClient.SetSessionToken(proxy.Options.ProxySessionToken)
wsproxyClient := wsproxysdk.New(api.AccessURL, proxy.Options.ProxySessionToken)
err = wsproxyClient.DeregisterWorkspaceProxy(ctx, wsproxysdk.DeregisterWorkspaceProxyRequest{
ReplicaID: otherReplicaID,
})
+14 -22
View File
@@ -33,15 +33,20 @@ type Client struct {
// New creates a external proxy client for the provided primary coder server
// URL.
func New(serverURL *url.URL) *Client {
func New(serverURL *url.URL, sessionToken string) *Client {
sdkClient := codersdk.New(serverURL)
sdkClient.SessionTokenHeader = httpmw.WorkspaceProxyAuthTokenHeader
sdkClient.SessionTokenProvider = codersdk.FixedSessionTokenProvider{
SessionToken: sessionToken,
SessionTokenHeader: httpmw.WorkspaceProxyAuthTokenHeader,
}
sdkClientIgnoreRedirects := codersdk.New(serverURL)
sdkClientIgnoreRedirects.HTTPClient.CheckRedirect = func(_ *http.Request, _ []*http.Request) error {
return http.ErrUseLastResponse
}
sdkClientIgnoreRedirects.SessionTokenHeader = httpmw.WorkspaceProxyAuthTokenHeader
sdkClientIgnoreRedirects.SessionTokenProvider = codersdk.FixedSessionTokenProvider{
SessionToken: sessionToken,
SessionTokenHeader: httpmw.WorkspaceProxyAuthTokenHeader,
}
return &Client{
SDKClient: sdkClient,
@@ -49,14 +54,6 @@ func New(serverURL *url.URL) *Client {
}
}
// SetSessionToken sets the session token for the client. An error is returned
// if the session token is not in the correct format for external proxies.
func (c *Client) SetSessionToken(token string) error {
c.SDKClient.SetSessionToken(token)
c.sdkClientIgnoreRedirects.SetSessionToken(token)
return nil
}
// SessionToken returns the currently set token for the client.
func (c *Client) SessionToken() string {
return c.SDKClient.SessionToken()
@@ -506,17 +503,12 @@ func (c *Client) TailnetDialer() (*workspacesdk.WebsocketDialer, error) {
if err != nil {
return nil, xerrors.Errorf("parse url: %w", err)
}
coordinateHeaders := make(http.Header)
tokenHeader := codersdk.SessionTokenHeader
if c.SDKClient.SessionTokenHeader != "" {
tokenHeader = c.SDKClient.SessionTokenHeader
}
coordinateHeaders.Set(tokenHeader, c.SessionToken())
return workspacesdk.NewWebsocketDialer(logger, coordinateURL, &websocket.DialOptions{
wsOptions := &websocket.DialOptions{
HTTPClient: c.SDKClient.HTTPClient,
HTTPHeader: coordinateHeaders,
}), nil
}
c.SDKClient.SessionTokenProvider.SetDialOption(wsOptions)
return workspacesdk.NewWebsocketDialer(logger, coordinateURL, wsOptions), nil
}
type CryptoKeysResponse struct {
@@ -60,8 +60,7 @@ func Test_IssueSignedAppTokenHTML(t *testing.T) {
u, err := url.Parse(srv.URL)
require.NoError(t, err)
client := wsproxysdk.New(u)
client.SetSessionToken(expectedProxyToken)
client := wsproxysdk.New(u, expectedProxyToken)
ctx := testutil.Context(t, testutil.WaitLong)
@@ -111,8 +110,7 @@ func Test_IssueSignedAppTokenHTML(t *testing.T) {
u, err := url.Parse(srv.URL)
require.NoError(t, err)
client := wsproxysdk.New(u)
_ = client.SetSessionToken(expectedProxyToken)
client := wsproxysdk.New(u, expectedProxyToken)
ctx := testutil.Context(t, testutil.WaitLong)
+4 -10
View File
@@ -6,7 +6,6 @@ import (
"errors"
"io"
"net"
"net/http"
"sync"
"time"
@@ -269,18 +268,13 @@ func (w *wrappedSSHConn) Write(p []byte) (n int, err error) {
}
func appClientConn(ctx context.Context, client *codersdk.Client, url string) (*countReadWriteCloser, error) {
headers := http.Header{}
tokenHeader := codersdk.SessionTokenHeader
if client.SessionTokenHeader != "" {
tokenHeader = client.SessionTokenHeader
wsOptions := &websocket.DialOptions{
HTTPClient: client.HTTPClient,
}
headers.Set(tokenHeader, client.SessionToken())
client.SessionTokenProvider.SetDialOption(wsOptions)
//nolint:bodyclose // The websocket conn manages the body.
conn, _, err := websocket.Dial(ctx, url, &websocket.DialOptions{
HTTPClient: client.HTTPClient,
HTTPHeader: headers,
})
conn, _, err := websocket.Dial(ctx, url, wsOptions)
if err != nil {
return nil, xerrors.Errorf("websocket dial: %w", err)
}
+2 -3
View File
@@ -6,6 +6,7 @@ import (
"io"
"net/http"
"net/http/httptest"
"net/url"
"runtime"
"slices"
"strings"
@@ -313,9 +314,7 @@ func TestRun(t *testing.T) {
readMetrics = &testMetrics{}
writeMetrics = &testMetrics{}
)
client := &codersdk.Client{
HTTPClient: &http.Client{},
}
client := codersdk.New(&url.URL{})
runner := workspacetraffic.NewRunner(client, workspacetraffic.Config{
BytesPerTick: int64(bytesPerTick),
TickInterval: tickInterval,