From 8652ef3e3b328790d6fe730c6843381febf36ed3 Mon Sep 17 00:00:00 2001 From: Danny Kopping Date: Mon, 25 May 2026 18:04:12 +0200 Subject: [PATCH] refactor: route `TransportFor` by provider name (#25650) Delegate `aibridge` routing responsibility to the in-memory transport layer. Co-authored-by: Claude Opus 4.7 (1M context) --- coderd/aibridge/factory.go | 12 ++++-- coderd/aibridge_test.go | 15 ++++---- coderd/aibridged/aibridged_test.go | 4 +- coderd/aibridged/transport.go | 40 +++++++++++++++---- coderd/aibridged/transport_test.go | 62 ++++++++++++++++++++++++------ 5 files changed, 101 insertions(+), 32 deletions(-) diff --git a/coderd/aibridge/factory.go b/coderd/aibridge/factory.go index 41f25fb454..2746195c22 100644 --- a/coderd/aibridge/factory.go +++ b/coderd/aibridge/factory.go @@ -3,8 +3,6 @@ package aibridge import ( "context" "net/http" - - "github.com/google/uuid" ) // Source identifies the call site that asked aibridge for a transport. It is @@ -53,14 +51,20 @@ func DelegatedAPIKeyIDFromContext(ctx context.Context) (string, bool) { } // TransportFactory returns an [http.RoundTripper] that dispatches an aibridge -// request in-process for a given ai_providers row. +// request in-process for a given provider instance name. // // Implementations live in coderd/aibridged. coderd registers an in-process // factory on coderd.API.AIBridgeTransportFactory at startup so callers route // traffic through the daemon without going through the gated HTTP route. // +// The returned RoundTripper is responsible for adapting the caller's request +// to the aibridge daemon's mount path: callers hand it an upstream-shaped +// request and the transport rewrites URL.Path to "/api/v2/aibridge//..." +// before dispatching. Routing keys on the provider's instance name so callers +// can use the same string the proxy daemon and the bridge mount use. +// // Source is informational: implementations must not gate on it. It is attached // to the request context so handlers can include it in logs and metrics. type TransportFactory interface { - TransportFor(providerID uuid.UUID, source Source) (http.RoundTripper, error) + TransportFor(providerName string, source Source) (http.RoundTripper, error) } diff --git a/coderd/aibridge_test.go b/coderd/aibridge_test.go index de16bdaf27..b73e366161 100644 --- a/coderd/aibridge_test.go +++ b/coderd/aibridge_test.go @@ -6,7 +6,6 @@ import ( "net/http/httptest" "testing" - "github.com/google/uuid" "github.com/stretchr/testify/require" "github.com/coder/coder/v2/coderd/aibridge" @@ -23,12 +22,12 @@ type stubTransportFactory struct { } type callRecord struct { - providerID uuid.UUID - source aibridge.Source + providerName string + source aibridge.Source } -func (f *stubTransportFactory) TransportFor(providerID uuid.UUID, source aibridge.Source) (http.RoundTripper, error) { - f.calls <- callRecord{providerID: providerID, source: source} +func (f *stubTransportFactory) TransportFor(providerName string, source aibridge.Source) (http.RoundTripper, error) { + f.calls <- callRecord{providerName: providerName, source: source} return &handlerRoundTripper{handler: f.handler}, nil } @@ -71,14 +70,14 @@ func TestAIBridgeTransportFactory_Registration(t *testing.T) { loaded := api.AIBridgeTransportFactory.Load() require.NotNil(t, loaded) - providerID := uuid.New() - rt, err := (*loaded).TransportFor(providerID, aibridge.SourceAgents) + providerName := "openai" + rt, err := (*loaded).TransportFor(providerName, aibridge.SourceAgents) require.NoError(t, err) require.NotNil(t, rt) select { case got := <-stub.calls: - require.Equal(t, providerID, got.providerID) + require.Equal(t, providerName, got.providerName) require.Equal(t, aibridge.SourceAgents, got.source) default: t.Fatal("factory was not invoked") diff --git a/coderd/aibridged/aibridged_test.go b/coderd/aibridged/aibridged_test.go index 229ff260a5..caa162888c 100644 --- a/coderd/aibridged/aibridged_test.go +++ b/coderd/aibridged/aibridged_test.go @@ -321,7 +321,7 @@ func TestServeHTTP_DelegatedAPIKey_BYOK_Integration(t *testing.T) { pool.EXPECT().Acquire(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(mockH, nil) factory := aibridged.NewTransportFactory(srv) - rt, err := factory.TransportFor(uuid.New(), agplaibridge.SourceAgents) + rt, err := factory.TransportFor("openai", agplaibridge.SourceAgents) require.NoError(t, err) ctx := agplaibridge.WithDelegatedAPIKeyID(testutil.Context(t, testutil.WaitShort), testKeyID) @@ -373,7 +373,7 @@ func TestServeHTTP_DelegatedAPIKey_Integration(t *testing.T) { pool.EXPECT().Acquire(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(mockH, nil) factory := aibridged.NewTransportFactory(srv) - rt, err := factory.TransportFor(uuid.New(), agplaibridge.SourceAgents) + rt, err := factory.TransportFor("openai", agplaibridge.SourceAgents) require.NoError(t, err) ctx := agplaibridge.WithDelegatedAPIKeyID(testutil.Context(t, testutil.WaitShort), testKeyID) diff --git a/coderd/aibridged/transport.go b/coderd/aibridged/transport.go index 214e3d6399..95b41f860e 100644 --- a/coderd/aibridged/transport.go +++ b/coderd/aibridged/transport.go @@ -4,14 +4,21 @@ import ( "fmt" "io" "net/http" + "net/url" "sync" - "github.com/google/uuid" "golang.org/x/xerrors" "github.com/coder/coder/v2/coderd/aibridge" ) +// aibridgeRootPath is the URL prefix the in-memory aibridged handler +// registers all of its routes under. The in-process round-tripper +// prepends this plus the provider name to every request before +// dispatch so callers can hand it upstream-shaped requests without +// knowing the daemon's mount layout. +const aibridgeRootPath = "/api/v2/aibridge" + // NewTransportFactory returns an [aibridge.TransportFactory] whose RoundTripper // dispatches requests to handler in-process, streaming the response body // through an [io.Pipe] so SSE/NDJSON/chunked responses propagate token-by-token @@ -28,20 +35,28 @@ type transportFactory struct { } // TransportFor returns an in-process [http.RoundTripper] that dispatches -// requests through the aibridged handler. The source is attached to the -// request context for downstream logging; routing does not depend on it. -func (f *transportFactory) TransportFor(_ uuid.UUID, source aibridge.Source) (http.RoundTripper, error) { +// requests through the aibridged handler. The provider name is the routing +// key the daemon mounts on; the round-tripper rewrites each request's URL +// path to "/api/v2/aibridge//..." before dispatching so +// callers can build upstream-shaped requests and stay agnostic of the +// daemon's mount layout. The source is attached to the request context for +// downstream logging; routing does not depend on it. +func (f *transportFactory) TransportFor(providerName string, source aibridge.Source) (http.RoundTripper, error) { if f.handler == nil { return nil, xerrors.New("aibridged handler not registered") } - return &inMemoryRoundTripper{handler: f.handler, source: source}, nil + if providerName == "" { + return nil, xerrors.New("provider name is required") + } + return &inMemoryRoundTripper{handler: f.handler, providerName: providerName, source: source}, nil } // inMemoryRoundTripper implements [http.RoundTripper] by invoking handler // in a goroutine and streaming its response back through an [io.Pipe]. type inMemoryRoundTripper struct { - handler http.Handler - source aibridge.Source + handler http.Handler + providerName string + source aibridge.Source } func (t *inMemoryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { @@ -53,6 +68,17 @@ func (t *inMemoryRoundTripper) RoundTrip(req *http.Request) (*http.Response, err return nil, xerrors.New("aibridged in-memory transport requires WithDelegatedAPIKeyID on the request context") } + // Adapt the caller's upstream-shaped URL to the daemon's mount layout: + // "/api/v2/aibridge//". Done here so + // callers do not need to encode the mount prefix or the provider + // routing key into the requests they hand to the transport. + newPath, err := url.JoinPath(aibridgeRootPath, t.providerName, req.URL.Path) + if err != nil { + return nil, xerrors.Errorf("rewrite request URL for provider %q: %w", t.providerName, err) + } + req = req.Clone(req.Context()) + req.URL.Path = newPath + pr, pw := io.Pipe() rw := &pipeResponseWriter{ header: http.Header{}, diff --git a/coderd/aibridged/transport_test.go b/coderd/aibridged/transport_test.go index 9dcae475fe..6be4862c99 100644 --- a/coderd/aibridged/transport_test.go +++ b/coderd/aibridged/transport_test.go @@ -10,7 +10,6 @@ import ( "sync" "testing" - "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/xerrors" @@ -26,7 +25,7 @@ func TestTransportFactory_TransportFor(t *testing.T) { t.Run("ReturnsTransport", func(t *testing.T) { t.Parallel() f := aibridged.NewTransportFactory(http.NotFoundHandler()) - rt, err := f.TransportFor(uuid.New(), aibridge.SourceAgents) + rt, err := f.TransportFor("openai", aibridge.SourceAgents) require.NoError(t, err) require.NotNil(t, rt) }) @@ -34,10 +33,51 @@ func TestTransportFactory_TransportFor(t *testing.T) { t.Run("NilHandlerErrors", func(t *testing.T) { t.Parallel() f := aibridged.NewTransportFactory(nil) - _, err := f.TransportFor(uuid.New(), aibridge.SourceAgents) + _, err := f.TransportFor("openai", aibridge.SourceAgents) require.Error(t, err) }) + t.Run("EmptyProviderErrors", func(t *testing.T) { + t.Parallel() + f := aibridged.NewTransportFactory(http.NotFoundHandler()) + _, err := f.TransportFor("", aibridge.SourceAgents) + require.Error(t, err) + }) + + t.Run("RewritesURLToAibridgeMount", func(t *testing.T) { + t.Parallel() + + // The round-tripper must adapt an upstream-shaped URL.Path + // ("/v1/messages") to the aibridge mount layout + // ("/api/v2/aibridge//v1/messages") so callers don't + // have to encode the daemon's routing key into their requests. + got := make(chan string, 1) + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + got <- r.URL.Path + w.WriteHeader(http.StatusOK) + }) + + rt, err := aibridged.NewTransportFactory(handler).TransportFor("my-anthropic", aibridge.SourceAgents) + require.NoError(t, err) + + ctx := aibridge.WithDelegatedAPIKeyID(testutil.Context(t, testutil.WaitShort), "test-key-id") + req, err := http.NewRequestWithContext(ctx, http.MethodPost, "http://upstream/v1/messages", nil) + require.NoError(t, err) + + // The caller's req.URL.Path is the upstream shape. Capture it so + // we can prove the transport mutates a clone, not the caller's + // request, after RoundTrip returns. + origPath := req.URL.Path + + resp, err := rt.RoundTrip(req) + require.NoError(t, err) + defer resp.Body.Close() + + require.Equal(t, "/api/v2/aibridge/my-anthropic/v1/messages", <-got) + require.Equal(t, origPath, req.URL.Path, + "caller's request URL must not be mutated by RoundTrip") + }) + t.Run("AttachesSourceToContext", func(t *testing.T) { t.Parallel() @@ -47,7 +87,7 @@ func TestTransportFactory_TransportFor(t *testing.T) { w.WriteHeader(http.StatusOK) }) - rt, err := aibridged.NewTransportFactory(handler).TransportFor(uuid.New(), aibridge.SourceAgents) + rt, err := aibridged.NewTransportFactory(handler).TransportFor("openai", aibridge.SourceAgents) require.NoError(t, err) ctx := aibridge.WithDelegatedAPIKeyID(testutil.Context(t, testutil.WaitShort), "test-key-id") @@ -72,7 +112,7 @@ func TestInMemoryRoundTripper_PassesHeadersAndStatus(t *testing.T) { _, _ = w.Write([]byte(`{"ok":true}`)) }) - rt, err := aibridged.NewTransportFactory(handler).TransportFor(uuid.New(), aibridge.SourceAgents) + rt, err := aibridged.NewTransportFactory(handler).TransportFor("openai", aibridge.SourceAgents) require.NoError(t, err) ctx := aibridge.WithDelegatedAPIKeyID(testutil.Context(t, testutil.WaitShort), "test-key-id") @@ -122,7 +162,7 @@ func TestInMemoryRoundTripper_Streams(t *testing.T) { } }) - rt, err := aibridged.NewTransportFactory(handler).TransportFor(uuid.New(), aibridge.SourceAgents) + rt, err := aibridged.NewTransportFactory(handler).TransportFor("openai", aibridge.SourceAgents) require.NoError(t, err) ctx := aibridge.WithDelegatedAPIKeyID(testutil.Context(t, testutil.WaitShort), "test-key-id") @@ -161,7 +201,7 @@ func TestInMemoryRoundTripper_CancelCloses(t *testing.T) { close(handlerCtxObserved) }) - rt, err := aibridged.NewTransportFactory(handler).TransportFor(uuid.New(), aibridge.SourceAgents) + rt, err := aibridged.NewTransportFactory(handler).TransportFor("openai", aibridge.SourceAgents) require.NoError(t, err) parentCtx := testutil.Context(t, testutil.WaitShort) @@ -199,7 +239,7 @@ func TestInMemoryRoundTripper_ConcurrentRequests(t *testing.T) { _, _ = w.Write(body) }) - rt, err := aibridged.NewTransportFactory(handler).TransportFor(uuid.New(), aibridge.SourceAgents) + rt, err := aibridged.NewTransportFactory(handler).TransportFor("openai", aibridge.SourceAgents) require.NoError(t, err) const n = 16 @@ -248,7 +288,7 @@ func TestInMemoryRoundTripper_HandlerPanic(t *testing.T) { panic("unexpected nil pointer") }) - rt, err := aibridged.NewTransportFactory(handler).TransportFor(uuid.New(), aibridge.SourceAgents) + rt, err := aibridged.NewTransportFactory(handler).TransportFor("openai", aibridge.SourceAgents) require.NoError(t, err) ctx := aibridge.WithDelegatedAPIKeyID(testutil.Context(t, testutil.WaitShort), "test-key-id") @@ -307,7 +347,7 @@ func TestInMemoryRoundTripper_RequiresDelegatedAPIKeyID(t *testing.T) { w.WriteHeader(http.StatusOK) }) - rt, err := aibridged.NewTransportFactory(handler).TransportFor(uuid.New(), aibridge.SourceAgents) + rt, err := aibridged.NewTransportFactory(handler).TransportFor("openai", aibridge.SourceAgents) require.NoError(t, err) ctx := tc.withCtx(testutil.Context(t, testutil.WaitShort)) @@ -340,7 +380,7 @@ func TestInMemoryRoundTripper_HandlerReturnsWithoutWriting(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) - rt, err := aibridged.NewTransportFactory(handler).TransportFor(uuid.New(), aibridge.SourceAgents) + rt, err := aibridged.NewTransportFactory(handler).TransportFor("openai", aibridge.SourceAgents) require.NoError(t, err) ctx := aibridge.WithDelegatedAPIKeyID(testutil.Context(t, testutil.WaitShort), "test-key-id")