feat: support multiple agents with shared instance-identity auth (#24325)

> This PR was authored by Mux on behalf of Mike.

## Summary

Adds support for multiple peer root workspace agents sharing the same
`auth_instance_id`, so AWS, Azure, and GCP instance-identity auth can
issue the correct session token for a selected agent instead of assuming
a
single root agent per instance.

## Problem

When a Terraform template attaches two or more `coder_agent` resources
(with `auth = "aws-instance-identity"`) to a single compute instance,
every agent shares the same cloud instance ID. The existing singular
lookup picks whichever agent was created most recently, silently
ignoring
the others.

## Solution

Introduce an optional pre-auth agent selector (`CODER_AGENT_NAME`) and
make the server-side lookup ambiguity-aware.

**Database layer:**
- `GetWorkspaceAgentsByInstanceID` (`:many`): returns all matching root
  agents for an instance ID.
- `GetWorkspaceAgentByInstanceIDAndName` (`:one`): returns the named
root
  agent for disambiguation.

**SDK and CLI:**
- `agent_name` field added to AWS, Azure, and GCP request structs
  (`omitempty` for backward compatibility).
- `CODER_AGENT_NAME` env var and `--agent-name` flag wired into the
agent
  bootstrap before instance-identity auth runs.

**Server handler (`handleAuthInstanceID`):**
- When `agent_name` is present: direct lookup by (instance ID, name).
- When absent: legacy lookup, then resource-scoped ambiguity check.
  Returns 409 with available agent names if multiple root agents match.
- Whitespace-only names are trimmed and treated as unspecified.
- Sub-agents remain excluded (`parent_id IS NULL` filter).

**Verification template:**
- `examples/templates/aws-multi-agent/` provisions one EC2 instance with
  two agents (`main` and `dev`), both using instance-identity auth with
  `CODER_AGENT_NAME` set in the cloud-init user data.

## Backward compatibility

Existing single-agent deployments work unchanged. The `agent_name` field
is optional with `omitempty`, and the unnamed path preserves today's
behavior when only one root agent matches.
This commit is contained in:
Michael Suchacz
2026-04-16 13:59:09 +02:00
committed by GitHub
parent 1cf0354f72
commit e5707a13d6
29 changed files with 1563 additions and 286 deletions
+17 -3
View File
@@ -86,6 +86,7 @@ const (
envAgentTokenFile = "CODER_AGENT_TOKEN_FILE" envAgentTokenFile = "CODER_AGENT_TOKEN_FILE"
envAgentURL = "CODER_AGENT_URL" envAgentURL = "CODER_AGENT_URL"
envAgentAuth = "CODER_AGENT_AUTH" envAgentAuth = "CODER_AGENT_AUTH"
envAgentName = "CODER_AGENT_NAME"
envURL = "CODER_URL" envURL = "CODER_URL"
) )
@@ -789,6 +790,7 @@ type AgentAuth struct {
agentTokenFile string agentTokenFile string
agentURL url.URL agentURL url.URL
agentAuth string agentAuth string
agentName string
} }
func (a *AgentAuth) AttachOptions(cmd *serpent.Command, hidden bool) { func (a *AgentAuth) AttachOptions(cmd *serpent.Command, hidden bool) {
@@ -821,6 +823,13 @@ func (a *AgentAuth) AttachOptions(cmd *serpent.Command, hidden bool) {
Default: "token", Default: "token",
Value: serpent.StringOf(&a.agentAuth), Value: serpent.StringOf(&a.agentAuth),
Hidden: hidden, Hidden: hidden,
}, serpent.Option{
Name: "Agent Name",
Description: "The name of the agent to authenticate as (only applicable for instance identity).",
Flag: "agent-name",
Env: envAgentName,
Value: serpent.StringOf(&a.agentName),
Hidden: hidden,
}) })
} }
@@ -832,6 +841,11 @@ func (a *AgentAuth) CreateClient() (*agentsdk.Client, error) {
return nil, xerrors.Errorf("%s must be set", envAgentURL) return nil, xerrors.Errorf("%s must be set", envAgentURL)
} }
var iiOpts []agentsdk.InstanceIdentityOption
if a.agentName != "" {
iiOpts = append(iiOpts, agentsdk.WithInstanceIdentityAgentName(a.agentName))
}
switch a.agentAuth { switch a.agentAuth {
case "token": case "token":
token := a.agentToken token := a.agentToken
@@ -850,11 +864,11 @@ func (a *AgentAuth) CreateClient() (*agentsdk.Client, error) {
} }
return agentsdk.New(&a.agentURL, agentsdk.WithFixedToken(token)), nil return agentsdk.New(&a.agentURL, agentsdk.WithFixedToken(token)), nil
case "google-instance-identity": case "google-instance-identity":
return agentsdk.New(&a.agentURL, agentsdk.WithGoogleInstanceIdentity("", nil)), nil return agentsdk.New(&a.agentURL, agentsdk.WithGoogleInstanceIdentity("", nil, iiOpts...)), nil
case "aws-instance-identity": case "aws-instance-identity":
return agentsdk.New(&a.agentURL, agentsdk.WithAWSInstanceIdentity()), nil return agentsdk.New(&a.agentURL, agentsdk.WithAWSInstanceIdentity(iiOpts...)), nil
case "azure-instance-identity": case "azure-instance-identity":
return agentsdk.New(&a.agentURL, agentsdk.WithAzureInstanceIdentity()), nil return agentsdk.New(&a.agentURL, agentsdk.WithAzureInstanceIdentity(iiOpts...)), nil
default: default:
return nil, xerrors.Errorf("unknown agent auth type: %s", a.agentAuth) return nil, xerrors.Errorf("unknown agent auth type: %s", a.agentAuth)
} }
+63
View File
@@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"reflect"
"runtime" "runtime"
"strings" "strings"
"sync/atomic" "sync/atomic"
@@ -346,6 +347,68 @@ func TestCreateAgentClient_Azure(t *testing.T) {
require.IsType(t, &agentsdk.AzureSessionTokenExchanger{}, provider.TokenExchanger) require.IsType(t, &agentsdk.AzureSessionTokenExchanger{}, provider.TokenExchanger)
} }
func TestCreateAgentClient_GoogleAgentName(t *testing.T) {
t.Parallel()
client := createAgentWithFlags(t,
"--auth", "google-instance-identity",
"--agent-url", "http://coder.fake",
"--agent-name", "google-agent")
requireInstanceIdentityAgentName(t, client, &agentsdk.GoogleSessionTokenExchanger{}, "google-agent")
}
func TestCreateAgentClient_AWSAgentName(t *testing.T) {
t.Parallel()
client := createAgentWithFlags(t,
"--auth", "aws-instance-identity",
"--agent-url", "http://coder.fake",
"--agent-name", "aws-agent")
requireInstanceIdentityAgentName(t, client, &agentsdk.AWSSessionTokenExchanger{}, "aws-agent")
}
func TestCreateAgentClient_AzureAgentName(t *testing.T) {
t.Parallel()
client := createAgentWithFlags(t,
"--auth", "azure-instance-identity",
"--agent-url", "http://coder.fake",
"--agent-name", "azure-agent")
requireInstanceIdentityAgentName(t, client, &agentsdk.AzureSessionTokenExchanger{}, "azure-agent")
}
func TestCreateAgentClient_GoogleAgentNameEnv(t *testing.T) {
t.Parallel()
r := &cli.RootCmd{}
var client *agentsdk.Client
subCmd := agentClientCommand(&client)
cmd, err := r.Command([]*serpent.Command{subCmd})
require.NoError(t, err)
inv, _ := clitest.NewWithCommand(t, cmd,
"agent-client",
"--auth", "google-instance-identity",
"--agent-url", "http://coder.fake")
inv.Environ.Set("CODER_AGENT_NAME", "env-agent")
err = inv.Run()
require.NoError(t, err)
require.NotNil(t, client)
requireInstanceIdentityAgentName(t, client, &agentsdk.GoogleSessionTokenExchanger{}, "env-agent")
}
func requireInstanceIdentityAgentName(t *testing.T, client *agentsdk.Client, expectedExchanger any, want string) {
t.Helper()
provider, ok := client.RefreshableSessionTokenProvider.(*agentsdk.InstanceIdentitySessionTokenProvider)
require.True(t, ok)
require.NotNil(t, provider.TokenExchanger)
require.IsType(t, expectedExchanger, provider.TokenExchanger)
agentNameField := reflect.ValueOf(provider.TokenExchanger).Elem().FieldByName("agentName")
require.True(t, agentNameField.IsValid())
require.Equal(t, want, agentNameField.String())
}
func createAgentWithFlags(t *testing.T, flags ...string) *agentsdk.Client { func createAgentWithFlags(t *testing.T, flags ...string) *agentsdk.Client {
t.Helper() t.Helper()
r := &cli.RootCmd{} r := &cli.RootCmd{}
+4
View File
@@ -9,6 +9,10 @@ OPTIONS:
--auth string, $CODER_AGENT_AUTH (default: token) --auth string, $CODER_AGENT_AUTH (default: token)
Specify the authentication type to use for the agent. Specify the authentication type to use for the agent.
--agent-name string, $CODER_AGENT_NAME
The name of the agent to authenticate as (only applicable for instance
identity).
--agent-token string, $CODER_AGENT_TOKEN --agent-token string, $CODER_AGENT_TOKEN
An agent authentication token. An agent authentication token.
@@ -28,6 +28,10 @@ OPTIONS:
--auth string, $CODER_AGENT_AUTH (default: token) --auth string, $CODER_AGENT_AUTH (default: token)
Specify the authentication type to use for the agent. Specify the authentication type to use for the agent.
--agent-name string, $CODER_AGENT_NAME
The name of the agent to authenticate as (only applicable for instance
identity).
--agent-token string, $CODER_AGENT_TOKEN --agent-token string, $CODER_AGENT_TOKEN
An agent authentication token. An agent authentication token.
+3 -1
View File
@@ -213,8 +213,10 @@ func TestSubAgentAPI(t *testing.T) {
// Double-check: looking up by the parent's instance ID must // Double-check: looking up by the parent's instance ID must
// still return the parent, not the sub-agent. // still return the parent, not the sub-agent.
lookedUp, err := db.GetWorkspaceAgentByInstanceID(dbauthz.AsSystemRestricted(ctx), parentAgent.AuthInstanceID.String) agents, err := db.GetWorkspaceAgentsByInstanceID(dbauthz.AsSystemRestricted(ctx), parentAgent.AuthInstanceID.String)
require.NoError(t, err) require.NoError(t, err)
require.Len(t, agents, 1)
lookedUp := agents[0]
assert.Equal(t, parentAgent.ID, lookedUp.ID, "instance ID lookup should still return the parent agent") assert.Equal(t, parentAgent.ID, lookedUp.ID, "instance ID lookup should still return the parent agent")
}) })
+15 -3
View File
@@ -10096,7 +10096,7 @@ const docTemplate = `{
"operationId": "authenticate-agent-on-aws-instance", "operationId": "authenticate-agent-on-aws-instance",
"parameters": [ "parameters": [
{ {
"description": "Instance identity token", "description": "Instance identity token. The optional agent_name field disambiguates when multiple agents share the same instance ID.",
"name": "request", "name": "request",
"in": "body", "in": "body",
"required": true, "required": true,
@@ -10135,7 +10135,7 @@ const docTemplate = `{
"operationId": "authenticate-agent-on-azure-instance", "operationId": "authenticate-agent-on-azure-instance",
"parameters": [ "parameters": [
{ {
"description": "Instance identity token", "description": "Instance identity token. The optional agent_name field disambiguates when multiple agents share the same instance ID.",
"name": "request", "name": "request",
"in": "body", "in": "body",
"required": true, "required": true,
@@ -10202,7 +10202,7 @@ const docTemplate = `{
"operationId": "authenticate-agent-on-google-cloud-instance", "operationId": "authenticate-agent-on-google-cloud-instance",
"parameters": [ "parameters": [
{ {
"description": "Instance identity token", "description": "Instance identity token. The optional agent_name field disambiguates when multiple agents share the same instance ID.",
"name": "request", "name": "request",
"in": "body", "in": "body",
"required": true, "required": true,
@@ -12780,6 +12780,10 @@ const docTemplate = `{
"signature" "signature"
], ],
"properties": { "properties": {
"agent_name": {
"description": "AgentName optionally selects a specific agent when multiple\nagents share the same instance identity. An empty string is\ntreated as unspecified.",
"type": "string"
},
"document": { "document": {
"type": "string" "type": "string"
}, },
@@ -12803,6 +12807,10 @@ const docTemplate = `{
"signature" "signature"
], ],
"properties": { "properties": {
"agent_name": {
"description": "AgentName optionally selects a specific agent when multiple\nagents share the same instance identity. An empty string is\ntreated as unspecified.",
"type": "string"
},
"encoding": { "encoding": {
"type": "string" "type": "string"
}, },
@@ -12853,6 +12861,10 @@ const docTemplate = `{
"json_web_token" "json_web_token"
], ],
"properties": { "properties": {
"agent_name": {
"description": "AgentName optionally selects a specific agent when multiple\nagents share the same instance identity. An empty string is\ntreated as unspecified.",
"type": "string"
},
"json_web_token": { "json_web_token": {
"type": "string" "type": "string"
} }
+15 -3
View File
@@ -8949,7 +8949,7 @@
"operationId": "authenticate-agent-on-aws-instance", "operationId": "authenticate-agent-on-aws-instance",
"parameters": [ "parameters": [
{ {
"description": "Instance identity token", "description": "Instance identity token. The optional agent_name field disambiguates when multiple agents share the same instance ID.",
"name": "request", "name": "request",
"in": "body", "in": "body",
"required": true, "required": true,
@@ -8982,7 +8982,7 @@
"operationId": "authenticate-agent-on-azure-instance", "operationId": "authenticate-agent-on-azure-instance",
"parameters": [ "parameters": [
{ {
"description": "Instance identity token", "description": "Instance identity token. The optional agent_name field disambiguates when multiple agents share the same instance ID.",
"name": "request", "name": "request",
"in": "body", "in": "body",
"required": true, "required": true,
@@ -9039,7 +9039,7 @@
"operationId": "authenticate-agent-on-google-cloud-instance", "operationId": "authenticate-agent-on-google-cloud-instance",
"parameters": [ "parameters": [
{ {
"description": "Instance identity token", "description": "Instance identity token. The optional agent_name field disambiguates when multiple agents share the same instance ID.",
"name": "request", "name": "request",
"in": "body", "in": "body",
"required": true, "required": true,
@@ -11337,6 +11337,10 @@
"type": "object", "type": "object",
"required": ["document", "signature"], "required": ["document", "signature"],
"properties": { "properties": {
"agent_name": {
"description": "AgentName optionally selects a specific agent when multiple\nagents share the same instance identity. An empty string is\ntreated as unspecified.",
"type": "string"
},
"document": { "document": {
"type": "string" "type": "string"
}, },
@@ -11357,6 +11361,10 @@
"type": "object", "type": "object",
"required": ["encoding", "signature"], "required": ["encoding", "signature"],
"properties": { "properties": {
"agent_name": {
"description": "AgentName optionally selects a specific agent when multiple\nagents share the same instance identity. An empty string is\ntreated as unspecified.",
"type": "string"
},
"encoding": { "encoding": {
"type": "string" "type": "string"
}, },
@@ -11405,6 +11413,10 @@
"type": "object", "type": "object",
"required": ["json_web_token"], "required": ["json_web_token"],
"properties": { "properties": {
"agent_name": {
"description": "AgentName optionally selects a specific agent when multiple\nagents share the same instance identity. An empty string is\ntreated as unspecified.",
"type": "string"
},
"json_web_token": { "json_web_token": {
"type": "string" "type": "string"
} }
+27 -16
View File
@@ -4422,22 +4422,6 @@ func (q *querier) GetWorkspaceAgentByID(ctx context.Context, id uuid.UUID) (data
return q.db.GetWorkspaceAgentByID(ctx, id) return q.db.GetWorkspaceAgentByID(ctx, id)
} }
// GetWorkspaceAgentByInstanceID might want to be a system call? Unsure exactly,
// but this will fail. Need to figure out what AuthInstanceID is, and if it
// is essentially an auth token. But the caller using this function is not
// an authenticated user. So this authz check will fail.
func (q *querier) GetWorkspaceAgentByInstanceID(ctx context.Context, authInstanceID string) (database.WorkspaceAgent, error) {
agent, err := q.db.GetWorkspaceAgentByInstanceID(ctx, authInstanceID)
if err != nil {
return database.WorkspaceAgent{}, err
}
_, err = q.GetWorkspaceByAgentID(ctx, agent.ID)
if err != nil {
return database.WorkspaceAgent{}, err
}
return agent, nil
}
func (q *querier) GetWorkspaceAgentDevcontainersByAgentID(ctx context.Context, workspaceAgentID uuid.UUID) ([]database.WorkspaceAgentDevcontainer, error) { func (q *querier) GetWorkspaceAgentDevcontainersByAgentID(ctx context.Context, workspaceAgentID uuid.UUID) ([]database.WorkspaceAgentDevcontainer, error) {
_, err := q.GetWorkspaceAgentByID(ctx, workspaceAgentID) _, err := q.GetWorkspaceAgentByID(ctx, workspaceAgentID)
if err != nil { if err != nil {
@@ -4527,6 +4511,33 @@ func (q *querier) GetWorkspaceAgentUsageStatsAndLabels(ctx context.Context, crea
return q.db.GetWorkspaceAgentUsageStatsAndLabels(ctx, createdAt) return q.db.GetWorkspaceAgentUsageStatsAndLabels(ctx, createdAt)
} }
func (q *querier) GetWorkspaceAgentsByInstanceID(ctx context.Context, authInstanceID string) ([]database.WorkspaceAgent, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err == nil {
return q.db.GetWorkspaceAgentsByInstanceID(ctx, authInstanceID)
}
agents, err := q.db.GetWorkspaceAgentsByInstanceID(ctx, authInstanceID)
if err != nil {
return nil, err
}
// Filter to agents whose workspace is accessible. Template-version
// agents can share the same instance ID but do not belong to a
// workspace, so GetWorkspaceByAgentID returns sql.ErrNoRows for
// them. Exclude those agents rather than failing the entire lookup.
filtered := make([]database.WorkspaceAgent, 0, len(agents))
for _, agent := range agents {
_, err = q.GetWorkspaceByAgentID(ctx, agent.ID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
continue
}
return nil, err
}
filtered = append(filtered, agent)
}
return filtered, nil
}
func (q *querier) GetWorkspaceAgentsByParentID(ctx context.Context, parentID uuid.UUID) ([]database.WorkspaceAgent, error) { func (q *querier) GetWorkspaceAgentsByParentID(ctx context.Context, parentID uuid.UUID) ([]database.WorkspaceAgent, error) {
workspace, err := q.db.GetWorkspaceByAgentID(ctx, parentID) workspace, err := q.db.GetWorkspaceByAgentID(ctx, parentID)
if err != nil { if err != nil {
+6 -3
View File
@@ -3012,13 +3012,16 @@ func (s *MethodTestSuite) TestWorkspace() {
dbm.EXPECT().BatchUpdateWorkspaceAgentMetadata(gomock.Any(), arg).Return(nil).AnyTimes() dbm.EXPECT().BatchUpdateWorkspaceAgentMetadata(gomock.Any(), arg).Return(nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceWorkspace.All(), policy.ActionUpdate).Returns() check.Args(arg).Asserts(rbac.ResourceWorkspace.All(), policy.ActionUpdate).Returns()
})) }))
s.Run("GetWorkspaceAgentByInstanceID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { s.Run("GetWorkspaceAgentsByInstanceID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
w := testutil.Fake(s.T(), faker, database.Workspace{}) w := testutil.Fake(s.T(), faker, database.Workspace{})
agt := testutil.Fake(s.T(), faker, database.WorkspaceAgent{}) agt := testutil.Fake(s.T(), faker, database.WorkspaceAgent{})
authInstanceID := "instance-id" authInstanceID := "instance-id"
dbm.EXPECT().GetWorkspaceAgentByInstanceID(gomock.Any(), authInstanceID).Return(agt, nil).AnyTimes() dbm.EXPECT().GetWorkspaceAgentsByInstanceID(gomock.Any(), authInstanceID).Return([]database.WorkspaceAgent{agt}, nil).AnyTimes()
dbm.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agt.ID).Return(w, nil).AnyTimes() dbm.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agt.ID).Return(w, nil).AnyTimes()
check.Args(authInstanceID).Asserts(w, policy.ActionRead).Returns(agt) check.Args(authInstanceID).
Asserts(rbac.ResourceSystem, policy.ActionRead, w, policy.ActionRead).
Returns([]database.WorkspaceAgent{agt}).
FailSystemObjectChecks()
})) }))
s.Run("UpdateWorkspaceAgentLifecycleStateByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { s.Run("UpdateWorkspaceAgentLifecycleStateByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
w := testutil.Fake(s.T(), faker, database.Workspace{}) w := testutil.Fake(s.T(), faker, database.Workspace{})
+8 -8
View File
@@ -2864,14 +2864,6 @@ func (m queryMetricsStore) GetWorkspaceAgentByID(ctx context.Context, id uuid.UU
return r0, r1 return r0, r1
} }
func (m queryMetricsStore) GetWorkspaceAgentByInstanceID(ctx context.Context, authInstanceID string) (database.WorkspaceAgent, error) {
start := time.Now()
r0, r1 := m.s.GetWorkspaceAgentByInstanceID(ctx, authInstanceID)
m.queryLatencies.WithLabelValues("GetWorkspaceAgentByInstanceID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetWorkspaceAgentByInstanceID").Inc()
return r0, r1
}
func (m queryMetricsStore) GetWorkspaceAgentDevcontainersByAgentID(ctx context.Context, workspaceAgentID uuid.UUID) ([]database.WorkspaceAgentDevcontainer, error) { func (m queryMetricsStore) GetWorkspaceAgentDevcontainersByAgentID(ctx context.Context, workspaceAgentID uuid.UUID) ([]database.WorkspaceAgentDevcontainer, error) {
start := time.Now() start := time.Now()
r0, r1 := m.s.GetWorkspaceAgentDevcontainersByAgentID(ctx, workspaceAgentID) r0, r1 := m.s.GetWorkspaceAgentDevcontainersByAgentID(ctx, workspaceAgentID)
@@ -2968,6 +2960,14 @@ func (m queryMetricsStore) GetWorkspaceAgentUsageStatsAndLabels(ctx context.Cont
return r0, r1 return r0, r1
} }
func (m queryMetricsStore) GetWorkspaceAgentsByInstanceID(ctx context.Context, authInstanceID string) ([]database.WorkspaceAgent, error) {
start := time.Now()
r0, r1 := m.s.GetWorkspaceAgentsByInstanceID(ctx, authInstanceID)
m.queryLatencies.WithLabelValues("GetWorkspaceAgentsByInstanceID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetWorkspaceAgentsByInstanceID").Inc()
return r0, r1
}
func (m queryMetricsStore) GetWorkspaceAgentsByParentID(ctx context.Context, parentID uuid.UUID) ([]database.WorkspaceAgent, error) { func (m queryMetricsStore) GetWorkspaceAgentsByParentID(ctx context.Context, parentID uuid.UUID) ([]database.WorkspaceAgent, error) {
start := time.Now() start := time.Now()
r0, r1 := m.s.GetWorkspaceAgentsByParentID(ctx, parentID) r0, r1 := m.s.GetWorkspaceAgentsByParentID(ctx, parentID)
+15 -15
View File
@@ -5357,21 +5357,6 @@ func (mr *MockStoreMockRecorder) GetWorkspaceAgentByID(ctx, id any) *gomock.Call
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWorkspaceAgentByID", reflect.TypeOf((*MockStore)(nil).GetWorkspaceAgentByID), ctx, id) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWorkspaceAgentByID", reflect.TypeOf((*MockStore)(nil).GetWorkspaceAgentByID), ctx, id)
} }
// GetWorkspaceAgentByInstanceID mocks base method.
func (m *MockStore) GetWorkspaceAgentByInstanceID(ctx context.Context, authInstanceID string) (database.WorkspaceAgent, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetWorkspaceAgentByInstanceID", ctx, authInstanceID)
ret0, _ := ret[0].(database.WorkspaceAgent)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetWorkspaceAgentByInstanceID indicates an expected call of GetWorkspaceAgentByInstanceID.
func (mr *MockStoreMockRecorder) GetWorkspaceAgentByInstanceID(ctx, authInstanceID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWorkspaceAgentByInstanceID", reflect.TypeOf((*MockStore)(nil).GetWorkspaceAgentByInstanceID), ctx, authInstanceID)
}
// GetWorkspaceAgentDevcontainersByAgentID mocks base method. // GetWorkspaceAgentDevcontainersByAgentID mocks base method.
func (m *MockStore) GetWorkspaceAgentDevcontainersByAgentID(ctx context.Context, workspaceAgentID uuid.UUID) ([]database.WorkspaceAgentDevcontainer, error) { func (m *MockStore) GetWorkspaceAgentDevcontainersByAgentID(ctx context.Context, workspaceAgentID uuid.UUID) ([]database.WorkspaceAgentDevcontainer, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@@ -5552,6 +5537,21 @@ func (mr *MockStoreMockRecorder) GetWorkspaceAgentUsageStatsAndLabels(ctx, creat
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWorkspaceAgentUsageStatsAndLabels", reflect.TypeOf((*MockStore)(nil).GetWorkspaceAgentUsageStatsAndLabels), ctx, createdAt) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWorkspaceAgentUsageStatsAndLabels", reflect.TypeOf((*MockStore)(nil).GetWorkspaceAgentUsageStatsAndLabels), ctx, createdAt)
} }
// GetWorkspaceAgentsByInstanceID mocks base method.
func (m *MockStore) GetWorkspaceAgentsByInstanceID(ctx context.Context, authInstanceID string) ([]database.WorkspaceAgent, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetWorkspaceAgentsByInstanceID", ctx, authInstanceID)
ret0, _ := ret[0].([]database.WorkspaceAgent)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetWorkspaceAgentsByInstanceID indicates an expected call of GetWorkspaceAgentsByInstanceID.
func (mr *MockStoreMockRecorder) GetWorkspaceAgentsByInstanceID(ctx, authInstanceID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWorkspaceAgentsByInstanceID", reflect.TypeOf((*MockStore)(nil).GetWorkspaceAgentsByInstanceID), ctx, authInstanceID)
}
// GetWorkspaceAgentsByParentID mocks base method. // GetWorkspaceAgentsByParentID mocks base method.
func (m *MockStore) GetWorkspaceAgentsByParentID(ctx context.Context, parentID uuid.UUID) ([]database.WorkspaceAgent, error) { func (m *MockStore) GetWorkspaceAgentsByParentID(ctx context.Context, parentID uuid.UUID) ([]database.WorkspaceAgent, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
+1 -1
View File
@@ -683,7 +683,6 @@ type sqlcQuerier interface {
GetWorkspaceACLByID(ctx context.Context, id uuid.UUID) (GetWorkspaceACLByIDRow, error) GetWorkspaceACLByID(ctx context.Context, id uuid.UUID) (GetWorkspaceACLByIDRow, error)
GetWorkspaceAgentAndWorkspaceByID(ctx context.Context, id uuid.UUID) (GetWorkspaceAgentAndWorkspaceByIDRow, error) GetWorkspaceAgentAndWorkspaceByID(ctx context.Context, id uuid.UUID) (GetWorkspaceAgentAndWorkspaceByIDRow, error)
GetWorkspaceAgentByID(ctx context.Context, id uuid.UUID) (WorkspaceAgent, error) GetWorkspaceAgentByID(ctx context.Context, id uuid.UUID) (WorkspaceAgent, error)
GetWorkspaceAgentByInstanceID(ctx context.Context, authInstanceID string) (WorkspaceAgent, error)
GetWorkspaceAgentDevcontainersByAgentID(ctx context.Context, workspaceAgentID uuid.UUID) ([]WorkspaceAgentDevcontainer, error) GetWorkspaceAgentDevcontainersByAgentID(ctx context.Context, workspaceAgentID uuid.UUID) ([]WorkspaceAgentDevcontainer, error)
GetWorkspaceAgentLifecycleStateByID(ctx context.Context, id uuid.UUID) (GetWorkspaceAgentLifecycleStateByIDRow, error) GetWorkspaceAgentLifecycleStateByID(ctx context.Context, id uuid.UUID) (GetWorkspaceAgentLifecycleStateByIDRow, error)
GetWorkspaceAgentLogSourcesByAgentIDs(ctx context.Context, ids []uuid.UUID) ([]WorkspaceAgentLogSource, error) GetWorkspaceAgentLogSourcesByAgentIDs(ctx context.Context, ids []uuid.UUID) ([]WorkspaceAgentLogSource, error)
@@ -697,6 +696,7 @@ type sqlcQuerier interface {
// `minute_buckets` could return 0 rows if there are no usage stats since `created_at`. // `minute_buckets` could return 0 rows if there are no usage stats since `created_at`.
GetWorkspaceAgentUsageStats(ctx context.Context, createdAt time.Time) ([]GetWorkspaceAgentUsageStatsRow, error) GetWorkspaceAgentUsageStats(ctx context.Context, createdAt time.Time) ([]GetWorkspaceAgentUsageStatsRow, error)
GetWorkspaceAgentUsageStatsAndLabels(ctx context.Context, createdAt time.Time) ([]GetWorkspaceAgentUsageStatsAndLabelsRow, error) GetWorkspaceAgentUsageStatsAndLabels(ctx context.Context, createdAt time.Time) ([]GetWorkspaceAgentUsageStatsAndLabelsRow, error)
GetWorkspaceAgentsByInstanceID(ctx context.Context, authInstanceID string) ([]WorkspaceAgent, error)
GetWorkspaceAgentsByParentID(ctx context.Context, parentID uuid.UUID) ([]WorkspaceAgent, error) GetWorkspaceAgentsByParentID(ctx context.Context, parentID uuid.UUID) ([]WorkspaceAgent, error)
GetWorkspaceAgentsByResourceIDs(ctx context.Context, ids []uuid.UUID) ([]WorkspaceAgent, error) GetWorkspaceAgentsByResourceIDs(ctx context.Context, ids []uuid.UUID) ([]WorkspaceAgent, error)
GetWorkspaceAgentsByWorkspaceAndBuildNumber(ctx context.Context, arg GetWorkspaceAgentsByWorkspaceAndBuildNumberParams) ([]WorkspaceAgent, error) GetWorkspaceAgentsByWorkspaceAndBuildNumber(ctx context.Context, arg GetWorkspaceAgentsByWorkspaceAndBuildNumberParams) ([]WorkspaceAgent, error)
+119 -26
View File
@@ -7184,38 +7184,55 @@ func TestGetWorkspaceAgentsByParentID(t *testing.T) {
}) })
} }
func TestGetWorkspaceAgentByInstanceID(t *testing.T) { func setupWorkspaceAgentQueryResources(t *testing.T, db database.Store, count int) []database.WorkspaceResource {
t.Helper()
org := dbgen.Organization(t, db, database.Organization{})
job := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
Type: database.ProvisionerJobTypeTemplateVersionImport,
OrganizationID: org.ID,
})
resources := make([]database.WorkspaceResource, 0, count)
for i := 0; i < count; i++ {
resources = append(resources, dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
JobID: job.ID,
}))
}
return resources
}
func markWorkspaceAgentDeleted(ctx context.Context, t *testing.T, sqlDB *sql.DB, agentID uuid.UUID) {
t.Helper()
_, err := sqlDB.ExecContext(ctx, "UPDATE workspace_agents SET deleted = TRUE WHERE id = $1", agentID)
require.NoError(t, err)
}
func TestGetWorkspaceAgentsByInstanceID(t *testing.T) {
t.Parallel() t.Parallel()
// Context: https://github.com/coder/coder/pull/22196 t.Run("ReturnsAllMatchingRootAgents", func(t *testing.T) {
t.Run("DoesNotReturnSubAgents", func(t *testing.T) {
t.Parallel() t.Parallel()
// Given: A parent workspace agent with an AuthInstanceID and a
// sub-agent that shares the same AuthInstanceID.
db, _ := dbtestutil.NewDB(t) db, _ := dbtestutil.NewDB(t)
org := dbgen.Organization(t, db, database.Organization{}) resources := setupWorkspaceAgentQueryResources(t, db, 2)
job := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
Type: database.ProvisionerJobTypeTemplateVersionImport,
OrganizationID: org.ID,
})
resource := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
JobID: job.ID,
})
authInstanceID := fmt.Sprintf("instance-%s-%d", t.Name(), time.Now().UnixNano()) authInstanceID := fmt.Sprintf("instance-%s-%d", t.Name(), time.Now().UnixNano())
parentAgent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{ olderCreatedAt := dbtime.Now().Add(-time.Hour)
ResourceID: resource.ID, newerCreatedAt := dbtime.Now()
olderAgent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
ResourceID: resources[0].ID,
CreatedAt: olderCreatedAt,
AuthInstanceID: sql.NullString{ AuthInstanceID: sql.NullString{
String: authInstanceID, String: authInstanceID,
Valid: true, Valid: true,
}, },
}) })
// Create a sub-agent with the same AuthInstanceID (simulating newerAgent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
// the old behavior before the fix). ResourceID: resources[1].ID,
_ = dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{ CreatedAt: newerCreatedAt,
ParentID: uuid.NullUUID{UUID: parentAgent.ID, Valid: true},
ResourceID: resource.ID,
AuthInstanceID: sql.NullString{ AuthInstanceID: sql.NullString{
String: authInstanceID, String: authInstanceID,
Valid: true, Valid: true,
@@ -7224,13 +7241,89 @@ func TestGetWorkspaceAgentByInstanceID(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitShort) ctx := testutil.Context(t, testutil.WaitShort)
// When: We look up the agent by instance ID. agents, err := db.GetWorkspaceAgentsByInstanceID(ctx, authInstanceID)
agent, err := db.GetWorkspaceAgentByInstanceID(ctx, authInstanceID)
require.NoError(t, err) require.NoError(t, err)
require.Len(t, agents, 2)
assert.Equal(t, []uuid.UUID{newerAgent.ID, olderAgent.ID}, []uuid.UUID{agents[0].ID, agents[1].ID})
})
// Then: The result must be the parent agent, not the sub-agent. t.Run("ExcludesDeletedAndSubAgents", func(t *testing.T) {
assert.Equal(t, parentAgent.ID, agent.ID, "instance ID lookup should return the parent agent, not a sub-agent") t.Parallel()
assert.False(t, agent.ParentID.Valid, "returned agent should not have a parent (should be the parent itself)")
db, _, sqlDB := dbtestutil.NewDBWithSQLDB(t)
resources := setupWorkspaceAgentQueryResources(t, db, 2)
authInstanceID := fmt.Sprintf("instance-%s-%d", t.Name(), time.Now().UnixNano())
baseCreatedAt := dbtime.Now()
rootAgent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
ResourceID: resources[0].ID,
CreatedAt: baseCreatedAt.Add(-time.Hour),
AuthInstanceID: sql.NullString{
String: authInstanceID,
Valid: true,
},
})
_ = dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
ParentID: uuid.NullUUID{UUID: rootAgent.ID, Valid: true},
ResourceID: resources[0].ID,
CreatedAt: baseCreatedAt,
AuthInstanceID: sql.NullString{
String: authInstanceID,
Valid: true,
},
})
deletedRootAgent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
ResourceID: resources[1].ID,
CreatedAt: baseCreatedAt.Add(time.Minute),
AuthInstanceID: sql.NullString{
String: authInstanceID,
Valid: true,
},
})
ctx := testutil.Context(t, testutil.WaitShort)
markWorkspaceAgentDeleted(ctx, t, sqlDB, deletedRootAgent.ID)
agents, err := db.GetWorkspaceAgentsByInstanceID(ctx, authInstanceID)
require.NoError(t, err)
require.Len(t, agents, 1)
assert.Equal(t, rootAgent.ID, agents[0].ID)
assert.False(t, agents[0].ParentID.Valid)
})
t.Run("OrdersNewestFirst", func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
resources := setupWorkspaceAgentQueryResources(t, db, 2)
authInstanceID := fmt.Sprintf("instance-%s-%d", t.Name(), time.Now().UnixNano())
olderCreatedAt := dbtime.Now().Add(-time.Hour)
newerCreatedAt := dbtime.Now()
olderAgent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
ResourceID: resources[0].ID,
CreatedAt: olderCreatedAt,
AuthInstanceID: sql.NullString{
String: authInstanceID,
Valid: true,
},
})
newerAgent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
ResourceID: resources[1].ID,
CreatedAt: newerCreatedAt,
AuthInstanceID: sql.NullString{
String: authInstanceID,
Valid: true,
},
})
ctx := testutil.Context(t, testutil.WaitShort)
agents, err := db.GetWorkspaceAgentsByInstanceID(ctx, authInstanceID)
require.NoError(t, err)
require.Len(t, agents, 2)
assert.Equal(t, newerAgent.ID, agents[0].ID)
assert.Equal(t, olderAgent.ID, agents[1].ID)
}) })
} }
+73 -57
View File
@@ -26566,63 +26566,6 @@ func (q *sqlQuerier) GetWorkspaceAgentByID(ctx context.Context, id uuid.UUID) (W
return i, err return i, err
} }
const getWorkspaceAgentByInstanceID = `-- name: GetWorkspaceAgentByInstanceID :one
SELECT
id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, instance_metadata, resource_metadata, directory, version, last_connected_replica_id, connection_timeout_seconds, troubleshooting_url, motd_file, lifecycle_state, expanded_directory, logs_length, logs_overflowed, started_at, ready_at, subsystems, display_apps, api_version, display_order, parent_id, api_key_scope, deleted
FROM
workspace_agents
WHERE
auth_instance_id = $1 :: TEXT
-- Filter out deleted sub agents.
AND deleted = FALSE
-- Filter out sub agents, they do not authenticate with auth_instance_id.
AND parent_id IS NULL
ORDER BY
created_at DESC
`
func (q *sqlQuerier) GetWorkspaceAgentByInstanceID(ctx context.Context, authInstanceID string) (WorkspaceAgent, error) {
row := q.db.QueryRowContext(ctx, getWorkspaceAgentByInstanceID, authInstanceID)
var i WorkspaceAgent
err := row.Scan(
&i.ID,
&i.CreatedAt,
&i.UpdatedAt,
&i.Name,
&i.FirstConnectedAt,
&i.LastConnectedAt,
&i.DisconnectedAt,
&i.ResourceID,
&i.AuthToken,
&i.AuthInstanceID,
&i.Architecture,
&i.EnvironmentVariables,
&i.OperatingSystem,
&i.InstanceMetadata,
&i.ResourceMetadata,
&i.Directory,
&i.Version,
&i.LastConnectedReplicaID,
&i.ConnectionTimeoutSeconds,
&i.TroubleshootingURL,
&i.MOTDFile,
&i.LifecycleState,
&i.ExpandedDirectory,
&i.LogsLength,
&i.LogsOverflowed,
&i.StartedAt,
&i.ReadyAt,
pq.Array(&i.Subsystems),
pq.Array(&i.DisplayApps),
&i.APIVersion,
&i.DisplayOrder,
&i.ParentID,
&i.APIKeyScope,
&i.Deleted,
)
return i, err
}
const getWorkspaceAgentLifecycleStateByID = `-- name: GetWorkspaceAgentLifecycleStateByID :one const getWorkspaceAgentLifecycleStateByID = `-- name: GetWorkspaceAgentLifecycleStateByID :one
SELECT SELECT
lifecycle_state, lifecycle_state,
@@ -26836,6 +26779,79 @@ func (q *sqlQuerier) GetWorkspaceAgentScriptTimingsByBuildID(ctx context.Context
return items, nil return items, nil
} }
const getWorkspaceAgentsByInstanceID = `-- name: GetWorkspaceAgentsByInstanceID :many
SELECT
id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, instance_metadata, resource_metadata, directory, version, last_connected_replica_id, connection_timeout_seconds, troubleshooting_url, motd_file, lifecycle_state, expanded_directory, logs_length, logs_overflowed, started_at, ready_at, subsystems, display_apps, api_version, display_order, parent_id, api_key_scope, deleted
FROM
workspace_agents
WHERE
auth_instance_id = $1 :: TEXT
-- Filter out deleted agents.
AND deleted = FALSE
-- Filter out sub agents, they do not authenticate with auth_instance_id.
AND parent_id IS NULL
ORDER BY
created_at DESC
`
func (q *sqlQuerier) GetWorkspaceAgentsByInstanceID(ctx context.Context, authInstanceID string) ([]WorkspaceAgent, error) {
rows, err := q.db.QueryContext(ctx, getWorkspaceAgentsByInstanceID, authInstanceID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []WorkspaceAgent
for rows.Next() {
var i WorkspaceAgent
if err := rows.Scan(
&i.ID,
&i.CreatedAt,
&i.UpdatedAt,
&i.Name,
&i.FirstConnectedAt,
&i.LastConnectedAt,
&i.DisconnectedAt,
&i.ResourceID,
&i.AuthToken,
&i.AuthInstanceID,
&i.Architecture,
&i.EnvironmentVariables,
&i.OperatingSystem,
&i.InstanceMetadata,
&i.ResourceMetadata,
&i.Directory,
&i.Version,
&i.LastConnectedReplicaID,
&i.ConnectionTimeoutSeconds,
&i.TroubleshootingURL,
&i.MOTDFile,
&i.LifecycleState,
&i.ExpandedDirectory,
&i.LogsLength,
&i.LogsOverflowed,
&i.StartedAt,
&i.ReadyAt,
pq.Array(&i.Subsystems),
pq.Array(&i.DisplayApps),
&i.APIVersion,
&i.DisplayOrder,
&i.ParentID,
&i.APIKeyScope,
&i.Deleted,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getWorkspaceAgentsByParentID = `-- name: GetWorkspaceAgentsByParentID :many const getWorkspaceAgentsByParentID = `-- name: GetWorkspaceAgentsByParentID :many
SELECT SELECT
id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, instance_metadata, resource_metadata, directory, version, last_connected_replica_id, connection_timeout_seconds, troubleshooting_url, motd_file, lifecycle_state, expanded_directory, logs_length, logs_overflowed, started_at, ready_at, subsystems, display_apps, api_version, display_order, parent_id, api_key_scope, deleted id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, instance_metadata, resource_metadata, directory, version, last_connected_replica_id, connection_timeout_seconds, troubleshooting_url, motd_file, lifecycle_state, expanded_directory, logs_length, logs_overflowed, started_at, ready_at, subsystems, display_apps, api_version, display_order, parent_id, api_key_scope, deleted
+2 -2
View File
@@ -8,14 +8,14 @@ WHERE
-- Filter out deleted sub agents. -- Filter out deleted sub agents.
AND deleted = FALSE; AND deleted = FALSE;
-- name: GetWorkspaceAgentByInstanceID :one -- name: GetWorkspaceAgentsByInstanceID :many
SELECT SELECT
* *
FROM FROM
workspace_agents workspace_agents
WHERE WHERE
auth_instance_id = @auth_instance_id :: TEXT auth_instance_id = @auth_instance_id :: TEXT
-- Filter out deleted sub agents. -- Filter out deleted agents.
AND deleted = FALSE AND deleted = FALSE
-- Filter out sub agents, they do not authenticate with auth_instance_id. -- Filter out sub agents, they do not authenticate with auth_instance_id.
AND parent_id IS NULL AND parent_id IS NULL
@@ -4286,8 +4286,10 @@ func TestInsertWorkspaceResource(t *testing.T) {
// Looking up by the parent's instance ID must still // Looking up by the parent's instance ID must still
// return the parent, not the sub-agent. // return the parent, not the sub-agent.
lookedUp, err := db.GetWorkspaceAgentByInstanceID(ctx, parentAgent.AuthInstanceID.String) agents, err := db.GetWorkspaceAgentsByInstanceID(ctx, parentAgent.AuthInstanceID.String)
require.NoError(t, err) require.NoError(t, err)
require.Len(t, agents, 1)
lookedUp := agents[0]
assert.Equal(t, parentAgent.ID, lookedUp.ID, "instance ID lookup should still return the parent agent") assert.Equal(t, parentAgent.ID, lookedUp.ID, "instance ID lookup should still return the parent agent")
}, },
}, },
+91 -23
View File
@@ -4,7 +4,10 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"sort"
"strings"
"github.com/google/uuid"
"github.com/mitchellh/mapstructure" "github.com/mitchellh/mapstructure"
"github.com/coder/coder/v2/coderd/awsidentity" "github.com/coder/coder/v2/coderd/awsidentity"
@@ -26,7 +29,7 @@ import (
// @Accept json // @Accept json
// @Produce json // @Produce json
// @Tags Agents // @Tags Agents
// @Param request body agentsdk.AzureInstanceIdentityToken true "Instance identity token" // @Param request body agentsdk.AzureInstanceIdentityToken true "Instance identity token. The optional agent_name field disambiguates when multiple agents share the same instance ID."
// @Success 200 {object} agentsdk.AuthenticateResponse // @Success 200 {object} agentsdk.AuthenticateResponse
// @Router /workspaceagents/azure-instance-identity [post] // @Router /workspaceagents/azure-instance-identity [post]
func (api *API) postWorkspaceAuthAzureInstanceIdentity(rw http.ResponseWriter, r *http.Request) { func (api *API) postWorkspaceAuthAzureInstanceIdentity(rw http.ResponseWriter, r *http.Request) {
@@ -45,7 +48,7 @@ func (api *API) postWorkspaceAuthAzureInstanceIdentity(rw http.ResponseWriter, r
}) })
return return
} }
api.handleAuthInstanceID(rw, r, instanceID) api.handleAuthInstanceID(rw, r, instanceID, req.AgentName)
} }
// AWS supports instance identity verification: // AWS supports instance identity verification:
@@ -58,7 +61,7 @@ func (api *API) postWorkspaceAuthAzureInstanceIdentity(rw http.ResponseWriter, r
// @Accept json // @Accept json
// @Produce json // @Produce json
// @Tags Agents // @Tags Agents
// @Param request body agentsdk.AWSInstanceIdentityToken true "Instance identity token" // @Param request body agentsdk.AWSInstanceIdentityToken true "Instance identity token. The optional agent_name field disambiguates when multiple agents share the same instance ID."
// @Success 200 {object} agentsdk.AuthenticateResponse // @Success 200 {object} agentsdk.AuthenticateResponse
// @Router /workspaceagents/aws-instance-identity [post] // @Router /workspaceagents/aws-instance-identity [post]
func (api *API) postWorkspaceAuthAWSInstanceIdentity(rw http.ResponseWriter, r *http.Request) { func (api *API) postWorkspaceAuthAWSInstanceIdentity(rw http.ResponseWriter, r *http.Request) {
@@ -75,7 +78,7 @@ func (api *API) postWorkspaceAuthAWSInstanceIdentity(rw http.ResponseWriter, r *
}) })
return return
} }
api.handleAuthInstanceID(rw, r, identity.InstanceID) api.handleAuthInstanceID(rw, r, identity.InstanceID, req.AgentName)
} }
// Google Compute Engine supports instance identity verification: // Google Compute Engine supports instance identity verification:
@@ -88,7 +91,7 @@ func (api *API) postWorkspaceAuthAWSInstanceIdentity(rw http.ResponseWriter, r *
// @Accept json // @Accept json
// @Produce json // @Produce json
// @Tags Agents // @Tags Agents
// @Param request body agentsdk.GoogleInstanceIdentityToken true "Instance identity token" // @Param request body agentsdk.GoogleInstanceIdentityToken true "Instance identity token. The optional agent_name field disambiguates when multiple agents share the same instance ID."
// @Success 200 {object} agentsdk.AuthenticateResponse // @Success 200 {object} agentsdk.AuthenticateResponse
// @Router /workspaceagents/google-instance-identity [post] // @Router /workspaceagents/google-instance-identity [post]
func (api *API) postWorkspaceAuthGoogleInstanceIdentity(rw http.ResponseWriter, r *http.Request) { func (api *API) postWorkspaceAuthGoogleInstanceIdentity(rw http.ResponseWriter, r *http.Request) {
@@ -122,19 +125,18 @@ func (api *API) postWorkspaceAuthGoogleInstanceIdentity(rw http.ResponseWriter,
}) })
return return
} }
api.handleAuthInstanceID(rw, r, claims.Google.ComputeEngine.InstanceID) api.handleAuthInstanceID(rw, r, claims.Google.ComputeEngine.InstanceID, req.AgentName)
} }
func (api *API) handleAuthInstanceID(rw http.ResponseWriter, r *http.Request, instanceID string) { func (api *API) handleAuthInstanceID(rw http.ResponseWriter, r *http.Request, instanceID string, agentName string) {
ctx := r.Context() ctx := r.Context()
//nolint:gocritic // needed for auth instance id // Instance identity auth happens before the agent has a session token, so
agent, err := api.Database.GetWorkspaceAgentByInstanceID(dbauthz.AsSystemRestricted(ctx), instanceID) // these lookups must use a restricted system context.
if httpapi.Is404Error(err) { //nolint:gocritic // Instance identity auth happens before agent auth.
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{ systemCtx := dbauthz.AsSystemRestricted(ctx)
Message: fmt.Sprintf("Instance with id %q not found.", instanceID), agentName = strings.TrimSpace(agentName)
})
return agents, err := api.Database.GetWorkspaceAgentsByInstanceID(systemCtx, instanceID)
}
if err != nil { if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching provisioner job agent.", Message: "Internal error fetching provisioner job agent.",
@@ -142,8 +144,77 @@ func (api *API) handleAuthInstanceID(rw http.ResponseWriter, r *http.Request, in
}) })
return return
} }
//nolint:gocritic // needed for auth instance id
resource, err := api.Database.GetWorkspaceResourceByID(dbauthz.AsSystemRestricted(ctx), agent.ResourceID) // Template version agents can share an instance ID with workspace build
// agents. Keep only workspace build agents before resolving ambiguity so
// template version agents do not force CODER_AGENT_NAME.
buildAgents := agents[:0]
for _, candidate := range agents {
resource, err := api.Database.GetWorkspaceResourceByID(systemCtx, candidate.ResourceID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching provisioner job resource.",
Detail: err.Error(),
})
return
}
job, err := api.Database.GetProvisionerJobByID(systemCtx, resource.JobID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching provisioner job.",
Detail: err.Error(),
})
return
}
if job.Type == database.ProvisionerJobTypeWorkspaceBuild {
buildAgents = append(buildAgents, candidate)
}
}
agents = buildAgents
if len(agents) == 0 {
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
Message: fmt.Sprintf("Instance with id %q not found.", instanceID),
})
return
}
var agent database.WorkspaceAgent
if agentName != "" {
for _, candidate := range agents {
if candidate.Name == agentName {
agent = candidate
break
}
}
if agent.ID == uuid.Nil {
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
Message: fmt.Sprintf("No agent found with instance ID %q and name %q.", instanceID, agentName),
})
return
}
} else {
if len(agents) != 1 {
// Include agent names in the error message to help operators
// configure CODER_AGENT_NAME. The caller has already proven
// cloud instance identity, so agent names are not sensitive
// here.
names := make([]string, len(agents))
for i, candidate := range agents {
names[i] = candidate.Name
}
sort.Strings(names)
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
Message: fmt.Sprintf(
"Multiple agents found with instance ID %q. Set CODER_AGENT_NAME to one of: %s",
instanceID,
strings.Join(names, ", "),
),
})
return
}
agent = agents[0]
}
resource, err := api.Database.GetWorkspaceResourceByID(systemCtx, agent.ResourceID)
if err != nil { if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching provisioner job resource.", Message: "Internal error fetching provisioner job resource.",
@@ -151,8 +222,7 @@ func (api *API) handleAuthInstanceID(rw http.ResponseWriter, r *http.Request, in
}) })
return return
} }
//nolint:gocritic // needed for auth instance id job, err := api.Database.GetProvisionerJobByID(systemCtx, resource.JobID)
job, err := api.Database.GetProvisionerJobByID(dbauthz.AsSystemRestricted(ctx), resource.JobID)
if err != nil { if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching provisioner job.", Message: "Internal error fetching provisioner job.",
@@ -175,8 +245,7 @@ func (api *API) handleAuthInstanceID(rw http.ResponseWriter, r *http.Request, in
}) })
return return
} }
//nolint:gocritic // needed for auth instance id resourceHistory, err := api.Database.GetWorkspaceBuildByID(systemCtx, jobData.WorkspaceBuildID)
resourceHistory, err := api.Database.GetWorkspaceBuildByID(dbauthz.AsSystemRestricted(ctx), jobData.WorkspaceBuildID)
if err != nil { if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching workspace build.", Message: "Internal error fetching workspace build.",
@@ -187,8 +256,7 @@ func (api *API) handleAuthInstanceID(rw http.ResponseWriter, r *http.Request, in
// This token should only be exchanged if the instance ID is valid // This token should only be exchanged if the instance ID is valid
// for the latest history. If an instance ID is recycled by a cloud, // for the latest history. If an instance ID is recycled by a cloud,
// we'd hate to leak access to a user's workspace. // we'd hate to leak access to a user's workspace.
//nolint:gocritic // needed for auth instance id latestHistory, err := api.Database.GetLatestWorkspaceBuildByWorkspaceID(systemCtx, resourceHistory.WorkspaceID)
latestHistory, err := api.Database.GetLatestWorkspaceBuildByWorkspaceID(dbauthz.AsSystemRestricted(ctx), resourceHistory.WorkspaceID)
if err != nil { if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching the latest workspace build.", Message: "Internal error fetching the latest workspace build.",
+347 -97
View File
@@ -2,12 +2,20 @@ package coderd_test
import ( import (
"context" "context"
"database/sql"
"encoding/json"
"fmt"
"io"
"net/http" "net/http"
"testing" "testing"
"time"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk" "github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/provisioner/echo" "github.com/coder/coder/v2/provisioner/echo"
@@ -17,96 +25,274 @@ import (
func TestPostWorkspaceAuthAzureInstanceIdentity(t *testing.T) { func TestPostWorkspaceAuthAzureInstanceIdentity(t *testing.T) {
t.Parallel() t.Parallel()
instanceID := "instanceidentifier"
certificates, metadataClient := coderdtest.NewAzureInstanceIdentity(t, instanceID)
client := coderdtest.New(t, &coderdtest.Options{
AzureCertificates: certificates,
IncludeProvisionerDaemon: true,
})
user := coderdtest.CreateFirstUser(t, client)
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
Parse: echo.ParseComplete,
ProvisionGraph: []*proto.Response{{
Type: &proto.Response_Graph{
Graph: &proto.GraphComplete{
Resources: []*proto.Resource{{
Name: "somename",
Type: "someinstance",
Agents: []*proto.Agent{{
Name: "dev",
Auth: &proto.Agent_InstanceId{
InstanceId: instanceID,
},
}},
}},
},
},
}},
})
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
workspace := coderdtest.CreateWorkspace(t, client, template.ID)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) t.Run("Success", func(t *testing.T) {
defer cancel() t.Parallel()
agentClient := agentsdk.New(client.URL, agentsdk.WithAzureInstanceIdentity()) instanceID := newTestInstanceID(t)
agentClient.SDK.HTTPClient = metadataClient certificates, metadataClient := coderdtest.NewAzureInstanceIdentity(t, instanceID)
err := agentClient.RefreshToken(ctx) client, _ := setupInstanceIDWorkspace(t, &coderdtest.Options{
require.NoError(t, err) AzureCertificates: certificates,
}, workspaceAgentsForInstanceID(instanceID, "dev"))
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
agentClient := agentsdk.New(client.URL, agentsdk.WithAzureInstanceIdentity())
agentClient.SDK.HTTPClient = metadataClient
err := agentClient.RefreshToken(ctx)
require.NoError(t, err)
})
t.Run("Ambiguous/AzureWithSelector", func(t *testing.T) {
t.Parallel()
instanceID := newTestInstanceID(t)
certificates, metadataClient := coderdtest.NewAzureInstanceIdentity(t, instanceID)
client, store := setupInstanceIDWorkspace(t, &coderdtest.Options{
AzureCertificates: certificates,
}, workspaceAgentsForInstanceID(instanceID, "alpha", "beta"))
expectedAgent := requireWorkspaceAgentByInstanceIDAndName(t, store, instanceID, "alpha")
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
agentClient := agentsdk.New(client.URL, agentsdk.WithAzureInstanceIdentity(
agentsdk.WithInstanceIdentityAgentName("alpha"),
))
agentClient.SDK.HTTPClient = metadataClient
err := agentClient.RefreshToken(ctx)
require.NoError(t, err)
require.Equal(t, expectedAgent.AuthToken.String(), agentClient.SDK.SessionToken())
})
} }
func TestPostWorkspaceAuthAWSInstanceIdentity(t *testing.T) { func TestPostWorkspaceAuthAWSInstanceIdentity(t *testing.T) {
t.Parallel() t.Parallel()
t.Run("Success", func(t *testing.T) {
t.Run("Ambiguous/SingleAgent", func(t *testing.T) {
t.Parallel() t.Parallel()
instanceID := "instanceidentifier"
instanceID := newTestInstanceID(t)
certificates, metadataClient := coderdtest.NewAWSInstanceIdentity(t, instanceID) certificates, metadataClient := coderdtest.NewAWSInstanceIdentity(t, instanceID)
client := coderdtest.New(t, &coderdtest.Options{ client, _ := setupInstanceIDWorkspace(t, &coderdtest.Options{
AWSCertificates: certificates, AWSCertificates: certificates,
IncludeProvisionerDaemon: true, }, workspaceAgentsForInstanceID(instanceID, "dev"))
})
user := coderdtest.CreateFirstUser(t, client)
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
Parse: echo.ParseComplete,
ProvisionGraph: []*proto.Response{{
Type: &proto.Response_Graph{
Graph: &proto.GraphComplete{
Resources: []*proto.Resource{{
Name: "somename",
Type: "someinstance",
Agents: []*proto.Agent{{
Name: "dev",
Auth: &proto.Agent_InstanceId{
InstanceId: instanceID,
},
}},
}},
},
},
}},
})
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
workspace := coderdtest.CreateWorkspace(t, client, template.ID)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel() defer cancel()
agentClient := agentsdk.New(client.URL, agentsdk.WithAWSInstanceIdentity()) agentClient := agentsdk.New(client.URL, agentsdk.WithAWSInstanceIdentity())
agentClient.SDK.HTTPClient = metadataClient agentClient.SDK.HTTPClient = metadataClient
err := agentClient.RefreshToken(ctx) err := agentClient.RefreshToken(ctx)
require.NoError(t, err) require.NoError(t, err)
}) })
t.Run("Ambiguous/MultipleAgentsNoSelector", func(t *testing.T) {
t.Parallel()
instanceID := newTestInstanceID(t)
certificates, metadataClient := coderdtest.NewAWSInstanceIdentity(t, instanceID)
client, _ := setupInstanceIDWorkspace(t, &coderdtest.Options{
AWSCertificates: certificates,
}, workspaceAgentsForInstanceID(instanceID, "alpha", "beta"))
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
agentClient := agentsdk.New(client.URL, agentsdk.WithAWSInstanceIdentity())
agentClient.SDK.HTTPClient = metadataClient
err := agentClient.RefreshToken(ctx)
var apiErr *codersdk.Error
require.ErrorAs(t, err, &apiErr)
require.Equal(t, http.StatusConflict, apiErr.StatusCode())
require.Contains(t, apiErr.Message, "CODER_AGENT_NAME")
require.Contains(t, apiErr.Message, "alpha, beta")
})
t.Run("Ambiguous/EmptyAgentNameTreatedAsUnset", func(t *testing.T) {
t.Parallel()
instanceID := newTestInstanceID(t)
certificates, metadataClient := coderdtest.NewAWSInstanceIdentity(t, instanceID)
client, _ := setupInstanceIDWorkspace(t, &coderdtest.Options{
AWSCertificates: certificates,
}, workspaceAgentsForInstanceID(instanceID, "alpha", "beta"))
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
signatureReq, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://169.254.169.254/latest/dynamic/instance-identity/signature", nil)
require.NoError(t, err)
signatureRes, err := metadataClient.Do(signatureReq)
require.NoError(t, err)
defer signatureRes.Body.Close()
signature, err := io.ReadAll(signatureRes.Body)
require.NoError(t, err)
documentReq, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://169.254.169.254/latest/dynamic/instance-identity/document", nil)
require.NoError(t, err)
documentRes, err := metadataClient.Do(documentReq)
require.NoError(t, err)
defer documentRes.Body.Close()
document, err := io.ReadAll(documentRes.Body)
require.NoError(t, err)
reqBody, err := json.Marshal(map[string]string{
"signature": string(signature),
"document": string(document),
"agent_name": "",
})
require.NoError(t, err)
res, err := client.RequestWithoutSessionToken(ctx, http.MethodPost, "/api/v2/workspaceagents/aws-instance-identity", reqBody)
require.NoError(t, err)
defer res.Body.Close()
require.Equal(t, http.StatusConflict, res.StatusCode)
err = codersdk.ReadBodyAsError(res)
var apiErr *codersdk.Error
require.ErrorAs(t, err, &apiErr)
require.Equal(t, http.StatusConflict, apiErr.StatusCode())
require.Contains(t, apiErr.Message, "CODER_AGENT_NAME")
require.Contains(t, apiErr.Message, "alpha, beta")
})
t.Run("Ambiguous/WhitespaceAgentNameTreatedAsUnset", func(t *testing.T) {
t.Parallel()
instanceID := newTestInstanceID(t)
certificates, metadataClient := coderdtest.NewAWSInstanceIdentity(t, instanceID)
client, _ := setupInstanceIDWorkspace(t, &coderdtest.Options{
AWSCertificates: certificates,
}, workspaceAgentsForInstanceID(instanceID, "alpha", "beta"))
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
signatureReq, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://169.254.169.254/latest/dynamic/instance-identity/signature", nil)
require.NoError(t, err)
signatureRes, err := metadataClient.Do(signatureReq)
require.NoError(t, err)
defer signatureRes.Body.Close()
signature, err := io.ReadAll(signatureRes.Body)
require.NoError(t, err)
documentReq, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://169.254.169.254/latest/dynamic/instance-identity/document", nil)
require.NoError(t, err)
documentRes, err := metadataClient.Do(documentReq)
require.NoError(t, err)
defer documentRes.Body.Close()
document, err := io.ReadAll(documentRes.Body)
require.NoError(t, err)
reqBody, err := json.Marshal(map[string]string{
"signature": string(signature),
"document": string(document),
"agent_name": " ",
})
require.NoError(t, err)
res, err := client.RequestWithoutSessionToken(ctx, http.MethodPost, "/api/v2/workspaceagents/aws-instance-identity", reqBody)
require.NoError(t, err)
defer res.Body.Close()
require.Equal(t, http.StatusConflict, res.StatusCode)
err = codersdk.ReadBodyAsError(res)
var apiErr *codersdk.Error
require.ErrorAs(t, err, &apiErr)
require.Equal(t, http.StatusConflict, apiErr.StatusCode())
require.Contains(t, apiErr.Message, "CODER_AGENT_NAME")
require.Contains(t, apiErr.Message, "alpha, beta")
})
t.Run("Ambiguous/MultipleAgentsWithSelector", func(t *testing.T) {
t.Parallel()
instanceID := newTestInstanceID(t)
certificates, metadataClient := coderdtest.NewAWSInstanceIdentity(t, instanceID)
client, store := setupInstanceIDWorkspace(t, &coderdtest.Options{
AWSCertificates: certificates,
}, workspaceAgentsForInstanceID(instanceID, "alpha", "beta"))
expectedAgent := requireWorkspaceAgentByInstanceIDAndName(t, store, instanceID, "alpha")
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
agentClient := agentsdk.New(client.URL, agentsdk.WithAWSInstanceIdentity(
agentsdk.WithInstanceIdentityAgentName("alpha"),
))
agentClient.SDK.HTTPClient = metadataClient
err := agentClient.RefreshToken(ctx)
require.NoError(t, err)
require.Equal(t, expectedAgent.AuthToken.String(), agentClient.SDK.SessionToken())
})
t.Run("Ambiguous/MultipleAgentsUnknownSelector", func(t *testing.T) {
t.Parallel()
instanceID := newTestInstanceID(t)
certificates, metadataClient := coderdtest.NewAWSInstanceIdentity(t, instanceID)
client, _ := setupInstanceIDWorkspace(t, &coderdtest.Options{
AWSCertificates: certificates,
}, workspaceAgentsForInstanceID(instanceID, "alpha", "beta"))
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
agentClient := agentsdk.New(client.URL, agentsdk.WithAWSInstanceIdentity(
agentsdk.WithInstanceIdentityAgentName("nonexistent"),
))
agentClient.SDK.HTTPClient = metadataClient
err := agentClient.RefreshToken(ctx)
var apiErr *codersdk.Error
require.ErrorAs(t, err, &apiErr)
require.Equal(t, http.StatusNotFound, apiErr.StatusCode())
})
t.Run("Ambiguous/SubAgentExcluded", func(t *testing.T) {
t.Parallel()
instanceID := newTestInstanceID(t)
certificates, metadataClient := coderdtest.NewAWSInstanceIdentity(t, instanceID)
client, store := setupInstanceIDWorkspace(t, &coderdtest.Options{
AWSCertificates: certificates,
}, workspaceAgentsForInstanceID(instanceID, "dev"))
rootAgent := requireWorkspaceAgentByInstanceIDAndName(t, store, instanceID, "dev")
_ = dbgen.WorkspaceSubAgent(t, store, rootAgent, database.WorkspaceAgent{
Name: "sub",
AuthInstanceID: sql.NullString{
String: instanceID,
Valid: true,
},
})
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
agentClient := agentsdk.New(client.URL, agentsdk.WithAWSInstanceIdentity())
agentClient.SDK.HTTPClient = metadataClient
err := agentClient.RefreshToken(ctx)
require.NoError(t, err)
require.Equal(t, rootAgent.AuthToken.String(), agentClient.SDK.SessionToken())
})
} }
func TestPostWorkspaceAuthGoogleInstanceIdentity(t *testing.T) { func TestPostWorkspaceAuthGoogleInstanceIdentity(t *testing.T) {
t.Parallel() t.Parallel()
t.Run("Expired", func(t *testing.T) { t.Run("Expired", func(t *testing.T) {
t.Parallel() t.Parallel()
instanceID := "instanceidentifier"
instanceID := newTestInstanceID(t)
validator, metadata := coderdtest.NewGoogleInstanceIdentity(t, instanceID, true) validator, metadata := coderdtest.NewGoogleInstanceIdentity(t, instanceID, true)
client := coderdtest.New(t, &coderdtest.Options{ client := coderdtest.New(t, &coderdtest.Options{
GoogleTokenValidator: validator, GoogleTokenValidator: validator,
@@ -124,7 +310,8 @@ func TestPostWorkspaceAuthGoogleInstanceIdentity(t *testing.T) {
t.Run("InstanceNotFound", func(t *testing.T) { t.Run("InstanceNotFound", func(t *testing.T) {
t.Parallel() t.Parallel()
instanceID := "instanceidentifier"
instanceID := newTestInstanceID(t)
validator, metadata := coderdtest.NewGoogleInstanceIdentity(t, instanceID, false) validator, metadata := coderdtest.NewGoogleInstanceIdentity(t, instanceID, false)
client := coderdtest.New(t, &coderdtest.Options{ client := coderdtest.New(t, &coderdtest.Options{
GoogleTokenValidator: validator, GoogleTokenValidator: validator,
@@ -142,36 +329,12 @@ func TestPostWorkspaceAuthGoogleInstanceIdentity(t *testing.T) {
t.Run("Success", func(t *testing.T) { t.Run("Success", func(t *testing.T) {
t.Parallel() t.Parallel()
instanceID := "instanceidentifier"
instanceID := newTestInstanceID(t)
validator, metadata := coderdtest.NewGoogleInstanceIdentity(t, instanceID, false) validator, metadata := coderdtest.NewGoogleInstanceIdentity(t, instanceID, false)
client := coderdtest.New(t, &coderdtest.Options{ client, _ := setupInstanceIDWorkspace(t, &coderdtest.Options{
GoogleTokenValidator: validator, GoogleTokenValidator: validator,
IncludeProvisionerDaemon: true, }, workspaceAgentsForInstanceID(instanceID, "dev"))
})
user := coderdtest.CreateFirstUser(t, client)
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
Parse: echo.ParseComplete,
ProvisionGraph: []*proto.Response{{
Type: &proto.Response_Graph{
Graph: &proto.GraphComplete{
Resources: []*proto.Resource{{
Name: "somename",
Type: "someinstance",
Agents: []*proto.Agent{{
Name: "dev",
Auth: &proto.Agent_InstanceId{
InstanceId: instanceID,
},
}},
}},
},
},
}},
})
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
workspace := coderdtest.CreateWorkspace(t, client, template.ID)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel() defer cancel()
@@ -180,4 +343,91 @@ func TestPostWorkspaceAuthGoogleInstanceIdentity(t *testing.T) {
err := agentClient.RefreshToken(ctx) err := agentClient.RefreshToken(ctx)
require.NoError(t, err) require.NoError(t, err)
}) })
t.Run("Ambiguous/GoogleWithSelector", func(t *testing.T) {
t.Parallel()
instanceID := newTestInstanceID(t)
validator, metadata := coderdtest.NewGoogleInstanceIdentity(t, instanceID, false)
client, store := setupInstanceIDWorkspace(t, &coderdtest.Options{
GoogleTokenValidator: validator,
}, workspaceAgentsForInstanceID(instanceID, "alpha", "beta"))
expectedAgent := requireWorkspaceAgentByInstanceIDAndName(t, store, instanceID, "alpha")
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
agentClient := agentsdk.New(client.URL, agentsdk.WithGoogleInstanceIdentity(
"",
metadata,
agentsdk.WithInstanceIdentityAgentName("alpha"),
))
err := agentClient.RefreshToken(ctx)
require.NoError(t, err)
require.Equal(t, expectedAgent.AuthToken.String(), agentClient.SDK.SessionToken())
})
}
func setupInstanceIDWorkspace(t *testing.T, opts *coderdtest.Options, agents []*proto.Agent) (*codersdk.Client, database.Store) {
t.Helper()
actualOpts := &coderdtest.Options{}
if opts != nil {
*actualOpts = *opts
}
actualOpts.IncludeProvisionerDaemon = true
client, store := coderdtest.NewWithDatabase(t, actualOpts)
user := coderdtest.CreateFirstUser(t, client)
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
Parse: echo.ParseComplete,
ProvisionGraph: []*proto.Response{{
Type: &proto.Response_Graph{
Graph: &proto.GraphComplete{
Resources: []*proto.Resource{{
Name: "resource",
Type: "instance",
Agents: agents,
}},
},
},
}},
})
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
workspace := coderdtest.CreateWorkspace(t, client, template.ID)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
return client, store
}
func workspaceAgentsForInstanceID(instanceID string, names ...string) []*proto.Agent {
agents := make([]*proto.Agent, 0, len(names))
for _, name := range names {
agents = append(agents, &proto.Agent{
Name: name,
Auth: &proto.Agent_InstanceId{InstanceId: instanceID},
})
}
return agents
}
func requireWorkspaceAgentByInstanceIDAndName(t testing.TB, store database.Store, instanceID string, name string) database.WorkspaceAgent {
t.Helper()
ctx := dbauthz.AsSystemRestricted(testutil.Context(t, testutil.WaitLong))
agents, err := store.GetWorkspaceAgentsByInstanceID(ctx, instanceID)
require.NoError(t, err)
for _, agent := range agents {
if agent.Name == name {
return agent
}
}
require.FailNow(t, "workspace agent not found", "instance ID %q, name %q", instanceID, name)
return database.WorkspaceAgent{}
}
func newTestInstanceID(t testing.TB) string {
t.Helper()
return fmt.Sprintf("instance-%d", time.Now().UnixNano())
} }
+27
View File
@@ -465,6 +465,33 @@ func (FixedSessionTokenProvider) RefreshToken(_ context.Context) error {
return nil return nil
} }
// InstanceIdentityConfig holds optional configuration for cloud
// instance-identity authentication.
type InstanceIdentityConfig struct {
AgentName string
}
// InstanceIdentityOption configures instance-identity authentication.
type InstanceIdentityOption func(*InstanceIdentityConfig)
// WithInstanceIdentityAgentName sets the agent name selector sent with
// the instance-identity authentication request.
func WithInstanceIdentityAgentName(name string) InstanceIdentityOption {
return func(c *InstanceIdentityConfig) {
c.AgentName = name
}
}
// applyInstanceIdentityOptions applies the given options and returns
// the resulting configuration.
func applyInstanceIdentityOptions(opts []InstanceIdentityOption) InstanceIdentityConfig {
var cfg InstanceIdentityConfig
for _, o := range opts {
o(&cfg)
}
return cfg
}
func WithFixedToken(token string) SessionTokenSetup { func WithFixedToken(token string) SessionTokenSetup {
return func(_ *codersdk.Client) RefreshableSessionTokenProvider { return func(_ *codersdk.Client) RefreshableSessionTokenProvider {
return FixedSessionTokenProvider{FixedSessionTokenProvider: codersdk.FixedSessionTokenProvider{SessionToken: token}} return FixedSessionTokenProvider{FixedSessionTokenProvider: codersdk.FixedSessionTokenProvider{SessionToken: token}}
+10 -3
View File
@@ -14,18 +14,24 @@ import (
type AWSInstanceIdentityToken struct { type AWSInstanceIdentityToken struct {
Signature string `json:"signature" validate:"required"` Signature string `json:"signature" validate:"required"`
Document string `json:"document" validate:"required"` Document string `json:"document" validate:"required"`
// AgentName optionally selects a specific agent when multiple
// agents share the same instance identity. An empty string is
// treated as unspecified.
AgentName string `json:"agent_name,omitempty"`
} }
// AWSSessionTokenExchanger exchanges AWS instance metadata for a Coder session token. // AWSSessionTokenExchanger exchanges AWS instance metadata for a Coder session token.
// @typescript-ignore AWSSessionTokenExchanger // @typescript-ignore AWSSessionTokenExchanger
type AWSSessionTokenExchanger struct { type AWSSessionTokenExchanger struct {
client *codersdk.Client client *codersdk.Client
agentName string
} }
func WithAWSInstanceIdentity() SessionTokenSetup { func WithAWSInstanceIdentity(opts ...InstanceIdentityOption) SessionTokenSetup {
cfg := applyInstanceIdentityOptions(opts)
return func(client *codersdk.Client) RefreshableSessionTokenProvider { return func(client *codersdk.Client) RefreshableSessionTokenProvider {
return &InstanceIdentitySessionTokenProvider{ return &InstanceIdentitySessionTokenProvider{
TokenExchanger: &AWSSessionTokenExchanger{client: client}, TokenExchanger: &AWSSessionTokenExchanger{client: client, agentName: cfg.AgentName},
} }
} }
} }
@@ -84,6 +90,7 @@ func (a *AWSSessionTokenExchanger) exchange(ctx context.Context) (AuthenticateRe
res, err = a.client.RequestWithoutSessionToken(ctx, http.MethodPost, "/api/v2/workspaceagents/aws-instance-identity", AWSInstanceIdentityToken{ res, err = a.client.RequestWithoutSessionToken(ctx, http.MethodPost, "/api/v2/workspaceagents/aws-instance-identity", AWSInstanceIdentityToken{
Signature: string(signature), Signature: string(signature),
Document: string(document), Document: string(document),
AgentName: a.agentName,
}) })
if err != nil { if err != nil {
return AuthenticateResponse{}, err return AuthenticateResponse{}, err
+10 -3
View File
@@ -11,18 +11,24 @@ import (
type AzureInstanceIdentityToken struct { type AzureInstanceIdentityToken struct {
Signature string `json:"signature" validate:"required"` Signature string `json:"signature" validate:"required"`
Encoding string `json:"encoding" validate:"required"` Encoding string `json:"encoding" validate:"required"`
// AgentName optionally selects a specific agent when multiple
// agents share the same instance identity. An empty string is
// treated as unspecified.
AgentName string `json:"agent_name,omitempty"`
} }
// AzureSessionTokenExchanger exchanges Azure attested metadata for a Coder session token. // AzureSessionTokenExchanger exchanges Azure attested metadata for a Coder session token.
// @typescript-ignore AzureSessionTokenExchanger // @typescript-ignore AzureSessionTokenExchanger
type AzureSessionTokenExchanger struct { type AzureSessionTokenExchanger struct {
client *codersdk.Client client *codersdk.Client
agentName string
} }
func WithAzureInstanceIdentity() SessionTokenSetup { func WithAzureInstanceIdentity(opts ...InstanceIdentityOption) SessionTokenSetup {
cfg := applyInstanceIdentityOptions(opts)
return func(client *codersdk.Client) RefreshableSessionTokenProvider { return func(client *codersdk.Client) RefreshableSessionTokenProvider {
return &InstanceIdentitySessionTokenProvider{ return &InstanceIdentitySessionTokenProvider{
TokenExchanger: &AzureSessionTokenExchanger{client: client}, TokenExchanger: &AzureSessionTokenExchanger{client: client, agentName: cfg.AgentName},
} }
} }
} }
@@ -46,6 +52,7 @@ func (a *AzureSessionTokenExchanger) exchange(ctx context.Context) (Authenticate
if err != nil { if err != nil {
return AuthenticateResponse{}, err return AuthenticateResponse{}, err
} }
token.AgentName = a.agentName
res, err = a.client.RequestWithoutSessionToken(ctx, http.MethodPost, "/api/v2/workspaceagents/azure-instance-identity", token) res, err = a.client.RequestWithoutSessionToken(ctx, http.MethodPost, "/api/v2/workspaceagents/azure-instance-identity", token)
if err != nil { if err != nil {
+9 -1
View File
@@ -14,6 +14,10 @@ import (
type GoogleInstanceIdentityToken struct { type GoogleInstanceIdentityToken struct {
JSONWebToken string `json:"json_web_token" validate:"required"` JSONWebToken string `json:"json_web_token" validate:"required"`
// AgentName optionally selects a specific agent when multiple
// agents share the same instance identity. An empty string is
// treated as unspecified.
AgentName string `json:"agent_name,omitempty"`
} }
// GoogleSessionTokenExchanger exchanges a Google instance JWT document for a Coder session token. // GoogleSessionTokenExchanger exchanges a Google instance JWT document for a Coder session token.
@@ -22,15 +26,18 @@ type GoogleSessionTokenExchanger struct {
serviceAccount string serviceAccount string
gcpClient *metadata.Client gcpClient *metadata.Client
client *codersdk.Client client *codersdk.Client
agentName string
} }
func WithGoogleInstanceIdentity(serviceAccount string, gcpClient *metadata.Client) SessionTokenSetup { func WithGoogleInstanceIdentity(serviceAccount string, gcpClient *metadata.Client, opts ...InstanceIdentityOption) SessionTokenSetup {
cfg := applyInstanceIdentityOptions(opts)
return func(client *codersdk.Client) RefreshableSessionTokenProvider { return func(client *codersdk.Client) RefreshableSessionTokenProvider {
return &InstanceIdentitySessionTokenProvider{ return &InstanceIdentitySessionTokenProvider{
TokenExchanger: &GoogleSessionTokenExchanger{ TokenExchanger: &GoogleSessionTokenExchanger{
client: client, client: client,
gcpClient: gcpClient, gcpClient: gcpClient,
serviceAccount: serviceAccount, serviceAccount: serviceAccount,
agentName: cfg.AgentName,
}, },
} }
} }
@@ -58,6 +65,7 @@ func (g *GoogleSessionTokenExchanger) exchange(ctx context.Context) (Authenticat
// request without the token to avoid re-entering this function // request without the token to avoid re-entering this function
res, err := g.client.RequestWithoutSessionToken(ctx, http.MethodPost, "/api/v2/workspaceagents/google-instance-identity", GoogleInstanceIdentityToken{ res, err := g.client.RequestWithoutSessionToken(ctx, http.MethodPost, "/api/v2/workspaceagents/google-instance-identity", GoogleInstanceIdentityToken{
JSONWebToken: jwt, JSONWebToken: jwt,
AgentName: g.agentName,
}) })
if err != nil { if err != nil {
return AuthenticateResponse{}, err return AuthenticateResponse{}, err
@@ -0,0 +1,217 @@
package agentsdk
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"cloud.google.com/go/compute/metadata"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/codersdk"
)
type roundTripFunc func(*http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}
func TestAWSInstanceIdentityExchange_AgentName(t *testing.T) {
t.Parallel()
capturedBody := runAWSInstanceIdentityExchange(t, WithInstanceIdentityAgentName("test-agent"))
assertJSONField(t, capturedBody, "agent_name", "test-agent")
}
func TestAWSInstanceIdentityExchange_OmitsAgentName(t *testing.T) {
t.Parallel()
capturedBody := runAWSInstanceIdentityExchange(t)
assertJSONFieldAbsent(t, capturedBody, "agent_name")
}
func TestAzureInstanceIdentityExchange_AgentName(t *testing.T) {
t.Parallel()
capturedBody := runAzureInstanceIdentityExchange(t, WithInstanceIdentityAgentName("test-agent"))
assertJSONField(t, capturedBody, "agent_name", "test-agent")
}
func TestAzureInstanceIdentityExchange_OmitsAgentName(t *testing.T) {
t.Parallel()
capturedBody := runAzureInstanceIdentityExchange(t)
assertJSONFieldAbsent(t, capturedBody, "agent_name")
}
func TestGoogleInstanceIdentityExchange_AgentName(t *testing.T) {
t.Parallel()
capturedBody := runGoogleInstanceIdentityExchange(t, WithInstanceIdentityAgentName("test-agent"))
assertJSONField(t, capturedBody, "agent_name", "test-agent")
}
func TestGoogleInstanceIdentityExchange_OmitsAgentName(t *testing.T) {
t.Parallel()
capturedBody := runGoogleInstanceIdentityExchange(t)
assertJSONFieldAbsent(t, capturedBody, "agent_name")
}
func runAWSInstanceIdentityExchange(t *testing.T, opts ...InstanceIdentityOption) []byte {
t.Helper()
var capturedBody []byte
server := newInstanceIdentityServer(t, "/api/v2/workspaceagents/aws-instance-identity", &capturedBody)
defer server.Close()
client := newCodersdkClient(t, server, roundTripFunc(func(req *http.Request) (*http.Response, error) {
switch {
case req.URL.Host == "169.254.169.254" && req.Method == http.MethodPut && req.URL.Path == "/latest/api/token":
return httpResponse(req, http.StatusOK, "fake-imds-token", nil), nil
case req.URL.Host == "169.254.169.254" && req.Method == http.MethodGet && req.URL.Path == "/latest/dynamic/instance-identity/signature":
return httpResponse(req, http.StatusOK, "fakesig", nil), nil
case req.URL.Host == "169.254.169.254" && req.Method == http.MethodGet && req.URL.Path == "/latest/dynamic/instance-identity/document":
return httpResponse(req, http.StatusOK, "fakedoc", nil), nil
default:
return http.DefaultTransport.RoundTrip(req)
}
}))
provider := requireInstanceIdentityProvider(t, WithAWSInstanceIdentity(opts...)(client))
resp, err := provider.TokenExchanger.exchange(context.Background())
require.NoError(t, err)
require.Equal(t, "test-session-token", resp.SessionToken)
return capturedBody
}
func runAzureInstanceIdentityExchange(t *testing.T, opts ...InstanceIdentityOption) []byte {
t.Helper()
var capturedBody []byte
server := newInstanceIdentityServer(t, "/api/v2/workspaceagents/azure-instance-identity", &capturedBody)
defer server.Close()
client := newCodersdkClient(t, server, roundTripFunc(func(req *http.Request) (*http.Response, error) {
switch {
case req.URL.Host == "169.254.169.254" && req.Method == http.MethodGet && req.URL.Path == "/metadata/attested/document":
return httpResponse(req, http.StatusOK, `{"signature":"fakesig","encoding":"fakeenc"}`, http.Header{"Content-Type": []string{"application/json"}}), nil
default:
return http.DefaultTransport.RoundTrip(req)
}
}))
provider := requireInstanceIdentityProvider(t, WithAzureInstanceIdentity(opts...)(client))
resp, err := provider.TokenExchanger.exchange(context.Background())
require.NoError(t, err)
require.Equal(t, "test-session-token", resp.SessionToken)
return capturedBody
}
func runGoogleInstanceIdentityExchange(t *testing.T, opts ...InstanceIdentityOption) []byte {
t.Helper()
var capturedBody []byte
server := newInstanceIdentityServer(t, "/api/v2/workspaceagents/google-instance-identity", &capturedBody)
defer server.Close()
metadataClient := metadata.NewClient(&http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
require.Equal(t, "169.254.169.254", req.URL.Host)
require.Equal(t, http.MethodGet, req.Method)
require.Equal(t, "/computeMetadata/v1/instance/service-accounts/test-service-account/identity", req.URL.Path)
require.Equal(t, "audience=coder&format=full", req.URL.RawQuery)
require.Equal(t, "Google", req.Header.Get("Metadata-Flavor"))
return httpResponse(req, http.StatusOK, "fake-jwt", nil), nil
})})
client := newCodersdkClient(t, server, http.DefaultTransport)
provider := requireInstanceIdentityProvider(t, WithGoogleInstanceIdentity("test-service-account", metadataClient, opts...)(client))
resp, err := provider.TokenExchanger.exchange(context.Background())
require.NoError(t, err)
require.Equal(t, "test-session-token", resp.SessionToken)
return capturedBody
}
func newInstanceIdentityServer(t *testing.T, path string, capturedBody *[]byte) *httptest.Server {
t.Helper()
return httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
require.Equal(t, http.MethodPost, req.Method)
require.Equal(t, path, req.URL.Path)
body, err := io.ReadAll(req.Body)
require.NoError(t, err)
require.NoError(t, req.Body.Close())
*capturedBody = body
rw.Header().Set("Content-Type", "application/json")
require.NoError(t, json.NewEncoder(rw).Encode(AuthenticateResponse{SessionToken: "test-session-token"}))
}))
}
func newCodersdkClient(t *testing.T, server *httptest.Server, transport http.RoundTripper) *codersdk.Client {
t.Helper()
serverURL, err := url.Parse(server.URL)
require.NoError(t, err)
return &codersdk.Client{
URL: serverURL,
HTTPClient: &http.Client{
Transport: transport,
},
}
}
func requireInstanceIdentityProvider(t *testing.T, provider RefreshableSessionTokenProvider) *InstanceIdentitySessionTokenProvider {
t.Helper()
identityProvider, ok := provider.(*InstanceIdentitySessionTokenProvider)
require.True(t, ok)
return identityProvider
}
func httpResponse(req *http.Request, statusCode int, body string, headers http.Header) *http.Response {
if headers == nil {
headers = make(http.Header)
}
return &http.Response{
StatusCode: statusCode,
Header: headers,
Body: io.NopCloser(strings.NewReader(body)),
Request: req,
}
}
func decodeJSONBody(t *testing.T, body []byte) map[string]any {
t.Helper()
var decoded map[string]any
require.NoError(t, json.Unmarshal(body, &decoded))
return decoded
}
func assertJSONField(t *testing.T, body []byte, key string, want string) {
t.Helper()
decoded := decodeJSONBody(t, body)
require.Equal(t, want, decoded[key])
}
func assertJSONFieldAbsent(t *testing.T, body []byte, key string) {
t.Helper()
decoded := decodeJSONBody(t, body)
_, ok := decoded[key]
require.False(t, ok)
}
+12 -9
View File
@@ -58,6 +58,7 @@ curl -X POST http://coder-server:8080/api/v2/workspaceagents/aws-instance-identi
```json ```json
{ {
"agent_name": "string",
"document": "string", "document": "string",
"signature": "string" "signature": "string"
} }
@@ -65,9 +66,9 @@ curl -X POST http://coder-server:8080/api/v2/workspaceagents/aws-instance-identi
### Parameters ### Parameters
| Name | In | Type | Required | Description | | Name | In | Type | Required | Description |
|--------|------|----------------------------------------------------------------------------------|----------|-------------------------| |--------|------|----------------------------------------------------------------------------------|----------|-----------------------------------------------------------------------------------------------------------------------|
| `body` | body | [agentsdk.AWSInstanceIdentityToken](schemas.md#agentsdkawsinstanceidentitytoken) | true | Instance identity token | | `body` | body | [agentsdk.AWSInstanceIdentityToken](schemas.md#agentsdkawsinstanceidentitytoken) | true | Instance identity token. The optional agent_name field disambiguates when multiple agents share the same instance ID. |
### Example responses ### Example responses
@@ -105,6 +106,7 @@ curl -X POST http://coder-server:8080/api/v2/workspaceagents/azure-instance-iden
```json ```json
{ {
"agent_name": "string",
"encoding": "string", "encoding": "string",
"signature": "string" "signature": "string"
} }
@@ -112,9 +114,9 @@ curl -X POST http://coder-server:8080/api/v2/workspaceagents/azure-instance-iden
### Parameters ### Parameters
| Name | In | Type | Required | Description | | Name | In | Type | Required | Description |
|--------|------|--------------------------------------------------------------------------------------|----------|-------------------------| |--------|------|--------------------------------------------------------------------------------------|----------|-----------------------------------------------------------------------------------------------------------------------|
| `body` | body | [agentsdk.AzureInstanceIdentityToken](schemas.md#agentsdkazureinstanceidentitytoken) | true | Instance identity token | | `body` | body | [agentsdk.AzureInstanceIdentityToken](schemas.md#agentsdkazureinstanceidentitytoken) | true | Instance identity token. The optional agent_name field disambiguates when multiple agents share the same instance ID. |
### Example responses ### Example responses
@@ -152,15 +154,16 @@ curl -X POST http://coder-server:8080/api/v2/workspaceagents/google-instance-ide
```json ```json
{ {
"agent_name": "string",
"json_web_token": "string" "json_web_token": "string"
} }
``` ```
### Parameters ### Parameters
| Name | In | Type | Required | Description | | Name | In | Type | Required | Description |
|--------|------|----------------------------------------------------------------------------------------|----------|-------------------------| |--------|------|----------------------------------------------------------------------------------------|----------|-----------------------------------------------------------------------------------------------------------------------|
| `body` | body | [agentsdk.GoogleInstanceIdentityToken](schemas.md#agentsdkgoogleinstanceidentitytoken) | true | Instance identity token | | `body` | body | [agentsdk.GoogleInstanceIdentityToken](schemas.md#agentsdkgoogleinstanceidentitytoken) | true | Instance identity token. The optional agent_name field disambiguates when multiple agents share the same instance ID. |
### Example responses ### Example responses
+17 -11
View File
@@ -4,6 +4,7 @@
```json ```json
{ {
"agent_name": "string",
"document": "string", "document": "string",
"signature": "string" "signature": "string"
} }
@@ -11,10 +12,11 @@
### Properties ### Properties
| Name | Type | Required | Restrictions | Description | | Name | Type | Required | Restrictions | Description |
|-------------|--------|----------|--------------|-------------| |--------------|--------|----------|--------------|--------------------------------------------------------------------------------------------------------------------------------------------------|
| `document` | string | true | | | | `agent_name` | string | false | | Agent name optionally selects a specific agent when multiple agents share the same instance identity. An empty string is treated as unspecified. |
| `signature` | string | true | | | | `document` | string | true | | |
| `signature` | string | true | | |
## agentsdk.AuthenticateResponse ## agentsdk.AuthenticateResponse
@@ -34,6 +36,7 @@
```json ```json
{ {
"agent_name": "string",
"encoding": "string", "encoding": "string",
"signature": "string" "signature": "string"
} }
@@ -41,10 +44,11 @@
### Properties ### Properties
| Name | Type | Required | Restrictions | Description | | Name | Type | Required | Restrictions | Description |
|-------------|--------|----------|--------------|-------------| |--------------|--------|----------|--------------|--------------------------------------------------------------------------------------------------------------------------------------------------|
| `encoding` | string | true | | | | `agent_name` | string | false | | Agent name optionally selects a specific agent when multiple agents share the same instance identity. An empty string is treated as unspecified. |
| `signature` | string | true | | | | `encoding` | string | true | | |
| `signature` | string | true | | |
## agentsdk.ExternalAuthResponse ## agentsdk.ExternalAuthResponse
@@ -90,15 +94,17 @@
```json ```json
{ {
"agent_name": "string",
"json_web_token": "string" "json_web_token": "string"
} }
``` ```
### Properties ### Properties
| Name | Type | Required | Restrictions | Description | | Name | Type | Required | Restrictions | Description |
|------------------|--------|----------|--------------|-------------| |------------------|--------|----------|--------------|--------------------------------------------------------------------------------------------------------------------------------------------------|
| `json_web_token` | string | true | | | | `agent_name` | string | false | | Agent name optionally selects a specific agent when multiple agents share the same instance identity. An empty string is treated as unspecified. |
| `json_web_token` | string | true | | |
## agentsdk.Log ## agentsdk.Log
+9
View File
@@ -77,3 +77,12 @@ URL for an agent to access your deployment.
| Default | <code>token</code> | | Default | <code>token</code> |
Specify the authentication type to use for the agent. Specify the authentication type to use for the agent.
### --agent-name
| | |
|-------------|--------------------------------|
| Type | <code>string</code> |
| Environment | <code>$CODER_AGENT_NAME</code> |
The name of the agent to authenticate as (only applicable for instance identity).
@@ -0,0 +1,81 @@
---
display_name: AWS EC2 Multi-Agent Instance Identity
description: Verify AWS instance identity auth for two Coder agents on one EC2 instance
icon: ../../../site/static/icon/aws.svg
maintainer_github: coder
verified: true
tags: [vm, linux, aws, multi-agent, instance-identity]
---
# AWS multi-agent instance identity verification
This template verifies the multi-agent instance-identity authentication flow on
AWS. It provisions a single EC2 instance with two peer root workspace agents,
`main` and `dev`, that both use AWS instance identity authentication.
The key behavior under test is `CODER_AGENT_NAME` disambiguation. Each agent
starts on the same VM with the same EC2 instance identity, but sets a distinct
`CODER_AGENT_NAME` so the Coder server can issue a separate session token for
that specific agent.
## Prerequisites
- AWS credentials configured for Terraform, such as environment variables or an
attached IAM role.
- A Coder deployment that includes the multi-agent instance-auth changes from
this branch.
- No special Coder server configuration. AWS instance identity certificates are
built in.
## What this template creates
- One VPC, subnet, internet gateway, route table, and route table association.
- One security group that allows SSH from anywhere for test access.
- One Ubuntu 24.04 EC2 instance.
- Two Coder agents, `main` and `dev`, on that single EC2 instance.
- Two agent startup flows that set `CODER_AGENT_NAME` before launching the
corresponding agent init script.
## How to verify
```bash
cd examples/templates/aws-multi-agent
coder templates push verify-multi-agent
coder create test-multi-agent --template verify-multi-agent
coder list
```
After the workspace starts, verify that both agents are connected in the Coder
Dashboard for `test-multi-agent`. You can also connect to each agent directly:
```bash
coder ssh test-multi-agent -a main true
coder ssh test-multi-agent -a dev true
```
## Expected behavior
- Both agents authenticate independently using AWS instance identity.
- Each agent receives its own session token.
- The workspace shows two connected agents in the Coder Dashboard.
- If `CODER_AGENT_NAME` is omitted, the server should return `409 Conflict`
because the shared instance identity is ambiguous.
## Troubleshooting
- If one agent gets `409 Conflict`, `CODER_AGENT_NAME` is not being set
correctly for that agent.
- If both agents fail, instance identity authentication is not working. Check
EC2 metadata service access from the instance.
- Check cloud-init logs with `journalctl -u cloud-init`.
- Check agent logs at `/tmp/coder-agent-main.log` and
`/tmp/coder-agent-dev.log`.
## Cleanup
```bash
coder delete test-multi-agent
coder templates delete verify-multi-agent
```
@@ -0,0 +1,18 @@
#!/bin/bash
set -euo pipefail
# Create the user if it doesn't exist.
if ! id -u "${linux_user}" >/dev/null 2>&1; then
useradd -m -s /bin/bash "${linux_user}"
fi
# Start main agent with disambiguation name.
CODER_AGENT_NAME=main sudo -u '${linux_user}' sh -c '${main_init_script}' \
>/tmp/coder-agent-main.log 2>&1 &
# Start dev agent with disambiguation name.
CODER_AGENT_NAME=dev sudo -u '${linux_user}' sh -c '${dev_init_script}' \
>/tmp/coder-agent-dev.log 2>&1 &
# Wait for both agent processes to start.
wait
+340
View File
@@ -0,0 +1,340 @@
terraform {
required_providers {
coder = {
source = "coder/coder"
}
aws = {
source = "hashicorp/aws"
}
cloudinit = {
source = "hashicorp/cloudinit"
}
}
}
# Last updated 2023-03-14
# aws ec2 describe-regions | jq -r '[.Regions[].RegionName] | sort'
data "coder_parameter" "region" {
name = "region"
display_name = "Region"
description = "The region to deploy the workspace in."
default = "us-east-1"
mutable = false
option {
name = "Asia Pacific (Tokyo)"
value = "ap-northeast-1"
icon = "/emojis/1f1ef-1f1f5.png"
}
option {
name = "Asia Pacific (Seoul)"
value = "ap-northeast-2"
icon = "/emojis/1f1f0-1f1f7.png"
}
option {
name = "Asia Pacific (Osaka)"
value = "ap-northeast-3"
icon = "/emojis/1f1ef-1f1f5.png"
}
option {
name = "Asia Pacific (Mumbai)"
value = "ap-south-1"
icon = "/emojis/1f1ee-1f1f3.png"
}
option {
name = "Asia Pacific (Singapore)"
value = "ap-southeast-1"
icon = "/emojis/1f1f8-1f1ec.png"
}
option {
name = "Asia Pacific (Sydney)"
value = "ap-southeast-2"
icon = "/emojis/1f1e6-1f1fa.png"
}
option {
name = "Canada (Central)"
value = "ca-central-1"
icon = "/emojis/1f1e8-1f1e6.png"
}
option {
name = "EU (Frankfurt)"
value = "eu-central-1"
icon = "/emojis/1f1ea-1f1fa.png"
}
option {
name = "EU (Stockholm)"
value = "eu-north-1"
icon = "/emojis/1f1ea-1f1fa.png"
}
option {
name = "EU (Ireland)"
value = "eu-west-1"
icon = "/emojis/1f1ea-1f1fa.png"
}
option {
name = "EU (London)"
value = "eu-west-2"
icon = "/emojis/1f1ea-1f1fa.png"
}
option {
name = "EU (Paris)"
value = "eu-west-3"
icon = "/emojis/1f1ea-1f1fa.png"
}
option {
name = "South America (São Paulo)"
value = "sa-east-1"
icon = "/emojis/1f1e7-1f1f7.png"
}
option {
name = "US East (N. Virginia)"
value = "us-east-1"
icon = "/emojis/1f1fa-1f1f8.png"
}
option {
name = "US East (Ohio)"
value = "us-east-2"
icon = "/emojis/1f1fa-1f1f8.png"
}
option {
name = "US West (N. California)"
value = "us-west-1"
icon = "/emojis/1f1fa-1f1f8.png"
}
option {
name = "US West (Oregon)"
value = "us-west-2"
icon = "/emojis/1f1fa-1f1f8.png"
}
}
data "coder_parameter" "instance_type" {
name = "instance_type"
display_name = "Instance type"
description = "What instance type should your workspace use?"
default = "t3.micro"
mutable = false
option {
name = "2 vCPU, 1 GiB RAM"
value = "t3.micro"
}
option {
name = "2 vCPU, 2 GiB RAM"
value = "t3.small"
}
option {
name = "2 vCPU, 4 GiB RAM"
value = "t3.medium"
}
option {
name = "2 vCPU, 8 GiB RAM"
value = "t3.large"
}
option {
name = "4 vCPU, 16 GiB RAM"
value = "t3.xlarge"
}
option {
name = "8 vCPU, 32 GiB RAM"
value = "t3.2xlarge"
}
}
provider "aws" {
region = data.coder_parameter.region.value
}
data "coder_workspace" "me" {}
data "coder_workspace_owner" "me" {}
data "aws_ami" "ubuntu" {
most_recent = true
filter {
name = "name"
values = ["ubuntu/images/hvm-ssd-gp3/ubuntu-noble-24.04-amd64-server-*"]
}
filter {
name = "virtualization-type"
values = ["hvm"]
}
owners = ["099720109477"] # Canonical
}
resource "coder_agent" "main" {
count = data.coder_workspace.me.start_count
os = "linux"
arch = "amd64"
auth = "aws-instance-identity"
startup_script = <<-EOT
#!/bin/bash
set -e
echo "Agent 'main' started successfully"
echo "CODER_AGENT_NAME=$CODER_AGENT_NAME"
EOT
metadata {
key = "agent-identity"
display_name = "Agent Identity"
interval = 60
timeout = 5
script = "echo main"
}
}
resource "coder_agent" "dev" {
count = data.coder_workspace.me.start_count
os = "linux"
arch = "amd64"
auth = "aws-instance-identity"
startup_script = <<-EOT
#!/bin/bash
set -e
echo "Agent 'dev' started successfully"
echo "CODER_AGENT_NAME=$CODER_AGENT_NAME"
EOT
metadata {
key = "agent-identity"
display_name = "Agent Identity"
interval = 60
timeout = 5
script = "echo dev"
}
}
locals {
aws_availability_zone = "${data.coder_parameter.region.value}a"
hostname = lower(data.coder_workspace.me.name)
linux_user = "coder"
}
data "cloudinit_config" "user_data" {
gzip = false
base64_encode = false
boundary = "//"
part {
filename = "userdata.sh"
content_type = "text/x-shellscript"
content = templatefile("${path.module}/cloud-init/userdata.sh.tftpl", {
linux_user = local.linux_user
main_init_script = try(coder_agent.main[0].init_script, "")
dev_init_script = try(coder_agent.dev[0].init_script, "")
})
}
}
resource "aws_vpc" "workspace" {
cidr_block = "10.0.0.0/16"
enable_dns_hostnames = true
enable_dns_support = true
tags = {
Name = "coder-${data.coder_workspace_owner.me.name}-${local.hostname}"
}
}
resource "aws_subnet" "workspace" {
vpc_id = aws_vpc.workspace.id
cidr_block = "10.0.1.0/24"
availability_zone = local.aws_availability_zone
map_public_ip_on_launch = true
tags = {
Name = "coder-${data.coder_workspace_owner.me.name}-${local.hostname}"
}
}
resource "aws_internet_gateway" "workspace" {
vpc_id = aws_vpc.workspace.id
tags = {
Name = "coder-${data.coder_workspace_owner.me.name}-${local.hostname}"
}
}
resource "aws_route_table" "workspace" {
vpc_id = aws_vpc.workspace.id
route {
cidr_block = "0.0.0.0/0"
gateway_id = aws_internet_gateway.workspace.id
}
tags = {
Name = "coder-${data.coder_workspace_owner.me.name}-${local.hostname}"
}
}
resource "aws_route_table_association" "workspace" {
subnet_id = aws_subnet.workspace.id
route_table_id = aws_route_table.workspace.id
}
resource "aws_security_group" "workspace" {
name_prefix = "coder-${local.hostname}-"
description = "Allow SSH access for testing."
vpc_id = aws_vpc.workspace.id
ingress {
description = "SSH"
from_port = 22
to_port = 22
protocol = "tcp"
cidr_blocks = ["0.0.0.0/0"]
}
egress {
from_port = 0
to_port = 0
protocol = "-1"
cidr_blocks = ["0.0.0.0/0"]
}
tags = {
Name = "coder-${data.coder_workspace_owner.me.name}-${local.hostname}"
}
}
resource "aws_instance" "dev" {
ami = data.aws_ami.ubuntu.id
availability_zone = local.aws_availability_zone
instance_type = data.coder_parameter.instance_type.value
subnet_id = aws_subnet.workspace.id
vpc_security_group_ids = [aws_security_group.workspace.id]
associate_public_ip_address = true
user_data = data.cloudinit_config.user_data.rendered
tags = {
Name = "coder-${data.coder_workspace_owner.me.name}-${data.coder_workspace.me.name}"
# Required if you are using our example policy, see template README
Coder_Provisioned = "true"
}
lifecycle {
ignore_changes = [ami]
}
depends_on = [aws_route_table_association.workspace]
}
resource "coder_metadata" "workspace_info" {
resource_id = aws_instance.dev.id
item {
key = "region"
value = data.coder_parameter.region.value
}
item {
key = "instance type"
value = aws_instance.dev.instance_type
}
item {
key = "ami"
value = aws_instance.dev.ami
}
}
resource "aws_ec2_instance_state" "dev" {
instance_id = aws_instance.dev.id
state = data.coder_workspace.me.transition == "start" ? "running" : "stopped"
}