feat: hot-reload aibridged and aibridgeproxyd providers on DB changes (#25673)

Previously the in-process aibridge daemon and the enterprise aibridgeproxy daemon both snapshotted their provider routing once at boot. Any `ai_providers` or `ai_provider_keys` mutation required a restart for either to pick it up.

Add an `ai_providers_changed` pubsub channel that the CRUD handlers publish on after Create / Update / Delete. Both daemons subscribe:

- **aibridged** rebuilds its `[]aibridge.Provider` snapshot via `BuildProviders` and swaps it into the pool atomically. Inflight requests keep serving against the bridge they already acquired; new acquires build against the new snapshot. Per-provider construction errors stay scoped to the offending row.
- **aibridgeproxyd** rebuilds its routing snapshot from `GetAIProviders` and swaps the host→provider map atomically. The MITM listener picks up new providers without restart.

DB read for aibridgeproxyd uses the existing `AsAIProviderMetadataReader` subject for routing-only access.
This commit is contained in:
Danny Kopping
2026-05-27 11:58:43 +02:00
committed by GitHub
parent 6acfe6c835
commit 79e007cf30
19 changed files with 1678 additions and 227 deletions
+87 -63
View File
@@ -18,6 +18,7 @@ import (
"strconv"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"
@@ -30,6 +31,18 @@ import (
agplaibridge "github.com/coder/coder/v2/coderd/aibridge"
)
// ProviderRoute is the routing entry for a single AI provider: the
// instance name (the routing key) and the upstream base URL (the
// source of the MITM allowlist host).
type ProviderRoute struct {
Name string
BaseURL string
}
// RefreshProvidersFunc returns the live provider set used by Reload to
// rebuild the proxy's routing snapshot.
type RefreshProvidersFunc func(ctx context.Context) ([]ProviderRoute, error)
// Known AI provider hosts.
const (
HostAnthropic = "api.anthropic.com"
@@ -119,14 +132,21 @@ var blockedIPRanges = func() []net.IPNet {
// - decrypting requests using the configured MITM CA certificate
// - forwarding requests to aibridged for processing
type Server struct {
ctx context.Context
logger slog.Logger
proxy *goproxy.ProxyHttpServer
httpServer *http.Server
listener net.Listener
tlsEnabled bool
coderAccessURL *url.URL
aibridgeProviderFromHost func(host string) string
ctx context.Context
logger slog.Logger
proxy *goproxy.ProxyHttpServer
httpServer *http.Server
listener net.Listener
tlsEnabled bool
coderAccessURL *url.URL
// refreshProviders fetches the live provider snapshot on Reload.
// Nil disables hot-reload.
refreshProviders RefreshProvidersFunc
// providerRouter holds the live (mitmHosts, nameByHost) pair.
providerRouter atomic.Pointer[providerRouter]
// allowedPorts is the port allowlist for CONNECT requests. Fixed at
// construction; not reloadable.
allowedPorts []string
// caCert is the PEM-encoded MITM CA certificate loaded during initialization.
// This is served to clients who need to trust the proxy's generated certificates.
caCert []byte
@@ -139,6 +159,21 @@ type Server struct {
metrics *Metrics
}
// providerRouter keeps CONNECT matching and provider lookup in sync.
type providerRouter struct {
mitmHosts []string // host:port allowlist for the goproxy condition.
nameByHost map[string]string // lowercase hostname -> provider name.
}
// emptyProviderRouter is used before the first Reload (or when the
// operator deconfigures every provider) so handlers can safely call
// loadProviderRouter without a nil check.
var emptyProviderRouter = &providerRouter{nameByHost: map[string]string{}}
func (r *providerRouter) providerFromHost(host string) string {
return r.nameByHost[strings.ToLower(host)]
}
// requestContext holds metadata propagated through the proxy request/response chain.
// It is stored in goproxy's ProxyCtx.UserData and enriched as the request progresses
// through the proxy handlers.
@@ -183,13 +218,12 @@ type Options struct {
// CertStore is an optional certificate cache for MITM. If nil, a default
// cache is created. Exposed for testing.
CertStore goproxy.CertStorage
// DomainAllowlist is the list of domains to intercept and route through AI Bridge.
// Only requests to these domains will be MITM'd and forwarded to aibridged.
// Requests to other domains will be tunneled directly without decryption.
// DomainAllowlist seeds the boot-time MITM allowlist. Production
// callers should leave this empty and rely on RefreshProviders;
// tests use it to skip the refresh round-trip.
DomainAllowlist []string
// AIBridgeProviderFromHost maps a hostname to a known aibridge provider
// name. Must be non-nil; the caller derives it from the configured
// provider list.
// AIBridgeProviderFromHost seeds the boot-time host -> provider
// name mapping. Required iff DomainAllowlist is non-empty.
AIBridgeProviderFromHost func(host string) string
// UpstreamProxy is the URL of an upstream HTTP proxy to chain tunneled
// (non-allowlisted) requests through. If empty, tunneled requests connect
@@ -214,6 +248,10 @@ type Options struct {
// Metrics is the prometheus metrics instance for recording proxy metrics.
// If nil, metrics will not be recorded.
Metrics *Metrics
// RefreshProviders, when set, is invoked by Server.Reload to fetch
// the live provider snapshot used to derive the MITM allowlist and
// host -> provider-name routing. Nil disables hot-reload.
RefreshProviders RefreshProvidersFunc
}
func New(ctx context.Context, logger slog.Logger, opts Options) (*Server, error) {
@@ -258,29 +296,12 @@ func New(ctx context.Context, logger slog.Logger, opts Options) (*Server, error)
allowedPorts = []string{"80", "443"}
}
// An empty allowlist is permitted so the server can boot before any
// ai_providers row exists; every intercept attempt is then rejected
// until providers are configured.
// TODO: refresh the allowlist when ai_providers changes so a restart
// is not required after the first provider is configured.
mitmHosts, err := convertDomainsToHosts(opts.DomainAllowlist, allowedPorts)
// Build the boot-time router from DomainAllowlist + the lookup fn.
// Both empty is fine: the server fails closed (no MITM until
// Reload populates the router from the database).
bootRouter, err := buildBootRouter(opts.DomainAllowlist, opts.AIBridgeProviderFromHost, allowedPorts)
if err != nil {
return nil, xerrors.Errorf("invalid domain allowlist: %w", err)
}
if opts.AIBridgeProviderFromHost == nil {
return nil, xerrors.New("AIBridgeProviderFromHost is required")
}
aibridgeProviderFromHost := opts.AIBridgeProviderFromHost
for _, domain := range opts.DomainAllowlist {
domain = strings.TrimSpace(strings.ToLower(domain))
if domain == "" {
continue
}
if aibridgeProviderFromHost(domain) == "" {
return nil, xerrors.Errorf("domain %q is in allowlist but has no provider mapping", domain)
}
return nil, err
}
// Parse configured exceptions to the blocked IP ranges.
@@ -319,17 +340,22 @@ func New(ctx context.Context, logger slog.Logger, opts Options) (*Server, error)
}
srv := &Server{
ctx: ctx,
logger: logger,
proxy: proxy,
tlsEnabled: opts.TLSCertFile != "",
coderAccessURL: coderAccessURL,
aibridgeProviderFromHost: aibridgeProviderFromHost,
caCert: certPEM,
allowedPrivateRanges: allowedPrivateRanges,
newDumper: opts.NewDumper,
metrics: opts.Metrics,
ctx: ctx,
logger: logger,
proxy: proxy,
tlsEnabled: opts.TLSCertFile != "",
coderAccessURL: coderAccessURL,
refreshProviders: opts.RefreshProviders,
allowedPorts: allowedPorts,
caCert: certPEM,
allowedPrivateRanges: allowedPrivateRanges,
newDumper: opts.NewDumper,
metrics: opts.Metrics,
}
// Seed the boot-time router from the constructor inputs so the
// proxy can serve immediately. Reload may swap this snapshot at any
// point after construction.
srv.providerRouter.Store(bootRouter)
// Configure upstream proxy for tunneled (non-allowlisted) CONNECT requests.
// Allowlisted domains are MITM'd and forwarded to aibridge directly,
@@ -417,12 +443,11 @@ func New(ctx context.Context, logger slog.Logger, opts Options) (*Server, error)
// Reject CONNECT requests to non-standard ports.
proxy.OnRequest().HandleConnectFunc(srv.portMiddleware(allowedPorts))
// Apply MITM with authentication only to allowlisted hosts.
proxy.OnRequest(
// Only CONNECT requests to these hosts will be intercepted and decrypted.
// All other requests will be tunneled directly to their destination.
goproxy.ReqHostIs(mitmHosts...),
).HandleConnectFunc(
// Apply MITM with authentication only to allowlisted hosts. The host
// list is loaded from the atomic router on every CONNECT so a
// Reload while inflight requests are in progress takes effect on
// the next CONNECT without touching the already-MITM'd ones.
proxy.OnRequest(srv.mitmHostsCondition()).HandleConnectFunc(
// Extract Coder token from proxy authentication to forward to aibridged.
srv.authMiddleware,
)
@@ -470,7 +495,7 @@ func New(ctx context.Context, logger slog.Logger, opts Options) (*Server, error)
slog.F("listen_addr", listener.Addr().String()),
slog.F("tls_listener_enabled", srv.tlsEnabled),
slog.F("coder_access_url", coderAccessURL.String()),
slog.F("domain_allowlist", mitmHosts),
slog.F("domain_allowlist", bootRouter.mitmHosts),
slog.F("upstream_proxy", opts.UpstreamProxy),
slog.F("allowed_private_cidrs", opts.AllowedPrivateCIDRs),
slog.F("api_dump_enabled", opts.NewDumper != nil),
@@ -651,11 +676,11 @@ func (s *Server) authMiddleware(host string, ctx *goproxy.ProxyCtx) (*goproxy.Co
)
// Determine the provider from the request hostname.
provider := s.aibridgeProviderFromHost(ctx.Req.URL.Hostname())
// This should never happen: startup validation ensures all allowlisted
// domains have known aibridge provider mappings.
provider := s.loadProviderRouter().providerFromHost(ctx.Req.URL.Hostname())
// A concurrent Reload can swap the router between CONNECT matching
// and provider lookup, so treat a missing mapping as a runtime miss.
if provider == "" {
logger.Error(s.ctx, "rejecting CONNECT request with no provider mapping")
logger.Warn(s.ctx, "rejecting CONNECT request with no provider mapping")
return goproxy.RejectConnect, host
}
@@ -922,12 +947,11 @@ func (s *Server) handleRequest(req *http.Request, ctx *goproxy.ProxyCtx) (*http.
}
if reqCtx.Provider == "" {
// This should never happen: startup validation ensures all allowlisted
// domains have known aibridge provider mappings.
// The request is MITM'd (decrypted) but since there is no mapping,
// there is no known route to aibridge.
// Log error and forward to the original destination as a fallback.
s.logger.Error(s.ctx, "decrypted request has no provider mapping, passing through",
// A concurrent Reload can remove the provider after CONNECT
// authentication. The request is MITM'd (decrypted), but without a
// mapping there is no known route to aibridge. Log and forward
// to the original destination as a fallback.
s.logger.Warn(s.ctx, "decrypted request has no provider mapping, passing through",
slog.F("connect_id", reqCtx.ConnectSessionID.String()),
slog.F("host", req.Host),
slog.F("method", req.Method),
@@ -158,6 +158,7 @@ type testProxyConfig struct {
allowedPrivateCIDRs []string
newDumper func(string, string) aibridgeproxyd.RoundTripDumper
metrics *aibridgeproxyd.Metrics
refreshProviders aibridgeproxyd.RefreshProvidersFunc
}
type testProxyOption func(*testProxyConfig)
@@ -250,6 +251,12 @@ func withListenerTLS(certFile, keyFile string) testProxyOption {
}
}
func withRefreshProviders(fn aibridgeproxyd.RefreshProvidersFunc) testProxyOption {
return func(cfg *testProxyConfig) {
cfg.refreshProviders = fn
}
}
// newTestProxy creates a new AI Bridge Proxy server for testing.
// It uses the shared MITM certificate and registers cleanup automatically.
// It waits for the proxy server to be ready before returning.
@@ -289,6 +296,7 @@ func newTestProxy(t *testing.T, opts ...testProxyOption) *aibridgeproxyd.Server
AllowedPrivateCIDRs: cfg.allowedPrivateCIDRs,
NewDumper: cfg.newDumper,
Metrics: cfg.metrics,
RefreshProviders: cfg.refreshProviders,
}
if cfg.certStore != nil {
aibridgeOpts.CertStore = cfg.certStore
+124
View File
@@ -0,0 +1,124 @@
package aibridgeproxyd
import (
"context"
"net/http"
"net/url"
"slices"
"strings"
"github.com/elazarl/goproxy"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
)
// Reload refreshes proxy routing from the configured provider source.
// A refresh failure leaves the previous snapshot in place.
func (s *Server) Reload(ctx context.Context) error {
if s.refreshProviders == nil {
return nil
}
providers, err := s.refreshProviders(ctx)
if err != nil {
return xerrors.Errorf("refresh ai providers for proxy routing: %w", err)
}
router, err := buildProviderRouter(ctx, s.logger, providers, s.allowedPorts)
if err != nil {
return xerrors.Errorf("build provider router (provider_count=%d): %w", len(providers), err)
}
s.providerRouter.Store(router)
s.logger.Debug(s.ctx, "aibridgeproxyd router reloaded",
slog.F("mitm_host_count", len(router.mitmHosts)),
)
return nil
}
func (s *Server) loadProviderRouter() *providerRouter {
if p := s.providerRouter.Load(); p != nil {
return p
}
return emptyProviderRouter
}
// mitmHostsCondition returns a goproxy ReqConditionFunc that reads the
// allowlist from the atomic router on every match. Using a closure
// instead of goproxy.ReqHostIs(...) lets Reload affect every later
// CONNECT without re-registering handlers.
func (s *Server) mitmHostsCondition() goproxy.ReqConditionFunc {
return func(req *http.Request, _ *goproxy.ProxyCtx) bool {
if req == nil {
return false
}
return slices.Contains(s.loadProviderRouter().mitmHosts, strings.ToLower(req.URL.Host))
}
}
// buildProviderRouter constructs a router snapshot from a refreshed
// provider list. First provider wins on duplicate hostnames.
func buildProviderRouter(ctx context.Context, logger slog.Logger, providers []ProviderRoute, allowedPorts []string) (*providerRouter, error) {
nameByHost := make(map[string]string, len(providers))
var domains []string
for _, p := range providers {
if p.BaseURL == "" {
logger.Warn(ctx, "skipping ai provider without base url",
slog.F("provider_name", p.Name),
)
continue
}
u, err := url.Parse(p.BaseURL)
if err != nil {
logger.Warn(ctx, "skipping ai provider with invalid base url",
slog.F("provider_name", p.Name),
slog.F("base_url", p.BaseURL),
slog.Error(err),
)
continue
}
if u.Hostname() == "" {
logger.Warn(ctx, "skipping ai provider base url without hostname",
slog.F("provider_name", p.Name),
slog.F("base_url", p.BaseURL),
)
continue
}
host := strings.ToLower(u.Hostname())
if _, exists := nameByHost[host]; exists {
continue
}
nameByHost[host] = p.Name
domains = append(domains, host)
}
mitmHosts, err := convertDomainsToHosts(domains, allowedPorts)
if err != nil {
return nil, err
}
return &providerRouter{mitmHosts: mitmHosts, nameByHost: nameByHost}, nil
}
// buildBootRouter seeds the providerRouter from the boot-time inputs.
// The lookup function is consulted only for hosts in the allowlist; a
// nil function with an empty allowlist is fine and yields an empty
// router (the proxy fails closed until Reload populates it).
func buildBootRouter(domainAllowlist []string, providerFromHost func(string) string, allowedPorts []string) (*providerRouter, error) {
mitmHosts, err := convertDomainsToHosts(domainAllowlist, allowedPorts)
if err != nil {
return nil, xerrors.Errorf("invalid domain allowlist: %w", err)
}
nameByHost := make(map[string]string, len(domainAllowlist))
for _, domain := range domainAllowlist {
domain = strings.TrimSpace(strings.ToLower(domain))
if domain == "" {
continue
}
var name string
if providerFromHost != nil {
name = providerFromHost(domain)
}
if name == "" {
return nil, xerrors.Errorf("domain %q is in allowlist but has no provider mapping", domain)
}
nameByHost[domain] = name
}
return &providerRouter{mitmHosts: mitmHosts, nameByHost: nameByHost}, nil
}
@@ -0,0 +1,147 @@
package aibridgeproxyd
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/testutil"
)
func TestServerReloadSwapsProviderRouter(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
providers := []ProviderRoute{{Name: "old", BaseURL: "https://old.example.com/"}}
srv := &Server{
ctx: ctx,
logger: slogtest.Make(t, nil),
allowedPorts: []string{"443"},
refreshProviders: func(context.Context) ([]ProviderRoute, error) {
return providers, nil
},
}
srv.providerRouter.Store(emptyProviderRouter)
require.NoError(t, srv.Reload(ctx))
assert.Equal(t, "old", srv.loadProviderRouter().providerFromHost("old.example.com"))
assert.Empty(t, srv.loadProviderRouter().providerFromHost("new.example.com"))
providers = []ProviderRoute{{Name: "new", BaseURL: "https://new.example.com/"}}
require.NoError(t, srv.Reload(ctx))
router := srv.loadProviderRouter()
assert.Empty(t, router.providerFromHost("old.example.com"))
assert.Equal(t, "new", router.providerFromHost("new.example.com"))
assert.Equal(t, []string{"new.example.com:443"}, router.mitmHosts)
}
func TestServerReloadPreservesProviderRouterOnRefreshError(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
refreshErr := xerrors.New("refresh failed")
providers := []ProviderRoute{{Name: "old", BaseURL: "https://old.example.com/"}}
failRefresh := false
srv := &Server{
ctx: ctx,
logger: slogtest.Make(t, nil),
allowedPorts: []string{"443"},
refreshProviders: func(context.Context) ([]ProviderRoute, error) {
if failRefresh {
return nil, refreshErr
}
return providers, nil
},
}
srv.providerRouter.Store(emptyProviderRouter)
require.NoError(t, srv.Reload(ctx))
before := srv.loadProviderRouter()
assert.Equal(t, "old", before.providerFromHost("old.example.com"))
failRefresh = true
require.ErrorIs(t, srv.Reload(ctx), refreshErr)
after := srv.loadProviderRouter()
assert.Same(t, before, after)
assert.Equal(t, "old", after.providerFromHost("old.example.com"))
assert.Equal(t, []string{"old.example.com:443"}, after.mitmHosts)
}
// TestBuildProviderRouter covers the host-and-routing derivation that
// Reload feeds into the providerRouter.
func TestBuildProviderRouter(t *testing.T) {
t.Parallel()
t.Run("ExtractsHostnames", func(t *testing.T) {
t.Parallel()
providers := []ProviderRoute{
{Name: "openai", BaseURL: "https://api.openai.com/v1/"},
{Name: "anthropic", BaseURL: "https://api.anthropic.com/"},
{Name: "custom", BaseURL: "https://custom-llm.example.com:8443/api"},
}
router, err := buildProviderRouter(testutil.Context(t, testutil.WaitShort), slogtest.Make(t, nil), providers, []string{"443"})
require.NoError(t, err)
assert.Equal(t, "openai", router.providerFromHost("api.openai.com"))
assert.Equal(t, "anthropic", router.providerFromHost("api.anthropic.com"))
assert.Equal(t, "custom", router.providerFromHost("custom-llm.example.com"))
assert.Empty(t, router.providerFromHost("unknown.com"))
assert.Contains(t, router.mitmHosts, "api.openai.com:443")
assert.Contains(t, router.mitmHosts, "api.anthropic.com:443")
})
t.Run("DeduplicatesSameHost", func(t *testing.T) {
t.Parallel()
providers := []ProviderRoute{
{Name: "first", BaseURL: "https://api.example.com/v1"},
{Name: "second", BaseURL: "https://api.example.com/v2"},
}
router, err := buildProviderRouter(testutil.Context(t, testutil.WaitShort), slogtest.Make(t, nil), providers, []string{"443"})
require.NoError(t, err)
// First provider wins on duplicate host.
assert.Equal(t, "first", router.providerFromHost("api.example.com"))
})
t.Run("CaseInsensitive", func(t *testing.T) {
t.Parallel()
providers := []ProviderRoute{
{Name: "provider", BaseURL: "https://API.Example.COM/v1"},
}
router, err := buildProviderRouter(testutil.Context(t, testutil.WaitShort), slogtest.Make(t, nil), providers, []string{"443"})
require.NoError(t, err)
assert.Equal(t, "provider", router.providerFromHost("API.Example.COM"))
assert.Equal(t, "provider", router.providerFromHost("api.example.com"))
})
t.Run("SkipsEmptyOrMalformedBaseURL", func(t *testing.T) {
t.Parallel()
providers := []ProviderRoute{
{Name: "no-url"},
{Name: "scheme-only", BaseURL: "https://"},
{Name: "good", BaseURL: "https://api.good.example.com/"},
}
router, err := buildProviderRouter(testutil.Context(t, testutil.WaitShort), slogtest.Make(t, nil), providers, []string{"443"})
require.NoError(t, err)
assert.Equal(t, "good", router.providerFromHost("api.good.example.com"))
assert.Empty(t, router.providerFromHost("scheme-only"))
assert.Equal(t, []string{"api.good.example.com:443"}, router.mitmHosts)
})
}
+379
View File
@@ -0,0 +1,379 @@
package aibridgeproxyd_test
import (
"context"
"io"
"net/http"
"net/http/httptest"
"slices"
"strings"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/enterprise/aibridgeproxyd"
"github.com/coder/coder/v2/testutil"
)
// reloadTestHarness wires a real proxy server to a mutable provider
// store and a mock aibridged backend so tests can drive Reload through
// a CRUD-style sequence and observe routing via real proxy requests.
type reloadTestHarness struct {
srv *aibridgeproxyd.Server
store *providerStore
client *http.Client
bridged *httptest.Server
recorder *aibridgedRecorder
}
// aibridgedRecorder captures the path of the last request received by
// the mock aibridged backend. Access is mutex-guarded so the test
// goroutine and the proxy's response goroutine can read/write safely.
type aibridgedRecorder struct {
mu sync.Mutex
path string
}
func (r *aibridgedRecorder) record(path string) {
r.mu.Lock()
defer r.mu.Unlock()
r.path = path
}
func (r *aibridgedRecorder) load() string {
r.mu.Lock()
defer r.mu.Unlock()
return r.path
}
func (r *aibridgedRecorder) reset() {
r.mu.Lock()
defer r.mu.Unlock()
r.path = ""
}
// providerStore is a mutable [aibridgeproxyd.RefreshProvidersFunc]
// backing for integration tests. set / setErr mutate the snapshot
// returned by the next Reload, mimicking CRUD against the database.
type providerStore struct {
mu sync.Mutex
providers []aibridgeproxyd.ProviderRoute
err error
}
func (s *providerStore) set(providers []aibridgeproxyd.ProviderRoute) {
s.mu.Lock()
defer s.mu.Unlock()
s.providers = providers
s.err = nil
}
func (s *providerStore) setErr(err error) {
s.mu.Lock()
defer s.mu.Unlock()
s.err = err
}
func (s *providerStore) refresh(context.Context) ([]aibridgeproxyd.ProviderRoute, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.err != nil {
return nil, s.err
}
// Return a copy so callers can't mutate our internal snapshot.
return slices.Clone(s.providers), nil
}
// newReloadTestHarness boots a proxy with an empty boot allowlist and a
// store-backed RefreshProviders. Production wiring is identical: the
// daemon constructs the proxy without a static allowlist and lets
// Reload populate the router from the database.
func newReloadTestHarness(t *testing.T) *reloadTestHarness {
t.Helper()
recorder := &aibridgedRecorder{}
bridged := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
recorder.record(r.URL.Path)
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("aibridged"))
}))
t.Cleanup(bridged.Close)
store := &providerStore{}
srv := newTestProxy(t,
withCoderAccessURL(bridged.URL),
withAllowedPorts("443"),
// Empty boot allowlist: the router must be populated by Reload,
// matching the production daemon's behavior.
withDomainAllowlist(),
withAIBridgeProviderFromHost(nil),
withRefreshProviders(store.refresh),
)
certPool := getProxyCertPool(t)
client := newProxyClient(t, srv, makeProxyAuthHeader("coder-token"), certPool, false)
// Disable keep-alives so each request opens a fresh CONNECT through
// the proxy. Per the Reload contract, already-MITM'd tunnels keep
// the provider name they captured at CONNECT time; only new
// connections see the post-Reload snapshot. Tests need a fresh
// CONNECT between phases to assert on the new routing.
client.Transport.(*http.Transport).DisableKeepAlives = true
return &reloadTestHarness{
srv: srv,
store: store,
client: client,
bridged: bridged,
recorder: recorder,
}
}
// requestResult is the outcome of sending a request through the proxy.
// Either err is set (CONNECT failed for a non-MITM'd host whose dial
// fell through to the tunneled path and could not be resolved) or
// status/body carry the MITM'd response from the mock aibridged.
type requestResult struct {
status int
body string
err error
}
// sendRequest issues a single POST through the proxy. It returns rather
// than asserting so callers can branch on whether the host is currently
// routed (MITM'd to aibridged) or not (tunneled, dial of an unresolvable
// host fails).
func (h *reloadTestHarness) sendRequest(t *testing.T, targetURL string) requestResult {
t.Helper()
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitShort)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, strings.NewReader(`{}`))
require.NoError(t, err)
req.Header.Set("Content-Type", "application/json")
resp, err := h.client.Do(req)
if err != nil {
return requestResult{err: err}
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
return requestResult{status: resp.StatusCode, body: string(body)}
}
// expectRoutedTo asserts the proxy MITM'd the request and forwarded it
// to aibridged with the expected /api/v2/aibridge/<name>/<path>.
func (h *reloadTestHarness) expectRoutedTo(t *testing.T, targetURL, expectedPath string) {
t.Helper()
h.recorder.reset()
res := h.sendRequest(t, targetURL)
require.NoError(t, res.err, "request to routed host must succeed")
require.Equal(t, http.StatusOK, res.status)
require.Equal(t, "aibridged", res.body)
require.Equal(t, expectedPath, h.recorder.load(),
"aibridged must observe the rewritten path for %s", targetURL)
}
// expectNotRouted asserts the proxy did not MITM the request for the
// given host. The CONNECT either falls through to the tunneled path
// (where the .invalid hostname fails to dial) or to a 502 from the
// proxy. Either way, aibridged never sees the request.
func (h *reloadTestHarness) expectNotRouted(t *testing.T, targetURL string) {
t.Helper()
h.recorder.reset()
_ = h.sendRequest(t, targetURL)
require.Empty(t, h.recorder.load(),
"aibridged must not be reached for non-routed host %s", targetURL)
}
// TestProxy_HotReloadRoutingCRUD drives the proxy through a CRUD-style
// sequence of provider changes and asserts on routing after each
// Reload via real HTTPS requests. Each sub-test mutates the store and
// validates that:
// - newly created providers are MITM'd to aibridged with the right
// /api/v2/aibridge/<name>/<path>
// - renamed providers route under the new name
// - providers whose BaseURL host changes route the new host and stop
// MITM'ing the old host
// - deleted providers stop being MITM'd; aibridged sees nothing
//
// Hostnames are .invalid (RFC 2606) so a request that escapes the MITM
// path fails fast via DNS rather than reaching a real upstream.
func TestProxy_HotReloadRoutingCRUD(t *testing.T) {
t.Parallel()
h := newReloadTestHarness(t)
// InitialEmptyRouter: no Reload has been called and the boot
// allowlist is empty, so any host falls through to the tunneled
// middleware.
h.expectNotRouted(t, "https://alpha.invalid/v1/messages")
// CreateProvider.
h.store.set([]aibridgeproxyd.ProviderRoute{
{Name: "alpha", BaseURL: "https://alpha.invalid/v1"},
})
require.NoError(t, h.srv.Reload(t.Context()))
h.expectRoutedTo(t, "https://alpha.invalid/v1/messages", "/api/v2/aibridge/alpha/v1/messages")
// UpdateProviderName: the same BaseURL with a new name must route
// under the new name on the next Reload.
h.store.set([]aibridgeproxyd.ProviderRoute{
{Name: "alpha-v2", BaseURL: "https://alpha.invalid/v1"},
})
require.NoError(t, h.srv.Reload(t.Context()))
h.expectRoutedTo(t, "https://alpha.invalid/v1/messages", "/api/v2/aibridge/alpha-v2/v1/messages")
// UpdateProviderBaseURLHost: moving the provider to a new host must
// start MITM'ing the new host and stop MITM'ing the old one.
h.store.set([]aibridgeproxyd.ProviderRoute{
{Name: "alpha-v2", BaseURL: "https://alpha-new.invalid/v1"},
})
require.NoError(t, h.srv.Reload(t.Context()))
h.expectRoutedTo(t, "https://alpha-new.invalid/v1/messages", "/api/v2/aibridge/alpha-v2/v1/messages")
h.expectNotRouted(t, "https://alpha.invalid/v1/messages")
// AddSecondProvider: a second provider added in the same Reload must
// route independently from the first.
h.store.set([]aibridgeproxyd.ProviderRoute{
{Name: "alpha-v2", BaseURL: "https://alpha-new.invalid/v1"},
{Name: "beta", BaseURL: "https://beta.invalid/v1"},
})
require.NoError(t, h.srv.Reload(t.Context()))
h.expectRoutedTo(t, "https://alpha-new.invalid/v1/messages", "/api/v2/aibridge/alpha-v2/v1/messages")
h.expectRoutedTo(t, "https://beta.invalid/v1/chat/completions", "/api/v2/aibridge/beta/v1/chat/completions")
// DeleteOneProvider: removing alpha must keep beta routed and stop
// routing alpha.
h.store.set([]aibridgeproxyd.ProviderRoute{
{Name: "beta", BaseURL: "https://beta.invalid/v1"},
})
require.NoError(t, h.srv.Reload(t.Context()))
h.expectRoutedTo(t, "https://beta.invalid/v1/chat/completions", "/api/v2/aibridge/beta/v1/chat/completions")
h.expectNotRouted(t, "https://alpha-new.invalid/v1/messages")
// DeleteAllProviders: an empty Reload must collapse the router to
// the fail-closed state with no host MITM'd.
h.store.set(nil)
require.NoError(t, h.srv.Reload(t.Context()))
h.expectNotRouted(t, "https://beta.invalid/v1/chat/completions")
h.expectNotRouted(t, "https://alpha-new.invalid/v1/messages")
// RecreateAfterDelete: reintroducing a previously-deleted provider
// must route again without restart, confirming the swap is
// symmetric.
h.store.set([]aibridgeproxyd.ProviderRoute{
{Name: "alpha", BaseURL: "https://alpha.invalid/v1"},
})
require.NoError(t, h.srv.Reload(t.Context()))
h.expectRoutedTo(t, "https://alpha.invalid/v1/messages", "/api/v2/aibridge/alpha/v1/messages")
}
// TestProxy_HotReloadRoutingInvalidProviders covers the resilience
// requirements stated in the [aibridgeproxyd.Server.Reload] contract:
// individual invalid provider entries do not poison the snapshot, and
// a refresh-level error does not collapse the previous snapshot to
// empty.
func TestProxy_HotReloadRoutingInvalidProviders(t *testing.T) {
t.Parallel()
t.Run("EmptyBaseURLSkipped", func(t *testing.T) {
t.Parallel()
h := newReloadTestHarness(t)
// One valid provider and one with an empty BaseURL. The empty
// entry must be silently dropped; the valid one must still
// route.
h.store.set([]aibridgeproxyd.ProviderRoute{
{Name: "no-url"},
{Name: "valid", BaseURL: "https://valid.invalid/v1"},
})
require.NoError(t, h.srv.Reload(t.Context()))
h.expectRoutedTo(t, "https://valid.invalid/v1/messages", "/api/v2/aibridge/valid/v1/messages")
})
t.Run("MalformedBaseURLSkipped", func(t *testing.T) {
t.Parallel()
h := newReloadTestHarness(t)
// A BaseURL that fails url.Parse and one whose Hostname() is
// empty must both be dropped. Mixed with a valid entry, only
// the valid one routes.
h.store.set([]aibridgeproxyd.ProviderRoute{
{Name: "malformed", BaseURL: "://not-a-url"},
{Name: "no-host", BaseURL: "https://"},
{Name: "valid", BaseURL: "https://valid.invalid/v1"},
})
require.NoError(t, h.srv.Reload(t.Context()))
h.expectRoutedTo(t, "https://valid.invalid/v1/messages", "/api/v2/aibridge/valid/v1/messages")
})
t.Run("DuplicateHostFirstWins", func(t *testing.T) {
t.Parallel()
h := newReloadTestHarness(t)
// Two providers with the same BaseURL host: the first one wins,
// matching buildProviderRouter's documented contract.
h.store.set([]aibridgeproxyd.ProviderRoute{
{Name: "first", BaseURL: "https://shared.invalid/v1"},
{Name: "second", BaseURL: "https://shared.invalid/v2"},
})
require.NoError(t, h.srv.Reload(t.Context()))
h.expectRoutedTo(t, "https://shared.invalid/v1/messages", "/api/v2/aibridge/first/v1/messages")
})
t.Run("AllInvalidYieldsEmptyRouter", func(t *testing.T) {
t.Parallel()
h := newReloadTestHarness(t)
// When every provider is invalid, the router contains no
// entries and the proxy fails closed: no host is MITM'd.
h.store.set([]aibridgeproxyd.ProviderRoute{
{Name: "no-url"},
{Name: "malformed", BaseURL: "://not-a-url"},
{Name: "no-host", BaseURL: "https://"},
})
require.NoError(t, h.srv.Reload(t.Context()))
h.expectNotRouted(t, "https://anything.invalid/v1/messages")
})
t.Run("RefreshErrorPreservesPreviousSnapshot", func(t *testing.T) {
t.Parallel()
h := newReloadTestHarness(t)
// Seed a valid snapshot so we have something to preserve.
h.store.set([]aibridgeproxyd.ProviderRoute{
{Name: "alpha", BaseURL: "https://alpha.invalid/v1"},
})
require.NoError(t, h.srv.Reload(t.Context()))
h.expectRoutedTo(t, "https://alpha.invalid/v1/messages", "/api/v2/aibridge/alpha/v1/messages")
// A refresh error must NOT clear the router: dropping the
// allowlist on every transient DB hiccup would amplify the
// fault into a denial of service.
h.store.setErr(xerrors.New("simulated db failure"))
err := h.srv.Reload(t.Context())
require.Error(t, err)
assert.Contains(t, err.Error(), "refresh ai providers for proxy routing")
h.expectRoutedTo(t, "https://alpha.invalid/v1/messages", "/api/v2/aibridge/alpha/v1/messages")
// Recovery: once the store returns providers again, the next
// Reload applies the new snapshot.
h.store.set([]aibridgeproxyd.ProviderRoute{
{Name: "beta", BaseURL: "https://beta.invalid/v1"},
})
require.NoError(t, h.srv.Reload(t.Context()))
h.expectRoutedTo(t, "https://beta.invalid/v1/messages", "/api/v2/aibridge/beta/v1/messages")
h.expectNotRouted(t, "https://alpha.invalid/v1/messages")
})
}
+63 -49
View File
@@ -4,27 +4,45 @@ package cli
import (
"context"
"net/url"
"io"
"path/filepath"
"strings"
"github.com/prometheus/client_golang/prometheus"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/aibridge"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/aibridge/intercept/apidump"
"github.com/coder/coder/v2/coderd/aibridged"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/enterprise/aibridgeproxyd"
"github.com/coder/coder/v2/enterprise/coderd"
)
func newAIBridgeProxyDaemon(coderAPI *coderd.API, providers []aibridge.Provider) (*aibridgeproxyd.Server, error) {
// aiBridgeProxyDaemon bundles the proxy server and its pubsub
// subscription so both are torn down by a single Close call.
type aiBridgeProxyDaemon struct {
server *aibridgeproxyd.Server
unsubscribe func()
}
func (d *aiBridgeProxyDaemon) Close() error {
if d.unsubscribe != nil {
d.unsubscribe()
}
return d.server.Close()
}
// newAIBridgeProxyDaemon starts the enterprise aibridge proxy daemon,
// subscribes to ai_providers changes so the proxy's routing snapshot
// tracks the database, and registers the HTTP handler on the API.
// The returned io.Closer tears down both the subscription and server.
func newAIBridgeProxyDaemon(coderAPI *coderd.API) (io.Closer, error) {
ctx := context.Background()
coderAPI.Logger.Debug(ctx, "starting in-memory aibridgeproxy daemon")
logger := coderAPI.Logger.Named("aibridgeproxyd")
domains, providerFromHost := domainsFromProviders(providers)
reg := prometheus.WrapRegistererWithPrefix("coder_aibridgeproxyd_", coderAPI.PrometheusRegistry)
metrics := aibridgeproxyd.NewMetrics(reg)
@@ -36,55 +54,51 @@ func newAIBridgeProxyDaemon(coderAPI *coderd.API, providers []aibridge.Provider)
}
srv, err := aibridgeproxyd.New(ctx, logger, aibridgeproxyd.Options{
ListenAddr: coderAPI.DeploymentValues.AI.BridgeProxyConfig.ListenAddr.String(),
TLSCertFile: coderAPI.DeploymentValues.AI.BridgeProxyConfig.TLSCertFile.String(),
TLSKeyFile: coderAPI.DeploymentValues.AI.BridgeProxyConfig.TLSKeyFile.String(),
CoderAccessURL: coderAPI.AccessURL.String(),
MITMCertFile: coderAPI.DeploymentValues.AI.BridgeProxyConfig.MITMCertFile.String(),
MITMKeyFile: coderAPI.DeploymentValues.AI.BridgeProxyConfig.MITMKeyFile.String(),
DomainAllowlist: domains,
AIBridgeProviderFromHost: providerFromHost,
UpstreamProxy: coderAPI.DeploymentValues.AI.BridgeProxyConfig.UpstreamProxy.String(),
UpstreamProxyCA: coderAPI.DeploymentValues.AI.BridgeProxyConfig.UpstreamProxyCA.String(),
AllowedPrivateCIDRs: coderAPI.DeploymentValues.AI.BridgeProxyConfig.AllowedPrivateCIDRs.Value(),
NewDumper: newDumper,
Metrics: metrics,
ListenAddr: coderAPI.DeploymentValues.AI.BridgeProxyConfig.ListenAddr.String(),
TLSCertFile: coderAPI.DeploymentValues.AI.BridgeProxyConfig.TLSCertFile.String(),
TLSKeyFile: coderAPI.DeploymentValues.AI.BridgeProxyConfig.TLSKeyFile.String(),
CoderAccessURL: coderAPI.AccessURL.String(),
MITMCertFile: coderAPI.DeploymentValues.AI.BridgeProxyConfig.MITMCertFile.String(),
MITMKeyFile: coderAPI.DeploymentValues.AI.BridgeProxyConfig.MITMKeyFile.String(),
UpstreamProxy: coderAPI.DeploymentValues.AI.BridgeProxyConfig.UpstreamProxy.String(),
UpstreamProxyCA: coderAPI.DeploymentValues.AI.BridgeProxyConfig.UpstreamProxyCA.String(),
AllowedPrivateCIDRs: coderAPI.DeploymentValues.AI.BridgeProxyConfig.AllowedPrivateCIDRs.Value(),
NewDumper: newDumper,
Metrics: metrics,
RefreshProviders: refreshProxyProviders(coderAPI.Database),
})
if err != nil {
return nil, xerrors.Errorf("failed to start in-memory aibridgeproxy daemon: %w", err)
}
return srv, nil
}
// domainsFromProviders extracts distinct hostnames from providers' base
// URLs and builds a host-to-provider-name mapping function. The returned
// domain list is suitable for use as DomainAllowlist and the mapping
// function is suitable for use as AIBridgeProviderFromHost.
func domainsFromProviders(providers []aibridge.Provider) ([]string, func(string) string) {
hostToProvider := make(map[string]string, len(providers))
var domains []string
for _, p := range providers {
raw := p.BaseURL()
if raw == "" {
continue
}
u, err := url.Parse(raw)
if err != nil || u.Hostname() == "" {
continue
}
host := strings.ToLower(u.Hostname())
if _, exists := hostToProvider[host]; exists {
// First provider wins; duplicates are expected when
// multiple providers share a base URL host (e.g. two
// OpenAI providers using the same proxy).
continue
}
hostToProvider[host] = p.Name()
domains = append(domains, host)
unsubscribe, err := aibridged.SubscribeProviderReload(ctx, coderAPI.Pubsub, srv, logger.Named("provider-reload"))
if err != nil {
logger.Warn(ctx, "subscribe aibridgeproxyd to ai providers change channel", slog.Error(err))
unsubscribe = func() {}
}
return domains, func(host string) string {
return hostToProvider[strings.ToLower(host)]
// Register the handler so coderd can serve the proxy endpoints.
coderAPI.RegisterInMemoryAIBridgeProxydHTTPHandler(srv.Handler())
return &aiBridgeProxyDaemon{
server: srv,
unsubscribe: unsubscribe,
}, nil
}
func refreshProxyProviders(db database.Store) aibridgeproxyd.RefreshProvidersFunc {
return func(ctx context.Context) ([]aibridgeproxyd.ProviderRoute, error) {
//nolint:gocritic // AsAIProviderMetadataReader is the correct subject for routing-only access.
rows, err := db.GetAIProviders(dbauthz.AsAIProviderMetadataReader(ctx), database.GetAIProvidersParams{
IncludeDisabled: false,
})
if err != nil {
return nil, xerrors.Errorf("load ai providers: %w", err)
}
out := make([]aibridgeproxyd.ProviderRoute, 0, len(rows))
for _, row := range rows {
out = append(out, aibridgeproxyd.ProviderRoute{Name: row.Name, BaseURL: row.BaseUrl})
}
return out, nil
}
}
@@ -1,72 +0,0 @@
//go:build !slim
package cli
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/coder/coder/v2/aibridge"
)
func TestDomainsFromProviders(t *testing.T) {
t.Parallel()
t.Run("ExtractsHostnames", func(t *testing.T) {
t.Parallel()
providers := []aibridge.Provider{
aibridge.NewOpenAIProvider(aibridge.OpenAIConfig{Name: "openai", BaseURL: "https://api.openai.com/v1/"}),
aibridge.NewAnthropicProvider(aibridge.AnthropicConfig{Name: "anthropic", BaseURL: "https://api.anthropic.com/"}, nil),
aibridge.NewOpenAIProvider(aibridge.OpenAIConfig{Name: "custom", BaseURL: "https://custom-llm.example.com:8443/api"}),
}
domains, mapping := domainsFromProviders(providers)
assert.Contains(t, domains, "api.openai.com")
assert.Contains(t, domains, "api.anthropic.com")
assert.Contains(t, domains, "custom-llm.example.com")
assert.Equal(t, "openai", mapping("api.openai.com"))
assert.Equal(t, "anthropic", mapping("api.anthropic.com"))
assert.Equal(t, "custom", mapping("custom-llm.example.com"))
assert.Empty(t, mapping("unknown.com"))
})
t.Run("DeduplicatesSameHost", func(t *testing.T) {
t.Parallel()
providers := []aibridge.Provider{
aibridge.NewOpenAIProvider(aibridge.OpenAIConfig{Name: "first", BaseURL: "https://api.example.com/v1"}),
aibridge.NewOpenAIProvider(aibridge.OpenAIConfig{Name: "second", BaseURL: "https://api.example.com/v2"}),
}
domains, mapping := domainsFromProviders(providers)
// Count occurrences of api.example.com.
count := 0
for _, d := range domains {
if d == "api.example.com" {
count++
}
}
assert.Equal(t, 1, count)
// First provider wins.
assert.Equal(t, "first", mapping("api.example.com"))
})
t.Run("CaseInsensitive", func(t *testing.T) {
t.Parallel()
providers := []aibridge.Provider{
aibridge.NewOpenAIProvider(aibridge.OpenAIConfig{Name: "provider", BaseURL: "https://API.Example.COM/v1"}),
}
domains, mapping := domainsFromProviders(providers)
assert.Contains(t, domains, "api.example.com")
assert.Equal(t, "provider", mapping("API.Example.COM"))
assert.Equal(t, "provider", mapping("api.example.com"))
})
}
+10 -17
View File
@@ -15,7 +15,6 @@ import (
"tailscale.com/derp"
"tailscale.com/types/key"
agplcli "github.com/coder/coder/v2/cli"
agplcoderd "github.com/coder/coder/v2/coderd"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/cryptorand"
@@ -167,13 +166,14 @@ func (r *RootCmd) Server(_ func()) *serpent.Command {
// in-memory roundtripper regardless of license); only the proxy
// daemon remains enterprise-gated by config.
if options.DeploymentValues.AI.BridgeProxyConfig.Enabled.Value() {
// Seed env-derived providers before reading them back so the
// proxy observes them on first startup. options.Database is
// dbcrypt-wrapped at this point (set by coderd.New above),
// so env-seeded keys are also written encrypted. Detached
// ctx for the same reason as in agplcli below: an early
// return would orphan newAPI's goroutines. Seeding is
// idempotent; the agplcli path seeds again post-newAPI.
// Seed env-derived providers before the proxy daemon's reloader
// reads them back so the proxy observes them on first startup.
// options.Database is dbcrypt-wrapped at this point (set by
// coderd.New above), so env-seeded keys are also written
// encrypted. Detached ctx for the same reason as in agplcli
// below: an early return would orphan newAPI's goroutines.
// Seeding is idempotent; the agplcli path seeds again
// post-newAPI.
//nolint:gocritic // Production timeout, not a test wait.
aibridgeInitCtx, aibridgeInitCancel := context.WithTimeout(context.WithoutCancel(ctx), 30*time.Second)
defer aibridgeInitCancel()
@@ -185,19 +185,12 @@ func (r *RootCmd) Server(_ func()) *serpent.Command {
); err != nil {
return nil, nil, xerrors.Errorf("seed ai providers from env: %w", err)
}
providers, err := agplcli.BuildProviders(aibridgeInitCtx, options.Database, options.DeploymentValues.AI.BridgeConfig, options.Logger.Named("aibridge.providers"))
if err != nil {
return nil, nil, xerrors.Errorf("build AI providers: %w", err)
}
aiBridgeProxyServer, err := newAIBridgeProxyDaemon(api, providers)
aiBridgeProxyCloser, err := newAIBridgeProxyDaemon(api)
if err != nil {
_ = closers.Close()
return nil, nil, xerrors.Errorf("create aibridgeproxyd: %w", err)
}
closers.Add(aiBridgeProxyServer)
// Register the handler so coderd can serve the proxy endpoints.
api.RegisterInMemoryAIBridgeProxydHTTPHandler(aiBridgeProxyServer.Handler())
closers.Add(aiBridgeProxyCloser)
}
return api.AGPL, closers, nil
+233
View File
@@ -0,0 +1,233 @@
package coderd_test
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/cli"
"github.com/coder/coder/v2/coderd"
"github.com/coder/coder/v2/coderd/aibridged"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/enterprise/coderd/coderdenttest"
"github.com/coder/coder/v2/enterprise/coderd/license"
"github.com/coder/coder/v2/testutil"
"github.com/coder/serpent"
)
// mockUpstream is a single httptest server identified by a unique
// marker that it echoes in every response body, so callers can verify
// which upstream a proxied request actually reached. The hit counter
// supports asserting the upstream was touched at all.
type mockUpstream struct {
server *httptest.Server
name string
hits atomic.Int32
}
func newMockUpstream(t *testing.T, name string) *mockUpstream {
t.Helper()
m := &mockUpstream{name: name}
m.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
m.hits.Add(1)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
assert.NoError(t, json.NewEncoder(w).Encode(map[string]string{"upstream": name}))
}))
t.Cleanup(m.server.Close)
return m
}
// startTestAIBridgeDaemon wires an in-process aibridged daemon onto
// the supplied API and subscribes it to ai_providers change events.
// This mirrors what cli/server.go does in production so /api/v2/aibridge
// requests dispatch through the real pool and reloader.
func startTestAIBridgeDaemon(t *testing.T, api *coderd.API) {
t.Helper()
ctx := context.Background()
logger := slogtest.Make(t, nil).Named("aibridged").Leveled(slog.LevelDebug)
cfg := api.DeploymentValues.AI.BridgeConfig
tracer := otel.Tracer("aibridge-reload-test")
providers, err := cli.BuildProviders(ctx, api.Database, cfg, logger)
require.NoError(t, err)
pool, err := aibridged.NewCachedBridgePool(aibridged.DefaultPoolOptions, providers, logger.Named("pool"), nil, tracer)
require.NoError(t, err)
t.Cleanup(func() { _ = pool.Shutdown(context.Background()) })
reloader := &testPoolReloader{pool: pool, db: api.Database, cfg: cfg, logger: logger.Named("reloader")}
unsubscribe, err := aibridged.SubscribeProviderReload(ctx, api.Pubsub, reloader, logger.Named("subscriber"))
require.NoError(t, err)
t.Cleanup(unsubscribe)
srv, err := aibridged.New(ctx, pool, func(dialCtx context.Context) (aibridged.DRPCClient, error) {
return api.CreateInMemoryAIBridgeServer(dialCtx)
}, logger, tracer)
require.NoError(t, err)
t.Cleanup(func() { _ = srv.Close() })
api.RegisterInMemoryAIBridgedHTTPHandler(srv)
}
type testPoolReloader struct {
pool *aibridged.CachedBridgePool
db database.Store
cfg codersdk.AIBridgeConfig
logger slog.Logger
}
func (r *testPoolReloader) Reload(ctx context.Context) error {
providers, err := cli.BuildProviders(ctx, r.db, r.cfg, r.logger)
if err != nil {
return err
}
r.pool.ReplaceProviders(providers)
return nil
}
// TestAIBridgeProviderHotReload exercises the end-to-end CRUD ->
// reload -> routing path: every provider mutation made through codersdk
// must, within a short window, change the routing observed at
// /api/v2/aibridge/{name}/v1/models. The OpenAI passthrough route
// /v1/models reverse-proxies to BaseURL, so the upstream that responds
// identifies which provider the daemon's mux dispatched to.
func TestAIBridgeProviderHotReload(t *testing.T) {
t.Parallel()
// Two distinct upstreams so an Update that swings the BaseURL is
// observable: which upstream answers tells us which BaseURL the
// freshly-built provider is pointed at.
upstreamA := newMockUpstream(t, "a")
upstreamB := newMockUpstream(t, "b")
dv := coderdtest.DeploymentValues(t)
dv.AI.BridgeConfig.Enabled = serpent.Bool(true)
client, _, api, _ := coderdenttest.NewWithAPI(t, &coderdenttest.Options{
Options: &coderdtest.Options{DeploymentValues: dv},
LicenseOptions: &coderdenttest.LicenseOptions{
Features: license.Features{codersdk.FeatureAIBridge: 1},
},
})
startTestAIBridgeDaemon(t, api.AGPL)
ctx := testutil.Context(t, testutil.WaitLong)
// sendRequest issues GET /api/v2/aibridge/{name}/v1/models and
// returns the status and the upstream marker decoded from the
// JSON body (empty if the body was not the marker JSON).
sendRequest := func(providerName string) (int, string) {
url := client.URL.String() + "/api/v2/aibridge/" + providerName + "/v1/models"
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
require.NoError(t, err)
req.Header.Set("Authorization", "Bearer "+client.SessionToken())
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return resp.StatusCode, ""
}
var decoded map[string]string
_ = json.Unmarshal(body, &decoded)
return resp.StatusCode, decoded["upstream"]
}
// requireRoutesTo polls until the routing reflects the expected
// upstream. The pool reloads asynchronously from a pubsub event;
// require.Eventually is the natural fit.
requireRoutesTo := func(t *testing.T, providerName string, upstream *mockUpstream) {
t.Helper()
before := upstream.hits.Load()
require.Eventuallyf(t, func() bool {
status, marker := sendRequest(providerName)
return status == http.StatusOK && marker == upstream.name
}, testutil.WaitShort, testutil.IntervalFast,
"expected provider %q to route to upstream %q", providerName, upstream.name)
require.Greater(t, upstream.hits.Load(), before,
"upstream %q must have observed at least one request", upstream.name)
}
// requireRoutingGone polls until the provider name yields a 404
// from the aibridge mux's catch-all, indicating the provider has
// been removed from the pool snapshot.
requireRoutingGone := func(t *testing.T, providerName string) {
t.Helper()
require.Eventuallyf(t, func() bool {
status, _ := sendRequest(providerName)
return status == http.StatusNotFound
}, testutil.WaitShort, testutil.IntervalFast,
"expected provider %q to stop routing", providerName)
}
// 1. Create: provider points at upstream A.
created, err := client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{
Type: codersdk.AIProviderTypeOpenAI,
Name: "primary",
Enabled: true,
BaseURL: upstreamA.server.URL,
APIKeys: []string{"sk-primary-key"},
})
require.NoError(t, err)
require.Equal(t, "primary", created.Name)
requireRoutesTo(t, "primary", upstreamA)
// 2. Update BaseURL: same name, now points at upstream B.
newBaseURL := upstreamB.server.URL
_, err = client.UpdateAIProvider(ctx, "primary", codersdk.UpdateAIProviderRequest{
BaseURL: &newBaseURL,
})
require.NoError(t, err)
requireRoutesTo(t, "primary", upstreamB)
// 3. Disable: the provider drops out of the snapshot, requests
// stop reaching any upstream.
disabled := false
_, err = client.UpdateAIProvider(ctx, "primary", codersdk.UpdateAIProviderRequest{
Enabled: &disabled,
})
require.NoError(t, err)
requireRoutingGone(t, "primary")
// 4. Re-enable: routing comes back at the most recent BaseURL.
enabled := true
_, err = client.UpdateAIProvider(ctx, "primary", codersdk.UpdateAIProviderRequest{
Enabled: &enabled,
})
require.NoError(t, err)
requireRoutesTo(t, "primary", upstreamB)
// 5. Add a second provider; both names must route independently.
_, err = client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{
Type: codersdk.AIProviderTypeOpenAI,
Name: "secondary",
Enabled: true,
BaseURL: upstreamA.server.URL,
APIKeys: []string{"sk-secondary-key"},
})
require.NoError(t, err)
requireRoutesTo(t, "primary", upstreamB)
requireRoutesTo(t, "secondary", upstreamA)
// 6. Delete primary: only secondary remains routable.
require.NoError(t, client.DeleteAIProvider(ctx, "primary"))
requireRoutingGone(t, "primary")
requireRoutesTo(t, "secondary", upstreamA)
}