feat: add API to serve proxy certificate (#21391)

Closes https://github.com/coder/internal/issues/1184
This commit is contained in:
Danny Kopping
2025-12-29 20:00:06 +02:00
committed by GitHub
parent a173c38715
commit 733b6b7db9
7 changed files with 337 additions and 7 deletions
+1 -1
View File
@@ -343,7 +343,7 @@ func assertAccept(t *testing.T, comment SwaggerComment) {
}
}
var allowedProduceTypes = []string{"json", "text/event-stream", "text/html"}
var allowedProduceTypes = []string{"json", "text/event-stream", "text/html", "text/plain"}
func assertProduce(t *testing.T, comment SwaggerComment) {
var hasResponseModel bool
+47 -5
View File
@@ -5,6 +5,7 @@ import (
"crypto/tls"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"errors"
"net"
"net/http"
@@ -14,6 +15,7 @@ import (
"time"
"github.com/elazarl/goproxy"
"github.com/go-chi/chi/v5"
"golang.org/x/xerrors"
"cdr.dev/slog"
@@ -46,6 +48,9 @@ type Server struct {
httpServer *http.Server
listener net.Listener
coderAccessURL *url.URL
// caCert is the PEM-encoded CA certificate loaded during initialization.
// This is served to clients who need to trust the proxy.
caCert []byte
}
// Options configures the AI Bridge Proxy server.
@@ -87,7 +92,8 @@ func New(ctx context.Context, logger slog.Logger, opts Options) (*Server, error)
}
// Load CA certificate for MITM
if err := loadMitmCertificate(opts.CertFile, opts.KeyFile); err != nil {
certPEM, err := loadMitmCertificate(opts.CertFile, opts.KeyFile)
if err != nil {
return nil, xerrors.Errorf("failed to load MITM certificate: %w", err)
}
@@ -109,6 +115,7 @@ func New(ctx context.Context, logger slog.Logger, opts Options) (*Server, error)
logger: logger,
proxy: proxy,
coderAccessURL: coderAccessURL,
caCert: certPEM,
}
// Reject CONNECT requests to non-standard ports.
@@ -173,17 +180,28 @@ func (s *Server) Close() error {
// loadMitmCertificate loads the CA certificate and private key for MITM proxying.
// This function is safe to call concurrently - the certificate is only loaded once
// into the global goproxy.GoproxyCa variable.
func loadMitmCertificate(certFile, keyFile string) error {
// Returns the PEM-encoded certificate for serving to clients.
func loadMitmCertificate(certFile, keyFile string) ([]byte, error) {
tlsCert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return xerrors.Errorf("load CA certificate: %w", err)
return nil, xerrors.Errorf("load CA certificate: %w", err)
}
if len(tlsCert.Certificate) == 0 {
return nil, xerrors.Errorf("no certificates found")
}
x509Cert, err := x509.ParseCertificate(tlsCert.Certificate[0])
if err != nil {
return xerrors.Errorf("parse CA certificate: %w", err)
return nil, xerrors.Errorf("parse CA certificate: %w", err)
}
// Ensure that we only return the certificate and never any included private keys.
certPEM := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: tlsCert.Certificate[0],
})
// Only protect the global assignment with sync.Once
loadMitmOnce.Do(func() {
goproxy.GoproxyCa = tls.Certificate{
@@ -193,7 +211,7 @@ func loadMitmCertificate(certFile, keyFile string) error {
}
})
return nil
return certPEM, nil
}
// portMiddleware is a CONNECT middleware that rejects requests to non-standard ports.
@@ -389,3 +407,27 @@ func (s *Server) handleRequest(req *http.Request, ctx *goproxy.ProxyCtx) (*http.
return req, nil
}
// Handler returns an HTTP handler for the AI Bridge Proxy's HTTP endpoints.
// This is separate from the proxy server itself and is used by coderd to
// serve endpoints like the CA certificate.
func (s *Server) Handler() http.Handler {
r := chi.NewRouter()
r.Get("/ca-cert.pem", s.serveCACert)
return r
}
// serveCACert is an HTTP handler that serves the CA certificate used for MITM
// proxying. Clients need this certificate to trust the proxy's intercepted
// connections. The certificate was validated during server initialization.
func (s *Server) serveCACert(rw http.ResponseWriter, _ *http.Request) {
if len(s.caCert) == 0 {
http.Error(rw, "CA certificate not configured", http.StatusNotFound)
return
}
rw.Header().Set("Content-Type", "application/x-pem-file")
rw.Header().Set("Content-Disposition", "attachment; filename=ca-cert.pem")
rw.WriteHeader(http.StatusOK)
_, _ = rw.Write(s.caCert)
}
@@ -749,3 +749,128 @@ func TestProxy_MITM(t *testing.T) {
})
}
}
// TestServeCACert validates that a configured certificate file can be served correctly by the API.
//
// Note: Tests for certificate file errors (missing file, invalid PEM) are
// covered by [TestNew] since certificate validation happens at initialization.
// The serveCACert handler returns the pre-loaded, pre-validated certificate.
func TestServeCACert(t *testing.T) {
t.Parallel()
t.Run("Success", func(t *testing.T) {
t.Parallel()
certFile, keyFile := getSharedTestCA(t)
logger := slogtest.Make(t, nil)
srv, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{
ListenAddr: "127.0.0.1:0",
CoderAccessURL: "http://localhost:3000",
CertFile: certFile,
KeyFile: keyFile,
})
require.NoError(t, err)
t.Cleanup(func() { _ = srv.Close() })
// Create a request to the CA cert endpoint via the Handler.
req := httptest.NewRequest(http.MethodGet, "/ca-cert.pem", nil)
rec := httptest.NewRecorder()
srv.Handler().ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "application/x-pem-file", rec.Header().Get("Content-Type"))
require.Equal(t, "attachment; filename=ca-cert.pem", rec.Header().Get("Content-Disposition"))
// Verify the certificate is valid PEM.
body := rec.Body.Bytes()
block, _ := pem.Decode(body)
require.NotNil(t, block, "response should be valid PEM")
require.Equal(t, "CERTIFICATE", block.Type)
// Verify the certificate is valid X.509.
cert, err := x509.ParseCertificate(block.Bytes)
require.NoError(t, err)
require.NotNil(t, cert)
// Verify it matches the original certificate.
expectedCertPEM, err := os.ReadFile(certFile)
require.NoError(t, err)
require.Equal(t, expectedCertPEM, body)
})
}
// TestServeCACert_CompoundPEM validates that a compound PEM certificate which contains a private key
// will only have its certificate type returned from the /ca-cert.pem endpoint.
func TestServeCACert_CompoundPEM(t *testing.T) {
t.Parallel()
certFile, keyFile := getSharedTestCA(t)
// Read the shared CA cert and key to create a compound PEM file.
certPEM, err := os.ReadFile(certFile)
require.NoError(t, err)
keyPEM, err := os.ReadFile(keyFile)
require.NoError(t, err)
// Create a compound PEM file containing both the certificate and the private key.
compoundPEM := make([]byte, 0, len(certPEM)+len(keyPEM))
compoundPEM = append(compoundPEM, certPEM...)
compoundPEM = append(compoundPEM, keyPEM...)
tmpDir := t.TempDir()
compoundCertFile := filepath.Join(tmpDir, "compound.pem")
err = os.WriteFile(compoundCertFile, compoundPEM, 0o600)
require.NoError(t, err)
logger := slogtest.Make(t, nil)
srv, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{
ListenAddr: "127.0.0.1:0",
CoderAccessURL: "http://localhost:3000",
CertFile: compoundCertFile,
KeyFile: keyFile,
})
require.NoError(t, err)
t.Cleanup(func() { _ = srv.Close() })
// Create a request to the CA cert endpoint via the Handler.
req := httptest.NewRequest(http.MethodGet, "/ca-cert.pem", nil)
rec := httptest.NewRecorder()
srv.Handler().ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
// Verify the response contains only the certificate, not the private key.
body := rec.Body.Bytes()
// Parse all PEM blocks from the response.
var pemBlocks []*pem.Block
remaining := body
for {
var block *pem.Block
block, remaining = pem.Decode(remaining)
if block == nil {
break
}
pemBlocks = append(pemBlocks, block)
}
// There should be exactly one PEM block (the certificate).
require.Len(t, pemBlocks, 1, "response should contain exactly one PEM block")
require.Equal(t, "CERTIFICATE", pemBlocks[0].Type, "the PEM block should be a certificate")
// Verify no private key material is present by checking for common key block types.
bodyStr := string(body)
require.NotContains(t, bodyStr, "PRIVATE KEY", "response should not contain any private key")
require.NotContains(t, bodyStr, "RSA PRIVATE KEY", "response should not contain RSA private key")
require.NotContains(t, bodyStr, "EC PRIVATE KEY", "response should not contain EC private key")
// Verify the certificate is valid X.509.
cert, err := x509.ParseCertificate(pemBlocks[0].Bytes)
require.NoError(t, err)
require.Equal(t, "Shared Test CA", cert.Subject.CommonName)
}
+3
View File
@@ -173,6 +173,9 @@ func (r *RootCmd) Server(_ func()) *serpent.Command {
return nil, nil, xerrors.Errorf("create aibridgeproxyd: %w", err)
}
closers.Add(aiBridgeProxyServer)
// Register the handler so coderd can serve the proxy endpoints.
api.RegisterInMemoryAIBridgeProxydHTTPHandler(aiBridgeProxyServer.Handler())
}
return api.AGPL, closers, nil
+50
View File
@@ -0,0 +1,50 @@
package coderd
import (
"net/http"
"github.com/go-chi/chi/v5"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/codersdk"
)
// RegisterInMemoryAIBridgeProxydHTTPHandler mounts [aibridgeproxyd.Server]'s HTTP handler
// onto [API]'s router, so that requests to aibridgedproxy will be relayed from Coder's API server
// to the in-memory aibridgedproxy.
func (api *API) RegisterInMemoryAIBridgeProxydHTTPHandler(srv http.Handler) {
if srv == nil {
panic("aibridgeproxyd cannot be nil")
}
api.aibridgeproxydHandler = srv
}
// aibridgeproxyHandler handles AI Bridge Proxy endpoints.
func aibridgeproxyHandler(api *API, middlewares ...func(http.Handler) http.Handler) func(r chi.Router) {
return func(r chi.Router) {
r.Use(api.RequireFeatureMW(codersdk.FeatureAIBridge))
r.Use(middlewares...)
r.HandleFunc("/*", func(rw http.ResponseWriter, r *http.Request) {
// Check if the proxy is enabled.
if !api.DeploymentValues.AI.BridgeProxyConfig.Enabled.Value() {
httpapi.Write(r.Context(), rw, http.StatusNotFound, codersdk.Response{
Message: "AI Bridge Proxy is not enabled.",
})
return
}
// Check if the handler is registered.
if api.aibridgeproxydHandler == nil {
httpapi.Write(r.Context(), rw, http.StatusNotFound, codersdk.Response{
Message: "AI Bridge Proxy handler not mounted.",
})
return
}
// Strip the prefix and relay to the aibridgeproxyd handler.
http.StripPrefix("/api/v2/aibridge/proxy", api.aibridgeproxydHandler).ServeHTTP(rw, r)
})
}
}
+105
View File
@@ -0,0 +1,105 @@
package coderd_test
import (
"net/http"
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/enterprise/coderd/coderdenttest"
"github.com/coder/coder/v2/enterprise/coderd/license"
"github.com/coder/coder/v2/testutil"
)
func TestAIBridgeProxyCertificateRetrieval(t *testing.T) {
t.Parallel()
t.Run("DisabledReturns404", func(t *testing.T) {
t.Parallel()
dv := coderdtest.DeploymentValues(t)
// Proxy is disabled by default, so we don't need to set it explicitly.
client, _ := coderdenttest.New(t, &coderdenttest.Options{
Options: &coderdtest.Options{
DeploymentValues: dv,
},
LicenseOptions: &coderdenttest.LicenseOptions{
Features: license.Features{
codersdk.FeatureAIBridge: 1,
},
},
})
ctx := testutil.Context(t, testutil.WaitLong)
// Make a request to the proxy CA cert endpoint.
req, err := http.NewRequestWithContext(ctx, http.MethodGet, client.URL.String()+"/api/v2/aibridge/proxy/ca-cert.pem", nil)
require.NoError(t, err)
req.Header.Set(codersdk.SessionTokenHeader, client.SessionToken())
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusNotFound, resp.StatusCode)
})
t.Run("RequiresLicenseFeature", func(t *testing.T) {
t.Parallel()
dv := coderdtest.DeploymentValues(t)
client, _ := coderdenttest.New(t, &coderdenttest.Options{
Options: &coderdtest.Options{
DeploymentValues: dv,
},
LicenseOptions: &coderdenttest.LicenseOptions{
// No aibridge feature.
Features: license.Features{},
},
})
ctx := testutil.Context(t, testutil.WaitLong)
// Make a request to the proxy CA cert endpoint.
req, err := http.NewRequestWithContext(ctx, http.MethodGet, client.URL.String()+"/api/v2/aibridge/proxy/ca-cert.pem", nil)
require.NoError(t, err)
req.Header.Set(codersdk.SessionTokenHeader, client.SessionToken())
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusForbidden, resp.StatusCode)
})
t.Run("RequiresAuthentication", func(t *testing.T) {
t.Parallel()
dv := coderdtest.DeploymentValues(t)
client, _ := coderdenttest.New(t, &coderdenttest.Options{
Options: &coderdtest.Options{
DeploymentValues: dv,
},
LicenseOptions: &coderdenttest.LicenseOptions{
Features: license.Features{
codersdk.FeatureAIBridge: 1,
},
},
})
ctx := testutil.Context(t, testutil.WaitLong)
// Make a request to the proxy CA cert endpoint without authentication.
req, err := http.NewRequestWithContext(ctx, http.MethodGet, client.URL.String()+"/api/v2/aibridge/proxy/ca-cert.pem", nil)
require.NoError(t, err)
// No session token header set.
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusUnauthorized, resp.StatusCode)
})
}
+6 -1
View File
@@ -230,6 +230,10 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
r.Route("/aibridge", aibridgeHandler(api, apiKeyMiddleware))
})
api.AGPL.APIHandler.Group(func(r chi.Router) {
r.Route("/aibridge/proxy", aibridgeproxyHandler(api, apiKeyMiddleware))
})
api.AGPL.APIHandler.Group(func(r chi.Router) {
r.Get("/entitlements", api.serveEntitlements)
// /regions overrides the AGPL /regions endpoint
@@ -691,7 +695,8 @@ type API struct {
licenseMetricsCollector *license.MetricsCollector
tailnetService *tailnet.ClientService
aibridgedHandler http.Handler
aibridgedHandler http.Handler
aibridgeproxydHandler http.Handler
}
// writeEntitlementWarningsHeader writes the entitlement warnings to the response header