refactor: move aibridged out of enterprise to AGPL (#25570)

In order to allow Coder Agents to use AI Gateway in OSS, we need to rehome the `aibridged`\-related code into the AGPL path.

The HTTP API is only registered under enterprise so will still require the AI Governance Add-on to be present in order to use it, whereas Coder Agents uses an in-memory pipe to the same handlers.
This commit is contained in:
Danny Kopping
2026-05-22 09:11:37 +02:00
committed by GitHub
parent c50b0e84b9
commit ddec110b0e
32 changed files with 631 additions and 603 deletions
+11 -11
View File
@@ -53,8 +53,8 @@ endif
tailnet/tailnettest/coordinateemock.go \
tailnet/tailnettest/workspaceupdatesprovidermock.go \
tailnet/tailnettest/subscriptionmock.go \
enterprise/aibridged/aibridgedmock/clientmock.go \
enterprise/aibridged/aibridgedmock/poolmock.go \
coderd/aibridged/aibridgedmock/clientmock.go \
coderd/aibridged/aibridgedmock/poolmock.go \
tailnet/proto/tailnet.pb.go \
agent/proto/agent.pb.go \
agent/agentsocket/proto/agentsocket.pb.go \
@@ -62,7 +62,7 @@ endif
provisionersdk/proto/provisioner.pb.go \
provisionerd/proto/provisionerd.pb.go \
vpn/vpn.pb.go \
enterprise/aibridged/proto/aibridged.pb.go \
coderd/aibridged/proto/aibridged.pb.go \
site/src/api/typesGenerated.ts \
site/e2e/provisionerGenerated.ts \
site/src/api/chatModelOptionsGenerated.json \
@@ -956,8 +956,8 @@ TAILNETTEST_MOCKS := \
tailnet/tailnettest/subscriptionmock.go
AIBRIDGED_MOCKS := \
enterprise/aibridged/aibridgedmock/clientmock.go \
enterprise/aibridged/aibridgedmock/poolmock.go
coderd/aibridged/aibridgedmock/clientmock.go \
coderd/aibridged/aibridgedmock/poolmock.go
GEN_FILES := \
tailnet/proto/tailnet.pb.go \
@@ -967,7 +967,7 @@ GEN_FILES := \
provisionersdk/proto/provisioner.pb.go \
provisionerd/proto/provisionerd.pb.go \
vpn/vpn.pb.go \
enterprise/aibridged/proto/aibridged.pb.go \
coderd/aibridged/proto/aibridged.pb.go \
$(DB_GEN_FILES) \
$(SITE_GEN_FILES) \
coderd/rbac/object_gen.go \
@@ -1032,7 +1032,7 @@ gen/mark-fresh:
agent/agentsocket/proto/agentsocket.pb.go \
agent/boundarylogproxy/codec/boundary.pb.go \
vpn/vpn.pb.go \
enterprise/aibridged/proto/aibridged.pb.go \
coderd/aibridged/proto/aibridged.pb.go \
coderd/database/dump.sql \
coderd/database/querier.go \
coderd/database/unique_constraint.go \
@@ -1121,8 +1121,8 @@ codersdk/workspacesdk/agentconnmock/agentconnmock.go: codersdk/workspacesdk/agen
./scripts/format_go_file.sh "$@"
touch "$@"
$(AIBRIDGED_MOCKS): enterprise/aibridged/client.go enterprise/aibridged/pool.go
go generate ./enterprise/aibridged/aibridgedmock/
$(AIBRIDGED_MOCKS): coderd/aibridged/client.go coderd/aibridged/pool.go
go generate ./coderd/aibridged/aibridgedmock/
touch "$@"
agent/agentcontainers/dcspec/dcspec_gen.go: \
@@ -1189,13 +1189,13 @@ agent/boundarylogproxy/codec/boundary.pb.go: agent/boundarylogproxy/codec/bounda
--go_opt=paths=source_relative \
./agent/boundarylogproxy/codec/boundary.proto
enterprise/aibridged/proto/aibridged.pb.go: enterprise/aibridged/proto/aibridged.proto
coderd/aibridged/proto/aibridged.pb.go: coderd/aibridged/proto/aibridged.proto
./scripts/atomic_protoc.sh \
--go_out=. \
--go_opt=paths=source_relative \
--go-drpc_out=. \
--go-drpc_opt=paths=source_relative \
./enterprise/aibridged/proto/aibridged.proto
./coderd/aibridged/proto/aibridged.proto
site/src/api/typesGenerated.ts: site/node_modules/.installed $(wildcard scripts/apitypings/*) \
$(shell find ./codersdk $(FIND_EXCLUSIONS) -type f -name '*.go') \
@@ -11,10 +11,10 @@ import (
"github.com/coder/coder/v2/aibridge"
"github.com/coder/coder/v2/aibridge/config"
"github.com/coder/coder/v2/aibridge/keypool"
"github.com/coder/coder/v2/coderd"
"github.com/coder/coder/v2/coderd/aibridged"
"github.com/coder/coder/v2/coderd/tracing"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/enterprise/aibridged"
"github.com/coder/coder/v2/enterprise/coderd"
"github.com/coder/quartz"
)
@@ -44,13 +44,13 @@ func newAIBridgeDaemon(coderAPI *coderd.API, providers []aibridge.Provider) (*ai
return srv, nil
}
// buildProviders constructs the list of AI providers from config.
// BuildProviders constructs the list of AI providers from config.
// It merges legacy single-provider env vars and indexed provider configs:
// 1. Legacy providers (from CODER_AI_GATEWAY_OPENAI_KEY, etc.) are added first.
// If a legacy name conflicts with an indexed provider, startup fails with
// a clear error asking the admin to remove one or the other.
// 2. Indexed providers (from CODER_AI_GATEWAY_PROVIDER_<N>_*) are added next.
func buildProviders(cfg codersdk.AIBridgeConfig) ([]aibridge.Provider, error) {
func BuildProviders(cfg codersdk.AIBridgeConfig) ([]aibridge.Provider, error) {
var cbConfig *config.CircuitBreaker
if cfg.CircuitBreakerEnabled.Value() {
cbConfig = &config.CircuitBreaker{
@@ -19,7 +19,7 @@ func TestBuildProviders(t *testing.T) {
t.Run("EmptyConfig", func(t *testing.T) {
t.Parallel()
providers, err := buildProviders(codersdk.AIBridgeConfig{})
providers, err := BuildProviders(codersdk.AIBridgeConfig{})
require.NoError(t, err)
assert.Empty(t, providers)
})
@@ -30,7 +30,7 @@ func TestBuildProviders(t *testing.T) {
cfg.LegacyOpenAI.Key = serpent.String("sk-openai")
cfg.LegacyAnthropic.Key = serpent.String("sk-anthropic")
providers, err := buildProviders(cfg)
providers, err := BuildProviders(cfg)
require.NoError(t, err)
names := providerNames(providers)
@@ -59,7 +59,7 @@ func TestBuildProviders(t *testing.T) {
},
}
providers, err := buildProviders(cfg)
providers, err := BuildProviders(cfg)
require.NoError(t, err)
names := providerNames(providers)
@@ -77,7 +77,7 @@ func TestBuildProviders(t *testing.T) {
}
cfg.LegacyOpenAI.Key = serpent.String("sk-legacy")
_, err := buildProviders(cfg)
_, err := BuildProviders(cfg)
require.Error(t, err)
assert.Contains(t, err.Error(), "conflicts with indexed provider")
})
@@ -91,7 +91,7 @@ func TestBuildProviders(t *testing.T) {
}
cfg.LegacyAnthropic.Key = serpent.String("sk-legacy")
_, err := buildProviders(cfg)
_, err := BuildProviders(cfg)
require.Error(t, err)
assert.Contains(t, err.Error(), "conflicts with indexed provider")
})
@@ -106,7 +106,7 @@ func TestBuildProviders(t *testing.T) {
cfg.LegacyOpenAI.Key = serpent.String("sk-openai")
cfg.LegacyAnthropic.Key = serpent.String("sk-anthropic")
providers, err := buildProviders(cfg)
providers, err := BuildProviders(cfg)
require.NoError(t, err)
names := providerNames(providers)
@@ -123,7 +123,7 @@ func TestBuildProviders(t *testing.T) {
cfg.LegacyBedrock.AccessKey = serpent.String("AKID")
cfg.LegacyBedrock.AccessKeySecret = serpent.String("secret")
providers, err := buildProviders(cfg)
providers, err := BuildProviders(cfg)
require.NoError(t, err)
names := providerNames(providers)
@@ -139,7 +139,7 @@ func TestBuildProviders(t *testing.T) {
cfg.LegacyBedrock.AccessKey = serpent.String("AKID")
cfg.LegacyBedrock.AccessKeySecret = serpent.String("secret")
providers, err := buildProviders(cfg)
providers, err := BuildProviders(cfg)
require.NoError(t, err)
require.Len(t, providers, 1)
@@ -156,7 +156,7 @@ func TestBuildProviders(t *testing.T) {
},
}
_, err := buildProviders(cfg)
_, err := BuildProviders(cfg)
require.Error(t, err)
assert.Contains(t, err.Error(), "unknown provider type")
})
@@ -173,7 +173,7 @@ func TestBuildProviders(t *testing.T) {
},
}
providers, err := buildProviders(cfg)
providers, err := BuildProviders(cfg)
require.NoError(t, err)
require.Len(t, providers, 3)
@@ -195,7 +195,7 @@ func TestBuildProviders(t *testing.T) {
},
}
providers, err := buildProviders(cfg)
providers, err := BuildProviders(cfg)
require.NoError(t, err)
require.Len(t, providers, 1)
@@ -211,73 +211,3 @@ func providerNames(providers []aibridge.Provider) []string {
}
return names
}
func TestDomainsFromProviders(t *testing.T) {
t.Parallel()
t.Run("ExtractsHostnames", func(t *testing.T) {
t.Parallel()
providers, err := buildProviders(codersdk.AIBridgeConfig{
Providers: []codersdk.AIProviderConfig{
{Type: aibridge.ProviderOpenAI, Name: "openai", Keys: []string{"k"}},
{Type: aibridge.ProviderAnthropic, Name: "anthropic", Keys: []string{"k"}},
{Type: aibridge.ProviderOpenAI, Name: "custom", Keys: []string{"k"}, BaseURL: "https://custom-llm.example.com:8443/api"},
},
})
require.NoError(t, err)
domains, mapping := domainsFromProviders(providers)
assert.Contains(t, domains, "api.openai.com")
assert.Contains(t, domains, "api.anthropic.com")
assert.Contains(t, domains, "custom-llm.example.com")
assert.Equal(t, "openai", mapping("api.openai.com"))
assert.Equal(t, "anthropic", mapping("api.anthropic.com"))
assert.Equal(t, "custom", mapping("custom-llm.example.com"))
assert.Empty(t, mapping("unknown.com"))
})
t.Run("DeduplicatesSameHost", func(t *testing.T) {
t.Parallel()
providers, err := buildProviders(codersdk.AIBridgeConfig{
Providers: []codersdk.AIProviderConfig{
{Type: aibridge.ProviderOpenAI, Name: "first", Keys: []string{"k"}, BaseURL: "https://api.example.com/v1"},
{Type: aibridge.ProviderOpenAI, Name: "second", Keys: []string{"k"}, BaseURL: "https://api.example.com/v2"},
},
})
require.NoError(t, err)
domains, mapping := domainsFromProviders(providers)
// Count occurrences of api.example.com.
count := 0
for _, d := range domains {
if d == "api.example.com" {
count++
}
}
assert.Equal(t, 1, count)
// First provider wins.
assert.Equal(t, "first", mapping("api.example.com"))
})
t.Run("CaseInsensitive", func(t *testing.T) {
t.Parallel()
providers, err := buildProviders(codersdk.AIBridgeConfig{
Providers: []codersdk.AIProviderConfig{
{Type: aibridge.ProviderOpenAI, Name: "provider", Keys: []string{"k"}, BaseURL: "https://API.Example.COM/v1"},
},
})
require.NoError(t, err)
domains, mapping := domainsFromProviders(providers)
assert.Contains(t, domains, "api.example.com")
assert.Equal(t, "provider", mapping("API.Example.COM"))
assert.Equal(t, "provider", mapping("api.example.com"))
})
}
+23
View File
@@ -1026,6 +1026,29 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
return xerrors.Errorf("seed ai providers from env: %w", err)
}
// In-memory aibridge daemon. Registered on coderd so chatd can
// dispatch LLM requests via the in-process transport without
// crossing the gated /api/v2/aibridge HTTP route. The HTTP route
// itself is registered (and license-gated) only by enterprise/coderd;
// in AGPL builds it does not exist at all. The daemon starts here
// unconditionally when the bridge feature is enabled by config so
// chatd can use it regardless of license entitlement.
if vals.AI.BridgeConfig.Enabled.Value() {
providers, err := BuildProviders(vals.AI.BridgeConfig)
if err != nil {
return xerrors.Errorf("build AI providers: %w", err)
}
aibridgeDaemon, err := newAIBridgeDaemon(coderAPI, providers)
if err != nil {
return xerrors.Errorf("create aibridged: %w", err)
}
coderAPI.RegisterInMemoryAIBridgedHTTPHandler(aibridgeDaemon)
// The handler is bound to coderAPI's lifecycle; Close() on the
// daemon does not affect in-flight requests but is needed to
// release pool/recorder resources at shutdown.
defer aibridgeDaemon.Close()
}
if vals.Prometheus.Enable {
// Agent metrics require reference to the tailnet coordinator, so must be initiated after Coder API.
closeAgentsFunc, err := prometheusmetrics.Agents(ctx, logger, options.PrometheusRegistry, coderAPI.Database, &coderAPI.TailnetCoordinator, coderAPI.DERPMap, coderAPI.Options.AgentInactiveDisconnectTimeout, 0)
@@ -11,13 +11,22 @@ import (
"storj.io/drpc/drpcserver"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/aibridged"
aibridgedproto "github.com/coder/coder/v2/coderd/aibridged/proto"
"github.com/coder/coder/v2/coderd/aibridgedserver"
"github.com/coder/coder/v2/coderd/tracing"
"github.com/coder/coder/v2/codersdk/drpcsdk"
"github.com/coder/coder/v2/enterprise/aibridged"
aibridgedproto "github.com/coder/coder/v2/enterprise/aibridged/proto"
"github.com/coder/coder/v2/enterprise/aibridgedserver"
)
// GetAIBridgedHandler returns the in-memory aibridge HTTP handler set by
// [API.RegisterInMemoryAIBridgedHTTPHandler], or nil if the daemon has not
// been wired in. Used by the enterprise /api/v2/aibridge route (license-gated)
// to forward requests into the same in-memory handler that chatd dispatches
// to in-process.
func (api *API) GetAIBridgedHandler() http.Handler {
return api.aibridgedHandler
}
// RegisterInMemoryAIBridgedHTTPHandler mounts [aibridged.Server]'s HTTP router onto
// [API]'s router, so that requests to aibridged will be relayed from Coder's API server
// to the in-memory aibridged.
@@ -48,7 +57,7 @@ func (api *API) CreateInMemoryAIBridgeServer(dialCtx context.Context) (client ai
mux := drpcmux.New()
srv, err := aibridgedserver.NewServer(api.ctx, api.Database, api.Logger.Named("aibridgedserver"),
api.AccessURL.String(), api.DeploymentValues.AI.BridgeConfig, api.ExternalAuthConfigs, api.AGPL.Experiments, api.aiSeatTracker)
api.AccessURL.String(), api.DeploymentValues.AI.BridgeConfig, api.ExternalAuthConfigs, api.Experiments, api.AISeatTracker)
if err != nil {
return nil, err
}
@@ -78,11 +87,11 @@ func (api *API) CreateInMemoryAIBridgeServer(dialCtx context.Context) (client ai
// in-mem pipes aren't technically "websockets" but they have the same properties as far as the
// API is concerned: they are long-lived connections that we need to close before completing
// shutdown of the API.
api.AGPL.WebsocketWaitMutex.Lock()
api.AGPL.WebsocketWaitGroup.Add(1)
api.AGPL.WebsocketWaitMutex.Unlock()
api.WebsocketWaitMutex.Lock()
api.WebsocketWaitGroup.Add(1)
api.WebsocketWaitMutex.Unlock()
go func() {
defer api.AGPL.WebsocketWaitGroup.Done()
defer api.WebsocketWaitGroup.Done()
// Here we pass the background context, since we want the server to keep serving until the
// client hangs up. The aibridged is local, in-mem, so there isn't a danger of losing contact with it and
// having a dead connection we don't know the status of.
@@ -19,10 +19,10 @@ import (
"github.com/coder/coder/v2/aibridge"
"github.com/coder/coder/v2/aibridge/intercept"
agplaibridge "github.com/coder/coder/v2/coderd/aibridge"
"github.com/coder/coder/v2/coderd/aibridged"
mock "github.com/coder/coder/v2/coderd/aibridged/aibridgedmock"
"github.com/coder/coder/v2/coderd/aibridged/proto"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/enterprise/aibridged"
mock "github.com/coder/coder/v2/enterprise/aibridged/aibridgedmock"
"github.com/coder/coder/v2/enterprise/aibridged/proto"
"github.com/coder/coder/v2/testutil"
)
@@ -1,9 +1,9 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/coder/coder/v2/enterprise/aibridged (interfaces: DRPCClient)
// Source: github.com/coder/coder/v2/coderd/aibridged (interfaces: DRPCClient)
//
// Generated by this command:
//
// mockgen -destination ./clientmock.go -package aibridgedmock github.com/coder/coder/v2/enterprise/aibridged DRPCClient
// mockgen -destination ./clientmock.go -package aibridgedmock github.com/coder/coder/v2/coderd/aibridged DRPCClient
//
// Package aibridgedmock is a generated GoMock package.
@@ -13,7 +13,7 @@ import (
context "context"
reflect "reflect"
proto "github.com/coder/coder/v2/enterprise/aibridged/proto"
proto "github.com/coder/coder/v2/coderd/aibridged/proto"
gomock "go.uber.org/mock/gomock"
drpc "storj.io/drpc"
)
@@ -1,4 +1,4 @@
package aibridgedmock
//go:generate go tool mockgen -destination ./clientmock.go -package aibridgedmock github.com/coder/coder/v2/enterprise/aibridged DRPCClient
//go:generate go tool mockgen -destination ./poolmock.go -package aibridgedmock github.com/coder/coder/v2/enterprise/aibridged Pooler
//go:generate go tool mockgen -destination ./clientmock.go -package aibridgedmock github.com/coder/coder/v2/coderd/aibridged DRPCClient
//go:generate go tool mockgen -destination ./poolmock.go -package aibridgedmock github.com/coder/coder/v2/coderd/aibridged Pooler
@@ -1,9 +1,9 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/coder/coder/v2/enterprise/aibridged (interfaces: Pooler)
// Source: github.com/coder/coder/v2/coderd/aibridged (interfaces: Pooler)
//
// Generated by this command:
//
// mockgen -destination ./poolmock.go -package aibridgedmock github.com/coder/coder/v2/enterprise/aibridged Pooler
// mockgen -destination ./poolmock.go -package aibridgedmock github.com/coder/coder/v2/coderd/aibridged Pooler
//
// Package aibridgedmock is a generated GoMock package.
@@ -14,7 +14,7 @@ import (
http "net/http"
reflect "reflect"
aibridged "github.com/coder/coder/v2/enterprise/aibridged"
aibridged "github.com/coder/coder/v2/coderd/aibridged"
gomock "go.uber.org/mock/gomock"
)
@@ -5,7 +5,7 @@ import (
"storj.io/drpc"
"github.com/coder/coder/v2/enterprise/aibridged/proto"
"github.com/coder/coder/v2/coderd/aibridged/proto"
)
type Dialer func(ctx context.Context) (DRPCClient, error)
@@ -11,7 +11,7 @@ import (
"github.com/coder/coder/v2/aibridge"
"github.com/coder/coder/v2/aibridge/recorder"
agplaibridge "github.com/coder/coder/v2/coderd/aibridge"
"github.com/coder/coder/v2/enterprise/aibridged/proto"
"github.com/coder/coder/v2/coderd/aibridged/proto"
)
var _ http.Handler = &Server{}
@@ -11,7 +11,7 @@ import (
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/aibridge/mcp"
"github.com/coder/coder/v2/enterprise/aibridged/proto"
"github.com/coder/coder/v2/coderd/aibridged/proto"
)
var (
@@ -6,7 +6,7 @@ import (
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"github.com/coder/coder/v2/enterprise/aibridged/proto"
"github.com/coder/coder/v2/coderd/aibridged/proto"
"github.com/coder/coder/v2/testutil"
)
@@ -14,8 +14,8 @@ import (
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/aibridge/mcp"
"github.com/coder/coder/v2/aibridge/mcpmock"
"github.com/coder/coder/v2/enterprise/aibridged"
mock "github.com/coder/coder/v2/enterprise/aibridged/aibridgedmock"
"github.com/coder/coder/v2/coderd/aibridged"
mock "github.com/coder/coder/v2/coderd/aibridged/aibridgedmock"
)
// TestPool validates the published behavior of [aibridged.CachedBridgePool].
File diff suppressed because it is too large Load Diff
@@ -1,5 +1,5 @@
syntax = "proto3";
option go_package = "github.com/coder/coder/v2/aibridged/proto";
option go_package = "github.com/coder/coder/v2/coderd/aibridged/proto";
package proto;
@@ -1,6 +1,6 @@
// Code generated by protoc-gen-go-drpc. DO NOT EDIT.
// protoc-gen-go-drpc version: v0.0.34
// source: enterprise/aibridged/proto/aibridged.proto
// source: coderd/aibridged/proto/aibridged.proto
package proto
@@ -13,25 +13,25 @@ import (
drpcerr "storj.io/drpc/drpcerr"
)
type drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto struct{}
type drpcEncoding_File_coderd_aibridged_proto_aibridged_proto struct{}
func (drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto) Marshal(msg drpc.Message) ([]byte, error) {
func (drpcEncoding_File_coderd_aibridged_proto_aibridged_proto) Marshal(msg drpc.Message) ([]byte, error) {
return proto.Marshal(msg.(proto.Message))
}
func (drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto) MarshalAppend(buf []byte, msg drpc.Message) ([]byte, error) {
func (drpcEncoding_File_coderd_aibridged_proto_aibridged_proto) MarshalAppend(buf []byte, msg drpc.Message) ([]byte, error) {
return proto.MarshalOptions{}.MarshalAppend(buf, msg.(proto.Message))
}
func (drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto) Unmarshal(buf []byte, msg drpc.Message) error {
func (drpcEncoding_File_coderd_aibridged_proto_aibridged_proto) Unmarshal(buf []byte, msg drpc.Message) error {
return proto.Unmarshal(buf, msg.(proto.Message))
}
func (drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto) JSONMarshal(msg drpc.Message) ([]byte, error) {
func (drpcEncoding_File_coderd_aibridged_proto_aibridged_proto) JSONMarshal(msg drpc.Message) ([]byte, error) {
return protojson.Marshal(msg.(proto.Message))
}
func (drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto) JSONUnmarshal(buf []byte, msg drpc.Message) error {
func (drpcEncoding_File_coderd_aibridged_proto_aibridged_proto) JSONUnmarshal(buf []byte, msg drpc.Message) error {
return protojson.Unmarshal(buf, msg.(proto.Message))
}
@@ -58,7 +58,7 @@ func (c *drpcRecorderClient) DRPCConn() drpc.Conn { return c.cc }
func (c *drpcRecorderClient) RecordInterception(ctx context.Context, in *RecordInterceptionRequest) (*RecordInterceptionResponse, error) {
out := new(RecordInterceptionResponse)
err := c.cc.Invoke(ctx, "/proto.Recorder/RecordInterception", drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto{}, in, out)
err := c.cc.Invoke(ctx, "/proto.Recorder/RecordInterception", drpcEncoding_File_coderd_aibridged_proto_aibridged_proto{}, in, out)
if err != nil {
return nil, err
}
@@ -67,7 +67,7 @@ func (c *drpcRecorderClient) RecordInterception(ctx context.Context, in *RecordI
func (c *drpcRecorderClient) RecordInterceptionEnded(ctx context.Context, in *RecordInterceptionEndedRequest) (*RecordInterceptionEndedResponse, error) {
out := new(RecordInterceptionEndedResponse)
err := c.cc.Invoke(ctx, "/proto.Recorder/RecordInterceptionEnded", drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto{}, in, out)
err := c.cc.Invoke(ctx, "/proto.Recorder/RecordInterceptionEnded", drpcEncoding_File_coderd_aibridged_proto_aibridged_proto{}, in, out)
if err != nil {
return nil, err
}
@@ -76,7 +76,7 @@ func (c *drpcRecorderClient) RecordInterceptionEnded(ctx context.Context, in *Re
func (c *drpcRecorderClient) RecordTokenUsage(ctx context.Context, in *RecordTokenUsageRequest) (*RecordTokenUsageResponse, error) {
out := new(RecordTokenUsageResponse)
err := c.cc.Invoke(ctx, "/proto.Recorder/RecordTokenUsage", drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto{}, in, out)
err := c.cc.Invoke(ctx, "/proto.Recorder/RecordTokenUsage", drpcEncoding_File_coderd_aibridged_proto_aibridged_proto{}, in, out)
if err != nil {
return nil, err
}
@@ -85,7 +85,7 @@ func (c *drpcRecorderClient) RecordTokenUsage(ctx context.Context, in *RecordTok
func (c *drpcRecorderClient) RecordPromptUsage(ctx context.Context, in *RecordPromptUsageRequest) (*RecordPromptUsageResponse, error) {
out := new(RecordPromptUsageResponse)
err := c.cc.Invoke(ctx, "/proto.Recorder/RecordPromptUsage", drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto{}, in, out)
err := c.cc.Invoke(ctx, "/proto.Recorder/RecordPromptUsage", drpcEncoding_File_coderd_aibridged_proto_aibridged_proto{}, in, out)
if err != nil {
return nil, err
}
@@ -94,7 +94,7 @@ func (c *drpcRecorderClient) RecordPromptUsage(ctx context.Context, in *RecordPr
func (c *drpcRecorderClient) RecordToolUsage(ctx context.Context, in *RecordToolUsageRequest) (*RecordToolUsageResponse, error) {
out := new(RecordToolUsageResponse)
err := c.cc.Invoke(ctx, "/proto.Recorder/RecordToolUsage", drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto{}, in, out)
err := c.cc.Invoke(ctx, "/proto.Recorder/RecordToolUsage", drpcEncoding_File_coderd_aibridged_proto_aibridged_proto{}, in, out)
if err != nil {
return nil, err
}
@@ -103,7 +103,7 @@ func (c *drpcRecorderClient) RecordToolUsage(ctx context.Context, in *RecordTool
func (c *drpcRecorderClient) RecordModelThought(ctx context.Context, in *RecordModelThoughtRequest) (*RecordModelThoughtResponse, error) {
out := new(RecordModelThoughtResponse)
err := c.cc.Invoke(ctx, "/proto.Recorder/RecordModelThought", drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto{}, in, out)
err := c.cc.Invoke(ctx, "/proto.Recorder/RecordModelThought", drpcEncoding_File_coderd_aibridged_proto_aibridged_proto{}, in, out)
if err != nil {
return nil, err
}
@@ -152,7 +152,7 @@ func (DRPCRecorderDescription) NumMethods() int { return 6 }
func (DRPCRecorderDescription) Method(n int) (string, drpc.Encoding, drpc.Receiver, interface{}, bool) {
switch n {
case 0:
return "/proto.Recorder/RecordInterception", drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto{},
return "/proto.Recorder/RecordInterception", drpcEncoding_File_coderd_aibridged_proto_aibridged_proto{},
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
return srv.(DRPCRecorderServer).
RecordInterception(
@@ -161,7 +161,7 @@ func (DRPCRecorderDescription) Method(n int) (string, drpc.Encoding, drpc.Receiv
)
}, DRPCRecorderServer.RecordInterception, true
case 1:
return "/proto.Recorder/RecordInterceptionEnded", drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto{},
return "/proto.Recorder/RecordInterceptionEnded", drpcEncoding_File_coderd_aibridged_proto_aibridged_proto{},
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
return srv.(DRPCRecorderServer).
RecordInterceptionEnded(
@@ -170,7 +170,7 @@ func (DRPCRecorderDescription) Method(n int) (string, drpc.Encoding, drpc.Receiv
)
}, DRPCRecorderServer.RecordInterceptionEnded, true
case 2:
return "/proto.Recorder/RecordTokenUsage", drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto{},
return "/proto.Recorder/RecordTokenUsage", drpcEncoding_File_coderd_aibridged_proto_aibridged_proto{},
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
return srv.(DRPCRecorderServer).
RecordTokenUsage(
@@ -179,7 +179,7 @@ func (DRPCRecorderDescription) Method(n int) (string, drpc.Encoding, drpc.Receiv
)
}, DRPCRecorderServer.RecordTokenUsage, true
case 3:
return "/proto.Recorder/RecordPromptUsage", drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto{},
return "/proto.Recorder/RecordPromptUsage", drpcEncoding_File_coderd_aibridged_proto_aibridged_proto{},
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
return srv.(DRPCRecorderServer).
RecordPromptUsage(
@@ -188,7 +188,7 @@ func (DRPCRecorderDescription) Method(n int) (string, drpc.Encoding, drpc.Receiv
)
}, DRPCRecorderServer.RecordPromptUsage, true
case 4:
return "/proto.Recorder/RecordToolUsage", drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto{},
return "/proto.Recorder/RecordToolUsage", drpcEncoding_File_coderd_aibridged_proto_aibridged_proto{},
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
return srv.(DRPCRecorderServer).
RecordToolUsage(
@@ -197,7 +197,7 @@ func (DRPCRecorderDescription) Method(n int) (string, drpc.Encoding, drpc.Receiv
)
}, DRPCRecorderServer.RecordToolUsage, true
case 5:
return "/proto.Recorder/RecordModelThought", drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto{},
return "/proto.Recorder/RecordModelThought", drpcEncoding_File_coderd_aibridged_proto_aibridged_proto{},
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
return srv.(DRPCRecorderServer).
RecordModelThought(
@@ -224,7 +224,7 @@ type drpcRecorder_RecordInterceptionStream struct {
}
func (x *drpcRecorder_RecordInterceptionStream) SendAndClose(m *RecordInterceptionResponse) error {
if err := x.MsgSend(m, drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto{}); err != nil {
if err := x.MsgSend(m, drpcEncoding_File_coderd_aibridged_proto_aibridged_proto{}); err != nil {
return err
}
return x.CloseSend()
@@ -240,7 +240,7 @@ type drpcRecorder_RecordInterceptionEndedStream struct {
}
func (x *drpcRecorder_RecordInterceptionEndedStream) SendAndClose(m *RecordInterceptionEndedResponse) error {
if err := x.MsgSend(m, drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto{}); err != nil {
if err := x.MsgSend(m, drpcEncoding_File_coderd_aibridged_proto_aibridged_proto{}); err != nil {
return err
}
return x.CloseSend()
@@ -256,7 +256,7 @@ type drpcRecorder_RecordTokenUsageStream struct {
}
func (x *drpcRecorder_RecordTokenUsageStream) SendAndClose(m *RecordTokenUsageResponse) error {
if err := x.MsgSend(m, drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto{}); err != nil {
if err := x.MsgSend(m, drpcEncoding_File_coderd_aibridged_proto_aibridged_proto{}); err != nil {
return err
}
return x.CloseSend()
@@ -272,7 +272,7 @@ type drpcRecorder_RecordPromptUsageStream struct {
}
func (x *drpcRecorder_RecordPromptUsageStream) SendAndClose(m *RecordPromptUsageResponse) error {
if err := x.MsgSend(m, drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto{}); err != nil {
if err := x.MsgSend(m, drpcEncoding_File_coderd_aibridged_proto_aibridged_proto{}); err != nil {
return err
}
return x.CloseSend()
@@ -288,7 +288,7 @@ type drpcRecorder_RecordToolUsageStream struct {
}
func (x *drpcRecorder_RecordToolUsageStream) SendAndClose(m *RecordToolUsageResponse) error {
if err := x.MsgSend(m, drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto{}); err != nil {
if err := x.MsgSend(m, drpcEncoding_File_coderd_aibridged_proto_aibridged_proto{}); err != nil {
return err
}
return x.CloseSend()
@@ -304,7 +304,7 @@ type drpcRecorder_RecordModelThoughtStream struct {
}
func (x *drpcRecorder_RecordModelThoughtStream) SendAndClose(m *RecordModelThoughtResponse) error {
if err := x.MsgSend(m, drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto{}); err != nil {
if err := x.MsgSend(m, drpcEncoding_File_coderd_aibridged_proto_aibridged_proto{}); err != nil {
return err
}
return x.CloseSend()
@@ -329,7 +329,7 @@ func (c *drpcMCPConfiguratorClient) DRPCConn() drpc.Conn { return c.cc }
func (c *drpcMCPConfiguratorClient) GetMCPServerConfigs(ctx context.Context, in *GetMCPServerConfigsRequest) (*GetMCPServerConfigsResponse, error) {
out := new(GetMCPServerConfigsResponse)
err := c.cc.Invoke(ctx, "/proto.MCPConfigurator/GetMCPServerConfigs", drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto{}, in, out)
err := c.cc.Invoke(ctx, "/proto.MCPConfigurator/GetMCPServerConfigs", drpcEncoding_File_coderd_aibridged_proto_aibridged_proto{}, in, out)
if err != nil {
return nil, err
}
@@ -338,7 +338,7 @@ func (c *drpcMCPConfiguratorClient) GetMCPServerConfigs(ctx context.Context, in
func (c *drpcMCPConfiguratorClient) GetMCPServerAccessTokensBatch(ctx context.Context, in *GetMCPServerAccessTokensBatchRequest) (*GetMCPServerAccessTokensBatchResponse, error) {
out := new(GetMCPServerAccessTokensBatchResponse)
err := c.cc.Invoke(ctx, "/proto.MCPConfigurator/GetMCPServerAccessTokensBatch", drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto{}, in, out)
err := c.cc.Invoke(ctx, "/proto.MCPConfigurator/GetMCPServerAccessTokensBatch", drpcEncoding_File_coderd_aibridged_proto_aibridged_proto{}, in, out)
if err != nil {
return nil, err
}
@@ -367,7 +367,7 @@ func (DRPCMCPConfiguratorDescription) NumMethods() int { return 2 }
func (DRPCMCPConfiguratorDescription) Method(n int) (string, drpc.Encoding, drpc.Receiver, interface{}, bool) {
switch n {
case 0:
return "/proto.MCPConfigurator/GetMCPServerConfigs", drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto{},
return "/proto.MCPConfigurator/GetMCPServerConfigs", drpcEncoding_File_coderd_aibridged_proto_aibridged_proto{},
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
return srv.(DRPCMCPConfiguratorServer).
GetMCPServerConfigs(
@@ -376,7 +376,7 @@ func (DRPCMCPConfiguratorDescription) Method(n int) (string, drpc.Encoding, drpc
)
}, DRPCMCPConfiguratorServer.GetMCPServerConfigs, true
case 1:
return "/proto.MCPConfigurator/GetMCPServerAccessTokensBatch", drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto{},
return "/proto.MCPConfigurator/GetMCPServerAccessTokensBatch", drpcEncoding_File_coderd_aibridged_proto_aibridged_proto{},
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
return srv.(DRPCMCPConfiguratorServer).
GetMCPServerAccessTokensBatch(
@@ -403,7 +403,7 @@ type drpcMCPConfigurator_GetMCPServerConfigsStream struct {
}
func (x *drpcMCPConfigurator_GetMCPServerConfigsStream) SendAndClose(m *GetMCPServerConfigsResponse) error {
if err := x.MsgSend(m, drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto{}); err != nil {
if err := x.MsgSend(m, drpcEncoding_File_coderd_aibridged_proto_aibridged_proto{}); err != nil {
return err
}
return x.CloseSend()
@@ -419,7 +419,7 @@ type drpcMCPConfigurator_GetMCPServerAccessTokensBatchStream struct {
}
func (x *drpcMCPConfigurator_GetMCPServerAccessTokensBatchStream) SendAndClose(m *GetMCPServerAccessTokensBatchResponse) error {
if err := x.MsgSend(m, drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto{}); err != nil {
if err := x.MsgSend(m, drpcEncoding_File_coderd_aibridged_proto_aibridged_proto{}); err != nil {
return err
}
return x.CloseSend()
@@ -443,7 +443,7 @@ func (c *drpcAuthorizerClient) DRPCConn() drpc.Conn { return c.cc }
func (c *drpcAuthorizerClient) IsAuthorized(ctx context.Context, in *IsAuthorizedRequest) (*IsAuthorizedResponse, error) {
out := new(IsAuthorizedResponse)
err := c.cc.Invoke(ctx, "/proto.Authorizer/IsAuthorized", drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto{}, in, out)
err := c.cc.Invoke(ctx, "/proto.Authorizer/IsAuthorized", drpcEncoding_File_coderd_aibridged_proto_aibridged_proto{}, in, out)
if err != nil {
return nil, err
}
@@ -467,7 +467,7 @@ func (DRPCAuthorizerDescription) NumMethods() int { return 1 }
func (DRPCAuthorizerDescription) Method(n int) (string, drpc.Encoding, drpc.Receiver, interface{}, bool) {
switch n {
case 0:
return "/proto.Authorizer/IsAuthorized", drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto{},
return "/proto.Authorizer/IsAuthorized", drpcEncoding_File_coderd_aibridged_proto_aibridged_proto{},
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
return srv.(DRPCAuthorizerServer).
IsAuthorized(
@@ -494,7 +494,7 @@ type drpcAuthorizer_IsAuthorizedStream struct {
}
func (x *drpcAuthorizer_IsAuthorizedStream) SendAndClose(m *IsAuthorizedResponse) error {
if err := x.MsgSend(m, drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto{}); err != nil {
if err := x.MsgSend(m, drpcEncoding_File_coderd_aibridged_proto_aibridged_proto{}); err != nil {
return err
}
return x.CloseSend()
@@ -1,6 +1,6 @@
package aibridged
import "github.com/coder/coder/v2/enterprise/aibridged/proto"
import "github.com/coder/coder/v2/coderd/aibridged/proto"
type DRPCServer interface {
proto.DRPCRecorderServer
@@ -11,8 +11,8 @@ import (
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/coder/coder/v2/aibridge"
"github.com/coder/coder/v2/coderd/aibridged/proto"
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/enterprise/aibridged/proto"
)
var _ aibridge.Recorder = &recorderTranslation{}
@@ -3,8 +3,12 @@ package aibridged_test
import (
"net/http"
"sync/atomic"
"go.opentelemetry.io/otel"
)
var testTracer = otel.Tracer("aibridged_test")
var _ http.Handler = &mockAIUpstreamServer{}
type mockAIUpstreamServer struct {
@@ -16,6 +16,8 @@ import (
"google.golang.org/protobuf/types/known/structpb"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/aibridged"
"github.com/coder/coder/v2/coderd/aibridged/proto"
"github.com/coder/coder/v2/coderd/aiseats"
"github.com/coder/coder/v2/coderd/apikey"
"github.com/coder/coder/v2/coderd/database"
@@ -25,8 +27,6 @@ import (
"github.com/coder/coder/v2/coderd/httpmw"
codermcp "github.com/coder/coder/v2/coderd/mcp"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/enterprise/aibridged"
"github.com/coder/coder/v2/enterprise/aibridged/proto"
)
var (
@@ -24,6 +24,9 @@ import (
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogjson"
"github.com/coder/coder/v2/coderd/aibridged"
"github.com/coder/coder/v2/coderd/aibridged/proto"
"github.com/coder/coder/v2/coderd/aibridgedserver"
agplaiseats "github.com/coder/coder/v2/coderd/aiseats"
"github.com/coder/coder/v2/coderd/apikey"
"github.com/coder/coder/v2/coderd/database"
@@ -36,9 +39,6 @@ import (
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/cryptorand"
"github.com/coder/coder/v2/enterprise/aibridged"
"github.com/coder/coder/v2/enterprise/aibridged/proto"
"github.com/coder/coder/v2/enterprise/aibridgedserver"
"github.com/coder/coder/v2/testutil"
"github.com/coder/serpent"
)
+4
View File
@@ -2198,6 +2198,10 @@ type API struct {
// UsageInserter is a pointer to an atomic pointer because it is passed to
// multiple components.
UsageInserter *atomic.Pointer[usage.Inserter]
// aibridgedHandler is the in-memory aibridge HTTP handler. Set by
// RegisterInMemoryAIBridgedHTTPHandler; read by the enterprise
// /api/v2/aibridge route (license-gated).
aibridgedHandler http.Handler
UpdatesProvider tailnet.WorkspaceUpdatesProvider
@@ -1,4 +1,4 @@
package aibridged_test
package enterprise_test
import (
"bytes"
@@ -22,6 +22,8 @@ import (
"github.com/coder/coder/v2/aibridge"
"github.com/coder/coder/v2/aibridge/config"
aibtracing "github.com/coder/coder/v2/aibridge/tracing"
"github.com/coder/coder/v2/coderd/aibridged"
"github.com/coder/coder/v2/coderd/aibridgedserver"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
@@ -29,13 +31,11 @@ import (
"github.com/coder/coder/v2/coderd/externalauth"
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/enterprise/aibridged"
"github.com/coder/coder/v2/enterprise/aibridgedserver"
"github.com/coder/coder/v2/enterprise/coderd/coderdenttest"
"github.com/coder/coder/v2/testutil"
)
var testTracer = otel.Tracer("aibridged_test")
var testTracer = otel.Tracer("aibridged_inttest")
// TestIntegration is not an exhaustive test against the upstream AI providers' SDKs (see coder/aibridge for those).
// This test validates that:
@@ -179,7 +179,7 @@ func TestIntegration(t *testing.T) {
require.NoError(t, err)
// Create aibridge server & client.
aiBridgeClient, err := api.CreateInMemoryAIBridgeServer(ctx)
aiBridgeClient, err := api.AGPL.CreateInMemoryAIBridgeServer(ctx)
require.NoError(t, err)
logger := testutil.Logger(t)
@@ -379,7 +379,7 @@ func TestIntegrationWithMetrics(t *testing.T) {
require.NoError(t, err)
// Create aibridge client.
aiBridgeClient, err := api.CreateInMemoryAIBridgeServer(ctx)
aiBridgeClient, err := api.AGPL.CreateInMemoryAIBridgeServer(ctx)
require.NoError(t, err)
logger := testutil.Logger(t)
@@ -476,7 +476,7 @@ func TestIntegrationCircuitBreaker(t *testing.T) {
require.NoError(t, err)
// Create aibridge client.
aiBridgeClient, err := api.CreateInMemoryAIBridgeServer(ctx)
aiBridgeClient, err := api.AGPL.CreateInMemoryAIBridgeServer(ctx)
require.NoError(t, err)
logger := testutil.Logger(t)
@@ -0,0 +1,84 @@
//go:build !slim
package cli
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/aibridge"
agplcli "github.com/coder/coder/v2/cli"
"github.com/coder/coder/v2/codersdk"
)
func TestDomainsFromProviders(t *testing.T) {
t.Parallel()
t.Run("ExtractsHostnames", func(t *testing.T) {
t.Parallel()
providers, err := agplcli.BuildProviders(codersdk.AIBridgeConfig{
Providers: []codersdk.AIProviderConfig{
{Type: aibridge.ProviderOpenAI, Name: "openai", Keys: []string{"k"}},
{Type: aibridge.ProviderAnthropic, Name: "anthropic", Keys: []string{"k"}},
{Type: aibridge.ProviderOpenAI, Name: "custom", Keys: []string{"k"}, BaseURL: "https://custom-llm.example.com:8443/api"},
},
})
require.NoError(t, err)
domains, mapping := domainsFromProviders(providers)
assert.Contains(t, domains, "api.openai.com")
assert.Contains(t, domains, "api.anthropic.com")
assert.Contains(t, domains, "custom-llm.example.com")
assert.Equal(t, "openai", mapping("api.openai.com"))
assert.Equal(t, "anthropic", mapping("api.anthropic.com"))
assert.Equal(t, "custom", mapping("custom-llm.example.com"))
assert.Empty(t, mapping("unknown.com"))
})
t.Run("DeduplicatesSameHost", func(t *testing.T) {
t.Parallel()
providers, err := agplcli.BuildProviders(codersdk.AIBridgeConfig{
Providers: []codersdk.AIProviderConfig{
{Type: aibridge.ProviderOpenAI, Name: "first", Keys: []string{"k"}, BaseURL: "https://api.example.com/v1"},
{Type: aibridge.ProviderOpenAI, Name: "second", Keys: []string{"k"}, BaseURL: "https://api.example.com/v2"},
},
})
require.NoError(t, err)
domains, mapping := domainsFromProviders(providers)
// Count occurrences of api.example.com.
count := 0
for _, d := range domains {
if d == "api.example.com" {
count++
}
}
assert.Equal(t, 1, count)
// First provider wins.
assert.Equal(t, "first", mapping("api.example.com"))
})
t.Run("CaseInsensitive", func(t *testing.T) {
t.Parallel()
providers, err := agplcli.BuildProviders(codersdk.AIBridgeConfig{
Providers: []codersdk.AIProviderConfig{
{Type: aibridge.ProviderOpenAI, Name: "provider", Keys: []string{"k"}, BaseURL: "https://API.Example.COM/v1"},
},
})
require.NoError(t, err)
domains, mapping := domainsFromProviders(providers)
assert.Contains(t, domains, "api.example.com")
assert.Equal(t, "provider", mapping("API.Example.COM"))
assert.Equal(t, "provider", mapping("api.example.com"))
})
}
+14 -40
View File
@@ -15,6 +15,7 @@ import (
"tailscale.com/derp"
"tailscale.com/types/key"
agplcli "github.com/coder/coder/v2/cli"
agplcoderd "github.com/coder/coder/v2/coderd"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/cryptorand"
@@ -161,51 +162,24 @@ func (r *RootCmd) Server(_ func()) *serpent.Command {
usageCron.Start(ctx)
closers.Add(usageCron)
// Build the provider list and start AI Bridge daemons only when
// at least one of the bridge or proxy features is enabled. The
// ai_providers env-seed runs unconditionally in the AGPL
// codepath, regardless of whether bridge or proxy are enabled.
bridgeEnabled := options.DeploymentValues.AI.BridgeConfig.Enabled.Value()
proxyEnabled := options.DeploymentValues.AI.BridgeProxyConfig.Enabled.Value()
if bridgeEnabled || proxyEnabled {
providers, err := buildProviders(options.DeploymentValues.AI.BridgeConfig)
// In-memory AI Bridge Proxy daemon. The bridge daemon itself is
// started unconditionally by AGPL cli/server.go (chatd uses its
// in-memory roundtripper regardless of license); only the proxy
// daemon remains enterprise-gated by config.
if options.DeploymentValues.AI.BridgeProxyConfig.Enabled.Value() {
providers, err := agplcli.BuildProviders(options.DeploymentValues.AI.BridgeConfig)
if err != nil {
return nil, nil, xerrors.Errorf("build AI providers: %w", err)
}
// In-memory aibridge daemon.
// TODO(@deansheather): the lifecycle of the aibridged server is
// probably better managed by the enterprise API type itself. Managing
// it in the API type means we can avoid starting it up when the license
// is not entitled to the feature.
if bridgeEnabled {
aibridgeDaemon, err := newAIBridgeDaemon(api, providers)
if err != nil {
return nil, nil, xerrors.Errorf("create aibridged: %w", err)
}
api.RegisterInMemoryAIBridgedHTTPHandler(aibridgeDaemon)
// When running as an in-memory daemon, the HTTP handler is
// wired into the coderd API and therefore is subject to its
// context. Calling Close() on aibridged will NOT affect
// in-flight requests but those will be closed once the API
// server is itself shutdown.
closers.Add(aibridgeDaemon)
aiBridgeProxyServer, err := newAIBridgeProxyDaemon(api, providers)
if err != nil {
_ = closers.Close()
return nil, nil, xerrors.Errorf("create aibridgeproxyd: %w", err)
}
closers.Add(aiBridgeProxyServer)
// In-memory AI Bridge Proxy daemon.
if proxyEnabled {
aiBridgeProxyServer, err := newAIBridgeProxyDaemon(api, providers)
if err != nil {
_ = closers.Close()
return nil, nil, xerrors.Errorf("create aibridgeproxyd: %w", err)
}
closers.Add(aiBridgeProxyServer)
// Register the handler so coderd can serve the proxy endpoints.
api.RegisterInMemoryAIBridgeProxydHTTPHandler(aiBridgeProxyServer.Handler())
}
// Register the handler so coderd can serve the proxy endpoints.
api.RegisterInMemoryAIBridgeProxydHTTPHandler(aiBridgeProxyServer.Handler())
}
return api.AGPL, closers, nil
+2 -2
View File
@@ -69,7 +69,7 @@ func aibridgeHandler(api *API, middlewares ...func(http.Handler) http.Handler) f
// This is a bit funky but since aibridge only exposes a HTTP
// handler, this is how it has to be.
r.HandleFunc("/*", func(rw http.ResponseWriter, r *http.Request) {
if api.aibridgedHandler == nil {
if api.AGPL.GetAIBridgedHandler() == nil {
httpapi.Write(r.Context(), rw, http.StatusNotFound, codersdk.Response{
Message: "aibridged handler not mounted",
})
@@ -86,7 +86,7 @@ func aibridgeHandler(api *API, middlewares ...func(http.Handler) http.Handler) f
return
}
http.StripPrefix("/api/v2/aibridge", api.aibridgedHandler).ServeHTTP(rw, r)
http.StripPrefix("/api/v2/aibridge", api.AGPL.GetAIBridgedHandler()).ServeHTTP(rw, r)
})
})
}
+4 -4
View File
@@ -1844,7 +1844,7 @@ func TestAIBridgeRouting(t *testing.T) {
rw.WriteHeader(http.StatusOK)
_, _ = rw.Write([]byte(r.URL.Path))
})
api.RegisterInMemoryAIBridgedHTTPHandler(testHandler)
api.AGPL.RegisterInMemoryAIBridgedHTTPHandler(testHandler)
cases := []struct {
name string
@@ -1907,7 +1907,7 @@ func TestAIBridgeRateLimiting(t *testing.T) {
testHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(http.StatusOK)
})
api.RegisterInMemoryAIBridgedHTTPHandler(testHandler)
api.AGPL.RegisterInMemoryAIBridgedHTTPHandler(testHandler)
ctx := testutil.Context(t, testutil.WaitLong)
httpClient := &http.Client{}
@@ -1967,7 +1967,7 @@ func TestAIBridgeConcurrencyLimiting(t *testing.T) {
<-unblock
rw.WriteHeader(http.StatusOK)
})
api.RegisterInMemoryAIBridgedHTTPHandler(testHandler)
api.AGPL.RegisterInMemoryAIBridgedHTTPHandler(testHandler)
ctx := testutil.Context(t, testutil.WaitLong)
httpClient := &http.Client{}
@@ -2583,7 +2583,7 @@ func TestAIBridgeAllowBYOK(t *testing.T) {
testHandler := http.HandlerFunc(func(rw http.ResponseWriter, _ *http.Request) {
rw.WriteHeader(http.StatusOK)
})
api.RegisterInMemoryAIBridgedHTTPHandler(testHandler)
api.AGPL.RegisterInMemoryAIBridgedHTTPHandler(testHandler)
ctx := testutil.Context(t, testutil.WaitLong)
reqURL := client.URL.String() + "/api/v2/aibridge/test"
-1
View File
@@ -797,7 +797,6 @@ type API struct {
licenseMetricsCollector *license.MetricsCollector
tailnetService *tailnet.ClientService
aibridgedHandler http.Handler
aibridgeproxydHandler http.Handler
aiSeatTracker *aiseats.SeatTracker
}