mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
8652ef3e3b
Delegate `aibridge` routing responsibility to the in-memory transport layer. Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
399 lines
12 KiB
Go
399 lines
12 KiB
Go
package aibridged_test
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"golang.org/x/xerrors"
|
|
|
|
"github.com/coder/coder/v2/coderd/aibridge"
|
|
"github.com/coder/coder/v2/coderd/aibridged"
|
|
"github.com/coder/coder/v2/testutil"
|
|
)
|
|
|
|
func TestTransportFactory_TransportFor(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("ReturnsTransport", func(t *testing.T) {
|
|
t.Parallel()
|
|
f := aibridged.NewTransportFactory(http.NotFoundHandler())
|
|
rt, err := f.TransportFor("openai", aibridge.SourceAgents)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, rt)
|
|
})
|
|
|
|
t.Run("NilHandlerErrors", func(t *testing.T) {
|
|
t.Parallel()
|
|
f := aibridged.NewTransportFactory(nil)
|
|
_, err := f.TransportFor("openai", aibridge.SourceAgents)
|
|
require.Error(t, err)
|
|
})
|
|
|
|
t.Run("EmptyProviderErrors", func(t *testing.T) {
|
|
t.Parallel()
|
|
f := aibridged.NewTransportFactory(http.NotFoundHandler())
|
|
_, err := f.TransportFor("", aibridge.SourceAgents)
|
|
require.Error(t, err)
|
|
})
|
|
|
|
t.Run("RewritesURLToAibridgeMount", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// The round-tripper must adapt an upstream-shaped URL.Path
|
|
// ("/v1/messages") to the aibridge mount layout
|
|
// ("/api/v2/aibridge/<provider>/v1/messages") so callers don't
|
|
// have to encode the daemon's routing key into their requests.
|
|
got := make(chan string, 1)
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
got <- r.URL.Path
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
rt, err := aibridged.NewTransportFactory(handler).TransportFor("my-anthropic", aibridge.SourceAgents)
|
|
require.NoError(t, err)
|
|
|
|
ctx := aibridge.WithDelegatedAPIKeyID(testutil.Context(t, testutil.WaitShort), "test-key-id")
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, "http://upstream/v1/messages", nil)
|
|
require.NoError(t, err)
|
|
|
|
// The caller's req.URL.Path is the upstream shape. Capture it so
|
|
// we can prove the transport mutates a clone, not the caller's
|
|
// request, after RoundTrip returns.
|
|
origPath := req.URL.Path
|
|
|
|
resp, err := rt.RoundTrip(req)
|
|
require.NoError(t, err)
|
|
defer resp.Body.Close()
|
|
|
|
require.Equal(t, "/api/v2/aibridge/my-anthropic/v1/messages", <-got)
|
|
require.Equal(t, origPath, req.URL.Path,
|
|
"caller's request URL must not be mutated by RoundTrip")
|
|
})
|
|
|
|
t.Run("AttachesSourceToContext", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
got := make(chan aibridge.Source, 1)
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
got <- aibridge.SourceFromContext(r.Context())
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
rt, err := aibridged.NewTransportFactory(handler).TransportFor("openai", aibridge.SourceAgents)
|
|
require.NoError(t, err)
|
|
|
|
ctx := aibridge.WithDelegatedAPIKeyID(testutil.Context(t, testutil.WaitShort), "test-key-id")
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, "http://aibridge/v1/test", nil)
|
|
require.NoError(t, err)
|
|
|
|
resp, err := rt.RoundTrip(req)
|
|
require.NoError(t, err)
|
|
defer resp.Body.Close()
|
|
|
|
require.Equal(t, aibridge.SourceAgents, <-got)
|
|
})
|
|
}
|
|
|
|
func TestInMemoryRoundTripper_PassesHeadersAndStatus(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("X-Custom", "yes")
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusTeapot)
|
|
_, _ = w.Write([]byte(`{"ok":true}`))
|
|
})
|
|
|
|
rt, err := aibridged.NewTransportFactory(handler).TransportFor("openai", aibridge.SourceAgents)
|
|
require.NoError(t, err)
|
|
|
|
ctx := aibridge.WithDelegatedAPIKeyID(testutil.Context(t, testutil.WaitShort), "test-key-id")
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, "http://aibridge/v1/test", nil)
|
|
require.NoError(t, err)
|
|
|
|
resp, err := rt.RoundTrip(req)
|
|
require.NoError(t, err)
|
|
defer resp.Body.Close()
|
|
|
|
require.Equal(t, http.StatusTeapot, resp.StatusCode)
|
|
require.Equal(t, "418 I'm a teapot", resp.Status)
|
|
require.Equal(t, "yes", resp.Header.Get("X-Custom"))
|
|
require.Equal(t, "application/json", resp.Header.Get("Content-Type"))
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
require.NoError(t, err)
|
|
require.Equal(t, `{"ok":true}`, string(body))
|
|
}
|
|
|
|
// Verify that response chunks become readable on the client side before the
|
|
// handler has finished writing. This is the property SSE/NDJSON streaming
|
|
// depends on.
|
|
func TestInMemoryRoundTripper_Streams(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
const chunks = 4
|
|
released := make([]chan struct{}, chunks)
|
|
for i := range released {
|
|
released[i] = make(chan struct{})
|
|
}
|
|
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
w.WriteHeader(http.StatusOK)
|
|
flusher, ok := w.(http.Flusher)
|
|
if !assert.True(t, ok, "ResponseWriter must implement http.Flusher") {
|
|
return
|
|
}
|
|
for i := range chunks {
|
|
<-released[i]
|
|
_, err := fmt.Fprintf(w, "data: chunk-%d\n\n", i)
|
|
if !assert.NoError(t, err) {
|
|
return
|
|
}
|
|
flusher.Flush()
|
|
}
|
|
})
|
|
|
|
rt, err := aibridged.NewTransportFactory(handler).TransportFor("openai", aibridge.SourceAgents)
|
|
require.NoError(t, err)
|
|
|
|
ctx := aibridge.WithDelegatedAPIKeyID(testutil.Context(t, testutil.WaitShort), "test-key-id")
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, "http://aibridge/stream", nil)
|
|
require.NoError(t, err)
|
|
|
|
resp, err := rt.RoundTrip(req)
|
|
require.NoError(t, err)
|
|
defer resp.Body.Close()
|
|
|
|
br := bufio.NewReader(resp.Body)
|
|
for i := range chunks {
|
|
close(released[i])
|
|
dataLine, err := br.ReadString('\n')
|
|
require.NoError(t, err)
|
|
require.Equal(t, fmt.Sprintf("data: chunk-%d\n", i), dataLine)
|
|
// Consume blank-line separator.
|
|
_, err = br.ReadString('\n')
|
|
require.NoError(t, err)
|
|
}
|
|
}
|
|
|
|
// Canceling the request context must surface as a body-read error, matching
|
|
// real-network behavior, and the handler must observe the cancellation
|
|
// through its own request context.
|
|
func TestInMemoryRoundTripper_CancelCloses(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
handlerCtxObserved := make(chan struct{})
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
if f, ok := w.(http.Flusher); ok {
|
|
f.Flush()
|
|
}
|
|
<-r.Context().Done()
|
|
close(handlerCtxObserved)
|
|
})
|
|
|
|
rt, err := aibridged.NewTransportFactory(handler).TransportFor("openai", aibridge.SourceAgents)
|
|
require.NoError(t, err)
|
|
|
|
parentCtx := testutil.Context(t, testutil.WaitShort)
|
|
ctx, cancel := context.WithCancel(parentCtx)
|
|
ctx = aibridge.WithDelegatedAPIKeyID(ctx, "test-key-id")
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, "http://aibridge/stream", nil)
|
|
require.NoError(t, err)
|
|
|
|
resp, err := rt.RoundTrip(req)
|
|
require.NoError(t, err)
|
|
defer resp.Body.Close()
|
|
|
|
cancel()
|
|
_, err = io.ReadAll(resp.Body)
|
|
require.Error(t, err)
|
|
|
|
select {
|
|
case <-handlerCtxObserved:
|
|
case <-parentCtx.Done():
|
|
t.Fatal("handler did not observe context cancellation")
|
|
}
|
|
}
|
|
|
|
// Many independent in-flight requests on a shared handler must not interfere.
|
|
func TestInMemoryRoundTripper_ConcurrentRequests(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
body, err := io.ReadAll(r.Body)
|
|
if err != nil {
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
return
|
|
}
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = w.Write(body)
|
|
})
|
|
|
|
rt, err := aibridged.NewTransportFactory(handler).TransportFor("openai", aibridge.SourceAgents)
|
|
require.NoError(t, err)
|
|
|
|
const n = 16
|
|
errs := make(chan error, n)
|
|
var wg sync.WaitGroup
|
|
for i := range n {
|
|
wg.Go(func() {
|
|
payload := fmt.Sprintf("payload-%d", i)
|
|
ctx := aibridge.WithDelegatedAPIKeyID(testutil.Context(t, testutil.WaitShort), "test-key-id")
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, "http://aibridge/echo", strings.NewReader(payload))
|
|
if err != nil {
|
|
errs <- err
|
|
return
|
|
}
|
|
resp, err := rt.RoundTrip(req)
|
|
if err != nil {
|
|
errs <- err
|
|
return
|
|
}
|
|
defer resp.Body.Close()
|
|
got, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
errs <- err
|
|
return
|
|
}
|
|
if string(got) != payload {
|
|
errs <- xerrors.Errorf("payload mismatch: want %q got %q", payload, string(got))
|
|
return
|
|
}
|
|
errs <- nil
|
|
})
|
|
}
|
|
wg.Wait()
|
|
close(errs)
|
|
for err := range errs {
|
|
require.NoError(t, err)
|
|
}
|
|
}
|
|
|
|
// A panicking handler must not crash the process; it should produce a 500
|
|
// response with an error on the body read, mirroring net/http.Server behavior.
|
|
func TestInMemoryRoundTripper_HandlerPanic(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
panic("unexpected nil pointer")
|
|
})
|
|
|
|
rt, err := aibridged.NewTransportFactory(handler).TransportFor("openai", aibridge.SourceAgents)
|
|
require.NoError(t, err)
|
|
|
|
ctx := aibridge.WithDelegatedAPIKeyID(testutil.Context(t, testutil.WaitShort), "test-key-id")
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, "http://aibridge/panic", nil)
|
|
require.NoError(t, err)
|
|
|
|
resp, err := rt.RoundTrip(req)
|
|
require.NoError(t, err)
|
|
defer resp.Body.Close()
|
|
|
|
require.Equal(t, http.StatusInternalServerError, resp.StatusCode)
|
|
_, err = io.ReadAll(resp.Body)
|
|
require.Error(t, err)
|
|
require.Contains(t, err.Error(), "handler panicked")
|
|
}
|
|
|
|
// The in-memory transport must reject any RoundTrip whose context does not
|
|
// carry a delegated API key ID. The handler relies on this invariant to know
|
|
// the request has a delegated identity attached.
|
|
func TestInMemoryRoundTripper_RequiresDelegatedAPIKeyID(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
withCtx func(context.Context) context.Context
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "missing delegated key ID",
|
|
withCtx: func(ctx context.Context) context.Context { return ctx },
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "empty delegated key ID",
|
|
withCtx: func(ctx context.Context) context.Context {
|
|
return aibridge.WithDelegatedAPIKeyID(ctx, "")
|
|
},
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "valid delegated key ID",
|
|
withCtx: func(ctx context.Context) context.Context {
|
|
return aibridge.WithDelegatedAPIKeyID(ctx, "test-key-id")
|
|
},
|
|
wantErr: false,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
handlerCalled := make(chan struct{}, 1)
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
handlerCalled <- struct{}{}
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
rt, err := aibridged.NewTransportFactory(handler).TransportFor("openai", aibridge.SourceAgents)
|
|
require.NoError(t, err)
|
|
|
|
ctx := tc.withCtx(testutil.Context(t, testutil.WaitShort))
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, "http://aibridge/v1/test", nil)
|
|
require.NoError(t, err)
|
|
|
|
resp, err := rt.RoundTrip(req)
|
|
if tc.wantErr {
|
|
require.Error(t, err)
|
|
require.Contains(t, err.Error(), "WithDelegatedAPIKeyID")
|
|
// Handler must not have been invoked.
|
|
select {
|
|
case <-handlerCalled:
|
|
t.Fatal("handler invoked despite transport rejecting the request")
|
|
default:
|
|
}
|
|
return
|
|
}
|
|
require.NoError(t, err)
|
|
defer resp.Body.Close()
|
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
|
})
|
|
}
|
|
}
|
|
|
|
// A handler that returns without writing must not block RoundTrip; the caller
|
|
// gets a zero-length 200 OK.
|
|
func TestInMemoryRoundTripper_HandlerReturnsWithoutWriting(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
|
|
|
rt, err := aibridged.NewTransportFactory(handler).TransportFor("openai", aibridge.SourceAgents)
|
|
require.NoError(t, err)
|
|
|
|
ctx := aibridge.WithDelegatedAPIKeyID(testutil.Context(t, testutil.WaitShort), "test-key-id")
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, "http://aibridge/noop", nil)
|
|
require.NoError(t, err)
|
|
|
|
resp, err := rt.RoundTrip(req)
|
|
require.NoError(t, err)
|
|
defer resp.Body.Close()
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
require.NoError(t, err)
|
|
require.Empty(t, body)
|
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
|
}
|