diff --git a/agent/agent.go b/agent/agent.go index e4d7ab60e0..aed6652de6 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -74,7 +74,6 @@ type Options struct { LogDir string TempDir string ScriptDataDir string - ExchangeToken func(ctx context.Context) (string, error) Client Client ReconnectingPTYTimeout time.Duration EnvironmentVariables map[string]string @@ -99,6 +98,7 @@ type Client interface { proto.DRPCAgentClient26, tailnetproto.DRPCTailnetClient26, error, ) tailnet.DERPMapRewriter + agentsdk.RefreshableSessionTokenProvider } type Agent interface { @@ -131,11 +131,6 @@ func New(options Options) Agent { } options.ScriptDataDir = options.TempDir } - if options.ExchangeToken == nil { - options.ExchangeToken = func(_ context.Context) (string, error) { - return "", nil - } - } if options.ReportMetadataInterval == 0 { options.ReportMetadataInterval = time.Second } @@ -172,7 +167,6 @@ func New(options Options) Agent { coordDisconnected: make(chan struct{}), environmentVariables: options.EnvironmentVariables, client: options.Client, - exchangeToken: options.ExchangeToken, filesystem: options.Filesystem, logDir: options.LogDir, tempDir: options.TempDir, @@ -203,7 +197,6 @@ func New(options Options) Agent { // coordinator during shut down. close(a.coordDisconnected) a.announcementBanners.Store(new([]codersdk.BannerConfig)) - a.sessionToken.Store(new(string)) a.init() return a } @@ -212,7 +205,6 @@ type agent struct { clock quartz.Clock logger slog.Logger client Client - exchangeToken func(ctx context.Context) (string, error) tailnetListenPort uint16 filesystem afero.Fs logDir string @@ -254,7 +246,6 @@ type agent struct { scriptRunner *agentscripts.Runner announcementBanners atomic.Pointer[[]codersdk.BannerConfig] // announcementBanners is atomic because it is periodically updated. announcementBannersRefreshInterval time.Duration - sessionToken atomic.Pointer[string] sshServer *agentssh.Server sshMaxTimeout time.Duration blockFileTransfer bool @@ -916,11 +907,10 @@ func (a *agent) run() (retErr error) { // This allows the agent to refresh its token if necessary. // For instance identity this is required, since the instance // may not have re-provisioned, but a new agent ID was created. - sessionToken, err := a.exchangeToken(a.hardCtx) + err := a.client.RefreshToken(a.hardCtx) if err != nil { - return xerrors.Errorf("exchange token: %w", err) + return xerrors.Errorf("refresh token: %w", err) } - a.sessionToken.Store(&sessionToken) // ConnectRPC returns the dRPC connection we use for the Agent and Tailnet v2+ APIs aAPI, tAPI, err := a.client.ConnectRPC26(a.hardCtx) @@ -1359,7 +1349,7 @@ func (a *agent) updateCommandEnv(current []string) (updated []string, err error) "CODER_WORKSPACE_OWNER_NAME": manifest.OwnerName, // Specific Coder subcommands require the agent token exposed! - "CODER_AGENT_TOKEN": *a.sessionToken.Load(), + "CODER_AGENT_TOKEN": a.client.GetSessionToken(), // Git on Windows resolves with UNIX-style paths. // If using backslashes, it's unable to find the executable. diff --git a/agent/agent_test.go b/agent/agent_test.go index d80f5d1982..e8b3b99a95 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -22,7 +22,6 @@ import ( "slices" "strconv" "strings" - "sync/atomic" "testing" "time" @@ -2926,11 +2925,11 @@ func TestAgent_Speedtest(t *testing.T) { func TestAgent_Reconnect(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) logger := testutil.Logger(t) // After the agent is disconnected from a coordinator, it's supposed // to reconnect! - coordinator := tailnet.NewCoordinator(logger) - defer coordinator.Close() + fCoordinator := tailnettest.NewFakeCoordinator() agentID := uuid.New() statsCh := make(chan *proto.Stats, 50) @@ -2942,27 +2941,24 @@ func TestAgent_Reconnect(t *testing.T) { DERPMap: derpMap, }, statsCh, - coordinator, + fCoordinator, ) defer client.Close() - initialized := atomic.Int32{} + closer := agent.New(agent.Options{ - ExchangeToken: func(ctx context.Context) (string, error) { - initialized.Add(1) - return "", nil - }, Client: client, Logger: logger.Named("agent"), }) defer closer.Close() - require.Eventually(t, func() bool { - return coordinator.Node(agentID) != nil - }, testutil.WaitShort, testutil.IntervalFast) - client.LastWorkspaceAgent() - require.Eventually(t, func() bool { - return initialized.Load() == 2 - }, testutil.WaitShort, testutil.IntervalFast) + call1 := testutil.RequireReceive(ctx, t, fCoordinator.CoordinateCalls) + require.Equal(t, client.GetNumRefreshTokenCalls(), 1) + close(call1.Resps) // hang up + // expect reconnect + testutil.RequireReceive(ctx, t, fCoordinator.CoordinateCalls) + // Check that the agent refreshes the token when it reconnects. + require.Equal(t, client.GetNumRefreshTokenCalls(), 2) + closer.Close() } func TestAgent_WriteVSCodeConfigs(t *testing.T) { @@ -2984,9 +2980,6 @@ func TestAgent_WriteVSCodeConfigs(t *testing.T) { defer client.Close() filesystem := afero.NewMemMapFs() closer := agent.New(agent.Options{ - ExchangeToken: func(ctx context.Context) (string, error) { - return "", nil - }, Client: client, Logger: logger.Named("agent"), Filesystem: filesystem, @@ -3015,9 +3008,6 @@ func TestAgent_DebugServer(t *testing.T) { conn, _, _, _, agnt := setupAgent(t, agentsdk.Manifest{ DERPMap: derpMap, }, 0, func(c *agenttest.Client, o *agent.Options) { - o.ExchangeToken = func(context.Context) (string, error) { - return "token", nil - } o.LogDir = logDir }) diff --git a/agent/agenttest/agent.go b/agent/agenttest/agent.go index d25170dfc2..a6356e6e25 100644 --- a/agent/agenttest/agent.go +++ b/agent/agenttest/agent.go @@ -1,7 +1,6 @@ package agenttest import ( - "context" "net/url" "testing" @@ -31,18 +30,11 @@ func New(t testing.TB, coderURL *url.URL, agentToken string, opts ...func(*agent } if o.Client == nil { - agentClient := agentsdk.New(coderURL) - agentClient.SetSessionToken(agentToken) + agentClient := agentsdk.New(coderURL, agentsdk.WithFixedToken(agentToken)) agentClient.SDK.SetLogger(log) o.Client = agentClient } - if o.ExchangeToken == nil { - o.ExchangeToken = func(_ context.Context) (string, error) { - return agentToken, nil - } - } - if o.LogDir == "" { o.LogDir = t.TempDir() } diff --git a/agent/agenttest/client.go b/agent/agenttest/client.go index 5d78dfe697..ff601a7d08 100644 --- a/agent/agenttest/client.go +++ b/agent/agenttest/client.go @@ -3,6 +3,7 @@ package agenttest import ( "context" "io" + "net/http" "slices" "sync" "sync/atomic" @@ -28,6 +29,7 @@ import ( "github.com/coder/coder/v2/tailnet" "github.com/coder/coder/v2/tailnet/proto" "github.com/coder/coder/v2/testutil" + "github.com/coder/websocket" ) const statsInterval = 500 * time.Millisecond @@ -86,10 +88,34 @@ type Client struct { fakeAgentAPI *FakeAgentAPI LastWorkspaceAgent func() - mu sync.Mutex // Protects following. - logs []agentsdk.Log - derpMapUpdates chan *tailcfg.DERPMap - derpMapOnce sync.Once + mu sync.Mutex // Protects following. + logs []agentsdk.Log + derpMapUpdates chan *tailcfg.DERPMap + derpMapOnce sync.Once + refreshTokenCalls int +} + +func (*Client) AsRequestOption() codersdk.RequestOption { + return func(_ *http.Request) {} +} + +func (*Client) SetDialOption(*websocket.DialOptions) {} + +func (*Client) GetSessionToken() string { + return "agenttest-token" +} + +func (c *Client) RefreshToken(context.Context) error { + c.mu.Lock() + defer c.mu.Unlock() + c.refreshTokenCalls++ + return nil +} + +func (c *Client) GetNumRefreshTokenCalls() int { + c.mu.Lock() + defer c.mu.Unlock() + return c.refreshTokenCalls } func (*Client) RewriteDERPMap(*tailcfg.DERPMap) {} diff --git a/cli/agent.go b/cli/agent.go index c192d4429c..2b8efad55b 100644 --- a/cli/agent.go +++ b/cli/agent.go @@ -15,7 +15,6 @@ import ( "strings" "time" - "cloud.google.com/go/compute/metadata" "golang.org/x/xerrors" "gopkg.in/natefinch/lumberjack.v2" @@ -38,9 +37,8 @@ import ( "github.com/coder/coder/v2/codersdk/agentsdk" ) -func (r *RootCmd) workspaceAgent() *serpent.Command { +func workspaceAgent() *serpent.Command { var ( - auth string logDir string scriptDataDir string pprofAddress string @@ -59,6 +57,7 @@ func (r *RootCmd) workspaceAgent() *serpent.Command { devcontainerProjectDiscovery bool devcontainerDiscoveryAutostart bool ) + agentAuth := &AgentAuth{} cmd := &serpent.Command{ Use: "agent", Short: `Starts the Coder workspace agent.`, @@ -176,12 +175,14 @@ func (r *RootCmd) workspaceAgent() *serpent.Command { version := buildinfo.Version() logger.Info(ctx, "agent is starting now", - slog.F("url", r.agentURL), - slog.F("auth", auth), + slog.F("url", agentAuth.agentURL), + slog.F("auth", agentAuth.agentAuth), slog.F("version", version), ) - - client := agentsdk.New(r.agentURL) + client, err := agentAuth.CreateClient(ctx) + if err != nil { + return xerrors.Errorf("create agent client: %w", err) + } client.SDK.SetLogger(logger) // Set a reasonable timeout so requests can't hang forever! // The timeout needs to be reasonably long, because requests @@ -190,7 +191,7 @@ func (r *RootCmd) workspaceAgent() *serpent.Command { client.SDK.HTTPClient.Timeout = 30 * time.Second // Attach header transport so we process --agent-header and // --agent-header-command flags - headerTransport, err := headerTransport(ctx, r.agentURL, agentHeader, agentHeaderCommand) + headerTransport, err := headerTransport(ctx, &agentAuth.agentURL, agentHeader, agentHeaderCommand) if err != nil { return xerrors.Errorf("configure header transport: %w", err) } @@ -214,68 +215,6 @@ func (r *RootCmd) workspaceAgent() *serpent.Command { ignorePorts[port] = "debug" } - // exchangeToken returns a session token. - // This is abstracted to allow for the same looping condition - // regardless of instance identity auth type. - var exchangeToken func(context.Context) (agentsdk.AuthenticateResponse, error) - switch auth { - case "token": - token, _ := inv.ParsedFlags().GetString(varAgentToken) - if token == "" { - tokenFile, _ := inv.ParsedFlags().GetString(varAgentTokenFile) - if tokenFile != "" { - tokenBytes, err := os.ReadFile(tokenFile) - if err != nil { - return xerrors.Errorf("read token file %q: %w", tokenFile, err) - } - token = strings.TrimSpace(string(tokenBytes)) - } - } - if token == "" { - return xerrors.Errorf("CODER_AGENT_TOKEN or CODER_AGENT_TOKEN_FILE must be set for token auth") - } - client.SetSessionToken(token) - case "google-instance-identity": - // This is *only* done for testing to mock client authentication. - // This will never be set in a production scenario. - var gcpClient *metadata.Client - gcpClientRaw := ctx.Value("gcp-client") - if gcpClientRaw != nil { - gcpClient, _ = gcpClientRaw.(*metadata.Client) - } - exchangeToken = func(ctx context.Context) (agentsdk.AuthenticateResponse, error) { - return client.AuthGoogleInstanceIdentity(ctx, "", gcpClient) - } - case "aws-instance-identity": - // This is *only* done for testing to mock client authentication. - // This will never be set in a production scenario. - var awsClient *http.Client - awsClientRaw := ctx.Value("aws-client") - if awsClientRaw != nil { - awsClient, _ = awsClientRaw.(*http.Client) - if awsClient != nil { - client.SDK.HTTPClient = awsClient - } - } - exchangeToken = func(ctx context.Context) (agentsdk.AuthenticateResponse, error) { - return client.AuthAWSInstanceIdentity(ctx) - } - case "azure-instance-identity": - // This is *only* done for testing to mock client authentication. - // This will never be set in a production scenario. - var azureClient *http.Client - azureClientRaw := ctx.Value("azure-client") - if azureClientRaw != nil { - azureClient, _ = azureClientRaw.(*http.Client) - if azureClient != nil { - client.SDK.HTTPClient = azureClient - } - } - exchangeToken = func(ctx context.Context) (agentsdk.AuthenticateResponse, error) { - return client.AuthAzureInstanceIdentity(ctx) - } - } - executablePath, err := os.Executable() if err != nil { return xerrors.Errorf("getting os executable: %w", err) @@ -343,18 +282,7 @@ func (r *RootCmd) workspaceAgent() *serpent.Command { LogDir: logDir, ScriptDataDir: scriptDataDir, // #nosec G115 - Safe conversion as tailnet listen port is within uint16 range (0-65535) - TailnetListenPort: uint16(tailnetListenPort), - ExchangeToken: func(ctx context.Context) (string, error) { - if exchangeToken == nil { - return client.SDK.SessionToken(), nil - } - resp, err := exchangeToken(ctx) - if err != nil { - return "", err - } - client.SetSessionToken(resp.SessionToken) - return resp.SessionToken, nil - }, + TailnetListenPort: uint16(tailnetListenPort), EnvironmentVariables: environmentVariables, IgnorePorts: ignorePorts, SSHMaxTimeout: sshMaxTimeout, @@ -365,7 +293,7 @@ func (r *RootCmd) workspaceAgent() *serpent.Command { Execer: execer, Devcontainers: devcontainers, DevcontainerAPIOptions: []agentcontainers.Option{ - agentcontainers.WithSubAgentURL(r.agentURL.String()), + agentcontainers.WithSubAgentURL(agentAuth.agentURL.String()), agentcontainers.WithProjectDiscovery(devcontainerProjectDiscovery), agentcontainers.WithDiscoveryAutostart(devcontainerDiscoveryAutostart), }, @@ -400,13 +328,6 @@ func (r *RootCmd) workspaceAgent() *serpent.Command { } cmd.Options = serpent.OptionSet{ - { - Flag: "auth", - Default: "token", - Description: "Specify the authentication type to use for the agent.", - Env: "CODER_AGENT_AUTH", - Value: serpent.StringOf(&auth), - }, { Flag: "log-dir", Default: os.TempDir(), @@ -529,7 +450,7 @@ func (r *RootCmd) workspaceAgent() *serpent.Command { Value: serpent.BoolOf(&devcontainerDiscoveryAutostart), }, } - + agentAuth.AttachOptions(cmd, false) return cmd } diff --git a/cli/exp_mcp.go b/cli/exp_mcp.go index d5ea267390..8388a5a4c7 100644 --- a/cli/exp_mcp.go +++ b/cli/exp_mcp.go @@ -56,7 +56,7 @@ func (r *RootCmd) mcpConfigure() *serpent.Command { }, Children: []*serpent.Command{ r.mcpConfigureClaudeDesktop(), - r.mcpConfigureClaudeCode(), + mcpConfigureClaudeCode(), r.mcpConfigureCursor(), }, } @@ -117,7 +117,7 @@ func (*RootCmd) mcpConfigureClaudeDesktop() *serpent.Command { return cmd } -func (r *RootCmd) mcpConfigureClaudeCode() *serpent.Command { +func mcpConfigureClaudeCode() *serpent.Command { var ( claudeAPIKey string claudeConfigPath string @@ -131,6 +131,7 @@ func (r *RootCmd) mcpConfigureClaudeCode() *serpent.Command { deprecatedCoderMCPClaudeAPIKey string ) + agentAuth := &AgentAuth{} cmd := &serpent.Command{ Use: "claude-code ", Short: "Configure the Claude Code server. You will need to run this command for each project you want to use. Specify the project directory as the first argument.", @@ -148,7 +149,7 @@ func (r *RootCmd) mcpConfigureClaudeCode() *serpent.Command { binPath = testBinaryName } configureClaudeEnv := map[string]string{} - agentClient, err := r.createAgentClient() + agentClient, err := agentAuth.CreateClient(inv.Context()) if err != nil { cliui.Warnf(inv.Stderr, "failed to create agent client: %s", err) } else { @@ -292,6 +293,7 @@ func (r *RootCmd) mcpConfigureClaudeCode() *serpent.Command { }, }, } + agentAuth.AttachOptions(cmd, false) return cmd } @@ -403,7 +405,8 @@ func (r *RootCmd) mcpServer() *serpent.Command { appStatusSlug string aiAgentAPIURL url.URL ) - return &serpent.Command{ + agentAuth := &AgentAuth{} + cmd := &serpent.Command{ Use: "server", Handler: func(inv *serpent.Invocation) error { var lastReport taskReport @@ -494,7 +497,7 @@ func (r *RootCmd) mcpServer() *serpent.Command { } // Try to create an agent client for status reporting. Not validated. - agentClient, err := r.createAgentClient() + agentClient, err := agentAuth.CreateClient(inv.Context()) if err == nil { cliui.Infof(inv.Stderr, "Agent URL : %s", agentClient.SDK.URL.String()) srv.agentClient = agentClient @@ -579,6 +582,8 @@ func (r *RootCmd) mcpServer() *serpent.Command { }, }, } + agentAuth.AttachOptions(cmd, false) + return cmd } func (s *mcpServer) startReporter(ctx context.Context, inv *serpent.Invocation) { diff --git a/cli/externalauth.go b/cli/externalauth.go index 98bd853992..4aaa72c197 100644 --- a/cli/externalauth.go +++ b/cli/externalauth.go @@ -2,19 +2,16 @@ package cli import ( "encoding/json" - "fmt" - - "golang.org/x/xerrors" "github.com/tidwall/gjson" + "golang.org/x/xerrors" "github.com/coder/coder/v2/cli/cliui" "github.com/coder/coder/v2/codersdk/agentsdk" - "github.com/coder/pretty" "github.com/coder/serpent" ) -func (r *RootCmd) externalAuth() *serpent.Command { +func externalAuth() *serpent.Command { return &serpent.Command{ Use: "external-auth", Short: "Manage external authentication", @@ -23,14 +20,15 @@ func (r *RootCmd) externalAuth() *serpent.Command { return i.Command.HelpHandler(i) }, Children: []*serpent.Command{ - r.externalAuthAccessToken(), + externalAuthAccessToken(), }, } } -func (r *RootCmd) externalAuthAccessToken() *serpent.Command { +func externalAuthAccessToken() *serpent.Command { var extra string - return &serpent.Command{ + agentAuth := &AgentAuth{} + cmd := &serpent.Command{ Use: "access-token ", Short: "Print auth for an external provider", Long: "Print an access-token for an external auth provider. " + @@ -70,12 +68,7 @@ fi ctx, stop := inv.SignalNotifyContext(ctx, StopSignals...) defer stop() - if r.agentToken == "" { - _, _ = fmt.Fprint(inv.Stderr, pretty.Sprintf(headLineStyle(), "No agent token found, this command must be run from inside a running workspace.\n")) - return xerrors.Errorf("agent token not found") - } - - client, err := r.tryCreateAgentClient() + client, err := agentAuth.CreateClient(ctx) if err != nil { return xerrors.Errorf("create agent client: %w", err) } @@ -115,4 +108,6 @@ fi return nil }, } + agentAuth.AttachOptions(cmd, false) + return cmd } diff --git a/cli/gitaskpass.go b/cli/gitaskpass.go index e54d93478d..4729b333ae 100644 --- a/cli/gitaskpass.go +++ b/cli/gitaskpass.go @@ -18,8 +18,8 @@ import ( // gitAskpass is used by the Coder agent to automatically authenticate // with Git providers based on a hostname. -func (r *RootCmd) gitAskpass() *serpent.Command { - return &serpent.Command{ +func gitAskpass(agentAuth *AgentAuth) *serpent.Command { + cmd := &serpent.Command{ Use: "gitaskpass", Hidden: true, Handler: func(inv *serpent.Invocation) error { @@ -33,7 +33,7 @@ func (r *RootCmd) gitAskpass() *serpent.Command { return xerrors.Errorf("parse host: %w", err) } - client, err := r.tryCreateAgentClient() + client, err := agentAuth.CreateClient(ctx) if err != nil { return xerrors.Errorf("create agent client: %w", err) } @@ -90,4 +90,6 @@ func (r *RootCmd) gitAskpass() *serpent.Command { return nil }, } + agentAuth.AttachOptions(cmd, false) + return cmd } diff --git a/cli/gitaskpass_test.go b/cli/gitaskpass_test.go index 8e51411de9..584e003427 100644 --- a/cli/gitaskpass_test.go +++ b/cli/gitaskpass_test.go @@ -16,6 +16,7 @@ import ( "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/agentsdk" "github.com/coder/coder/v2/pty/ptytest" + "github.com/coder/coder/v2/testutil" ) func TestGitAskpass(t *testing.T) { @@ -32,6 +33,7 @@ func TestGitAskpass(t *testing.T) { url := srv.URL inv, _ := clitest.New(t, "--agent-url", url, "Username for 'https://github.com':") inv.Environ.Set("GIT_PREFIX", "/") + inv.Environ.Set("CODER_AGENT_TOKEN", "fake-token") pty := ptytest.New(t) inv.Stdout = pty.Output() clitest.Start(t, inv) @@ -39,6 +41,7 @@ func TestGitAskpass(t *testing.T) { inv, _ = clitest.New(t, "--agent-url", url, "Password for 'https://potato@github.com':") inv.Environ.Set("GIT_PREFIX", "/") + inv.Environ.Set("CODER_AGENT_TOKEN", "fake-token") pty = ptytest.New(t) inv.Stdout = pty.Output() clitest.Start(t, inv) @@ -56,6 +59,7 @@ func TestGitAskpass(t *testing.T) { url := srv.URL inv, _ := clitest.New(t, "--agent-url", url, "--no-open", "Username for 'https://github.com':") inv.Environ.Set("GIT_PREFIX", "/") + inv.Environ.Set("CODER_AGENT_TOKEN", "fake-token") pty := ptytest.New(t) inv.Stderr = pty.Output() err := inv.Run() @@ -65,6 +69,7 @@ func TestGitAskpass(t *testing.T) { t.Run("Poll", func(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) resp := atomic.Pointer[agentsdk.ExternalAuthResponse]{} resp.Store(&agentsdk.ExternalAuthResponse{ URL: "https://something.org", @@ -86,6 +91,7 @@ func TestGitAskpass(t *testing.T) { inv, _ := clitest.New(t, "--agent-url", url, "--no-open", "Username for 'https://github.com':") inv.Environ.Set("GIT_PREFIX", "/") + inv.Environ.Set("CODER_AGENT_TOKEN", "fake-token") stdout := ptytest.New(t) inv.Stdout = stdout.Output() stderr := ptytest.New(t) @@ -94,7 +100,7 @@ func TestGitAskpass(t *testing.T) { err := inv.Run() assert.NoError(t, err) }() - <-poll + testutil.RequireReceive(ctx, t, poll) stderr.ExpectMatch("Open the following URL to authenticate") resp.Store(&agentsdk.ExternalAuthResponse{ Username: "username", diff --git a/cli/gitssh.go b/cli/gitssh.go index 566d3cc6f1..043049b7e8 100644 --- a/cli/gitssh.go +++ b/cli/gitssh.go @@ -18,7 +18,8 @@ import ( "github.com/coder/serpent" ) -func (r *RootCmd) gitssh() *serpent.Command { +func gitssh() *serpent.Command { + agentAuth := &AgentAuth{} cmd := &serpent.Command{ Use: "gitssh", Hidden: true, @@ -38,7 +39,7 @@ func (r *RootCmd) gitssh() *serpent.Command { return err } - client, err := r.tryCreateAgentClient() + client, err := agentAuth.CreateClient(ctx) if err != nil { return xerrors.Errorf("create agent client: %w", err) } @@ -108,7 +109,7 @@ func (r *RootCmd) gitssh() *serpent.Command { return nil }, } - + agentAuth.AttachOptions(cmd, false) return cmd } diff --git a/cli/gitssh_test.go b/cli/gitssh_test.go index 6d574ae651..8ff32363e9 100644 --- a/cli/gitssh_test.go +++ b/cli/gitssh_test.go @@ -54,8 +54,7 @@ func prepareTestGitSSH(ctx context.Context, t *testing.T) (*agentsdk.Client, str }).WithAgent().Do() // start workspace agent - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(r.AgentToken) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken)) _ = agenttest.New(t, client.URL, r.AgentToken, func(o *agent.Options) { o.Client = agentClient }) diff --git a/cli/root.go b/cli/root.go index ed6869b6a1..a18401e253 100644 --- a/cli/root.go +++ b/cli/root.go @@ -24,6 +24,7 @@ import ( "text/tabwriter" "time" + "cloud.google.com/go/compute/metadata" "github.com/mattn/go-isatty" "github.com/mitchellh/go-wordwrap" "golang.org/x/mod/semver" @@ -59,9 +60,6 @@ var ( const ( varURL = "url" varToken = "token" - varAgentToken = "agent-token" - varAgentTokenFile = "agent-token-file" - varAgentURL = "agent-url" varHeader = "header" varHeaderCommand = "header-command" varNoOpen = "no-open" @@ -82,6 +80,7 @@ const ( //nolint:gosec envAgentTokenFile = "CODER_AGENT_TOKEN_FILE" envAgentURL = "CODER_AGENT_URL" + envAgentAuth = "CODER_AGENT_AUTH" envURL = "CODER_URL" ) @@ -90,7 +89,7 @@ func (r *RootCmd) CoreSubcommands() []*serpent.Command { return []*serpent.Command{ r.completion(), r.dotfiles(), - r.externalAuth(), + externalAuth(), r.login(), r.logout(), r.netcheck(), @@ -130,11 +129,11 @@ func (r *RootCmd) CoreSubcommands() []*serpent.Command { // Hidden r.connectCmd(), r.expCmd(), - r.gitssh(), + gitssh(), r.support(), r.vpnDaemon(), r.vscodeSSH(), - r.workspaceAgent(), + workspaceAgent(), } } @@ -198,6 +197,7 @@ func (r *RootCmd) RunWithSubcommands(subcommands []*serpent.Command) { func (r *RootCmd) Command(subcommands []*serpent.Command) (*serpent.Command, error) { fmtLong := `Coder %s — A tool for provisioning self-hosted development environments with Terraform. ` + hiddenAgentAuth := &AgentAuth{} cmd := &serpent.Command{ Use: "coder [global-flags] ", Long: fmt.Sprintf(fmtLong, buildinfo.Version()) + FormatExamples( @@ -220,7 +220,7 @@ func (r *RootCmd) Command(subcommands []*serpent.Command) (*serpent.Command, err // with a `gitaskpass` subcommand, we override the entrypoint // to check if the command was invoked. if gitauth.CheckCommand(i.Args, i.Environ.ToOS()) { - return r.gitAskpass().Handler(i) + return gitAskpass(hiddenAgentAuth).Handler(i) } return i.Command.HelpHandler(i) }, @@ -349,9 +349,6 @@ func (r *RootCmd) Command(subcommands []*serpent.Command) (*serpent.Command, err } }) - if r.agentURL == nil { - r.agentURL = new(url.URL) - } if r.clientURL == nil { r.clientURL = new(url.URL) } @@ -381,30 +378,6 @@ func (r *RootCmd) Command(subcommands []*serpent.Command) (*serpent.Command, err Value: serpent.StringOf(&r.token), Group: globalGroup, }, - { - Flag: varAgentToken, - Env: envAgentToken, - Description: "An agent authentication token.", - Value: serpent.StringOf(&r.agentToken), - Hidden: true, - Group: globalGroup, - }, - { - Flag: varAgentTokenFile, - Env: envAgentTokenFile, - Description: "A file containing an agent authentication token.", - Value: serpent.StringOf(&r.agentTokenFile), - Hidden: true, - Group: globalGroup, - }, - { - Flag: varAgentURL, - Env: envAgentURL, - Description: "URL for an agent to access your deployment.", - Value: serpent.URLOf(r.agentURL), - Hidden: true, - Group: globalGroup, - }, { Flag: varNoVersionCheck, Env: envNoVersionCheck, @@ -496,26 +469,25 @@ func (r *RootCmd) Command(subcommands []*serpent.Command) (*serpent.Command, err Hidden: true, }, } + hiddenAgentAuth.AttachOptions(cmd, true) return cmd, nil } // RootCmd contains parameters and helpers useful to all commands. type RootCmd struct { - clientURL *url.URL - token string - globalConfig string - header []string - headerCommand string - agentToken string - agentTokenFile string - agentURL *url.URL - forceTTY bool - noOpen bool - verbose bool - versionFlag bool - disableDirect bool - debugHTTP bool + clientURL *url.URL + token string + globalConfig string + header []string + headerCommand string + + forceTTY bool + noOpen bool + verbose bool + versionFlag bool + disableDirect bool + debugHTTP bool disableNetworkTelemetry bool noVersionCheck bool @@ -672,38 +644,111 @@ func (r *RootCmd) createUnauthenticatedClient(ctx context.Context, serverURL *ur return &client, err } -// createAgentClient returns a new client from the command context. It works -// just like InitClient, but uses the agent token and URL instead. -func (r *RootCmd) createAgentClient() (*agentsdk.Client, error) { - agentURL := r.agentURL - if agentURL == nil || agentURL.String() == "" { - return nil, xerrors.Errorf("%s must be set", envAgentURL) - } - token := r.agentToken - if token == "" { - if r.agentTokenFile == "" { - return nil, xerrors.Errorf("Either %s or %s must be set", envAgentToken, envAgentTokenFile) - } - tokenBytes, err := os.ReadFile(r.agentTokenFile) - if err != nil { - return nil, xerrors.Errorf("read token file %q: %w", r.agentTokenFile, err) - } - token = strings.TrimSpace(string(tokenBytes)) - } - client := agentsdk.New(agentURL) - client.SetSessionToken(token) - return client, nil +type AgentAuth struct { + // Agent Client config + agentToken string + agentTokenFile string + agentURL url.URL + agentAuth string } -// tryCreateAgentClient returns a new client from the command context. It works -// just like tryCreateAgentClient, but does not error. -func (r *RootCmd) tryCreateAgentClient() (*agentsdk.Client, error) { - // TODO: Why does this not actually return any errors despite the function - // signature? Could we just use createAgentClient instead, or is it expected - // that we return a client in some cases even without a valid URL or token? - client := agentsdk.New(r.agentURL) - client.SetSessionToken(r.agentToken) - return client, nil +func (a *AgentAuth) AttachOptions(cmd *serpent.Command, hidden bool) { + cmd.Options = append(cmd.Options, serpent.Option{ + Name: "Agent Token", + Description: "An agent authentication token.", + Flag: "agent-token", + Env: envAgentToken, + Value: serpent.StringOf(&a.agentToken), + Hidden: hidden, + }, serpent.Option{ + Name: "Agent Token File", + Description: "A file containing an agent authentication token.", + Flag: "agent-token-file", + Env: envAgentTokenFile, + Value: serpent.StringOf(&a.agentTokenFile), + Hidden: hidden, + }, serpent.Option{ + Name: "Agent URL", + Description: "URL for an agent to access your deployment.", + Flag: "agent-url", + Env: envAgentURL, + Value: serpent.URLOf(&a.agentURL), + Hidden: hidden, + }, serpent.Option{ + Name: "Agent Auth", + Description: "Specify the authentication type to use for the agent.", + Flag: "auth", + Env: envAgentAuth, + Default: "token", + Value: serpent.StringOf(&a.agentAuth), + Hidden: hidden, + }) +} + +// CreateClient returns a new agent client from the command context. It works +// just like InitClient, but uses the agent token and URL instead. +func (a *AgentAuth) CreateClient(ctx context.Context) (*agentsdk.Client, error) { + agentURL := a.agentURL + if agentURL.String() == "" { + return nil, xerrors.Errorf("%s must be set", envAgentURL) + } + + switch a.agentAuth { + case "token": + token := a.agentToken + if token == "" { + if a.agentTokenFile == "" { + return nil, xerrors.Errorf("Either %s or %s must be set", envAgentToken, envAgentTokenFile) + } + tokenBytes, err := os.ReadFile(a.agentTokenFile) + if err != nil { + return nil, xerrors.Errorf("read token file %q: %w", a.agentTokenFile, err) + } + token = strings.TrimSpace(string(tokenBytes)) + } + if token == "" { + return nil, xerrors.Errorf("CODER_AGENT_TOKEN or CODER_AGENT_TOKEN_FILE must be set for token auth") + } + return agentsdk.New(&a.agentURL, agentsdk.WithFixedToken(token)), nil + case "google-instance-identity": + + // This is *only* done for testing to mock client authentication. + // This will never be set in a production scenario. + var gcpClient *metadata.Client + gcpClientRaw := ctx.Value("gcp-client") + if gcpClientRaw != nil { + gcpClient, _ = gcpClientRaw.(*metadata.Client) + } + return agentsdk.New(&a.agentURL, agentsdk.WithGoogleInstanceIdentity("", gcpClient)), nil + case "aws-instance-identity": + client := agentsdk.New(&a.agentURL, agentsdk.WithAWSInstanceIdentity()) + // This is *only* done for testing to mock client authentication. + // This will never be set in a production scenario. + var awsClient *http.Client + awsClientRaw := ctx.Value("aws-client") + if awsClientRaw != nil { + awsClient, _ = awsClientRaw.(*http.Client) + if awsClient != nil { + client.SDK.HTTPClient = awsClient + } + } + return client, nil + case "azure-instance-identity": + client := agentsdk.New(&a.agentURL, agentsdk.WithAzureInstanceIdentity()) + // This is *only* done for testing to mock client authentication. + // This will never be set in a production scenario. + var azureClient *http.Client + azureClientRaw := ctx.Value("azure-client") + if azureClientRaw != nil { + azureClient, _ = azureClientRaw.(*http.Client) + if azureClient != nil { + client.SDK.HTTPClient = azureClient + } + } + return client, nil + default: + return nil, xerrors.Errorf("unknown agent auth type: %s", a.agentAuth) + } } type OrganizationContext struct { diff --git a/cli/testdata/coder_agent_--help.golden b/cli/testdata/coder_agent_--help.golden index c6d75705a6..1f25fc6941 100644 --- a/cli/testdata/coder_agent_--help.golden +++ b/cli/testdata/coder_agent_--help.golden @@ -6,6 +6,18 @@ USAGE: Starts the Coder workspace agent. OPTIONS: + --auth string, $CODER_AGENT_AUTH (default: token) + Specify the authentication type to use for the agent. + + --agent-token string, $CODER_AGENT_TOKEN + An agent authentication token. + + --agent-token-file string, $CODER_AGENT_TOKEN_FILE + A file containing an agent authentication token. + + --agent-url url, $CODER_AGENT_URL + URL for an agent to access your deployment. + --log-human string, $CODER_AGENT_LOGGING_HUMAN (default: /dev/stderr) Output human-readable logs to a given file. @@ -24,9 +36,6 @@ OPTIONS: requests. The command must output each header as `key=value` on its own line. - --auth string, $CODER_AGENT_AUTH (default: token) - Specify the authentication type to use for the agent. - --block-file-transfer bool, $CODER_AGENT_BLOCK_FILE_TRANSFER (default: false) Block file transfer using known applications: nc,rsync,scp,sftp. diff --git a/cli/testdata/coder_external-auth_access-token_--help.golden b/cli/testdata/coder_external-auth_access-token_--help.golden index e4693a6fb9..234cca5d4f 100644 --- a/cli/testdata/coder_external-auth_access-token_--help.golden +++ b/cli/testdata/coder_external-auth_access-token_--help.golden @@ -25,6 +25,18 @@ USAGE: $ coder external-auth access-token slack --extra "authed_user.id" OPTIONS: + --auth string, $CODER_AGENT_AUTH (default: token) + Specify the authentication type to use for the agent. + + --agent-token string, $CODER_AGENT_TOKEN + An agent authentication token. + + --agent-token-file string, $CODER_AGENT_TOKEN_FILE + A file containing an agent authentication token. + + --agent-url url, $CODER_AGENT_URL + URL for an agent to access your deployment. + --extra string Extract a field from the "extra" properties of the OAuth token. diff --git a/coderd/externalauth_test.go b/coderd/externalauth_test.go index c9ba491121..68244bf3a4 100644 --- a/coderd/externalauth_test.go +++ b/coderd/externalauth_test.go @@ -432,8 +432,7 @@ func TestExternalAuthCallback(t *testing.T) { workspace := coderdtest.CreateWorkspace(t, client, template.ID) coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(authToken) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(authToken)) _, err := agentClient.ExternalAuth(context.Background(), agentsdk.ExternalAuthRequest{ Match: "github.com", }) @@ -464,8 +463,7 @@ func TestExternalAuthCallback(t *testing.T) { workspace := coderdtest.CreateWorkspace(t, client, template.ID) coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(authToken) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(authToken)) token, err := agentClient.ExternalAuth(context.Background(), agentsdk.ExternalAuthRequest{ Match: "github.com/asd/asd", }) @@ -565,8 +563,7 @@ func TestExternalAuthCallback(t *testing.T) { workspace := coderdtest.CreateWorkspace(t, client, template.ID) coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(authToken) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(authToken)) resp := coderdtest.RequestExternalAuthCallback(t, "github", client) require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) @@ -627,8 +624,7 @@ func TestExternalAuthCallback(t *testing.T) { workspace := coderdtest.CreateWorkspace(t, client, template.ID) coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(authToken) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(authToken)) token, err := agentClient.ExternalAuth(context.Background(), agentsdk.ExternalAuthRequest{ Match: "github.com/asd/asd", @@ -674,8 +670,7 @@ func TestExternalAuthCallback(t *testing.T) { workspace := coderdtest.CreateWorkspace(t, client, template.ID) coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(authToken) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(authToken)) token, err := agentClient.ExternalAuth(context.Background(), agentsdk.ExternalAuthRequest{ Match: "github.com/asd/asd", @@ -740,8 +735,7 @@ func TestExternalAuthCallback(t *testing.T) { workspace := coderdtest.CreateWorkspace(t, client, template.ID) coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(authToken) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(authToken)) token, err := agentClient.ExternalAuth(t.Context(), agentsdk.ExternalAuthRequest{ Match: "github.com/asd/asd", diff --git a/coderd/gitsshkey_test.go b/coderd/gitsshkey_test.go index abd18508ce..27f9121bd3 100644 --- a/coderd/gitsshkey_test.go +++ b/coderd/gitsshkey_test.go @@ -118,8 +118,7 @@ func TestAgentGitSSHKey(t *testing.T) { workspace := coderdtest.CreateWorkspace(t, client, project.ID) coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(authToken) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(authToken)) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() @@ -157,8 +156,7 @@ func TestAgentGitSSHKey_APIKeyScopes(t *testing.T) { workspace := coderdtest.CreateWorkspace(t, client, project.ID) coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(authToken) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(authToken)) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() diff --git a/coderd/insights_test.go b/coderd/insights_test.go index cf5f63065d..99bf9b9a66 100644 --- a/coderd/insights_test.go +++ b/coderd/insights_test.go @@ -585,8 +585,7 @@ func TestTemplateInsights_Golden(t *testing.T) { continue } authToken := uuid.New() - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(authToken.String()) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(authToken.String())) workspace.agentClient = agentClient var apps []*proto.App @@ -1494,8 +1493,7 @@ func TestUserActivityInsights_Golden(t *testing.T) { continue } authToken := uuid.New() - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(authToken.String()) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(authToken.String())) workspace.agentClient = agentClient var apps []*proto.App diff --git a/coderd/prometheusmetrics/insights/metricscollector_test.go b/coderd/prometheusmetrics/insights/metricscollector_test.go index 5c18ec6d1a..560a601992 100644 --- a/coderd/prometheusmetrics/insights/metricscollector_test.go +++ b/coderd/prometheusmetrics/insights/metricscollector_test.go @@ -90,8 +90,7 @@ func TestCollectInsights(t *testing.T) { // Start an agent so that we can generate stats. var agentClients []agentproto.DRPCAgentClient for i, agent := range []database.WorkspaceAgent{agent1, agent2} { - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(agent.AuthToken.String()) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(agent.AuthToken.String())) agentClient.SDK.SetLogger(logger.Leveled(slog.LevelDebug).Named(fmt.Sprintf("agent%d", i+1))) conn, err := agentClient.ConnectRPC(context.Background()) require.NoError(t, err) diff --git a/coderd/prometheusmetrics/prometheusmetrics_test.go b/coderd/prometheusmetrics/prometheusmetrics_test.go index 3d8704f924..e75f86e51b 100644 --- a/coderd/prometheusmetrics/prometheusmetrics_test.go +++ b/coderd/prometheusmetrics/prometheusmetrics_test.go @@ -875,8 +875,7 @@ func prepareWorkspaceAndAgent(ctx context.Context, t *testing.T, client *codersd }) coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) - ac := agentsdk.New(client.URL) - ac.SetSessionToken(authToken) + ac := agentsdk.New(client.URL, agentsdk.WithFixedToken(authToken)) conn, err := ac.ConnectRPC(ctx) require.NoError(t, err) agentAPI := agentproto.NewDRPCAgentClient(conn) diff --git a/coderd/workspaceagents_test.go b/coderd/workspaceagents_test.go index 6a817966f4..e950f97075 100644 --- a/coderd/workspaceagents_test.go +++ b/coderd/workspaceagents_test.go @@ -228,8 +228,7 @@ func TestWorkspaceAgentLogs(t *testing.T) { OwnerID: user.UserID, }).WithAgent().Do() - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(r.AgentToken) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken)) err := agentClient.PatchLogs(ctx, agentsdk.PatchLogs{ Logs: []agentsdk.Log{ { @@ -269,8 +268,7 @@ func TestWorkspaceAgentLogs(t *testing.T) { OrganizationID: user.OrganizationID, OwnerID: user.UserID, }).WithAgent().Do() - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(r.AgentToken) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken)) err := agentClient.PatchLogs(ctx, agentsdk.PatchLogs{ Logs: []agentsdk.Log{ { @@ -314,8 +312,7 @@ func TestWorkspaceAgentLogs(t *testing.T) { updates, err := client.WatchWorkspace(ctx, r.Workspace.ID) require.NoError(t, err) - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(r.AgentToken) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken)) err = agentClient.PatchLogs(ctx, agentsdk.PatchLogs{ Logs: []agentsdk.Log{{ CreatedAt: dbtime.Now(), @@ -360,8 +357,7 @@ func TestWorkspaceAgentAppStatus(t *testing.T) { return a }).Do() - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(r.AgentToken) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken)) t.Run("Success", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitShort) @@ -542,8 +538,7 @@ func TestWorkspaceAgentConnectRPC(t *testing.T) { require.NoError(t, err) coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, stopBuild.ID) - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(authToken) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(authToken)) _, err = agentClient.ConnectRPC(ctx) require.Error(t, err) @@ -568,8 +563,7 @@ func TestWorkspaceAgentConnectRPC(t *testing.T) { ) require.NoError(t, err) // Then: the agent token should no longer be valid - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(wsb.AgentToken) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken((wsb.AgentToken))) _, err = agentClient.ConnectRPC(ctx) require.Error(t, err) var sdkErr *codersdk.Error @@ -890,8 +884,7 @@ func TestWorkspaceAgentTailnetDirectDisabled(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) // Verify that the manifest has DisableDirectConnections set to true. - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(r.AgentToken) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken)) rpc, err := agentClient.ConnectRPC(ctx) require.NoError(t, err) defer func() { @@ -1742,8 +1735,7 @@ func TestWorkspaceAgentAppHealth(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(r.AgentToken) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken)) conn, err := agentClient.ConnectRPC(ctx) require.NoError(t, err) defer func() { @@ -1818,8 +1810,7 @@ func TestWorkspaceAgentPostLogSource(t *testing.T) { OwnerID: user.UserID, }).WithAgent().Do() - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(r.AgentToken) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken)) req := agentsdk.PostLogSourceRequest{ ID: uuid.New(), @@ -1867,8 +1858,7 @@ func TestWorkspaceAgent_LifecycleState(t *testing.T) { } } - ac := agentsdk.New(client.URL) - ac.SetSessionToken(r.AgentToken) + ac := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken)) conn, err := ac.ConnectRPC(ctx) require.NoError(t, err) defer func() { @@ -1965,8 +1955,7 @@ func TestWorkspaceAgent_Metadata(t *testing.T) { } } - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(r.AgentToken) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken)) ctx := testutil.Context(t, testutil.WaitMedium) conn, err := agentClient.ConnectRPC(ctx) @@ -2229,8 +2218,7 @@ func TestWorkspaceAgent_Metadata_CatchMemoryLeak(t *testing.T) { } } - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(r.AgentToken) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken)) ctx := testutil.Context(t, testutil.WaitSuperLong) conn, err := agentClient.ConnectRPC(ctx) @@ -2335,8 +2323,7 @@ func TestWorkspaceAgent_Startup(t *testing.T) { OrganizationID: user.OrganizationID, OwnerID: user.UserID, }).WithAgent().Do() - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(r.AgentToken) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken)) ctx := testutil.Context(t, testutil.WaitMedium) @@ -2382,8 +2369,7 @@ func TestWorkspaceAgent_Startup(t *testing.T) { OwnerID: user.UserID, }).WithAgent().Do() - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(r.AgentToken) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken)) ctx := testutil.Context(t, testutil.WaitMedium) @@ -2547,8 +2533,7 @@ func TestWorkspaceAgentExternalAuthListen(t *testing.T) { return agents }).Do() - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(r.AgentToken) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken)) // We need to include an invalid oauth token that is not expired. dbgen.ExternalAuthLink(t, db, database.ExternalAuthLink{ @@ -3028,8 +3013,7 @@ func TestReinit(t *testing.T) { pubsubSpy.Unlock() agentCtx := testutil.Context(t, testutil.WaitShort) - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(r.AgentToken) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken)) agentReinitializedCh := make(chan *agentsdk.ReinitializationEvent) go func() { diff --git a/coderd/workspaceagentsrpc_test.go b/coderd/workspaceagentsrpc_test.go index 5175f80b0b..525b8a981d 100644 --- a/coderd/workspaceagentsrpc_test.go +++ b/coderd/workspaceagentsrpc_test.go @@ -68,8 +68,7 @@ func TestWorkspaceAgentReportStats(t *testing.T) { }, ).Do() - ac := agentsdk.New(client.URL) - ac.SetSessionToken(r.AgentToken) + ac := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken)) conn, err := ac.ConnectRPC(context.Background()) require.NoError(t, err) defer func() { @@ -155,8 +154,7 @@ func TestAgentAPI_LargeManifest(t *testing.T) { agents[0].ApiKeyScope = string(tc.apiKeyScope) return agents }).Do() - ac := agentsdk.New(client.URL) - ac.SetSessionToken(r.AgentToken) + ac := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken)) conn, err := ac.ConnectRPC(ctx) defer func() { _ = conn.Close() diff --git a/coderd/workspaceapps/apptest/setup.go b/coderd/workspaceapps/apptest/setup.go index 296934591e..05bfb66219 100644 --- a/coderd/workspaceapps/apptest/setup.go +++ b/coderd/workspaceapps/apptest/setup.go @@ -482,8 +482,7 @@ func createWorkspaceWithApps(t *testing.T, client *codersdk.Client, orgID uuid.U require.Equal(t, appURL.String(), app.SubdomainName) } - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(authToken) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(authToken)) // TODO (@dean): currently, the primary app host is used when generating // the port URL we tell the agent to use. We don't have any plans to change diff --git a/coderd/workspaceresourceauth_test.go b/coderd/workspaceresourceauth_test.go index 8c1b64feaf..73524a63ad 100644 --- a/coderd/workspaceresourceauth_test.go +++ b/coderd/workspaceresourceauth_test.go @@ -51,11 +51,9 @@ func TestPostWorkspaceAuthAzureInstanceIdentity(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() - client.HTTPClient = metadataClient - agentClient := &agentsdk.Client{ - SDK: client, - } - _, err := agentClient.AuthAzureInstanceIdentity(ctx) + agentClient := agentsdk.New(client.URL, agentsdk.WithAzureInstanceIdentity()) + agentClient.SDK.HTTPClient = metadataClient + err := agentClient.RefreshToken(ctx) require.NoError(t, err) } @@ -97,11 +95,9 @@ func TestPostWorkspaceAuthAWSInstanceIdentity(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() - client.HTTPClient = metadataClient - agentClient := &agentsdk.Client{ - SDK: client, - } - _, err := agentClient.AuthAWSInstanceIdentity(ctx) + agentClient := agentsdk.New(client.URL, agentsdk.WithAWSInstanceIdentity()) + agentClient.SDK.HTTPClient = metadataClient + err := agentClient.RefreshToken(ctx) require.NoError(t, err) }) } @@ -119,10 +115,8 @@ func TestPostWorkspaceAuthGoogleInstanceIdentity(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() - agentClient := &agentsdk.Client{ - SDK: client, - } - _, err := agentClient.AuthGoogleInstanceIdentity(ctx, "", metadata) + agentClient := agentsdk.New(client.URL, agentsdk.WithGoogleInstanceIdentity("", metadata)) + err := agentClient.RefreshToken(ctx) var apiErr *codersdk.Error require.ErrorAs(t, err, &apiErr) require.Equal(t, http.StatusUnauthorized, apiErr.StatusCode()) @@ -139,10 +133,8 @@ func TestPostWorkspaceAuthGoogleInstanceIdentity(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() - agentClient := &agentsdk.Client{ - SDK: client, - } - _, err := agentClient.AuthGoogleInstanceIdentity(ctx, "", metadata) + agentClient := agentsdk.New(client.URL, agentsdk.WithGoogleInstanceIdentity("", metadata)) + err := agentClient.RefreshToken(ctx) var apiErr *codersdk.Error require.ErrorAs(t, err, &apiErr) require.Equal(t, http.StatusNotFound, apiErr.StatusCode()) @@ -184,10 +176,8 @@ func TestPostWorkspaceAuthGoogleInstanceIdentity(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() - agentClient := &agentsdk.Client{ - SDK: client, - } - _, err := agentClient.AuthGoogleInstanceIdentity(ctx, "", metadata) + agentClient := agentsdk.New(client.URL, agentsdk.WithGoogleInstanceIdentity("", metadata)) + err := agentClient.RefreshToken(ctx) require.NoError(t, err) }) } diff --git a/codersdk/agentsdk/agentsdk.go b/codersdk/agentsdk/agentsdk.go index 5bd0030456..d13f600a03 100644 --- a/codersdk/agentsdk/agentsdk.go +++ b/codersdk/agentsdk/agentsdk.go @@ -8,9 +8,9 @@ import ( "net/http" "net/http/cookiejar" "net/url" + "sync" "time" - "cloud.google.com/go/compute/metadata" "github.com/google/uuid" "github.com/hashicorp/yamux" "golang.org/x/xerrors" @@ -37,24 +37,31 @@ import ( // log-source. This should be removed in the future. var ExternalLogSourceID = uuid.MustParse("3b579bf4-1ed8-4b99-87a8-e9a1e3410410") -// New returns a client that is used to interact with the -// Coder API from a workspace agent. -func New(serverURL *url.URL) *Client { +// SessionTokenSetup is a function that creates the token provider while setting up the workspace agent. We do it this +// way because cloud instance identity (AWS, Azure, Google, etc.) requires interacting with coderd to exchange tokens. +// This means that the token providers need a codersdk.Client. However, the SessionTokenProvider is itself used by +// the client to authenticate requests. Thus, the dependency is bidirectional. Functions of this type are used in +// New() to ensure that things are set up correctly so there is only one instance of the codersdk.Client created. +// @typescript-ignore SessionTokenSetup +type SessionTokenSetup func(client *codersdk.Client) RefreshableSessionTokenProvider + +func New(serverURL *url.URL, setup SessionTokenSetup) *Client { + c := codersdk.New(serverURL) + provider := setup(c) + c.SessionTokenProvider = provider return &Client{ - SDK: codersdk.New(serverURL), + SDK: c, + RefreshableSessionTokenProvider: provider, } } // Client wraps `codersdk.Client` with specific functions // scoped to a workspace agent. type Client struct { + RefreshableSessionTokenProvider SDK *codersdk.Client } -func (c *Client) SetSessionToken(token string) { - c.SDK.SetSessionToken(token) -} - type GitSSHKey struct { PublicKey string `json:"public_key"` PrivateKey string `json:"private_key"` @@ -326,146 +333,91 @@ type AuthenticateResponse struct { SessionToken string `json:"session_token"` } -type GoogleInstanceIdentityToken struct { - JSONWebToken string `json:"json_web_token" validate:"required"` +// RefreshableSessionTokenProvider is a SessionTokenProvider that can be refreshed, for example, via token exchange. +// @typescript-ignore RefreshableSessionTokenProvider +type RefreshableSessionTokenProvider interface { + codersdk.SessionTokenProvider + RefreshToken(ctx context.Context) error } -// AuthWorkspaceGoogleInstanceIdentity uses the Google Compute Engine Metadata API to -// fetch a signed JWT, and exchange it for a session token for a workspace agent. -// -// The requesting instance must be registered as a resource in the latest history for a workspace. -func (c *Client) AuthGoogleInstanceIdentity(ctx context.Context, serviceAccount string, gcpClient *metadata.Client) (AuthenticateResponse, error) { - if serviceAccount == "" { - // This is the default name specified by Google. - serviceAccount = "default" - } - if gcpClient == nil { - gcpClient = metadata.NewClient(c.SDK.HTTPClient) - } - // "format=full" is required, otherwise the responding payload will be missing "instance_id". - jwt, err := gcpClient.Get(fmt.Sprintf("instance/service-accounts/%s/identity?audience=coder&format=full", serviceAccount)) - if err != nil { - return AuthenticateResponse{}, xerrors.Errorf("get metadata identity: %w", err) - } - res, err := c.SDK.Request(ctx, http.MethodPost, "/api/v2/workspaceagents/google-instance-identity", GoogleInstanceIdentityToken{ - JSONWebToken: jwt, - }) - if err != nil { - return AuthenticateResponse{}, err - } - defer res.Body.Close() - if res.StatusCode != http.StatusOK { - return AuthenticateResponse{}, codersdk.ReadBodyAsError(res) - } - var resp AuthenticateResponse - return resp, json.NewDecoder(res.Body).Decode(&resp) +// instanceIdentitySessionTokenProvider implements RefreshableSessionTokenProvider via token exchange for a cloud +// compute instance identity. +// @typescript-ignore instanceIdentitySessionTokenProvider +type instanceIdentitySessionTokenProvider struct { + tokenExchanger tokenExchanger + logger slog.Logger + + // cache so we don't request each time + mu sync.Mutex + sessionToken string } -type AWSInstanceIdentityToken struct { - Signature string `json:"signature" validate:"required"` - Document string `json:"document" validate:"required"` +// tokenExchanger obtains a session token by exchanging a cloud instance identity credential for a Coder session token. +// @typescript-ignore tokenExchanger +type tokenExchanger interface { + exchange(ctx context.Context) (AuthenticateResponse, error) } -// AuthWorkspaceAWSInstanceIdentity uses the Amazon Metadata API to -// fetch a signed payload, and exchange it for a session token for a workspace agent. -// -// The requesting instance must be registered as a resource in the latest history for a workspace. -func (c *Client) AuthAWSInstanceIdentity(ctx context.Context) (AuthenticateResponse, error) { - req, err := http.NewRequestWithContext(ctx, http.MethodPut, "http://169.254.169.254/latest/api/token", nil) - if err != nil { - return AuthenticateResponse{}, nil +func (i *instanceIdentitySessionTokenProvider) AsRequestOption() codersdk.RequestOption { + t := i.GetSessionToken() + return func(req *http.Request) { + req.Header.Set(codersdk.SessionTokenHeader, t) } - req.Header.Set("X-aws-ec2-metadata-token-ttl-seconds", "21600") - res, err := c.SDK.HTTPClient.Do(req) - if err != nil { - return AuthenticateResponse{}, err - } - defer res.Body.Close() - token, err := io.ReadAll(res.Body) - if err != nil { - return AuthenticateResponse{}, xerrors.Errorf("read token: %w", err) - } - - req, err = http.NewRequestWithContext(ctx, http.MethodGet, "http://169.254.169.254/latest/dynamic/instance-identity/signature", nil) - if err != nil { - return AuthenticateResponse{}, nil - } - req.Header.Set("X-aws-ec2-metadata-token", string(token)) - res, err = c.SDK.HTTPClient.Do(req) - if err != nil { - return AuthenticateResponse{}, err - } - defer res.Body.Close() - signature, err := io.ReadAll(res.Body) - if err != nil { - return AuthenticateResponse{}, xerrors.Errorf("read token: %w", err) - } - - req, err = http.NewRequestWithContext(ctx, http.MethodGet, "http://169.254.169.254/latest/dynamic/instance-identity/document", nil) - if err != nil { - return AuthenticateResponse{}, nil - } - req.Header.Set("X-aws-ec2-metadata-token", string(token)) - res, err = c.SDK.HTTPClient.Do(req) - if err != nil { - return AuthenticateResponse{}, err - } - defer res.Body.Close() - document, err := io.ReadAll(res.Body) - if err != nil { - return AuthenticateResponse{}, xerrors.Errorf("read token: %w", err) - } - - res, err = c.SDK.Request(ctx, http.MethodPost, "/api/v2/workspaceagents/aws-instance-identity", AWSInstanceIdentityToken{ - Signature: string(signature), - Document: string(document), - }) - if err != nil { - return AuthenticateResponse{}, err - } - defer res.Body.Close() - if res.StatusCode != http.StatusOK { - return AuthenticateResponse{}, codersdk.ReadBodyAsError(res) - } - var resp AuthenticateResponse - return resp, json.NewDecoder(res.Body).Decode(&resp) } -type AzureInstanceIdentityToken struct { - Signature string `json:"signature" validate:"required"` - Encoding string `json:"encoding" validate:"required"` +func (i *instanceIdentitySessionTokenProvider) SetDialOption(opts *websocket.DialOptions) { + t := i.GetSessionToken() + if opts.HTTPHeader == nil { + opts.HTTPHeader = http.Header{} + } + if opts.HTTPHeader.Get(codersdk.SessionTokenHeader) == "" { + opts.HTTPHeader.Set(codersdk.SessionTokenHeader, t) + } } -// AuthWorkspaceAzureInstanceIdentity uses the Azure Instance Metadata Service to -// fetch a signed payload, and exchange it for a session token for a workspace agent. -func (c *Client) AuthAzureInstanceIdentity(ctx context.Context) (AuthenticateResponse, error) { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://169.254.169.254/metadata/attested/document?api-version=2020-09-01", nil) - if err != nil { - return AuthenticateResponse{}, nil +func (i *instanceIdentitySessionTokenProvider) GetSessionToken() string { + i.mu.Lock() + defer i.mu.Unlock() + if i.sessionToken != "" { + return i.sessionToken } - req.Header.Set("Metadata", "true") - res, err := c.SDK.HTTPClient.Do(req) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + resp, err := i.tokenExchanger.exchange(ctx) if err != nil { - return AuthenticateResponse{}, err + i.logger.Error(ctx, "failed to exchange session token: %v", err) + return "" } - defer res.Body.Close() + i.sessionToken = resp.SessionToken + return i.sessionToken +} - var token AzureInstanceIdentityToken - err = json.NewDecoder(res.Body).Decode(&token) +func (i *instanceIdentitySessionTokenProvider) RefreshToken(ctx context.Context) error { + i.mu.Lock() + defer i.mu.Unlock() + resp, err := i.tokenExchanger.exchange(ctx) if err != nil { - return AuthenticateResponse{}, err + return err } + i.sessionToken = resp.SessionToken + return nil +} - res, err = c.SDK.Request(ctx, http.MethodPost, "/api/v2/workspaceagents/azure-instance-identity", token) - if err != nil { - return AuthenticateResponse{}, err +// FixedSessionTokenProvider wraps the codersdk variant to add a no-op RefreshToken method to satisfy the +// RefreshableSessionTokenProvider interface. +// @typescript-ignore FixedSessionTokenProvider +type FixedSessionTokenProvider struct { + codersdk.FixedSessionTokenProvider +} + +func (FixedSessionTokenProvider) RefreshToken(_ context.Context) error { + return nil +} + +func WithFixedToken(token string) SessionTokenSetup { + return func(_ *codersdk.Client) RefreshableSessionTokenProvider { + return FixedSessionTokenProvider{FixedSessionTokenProvider: codersdk.FixedSessionTokenProvider{SessionToken: token}} } - defer res.Body.Close() - if res.StatusCode != http.StatusOK { - return AuthenticateResponse{}, codersdk.ReadBodyAsError(res) - } - var resp AuthenticateResponse - return resp, json.NewDecoder(res.Body).Decode(&resp) } // Stats records the Agent's network connection statistics for use in diff --git a/codersdk/agentsdk/agentsdk_test.go b/codersdk/agentsdk/agentsdk_test.go index e6ea6838dd..4f3d7d838b 100644 --- a/codersdk/agentsdk/agentsdk_test.go +++ b/codersdk/agentsdk/agentsdk_test.go @@ -141,7 +141,7 @@ func TestRewriteDERPMap(t *testing.T) { } parsed, err := url.Parse("https://coconuts.org:44558") require.NoError(t, err) - client := agentsdk.New(parsed) + client := agentsdk.New(parsed, agentsdk.WithFixedToken("unused")) client.RewriteDERPMap(dm) region := dm.Regions[1] require.True(t, region.EmbeddedRelay) diff --git a/codersdk/agentsdk/aws.go b/codersdk/agentsdk/aws.go new file mode 100644 index 0000000000..b4f30ec4e9 --- /dev/null +++ b/codersdk/agentsdk/aws.go @@ -0,0 +1,97 @@ +package agentsdk + +import ( + "context" + "encoding/json" + "io" + "net/http" + + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/codersdk" +) + +type AWSInstanceIdentityToken struct { + Signature string `json:"signature" validate:"required"` + Document string `json:"document" validate:"required"` +} + +// awsSessionTokenExchanger exchanges AWS instance metadata for a Coder session token. +// @typescript-ignore awsSessionTokenExchanger +type awsSessionTokenExchanger struct { + client *codersdk.Client +} + +func WithAWSInstanceIdentity() SessionTokenSetup { + return func(client *codersdk.Client) RefreshableSessionTokenProvider { + return &instanceIdentitySessionTokenProvider{ + tokenExchanger: &awsSessionTokenExchanger{client: client}, + } + } +} + +// exchange uses the Amazon Metadata API to fetch a signed payload, and exchange it for a session token for a workspace +// agent. +// +// The requesting instance must be registered as a resource in the latest history for a workspace. +func (a *awsSessionTokenExchanger) exchange(ctx context.Context) (AuthenticateResponse, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPut, "http://169.254.169.254/latest/api/token", nil) + if err != nil { + return AuthenticateResponse{}, nil + } + req.Header.Set("X-aws-ec2-metadata-token-ttl-seconds", "21600") + res, err := a.client.HTTPClient.Do(req) + if err != nil { + return AuthenticateResponse{}, err + } + defer res.Body.Close() + token, err := io.ReadAll(res.Body) + if err != nil { + return AuthenticateResponse{}, xerrors.Errorf("read token: %w", err) + } + + req, err = http.NewRequestWithContext(ctx, http.MethodGet, "http://169.254.169.254/latest/dynamic/instance-identity/signature", nil) + if err != nil { + return AuthenticateResponse{}, nil + } + req.Header.Set("X-aws-ec2-metadata-token", string(token)) + res, err = a.client.HTTPClient.Do(req) + if err != nil { + return AuthenticateResponse{}, err + } + defer res.Body.Close() + signature, err := io.ReadAll(res.Body) + if err != nil { + return AuthenticateResponse{}, xerrors.Errorf("read token: %w", err) + } + + req, err = http.NewRequestWithContext(ctx, http.MethodGet, "http://169.254.169.254/latest/dynamic/instance-identity/document", nil) + if err != nil { + return AuthenticateResponse{}, nil + } + req.Header.Set("X-aws-ec2-metadata-token", string(token)) + res, err = a.client.HTTPClient.Do(req) + if err != nil { + return AuthenticateResponse{}, err + } + defer res.Body.Close() + document, err := io.ReadAll(res.Body) + if err != nil { + return AuthenticateResponse{}, xerrors.Errorf("read token: %w", err) + } + + // request without the token to avoid re-entering this function + res, err = a.client.RequestWithoutSessionToken(ctx, http.MethodPost, "/api/v2/workspaceagents/aws-instance-identity", AWSInstanceIdentityToken{ + Signature: string(signature), + Document: string(document), + }) + if err != nil { + return AuthenticateResponse{}, err + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return AuthenticateResponse{}, codersdk.ReadBodyAsError(res) + } + var resp AuthenticateResponse + return resp, json.NewDecoder(res.Body).Decode(&resp) +} diff --git a/codersdk/agentsdk/azure.go b/codersdk/agentsdk/azure.go new file mode 100644 index 0000000000..eb66e21097 --- /dev/null +++ b/codersdk/agentsdk/azure.go @@ -0,0 +1,60 @@ +package agentsdk + +import ( + "context" + "encoding/json" + "net/http" + + "github.com/coder/coder/v2/codersdk" +) + +type AzureInstanceIdentityToken struct { + Signature string `json:"signature" validate:"required"` + Encoding string `json:"encoding" validate:"required"` +} + +// azureSessionTokenExchanger exchanges Azure attested metadata for a Coder session token. +// @typescript-ignore azureSessionTokenExchanger +type azureSessionTokenExchanger struct { + client *codersdk.Client +} + +func WithAzureInstanceIdentity() SessionTokenSetup { + return func(client *codersdk.Client) RefreshableSessionTokenProvider { + return &instanceIdentitySessionTokenProvider{ + tokenExchanger: &azureSessionTokenExchanger{client: client}, + } + } +} + +// AuthWorkspaceAzureInstanceIdentity uses the Azure Instance Metadata Service to +// fetch a signed payload, and exchange it for a session token for a workspace agent. +func (a *azureSessionTokenExchanger) exchange(ctx context.Context) (AuthenticateResponse, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://169.254.169.254/metadata/attested/document?api-version=2020-09-01", nil) + if err != nil { + return AuthenticateResponse{}, nil + } + req.Header.Set("Metadata", "true") + res, err := a.client.HTTPClient.Do(req) + if err != nil { + return AuthenticateResponse{}, err + } + defer res.Body.Close() + + var token AzureInstanceIdentityToken + err = json.NewDecoder(res.Body).Decode(&token) + if err != nil { + return AuthenticateResponse{}, err + } + + res, err = a.client.RequestWithoutSessionToken(ctx, http.MethodPost, "/api/v2/workspaceagents/azure-instance-identity", token) + if err != nil { + return AuthenticateResponse{}, err + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return AuthenticateResponse{}, codersdk.ReadBodyAsError(res) + } + var resp AuthenticateResponse + return resp, json.NewDecoder(res.Body).Decode(&resp) +} diff --git a/codersdk/agentsdk/google.go b/codersdk/agentsdk/google.go new file mode 100644 index 0000000000..e462ba2404 --- /dev/null +++ b/codersdk/agentsdk/google.go @@ -0,0 +1,71 @@ +package agentsdk + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + + "cloud.google.com/go/compute/metadata" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/codersdk" +) + +type GoogleInstanceIdentityToken struct { + JSONWebToken string `json:"json_web_token" validate:"required"` +} + +// googleSessionTokenExchanger exchanges a Google instance JWT document for a Coder session token. +// @typescript-ignore googleSessionTokenExchanger +type googleSessionTokenExchanger struct { + serviceAccount string + gcpClient *metadata.Client + client *codersdk.Client +} + +func WithGoogleInstanceIdentity(serviceAccount string, gcpClient *metadata.Client) SessionTokenSetup { + return func(client *codersdk.Client) RefreshableSessionTokenProvider { + return &instanceIdentitySessionTokenProvider{ + tokenExchanger: &googleSessionTokenExchanger{ + client: client, + gcpClient: gcpClient, + serviceAccount: serviceAccount, + }, + } + } +} + +// exchange uses the Google Compute Engine Metadata API to fetch a signed JWT, and exchange it for a session token for a +// workspace agent. +// +// The requesting instance must be registered as a resource in the latest history for a workspace. +func (g *googleSessionTokenExchanger) exchange(ctx context.Context) (AuthenticateResponse, error) { + if g.serviceAccount == "" { + // This is the default name specified by Google. + g.serviceAccount = "default" + } + gcpClient := metadata.NewClient(g.client.HTTPClient) + if g.gcpClient != nil { + gcpClient = g.gcpClient + } + + // "format=full" is required, otherwise the responding payload will be missing "instance_id". + jwt, err := gcpClient.Get(fmt.Sprintf("instance/service-accounts/%s/identity?audience=coder&format=full", g.serviceAccount)) + if err != nil { + return AuthenticateResponse{}, xerrors.Errorf("get metadata identity: %w", err) + } + // request without the token to avoid re-entering this function + res, err := g.client.RequestWithoutSessionToken(ctx, http.MethodPost, "/api/v2/workspaceagents/google-instance-identity", GoogleInstanceIdentityToken{ + JSONWebToken: jwt, + }) + if err != nil { + return AuthenticateResponse{}, err + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return AuthenticateResponse{}, codersdk.ReadBodyAsError(res) + } + var resp AuthenticateResponse + return resp, json.NewDecoder(res.Body).Decode(&resp) +} diff --git a/codersdk/toolsdk/toolsdk_test.go b/codersdk/toolsdk/toolsdk_test.go index fb321e90e7..6d4031e22a 100644 --- a/codersdk/toolsdk/toolsdk_test.go +++ b/codersdk/toolsdk/toolsdk_test.go @@ -75,8 +75,7 @@ func TestTools(t *testing.T) { }).Do() // Given: a client configured with the agent token. - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(r.AgentToken) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken)) // Get the agent ID from the API. Overriding it in dbfake doesn't work. ws, err := client.Workspace(setupCtx, r.Workspace.ID) require.NoError(t, err) diff --git a/docs/reference/cli/external-auth_access-token.md b/docs/reference/cli/external-auth_access-token.md index 2303e8f076..7fb022077a 100644 --- a/docs/reference/cli/external-auth_access-token.md +++ b/docs/reference/cli/external-auth_access-token.md @@ -40,3 +40,40 @@ fi | Type | string | Extract a field from the "extra" properties of the OAuth token. + +### --agent-token + +| | | +|-------------|---------------------------------| +| Type | string | +| Environment | $CODER_AGENT_TOKEN | + +An agent authentication token. + +### --agent-token-file + +| | | +|-------------|--------------------------------------| +| Type | string | +| Environment | $CODER_AGENT_TOKEN_FILE | + +A file containing an agent authentication token. + +### --agent-url + +| | | +|-------------|-------------------------------| +| Type | url | +| Environment | $CODER_AGENT_URL | + +URL for an agent to access your deployment. + +### --auth + +| | | +|-------------|--------------------------------| +| Type | string | +| Environment | $CODER_AGENT_AUTH | +| Default | token | + +Specify the authentication type to use for the agent. diff --git a/enterprise/coderd/appearance_test.go b/enterprise/coderd/appearance_test.go index 8550f13904..81ba7eddc7 100644 --- a/enterprise/coderd/appearance_test.go +++ b/enterprise/coderd/appearance_test.go @@ -153,15 +153,13 @@ func TestAnnouncementBanners(t *testing.T) { OwnerID: user.UserID, }).WithAgent().Do() - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(r.AgentToken) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken)) banners := requireGetAnnouncementBanners(ctx, t, agentClient) require.Equal(t, cfg.AnnouncementBanners, banners) // Create an AGPL Coderd against the same database agplClient := coderdtest.New(t, &coderdtest.Options{Database: store, Pubsub: ps}) - agplAgentClient := agentsdk.New(agplClient.URL) - agplAgentClient.SetSessionToken(r.AgentToken) + agplAgentClient := agentsdk.New(agplClient.URL, agentsdk.WithFixedToken(r.AgentToken)) banners = requireGetAnnouncementBanners(ctx, t, agplAgentClient) require.Equal(t, []codersdk.BannerConfig{}, banners) diff --git a/enterprise/coderd/gitsshkey_test.go b/enterprise/coderd/gitsshkey_test.go index a4978ac8fd..7045c8dd86 100644 --- a/enterprise/coderd/gitsshkey_test.go +++ b/enterprise/coderd/gitsshkey_test.go @@ -69,8 +69,7 @@ func TestAgentGitSSHKeyCustomRoles(t *testing.T) { workspace := coderdtest.CreateWorkspace(t, client, project.ID) coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(authToken) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(authToken)) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() diff --git a/enterprise/coderd/workspaceagents_test.go b/enterprise/coderd/workspaceagents_test.go index c9d44e667c..917d44dff2 100644 --- a/enterprise/coderd/workspaceagents_test.go +++ b/enterprise/coderd/workspaceagents_test.go @@ -319,7 +319,7 @@ func setupWorkspaceAgent(t *testing.T, client *codersdk.Client, user codersdk.Cr template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) workspace := coderdtest.CreateWorkspace(t, client, template.ID) coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) - agentClient := agentsdk.New(client.URL) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(authToken)) agentClient.SDK.HTTPClient = &http.Client{ Transport: &http.Transport{ TLSClientConfig: &tls.Config{ @@ -328,7 +328,6 @@ func setupWorkspaceAgent(t *testing.T, client *codersdk.Client, user codersdk.Cr }, }, } - agentClient.SetSessionToken(authToken) agnt := agent.New(agent.Options{ Client: agentClient, Logger: testutil.Logger(t).Named("agent"), diff --git a/scaletest/createworkspaces/run_test.go b/scaletest/createworkspaces/run_test.go index c63854ff8a..edade6b79e 100644 --- a/scaletest/createworkspaces/run_test.go +++ b/scaletest/createworkspaces/run_test.go @@ -561,8 +561,7 @@ func goEventuallyStartFakeAgent(ctx context.Context, t *testing.T, client *coder coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(agentToken) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(agentToken)) agentCloser := agent.New(agent.Options{ Client: agentClient, Logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}). diff --git a/scaletest/workspacebuild/run_test.go b/scaletest/workspacebuild/run_test.go index 5949f04d5b..f813019d0f 100644 --- a/scaletest/workspacebuild/run_test.go +++ b/scaletest/workspacebuild/run_test.go @@ -134,8 +134,7 @@ func Test_Runner(t *testing.T) { for i, authToken := range []string{authToken1, authToken2, authToken3} { i := i + 1 - agentClient := agentsdk.New(client.URL) - agentClient.SetSessionToken(authToken) + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(authToken)) agentCloser := agent.New(agent.Options{ Client: agentClient, Logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).