mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
refactor: move aibridged out of enterprise to AGPL (#25570)
In order to allow Coder Agents to use AI Gateway in OSS, we need to rehome the `aibridged`\-related code into the AGPL path. The HTTP API is only registered under enterprise so will still require the AI Governance Add-on to be present in order to use it, whereas Coder Agents uses an in-memory pipe to the same handlers.
This commit is contained in:
@@ -53,8 +53,8 @@ endif
|
||||
tailnet/tailnettest/coordinateemock.go \
|
||||
tailnet/tailnettest/workspaceupdatesprovidermock.go \
|
||||
tailnet/tailnettest/subscriptionmock.go \
|
||||
enterprise/aibridged/aibridgedmock/clientmock.go \
|
||||
enterprise/aibridged/aibridgedmock/poolmock.go \
|
||||
coderd/aibridged/aibridgedmock/clientmock.go \
|
||||
coderd/aibridged/aibridgedmock/poolmock.go \
|
||||
tailnet/proto/tailnet.pb.go \
|
||||
agent/proto/agent.pb.go \
|
||||
agent/agentsocket/proto/agentsocket.pb.go \
|
||||
@@ -62,7 +62,7 @@ endif
|
||||
provisionersdk/proto/provisioner.pb.go \
|
||||
provisionerd/proto/provisionerd.pb.go \
|
||||
vpn/vpn.pb.go \
|
||||
enterprise/aibridged/proto/aibridged.pb.go \
|
||||
coderd/aibridged/proto/aibridged.pb.go \
|
||||
site/src/api/typesGenerated.ts \
|
||||
site/e2e/provisionerGenerated.ts \
|
||||
site/src/api/chatModelOptionsGenerated.json \
|
||||
@@ -956,8 +956,8 @@ TAILNETTEST_MOCKS := \
|
||||
tailnet/tailnettest/subscriptionmock.go
|
||||
|
||||
AIBRIDGED_MOCKS := \
|
||||
enterprise/aibridged/aibridgedmock/clientmock.go \
|
||||
enterprise/aibridged/aibridgedmock/poolmock.go
|
||||
coderd/aibridged/aibridgedmock/clientmock.go \
|
||||
coderd/aibridged/aibridgedmock/poolmock.go
|
||||
|
||||
GEN_FILES := \
|
||||
tailnet/proto/tailnet.pb.go \
|
||||
@@ -967,7 +967,7 @@ GEN_FILES := \
|
||||
provisionersdk/proto/provisioner.pb.go \
|
||||
provisionerd/proto/provisionerd.pb.go \
|
||||
vpn/vpn.pb.go \
|
||||
enterprise/aibridged/proto/aibridged.pb.go \
|
||||
coderd/aibridged/proto/aibridged.pb.go \
|
||||
$(DB_GEN_FILES) \
|
||||
$(SITE_GEN_FILES) \
|
||||
coderd/rbac/object_gen.go \
|
||||
@@ -1032,7 +1032,7 @@ gen/mark-fresh:
|
||||
agent/agentsocket/proto/agentsocket.pb.go \
|
||||
agent/boundarylogproxy/codec/boundary.pb.go \
|
||||
vpn/vpn.pb.go \
|
||||
enterprise/aibridged/proto/aibridged.pb.go \
|
||||
coderd/aibridged/proto/aibridged.pb.go \
|
||||
coderd/database/dump.sql \
|
||||
coderd/database/querier.go \
|
||||
coderd/database/unique_constraint.go \
|
||||
@@ -1121,8 +1121,8 @@ codersdk/workspacesdk/agentconnmock/agentconnmock.go: codersdk/workspacesdk/agen
|
||||
./scripts/format_go_file.sh "$@"
|
||||
touch "$@"
|
||||
|
||||
$(AIBRIDGED_MOCKS): enterprise/aibridged/client.go enterprise/aibridged/pool.go
|
||||
go generate ./enterprise/aibridged/aibridgedmock/
|
||||
$(AIBRIDGED_MOCKS): coderd/aibridged/client.go coderd/aibridged/pool.go
|
||||
go generate ./coderd/aibridged/aibridgedmock/
|
||||
touch "$@"
|
||||
|
||||
agent/agentcontainers/dcspec/dcspec_gen.go: \
|
||||
@@ -1189,13 +1189,13 @@ agent/boundarylogproxy/codec/boundary.pb.go: agent/boundarylogproxy/codec/bounda
|
||||
--go_opt=paths=source_relative \
|
||||
./agent/boundarylogproxy/codec/boundary.proto
|
||||
|
||||
enterprise/aibridged/proto/aibridged.pb.go: enterprise/aibridged/proto/aibridged.proto
|
||||
coderd/aibridged/proto/aibridged.pb.go: coderd/aibridged/proto/aibridged.proto
|
||||
./scripts/atomic_protoc.sh \
|
||||
--go_out=. \
|
||||
--go_opt=paths=source_relative \
|
||||
--go-drpc_out=. \
|
||||
--go-drpc_opt=paths=source_relative \
|
||||
./enterprise/aibridged/proto/aibridged.proto
|
||||
./coderd/aibridged/proto/aibridged.proto
|
||||
|
||||
site/src/api/typesGenerated.ts: site/node_modules/.installed $(wildcard scripts/apitypings/*) \
|
||||
$(shell find ./codersdk $(FIND_EXCLUSIONS) -type f -name '*.go') \
|
||||
|
||||
@@ -11,10 +11,10 @@ import (
|
||||
"github.com/coder/coder/v2/aibridge"
|
||||
"github.com/coder/coder/v2/aibridge/config"
|
||||
"github.com/coder/coder/v2/aibridge/keypool"
|
||||
"github.com/coder/coder/v2/coderd"
|
||||
"github.com/coder/coder/v2/coderd/aibridged"
|
||||
"github.com/coder/coder/v2/coderd/tracing"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/enterprise/aibridged"
|
||||
"github.com/coder/coder/v2/enterprise/coderd"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
@@ -44,13 +44,13 @@ func newAIBridgeDaemon(coderAPI *coderd.API, providers []aibridge.Provider) (*ai
|
||||
return srv, nil
|
||||
}
|
||||
|
||||
// buildProviders constructs the list of AI providers from config.
|
||||
// BuildProviders constructs the list of AI providers from config.
|
||||
// It merges legacy single-provider env vars and indexed provider configs:
|
||||
// 1. Legacy providers (from CODER_AI_GATEWAY_OPENAI_KEY, etc.) are added first.
|
||||
// If a legacy name conflicts with an indexed provider, startup fails with
|
||||
// a clear error asking the admin to remove one or the other.
|
||||
// 2. Indexed providers (from CODER_AI_GATEWAY_PROVIDER_<N>_*) are added next.
|
||||
func buildProviders(cfg codersdk.AIBridgeConfig) ([]aibridge.Provider, error) {
|
||||
func BuildProviders(cfg codersdk.AIBridgeConfig) ([]aibridge.Provider, error) {
|
||||
var cbConfig *config.CircuitBreaker
|
||||
if cfg.CircuitBreakerEnabled.Value() {
|
||||
cbConfig = &config.CircuitBreaker{
|
||||
@@ -19,7 +19,7 @@ func TestBuildProviders(t *testing.T) {
|
||||
|
||||
t.Run("EmptyConfig", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
providers, err := buildProviders(codersdk.AIBridgeConfig{})
|
||||
providers, err := BuildProviders(codersdk.AIBridgeConfig{})
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, providers)
|
||||
})
|
||||
@@ -30,7 +30,7 @@ func TestBuildProviders(t *testing.T) {
|
||||
cfg.LegacyOpenAI.Key = serpent.String("sk-openai")
|
||||
cfg.LegacyAnthropic.Key = serpent.String("sk-anthropic")
|
||||
|
||||
providers, err := buildProviders(cfg)
|
||||
providers, err := BuildProviders(cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
names := providerNames(providers)
|
||||
@@ -59,7 +59,7 @@ func TestBuildProviders(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
providers, err := buildProviders(cfg)
|
||||
providers, err := BuildProviders(cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
names := providerNames(providers)
|
||||
@@ -77,7 +77,7 @@ func TestBuildProviders(t *testing.T) {
|
||||
}
|
||||
cfg.LegacyOpenAI.Key = serpent.String("sk-legacy")
|
||||
|
||||
_, err := buildProviders(cfg)
|
||||
_, err := BuildProviders(cfg)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "conflicts with indexed provider")
|
||||
})
|
||||
@@ -91,7 +91,7 @@ func TestBuildProviders(t *testing.T) {
|
||||
}
|
||||
cfg.LegacyAnthropic.Key = serpent.String("sk-legacy")
|
||||
|
||||
_, err := buildProviders(cfg)
|
||||
_, err := BuildProviders(cfg)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "conflicts with indexed provider")
|
||||
})
|
||||
@@ -106,7 +106,7 @@ func TestBuildProviders(t *testing.T) {
|
||||
cfg.LegacyOpenAI.Key = serpent.String("sk-openai")
|
||||
cfg.LegacyAnthropic.Key = serpent.String("sk-anthropic")
|
||||
|
||||
providers, err := buildProviders(cfg)
|
||||
providers, err := BuildProviders(cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
names := providerNames(providers)
|
||||
@@ -123,7 +123,7 @@ func TestBuildProviders(t *testing.T) {
|
||||
cfg.LegacyBedrock.AccessKey = serpent.String("AKID")
|
||||
cfg.LegacyBedrock.AccessKeySecret = serpent.String("secret")
|
||||
|
||||
providers, err := buildProviders(cfg)
|
||||
providers, err := BuildProviders(cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
names := providerNames(providers)
|
||||
@@ -139,7 +139,7 @@ func TestBuildProviders(t *testing.T) {
|
||||
cfg.LegacyBedrock.AccessKey = serpent.String("AKID")
|
||||
cfg.LegacyBedrock.AccessKeySecret = serpent.String("secret")
|
||||
|
||||
providers, err := buildProviders(cfg)
|
||||
providers, err := BuildProviders(cfg)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, providers, 1)
|
||||
|
||||
@@ -156,7 +156,7 @@ func TestBuildProviders(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
_, err := buildProviders(cfg)
|
||||
_, err := BuildProviders(cfg)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unknown provider type")
|
||||
})
|
||||
@@ -173,7 +173,7 @@ func TestBuildProviders(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
providers, err := buildProviders(cfg)
|
||||
providers, err := BuildProviders(cfg)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, providers, 3)
|
||||
|
||||
@@ -195,7 +195,7 @@ func TestBuildProviders(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
providers, err := buildProviders(cfg)
|
||||
providers, err := BuildProviders(cfg)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, providers, 1)
|
||||
|
||||
@@ -211,73 +211,3 @@ func providerNames(providers []aibridge.Provider) []string {
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
func TestDomainsFromProviders(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("ExtractsHostnames", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
providers, err := buildProviders(codersdk.AIBridgeConfig{
|
||||
Providers: []codersdk.AIProviderConfig{
|
||||
{Type: aibridge.ProviderOpenAI, Name: "openai", Keys: []string{"k"}},
|
||||
{Type: aibridge.ProviderAnthropic, Name: "anthropic", Keys: []string{"k"}},
|
||||
{Type: aibridge.ProviderOpenAI, Name: "custom", Keys: []string{"k"}, BaseURL: "https://custom-llm.example.com:8443/api"},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
domains, mapping := domainsFromProviders(providers)
|
||||
|
||||
assert.Contains(t, domains, "api.openai.com")
|
||||
assert.Contains(t, domains, "api.anthropic.com")
|
||||
assert.Contains(t, domains, "custom-llm.example.com")
|
||||
|
||||
assert.Equal(t, "openai", mapping("api.openai.com"))
|
||||
assert.Equal(t, "anthropic", mapping("api.anthropic.com"))
|
||||
assert.Equal(t, "custom", mapping("custom-llm.example.com"))
|
||||
assert.Empty(t, mapping("unknown.com"))
|
||||
})
|
||||
|
||||
t.Run("DeduplicatesSameHost", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
providers, err := buildProviders(codersdk.AIBridgeConfig{
|
||||
Providers: []codersdk.AIProviderConfig{
|
||||
{Type: aibridge.ProviderOpenAI, Name: "first", Keys: []string{"k"}, BaseURL: "https://api.example.com/v1"},
|
||||
{Type: aibridge.ProviderOpenAI, Name: "second", Keys: []string{"k"}, BaseURL: "https://api.example.com/v2"},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
domains, mapping := domainsFromProviders(providers)
|
||||
|
||||
// Count occurrences of api.example.com.
|
||||
count := 0
|
||||
for _, d := range domains {
|
||||
if d == "api.example.com" {
|
||||
count++
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 1, count)
|
||||
// First provider wins.
|
||||
assert.Equal(t, "first", mapping("api.example.com"))
|
||||
})
|
||||
|
||||
t.Run("CaseInsensitive", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
providers, err := buildProviders(codersdk.AIBridgeConfig{
|
||||
Providers: []codersdk.AIProviderConfig{
|
||||
{Type: aibridge.ProviderOpenAI, Name: "provider", Keys: []string{"k"}, BaseURL: "https://API.Example.COM/v1"},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
domains, mapping := domainsFromProviders(providers)
|
||||
|
||||
assert.Contains(t, domains, "api.example.com")
|
||||
assert.Equal(t, "provider", mapping("API.Example.COM"))
|
||||
assert.Equal(t, "provider", mapping("api.example.com"))
|
||||
})
|
||||
}
|
||||
@@ -1026,6 +1026,29 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
|
||||
return xerrors.Errorf("seed ai providers from env: %w", err)
|
||||
}
|
||||
|
||||
// In-memory aibridge daemon. Registered on coderd so chatd can
|
||||
// dispatch LLM requests via the in-process transport without
|
||||
// crossing the gated /api/v2/aibridge HTTP route. The HTTP route
|
||||
// itself is registered (and license-gated) only by enterprise/coderd;
|
||||
// in AGPL builds it does not exist at all. The daemon starts here
|
||||
// unconditionally when the bridge feature is enabled by config so
|
||||
// chatd can use it regardless of license entitlement.
|
||||
if vals.AI.BridgeConfig.Enabled.Value() {
|
||||
providers, err := BuildProviders(vals.AI.BridgeConfig)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("build AI providers: %w", err)
|
||||
}
|
||||
aibridgeDaemon, err := newAIBridgeDaemon(coderAPI, providers)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create aibridged: %w", err)
|
||||
}
|
||||
coderAPI.RegisterInMemoryAIBridgedHTTPHandler(aibridgeDaemon)
|
||||
// The handler is bound to coderAPI's lifecycle; Close() on the
|
||||
// daemon does not affect in-flight requests but is needed to
|
||||
// release pool/recorder resources at shutdown.
|
||||
defer aibridgeDaemon.Close()
|
||||
}
|
||||
|
||||
if vals.Prometheus.Enable {
|
||||
// Agent metrics require reference to the tailnet coordinator, so must be initiated after Coder API.
|
||||
closeAgentsFunc, err := prometheusmetrics.Agents(ctx, logger, options.PrometheusRegistry, coderAPI.Database, &coderAPI.TailnetCoordinator, coderAPI.DERPMap, coderAPI.Options.AgentInactiveDisconnectTimeout, 0)
|
||||
|
||||
@@ -11,13 +11,22 @@ import (
|
||||
"storj.io/drpc/drpcserver"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/aibridged"
|
||||
aibridgedproto "github.com/coder/coder/v2/coderd/aibridged/proto"
|
||||
"github.com/coder/coder/v2/coderd/aibridgedserver"
|
||||
"github.com/coder/coder/v2/coderd/tracing"
|
||||
"github.com/coder/coder/v2/codersdk/drpcsdk"
|
||||
"github.com/coder/coder/v2/enterprise/aibridged"
|
||||
aibridgedproto "github.com/coder/coder/v2/enterprise/aibridged/proto"
|
||||
"github.com/coder/coder/v2/enterprise/aibridgedserver"
|
||||
)
|
||||
|
||||
// GetAIBridgedHandler returns the in-memory aibridge HTTP handler set by
|
||||
// [API.RegisterInMemoryAIBridgedHTTPHandler], or nil if the daemon has not
|
||||
// been wired in. Used by the enterprise /api/v2/aibridge route (license-gated)
|
||||
// to forward requests into the same in-memory handler that chatd dispatches
|
||||
// to in-process.
|
||||
func (api *API) GetAIBridgedHandler() http.Handler {
|
||||
return api.aibridgedHandler
|
||||
}
|
||||
|
||||
// RegisterInMemoryAIBridgedHTTPHandler mounts [aibridged.Server]'s HTTP router onto
|
||||
// [API]'s router, so that requests to aibridged will be relayed from Coder's API server
|
||||
// to the in-memory aibridged.
|
||||
@@ -48,7 +57,7 @@ func (api *API) CreateInMemoryAIBridgeServer(dialCtx context.Context) (client ai
|
||||
|
||||
mux := drpcmux.New()
|
||||
srv, err := aibridgedserver.NewServer(api.ctx, api.Database, api.Logger.Named("aibridgedserver"),
|
||||
api.AccessURL.String(), api.DeploymentValues.AI.BridgeConfig, api.ExternalAuthConfigs, api.AGPL.Experiments, api.aiSeatTracker)
|
||||
api.AccessURL.String(), api.DeploymentValues.AI.BridgeConfig, api.ExternalAuthConfigs, api.Experiments, api.AISeatTracker)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -78,11 +87,11 @@ func (api *API) CreateInMemoryAIBridgeServer(dialCtx context.Context) (client ai
|
||||
// in-mem pipes aren't technically "websockets" but they have the same properties as far as the
|
||||
// API is concerned: they are long-lived connections that we need to close before completing
|
||||
// shutdown of the API.
|
||||
api.AGPL.WebsocketWaitMutex.Lock()
|
||||
api.AGPL.WebsocketWaitGroup.Add(1)
|
||||
api.AGPL.WebsocketWaitMutex.Unlock()
|
||||
api.WebsocketWaitMutex.Lock()
|
||||
api.WebsocketWaitGroup.Add(1)
|
||||
api.WebsocketWaitMutex.Unlock()
|
||||
go func() {
|
||||
defer api.AGPL.WebsocketWaitGroup.Done()
|
||||
defer api.WebsocketWaitGroup.Done()
|
||||
// Here we pass the background context, since we want the server to keep serving until the
|
||||
// client hangs up. The aibridged is local, in-mem, so there isn't a danger of losing contact with it and
|
||||
// having a dead connection we don't know the status of.
|
||||
@@ -19,10 +19,10 @@ import (
|
||||
"github.com/coder/coder/v2/aibridge"
|
||||
"github.com/coder/coder/v2/aibridge/intercept"
|
||||
agplaibridge "github.com/coder/coder/v2/coderd/aibridge"
|
||||
"github.com/coder/coder/v2/coderd/aibridged"
|
||||
mock "github.com/coder/coder/v2/coderd/aibridged/aibridgedmock"
|
||||
"github.com/coder/coder/v2/coderd/aibridged/proto"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/enterprise/aibridged"
|
||||
mock "github.com/coder/coder/v2/enterprise/aibridged/aibridgedmock"
|
||||
"github.com/coder/coder/v2/enterprise/aibridged/proto"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
+3
-3
@@ -1,9 +1,9 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/coder/coder/v2/enterprise/aibridged (interfaces: DRPCClient)
|
||||
// Source: github.com/coder/coder/v2/coderd/aibridged (interfaces: DRPCClient)
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -destination ./clientmock.go -package aibridgedmock github.com/coder/coder/v2/enterprise/aibridged DRPCClient
|
||||
// mockgen -destination ./clientmock.go -package aibridgedmock github.com/coder/coder/v2/coderd/aibridged DRPCClient
|
||||
//
|
||||
|
||||
// Package aibridgedmock is a generated GoMock package.
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
proto "github.com/coder/coder/v2/enterprise/aibridged/proto"
|
||||
proto "github.com/coder/coder/v2/coderd/aibridged/proto"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
drpc "storj.io/drpc"
|
||||
)
|
||||
@@ -1,4 +1,4 @@
|
||||
package aibridgedmock
|
||||
|
||||
//go:generate go tool mockgen -destination ./clientmock.go -package aibridgedmock github.com/coder/coder/v2/enterprise/aibridged DRPCClient
|
||||
//go:generate go tool mockgen -destination ./poolmock.go -package aibridgedmock github.com/coder/coder/v2/enterprise/aibridged Pooler
|
||||
//go:generate go tool mockgen -destination ./clientmock.go -package aibridgedmock github.com/coder/coder/v2/coderd/aibridged DRPCClient
|
||||
//go:generate go tool mockgen -destination ./poolmock.go -package aibridgedmock github.com/coder/coder/v2/coderd/aibridged Pooler
|
||||
+3
-3
@@ -1,9 +1,9 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/coder/coder/v2/enterprise/aibridged (interfaces: Pooler)
|
||||
// Source: github.com/coder/coder/v2/coderd/aibridged (interfaces: Pooler)
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -destination ./poolmock.go -package aibridgedmock github.com/coder/coder/v2/enterprise/aibridged Pooler
|
||||
// mockgen -destination ./poolmock.go -package aibridgedmock github.com/coder/coder/v2/coderd/aibridged Pooler
|
||||
//
|
||||
|
||||
// Package aibridgedmock is a generated GoMock package.
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
http "net/http"
|
||||
reflect "reflect"
|
||||
|
||||
aibridged "github.com/coder/coder/v2/enterprise/aibridged"
|
||||
aibridged "github.com/coder/coder/v2/coderd/aibridged"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
|
||||
"storj.io/drpc"
|
||||
|
||||
"github.com/coder/coder/v2/enterprise/aibridged/proto"
|
||||
"github.com/coder/coder/v2/coderd/aibridged/proto"
|
||||
)
|
||||
|
||||
type Dialer func(ctx context.Context) (DRPCClient, error)
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"github.com/coder/coder/v2/aibridge"
|
||||
"github.com/coder/coder/v2/aibridge/recorder"
|
||||
agplaibridge "github.com/coder/coder/v2/coderd/aibridge"
|
||||
"github.com/coder/coder/v2/enterprise/aibridged/proto"
|
||||
"github.com/coder/coder/v2/coderd/aibridged/proto"
|
||||
)
|
||||
|
||||
var _ http.Handler = &Server{}
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/aibridge/mcp"
|
||||
"github.com/coder/coder/v2/enterprise/aibridged/proto"
|
||||
"github.com/coder/coder/v2/coderd/aibridged/proto"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel"
|
||||
|
||||
"github.com/coder/coder/v2/enterprise/aibridged/proto"
|
||||
"github.com/coder/coder/v2/coderd/aibridged/proto"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
@@ -14,8 +14,8 @@ import (
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/aibridge/mcp"
|
||||
"github.com/coder/coder/v2/aibridge/mcpmock"
|
||||
"github.com/coder/coder/v2/enterprise/aibridged"
|
||||
mock "github.com/coder/coder/v2/enterprise/aibridged/aibridgedmock"
|
||||
"github.com/coder/coder/v2/coderd/aibridged"
|
||||
mock "github.com/coder/coder/v2/coderd/aibridged/aibridgedmock"
|
||||
)
|
||||
|
||||
// TestPool validates the published behavior of [aibridged.CachedBridgePool].
|
||||
+387
-386
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,5 @@
|
||||
syntax = "proto3";
|
||||
option go_package = "github.com/coder/coder/v2/aibridged/proto";
|
||||
option go_package = "github.com/coder/coder/v2/coderd/aibridged/proto";
|
||||
|
||||
package proto;
|
||||
|
||||
+34
-34
@@ -1,6 +1,6 @@
|
||||
// Code generated by protoc-gen-go-drpc. DO NOT EDIT.
|
||||
// protoc-gen-go-drpc version: v0.0.34
|
||||
// source: enterprise/aibridged/proto/aibridged.proto
|
||||
// source: coderd/aibridged/proto/aibridged.proto
|
||||
|
||||
package proto
|
||||
|
||||
@@ -13,25 +13,25 @@ import (
|
||||
drpcerr "storj.io/drpc/drpcerr"
|
||||
)
|
||||
|
||||
type drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto struct{}
|
||||
type drpcEncoding_File_coderd_aibridged_proto_aibridged_proto struct{}
|
||||
|
||||
func (drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto) Marshal(msg drpc.Message) ([]byte, error) {
|
||||
func (drpcEncoding_File_coderd_aibridged_proto_aibridged_proto) Marshal(msg drpc.Message) ([]byte, error) {
|
||||
return proto.Marshal(msg.(proto.Message))
|
||||
}
|
||||
|
||||
func (drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto) MarshalAppend(buf []byte, msg drpc.Message) ([]byte, error) {
|
||||
func (drpcEncoding_File_coderd_aibridged_proto_aibridged_proto) MarshalAppend(buf []byte, msg drpc.Message) ([]byte, error) {
|
||||
return proto.MarshalOptions{}.MarshalAppend(buf, msg.(proto.Message))
|
||||
}
|
||||
|
||||
func (drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto) Unmarshal(buf []byte, msg drpc.Message) error {
|
||||
func (drpcEncoding_File_coderd_aibridged_proto_aibridged_proto) Unmarshal(buf []byte, msg drpc.Message) error {
|
||||
return proto.Unmarshal(buf, msg.(proto.Message))
|
||||
}
|
||||
|
||||
func (drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto) JSONMarshal(msg drpc.Message) ([]byte, error) {
|
||||
func (drpcEncoding_File_coderd_aibridged_proto_aibridged_proto) JSONMarshal(msg drpc.Message) ([]byte, error) {
|
||||
return protojson.Marshal(msg.(proto.Message))
|
||||
}
|
||||
|
||||
func (drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto) JSONUnmarshal(buf []byte, msg drpc.Message) error {
|
||||
func (drpcEncoding_File_coderd_aibridged_proto_aibridged_proto) JSONUnmarshal(buf []byte, msg drpc.Message) error {
|
||||
return protojson.Unmarshal(buf, msg.(proto.Message))
|
||||
}
|
||||
|
||||
@@ -58,7 +58,7 @@ func (c *drpcRecorderClient) DRPCConn() drpc.Conn { return c.cc }
|
||||
|
||||
func (c *drpcRecorderClient) RecordInterception(ctx context.Context, in *RecordInterceptionRequest) (*RecordInterceptionResponse, error) {
|
||||
out := new(RecordInterceptionResponse)
|
||||
err := c.cc.Invoke(ctx, "/proto.Recorder/RecordInterception", drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto{}, in, out)
|
||||
err := c.cc.Invoke(ctx, "/proto.Recorder/RecordInterception", drpcEncoding_File_coderd_aibridged_proto_aibridged_proto{}, in, out)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -67,7 +67,7 @@ func (c *drpcRecorderClient) RecordInterception(ctx context.Context, in *RecordI
|
||||
|
||||
func (c *drpcRecorderClient) RecordInterceptionEnded(ctx context.Context, in *RecordInterceptionEndedRequest) (*RecordInterceptionEndedResponse, error) {
|
||||
out := new(RecordInterceptionEndedResponse)
|
||||
err := c.cc.Invoke(ctx, "/proto.Recorder/RecordInterceptionEnded", drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto{}, in, out)
|
||||
err := c.cc.Invoke(ctx, "/proto.Recorder/RecordInterceptionEnded", drpcEncoding_File_coderd_aibridged_proto_aibridged_proto{}, in, out)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -76,7 +76,7 @@ func (c *drpcRecorderClient) RecordInterceptionEnded(ctx context.Context, in *Re
|
||||
|
||||
func (c *drpcRecorderClient) RecordTokenUsage(ctx context.Context, in *RecordTokenUsageRequest) (*RecordTokenUsageResponse, error) {
|
||||
out := new(RecordTokenUsageResponse)
|
||||
err := c.cc.Invoke(ctx, "/proto.Recorder/RecordTokenUsage", drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto{}, in, out)
|
||||
err := c.cc.Invoke(ctx, "/proto.Recorder/RecordTokenUsage", drpcEncoding_File_coderd_aibridged_proto_aibridged_proto{}, in, out)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -85,7 +85,7 @@ func (c *drpcRecorderClient) RecordTokenUsage(ctx context.Context, in *RecordTok
|
||||
|
||||
func (c *drpcRecorderClient) RecordPromptUsage(ctx context.Context, in *RecordPromptUsageRequest) (*RecordPromptUsageResponse, error) {
|
||||
out := new(RecordPromptUsageResponse)
|
||||
err := c.cc.Invoke(ctx, "/proto.Recorder/RecordPromptUsage", drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto{}, in, out)
|
||||
err := c.cc.Invoke(ctx, "/proto.Recorder/RecordPromptUsage", drpcEncoding_File_coderd_aibridged_proto_aibridged_proto{}, in, out)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -94,7 +94,7 @@ func (c *drpcRecorderClient) RecordPromptUsage(ctx context.Context, in *RecordPr
|
||||
|
||||
func (c *drpcRecorderClient) RecordToolUsage(ctx context.Context, in *RecordToolUsageRequest) (*RecordToolUsageResponse, error) {
|
||||
out := new(RecordToolUsageResponse)
|
||||
err := c.cc.Invoke(ctx, "/proto.Recorder/RecordToolUsage", drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto{}, in, out)
|
||||
err := c.cc.Invoke(ctx, "/proto.Recorder/RecordToolUsage", drpcEncoding_File_coderd_aibridged_proto_aibridged_proto{}, in, out)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -103,7 +103,7 @@ func (c *drpcRecorderClient) RecordToolUsage(ctx context.Context, in *RecordTool
|
||||
|
||||
func (c *drpcRecorderClient) RecordModelThought(ctx context.Context, in *RecordModelThoughtRequest) (*RecordModelThoughtResponse, error) {
|
||||
out := new(RecordModelThoughtResponse)
|
||||
err := c.cc.Invoke(ctx, "/proto.Recorder/RecordModelThought", drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto{}, in, out)
|
||||
err := c.cc.Invoke(ctx, "/proto.Recorder/RecordModelThought", drpcEncoding_File_coderd_aibridged_proto_aibridged_proto{}, in, out)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -152,7 +152,7 @@ func (DRPCRecorderDescription) NumMethods() int { return 6 }
|
||||
func (DRPCRecorderDescription) Method(n int) (string, drpc.Encoding, drpc.Receiver, interface{}, bool) {
|
||||
switch n {
|
||||
case 0:
|
||||
return "/proto.Recorder/RecordInterception", drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto{},
|
||||
return "/proto.Recorder/RecordInterception", drpcEncoding_File_coderd_aibridged_proto_aibridged_proto{},
|
||||
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
|
||||
return srv.(DRPCRecorderServer).
|
||||
RecordInterception(
|
||||
@@ -161,7 +161,7 @@ func (DRPCRecorderDescription) Method(n int) (string, drpc.Encoding, drpc.Receiv
|
||||
)
|
||||
}, DRPCRecorderServer.RecordInterception, true
|
||||
case 1:
|
||||
return "/proto.Recorder/RecordInterceptionEnded", drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto{},
|
||||
return "/proto.Recorder/RecordInterceptionEnded", drpcEncoding_File_coderd_aibridged_proto_aibridged_proto{},
|
||||
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
|
||||
return srv.(DRPCRecorderServer).
|
||||
RecordInterceptionEnded(
|
||||
@@ -170,7 +170,7 @@ func (DRPCRecorderDescription) Method(n int) (string, drpc.Encoding, drpc.Receiv
|
||||
)
|
||||
}, DRPCRecorderServer.RecordInterceptionEnded, true
|
||||
case 2:
|
||||
return "/proto.Recorder/RecordTokenUsage", drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto{},
|
||||
return "/proto.Recorder/RecordTokenUsage", drpcEncoding_File_coderd_aibridged_proto_aibridged_proto{},
|
||||
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
|
||||
return srv.(DRPCRecorderServer).
|
||||
RecordTokenUsage(
|
||||
@@ -179,7 +179,7 @@ func (DRPCRecorderDescription) Method(n int) (string, drpc.Encoding, drpc.Receiv
|
||||
)
|
||||
}, DRPCRecorderServer.RecordTokenUsage, true
|
||||
case 3:
|
||||
return "/proto.Recorder/RecordPromptUsage", drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto{},
|
||||
return "/proto.Recorder/RecordPromptUsage", drpcEncoding_File_coderd_aibridged_proto_aibridged_proto{},
|
||||
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
|
||||
return srv.(DRPCRecorderServer).
|
||||
RecordPromptUsage(
|
||||
@@ -188,7 +188,7 @@ func (DRPCRecorderDescription) Method(n int) (string, drpc.Encoding, drpc.Receiv
|
||||
)
|
||||
}, DRPCRecorderServer.RecordPromptUsage, true
|
||||
case 4:
|
||||
return "/proto.Recorder/RecordToolUsage", drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto{},
|
||||
return "/proto.Recorder/RecordToolUsage", drpcEncoding_File_coderd_aibridged_proto_aibridged_proto{},
|
||||
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
|
||||
return srv.(DRPCRecorderServer).
|
||||
RecordToolUsage(
|
||||
@@ -197,7 +197,7 @@ func (DRPCRecorderDescription) Method(n int) (string, drpc.Encoding, drpc.Receiv
|
||||
)
|
||||
}, DRPCRecorderServer.RecordToolUsage, true
|
||||
case 5:
|
||||
return "/proto.Recorder/RecordModelThought", drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto{},
|
||||
return "/proto.Recorder/RecordModelThought", drpcEncoding_File_coderd_aibridged_proto_aibridged_proto{},
|
||||
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
|
||||
return srv.(DRPCRecorderServer).
|
||||
RecordModelThought(
|
||||
@@ -224,7 +224,7 @@ type drpcRecorder_RecordInterceptionStream struct {
|
||||
}
|
||||
|
||||
func (x *drpcRecorder_RecordInterceptionStream) SendAndClose(m *RecordInterceptionResponse) error {
|
||||
if err := x.MsgSend(m, drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto{}); err != nil {
|
||||
if err := x.MsgSend(m, drpcEncoding_File_coderd_aibridged_proto_aibridged_proto{}); err != nil {
|
||||
return err
|
||||
}
|
||||
return x.CloseSend()
|
||||
@@ -240,7 +240,7 @@ type drpcRecorder_RecordInterceptionEndedStream struct {
|
||||
}
|
||||
|
||||
func (x *drpcRecorder_RecordInterceptionEndedStream) SendAndClose(m *RecordInterceptionEndedResponse) error {
|
||||
if err := x.MsgSend(m, drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto{}); err != nil {
|
||||
if err := x.MsgSend(m, drpcEncoding_File_coderd_aibridged_proto_aibridged_proto{}); err != nil {
|
||||
return err
|
||||
}
|
||||
return x.CloseSend()
|
||||
@@ -256,7 +256,7 @@ type drpcRecorder_RecordTokenUsageStream struct {
|
||||
}
|
||||
|
||||
func (x *drpcRecorder_RecordTokenUsageStream) SendAndClose(m *RecordTokenUsageResponse) error {
|
||||
if err := x.MsgSend(m, drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto{}); err != nil {
|
||||
if err := x.MsgSend(m, drpcEncoding_File_coderd_aibridged_proto_aibridged_proto{}); err != nil {
|
||||
return err
|
||||
}
|
||||
return x.CloseSend()
|
||||
@@ -272,7 +272,7 @@ type drpcRecorder_RecordPromptUsageStream struct {
|
||||
}
|
||||
|
||||
func (x *drpcRecorder_RecordPromptUsageStream) SendAndClose(m *RecordPromptUsageResponse) error {
|
||||
if err := x.MsgSend(m, drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto{}); err != nil {
|
||||
if err := x.MsgSend(m, drpcEncoding_File_coderd_aibridged_proto_aibridged_proto{}); err != nil {
|
||||
return err
|
||||
}
|
||||
return x.CloseSend()
|
||||
@@ -288,7 +288,7 @@ type drpcRecorder_RecordToolUsageStream struct {
|
||||
}
|
||||
|
||||
func (x *drpcRecorder_RecordToolUsageStream) SendAndClose(m *RecordToolUsageResponse) error {
|
||||
if err := x.MsgSend(m, drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto{}); err != nil {
|
||||
if err := x.MsgSend(m, drpcEncoding_File_coderd_aibridged_proto_aibridged_proto{}); err != nil {
|
||||
return err
|
||||
}
|
||||
return x.CloseSend()
|
||||
@@ -304,7 +304,7 @@ type drpcRecorder_RecordModelThoughtStream struct {
|
||||
}
|
||||
|
||||
func (x *drpcRecorder_RecordModelThoughtStream) SendAndClose(m *RecordModelThoughtResponse) error {
|
||||
if err := x.MsgSend(m, drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto{}); err != nil {
|
||||
if err := x.MsgSend(m, drpcEncoding_File_coderd_aibridged_proto_aibridged_proto{}); err != nil {
|
||||
return err
|
||||
}
|
||||
return x.CloseSend()
|
||||
@@ -329,7 +329,7 @@ func (c *drpcMCPConfiguratorClient) DRPCConn() drpc.Conn { return c.cc }
|
||||
|
||||
func (c *drpcMCPConfiguratorClient) GetMCPServerConfigs(ctx context.Context, in *GetMCPServerConfigsRequest) (*GetMCPServerConfigsResponse, error) {
|
||||
out := new(GetMCPServerConfigsResponse)
|
||||
err := c.cc.Invoke(ctx, "/proto.MCPConfigurator/GetMCPServerConfigs", drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto{}, in, out)
|
||||
err := c.cc.Invoke(ctx, "/proto.MCPConfigurator/GetMCPServerConfigs", drpcEncoding_File_coderd_aibridged_proto_aibridged_proto{}, in, out)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -338,7 +338,7 @@ func (c *drpcMCPConfiguratorClient) GetMCPServerConfigs(ctx context.Context, in
|
||||
|
||||
func (c *drpcMCPConfiguratorClient) GetMCPServerAccessTokensBatch(ctx context.Context, in *GetMCPServerAccessTokensBatchRequest) (*GetMCPServerAccessTokensBatchResponse, error) {
|
||||
out := new(GetMCPServerAccessTokensBatchResponse)
|
||||
err := c.cc.Invoke(ctx, "/proto.MCPConfigurator/GetMCPServerAccessTokensBatch", drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto{}, in, out)
|
||||
err := c.cc.Invoke(ctx, "/proto.MCPConfigurator/GetMCPServerAccessTokensBatch", drpcEncoding_File_coderd_aibridged_proto_aibridged_proto{}, in, out)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -367,7 +367,7 @@ func (DRPCMCPConfiguratorDescription) NumMethods() int { return 2 }
|
||||
func (DRPCMCPConfiguratorDescription) Method(n int) (string, drpc.Encoding, drpc.Receiver, interface{}, bool) {
|
||||
switch n {
|
||||
case 0:
|
||||
return "/proto.MCPConfigurator/GetMCPServerConfigs", drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto{},
|
||||
return "/proto.MCPConfigurator/GetMCPServerConfigs", drpcEncoding_File_coderd_aibridged_proto_aibridged_proto{},
|
||||
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
|
||||
return srv.(DRPCMCPConfiguratorServer).
|
||||
GetMCPServerConfigs(
|
||||
@@ -376,7 +376,7 @@ func (DRPCMCPConfiguratorDescription) Method(n int) (string, drpc.Encoding, drpc
|
||||
)
|
||||
}, DRPCMCPConfiguratorServer.GetMCPServerConfigs, true
|
||||
case 1:
|
||||
return "/proto.MCPConfigurator/GetMCPServerAccessTokensBatch", drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto{},
|
||||
return "/proto.MCPConfigurator/GetMCPServerAccessTokensBatch", drpcEncoding_File_coderd_aibridged_proto_aibridged_proto{},
|
||||
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
|
||||
return srv.(DRPCMCPConfiguratorServer).
|
||||
GetMCPServerAccessTokensBatch(
|
||||
@@ -403,7 +403,7 @@ type drpcMCPConfigurator_GetMCPServerConfigsStream struct {
|
||||
}
|
||||
|
||||
func (x *drpcMCPConfigurator_GetMCPServerConfigsStream) SendAndClose(m *GetMCPServerConfigsResponse) error {
|
||||
if err := x.MsgSend(m, drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto{}); err != nil {
|
||||
if err := x.MsgSend(m, drpcEncoding_File_coderd_aibridged_proto_aibridged_proto{}); err != nil {
|
||||
return err
|
||||
}
|
||||
return x.CloseSend()
|
||||
@@ -419,7 +419,7 @@ type drpcMCPConfigurator_GetMCPServerAccessTokensBatchStream struct {
|
||||
}
|
||||
|
||||
func (x *drpcMCPConfigurator_GetMCPServerAccessTokensBatchStream) SendAndClose(m *GetMCPServerAccessTokensBatchResponse) error {
|
||||
if err := x.MsgSend(m, drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto{}); err != nil {
|
||||
if err := x.MsgSend(m, drpcEncoding_File_coderd_aibridged_proto_aibridged_proto{}); err != nil {
|
||||
return err
|
||||
}
|
||||
return x.CloseSend()
|
||||
@@ -443,7 +443,7 @@ func (c *drpcAuthorizerClient) DRPCConn() drpc.Conn { return c.cc }
|
||||
|
||||
func (c *drpcAuthorizerClient) IsAuthorized(ctx context.Context, in *IsAuthorizedRequest) (*IsAuthorizedResponse, error) {
|
||||
out := new(IsAuthorizedResponse)
|
||||
err := c.cc.Invoke(ctx, "/proto.Authorizer/IsAuthorized", drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto{}, in, out)
|
||||
err := c.cc.Invoke(ctx, "/proto.Authorizer/IsAuthorized", drpcEncoding_File_coderd_aibridged_proto_aibridged_proto{}, in, out)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -467,7 +467,7 @@ func (DRPCAuthorizerDescription) NumMethods() int { return 1 }
|
||||
func (DRPCAuthorizerDescription) Method(n int) (string, drpc.Encoding, drpc.Receiver, interface{}, bool) {
|
||||
switch n {
|
||||
case 0:
|
||||
return "/proto.Authorizer/IsAuthorized", drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto{},
|
||||
return "/proto.Authorizer/IsAuthorized", drpcEncoding_File_coderd_aibridged_proto_aibridged_proto{},
|
||||
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
|
||||
return srv.(DRPCAuthorizerServer).
|
||||
IsAuthorized(
|
||||
@@ -494,7 +494,7 @@ type drpcAuthorizer_IsAuthorizedStream struct {
|
||||
}
|
||||
|
||||
func (x *drpcAuthorizer_IsAuthorizedStream) SendAndClose(m *IsAuthorizedResponse) error {
|
||||
if err := x.MsgSend(m, drpcEncoding_File_enterprise_aibridged_proto_aibridged_proto{}); err != nil {
|
||||
if err := x.MsgSend(m, drpcEncoding_File_coderd_aibridged_proto_aibridged_proto{}); err != nil {
|
||||
return err
|
||||
}
|
||||
return x.CloseSend()
|
||||
@@ -1,6 +1,6 @@
|
||||
package aibridged
|
||||
|
||||
import "github.com/coder/coder/v2/enterprise/aibridged/proto"
|
||||
import "github.com/coder/coder/v2/coderd/aibridged/proto"
|
||||
|
||||
type DRPCServer interface {
|
||||
proto.DRPCRecorderServer
|
||||
@@ -11,8 +11,8 @@ import (
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/coder/coder/v2/aibridge"
|
||||
"github.com/coder/coder/v2/coderd/aibridged/proto"
|
||||
"github.com/coder/coder/v2/coderd/util/ptr"
|
||||
"github.com/coder/coder/v2/enterprise/aibridged/proto"
|
||||
)
|
||||
|
||||
var _ aibridge.Recorder = &recorderTranslation{}
|
||||
@@ -3,8 +3,12 @@ package aibridged_test
|
||||
import (
|
||||
"net/http"
|
||||
"sync/atomic"
|
||||
|
||||
"go.opentelemetry.io/otel"
|
||||
)
|
||||
|
||||
var testTracer = otel.Tracer("aibridged_test")
|
||||
|
||||
var _ http.Handler = &mockAIUpstreamServer{}
|
||||
|
||||
type mockAIUpstreamServer struct {
|
||||
+2
-2
@@ -16,6 +16,8 @@ import (
|
||||
"google.golang.org/protobuf/types/known/structpb"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/aibridged"
|
||||
"github.com/coder/coder/v2/coderd/aibridged/proto"
|
||||
"github.com/coder/coder/v2/coderd/aiseats"
|
||||
"github.com/coder/coder/v2/coderd/apikey"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
@@ -25,8 +27,6 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
codermcp "github.com/coder/coder/v2/coderd/mcp"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/enterprise/aibridged"
|
||||
"github.com/coder/coder/v2/enterprise/aibridged/proto"
|
||||
)
|
||||
|
||||
var (
|
||||
+3
-3
@@ -24,6 +24,9 @@ import (
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/slogjson"
|
||||
"github.com/coder/coder/v2/coderd/aibridged"
|
||||
"github.com/coder/coder/v2/coderd/aibridged/proto"
|
||||
"github.com/coder/coder/v2/coderd/aibridgedserver"
|
||||
agplaiseats "github.com/coder/coder/v2/coderd/aiseats"
|
||||
"github.com/coder/coder/v2/coderd/apikey"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
@@ -36,9 +39,6 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/util/ptr"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/cryptorand"
|
||||
"github.com/coder/coder/v2/enterprise/aibridged"
|
||||
"github.com/coder/coder/v2/enterprise/aibridged/proto"
|
||||
"github.com/coder/coder/v2/enterprise/aibridgedserver"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
@@ -2198,6 +2198,10 @@ type API struct {
|
||||
// UsageInserter is a pointer to an atomic pointer because it is passed to
|
||||
// multiple components.
|
||||
UsageInserter *atomic.Pointer[usage.Inserter]
|
||||
// aibridgedHandler is the in-memory aibridge HTTP handler. Set by
|
||||
// RegisterInMemoryAIBridgedHTTPHandler; read by the enterprise
|
||||
// /api/v2/aibridge route (license-gated).
|
||||
aibridgedHandler http.Handler
|
||||
|
||||
UpdatesProvider tailnet.WorkspaceUpdatesProvider
|
||||
|
||||
|
||||
+7
-7
@@ -1,4 +1,4 @@
|
||||
package aibridged_test
|
||||
package enterprise_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
@@ -22,6 +22,8 @@ import (
|
||||
"github.com/coder/coder/v2/aibridge"
|
||||
"github.com/coder/coder/v2/aibridge/config"
|
||||
aibtracing "github.com/coder/coder/v2/aibridge/tracing"
|
||||
"github.com/coder/coder/v2/coderd/aibridged"
|
||||
"github.com/coder/coder/v2/coderd/aibridgedserver"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
@@ -29,13 +31,11 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/externalauth"
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/enterprise/aibridged"
|
||||
"github.com/coder/coder/v2/enterprise/aibridgedserver"
|
||||
"github.com/coder/coder/v2/enterprise/coderd/coderdenttest"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
var testTracer = otel.Tracer("aibridged_test")
|
||||
var testTracer = otel.Tracer("aibridged_inttest")
|
||||
|
||||
// TestIntegration is not an exhaustive test against the upstream AI providers' SDKs (see coder/aibridge for those).
|
||||
// This test validates that:
|
||||
@@ -179,7 +179,7 @@ func TestIntegration(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create aibridge server & client.
|
||||
aiBridgeClient, err := api.CreateInMemoryAIBridgeServer(ctx)
|
||||
aiBridgeClient, err := api.AGPL.CreateInMemoryAIBridgeServer(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
logger := testutil.Logger(t)
|
||||
@@ -379,7 +379,7 @@ func TestIntegrationWithMetrics(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create aibridge client.
|
||||
aiBridgeClient, err := api.CreateInMemoryAIBridgeServer(ctx)
|
||||
aiBridgeClient, err := api.AGPL.CreateInMemoryAIBridgeServer(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
logger := testutil.Logger(t)
|
||||
@@ -476,7 +476,7 @@ func TestIntegrationCircuitBreaker(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create aibridge client.
|
||||
aiBridgeClient, err := api.CreateInMemoryAIBridgeServer(ctx)
|
||||
aiBridgeClient, err := api.AGPL.CreateInMemoryAIBridgeServer(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
logger := testutil.Logger(t)
|
||||
@@ -0,0 +1,84 @@
|
||||
//go:build !slim
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/aibridge"
|
||||
agplcli "github.com/coder/coder/v2/cli"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
func TestDomainsFromProviders(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("ExtractsHostnames", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
providers, err := agplcli.BuildProviders(codersdk.AIBridgeConfig{
|
||||
Providers: []codersdk.AIProviderConfig{
|
||||
{Type: aibridge.ProviderOpenAI, Name: "openai", Keys: []string{"k"}},
|
||||
{Type: aibridge.ProviderAnthropic, Name: "anthropic", Keys: []string{"k"}},
|
||||
{Type: aibridge.ProviderOpenAI, Name: "custom", Keys: []string{"k"}, BaseURL: "https://custom-llm.example.com:8443/api"},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
domains, mapping := domainsFromProviders(providers)
|
||||
|
||||
assert.Contains(t, domains, "api.openai.com")
|
||||
assert.Contains(t, domains, "api.anthropic.com")
|
||||
assert.Contains(t, domains, "custom-llm.example.com")
|
||||
|
||||
assert.Equal(t, "openai", mapping("api.openai.com"))
|
||||
assert.Equal(t, "anthropic", mapping("api.anthropic.com"))
|
||||
assert.Equal(t, "custom", mapping("custom-llm.example.com"))
|
||||
assert.Empty(t, mapping("unknown.com"))
|
||||
})
|
||||
|
||||
t.Run("DeduplicatesSameHost", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
providers, err := agplcli.BuildProviders(codersdk.AIBridgeConfig{
|
||||
Providers: []codersdk.AIProviderConfig{
|
||||
{Type: aibridge.ProviderOpenAI, Name: "first", Keys: []string{"k"}, BaseURL: "https://api.example.com/v1"},
|
||||
{Type: aibridge.ProviderOpenAI, Name: "second", Keys: []string{"k"}, BaseURL: "https://api.example.com/v2"},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
domains, mapping := domainsFromProviders(providers)
|
||||
|
||||
// Count occurrences of api.example.com.
|
||||
count := 0
|
||||
for _, d := range domains {
|
||||
if d == "api.example.com" {
|
||||
count++
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 1, count)
|
||||
// First provider wins.
|
||||
assert.Equal(t, "first", mapping("api.example.com"))
|
||||
})
|
||||
|
||||
t.Run("CaseInsensitive", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
providers, err := agplcli.BuildProviders(codersdk.AIBridgeConfig{
|
||||
Providers: []codersdk.AIProviderConfig{
|
||||
{Type: aibridge.ProviderOpenAI, Name: "provider", Keys: []string{"k"}, BaseURL: "https://API.Example.COM/v1"},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
domains, mapping := domainsFromProviders(providers)
|
||||
|
||||
assert.Contains(t, domains, "api.example.com")
|
||||
assert.Equal(t, "provider", mapping("API.Example.COM"))
|
||||
assert.Equal(t, "provider", mapping("api.example.com"))
|
||||
})
|
||||
}
|
||||
+14
-40
@@ -15,6 +15,7 @@ import (
|
||||
"tailscale.com/derp"
|
||||
"tailscale.com/types/key"
|
||||
|
||||
agplcli "github.com/coder/coder/v2/cli"
|
||||
agplcoderd "github.com/coder/coder/v2/coderd"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/cryptorand"
|
||||
@@ -161,51 +162,24 @@ func (r *RootCmd) Server(_ func()) *serpent.Command {
|
||||
usageCron.Start(ctx)
|
||||
closers.Add(usageCron)
|
||||
|
||||
// Build the provider list and start AI Bridge daemons only when
|
||||
// at least one of the bridge or proxy features is enabled. The
|
||||
// ai_providers env-seed runs unconditionally in the AGPL
|
||||
// codepath, regardless of whether bridge or proxy are enabled.
|
||||
bridgeEnabled := options.DeploymentValues.AI.BridgeConfig.Enabled.Value()
|
||||
proxyEnabled := options.DeploymentValues.AI.BridgeProxyConfig.Enabled.Value()
|
||||
if bridgeEnabled || proxyEnabled {
|
||||
providers, err := buildProviders(options.DeploymentValues.AI.BridgeConfig)
|
||||
// In-memory AI Bridge Proxy daemon. The bridge daemon itself is
|
||||
// started unconditionally by AGPL cli/server.go (chatd uses its
|
||||
// in-memory roundtripper regardless of license); only the proxy
|
||||
// daemon remains enterprise-gated by config.
|
||||
if options.DeploymentValues.AI.BridgeProxyConfig.Enabled.Value() {
|
||||
providers, err := agplcli.BuildProviders(options.DeploymentValues.AI.BridgeConfig)
|
||||
if err != nil {
|
||||
return nil, nil, xerrors.Errorf("build AI providers: %w", err)
|
||||
}
|
||||
|
||||
// In-memory aibridge daemon.
|
||||
// TODO(@deansheather): the lifecycle of the aibridged server is
|
||||
// probably better managed by the enterprise API type itself. Managing
|
||||
// it in the API type means we can avoid starting it up when the license
|
||||
// is not entitled to the feature.
|
||||
if bridgeEnabled {
|
||||
aibridgeDaemon, err := newAIBridgeDaemon(api, providers)
|
||||
if err != nil {
|
||||
return nil, nil, xerrors.Errorf("create aibridged: %w", err)
|
||||
}
|
||||
|
||||
api.RegisterInMemoryAIBridgedHTTPHandler(aibridgeDaemon)
|
||||
|
||||
// When running as an in-memory daemon, the HTTP handler is
|
||||
// wired into the coderd API and therefore is subject to its
|
||||
// context. Calling Close() on aibridged will NOT affect
|
||||
// in-flight requests but those will be closed once the API
|
||||
// server is itself shutdown.
|
||||
closers.Add(aibridgeDaemon)
|
||||
aiBridgeProxyServer, err := newAIBridgeProxyDaemon(api, providers)
|
||||
if err != nil {
|
||||
_ = closers.Close()
|
||||
return nil, nil, xerrors.Errorf("create aibridgeproxyd: %w", err)
|
||||
}
|
||||
closers.Add(aiBridgeProxyServer)
|
||||
|
||||
// In-memory AI Bridge Proxy daemon.
|
||||
if proxyEnabled {
|
||||
aiBridgeProxyServer, err := newAIBridgeProxyDaemon(api, providers)
|
||||
if err != nil {
|
||||
_ = closers.Close()
|
||||
return nil, nil, xerrors.Errorf("create aibridgeproxyd: %w", err)
|
||||
}
|
||||
closers.Add(aiBridgeProxyServer)
|
||||
|
||||
// Register the handler so coderd can serve the proxy endpoints.
|
||||
api.RegisterInMemoryAIBridgeProxydHTTPHandler(aiBridgeProxyServer.Handler())
|
||||
}
|
||||
// Register the handler so coderd can serve the proxy endpoints.
|
||||
api.RegisterInMemoryAIBridgeProxydHTTPHandler(aiBridgeProxyServer.Handler())
|
||||
}
|
||||
|
||||
return api.AGPL, closers, nil
|
||||
|
||||
@@ -69,7 +69,7 @@ func aibridgeHandler(api *API, middlewares ...func(http.Handler) http.Handler) f
|
||||
// This is a bit funky but since aibridge only exposes a HTTP
|
||||
// handler, this is how it has to be.
|
||||
r.HandleFunc("/*", func(rw http.ResponseWriter, r *http.Request) {
|
||||
if api.aibridgedHandler == nil {
|
||||
if api.AGPL.GetAIBridgedHandler() == nil {
|
||||
httpapi.Write(r.Context(), rw, http.StatusNotFound, codersdk.Response{
|
||||
Message: "aibridged handler not mounted",
|
||||
})
|
||||
@@ -86,7 +86,7 @@ func aibridgeHandler(api *API, middlewares ...func(http.Handler) http.Handler) f
|
||||
return
|
||||
}
|
||||
|
||||
http.StripPrefix("/api/v2/aibridge", api.aibridgedHandler).ServeHTTP(rw, r)
|
||||
http.StripPrefix("/api/v2/aibridge", api.AGPL.GetAIBridgedHandler()).ServeHTTP(rw, r)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1844,7 +1844,7 @@ func TestAIBridgeRouting(t *testing.T) {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
_, _ = rw.Write([]byte(r.URL.Path))
|
||||
})
|
||||
api.RegisterInMemoryAIBridgedHTTPHandler(testHandler)
|
||||
api.AGPL.RegisterInMemoryAIBridgedHTTPHandler(testHandler)
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
@@ -1907,7 +1907,7 @@ func TestAIBridgeRateLimiting(t *testing.T) {
|
||||
testHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
})
|
||||
api.RegisterInMemoryAIBridgedHTTPHandler(testHandler)
|
||||
api.AGPL.RegisterInMemoryAIBridgedHTTPHandler(testHandler)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
httpClient := &http.Client{}
|
||||
@@ -1967,7 +1967,7 @@ func TestAIBridgeConcurrencyLimiting(t *testing.T) {
|
||||
<-unblock
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
})
|
||||
api.RegisterInMemoryAIBridgedHTTPHandler(testHandler)
|
||||
api.AGPL.RegisterInMemoryAIBridgedHTTPHandler(testHandler)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
httpClient := &http.Client{}
|
||||
@@ -2583,7 +2583,7 @@ func TestAIBridgeAllowBYOK(t *testing.T) {
|
||||
testHandler := http.HandlerFunc(func(rw http.ResponseWriter, _ *http.Request) {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
})
|
||||
api.RegisterInMemoryAIBridgedHTTPHandler(testHandler)
|
||||
api.AGPL.RegisterInMemoryAIBridgedHTTPHandler(testHandler)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
reqURL := client.URL.String() + "/api/v2/aibridge/test"
|
||||
|
||||
@@ -797,7 +797,6 @@ type API struct {
|
||||
licenseMetricsCollector *license.MetricsCollector
|
||||
tailnetService *tailnet.ClientService
|
||||
|
||||
aibridgedHandler http.Handler
|
||||
aibridgeproxydHandler http.Handler
|
||||
aiSeatTracker *aiseats.SeatTracker
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user