From e5707a13d6a04bd6f33da14b2deafd609b798983 Mon Sep 17 00:00:00 2001 From: Michael Suchacz <203725896+ibetitsmike@users.noreply.github.com> Date: Thu, 16 Apr 2026 13:59:09 +0200 Subject: [PATCH] feat: support multiple agents with shared instance-identity auth (#24325) > This PR was authored by Mux on behalf of Mike. ## Summary Adds support for multiple peer root workspace agents sharing the same `auth_instance_id`, so AWS, Azure, and GCP instance-identity auth can issue the correct session token for a selected agent instead of assuming a single root agent per instance. ## Problem When a Terraform template attaches two or more `coder_agent` resources (with `auth = "aws-instance-identity"`) to a single compute instance, every agent shares the same cloud instance ID. The existing singular lookup picks whichever agent was created most recently, silently ignoring the others. ## Solution Introduce an optional pre-auth agent selector (`CODER_AGENT_NAME`) and make the server-side lookup ambiguity-aware. **Database layer:** - `GetWorkspaceAgentsByInstanceID` (`:many`): returns all matching root agents for an instance ID. - `GetWorkspaceAgentByInstanceIDAndName` (`:one`): returns the named root agent for disambiguation. **SDK and CLI:** - `agent_name` field added to AWS, Azure, and GCP request structs (`omitempty` for backward compatibility). - `CODER_AGENT_NAME` env var and `--agent-name` flag wired into the agent bootstrap before instance-identity auth runs. **Server handler (`handleAuthInstanceID`):** - When `agent_name` is present: direct lookup by (instance ID, name). - When absent: legacy lookup, then resource-scoped ambiguity check. Returns 409 with available agent names if multiple root agents match. - Whitespace-only names are trimmed and treated as unspecified. - Sub-agents remain excluded (`parent_id IS NULL` filter). **Verification template:** - `examples/templates/aws-multi-agent/` provisions one EC2 instance with two agents (`main` and `dev`), both using instance-identity auth with `CODER_AGENT_NAME` set in the cloud-init user data. ## Backward compatibility Existing single-agent deployments work unchanged. The `agent_name` field is optional with `omitempty`, and the unnamed path preserves today's behavior when only one root agent matches. --- cli/root.go | 20 +- cli/root_test.go | 63 +++ cli/testdata/coder_agent_--help.golden | 4 + ...r_external-auth_access-token_--help.golden | 4 + coderd/agentapi/subagent_test.go | 4 +- coderd/apidoc/docs.go | 18 +- coderd/apidoc/swagger.json | 18 +- coderd/database/dbauthz/dbauthz.go | 43 +- coderd/database/dbauthz/dbauthz_test.go | 9 +- coderd/database/dbmetrics/querymetrics.go | 16 +- coderd/database/dbmock/dbmock.go | 30 +- coderd/database/querier.go | 2 +- coderd/database/querier_test.go | 145 +++++- coderd/database/queries.sql.go | 130 ++--- coderd/database/queries/workspaceagents.sql | 4 +- .../provisionerdserver_test.go | 4 +- coderd/workspaceresourceauth.go | 114 ++++- coderd/workspaceresourceauth_test.go | 444 ++++++++++++++---- codersdk/agentsdk/agentsdk.go | 27 ++ codersdk/agentsdk/aws.go | 13 +- codersdk/agentsdk/azure.go | 13 +- codersdk/agentsdk/google.go | 10 +- .../instanceidentity_internal_test.go | 217 +++++++++ docs/reference/api/agents.md | 21 +- docs/reference/api/schemas.md | 28 +- .../cli/external-auth_access-token.md | 9 + examples/templates/aws-multi-agent/README.md | 81 ++++ .../cloud-init/userdata.sh.tftpl | 18 + examples/templates/aws-multi-agent/main.tf | 340 ++++++++++++++ 29 files changed, 1563 insertions(+), 286 deletions(-) create mode 100644 codersdk/agentsdk/instanceidentity_internal_test.go create mode 100644 examples/templates/aws-multi-agent/README.md create mode 100644 examples/templates/aws-multi-agent/cloud-init/userdata.sh.tftpl create mode 100644 examples/templates/aws-multi-agent/main.tf diff --git a/cli/root.go b/cli/root.go index 5bda9a416f..8af6b3c96a 100644 --- a/cli/root.go +++ b/cli/root.go @@ -86,6 +86,7 @@ const ( envAgentTokenFile = "CODER_AGENT_TOKEN_FILE" envAgentURL = "CODER_AGENT_URL" envAgentAuth = "CODER_AGENT_AUTH" + envAgentName = "CODER_AGENT_NAME" envURL = "CODER_URL" ) @@ -789,6 +790,7 @@ type AgentAuth struct { agentTokenFile string agentURL url.URL agentAuth string + agentName string } func (a *AgentAuth) AttachOptions(cmd *serpent.Command, hidden bool) { @@ -821,6 +823,13 @@ func (a *AgentAuth) AttachOptions(cmd *serpent.Command, hidden bool) { Default: "token", Value: serpent.StringOf(&a.agentAuth), Hidden: hidden, + }, serpent.Option{ + Name: "Agent Name", + Description: "The name of the agent to authenticate as (only applicable for instance identity).", + Flag: "agent-name", + Env: envAgentName, + Value: serpent.StringOf(&a.agentName), + Hidden: hidden, }) } @@ -832,6 +841,11 @@ func (a *AgentAuth) CreateClient() (*agentsdk.Client, error) { return nil, xerrors.Errorf("%s must be set", envAgentURL) } + var iiOpts []agentsdk.InstanceIdentityOption + if a.agentName != "" { + iiOpts = append(iiOpts, agentsdk.WithInstanceIdentityAgentName(a.agentName)) + } + switch a.agentAuth { case "token": token := a.agentToken @@ -850,11 +864,11 @@ func (a *AgentAuth) CreateClient() (*agentsdk.Client, error) { } return agentsdk.New(&a.agentURL, agentsdk.WithFixedToken(token)), nil case "google-instance-identity": - return agentsdk.New(&a.agentURL, agentsdk.WithGoogleInstanceIdentity("", nil)), nil + return agentsdk.New(&a.agentURL, agentsdk.WithGoogleInstanceIdentity("", nil, iiOpts...)), nil case "aws-instance-identity": - return agentsdk.New(&a.agentURL, agentsdk.WithAWSInstanceIdentity()), nil + return agentsdk.New(&a.agentURL, agentsdk.WithAWSInstanceIdentity(iiOpts...)), nil case "azure-instance-identity": - return agentsdk.New(&a.agentURL, agentsdk.WithAzureInstanceIdentity()), nil + return agentsdk.New(&a.agentURL, agentsdk.WithAzureInstanceIdentity(iiOpts...)), nil default: return nil, xerrors.Errorf("unknown agent auth type: %s", a.agentAuth) } diff --git a/cli/root_test.go b/cli/root_test.go index 10642d6c99..3aab248dec 100644 --- a/cli/root_test.go +++ b/cli/root_test.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http" "net/http/httptest" + "reflect" "runtime" "strings" "sync/atomic" @@ -346,6 +347,68 @@ func TestCreateAgentClient_Azure(t *testing.T) { require.IsType(t, &agentsdk.AzureSessionTokenExchanger{}, provider.TokenExchanger) } +func TestCreateAgentClient_GoogleAgentName(t *testing.T) { + t.Parallel() + + client := createAgentWithFlags(t, + "--auth", "google-instance-identity", + "--agent-url", "http://coder.fake", + "--agent-name", "google-agent") + requireInstanceIdentityAgentName(t, client, &agentsdk.GoogleSessionTokenExchanger{}, "google-agent") +} + +func TestCreateAgentClient_AWSAgentName(t *testing.T) { + t.Parallel() + + client := createAgentWithFlags(t, + "--auth", "aws-instance-identity", + "--agent-url", "http://coder.fake", + "--agent-name", "aws-agent") + requireInstanceIdentityAgentName(t, client, &agentsdk.AWSSessionTokenExchanger{}, "aws-agent") +} + +func TestCreateAgentClient_AzureAgentName(t *testing.T) { + t.Parallel() + + client := createAgentWithFlags(t, + "--auth", "azure-instance-identity", + "--agent-url", "http://coder.fake", + "--agent-name", "azure-agent") + requireInstanceIdentityAgentName(t, client, &agentsdk.AzureSessionTokenExchanger{}, "azure-agent") +} + +func TestCreateAgentClient_GoogleAgentNameEnv(t *testing.T) { + t.Parallel() + + r := &cli.RootCmd{} + var client *agentsdk.Client + subCmd := agentClientCommand(&client) + cmd, err := r.Command([]*serpent.Command{subCmd}) + require.NoError(t, err) + inv, _ := clitest.NewWithCommand(t, cmd, + "agent-client", + "--auth", "google-instance-identity", + "--agent-url", "http://coder.fake") + inv.Environ.Set("CODER_AGENT_NAME", "env-agent") + err = inv.Run() + require.NoError(t, err) + require.NotNil(t, client) + requireInstanceIdentityAgentName(t, client, &agentsdk.GoogleSessionTokenExchanger{}, "env-agent") +} + +func requireInstanceIdentityAgentName(t *testing.T, client *agentsdk.Client, expectedExchanger any, want string) { + t.Helper() + + provider, ok := client.RefreshableSessionTokenProvider.(*agentsdk.InstanceIdentitySessionTokenProvider) + require.True(t, ok) + require.NotNil(t, provider.TokenExchanger) + require.IsType(t, expectedExchanger, provider.TokenExchanger) + + agentNameField := reflect.ValueOf(provider.TokenExchanger).Elem().FieldByName("agentName") + require.True(t, agentNameField.IsValid()) + require.Equal(t, want, agentNameField.String()) +} + func createAgentWithFlags(t *testing.T, flags ...string) *agentsdk.Client { t.Helper() r := &cli.RootCmd{} diff --git a/cli/testdata/coder_agent_--help.golden b/cli/testdata/coder_agent_--help.golden index 2e39e2fac5..e153548a60 100644 --- a/cli/testdata/coder_agent_--help.golden +++ b/cli/testdata/coder_agent_--help.golden @@ -9,6 +9,10 @@ OPTIONS: --auth string, $CODER_AGENT_AUTH (default: token) Specify the authentication type to use for the agent. + --agent-name string, $CODER_AGENT_NAME + The name of the agent to authenticate as (only applicable for instance + identity). + --agent-token string, $CODER_AGENT_TOKEN An agent authentication token. diff --git a/cli/testdata/coder_external-auth_access-token_--help.golden b/cli/testdata/coder_external-auth_access-token_--help.golden index 234cca5d4f..ce11b0a8a7 100644 --- a/cli/testdata/coder_external-auth_access-token_--help.golden +++ b/cli/testdata/coder_external-auth_access-token_--help.golden @@ -28,6 +28,10 @@ OPTIONS: --auth string, $CODER_AGENT_AUTH (default: token) Specify the authentication type to use for the agent. + --agent-name string, $CODER_AGENT_NAME + The name of the agent to authenticate as (only applicable for instance + identity). + --agent-token string, $CODER_AGENT_TOKEN An agent authentication token. diff --git a/coderd/agentapi/subagent_test.go b/coderd/agentapi/subagent_test.go index 78219aabe7..a7217cc513 100644 --- a/coderd/agentapi/subagent_test.go +++ b/coderd/agentapi/subagent_test.go @@ -213,8 +213,10 @@ func TestSubAgentAPI(t *testing.T) { // Double-check: looking up by the parent's instance ID must // still return the parent, not the sub-agent. - lookedUp, err := db.GetWorkspaceAgentByInstanceID(dbauthz.AsSystemRestricted(ctx), parentAgent.AuthInstanceID.String) + agents, err := db.GetWorkspaceAgentsByInstanceID(dbauthz.AsSystemRestricted(ctx), parentAgent.AuthInstanceID.String) require.NoError(t, err) + require.Len(t, agents, 1) + lookedUp := agents[0] assert.Equal(t, parentAgent.ID, lookedUp.ID, "instance ID lookup should still return the parent agent") }) diff --git a/coderd/apidoc/docs.go b/coderd/apidoc/docs.go index 3c37594835..95bff0ecf5 100644 --- a/coderd/apidoc/docs.go +++ b/coderd/apidoc/docs.go @@ -10096,7 +10096,7 @@ const docTemplate = `{ "operationId": "authenticate-agent-on-aws-instance", "parameters": [ { - "description": "Instance identity token", + "description": "Instance identity token. The optional agent_name field disambiguates when multiple agents share the same instance ID.", "name": "request", "in": "body", "required": true, @@ -10135,7 +10135,7 @@ const docTemplate = `{ "operationId": "authenticate-agent-on-azure-instance", "parameters": [ { - "description": "Instance identity token", + "description": "Instance identity token. The optional agent_name field disambiguates when multiple agents share the same instance ID.", "name": "request", "in": "body", "required": true, @@ -10202,7 +10202,7 @@ const docTemplate = `{ "operationId": "authenticate-agent-on-google-cloud-instance", "parameters": [ { - "description": "Instance identity token", + "description": "Instance identity token. The optional agent_name field disambiguates when multiple agents share the same instance ID.", "name": "request", "in": "body", "required": true, @@ -12780,6 +12780,10 @@ const docTemplate = `{ "signature" ], "properties": { + "agent_name": { + "description": "AgentName optionally selects a specific agent when multiple\nagents share the same instance identity. An empty string is\ntreated as unspecified.", + "type": "string" + }, "document": { "type": "string" }, @@ -12803,6 +12807,10 @@ const docTemplate = `{ "signature" ], "properties": { + "agent_name": { + "description": "AgentName optionally selects a specific agent when multiple\nagents share the same instance identity. An empty string is\ntreated as unspecified.", + "type": "string" + }, "encoding": { "type": "string" }, @@ -12853,6 +12861,10 @@ const docTemplate = `{ "json_web_token" ], "properties": { + "agent_name": { + "description": "AgentName optionally selects a specific agent when multiple\nagents share the same instance identity. An empty string is\ntreated as unspecified.", + "type": "string" + }, "json_web_token": { "type": "string" } diff --git a/coderd/apidoc/swagger.json b/coderd/apidoc/swagger.json index 175fdd892c..ecb1e47eb4 100644 --- a/coderd/apidoc/swagger.json +++ b/coderd/apidoc/swagger.json @@ -8949,7 +8949,7 @@ "operationId": "authenticate-agent-on-aws-instance", "parameters": [ { - "description": "Instance identity token", + "description": "Instance identity token. The optional agent_name field disambiguates when multiple agents share the same instance ID.", "name": "request", "in": "body", "required": true, @@ -8982,7 +8982,7 @@ "operationId": "authenticate-agent-on-azure-instance", "parameters": [ { - "description": "Instance identity token", + "description": "Instance identity token. The optional agent_name field disambiguates when multiple agents share the same instance ID.", "name": "request", "in": "body", "required": true, @@ -9039,7 +9039,7 @@ "operationId": "authenticate-agent-on-google-cloud-instance", "parameters": [ { - "description": "Instance identity token", + "description": "Instance identity token. The optional agent_name field disambiguates when multiple agents share the same instance ID.", "name": "request", "in": "body", "required": true, @@ -11337,6 +11337,10 @@ "type": "object", "required": ["document", "signature"], "properties": { + "agent_name": { + "description": "AgentName optionally selects a specific agent when multiple\nagents share the same instance identity. An empty string is\ntreated as unspecified.", + "type": "string" + }, "document": { "type": "string" }, @@ -11357,6 +11361,10 @@ "type": "object", "required": ["encoding", "signature"], "properties": { + "agent_name": { + "description": "AgentName optionally selects a specific agent when multiple\nagents share the same instance identity. An empty string is\ntreated as unspecified.", + "type": "string" + }, "encoding": { "type": "string" }, @@ -11405,6 +11413,10 @@ "type": "object", "required": ["json_web_token"], "properties": { + "agent_name": { + "description": "AgentName optionally selects a specific agent when multiple\nagents share the same instance identity. An empty string is\ntreated as unspecified.", + "type": "string" + }, "json_web_token": { "type": "string" } diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 7afbea6df8..8d9a95d909 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -4422,22 +4422,6 @@ func (q *querier) GetWorkspaceAgentByID(ctx context.Context, id uuid.UUID) (data return q.db.GetWorkspaceAgentByID(ctx, id) } -// GetWorkspaceAgentByInstanceID might want to be a system call? Unsure exactly, -// but this will fail. Need to figure out what AuthInstanceID is, and if it -// is essentially an auth token. But the caller using this function is not -// an authenticated user. So this authz check will fail. -func (q *querier) GetWorkspaceAgentByInstanceID(ctx context.Context, authInstanceID string) (database.WorkspaceAgent, error) { - agent, err := q.db.GetWorkspaceAgentByInstanceID(ctx, authInstanceID) - if err != nil { - return database.WorkspaceAgent{}, err - } - _, err = q.GetWorkspaceByAgentID(ctx, agent.ID) - if err != nil { - return database.WorkspaceAgent{}, err - } - return agent, nil -} - func (q *querier) GetWorkspaceAgentDevcontainersByAgentID(ctx context.Context, workspaceAgentID uuid.UUID) ([]database.WorkspaceAgentDevcontainer, error) { _, err := q.GetWorkspaceAgentByID(ctx, workspaceAgentID) if err != nil { @@ -4527,6 +4511,33 @@ func (q *querier) GetWorkspaceAgentUsageStatsAndLabels(ctx context.Context, crea return q.db.GetWorkspaceAgentUsageStatsAndLabels(ctx, createdAt) } +func (q *querier) GetWorkspaceAgentsByInstanceID(ctx context.Context, authInstanceID string) ([]database.WorkspaceAgent, error) { + if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err == nil { + return q.db.GetWorkspaceAgentsByInstanceID(ctx, authInstanceID) + } + + agents, err := q.db.GetWorkspaceAgentsByInstanceID(ctx, authInstanceID) + if err != nil { + return nil, err + } + // Filter to agents whose workspace is accessible. Template-version + // agents can share the same instance ID but do not belong to a + // workspace, so GetWorkspaceByAgentID returns sql.ErrNoRows for + // them. Exclude those agents rather than failing the entire lookup. + filtered := make([]database.WorkspaceAgent, 0, len(agents)) + for _, agent := range agents { + _, err = q.GetWorkspaceByAgentID(ctx, agent.ID) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + continue + } + return nil, err + } + filtered = append(filtered, agent) + } + return filtered, nil +} + func (q *querier) GetWorkspaceAgentsByParentID(ctx context.Context, parentID uuid.UUID) ([]database.WorkspaceAgent, error) { workspace, err := q.db.GetWorkspaceByAgentID(ctx, parentID) if err != nil { diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 645a6ce643..0777373c79 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -3012,13 +3012,16 @@ func (s *MethodTestSuite) TestWorkspace() { dbm.EXPECT().BatchUpdateWorkspaceAgentMetadata(gomock.Any(), arg).Return(nil).AnyTimes() check.Args(arg).Asserts(rbac.ResourceWorkspace.All(), policy.ActionUpdate).Returns() })) - s.Run("GetWorkspaceAgentByInstanceID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + s.Run("GetWorkspaceAgentsByInstanceID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { w := testutil.Fake(s.T(), faker, database.Workspace{}) agt := testutil.Fake(s.T(), faker, database.WorkspaceAgent{}) authInstanceID := "instance-id" - dbm.EXPECT().GetWorkspaceAgentByInstanceID(gomock.Any(), authInstanceID).Return(agt, nil).AnyTimes() + dbm.EXPECT().GetWorkspaceAgentsByInstanceID(gomock.Any(), authInstanceID).Return([]database.WorkspaceAgent{agt}, nil).AnyTimes() dbm.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agt.ID).Return(w, nil).AnyTimes() - check.Args(authInstanceID).Asserts(w, policy.ActionRead).Returns(agt) + check.Args(authInstanceID). + Asserts(rbac.ResourceSystem, policy.ActionRead, w, policy.ActionRead). + Returns([]database.WorkspaceAgent{agt}). + FailSystemObjectChecks() })) s.Run("UpdateWorkspaceAgentLifecycleStateByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { w := testutil.Fake(s.T(), faker, database.Workspace{}) diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index 6c312614b7..2a1ffac7ad 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -2864,14 +2864,6 @@ func (m queryMetricsStore) GetWorkspaceAgentByID(ctx context.Context, id uuid.UU return r0, r1 } -func (m queryMetricsStore) GetWorkspaceAgentByInstanceID(ctx context.Context, authInstanceID string) (database.WorkspaceAgent, error) { - start := time.Now() - r0, r1 := m.s.GetWorkspaceAgentByInstanceID(ctx, authInstanceID) - m.queryLatencies.WithLabelValues("GetWorkspaceAgentByInstanceID").Observe(time.Since(start).Seconds()) - m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetWorkspaceAgentByInstanceID").Inc() - return r0, r1 -} - func (m queryMetricsStore) GetWorkspaceAgentDevcontainersByAgentID(ctx context.Context, workspaceAgentID uuid.UUID) ([]database.WorkspaceAgentDevcontainer, error) { start := time.Now() r0, r1 := m.s.GetWorkspaceAgentDevcontainersByAgentID(ctx, workspaceAgentID) @@ -2968,6 +2960,14 @@ func (m queryMetricsStore) GetWorkspaceAgentUsageStatsAndLabels(ctx context.Cont return r0, r1 } +func (m queryMetricsStore) GetWorkspaceAgentsByInstanceID(ctx context.Context, authInstanceID string) ([]database.WorkspaceAgent, error) { + start := time.Now() + r0, r1 := m.s.GetWorkspaceAgentsByInstanceID(ctx, authInstanceID) + m.queryLatencies.WithLabelValues("GetWorkspaceAgentsByInstanceID").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetWorkspaceAgentsByInstanceID").Inc() + return r0, r1 +} + func (m queryMetricsStore) GetWorkspaceAgentsByParentID(ctx context.Context, parentID uuid.UUID) ([]database.WorkspaceAgent, error) { start := time.Now() r0, r1 := m.s.GetWorkspaceAgentsByParentID(ctx, parentID) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index 76c7fc1c77..ace93212e0 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -5357,21 +5357,6 @@ func (mr *MockStoreMockRecorder) GetWorkspaceAgentByID(ctx, id any) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWorkspaceAgentByID", reflect.TypeOf((*MockStore)(nil).GetWorkspaceAgentByID), ctx, id) } -// GetWorkspaceAgentByInstanceID mocks base method. -func (m *MockStore) GetWorkspaceAgentByInstanceID(ctx context.Context, authInstanceID string) (database.WorkspaceAgent, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetWorkspaceAgentByInstanceID", ctx, authInstanceID) - ret0, _ := ret[0].(database.WorkspaceAgent) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetWorkspaceAgentByInstanceID indicates an expected call of GetWorkspaceAgentByInstanceID. -func (mr *MockStoreMockRecorder) GetWorkspaceAgentByInstanceID(ctx, authInstanceID any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWorkspaceAgentByInstanceID", reflect.TypeOf((*MockStore)(nil).GetWorkspaceAgentByInstanceID), ctx, authInstanceID) -} - // GetWorkspaceAgentDevcontainersByAgentID mocks base method. func (m *MockStore) GetWorkspaceAgentDevcontainersByAgentID(ctx context.Context, workspaceAgentID uuid.UUID) ([]database.WorkspaceAgentDevcontainer, error) { m.ctrl.T.Helper() @@ -5552,6 +5537,21 @@ func (mr *MockStoreMockRecorder) GetWorkspaceAgentUsageStatsAndLabels(ctx, creat return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWorkspaceAgentUsageStatsAndLabels", reflect.TypeOf((*MockStore)(nil).GetWorkspaceAgentUsageStatsAndLabels), ctx, createdAt) } +// GetWorkspaceAgentsByInstanceID mocks base method. +func (m *MockStore) GetWorkspaceAgentsByInstanceID(ctx context.Context, authInstanceID string) ([]database.WorkspaceAgent, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetWorkspaceAgentsByInstanceID", ctx, authInstanceID) + ret0, _ := ret[0].([]database.WorkspaceAgent) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetWorkspaceAgentsByInstanceID indicates an expected call of GetWorkspaceAgentsByInstanceID. +func (mr *MockStoreMockRecorder) GetWorkspaceAgentsByInstanceID(ctx, authInstanceID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWorkspaceAgentsByInstanceID", reflect.TypeOf((*MockStore)(nil).GetWorkspaceAgentsByInstanceID), ctx, authInstanceID) +} + // GetWorkspaceAgentsByParentID mocks base method. func (m *MockStore) GetWorkspaceAgentsByParentID(ctx context.Context, parentID uuid.UUID) ([]database.WorkspaceAgent, error) { m.ctrl.T.Helper() diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 19451b0852..7c508dce71 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -683,7 +683,6 @@ type sqlcQuerier interface { GetWorkspaceACLByID(ctx context.Context, id uuid.UUID) (GetWorkspaceACLByIDRow, error) GetWorkspaceAgentAndWorkspaceByID(ctx context.Context, id uuid.UUID) (GetWorkspaceAgentAndWorkspaceByIDRow, error) GetWorkspaceAgentByID(ctx context.Context, id uuid.UUID) (WorkspaceAgent, error) - GetWorkspaceAgentByInstanceID(ctx context.Context, authInstanceID string) (WorkspaceAgent, error) GetWorkspaceAgentDevcontainersByAgentID(ctx context.Context, workspaceAgentID uuid.UUID) ([]WorkspaceAgentDevcontainer, error) GetWorkspaceAgentLifecycleStateByID(ctx context.Context, id uuid.UUID) (GetWorkspaceAgentLifecycleStateByIDRow, error) GetWorkspaceAgentLogSourcesByAgentIDs(ctx context.Context, ids []uuid.UUID) ([]WorkspaceAgentLogSource, error) @@ -697,6 +696,7 @@ type sqlcQuerier interface { // `minute_buckets` could return 0 rows if there are no usage stats since `created_at`. GetWorkspaceAgentUsageStats(ctx context.Context, createdAt time.Time) ([]GetWorkspaceAgentUsageStatsRow, error) GetWorkspaceAgentUsageStatsAndLabels(ctx context.Context, createdAt time.Time) ([]GetWorkspaceAgentUsageStatsAndLabelsRow, error) + GetWorkspaceAgentsByInstanceID(ctx context.Context, authInstanceID string) ([]WorkspaceAgent, error) GetWorkspaceAgentsByParentID(ctx context.Context, parentID uuid.UUID) ([]WorkspaceAgent, error) GetWorkspaceAgentsByResourceIDs(ctx context.Context, ids []uuid.UUID) ([]WorkspaceAgent, error) GetWorkspaceAgentsByWorkspaceAndBuildNumber(ctx context.Context, arg GetWorkspaceAgentsByWorkspaceAndBuildNumberParams) ([]WorkspaceAgent, error) diff --git a/coderd/database/querier_test.go b/coderd/database/querier_test.go index f2b4cdab1b..881e800f0b 100644 --- a/coderd/database/querier_test.go +++ b/coderd/database/querier_test.go @@ -7184,38 +7184,55 @@ func TestGetWorkspaceAgentsByParentID(t *testing.T) { }) } -func TestGetWorkspaceAgentByInstanceID(t *testing.T) { +func setupWorkspaceAgentQueryResources(t *testing.T, db database.Store, count int) []database.WorkspaceResource { + t.Helper() + + org := dbgen.Organization(t, db, database.Organization{}) + job := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{ + Type: database.ProvisionerJobTypeTemplateVersionImport, + OrganizationID: org.ID, + }) + + resources := make([]database.WorkspaceResource, 0, count) + for i := 0; i < count; i++ { + resources = append(resources, dbgen.WorkspaceResource(t, db, database.WorkspaceResource{ + JobID: job.ID, + })) + } + + return resources +} + +func markWorkspaceAgentDeleted(ctx context.Context, t *testing.T, sqlDB *sql.DB, agentID uuid.UUID) { + t.Helper() + + _, err := sqlDB.ExecContext(ctx, "UPDATE workspace_agents SET deleted = TRUE WHERE id = $1", agentID) + require.NoError(t, err) +} + +func TestGetWorkspaceAgentsByInstanceID(t *testing.T) { t.Parallel() - // Context: https://github.com/coder/coder/pull/22196 - t.Run("DoesNotReturnSubAgents", func(t *testing.T) { + t.Run("ReturnsAllMatchingRootAgents", func(t *testing.T) { t.Parallel() - // Given: A parent workspace agent with an AuthInstanceID and a - // sub-agent that shares the same AuthInstanceID. db, _ := dbtestutil.NewDB(t) - org := dbgen.Organization(t, db, database.Organization{}) - job := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{ - Type: database.ProvisionerJobTypeTemplateVersionImport, - OrganizationID: org.ID, - }) - resource := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{ - JobID: job.ID, - }) - + resources := setupWorkspaceAgentQueryResources(t, db, 2) authInstanceID := fmt.Sprintf("instance-%s-%d", t.Name(), time.Now().UnixNano()) - parentAgent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{ - ResourceID: resource.ID, + olderCreatedAt := dbtime.Now().Add(-time.Hour) + newerCreatedAt := dbtime.Now() + + olderAgent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{ + ResourceID: resources[0].ID, + CreatedAt: olderCreatedAt, AuthInstanceID: sql.NullString{ String: authInstanceID, Valid: true, }, }) - // Create a sub-agent with the same AuthInstanceID (simulating - // the old behavior before the fix). - _ = dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{ - ParentID: uuid.NullUUID{UUID: parentAgent.ID, Valid: true}, - ResourceID: resource.ID, + newerAgent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{ + ResourceID: resources[1].ID, + CreatedAt: newerCreatedAt, AuthInstanceID: sql.NullString{ String: authInstanceID, Valid: true, @@ -7224,13 +7241,89 @@ func TestGetWorkspaceAgentByInstanceID(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) - // When: We look up the agent by instance ID. - agent, err := db.GetWorkspaceAgentByInstanceID(ctx, authInstanceID) + agents, err := db.GetWorkspaceAgentsByInstanceID(ctx, authInstanceID) require.NoError(t, err) + require.Len(t, agents, 2) + assert.Equal(t, []uuid.UUID{newerAgent.ID, olderAgent.ID}, []uuid.UUID{agents[0].ID, agents[1].ID}) + }) - // Then: The result must be the parent agent, not the sub-agent. - assert.Equal(t, parentAgent.ID, agent.ID, "instance ID lookup should return the parent agent, not a sub-agent") - assert.False(t, agent.ParentID.Valid, "returned agent should not have a parent (should be the parent itself)") + t.Run("ExcludesDeletedAndSubAgents", func(t *testing.T) { + t.Parallel() + + db, _, sqlDB := dbtestutil.NewDBWithSQLDB(t) + resources := setupWorkspaceAgentQueryResources(t, db, 2) + authInstanceID := fmt.Sprintf("instance-%s-%d", t.Name(), time.Now().UnixNano()) + baseCreatedAt := dbtime.Now() + + rootAgent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{ + ResourceID: resources[0].ID, + CreatedAt: baseCreatedAt.Add(-time.Hour), + AuthInstanceID: sql.NullString{ + String: authInstanceID, + Valid: true, + }, + }) + _ = dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{ + ParentID: uuid.NullUUID{UUID: rootAgent.ID, Valid: true}, + ResourceID: resources[0].ID, + CreatedAt: baseCreatedAt, + AuthInstanceID: sql.NullString{ + String: authInstanceID, + Valid: true, + }, + }) + deletedRootAgent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{ + ResourceID: resources[1].ID, + CreatedAt: baseCreatedAt.Add(time.Minute), + AuthInstanceID: sql.NullString{ + String: authInstanceID, + Valid: true, + }, + }) + + ctx := testutil.Context(t, testutil.WaitShort) + markWorkspaceAgentDeleted(ctx, t, sqlDB, deletedRootAgent.ID) + + agents, err := db.GetWorkspaceAgentsByInstanceID(ctx, authInstanceID) + require.NoError(t, err) + require.Len(t, agents, 1) + assert.Equal(t, rootAgent.ID, agents[0].ID) + assert.False(t, agents[0].ParentID.Valid) + }) + + t.Run("OrdersNewestFirst", func(t *testing.T) { + t.Parallel() + + db, _ := dbtestutil.NewDB(t) + resources := setupWorkspaceAgentQueryResources(t, db, 2) + authInstanceID := fmt.Sprintf("instance-%s-%d", t.Name(), time.Now().UnixNano()) + olderCreatedAt := dbtime.Now().Add(-time.Hour) + newerCreatedAt := dbtime.Now() + + olderAgent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{ + ResourceID: resources[0].ID, + CreatedAt: olderCreatedAt, + AuthInstanceID: sql.NullString{ + String: authInstanceID, + Valid: true, + }, + }) + newerAgent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{ + ResourceID: resources[1].ID, + CreatedAt: newerCreatedAt, + AuthInstanceID: sql.NullString{ + String: authInstanceID, + Valid: true, + }, + }) + + ctx := testutil.Context(t, testutil.WaitShort) + + agents, err := db.GetWorkspaceAgentsByInstanceID(ctx, authInstanceID) + require.NoError(t, err) + require.Len(t, agents, 2) + assert.Equal(t, newerAgent.ID, agents[0].ID) + assert.Equal(t, olderAgent.ID, agents[1].ID) }) } diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 6447b37dfe..f6cfc65021 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -26566,63 +26566,6 @@ func (q *sqlQuerier) GetWorkspaceAgentByID(ctx context.Context, id uuid.UUID) (W return i, err } -const getWorkspaceAgentByInstanceID = `-- name: GetWorkspaceAgentByInstanceID :one -SELECT - id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, instance_metadata, resource_metadata, directory, version, last_connected_replica_id, connection_timeout_seconds, troubleshooting_url, motd_file, lifecycle_state, expanded_directory, logs_length, logs_overflowed, started_at, ready_at, subsystems, display_apps, api_version, display_order, parent_id, api_key_scope, deleted -FROM - workspace_agents -WHERE - auth_instance_id = $1 :: TEXT - -- Filter out deleted sub agents. - AND deleted = FALSE - -- Filter out sub agents, they do not authenticate with auth_instance_id. - AND parent_id IS NULL -ORDER BY - created_at DESC -` - -func (q *sqlQuerier) GetWorkspaceAgentByInstanceID(ctx context.Context, authInstanceID string) (WorkspaceAgent, error) { - row := q.db.QueryRowContext(ctx, getWorkspaceAgentByInstanceID, authInstanceID) - var i WorkspaceAgent - err := row.Scan( - &i.ID, - &i.CreatedAt, - &i.UpdatedAt, - &i.Name, - &i.FirstConnectedAt, - &i.LastConnectedAt, - &i.DisconnectedAt, - &i.ResourceID, - &i.AuthToken, - &i.AuthInstanceID, - &i.Architecture, - &i.EnvironmentVariables, - &i.OperatingSystem, - &i.InstanceMetadata, - &i.ResourceMetadata, - &i.Directory, - &i.Version, - &i.LastConnectedReplicaID, - &i.ConnectionTimeoutSeconds, - &i.TroubleshootingURL, - &i.MOTDFile, - &i.LifecycleState, - &i.ExpandedDirectory, - &i.LogsLength, - &i.LogsOverflowed, - &i.StartedAt, - &i.ReadyAt, - pq.Array(&i.Subsystems), - pq.Array(&i.DisplayApps), - &i.APIVersion, - &i.DisplayOrder, - &i.ParentID, - &i.APIKeyScope, - &i.Deleted, - ) - return i, err -} - const getWorkspaceAgentLifecycleStateByID = `-- name: GetWorkspaceAgentLifecycleStateByID :one SELECT lifecycle_state, @@ -26836,6 +26779,79 @@ func (q *sqlQuerier) GetWorkspaceAgentScriptTimingsByBuildID(ctx context.Context return items, nil } +const getWorkspaceAgentsByInstanceID = `-- name: GetWorkspaceAgentsByInstanceID :many +SELECT + id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, instance_metadata, resource_metadata, directory, version, last_connected_replica_id, connection_timeout_seconds, troubleshooting_url, motd_file, lifecycle_state, expanded_directory, logs_length, logs_overflowed, started_at, ready_at, subsystems, display_apps, api_version, display_order, parent_id, api_key_scope, deleted +FROM + workspace_agents +WHERE + auth_instance_id = $1 :: TEXT + -- Filter out deleted agents. + AND deleted = FALSE + -- Filter out sub agents, they do not authenticate with auth_instance_id. + AND parent_id IS NULL +ORDER BY + created_at DESC +` + +func (q *sqlQuerier) GetWorkspaceAgentsByInstanceID(ctx context.Context, authInstanceID string) ([]WorkspaceAgent, error) { + rows, err := q.db.QueryContext(ctx, getWorkspaceAgentsByInstanceID, authInstanceID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []WorkspaceAgent + for rows.Next() { + var i WorkspaceAgent + if err := rows.Scan( + &i.ID, + &i.CreatedAt, + &i.UpdatedAt, + &i.Name, + &i.FirstConnectedAt, + &i.LastConnectedAt, + &i.DisconnectedAt, + &i.ResourceID, + &i.AuthToken, + &i.AuthInstanceID, + &i.Architecture, + &i.EnvironmentVariables, + &i.OperatingSystem, + &i.InstanceMetadata, + &i.ResourceMetadata, + &i.Directory, + &i.Version, + &i.LastConnectedReplicaID, + &i.ConnectionTimeoutSeconds, + &i.TroubleshootingURL, + &i.MOTDFile, + &i.LifecycleState, + &i.ExpandedDirectory, + &i.LogsLength, + &i.LogsOverflowed, + &i.StartedAt, + &i.ReadyAt, + pq.Array(&i.Subsystems), + pq.Array(&i.DisplayApps), + &i.APIVersion, + &i.DisplayOrder, + &i.ParentID, + &i.APIKeyScope, + &i.Deleted, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getWorkspaceAgentsByParentID = `-- name: GetWorkspaceAgentsByParentID :many SELECT id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, instance_metadata, resource_metadata, directory, version, last_connected_replica_id, connection_timeout_seconds, troubleshooting_url, motd_file, lifecycle_state, expanded_directory, logs_length, logs_overflowed, started_at, ready_at, subsystems, display_apps, api_version, display_order, parent_id, api_key_scope, deleted diff --git a/coderd/database/queries/workspaceagents.sql b/coderd/database/queries/workspaceagents.sql index 830ab57835..b75fb61b15 100644 --- a/coderd/database/queries/workspaceagents.sql +++ b/coderd/database/queries/workspaceagents.sql @@ -8,14 +8,14 @@ WHERE -- Filter out deleted sub agents. AND deleted = FALSE; --- name: GetWorkspaceAgentByInstanceID :one +-- name: GetWorkspaceAgentsByInstanceID :many SELECT * FROM workspace_agents WHERE auth_instance_id = @auth_instance_id :: TEXT - -- Filter out deleted sub agents. + -- Filter out deleted agents. AND deleted = FALSE -- Filter out sub agents, they do not authenticate with auth_instance_id. AND parent_id IS NULL diff --git a/coderd/provisionerdserver/provisionerdserver_test.go b/coderd/provisionerdserver/provisionerdserver_test.go index 5bf8d9f0e8..007c26cb18 100644 --- a/coderd/provisionerdserver/provisionerdserver_test.go +++ b/coderd/provisionerdserver/provisionerdserver_test.go @@ -4286,8 +4286,10 @@ func TestInsertWorkspaceResource(t *testing.T) { // Looking up by the parent's instance ID must still // return the parent, not the sub-agent. - lookedUp, err := db.GetWorkspaceAgentByInstanceID(ctx, parentAgent.AuthInstanceID.String) + agents, err := db.GetWorkspaceAgentsByInstanceID(ctx, parentAgent.AuthInstanceID.String) require.NoError(t, err) + require.Len(t, agents, 1) + lookedUp := agents[0] assert.Equal(t, parentAgent.ID, lookedUp.ID, "instance ID lookup should still return the parent agent") }, }, diff --git a/coderd/workspaceresourceauth.go b/coderd/workspaceresourceauth.go index c8608ea03c..f414c3a828 100644 --- a/coderd/workspaceresourceauth.go +++ b/coderd/workspaceresourceauth.go @@ -4,7 +4,10 @@ import ( "encoding/json" "fmt" "net/http" + "sort" + "strings" + "github.com/google/uuid" "github.com/mitchellh/mapstructure" "github.com/coder/coder/v2/coderd/awsidentity" @@ -26,7 +29,7 @@ import ( // @Accept json // @Produce json // @Tags Agents -// @Param request body agentsdk.AzureInstanceIdentityToken true "Instance identity token" +// @Param request body agentsdk.AzureInstanceIdentityToken true "Instance identity token. The optional agent_name field disambiguates when multiple agents share the same instance ID." // @Success 200 {object} agentsdk.AuthenticateResponse // @Router /workspaceagents/azure-instance-identity [post] func (api *API) postWorkspaceAuthAzureInstanceIdentity(rw http.ResponseWriter, r *http.Request) { @@ -45,7 +48,7 @@ func (api *API) postWorkspaceAuthAzureInstanceIdentity(rw http.ResponseWriter, r }) return } - api.handleAuthInstanceID(rw, r, instanceID) + api.handleAuthInstanceID(rw, r, instanceID, req.AgentName) } // AWS supports instance identity verification: @@ -58,7 +61,7 @@ func (api *API) postWorkspaceAuthAzureInstanceIdentity(rw http.ResponseWriter, r // @Accept json // @Produce json // @Tags Agents -// @Param request body agentsdk.AWSInstanceIdentityToken true "Instance identity token" +// @Param request body agentsdk.AWSInstanceIdentityToken true "Instance identity token. The optional agent_name field disambiguates when multiple agents share the same instance ID." // @Success 200 {object} agentsdk.AuthenticateResponse // @Router /workspaceagents/aws-instance-identity [post] func (api *API) postWorkspaceAuthAWSInstanceIdentity(rw http.ResponseWriter, r *http.Request) { @@ -75,7 +78,7 @@ func (api *API) postWorkspaceAuthAWSInstanceIdentity(rw http.ResponseWriter, r * }) return } - api.handleAuthInstanceID(rw, r, identity.InstanceID) + api.handleAuthInstanceID(rw, r, identity.InstanceID, req.AgentName) } // Google Compute Engine supports instance identity verification: @@ -88,7 +91,7 @@ func (api *API) postWorkspaceAuthAWSInstanceIdentity(rw http.ResponseWriter, r * // @Accept json // @Produce json // @Tags Agents -// @Param request body agentsdk.GoogleInstanceIdentityToken true "Instance identity token" +// @Param request body agentsdk.GoogleInstanceIdentityToken true "Instance identity token. The optional agent_name field disambiguates when multiple agents share the same instance ID." // @Success 200 {object} agentsdk.AuthenticateResponse // @Router /workspaceagents/google-instance-identity [post] func (api *API) postWorkspaceAuthGoogleInstanceIdentity(rw http.ResponseWriter, r *http.Request) { @@ -122,19 +125,18 @@ func (api *API) postWorkspaceAuthGoogleInstanceIdentity(rw http.ResponseWriter, }) return } - api.handleAuthInstanceID(rw, r, claims.Google.ComputeEngine.InstanceID) + api.handleAuthInstanceID(rw, r, claims.Google.ComputeEngine.InstanceID, req.AgentName) } -func (api *API) handleAuthInstanceID(rw http.ResponseWriter, r *http.Request, instanceID string) { +func (api *API) handleAuthInstanceID(rw http.ResponseWriter, r *http.Request, instanceID string, agentName string) { ctx := r.Context() - //nolint:gocritic // needed for auth instance id - agent, err := api.Database.GetWorkspaceAgentByInstanceID(dbauthz.AsSystemRestricted(ctx), instanceID) - if httpapi.Is404Error(err) { - httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{ - Message: fmt.Sprintf("Instance with id %q not found.", instanceID), - }) - return - } + // Instance identity auth happens before the agent has a session token, so + // these lookups must use a restricted system context. + //nolint:gocritic // Instance identity auth happens before agent auth. + systemCtx := dbauthz.AsSystemRestricted(ctx) + agentName = strings.TrimSpace(agentName) + + agents, err := api.Database.GetWorkspaceAgentsByInstanceID(systemCtx, instanceID) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error fetching provisioner job agent.", @@ -142,8 +144,77 @@ func (api *API) handleAuthInstanceID(rw http.ResponseWriter, r *http.Request, in }) return } - //nolint:gocritic // needed for auth instance id - resource, err := api.Database.GetWorkspaceResourceByID(dbauthz.AsSystemRestricted(ctx), agent.ResourceID) + + // Template version agents can share an instance ID with workspace build + // agents. Keep only workspace build agents before resolving ambiguity so + // template version agents do not force CODER_AGENT_NAME. + buildAgents := agents[:0] + for _, candidate := range agents { + resource, err := api.Database.GetWorkspaceResourceByID(systemCtx, candidate.ResourceID) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error fetching provisioner job resource.", + Detail: err.Error(), + }) + return + } + job, err := api.Database.GetProvisionerJobByID(systemCtx, resource.JobID) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error fetching provisioner job.", + Detail: err.Error(), + }) + return + } + if job.Type == database.ProvisionerJobTypeWorkspaceBuild { + buildAgents = append(buildAgents, candidate) + } + } + agents = buildAgents + if len(agents) == 0 { + httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{ + Message: fmt.Sprintf("Instance with id %q not found.", instanceID), + }) + return + } + + var agent database.WorkspaceAgent + if agentName != "" { + for _, candidate := range agents { + if candidate.Name == agentName { + agent = candidate + break + } + } + if agent.ID == uuid.Nil { + httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{ + Message: fmt.Sprintf("No agent found with instance ID %q and name %q.", instanceID, agentName), + }) + return + } + } else { + if len(agents) != 1 { + // Include agent names in the error message to help operators + // configure CODER_AGENT_NAME. The caller has already proven + // cloud instance identity, so agent names are not sensitive + // here. + names := make([]string, len(agents)) + for i, candidate := range agents { + names[i] = candidate.Name + } + sort.Strings(names) + httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{ + Message: fmt.Sprintf( + "Multiple agents found with instance ID %q. Set CODER_AGENT_NAME to one of: %s", + instanceID, + strings.Join(names, ", "), + ), + }) + return + } + agent = agents[0] + } + resource, err := api.Database.GetWorkspaceResourceByID(systemCtx, agent.ResourceID) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error fetching provisioner job resource.", @@ -151,8 +222,7 @@ func (api *API) handleAuthInstanceID(rw http.ResponseWriter, r *http.Request, in }) return } - //nolint:gocritic // needed for auth instance id - job, err := api.Database.GetProvisionerJobByID(dbauthz.AsSystemRestricted(ctx), resource.JobID) + job, err := api.Database.GetProvisionerJobByID(systemCtx, resource.JobID) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error fetching provisioner job.", @@ -175,8 +245,7 @@ func (api *API) handleAuthInstanceID(rw http.ResponseWriter, r *http.Request, in }) return } - //nolint:gocritic // needed for auth instance id - resourceHistory, err := api.Database.GetWorkspaceBuildByID(dbauthz.AsSystemRestricted(ctx), jobData.WorkspaceBuildID) + resourceHistory, err := api.Database.GetWorkspaceBuildByID(systemCtx, jobData.WorkspaceBuildID) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error fetching workspace build.", @@ -187,8 +256,7 @@ func (api *API) handleAuthInstanceID(rw http.ResponseWriter, r *http.Request, in // This token should only be exchanged if the instance ID is valid // for the latest history. If an instance ID is recycled by a cloud, // we'd hate to leak access to a user's workspace. - //nolint:gocritic // needed for auth instance id - latestHistory, err := api.Database.GetLatestWorkspaceBuildByWorkspaceID(dbauthz.AsSystemRestricted(ctx), resourceHistory.WorkspaceID) + latestHistory, err := api.Database.GetLatestWorkspaceBuildByWorkspaceID(systemCtx, resourceHistory.WorkspaceID) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error fetching the latest workspace build.", diff --git a/coderd/workspaceresourceauth_test.go b/coderd/workspaceresourceauth_test.go index 5282adb0fb..0b95b267a0 100644 --- a/coderd/workspaceresourceauth_test.go +++ b/coderd/workspaceresourceauth_test.go @@ -2,12 +2,20 @@ package coderd_test import ( "context" + "database/sql" + "encoding/json" + "fmt" + "io" "net/http" "testing" + "time" "github.com/stretchr/testify/require" "github.com/coder/coder/v2/coderd/coderdtest" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbauthz" + "github.com/coder/coder/v2/coderd/database/dbgen" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/agentsdk" "github.com/coder/coder/v2/provisioner/echo" @@ -17,96 +25,274 @@ import ( func TestPostWorkspaceAuthAzureInstanceIdentity(t *testing.T) { t.Parallel() - instanceID := "instanceidentifier" - certificates, metadataClient := coderdtest.NewAzureInstanceIdentity(t, instanceID) - client := coderdtest.New(t, &coderdtest.Options{ - AzureCertificates: certificates, - IncludeProvisionerDaemon: true, - }) - user := coderdtest.CreateFirstUser(t, client) - version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{ - Parse: echo.ParseComplete, - ProvisionGraph: []*proto.Response{{ - Type: &proto.Response_Graph{ - Graph: &proto.GraphComplete{ - Resources: []*proto.Resource{{ - Name: "somename", - Type: "someinstance", - Agents: []*proto.Agent{{ - Name: "dev", - Auth: &proto.Agent_InstanceId{ - InstanceId: instanceID, - }, - }}, - }}, - }, - }, - }}, - }) - template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) - coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID) - workspace := coderdtest.CreateWorkspace(t, client, template.ID) - coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() + t.Run("Success", func(t *testing.T) { + t.Parallel() - agentClient := agentsdk.New(client.URL, agentsdk.WithAzureInstanceIdentity()) - agentClient.SDK.HTTPClient = metadataClient - err := agentClient.RefreshToken(ctx) - require.NoError(t, err) + instanceID := newTestInstanceID(t) + certificates, metadataClient := coderdtest.NewAzureInstanceIdentity(t, instanceID) + client, _ := setupInstanceIDWorkspace(t, &coderdtest.Options{ + AzureCertificates: certificates, + }, workspaceAgentsForInstanceID(instanceID, "dev")) + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + agentClient := agentsdk.New(client.URL, agentsdk.WithAzureInstanceIdentity()) + agentClient.SDK.HTTPClient = metadataClient + + err := agentClient.RefreshToken(ctx) + require.NoError(t, err) + }) + + t.Run("Ambiguous/AzureWithSelector", func(t *testing.T) { + t.Parallel() + + instanceID := newTestInstanceID(t) + certificates, metadataClient := coderdtest.NewAzureInstanceIdentity(t, instanceID) + client, store := setupInstanceIDWorkspace(t, &coderdtest.Options{ + AzureCertificates: certificates, + }, workspaceAgentsForInstanceID(instanceID, "alpha", "beta")) + + expectedAgent := requireWorkspaceAgentByInstanceIDAndName(t, store, instanceID, "alpha") + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + agentClient := agentsdk.New(client.URL, agentsdk.WithAzureInstanceIdentity( + agentsdk.WithInstanceIdentityAgentName("alpha"), + )) + agentClient.SDK.HTTPClient = metadataClient + + err := agentClient.RefreshToken(ctx) + require.NoError(t, err) + require.Equal(t, expectedAgent.AuthToken.String(), agentClient.SDK.SessionToken()) + }) } func TestPostWorkspaceAuthAWSInstanceIdentity(t *testing.T) { t.Parallel() - t.Run("Success", func(t *testing.T) { + + t.Run("Ambiguous/SingleAgent", func(t *testing.T) { t.Parallel() - instanceID := "instanceidentifier" + + instanceID := newTestInstanceID(t) certificates, metadataClient := coderdtest.NewAWSInstanceIdentity(t, instanceID) - client := coderdtest.New(t, &coderdtest.Options{ - AWSCertificates: certificates, - IncludeProvisionerDaemon: true, - }) - user := coderdtest.CreateFirstUser(t, client) - version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{ - Parse: echo.ParseComplete, - ProvisionGraph: []*proto.Response{{ - Type: &proto.Response_Graph{ - Graph: &proto.GraphComplete{ - Resources: []*proto.Resource{{ - Name: "somename", - Type: "someinstance", - Agents: []*proto.Agent{{ - Name: "dev", - Auth: &proto.Agent_InstanceId{ - InstanceId: instanceID, - }, - }}, - }}, - }, - }, - }}, - }) - template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) - coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID) - workspace := coderdtest.CreateWorkspace(t, client, template.ID) - coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) + client, _ := setupInstanceIDWorkspace(t, &coderdtest.Options{ + AWSCertificates: certificates, + }, workspaceAgentsForInstanceID(instanceID, "dev")) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() agentClient := agentsdk.New(client.URL, agentsdk.WithAWSInstanceIdentity()) agentClient.SDK.HTTPClient = metadataClient + err := agentClient.RefreshToken(ctx) require.NoError(t, err) }) + + t.Run("Ambiguous/MultipleAgentsNoSelector", func(t *testing.T) { + t.Parallel() + + instanceID := newTestInstanceID(t) + certificates, metadataClient := coderdtest.NewAWSInstanceIdentity(t, instanceID) + client, _ := setupInstanceIDWorkspace(t, &coderdtest.Options{ + AWSCertificates: certificates, + }, workspaceAgentsForInstanceID(instanceID, "alpha", "beta")) + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + agentClient := agentsdk.New(client.URL, agentsdk.WithAWSInstanceIdentity()) + agentClient.SDK.HTTPClient = metadataClient + + err := agentClient.RefreshToken(ctx) + var apiErr *codersdk.Error + require.ErrorAs(t, err, &apiErr) + require.Equal(t, http.StatusConflict, apiErr.StatusCode()) + require.Contains(t, apiErr.Message, "CODER_AGENT_NAME") + require.Contains(t, apiErr.Message, "alpha, beta") + }) + + t.Run("Ambiguous/EmptyAgentNameTreatedAsUnset", func(t *testing.T) { + t.Parallel() + + instanceID := newTestInstanceID(t) + certificates, metadataClient := coderdtest.NewAWSInstanceIdentity(t, instanceID) + client, _ := setupInstanceIDWorkspace(t, &coderdtest.Options{ + AWSCertificates: certificates, + }, workspaceAgentsForInstanceID(instanceID, "alpha", "beta")) + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + signatureReq, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://169.254.169.254/latest/dynamic/instance-identity/signature", nil) + require.NoError(t, err) + signatureRes, err := metadataClient.Do(signatureReq) + require.NoError(t, err) + defer signatureRes.Body.Close() + signature, err := io.ReadAll(signatureRes.Body) + require.NoError(t, err) + + documentReq, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://169.254.169.254/latest/dynamic/instance-identity/document", nil) + require.NoError(t, err) + documentRes, err := metadataClient.Do(documentReq) + require.NoError(t, err) + defer documentRes.Body.Close() + document, err := io.ReadAll(documentRes.Body) + require.NoError(t, err) + + reqBody, err := json.Marshal(map[string]string{ + "signature": string(signature), + "document": string(document), + "agent_name": "", + }) + require.NoError(t, err) + + res, err := client.RequestWithoutSessionToken(ctx, http.MethodPost, "/api/v2/workspaceagents/aws-instance-identity", reqBody) + require.NoError(t, err) + defer res.Body.Close() + + require.Equal(t, http.StatusConflict, res.StatusCode) + err = codersdk.ReadBodyAsError(res) + var apiErr *codersdk.Error + require.ErrorAs(t, err, &apiErr) + require.Equal(t, http.StatusConflict, apiErr.StatusCode()) + require.Contains(t, apiErr.Message, "CODER_AGENT_NAME") + require.Contains(t, apiErr.Message, "alpha, beta") + }) + + t.Run("Ambiguous/WhitespaceAgentNameTreatedAsUnset", func(t *testing.T) { + t.Parallel() + + instanceID := newTestInstanceID(t) + certificates, metadataClient := coderdtest.NewAWSInstanceIdentity(t, instanceID) + client, _ := setupInstanceIDWorkspace(t, &coderdtest.Options{ + AWSCertificates: certificates, + }, workspaceAgentsForInstanceID(instanceID, "alpha", "beta")) + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + signatureReq, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://169.254.169.254/latest/dynamic/instance-identity/signature", nil) + require.NoError(t, err) + signatureRes, err := metadataClient.Do(signatureReq) + require.NoError(t, err) + defer signatureRes.Body.Close() + signature, err := io.ReadAll(signatureRes.Body) + require.NoError(t, err) + + documentReq, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://169.254.169.254/latest/dynamic/instance-identity/document", nil) + require.NoError(t, err) + documentRes, err := metadataClient.Do(documentReq) + require.NoError(t, err) + defer documentRes.Body.Close() + document, err := io.ReadAll(documentRes.Body) + require.NoError(t, err) + + reqBody, err := json.Marshal(map[string]string{ + "signature": string(signature), + "document": string(document), + "agent_name": " ", + }) + require.NoError(t, err) + + res, err := client.RequestWithoutSessionToken(ctx, http.MethodPost, "/api/v2/workspaceagents/aws-instance-identity", reqBody) + require.NoError(t, err) + defer res.Body.Close() + + require.Equal(t, http.StatusConflict, res.StatusCode) + err = codersdk.ReadBodyAsError(res) + var apiErr *codersdk.Error + require.ErrorAs(t, err, &apiErr) + require.Equal(t, http.StatusConflict, apiErr.StatusCode()) + require.Contains(t, apiErr.Message, "CODER_AGENT_NAME") + require.Contains(t, apiErr.Message, "alpha, beta") + }) + + t.Run("Ambiguous/MultipleAgentsWithSelector", func(t *testing.T) { + t.Parallel() + + instanceID := newTestInstanceID(t) + certificates, metadataClient := coderdtest.NewAWSInstanceIdentity(t, instanceID) + client, store := setupInstanceIDWorkspace(t, &coderdtest.Options{ + AWSCertificates: certificates, + }, workspaceAgentsForInstanceID(instanceID, "alpha", "beta")) + + expectedAgent := requireWorkspaceAgentByInstanceIDAndName(t, store, instanceID, "alpha") + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + agentClient := agentsdk.New(client.URL, agentsdk.WithAWSInstanceIdentity( + agentsdk.WithInstanceIdentityAgentName("alpha"), + )) + agentClient.SDK.HTTPClient = metadataClient + + err := agentClient.RefreshToken(ctx) + require.NoError(t, err) + require.Equal(t, expectedAgent.AuthToken.String(), agentClient.SDK.SessionToken()) + }) + + t.Run("Ambiguous/MultipleAgentsUnknownSelector", func(t *testing.T) { + t.Parallel() + + instanceID := newTestInstanceID(t) + certificates, metadataClient := coderdtest.NewAWSInstanceIdentity(t, instanceID) + client, _ := setupInstanceIDWorkspace(t, &coderdtest.Options{ + AWSCertificates: certificates, + }, workspaceAgentsForInstanceID(instanceID, "alpha", "beta")) + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + agentClient := agentsdk.New(client.URL, agentsdk.WithAWSInstanceIdentity( + agentsdk.WithInstanceIdentityAgentName("nonexistent"), + )) + agentClient.SDK.HTTPClient = metadataClient + + err := agentClient.RefreshToken(ctx) + var apiErr *codersdk.Error + require.ErrorAs(t, err, &apiErr) + require.Equal(t, http.StatusNotFound, apiErr.StatusCode()) + }) + + t.Run("Ambiguous/SubAgentExcluded", func(t *testing.T) { + t.Parallel() + + instanceID := newTestInstanceID(t) + certificates, metadataClient := coderdtest.NewAWSInstanceIdentity(t, instanceID) + client, store := setupInstanceIDWorkspace(t, &coderdtest.Options{ + AWSCertificates: certificates, + }, workspaceAgentsForInstanceID(instanceID, "dev")) + + rootAgent := requireWorkspaceAgentByInstanceIDAndName(t, store, instanceID, "dev") + _ = dbgen.WorkspaceSubAgent(t, store, rootAgent, database.WorkspaceAgent{ + Name: "sub", + AuthInstanceID: sql.NullString{ + String: instanceID, + Valid: true, + }, + }) + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + agentClient := agentsdk.New(client.URL, agentsdk.WithAWSInstanceIdentity()) + agentClient.SDK.HTTPClient = metadataClient + + err := agentClient.RefreshToken(ctx) + require.NoError(t, err) + require.Equal(t, rootAgent.AuthToken.String(), agentClient.SDK.SessionToken()) + }) } func TestPostWorkspaceAuthGoogleInstanceIdentity(t *testing.T) { t.Parallel() + t.Run("Expired", func(t *testing.T) { t.Parallel() - instanceID := "instanceidentifier" + + instanceID := newTestInstanceID(t) validator, metadata := coderdtest.NewGoogleInstanceIdentity(t, instanceID, true) client := coderdtest.New(t, &coderdtest.Options{ GoogleTokenValidator: validator, @@ -124,7 +310,8 @@ func TestPostWorkspaceAuthGoogleInstanceIdentity(t *testing.T) { t.Run("InstanceNotFound", func(t *testing.T) { t.Parallel() - instanceID := "instanceidentifier" + + instanceID := newTestInstanceID(t) validator, metadata := coderdtest.NewGoogleInstanceIdentity(t, instanceID, false) client := coderdtest.New(t, &coderdtest.Options{ GoogleTokenValidator: validator, @@ -142,36 +329,12 @@ func TestPostWorkspaceAuthGoogleInstanceIdentity(t *testing.T) { t.Run("Success", func(t *testing.T) { t.Parallel() - instanceID := "instanceidentifier" + + instanceID := newTestInstanceID(t) validator, metadata := coderdtest.NewGoogleInstanceIdentity(t, instanceID, false) - client := coderdtest.New(t, &coderdtest.Options{ - GoogleTokenValidator: validator, - IncludeProvisionerDaemon: true, - }) - user := coderdtest.CreateFirstUser(t, client) - version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{ - Parse: echo.ParseComplete, - ProvisionGraph: []*proto.Response{{ - Type: &proto.Response_Graph{ - Graph: &proto.GraphComplete{ - Resources: []*proto.Resource{{ - Name: "somename", - Type: "someinstance", - Agents: []*proto.Agent{{ - Name: "dev", - Auth: &proto.Agent_InstanceId{ - InstanceId: instanceID, - }, - }}, - }}, - }, - }, - }}, - }) - template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) - coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID) - workspace := coderdtest.CreateWorkspace(t, client, template.ID) - coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) + client, _ := setupInstanceIDWorkspace(t, &coderdtest.Options{ + GoogleTokenValidator: validator, + }, workspaceAgentsForInstanceID(instanceID, "dev")) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() @@ -180,4 +343,91 @@ func TestPostWorkspaceAuthGoogleInstanceIdentity(t *testing.T) { err := agentClient.RefreshToken(ctx) require.NoError(t, err) }) + + t.Run("Ambiguous/GoogleWithSelector", func(t *testing.T) { + t.Parallel() + + instanceID := newTestInstanceID(t) + validator, metadata := coderdtest.NewGoogleInstanceIdentity(t, instanceID, false) + client, store := setupInstanceIDWorkspace(t, &coderdtest.Options{ + GoogleTokenValidator: validator, + }, workspaceAgentsForInstanceID(instanceID, "alpha", "beta")) + + expectedAgent := requireWorkspaceAgentByInstanceIDAndName(t, store, instanceID, "alpha") + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + agentClient := agentsdk.New(client.URL, agentsdk.WithGoogleInstanceIdentity( + "", + metadata, + agentsdk.WithInstanceIdentityAgentName("alpha"), + )) + err := agentClient.RefreshToken(ctx) + require.NoError(t, err) + require.Equal(t, expectedAgent.AuthToken.String(), agentClient.SDK.SessionToken()) + }) +} + +func setupInstanceIDWorkspace(t *testing.T, opts *coderdtest.Options, agents []*proto.Agent) (*codersdk.Client, database.Store) { + t.Helper() + + actualOpts := &coderdtest.Options{} + if opts != nil { + *actualOpts = *opts + } + actualOpts.IncludeProvisionerDaemon = true + + client, store := coderdtest.NewWithDatabase(t, actualOpts) + user := coderdtest.CreateFirstUser(t, client) + version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{ + Parse: echo.ParseComplete, + ProvisionGraph: []*proto.Response{{ + Type: &proto.Response_Graph{ + Graph: &proto.GraphComplete{ + Resources: []*proto.Resource{{ + Name: "resource", + Type: "instance", + Agents: agents, + }}, + }, + }, + }}, + }) + template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) + coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID) + workspace := coderdtest.CreateWorkspace(t, client, template.ID) + coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) + + return client, store +} + +func workspaceAgentsForInstanceID(instanceID string, names ...string) []*proto.Agent { + agents := make([]*proto.Agent, 0, len(names)) + for _, name := range names { + agents = append(agents, &proto.Agent{ + Name: name, + Auth: &proto.Agent_InstanceId{InstanceId: instanceID}, + }) + } + return agents +} + +func requireWorkspaceAgentByInstanceIDAndName(t testing.TB, store database.Store, instanceID string, name string) database.WorkspaceAgent { + t.Helper() + + ctx := dbauthz.AsSystemRestricted(testutil.Context(t, testutil.WaitLong)) + agents, err := store.GetWorkspaceAgentsByInstanceID(ctx, instanceID) + require.NoError(t, err) + for _, agent := range agents { + if agent.Name == name { + return agent + } + } + require.FailNow(t, "workspace agent not found", "instance ID %q, name %q", instanceID, name) + return database.WorkspaceAgent{} +} + +func newTestInstanceID(t testing.TB) string { + t.Helper() + return fmt.Sprintf("instance-%d", time.Now().UnixNano()) } diff --git a/codersdk/agentsdk/agentsdk.go b/codersdk/agentsdk/agentsdk.go index 5e72eef6c2..8b008aea01 100644 --- a/codersdk/agentsdk/agentsdk.go +++ b/codersdk/agentsdk/agentsdk.go @@ -465,6 +465,33 @@ func (FixedSessionTokenProvider) RefreshToken(_ context.Context) error { return nil } +// InstanceIdentityConfig holds optional configuration for cloud +// instance-identity authentication. +type InstanceIdentityConfig struct { + AgentName string +} + +// InstanceIdentityOption configures instance-identity authentication. +type InstanceIdentityOption func(*InstanceIdentityConfig) + +// WithInstanceIdentityAgentName sets the agent name selector sent with +// the instance-identity authentication request. +func WithInstanceIdentityAgentName(name string) InstanceIdentityOption { + return func(c *InstanceIdentityConfig) { + c.AgentName = name + } +} + +// applyInstanceIdentityOptions applies the given options and returns +// the resulting configuration. +func applyInstanceIdentityOptions(opts []InstanceIdentityOption) InstanceIdentityConfig { + var cfg InstanceIdentityConfig + for _, o := range opts { + o(&cfg) + } + return cfg +} + func WithFixedToken(token string) SessionTokenSetup { return func(_ *codersdk.Client) RefreshableSessionTokenProvider { return FixedSessionTokenProvider{FixedSessionTokenProvider: codersdk.FixedSessionTokenProvider{SessionToken: token}} diff --git a/codersdk/agentsdk/aws.go b/codersdk/agentsdk/aws.go index 5440151897..002f4333f7 100644 --- a/codersdk/agentsdk/aws.go +++ b/codersdk/agentsdk/aws.go @@ -14,18 +14,24 @@ import ( type AWSInstanceIdentityToken struct { Signature string `json:"signature" validate:"required"` Document string `json:"document" validate:"required"` + // AgentName optionally selects a specific agent when multiple + // agents share the same instance identity. An empty string is + // treated as unspecified. + AgentName string `json:"agent_name,omitempty"` } // AWSSessionTokenExchanger exchanges AWS instance metadata for a Coder session token. // @typescript-ignore AWSSessionTokenExchanger type AWSSessionTokenExchanger struct { - client *codersdk.Client + client *codersdk.Client + agentName string } -func WithAWSInstanceIdentity() SessionTokenSetup { +func WithAWSInstanceIdentity(opts ...InstanceIdentityOption) SessionTokenSetup { + cfg := applyInstanceIdentityOptions(opts) return func(client *codersdk.Client) RefreshableSessionTokenProvider { return &InstanceIdentitySessionTokenProvider{ - TokenExchanger: &AWSSessionTokenExchanger{client: client}, + TokenExchanger: &AWSSessionTokenExchanger{client: client, agentName: cfg.AgentName}, } } } @@ -84,6 +90,7 @@ func (a *AWSSessionTokenExchanger) exchange(ctx context.Context) (AuthenticateRe res, err = a.client.RequestWithoutSessionToken(ctx, http.MethodPost, "/api/v2/workspaceagents/aws-instance-identity", AWSInstanceIdentityToken{ Signature: string(signature), Document: string(document), + AgentName: a.agentName, }) if err != nil { return AuthenticateResponse{}, err diff --git a/codersdk/agentsdk/azure.go b/codersdk/agentsdk/azure.go index 121292ac93..79898d61d2 100644 --- a/codersdk/agentsdk/azure.go +++ b/codersdk/agentsdk/azure.go @@ -11,18 +11,24 @@ import ( type AzureInstanceIdentityToken struct { Signature string `json:"signature" validate:"required"` Encoding string `json:"encoding" validate:"required"` + // AgentName optionally selects a specific agent when multiple + // agents share the same instance identity. An empty string is + // treated as unspecified. + AgentName string `json:"agent_name,omitempty"` } // AzureSessionTokenExchanger exchanges Azure attested metadata for a Coder session token. // @typescript-ignore AzureSessionTokenExchanger type AzureSessionTokenExchanger struct { - client *codersdk.Client + client *codersdk.Client + agentName string } -func WithAzureInstanceIdentity() SessionTokenSetup { +func WithAzureInstanceIdentity(opts ...InstanceIdentityOption) SessionTokenSetup { + cfg := applyInstanceIdentityOptions(opts) return func(client *codersdk.Client) RefreshableSessionTokenProvider { return &InstanceIdentitySessionTokenProvider{ - TokenExchanger: &AzureSessionTokenExchanger{client: client}, + TokenExchanger: &AzureSessionTokenExchanger{client: client, agentName: cfg.AgentName}, } } } @@ -46,6 +52,7 @@ func (a *AzureSessionTokenExchanger) exchange(ctx context.Context) (Authenticate if err != nil { return AuthenticateResponse{}, err } + token.AgentName = a.agentName res, err = a.client.RequestWithoutSessionToken(ctx, http.MethodPost, "/api/v2/workspaceagents/azure-instance-identity", token) if err != nil { diff --git a/codersdk/agentsdk/google.go b/codersdk/agentsdk/google.go index 51dd138f8e..a2a281febd 100644 --- a/codersdk/agentsdk/google.go +++ b/codersdk/agentsdk/google.go @@ -14,6 +14,10 @@ import ( type GoogleInstanceIdentityToken struct { JSONWebToken string `json:"json_web_token" validate:"required"` + // AgentName optionally selects a specific agent when multiple + // agents share the same instance identity. An empty string is + // treated as unspecified. + AgentName string `json:"agent_name,omitempty"` } // GoogleSessionTokenExchanger exchanges a Google instance JWT document for a Coder session token. @@ -22,15 +26,18 @@ type GoogleSessionTokenExchanger struct { serviceAccount string gcpClient *metadata.Client client *codersdk.Client + agentName string } -func WithGoogleInstanceIdentity(serviceAccount string, gcpClient *metadata.Client) SessionTokenSetup { +func WithGoogleInstanceIdentity(serviceAccount string, gcpClient *metadata.Client, opts ...InstanceIdentityOption) SessionTokenSetup { + cfg := applyInstanceIdentityOptions(opts) return func(client *codersdk.Client) RefreshableSessionTokenProvider { return &InstanceIdentitySessionTokenProvider{ TokenExchanger: &GoogleSessionTokenExchanger{ client: client, gcpClient: gcpClient, serviceAccount: serviceAccount, + agentName: cfg.AgentName, }, } } @@ -58,6 +65,7 @@ func (g *GoogleSessionTokenExchanger) exchange(ctx context.Context) (Authenticat // request without the token to avoid re-entering this function res, err := g.client.RequestWithoutSessionToken(ctx, http.MethodPost, "/api/v2/workspaceagents/google-instance-identity", GoogleInstanceIdentityToken{ JSONWebToken: jwt, + AgentName: g.agentName, }) if err != nil { return AuthenticateResponse{}, err diff --git a/codersdk/agentsdk/instanceidentity_internal_test.go b/codersdk/agentsdk/instanceidentity_internal_test.go new file mode 100644 index 0000000000..75966093ea --- /dev/null +++ b/codersdk/agentsdk/instanceidentity_internal_test.go @@ -0,0 +1,217 @@ +package agentsdk + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "cloud.google.com/go/compute/metadata" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/codersdk" +) + +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + +func TestAWSInstanceIdentityExchange_AgentName(t *testing.T) { + t.Parallel() + + capturedBody := runAWSInstanceIdentityExchange(t, WithInstanceIdentityAgentName("test-agent")) + assertJSONField(t, capturedBody, "agent_name", "test-agent") +} + +func TestAWSInstanceIdentityExchange_OmitsAgentName(t *testing.T) { + t.Parallel() + + capturedBody := runAWSInstanceIdentityExchange(t) + assertJSONFieldAbsent(t, capturedBody, "agent_name") +} + +func TestAzureInstanceIdentityExchange_AgentName(t *testing.T) { + t.Parallel() + + capturedBody := runAzureInstanceIdentityExchange(t, WithInstanceIdentityAgentName("test-agent")) + assertJSONField(t, capturedBody, "agent_name", "test-agent") +} + +func TestAzureInstanceIdentityExchange_OmitsAgentName(t *testing.T) { + t.Parallel() + + capturedBody := runAzureInstanceIdentityExchange(t) + assertJSONFieldAbsent(t, capturedBody, "agent_name") +} + +func TestGoogleInstanceIdentityExchange_AgentName(t *testing.T) { + t.Parallel() + + capturedBody := runGoogleInstanceIdentityExchange(t, WithInstanceIdentityAgentName("test-agent")) + assertJSONField(t, capturedBody, "agent_name", "test-agent") +} + +func TestGoogleInstanceIdentityExchange_OmitsAgentName(t *testing.T) { + t.Parallel() + + capturedBody := runGoogleInstanceIdentityExchange(t) + assertJSONFieldAbsent(t, capturedBody, "agent_name") +} + +func runAWSInstanceIdentityExchange(t *testing.T, opts ...InstanceIdentityOption) []byte { + t.Helper() + + var capturedBody []byte + server := newInstanceIdentityServer(t, "/api/v2/workspaceagents/aws-instance-identity", &capturedBody) + defer server.Close() + + client := newCodersdkClient(t, server, roundTripFunc(func(req *http.Request) (*http.Response, error) { + switch { + case req.URL.Host == "169.254.169.254" && req.Method == http.MethodPut && req.URL.Path == "/latest/api/token": + return httpResponse(req, http.StatusOK, "fake-imds-token", nil), nil + case req.URL.Host == "169.254.169.254" && req.Method == http.MethodGet && req.URL.Path == "/latest/dynamic/instance-identity/signature": + return httpResponse(req, http.StatusOK, "fakesig", nil), nil + case req.URL.Host == "169.254.169.254" && req.Method == http.MethodGet && req.URL.Path == "/latest/dynamic/instance-identity/document": + return httpResponse(req, http.StatusOK, "fakedoc", nil), nil + default: + return http.DefaultTransport.RoundTrip(req) + } + })) + + provider := requireInstanceIdentityProvider(t, WithAWSInstanceIdentity(opts...)(client)) + resp, err := provider.TokenExchanger.exchange(context.Background()) + require.NoError(t, err) + require.Equal(t, "test-session-token", resp.SessionToken) + + return capturedBody +} + +func runAzureInstanceIdentityExchange(t *testing.T, opts ...InstanceIdentityOption) []byte { + t.Helper() + + var capturedBody []byte + server := newInstanceIdentityServer(t, "/api/v2/workspaceagents/azure-instance-identity", &capturedBody) + defer server.Close() + + client := newCodersdkClient(t, server, roundTripFunc(func(req *http.Request) (*http.Response, error) { + switch { + case req.URL.Host == "169.254.169.254" && req.Method == http.MethodGet && req.URL.Path == "/metadata/attested/document": + return httpResponse(req, http.StatusOK, `{"signature":"fakesig","encoding":"fakeenc"}`, http.Header{"Content-Type": []string{"application/json"}}), nil + default: + return http.DefaultTransport.RoundTrip(req) + } + })) + + provider := requireInstanceIdentityProvider(t, WithAzureInstanceIdentity(opts...)(client)) + resp, err := provider.TokenExchanger.exchange(context.Background()) + require.NoError(t, err) + require.Equal(t, "test-session-token", resp.SessionToken) + + return capturedBody +} + +func runGoogleInstanceIdentityExchange(t *testing.T, opts ...InstanceIdentityOption) []byte { + t.Helper() + + var capturedBody []byte + server := newInstanceIdentityServer(t, "/api/v2/workspaceagents/google-instance-identity", &capturedBody) + defer server.Close() + + metadataClient := metadata.NewClient(&http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + require.Equal(t, "169.254.169.254", req.URL.Host) + require.Equal(t, http.MethodGet, req.Method) + require.Equal(t, "/computeMetadata/v1/instance/service-accounts/test-service-account/identity", req.URL.Path) + require.Equal(t, "audience=coder&format=full", req.URL.RawQuery) + require.Equal(t, "Google", req.Header.Get("Metadata-Flavor")) + return httpResponse(req, http.StatusOK, "fake-jwt", nil), nil + })}) + client := newCodersdkClient(t, server, http.DefaultTransport) + + provider := requireInstanceIdentityProvider(t, WithGoogleInstanceIdentity("test-service-account", metadataClient, opts...)(client)) + resp, err := provider.TokenExchanger.exchange(context.Background()) + require.NoError(t, err) + require.Equal(t, "test-session-token", resp.SessionToken) + + return capturedBody +} + +func newInstanceIdentityServer(t *testing.T, path string, capturedBody *[]byte) *httptest.Server { + t.Helper() + + return httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + require.Equal(t, http.MethodPost, req.Method) + require.Equal(t, path, req.URL.Path) + + body, err := io.ReadAll(req.Body) + require.NoError(t, err) + require.NoError(t, req.Body.Close()) + *capturedBody = body + + rw.Header().Set("Content-Type", "application/json") + require.NoError(t, json.NewEncoder(rw).Encode(AuthenticateResponse{SessionToken: "test-session-token"})) + })) +} + +func newCodersdkClient(t *testing.T, server *httptest.Server, transport http.RoundTripper) *codersdk.Client { + t.Helper() + + serverURL, err := url.Parse(server.URL) + require.NoError(t, err) + + return &codersdk.Client{ + URL: serverURL, + HTTPClient: &http.Client{ + Transport: transport, + }, + } +} + +func requireInstanceIdentityProvider(t *testing.T, provider RefreshableSessionTokenProvider) *InstanceIdentitySessionTokenProvider { + t.Helper() + + identityProvider, ok := provider.(*InstanceIdentitySessionTokenProvider) + require.True(t, ok) + return identityProvider +} + +func httpResponse(req *http.Request, statusCode int, body string, headers http.Header) *http.Response { + if headers == nil { + headers = make(http.Header) + } + + return &http.Response{ + StatusCode: statusCode, + Header: headers, + Body: io.NopCloser(strings.NewReader(body)), + Request: req, + } +} + +func decodeJSONBody(t *testing.T, body []byte) map[string]any { + t.Helper() + + var decoded map[string]any + require.NoError(t, json.Unmarshal(body, &decoded)) + return decoded +} + +func assertJSONField(t *testing.T, body []byte, key string, want string) { + t.Helper() + + decoded := decodeJSONBody(t, body) + require.Equal(t, want, decoded[key]) +} + +func assertJSONFieldAbsent(t *testing.T, body []byte, key string) { + t.Helper() + + decoded := decodeJSONBody(t, body) + _, ok := decoded[key] + require.False(t, ok) +} diff --git a/docs/reference/api/agents.md b/docs/reference/api/agents.md index bc0349eb11..2d2848755f 100644 --- a/docs/reference/api/agents.md +++ b/docs/reference/api/agents.md @@ -58,6 +58,7 @@ curl -X POST http://coder-server:8080/api/v2/workspaceagents/aws-instance-identi ```json { + "agent_name": "string", "document": "string", "signature": "string" } @@ -65,9 +66,9 @@ curl -X POST http://coder-server:8080/api/v2/workspaceagents/aws-instance-identi ### Parameters -| Name | In | Type | Required | Description | -|--------|------|----------------------------------------------------------------------------------|----------|-------------------------| -| `body` | body | [agentsdk.AWSInstanceIdentityToken](schemas.md#agentsdkawsinstanceidentitytoken) | true | Instance identity token | +| Name | In | Type | Required | Description | +|--------|------|----------------------------------------------------------------------------------|----------|-----------------------------------------------------------------------------------------------------------------------| +| `body` | body | [agentsdk.AWSInstanceIdentityToken](schemas.md#agentsdkawsinstanceidentitytoken) | true | Instance identity token. The optional agent_name field disambiguates when multiple agents share the same instance ID. | ### Example responses @@ -105,6 +106,7 @@ curl -X POST http://coder-server:8080/api/v2/workspaceagents/azure-instance-iden ```json { + "agent_name": "string", "encoding": "string", "signature": "string" } @@ -112,9 +114,9 @@ curl -X POST http://coder-server:8080/api/v2/workspaceagents/azure-instance-iden ### Parameters -| Name | In | Type | Required | Description | -|--------|------|--------------------------------------------------------------------------------------|----------|-------------------------| -| `body` | body | [agentsdk.AzureInstanceIdentityToken](schemas.md#agentsdkazureinstanceidentitytoken) | true | Instance identity token | +| Name | In | Type | Required | Description | +|--------|------|--------------------------------------------------------------------------------------|----------|-----------------------------------------------------------------------------------------------------------------------| +| `body` | body | [agentsdk.AzureInstanceIdentityToken](schemas.md#agentsdkazureinstanceidentitytoken) | true | Instance identity token. The optional agent_name field disambiguates when multiple agents share the same instance ID. | ### Example responses @@ -152,15 +154,16 @@ curl -X POST http://coder-server:8080/api/v2/workspaceagents/google-instance-ide ```json { + "agent_name": "string", "json_web_token": "string" } ``` ### Parameters -| Name | In | Type | Required | Description | -|--------|------|----------------------------------------------------------------------------------------|----------|-------------------------| -| `body` | body | [agentsdk.GoogleInstanceIdentityToken](schemas.md#agentsdkgoogleinstanceidentitytoken) | true | Instance identity token | +| Name | In | Type | Required | Description | +|--------|------|----------------------------------------------------------------------------------------|----------|-----------------------------------------------------------------------------------------------------------------------| +| `body` | body | [agentsdk.GoogleInstanceIdentityToken](schemas.md#agentsdkgoogleinstanceidentitytoken) | true | Instance identity token. The optional agent_name field disambiguates when multiple agents share the same instance ID. | ### Example responses diff --git a/docs/reference/api/schemas.md b/docs/reference/api/schemas.md index fc08cb695e..fd3fef6764 100644 --- a/docs/reference/api/schemas.md +++ b/docs/reference/api/schemas.md @@ -4,6 +4,7 @@ ```json { + "agent_name": "string", "document": "string", "signature": "string" } @@ -11,10 +12,11 @@ ### Properties -| Name | Type | Required | Restrictions | Description | -|-------------|--------|----------|--------------|-------------| -| `document` | string | true | | | -| `signature` | string | true | | | +| Name | Type | Required | Restrictions | Description | +|--------------|--------|----------|--------------|--------------------------------------------------------------------------------------------------------------------------------------------------| +| `agent_name` | string | false | | Agent name optionally selects a specific agent when multiple agents share the same instance identity. An empty string is treated as unspecified. | +| `document` | string | true | | | +| `signature` | string | true | | | ## agentsdk.AuthenticateResponse @@ -34,6 +36,7 @@ ```json { + "agent_name": "string", "encoding": "string", "signature": "string" } @@ -41,10 +44,11 @@ ### Properties -| Name | Type | Required | Restrictions | Description | -|-------------|--------|----------|--------------|-------------| -| `encoding` | string | true | | | -| `signature` | string | true | | | +| Name | Type | Required | Restrictions | Description | +|--------------|--------|----------|--------------|--------------------------------------------------------------------------------------------------------------------------------------------------| +| `agent_name` | string | false | | Agent name optionally selects a specific agent when multiple agents share the same instance identity. An empty string is treated as unspecified. | +| `encoding` | string | true | | | +| `signature` | string | true | | | ## agentsdk.ExternalAuthResponse @@ -90,15 +94,17 @@ ```json { + "agent_name": "string", "json_web_token": "string" } ``` ### Properties -| Name | Type | Required | Restrictions | Description | -|------------------|--------|----------|--------------|-------------| -| `json_web_token` | string | true | | | +| Name | Type | Required | Restrictions | Description | +|------------------|--------|----------|--------------|--------------------------------------------------------------------------------------------------------------------------------------------------| +| `agent_name` | string | false | | Agent name optionally selects a specific agent when multiple agents share the same instance identity. An empty string is treated as unspecified. | +| `json_web_token` | string | true | | | ## agentsdk.Log diff --git a/docs/reference/cli/external-auth_access-token.md b/docs/reference/cli/external-auth_access-token.md index 7fb022077a..f7f8960b48 100644 --- a/docs/reference/cli/external-auth_access-token.md +++ b/docs/reference/cli/external-auth_access-token.md @@ -77,3 +77,12 @@ URL for an agent to access your deployment. | Default | token | Specify the authentication type to use for the agent. + +### --agent-name + +| | | +|-------------|--------------------------------| +| Type | string | +| Environment | $CODER_AGENT_NAME | + +The name of the agent to authenticate as (only applicable for instance identity). diff --git a/examples/templates/aws-multi-agent/README.md b/examples/templates/aws-multi-agent/README.md new file mode 100644 index 0000000000..143ffc8612 --- /dev/null +++ b/examples/templates/aws-multi-agent/README.md @@ -0,0 +1,81 @@ +--- +display_name: AWS EC2 Multi-Agent Instance Identity +description: Verify AWS instance identity auth for two Coder agents on one EC2 instance +icon: ../../../site/static/icon/aws.svg +maintainer_github: coder +verified: true +tags: [vm, linux, aws, multi-agent, instance-identity] +--- + +# AWS multi-agent instance identity verification + +This template verifies the multi-agent instance-identity authentication flow on +AWS. It provisions a single EC2 instance with two peer root workspace agents, +`main` and `dev`, that both use AWS instance identity authentication. + +The key behavior under test is `CODER_AGENT_NAME` disambiguation. Each agent +starts on the same VM with the same EC2 instance identity, but sets a distinct +`CODER_AGENT_NAME` so the Coder server can issue a separate session token for +that specific agent. + +## Prerequisites + +- AWS credentials configured for Terraform, such as environment variables or an + attached IAM role. +- A Coder deployment that includes the multi-agent instance-auth changes from + this branch. +- No special Coder server configuration. AWS instance identity certificates are + built in. + +## What this template creates + +- One VPC, subnet, internet gateway, route table, and route table association. +- One security group that allows SSH from anywhere for test access. +- One Ubuntu 24.04 EC2 instance. +- Two Coder agents, `main` and `dev`, on that single EC2 instance. +- Two agent startup flows that set `CODER_AGENT_NAME` before launching the + corresponding agent init script. + +## How to verify + +```bash +cd examples/templates/aws-multi-agent +coder templates push verify-multi-agent + +coder create test-multi-agent --template verify-multi-agent + +coder list +``` + +After the workspace starts, verify that both agents are connected in the Coder +Dashboard for `test-multi-agent`. You can also connect to each agent directly: + +```bash +coder ssh test-multi-agent -a main true +coder ssh test-multi-agent -a dev true +``` + +## Expected behavior + +- Both agents authenticate independently using AWS instance identity. +- Each agent receives its own session token. +- The workspace shows two connected agents in the Coder Dashboard. +- If `CODER_AGENT_NAME` is omitted, the server should return `409 Conflict` + because the shared instance identity is ambiguous. + +## Troubleshooting + +- If one agent gets `409 Conflict`, `CODER_AGENT_NAME` is not being set + correctly for that agent. +- If both agents fail, instance identity authentication is not working. Check + EC2 metadata service access from the instance. +- Check cloud-init logs with `journalctl -u cloud-init`. +- Check agent logs at `/tmp/coder-agent-main.log` and + `/tmp/coder-agent-dev.log`. + +## Cleanup + +```bash +coder delete test-multi-agent +coder templates delete verify-multi-agent +``` diff --git a/examples/templates/aws-multi-agent/cloud-init/userdata.sh.tftpl b/examples/templates/aws-multi-agent/cloud-init/userdata.sh.tftpl new file mode 100644 index 0000000000..52cc1cb8e3 --- /dev/null +++ b/examples/templates/aws-multi-agent/cloud-init/userdata.sh.tftpl @@ -0,0 +1,18 @@ +#!/bin/bash +set -euo pipefail + +# Create the user if it doesn't exist. +if ! id -u "${linux_user}" >/dev/null 2>&1; then + useradd -m -s /bin/bash "${linux_user}" +fi + +# Start main agent with disambiguation name. +CODER_AGENT_NAME=main sudo -u '${linux_user}' sh -c '${main_init_script}' \ + >/tmp/coder-agent-main.log 2>&1 & + +# Start dev agent with disambiguation name. +CODER_AGENT_NAME=dev sudo -u '${linux_user}' sh -c '${dev_init_script}' \ + >/tmp/coder-agent-dev.log 2>&1 & + +# Wait for both agent processes to start. +wait diff --git a/examples/templates/aws-multi-agent/main.tf b/examples/templates/aws-multi-agent/main.tf new file mode 100644 index 0000000000..9f5be93914 --- /dev/null +++ b/examples/templates/aws-multi-agent/main.tf @@ -0,0 +1,340 @@ +terraform { + required_providers { + coder = { + source = "coder/coder" + } + aws = { + source = "hashicorp/aws" + } + cloudinit = { + source = "hashicorp/cloudinit" + } + } +} + +# Last updated 2023-03-14 +# aws ec2 describe-regions | jq -r '[.Regions[].RegionName] | sort' +data "coder_parameter" "region" { + name = "region" + display_name = "Region" + description = "The region to deploy the workspace in." + default = "us-east-1" + mutable = false + option { + name = "Asia Pacific (Tokyo)" + value = "ap-northeast-1" + icon = "/emojis/1f1ef-1f1f5.png" + } + option { + name = "Asia Pacific (Seoul)" + value = "ap-northeast-2" + icon = "/emojis/1f1f0-1f1f7.png" + } + option { + name = "Asia Pacific (Osaka)" + value = "ap-northeast-3" + icon = "/emojis/1f1ef-1f1f5.png" + } + option { + name = "Asia Pacific (Mumbai)" + value = "ap-south-1" + icon = "/emojis/1f1ee-1f1f3.png" + } + option { + name = "Asia Pacific (Singapore)" + value = "ap-southeast-1" + icon = "/emojis/1f1f8-1f1ec.png" + } + option { + name = "Asia Pacific (Sydney)" + value = "ap-southeast-2" + icon = "/emojis/1f1e6-1f1fa.png" + } + option { + name = "Canada (Central)" + value = "ca-central-1" + icon = "/emojis/1f1e8-1f1e6.png" + } + option { + name = "EU (Frankfurt)" + value = "eu-central-1" + icon = "/emojis/1f1ea-1f1fa.png" + } + option { + name = "EU (Stockholm)" + value = "eu-north-1" + icon = "/emojis/1f1ea-1f1fa.png" + } + option { + name = "EU (Ireland)" + value = "eu-west-1" + icon = "/emojis/1f1ea-1f1fa.png" + } + option { + name = "EU (London)" + value = "eu-west-2" + icon = "/emojis/1f1ea-1f1fa.png" + } + option { + name = "EU (Paris)" + value = "eu-west-3" + icon = "/emojis/1f1ea-1f1fa.png" + } + option { + name = "South America (São Paulo)" + value = "sa-east-1" + icon = "/emojis/1f1e7-1f1f7.png" + } + option { + name = "US East (N. Virginia)" + value = "us-east-1" + icon = "/emojis/1f1fa-1f1f8.png" + } + option { + name = "US East (Ohio)" + value = "us-east-2" + icon = "/emojis/1f1fa-1f1f8.png" + } + option { + name = "US West (N. California)" + value = "us-west-1" + icon = "/emojis/1f1fa-1f1f8.png" + } + option { + name = "US West (Oregon)" + value = "us-west-2" + icon = "/emojis/1f1fa-1f1f8.png" + } +} + +data "coder_parameter" "instance_type" { + name = "instance_type" + display_name = "Instance type" + description = "What instance type should your workspace use?" + default = "t3.micro" + mutable = false + option { + name = "2 vCPU, 1 GiB RAM" + value = "t3.micro" + } + option { + name = "2 vCPU, 2 GiB RAM" + value = "t3.small" + } + option { + name = "2 vCPU, 4 GiB RAM" + value = "t3.medium" + } + option { + name = "2 vCPU, 8 GiB RAM" + value = "t3.large" + } + option { + name = "4 vCPU, 16 GiB RAM" + value = "t3.xlarge" + } + option { + name = "8 vCPU, 32 GiB RAM" + value = "t3.2xlarge" + } +} + +provider "aws" { + region = data.coder_parameter.region.value +} + +data "coder_workspace" "me" {} +data "coder_workspace_owner" "me" {} + +data "aws_ami" "ubuntu" { + most_recent = true + filter { + name = "name" + values = ["ubuntu/images/hvm-ssd-gp3/ubuntu-noble-24.04-amd64-server-*"] + } + filter { + name = "virtualization-type" + values = ["hvm"] + } + owners = ["099720109477"] # Canonical +} + +resource "coder_agent" "main" { + count = data.coder_workspace.me.start_count + os = "linux" + arch = "amd64" + auth = "aws-instance-identity" + startup_script = <<-EOT + #!/bin/bash + set -e + echo "Agent 'main' started successfully" + echo "CODER_AGENT_NAME=$CODER_AGENT_NAME" + EOT + + metadata { + key = "agent-identity" + display_name = "Agent Identity" + interval = 60 + timeout = 5 + script = "echo main" + } +} + +resource "coder_agent" "dev" { + count = data.coder_workspace.me.start_count + os = "linux" + arch = "amd64" + auth = "aws-instance-identity" + startup_script = <<-EOT + #!/bin/bash + set -e + echo "Agent 'dev' started successfully" + echo "CODER_AGENT_NAME=$CODER_AGENT_NAME" + EOT + + metadata { + key = "agent-identity" + display_name = "Agent Identity" + interval = 60 + timeout = 5 + script = "echo dev" + } +} + +locals { + aws_availability_zone = "${data.coder_parameter.region.value}a" + hostname = lower(data.coder_workspace.me.name) + linux_user = "coder" +} + +data "cloudinit_config" "user_data" { + gzip = false + base64_encode = false + + boundary = "//" + + part { + filename = "userdata.sh" + content_type = "text/x-shellscript" + + content = templatefile("${path.module}/cloud-init/userdata.sh.tftpl", { + linux_user = local.linux_user + main_init_script = try(coder_agent.main[0].init_script, "") + dev_init_script = try(coder_agent.dev[0].init_script, "") + }) + } +} + +resource "aws_vpc" "workspace" { + cidr_block = "10.0.0.0/16" + enable_dns_hostnames = true + enable_dns_support = true + + tags = { + Name = "coder-${data.coder_workspace_owner.me.name}-${local.hostname}" + } +} + +resource "aws_subnet" "workspace" { + vpc_id = aws_vpc.workspace.id + cidr_block = "10.0.1.0/24" + availability_zone = local.aws_availability_zone + map_public_ip_on_launch = true + + tags = { + Name = "coder-${data.coder_workspace_owner.me.name}-${local.hostname}" + } +} + +resource "aws_internet_gateway" "workspace" { + vpc_id = aws_vpc.workspace.id + + tags = { + Name = "coder-${data.coder_workspace_owner.me.name}-${local.hostname}" + } +} + +resource "aws_route_table" "workspace" { + vpc_id = aws_vpc.workspace.id + + route { + cidr_block = "0.0.0.0/0" + gateway_id = aws_internet_gateway.workspace.id + } + + tags = { + Name = "coder-${data.coder_workspace_owner.me.name}-${local.hostname}" + } +} + +resource "aws_route_table_association" "workspace" { + subnet_id = aws_subnet.workspace.id + route_table_id = aws_route_table.workspace.id +} + +resource "aws_security_group" "workspace" { + name_prefix = "coder-${local.hostname}-" + description = "Allow SSH access for testing." + vpc_id = aws_vpc.workspace.id + + ingress { + description = "SSH" + from_port = 22 + to_port = 22 + protocol = "tcp" + cidr_blocks = ["0.0.0.0/0"] + } + + egress { + from_port = 0 + to_port = 0 + protocol = "-1" + cidr_blocks = ["0.0.0.0/0"] + } + + tags = { + Name = "coder-${data.coder_workspace_owner.me.name}-${local.hostname}" + } +} + +resource "aws_instance" "dev" { + ami = data.aws_ami.ubuntu.id + availability_zone = local.aws_availability_zone + instance_type = data.coder_parameter.instance_type.value + subnet_id = aws_subnet.workspace.id + vpc_security_group_ids = [aws_security_group.workspace.id] + associate_public_ip_address = true + + user_data = data.cloudinit_config.user_data.rendered + tags = { + Name = "coder-${data.coder_workspace_owner.me.name}-${data.coder_workspace.me.name}" + # Required if you are using our example policy, see template README + Coder_Provisioned = "true" + } + lifecycle { + ignore_changes = [ami] + } + + depends_on = [aws_route_table_association.workspace] +} + +resource "coder_metadata" "workspace_info" { + resource_id = aws_instance.dev.id + item { + key = "region" + value = data.coder_parameter.region.value + } + item { + key = "instance type" + value = aws_instance.dev.instance_type + } + item { + key = "ami" + value = aws_instance.dev.ami + } +} + +resource "aws_ec2_instance_state" "dev" { + instance_id = aws_instance.dev.id + state = data.coder_workspace.me.transition == "start" ? "running" : "stopped" +}