mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
chore: refactor CLI agent auth tests as unit tests (#19609)
Fixes https://github.com/coder/internal/issues/933 Refactors CLI tests that check the `--auth` flag parsing for various public clouds into a unit test that just creates the agent Client and asserts on the type. Testing that the agent client actually authenticates correctly with these auth types is well covered by Coderd tests, so we don't need to retread that ground here, and the deleted tests were flaky on Windows.
This commit is contained in:
+1
-1
@@ -179,7 +179,7 @@ func workspaceAgent() *serpent.Command {
|
|||||||
slog.F("auth", agentAuth.agentAuth),
|
slog.F("auth", agentAuth.agentAuth),
|
||||||
slog.F("version", version),
|
slog.F("version", version),
|
||||||
)
|
)
|
||||||
client, err := agentAuth.CreateClient(ctx)
|
client, err := agentAuth.CreateClient()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return xerrors.Errorf("create agent client: %w", err)
|
return xerrors.Errorf("create agent client: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package cli_test
|
package cli_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
@@ -11,7 +10,6 @@ import (
|
|||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
@@ -21,10 +19,7 @@ import (
|
|||||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||||
"github.com/coder/coder/v2/coderd/database"
|
"github.com/coder/coder/v2/coderd/database"
|
||||||
"github.com/coder/coder/v2/coderd/database/dbfake"
|
"github.com/coder/coder/v2/coderd/database/dbfake"
|
||||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
|
||||||
"github.com/coder/coder/v2/codersdk"
|
"github.com/coder/coder/v2/codersdk"
|
||||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
|
||||||
"github.com/coder/coder/v2/provisionersdk/proto"
|
|
||||||
"github.com/coder/coder/v2/testutil"
|
"github.com/coder/coder/v2/testutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -64,158 +59,6 @@ func TestWorkspaceAgent(t *testing.T) {
|
|||||||
}, testutil.WaitLong, testutil.IntervalMedium)
|
}, testutil.WaitLong, testutil.IntervalMedium)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("Azure", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
instanceID := "instanceidentifier"
|
|
||||||
certificates, metadataClient := coderdtest.NewAzureInstanceIdentity(t, instanceID)
|
|
||||||
db, ps := dbtestutil.NewDB(t,
|
|
||||||
dbtestutil.WithDumpOnFailure(),
|
|
||||||
)
|
|
||||||
client := coderdtest.New(t, &coderdtest.Options{
|
|
||||||
Database: db,
|
|
||||||
Pubsub: ps,
|
|
||||||
AzureCertificates: certificates,
|
|
||||||
})
|
|
||||||
user := coderdtest.CreateFirstUser(t, client)
|
|
||||||
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
|
||||||
OrganizationID: user.OrganizationID,
|
|
||||||
OwnerID: user.UserID,
|
|
||||||
}).WithAgent(func(agents []*proto.Agent) []*proto.Agent {
|
|
||||||
agents[0].Auth = &proto.Agent_InstanceId{InstanceId: instanceID}
|
|
||||||
return agents
|
|
||||||
}).Do()
|
|
||||||
|
|
||||||
inv, _ := clitest.New(t, "agent", "--auth", "azure-instance-identity", "--agent-url", client.URL.String())
|
|
||||||
inv = inv.WithContext(
|
|
||||||
//nolint:revive,staticcheck
|
|
||||||
context.WithValue(inv.Context(), "azure-client", metadataClient),
|
|
||||||
)
|
|
||||||
|
|
||||||
ctx := inv.Context()
|
|
||||||
clitest.Start(t, inv)
|
|
||||||
coderdtest.NewWorkspaceAgentWaiter(t, client, r.Workspace.ID).
|
|
||||||
MatchResources(matchAgentWithVersion).Wait()
|
|
||||||
workspace, err := client.Workspace(ctx, r.Workspace.ID)
|
|
||||||
require.NoError(t, err)
|
|
||||||
resources := workspace.LatestBuild.Resources
|
|
||||||
if assert.NotEmpty(t, workspace.LatestBuild.Resources) && assert.NotEmpty(t, resources[0].Agents) {
|
|
||||||
assert.NotEmpty(t, resources[0].Agents[0].Version)
|
|
||||||
}
|
|
||||||
dialer, err := workspacesdk.New(client).
|
|
||||||
DialAgent(ctx, resources[0].Agents[0].ID, nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer dialer.Close()
|
|
||||||
require.True(t, dialer.AwaitReachable(ctx))
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("AWS", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
instanceID := "instanceidentifier"
|
|
||||||
certificates, metadataClient := coderdtest.NewAWSInstanceIdentity(t, instanceID)
|
|
||||||
db, ps := dbtestutil.NewDB(t,
|
|
||||||
dbtestutil.WithDumpOnFailure(),
|
|
||||||
)
|
|
||||||
client := coderdtest.New(t, &coderdtest.Options{
|
|
||||||
Database: db,
|
|
||||||
Pubsub: ps,
|
|
||||||
AWSCertificates: certificates,
|
|
||||||
})
|
|
||||||
user := coderdtest.CreateFirstUser(t, client)
|
|
||||||
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
|
||||||
OrganizationID: user.OrganizationID,
|
|
||||||
OwnerID: user.UserID,
|
|
||||||
}).WithAgent(func(agents []*proto.Agent) []*proto.Agent {
|
|
||||||
agents[0].Auth = &proto.Agent_InstanceId{InstanceId: instanceID}
|
|
||||||
return agents
|
|
||||||
}).Do()
|
|
||||||
|
|
||||||
inv, _ := clitest.New(t, "agent", "--auth", "aws-instance-identity", "--agent-url", client.URL.String())
|
|
||||||
inv = inv.WithContext(
|
|
||||||
//nolint:revive,staticcheck
|
|
||||||
context.WithValue(inv.Context(), "aws-client", metadataClient),
|
|
||||||
)
|
|
||||||
|
|
||||||
clitest.Start(t, inv)
|
|
||||||
ctx := inv.Context()
|
|
||||||
coderdtest.NewWorkspaceAgentWaiter(t, client, r.Workspace.ID).
|
|
||||||
MatchResources(matchAgentWithVersion).
|
|
||||||
Wait()
|
|
||||||
workspace, err := client.Workspace(ctx, r.Workspace.ID)
|
|
||||||
require.NoError(t, err)
|
|
||||||
resources := workspace.LatestBuild.Resources
|
|
||||||
if assert.NotEmpty(t, resources) && assert.NotEmpty(t, resources[0].Agents) {
|
|
||||||
assert.NotEmpty(t, resources[0].Agents[0].Version)
|
|
||||||
}
|
|
||||||
dialer, err := workspacesdk.New(client).
|
|
||||||
DialAgent(ctx, resources[0].Agents[0].ID, nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer dialer.Close()
|
|
||||||
require.True(t, dialer.AwaitReachable(ctx))
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("GoogleCloud", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
instanceID := "instanceidentifier"
|
|
||||||
validator, metadataClient := coderdtest.NewGoogleInstanceIdentity(t, instanceID, false)
|
|
||||||
db, ps := dbtestutil.NewDB(t,
|
|
||||||
dbtestutil.WithDumpOnFailure(),
|
|
||||||
)
|
|
||||||
client := coderdtest.New(t, &coderdtest.Options{
|
|
||||||
Database: db,
|
|
||||||
Pubsub: ps,
|
|
||||||
GoogleTokenValidator: validator,
|
|
||||||
})
|
|
||||||
owner := coderdtest.CreateFirstUser(t, client)
|
|
||||||
member, memberUser := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
|
|
||||||
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
|
||||||
OrganizationID: owner.OrganizationID,
|
|
||||||
OwnerID: memberUser.ID,
|
|
||||||
}).WithAgent(func(agents []*proto.Agent) []*proto.Agent {
|
|
||||||
agents[0].Auth = &proto.Agent_InstanceId{InstanceId: instanceID}
|
|
||||||
return agents
|
|
||||||
}).Do()
|
|
||||||
|
|
||||||
inv, cfg := clitest.New(t, "agent", "--auth", "google-instance-identity", "--agent-url", client.URL.String())
|
|
||||||
clitest.SetupConfig(t, member, cfg)
|
|
||||||
|
|
||||||
clitest.Start(t,
|
|
||||||
inv.WithContext(
|
|
||||||
//nolint:revive,staticcheck
|
|
||||||
context.WithValue(inv.Context(), "gcp-client", metadataClient),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
ctx := inv.Context()
|
|
||||||
coderdtest.NewWorkspaceAgentWaiter(t, client, r.Workspace.ID).
|
|
||||||
MatchResources(matchAgentWithVersion).
|
|
||||||
Wait()
|
|
||||||
workspace, err := client.Workspace(ctx, r.Workspace.ID)
|
|
||||||
require.NoError(t, err)
|
|
||||||
resources := workspace.LatestBuild.Resources
|
|
||||||
if assert.NotEmpty(t, resources) && assert.NotEmpty(t, resources[0].Agents) {
|
|
||||||
assert.NotEmpty(t, resources[0].Agents[0].Version)
|
|
||||||
}
|
|
||||||
dialer, err := workspacesdk.New(client).DialAgent(ctx, resources[0].Agents[0].ID, nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer dialer.Close()
|
|
||||||
require.True(t, dialer.AwaitReachable(ctx))
|
|
||||||
sshClient, err := dialer.SSHClient(ctx)
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer sshClient.Close()
|
|
||||||
session, err := sshClient.NewSession()
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer session.Close()
|
|
||||||
key := "CODER_AGENT_TOKEN"
|
|
||||||
command := "sh -c 'echo $" + key + "'"
|
|
||||||
if runtime.GOOS == "windows" {
|
|
||||||
command = "cmd.exe /c echo %" + key + "%"
|
|
||||||
}
|
|
||||||
token, err := session.CombinedOutput(command)
|
|
||||||
require.NoError(t, err)
|
|
||||||
_, err = uuid.Parse(strings.TrimSpace(string(token)))
|
|
||||||
require.NoError(t, err)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("PostStartup", func(t *testing.T) {
|
t.Run("PostStartup", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
|||||||
+2
-2
@@ -149,7 +149,7 @@ func mcpConfigureClaudeCode() *serpent.Command {
|
|||||||
binPath = testBinaryName
|
binPath = testBinaryName
|
||||||
}
|
}
|
||||||
configureClaudeEnv := map[string]string{}
|
configureClaudeEnv := map[string]string{}
|
||||||
agentClient, err := agentAuth.CreateClient(inv.Context())
|
agentClient, err := agentAuth.CreateClient()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cliui.Warnf(inv.Stderr, "failed to create agent client: %s", err)
|
cliui.Warnf(inv.Stderr, "failed to create agent client: %s", err)
|
||||||
} else {
|
} else {
|
||||||
@@ -497,7 +497,7 @@ func (r *RootCmd) mcpServer() *serpent.Command {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Try to create an agent client for status reporting. Not validated.
|
// Try to create an agent client for status reporting. Not validated.
|
||||||
agentClient, err := agentAuth.CreateClient(inv.Context())
|
agentClient, err := agentAuth.CreateClient()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
cliui.Infof(inv.Stderr, "Agent URL : %s", agentClient.SDK.URL.String())
|
cliui.Infof(inv.Stderr, "Agent URL : %s", agentClient.SDK.URL.String())
|
||||||
srv.agentClient = agentClient
|
srv.agentClient = agentClient
|
||||||
|
|||||||
+1
-1
@@ -68,7 +68,7 @@ fi
|
|||||||
ctx, stop := inv.SignalNotifyContext(ctx, StopSignals...)
|
ctx, stop := inv.SignalNotifyContext(ctx, StopSignals...)
|
||||||
defer stop()
|
defer stop()
|
||||||
|
|
||||||
client, err := agentAuth.CreateClient(ctx)
|
client, err := agentAuth.CreateClient()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return xerrors.Errorf("create agent client: %w", err)
|
return xerrors.Errorf("create agent client: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
+1
-1
@@ -33,7 +33,7 @@ func gitAskpass(agentAuth *AgentAuth) *serpent.Command {
|
|||||||
return xerrors.Errorf("parse host: %w", err)
|
return xerrors.Errorf("parse host: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
client, err := agentAuth.CreateClient(ctx)
|
client, err := agentAuth.CreateClient()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return xerrors.Errorf("create agent client: %w", err)
|
return xerrors.Errorf("create agent client: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
+1
-1
@@ -39,7 +39,7 @@ func gitssh() *serpent.Command {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
client, err := agentAuth.CreateClient(ctx)
|
client, err := agentAuth.CreateClient()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return xerrors.Errorf("create agent client: %w", err)
|
return xerrors.Errorf("create agent client: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
+4
-35
@@ -24,7 +24,6 @@ import (
|
|||||||
"text/tabwriter"
|
"text/tabwriter"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"cloud.google.com/go/compute/metadata"
|
|
||||||
"github.com/mattn/go-isatty"
|
"github.com/mattn/go-isatty"
|
||||||
"github.com/mitchellh/go-wordwrap"
|
"github.com/mitchellh/go-wordwrap"
|
||||||
"golang.org/x/mod/semver"
|
"golang.org/x/mod/semver"
|
||||||
@@ -687,7 +686,7 @@ func (a *AgentAuth) AttachOptions(cmd *serpent.Command, hidden bool) {
|
|||||||
|
|
||||||
// CreateClient returns a new agent client from the command context. It works
|
// CreateClient returns a new agent client from the command context. It works
|
||||||
// just like InitClient, but uses the agent token and URL instead.
|
// just like InitClient, but uses the agent token and URL instead.
|
||||||
func (a *AgentAuth) CreateClient(ctx context.Context) (*agentsdk.Client, error) {
|
func (a *AgentAuth) CreateClient() (*agentsdk.Client, error) {
|
||||||
agentURL := a.agentURL
|
agentURL := a.agentURL
|
||||||
if agentURL.String() == "" {
|
if agentURL.String() == "" {
|
||||||
return nil, xerrors.Errorf("%s must be set", envAgentURL)
|
return nil, xerrors.Errorf("%s must be set", envAgentURL)
|
||||||
@@ -711,41 +710,11 @@ func (a *AgentAuth) CreateClient(ctx context.Context) (*agentsdk.Client, error)
|
|||||||
}
|
}
|
||||||
return agentsdk.New(&a.agentURL, agentsdk.WithFixedToken(token)), nil
|
return agentsdk.New(&a.agentURL, agentsdk.WithFixedToken(token)), nil
|
||||||
case "google-instance-identity":
|
case "google-instance-identity":
|
||||||
|
return agentsdk.New(&a.agentURL, agentsdk.WithGoogleInstanceIdentity("", nil)), nil
|
||||||
// This is *only* done for testing to mock client authentication.
|
|
||||||
// This will never be set in a production scenario.
|
|
||||||
var gcpClient *metadata.Client
|
|
||||||
gcpClientRaw := ctx.Value("gcp-client")
|
|
||||||
if gcpClientRaw != nil {
|
|
||||||
gcpClient, _ = gcpClientRaw.(*metadata.Client)
|
|
||||||
}
|
|
||||||
return agentsdk.New(&a.agentURL, agentsdk.WithGoogleInstanceIdentity("", gcpClient)), nil
|
|
||||||
case "aws-instance-identity":
|
case "aws-instance-identity":
|
||||||
client := agentsdk.New(&a.agentURL, agentsdk.WithAWSInstanceIdentity())
|
return agentsdk.New(&a.agentURL, agentsdk.WithAWSInstanceIdentity()), nil
|
||||||
// This is *only* done for testing to mock client authentication.
|
|
||||||
// This will never be set in a production scenario.
|
|
||||||
var awsClient *http.Client
|
|
||||||
awsClientRaw := ctx.Value("aws-client")
|
|
||||||
if awsClientRaw != nil {
|
|
||||||
awsClient, _ = awsClientRaw.(*http.Client)
|
|
||||||
if awsClient != nil {
|
|
||||||
client.SDK.HTTPClient = awsClient
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return client, nil
|
|
||||||
case "azure-instance-identity":
|
case "azure-instance-identity":
|
||||||
client := agentsdk.New(&a.agentURL, agentsdk.WithAzureInstanceIdentity())
|
return agentsdk.New(&a.agentURL, agentsdk.WithAzureInstanceIdentity()), nil
|
||||||
// This is *only* done for testing to mock client authentication.
|
|
||||||
// This will never be set in a production scenario.
|
|
||||||
var azureClient *http.Client
|
|
||||||
azureClientRaw := ctx.Value("azure-client")
|
|
||||||
if azureClientRaw != nil {
|
|
||||||
azureClient, _ = azureClientRaw.(*http.Client)
|
|
||||||
if azureClient != nil {
|
|
||||||
client.SDK.HTTPClient = azureClient
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return client, nil
|
|
||||||
default:
|
default:
|
||||||
return nil, xerrors.Errorf("unknown agent auth type: %s", a.agentAuth)
|
return nil, xerrors.Errorf("unknown agent auth type: %s", a.agentAuth)
|
||||||
}
|
}
|
||||||
|
|||||||
+88
-8
@@ -10,20 +10,20 @@ import (
|
|||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/coder/serpent"
|
|
||||||
|
|
||||||
"github.com/coder/coder/v2/coderd"
|
|
||||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
|
||||||
"github.com/coder/coder/v2/codersdk"
|
|
||||||
"github.com/coder/coder/v2/pty/ptytest"
|
|
||||||
"github.com/coder/coder/v2/testutil"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/xerrors"
|
||||||
|
|
||||||
"github.com/coder/coder/v2/buildinfo"
|
"github.com/coder/coder/v2/buildinfo"
|
||||||
"github.com/coder/coder/v2/cli"
|
"github.com/coder/coder/v2/cli"
|
||||||
"github.com/coder/coder/v2/cli/clitest"
|
"github.com/coder/coder/v2/cli/clitest"
|
||||||
|
"github.com/coder/coder/v2/coderd"
|
||||||
|
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||||
|
"github.com/coder/coder/v2/codersdk"
|
||||||
|
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||||
|
"github.com/coder/coder/v2/pty/ptytest"
|
||||||
|
"github.com/coder/coder/v2/testutil"
|
||||||
|
"github.com/coder/serpent"
|
||||||
)
|
)
|
||||||
|
|
||||||
//nolint:tparallel,paralleltest
|
//nolint:tparallel,paralleltest
|
||||||
@@ -275,3 +275,83 @@ func TestHandlersOK(t *testing.T) {
|
|||||||
|
|
||||||
clitest.HandlersOK(t, cmd)
|
clitest.HandlersOK(t, cmd)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCreateAgentClient_Token(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
client := createAgentWithFlags(t,
|
||||||
|
"--agent-token", "fake-token",
|
||||||
|
"--agent-url", "http://coder.fake")
|
||||||
|
require.Equal(t, "fake-token", client.GetSessionToken())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateAgentClient_Google(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
client := createAgentWithFlags(t,
|
||||||
|
"--auth", "google-instance-identity",
|
||||||
|
"--agent-url", "http://coder.fake")
|
||||||
|
provider, ok := client.RefreshableSessionTokenProvider.(*agentsdk.InstanceIdentitySessionTokenProvider)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.NotNil(t, provider.TokenExchanger)
|
||||||
|
require.IsType(t, &agentsdk.GoogleSessionTokenExchanger{}, provider.TokenExchanger)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateAgentClient_AWS(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
client := createAgentWithFlags(t,
|
||||||
|
"--auth", "aws-instance-identity",
|
||||||
|
"--agent-url", "http://coder.fake")
|
||||||
|
provider, ok := client.RefreshableSessionTokenProvider.(*agentsdk.InstanceIdentitySessionTokenProvider)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.NotNil(t, provider.TokenExchanger)
|
||||||
|
require.IsType(t, &agentsdk.AWSSessionTokenExchanger{}, provider.TokenExchanger)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateAgentClient_Azure(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
client := createAgentWithFlags(t,
|
||||||
|
"--auth", "azure-instance-identity",
|
||||||
|
"--agent-url", "http://coder.fake")
|
||||||
|
provider, ok := client.RefreshableSessionTokenProvider.(*agentsdk.InstanceIdentitySessionTokenProvider)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.NotNil(t, provider.TokenExchanger)
|
||||||
|
require.IsType(t, &agentsdk.AzureSessionTokenExchanger{}, provider.TokenExchanger)
|
||||||
|
}
|
||||||
|
|
||||||
|
func createAgentWithFlags(t *testing.T, flags ...string) *agentsdk.Client {
|
||||||
|
t.Helper()
|
||||||
|
r := &cli.RootCmd{}
|
||||||
|
var client *agentsdk.Client
|
||||||
|
subCmd := agentClientCommand(&client)
|
||||||
|
cmd, err := r.Command([]*serpent.Command{subCmd})
|
||||||
|
require.NoError(t, err)
|
||||||
|
inv, _ := clitest.NewWithCommand(t, cmd,
|
||||||
|
append([]string{"agent-client"}, flags...)...)
|
||||||
|
err = inv.Run()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, client)
|
||||||
|
return client
|
||||||
|
}
|
||||||
|
|
||||||
|
// agentClientCommand creates a subcommand that creates an agent client and stores it in the provided clientRef. Used to
|
||||||
|
// test the properties of the client with various root command flags.
|
||||||
|
func agentClientCommand(clientRef **agentsdk.Client) *serpent.Command {
|
||||||
|
agentAuth := &cli.AgentAuth{}
|
||||||
|
cmd := &serpent.Command{
|
||||||
|
Use: "agent-client",
|
||||||
|
Short: `Creates and agent client for testing.`,
|
||||||
|
Handler: func(inv *serpent.Invocation) error {
|
||||||
|
client, err := agentAuth.CreateClient()
|
||||||
|
if err != nil {
|
||||||
|
return xerrors.Errorf("create agent client: %w", err)
|
||||||
|
}
|
||||||
|
*clientRef = client
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
agentAuth.AttachOptions(cmd, false)
|
||||||
|
return cmd
|
||||||
|
}
|
||||||
|
|||||||
@@ -340,11 +340,11 @@ type RefreshableSessionTokenProvider interface {
|
|||||||
RefreshToken(ctx context.Context) error
|
RefreshToken(ctx context.Context) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// instanceIdentitySessionTokenProvider implements RefreshableSessionTokenProvider via token exchange for a cloud
|
// InstanceIdentitySessionTokenProvider implements RefreshableSessionTokenProvider via token exchange for a cloud
|
||||||
// compute instance identity.
|
// compute instance identity.
|
||||||
// @typescript-ignore instanceIdentitySessionTokenProvider
|
// @typescript-ignore InstanceIdentitySessionTokenProvider
|
||||||
type instanceIdentitySessionTokenProvider struct {
|
type InstanceIdentitySessionTokenProvider struct {
|
||||||
tokenExchanger tokenExchanger
|
TokenExchanger TokenExchanger
|
||||||
logger slog.Logger
|
logger slog.Logger
|
||||||
|
|
||||||
// cache so we don't request each time
|
// cache so we don't request each time
|
||||||
@@ -352,20 +352,20 @@ type instanceIdentitySessionTokenProvider struct {
|
|||||||
sessionToken string
|
sessionToken string
|
||||||
}
|
}
|
||||||
|
|
||||||
// tokenExchanger obtains a session token by exchanging a cloud instance identity credential for a Coder session token.
|
// TokenExchanger obtains a session token by exchanging a cloud instance identity credential for a Coder session token.
|
||||||
// @typescript-ignore tokenExchanger
|
// @typescript-ignore TokenExchanger
|
||||||
type tokenExchanger interface {
|
type TokenExchanger interface {
|
||||||
exchange(ctx context.Context) (AuthenticateResponse, error)
|
exchange(ctx context.Context) (AuthenticateResponse, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *instanceIdentitySessionTokenProvider) AsRequestOption() codersdk.RequestOption {
|
func (i *InstanceIdentitySessionTokenProvider) AsRequestOption() codersdk.RequestOption {
|
||||||
t := i.GetSessionToken()
|
t := i.GetSessionToken()
|
||||||
return func(req *http.Request) {
|
return func(req *http.Request) {
|
||||||
req.Header.Set(codersdk.SessionTokenHeader, t)
|
req.Header.Set(codersdk.SessionTokenHeader, t)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *instanceIdentitySessionTokenProvider) SetDialOption(opts *websocket.DialOptions) {
|
func (i *InstanceIdentitySessionTokenProvider) SetDialOption(opts *websocket.DialOptions) {
|
||||||
t := i.GetSessionToken()
|
t := i.GetSessionToken()
|
||||||
if opts.HTTPHeader == nil {
|
if opts.HTTPHeader == nil {
|
||||||
opts.HTTPHeader = http.Header{}
|
opts.HTTPHeader = http.Header{}
|
||||||
@@ -375,7 +375,7 @@ func (i *instanceIdentitySessionTokenProvider) SetDialOption(opts *websocket.Dia
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *instanceIdentitySessionTokenProvider) GetSessionToken() string {
|
func (i *InstanceIdentitySessionTokenProvider) GetSessionToken() string {
|
||||||
i.mu.Lock()
|
i.mu.Lock()
|
||||||
defer i.mu.Unlock()
|
defer i.mu.Unlock()
|
||||||
if i.sessionToken != "" {
|
if i.sessionToken != "" {
|
||||||
@@ -383,7 +383,7 @@ func (i *instanceIdentitySessionTokenProvider) GetSessionToken() string {
|
|||||||
}
|
}
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
resp, err := i.tokenExchanger.exchange(ctx)
|
resp, err := i.TokenExchanger.exchange(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
i.logger.Error(ctx, "failed to exchange session token: %v", err)
|
i.logger.Error(ctx, "failed to exchange session token: %v", err)
|
||||||
return ""
|
return ""
|
||||||
@@ -392,10 +392,10 @@ func (i *instanceIdentitySessionTokenProvider) GetSessionToken() string {
|
|||||||
return i.sessionToken
|
return i.sessionToken
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *instanceIdentitySessionTokenProvider) RefreshToken(ctx context.Context) error {
|
func (i *InstanceIdentitySessionTokenProvider) RefreshToken(ctx context.Context) error {
|
||||||
i.mu.Lock()
|
i.mu.Lock()
|
||||||
defer i.mu.Unlock()
|
defer i.mu.Unlock()
|
||||||
resp, err := i.tokenExchanger.exchange(ctx)
|
resp, err := i.TokenExchanger.exchange(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,16 +16,16 @@ type AWSInstanceIdentityToken struct {
|
|||||||
Document string `json:"document" validate:"required"`
|
Document string `json:"document" validate:"required"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// awsSessionTokenExchanger exchanges AWS instance metadata for a Coder session token.
|
// AWSSessionTokenExchanger exchanges AWS instance metadata for a Coder session token.
|
||||||
// @typescript-ignore awsSessionTokenExchanger
|
// @typescript-ignore AWSSessionTokenExchanger
|
||||||
type awsSessionTokenExchanger struct {
|
type AWSSessionTokenExchanger struct {
|
||||||
client *codersdk.Client
|
client *codersdk.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithAWSInstanceIdentity() SessionTokenSetup {
|
func WithAWSInstanceIdentity() SessionTokenSetup {
|
||||||
return func(client *codersdk.Client) RefreshableSessionTokenProvider {
|
return func(client *codersdk.Client) RefreshableSessionTokenProvider {
|
||||||
return &instanceIdentitySessionTokenProvider{
|
return &InstanceIdentitySessionTokenProvider{
|
||||||
tokenExchanger: &awsSessionTokenExchanger{client: client},
|
TokenExchanger: &AWSSessionTokenExchanger{client: client},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -34,7 +34,7 @@ func WithAWSInstanceIdentity() SessionTokenSetup {
|
|||||||
// agent.
|
// agent.
|
||||||
//
|
//
|
||||||
// The requesting instance must be registered as a resource in the latest history for a workspace.
|
// The requesting instance must be registered as a resource in the latest history for a workspace.
|
||||||
func (a *awsSessionTokenExchanger) exchange(ctx context.Context) (AuthenticateResponse, error) {
|
func (a *AWSSessionTokenExchanger) exchange(ctx context.Context) (AuthenticateResponse, error) {
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPut, "http://169.254.169.254/latest/api/token", nil)
|
req, err := http.NewRequestWithContext(ctx, http.MethodPut, "http://169.254.169.254/latest/api/token", nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return AuthenticateResponse{}, nil
|
return AuthenticateResponse{}, nil
|
||||||
|
|||||||
@@ -13,23 +13,23 @@ type AzureInstanceIdentityToken struct {
|
|||||||
Encoding string `json:"encoding" validate:"required"`
|
Encoding string `json:"encoding" validate:"required"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// azureSessionTokenExchanger exchanges Azure attested metadata for a Coder session token.
|
// AzureSessionTokenExchanger exchanges Azure attested metadata for a Coder session token.
|
||||||
// @typescript-ignore azureSessionTokenExchanger
|
// @typescript-ignore AzureSessionTokenExchanger
|
||||||
type azureSessionTokenExchanger struct {
|
type AzureSessionTokenExchanger struct {
|
||||||
client *codersdk.Client
|
client *codersdk.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithAzureInstanceIdentity() SessionTokenSetup {
|
func WithAzureInstanceIdentity() SessionTokenSetup {
|
||||||
return func(client *codersdk.Client) RefreshableSessionTokenProvider {
|
return func(client *codersdk.Client) RefreshableSessionTokenProvider {
|
||||||
return &instanceIdentitySessionTokenProvider{
|
return &InstanceIdentitySessionTokenProvider{
|
||||||
tokenExchanger: &azureSessionTokenExchanger{client: client},
|
TokenExchanger: &AzureSessionTokenExchanger{client: client},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthWorkspaceAzureInstanceIdentity uses the Azure Instance Metadata Service to
|
// AuthWorkspaceAzureInstanceIdentity uses the Azure Instance Metadata Service to
|
||||||
// fetch a signed payload, and exchange it for a session token for a workspace agent.
|
// fetch a signed payload, and exchange it for a session token for a workspace agent.
|
||||||
func (a *azureSessionTokenExchanger) exchange(ctx context.Context) (AuthenticateResponse, error) {
|
func (a *AzureSessionTokenExchanger) exchange(ctx context.Context) (AuthenticateResponse, error) {
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://169.254.169.254/metadata/attested/document?api-version=2020-09-01", nil)
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://169.254.169.254/metadata/attested/document?api-version=2020-09-01", nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return AuthenticateResponse{}, nil
|
return AuthenticateResponse{}, nil
|
||||||
|
|||||||
@@ -16,9 +16,9 @@ type GoogleInstanceIdentityToken struct {
|
|||||||
JSONWebToken string `json:"json_web_token" validate:"required"`
|
JSONWebToken string `json:"json_web_token" validate:"required"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// googleSessionTokenExchanger exchanges a Google instance JWT document for a Coder session token.
|
// GoogleSessionTokenExchanger exchanges a Google instance JWT document for a Coder session token.
|
||||||
// @typescript-ignore googleSessionTokenExchanger
|
// @typescript-ignore GoogleSessionTokenExchanger
|
||||||
type googleSessionTokenExchanger struct {
|
type GoogleSessionTokenExchanger struct {
|
||||||
serviceAccount string
|
serviceAccount string
|
||||||
gcpClient *metadata.Client
|
gcpClient *metadata.Client
|
||||||
client *codersdk.Client
|
client *codersdk.Client
|
||||||
@@ -26,8 +26,8 @@ type googleSessionTokenExchanger struct {
|
|||||||
|
|
||||||
func WithGoogleInstanceIdentity(serviceAccount string, gcpClient *metadata.Client) SessionTokenSetup {
|
func WithGoogleInstanceIdentity(serviceAccount string, gcpClient *metadata.Client) SessionTokenSetup {
|
||||||
return func(client *codersdk.Client) RefreshableSessionTokenProvider {
|
return func(client *codersdk.Client) RefreshableSessionTokenProvider {
|
||||||
return &instanceIdentitySessionTokenProvider{
|
return &InstanceIdentitySessionTokenProvider{
|
||||||
tokenExchanger: &googleSessionTokenExchanger{
|
TokenExchanger: &GoogleSessionTokenExchanger{
|
||||||
client: client,
|
client: client,
|
||||||
gcpClient: gcpClient,
|
gcpClient: gcpClient,
|
||||||
serviceAccount: serviceAccount,
|
serviceAccount: serviceAccount,
|
||||||
@@ -40,7 +40,7 @@ func WithGoogleInstanceIdentity(serviceAccount string, gcpClient *metadata.Clien
|
|||||||
// workspace agent.
|
// workspace agent.
|
||||||
//
|
//
|
||||||
// The requesting instance must be registered as a resource in the latest history for a workspace.
|
// The requesting instance must be registered as a resource in the latest history for a workspace.
|
||||||
func (g *googleSessionTokenExchanger) exchange(ctx context.Context) (AuthenticateResponse, error) {
|
func (g *GoogleSessionTokenExchanger) exchange(ctx context.Context) (AuthenticateResponse, error) {
|
||||||
if g.serviceAccount == "" {
|
if g.serviceAccount == "" {
|
||||||
// This is the default name specified by Google.
|
// This is the default name specified by Google.
|
||||||
g.serviceAccount = "default"
|
g.serviceAccount = "default"
|
||||||
|
|||||||
Reference in New Issue
Block a user