Files
coder/coderd/aibridged/aibridged_test.go
T
Danny Kopping 5d8ca2e5ce fix: extract key when BYOK header is given with delegated auth (#25688)
Previously we were only extracting the API when _not_ delegating auth;
this is incorrect.

We need to extract the key _always_ when BYOK is intended.

---------

Signed-off-by: Danny Kopping <danny@coder.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-26 19:46:26 +02:00

869 lines
30 KiB
Go

package aibridged_test
import (
"bytes"
"context"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"golang.org/x/xerrors"
"storj.io/drpc"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/aibridge"
"github.com/coder/coder/v2/aibridge/intercept"
agplaibridge "github.com/coder/coder/v2/coderd/aibridge"
"github.com/coder/coder/v2/coderd/aibridged"
mock "github.com/coder/coder/v2/coderd/aibridged/aibridgedmock"
"github.com/coder/coder/v2/coderd/aibridged/proto"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
)
func newTestServer(t *testing.T) (*aibridged.Server, *mock.MockDRPCClient, *mock.MockPooler) {
t.Helper()
logger := slogtest.Make(t, nil)
ctrl := gomock.NewController(t)
client := mock.NewMockDRPCClient(ctrl)
pool := mock.NewMockPooler(ctrl)
conn := &mockDRPCConn{}
client.EXPECT().DRPCConn().AnyTimes().Return(conn)
pool.EXPECT().Shutdown(gomock.Any()).MinTimes(1).Return(nil)
srv, err := aibridged.New(
t.Context(),
pool,
func(ctx context.Context) (aibridged.DRPCClient, error) {
return client, nil
}, logger, testTracer)
require.NoError(t, err, "create new aibridged")
t.Cleanup(func() {
srv.Shutdown(context.Background())
})
return srv, client, pool
}
// mockDRPCConn is a mock implementation of drpc.Conn
type mockDRPCConn struct{}
func (*mockDRPCConn) Close() error { return nil }
func (*mockDRPCConn) Closed() <-chan struct{} { ch := make(chan struct{}); return ch }
func (*mockDRPCConn) Transport() drpc.Transport { return nil }
func (*mockDRPCConn) Invoke(ctx context.Context, rpc string, enc drpc.Encoding, in, out drpc.Message) error {
return nil
}
func (*mockDRPCConn) NewStream(ctx context.Context, rpc string, enc drpc.Encoding) (drpc.Stream, error) {
// nolint:nilnil // Chillchill.
return nil, nil
}
func TestServeHTTP_FailureModes(t *testing.T) {
t.Parallel()
defaultHeaders := map[string]string{"Authorization": "Bearer key"}
httpClient := &http.Client{}
cases := []struct {
name string
reqHeaders map[string]string
applyMocksFn func(client *mock.MockDRPCClient, pool *mock.MockPooler)
dialerFn aibridged.Dialer
contextFn func() context.Context
expectedErr error
expectedStatus int
}{
// Authnz-related failures.
{
name: "no auth key",
reqHeaders: make(map[string]string),
expectedErr: aibridged.ErrNoAuthKey,
expectedStatus: http.StatusBadRequest,
},
{
name: "unrecognized header",
reqHeaders: map[string]string{
codersdk.SessionTokenHeader: "key", // Coder-Session-Token is not supported; requests originate with AI clients, not coder CLI.
},
applyMocksFn: func(client *mock.MockDRPCClient, _ *mock.MockPooler) {},
expectedErr: aibridged.ErrNoAuthKey,
expectedStatus: http.StatusBadRequest,
},
{
name: "unauthorized",
applyMocksFn: func(client *mock.MockDRPCClient, _ *mock.MockPooler) {
client.EXPECT().IsAuthorized(gomock.Any(), gomock.Any()).AnyTimes().Return(nil, xerrors.New("not authorized"))
},
expectedErr: aibridged.ErrUnauthorized,
expectedStatus: http.StatusForbidden,
},
{
name: "invalid key owner ID",
applyMocksFn: func(client *mock.MockDRPCClient, _ *mock.MockPooler) {
client.EXPECT().IsAuthorized(gomock.Any(), gomock.Any()).AnyTimes().Return(&proto.IsAuthorizedResponse{OwnerId: "oops"}, nil)
},
expectedErr: aibridged.ErrUnauthorized,
expectedStatus: http.StatusForbidden,
},
// TODO: coderd connection-related failures.
// Pool-related failures.
{
name: "pool instance",
applyMocksFn: func(client *mock.MockDRPCClient, pool *mock.MockPooler) {
// Should pass authorization.
client.EXPECT().IsAuthorized(gomock.Any(), gomock.Any()).AnyTimes().Return(&proto.IsAuthorizedResponse{OwnerId: uuid.NewString()}, nil)
// But fail when acquiring a pool instance.
pool.EXPECT().Acquire(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil, xerrors.New("oops"))
},
expectedErr: aibridged.ErrAcquireRequestHandler,
expectedStatus: http.StatusInternalServerError,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
srv, client, pool := newTestServer(t)
conn := &mockDRPCConn{}
client.EXPECT().DRPCConn().AnyTimes().Return(conn)
if tc.applyMocksFn != nil {
tc.applyMocksFn(client, pool)
}
httpSrv := httptest.NewServer(srv)
ctx := testutil.Context(t, testutil.WaitShort)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, httpSrv.URL+"/openai/v1/chat/completions", nil)
require.NoError(t, err, "make request to test server")
headers := defaultHeaders
if tc.reqHeaders != nil {
headers = tc.reqHeaders
}
for k, v := range headers {
req.Header.Set(k, v)
}
resp, err := httpClient.Do(req)
t.Cleanup(func() {
if resp == nil || resp.Body == nil {
return
}
resp.Body.Close()
})
require.NoError(t, err)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err, "read response body")
require.Contains(t, string(body), tc.expectedErr.Error())
require.Equal(t, tc.expectedStatus, resp.StatusCode)
})
}
}
// When the request context carries a delegated API key ID (set by the
// in-process transport on behalf of a trusted caller like chatd), the handler
// must authenticate via the key_id field, skipping the header-based key
// extraction entirely. Validation succeeds or fails exactly as it would for a
// real API key. Delegation is orthogonal to BYOK: in BYOK mode the user's own
// LLM credentials must still be forwarded upstream while the Coder governance
// token is stripped.
func TestServeHTTP_DelegatedAPIKey(t *testing.T) {
t.Parallel()
const testKeyID = "abcdef1234"
tests := []struct {
name string
reqHeaders map[string]string
applyMocks func(t *testing.T, client *mock.MockDRPCClient, pool *mock.MockPooler, mockH *mockHandler)
expectStatus int
expectHandled bool
expectPresent map[string]string
expectAbsent []string
}{
{
// Delegated + centralized: identity comes from the
// api key ID on the context, in lieu of a session
// token. No header credentials are sent and SessionKey
// is empty downstream.
name: "valid centralized",
applyMocks: func(t *testing.T, client *mock.MockDRPCClient, pool *mock.MockPooler, mockH *mockHandler) {
client.EXPECT().IsAuthorized(gomock.Any(), gomock.Any()).DoAndReturn(
func(_ context.Context, in *proto.IsAuthorizedRequest) (*proto.IsAuthorizedResponse, error) {
assert.Equal(t, testKeyID, in.GetKeyId(), "handler must use KeyId for delegated requests")
assert.Empty(t, in.GetKey(), "handler must not set Key for delegated requests")
return &proto.IsAuthorizedResponse{
OwnerId: uuid.NewString(),
ApiKeyId: testKeyID,
Username: "u",
}, nil
})
pool.EXPECT().Acquire(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
func(_ context.Context, req aibridged.Request, _ aibridged.ClientFunc, _ aibridged.MCPProxyBuilder) (http.Handler, error) {
assert.Empty(t, req.SessionKey,
"delegated centralized request carries no session token")
return mockH, nil
})
},
expectStatus: http.StatusOK,
expectHandled: true,
expectAbsent: []string{
"Authorization",
"X-Api-Key",
agplaibridge.HeaderCoderToken,
},
},
{
name: "valid BYOK preserves user credentials",
reqHeaders: map[string]string{
// Marks BYOK; this header must be stripped before
// forwarding upstream. Its value is what gets
// surfaced downstream as the SessionKey because
// ExtractAuthToken prefers HeaderCoderToken.
agplaibridge.HeaderCoderToken: "coder-token-byok",
// The user's own LLM credential; must be preserved.
"Authorization": "Bearer sk-ant-oat01-user-token",
},
applyMocks: func(t *testing.T, client *mock.MockDRPCClient, pool *mock.MockPooler, mockH *mockHandler) {
client.EXPECT().IsAuthorized(gomock.Any(), gomock.Any()).Return(&proto.IsAuthorizedResponse{
OwnerId: uuid.NewString(),
ApiKeyId: testKeyID,
Username: "u",
}, nil)
pool.EXPECT().Acquire(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
func(_ context.Context, req aibridged.Request, _ aibridged.ClientFunc, _ aibridged.MCPProxyBuilder) (http.Handler, error) {
assert.Equal(t, "coder-token-byok", req.SessionKey,
"BYOK delegated request must still surface the extracted Coder token as SessionKey")
return mockH, nil
})
},
expectStatus: http.StatusOK,
expectHandled: true,
expectPresent: map[string]string{
"Authorization": "Bearer sk-ant-oat01-user-token",
},
expectAbsent: []string{
agplaibridge.HeaderCoderToken,
},
},
{
name: "invalid",
applyMocks: func(_ *testing.T, client *mock.MockDRPCClient, _ *mock.MockPooler, _ *mockHandler) {
client.EXPECT().IsAuthorized(gomock.Any(), gomock.Any()).Return(nil, xerrors.New("unknown key"))
},
expectStatus: http.StatusForbidden,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
srv, client, pool := newTestServer(t)
conn := &mockDRPCConn{}
client.EXPECT().DRPCConn().AnyTimes().Return(conn)
mockH := &mockHandler{}
tc.applyMocks(t, client, pool, mockH)
ctx := agplaibridge.WithDelegatedAPIKeyID(testutil.Context(t, testutil.WaitShort), testKeyID)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, "http://aibridge/openai/v1/chat/completions", nil)
require.NoError(t, err)
for k, v := range tc.reqHeaders {
req.Header.Set(k, v)
}
rw := httptest.NewRecorder()
srv.ServeHTTP(rw, req)
require.Equal(t, tc.expectStatus, rw.Code)
if tc.expectHandled {
require.NotNil(t, mockH.headersReceived, "downstream handler must be invoked")
for h, v := range tc.expectPresent {
require.Equal(t, v, mockH.headersReceived.Get(h), "header %q must be preserved", h)
}
for _, h := range tc.expectAbsent {
require.Empty(t, mockH.headersReceived.Get(h), "header %q must be stripped", h)
}
} else {
require.Nil(t, mockH.headersReceived, "downstream handler must not be invoked on auth failure")
}
})
}
}
// End-to-end: a real transport factory wired to a real server, with BYOK in
// effect. The delegated key ID identifies the user (no Coder token over the
// wire) while the user's own LLM credentials in Authorization must flow
// through to the downstream handler. The Coder governance token, if set by
// the caller, must be stripped.
func TestServeHTTP_DelegatedAPIKey_BYOK_Integration(t *testing.T) {
t.Parallel()
const (
testKeyID = "abcdef1234"
// nolint:gosec // Fake LLM credential for assertion comparison.
userLLMToken = "Bearer sk-ant-oat01-user-byok-token"
)
srv, client, pool := newTestServer(t)
conn := &mockDRPCConn{}
client.EXPECT().DRPCConn().AnyTimes().Return(conn)
mockH := &mockHandler{}
client.EXPECT().IsAuthorized(gomock.Any(), gomock.Any()).DoAndReturn(
func(_ context.Context, in *proto.IsAuthorizedRequest) (*proto.IsAuthorizedResponse, error) {
assert.Equal(t, testKeyID, in.GetKeyId(), "delegated identity must be carried in KeyId")
assert.Empty(t, in.GetKey(), "Key must not be set on delegated requests")
return &proto.IsAuthorizedResponse{
OwnerId: uuid.NewString(),
ApiKeyId: testKeyID,
Username: "u",
}, nil
})
pool.EXPECT().Acquire(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(mockH, nil)
factory := aibridged.NewTransportFactory(srv)
rt, err := factory.TransportFor("openai", agplaibridge.SourceAgents)
require.NoError(t, err)
ctx := agplaibridge.WithDelegatedAPIKeyID(testutil.Context(t, testutil.WaitShort), testKeyID)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, "http://aibridge/anthropic/v1/messages", nil)
require.NoError(t, err)
// HeaderCoderToken marks the request as BYOK. Its value is irrelevant on
// the delegated path (identity comes from context) and it must be
// stripped before forwarding upstream.
req.Header.Set(agplaibridge.HeaderCoderToken, "ignored-on-delegated-path")
// The user's own LLM credential; must reach the downstream handler.
req.Header.Set("Authorization", userLLMToken)
resp, err := rt.RoundTrip(req)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)
require.NotNil(t, mockH.headersReceived, "downstream handler must be invoked")
require.Equal(t, userLLMToken, mockH.headersReceived.Get("Authorization"),
"user's BYOK credential must be preserved end-to-end")
require.Empty(t, mockH.headersReceived.Get(agplaibridge.HeaderCoderToken),
"Coder governance token must be stripped before forwarding upstream")
}
// End-to-end: a real transport factory wired to a real server. Verifies the
// delegated key ID survives the in-memory round-trip and is treated as the
// authoritative caller identity by the handler, without any HTTP-layer header
// extraction.
func TestServeHTTP_DelegatedAPIKey_Integration(t *testing.T) {
t.Parallel()
const testKeyID = "abcdef1234"
srv, client, pool := newTestServer(t)
conn := &mockDRPCConn{}
client.EXPECT().DRPCConn().AnyTimes().Return(conn)
mockH := &mockHandler{}
client.EXPECT().IsAuthorized(gomock.Any(), gomock.Any()).DoAndReturn(
func(_ context.Context, in *proto.IsAuthorizedRequest) (*proto.IsAuthorizedResponse, error) {
assert.Equal(t, testKeyID, in.GetKeyId())
assert.Empty(t, in.GetKey())
return &proto.IsAuthorizedResponse{
OwnerId: uuid.NewString(),
ApiKeyId: testKeyID,
Username: "u",
}, nil
})
pool.EXPECT().Acquire(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(mockH, nil)
factory := aibridged.NewTransportFactory(srv)
rt, err := factory.TransportFor("openai", agplaibridge.SourceAgents)
require.NoError(t, err)
ctx := agplaibridge.WithDelegatedAPIKeyID(testutil.Context(t, testutil.WaitShort), testKeyID)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, "http://aibridge/openai/v1/chat/completions", nil)
require.NoError(t, err)
resp, err := rt.RoundTrip(req)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)
require.NotNil(t, mockH.headersReceived, "downstream handler must observe the delegated request")
}
func TestServeHTTP_StripCoderToken(t *testing.T) {
t.Parallel()
cases := []struct {
name string
reqHeaders map[string]string
expectPresent map[string]string // header → expected value
expectAbsent []string // headers that must be gone
}{
{
// Centralized: the client sets Authorization and X-Api-Key,
// but does not include HeaderCoderToken.
// All auth headers are stripped.
name: "centralized",
reqHeaders: map[string]string{
"Authorization": "Bearer coder-token",
"X-Api-Key": "sk-ant-api03-user-key",
},
expectAbsent: []string{
"Authorization",
"X-Api-Key",
agplaibridge.HeaderCoderToken,
},
},
{
// BYOK with access token: Coder token in BYOK header,
// user's access token in Authorization. Only the
// BYOK header is stripped.
name: "byok bearer token",
reqHeaders: map[string]string{
agplaibridge.HeaderCoderToken: "coder-token",
"Authorization": "Bearer sk-ant-oat01-user-oauth-token",
},
expectPresent: map[string]string{
"Authorization": "Bearer sk-ant-oat01-user-oauth-token",
},
expectAbsent: []string{
agplaibridge.HeaderCoderToken,
},
},
{
// BYOK with personal API key: Coder token in BYOK header,
// user's API key in X-Api-Key. Only the BYOK header is
// stripped.
name: "byok api key",
reqHeaders: map[string]string{
agplaibridge.HeaderCoderToken: "coder-token",
"X-Api-Key": "sk-ant-api03-user-key",
},
expectPresent: map[string]string{
"X-Api-Key": "sk-ant-api03-user-key",
},
expectAbsent: []string{
agplaibridge.HeaderCoderToken,
},
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
mockH := &mockHandler{}
srv, client, pool := newTestServer(t)
conn := &mockDRPCConn{}
client.EXPECT().DRPCConn().AnyTimes().Return(conn)
client.EXPECT().IsAuthorized(gomock.Any(), gomock.Any()).AnyTimes().Return(&proto.IsAuthorizedResponse{OwnerId: uuid.NewString()}, nil)
pool.EXPECT().Acquire(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(mockH, nil)
httpSrv := httptest.NewServer(srv)
t.Cleanup(httpSrv.Close)
ctx := testutil.Context(t, testutil.WaitShort)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, httpSrv.URL+"/openai/v1/chat/completions", nil)
require.NoError(t, err)
for k, v := range tc.reqHeaders {
req.Header.Set(k, v)
}
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)
require.NotNil(t, mockH.headersReceived)
for header, expected := range tc.expectPresent {
require.Equal(t, expected, mockH.headersReceived.Get(header),
"header %q should be preserved with value %q", header, expected)
}
for _, header := range tc.expectAbsent {
require.Empty(t, mockH.headersReceived.Get(header),
"header %q should be stripped", header)
}
// HeaderCoderToken should always be stripped
require.Empty(t, mockH.headersReceived.Get(agplaibridge.HeaderCoderToken),
"header %q should be stripped", agplaibridge.HeaderCoderToken)
})
}
}
func TestExtractAuthToken(t *testing.T) {
t.Parallel()
cases := []struct {
name string
headers map[string]string
expectedKey string
}{
{
name: "none",
},
{
name: "authorization/invalid",
headers: map[string]string{"authorization": "invalid"},
},
{
name: "authorization/bearer empty",
headers: map[string]string{"authorization": "bearer"},
},
{
name: "authorization/bearer ok",
headers: map[string]string{"authorization": "bearer key"},
expectedKey: "key",
},
{
name: "authorization/case",
headers: map[string]string{"AUTHORIZATION": "BEARer key"},
expectedKey: "key",
},
{
name: "authorization/priority over x-api-key",
headers: map[string]string{
"Authorization": "Bearer auth-token",
"X-Api-Key": "api-key",
},
expectedKey: "auth-token",
},
{
name: "x-api-key/empty",
headers: map[string]string{"X-Api-Key": ""},
},
{
name: "x-api-key/ok",
headers: map[string]string{"X-Api-Key": "key"},
expectedKey: "key",
},
// BYOK: X-Coder-AI-Governance-Token carries the Coder
// token and has the highest priority.
{
name: "byok/empty",
headers: map[string]string{agplaibridge.HeaderCoderToken: ""},
},
{
name: "byok/ok",
headers: map[string]string{agplaibridge.HeaderCoderToken: "coder-token"},
expectedKey: "coder-token",
},
{
name: "byok/priority over all",
headers: map[string]string{
agplaibridge.HeaderCoderToken: "coder-token",
"Authorization": "Bearer oauth-token",
"X-Api-Key": "api-key",
},
expectedKey: "coder-token",
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
headers := make(http.Header, len(tc.headers))
for k, v := range tc.headers {
headers.Add(k, v)
}
key := agplaibridge.ExtractAuthToken(headers)
require.Equal(t, tc.expectedKey, key)
})
}
}
var _ http.Handler = &mockHandler{}
type mockHandler struct {
headersReceived http.Header
}
func (h *mockHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
h.headersReceived = r.Header.Clone()
rw.WriteHeader(http.StatusOK)
_, _ = rw.Write([]byte(r.URL.Path))
}
// TestServeHTTP_ActorHeaders validates that actor headers are correctly forwarded to
// upstream AI providers when SendActorHeaders is enabled in the provider configuration.
// These headers allow upstream providers to identify the user making the request for
// tracking and auditing purposes.
func TestServeHTTP_ActorHeaders(t *testing.T) {
t.Parallel()
testUsername := "testuser"
testUserID := uuid.New()
cases := []struct {
path string
}{
// Not a complete set of paths; we're not testing the specific APIs - just the provider configs.
{
path: "/openai/v1/chat/completions",
},
{
path: "/anthropic/v1/messages",
},
}
for _, tc := range cases {
t.Run(tc.path, func(t *testing.T) {
t.Parallel()
// Setup mock upstream AI server that captures headers.
var receivedHeaders http.Header
upstreamSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedHeaders = r.Header.Clone()
w.WriteHeader(http.StatusTeapot)
_, _ = w.Write([]byte(`i am a teapot`))
}))
t.Cleanup(upstreamSrv.Close)
// Setup with SendActorHeaders enabled.
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
ctrl := gomock.NewController(t)
client := mock.NewMockDRPCClient(ctrl)
// Create providers with SendActorHeaders=true.
providers := []aibridge.Provider{
aibridge.NewOpenAIProvider(aibridge.OpenAIConfig{
BaseURL: upstreamSrv.URL,
SendActorHeaders: true,
}),
aibridge.NewAnthropicProvider(aibridge.AnthropicConfig{
BaseURL: upstreamSrv.URL,
SendActorHeaders: true,
}, nil),
}
pool, err := aibridged.NewCachedBridgePool(aibridged.DefaultPoolOptions, providers, logger, nil, testTracer)
require.NoError(t, err)
conn := &mockDRPCConn{}
client.EXPECT().DRPCConn().AnyTimes().Return(conn)
// Return authorization response with user ID and username.
client.EXPECT().IsAuthorized(gomock.Any(), gomock.Any()).AnyTimes().Return(&proto.IsAuthorizedResponse{
OwnerId: testUserID.String(),
Username: testUsername,
}, nil)
client.EXPECT().GetMCPServerConfigs(gomock.Any(), gomock.Any()).AnyTimes().Return(&proto.GetMCPServerConfigsResponse{}, nil)
client.EXPECT().RecordInterception(gomock.Any(), gomock.Any()).AnyTimes().Return(&proto.RecordInterceptionResponse{}, nil)
client.EXPECT().RecordInterceptionEnded(gomock.Any(), gomock.Any()).AnyTimes()
// Given: aibridged is started.
srv, err := aibridged.New(t.Context(), pool, func(ctx context.Context) (aibridged.DRPCClient, error) {
return client, nil
}, logger, testTracer)
require.NoError(t, err, "create new aibridged")
t.Cleanup(func() {
_ = srv.Shutdown(testutil.Context(t, testutil.WaitShort))
})
// When: a request is made to aibridged.
ctx := testutil.Context(t, testutil.WaitShort)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, tc.path, bytes.NewBufferString(`{}`))
require.NoError(t, err, "make request to test server")
req.Header.Add("Authorization", "Bearer key")
req.Header.Add("Accept", "application/json")
// When: aibridged handles the request.
rec := httptest.NewRecorder()
srv.ServeHTTP(rec, req)
// Then: the actor headers should be present in the upstream request.
require.NotEmpty(t, receivedHeaders, "upstream server should have received headers")
// Verify the actor ID header is present with the correct value.
actorIDHeader := receivedHeaders.Get(intercept.ActorIDHeader())
assert.Equal(t, testUserID.String(), actorIDHeader, "actor ID header should contain user ID")
// Verify the actor metadata header for username is present.
usernameHeader := receivedHeaders.Get(intercept.ActorMetadataHeader("Username"))
assert.Equal(t, testUsername, usernameHeader, "actor metadata username header should contain username")
})
}
}
// TestRouting validates that a request which originates with aibridged will be handled
// by coder/aibridge's handling logic in a provider-specific manner.
// We must validate that logic that pertains to coder/coder is exercised.
// aibridge will only handle certain routes; we don't need to test these exhaustively
// (that's coder/aibridge's responsibility), but we do need to validate that it handles
// requests correctly.
func TestRouting(t *testing.T) {
t.Parallel()
cases := []struct {
name string
path string
expectedStatus int
expectedHits int // Expected hits to the upstream server.
}{
{
name: "unsupported",
path: "/this-route-does-not-exist",
expectedStatus: http.StatusNotFound,
expectedHits: 0,
},
{
name: "openai chat completions",
path: "/openai/v1/chat/completions",
expectedStatus: http.StatusTeapot, // Nonsense status to indicate server was hit.
expectedHits: 1,
},
{
name: "anthropic messages",
path: "/anthropic/v1/messages",
expectedStatus: http.StatusTeapot, // Nonsense status to indicate server was hit.
expectedHits: 1,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
// Setup mock upstream AI server.
upstreamSrv := &mockAIUpstreamServer{}
openaiSrv := httptest.NewServer(upstreamSrv)
antSrv := httptest.NewServer(upstreamSrv)
t.Cleanup(openaiSrv.Close)
t.Cleanup(antSrv.Close)
// Setup.
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
ctrl := gomock.NewController(t)
client := mock.NewMockDRPCClient(ctrl)
providers := []aibridge.Provider{
aibridge.NewOpenAIProvider(aibridge.OpenAIConfig{BaseURL: openaiSrv.URL}),
aibridge.NewAnthropicProvider(aibridge.AnthropicConfig{BaseURL: antSrv.URL}, nil),
}
pool, err := aibridged.NewCachedBridgePool(aibridged.DefaultPoolOptions, providers, logger, nil, testTracer)
require.NoError(t, err)
conn := &mockDRPCConn{}
client.EXPECT().DRPCConn().AnyTimes().Return(conn)
client.EXPECT().IsAuthorized(gomock.Any(), gomock.Any()).AnyTimes().Return(&proto.IsAuthorizedResponse{OwnerId: uuid.NewString()}, nil)
client.EXPECT().GetMCPServerConfigs(gomock.Any(), gomock.Any()).AnyTimes().Return(&proto.GetMCPServerConfigsResponse{}, nil)
// This is the only recording we really care about in this test. This is called before the provider-specific logic processes
// the incoming request, and anything beyond that is the responsibility of coder/aibridge to test.
var interceptionID string
client.EXPECT().RecordInterception(gomock.Any(), gomock.Any()).Times(tc.expectedHits).DoAndReturn(func(ctx context.Context, in *proto.RecordInterceptionRequest) (*proto.RecordInterceptionResponse, error) {
interceptionID = in.GetId()
return &proto.RecordInterceptionResponse{}, nil
})
client.EXPECT().RecordInterceptionEnded(gomock.Any(), gomock.Any()).Times(tc.expectedHits)
// Given: aibridged is started.
srv, err := aibridged.New(t.Context(), pool, func(ctx context.Context) (aibridged.DRPCClient, error) {
return client, nil
}, logger, testTracer)
require.NoError(t, err, "create new aibridged")
t.Cleanup(func() {
_ = srv.Shutdown(testutil.Context(t, testutil.WaitShort))
})
// When: a request is made to aibridged.
ctx := testutil.Context(t, testutil.WaitShort)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, tc.path, bytes.NewBufferString(`{}`))
require.NoError(t, err, "make request to test server")
req.Header.Add("Authorization", "Bearer key")
req.Header.Add("Accept", "application/json")
// When: aibridged handles the request.
rec := httptest.NewRecorder()
srv.ServeHTTP(rec, req)
// Then: the upstream server will have received a number of hits.
// NOTE: we *expect* the interceptions to fail because [mockAIUpstreamServer] returns a nonsense status code.
// We only need to test that the request was routed, NOT processed.
require.Equal(t, tc.expectedStatus, rec.Code)
assert.EqualValues(t, tc.expectedHits, upstreamSrv.Hits())
if tc.expectedHits > 0 {
_, err = uuid.Parse(interceptionID)
require.NoError(t, err, "parse interception ID")
}
})
}
}
// TestServeHTTP_StripInternalHeaders verifies that internal X-Coder-*
// headers are never forwarded to upstream LLM providers.
func TestServeHTTP_StripInternalHeaders(t *testing.T) {
t.Parallel()
cases := []struct {
name string
header string
value string
}{
{
name: "X-Coder-AI-Governance-Token",
header: agplaibridge.HeaderCoderToken,
value: "coder-token",
},
{
name: "X-Coder-AI-Governance-Request-Id",
header: agplaibridge.HeaderCoderRequestID,
value: uuid.NewString(),
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
mockH := &mockHandler{}
srv, client, pool := newTestServer(t)
conn := &mockDRPCConn{}
client.EXPECT().DRPCConn().AnyTimes().Return(conn)
client.EXPECT().IsAuthorized(gomock.Any(), gomock.Any()).AnyTimes().Return(&proto.IsAuthorizedResponse{OwnerId: uuid.NewString()}, nil)
pool.EXPECT().Acquire(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(mockH, nil)
httpSrv := httptest.NewServer(srv)
t.Cleanup(httpSrv.Close)
ctx := testutil.Context(t, testutil.WaitShort)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, httpSrv.URL+"/anthropic/v1/messages", nil)
require.NoError(t, err)
// Always set a valid auth token so the request reaches
// the upstream handler.
req.Header.Set("Authorization", "Bearer coder-token")
req.Header.Set(tc.header, tc.value)
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)
require.NotNil(t, mockH.headersReceived)
// Assert no X-Coder-* headers were forwarded upstream.
for name := range mockH.headersReceived {
require.NotContains(t, name, "X-Coder-",
"internal header %q must not be forwarded to upstream providers", name)
}
})
}
}