mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
fix: isolate MCP HTTP transports from DefaultTransport in tests (#25821)
Use testing.Testing() inside createTransport to automatically clone http.DefaultTransport when running in tests. In production, DefaultTransport is used as-is (efficient connection pooling). This fixes the CloseIdleConnections flake class: httptest.Server.Close() calls http.DefaultTransport.CloseIdleConnections(), which disrupts any MCP client sharing that transport. The testing.Testing() check means every MCP transport created during tests gets isolation automatically, with no caller changes needed. Closes coder/internal#1016 Closes PLAT-291
This commit is contained in:
committed by
GitHub
parent
c8555e2163
commit
82752844bc
@@ -975,15 +975,19 @@ func (m *Manager) createTransport(ctx context.Context, cfg ServerConfig) (transp
|
|||||||
}),
|
}),
|
||||||
), nil
|
), nil
|
||||||
case "http", "":
|
case "http", "":
|
||||||
return transport.NewStreamableHTTP(
|
var opts []transport.StreamableHTTPCOption
|
||||||
cfg.URL,
|
opts = append(opts, transport.WithHTTPHeaders(cfg.Headers))
|
||||||
transport.WithHTTPHeaders(cfg.Headers),
|
if c := mcpHTTPClient(); c != nil {
|
||||||
)
|
opts = append(opts, transport.WithHTTPBasicClient(c))
|
||||||
|
}
|
||||||
|
return transport.NewStreamableHTTP(cfg.URL, opts...)
|
||||||
case "sse":
|
case "sse":
|
||||||
return transport.NewSSE(
|
var sseOpts []transport.ClientOption
|
||||||
cfg.URL,
|
sseOpts = append(sseOpts, transport.WithHeaders(cfg.Headers))
|
||||||
transport.WithHeaders(cfg.Headers),
|
if c := mcpHTTPClient(); c != nil {
|
||||||
)
|
sseOpts = append(sseOpts, transport.WithHTTPClient(c))
|
||||||
|
}
|
||||||
|
return transport.NewSSE(cfg.URL, sseOpts...)
|
||||||
default:
|
default:
|
||||||
return nil, xerrors.Errorf("unsupported transport %q", cfg.Transport)
|
return nil, xerrors.Errorf("unsupported transport %q", cfg.Transport)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,25 @@
|
|||||||
|
package agentmcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// mcpHTTPClient returns an isolated *http.Client when running
|
||||||
|
// inside tests, or nil for production. During tests,
|
||||||
|
// httptest.Server.Close() calls
|
||||||
|
// http.DefaultTransport.CloseIdleConnections(), which disrupts
|
||||||
|
// any MCP client sharing that transport. When DefaultTransport
|
||||||
|
// is a *http.Transport it is cloned; otherwise a minimal
|
||||||
|
// transport with ProxyFromEnvironment is created as a fallback.
|
||||||
|
func mcpHTTPClient() *http.Client {
|
||||||
|
if !testing.Testing() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if dt, ok := http.DefaultTransport.(*http.Transport); ok {
|
||||||
|
return &http.Client{Transport: dt.Clone()}
|
||||||
|
}
|
||||||
|
return &http.Client{Transport: &http.Transport{
|
||||||
|
Proxy: http.ProxyFromEnvironment,
|
||||||
|
}}
|
||||||
|
}
|
||||||
@@ -0,0 +1,25 @@
|
|||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// mcpHTTPClient returns an isolated *http.Client when running
|
||||||
|
// inside tests, or nil for production. During tests,
|
||||||
|
// httptest.Server.Close() calls
|
||||||
|
// http.DefaultTransport.CloseIdleConnections(), which disrupts
|
||||||
|
// any MCP client sharing that transport. When DefaultTransport
|
||||||
|
// is a *http.Transport it is cloned; otherwise a minimal
|
||||||
|
// transport with ProxyFromEnvironment is created as a fallback.
|
||||||
|
func mcpHTTPClient() *http.Client {
|
||||||
|
if !testing.Testing() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if dt, ok := http.DefaultTransport.(*http.Transport); ok {
|
||||||
|
return &http.Client{Transport: dt.Clone()}
|
||||||
|
}
|
||||||
|
return &http.Client{Transport: &http.Transport{
|
||||||
|
Proxy: http.ProxyFromEnvironment,
|
||||||
|
}}
|
||||||
|
}
|
||||||
@@ -39,6 +39,17 @@ func NewStreamableHTTPServerProxy(serverName, serverURL string, headers map[stri
|
|||||||
opts = append(opts, transport.WithHTTPHeaders(headers))
|
opts = append(opts, transport.WithHTTPHeaders(headers))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Prepend an isolated HTTP client when running in tests so
|
||||||
|
// httptest.Server.Close() does not disrupt this proxy's
|
||||||
|
// connections via http.DefaultTransport.CloseIdleConnections().
|
||||||
|
// Caller-provided WithHTTPBasicClient in opts overrides this
|
||||||
|
// (last-wins).
|
||||||
|
if c := mcpHTTPClient(); c != nil {
|
||||||
|
opts = append([]transport.StreamableHTTPCOption{
|
||||||
|
transport.WithHTTPBasicClient(c),
|
||||||
|
}, opts...)
|
||||||
|
}
|
||||||
|
|
||||||
mcpClient, err := client.NewStreamableHttpClient(serverURL, opts...)
|
mcpClient, err := client.NewStreamableHttpClient(serverURL, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, xerrors.Errorf("create streamable http client: %w", err)
|
return nil, xerrors.Errorf("create streamable http client: %w", err)
|
||||||
|
|||||||
+107
-30
@@ -12,6 +12,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
@@ -57,11 +58,10 @@ func TestMCPHTTP_E2E_ClientIntegration(t *testing.T) {
|
|||||||
mcpURL := api.AccessURL.String() + mcpserver.MCPEndpoint
|
mcpURL := api.AccessURL.String() + mcpserver.MCPEndpoint
|
||||||
|
|
||||||
// Configure client with authentication headers using RFC 6750 Bearer token
|
// Configure client with authentication headers using RFC 6750 Bearer token
|
||||||
mcpClient, err := mcpclient.NewStreamableHttpClient(mcpURL,
|
mcpClient := newIsolatedMCPClient(t, mcpURL,
|
||||||
transport.WithHTTPHeaders(map[string]string{
|
transport.WithHTTPHeaders(map[string]string{
|
||||||
"Authorization": "Bearer " + coderClient.SessionToken(),
|
"Authorization": "Bearer " + coderClient.SessionToken(),
|
||||||
}))
|
}))
|
||||||
require.NoError(t, err)
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if closeErr := mcpClient.Close(); closeErr != nil {
|
if closeErr := mcpClient.Close(); closeErr != nil {
|
||||||
t.Logf("Failed to close MCP client: %v", closeErr)
|
t.Logf("Failed to close MCP client: %v", closeErr)
|
||||||
@@ -72,7 +72,7 @@ func TestMCPHTTP_E2E_ClientIntegration(t *testing.T) {
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
// Start client
|
// Start client
|
||||||
err = mcpClient.Start(ctx)
|
err := mcpClient.Start(ctx)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Initialize connection
|
// Initialize connection
|
||||||
@@ -190,8 +190,7 @@ func TestMCPHTTP_E2E_UnauthenticatedAccess(t *testing.T) {
|
|||||||
require.Equal(t, http.StatusUnauthorized, resp.StatusCode, "Should get HTTP 401 for unauthenticated access")
|
require.Equal(t, http.StatusUnauthorized, resp.StatusCode, "Should get HTTP 401 for unauthenticated access")
|
||||||
|
|
||||||
// Also test with MCP client to ensure it handles the error gracefully
|
// Also test with MCP client to ensure it handles the error gracefully
|
||||||
mcpClient, err := mcpclient.NewStreamableHttpClient(mcpURL)
|
mcpClient := newIsolatedMCPClient(t, mcpURL)
|
||||||
require.NoError(t, err, "Should be able to create MCP client without authentication")
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if closeErr := mcpClient.Close(); closeErr != nil {
|
if closeErr := mcpClient.Close(); closeErr != nil {
|
||||||
t.Logf("Failed to close MCP client: %v", closeErr)
|
t.Logf("Failed to close MCP client: %v", closeErr)
|
||||||
@@ -245,11 +244,10 @@ func TestMCPHTTP_E2E_ToolWithWorkspace(t *testing.T) {
|
|||||||
coderdtest.NewWorkspaceAgentWaiter(t, coderClient, r.Workspace.ID).Wait()
|
coderdtest.NewWorkspaceAgentWaiter(t, coderClient, r.Workspace.ID).Wait()
|
||||||
|
|
||||||
mcpURL := api.AccessURL.String() + mcpserver.MCPEndpoint
|
mcpURL := api.AccessURL.String() + mcpserver.MCPEndpoint
|
||||||
mcpClient, err := mcpclient.NewStreamableHttpClient(mcpURL,
|
mcpClient := newIsolatedMCPClient(t, mcpURL,
|
||||||
transport.WithHTTPHeaders(map[string]string{
|
transport.WithHTTPHeaders(map[string]string{
|
||||||
"Authorization": "Bearer " + coderClient.SessionToken(),
|
"Authorization": "Bearer " + coderClient.SessionToken(),
|
||||||
}))
|
}))
|
||||||
require.NoError(t, err)
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if closeErr := mcpClient.Close(); closeErr != nil {
|
if closeErr := mcpClient.Close(); closeErr != nil {
|
||||||
t.Logf("Failed to close MCP client: %v", closeErr)
|
t.Logf("Failed to close MCP client: %v", closeErr)
|
||||||
@@ -260,7 +258,7 @@ func TestMCPHTTP_E2E_ToolWithWorkspace(t *testing.T) {
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
require.NoError(t, mcpClient.Start(ctx))
|
require.NoError(t, mcpClient.Start(ctx))
|
||||||
_, err = mcpClient.Initialize(ctx, mcp.InitializeRequest{
|
_, err := mcpClient.Initialize(ctx, mcp.InitializeRequest{
|
||||||
Params: mcp.InitializeParams{
|
Params: mcp.InitializeParams{
|
||||||
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
|
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
|
||||||
ClientInfo: mcp.Implementation{
|
ClientInfo: mcp.Implementation{
|
||||||
@@ -307,11 +305,10 @@ func TestMCPHTTP_E2E_ErrorHandling(t *testing.T) {
|
|||||||
|
|
||||||
// Create MCP client
|
// Create MCP client
|
||||||
mcpURL := api.AccessURL.String() + mcpserver.MCPEndpoint
|
mcpURL := api.AccessURL.String() + mcpserver.MCPEndpoint
|
||||||
mcpClient, err := mcpclient.NewStreamableHttpClient(mcpURL,
|
mcpClient := newIsolatedMCPClient(t, mcpURL,
|
||||||
transport.WithHTTPHeaders(map[string]string{
|
transport.WithHTTPHeaders(map[string]string{
|
||||||
"Authorization": "Bearer " + coderClient.SessionToken(),
|
"Authorization": "Bearer " + coderClient.SessionToken(),
|
||||||
}))
|
}))
|
||||||
require.NoError(t, err)
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if closeErr := mcpClient.Close(); closeErr != nil {
|
if closeErr := mcpClient.Close(); closeErr != nil {
|
||||||
t.Logf("Failed to close MCP client: %v", closeErr)
|
t.Logf("Failed to close MCP client: %v", closeErr)
|
||||||
@@ -322,7 +319,7 @@ func TestMCPHTTP_E2E_ErrorHandling(t *testing.T) {
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
// Start and initialize client
|
// Start and initialize client
|
||||||
err = mcpClient.Start(ctx)
|
err := mcpClient.Start(ctx)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
initReq := mcp.InitializeRequest{
|
initReq := mcp.InitializeRequest{
|
||||||
@@ -366,11 +363,10 @@ func TestMCPHTTP_E2E_ConcurrentRequests(t *testing.T) {
|
|||||||
|
|
||||||
// Create MCP client
|
// Create MCP client
|
||||||
mcpURL := api.AccessURL.String() + mcpserver.MCPEndpoint
|
mcpURL := api.AccessURL.String() + mcpserver.MCPEndpoint
|
||||||
mcpClient, err := mcpclient.NewStreamableHttpClient(mcpURL,
|
mcpClient := newIsolatedMCPClient(t, mcpURL,
|
||||||
transport.WithHTTPHeaders(map[string]string{
|
transport.WithHTTPHeaders(map[string]string{
|
||||||
"Authorization": "Bearer " + coderClient.SessionToken(),
|
"Authorization": "Bearer " + coderClient.SessionToken(),
|
||||||
}))
|
}))
|
||||||
require.NoError(t, err)
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if closeErr := mcpClient.Close(); closeErr != nil {
|
if closeErr := mcpClient.Close(); closeErr != nil {
|
||||||
t.Logf("Failed to close MCP client: %v", closeErr)
|
t.Logf("Failed to close MCP client: %v", closeErr)
|
||||||
@@ -381,7 +377,7 @@ func TestMCPHTTP_E2E_ConcurrentRequests(t *testing.T) {
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
// Start and initialize client
|
// Start and initialize client
|
||||||
err = mcpClient.Start(ctx)
|
err := mcpClient.Start(ctx)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
initReq := mcp.InitializeRequest{
|
initReq := mcp.InitializeRequest{
|
||||||
@@ -520,11 +516,10 @@ func TestMCPHTTP_E2E_OAuth2_EndToEnd(t *testing.T) {
|
|||||||
sessionToken := coderClient.SessionToken()
|
sessionToken := coderClient.SessionToken()
|
||||||
|
|
||||||
mcpURL := api.AccessURL.String() + mcpserver.MCPEndpoint
|
mcpURL := api.AccessURL.String() + mcpserver.MCPEndpoint
|
||||||
mcpClient, err := mcpclient.NewStreamableHttpClient(mcpURL,
|
mcpClient := newIsolatedMCPClient(t, mcpURL,
|
||||||
transport.WithHTTPHeaders(map[string]string{
|
transport.WithHTTPHeaders(map[string]string{
|
||||||
"Authorization": "Bearer " + sessionToken,
|
"Authorization": "Bearer " + sessionToken,
|
||||||
}))
|
}))
|
||||||
require.NoError(t, err)
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if closeErr := mcpClient.Close(); closeErr != nil {
|
if closeErr := mcpClient.Close(); closeErr != nil {
|
||||||
t.Logf("Failed to close MCP client: %v", closeErr)
|
t.Logf("Failed to close MCP client: %v", closeErr)
|
||||||
@@ -669,11 +664,10 @@ func TestMCPHTTP_E2E_OAuth2_EndToEnd(t *testing.T) {
|
|||||||
|
|
||||||
// Step 3: Use access token to authenticate with MCP endpoint
|
// Step 3: Use access token to authenticate with MCP endpoint
|
||||||
mcpURL := api.AccessURL.String() + mcpserver.MCPEndpoint
|
mcpURL := api.AccessURL.String() + mcpserver.MCPEndpoint
|
||||||
mcpClient, err := mcpclient.NewStreamableHttpClient(mcpURL,
|
mcpClient := newIsolatedMCPClient(t, mcpURL,
|
||||||
transport.WithHTTPHeaders(map[string]string{
|
transport.WithHTTPHeaders(map[string]string{
|
||||||
"Authorization": "Bearer " + accessToken,
|
"Authorization": "Bearer " + accessToken,
|
||||||
}))
|
}))
|
||||||
require.NoError(t, err)
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if closeErr := mcpClient.Close(); closeErr != nil {
|
if closeErr := mcpClient.Close(); closeErr != nil {
|
||||||
t.Logf("Failed to close MCP client: %v", closeErr)
|
t.Logf("Failed to close MCP client: %v", closeErr)
|
||||||
@@ -762,11 +756,10 @@ func TestMCPHTTP_E2E_OAuth2_EndToEnd(t *testing.T) {
|
|||||||
t.Logf("Successfully refreshed token: %s...", newAccessToken[:10])
|
t.Logf("Successfully refreshed token: %s...", newAccessToken[:10])
|
||||||
|
|
||||||
// Step 5: Use new access token to create another MCP connection
|
// Step 5: Use new access token to create another MCP connection
|
||||||
newMcpClient, err := mcpclient.NewStreamableHttpClient(mcpURL,
|
newMcpClient := newIsolatedMCPClient(t, mcpURL,
|
||||||
transport.WithHTTPHeaders(map[string]string{
|
transport.WithHTTPHeaders(map[string]string{
|
||||||
"Authorization": "Bearer " + newAccessToken,
|
"Authorization": "Bearer " + newAccessToken,
|
||||||
}))
|
}))
|
||||||
require.NoError(t, err)
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if closeErr := newMcpClient.Close(); closeErr != nil {
|
if closeErr := newMcpClient.Close(); closeErr != nil {
|
||||||
t.Logf("Failed to close new MCP client: %v", closeErr)
|
t.Logf("Failed to close new MCP client: %v", closeErr)
|
||||||
@@ -990,11 +983,10 @@ func TestMCPHTTP_E2E_OAuth2_EndToEnd(t *testing.T) {
|
|||||||
t.Logf("Successfully obtained access token: %s...", accessToken[:10])
|
t.Logf("Successfully obtained access token: %s...", accessToken[:10])
|
||||||
|
|
||||||
// Step 5: Use access token to get user information via MCP
|
// Step 5: Use access token to get user information via MCP
|
||||||
mcpClient, err := mcpclient.NewStreamableHttpClient(mcpURL,
|
mcpClient := newIsolatedMCPClient(t, mcpURL,
|
||||||
transport.WithHTTPHeaders(map[string]string{
|
transport.WithHTTPHeaders(map[string]string{
|
||||||
"Authorization": "Bearer " + accessToken,
|
"Authorization": "Bearer " + accessToken,
|
||||||
}))
|
}))
|
||||||
require.NoError(t, err)
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if closeErr := mcpClient.Close(); closeErr != nil {
|
if closeErr := mcpClient.Close(); closeErr != nil {
|
||||||
t.Logf("Failed to close MCP client: %v", closeErr)
|
t.Logf("Failed to close MCP client: %v", closeErr)
|
||||||
@@ -1088,11 +1080,10 @@ func TestMCPHTTP_E2E_OAuth2_EndToEnd(t *testing.T) {
|
|||||||
t.Logf("Successfully refreshed token: %s...", newAccessToken[:10])
|
t.Logf("Successfully refreshed token: %s...", newAccessToken[:10])
|
||||||
|
|
||||||
// Step 7: Use refreshed token to get user information again via MCP
|
// Step 7: Use refreshed token to get user information again via MCP
|
||||||
newMcpClient, err := mcpclient.NewStreamableHttpClient(mcpURL,
|
newMcpClient := newIsolatedMCPClient(t, mcpURL,
|
||||||
transport.WithHTTPHeaders(map[string]string{
|
transport.WithHTTPHeaders(map[string]string{
|
||||||
"Authorization": "Bearer " + newAccessToken,
|
"Authorization": "Bearer " + newAccessToken,
|
||||||
}))
|
}))
|
||||||
require.NoError(t, err)
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if closeErr := newMcpClient.Close(); closeErr != nil {
|
if closeErr := newMcpClient.Close(); closeErr != nil {
|
||||||
t.Logf("Failed to close new MCP client: %v", closeErr)
|
t.Logf("Failed to close new MCP client: %v", closeErr)
|
||||||
@@ -1268,11 +1259,10 @@ func TestMCPHTTP_E2E_ChatGPTEndpoint(t *testing.T) {
|
|||||||
mcpURL := api.AccessURL.String() + mcpserver.MCPEndpoint + "?toolset=chatgpt"
|
mcpURL := api.AccessURL.String() + mcpserver.MCPEndpoint + "?toolset=chatgpt"
|
||||||
|
|
||||||
// Configure client with authentication headers using RFC 6750 Bearer token
|
// Configure client with authentication headers using RFC 6750 Bearer token
|
||||||
mcpClient, err := mcpclient.NewStreamableHttpClient(mcpURL,
|
mcpClient := newIsolatedMCPClient(t, mcpURL,
|
||||||
transport.WithHTTPHeaders(map[string]string{
|
transport.WithHTTPHeaders(map[string]string{
|
||||||
"Authorization": "Bearer " + coderClient.SessionToken(),
|
"Authorization": "Bearer " + coderClient.SessionToken(),
|
||||||
}))
|
}))
|
||||||
require.NoError(t, err)
|
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
if closeErr := mcpClient.Close(); closeErr != nil {
|
if closeErr := mcpClient.Close(); closeErr != nil {
|
||||||
t.Logf("Failed to close MCP client: %v", closeErr)
|
t.Logf("Failed to close MCP client: %v", closeErr)
|
||||||
@@ -1283,7 +1273,7 @@ func TestMCPHTTP_E2E_ChatGPTEndpoint(t *testing.T) {
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
// Start client
|
// Start client
|
||||||
err = mcpClient.Start(ctx)
|
err := mcpClient.Start(ctx)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Initialize connection
|
// Initialize connection
|
||||||
@@ -1433,11 +1423,10 @@ func TestMCPHTTP_E2E_WorkspaceSSHAuthz(t *testing.T) {
|
|||||||
|
|
||||||
// Connect with the template-admin user.
|
// Connect with the template-admin user.
|
||||||
mcpURL := api.AccessURL.String() + mcpserver.MCPEndpoint
|
mcpURL := api.AccessURL.String() + mcpserver.MCPEndpoint
|
||||||
mcpClient, err := mcpclient.NewStreamableHttpClient(mcpURL,
|
mcpClient := newIsolatedMCPClient(t, mcpURL,
|
||||||
transport.WithHTTPHeaders(map[string]string{
|
transport.WithHTTPHeaders(map[string]string{
|
||||||
"Authorization": "Bearer " + tmplAdminClient.SessionToken(),
|
"Authorization": "Bearer " + tmplAdminClient.SessionToken(),
|
||||||
}))
|
}))
|
||||||
require.NoError(t, err)
|
|
||||||
defer func() {
|
defer func() {
|
||||||
_ = mcpClient.Close()
|
_ = mcpClient.Close()
|
||||||
}()
|
}()
|
||||||
@@ -1446,7 +1435,7 @@ func TestMCPHTTP_E2E_WorkspaceSSHAuthz(t *testing.T) {
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
require.NoError(t, mcpClient.Start(ctx))
|
require.NoError(t, mcpClient.Start(ctx))
|
||||||
_, err = mcpClient.Initialize(ctx, mcp.InitializeRequest{
|
_, err := mcpClient.Initialize(ctx, mcp.InitializeRequest{
|
||||||
Params: mcp.InitializeParams{
|
Params: mcp.InitializeParams{
|
||||||
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
|
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
|
||||||
ClientInfo: mcp.Implementation{
|
ClientInfo: mcp.Implementation{
|
||||||
@@ -1489,3 +1478,91 @@ func mustParseURL(t *testing.T, rawURL string) *url.URL {
|
|||||||
require.NoError(t, err, "Failed to parse URL %q", rawURL)
|
require.NoError(t, err, "Failed to parse URL %q", rawURL)
|
||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// newIsolatedMCPClient creates a streamable HTTP MCP client that uses
|
||||||
|
// an isolated http.Transport cloned from http.DefaultTransport.
|
||||||
|
// This prevents httptest.Server.Close() (which calls
|
||||||
|
// http.DefaultTransport.CloseIdleConnections()) from disrupting the
|
||||||
|
// client's connections during parallel tests.
|
||||||
|
func newIsolatedMCPClient(t *testing.T, mcpURL string, opts ...transport.StreamableHTTPCOption) *mcpclient.Client {
|
||||||
|
t.Helper()
|
||||||
|
isolated := coderdtest.NewIsolatedHTTPClient(nil)
|
||||||
|
opts = append([]transport.StreamableHTTPCOption{transport.WithHTTPBasicClient(isolated)}, opts...)
|
||||||
|
client, err := mcpclient.NewStreamableHttpClient(mcpURL, opts...)
|
||||||
|
require.NoError(t, err)
|
||||||
|
return client
|
||||||
|
}
|
||||||
|
|
||||||
|
// sentinelTransport wraps an http.RoundTripper and counts how many
|
||||||
|
// requests flow through it. Used as a test sentinel to verify
|
||||||
|
// whether a client is (or is not) using http.DefaultTransport.
|
||||||
|
type sentinelTransport struct {
|
||||||
|
inner http.RoundTripper
|
||||||
|
hits atomic.Int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *sentinelTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
s.hits.Add(1)
|
||||||
|
return s.inner.RoundTrip(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestMCPHTTP_E2E_TransportIsolation verifies that the
|
||||||
|
// newIsolatedMCPClient helper creates clients that do NOT route
|
||||||
|
// requests through http.DefaultTransport, while raw
|
||||||
|
// mcpclient.NewStreamableHttpClient (without explicit
|
||||||
|
// WithHTTPBasicClient) does use it.
|
||||||
|
//
|
||||||
|
//nolint:paralleltest // Mutates http.DefaultTransport.
|
||||||
|
func TestMCPHTTP_E2E_TransportIsolation(t *testing.T) {
|
||||||
|
// Replace DefaultTransport with a counting sentinel.
|
||||||
|
original := http.DefaultTransport
|
||||||
|
sentinel := &sentinelTransport{inner: original}
|
||||||
|
http.DefaultTransport = sentinel
|
||||||
|
t.Cleanup(func() { http.DefaultTransport = original })
|
||||||
|
|
||||||
|
coderClient, closer, api := coderdtest.NewWithAPI(t, nil)
|
||||||
|
t.Cleanup(func() { closer.Close() })
|
||||||
|
_ = coderdtest.CreateFirstUser(t, coderClient)
|
||||||
|
|
||||||
|
mcpURL := api.AccessURL.String() + mcpserver.MCPEndpoint
|
||||||
|
authOpt := transport.WithHTTPHeaders(map[string]string{
|
||||||
|
"Authorization": "Bearer " + coderClient.SessionToken(),
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
initReq := mcp.InitializeRequest{
|
||||||
|
Params: mcp.InitializeParams{
|
||||||
|
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
|
||||||
|
ClientInfo: mcp.Implementation{Name: "sentinel-test", Version: "1.0.0"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("RawClientUsesDefaultTransport", func(t *testing.T) {
|
||||||
|
sentinel.hits.Store(0)
|
||||||
|
rawClient, err := mcpclient.NewStreamableHttpClient(mcpURL, authOpt)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() { _ = rawClient.Close() }()
|
||||||
|
|
||||||
|
require.NoError(t, rawClient.Start(ctx))
|
||||||
|
_, err = rawClient.Initialize(ctx, initReq)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Greater(t, sentinel.hits.Load(), int64(0),
|
||||||
|
"raw client should route requests through http.DefaultTransport")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("IsolatedClientBypassesDefaultTransport", func(t *testing.T) {
|
||||||
|
sentinel.hits.Store(0)
|
||||||
|
isoClient := newIsolatedMCPClient(t, mcpURL, authOpt)
|
||||||
|
defer func() { _ = isoClient.Close() }()
|
||||||
|
|
||||||
|
require.NoError(t, isoClient.Start(ctx))
|
||||||
|
_, err := isoClient.Initialize(ctx, initReq)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Equal(t, int64(0), sentinel.hits.Load(),
|
||||||
|
"isolated client must NOT route requests through http.DefaultTransport")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -285,31 +285,24 @@ func createTransport(
|
|||||||
cfg database.MCPServerConfig,
|
cfg database.MCPServerConfig,
|
||||||
headers map[string]string,
|
headers map[string]string,
|
||||||
) (transport.Interface, error) {
|
) (transport.Interface, error) {
|
||||||
// Each connection gets its own HTTP client with a dedicated
|
httpClient := mcpHTTPClient()
|
||||||
// transport so that httptest.Server.Close() (which calls
|
|
||||||
// CloseIdleConnections on http.DefaultTransport) does not
|
|
||||||
// disrupt unrelated connections during parallel tests.
|
|
||||||
var httpClient *http.Client
|
|
||||||
if dt, ok := http.DefaultTransport.(*http.Transport); ok {
|
|
||||||
httpClient = &http.Client{Transport: dt.Clone()}
|
|
||||||
} else {
|
|
||||||
httpClient = &http.Client{}
|
|
||||||
}
|
|
||||||
|
|
||||||
switch cfg.Transport {
|
switch cfg.Transport {
|
||||||
case "sse":
|
case "sse":
|
||||||
return transport.NewSSE(
|
var opts []transport.ClientOption
|
||||||
cfg.Url,
|
opts = append(opts, transport.WithHeaders(headers))
|
||||||
transport.WithHeaders(headers),
|
if httpClient != nil {
|
||||||
transport.WithHTTPClient(httpClient),
|
opts = append(opts, transport.WithHTTPClient(httpClient))
|
||||||
)
|
}
|
||||||
|
return transport.NewSSE(cfg.Url, opts...)
|
||||||
case "", "streamable_http":
|
case "", "streamable_http":
|
||||||
// Default to streamable HTTP, the newer transport.
|
// Default to streamable HTTP, the newer transport.
|
||||||
return transport.NewStreamableHTTP(
|
var opts []transport.StreamableHTTPCOption
|
||||||
cfg.Url,
|
opts = append(opts, transport.WithHTTPHeaders(headers))
|
||||||
transport.WithHTTPHeaders(headers),
|
if httpClient != nil {
|
||||||
transport.WithHTTPBasicClient(httpClient),
|
opts = append(opts, transport.WithHTTPBasicClient(httpClient))
|
||||||
)
|
}
|
||||||
|
return transport.NewStreamableHTTP(cfg.Url, opts...)
|
||||||
default:
|
default:
|
||||||
return nil, xerrors.Errorf(
|
return nil, xerrors.Errorf(
|
||||||
"unsupported transport %q", cfg.Transport,
|
"unsupported transport %q", cfg.Transport,
|
||||||
|
|||||||
@@ -0,0 +1,25 @@
|
|||||||
|
package mcpclient
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// mcpHTTPClient returns an isolated *http.Client when running
|
||||||
|
// inside tests, or nil for production. During tests,
|
||||||
|
// httptest.Server.Close() calls
|
||||||
|
// http.DefaultTransport.CloseIdleConnections(), which disrupts
|
||||||
|
// any MCP client sharing that transport. When DefaultTransport
|
||||||
|
// is a *http.Transport it is cloned; otherwise a minimal
|
||||||
|
// transport with ProxyFromEnvironment is created as a fallback.
|
||||||
|
func mcpHTTPClient() *http.Client {
|
||||||
|
if !testing.Testing() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if dt, ok := http.DefaultTransport.(*http.Transport); ok {
|
||||||
|
return &http.Client{Transport: dt.Clone()}
|
||||||
|
}
|
||||||
|
return &http.Client{Transport: &http.Transport{
|
||||||
|
Proxy: http.ProxyFromEnvironment,
|
||||||
|
}}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user