diff --git a/coderd/aibridge/factory.go b/coderd/aibridge/factory.go new file mode 100644 index 0000000000..3db9dc8a1d --- /dev/null +++ b/coderd/aibridge/factory.go @@ -0,0 +1,45 @@ +package aibridge + +import ( + "context" + "net/http" + + "github.com/google/uuid" +) + +// Source identifies the call site that asked aibridge for a transport. It is +// attached to the request context so downstream handlers and logs can attribute +// traffic without changing behavior based on the value. +type Source string + +// SourceAgents is chatd traffic originating from a Coder agent. +const SourceAgents Source = "agents" + +type sourceCtxKey struct{} + +// WithSource returns a copy of ctx carrying the given Source. Use this on the +// request context before invoking a downstream handler so [SourceFromContext] +// can recover it for logging. +func WithSource(ctx context.Context, src Source) context.Context { + return context.WithValue(ctx, sourceCtxKey{}, src) +} + +// SourceFromContext returns the Source attached by [WithSource], or the empty +// string when no Source is set. +func SourceFromContext(ctx context.Context) Source { + src, _ := ctx.Value(sourceCtxKey{}).(Source) + return src +} + +// TransportFactory returns an [http.RoundTripper] that dispatches an aibridge +// request in-process for a given ai_providers row. +// +// Implementations live in coderd/aibridged. coderd registers an in-process +// factory on coderd.API.AIBridgeTransportFactory at startup so callers route +// traffic through the daemon without going through the gated HTTP route. +// +// Source is informational: implementations must not gate on it. It is attached +// to the request context so handlers can include it in logs and metrics. +type TransportFactory interface { + TransportFor(providerID uuid.UUID, source Source) (http.RoundTripper, error) +} diff --git a/coderd/aibridge_test.go b/coderd/aibridge_test.go new file mode 100644 index 0000000000..de16bdaf27 --- /dev/null +++ b/coderd/aibridge_test.go @@ -0,0 +1,101 @@ +package coderd_test + +import ( + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/coderd/aibridge" + "github.com/coder/coder/v2/coderd/coderdtest" + "github.com/coder/coder/v2/testutil" +) + +// stubTransportFactory wires a deterministic handler through the +// AIBridgeTransportFactory hook so the AGPL side of the in-memory pipe can be +// exercised without pulling coderd/aibridged in. +type stubTransportFactory struct { + handler http.Handler + calls chan callRecord +} + +type callRecord struct { + providerID uuid.UUID + source aibridge.Source +} + +func (f *stubTransportFactory) TransportFor(providerID uuid.UUID, source aibridge.Source) (http.RoundTripper, error) { + f.calls <- callRecord{providerID: providerID, source: source} + return &handlerRoundTripper{handler: f.handler}, nil +} + +// handlerRoundTripper is a minimal http.RoundTripper for the AGPL test. It +// does not stream; coderd/aibridged.transport_test.go already covers +// streaming semantics. +type handlerRoundTripper struct{ handler http.Handler } + +func (h *handlerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + rec := httptest.NewRecorder() + h.handler.ServeHTTP(rec, req) + resp := rec.Result() + resp.Request = req + return resp, nil +} + +// Verify that a factory stored on coderd.API.AIBridgeTransportFactory is +// observable through the normal API lifecycle: cli/server.go registers it +// when the bridge daemon starts (see RegisterInMemoryAIBridgedHTTPHandler). +func TestAIBridgeTransportFactory_Registration(t *testing.T) { + t.Parallel() + + _, _, api := coderdtest.NewWithAPI(t, nil) + + require.Nil(t, api.AIBridgeTransportFactory.Load(), + "AGPL coderd must not pre-populate the factory") + + stub := &stubTransportFactory{ + handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"bridged":true}`)) + }), + calls: make(chan callRecord, 4), + } + + var asInterface aibridge.TransportFactory = stub + api.AIBridgeTransportFactory.Store(&asInterface) + + loaded := api.AIBridgeTransportFactory.Load() + require.NotNil(t, loaded) + + providerID := uuid.New() + rt, err := (*loaded).TransportFor(providerID, aibridge.SourceAgents) + require.NoError(t, err) + require.NotNil(t, rt) + + select { + case got := <-stub.calls: + require.Equal(t, providerID, got.providerID) + require.Equal(t, aibridge.SourceAgents, got.source) + default: + t.Fatal("factory was not invoked") + } + + ctx := testutil.Context(t, testutil.WaitShort) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, "http://aibridge/v1/messages", nil) + require.NoError(t, err) + + client := &http.Client{Transport: rt} + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + require.Equal(t, `{"bridged":true}`, string(body)) + require.Equal(t, "application/json", resp.Header.Get("Content-Type")) +} diff --git a/coderd/aibridged.go b/coderd/aibridged.go index 2efd186507..30439752cc 100644 --- a/coderd/aibridged.go +++ b/coderd/aibridged.go @@ -11,6 +11,7 @@ import ( "storj.io/drpc/drpcserver" "cdr.dev/slog/v3" + agplaibridge "github.com/coder/coder/v2/coderd/aibridge" "github.com/coder/coder/v2/coderd/aibridged" aibridgedproto "github.com/coder/coder/v2/coderd/aibridged/proto" "github.com/coder/coder/v2/coderd/aibridgedserver" @@ -30,12 +31,22 @@ func (api *API) GetAIBridgedHandler() http.Handler { // RegisterInMemoryAIBridgedHTTPHandler mounts [aibridged.Server]'s HTTP router onto // [API]'s router, so that requests to aibridged will be relayed from Coder's API server // to the in-memory aibridged. +// +// This also registers an in-process [agplaibridge.TransportFactory] so that +// chatd can route coder-agent LLM traffic through aibridge without crossing +// the HTTP route. No license entitlement gate is applied at the factory layer: +// the entitlement check stays on the HTTP route for external callers, while +// in-process coder-agent traffic is the explicit carve-out. func (api *API) RegisterInMemoryAIBridgedHTTPHandler(srv http.Handler) { if srv == nil { panic("aibridged cannot be nil") } api.aibridgedHandler = srv + + factory := aibridged.NewTransportFactory(srv) + var asInterface agplaibridge.TransportFactory = factory + api.AIBridgeTransportFactory.Store(&asInterface) } // CreateInMemoryAIBridgeServer creates a [aibridged.DRPCServer] and returns a diff --git a/coderd/aibridged/transport.go b/coderd/aibridged/transport.go new file mode 100644 index 0000000000..391f3b09c4 --- /dev/null +++ b/coderd/aibridged/transport.go @@ -0,0 +1,172 @@ +package aibridged + +import ( + "fmt" + "io" + "net/http" + "sync" + + "github.com/google/uuid" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/aibridge" +) + +// NewTransportFactory returns an [aibridge.TransportFactory] whose RoundTripper +// dispatches requests to handler in-process, streaming the response body +// through an [io.Pipe] so SSE/NDJSON/chunked responses propagate token-by-token +// just as they would over the wire. +// +// handler is typically the aibridged HTTP entrypoint registered via +// [API.RegisterInMemoryAIBridgedHTTPHandler]. +func NewTransportFactory(handler http.Handler) aibridge.TransportFactory { + return &transportFactory{handler: handler} +} + +type transportFactory struct { + handler http.Handler +} + +// TransportFor returns an in-process [http.RoundTripper] that dispatches +// requests through the aibridged handler. The source is attached to the +// request context for downstream logging; routing does not depend on it. +func (f *transportFactory) TransportFor(_ uuid.UUID, source aibridge.Source) (http.RoundTripper, error) { + if f.handler == nil { + return nil, xerrors.New("aibridged handler not registered") + } + return &inMemoryRoundTripper{handler: f.handler, source: source}, nil +} + +// inMemoryRoundTripper implements [http.RoundTripper] by invoking handler +// in a goroutine and streaming its response back through an [io.Pipe]. +type inMemoryRoundTripper struct { + handler http.Handler + source aibridge.Source +} + +func (t *inMemoryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + pr, pw := io.Pipe() + rw := &pipeResponseWriter{ + header: http.Header{}, + body: pw, + gotHeaders: make(chan struct{}), + status: http.StatusOK, + } + + // Cloning preserves caller-supplied headers and context but lets the + // handler operate on its own request value without surprising the caller + // if it mutates Headers or stores the request. The Source is attached to + // the served context so downstream handlers can log the call site. + served := req.Clone(aibridge.WithSource(req.Context(), t.source)) + + handlerDone := make(chan struct{}) + go func() { + defer func() { + if r := recover(); r != nil { + // Mirror net/http.Server behavior: a panicking handler + // produces a 500 instead of crashing the process. + rw.WriteHeader(http.StatusInternalServerError) + _ = pw.CloseWithError(xerrors.Errorf("handler panicked: %v", r)) + } + // Make sure we always unblock RoundTrip even if the handler + // returns before writing headers (e.g. handler returns early + // without writing). + rw.ensureHeaders() + // If the request context was canceled, surface that as a + // body-read error so the caller sees a network-style failure + // rather than EOF. Otherwise close cleanly. + if cerr := served.Context().Err(); cerr != nil { + _ = pw.CloseWithError(cerr) + } else { + _ = pw.Close() + } + close(handlerDone) + }() + t.handler.ServeHTTP(rw, served) + }() + + // Close the pipe eagerly when the caller cancels, so an unresponsive + // handler does not strand the consumer's body read. The handler's own + // context derives from req.Context(), so it observes the same + // cancellation independently. The goroutine also exits when the handler + // completes normally (handlerDone closes) to avoid leaking a parked + // goroutine per successful request. + go func() { + select { + case <-served.Context().Done(): + _ = pw.CloseWithError(served.Context().Err()) + case <-handlerDone: + // Handler finished; nothing to cancel. + } + }() + + select { + case <-rw.gotHeaders: + case <-served.Context().Done(): + return nil, served.Context().Err() + } + + return &http.Response{ + Status: fmt.Sprintf("%d %s", rw.status, http.StatusText(rw.status)), + StatusCode: rw.status, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: rw.frozenHeader, + Body: pr, + Request: req, + ContentLength: -1, // streaming; unknown length + }, nil +} + +// pipeResponseWriter is an [http.ResponseWriter] that streams the response +// body into an [io.PipeWriter]. The first call to WriteHeader (implicit or +// explicit) closes gotHeaders so the RoundTrip caller can return an +// *http.Response while the handler keeps writing. +type pipeResponseWriter struct { + header http.Header + frozenHeader http.Header + body *io.PipeWriter + + once sync.Once + gotHeaders chan struct{} + status int +} + +func (w *pipeResponseWriter) Header() http.Header { + return w.header +} + +func (w *pipeResponseWriter) WriteHeader(status int) { + w.once.Do(func() { + w.status = status + w.frozenHeader = w.header.Clone() + close(w.gotHeaders) + }) +} + +func (w *pipeResponseWriter) Write(p []byte) (int, error) { + // net/http semantics: an implicit 200 OK on first Write if the handler + // did not call WriteHeader explicitly. + w.WriteHeader(http.StatusOK) + return w.body.Write(p) +} + +// Flush is a no-op: pipe writes are already synchronous with the reader, so +// each Write is observed as soon as the reader consumes it. We satisfy +// [http.Flusher] so handlers that type-assert it (the aibridge library does +// for SSE) do not fall back to buffered mode. +func (*pipeResponseWriter) Flush() {} + +// ensureHeaders closes gotHeaders if it has not already been closed, with the +// current status. Used to unblock RoundTrip on handler return-without-write. +func (w *pipeResponseWriter) ensureHeaders() { + w.once.Do(func() { + close(w.gotHeaders) + }) +} + +var ( + _ http.ResponseWriter = (*pipeResponseWriter)(nil) + _ http.Flusher = (*pipeResponseWriter)(nil) +) diff --git a/coderd/aibridged/transport_test.go b/coderd/aibridged/transport_test.go new file mode 100644 index 0000000000..9d2d39864d --- /dev/null +++ b/coderd/aibridged/transport_test.go @@ -0,0 +1,289 @@ +package aibridged_test + +import ( + "bufio" + "context" + "fmt" + "io" + "net/http" + "strings" + "sync" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/aibridge" + "github.com/coder/coder/v2/coderd/aibridged" + "github.com/coder/coder/v2/testutil" +) + +func TestTransportFactory_TransportFor(t *testing.T) { + t.Parallel() + + t.Run("ReturnsTransport", func(t *testing.T) { + t.Parallel() + f := aibridged.NewTransportFactory(http.NotFoundHandler()) + rt, err := f.TransportFor(uuid.New(), aibridge.SourceAgents) + require.NoError(t, err) + require.NotNil(t, rt) + }) + + t.Run("NilHandlerErrors", func(t *testing.T) { + t.Parallel() + f := aibridged.NewTransportFactory(nil) + _, err := f.TransportFor(uuid.New(), aibridge.SourceAgents) + require.Error(t, err) + }) + + t.Run("AttachesSourceToContext", func(t *testing.T) { + t.Parallel() + + got := make(chan aibridge.Source, 1) + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + got <- aibridge.SourceFromContext(r.Context()) + w.WriteHeader(http.StatusOK) + }) + + rt, err := aibridged.NewTransportFactory(handler).TransportFor(uuid.New(), aibridge.SourceAgents) + require.NoError(t, err) + + ctx := testutil.Context(t, testutil.WaitShort) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, "http://aibridge/v1/test", nil) + require.NoError(t, err) + + resp, err := rt.RoundTrip(req) + require.NoError(t, err) + defer resp.Body.Close() + + require.Equal(t, aibridge.SourceAgents, <-got) + }) +} + +func TestInMemoryRoundTripper_PassesHeadersAndStatus(t *testing.T) { + t.Parallel() + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Custom", "yes") + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusTeapot) + _, _ = w.Write([]byte(`{"ok":true}`)) + }) + + rt, err := aibridged.NewTransportFactory(handler).TransportFor(uuid.New(), aibridge.SourceAgents) + require.NoError(t, err) + + ctx := testutil.Context(t, testutil.WaitShort) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, "http://aibridge/v1/test", nil) + require.NoError(t, err) + + resp, err := rt.RoundTrip(req) + require.NoError(t, err) + defer resp.Body.Close() + + require.Equal(t, http.StatusTeapot, resp.StatusCode) + require.Equal(t, "418 I'm a teapot", resp.Status) + require.Equal(t, "yes", resp.Header.Get("X-Custom")) + require.Equal(t, "application/json", resp.Header.Get("Content-Type")) + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, `{"ok":true}`, string(body)) +} + +// Verify that response chunks become readable on the client side before the +// handler has finished writing. This is the property SSE/NDJSON streaming +// depends on. +func TestInMemoryRoundTripper_Streams(t *testing.T) { + t.Parallel() + + const chunks = 4 + released := make([]chan struct{}, chunks) + for i := range released { + released[i] = make(chan struct{}) + } + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + flusher, ok := w.(http.Flusher) + if !assert.True(t, ok, "ResponseWriter must implement http.Flusher") { + return + } + for i := range chunks { + <-released[i] + _, err := fmt.Fprintf(w, "data: chunk-%d\n\n", i) + if !assert.NoError(t, err) { + return + } + flusher.Flush() + } + }) + + rt, err := aibridged.NewTransportFactory(handler).TransportFor(uuid.New(), aibridge.SourceAgents) + require.NoError(t, err) + + ctx := testutil.Context(t, testutil.WaitShort) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, "http://aibridge/stream", nil) + require.NoError(t, err) + + resp, err := rt.RoundTrip(req) + require.NoError(t, err) + defer resp.Body.Close() + + br := bufio.NewReader(resp.Body) + for i := range chunks { + close(released[i]) + dataLine, err := br.ReadString('\n') + require.NoError(t, err) + require.Equal(t, fmt.Sprintf("data: chunk-%d\n", i), dataLine) + // Consume blank-line separator. + _, err = br.ReadString('\n') + require.NoError(t, err) + } +} + +// Canceling the request context must surface as a body-read error, matching +// real-network behavior, and the handler must observe the cancellation +// through its own request context. +func TestInMemoryRoundTripper_CancelCloses(t *testing.T) { + t.Parallel() + + handlerCtxObserved := make(chan struct{}) + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + <-r.Context().Done() + close(handlerCtxObserved) + }) + + rt, err := aibridged.NewTransportFactory(handler).TransportFor(uuid.New(), aibridge.SourceAgents) + require.NoError(t, err) + + parentCtx := testutil.Context(t, testutil.WaitShort) + ctx, cancel := context.WithCancel(parentCtx) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, "http://aibridge/stream", nil) + require.NoError(t, err) + + resp, err := rt.RoundTrip(req) + require.NoError(t, err) + defer resp.Body.Close() + + cancel() + _, err = io.ReadAll(resp.Body) + require.Error(t, err) + + select { + case <-handlerCtxObserved: + case <-parentCtx.Done(): + t.Fatal("handler did not observe context cancellation") + } +} + +// Many independent in-flight requests on a shared handler must not interfere. +func TestInMemoryRoundTripper_ConcurrentRequests(t *testing.T) { + t.Parallel() + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write(body) + }) + + rt, err := aibridged.NewTransportFactory(handler).TransportFor(uuid.New(), aibridge.SourceAgents) + require.NoError(t, err) + + const n = 16 + errs := make(chan error, n) + var wg sync.WaitGroup + for i := range n { + wg.Go(func() { + payload := fmt.Sprintf("payload-%d", i) + ctx := testutil.Context(t, testutil.WaitShort) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, "http://aibridge/echo", strings.NewReader(payload)) + if err != nil { + errs <- err + return + } + resp, err := rt.RoundTrip(req) + if err != nil { + errs <- err + return + } + defer resp.Body.Close() + got, err := io.ReadAll(resp.Body) + if err != nil { + errs <- err + return + } + if string(got) != payload { + errs <- xerrors.Errorf("payload mismatch: want %q got %q", payload, string(got)) + return + } + errs <- nil + }) + } + wg.Wait() + close(errs) + for err := range errs { + require.NoError(t, err) + } +} + +// A panicking handler must not crash the process; it should produce a 500 +// response with an error on the body read, mirroring net/http.Server behavior. +func TestInMemoryRoundTripper_HandlerPanic(t *testing.T) { + t.Parallel() + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + panic("unexpected nil pointer") + }) + + rt, err := aibridged.NewTransportFactory(handler).TransportFor(uuid.New(), aibridge.SourceAgents) + require.NoError(t, err) + + ctx := testutil.Context(t, testutil.WaitShort) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, "http://aibridge/panic", nil) + require.NoError(t, err) + + resp, err := rt.RoundTrip(req) + require.NoError(t, err) + defer resp.Body.Close() + + require.Equal(t, http.StatusInternalServerError, resp.StatusCode) + _, err = io.ReadAll(resp.Body) + require.Error(t, err) + require.Contains(t, err.Error(), "handler panicked") +} + +// A handler that returns without writing must not block RoundTrip; the caller +// gets a zero-length 200 OK. +func TestInMemoryRoundTripper_HandlerReturnsWithoutWriting(t *testing.T) { + t.Parallel() + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + + rt, err := aibridged.NewTransportFactory(handler).TransportFor(uuid.New(), aibridge.SourceAgents) + require.NoError(t, err) + + ctx := testutil.Context(t, testutil.WaitShort) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, "http://aibridge/noop", nil) + require.NoError(t, err) + + resp, err := rt.RoundTrip(req) + require.NoError(t, err) + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Empty(t, body) + require.Equal(t, http.StatusOK, resp.StatusCode) +} diff --git a/coderd/coderd.go b/coderd/coderd.go index dc659b8263..91d95c5ef2 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -45,6 +45,7 @@ import ( "github.com/coder/coder/v2/buildinfo" "github.com/coder/coder/v2/coderd/agentapi" "github.com/coder/coder/v2/coderd/agentapi/metadatabatcher" + "github.com/coder/coder/v2/coderd/aibridge" "github.com/coder/coder/v2/coderd/aibridge/prices" "github.com/coder/coder/v2/coderd/aiseats" _ "github.com/coder/coder/v2/coderd/apidoc" // Used for swagger docs. @@ -2198,9 +2199,15 @@ type API struct { // UsageInserter is a pointer to an atomic pointer because it is passed to // multiple components. UsageInserter *atomic.Pointer[usage.Inserter] + // AIBridgeTransportFactory, when non-nil, lets chatd route LLM requests + // through an in-process aibridge transport instead of calling upstream + // providers directly. Registered by coderd at startup once aibridged is + // wired in-memory. + AIBridgeTransportFactory atomic.Pointer[aibridge.TransportFactory] // aibridgedHandler is the in-memory aibridge HTTP handler. Set by - // RegisterInMemoryAIBridgedHTTPHandler; read by the enterprise - // /api/v2/aibridge route (license-gated). + // RegisterInMemoryAIBridgedHTTPHandler; read both by the enterprise + // /api/v2/aibridge route (license-gated) and by the in-memory transport + // (used by chatd, license-exempt). aibridgedHandler http.Handler UpdatesProvider tailnet.WorkspaceUpdatesProvider