Files
coder/codersdk/agentsdk/instanceidentity_internal_test.go
Michael Suchacz e5707a13d6 feat: support multiple agents with shared instance-identity auth (#24325)
> This PR was authored by Mux on behalf of Mike.

## Summary

Adds support for multiple peer root workspace agents sharing the same
`auth_instance_id`, so AWS, Azure, and GCP instance-identity auth can
issue the correct session token for a selected agent instead of assuming
a
single root agent per instance.

## Problem

When a Terraform template attaches two or more `coder_agent` resources
(with `auth = "aws-instance-identity"`) to a single compute instance,
every agent shares the same cloud instance ID. The existing singular
lookup picks whichever agent was created most recently, silently
ignoring
the others.

## Solution

Introduce an optional pre-auth agent selector (`CODER_AGENT_NAME`) and
make the server-side lookup ambiguity-aware.

**Database layer:**
- `GetWorkspaceAgentsByInstanceID` (`:many`): returns all matching root
  agents for an instance ID.
- `GetWorkspaceAgentByInstanceIDAndName` (`:one`): returns the named
root
  agent for disambiguation.

**SDK and CLI:**
- `agent_name` field added to AWS, Azure, and GCP request structs
  (`omitempty` for backward compatibility).
- `CODER_AGENT_NAME` env var and `--agent-name` flag wired into the
agent
  bootstrap before instance-identity auth runs.

**Server handler (`handleAuthInstanceID`):**
- When `agent_name` is present: direct lookup by (instance ID, name).
- When absent: legacy lookup, then resource-scoped ambiguity check.
  Returns 409 with available agent names if multiple root agents match.
- Whitespace-only names are trimmed and treated as unspecified.
- Sub-agents remain excluded (`parent_id IS NULL` filter).

**Verification template:**
- `examples/templates/aws-multi-agent/` provisions one EC2 instance with
  two agents (`main` and `dev`), both using instance-identity auth with
  `CODER_AGENT_NAME` set in the cloud-init user data.

## Backward compatibility

Existing single-agent deployments work unchanged. The `agent_name` field
is optional with `omitempty`, and the unnamed path preserves today's
behavior when only one root agent matches.
2026-04-16 13:59:09 +02:00

218 lines
7.0 KiB
Go

package agentsdk
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"cloud.google.com/go/compute/metadata"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/codersdk"
)
type roundTripFunc func(*http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}
func TestAWSInstanceIdentityExchange_AgentName(t *testing.T) {
t.Parallel()
capturedBody := runAWSInstanceIdentityExchange(t, WithInstanceIdentityAgentName("test-agent"))
assertJSONField(t, capturedBody, "agent_name", "test-agent")
}
func TestAWSInstanceIdentityExchange_OmitsAgentName(t *testing.T) {
t.Parallel()
capturedBody := runAWSInstanceIdentityExchange(t)
assertJSONFieldAbsent(t, capturedBody, "agent_name")
}
func TestAzureInstanceIdentityExchange_AgentName(t *testing.T) {
t.Parallel()
capturedBody := runAzureInstanceIdentityExchange(t, WithInstanceIdentityAgentName("test-agent"))
assertJSONField(t, capturedBody, "agent_name", "test-agent")
}
func TestAzureInstanceIdentityExchange_OmitsAgentName(t *testing.T) {
t.Parallel()
capturedBody := runAzureInstanceIdentityExchange(t)
assertJSONFieldAbsent(t, capturedBody, "agent_name")
}
func TestGoogleInstanceIdentityExchange_AgentName(t *testing.T) {
t.Parallel()
capturedBody := runGoogleInstanceIdentityExchange(t, WithInstanceIdentityAgentName("test-agent"))
assertJSONField(t, capturedBody, "agent_name", "test-agent")
}
func TestGoogleInstanceIdentityExchange_OmitsAgentName(t *testing.T) {
t.Parallel()
capturedBody := runGoogleInstanceIdentityExchange(t)
assertJSONFieldAbsent(t, capturedBody, "agent_name")
}
func runAWSInstanceIdentityExchange(t *testing.T, opts ...InstanceIdentityOption) []byte {
t.Helper()
var capturedBody []byte
server := newInstanceIdentityServer(t, "/api/v2/workspaceagents/aws-instance-identity", &capturedBody)
defer server.Close()
client := newCodersdkClient(t, server, roundTripFunc(func(req *http.Request) (*http.Response, error) {
switch {
case req.URL.Host == "169.254.169.254" && req.Method == http.MethodPut && req.URL.Path == "/latest/api/token":
return httpResponse(req, http.StatusOK, "fake-imds-token", nil), nil
case req.URL.Host == "169.254.169.254" && req.Method == http.MethodGet && req.URL.Path == "/latest/dynamic/instance-identity/signature":
return httpResponse(req, http.StatusOK, "fakesig", nil), nil
case req.URL.Host == "169.254.169.254" && req.Method == http.MethodGet && req.URL.Path == "/latest/dynamic/instance-identity/document":
return httpResponse(req, http.StatusOK, "fakedoc", nil), nil
default:
return http.DefaultTransport.RoundTrip(req)
}
}))
provider := requireInstanceIdentityProvider(t, WithAWSInstanceIdentity(opts...)(client))
resp, err := provider.TokenExchanger.exchange(context.Background())
require.NoError(t, err)
require.Equal(t, "test-session-token", resp.SessionToken)
return capturedBody
}
func runAzureInstanceIdentityExchange(t *testing.T, opts ...InstanceIdentityOption) []byte {
t.Helper()
var capturedBody []byte
server := newInstanceIdentityServer(t, "/api/v2/workspaceagents/azure-instance-identity", &capturedBody)
defer server.Close()
client := newCodersdkClient(t, server, roundTripFunc(func(req *http.Request) (*http.Response, error) {
switch {
case req.URL.Host == "169.254.169.254" && req.Method == http.MethodGet && req.URL.Path == "/metadata/attested/document":
return httpResponse(req, http.StatusOK, `{"signature":"fakesig","encoding":"fakeenc"}`, http.Header{"Content-Type": []string{"application/json"}}), nil
default:
return http.DefaultTransport.RoundTrip(req)
}
}))
provider := requireInstanceIdentityProvider(t, WithAzureInstanceIdentity(opts...)(client))
resp, err := provider.TokenExchanger.exchange(context.Background())
require.NoError(t, err)
require.Equal(t, "test-session-token", resp.SessionToken)
return capturedBody
}
func runGoogleInstanceIdentityExchange(t *testing.T, opts ...InstanceIdentityOption) []byte {
t.Helper()
var capturedBody []byte
server := newInstanceIdentityServer(t, "/api/v2/workspaceagents/google-instance-identity", &capturedBody)
defer server.Close()
metadataClient := metadata.NewClient(&http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
require.Equal(t, "169.254.169.254", req.URL.Host)
require.Equal(t, http.MethodGet, req.Method)
require.Equal(t, "/computeMetadata/v1/instance/service-accounts/test-service-account/identity", req.URL.Path)
require.Equal(t, "audience=coder&format=full", req.URL.RawQuery)
require.Equal(t, "Google", req.Header.Get("Metadata-Flavor"))
return httpResponse(req, http.StatusOK, "fake-jwt", nil), nil
})})
client := newCodersdkClient(t, server, http.DefaultTransport)
provider := requireInstanceIdentityProvider(t, WithGoogleInstanceIdentity("test-service-account", metadataClient, opts...)(client))
resp, err := provider.TokenExchanger.exchange(context.Background())
require.NoError(t, err)
require.Equal(t, "test-session-token", resp.SessionToken)
return capturedBody
}
func newInstanceIdentityServer(t *testing.T, path string, capturedBody *[]byte) *httptest.Server {
t.Helper()
return httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
require.Equal(t, http.MethodPost, req.Method)
require.Equal(t, path, req.URL.Path)
body, err := io.ReadAll(req.Body)
require.NoError(t, err)
require.NoError(t, req.Body.Close())
*capturedBody = body
rw.Header().Set("Content-Type", "application/json")
require.NoError(t, json.NewEncoder(rw).Encode(AuthenticateResponse{SessionToken: "test-session-token"}))
}))
}
func newCodersdkClient(t *testing.T, server *httptest.Server, transport http.RoundTripper) *codersdk.Client {
t.Helper()
serverURL, err := url.Parse(server.URL)
require.NoError(t, err)
return &codersdk.Client{
URL: serverURL,
HTTPClient: &http.Client{
Transport: transport,
},
}
}
func requireInstanceIdentityProvider(t *testing.T, provider RefreshableSessionTokenProvider) *InstanceIdentitySessionTokenProvider {
t.Helper()
identityProvider, ok := provider.(*InstanceIdentitySessionTokenProvider)
require.True(t, ok)
return identityProvider
}
func httpResponse(req *http.Request, statusCode int, body string, headers http.Header) *http.Response {
if headers == nil {
headers = make(http.Header)
}
return &http.Response{
StatusCode: statusCode,
Header: headers,
Body: io.NopCloser(strings.NewReader(body)),
Request: req,
}
}
func decodeJSONBody(t *testing.T, body []byte) map[string]any {
t.Helper()
var decoded map[string]any
require.NoError(t, json.Unmarshal(body, &decoded))
return decoded
}
func assertJSONField(t *testing.T, body []byte, key string, want string) {
t.Helper()
decoded := decodeJSONBody(t, body)
require.Equal(t, want, decoded[key])
}
func assertJSONFieldAbsent(t *testing.T, body []byte, key string) {
t.Helper()
decoded := decodeJSONBody(t, body)
_, ok := decoded[key]
require.False(t, ok)
}