Files
coder/aibridge/mcp/proxy_streamable_http.go
T
Paweł Banaszewski e00e85765b chore: move aibridge library code into coder repo (#24190)
This PR merges code from `coder/aibridge` repository into `coder/coder`.
It was split into 4 PRs for easier review but stacked PRs will need to
be merged into this PR so all checks pass.

* https://github.com/coder/coder/pull/24190 -> raw code copy (this PR,
before merging PRs on top of it, it was just 1 commit:
https://github.com/coder/coder/commit/70d33f33200c7e77df910957595715f81f9bec24)
* https://github.com/coder/coder/pull/24570 -> update imports in
`coder/coder` to use copied code
* https://github.com/coder/coder/pull/24586 -> linter fixes and CI
integration (also added README.md)
* https://github.com/coder/coder/pull/24571 -> added exclude to
scripts/check_emdash.sh check

Original PR message (before PR squash):
Moves coder/aibridge code into coder/coder repository.

Omitted files:

- `go.mod`, `go.sum`, `.gitignore`, `.github/workflows/ci.yml,`
`Makefile`, `LICENSE`, `README.md` (modified README.md is added later)
- `.github`, `example`, `buildinfo,` `scripts` directories

Simple verification script (will list omitted files)

```
tmp=$(mktemp -d)
echo "$tmp"
git clone --depth=1 https://github.com/coder/aibridge "$tmp/aibridge"
git clone --depth=1 --branch pb/aibridge-code-move https://github.com/coder/coder "$tmp/coder"
diff -rq --exclude=.git "$tmp/aibridge" "$tmp/coder/aibridge"
# rm -rf "$tmp"
```
2026-04-22 17:01:01 +02:00

181 lines
5.2 KiB
Go

package mcp
import (
"context"
"regexp"
"slices"
"strings"
"github.com/mark3labs/mcp-go/client"
"github.com/mark3labs/mcp-go/client/transport"
"github.com/mark3labs/mcp-go/mcp"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"golang.org/x/exp/maps"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/aibridge/tracing"
)
var _ ServerProxier = &StreamableHTTPServerProxy{}
type StreamableHTTPServerProxy struct {
client *client.Client
logger slog.Logger
tracer trace.Tracer
allowlistPattern *regexp.Regexp
denylistPattern *regexp.Regexp
serverName string
serverURL string
tools map[string]*Tool
}
func NewStreamableHTTPServerProxy(serverName, serverURL string, headers map[string]string, allowlist, denylist *regexp.Regexp, logger slog.Logger, tracer trace.Tracer, opts ...transport.StreamableHTTPCOption) (*StreamableHTTPServerProxy, error) {
// nit: headers should be passed in as an option instead of a separate parameter. Not changed as this would be a breaking change.
if headers != nil {
opts = append(opts, transport.WithHTTPHeaders(headers))
}
mcpClient, err := client.NewStreamableHttpClient(serverURL, opts...)
if err != nil {
return nil, xerrors.Errorf("create streamable http client: %w", err)
}
return &StreamableHTTPServerProxy{
serverName: serverName,
serverURL: serverURL,
client: mcpClient,
logger: logger,
tracer: tracer,
allowlistPattern: allowlist,
denylistPattern: denylist,
}, nil
}
func (p *StreamableHTTPServerProxy) Name() string {
return p.serverName
}
func (p *StreamableHTTPServerProxy) Init(ctx context.Context) (outErr error) {
ctx, span := p.tracer.Start(ctx, "StreamableHTTPServerProxy.Init", trace.WithAttributes(p.traceAttributes()...))
defer tracing.EndSpanErr(span, &outErr)
if err := p.client.Start(ctx); err != nil {
return xerrors.Errorf("start client: %w", err)
}
version := mcp.LATEST_PROTOCOL_VERSION
initReq := mcp.InitializeRequest{
Params: mcp.InitializeParams{
ProtocolVersion: version,
ClientInfo: GetClientInfo(),
},
}
result, err := p.client.Initialize(ctx, initReq)
if err != nil {
return xerrors.Errorf("init MCP client: %w", err)
}
if !slices.Contains(mcp.ValidProtocolVersions, result.ProtocolVersion) {
if err := p.client.Close(); err != nil {
p.logger.Debug(ctx, "failed to close MCP client on unsuccessful version negotiation", slog.Error(err))
}
return xerrors.Errorf("MCP version negotiation failed; requested %q, accepts %q, received %q", version, strings.Join(mcp.ValidProtocolVersions, ","), result.ProtocolVersion)
}
p.logger.Debug(ctx, "mcp client initialized", slog.F("name", result.ServerInfo.Name), slog.F("server_version", result.ServerInfo.Version))
tools, err := p.fetchTools(ctx)
if err != nil {
return xerrors.Errorf("fetch tools: %w", err)
}
// Only include allowed tools.
p.tools = FilterAllowedTools(p.logger.Named("tool-filterer"), tools, p.allowlistPattern, p.denylistPattern)
return nil
}
func (p *StreamableHTTPServerProxy) ListTools() []*Tool {
tools := maps.Values(p.tools)
slices.SortStableFunc(tools, func(a, b *Tool) int {
return strings.Compare(a.ID, b.ID)
})
return tools
}
func (p *StreamableHTTPServerProxy) GetTool(name string) *Tool {
if p.tools == nil {
return nil
}
t, ok := p.tools[name]
if !ok {
return nil
}
return t
}
func (p *StreamableHTTPServerProxy) CallTool(ctx context.Context, name string, input any) (*mcp.CallToolResult, error) {
tool := p.GetTool(name)
if tool == nil {
return nil, xerrors.Errorf("%q tool not known", name)
}
return p.client.CallTool(ctx, mcp.CallToolRequest{
Params: mcp.CallToolParams{
Name: tool.Name,
Arguments: input,
},
})
}
func (p *StreamableHTTPServerProxy) fetchTools(ctx context.Context) (_ map[string]*Tool, outErr error) {
ctx, span := p.tracer.Start(ctx, "StreamableHTTPServerProxy.Init.fetchTools", trace.WithAttributes(p.traceAttributes()...))
defer tracing.EndSpanErr(span, &outErr)
tools, err := p.client.ListTools(ctx, mcp.ListToolsRequest{})
if err != nil {
return nil, xerrors.Errorf("list MCP tools: %w", err)
}
out := make(map[string]*Tool, len(tools.Tools))
for _, tool := range tools.Tools {
encodedID := EncodeToolID(p.serverName, tool.Name)
out[encodedID] = &Tool{
Client: p.client,
ID: encodedID,
Name: tool.Name,
ServerName: p.serverName,
ServerURL: p.serverURL,
Description: tool.Description,
Params: tool.InputSchema.Properties,
Required: tool.InputSchema.Required,
Logger: p.logger,
}
}
span.SetAttributes(append(p.traceAttributes(), attribute.Int(tracing.MCPToolCount, len(out)))...)
return out, nil
}
func (p *StreamableHTTPServerProxy) Shutdown(_ context.Context) error {
if p.client == nil {
return nil
}
// NOTE: as of v0.38.0 the lib doesn't allow an outside context to be passed in;
// it has an internal timeout of 5s, though.
return p.client.Close()
}
func (p *StreamableHTTPServerProxy) traceAttributes() []attribute.KeyValue {
return []attribute.KeyValue{
attribute.String(tracing.MCPProxyName, p.Name()),
attribute.String(tracing.MCPServerName, p.serverName),
attribute.String(tracing.MCPServerURL, p.serverURL),
}
}