refactor: route TransportFor by provider name (#25650)

Delegate `aibridge` routing responsibility to the in-memory transport
layer.

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Danny Kopping
2026-05-25 18:04:12 +02:00
committed by GitHub
parent 0a45f96d30
commit 8652ef3e3b
5 changed files with 101 additions and 32 deletions
+8 -4
View File
@@ -3,8 +3,6 @@ package aibridge
import (
"context"
"net/http"
"github.com/google/uuid"
)
// Source identifies the call site that asked aibridge for a transport. It is
@@ -53,14 +51,20 @@ func DelegatedAPIKeyIDFromContext(ctx context.Context) (string, bool) {
}
// TransportFactory returns an [http.RoundTripper] that dispatches an aibridge
// request in-process for a given ai_providers row.
// request in-process for a given provider instance name.
//
// Implementations live in coderd/aibridged. coderd registers an in-process
// factory on coderd.API.AIBridgeTransportFactory at startup so callers route
// traffic through the daemon without going through the gated HTTP route.
//
// The returned RoundTripper is responsible for adapting the caller's request
// to the aibridge daemon's mount path: callers hand it an upstream-shaped
// request and the transport rewrites URL.Path to "/api/v2/aibridge/<name>/..."
// before dispatching. Routing keys on the provider's instance name so callers
// can use the same string the proxy daemon and the bridge mount use.
//
// Source is informational: implementations must not gate on it. It is attached
// to the request context so handlers can include it in logs and metrics.
type TransportFactory interface {
TransportFor(providerID uuid.UUID, source Source) (http.RoundTripper, error)
TransportFor(providerName string, source Source) (http.RoundTripper, error)
}
+6 -7
View File
@@ -6,7 +6,6 @@ import (
"net/http/httptest"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/aibridge"
@@ -23,12 +22,12 @@ type stubTransportFactory struct {
}
type callRecord struct {
providerID uuid.UUID
providerName string
source aibridge.Source
}
func (f *stubTransportFactory) TransportFor(providerID uuid.UUID, source aibridge.Source) (http.RoundTripper, error) {
f.calls <- callRecord{providerID: providerID, source: source}
func (f *stubTransportFactory) TransportFor(providerName string, source aibridge.Source) (http.RoundTripper, error) {
f.calls <- callRecord{providerName: providerName, source: source}
return &handlerRoundTripper{handler: f.handler}, nil
}
@@ -71,14 +70,14 @@ func TestAIBridgeTransportFactory_Registration(t *testing.T) {
loaded := api.AIBridgeTransportFactory.Load()
require.NotNil(t, loaded)
providerID := uuid.New()
rt, err := (*loaded).TransportFor(providerID, aibridge.SourceAgents)
providerName := "openai"
rt, err := (*loaded).TransportFor(providerName, aibridge.SourceAgents)
require.NoError(t, err)
require.NotNil(t, rt)
select {
case got := <-stub.calls:
require.Equal(t, providerID, got.providerID)
require.Equal(t, providerName, got.providerName)
require.Equal(t, aibridge.SourceAgents, got.source)
default:
t.Fatal("factory was not invoked")
+2 -2
View File
@@ -321,7 +321,7 @@ func TestServeHTTP_DelegatedAPIKey_BYOK_Integration(t *testing.T) {
pool.EXPECT().Acquire(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(mockH, nil)
factory := aibridged.NewTransportFactory(srv)
rt, err := factory.TransportFor(uuid.New(), agplaibridge.SourceAgents)
rt, err := factory.TransportFor("openai", agplaibridge.SourceAgents)
require.NoError(t, err)
ctx := agplaibridge.WithDelegatedAPIKeyID(testutil.Context(t, testutil.WaitShort), testKeyID)
@@ -373,7 +373,7 @@ func TestServeHTTP_DelegatedAPIKey_Integration(t *testing.T) {
pool.EXPECT().Acquire(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(mockH, nil)
factory := aibridged.NewTransportFactory(srv)
rt, err := factory.TransportFor(uuid.New(), agplaibridge.SourceAgents)
rt, err := factory.TransportFor("openai", agplaibridge.SourceAgents)
require.NoError(t, err)
ctx := agplaibridge.WithDelegatedAPIKeyID(testutil.Context(t, testutil.WaitShort), testKeyID)
+31 -5
View File
@@ -4,14 +4,21 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"sync"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/aibridge"
)
// aibridgeRootPath is the URL prefix the in-memory aibridged handler
// registers all of its routes under. The in-process round-tripper
// prepends this plus the provider name to every request before
// dispatch so callers can hand it upstream-shaped requests without
// knowing the daemon's mount layout.
const aibridgeRootPath = "/api/v2/aibridge"
// NewTransportFactory returns an [aibridge.TransportFactory] whose RoundTripper
// dispatches requests to handler in-process, streaming the response body
// through an [io.Pipe] so SSE/NDJSON/chunked responses propagate token-by-token
@@ -28,19 +35,27 @@ type transportFactory struct {
}
// TransportFor returns an in-process [http.RoundTripper] that dispatches
// requests through the aibridged handler. The source is attached to the
// request context for downstream logging; routing does not depend on it.
func (f *transportFactory) TransportFor(_ uuid.UUID, source aibridge.Source) (http.RoundTripper, error) {
// requests through the aibridged handler. The provider name is the routing
// key the daemon mounts on; the round-tripper rewrites each request's URL
// path to "/api/v2/aibridge/<providerName>/..." before dispatching so
// callers can build upstream-shaped requests and stay agnostic of the
// daemon's mount layout. The source is attached to the request context for
// downstream logging; routing does not depend on it.
func (f *transportFactory) TransportFor(providerName string, source aibridge.Source) (http.RoundTripper, error) {
if f.handler == nil {
return nil, xerrors.New("aibridged handler not registered")
}
return &inMemoryRoundTripper{handler: f.handler, source: source}, nil
if providerName == "" {
return nil, xerrors.New("provider name is required")
}
return &inMemoryRoundTripper{handler: f.handler, providerName: providerName, source: source}, nil
}
// inMemoryRoundTripper implements [http.RoundTripper] by invoking handler
// in a goroutine and streaming its response back through an [io.Pipe].
type inMemoryRoundTripper struct {
handler http.Handler
providerName string
source aibridge.Source
}
@@ -53,6 +68,17 @@ func (t *inMemoryRoundTripper) RoundTrip(req *http.Request) (*http.Response, err
return nil, xerrors.New("aibridged in-memory transport requires WithDelegatedAPIKeyID on the request context")
}
// Adapt the caller's upstream-shaped URL to the daemon's mount layout:
// "/api/v2/aibridge/<providerName>/<original-path>". Done here so
// callers do not need to encode the mount prefix or the provider
// routing key into the requests they hand to the transport.
newPath, err := url.JoinPath(aibridgeRootPath, t.providerName, req.URL.Path)
if err != nil {
return nil, xerrors.Errorf("rewrite request URL for provider %q: %w", t.providerName, err)
}
req = req.Clone(req.Context())
req.URL.Path = newPath
pr, pw := io.Pipe()
rw := &pipeResponseWriter{
header: http.Header{},
+51 -11
View File
@@ -10,7 +10,6 @@ import (
"sync"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
@@ -26,7 +25,7 @@ func TestTransportFactory_TransportFor(t *testing.T) {
t.Run("ReturnsTransport", func(t *testing.T) {
t.Parallel()
f := aibridged.NewTransportFactory(http.NotFoundHandler())
rt, err := f.TransportFor(uuid.New(), aibridge.SourceAgents)
rt, err := f.TransportFor("openai", aibridge.SourceAgents)
require.NoError(t, err)
require.NotNil(t, rt)
})
@@ -34,10 +33,51 @@ func TestTransportFactory_TransportFor(t *testing.T) {
t.Run("NilHandlerErrors", func(t *testing.T) {
t.Parallel()
f := aibridged.NewTransportFactory(nil)
_, err := f.TransportFor(uuid.New(), aibridge.SourceAgents)
_, err := f.TransportFor("openai", aibridge.SourceAgents)
require.Error(t, err)
})
t.Run("EmptyProviderErrors", func(t *testing.T) {
t.Parallel()
f := aibridged.NewTransportFactory(http.NotFoundHandler())
_, err := f.TransportFor("", aibridge.SourceAgents)
require.Error(t, err)
})
t.Run("RewritesURLToAibridgeMount", func(t *testing.T) {
t.Parallel()
// The round-tripper must adapt an upstream-shaped URL.Path
// ("/v1/messages") to the aibridge mount layout
// ("/api/v2/aibridge/<provider>/v1/messages") so callers don't
// have to encode the daemon's routing key into their requests.
got := make(chan string, 1)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
got <- r.URL.Path
w.WriteHeader(http.StatusOK)
})
rt, err := aibridged.NewTransportFactory(handler).TransportFor("my-anthropic", aibridge.SourceAgents)
require.NoError(t, err)
ctx := aibridge.WithDelegatedAPIKeyID(testutil.Context(t, testutil.WaitShort), "test-key-id")
req, err := http.NewRequestWithContext(ctx, http.MethodPost, "http://upstream/v1/messages", nil)
require.NoError(t, err)
// The caller's req.URL.Path is the upstream shape. Capture it so
// we can prove the transport mutates a clone, not the caller's
// request, after RoundTrip returns.
origPath := req.URL.Path
resp, err := rt.RoundTrip(req)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, "/api/v2/aibridge/my-anthropic/v1/messages", <-got)
require.Equal(t, origPath, req.URL.Path,
"caller's request URL must not be mutated by RoundTrip")
})
t.Run("AttachesSourceToContext", func(t *testing.T) {
t.Parallel()
@@ -47,7 +87,7 @@ func TestTransportFactory_TransportFor(t *testing.T) {
w.WriteHeader(http.StatusOK)
})
rt, err := aibridged.NewTransportFactory(handler).TransportFor(uuid.New(), aibridge.SourceAgents)
rt, err := aibridged.NewTransportFactory(handler).TransportFor("openai", aibridge.SourceAgents)
require.NoError(t, err)
ctx := aibridge.WithDelegatedAPIKeyID(testutil.Context(t, testutil.WaitShort), "test-key-id")
@@ -72,7 +112,7 @@ func TestInMemoryRoundTripper_PassesHeadersAndStatus(t *testing.T) {
_, _ = w.Write([]byte(`{"ok":true}`))
})
rt, err := aibridged.NewTransportFactory(handler).TransportFor(uuid.New(), aibridge.SourceAgents)
rt, err := aibridged.NewTransportFactory(handler).TransportFor("openai", aibridge.SourceAgents)
require.NoError(t, err)
ctx := aibridge.WithDelegatedAPIKeyID(testutil.Context(t, testutil.WaitShort), "test-key-id")
@@ -122,7 +162,7 @@ func TestInMemoryRoundTripper_Streams(t *testing.T) {
}
})
rt, err := aibridged.NewTransportFactory(handler).TransportFor(uuid.New(), aibridge.SourceAgents)
rt, err := aibridged.NewTransportFactory(handler).TransportFor("openai", aibridge.SourceAgents)
require.NoError(t, err)
ctx := aibridge.WithDelegatedAPIKeyID(testutil.Context(t, testutil.WaitShort), "test-key-id")
@@ -161,7 +201,7 @@ func TestInMemoryRoundTripper_CancelCloses(t *testing.T) {
close(handlerCtxObserved)
})
rt, err := aibridged.NewTransportFactory(handler).TransportFor(uuid.New(), aibridge.SourceAgents)
rt, err := aibridged.NewTransportFactory(handler).TransportFor("openai", aibridge.SourceAgents)
require.NoError(t, err)
parentCtx := testutil.Context(t, testutil.WaitShort)
@@ -199,7 +239,7 @@ func TestInMemoryRoundTripper_ConcurrentRequests(t *testing.T) {
_, _ = w.Write(body)
})
rt, err := aibridged.NewTransportFactory(handler).TransportFor(uuid.New(), aibridge.SourceAgents)
rt, err := aibridged.NewTransportFactory(handler).TransportFor("openai", aibridge.SourceAgents)
require.NoError(t, err)
const n = 16
@@ -248,7 +288,7 @@ func TestInMemoryRoundTripper_HandlerPanic(t *testing.T) {
panic("unexpected nil pointer")
})
rt, err := aibridged.NewTransportFactory(handler).TransportFor(uuid.New(), aibridge.SourceAgents)
rt, err := aibridged.NewTransportFactory(handler).TransportFor("openai", aibridge.SourceAgents)
require.NoError(t, err)
ctx := aibridge.WithDelegatedAPIKeyID(testutil.Context(t, testutil.WaitShort), "test-key-id")
@@ -307,7 +347,7 @@ func TestInMemoryRoundTripper_RequiresDelegatedAPIKeyID(t *testing.T) {
w.WriteHeader(http.StatusOK)
})
rt, err := aibridged.NewTransportFactory(handler).TransportFor(uuid.New(), aibridge.SourceAgents)
rt, err := aibridged.NewTransportFactory(handler).TransportFor("openai", aibridge.SourceAgents)
require.NoError(t, err)
ctx := tc.withCtx(testutil.Context(t, testutil.WaitShort))
@@ -340,7 +380,7 @@ func TestInMemoryRoundTripper_HandlerReturnsWithoutWriting(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
rt, err := aibridged.NewTransportFactory(handler).TransportFor(uuid.New(), aibridge.SourceAgents)
rt, err := aibridged.NewTransportFactory(handler).TransportFor("openai", aibridge.SourceAgents)
require.NoError(t, err)
ctx := aibridge.WithDelegatedAPIKeyID(testutil.Context(t, testutil.WaitShort), "test-key-id")