Files
Danny Kopping 12520ee964 feat: add ai provider status and reload freshness metrics (#25770)
Add metrics for `aibridged` and `aibridgeproxyd`'s provider statuses. AI providers can be modified, and possibly misconfigured, at runtime. These metrics help operators understand the state of these provider definitions in case unexpected behaviour is observed.
2026-05-28 14:57:33 +02:00

2402 lines
80 KiB
Go

package aibridgeproxyd_test
import (
"bufio"
"bytes"
"context"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/base64"
"encoding/pem"
"fmt"
"io"
"math/big"
"net"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"strings"
"sync"
"testing"
"time"
"github.com/google/uuid"
"github.com/prometheus/client_golang/prometheus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/aibridge"
agplaibridge "github.com/coder/coder/v2/coderd/aibridge"
"github.com/coder/coder/v2/coderd/aibridged"
"github.com/coder/coder/v2/enterprise/aibridgeproxyd"
"github.com/coder/coder/v2/testutil"
)
var (
// testMITMCertOnce ensures the shared MITM certificate is generated exactly once.
// sync.Once guarantees single execution even with parallel tests.
// Note: no retry on failure.
testMITMCertOnce sync.Once
// Shared MITM certificate and key paths, and any error from generation.
// These are set once by testMITMCertOnce and read by all tests.
testMITMCert string
testMITMKey string
errTestSharedMITMCert error
)
// getSharedTestMITMCert returns a shared MITM certificate for all tests.
// This avoids race conditions with goproxy.GoproxyCa which is a global variable.
// Using sync.Once ensures the certificate is generated exactly once, even when
// tests run in parallel. All tests share the same certificate, so
// goproxy.GoproxyCa is only set once.
func getSharedTestMITMCert(t *testing.T) (certFile, keyFile string) {
t.Helper()
testMITMCertOnce.Do(func() {
testMITMCert, testMITMKey, errTestSharedMITMCert = generateSharedTestMITMCert()
})
require.NoError(t, errTestSharedMITMCert, "failed to generate shared test MITM certificate")
return testMITMCert, testMITMKey
}
// generateSharedTestMITMCert creates a shared MITM certificate and key for testing.
func generateSharedTestMITMCert() (certFile, keyFile string, err error) {
mitmKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return "", "", xerrors.Errorf("generate MITM key: %w", err)
}
// Create a self-signed root CA certificate used to sign per-hostname
// leaf certificates during MITM interception.
mitmTemplate := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: "Shared Test MITM Cert"},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour),
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
BasicConstraintsValid: true,
IsCA: true,
}
mitmCertDER, err := x509.CreateCertificate(rand.Reader, &mitmTemplate, &mitmTemplate, &mitmKey.PublicKey, mitmKey)
if err != nil {
return "", "", xerrors.Errorf("create MITM certificate: %w", err)
}
tmpDir := os.TempDir()
certPath := filepath.Join(tmpDir, "aibridgeproxyd_test_mitm.crt")
keyPath := filepath.Join(tmpDir, "aibridgeproxyd_test_mitm.key")
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: mitmCertDER})
if err := os.WriteFile(certPath, certPEM, 0o600); err != nil {
return "", "", xerrors.Errorf("write cert file: %w", err)
}
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(mitmKey)})
if err := os.WriteFile(keyPath, keyPEM, 0o600); err != nil {
return "", "", xerrors.Errorf("write key file: %w", err)
}
return certPath, keyPath, nil
}
// generateListenerCert generates a self-signed certificate and key for use as a
// proxy listener TLS certificate. Files are written to t.TempDir() and cleaned
// up automatically when the test ends.
func generateListenerCert(t *testing.T) (certFile, keyFile string) {
t.Helper()
key, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err, "generate listener key")
template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: "Test Listener"},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
// The client connects to the proxy via IP address, so the certificate
// must include 127.0.0.1 as a Subject Alternative Name for validation to succeed.
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
}
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key)
require.NoError(t, err, "create listener certificate")
tmpDir := t.TempDir()
certPath := filepath.Join(tmpDir, "listener.crt")
keyPath := filepath.Join(tmpDir, "listener.key")
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
require.NoError(t, os.WriteFile(certPath, certPEM, 0o600), "write listener cert file")
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})
require.NoError(t, os.WriteFile(keyPath, keyPEM, 0o600), "write listener key file")
return certPath, keyPath
}
type testProxyConfig struct {
listenAddr string
tlsCertFile string
tlsKeyFile string
coderAccessURL string
allowedPorts []string
certStore *aibridgeproxyd.CertCache
providers []aibridgeproxyd.ReloadedProvider
upstreamProxy string
upstreamProxyCA string
allowedPrivateCIDRs []string
newDumper func(string, string) aibridgeproxyd.RoundTripDumper
metrics *aibridgeproxyd.Metrics
refreshProviders aibridgeproxyd.RefreshProvidersFunc
}
type testProxyOption func(*testProxyConfig)
func withAllowedPorts(ports ...string) testProxyOption {
return func(cfg *testProxyConfig) {
cfg.allowedPorts = ports
}
}
func withCoderAccessURL(coderAccessURL string) testProxyOption {
return func(cfg *testProxyConfig) {
cfg.coderAccessURL = coderAccessURL
}
}
func withCertStore(store *aibridgeproxyd.CertCache) testProxyOption {
return func(cfg *testProxyConfig) {
cfg.certStore = store
}
}
// withProviders configures the proxy with the given classified provider
// set. The reload helper synthesizes a RefreshProvidersFunc and the
// router is populated synchronously during newTestProxy before the
// server begins serving.
func withProviders(providers ...aibridgeproxyd.ReloadedProvider) testProxyOption {
return func(cfg *testProxyConfig) {
cfg.providers = providers
}
}
// withProviderHosts is a convenience that builds enabled
// ReloadedProvider entries from each host, looking up the well-known
// provider name via testProviderFromHost and falling back to
// "test-provider" for hosts without a well-known mapping. Equivalent
// to passing each entry individually to withProviders.
func withProviderHosts(hosts ...string) testProxyOption {
return func(cfg *testProxyConfig) {
providers := make([]aibridgeproxyd.ReloadedProvider, 0, len(hosts))
for _, h := range hosts {
name := testProviderFromHost(h)
if name == "" {
name = "test-provider"
}
host, _, splitErr := net.SplitHostPort(h)
if splitErr != nil {
host = h
}
providers = append(providers, aibridgeproxyd.ReloadedProvider{
ProviderOutcome: aibridged.ProviderOutcome{
Name: name,
Type: "openai",
Status: aibridged.ProviderStatusEnabled,
},
Host: strings.ToLower(host),
})
}
cfg.providers = providers
}
}
// testProviderFromHost maps well-known AI provider hostnames to
// provider names for test use. Unknown hosts return "".
func testProviderFromHost(host string) string {
switch strings.ToLower(host) {
case aibridgeproxyd.HostAnthropic:
return aibridge.ProviderAnthropic
case aibridgeproxyd.HostOpenAI:
return aibridge.ProviderOpenAI
case aibridgeproxyd.HostCopilot:
return aibridge.ProviderCopilot
case agplaibridge.HostCopilotBusiness:
return agplaibridge.ProviderCopilotBusiness
case agplaibridge.HostCopilotEnterprise:
return agplaibridge.ProviderCopilotEnterprise
case agplaibridge.HostChatGPT:
return agplaibridge.ProviderChatGPT
default:
return ""
}
}
func withUpstreamProxy(upstreamProxy string) testProxyOption {
return func(cfg *testProxyConfig) {
cfg.upstreamProxy = upstreamProxy
}
}
func withUpstreamProxyCA(upstreamProxyCA string) testProxyOption {
return func(cfg *testProxyConfig) {
cfg.upstreamProxyCA = upstreamProxyCA
}
}
func withAllowedPrivateCIDRs(cidrs ...string) testProxyOption {
return func(cfg *testProxyConfig) {
cfg.allowedPrivateCIDRs = cidrs
}
}
func withNewDumper(fn func(string, string) aibridgeproxyd.RoundTripDumper) testProxyOption {
return func(cfg *testProxyConfig) {
cfg.newDumper = fn
}
}
func withMetrics(metrics *aibridgeproxyd.Metrics) testProxyOption {
return func(cfg *testProxyConfig) {
cfg.metrics = metrics
}
}
func withListenerTLS(certFile, keyFile string) testProxyOption {
return func(cfg *testProxyConfig) {
cfg.tlsCertFile = certFile
cfg.tlsKeyFile = keyFile
}
}
func withRefreshProviders(fn aibridgeproxyd.RefreshProvidersFunc) testProxyOption {
return func(cfg *testProxyConfig) {
cfg.refreshProviders = fn
}
}
// newTestProxy creates a new AI Bridge Proxy server for testing.
// It uses the shared MITM certificate and registers cleanup automatically.
// It waits for the proxy server to be ready before returning.
func newTestProxy(t *testing.T, opts ...testProxyOption) *aibridgeproxyd.Server {
t.Helper()
cfg := &testProxyConfig{
listenAddr: "127.0.0.1:0",
coderAccessURL: "http://localhost:3000",
// Allow 127.0.0.1 by default so test servers, which always listen on
// loopback, are reachable. Tests that verify IP blocking override this.
allowedPrivateCIDRs: []string{"127.0.0.1/32"},
providers: []aibridgeproxyd.ReloadedProvider{
{ProviderOutcome: aibridged.ProviderOutcome{Name: "test-provider", Type: "openai", Status: aibridged.ProviderStatusEnabled}, Host: "127.0.0.1"},
{ProviderOutcome: aibridged.ProviderOutcome{Name: "test-provider", Type: "openai", Status: aibridged.ProviderStatusEnabled}, Host: "localhost"},
},
}
for _, opt := range opts {
opt(cfg)
}
// If the test did not supply a RefreshProviders, synthesize one
// that returns the configured providers verbatim. This populates
// the router synchronously below, mirroring how production starts
// up after the first reload completes.
if cfg.refreshProviders == nil {
providers := cfg.providers
cfg.refreshProviders = func(context.Context) (aibridgeproxyd.ProviderReload, error) {
return aibridgeproxyd.ProviderReload{Providers: providers}, nil
}
}
mitmCertFile, mitmKeyFile := getSharedTestMITMCert(t)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
aibridgeOpts := aibridgeproxyd.Options{
ListenAddr: cfg.listenAddr,
TLSCertFile: cfg.tlsCertFile,
TLSKeyFile: cfg.tlsKeyFile,
CoderAccessURL: cfg.coderAccessURL,
MITMCertFile: mitmCertFile,
MITMKeyFile: mitmKeyFile,
AllowedPorts: cfg.allowedPorts,
UpstreamProxy: cfg.upstreamProxy,
UpstreamProxyCA: cfg.upstreamProxyCA,
AllowedPrivateCIDRs: cfg.allowedPrivateCIDRs,
NewDumper: cfg.newDumper,
Metrics: cfg.metrics,
RefreshProviders: cfg.refreshProviders,
}
if cfg.certStore != nil {
aibridgeOpts.CertStore = cfg.certStore
}
srv, err := aibridgeproxyd.New(t.Context(), logger, aibridgeOpts)
require.NoError(t, err)
t.Cleanup(func() { _ = srv.Close() })
// Populate the router before the server starts handling traffic.
// Production performs the first reload during boot via pubsub.
require.NoError(t, srv.Reload(t.Context()))
// Wait for the proxy server to be ready.
proxyAddr := srv.Addr()
require.NotEmpty(t, proxyAddr)
require.Eventually(t, func() bool {
conn, err := net.Dial("tcp", proxyAddr)
if err != nil {
return false
}
_ = conn.Close()
return true
}, testutil.WaitShort, testutil.IntervalFast)
return srv
}
// getProxyCertPool returns a cert pool containing the shared MITM certificate.
// This is used for tests where requests are MITM'd by the proxy, so the client
// needs to trust the MITM certificate to verify the generated certificates.
func getProxyCertPool(t *testing.T) *x509.CertPool {
t.Helper()
mitmCertFile, _ := getSharedTestMITMCert(t)
// Load the MITM certificate so the client trusts the proxy's generated certificates.
certPEM, err := os.ReadFile(mitmCertFile)
require.NoError(t, err)
certPool := x509.NewCertPool()
ok := certPool.AppendCertsFromPEM(certPEM)
require.True(t, ok)
return certPool
}
// newProxyClient creates an HTTP(S) client configured to use the proxy.
// It adds a Proxy-Authorization header with the provided token for authentication.
// The certPool and insecureSkipVerify parameters control TLS verification:
// - If the proxy listener is TLS, include the listener certificate.
// - For MITM'd requests, include the proxy's MITM certificate.
// - For tunneled requests, include the target server's certificate.
// - Set insecureSkipVerify when the target cert SANs do not match the hostname.
func newProxyClient(t *testing.T, srv *aibridgeproxyd.Server, proxyAuth string, certPool *x509.CertPool, insecureSkipVerify bool) *http.Client {
t.Helper()
// Create an HTTP(S) client configured to use the proxy.
scheme := "http"
if srv.IsTLSListener() {
scheme = "https"
}
proxyURL, err := url.Parse(scheme + "://" + srv.Addr())
require.NoError(t, err)
transport := &http.Transport{
Proxy: http.ProxyURL(proxyURL),
TLSClientConfig: &tls.Config{
MinVersion: tls.VersionTLS12,
RootCAs: certPool,
InsecureSkipVerify: insecureSkipVerify, //nolint:gosec
},
}
// Only set the header if proxyAuth is provided. This allows tests to
// verify behavior when the Proxy-Authorization header is missing.
if proxyAuth != "" {
transport.ProxyConnectHeader = http.Header{
"Proxy-Authorization": []string{proxyAuth},
}
}
return &http.Client{Transport: transport}
}
// newTargetServer creates a mock HTTPS server that will be the target of proxied requests.
// It returns the server and its parsed URL. The server is automatically closed when the test ends.
func newTargetServer(t *testing.T, handler http.HandlerFunc) (*httptest.Server, *url.URL) {
t.Helper()
srv := httptest.NewTLSServer(handler)
t.Cleanup(srv.Close)
srvURL, err := url.Parse(srv.URL)
require.NoError(t, err)
return srv, srvURL
}
// makeProxyAuthHeader creates a Proxy-Authorization header value with the given token.
// Format: "Basic base64(username:token)" where username is "ignored".
func makeProxyAuthHeader(token string) string {
credentials := base64.StdEncoding.EncodeToString([]byte("ignored:" + token))
return "Basic " + credentials
}
// sendConnect sends a raw CONNECT request to the proxy and returns the response.
// This is needed to test proxy authentication challenges because Go's HTTP client
// doesn't expose the response when CONNECT fails with a non-2xx status.
func sendConnect(t *testing.T, proxyAddr, targetHost, proxyAuth string) *http.Response {
t.Helper()
conn, err := net.Dial("tcp", proxyAddr)
require.NoError(t, err)
t.Cleanup(func() { _ = conn.Close() })
// Build CONNECT request.
var reqBuf bytes.Buffer
_, err = fmt.Fprintf(&reqBuf, "CONNECT %s HTTP/1.1\r\n", targetHost)
require.NoError(t, err)
_, err = fmt.Fprintf(&reqBuf, "Host: %s\r\n", targetHost)
require.NoError(t, err)
if proxyAuth != "" {
_, err = fmt.Fprintf(&reqBuf, "Proxy-Authorization: %s\r\n", proxyAuth)
require.NoError(t, err)
}
_, err = reqBuf.WriteString("\r\n")
require.NoError(t, err)
// Send the CONNECT request to the proxy.
_, err = conn.Write(reqBuf.Bytes())
require.NoError(t, err)
// Read and parse the proxy's response.
// On success (200), the proxy establishes a tunnel.
// On auth failure (407), the proxy returns a challenge with Proxy-Authenticate header.
resp, err := http.ReadResponse(bufio.NewReader(conn), nil)
require.NoError(t, err)
return resp
}
func TestNew(t *testing.T) {
t.Parallel()
t.Run("MissingListenAddr", func(t *testing.T) {
t.Parallel()
mitmCertFile, mitmKeyFile := getSharedTestMITMCert(t)
logger := slogtest.Make(t, nil)
_, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{
CoderAccessURL: "http://localhost:3000",
MITMCertFile: mitmCertFile,
MITMKeyFile: mitmKeyFile,
})
require.Error(t, err)
require.Contains(t, err.Error(), "listen address is required")
})
t.Run("EmptyListenAddr", func(t *testing.T) {
t.Parallel()
mitmCertFile, mitmKeyFile := getSharedTestMITMCert(t)
logger := slogtest.Make(t, nil)
_, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{
ListenAddr: "",
CoderAccessURL: "http://localhost:3000",
MITMCertFile: mitmCertFile,
MITMKeyFile: mitmKeyFile,
})
require.Error(t, err)
require.Contains(t, err.Error(), "listen address is required")
})
t.Run("TLSCertWithoutKey", func(t *testing.T) {
t.Parallel()
mitmCertFile, mitmKeyFile := getSharedTestMITMCert(t)
logger := slogtest.Make(t, nil)
_, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{
ListenAddr: "127.0.0.1:0",
TLSCertFile: "cert.pem",
CoderAccessURL: "http://localhost:3000",
MITMCertFile: mitmCertFile,
MITMKeyFile: mitmKeyFile,
})
require.Error(t, err)
require.Contains(t, err.Error(), "tls cert file and tls key file must both be set")
})
t.Run("TLSKeyWithoutCert", func(t *testing.T) {
t.Parallel()
mitmCertFile, mitmKeyFile := getSharedTestMITMCert(t)
logger := slogtest.Make(t, nil)
_, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{
ListenAddr: "127.0.0.1:0",
TLSKeyFile: "key.pem",
CoderAccessURL: "http://localhost:3000",
MITMCertFile: mitmCertFile,
MITMKeyFile: mitmKeyFile,
})
require.Error(t, err)
require.Contains(t, err.Error(), "tls cert file and tls key file must both be set")
})
t.Run("InvalidListenerTLSFiles", func(t *testing.T) {
t.Parallel()
mitmCertFile, mitmKeyFile := getSharedTestMITMCert(t)
logger := slogtest.Make(t, nil)
_, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{
ListenAddr: "127.0.0.1:0",
TLSCertFile: "/nonexistent/cert.pem",
TLSKeyFile: "/nonexistent/key.pem",
CoderAccessURL: "http://localhost:3000",
MITMCertFile: mitmCertFile,
MITMKeyFile: mitmKeyFile,
})
require.Error(t, err)
require.Contains(t, err.Error(), "load listener TLS certificate")
})
t.Run("MissingCoderAccessURL", func(t *testing.T) {
t.Parallel()
mitmCertFile, mitmKeyFile := getSharedTestMITMCert(t)
logger := slogtest.Make(t, nil)
_, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{
ListenAddr: "127.0.0.1:0",
MITMCertFile: mitmCertFile,
MITMKeyFile: mitmKeyFile,
})
require.Error(t, err)
require.Contains(t, err.Error(), "coder access URL is required")
})
t.Run("EmptyCoderAccessURL", func(t *testing.T) {
t.Parallel()
mitmCertFile, mitmKeyFile := getSharedTestMITMCert(t)
logger := slogtest.Make(t, nil)
_, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{
ListenAddr: "127.0.0.1:0",
CoderAccessURL: " ",
MITMCertFile: mitmCertFile,
MITMKeyFile: mitmKeyFile,
})
require.Error(t, err)
require.Contains(t, err.Error(), "coder access URL is required")
})
t.Run("InvalidCoderAccessURL", func(t *testing.T) {
t.Parallel()
mitmCertFile, mitmKeyFile := getSharedTestMITMCert(t)
logger := slogtest.Make(t, nil)
_, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{
ListenAddr: "127.0.0.1:0",
CoderAccessURL: "://invalid",
MITMCertFile: mitmCertFile,
MITMKeyFile: mitmKeyFile,
})
require.Error(t, err)
require.Contains(t, err.Error(), "invalid coder access URL")
})
t.Run("CoderAccessURLDefaultHTTPPort", func(t *testing.T) {
t.Parallel()
mitmCertFile, mitmKeyFile := getSharedTestMITMCert(t)
logger := slogtest.Make(t, nil)
srv, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{
ListenAddr: "127.0.0.1:0",
CoderAccessURL: "http://localhost",
MITMCertFile: mitmCertFile,
MITMKeyFile: mitmKeyFile,
})
require.NoError(t, err)
require.Equal(t, "localhost", srv.CoderAccessURL().Hostname())
require.Equal(t, "80", srv.CoderAccessURL().Port())
})
t.Run("CoderAccessURLDefaultHTTPSPort", func(t *testing.T) {
t.Parallel()
mitmCertFile, mitmKeyFile := getSharedTestMITMCert(t)
logger := slogtest.Make(t, nil)
srv, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{
ListenAddr: "127.0.0.1:0",
CoderAccessURL: "https://localhost",
MITMCertFile: mitmCertFile,
MITMKeyFile: mitmKeyFile,
})
require.NoError(t, err)
require.Equal(t, "localhost", srv.CoderAccessURL().Hostname())
require.Equal(t, "443", srv.CoderAccessURL().Port())
})
t.Run("CoderAccessURLExplicitPort", func(t *testing.T) {
t.Parallel()
mitmCertFile, mitmKeyFile := getSharedTestMITMCert(t)
logger := slogtest.Make(t, nil)
srv, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{
ListenAddr: "127.0.0.1:0",
CoderAccessURL: "http://localhost:3000",
MITMCertFile: mitmCertFile,
MITMKeyFile: mitmKeyFile,
})
require.NoError(t, err)
require.Equal(t, "localhost", srv.CoderAccessURL().Hostname())
require.Equal(t, "3000", srv.CoderAccessURL().Port())
})
t.Run("MissingCertFile", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
_, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{
ListenAddr: ":0",
CoderAccessURL: "http://localhost:3000",
MITMKeyFile: "key.pem",
})
require.Error(t, err)
require.Contains(t, err.Error(), "cert file and key file are required")
})
t.Run("MissingKeyFile", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
_, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{
ListenAddr: ":0",
CoderAccessURL: "http://localhost:3000",
MITMCertFile: "cert.pem",
})
require.Error(t, err)
require.Contains(t, err.Error(), "cert file and key file are required")
})
t.Run("InvalidCertFile", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
_, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{
ListenAddr: ":0",
CoderAccessURL: "http://localhost:3000",
MITMCertFile: "/nonexistent/cert.pem",
MITMKeyFile: "/nonexistent/key.pem",
})
require.Error(t, err)
require.Contains(t, err.Error(), "failed to load MITM certificate")
})
t.Run("InvalidUpstreamProxy", func(t *testing.T) {
t.Parallel()
mitmCertFile, mitmKeyFile := getSharedTestMITMCert(t)
logger := slogtest.Make(t, nil)
_, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{
ListenAddr: "127.0.0.1:0",
CoderAccessURL: "http://localhost:3000",
MITMCertFile: mitmCertFile,
MITMKeyFile: mitmKeyFile,
UpstreamProxy: "://invalid-url",
})
require.Error(t, err)
require.Contains(t, err.Error(), "invalid upstream proxy URL")
})
t.Run("UpstreamProxyCAFileNotFound", func(t *testing.T) {
t.Parallel()
mitmCertFile, mitmKeyFile := getSharedTestMITMCert(t)
logger := slogtest.Make(t, nil)
_, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{
ListenAddr: "127.0.0.1:0",
CoderAccessURL: "http://localhost:3000",
MITMCertFile: mitmCertFile,
MITMKeyFile: mitmKeyFile,
UpstreamProxy: "https://proxy.example.com:8080",
UpstreamProxyCA: "/nonexistent/ca.pem",
})
require.Error(t, err)
require.Contains(t, err.Error(), "failed to read upstream proxy CA certificate")
})
t.Run("UpstreamProxyAuthWithBothEmpty", func(t *testing.T) {
t.Parallel()
mitmCertFile, mitmKeyFile := getSharedTestMITMCert(t)
logger := slogtest.Make(t, nil)
_, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{
ListenAddr: "127.0.0.1:0",
CoderAccessURL: "http://localhost:3000",
MITMCertFile: mitmCertFile,
MITMKeyFile: mitmKeyFile,
UpstreamProxy: "http://:@proxy.example.com:8080",
})
require.Error(t, err)
require.Contains(t, err.Error(), "invalid credentials: both username and password are empty")
})
t.Run("InvalidAllowedPrivateCIDR", func(t *testing.T) {
t.Parallel()
mitmCertFile, mitmKeyFile := getSharedTestMITMCert(t)
logger := slogtest.Make(t, nil)
_, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{
ListenAddr: "127.0.0.1:0",
CoderAccessURL: "http://localhost:3000",
MITMCertFile: mitmCertFile,
MITMKeyFile: mitmKeyFile,
AllowedPrivateCIDRs: []string{"not-a-cidr"},
})
require.Error(t, err)
require.Contains(t, err.Error(), "invalid allowed private CIDR")
})
t.Run("Success", func(t *testing.T) {
t.Parallel()
mitmCertFile, mitmKeyFile := getSharedTestMITMCert(t)
logger := slogtest.Make(t, nil)
srv, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{
ListenAddr: "127.0.0.1:0",
CoderAccessURL: "http://localhost:3000",
MITMCertFile: mitmCertFile,
MITMKeyFile: mitmKeyFile,
})
require.NoError(t, err)
require.NotNil(t, srv)
})
t.Run("SuccessWithListenerTLS", func(t *testing.T) {
t.Parallel()
mitmCertFile, mitmKeyFile := getSharedTestMITMCert(t)
listenerCertFile, listenerKeyFile := generateListenerCert(t)
logger := slogtest.Make(t, nil)
srv, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{
ListenAddr: "127.0.0.1:0",
TLSCertFile: listenerCertFile,
TLSKeyFile: listenerKeyFile,
CoderAccessURL: "http://localhost:3000",
MITMCertFile: mitmCertFile,
MITMKeyFile: mitmKeyFile,
})
require.NoError(t, err)
require.NotNil(t, srv)
})
t.Run("SuccessWithUpstreamProxy", func(t *testing.T) {
t.Parallel()
mitmCertFile, mitmKeyFile := getSharedTestMITMCert(t)
logger := slogtest.Make(t, nil)
srv, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{
ListenAddr: "127.0.0.1:0",
CoderAccessURL: "http://localhost:3000",
MITMCertFile: mitmCertFile,
MITMKeyFile: mitmKeyFile,
UpstreamProxy: "http://proxy.example.com:8080",
})
require.NoError(t, err)
require.NotNil(t, srv)
})
t.Run("SuccessWithHTTPSUpstreamProxyAndCA", func(t *testing.T) {
t.Parallel()
mitmCertFile, mitmKeyFile := getSharedTestMITMCert(t)
logger := slogtest.Make(t, nil)
// Use the shared MITM certificate as the upstream proxy CA (it's a valid PEM cert)
srv, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{
ListenAddr: "127.0.0.1:0",
CoderAccessURL: "http://localhost:3000",
MITMCertFile: mitmCertFile,
MITMKeyFile: mitmKeyFile,
UpstreamProxy: "https://proxy.example.com:8080",
UpstreamProxyCA: mitmCertFile,
})
require.NoError(t, err)
require.NotNil(t, srv)
})
t.Run("SuccessWithUpstreamProxyAuth", func(t *testing.T) {
t.Parallel()
mitmCertFile, mitmKeyFile := getSharedTestMITMCert(t)
logger := slogtest.Make(t, nil)
srv, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{
ListenAddr: "127.0.0.1:0",
CoderAccessURL: "http://localhost:3000",
MITMCertFile: mitmCertFile,
MITMKeyFile: mitmKeyFile,
UpstreamProxy: "http://proxyuser:proxypass@proxy.example.com:8080",
})
require.NoError(t, err)
require.NotNil(t, srv)
})
t.Run("SuccessWithUpstreamProxyUsernameAuthColon", func(t *testing.T) {
t.Parallel()
mitmCertFile, mitmKeyFile := getSharedTestMITMCert(t)
logger := slogtest.Make(t, nil)
srv, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{
ListenAddr: "127.0.0.1:0",
CoderAccessURL: "http://localhost:3000",
MITMCertFile: mitmCertFile,
MITMKeyFile: mitmKeyFile,
UpstreamProxy: "http://proxyuser:@proxy.example.com:8080",
})
require.NoError(t, err)
require.NotNil(t, srv)
})
t.Run("SuccessWithUpstreamProxyUsernameAuth", func(t *testing.T) {
t.Parallel()
mitmCertFile, mitmKeyFile := getSharedTestMITMCert(t)
logger := slogtest.Make(t, nil)
// Username only (no colon) should also succeed (password is optional)
srv, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{
ListenAddr: "127.0.0.1:0",
CoderAccessURL: "http://localhost:3000",
MITMCertFile: mitmCertFile,
MITMKeyFile: mitmKeyFile,
UpstreamProxy: "http://proxyuser@proxy.example.com:8080",
})
require.NoError(t, err)
require.NotNil(t, srv)
})
t.Run("SuccessWithUpstreamProxyTokenAuth", func(t *testing.T) {
t.Parallel()
mitmCertFile, mitmKeyFile := getSharedTestMITMCert(t)
logger := slogtest.Make(t, nil)
srv, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{
ListenAddr: "127.0.0.1:0",
CoderAccessURL: "http://localhost:3000",
MITMCertFile: mitmCertFile,
MITMKeyFile: mitmKeyFile,
UpstreamProxy: "http://:proxypass@proxy.example.com:8080",
})
require.NoError(t, err)
require.NotNil(t, srv)
})
t.Run("SuccessWithMetrics", func(t *testing.T) {
t.Parallel()
mitmCertFile, mitmKeyFile := getSharedTestMITMCert(t)
logger := slogtest.Make(t, nil)
// Create metrics instance to verify it can be passed and stored.
reg := prometheus.NewRegistry()
metrics := aibridgeproxyd.NewMetrics(reg)
srv, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{
ListenAddr: "127.0.0.1:0",
CoderAccessURL: "http://localhost:3000",
MITMCertFile: mitmCertFile,
MITMKeyFile: mitmKeyFile,
Metrics: metrics,
})
require.NoError(t, err)
require.NotNil(t, srv)
})
t.Run("SuccessWithAllowedPrivateCIDRs", func(t *testing.T) {
t.Parallel()
mitmCertFile, mitmKeyFile := getSharedTestMITMCert(t)
logger := slogtest.Make(t, nil)
srv, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{
ListenAddr: "127.0.0.1:0",
CoderAccessURL: "http://localhost:3000",
MITMCertFile: mitmCertFile,
MITMKeyFile: mitmKeyFile,
AllowedPrivateCIDRs: []string{"127.0.0.1/32"},
})
require.NoError(t, err)
require.NotNil(t, srv)
})
}
func TestClose(t *testing.T) {
t.Parallel()
t.Run("Success", func(t *testing.T) {
t.Parallel()
mitmCertFile, mitmKeyFile := getSharedTestMITMCert(t)
logger := slogtest.Make(t, nil)
srv, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{
ListenAddr: "127.0.0.1:0",
CoderAccessURL: "http://localhost:3000",
MITMCertFile: mitmCertFile,
MITMKeyFile: mitmKeyFile,
})
require.NoError(t, err)
err = srv.Close()
require.NoError(t, err)
// Calling Close again should not error.
err = srv.Close()
require.NoError(t, err)
})
t.Run("WithMetrics", func(t *testing.T) {
t.Parallel()
mitmCertFile, mitmKeyFile := getSharedTestMITMCert(t)
logger := slogtest.Make(t, nil)
// Create metrics instance to verify Close() properly unregisters them.
reg := prometheus.NewRegistry()
metrics := aibridgeproxyd.NewMetrics(reg)
srv, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{
ListenAddr: "127.0.0.1:0",
CoderAccessURL: "http://localhost:3000",
MITMCertFile: mitmCertFile,
MITMKeyFile: mitmKeyFile,
Metrics: metrics,
})
require.NoError(t, err)
err = srv.Close()
require.NoError(t, err)
// Verify metrics were unregistered by attempting to register new metrics
// with the same registry. This should succeed if the old metrics were
// properly unregistered.
newMetrics := aibridgeproxyd.NewMetrics(reg)
require.NotNil(t, newMetrics, "should be able to create new metrics after Close() unregisters old ones")
// Calling Close again should not error.
err = srv.Close()
require.NoError(t, err)
})
}
func TestProxy_CertCaching(t *testing.T) {
t.Parallel()
tests := []struct {
name string
providerHosts []string
tunneled bool
}{
{
name: "ProviderHostCached",
providerHosts: nil, // will use targetURL.Hostname()
tunneled: false,
},
{
name: "NonProviderHostNotCached",
providerHosts: []string{"other.example.com"},
tunneled: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
// Create a mock HTTPS server that will be the target of the proxied request.
targetServer, targetURL := newTargetServer(t, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
// Create a mock aibridged server for provider-host (MITM'd) requests.
aibridgedServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
t.Cleanup(func() { aibridgedServer.Close() })
// Create a cert cache so we can inspect it after the request.
certCache := aibridgeproxyd.NewCertCache()
// Configure provider hosts.
providerHosts := tt.providerHosts
if providerHosts == nil {
providerHosts = []string{targetURL.Hostname()}
}
// Start the proxy server with the certificate cache.
srv := newTestProxy(t,
withCoderAccessURL(aibridgedServer.URL),
withAllowedPorts(targetURL.Port()),
withCertStore(certCache),
withProviderHosts(providerHosts...),
)
// Build the cert pool for the client to trust:
// - For tunneled requests, the client connects directly to the target server
// through a tunnel, so it needs to trust the target's self-signed certificate.
// - For MITM'd requests, the client connects through the proxy which generates
// certificates signed by the MITM certificate, so it needs to trust the MITM certificate.
var certPool *x509.CertPool
if tt.tunneled {
certPool = x509.NewCertPool()
certPool.AddCert(targetServer.Certificate())
} else {
certPool = getProxyCertPool(t)
}
// Make a request through the proxy to the target server.
client := newProxyClient(t, srv, makeProxyAuthHeader("test-token"), certPool, false)
req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, targetURL.String(), nil)
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
// Fetch with a generator that tracks calls.
genCalls := 0
_, err = certCache.Fetch(targetURL.Hostname(), func() (*tls.Certificate, error) {
genCalls++
return &tls.Certificate{}, nil
})
require.NoError(t, err)
if tt.tunneled {
// Certificate should NOT have been cached since request was tunneled.
require.Equal(t, 1, genCalls, "certificate should NOT have been cached for non-provider-host")
} else {
// Certificate should have been cached during MITM.
require.Equal(t, 0, genCalls, "certificate should have been cached during request")
}
})
}
}
func TestProxy_PortValidation(t *testing.T) {
t.Parallel()
tests := []struct {
name string
allowedPorts func(targetURL *url.URL) []string
expectError bool
}{
{
name: "AllowedPort",
// Include the target's random port so the request is allowed.
allowedPorts: func(targetURL *url.URL) []string {
return []string{targetURL.Port()}
},
},
{
name: "RejectedPort",
// Only allow port 443 which doesn't match the target.
allowedPorts: func(_ *url.URL) []string {
return []string{"443"}
},
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
// Create a target HTTPS server that will be the destination of our proxied request.
_, targetURL := newTargetServer(t, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("hello from target"))
})
// Create a mock aibridged server for provider-host (MITM'd) requests.
aibridgedServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("hello from aibridged"))
}))
t.Cleanup(func() { aibridgedServer.Close() })
// Start the proxy server.
srv := newTestProxy(t,
withCoderAccessURL(aibridgedServer.URL),
withAllowedPorts(tt.allowedPorts(targetURL)...),
withProviderHosts(targetURL.Hostname()),
)
// Make a request through the proxy to the target server.
client := newProxyClient(t, srv, makeProxyAuthHeader("test-token"), getProxyCertPool(t), false)
req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, targetURL.String(), nil)
require.NoError(t, err)
resp, err := client.Do(req)
if tt.expectError {
require.Error(t, err)
return
}
require.NoError(t, err)
defer resp.Body.Close()
// Verify the request was successful and routed to aibridged.
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
require.Equal(t, "hello from aibridged", string(body))
})
}
}
func TestProxy_Authentication(t *testing.T) {
t.Parallel()
tests := []struct {
name string
proxyAuth string
expectSuccess bool
}{
{
name: "ValidCredentials",
proxyAuth: makeProxyAuthHeader("test-coder-token"),
expectSuccess: true,
},
{
name: "MissingCredentials",
proxyAuth: "",
expectSuccess: false,
},
{
name: "InvalidBase64",
proxyAuth: "Basic not-valid-base64!",
expectSuccess: false,
},
{
name: "EmptyToken",
proxyAuth: makeProxyAuthHeader(""),
expectSuccess: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
// Create a mock HTTPS server that will be the target of our proxied request.
_, targetURL := newTargetServer(t, func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("hello from target"))
})
// Create a mock aibridged server for provider-host (MITM'd) requests.
aibridgedServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("hello from aibridged"))
}))
t.Cleanup(func() { aibridgedServer.Close() })
// Start the proxy server.
srv := newTestProxy(t,
withCoderAccessURL(aibridgedServer.URL),
withAllowedPorts(targetURL.Port()),
withProviderHosts(targetURL.Hostname()),
)
if tt.expectSuccess {
// Use the standard HTTP client for successful requests.
client := newProxyClient(t, srv, tt.proxyAuth, getProxyCertPool(t), false)
req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, targetURL.String(), nil)
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
// Verify the response was successfully routed to aibridged.
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
require.Equal(t, "hello from aibridged", string(body))
} else {
// Verify the proxy returns a 407 challenge with Proxy-Authenticate header.
// A raw CONNECT request is sent because Go's HTTP client doesn't expose
// the response when CONNECT fails with a non-2xx status.
resp := sendConnect(t, srv.Addr(), targetURL.Host, tt.proxyAuth)
defer resp.Body.Close()
// Verify the status code indicates proxy authentication is required.
require.Equal(t, http.StatusProxyAuthRequired, resp.StatusCode)
// Verify the Proxy-Authenticate header is present and contains the
// expected realm. This header tells clients how to authenticate.
proxyAuthenticate := resp.Header.Get("Proxy-Authenticate")
require.Equal(t, "Basic realm="+aibridgeproxyd.ProxyAuthRealm, proxyAuthenticate)
// Verify the response body contains the expected error message.
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, http.StatusText(http.StatusProxyAuthRequired), string(body))
}
})
}
}
func TestProxy_MITM(t *testing.T) {
t.Parallel()
tests := []struct {
name string
providerHosts []string
allowedPorts []string
buildTargetURL func(tunneledURL *url.URL) (string, error)
tunneled bool
expectedPath string
provider string
}{
{
name: "MitmdAnthropic",
providerHosts: []string{aibridgeproxyd.HostAnthropic},
allowedPorts: []string{"443"},
buildTargetURL: func(_ *url.URL) (string, error) {
return "https://api.anthropic.com/v1/messages", nil
},
expectedPath: "/api/v2/aibridge/anthropic/v1/messages",
provider: "anthropic",
},
{
name: "MitmdAnthropicNonDefaultPort",
providerHosts: []string{aibridgeproxyd.HostAnthropic},
allowedPorts: []string{"8443"},
buildTargetURL: func(_ *url.URL) (string, error) {
return "https://api.anthropic.com:8443/v1/messages", nil
},
expectedPath: "/api/v2/aibridge/anthropic/v1/messages",
provider: "anthropic",
},
{
name: "MitmdOpenAI",
providerHosts: []string{aibridgeproxyd.HostOpenAI},
allowedPorts: []string{"443"},
buildTargetURL: func(_ *url.URL) (string, error) {
return "https://api.openai.com/v1/chat/completions", nil
},
expectedPath: "/api/v2/aibridge/openai/v1/chat/completions",
provider: "openai",
},
{
name: "MitmdOpenAINonDefaultPort",
providerHosts: []string{aibridgeproxyd.HostOpenAI},
allowedPorts: []string{"8443"},
buildTargetURL: func(_ *url.URL) (string, error) {
return "https://api.openai.com:8443/v1/chat/completions", nil
},
expectedPath: "/api/v2/aibridge/openai/v1/chat/completions",
provider: "openai",
},
{
name: "TunneledUnknownHost",
providerHosts: []string{aibridgeproxyd.HostAnthropic, aibridgeproxyd.HostOpenAI},
allowedPorts: nil, // will use tunneledURL.Port()
buildTargetURL: func(tunneledURL *url.URL) (string, error) {
return url.JoinPath(tunneledURL.String(), "/some/path")
},
tunneled: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
// Create metrics for verification.
reg := prometheus.NewRegistry()
metrics := aibridgeproxyd.NewMetrics(reg)
// Track what aibridged receives.
var receivedPath, receivedAuthz, receivedBYOK, receivedRequestID string
// Create a mock aibridged server that captures requests.
aibridgedServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedPath = r.URL.Path
receivedAuthz = r.Header.Get("Authorization")
receivedBYOK = r.Header.Get(agplaibridge.HeaderCoderToken)
receivedRequestID = r.Header.Get(agplaibridge.HeaderCoderRequestID)
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("hello from aibridged"))
}))
t.Cleanup(func() { aibridgedServer.Close() })
// Create a mock target server for tunneled tests.
tunneledServer, tunneledURL := newTargetServer(t, func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("hello from tunneled"))
})
// Configure allowed ports.
allowedPorts := tt.allowedPorts
if allowedPorts == nil {
allowedPorts = []string{tunneledURL.Port()}
}
// Configure provider hosts.
providerHosts := tt.providerHosts
if providerHosts == nil {
providerHosts = []string{tunneledURL.Hostname()}
}
// Start the proxy server pointing to our mock aibridged.
srv := newTestProxy(t,
withCoderAccessURL(aibridgedServer.URL),
withAllowedPorts(allowedPorts...),
withProviderHosts(providerHosts...),
withMetrics(metrics),
)
// Build the target URL:
targetURL, err := tt.buildTargetURL(tunneledURL)
require.NoError(t, err)
// Build the cert pool for the client to trust:
// - For tunneled requests, the client connects directly to the target server
// through a tunnel, so it needs to trust the target's self-signed certificate.
// - For MITM'd requests, the client connects through the proxy which generates
// certificates signed by the MITM certificate, so it needs to trust the MITM certificate.
var certPool *x509.CertPool
if tt.tunneled {
certPool = x509.NewCertPool()
certPool.AddCert(tunneledServer.Certificate())
} else {
certPool = getProxyCertPool(t)
}
// Simulate the primary proxy use case: the Coder
// token is in Proxy-Authorization, and the user's
// own LLM token is in Authorization.
client := newProxyClient(t, srv, makeProxyAuthHeader("coder-token"), certPool, false)
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, targetURL, strings.NewReader(`{}`))
require.NoError(t, err)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer user-llm-token")
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
// Gather metrics for verification.
gatheredMetrics, err := reg.Gather()
require.NoError(t, err)
if tt.tunneled {
// Verify request went to target server, not aibridged.
require.Equal(t, "hello from tunneled", string(body))
require.Empty(t, receivedPath, "aibridged should not receive tunneled requests")
require.Empty(t, receivedAuthz, "tunneled requests should not reach aibridged")
require.Empty(t, receivedRequestID, "tunneled requests should not have request ID header")
// Verify metrics for tunneled requests.
require.True(t, testutil.PromCounterHasValue(t, gatheredMetrics, 1, "connect_sessions_total", aibridgeproxyd.RequestTypeTunneled))
// Verify MITM-specific metrics were not set.
require.False(t, testutil.PromCounterGathered(t, gatheredMetrics, "connect_sessions_total", aibridgeproxyd.RequestTypeMITM))
require.False(t, testutil.PromCounterGathered(t, gatheredMetrics, "mitm_requests_total", tt.provider))
require.False(t, testutil.PromGaugeGathered(t, gatheredMetrics, "inflight_mitm_requests", tt.provider))
require.False(t, testutil.PromCounterGathered(t, gatheredMetrics, "mitm_responses_total", "200", tt.provider))
} else {
// Verify the request was routed to aibridged correctly.
require.Equal(t, "hello from aibridged", string(body))
require.Equal(t, tt.expectedPath, receivedPath)
require.Equal(t, "Bearer user-llm-token", receivedAuthz, "user's LLM credentials must be forwarded")
require.Equal(t, "coder-token", receivedBYOK, "proxy must inject BYOK header with Coder token")
require.NotEmpty(t, receivedRequestID, "MITM'd requests must include request ID header")
_, err := uuid.Parse(receivedRequestID)
require.NoError(t, err, "request ID must be a valid UUID")
// Verify metrics for MITM requests.
require.True(t, testutil.PromCounterHasValue(t, gatheredMetrics, 1, "connect_sessions_total", aibridgeproxyd.RequestTypeMITM))
require.True(t, testutil.PromCounterHasValue(t, gatheredMetrics, 1, "mitm_requests_total", tt.provider))
require.True(t, testutil.PromGaugeHasValue(t, gatheredMetrics, 0, "inflight_mitm_requests", tt.provider))
require.True(t, testutil.PromCounterHasValue(t, gatheredMetrics, 1, "mitm_responses_total", "200", tt.provider))
// Verify tunneled counter was not set.
require.False(t, testutil.PromCounterGathered(t, gatheredMetrics, "connect_sessions_total", aibridgeproxyd.RequestTypeTunneled))
}
})
}
}
// TestProxy_MITM_BYOKInjection verifies that the proxy sets the BYOK header
// when Authorization carries a bearer token different from the Coder
// token. This handles clients that send per-user LLM credentials
// but cannot set custom headers.
func TestProxy_MITM_BYOKInjection(t *testing.T) {
t.Parallel()
coderToken := "coder-token"
tests := []struct {
name string
authzHeader string
byokHeader string // pre-set by client; empty means not set
expectBYOK bool
expectBYOKVal string
}{
{
// Centralized: Authorization carries the Coder token (same
// value as Proxy-Authorization). No BYOK header is set.
name: "Authorization matches Coder token",
authzHeader: "Bearer " + coderToken,
expectBYOK: false,
},
{
// BYOK: Authorization carries the user's token,
// which differs from the Coder token. The proxy injects
// the BYOK header.
name: "Authorization differs from Coder token",
authzHeader: "Bearer client-access-token",
expectBYOK: true,
expectBYOKVal: coderToken,
},
{
// Client already set the BYOK header (Claude Code, Codex).
// The proxy must not overwrite it.
name: "BYOK header already set by client — not overwritten",
authzHeader: "Bearer client-access-token",
byokHeader: "client-set-coder-token",
expectBYOK: true,
expectBYOKVal: "client-set-coder-token",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var receivedBYOKHeader, receivedAuthz string
aibridgedServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedAuthz = r.Header.Get("Authorization")
receivedBYOKHeader = r.Header.Get(agplaibridge.HeaderCoderToken)
w.WriteHeader(http.StatusOK)
}))
t.Cleanup(aibridgedServer.Close)
srv := newTestProxy(t,
withCoderAccessURL(aibridgedServer.URL),
withProviderHosts(aibridgeproxyd.HostCopilot),
)
certPool := getProxyCertPool(t)
client := newProxyClient(t, srv, makeProxyAuthHeader(coderToken), certPool, false)
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "https://"+aibridgeproxyd.HostCopilot+"/chat/completions", strings.NewReader(`{}`))
require.NoError(t, err)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", tt.authzHeader)
if tt.byokHeader != "" {
req.Header.Set(agplaibridge.HeaderCoderToken, tt.byokHeader)
}
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)
require.Equal(t, tt.authzHeader, receivedAuthz, "Authorization must be forwarded to aibridged")
if tt.expectBYOK {
require.Equal(t, tt.expectBYOKVal, receivedBYOKHeader, "BYOK header must be set when Authorization differs from Coder token")
} else {
require.Empty(t, receivedBYOKHeader, "BYOK header must not be set")
}
})
}
}
// TestListenerTLS verifies that the proxy works correctly when its listener is wrapped in TLS.
// It tests both tunneled and MITM'd requests through an HTTPS proxy listener.
func TestListenerTLS(t *testing.T) {
t.Parallel()
tests := []struct {
name string
tunneled bool
expectedBody string
}{
{
name: "Tunneled",
tunneled: true,
expectedBody: "hello from tunneled",
},
{
name: "MITM",
tunneled: false,
expectedBody: "hello from aibridged",
},
}
// Shared across subtests since all use the same TLS listener certificate.
listenerCertFile, listenerKeyFile := generateListenerCert(t)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
// Mock aibridged server that receives MITM'd requests.
aibridgedServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("hello from aibridged"))
}))
t.Cleanup(func() { aibridgedServer.Close() })
// Target server: response is returned directly for tunneled, intercepted for MITM.
tunneledServer, targetURL := newTargetServer(t, func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("hello from tunneled"))
})
var proxyOpts []testProxyOption
proxyOpts = append(proxyOpts,
withListenerTLS(listenerCertFile, listenerKeyFile),
withCoderAccessURL(aibridgedServer.URL),
withAllowedPorts(targetURL.Port()),
)
if tt.tunneled {
// Configure provider hosts that exclude the target server so requests are tunneled.
proxyOpts = append(proxyOpts, withProviderHosts(aibridgeproxyd.HostAnthropic, aibridgeproxyd.HostOpenAI))
}
srv := newTestProxy(t, proxyOpts...)
// Cert pool must include two certificates: the listener certificate to connect
// to the proxy over TLS, and the MITM or target certificate for the inner
// TLS handshake.
listenerCertPEM, err := os.ReadFile(listenerCertFile)
require.NoError(t, err)
var certPool *x509.CertPool
if tt.tunneled {
certPool = x509.NewCertPool()
certPool.AddCert(tunneledServer.Certificate())
} else {
certPool = getProxyCertPool(t)
}
certPool.AppendCertsFromPEM(listenerCertPEM)
client := newProxyClient(t, srv, makeProxyAuthHeader("test-token"), certPool, false)
req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, targetURL.String(), nil)
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
require.Equal(t, tt.expectedBody, string(body))
})
}
}
// TestServeCACert validates that a configured certificate file can be served correctly by the API.
//
// Note: Tests for certificate file errors (missing file, invalid PEM) are
// covered by [TestNew] since certificate validation happens at initialization.
// The serveCACert handler returns the pre-loaded, pre-validated certificate.
func TestServeCACert(t *testing.T) {
t.Parallel()
t.Run("Success", func(t *testing.T) {
t.Parallel()
srv := newTestProxy(t)
// Create a request to the MITM certificate endpoint via the Handler.
req := httptest.NewRequest(http.MethodGet, "/ca-cert.pem", nil)
rec := httptest.NewRecorder()
srv.Handler().ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "application/x-pem-file", rec.Header().Get("Content-Type"))
require.Equal(t, "attachment; filename=ca-cert.pem", rec.Header().Get("Content-Disposition"))
// Verify the certificate is valid PEM.
body := rec.Body.Bytes()
block, _ := pem.Decode(body)
require.NotNil(t, block, "response should be valid PEM")
require.Equal(t, "CERTIFICATE", block.Type)
// Verify the certificate is valid X.509.
cert, err := x509.ParseCertificate(block.Bytes)
require.NoError(t, err)
require.NotNil(t, cert)
// Verify it matches the original certificate.
certFile, _ := getSharedTestMITMCert(t)
expectedCertPEM, err := os.ReadFile(certFile)
require.NoError(t, err)
require.Equal(t, expectedCertPEM, body)
})
}
// TestServeCACert_CompoundPEM validates that a compound PEM certificate which contains a private key
// will only have its certificate type returned from the /ca-cert.pem endpoint.
func TestServeCACert_CompoundPEM(t *testing.T) {
t.Parallel()
certFile, keyFile := getSharedTestMITMCert(t)
// Read the shared MITM certificate and key to create a compound PEM file.
certPEM, err := os.ReadFile(certFile)
require.NoError(t, err)
keyPEM, err := os.ReadFile(keyFile)
require.NoError(t, err)
// Create a compound PEM file containing both the certificate and the private key.
compoundPEM := make([]byte, 0, len(certPEM)+len(keyPEM))
compoundPEM = append(compoundPEM, certPEM...)
compoundPEM = append(compoundPEM, keyPEM...)
tmpDir := t.TempDir()
compoundCertFile := filepath.Join(tmpDir, "compound.pem")
err = os.WriteFile(compoundCertFile, compoundPEM, 0o600)
require.NoError(t, err)
logger := slogtest.Make(t, nil)
srv, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{
ListenAddr: "127.0.0.1:0",
CoderAccessURL: "http://localhost:3000",
MITMCertFile: compoundCertFile,
MITMKeyFile: keyFile,
})
require.NoError(t, err)
t.Cleanup(func() { _ = srv.Close() })
// Create a request to the MITM certificate endpoint via the Handler.
req := httptest.NewRequest(http.MethodGet, "/ca-cert.pem", nil)
rec := httptest.NewRecorder()
srv.Handler().ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
// Verify the response contains only the certificate, not the private key.
body := rec.Body.Bytes()
// Parse all PEM blocks from the response.
var pemBlocks []*pem.Block
remaining := body
for {
var block *pem.Block
block, remaining = pem.Decode(remaining)
if block == nil {
break
}
pemBlocks = append(pemBlocks, block)
}
// There should be exactly one PEM block (the certificate).
require.Len(t, pemBlocks, 1, "response should contain exactly one PEM block")
require.Equal(t, "CERTIFICATE", pemBlocks[0].Type, "the PEM block should be a certificate")
// Verify no private key material is present by checking for common key block types.
bodyStr := string(body)
require.NotContains(t, bodyStr, "PRIVATE KEY", "response should not contain any private key")
require.NotContains(t, bodyStr, "RSA PRIVATE KEY", "response should not contain RSA private key")
require.NotContains(t, bodyStr, "EC PRIVATE KEY", "response should not contain EC private key")
// Verify the certificate is valid X.509.
cert, err := x509.ParseCertificate(pemBlocks[0].Bytes)
require.NoError(t, err)
require.Equal(t, "Shared Test MITM Cert", cert.Subject.CommonName)
}
func TestUpstreamProxy(t *testing.T) {
t.Parallel()
tests := []struct {
name string
// tunneled determines whether the request should be tunneled through
// the upstream proxy (true) or MITM'd by aiproxy (false).
// When true, the target domain has no configured provider.
// When false, the target domain has a configured provider.
tunneled bool
// upstreamProxyTLS determines whether the upstream proxy uses TLS.
// When true, aiproxy must be configured with the upstream proxy's CA.
upstreamProxyTLS bool
// buildTargetURL constructs the request URL. For tunneled requests, it uses
// the final destination URL. For MITM, it uses api.anthropic.com.
buildTargetURL func(finalDestinationURL *url.URL) string
// expectedAIBridgePath is the path aibridge should receive for MITM requests.
expectedAIBridgePath string
// upstreamProxyAuth is optional "user:pass" credentials for the upstream proxy.
// If set, the test verifies the Proxy-Authorization header is sent correctly.
upstreamProxyAuth string
}{
{
name: "NonProviderHost_TunneledToHTTPUpstreamProxy",
tunneled: true,
upstreamProxyTLS: false,
buildTargetURL: func(finalDestinationURL *url.URL) string {
return fmt.Sprintf("https://%s/tunneled-path", finalDestinationURL.Host)
},
},
{
name: "NonProviderHost_TunneledToHTTPSUpstreamProxy",
tunneled: true,
upstreamProxyTLS: true,
buildTargetURL: func(finalDestinationURL *url.URL) string {
return fmt.Sprintf("https://%s/tunneled-path", finalDestinationURL.Host)
},
},
{
name: "NonProviderHost_TunneledToHTTPUpstreamProxyWithAuth",
tunneled: true,
upstreamProxyTLS: false,
upstreamProxyAuth: "proxyuser:proxypass",
buildTargetURL: func(finalDestinationURL *url.URL) string {
return fmt.Sprintf("https://%s/tunneled-path", finalDestinationURL.Host)
},
},
{
name: "NonProviderHost_TunneledToHTTPUpstreamProxyWithUsernameOnly",
tunneled: true,
upstreamProxyTLS: false,
upstreamProxyAuth: "proxyuser",
buildTargetURL: func(finalDestinationURL *url.URL) string {
return fmt.Sprintf("https://%s/tunneled-path", finalDestinationURL.Host)
},
},
{
name: "NonProviderHost_TunneledToHTTPUpstreamProxyWithUsernameAndColon",
tunneled: true,
upstreamProxyTLS: false,
upstreamProxyAuth: "proxyuser:",
buildTargetURL: func(finalDestinationURL *url.URL) string {
return fmt.Sprintf("https://%s/tunneled-path", finalDestinationURL.Host)
},
},
{
name: "NonProviderHost_TunneledToHTTPUpstreamProxyWithTokenAuth",
tunneled: true,
upstreamProxyTLS: false,
upstreamProxyAuth: ":proxypass",
buildTargetURL: func(finalDestinationURL *url.URL) string {
return fmt.Sprintf("https://%s/tunneled-path", finalDestinationURL.Host)
},
},
{
name: "ProviderHost_MITMByAIProxy",
tunneled: false,
upstreamProxyTLS: false,
buildTargetURL: func(_ *url.URL) string {
return "https://api.anthropic.com:443/v1/messages"
},
expectedAIBridgePath: "/api/v2/aibridge/anthropic/v1/messages",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
// Track requests received by each component to verify the flow.
var (
upstreamProxyCONNECTReceived bool
upstreamProxyCONNECTHost string
upstreamProxyAuthHeader string
finalDestinationReceived bool
finalDestinationPath string
finalDestinationBody string
aibridgeReceived bool
aibridgePath string
aibridgeAuthz string
aibridgeBYOK string
aibridgeBody string
)
// Create mock final destination server representing the actual target:
// - For tunneled requests, traffic should reach this server.
// - For MITM requests, traffic should NOT reach this server.
finalDestination, finalDestinationURL := newTargetServer(t, func(w http.ResponseWriter, r *http.Request) {
finalDestinationReceived = true
finalDestinationPath = r.URL.Path
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
finalDestinationBody = string(body)
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("final destination response"))
})
// Upstream proxy handler: same logic for both HTTP and HTTPS.
upstreamProxyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodConnect {
http.Error(w, "expected CONNECT request", http.StatusBadRequest)
return
}
upstreamProxyCONNECTReceived = true
upstreamProxyCONNECTHost = r.Host
upstreamProxyAuthHeader = r.Header.Get("Proxy-Authorization")
// Connect to the mock final destination server.
targetConn, err := net.Dial("tcp", finalDestinationURL.Host)
if err != nil {
http.Error(w, err.Error(), http.StatusBadGateway)
return
}
defer targetConn.Close()
// Hijack the connection to take over the raw TCP socket.
// After responding "200 Connection Established", the proxy stops being
// an HTTP server and becomes a transparent tunnel that copies bytes
// bidirectionally. The http package can't handle this mode, so we
// hijack and manage the connection ourselves.
hijacker, ok := w.(http.Hijacker)
if !ok {
http.Error(w, "hijacking not supported", http.StatusInternalServerError)
return
}
clientConn, _, err := hijacker.Hijack()
if err != nil {
return
}
defer clientConn.Close()
// Send 200 Connection Established to signal tunnel is ready.
_, _ = clientConn.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n"))
// Copy data bidirectionally between aiproxy and final destination.
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
_, _ = io.Copy(targetConn, clientConn)
}()
go func() {
defer wg.Done()
_, _ = io.Copy(clientConn, targetConn)
}()
wg.Wait()
})
// Create upstream proxy: HTTP or HTTPS based on test case.
var upstreamProxy *httptest.Server
var upstreamProxyCAFile string
if tt.upstreamProxyTLS {
upstreamProxy = httptest.NewTLSServer(upstreamProxyHandler)
// Write the upstream proxy's CA cert to a temp file for aiproxy to trust.
upstreamProxyCAFile = filepath.Join(t.TempDir(), "upstream-proxy-ca.pem")
certPEM := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: upstreamProxy.Certificate().Raw,
})
err := os.WriteFile(upstreamProxyCAFile, certPEM, 0o600)
require.NoError(t, err)
} else {
upstreamProxy = httptest.NewServer(upstreamProxyHandler)
}
t.Cleanup(upstreamProxy.Close)
// Create a mock aibridged server:
// - For tunneled requests, traffic should NOT reach this server.
// - For MITM requests, aiproxy rewrites the URL and forwards here.
aibridgeServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
aibridgeReceived = true
aibridgePath = r.URL.Path
aibridgeAuthz = r.Header.Get("Authorization")
aibridgeBYOK = r.Header.Get(agplaibridge.HeaderCoderToken)
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
aibridgeBody = string(body)
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("aibridge response"))
}))
t.Cleanup(aibridgeServer.Close)
// Build the target URL for this test case.
targetURL := tt.buildTargetURL(finalDestinationURL)
parsedTargetURL, err := url.Parse(targetURL)
require.NoError(t, err)
// Configure provider hosts based on test case:
// - For tunneled requests, api.anthropic.com has a configured provider, but we target a different host.
// - For MITM, api.anthropic.com must have a configured provider.
providerHosts := []string{aibridgeproxyd.HostAnthropic}
// Build upstream proxy URL with optional auth credentials.
upstreamProxyURLStr := upstreamProxy.URL
if tt.upstreamProxyAuth != "" {
parsed, err := url.Parse(upstreamProxy.URL)
require.NoError(t, err)
upstreamProxyURLStr = fmt.Sprintf("%s://%s@%s", parsed.Scheme, tt.upstreamProxyAuth, parsed.Host)
}
// Create aiproxy with upstream proxy configured.
proxyOpts := []testProxyOption{
withCoderAccessURL(aibridgeServer.URL),
withProviderHosts(providerHosts...),
withUpstreamProxy(upstreamProxyURLStr),
withAllowedPorts("80", "443", parsedTargetURL.Port()),
}
if upstreamProxyCAFile != "" {
proxyOpts = append(proxyOpts, withUpstreamProxyCA(upstreamProxyCAFile))
}
srv := newTestProxy(t, proxyOpts...)
// Configure certificate trust based on test case:
// - For tunneled requests: client trusts final destination's CA.
// - For MITM: client trusts aiproxy's MITM certificate (for generated leaf certs).
var certPool *x509.CertPool
if tt.tunneled {
certPool = x509.NewCertPool()
certPool.AddCert(finalDestination.Certificate())
} else {
certPool = getProxyCertPool(t)
}
// Create HTTP client configured to use aiproxy. Coder token
// in Proxy-Authorization, user's LLM token in Authorization.
client := newProxyClient(t, srv, makeProxyAuthHeader("test-coder-token"), certPool, false)
// Make request through aiproxy.
requestBody := `{"test": "data", "foo": "bar"}`
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, targetURL, strings.NewReader(requestBody))
require.NoError(t, err)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer user-llm-token")
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)
// Verify the request flow based on test case.
if tt.tunneled {
require.True(t, upstreamProxyCONNECTReceived,
"upstream proxy should receive CONNECT for non-provider-host")
require.Equal(t, finalDestinationURL.Host, upstreamProxyCONNECTHost,
"upstream proxy should receive CONNECT to correct host")
require.True(t, finalDestinationReceived,
"final destination should receive the tunneled request")
require.Equal(t, parsedTargetURL.Path, finalDestinationPath,
"final destination should receive correct path")
require.Equal(t, requestBody, finalDestinationBody,
"final destination should receive the exact request body")
require.False(t, aibridgeReceived,
"aibridge should NOT receive request for non-provider-host")
require.Empty(t, aibridgeAuthz,
"tunneled requests should not reach aibridge")
} else {
require.False(t, upstreamProxyCONNECTReceived,
"upstream proxy should NOT receive CONNECT for provider host")
require.True(t, aibridgeReceived,
"aibridge should receive the MITM'd request")
require.Equal(t, tt.expectedAIBridgePath, aibridgePath,
"aibridge should receive rewritten path")
require.Equal(t, "Bearer user-llm-token", aibridgeAuthz,
"user's LLM credentials must be forwarded")
require.Equal(t, "test-coder-token", aibridgeBYOK,
"proxy must inject BYOK header with Coder token")
require.Equal(t, requestBody, aibridgeBody,
"aibridge should receive the exact request body")
require.False(t, finalDestinationReceived,
"final destination should NOT receive request for provider host")
}
// Verify upstream proxy authentication if configured.
if tt.upstreamProxyAuth != "" {
expectedAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte(tt.upstreamProxyAuth))
require.Equal(t, expectedAuth, upstreamProxyAuthHeader,
"Proxy-Authorization header should contain correct credentials")
}
})
}
}
// TestProxy_MITM_CustomProvider verifies that a non-builtin provider
// (e.g. OpenRouter) whose domain is registered as a provider host is correctly
// MITM'd and routed through the proxy to the bridge endpoint.
func TestProxy_MITM_CustomProvider(t *testing.T) {
t.Parallel()
const (
openrouterDomain = "openrouter.ai"
openrouterProvider = "openrouter"
)
// Track what aibridged receives.
var receivedPath, receivedBYOK string
// Create a mock aibridged server that captures requests.
aibridgedServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedPath = r.URL.Path
receivedBYOK = r.Header.Get(agplaibridge.HeaderCoderToken)
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("hello from aibridged"))
}))
t.Cleanup(aibridgedServer.Close)
// Wire the custom domain and provider mapping directly via
// withProviders, equivalent to the snapshot the daemon's Reload
// builds from classified providers in production.
srv := newTestProxy(t,
withCoderAccessURL(aibridgedServer.URL),
withProviders(aibridgeproxyd.ReloadedProvider{
ProviderOutcome: aibridged.ProviderOutcome{
Name: openrouterProvider,
Type: "openai",
Status: aibridged.ProviderStatusEnabled,
},
Host: openrouterDomain,
}),
)
certPool := getProxyCertPool(t)
client := newProxyClient(t, srv, makeProxyAuthHeader("coder-token"), certPool, false)
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "https://"+openrouterDomain+"/api/v1/chat/completions", strings.NewReader(`{}`))
require.NoError(t, err)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer user-llm-token")
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
require.Equal(t, "hello from aibridged", string(body))
// The proxy should route through the aibridge path using the custom
// provider name.
require.Equal(t, "/api/v2/aibridge/"+openrouterProvider+"/api/v1/chat/completions", receivedPath)
require.Equal(t, "coder-token", receivedBYOK)
}
func TestProxy_PrivateIPBlocking(t *testing.T) {
t.Parallel()
tests := []struct {
name string
targetHostname string
useUpstreamProxy bool
allowedCIDRs []string
coderAccessURLFn func(targetHostname, port string) string
expectBlocked bool
expectDialFail bool
}{
{
// Direct IP: by default, all private/reserved IPs are blocked.
name: "BlockedDirectDial",
targetHostname: "127.0.0.1",
expectBlocked: true,
},
{
// Hostname: DNS resolves to 127.0.0.1, which is then blocked.
name: "BlockedDirectDialByHostname",
targetHostname: "localhost",
expectBlocked: true,
},
{
// Direct IP: block applies even with an upstream proxy configured.
name: "BlockedViaUpstreamProxy",
targetHostname: "127.0.0.1",
useUpstreamProxy: true,
expectBlocked: true,
},
{
// Hostname: DNS resolves to 127.0.0.1, which is then blocked.
name: "BlockedViaUpstreamProxyByHostname",
targetHostname: "localhost",
useUpstreamProxy: true,
expectBlocked: true,
},
{
// Direct IP: a configured CIDR exception allows the range.
name: "AllowedByPrivateCIDR",
targetHostname: "127.0.0.1",
allowedCIDRs: []string{"127.0.0.1/32"},
expectBlocked: false,
},
{
// Hostname: DNS resolves to 127.0.0.1, which is allowed by the CIDR exception.
name: "AllowedByPrivateCIDRByHostname",
targetHostname: "localhost",
allowedCIDRs: []string{"127.0.0.1/32"},
expectBlocked: false,
},
{
// Direct IP: the Coder access URL host:port is always exempt.
name: "AllowedByCoderAccessURL",
targetHostname: "127.0.0.1",
coderAccessURLFn: func(targetHostname, port string) string {
return fmt.Sprintf("http://%s:%s", targetHostname, port)
},
expectBlocked: false,
},
{
// Hostname: DNS resolves to 127.0.0.1, which is exempt as the Coder access URL.
name: "AllowedByCoderAccessURLByHostname",
targetHostname: "localhost",
coderAccessURLFn: func(targetHostname, port string) string {
return fmt.Sprintf("http://%s:%s", targetHostname, port)
},
expectBlocked: false,
},
{
// A domain reserved by RFC 2606 that never resolves causes a plain dial
// failure (not a blocked IP). The proxy should return 502 Bad Gateway,
// not 403, to confirm the two error paths are distinguished correctly.
name: "DialFailureReturns502",
targetHostname: "host.invalid",
expectDialFail: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
// The target server always listens on 127.0.0.1. When targetHostname is
// "localhost", the proxy resolves it to 127.0.0.1 via DNS, exercising
// the hostname resolution path of the IP check.
targetServer, targetURL := newTargetServer(t, func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("hello from target"))
})
// Build the CONNECT target using the configured hostname.
connectTarget := fmt.Sprintf("%s:%s", tt.targetHostname, targetURL.Port())
// Configure provider hosts that exclude the target so CONNECT requests
// go through the tunnel path rather than being MITM'd.
opts := []testProxyOption{
withProviderHosts(aibridgeproxyd.HostAnthropic),
withAllowedPorts(targetURL.Port()),
}
if tt.useUpstreamProxy {
// A minimal upstream proxy server is sufficient here: the IP check
// fires inside ConnectDial before any connection reaches it.
upstreamProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {}))
t.Cleanup(upstreamProxy.Close)
opts = append(opts, withUpstreamProxy(upstreamProxy.URL))
}
// Always override the default allowedPrivateCIDRs so blocked cases
// are not accidentally exempted by the loopback default.
opts = append(opts, withAllowedPrivateCIDRs(tt.allowedCIDRs...))
if tt.coderAccessURLFn != nil {
opts = append(opts, withCoderAccessURL(tt.coderAccessURLFn(tt.targetHostname, targetURL.Port())))
}
srv := newTestProxy(t, opts...)
switch {
case tt.expectBlocked:
// Use a raw CONNECT to observe the 403 returned when ConnectDial blocks
// a private/reserved IP. Go's HTTP client does not expose the response
// for non-2xx CONNECT results.
resp := sendConnect(t, srv.Addr(), connectTarget, makeProxyAuthHeader("test-token"))
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, http.StatusForbidden, resp.StatusCode)
require.Equal(t, "Forbidden", string(body), "error details should not be leaked to the client")
case tt.expectDialFail:
// Use a raw CONNECT to observe the 502 returned when ConnectDial fails
// for a reason other than a blocked IP (e.g. unresolvable hostname).
resp := sendConnect(t, srv.Addr(), connectTarget, makeProxyAuthHeader("test-token"))
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, http.StatusBadGateway, resp.StatusCode)
require.Equal(t, "Bad Gateway", string(body))
default:
certPool := x509.NewCertPool()
certPool.AddCert(targetServer.Certificate())
// InsecureSkipVerify is needed for "localhost": by default the cert SAN is 127.0.0.1.
client := newProxyClient(t, srv, makeProxyAuthHeader("test-token"), certPool, tt.targetHostname != "127.0.0.1")
reqURL := fmt.Sprintf("https://%s/", connectTarget)
req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, reqURL, nil)
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
require.Equal(t, "hello from target", string(body))
}
})
}
}
// TestProxy_APIDump verifies that when NewDumper is configured, the proxy
// calls DumpRequest and DumpResponse for MITM'd requests.
func TestProxy_APIDump(t *testing.T) {
t.Parallel()
aibridgedServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"ok":true}`))
}))
t.Cleanup(aibridgedServer.Close)
var (
dumpedProvider string
dumpedRequestID string
reqDumped bool
respDumped bool
)
srv := newTestProxy(t,
withCoderAccessURL(aibridgedServer.URL),
withAllowedPorts("443"),
withProviderHosts(aibridgeproxyd.HostAnthropic),
withNewDumper(func(provider, requestID string) aibridgeproxyd.RoundTripDumper {
dumpedProvider = provider
dumpedRequestID = requestID
return &mockDumper{
onRequest: func() { reqDumped = true },
onResponse: func() { respDumped = true },
}
}),
)
certPool := getProxyCertPool(t)
client := newProxyClient(t, srv, makeProxyAuthHeader("coder-token"), certPool, false)
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "https://api.anthropic.com/v1/messages", strings.NewReader(`{}`))
require.NoError(t, err)
req.Header.Set("Authorization", "Bearer user-llm-token")
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
_, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "anthropic", dumpedProvider)
assert.NotEmpty(t, dumpedRequestID)
_, err = uuid.Parse(dumpedRequestID)
require.NoError(t, err, "request ID passed to NewDumper must be a valid UUID")
assert.True(t, reqDumped, "DumpRequest should have been called")
assert.True(t, respDumped, "DumpResponse should have been called")
}
// TestProxy_APIDump_ErrorsDoNotAffectProxy verifies that dump failures
// do not break the proxied request/response flow.
func TestProxy_APIDump_ErrorsDoNotAffectProxy(t *testing.T) {
t.Parallel()
aibridgedServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"ok":true}`))
}))
t.Cleanup(aibridgedServer.Close)
srv := newTestProxy(t,
withCoderAccessURL(aibridgedServer.URL),
withAllowedPorts("443"),
withProviderHosts(aibridgeproxyd.HostAnthropic),
withNewDumper(func(_, _ string) aibridgeproxyd.RoundTripDumper {
return &failingDumper{}
}),
)
certPool := getProxyCertPool(t)
client := newProxyClient(t, srv, makeProxyAuthHeader("coder-token"), certPool, false)
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "https://api.anthropic.com/v1/messages", strings.NewReader(`{}`))
require.NoError(t, err)
req.Header.Set("Authorization", "Bearer user-token")
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
// The proxy must return the upstream response despite dump errors.
require.Equal(t, http.StatusOK, resp.StatusCode)
require.JSONEq(t, `{"ok":true}`, string(body))
}
type mockDumper struct {
onRequest func()
onResponse func()
onError func()
}
func (m *mockDumper) DumpRequest(_ *http.Request) error {
if m.onRequest != nil {
m.onRequest()
}
return nil
}
func (m *mockDumper) DumpResponse(_ *http.Response) error {
if m.onResponse != nil {
m.onResponse()
}
return nil
}
func (m *mockDumper) DumpError(_ error) error {
if m.onError != nil {
m.onError()
}
return nil
}
// failingDumper always returns errors, used to verify dump failures
// do not affect proxy behavior.
type failingDumper struct{}
func (*failingDumper) DumpRequest(*http.Request) error { return xerrors.New("dump request failed") }
func (*failingDumper) DumpResponse(*http.Response) error { return xerrors.New("dump response failed") }
func (*failingDumper) DumpError(error) error { return xerrors.New("dump error failed") }