diff --git a/enterprise/coderd/aibridge.go b/enterprise/coderd/aibridge.go index bdd2a99166..3f3b1d6789 100644 --- a/enterprise/coderd/aibridge.go +++ b/enterprise/coderd/aibridge.go @@ -4,7 +4,9 @@ import ( "context" "fmt" "net/http" + "strings" + "github.com/go-chi/chi/v5" "github.com/google/uuid" "golang.org/x/xerrors" @@ -23,6 +25,38 @@ const ( defaultListInterceptionsLimit = 100 ) +// aibridgeHandler handles all aibridged-related endpoints. +func aibridgeHandler(api *API, middlewares ...func(http.Handler) http.Handler) func(r chi.Router) { + return func(r chi.Router) { + r.Use(api.RequireFeatureMW(codersdk.FeatureAIBridge)) + r.Group(func(r chi.Router) { + r.Use(middlewares...) + r.Get("/interceptions", api.aiBridgeListInterceptions) + }) + + // This is a bit funky but since aibridge only exposes a HTTP + // handler, this is how it has to be. + r.HandleFunc("/*", func(rw http.ResponseWriter, r *http.Request) { + if api.aibridgedHandler == nil { + httpapi.Write(r.Context(), rw, http.StatusNotFound, codersdk.Response{ + Message: "aibridged handler not mounted", + }) + return + } + + // Strip either the experimental or stable prefix. + // TODO: experimental route is deprecated and must be removed with Beta. + prefixes := []string{"/api/experimental/aibridge", "/api/v2/aibridge"} + for _, prefix := range prefixes { + if strings.Contains(r.URL.String(), prefix) { + http.StripPrefix(prefix, api.aibridgedHandler).ServeHTTP(rw, r) + break + } + } + }) + } +} + // aiBridgeListInterceptions returns all AIBridge interceptions a user can read. // Optional filters with query params // diff --git a/enterprise/coderd/aibridge_test.go b/enterprise/coderd/aibridge_test.go index 17e5df56fb..db1698f353 100644 --- a/enterprise/coderd/aibridge_test.go +++ b/enterprise/coderd/aibridge_test.go @@ -1,6 +1,7 @@ package coderd_test import ( + "io" "net/http" "testing" "time" @@ -592,3 +593,68 @@ func TestAIBridgeListInterceptions(t *testing.T) { } }) } + +func TestAIBridgeRouting(t *testing.T) { + t.Parallel() + + dv := coderdtest.DeploymentValues(t) + client, closer, api, _ := coderdenttest.NewWithAPI(t, &coderdenttest.Options{ + Options: &coderdtest.Options{ + DeploymentValues: dv, + }, + LicenseOptions: &coderdenttest.LicenseOptions{ + Features: license.Features{ + codersdk.FeatureAIBridge: 1, + }, + }, + }) + t.Cleanup(func() { + _ = closer.Close() + }) + + // Register a simple test handler that echoes back the request path. + testHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + rw.WriteHeader(http.StatusOK) + _, _ = rw.Write([]byte(r.URL.Path)) + }) + api.RegisterInMemoryAIBridgedHTTPHandler(testHandler) + + cases := []struct { + name string + path string + expectedPath string + }{ + { + name: "StablePrefix", + path: "/api/v2/aibridge/openai/v1/chat/completions", + expectedPath: "/openai/v1/chat/completions", + }, + { + name: "ExperimentalPrefix", + path: "/api/experimental/aibridge/openai/v1/chat/completions", + expectedPath: "/openai/v1/chat/completions", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, client.URL.String()+tc.path, nil) + require.NoError(t, err) + req.Header.Set(codersdk.SessionTokenHeader, client.SessionToken()) + + httpClient := &http.Client{} + resp, err := httpClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + require.Equal(t, http.StatusOK, resp.StatusCode) + + // Verify that the prefix was stripped correctly and the path was forwarded. + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, tc.expectedPath, string(body)) + }) + } +} diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index a4adb0479b..00a78c0bd9 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -226,26 +226,14 @@ func New(ctx context.Context, options *Options) (_ *API, err error) { return api.refreshEntitlements(ctx) } - api.AGPL.APIHandler.Group(func(r chi.Router) { - r.Route("/aibridge", func(r chi.Router) { - r.Use(api.RequireFeatureMW(codersdk.FeatureAIBridge)) - r.Group(func(r chi.Router) { - r.Use(apiKeyMiddleware) - r.Get("/interceptions", api.aiBridgeListInterceptions) - }) + api.AGPL.ExperimentalHandler.Group(func(r chi.Router) { + // Deprecated. + // TODO: remove with Beta release. + r.Route("/aibridge", aibridgeHandler(api, apiKeyMiddleware)) + }) - // This is a bit funky but since aibridge only exposes a HTTP - // handler, this is how it has to be. - r.HandleFunc("/*", func(rw http.ResponseWriter, r *http.Request) { - if api.aibridgedHandler == nil { - httpapi.Write(r.Context(), rw, http.StatusNotFound, codersdk.Response{ - Message: "aibridged handler not mounted", - }) - return - } - http.StripPrefix("/api/v2/aibridge", api.aibridgedHandler).ServeHTTP(rw, r) - }) - }) + api.AGPL.APIHandler.Group(func(r chi.Router) { + r.Route("/aibridge", aibridgeHandler(api, apiKeyMiddleware)) }) api.AGPL.APIHandler.Group(func(r chi.Router) {