mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
43db8282b0
ConnectAll grows a fourth indexed input — []database.McpServerUserHeaderValue — which buildAuthHeaders consults inside the custom_headers branch to overlay per-user values on top of admin static headers, scoped to cfg.CustomHeadersUserKeys. chatd loads the user's stored header values via GetMCPServerUserHeaderValuesByUserID alongside the existing GetMCPServerUserTokensByUserID call and threads them into ConnectAll. A missing row is non-fatal: admin headers still ship, user-keyed headers are simply absent and a warning is logged.
331 lines
10 KiB
Go
331 lines
10 KiB
Go
package mcpclient_test
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"sync"
|
|
"testing"
|
|
|
|
"charm.land/fantasy"
|
|
"github.com/google/uuid"
|
|
"github.com/mark3labs/mcp-go/mcp"
|
|
mcpserver "github.com/mark3labs/mcp-go/server"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"cdr.dev/slog/v3/sloggers/slogtest"
|
|
"github.com/coder/coder/v2/coderd/database"
|
|
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
|
|
"github.com/coder/coder/v2/coderd/x/chatd/mcpclient"
|
|
)
|
|
|
|
// newHeaderRecordingServer creates a streamable HTTP MCP server with a
|
|
// single "ping" tool. Every request's headers are appended to the
|
|
// returned slice so tests can assert which headers were forwarded.
|
|
func newHeaderRecordingServer(t *testing.T) (*httptest.Server, *sync.Mutex, *[]http.Header) {
|
|
t.Helper()
|
|
var (
|
|
mu sync.Mutex
|
|
headers []http.Header
|
|
)
|
|
srv := mcpserver.NewMCPServer("hdr-server", "1.0.0")
|
|
srv.AddTools(mcpserver.ServerTool{
|
|
Tool: mcp.NewTool("ping", mcp.WithDescription("records the request headers")),
|
|
Handler: func(_ context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
|
mu.Lock()
|
|
headers = append(headers, req.Header.Clone())
|
|
mu.Unlock()
|
|
return mcp.NewToolResultText("ok"), nil
|
|
},
|
|
})
|
|
httpSrv := mcpserver.NewStreamableHTTPServer(srv)
|
|
ts := httptest.NewServer(httpSrv)
|
|
t.Cleanup(ts.Close)
|
|
return ts, &mu, &headers
|
|
}
|
|
|
|
// TestConnectAll_ForwardCoderHeaders_DefaultOff is a regression guard
|
|
// that the Coder identity headers are NOT sent when the option is
|
|
// left at its default (false).
|
|
func TestConnectAll_ForwardCoderHeaders_DefaultOff(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := t.Context()
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
|
|
ts, mu, recorded := newHeaderRecordingServer(t)
|
|
|
|
cfg := makeConfig("no-hdr", ts.URL)
|
|
assert.False(t, cfg.ForwardCoderHeaders, "default must be false")
|
|
|
|
coderHeaders := map[string]string{
|
|
chatprovider.HeaderCoderOwnerID: uuid.NewString(),
|
|
chatprovider.HeaderCoderChatID: uuid.NewString(),
|
|
chatprovider.HeaderCoderWorkspaceID: uuid.NewString(),
|
|
}
|
|
|
|
tools, cleanup := mcpclient.ConnectAll(
|
|
ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil,
|
|
coderHeaders,
|
|
)
|
|
t.Cleanup(cleanup)
|
|
require.Len(t, tools, 1)
|
|
|
|
_, err := tools[0].Run(ctx, fantasy.ToolCall{
|
|
ID: "call-1", Name: "no-hdr__ping", Input: "{}",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
require.NotEmpty(t, *recorded)
|
|
for _, h := range *recorded {
|
|
assert.Empty(t, h.Get(chatprovider.HeaderCoderOwnerID))
|
|
assert.Empty(t, h.Get(chatprovider.HeaderCoderChatID))
|
|
assert.Empty(t, h.Get(chatprovider.HeaderCoderSubchatID))
|
|
assert.Empty(t, h.Get(chatprovider.HeaderCoderWorkspaceID))
|
|
}
|
|
}
|
|
|
|
// TestConnectAll_ForwardCoderHeaders_Enabled verifies that when the
|
|
// option is enabled, the Coder identity headers are forwarded on every
|
|
// outgoing MCP request, including the subchat and workspace headers.
|
|
func TestConnectAll_ForwardCoderHeaders_Enabled(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := t.Context()
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
|
|
ts, mu, recorded := newHeaderRecordingServer(t)
|
|
|
|
ownerID := uuid.New()
|
|
chatID := uuid.New()
|
|
workspaceID := uuid.New()
|
|
subchatID := uuid.New()
|
|
|
|
cfg := makeConfig("hdr", ts.URL)
|
|
cfg.ForwardCoderHeaders = true
|
|
|
|
// Subchat headers: parent's chat ID lives in X-Coder-Chat-Id, the
|
|
// subchat's own ID lives in X-Coder-Subchat-Id.
|
|
coderHeaders := chatprovider.CoderHeaders(database.Chat{
|
|
ID: subchatID,
|
|
OwnerID: ownerID,
|
|
ParentChatID: uuid.NullUUID{UUID: chatID, Valid: true},
|
|
WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true},
|
|
})
|
|
|
|
tools, cleanup := mcpclient.ConnectAll(
|
|
ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil,
|
|
coderHeaders,
|
|
)
|
|
t.Cleanup(cleanup)
|
|
require.Len(t, tools, 1)
|
|
|
|
_, err := tools[0].Run(ctx, fantasy.ToolCall{
|
|
ID: "call-1", Name: "hdr__ping", Input: "{}",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
require.NotEmpty(t, *recorded)
|
|
last := (*recorded)[len(*recorded)-1]
|
|
assert.Equal(t, ownerID.String(), last.Get(chatprovider.HeaderCoderOwnerID))
|
|
assert.Equal(t, chatID.String(), last.Get(chatprovider.HeaderCoderChatID))
|
|
assert.Equal(t, subchatID.String(), last.Get(chatprovider.HeaderCoderSubchatID))
|
|
assert.Equal(t, workspaceID.String(), last.Get(chatprovider.HeaderCoderWorkspaceID))
|
|
}
|
|
|
|
// TestConnectAll_ForwardCoderHeaders_RootChat verifies that for a root
|
|
// chat (no parent), the chat's own ID is forwarded as
|
|
// X-Coder-Chat-Id and the X-Coder-Subchat-Id header is absent.
|
|
func TestConnectAll_ForwardCoderHeaders_RootChat(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := t.Context()
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
|
|
ts, mu, recorded := newHeaderRecordingServer(t)
|
|
|
|
ownerID := uuid.New()
|
|
chatID := uuid.New()
|
|
|
|
cfg := makeConfig("hdr-root", ts.URL)
|
|
cfg.ForwardCoderHeaders = true
|
|
|
|
coderHeaders := chatprovider.CoderHeaders(database.Chat{
|
|
ID: chatID,
|
|
OwnerID: ownerID,
|
|
})
|
|
|
|
tools, cleanup := mcpclient.ConnectAll(
|
|
ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil,
|
|
coderHeaders,
|
|
)
|
|
t.Cleanup(cleanup)
|
|
require.Len(t, tools, 1)
|
|
|
|
_, err := tools[0].Run(ctx, fantasy.ToolCall{
|
|
ID: "call-1", Name: "hdr-root__ping", Input: "{}",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
require.NotEmpty(t, *recorded)
|
|
last := (*recorded)[len(*recorded)-1]
|
|
assert.Equal(t, ownerID.String(), last.Get(chatprovider.HeaderCoderOwnerID))
|
|
assert.Equal(t, chatID.String(), last.Get(chatprovider.HeaderCoderChatID))
|
|
assert.Empty(t, last.Get(chatprovider.HeaderCoderSubchatID))
|
|
assert.Empty(t, last.Get(chatprovider.HeaderCoderWorkspaceID))
|
|
}
|
|
|
|
// TestConnectAll_ForwardCoderHeaders_WithAPIKeyAuth verifies that the
|
|
// api_key auth header is preserved when Coder identity headers are
|
|
// forwarded alongside.
|
|
func TestConnectAll_ForwardCoderHeaders_WithAPIKeyAuth(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := t.Context()
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
|
|
ts, mu, recorded := newHeaderRecordingServer(t)
|
|
|
|
ownerID := uuid.New()
|
|
chatID := uuid.New()
|
|
|
|
cfg := makeConfig("hdr-apikey", ts.URL)
|
|
cfg.AuthType = "api_key"
|
|
cfg.APIKeyHeader = "X-Api-Key"
|
|
cfg.APIKeyValue = "sekret"
|
|
cfg.ForwardCoderHeaders = true
|
|
|
|
coderHeaders := chatprovider.CoderHeaders(database.Chat{
|
|
ID: chatID,
|
|
OwnerID: ownerID,
|
|
})
|
|
|
|
tools, cleanup := mcpclient.ConnectAll(
|
|
ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil,
|
|
coderHeaders,
|
|
)
|
|
t.Cleanup(cleanup)
|
|
require.Len(t, tools, 1)
|
|
|
|
_, err := tools[0].Run(ctx, fantasy.ToolCall{
|
|
ID: "call-1", Name: "hdr-apikey__ping", Input: "{}",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
require.NotEmpty(t, *recorded)
|
|
last := (*recorded)[len(*recorded)-1]
|
|
assert.Equal(t, "sekret", last.Get("X-Api-Key"))
|
|
assert.Equal(t, ownerID.String(), last.Get(chatprovider.HeaderCoderOwnerID))
|
|
assert.Equal(t, chatID.String(), last.Get(chatprovider.HeaderCoderChatID))
|
|
}
|
|
|
|
// TestConnectAll_ForwardCoderHeaders_WithOAuth2 verifies that the
|
|
// oauth2 Authorization header is preserved when Coder identity
|
|
// headers are forwarded alongside, and that auth wins on a conflict.
|
|
func TestConnectAll_ForwardCoderHeaders_WithOAuth2(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := t.Context()
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
|
|
ts, mu, recorded := newHeaderRecordingServer(t)
|
|
|
|
cfgID := uuid.New()
|
|
cfg := makeConfig("hdr-oauth", ts.URL)
|
|
cfg.ID = cfgID
|
|
cfg.AuthType = "oauth2"
|
|
cfg.ForwardCoderHeaders = true
|
|
token := database.MCPServerUserToken{
|
|
MCPServerConfigID: cfgID,
|
|
AccessToken: "oauth-token-xyz",
|
|
TokenType: "Bearer",
|
|
}
|
|
|
|
// Intentionally include an Authorization key to verify the auth
|
|
// header wins on conflict.
|
|
ownerID := uuid.NewString()
|
|
coderHeaders := map[string]string{
|
|
"Authorization": "Bearer should-be-overridden",
|
|
chatprovider.HeaderCoderOwnerID: ownerID,
|
|
}
|
|
|
|
tools, cleanup := mcpclient.ConnectAll(
|
|
ctx, logger,
|
|
[]database.MCPServerConfig{cfg},
|
|
[]database.MCPServerUserToken{token},
|
|
nil,
|
|
uuid.Nil, nil,
|
|
coderHeaders,
|
|
)
|
|
t.Cleanup(cleanup)
|
|
require.Len(t, tools, 1)
|
|
|
|
_, err := tools[0].Run(ctx, fantasy.ToolCall{
|
|
ID: "call-1", Name: "hdr-oauth__ping", Input: "{}",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
require.NotEmpty(t, *recorded)
|
|
last := (*recorded)[len(*recorded)-1]
|
|
assert.Equal(t, "Bearer oauth-token-xyz", last.Get("Authorization"))
|
|
assert.Equal(t, ownerID, last.Get(chatprovider.HeaderCoderOwnerID))
|
|
}
|
|
|
|
// TestConnectAll_ForwardCoderHeaders_WithCustomHeaders verifies that
|
|
// custom_headers admin-configured values are preserved when Coder
|
|
// identity headers are forwarded alongside, including the case where
|
|
// the admin configures a custom header whose name only differs from a
|
|
// Coder identity header by case. Conflict detection is case-
|
|
// insensitive because http.Header.Set canonicalizes header names.
|
|
func TestConnectAll_ForwardCoderHeaders_WithCustomHeaders(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := t.Context()
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
|
|
ts, mu, recorded := newHeaderRecordingServer(t)
|
|
|
|
ownerID := uuid.New()
|
|
chatID := uuid.New()
|
|
|
|
cfg := makeConfig("hdr-custom", ts.URL)
|
|
cfg.AuthType = "custom_headers"
|
|
// Include both an unrelated custom header AND a case-variant of
|
|
// X-Coder-Owner-Id to exercise the case-insensitive conflict
|
|
// check. The admin-configured value MUST win.
|
|
cfg.CustomHeaders = `{"X-Tenant":"acme","x-coder-owner-id":"admin-controlled"}`
|
|
cfg.ForwardCoderHeaders = true
|
|
|
|
coderHeaders := chatprovider.CoderHeaders(database.Chat{
|
|
ID: chatID,
|
|
OwnerID: ownerID,
|
|
})
|
|
|
|
tools, cleanup := mcpclient.ConnectAll(
|
|
ctx, logger, []database.MCPServerConfig{cfg}, nil, nil, uuid.Nil, nil,
|
|
coderHeaders,
|
|
)
|
|
t.Cleanup(cleanup)
|
|
require.Len(t, tools, 1)
|
|
|
|
_, err := tools[0].Run(ctx, fantasy.ToolCall{
|
|
ID: "call-1", Name: "hdr-custom__ping", Input: "{}",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
require.NotEmpty(t, *recorded)
|
|
last := (*recorded)[len(*recorded)-1]
|
|
assert.Equal(t, "acme", last.Get("X-Tenant"))
|
|
// The admin's case-variant header must win, because HTTP header
|
|
// names are case-insensitive at the transport level.
|
|
assert.Equal(t, "admin-controlled", last.Get(chatprovider.HeaderCoderOwnerID))
|
|
assert.Equal(t, chatID.String(), last.Get(chatprovider.HeaderCoderChatID))
|
|
}
|