mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
refactor: route TransportFor by provider name (#25650)
Delegate `aibridge` routing responsibility to the in-memory transport layer. Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -3,8 +3,6 @@ package aibridge
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// Source identifies the call site that asked aibridge for a transport. It is
|
||||
@@ -53,14 +51,20 @@ func DelegatedAPIKeyIDFromContext(ctx context.Context) (string, bool) {
|
||||
}
|
||||
|
||||
// TransportFactory returns an [http.RoundTripper] that dispatches an aibridge
|
||||
// request in-process for a given ai_providers row.
|
||||
// request in-process for a given provider instance name.
|
||||
//
|
||||
// Implementations live in coderd/aibridged. coderd registers an in-process
|
||||
// factory on coderd.API.AIBridgeTransportFactory at startup so callers route
|
||||
// traffic through the daemon without going through the gated HTTP route.
|
||||
//
|
||||
// The returned RoundTripper is responsible for adapting the caller's request
|
||||
// to the aibridge daemon's mount path: callers hand it an upstream-shaped
|
||||
// request and the transport rewrites URL.Path to "/api/v2/aibridge/<name>/..."
|
||||
// before dispatching. Routing keys on the provider's instance name so callers
|
||||
// can use the same string the proxy daemon and the bridge mount use.
|
||||
//
|
||||
// Source is informational: implementations must not gate on it. It is attached
|
||||
// to the request context so handlers can include it in logs and metrics.
|
||||
type TransportFactory interface {
|
||||
TransportFor(providerID uuid.UUID, source Source) (http.RoundTripper, error)
|
||||
TransportFor(providerName string, source Source) (http.RoundTripper, error)
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/aibridge"
|
||||
@@ -23,12 +22,12 @@ type stubTransportFactory struct {
|
||||
}
|
||||
|
||||
type callRecord struct {
|
||||
providerID uuid.UUID
|
||||
source aibridge.Source
|
||||
providerName string
|
||||
source aibridge.Source
|
||||
}
|
||||
|
||||
func (f *stubTransportFactory) TransportFor(providerID uuid.UUID, source aibridge.Source) (http.RoundTripper, error) {
|
||||
f.calls <- callRecord{providerID: providerID, source: source}
|
||||
func (f *stubTransportFactory) TransportFor(providerName string, source aibridge.Source) (http.RoundTripper, error) {
|
||||
f.calls <- callRecord{providerName: providerName, source: source}
|
||||
return &handlerRoundTripper{handler: f.handler}, nil
|
||||
}
|
||||
|
||||
@@ -71,14 +70,14 @@ func TestAIBridgeTransportFactory_Registration(t *testing.T) {
|
||||
loaded := api.AIBridgeTransportFactory.Load()
|
||||
require.NotNil(t, loaded)
|
||||
|
||||
providerID := uuid.New()
|
||||
rt, err := (*loaded).TransportFor(providerID, aibridge.SourceAgents)
|
||||
providerName := "openai"
|
||||
rt, err := (*loaded).TransportFor(providerName, aibridge.SourceAgents)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rt)
|
||||
|
||||
select {
|
||||
case got := <-stub.calls:
|
||||
require.Equal(t, providerID, got.providerID)
|
||||
require.Equal(t, providerName, got.providerName)
|
||||
require.Equal(t, aibridge.SourceAgents, got.source)
|
||||
default:
|
||||
t.Fatal("factory was not invoked")
|
||||
|
||||
@@ -321,7 +321,7 @@ func TestServeHTTP_DelegatedAPIKey_BYOK_Integration(t *testing.T) {
|
||||
pool.EXPECT().Acquire(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(mockH, nil)
|
||||
|
||||
factory := aibridged.NewTransportFactory(srv)
|
||||
rt, err := factory.TransportFor(uuid.New(), agplaibridge.SourceAgents)
|
||||
rt, err := factory.TransportFor("openai", agplaibridge.SourceAgents)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := agplaibridge.WithDelegatedAPIKeyID(testutil.Context(t, testutil.WaitShort), testKeyID)
|
||||
@@ -373,7 +373,7 @@ func TestServeHTTP_DelegatedAPIKey_Integration(t *testing.T) {
|
||||
pool.EXPECT().Acquire(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(mockH, nil)
|
||||
|
||||
factory := aibridged.NewTransportFactory(srv)
|
||||
rt, err := factory.TransportFor(uuid.New(), agplaibridge.SourceAgents)
|
||||
rt, err := factory.TransportFor("openai", agplaibridge.SourceAgents)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := agplaibridge.WithDelegatedAPIKeyID(testutil.Context(t, testutil.WaitShort), testKeyID)
|
||||
|
||||
@@ -4,14 +4,21 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/aibridge"
|
||||
)
|
||||
|
||||
// aibridgeRootPath is the URL prefix the in-memory aibridged handler
|
||||
// registers all of its routes under. The in-process round-tripper
|
||||
// prepends this plus the provider name to every request before
|
||||
// dispatch so callers can hand it upstream-shaped requests without
|
||||
// knowing the daemon's mount layout.
|
||||
const aibridgeRootPath = "/api/v2/aibridge"
|
||||
|
||||
// NewTransportFactory returns an [aibridge.TransportFactory] whose RoundTripper
|
||||
// dispatches requests to handler in-process, streaming the response body
|
||||
// through an [io.Pipe] so SSE/NDJSON/chunked responses propagate token-by-token
|
||||
@@ -28,20 +35,28 @@ type transportFactory struct {
|
||||
}
|
||||
|
||||
// TransportFor returns an in-process [http.RoundTripper] that dispatches
|
||||
// requests through the aibridged handler. The source is attached to the
|
||||
// request context for downstream logging; routing does not depend on it.
|
||||
func (f *transportFactory) TransportFor(_ uuid.UUID, source aibridge.Source) (http.RoundTripper, error) {
|
||||
// requests through the aibridged handler. The provider name is the routing
|
||||
// key the daemon mounts on; the round-tripper rewrites each request's URL
|
||||
// path to "/api/v2/aibridge/<providerName>/..." before dispatching so
|
||||
// callers can build upstream-shaped requests and stay agnostic of the
|
||||
// daemon's mount layout. The source is attached to the request context for
|
||||
// downstream logging; routing does not depend on it.
|
||||
func (f *transportFactory) TransportFor(providerName string, source aibridge.Source) (http.RoundTripper, error) {
|
||||
if f.handler == nil {
|
||||
return nil, xerrors.New("aibridged handler not registered")
|
||||
}
|
||||
return &inMemoryRoundTripper{handler: f.handler, source: source}, nil
|
||||
if providerName == "" {
|
||||
return nil, xerrors.New("provider name is required")
|
||||
}
|
||||
return &inMemoryRoundTripper{handler: f.handler, providerName: providerName, source: source}, nil
|
||||
}
|
||||
|
||||
// inMemoryRoundTripper implements [http.RoundTripper] by invoking handler
|
||||
// in a goroutine and streaming its response back through an [io.Pipe].
|
||||
type inMemoryRoundTripper struct {
|
||||
handler http.Handler
|
||||
source aibridge.Source
|
||||
handler http.Handler
|
||||
providerName string
|
||||
source aibridge.Source
|
||||
}
|
||||
|
||||
func (t *inMemoryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
@@ -53,6 +68,17 @@ func (t *inMemoryRoundTripper) RoundTrip(req *http.Request) (*http.Response, err
|
||||
return nil, xerrors.New("aibridged in-memory transport requires WithDelegatedAPIKeyID on the request context")
|
||||
}
|
||||
|
||||
// Adapt the caller's upstream-shaped URL to the daemon's mount layout:
|
||||
// "/api/v2/aibridge/<providerName>/<original-path>". Done here so
|
||||
// callers do not need to encode the mount prefix or the provider
|
||||
// routing key into the requests they hand to the transport.
|
||||
newPath, err := url.JoinPath(aibridgeRootPath, t.providerName, req.URL.Path)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("rewrite request URL for provider %q: %w", t.providerName, err)
|
||||
}
|
||||
req = req.Clone(req.Context())
|
||||
req.URL.Path = newPath
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
rw := &pipeResponseWriter{
|
||||
header: http.Header{},
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
@@ -26,7 +25,7 @@ func TestTransportFactory_TransportFor(t *testing.T) {
|
||||
t.Run("ReturnsTransport", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
f := aibridged.NewTransportFactory(http.NotFoundHandler())
|
||||
rt, err := f.TransportFor(uuid.New(), aibridge.SourceAgents)
|
||||
rt, err := f.TransportFor("openai", aibridge.SourceAgents)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rt)
|
||||
})
|
||||
@@ -34,10 +33,51 @@ func TestTransportFactory_TransportFor(t *testing.T) {
|
||||
t.Run("NilHandlerErrors", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
f := aibridged.NewTransportFactory(nil)
|
||||
_, err := f.TransportFor(uuid.New(), aibridge.SourceAgents)
|
||||
_, err := f.TransportFor("openai", aibridge.SourceAgents)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("EmptyProviderErrors", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
f := aibridged.NewTransportFactory(http.NotFoundHandler())
|
||||
_, err := f.TransportFor("", aibridge.SourceAgents)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("RewritesURLToAibridgeMount", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// The round-tripper must adapt an upstream-shaped URL.Path
|
||||
// ("/v1/messages") to the aibridge mount layout
|
||||
// ("/api/v2/aibridge/<provider>/v1/messages") so callers don't
|
||||
// have to encode the daemon's routing key into their requests.
|
||||
got := make(chan string, 1)
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
got <- r.URL.Path
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
rt, err := aibridged.NewTransportFactory(handler).TransportFor("my-anthropic", aibridge.SourceAgents)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := aibridge.WithDelegatedAPIKeyID(testutil.Context(t, testutil.WaitShort), "test-key-id")
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, "http://upstream/v1/messages", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// The caller's req.URL.Path is the upstream shape. Capture it so
|
||||
// we can prove the transport mutates a clone, not the caller's
|
||||
// request, after RoundTrip returns.
|
||||
origPath := req.URL.Path
|
||||
|
||||
resp, err := rt.RoundTrip(req)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
require.Equal(t, "/api/v2/aibridge/my-anthropic/v1/messages", <-got)
|
||||
require.Equal(t, origPath, req.URL.Path,
|
||||
"caller's request URL must not be mutated by RoundTrip")
|
||||
})
|
||||
|
||||
t.Run("AttachesSourceToContext", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -47,7 +87,7 @@ func TestTransportFactory_TransportFor(t *testing.T) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
rt, err := aibridged.NewTransportFactory(handler).TransportFor(uuid.New(), aibridge.SourceAgents)
|
||||
rt, err := aibridged.NewTransportFactory(handler).TransportFor("openai", aibridge.SourceAgents)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := aibridge.WithDelegatedAPIKeyID(testutil.Context(t, testutil.WaitShort), "test-key-id")
|
||||
@@ -72,7 +112,7 @@ func TestInMemoryRoundTripper_PassesHeadersAndStatus(t *testing.T) {
|
||||
_, _ = w.Write([]byte(`{"ok":true}`))
|
||||
})
|
||||
|
||||
rt, err := aibridged.NewTransportFactory(handler).TransportFor(uuid.New(), aibridge.SourceAgents)
|
||||
rt, err := aibridged.NewTransportFactory(handler).TransportFor("openai", aibridge.SourceAgents)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := aibridge.WithDelegatedAPIKeyID(testutil.Context(t, testutil.WaitShort), "test-key-id")
|
||||
@@ -122,7 +162,7 @@ func TestInMemoryRoundTripper_Streams(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
rt, err := aibridged.NewTransportFactory(handler).TransportFor(uuid.New(), aibridge.SourceAgents)
|
||||
rt, err := aibridged.NewTransportFactory(handler).TransportFor("openai", aibridge.SourceAgents)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := aibridge.WithDelegatedAPIKeyID(testutil.Context(t, testutil.WaitShort), "test-key-id")
|
||||
@@ -161,7 +201,7 @@ func TestInMemoryRoundTripper_CancelCloses(t *testing.T) {
|
||||
close(handlerCtxObserved)
|
||||
})
|
||||
|
||||
rt, err := aibridged.NewTransportFactory(handler).TransportFor(uuid.New(), aibridge.SourceAgents)
|
||||
rt, err := aibridged.NewTransportFactory(handler).TransportFor("openai", aibridge.SourceAgents)
|
||||
require.NoError(t, err)
|
||||
|
||||
parentCtx := testutil.Context(t, testutil.WaitShort)
|
||||
@@ -199,7 +239,7 @@ func TestInMemoryRoundTripper_ConcurrentRequests(t *testing.T) {
|
||||
_, _ = w.Write(body)
|
||||
})
|
||||
|
||||
rt, err := aibridged.NewTransportFactory(handler).TransportFor(uuid.New(), aibridge.SourceAgents)
|
||||
rt, err := aibridged.NewTransportFactory(handler).TransportFor("openai", aibridge.SourceAgents)
|
||||
require.NoError(t, err)
|
||||
|
||||
const n = 16
|
||||
@@ -248,7 +288,7 @@ func TestInMemoryRoundTripper_HandlerPanic(t *testing.T) {
|
||||
panic("unexpected nil pointer")
|
||||
})
|
||||
|
||||
rt, err := aibridged.NewTransportFactory(handler).TransportFor(uuid.New(), aibridge.SourceAgents)
|
||||
rt, err := aibridged.NewTransportFactory(handler).TransportFor("openai", aibridge.SourceAgents)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := aibridge.WithDelegatedAPIKeyID(testutil.Context(t, testutil.WaitShort), "test-key-id")
|
||||
@@ -307,7 +347,7 @@ func TestInMemoryRoundTripper_RequiresDelegatedAPIKeyID(t *testing.T) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
rt, err := aibridged.NewTransportFactory(handler).TransportFor(uuid.New(), aibridge.SourceAgents)
|
||||
rt, err := aibridged.NewTransportFactory(handler).TransportFor("openai", aibridge.SourceAgents)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := tc.withCtx(testutil.Context(t, testutil.WaitShort))
|
||||
@@ -340,7 +380,7 @@ func TestInMemoryRoundTripper_HandlerReturnsWithoutWriting(t *testing.T) {
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
||||
|
||||
rt, err := aibridged.NewTransportFactory(handler).TransportFor(uuid.New(), aibridge.SourceAgents)
|
||||
rt, err := aibridged.NewTransportFactory(handler).TransportFor("openai", aibridge.SourceAgents)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := aibridge.WithDelegatedAPIKeyID(testutil.Context(t, testutil.WaitShort), "test-key-id")
|
||||
|
||||
Reference in New Issue
Block a user