chore: refactor CLI agent auth tests as unit tests (#19609)

Fixes https://github.com/coder/internal/issues/933

Refactors CLI tests that check the `--auth` flag parsing for various public clouds into a unit test that just creates the agent Client and asserts on the type.

Testing that the agent client actually authenticates correctly with these auth types is well covered by Coderd tests, so we don't need to retread that ground here, and the deleted tests were flaky on Windows.
This commit is contained in:
Spike Curtis
2025-09-03 10:49:19 +04:00
committed by GitHub
parent 1354d84eb4
commit 18945a7949
12 changed files with 129 additions and 237 deletions
+1 -1
View File
@@ -179,7 +179,7 @@ func workspaceAgent() *serpent.Command {
slog.F("auth", agentAuth.agentAuth), slog.F("auth", agentAuth.agentAuth),
slog.F("version", version), slog.F("version", version),
) )
client, err := agentAuth.CreateClient(ctx) client, err := agentAuth.CreateClient()
if err != nil { if err != nil {
return xerrors.Errorf("create agent client: %w", err) return xerrors.Errorf("create agent client: %w", err)
} }
-157
View File
@@ -1,7 +1,6 @@
package cli_test package cli_test
import ( import (
"context"
"fmt" "fmt"
"net/http" "net/http"
"os" "os"
@@ -11,7 +10,6 @@ import (
"sync/atomic" "sync/atomic"
"testing" "testing"
"github.com/google/uuid"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -21,10 +19,7 @@ import (
"github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbfake" "github.com/coder/coder/v2/coderd/database/dbfake"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/provisionersdk/proto"
"github.com/coder/coder/v2/testutil" "github.com/coder/coder/v2/testutil"
) )
@@ -64,158 +59,6 @@ func TestWorkspaceAgent(t *testing.T) {
}, testutil.WaitLong, testutil.IntervalMedium) }, testutil.WaitLong, testutil.IntervalMedium)
}) })
t.Run("Azure", func(t *testing.T) {
t.Parallel()
instanceID := "instanceidentifier"
certificates, metadataClient := coderdtest.NewAzureInstanceIdentity(t, instanceID)
db, ps := dbtestutil.NewDB(t,
dbtestutil.WithDumpOnFailure(),
)
client := coderdtest.New(t, &coderdtest.Options{
Database: db,
Pubsub: ps,
AzureCertificates: certificates,
})
user := coderdtest.CreateFirstUser(t, client)
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
}).WithAgent(func(agents []*proto.Agent) []*proto.Agent {
agents[0].Auth = &proto.Agent_InstanceId{InstanceId: instanceID}
return agents
}).Do()
inv, _ := clitest.New(t, "agent", "--auth", "azure-instance-identity", "--agent-url", client.URL.String())
inv = inv.WithContext(
//nolint:revive,staticcheck
context.WithValue(inv.Context(), "azure-client", metadataClient),
)
ctx := inv.Context()
clitest.Start(t, inv)
coderdtest.NewWorkspaceAgentWaiter(t, client, r.Workspace.ID).
MatchResources(matchAgentWithVersion).Wait()
workspace, err := client.Workspace(ctx, r.Workspace.ID)
require.NoError(t, err)
resources := workspace.LatestBuild.Resources
if assert.NotEmpty(t, workspace.LatestBuild.Resources) && assert.NotEmpty(t, resources[0].Agents) {
assert.NotEmpty(t, resources[0].Agents[0].Version)
}
dialer, err := workspacesdk.New(client).
DialAgent(ctx, resources[0].Agents[0].ID, nil)
require.NoError(t, err)
defer dialer.Close()
require.True(t, dialer.AwaitReachable(ctx))
})
t.Run("AWS", func(t *testing.T) {
t.Parallel()
instanceID := "instanceidentifier"
certificates, metadataClient := coderdtest.NewAWSInstanceIdentity(t, instanceID)
db, ps := dbtestutil.NewDB(t,
dbtestutil.WithDumpOnFailure(),
)
client := coderdtest.New(t, &coderdtest.Options{
Database: db,
Pubsub: ps,
AWSCertificates: certificates,
})
user := coderdtest.CreateFirstUser(t, client)
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
}).WithAgent(func(agents []*proto.Agent) []*proto.Agent {
agents[0].Auth = &proto.Agent_InstanceId{InstanceId: instanceID}
return agents
}).Do()
inv, _ := clitest.New(t, "agent", "--auth", "aws-instance-identity", "--agent-url", client.URL.String())
inv = inv.WithContext(
//nolint:revive,staticcheck
context.WithValue(inv.Context(), "aws-client", metadataClient),
)
clitest.Start(t, inv)
ctx := inv.Context()
coderdtest.NewWorkspaceAgentWaiter(t, client, r.Workspace.ID).
MatchResources(matchAgentWithVersion).
Wait()
workspace, err := client.Workspace(ctx, r.Workspace.ID)
require.NoError(t, err)
resources := workspace.LatestBuild.Resources
if assert.NotEmpty(t, resources) && assert.NotEmpty(t, resources[0].Agents) {
assert.NotEmpty(t, resources[0].Agents[0].Version)
}
dialer, err := workspacesdk.New(client).
DialAgent(ctx, resources[0].Agents[0].ID, nil)
require.NoError(t, err)
defer dialer.Close()
require.True(t, dialer.AwaitReachable(ctx))
})
t.Run("GoogleCloud", func(t *testing.T) {
t.Parallel()
instanceID := "instanceidentifier"
validator, metadataClient := coderdtest.NewGoogleInstanceIdentity(t, instanceID, false)
db, ps := dbtestutil.NewDB(t,
dbtestutil.WithDumpOnFailure(),
)
client := coderdtest.New(t, &coderdtest.Options{
Database: db,
Pubsub: ps,
GoogleTokenValidator: validator,
})
owner := coderdtest.CreateFirstUser(t, client)
member, memberUser := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: owner.OrganizationID,
OwnerID: memberUser.ID,
}).WithAgent(func(agents []*proto.Agent) []*proto.Agent {
agents[0].Auth = &proto.Agent_InstanceId{InstanceId: instanceID}
return agents
}).Do()
inv, cfg := clitest.New(t, "agent", "--auth", "google-instance-identity", "--agent-url", client.URL.String())
clitest.SetupConfig(t, member, cfg)
clitest.Start(t,
inv.WithContext(
//nolint:revive,staticcheck
context.WithValue(inv.Context(), "gcp-client", metadataClient),
),
)
ctx := inv.Context()
coderdtest.NewWorkspaceAgentWaiter(t, client, r.Workspace.ID).
MatchResources(matchAgentWithVersion).
Wait()
workspace, err := client.Workspace(ctx, r.Workspace.ID)
require.NoError(t, err)
resources := workspace.LatestBuild.Resources
if assert.NotEmpty(t, resources) && assert.NotEmpty(t, resources[0].Agents) {
assert.NotEmpty(t, resources[0].Agents[0].Version)
}
dialer, err := workspacesdk.New(client).DialAgent(ctx, resources[0].Agents[0].ID, nil)
require.NoError(t, err)
defer dialer.Close()
require.True(t, dialer.AwaitReachable(ctx))
sshClient, err := dialer.SSHClient(ctx)
require.NoError(t, err)
defer sshClient.Close()
session, err := sshClient.NewSession()
require.NoError(t, err)
defer session.Close()
key := "CODER_AGENT_TOKEN"
command := "sh -c 'echo $" + key + "'"
if runtime.GOOS == "windows" {
command = "cmd.exe /c echo %" + key + "%"
}
token, err := session.CombinedOutput(command)
require.NoError(t, err)
_, err = uuid.Parse(strings.TrimSpace(string(token)))
require.NoError(t, err)
})
t.Run("PostStartup", func(t *testing.T) { t.Run("PostStartup", func(t *testing.T) {
t.Parallel() t.Parallel()
+2 -2
View File
@@ -149,7 +149,7 @@ func mcpConfigureClaudeCode() *serpent.Command {
binPath = testBinaryName binPath = testBinaryName
} }
configureClaudeEnv := map[string]string{} configureClaudeEnv := map[string]string{}
agentClient, err := agentAuth.CreateClient(inv.Context()) agentClient, err := agentAuth.CreateClient()
if err != nil { if err != nil {
cliui.Warnf(inv.Stderr, "failed to create agent client: %s", err) cliui.Warnf(inv.Stderr, "failed to create agent client: %s", err)
} else { } else {
@@ -497,7 +497,7 @@ func (r *RootCmd) mcpServer() *serpent.Command {
} }
// Try to create an agent client for status reporting. Not validated. // Try to create an agent client for status reporting. Not validated.
agentClient, err := agentAuth.CreateClient(inv.Context()) agentClient, err := agentAuth.CreateClient()
if err == nil { if err == nil {
cliui.Infof(inv.Stderr, "Agent URL : %s", agentClient.SDK.URL.String()) cliui.Infof(inv.Stderr, "Agent URL : %s", agentClient.SDK.URL.String())
srv.agentClient = agentClient srv.agentClient = agentClient
+1 -1
View File
@@ -68,7 +68,7 @@ fi
ctx, stop := inv.SignalNotifyContext(ctx, StopSignals...) ctx, stop := inv.SignalNotifyContext(ctx, StopSignals...)
defer stop() defer stop()
client, err := agentAuth.CreateClient(ctx) client, err := agentAuth.CreateClient()
if err != nil { if err != nil {
return xerrors.Errorf("create agent client: %w", err) return xerrors.Errorf("create agent client: %w", err)
} }
+1 -1
View File
@@ -33,7 +33,7 @@ func gitAskpass(agentAuth *AgentAuth) *serpent.Command {
return xerrors.Errorf("parse host: %w", err) return xerrors.Errorf("parse host: %w", err)
} }
client, err := agentAuth.CreateClient(ctx) client, err := agentAuth.CreateClient()
if err != nil { if err != nil {
return xerrors.Errorf("create agent client: %w", err) return xerrors.Errorf("create agent client: %w", err)
} }
+1 -1
View File
@@ -39,7 +39,7 @@ func gitssh() *serpent.Command {
return err return err
} }
client, err := agentAuth.CreateClient(ctx) client, err := agentAuth.CreateClient()
if err != nil { if err != nil {
return xerrors.Errorf("create agent client: %w", err) return xerrors.Errorf("create agent client: %w", err)
} }
+4 -35
View File
@@ -24,7 +24,6 @@ import (
"text/tabwriter" "text/tabwriter"
"time" "time"
"cloud.google.com/go/compute/metadata"
"github.com/mattn/go-isatty" "github.com/mattn/go-isatty"
"github.com/mitchellh/go-wordwrap" "github.com/mitchellh/go-wordwrap"
"golang.org/x/mod/semver" "golang.org/x/mod/semver"
@@ -687,7 +686,7 @@ func (a *AgentAuth) AttachOptions(cmd *serpent.Command, hidden bool) {
// CreateClient returns a new agent client from the command context. It works // CreateClient returns a new agent client from the command context. It works
// just like InitClient, but uses the agent token and URL instead. // just like InitClient, but uses the agent token and URL instead.
func (a *AgentAuth) CreateClient(ctx context.Context) (*agentsdk.Client, error) { func (a *AgentAuth) CreateClient() (*agentsdk.Client, error) {
agentURL := a.agentURL agentURL := a.agentURL
if agentURL.String() == "" { if agentURL.String() == "" {
return nil, xerrors.Errorf("%s must be set", envAgentURL) return nil, xerrors.Errorf("%s must be set", envAgentURL)
@@ -711,41 +710,11 @@ func (a *AgentAuth) CreateClient(ctx context.Context) (*agentsdk.Client, error)
} }
return agentsdk.New(&a.agentURL, agentsdk.WithFixedToken(token)), nil return agentsdk.New(&a.agentURL, agentsdk.WithFixedToken(token)), nil
case "google-instance-identity": case "google-instance-identity":
return agentsdk.New(&a.agentURL, agentsdk.WithGoogleInstanceIdentity("", nil)), nil
// This is *only* done for testing to mock client authentication.
// This will never be set in a production scenario.
var gcpClient *metadata.Client
gcpClientRaw := ctx.Value("gcp-client")
if gcpClientRaw != nil {
gcpClient, _ = gcpClientRaw.(*metadata.Client)
}
return agentsdk.New(&a.agentURL, agentsdk.WithGoogleInstanceIdentity("", gcpClient)), nil
case "aws-instance-identity": case "aws-instance-identity":
client := agentsdk.New(&a.agentURL, agentsdk.WithAWSInstanceIdentity()) return agentsdk.New(&a.agentURL, agentsdk.WithAWSInstanceIdentity()), nil
// This is *only* done for testing to mock client authentication.
// This will never be set in a production scenario.
var awsClient *http.Client
awsClientRaw := ctx.Value("aws-client")
if awsClientRaw != nil {
awsClient, _ = awsClientRaw.(*http.Client)
if awsClient != nil {
client.SDK.HTTPClient = awsClient
}
}
return client, nil
case "azure-instance-identity": case "azure-instance-identity":
client := agentsdk.New(&a.agentURL, agentsdk.WithAzureInstanceIdentity()) return agentsdk.New(&a.agentURL, agentsdk.WithAzureInstanceIdentity()), nil
// This is *only* done for testing to mock client authentication.
// This will never be set in a production scenario.
var azureClient *http.Client
azureClientRaw := ctx.Value("azure-client")
if azureClientRaw != nil {
azureClient, _ = azureClientRaw.(*http.Client)
if azureClient != nil {
client.SDK.HTTPClient = azureClient
}
}
return client, nil
default: default:
return nil, xerrors.Errorf("unknown agent auth type: %s", a.agentAuth) return nil, xerrors.Errorf("unknown agent auth type: %s", a.agentAuth)
} }
+88 -8
View File
@@ -10,20 +10,20 @@ import (
"sync/atomic" "sync/atomic"
"testing" "testing"
"github.com/coder/serpent"
"github.com/coder/coder/v2/coderd"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/pty/ptytest"
"github.com/coder/coder/v2/testutil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/buildinfo" "github.com/coder/coder/v2/buildinfo"
"github.com/coder/coder/v2/cli" "github.com/coder/coder/v2/cli"
"github.com/coder/coder/v2/cli/clitest" "github.com/coder/coder/v2/cli/clitest"
"github.com/coder/coder/v2/coderd"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/pty/ptytest"
"github.com/coder/coder/v2/testutil"
"github.com/coder/serpent"
) )
//nolint:tparallel,paralleltest //nolint:tparallel,paralleltest
@@ -275,3 +275,83 @@ func TestHandlersOK(t *testing.T) {
clitest.HandlersOK(t, cmd) clitest.HandlersOK(t, cmd)
} }
func TestCreateAgentClient_Token(t *testing.T) {
t.Parallel()
client := createAgentWithFlags(t,
"--agent-token", "fake-token",
"--agent-url", "http://coder.fake")
require.Equal(t, "fake-token", client.GetSessionToken())
}
func TestCreateAgentClient_Google(t *testing.T) {
t.Parallel()
client := createAgentWithFlags(t,
"--auth", "google-instance-identity",
"--agent-url", "http://coder.fake")
provider, ok := client.RefreshableSessionTokenProvider.(*agentsdk.InstanceIdentitySessionTokenProvider)
require.True(t, ok)
require.NotNil(t, provider.TokenExchanger)
require.IsType(t, &agentsdk.GoogleSessionTokenExchanger{}, provider.TokenExchanger)
}
func TestCreateAgentClient_AWS(t *testing.T) {
t.Parallel()
client := createAgentWithFlags(t,
"--auth", "aws-instance-identity",
"--agent-url", "http://coder.fake")
provider, ok := client.RefreshableSessionTokenProvider.(*agentsdk.InstanceIdentitySessionTokenProvider)
require.True(t, ok)
require.NotNil(t, provider.TokenExchanger)
require.IsType(t, &agentsdk.AWSSessionTokenExchanger{}, provider.TokenExchanger)
}
func TestCreateAgentClient_Azure(t *testing.T) {
t.Parallel()
client := createAgentWithFlags(t,
"--auth", "azure-instance-identity",
"--agent-url", "http://coder.fake")
provider, ok := client.RefreshableSessionTokenProvider.(*agentsdk.InstanceIdentitySessionTokenProvider)
require.True(t, ok)
require.NotNil(t, provider.TokenExchanger)
require.IsType(t, &agentsdk.AzureSessionTokenExchanger{}, provider.TokenExchanger)
}
func createAgentWithFlags(t *testing.T, flags ...string) *agentsdk.Client {
t.Helper()
r := &cli.RootCmd{}
var client *agentsdk.Client
subCmd := agentClientCommand(&client)
cmd, err := r.Command([]*serpent.Command{subCmd})
require.NoError(t, err)
inv, _ := clitest.NewWithCommand(t, cmd,
append([]string{"agent-client"}, flags...)...)
err = inv.Run()
require.NoError(t, err)
require.NotNil(t, client)
return client
}
// agentClientCommand creates a subcommand that creates an agent client and stores it in the provided clientRef. Used to
// test the properties of the client with various root command flags.
func agentClientCommand(clientRef **agentsdk.Client) *serpent.Command {
agentAuth := &cli.AgentAuth{}
cmd := &serpent.Command{
Use: "agent-client",
Short: `Creates and agent client for testing.`,
Handler: func(inv *serpent.Invocation) error {
client, err := agentAuth.CreateClient()
if err != nil {
return xerrors.Errorf("create agent client: %w", err)
}
*clientRef = client
return nil
},
}
agentAuth.AttachOptions(cmd, false)
return cmd
}
+13 -13
View File
@@ -340,11 +340,11 @@ type RefreshableSessionTokenProvider interface {
RefreshToken(ctx context.Context) error RefreshToken(ctx context.Context) error
} }
// instanceIdentitySessionTokenProvider implements RefreshableSessionTokenProvider via token exchange for a cloud // InstanceIdentitySessionTokenProvider implements RefreshableSessionTokenProvider via token exchange for a cloud
// compute instance identity. // compute instance identity.
// @typescript-ignore instanceIdentitySessionTokenProvider // @typescript-ignore InstanceIdentitySessionTokenProvider
type instanceIdentitySessionTokenProvider struct { type InstanceIdentitySessionTokenProvider struct {
tokenExchanger tokenExchanger TokenExchanger TokenExchanger
logger slog.Logger logger slog.Logger
// cache so we don't request each time // cache so we don't request each time
@@ -352,20 +352,20 @@ type instanceIdentitySessionTokenProvider struct {
sessionToken string sessionToken string
} }
// tokenExchanger obtains a session token by exchanging a cloud instance identity credential for a Coder session token. // TokenExchanger obtains a session token by exchanging a cloud instance identity credential for a Coder session token.
// @typescript-ignore tokenExchanger // @typescript-ignore TokenExchanger
type tokenExchanger interface { type TokenExchanger interface {
exchange(ctx context.Context) (AuthenticateResponse, error) exchange(ctx context.Context) (AuthenticateResponse, error)
} }
func (i *instanceIdentitySessionTokenProvider) AsRequestOption() codersdk.RequestOption { func (i *InstanceIdentitySessionTokenProvider) AsRequestOption() codersdk.RequestOption {
t := i.GetSessionToken() t := i.GetSessionToken()
return func(req *http.Request) { return func(req *http.Request) {
req.Header.Set(codersdk.SessionTokenHeader, t) req.Header.Set(codersdk.SessionTokenHeader, t)
} }
} }
func (i *instanceIdentitySessionTokenProvider) SetDialOption(opts *websocket.DialOptions) { func (i *InstanceIdentitySessionTokenProvider) SetDialOption(opts *websocket.DialOptions) {
t := i.GetSessionToken() t := i.GetSessionToken()
if opts.HTTPHeader == nil { if opts.HTTPHeader == nil {
opts.HTTPHeader = http.Header{} opts.HTTPHeader = http.Header{}
@@ -375,7 +375,7 @@ func (i *instanceIdentitySessionTokenProvider) SetDialOption(opts *websocket.Dia
} }
} }
func (i *instanceIdentitySessionTokenProvider) GetSessionToken() string { func (i *InstanceIdentitySessionTokenProvider) GetSessionToken() string {
i.mu.Lock() i.mu.Lock()
defer i.mu.Unlock() defer i.mu.Unlock()
if i.sessionToken != "" { if i.sessionToken != "" {
@@ -383,7 +383,7 @@ func (i *instanceIdentitySessionTokenProvider) GetSessionToken() string {
} }
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel() defer cancel()
resp, err := i.tokenExchanger.exchange(ctx) resp, err := i.TokenExchanger.exchange(ctx)
if err != nil { if err != nil {
i.logger.Error(ctx, "failed to exchange session token: %v", err) i.logger.Error(ctx, "failed to exchange session token: %v", err)
return "" return ""
@@ -392,10 +392,10 @@ func (i *instanceIdentitySessionTokenProvider) GetSessionToken() string {
return i.sessionToken return i.sessionToken
} }
func (i *instanceIdentitySessionTokenProvider) RefreshToken(ctx context.Context) error { func (i *InstanceIdentitySessionTokenProvider) RefreshToken(ctx context.Context) error {
i.mu.Lock() i.mu.Lock()
defer i.mu.Unlock() defer i.mu.Unlock()
resp, err := i.tokenExchanger.exchange(ctx) resp, err := i.TokenExchanger.exchange(ctx)
if err != nil { if err != nil {
return err return err
} }
+6 -6
View File
@@ -16,16 +16,16 @@ type AWSInstanceIdentityToken struct {
Document string `json:"document" validate:"required"` Document string `json:"document" validate:"required"`
} }
// awsSessionTokenExchanger exchanges AWS instance metadata for a Coder session token. // AWSSessionTokenExchanger exchanges AWS instance metadata for a Coder session token.
// @typescript-ignore awsSessionTokenExchanger // @typescript-ignore AWSSessionTokenExchanger
type awsSessionTokenExchanger struct { type AWSSessionTokenExchanger struct {
client *codersdk.Client client *codersdk.Client
} }
func WithAWSInstanceIdentity() SessionTokenSetup { func WithAWSInstanceIdentity() SessionTokenSetup {
return func(client *codersdk.Client) RefreshableSessionTokenProvider { return func(client *codersdk.Client) RefreshableSessionTokenProvider {
return &instanceIdentitySessionTokenProvider{ return &InstanceIdentitySessionTokenProvider{
tokenExchanger: &awsSessionTokenExchanger{client: client}, TokenExchanger: &AWSSessionTokenExchanger{client: client},
} }
} }
} }
@@ -34,7 +34,7 @@ func WithAWSInstanceIdentity() SessionTokenSetup {
// agent. // agent.
// //
// The requesting instance must be registered as a resource in the latest history for a workspace. // The requesting instance must be registered as a resource in the latest history for a workspace.
func (a *awsSessionTokenExchanger) exchange(ctx context.Context) (AuthenticateResponse, error) { func (a *AWSSessionTokenExchanger) exchange(ctx context.Context) (AuthenticateResponse, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodPut, "http://169.254.169.254/latest/api/token", nil) req, err := http.NewRequestWithContext(ctx, http.MethodPut, "http://169.254.169.254/latest/api/token", nil)
if err != nil { if err != nil {
return AuthenticateResponse{}, nil return AuthenticateResponse{}, nil
+6 -6
View File
@@ -13,23 +13,23 @@ type AzureInstanceIdentityToken struct {
Encoding string `json:"encoding" validate:"required"` Encoding string `json:"encoding" validate:"required"`
} }
// azureSessionTokenExchanger exchanges Azure attested metadata for a Coder session token. // AzureSessionTokenExchanger exchanges Azure attested metadata for a Coder session token.
// @typescript-ignore azureSessionTokenExchanger // @typescript-ignore AzureSessionTokenExchanger
type azureSessionTokenExchanger struct { type AzureSessionTokenExchanger struct {
client *codersdk.Client client *codersdk.Client
} }
func WithAzureInstanceIdentity() SessionTokenSetup { func WithAzureInstanceIdentity() SessionTokenSetup {
return func(client *codersdk.Client) RefreshableSessionTokenProvider { return func(client *codersdk.Client) RefreshableSessionTokenProvider {
return &instanceIdentitySessionTokenProvider{ return &InstanceIdentitySessionTokenProvider{
tokenExchanger: &azureSessionTokenExchanger{client: client}, TokenExchanger: &AzureSessionTokenExchanger{client: client},
} }
} }
} }
// AuthWorkspaceAzureInstanceIdentity uses the Azure Instance Metadata Service to // AuthWorkspaceAzureInstanceIdentity uses the Azure Instance Metadata Service to
// fetch a signed payload, and exchange it for a session token for a workspace agent. // fetch a signed payload, and exchange it for a session token for a workspace agent.
func (a *azureSessionTokenExchanger) exchange(ctx context.Context) (AuthenticateResponse, error) { func (a *AzureSessionTokenExchanger) exchange(ctx context.Context) (AuthenticateResponse, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://169.254.169.254/metadata/attested/document?api-version=2020-09-01", nil) req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://169.254.169.254/metadata/attested/document?api-version=2020-09-01", nil)
if err != nil { if err != nil {
return AuthenticateResponse{}, nil return AuthenticateResponse{}, nil
+6 -6
View File
@@ -16,9 +16,9 @@ type GoogleInstanceIdentityToken struct {
JSONWebToken string `json:"json_web_token" validate:"required"` JSONWebToken string `json:"json_web_token" validate:"required"`
} }
// googleSessionTokenExchanger exchanges a Google instance JWT document for a Coder session token. // GoogleSessionTokenExchanger exchanges a Google instance JWT document for a Coder session token.
// @typescript-ignore googleSessionTokenExchanger // @typescript-ignore GoogleSessionTokenExchanger
type googleSessionTokenExchanger struct { type GoogleSessionTokenExchanger struct {
serviceAccount string serviceAccount string
gcpClient *metadata.Client gcpClient *metadata.Client
client *codersdk.Client client *codersdk.Client
@@ -26,8 +26,8 @@ type googleSessionTokenExchanger struct {
func WithGoogleInstanceIdentity(serviceAccount string, gcpClient *metadata.Client) SessionTokenSetup { func WithGoogleInstanceIdentity(serviceAccount string, gcpClient *metadata.Client) SessionTokenSetup {
return func(client *codersdk.Client) RefreshableSessionTokenProvider { return func(client *codersdk.Client) RefreshableSessionTokenProvider {
return &instanceIdentitySessionTokenProvider{ return &InstanceIdentitySessionTokenProvider{
tokenExchanger: &googleSessionTokenExchanger{ TokenExchanger: &GoogleSessionTokenExchanger{
client: client, client: client,
gcpClient: gcpClient, gcpClient: gcpClient,
serviceAccount: serviceAccount, serviceAccount: serviceAccount,
@@ -40,7 +40,7 @@ func WithGoogleInstanceIdentity(serviceAccount string, gcpClient *metadata.Clien
// workspace agent. // workspace agent.
// //
// The requesting instance must be registered as a resource in the latest history for a workspace. // The requesting instance must be registered as a resource in the latest history for a workspace.
func (g *googleSessionTokenExchanger) exchange(ctx context.Context) (AuthenticateResponse, error) { func (g *GoogleSessionTokenExchanger) exchange(ctx context.Context) (AuthenticateResponse, error) {
if g.serviceAccount == "" { if g.serviceAccount == "" {
// This is the default name specified by Google. // This is the default name specified by Google.
g.serviceAccount = "default" g.serviceAccount = "default"