mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
eddd4a8c2f
Allows an `api_key_id` to be passed from a trusted in-memory transport (currently: `chatd`) to `aibridged` for use in authenticating LLM requests. This value can _only_ be passed via context, and all users of the in-memory transport _must_ provide it. It can be used in conjunction with BYOK headers. --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
206 lines
6.9 KiB
Go
206 lines
6.9 KiB
Go
package aibridged
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"regexp"
|
|
"time"
|
|
|
|
"go.opentelemetry.io/otel/trace"
|
|
"golang.org/x/xerrors"
|
|
|
|
"cdr.dev/slog/v3"
|
|
"github.com/coder/coder/v2/aibridge/mcp"
|
|
"github.com/coder/coder/v2/coderd/aibridged/proto"
|
|
)
|
|
|
|
var (
|
|
ErrEmptyConfig = xerrors.New("empty config given")
|
|
ErrCompileRegex = xerrors.New("compile tool regex")
|
|
)
|
|
|
|
const (
|
|
InternalMCPServerID = "coder"
|
|
)
|
|
|
|
// Deprecated: Injected MCP in AI Bridge is deprecated and will be removed in a future release.
|
|
type MCPProxyBuilder interface {
|
|
// Build creates a [mcp.ServerProxier] for the given request initiator.
|
|
// At minimum, the Coder MCP server will be proxied.
|
|
// The SessionKey from [Request] is used to authenticate against the Coder MCP server.
|
|
//
|
|
// NOTE: the [mcp.ServerProxier] instance may be proxying one or more MCP servers.
|
|
Build(ctx context.Context, req Request, tracer trace.Tracer) (mcp.ServerProxier, error)
|
|
}
|
|
|
|
var _ MCPProxyBuilder = &MCPProxyFactory{}
|
|
|
|
// Deprecated: Injected MCP in AI Bridge is deprecated and will be removed in a future release.
|
|
type MCPProxyFactory struct {
|
|
logger slog.Logger
|
|
tracer trace.Tracer
|
|
clientFn ClientFunc
|
|
}
|
|
|
|
func NewMCPProxyFactory(logger slog.Logger, tracer trace.Tracer, clientFn ClientFunc) *MCPProxyFactory {
|
|
return &MCPProxyFactory{
|
|
logger: logger,
|
|
tracer: tracer,
|
|
clientFn: clientFn,
|
|
}
|
|
}
|
|
|
|
func (m *MCPProxyFactory) Build(ctx context.Context, req Request, tracer trace.Tracer) (mcp.ServerProxier, error) {
|
|
proxiers, err := m.retrieveMCPServerConfigs(ctx, req)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("resolve configs: %w", err)
|
|
}
|
|
|
|
return mcp.NewServerProxyManager(proxiers, tracer), nil
|
|
}
|
|
|
|
func (m *MCPProxyFactory) retrieveMCPServerConfigs(ctx context.Context, req Request) (map[string]mcp.ServerProxier, error) {
|
|
client, err := m.clientFn()
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("acquire client: %w", err)
|
|
}
|
|
|
|
srvCfgCtx, srvCfgCancel := context.WithTimeout(ctx, time.Second*10)
|
|
defer srvCfgCancel()
|
|
|
|
// Fetch MCP server configs.
|
|
mcpSrvCfgs, err := client.GetMCPServerConfigs(srvCfgCtx, &proto.GetMCPServerConfigsRequest{
|
|
UserId: req.InitiatorID.String(),
|
|
})
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("get MCP server configs: %w", err)
|
|
}
|
|
|
|
proxiers := make(map[string]mcp.ServerProxier, len(mcpSrvCfgs.GetExternalAuthMcpConfigs())+1) // Extra one for Coder MCP server.
|
|
|
|
if mcpSrvCfgs.GetCoderMcpConfig() != nil {
|
|
// Delegated callers (e.g., chatd) do not hold the user's API key
|
|
// secret and so cannot authenticate against the Coder MCP server.
|
|
// Skip the proxy in that case rather than attempting a connection
|
|
// with an empty bearer token, which will fail upstream.
|
|
if req.SessionKey == "" {
|
|
m.logger.Debug(ctx, "skipping Coder MCP server proxy: no session key (delegated request)", slog.F("mcp_server_id", mcpSrvCfgs.GetCoderMcpConfig().GetId()))
|
|
} else {
|
|
// Setup the Coder MCP server proxy.
|
|
coderMCPProxy, err := m.newStreamableHTTPServerProxy(mcpSrvCfgs.GetCoderMcpConfig(), req.SessionKey) // The session key is used to auth against our internal MCP server.
|
|
if err != nil {
|
|
m.logger.Warn(ctx, "failed to create MCP server proxy", slog.F("mcp_server_id", mcpSrvCfgs.GetCoderMcpConfig().GetId()), slog.Error(err))
|
|
} else {
|
|
proxiers[InternalMCPServerID] = coderMCPProxy
|
|
}
|
|
}
|
|
}
|
|
|
|
if len(mcpSrvCfgs.GetExternalAuthMcpConfigs()) == 0 {
|
|
return proxiers, nil
|
|
}
|
|
|
|
serverIDs := make([]string, 0, len(mcpSrvCfgs.GetExternalAuthMcpConfigs()))
|
|
for _, cfg := range mcpSrvCfgs.GetExternalAuthMcpConfigs() {
|
|
serverIDs = append(serverIDs, cfg.GetId())
|
|
}
|
|
|
|
accTokCtx, accTokCancel := context.WithTimeout(ctx, time.Second*10)
|
|
defer accTokCancel()
|
|
|
|
// Request a batch of access tokens, one per given server ID.
|
|
resp, err := client.GetMCPServerAccessTokensBatch(accTokCtx, &proto.GetMCPServerAccessTokensBatchRequest{
|
|
UserId: req.InitiatorID.String(),
|
|
McpServerConfigIds: serverIDs,
|
|
})
|
|
if err != nil {
|
|
m.logger.Warn(ctx, "failed to retrieve access token(s)", slog.F("server_ids", serverIDs), slog.Error(err))
|
|
}
|
|
|
|
if resp == nil {
|
|
m.logger.Warn(ctx, "nil response given to mcp access tokens call")
|
|
return proxiers, nil
|
|
}
|
|
tokens := resp.GetAccessTokens()
|
|
if len(tokens) == 0 {
|
|
return proxiers, nil
|
|
}
|
|
|
|
// Iterate over all External Auth configurations which are configured for MCP and attempt to setup
|
|
// a [mcp.ServerProxier] for it using the access token retrieved above.
|
|
for _, cfg := range mcpSrvCfgs.GetExternalAuthMcpConfigs() {
|
|
if err, ok := resp.GetErrors()[cfg.GetId()]; ok {
|
|
m.logger.Debug(ctx, "failed to get access token", slog.F("mcp_server_id", cfg.GetId()), slog.F("error", err))
|
|
continue
|
|
}
|
|
|
|
token, ok := tokens[cfg.GetId()]
|
|
if !ok {
|
|
m.logger.Warn(ctx, "no access token found", slog.F("mcp_server_id", cfg.GetId()))
|
|
continue
|
|
}
|
|
|
|
proxy, err := m.newStreamableHTTPServerProxy(cfg, token)
|
|
if err != nil {
|
|
m.logger.Warn(ctx, "failed to create MCP server proxy", slog.F("mcp_server_id", cfg.GetId()), slog.Error(err))
|
|
continue
|
|
}
|
|
|
|
proxiers[cfg.Id] = proxy
|
|
}
|
|
return proxiers, nil
|
|
}
|
|
|
|
// newStreamableHTTPServerProxy creates an MCP server capable of proxying requests using the Streamable HTTP transport.
|
|
//
|
|
// TODO: support SSE transport.
|
|
func (m *MCPProxyFactory) newStreamableHTTPServerProxy(cfg *proto.MCPServerConfig, accessToken string) (mcp.ServerProxier, error) {
|
|
if cfg == nil {
|
|
return nil, ErrEmptyConfig
|
|
}
|
|
|
|
var (
|
|
allowlist, denylist *regexp.Regexp
|
|
err error
|
|
)
|
|
if cfg.GetToolAllowRegex() != "" {
|
|
allowlist, err = regexp.Compile(cfg.GetToolAllowRegex())
|
|
if err != nil {
|
|
return nil, ErrCompileRegex
|
|
}
|
|
}
|
|
if cfg.GetToolDenyRegex() != "" {
|
|
denylist, err = regexp.Compile(cfg.GetToolDenyRegex())
|
|
if err != nil {
|
|
return nil, ErrCompileRegex
|
|
}
|
|
}
|
|
|
|
// TODO: future improvement:
|
|
//
|
|
// The access token provided here may expire at any time, or the connection to the MCP server could be severed.
|
|
// Instead of passing through an access token directly, rather provide an interface through which to retrieve
|
|
// an access token imperatively. In the event of a tool call failing, we could Ping() the MCP server to establish
|
|
// whether the connection is still active. If not, this indicates that the access token is probably expired/revoked.
|
|
// (It could also mean the server has a problem, which we should account for.)
|
|
// The proxy could then use its interface to retrieve a new access token and re-establish a connection.
|
|
// For now though, the short TTL of this cache should mostly mask this problem.
|
|
srv, err := mcp.NewStreamableHTTPServerProxy(
|
|
cfg.GetId(),
|
|
cfg.GetUrl(),
|
|
// See https://modelcontextprotocol.io/specification/2025-06-18/basic/authorization#token-requirements.
|
|
map[string]string{
|
|
"Authorization": fmt.Sprintf("Bearer %s", accessToken),
|
|
},
|
|
allowlist,
|
|
denylist,
|
|
m.logger.Named(fmt.Sprintf("mcp-server-proxy-%s", cfg.GetId())),
|
|
m.tracer,
|
|
)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("create streamable HTTP MCP server proxy: %w", err)
|
|
}
|
|
|
|
return srv, nil
|
|
}
|