chore: refactor CLI agent auth tests as unit tests (#19609)

Fixes https://github.com/coder/internal/issues/933

Refactors CLI tests that check the `--auth` flag parsing for various public clouds into a unit test that just creates the agent Client and asserts on the type.

Testing that the agent client actually authenticates correctly with these auth types is well covered by Coderd tests, so we don't need to retread that ground here, and the deleted tests were flaky on Windows.
This commit is contained in:
Spike Curtis
2025-09-03 10:49:19 +04:00
committed by GitHub
parent 1354d84eb4
commit 18945a7949
12 changed files with 129 additions and 237 deletions
+1 -1
View File
@@ -179,7 +179,7 @@ func workspaceAgent() *serpent.Command {
slog.F("auth", agentAuth.agentAuth),
slog.F("version", version),
)
client, err := agentAuth.CreateClient(ctx)
client, err := agentAuth.CreateClient()
if err != nil {
return xerrors.Errorf("create agent client: %w", err)
}
-157
View File
@@ -1,7 +1,6 @@
package cli_test
import (
"context"
"fmt"
"net/http"
"os"
@@ -11,7 +10,6 @@ import (
"sync/atomic"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -21,10 +19,7 @@ import (
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbfake"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/provisionersdk/proto"
"github.com/coder/coder/v2/testutil"
)
@@ -64,158 +59,6 @@ func TestWorkspaceAgent(t *testing.T) {
}, testutil.WaitLong, testutil.IntervalMedium)
})
t.Run("Azure", func(t *testing.T) {
t.Parallel()
instanceID := "instanceidentifier"
certificates, metadataClient := coderdtest.NewAzureInstanceIdentity(t, instanceID)
db, ps := dbtestutil.NewDB(t,
dbtestutil.WithDumpOnFailure(),
)
client := coderdtest.New(t, &coderdtest.Options{
Database: db,
Pubsub: ps,
AzureCertificates: certificates,
})
user := coderdtest.CreateFirstUser(t, client)
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
}).WithAgent(func(agents []*proto.Agent) []*proto.Agent {
agents[0].Auth = &proto.Agent_InstanceId{InstanceId: instanceID}
return agents
}).Do()
inv, _ := clitest.New(t, "agent", "--auth", "azure-instance-identity", "--agent-url", client.URL.String())
inv = inv.WithContext(
//nolint:revive,staticcheck
context.WithValue(inv.Context(), "azure-client", metadataClient),
)
ctx := inv.Context()
clitest.Start(t, inv)
coderdtest.NewWorkspaceAgentWaiter(t, client, r.Workspace.ID).
MatchResources(matchAgentWithVersion).Wait()
workspace, err := client.Workspace(ctx, r.Workspace.ID)
require.NoError(t, err)
resources := workspace.LatestBuild.Resources
if assert.NotEmpty(t, workspace.LatestBuild.Resources) && assert.NotEmpty(t, resources[0].Agents) {
assert.NotEmpty(t, resources[0].Agents[0].Version)
}
dialer, err := workspacesdk.New(client).
DialAgent(ctx, resources[0].Agents[0].ID, nil)
require.NoError(t, err)
defer dialer.Close()
require.True(t, dialer.AwaitReachable(ctx))
})
t.Run("AWS", func(t *testing.T) {
t.Parallel()
instanceID := "instanceidentifier"
certificates, metadataClient := coderdtest.NewAWSInstanceIdentity(t, instanceID)
db, ps := dbtestutil.NewDB(t,
dbtestutil.WithDumpOnFailure(),
)
client := coderdtest.New(t, &coderdtest.Options{
Database: db,
Pubsub: ps,
AWSCertificates: certificates,
})
user := coderdtest.CreateFirstUser(t, client)
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
}).WithAgent(func(agents []*proto.Agent) []*proto.Agent {
agents[0].Auth = &proto.Agent_InstanceId{InstanceId: instanceID}
return agents
}).Do()
inv, _ := clitest.New(t, "agent", "--auth", "aws-instance-identity", "--agent-url", client.URL.String())
inv = inv.WithContext(
//nolint:revive,staticcheck
context.WithValue(inv.Context(), "aws-client", metadataClient),
)
clitest.Start(t, inv)
ctx := inv.Context()
coderdtest.NewWorkspaceAgentWaiter(t, client, r.Workspace.ID).
MatchResources(matchAgentWithVersion).
Wait()
workspace, err := client.Workspace(ctx, r.Workspace.ID)
require.NoError(t, err)
resources := workspace.LatestBuild.Resources
if assert.NotEmpty(t, resources) && assert.NotEmpty(t, resources[0].Agents) {
assert.NotEmpty(t, resources[0].Agents[0].Version)
}
dialer, err := workspacesdk.New(client).
DialAgent(ctx, resources[0].Agents[0].ID, nil)
require.NoError(t, err)
defer dialer.Close()
require.True(t, dialer.AwaitReachable(ctx))
})
t.Run("GoogleCloud", func(t *testing.T) {
t.Parallel()
instanceID := "instanceidentifier"
validator, metadataClient := coderdtest.NewGoogleInstanceIdentity(t, instanceID, false)
db, ps := dbtestutil.NewDB(t,
dbtestutil.WithDumpOnFailure(),
)
client := coderdtest.New(t, &coderdtest.Options{
Database: db,
Pubsub: ps,
GoogleTokenValidator: validator,
})
owner := coderdtest.CreateFirstUser(t, client)
member, memberUser := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: owner.OrganizationID,
OwnerID: memberUser.ID,
}).WithAgent(func(agents []*proto.Agent) []*proto.Agent {
agents[0].Auth = &proto.Agent_InstanceId{InstanceId: instanceID}
return agents
}).Do()
inv, cfg := clitest.New(t, "agent", "--auth", "google-instance-identity", "--agent-url", client.URL.String())
clitest.SetupConfig(t, member, cfg)
clitest.Start(t,
inv.WithContext(
//nolint:revive,staticcheck
context.WithValue(inv.Context(), "gcp-client", metadataClient),
),
)
ctx := inv.Context()
coderdtest.NewWorkspaceAgentWaiter(t, client, r.Workspace.ID).
MatchResources(matchAgentWithVersion).
Wait()
workspace, err := client.Workspace(ctx, r.Workspace.ID)
require.NoError(t, err)
resources := workspace.LatestBuild.Resources
if assert.NotEmpty(t, resources) && assert.NotEmpty(t, resources[0].Agents) {
assert.NotEmpty(t, resources[0].Agents[0].Version)
}
dialer, err := workspacesdk.New(client).DialAgent(ctx, resources[0].Agents[0].ID, nil)
require.NoError(t, err)
defer dialer.Close()
require.True(t, dialer.AwaitReachable(ctx))
sshClient, err := dialer.SSHClient(ctx)
require.NoError(t, err)
defer sshClient.Close()
session, err := sshClient.NewSession()
require.NoError(t, err)
defer session.Close()
key := "CODER_AGENT_TOKEN"
command := "sh -c 'echo $" + key + "'"
if runtime.GOOS == "windows" {
command = "cmd.exe /c echo %" + key + "%"
}
token, err := session.CombinedOutput(command)
require.NoError(t, err)
_, err = uuid.Parse(strings.TrimSpace(string(token)))
require.NoError(t, err)
})
t.Run("PostStartup", func(t *testing.T) {
t.Parallel()
+2 -2
View File
@@ -149,7 +149,7 @@ func mcpConfigureClaudeCode() *serpent.Command {
binPath = testBinaryName
}
configureClaudeEnv := map[string]string{}
agentClient, err := agentAuth.CreateClient(inv.Context())
agentClient, err := agentAuth.CreateClient()
if err != nil {
cliui.Warnf(inv.Stderr, "failed to create agent client: %s", err)
} else {
@@ -497,7 +497,7 @@ func (r *RootCmd) mcpServer() *serpent.Command {
}
// Try to create an agent client for status reporting. Not validated.
agentClient, err := agentAuth.CreateClient(inv.Context())
agentClient, err := agentAuth.CreateClient()
if err == nil {
cliui.Infof(inv.Stderr, "Agent URL : %s", agentClient.SDK.URL.String())
srv.agentClient = agentClient
+1 -1
View File
@@ -68,7 +68,7 @@ fi
ctx, stop := inv.SignalNotifyContext(ctx, StopSignals...)
defer stop()
client, err := agentAuth.CreateClient(ctx)
client, err := agentAuth.CreateClient()
if err != nil {
return xerrors.Errorf("create agent client: %w", err)
}
+1 -1
View File
@@ -33,7 +33,7 @@ func gitAskpass(agentAuth *AgentAuth) *serpent.Command {
return xerrors.Errorf("parse host: %w", err)
}
client, err := agentAuth.CreateClient(ctx)
client, err := agentAuth.CreateClient()
if err != nil {
return xerrors.Errorf("create agent client: %w", err)
}
+1 -1
View File
@@ -39,7 +39,7 @@ func gitssh() *serpent.Command {
return err
}
client, err := agentAuth.CreateClient(ctx)
client, err := agentAuth.CreateClient()
if err != nil {
return xerrors.Errorf("create agent client: %w", err)
}
+4 -35
View File
@@ -24,7 +24,6 @@ import (
"text/tabwriter"
"time"
"cloud.google.com/go/compute/metadata"
"github.com/mattn/go-isatty"
"github.com/mitchellh/go-wordwrap"
"golang.org/x/mod/semver"
@@ -687,7 +686,7 @@ func (a *AgentAuth) AttachOptions(cmd *serpent.Command, hidden bool) {
// CreateClient returns a new agent client from the command context. It works
// just like InitClient, but uses the agent token and URL instead.
func (a *AgentAuth) CreateClient(ctx context.Context) (*agentsdk.Client, error) {
func (a *AgentAuth) CreateClient() (*agentsdk.Client, error) {
agentURL := a.agentURL
if agentURL.String() == "" {
return nil, xerrors.Errorf("%s must be set", envAgentURL)
@@ -711,41 +710,11 @@ func (a *AgentAuth) CreateClient(ctx context.Context) (*agentsdk.Client, error)
}
return agentsdk.New(&a.agentURL, agentsdk.WithFixedToken(token)), nil
case "google-instance-identity":
// This is *only* done for testing to mock client authentication.
// This will never be set in a production scenario.
var gcpClient *metadata.Client
gcpClientRaw := ctx.Value("gcp-client")
if gcpClientRaw != nil {
gcpClient, _ = gcpClientRaw.(*metadata.Client)
}
return agentsdk.New(&a.agentURL, agentsdk.WithGoogleInstanceIdentity("", gcpClient)), nil
return agentsdk.New(&a.agentURL, agentsdk.WithGoogleInstanceIdentity("", nil)), nil
case "aws-instance-identity":
client := agentsdk.New(&a.agentURL, agentsdk.WithAWSInstanceIdentity())
// This is *only* done for testing to mock client authentication.
// This will never be set in a production scenario.
var awsClient *http.Client
awsClientRaw := ctx.Value("aws-client")
if awsClientRaw != nil {
awsClient, _ = awsClientRaw.(*http.Client)
if awsClient != nil {
client.SDK.HTTPClient = awsClient
}
}
return client, nil
return agentsdk.New(&a.agentURL, agentsdk.WithAWSInstanceIdentity()), nil
case "azure-instance-identity":
client := agentsdk.New(&a.agentURL, agentsdk.WithAzureInstanceIdentity())
// This is *only* done for testing to mock client authentication.
// This will never be set in a production scenario.
var azureClient *http.Client
azureClientRaw := ctx.Value("azure-client")
if azureClientRaw != nil {
azureClient, _ = azureClientRaw.(*http.Client)
if azureClient != nil {
client.SDK.HTTPClient = azureClient
}
}
return client, nil
return agentsdk.New(&a.agentURL, agentsdk.WithAzureInstanceIdentity()), nil
default:
return nil, xerrors.Errorf("unknown agent auth type: %s", a.agentAuth)
}
+88 -8
View File
@@ -10,20 +10,20 @@ import (
"sync/atomic"
"testing"
"github.com/coder/serpent"
"github.com/coder/coder/v2/coderd"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/pty/ptytest"
"github.com/coder/coder/v2/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/buildinfo"
"github.com/coder/coder/v2/cli"
"github.com/coder/coder/v2/cli/clitest"
"github.com/coder/coder/v2/coderd"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/pty/ptytest"
"github.com/coder/coder/v2/testutil"
"github.com/coder/serpent"
)
//nolint:tparallel,paralleltest
@@ -275,3 +275,83 @@ func TestHandlersOK(t *testing.T) {
clitest.HandlersOK(t, cmd)
}
func TestCreateAgentClient_Token(t *testing.T) {
t.Parallel()
client := createAgentWithFlags(t,
"--agent-token", "fake-token",
"--agent-url", "http://coder.fake")
require.Equal(t, "fake-token", client.GetSessionToken())
}
func TestCreateAgentClient_Google(t *testing.T) {
t.Parallel()
client := createAgentWithFlags(t,
"--auth", "google-instance-identity",
"--agent-url", "http://coder.fake")
provider, ok := client.RefreshableSessionTokenProvider.(*agentsdk.InstanceIdentitySessionTokenProvider)
require.True(t, ok)
require.NotNil(t, provider.TokenExchanger)
require.IsType(t, &agentsdk.GoogleSessionTokenExchanger{}, provider.TokenExchanger)
}
func TestCreateAgentClient_AWS(t *testing.T) {
t.Parallel()
client := createAgentWithFlags(t,
"--auth", "aws-instance-identity",
"--agent-url", "http://coder.fake")
provider, ok := client.RefreshableSessionTokenProvider.(*agentsdk.InstanceIdentitySessionTokenProvider)
require.True(t, ok)
require.NotNil(t, provider.TokenExchanger)
require.IsType(t, &agentsdk.AWSSessionTokenExchanger{}, provider.TokenExchanger)
}
func TestCreateAgentClient_Azure(t *testing.T) {
t.Parallel()
client := createAgentWithFlags(t,
"--auth", "azure-instance-identity",
"--agent-url", "http://coder.fake")
provider, ok := client.RefreshableSessionTokenProvider.(*agentsdk.InstanceIdentitySessionTokenProvider)
require.True(t, ok)
require.NotNil(t, provider.TokenExchanger)
require.IsType(t, &agentsdk.AzureSessionTokenExchanger{}, provider.TokenExchanger)
}
func createAgentWithFlags(t *testing.T, flags ...string) *agentsdk.Client {
t.Helper()
r := &cli.RootCmd{}
var client *agentsdk.Client
subCmd := agentClientCommand(&client)
cmd, err := r.Command([]*serpent.Command{subCmd})
require.NoError(t, err)
inv, _ := clitest.NewWithCommand(t, cmd,
append([]string{"agent-client"}, flags...)...)
err = inv.Run()
require.NoError(t, err)
require.NotNil(t, client)
return client
}
// agentClientCommand creates a subcommand that creates an agent client and stores it in the provided clientRef. Used to
// test the properties of the client with various root command flags.
func agentClientCommand(clientRef **agentsdk.Client) *serpent.Command {
agentAuth := &cli.AgentAuth{}
cmd := &serpent.Command{
Use: "agent-client",
Short: `Creates and agent client for testing.`,
Handler: func(inv *serpent.Invocation) error {
client, err := agentAuth.CreateClient()
if err != nil {
return xerrors.Errorf("create agent client: %w", err)
}
*clientRef = client
return nil
},
}
agentAuth.AttachOptions(cmd, false)
return cmd
}
+13 -13
View File
@@ -340,11 +340,11 @@ type RefreshableSessionTokenProvider interface {
RefreshToken(ctx context.Context) error
}
// instanceIdentitySessionTokenProvider implements RefreshableSessionTokenProvider via token exchange for a cloud
// InstanceIdentitySessionTokenProvider implements RefreshableSessionTokenProvider via token exchange for a cloud
// compute instance identity.
// @typescript-ignore instanceIdentitySessionTokenProvider
type instanceIdentitySessionTokenProvider struct {
tokenExchanger tokenExchanger
// @typescript-ignore InstanceIdentitySessionTokenProvider
type InstanceIdentitySessionTokenProvider struct {
TokenExchanger TokenExchanger
logger slog.Logger
// cache so we don't request each time
@@ -352,20 +352,20 @@ type instanceIdentitySessionTokenProvider struct {
sessionToken string
}
// tokenExchanger obtains a session token by exchanging a cloud instance identity credential for a Coder session token.
// @typescript-ignore tokenExchanger
type tokenExchanger interface {
// TokenExchanger obtains a session token by exchanging a cloud instance identity credential for a Coder session token.
// @typescript-ignore TokenExchanger
type TokenExchanger interface {
exchange(ctx context.Context) (AuthenticateResponse, error)
}
func (i *instanceIdentitySessionTokenProvider) AsRequestOption() codersdk.RequestOption {
func (i *InstanceIdentitySessionTokenProvider) AsRequestOption() codersdk.RequestOption {
t := i.GetSessionToken()
return func(req *http.Request) {
req.Header.Set(codersdk.SessionTokenHeader, t)
}
}
func (i *instanceIdentitySessionTokenProvider) SetDialOption(opts *websocket.DialOptions) {
func (i *InstanceIdentitySessionTokenProvider) SetDialOption(opts *websocket.DialOptions) {
t := i.GetSessionToken()
if opts.HTTPHeader == nil {
opts.HTTPHeader = http.Header{}
@@ -375,7 +375,7 @@ func (i *instanceIdentitySessionTokenProvider) SetDialOption(opts *websocket.Dia
}
}
func (i *instanceIdentitySessionTokenProvider) GetSessionToken() string {
func (i *InstanceIdentitySessionTokenProvider) GetSessionToken() string {
i.mu.Lock()
defer i.mu.Unlock()
if i.sessionToken != "" {
@@ -383,7 +383,7 @@ func (i *instanceIdentitySessionTokenProvider) GetSessionToken() string {
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
resp, err := i.tokenExchanger.exchange(ctx)
resp, err := i.TokenExchanger.exchange(ctx)
if err != nil {
i.logger.Error(ctx, "failed to exchange session token: %v", err)
return ""
@@ -392,10 +392,10 @@ func (i *instanceIdentitySessionTokenProvider) GetSessionToken() string {
return i.sessionToken
}
func (i *instanceIdentitySessionTokenProvider) RefreshToken(ctx context.Context) error {
func (i *InstanceIdentitySessionTokenProvider) RefreshToken(ctx context.Context) error {
i.mu.Lock()
defer i.mu.Unlock()
resp, err := i.tokenExchanger.exchange(ctx)
resp, err := i.TokenExchanger.exchange(ctx)
if err != nil {
return err
}
+6 -6
View File
@@ -16,16 +16,16 @@ type AWSInstanceIdentityToken struct {
Document string `json:"document" validate:"required"`
}
// awsSessionTokenExchanger exchanges AWS instance metadata for a Coder session token.
// @typescript-ignore awsSessionTokenExchanger
type awsSessionTokenExchanger struct {
// AWSSessionTokenExchanger exchanges AWS instance metadata for a Coder session token.
// @typescript-ignore AWSSessionTokenExchanger
type AWSSessionTokenExchanger struct {
client *codersdk.Client
}
func WithAWSInstanceIdentity() SessionTokenSetup {
return func(client *codersdk.Client) RefreshableSessionTokenProvider {
return &instanceIdentitySessionTokenProvider{
tokenExchanger: &awsSessionTokenExchanger{client: client},
return &InstanceIdentitySessionTokenProvider{
TokenExchanger: &AWSSessionTokenExchanger{client: client},
}
}
}
@@ -34,7 +34,7 @@ func WithAWSInstanceIdentity() SessionTokenSetup {
// agent.
//
// The requesting instance must be registered as a resource in the latest history for a workspace.
func (a *awsSessionTokenExchanger) exchange(ctx context.Context) (AuthenticateResponse, error) {
func (a *AWSSessionTokenExchanger) exchange(ctx context.Context) (AuthenticateResponse, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodPut, "http://169.254.169.254/latest/api/token", nil)
if err != nil {
return AuthenticateResponse{}, nil
+6 -6
View File
@@ -13,23 +13,23 @@ type AzureInstanceIdentityToken struct {
Encoding string `json:"encoding" validate:"required"`
}
// azureSessionTokenExchanger exchanges Azure attested metadata for a Coder session token.
// @typescript-ignore azureSessionTokenExchanger
type azureSessionTokenExchanger struct {
// AzureSessionTokenExchanger exchanges Azure attested metadata for a Coder session token.
// @typescript-ignore AzureSessionTokenExchanger
type AzureSessionTokenExchanger struct {
client *codersdk.Client
}
func WithAzureInstanceIdentity() SessionTokenSetup {
return func(client *codersdk.Client) RefreshableSessionTokenProvider {
return &instanceIdentitySessionTokenProvider{
tokenExchanger: &azureSessionTokenExchanger{client: client},
return &InstanceIdentitySessionTokenProvider{
TokenExchanger: &AzureSessionTokenExchanger{client: client},
}
}
}
// AuthWorkspaceAzureInstanceIdentity uses the Azure Instance Metadata Service to
// fetch a signed payload, and exchange it for a session token for a workspace agent.
func (a *azureSessionTokenExchanger) exchange(ctx context.Context) (AuthenticateResponse, error) {
func (a *AzureSessionTokenExchanger) exchange(ctx context.Context) (AuthenticateResponse, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://169.254.169.254/metadata/attested/document?api-version=2020-09-01", nil)
if err != nil {
return AuthenticateResponse{}, nil
+6 -6
View File
@@ -16,9 +16,9 @@ type GoogleInstanceIdentityToken struct {
JSONWebToken string `json:"json_web_token" validate:"required"`
}
// googleSessionTokenExchanger exchanges a Google instance JWT document for a Coder session token.
// @typescript-ignore googleSessionTokenExchanger
type googleSessionTokenExchanger struct {
// GoogleSessionTokenExchanger exchanges a Google instance JWT document for a Coder session token.
// @typescript-ignore GoogleSessionTokenExchanger
type GoogleSessionTokenExchanger struct {
serviceAccount string
gcpClient *metadata.Client
client *codersdk.Client
@@ -26,8 +26,8 @@ type googleSessionTokenExchanger struct {
func WithGoogleInstanceIdentity(serviceAccount string, gcpClient *metadata.Client) SessionTokenSetup {
return func(client *codersdk.Client) RefreshableSessionTokenProvider {
return &instanceIdentitySessionTokenProvider{
tokenExchanger: &googleSessionTokenExchanger{
return &InstanceIdentitySessionTokenProvider{
TokenExchanger: &GoogleSessionTokenExchanger{
client: client,
gcpClient: gcpClient,
serviceAccount: serviceAccount,
@@ -40,7 +40,7 @@ func WithGoogleInstanceIdentity(serviceAccount string, gcpClient *metadata.Clien
// workspace agent.
//
// The requesting instance must be registered as a resource in the latest history for a workspace.
func (g *googleSessionTokenExchanger) exchange(ctx context.Context) (AuthenticateResponse, error) {
func (g *GoogleSessionTokenExchanger) exchange(ctx context.Context) (AuthenticateResponse, error) {
if g.serviceAccount == "" {
// This is the default name specified by Google.
g.serviceAccount = "default"