Files
coder/coderd/workspaceresourceauth_test.go
T
Michael Suchacz e5707a13d6 feat: support multiple agents with shared instance-identity auth (#24325)
> This PR was authored by Mux on behalf of Mike.

## Summary

Adds support for multiple peer root workspace agents sharing the same
`auth_instance_id`, so AWS, Azure, and GCP instance-identity auth can
issue the correct session token for a selected agent instead of assuming
a
single root agent per instance.

## Problem

When a Terraform template attaches two or more `coder_agent` resources
(with `auth = "aws-instance-identity"`) to a single compute instance,
every agent shares the same cloud instance ID. The existing singular
lookup picks whichever agent was created most recently, silently
ignoring
the others.

## Solution

Introduce an optional pre-auth agent selector (`CODER_AGENT_NAME`) and
make the server-side lookup ambiguity-aware.

**Database layer:**
- `GetWorkspaceAgentsByInstanceID` (`:many`): returns all matching root
  agents for an instance ID.
- `GetWorkspaceAgentByInstanceIDAndName` (`:one`): returns the named
root
  agent for disambiguation.

**SDK and CLI:**
- `agent_name` field added to AWS, Azure, and GCP request structs
  (`omitempty` for backward compatibility).
- `CODER_AGENT_NAME` env var and `--agent-name` flag wired into the
agent
  bootstrap before instance-identity auth runs.

**Server handler (`handleAuthInstanceID`):**
- When `agent_name` is present: direct lookup by (instance ID, name).
- When absent: legacy lookup, then resource-scoped ambiguity check.
  Returns 409 with available agent names if multiple root agents match.
- Whitespace-only names are trimmed and treated as unspecified.
- Sub-agents remain excluded (`parent_id IS NULL` filter).

**Verification template:**
- `examples/templates/aws-multi-agent/` provisions one EC2 instance with
  two agents (`main` and `dev`), both using instance-identity auth with
  `CODER_AGENT_NAME` set in the cloud-init user data.

## Backward compatibility

Existing single-agent deployments work unchanged. The `agent_name` field
is optional with `omitempty`, and the unnamed path preserves today's
behavior when only one root agent matches.
2026-04-16 13:59:09 +02:00

434 lines
15 KiB
Go

package coderd_test
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"io"
"net/http"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/provisioner/echo"
"github.com/coder/coder/v2/provisionersdk/proto"
"github.com/coder/coder/v2/testutil"
)
func TestPostWorkspaceAuthAzureInstanceIdentity(t *testing.T) {
t.Parallel()
t.Run("Success", func(t *testing.T) {
t.Parallel()
instanceID := newTestInstanceID(t)
certificates, metadataClient := coderdtest.NewAzureInstanceIdentity(t, instanceID)
client, _ := setupInstanceIDWorkspace(t, &coderdtest.Options{
AzureCertificates: certificates,
}, workspaceAgentsForInstanceID(instanceID, "dev"))
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
agentClient := agentsdk.New(client.URL, agentsdk.WithAzureInstanceIdentity())
agentClient.SDK.HTTPClient = metadataClient
err := agentClient.RefreshToken(ctx)
require.NoError(t, err)
})
t.Run("Ambiguous/AzureWithSelector", func(t *testing.T) {
t.Parallel()
instanceID := newTestInstanceID(t)
certificates, metadataClient := coderdtest.NewAzureInstanceIdentity(t, instanceID)
client, store := setupInstanceIDWorkspace(t, &coderdtest.Options{
AzureCertificates: certificates,
}, workspaceAgentsForInstanceID(instanceID, "alpha", "beta"))
expectedAgent := requireWorkspaceAgentByInstanceIDAndName(t, store, instanceID, "alpha")
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
agentClient := agentsdk.New(client.URL, agentsdk.WithAzureInstanceIdentity(
agentsdk.WithInstanceIdentityAgentName("alpha"),
))
agentClient.SDK.HTTPClient = metadataClient
err := agentClient.RefreshToken(ctx)
require.NoError(t, err)
require.Equal(t, expectedAgent.AuthToken.String(), agentClient.SDK.SessionToken())
})
}
func TestPostWorkspaceAuthAWSInstanceIdentity(t *testing.T) {
t.Parallel()
t.Run("Ambiguous/SingleAgent", func(t *testing.T) {
t.Parallel()
instanceID := newTestInstanceID(t)
certificates, metadataClient := coderdtest.NewAWSInstanceIdentity(t, instanceID)
client, _ := setupInstanceIDWorkspace(t, &coderdtest.Options{
AWSCertificates: certificates,
}, workspaceAgentsForInstanceID(instanceID, "dev"))
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
agentClient := agentsdk.New(client.URL, agentsdk.WithAWSInstanceIdentity())
agentClient.SDK.HTTPClient = metadataClient
err := agentClient.RefreshToken(ctx)
require.NoError(t, err)
})
t.Run("Ambiguous/MultipleAgentsNoSelector", func(t *testing.T) {
t.Parallel()
instanceID := newTestInstanceID(t)
certificates, metadataClient := coderdtest.NewAWSInstanceIdentity(t, instanceID)
client, _ := setupInstanceIDWorkspace(t, &coderdtest.Options{
AWSCertificates: certificates,
}, workspaceAgentsForInstanceID(instanceID, "alpha", "beta"))
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
agentClient := agentsdk.New(client.URL, agentsdk.WithAWSInstanceIdentity())
agentClient.SDK.HTTPClient = metadataClient
err := agentClient.RefreshToken(ctx)
var apiErr *codersdk.Error
require.ErrorAs(t, err, &apiErr)
require.Equal(t, http.StatusConflict, apiErr.StatusCode())
require.Contains(t, apiErr.Message, "CODER_AGENT_NAME")
require.Contains(t, apiErr.Message, "alpha, beta")
})
t.Run("Ambiguous/EmptyAgentNameTreatedAsUnset", func(t *testing.T) {
t.Parallel()
instanceID := newTestInstanceID(t)
certificates, metadataClient := coderdtest.NewAWSInstanceIdentity(t, instanceID)
client, _ := setupInstanceIDWorkspace(t, &coderdtest.Options{
AWSCertificates: certificates,
}, workspaceAgentsForInstanceID(instanceID, "alpha", "beta"))
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
signatureReq, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://169.254.169.254/latest/dynamic/instance-identity/signature", nil)
require.NoError(t, err)
signatureRes, err := metadataClient.Do(signatureReq)
require.NoError(t, err)
defer signatureRes.Body.Close()
signature, err := io.ReadAll(signatureRes.Body)
require.NoError(t, err)
documentReq, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://169.254.169.254/latest/dynamic/instance-identity/document", nil)
require.NoError(t, err)
documentRes, err := metadataClient.Do(documentReq)
require.NoError(t, err)
defer documentRes.Body.Close()
document, err := io.ReadAll(documentRes.Body)
require.NoError(t, err)
reqBody, err := json.Marshal(map[string]string{
"signature": string(signature),
"document": string(document),
"agent_name": "",
})
require.NoError(t, err)
res, err := client.RequestWithoutSessionToken(ctx, http.MethodPost, "/api/v2/workspaceagents/aws-instance-identity", reqBody)
require.NoError(t, err)
defer res.Body.Close()
require.Equal(t, http.StatusConflict, res.StatusCode)
err = codersdk.ReadBodyAsError(res)
var apiErr *codersdk.Error
require.ErrorAs(t, err, &apiErr)
require.Equal(t, http.StatusConflict, apiErr.StatusCode())
require.Contains(t, apiErr.Message, "CODER_AGENT_NAME")
require.Contains(t, apiErr.Message, "alpha, beta")
})
t.Run("Ambiguous/WhitespaceAgentNameTreatedAsUnset", func(t *testing.T) {
t.Parallel()
instanceID := newTestInstanceID(t)
certificates, metadataClient := coderdtest.NewAWSInstanceIdentity(t, instanceID)
client, _ := setupInstanceIDWorkspace(t, &coderdtest.Options{
AWSCertificates: certificates,
}, workspaceAgentsForInstanceID(instanceID, "alpha", "beta"))
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
signatureReq, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://169.254.169.254/latest/dynamic/instance-identity/signature", nil)
require.NoError(t, err)
signatureRes, err := metadataClient.Do(signatureReq)
require.NoError(t, err)
defer signatureRes.Body.Close()
signature, err := io.ReadAll(signatureRes.Body)
require.NoError(t, err)
documentReq, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://169.254.169.254/latest/dynamic/instance-identity/document", nil)
require.NoError(t, err)
documentRes, err := metadataClient.Do(documentReq)
require.NoError(t, err)
defer documentRes.Body.Close()
document, err := io.ReadAll(documentRes.Body)
require.NoError(t, err)
reqBody, err := json.Marshal(map[string]string{
"signature": string(signature),
"document": string(document),
"agent_name": " ",
})
require.NoError(t, err)
res, err := client.RequestWithoutSessionToken(ctx, http.MethodPost, "/api/v2/workspaceagents/aws-instance-identity", reqBody)
require.NoError(t, err)
defer res.Body.Close()
require.Equal(t, http.StatusConflict, res.StatusCode)
err = codersdk.ReadBodyAsError(res)
var apiErr *codersdk.Error
require.ErrorAs(t, err, &apiErr)
require.Equal(t, http.StatusConflict, apiErr.StatusCode())
require.Contains(t, apiErr.Message, "CODER_AGENT_NAME")
require.Contains(t, apiErr.Message, "alpha, beta")
})
t.Run("Ambiguous/MultipleAgentsWithSelector", func(t *testing.T) {
t.Parallel()
instanceID := newTestInstanceID(t)
certificates, metadataClient := coderdtest.NewAWSInstanceIdentity(t, instanceID)
client, store := setupInstanceIDWorkspace(t, &coderdtest.Options{
AWSCertificates: certificates,
}, workspaceAgentsForInstanceID(instanceID, "alpha", "beta"))
expectedAgent := requireWorkspaceAgentByInstanceIDAndName(t, store, instanceID, "alpha")
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
agentClient := agentsdk.New(client.URL, agentsdk.WithAWSInstanceIdentity(
agentsdk.WithInstanceIdentityAgentName("alpha"),
))
agentClient.SDK.HTTPClient = metadataClient
err := agentClient.RefreshToken(ctx)
require.NoError(t, err)
require.Equal(t, expectedAgent.AuthToken.String(), agentClient.SDK.SessionToken())
})
t.Run("Ambiguous/MultipleAgentsUnknownSelector", func(t *testing.T) {
t.Parallel()
instanceID := newTestInstanceID(t)
certificates, metadataClient := coderdtest.NewAWSInstanceIdentity(t, instanceID)
client, _ := setupInstanceIDWorkspace(t, &coderdtest.Options{
AWSCertificates: certificates,
}, workspaceAgentsForInstanceID(instanceID, "alpha", "beta"))
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
agentClient := agentsdk.New(client.URL, agentsdk.WithAWSInstanceIdentity(
agentsdk.WithInstanceIdentityAgentName("nonexistent"),
))
agentClient.SDK.HTTPClient = metadataClient
err := agentClient.RefreshToken(ctx)
var apiErr *codersdk.Error
require.ErrorAs(t, err, &apiErr)
require.Equal(t, http.StatusNotFound, apiErr.StatusCode())
})
t.Run("Ambiguous/SubAgentExcluded", func(t *testing.T) {
t.Parallel()
instanceID := newTestInstanceID(t)
certificates, metadataClient := coderdtest.NewAWSInstanceIdentity(t, instanceID)
client, store := setupInstanceIDWorkspace(t, &coderdtest.Options{
AWSCertificates: certificates,
}, workspaceAgentsForInstanceID(instanceID, "dev"))
rootAgent := requireWorkspaceAgentByInstanceIDAndName(t, store, instanceID, "dev")
_ = dbgen.WorkspaceSubAgent(t, store, rootAgent, database.WorkspaceAgent{
Name: "sub",
AuthInstanceID: sql.NullString{
String: instanceID,
Valid: true,
},
})
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
agentClient := agentsdk.New(client.URL, agentsdk.WithAWSInstanceIdentity())
agentClient.SDK.HTTPClient = metadataClient
err := agentClient.RefreshToken(ctx)
require.NoError(t, err)
require.Equal(t, rootAgent.AuthToken.String(), agentClient.SDK.SessionToken())
})
}
func TestPostWorkspaceAuthGoogleInstanceIdentity(t *testing.T) {
t.Parallel()
t.Run("Expired", func(t *testing.T) {
t.Parallel()
instanceID := newTestInstanceID(t)
validator, metadata := coderdtest.NewGoogleInstanceIdentity(t, instanceID, true)
client := coderdtest.New(t, &coderdtest.Options{
GoogleTokenValidator: validator,
})
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
agentClient := agentsdk.New(client.URL, agentsdk.WithGoogleInstanceIdentity("", metadata))
err := agentClient.RefreshToken(ctx)
var apiErr *codersdk.Error
require.ErrorAs(t, err, &apiErr)
require.Equal(t, http.StatusUnauthorized, apiErr.StatusCode())
})
t.Run("InstanceNotFound", func(t *testing.T) {
t.Parallel()
instanceID := newTestInstanceID(t)
validator, metadata := coderdtest.NewGoogleInstanceIdentity(t, instanceID, false)
client := coderdtest.New(t, &coderdtest.Options{
GoogleTokenValidator: validator,
})
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
agentClient := agentsdk.New(client.URL, agentsdk.WithGoogleInstanceIdentity("", metadata))
err := agentClient.RefreshToken(ctx)
var apiErr *codersdk.Error
require.ErrorAs(t, err, &apiErr)
require.Equal(t, http.StatusNotFound, apiErr.StatusCode())
})
t.Run("Success", func(t *testing.T) {
t.Parallel()
instanceID := newTestInstanceID(t)
validator, metadata := coderdtest.NewGoogleInstanceIdentity(t, instanceID, false)
client, _ := setupInstanceIDWorkspace(t, &coderdtest.Options{
GoogleTokenValidator: validator,
}, workspaceAgentsForInstanceID(instanceID, "dev"))
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
agentClient := agentsdk.New(client.URL, agentsdk.WithGoogleInstanceIdentity("", metadata))
err := agentClient.RefreshToken(ctx)
require.NoError(t, err)
})
t.Run("Ambiguous/GoogleWithSelector", func(t *testing.T) {
t.Parallel()
instanceID := newTestInstanceID(t)
validator, metadata := coderdtest.NewGoogleInstanceIdentity(t, instanceID, false)
client, store := setupInstanceIDWorkspace(t, &coderdtest.Options{
GoogleTokenValidator: validator,
}, workspaceAgentsForInstanceID(instanceID, "alpha", "beta"))
expectedAgent := requireWorkspaceAgentByInstanceIDAndName(t, store, instanceID, "alpha")
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
agentClient := agentsdk.New(client.URL, agentsdk.WithGoogleInstanceIdentity(
"",
metadata,
agentsdk.WithInstanceIdentityAgentName("alpha"),
))
err := agentClient.RefreshToken(ctx)
require.NoError(t, err)
require.Equal(t, expectedAgent.AuthToken.String(), agentClient.SDK.SessionToken())
})
}
func setupInstanceIDWorkspace(t *testing.T, opts *coderdtest.Options, agents []*proto.Agent) (*codersdk.Client, database.Store) {
t.Helper()
actualOpts := &coderdtest.Options{}
if opts != nil {
*actualOpts = *opts
}
actualOpts.IncludeProvisionerDaemon = true
client, store := coderdtest.NewWithDatabase(t, actualOpts)
user := coderdtest.CreateFirstUser(t, client)
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
Parse: echo.ParseComplete,
ProvisionGraph: []*proto.Response{{
Type: &proto.Response_Graph{
Graph: &proto.GraphComplete{
Resources: []*proto.Resource{{
Name: "resource",
Type: "instance",
Agents: agents,
}},
},
},
}},
})
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
workspace := coderdtest.CreateWorkspace(t, client, template.ID)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
return client, store
}
func workspaceAgentsForInstanceID(instanceID string, names ...string) []*proto.Agent {
agents := make([]*proto.Agent, 0, len(names))
for _, name := range names {
agents = append(agents, &proto.Agent{
Name: name,
Auth: &proto.Agent_InstanceId{InstanceId: instanceID},
})
}
return agents
}
func requireWorkspaceAgentByInstanceIDAndName(t testing.TB, store database.Store, instanceID string, name string) database.WorkspaceAgent {
t.Helper()
ctx := dbauthz.AsSystemRestricted(testutil.Context(t, testutil.WaitLong))
agents, err := store.GetWorkspaceAgentsByInstanceID(ctx, instanceID)
require.NoError(t, err)
for _, agent := range agents {
if agent.Name == name {
return agent
}
}
require.FailNow(t, "workspace agent not found", "instance ID %q, name %q", instanceID, name)
return database.WorkspaceAgent{}
}
func newTestInstanceID(t testing.TB) string {
t.Helper()
return fmt.Sprintf("instance-%d", time.Now().UnixNano())
}