mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
feat: add in-memory transport for chatd -> aibridge routing (#25576)
### TL;DR Introduces an in-process `TransportFactory` for aibridge so that chatd (coder-agent LLM traffic) can route requests through the aibridged handler without crossing the HTTP route or requiring a license entitlement check. ### What changed? - Added a new `coderd/aibridge` package with a `TransportFactory` interface and a `Source` type for tagging the call site on request contexts. `SourceAgents` is defined as the constant for coder-agent traffic. - Implemented `NewTransportFactory` in `coderd/aibridged/transport.go`, which returns an `http.RoundTripper` that dispatches requests to the aibridged handler in-process. The response body is streamed through an `io.Pipe` so SSE/NDJSON/chunked responses propagate token-by-token. Handler panics are recovered and surfaced as 500 responses, and context cancellation closes the pipe with the appropriate error. - `RegisterInMemoryAIBridgedHTTPHandler` now also constructs a `TransportFactory` from the registered handler and stores it on `API.AIBridgeTransportFactory` (an `atomic.Pointer`), making it available to chatd without going through the license-gated HTTP route. - Added `API.AIBridgeTransportFactory` as a public `atomic.Pointer[aibridge.TransportFactory]` field on `coderd.API`. ### How to test? - `coderd/aibridged/transport_test.go` covers: transport creation, nil-handler errors, source attachment to context, header/status passthrough, streaming (SSE-style chunked writes visible before handler completion), context cancellation closing the body with an error, concurrent requests, handler panics producing 500s, and handlers that return without writing. - `coderd/aibridge_test.go` verifies that `AIBridgeTransportFactory` starts as nil on AGPL coderd, can be stored and loaded atomically, and that the stored factory correctly dispatches requests through the stub handler. ### Why make this change? Chatd needs to send LLM requests through aibridge in-process rather than via the external HTTP route, which is license-gated. The `TransportFactory` abstraction provides a clean seam: the entitlement check remains on the HTTP route for external callers, while in-process coder-agent traffic bypasses it through the factory. The `Source` type allows downstream handlers and logs to attribute traffic without gating behavior on the caller identity.
This commit is contained in:
@@ -0,0 +1,45 @@
|
||||
package aibridge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// Source identifies the call site that asked aibridge for a transport. It is
|
||||
// attached to the request context so downstream handlers and logs can attribute
|
||||
// traffic without changing behavior based on the value.
|
||||
type Source string
|
||||
|
||||
// SourceAgents is chatd traffic originating from a Coder agent.
|
||||
const SourceAgents Source = "agents"
|
||||
|
||||
type sourceCtxKey struct{}
|
||||
|
||||
// WithSource returns a copy of ctx carrying the given Source. Use this on the
|
||||
// request context before invoking a downstream handler so [SourceFromContext]
|
||||
// can recover it for logging.
|
||||
func WithSource(ctx context.Context, src Source) context.Context {
|
||||
return context.WithValue(ctx, sourceCtxKey{}, src)
|
||||
}
|
||||
|
||||
// SourceFromContext returns the Source attached by [WithSource], or the empty
|
||||
// string when no Source is set.
|
||||
func SourceFromContext(ctx context.Context) Source {
|
||||
src, _ := ctx.Value(sourceCtxKey{}).(Source)
|
||||
return src
|
||||
}
|
||||
|
||||
// TransportFactory returns an [http.RoundTripper] that dispatches an aibridge
|
||||
// request in-process for a given ai_providers row.
|
||||
//
|
||||
// Implementations live in coderd/aibridged. coderd registers an in-process
|
||||
// factory on coderd.API.AIBridgeTransportFactory at startup so callers route
|
||||
// traffic through the daemon without going through the gated HTTP route.
|
||||
//
|
||||
// Source is informational: implementations must not gate on it. It is attached
|
||||
// to the request context so handlers can include it in logs and metrics.
|
||||
type TransportFactory interface {
|
||||
TransportFor(providerID uuid.UUID, source Source) (http.RoundTripper, error)
|
||||
}
|
||||
@@ -0,0 +1,101 @@
|
||||
package coderd_test
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/aibridge"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
// stubTransportFactory wires a deterministic handler through the
|
||||
// AIBridgeTransportFactory hook so the AGPL side of the in-memory pipe can be
|
||||
// exercised without pulling coderd/aibridged in.
|
||||
type stubTransportFactory struct {
|
||||
handler http.Handler
|
||||
calls chan callRecord
|
||||
}
|
||||
|
||||
type callRecord struct {
|
||||
providerID uuid.UUID
|
||||
source aibridge.Source
|
||||
}
|
||||
|
||||
func (f *stubTransportFactory) TransportFor(providerID uuid.UUID, source aibridge.Source) (http.RoundTripper, error) {
|
||||
f.calls <- callRecord{providerID: providerID, source: source}
|
||||
return &handlerRoundTripper{handler: f.handler}, nil
|
||||
}
|
||||
|
||||
// handlerRoundTripper is a minimal http.RoundTripper for the AGPL test. It
|
||||
// does not stream; coderd/aibridged.transport_test.go already covers
|
||||
// streaming semantics.
|
||||
type handlerRoundTripper struct{ handler http.Handler }
|
||||
|
||||
func (h *handlerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
rec := httptest.NewRecorder()
|
||||
h.handler.ServeHTTP(rec, req)
|
||||
resp := rec.Result()
|
||||
resp.Request = req
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// Verify that a factory stored on coderd.API.AIBridgeTransportFactory is
|
||||
// observable through the normal API lifecycle: cli/server.go registers it
|
||||
// when the bridge daemon starts (see RegisterInMemoryAIBridgedHTTPHandler).
|
||||
func TestAIBridgeTransportFactory_Registration(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, _, api := coderdtest.NewWithAPI(t, nil)
|
||||
|
||||
require.Nil(t, api.AIBridgeTransportFactory.Load(),
|
||||
"AGPL coderd must not pre-populate the factory")
|
||||
|
||||
stub := &stubTransportFactory{
|
||||
handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(`{"bridged":true}`))
|
||||
}),
|
||||
calls: make(chan callRecord, 4),
|
||||
}
|
||||
|
||||
var asInterface aibridge.TransportFactory = stub
|
||||
api.AIBridgeTransportFactory.Store(&asInterface)
|
||||
|
||||
loaded := api.AIBridgeTransportFactory.Load()
|
||||
require.NotNil(t, loaded)
|
||||
|
||||
providerID := uuid.New()
|
||||
rt, err := (*loaded).TransportFor(providerID, aibridge.SourceAgents)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rt)
|
||||
|
||||
select {
|
||||
case got := <-stub.calls:
|
||||
require.Equal(t, providerID, got.providerID)
|
||||
require.Equal(t, aibridge.SourceAgents, got.source)
|
||||
default:
|
||||
t.Fatal("factory was not invoked")
|
||||
}
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, "http://aibridge/v1/messages", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
client := &http.Client{Transport: rt}
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
require.Equal(t, `{"bridged":true}`, string(body))
|
||||
require.Equal(t, "application/json", resp.Header.Get("Content-Type"))
|
||||
}
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"storj.io/drpc/drpcserver"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
agplaibridge "github.com/coder/coder/v2/coderd/aibridge"
|
||||
"github.com/coder/coder/v2/coderd/aibridged"
|
||||
aibridgedproto "github.com/coder/coder/v2/coderd/aibridged/proto"
|
||||
"github.com/coder/coder/v2/coderd/aibridgedserver"
|
||||
@@ -30,12 +31,22 @@ func (api *API) GetAIBridgedHandler() http.Handler {
|
||||
// RegisterInMemoryAIBridgedHTTPHandler mounts [aibridged.Server]'s HTTP router onto
|
||||
// [API]'s router, so that requests to aibridged will be relayed from Coder's API server
|
||||
// to the in-memory aibridged.
|
||||
//
|
||||
// This also registers an in-process [agplaibridge.TransportFactory] so that
|
||||
// chatd can route coder-agent LLM traffic through aibridge without crossing
|
||||
// the HTTP route. No license entitlement gate is applied at the factory layer:
|
||||
// the entitlement check stays on the HTTP route for external callers, while
|
||||
// in-process coder-agent traffic is the explicit carve-out.
|
||||
func (api *API) RegisterInMemoryAIBridgedHTTPHandler(srv http.Handler) {
|
||||
if srv == nil {
|
||||
panic("aibridged cannot be nil")
|
||||
}
|
||||
|
||||
api.aibridgedHandler = srv
|
||||
|
||||
factory := aibridged.NewTransportFactory(srv)
|
||||
var asInterface agplaibridge.TransportFactory = factory
|
||||
api.AIBridgeTransportFactory.Store(&asInterface)
|
||||
}
|
||||
|
||||
// CreateInMemoryAIBridgeServer creates a [aibridged.DRPCServer] and returns a
|
||||
|
||||
@@ -0,0 +1,172 @@
|
||||
package aibridged
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/aibridge"
|
||||
)
|
||||
|
||||
// NewTransportFactory returns an [aibridge.TransportFactory] whose RoundTripper
|
||||
// dispatches requests to handler in-process, streaming the response body
|
||||
// through an [io.Pipe] so SSE/NDJSON/chunked responses propagate token-by-token
|
||||
// just as they would over the wire.
|
||||
//
|
||||
// handler is typically the aibridged HTTP entrypoint registered via
|
||||
// [API.RegisterInMemoryAIBridgedHTTPHandler].
|
||||
func NewTransportFactory(handler http.Handler) aibridge.TransportFactory {
|
||||
return &transportFactory{handler: handler}
|
||||
}
|
||||
|
||||
type transportFactory struct {
|
||||
handler http.Handler
|
||||
}
|
||||
|
||||
// TransportFor returns an in-process [http.RoundTripper] that dispatches
|
||||
// requests through the aibridged handler. The source is attached to the
|
||||
// request context for downstream logging; routing does not depend on it.
|
||||
func (f *transportFactory) TransportFor(_ uuid.UUID, source aibridge.Source) (http.RoundTripper, error) {
|
||||
if f.handler == nil {
|
||||
return nil, xerrors.New("aibridged handler not registered")
|
||||
}
|
||||
return &inMemoryRoundTripper{handler: f.handler, source: source}, nil
|
||||
}
|
||||
|
||||
// inMemoryRoundTripper implements [http.RoundTripper] by invoking handler
|
||||
// in a goroutine and streaming its response back through an [io.Pipe].
|
||||
type inMemoryRoundTripper struct {
|
||||
handler http.Handler
|
||||
source aibridge.Source
|
||||
}
|
||||
|
||||
func (t *inMemoryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
pr, pw := io.Pipe()
|
||||
rw := &pipeResponseWriter{
|
||||
header: http.Header{},
|
||||
body: pw,
|
||||
gotHeaders: make(chan struct{}),
|
||||
status: http.StatusOK,
|
||||
}
|
||||
|
||||
// Cloning preserves caller-supplied headers and context but lets the
|
||||
// handler operate on its own request value without surprising the caller
|
||||
// if it mutates Headers or stores the request. The Source is attached to
|
||||
// the served context so downstream handlers can log the call site.
|
||||
served := req.Clone(aibridge.WithSource(req.Context(), t.source))
|
||||
|
||||
handlerDone := make(chan struct{})
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
// Mirror net/http.Server behavior: a panicking handler
|
||||
// produces a 500 instead of crashing the process.
|
||||
rw.WriteHeader(http.StatusInternalServerError)
|
||||
_ = pw.CloseWithError(xerrors.Errorf("handler panicked: %v", r))
|
||||
}
|
||||
// Make sure we always unblock RoundTrip even if the handler
|
||||
// returns before writing headers (e.g. handler returns early
|
||||
// without writing).
|
||||
rw.ensureHeaders()
|
||||
// If the request context was canceled, surface that as a
|
||||
// body-read error so the caller sees a network-style failure
|
||||
// rather than EOF. Otherwise close cleanly.
|
||||
if cerr := served.Context().Err(); cerr != nil {
|
||||
_ = pw.CloseWithError(cerr)
|
||||
} else {
|
||||
_ = pw.Close()
|
||||
}
|
||||
close(handlerDone)
|
||||
}()
|
||||
t.handler.ServeHTTP(rw, served)
|
||||
}()
|
||||
|
||||
// Close the pipe eagerly when the caller cancels, so an unresponsive
|
||||
// handler does not strand the consumer's body read. The handler's own
|
||||
// context derives from req.Context(), so it observes the same
|
||||
// cancellation independently. The goroutine also exits when the handler
|
||||
// completes normally (handlerDone closes) to avoid leaking a parked
|
||||
// goroutine per successful request.
|
||||
go func() {
|
||||
select {
|
||||
case <-served.Context().Done():
|
||||
_ = pw.CloseWithError(served.Context().Err())
|
||||
case <-handlerDone:
|
||||
// Handler finished; nothing to cancel.
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-rw.gotHeaders:
|
||||
case <-served.Context().Done():
|
||||
return nil, served.Context().Err()
|
||||
}
|
||||
|
||||
return &http.Response{
|
||||
Status: fmt.Sprintf("%d %s", rw.status, http.StatusText(rw.status)),
|
||||
StatusCode: rw.status,
|
||||
Proto: "HTTP/1.1",
|
||||
ProtoMajor: 1,
|
||||
ProtoMinor: 1,
|
||||
Header: rw.frozenHeader,
|
||||
Body: pr,
|
||||
Request: req,
|
||||
ContentLength: -1, // streaming; unknown length
|
||||
}, nil
|
||||
}
|
||||
|
||||
// pipeResponseWriter is an [http.ResponseWriter] that streams the response
|
||||
// body into an [io.PipeWriter]. The first call to WriteHeader (implicit or
|
||||
// explicit) closes gotHeaders so the RoundTrip caller can return an
|
||||
// *http.Response while the handler keeps writing.
|
||||
type pipeResponseWriter struct {
|
||||
header http.Header
|
||||
frozenHeader http.Header
|
||||
body *io.PipeWriter
|
||||
|
||||
once sync.Once
|
||||
gotHeaders chan struct{}
|
||||
status int
|
||||
}
|
||||
|
||||
func (w *pipeResponseWriter) Header() http.Header {
|
||||
return w.header
|
||||
}
|
||||
|
||||
func (w *pipeResponseWriter) WriteHeader(status int) {
|
||||
w.once.Do(func() {
|
||||
w.status = status
|
||||
w.frozenHeader = w.header.Clone()
|
||||
close(w.gotHeaders)
|
||||
})
|
||||
}
|
||||
|
||||
func (w *pipeResponseWriter) Write(p []byte) (int, error) {
|
||||
// net/http semantics: an implicit 200 OK on first Write if the handler
|
||||
// did not call WriteHeader explicitly.
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return w.body.Write(p)
|
||||
}
|
||||
|
||||
// Flush is a no-op: pipe writes are already synchronous with the reader, so
|
||||
// each Write is observed as soon as the reader consumes it. We satisfy
|
||||
// [http.Flusher] so handlers that type-assert it (the aibridge library does
|
||||
// for SSE) do not fall back to buffered mode.
|
||||
func (*pipeResponseWriter) Flush() {}
|
||||
|
||||
// ensureHeaders closes gotHeaders if it has not already been closed, with the
|
||||
// current status. Used to unblock RoundTrip on handler return-without-write.
|
||||
func (w *pipeResponseWriter) ensureHeaders() {
|
||||
w.once.Do(func() {
|
||||
close(w.gotHeaders)
|
||||
})
|
||||
}
|
||||
|
||||
var (
|
||||
_ http.ResponseWriter = (*pipeResponseWriter)(nil)
|
||||
_ http.Flusher = (*pipeResponseWriter)(nil)
|
||||
)
|
||||
@@ -0,0 +1,289 @@
|
||||
package aibridged_test
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/aibridge"
|
||||
"github.com/coder/coder/v2/coderd/aibridged"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestTransportFactory_TransportFor(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("ReturnsTransport", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
f := aibridged.NewTransportFactory(http.NotFoundHandler())
|
||||
rt, err := f.TransportFor(uuid.New(), aibridge.SourceAgents)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rt)
|
||||
})
|
||||
|
||||
t.Run("NilHandlerErrors", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
f := aibridged.NewTransportFactory(nil)
|
||||
_, err := f.TransportFor(uuid.New(), aibridge.SourceAgents)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("AttachesSourceToContext", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := make(chan aibridge.Source, 1)
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
got <- aibridge.SourceFromContext(r.Context())
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
rt, err := aibridged.NewTransportFactory(handler).TransportFor(uuid.New(), aibridge.SourceAgents)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, "http://aibridge/v1/test", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := rt.RoundTrip(req)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
require.Equal(t, aibridge.SourceAgents, <-got)
|
||||
})
|
||||
}
|
||||
|
||||
func TestInMemoryRoundTripper_PassesHeadersAndStatus(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Custom", "yes")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusTeapot)
|
||||
_, _ = w.Write([]byte(`{"ok":true}`))
|
||||
})
|
||||
|
||||
rt, err := aibridged.NewTransportFactory(handler).TransportFor(uuid.New(), aibridge.SourceAgents)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, "http://aibridge/v1/test", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := rt.RoundTrip(req)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
require.Equal(t, http.StatusTeapot, resp.StatusCode)
|
||||
require.Equal(t, "418 I'm a teapot", resp.Status)
|
||||
require.Equal(t, "yes", resp.Header.Get("X-Custom"))
|
||||
require.Equal(t, "application/json", resp.Header.Get("Content-Type"))
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, `{"ok":true}`, string(body))
|
||||
}
|
||||
|
||||
// Verify that response chunks become readable on the client side before the
|
||||
// handler has finished writing. This is the property SSE/NDJSON streaming
|
||||
// depends on.
|
||||
func TestInMemoryRoundTripper_Streams(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const chunks = 4
|
||||
released := make([]chan struct{}, chunks)
|
||||
for i := range released {
|
||||
released[i] = make(chan struct{})
|
||||
}
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !assert.True(t, ok, "ResponseWriter must implement http.Flusher") {
|
||||
return
|
||||
}
|
||||
for i := range chunks {
|
||||
<-released[i]
|
||||
_, err := fmt.Fprintf(w, "data: chunk-%d\n\n", i)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
})
|
||||
|
||||
rt, err := aibridged.NewTransportFactory(handler).TransportFor(uuid.New(), aibridge.SourceAgents)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, "http://aibridge/stream", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := rt.RoundTrip(req)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
br := bufio.NewReader(resp.Body)
|
||||
for i := range chunks {
|
||||
close(released[i])
|
||||
dataLine, err := br.ReadString('\n')
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, fmt.Sprintf("data: chunk-%d\n", i), dataLine)
|
||||
// Consume blank-line separator.
|
||||
_, err = br.ReadString('\n')
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Canceling the request context must surface as a body-read error, matching
|
||||
// real-network behavior, and the handler must observe the cancellation
|
||||
// through its own request context.
|
||||
func TestInMemoryRoundTripper_CancelCloses(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handlerCtxObserved := make(chan struct{})
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
<-r.Context().Done()
|
||||
close(handlerCtxObserved)
|
||||
})
|
||||
|
||||
rt, err := aibridged.NewTransportFactory(handler).TransportFor(uuid.New(), aibridge.SourceAgents)
|
||||
require.NoError(t, err)
|
||||
|
||||
parentCtx := testutil.Context(t, testutil.WaitShort)
|
||||
ctx, cancel := context.WithCancel(parentCtx)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, "http://aibridge/stream", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := rt.RoundTrip(req)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
cancel()
|
||||
_, err = io.ReadAll(resp.Body)
|
||||
require.Error(t, err)
|
||||
|
||||
select {
|
||||
case <-handlerCtxObserved:
|
||||
case <-parentCtx.Done():
|
||||
t.Fatal("handler did not observe context cancellation")
|
||||
}
|
||||
}
|
||||
|
||||
// Many independent in-flight requests on a shared handler must not interfere.
|
||||
func TestInMemoryRoundTripper_ConcurrentRequests(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(body)
|
||||
})
|
||||
|
||||
rt, err := aibridged.NewTransportFactory(handler).TransportFor(uuid.New(), aibridge.SourceAgents)
|
||||
require.NoError(t, err)
|
||||
|
||||
const n = 16
|
||||
errs := make(chan error, n)
|
||||
var wg sync.WaitGroup
|
||||
for i := range n {
|
||||
wg.Go(func() {
|
||||
payload := fmt.Sprintf("payload-%d", i)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, "http://aibridge/echo", strings.NewReader(payload))
|
||||
if err != nil {
|
||||
errs <- err
|
||||
return
|
||||
}
|
||||
resp, err := rt.RoundTrip(req)
|
||||
if err != nil {
|
||||
errs <- err
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
got, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
errs <- err
|
||||
return
|
||||
}
|
||||
if string(got) != payload {
|
||||
errs <- xerrors.Errorf("payload mismatch: want %q got %q", payload, string(got))
|
||||
return
|
||||
}
|
||||
errs <- nil
|
||||
})
|
||||
}
|
||||
wg.Wait()
|
||||
close(errs)
|
||||
for err := range errs {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
// A panicking handler must not crash the process; it should produce a 500
|
||||
// response with an error on the body read, mirroring net/http.Server behavior.
|
||||
func TestInMemoryRoundTripper_HandlerPanic(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
panic("unexpected nil pointer")
|
||||
})
|
||||
|
||||
rt, err := aibridged.NewTransportFactory(handler).TransportFor(uuid.New(), aibridge.SourceAgents)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, "http://aibridge/panic", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := rt.RoundTrip(req)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
require.Equal(t, http.StatusInternalServerError, resp.StatusCode)
|
||||
_, err = io.ReadAll(resp.Body)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "handler panicked")
|
||||
}
|
||||
|
||||
// A handler that returns without writing must not block RoundTrip; the caller
|
||||
// gets a zero-length 200 OK.
|
||||
func TestInMemoryRoundTripper_HandlerReturnsWithoutWriting(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
||||
|
||||
rt, err := aibridged.NewTransportFactory(handler).TransportFor(uuid.New(), aibridge.SourceAgents)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, "http://aibridge/noop", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := rt.RoundTrip(req)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, body)
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
}
|
||||
+9
-2
@@ -45,6 +45,7 @@ import (
|
||||
"github.com/coder/coder/v2/buildinfo"
|
||||
"github.com/coder/coder/v2/coderd/agentapi"
|
||||
"github.com/coder/coder/v2/coderd/agentapi/metadatabatcher"
|
||||
"github.com/coder/coder/v2/coderd/aibridge"
|
||||
"github.com/coder/coder/v2/coderd/aibridge/prices"
|
||||
"github.com/coder/coder/v2/coderd/aiseats"
|
||||
_ "github.com/coder/coder/v2/coderd/apidoc" // Used for swagger docs.
|
||||
@@ -2198,9 +2199,15 @@ type API struct {
|
||||
// UsageInserter is a pointer to an atomic pointer because it is passed to
|
||||
// multiple components.
|
||||
UsageInserter *atomic.Pointer[usage.Inserter]
|
||||
// AIBridgeTransportFactory, when non-nil, lets chatd route LLM requests
|
||||
// through an in-process aibridge transport instead of calling upstream
|
||||
// providers directly. Registered by coderd at startup once aibridged is
|
||||
// wired in-memory.
|
||||
AIBridgeTransportFactory atomic.Pointer[aibridge.TransportFactory]
|
||||
// aibridgedHandler is the in-memory aibridge HTTP handler. Set by
|
||||
// RegisterInMemoryAIBridgedHTTPHandler; read by the enterprise
|
||||
// /api/v2/aibridge route (license-gated).
|
||||
// RegisterInMemoryAIBridgedHTTPHandler; read both by the enterprise
|
||||
// /api/v2/aibridge route (license-gated) and by the in-memory transport
|
||||
// (used by chatd, license-exempt).
|
||||
aibridgedHandler http.Handler
|
||||
|
||||
UpdatesProvider tailnet.WorkspaceUpdatesProvider
|
||||
|
||||
Reference in New Issue
Block a user