mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
feat: block CONNECT tunnels to private/reserved IP ranges (#23109)
## Description Blocks `CONNECT` tunnels to private and reserved IP ranges in aibridgeproxyd, preventing the proxy from being used to reach internal networks. The Coder access URL is always exempt (hostname+port match) so the proxy can reach its own deployment. It is possible to exempt additional ranges via `CODER_AIBRIDGE_PROXY_ALLOWED_PRIVATE_CIDRS`. DNS rebinding is handled differently per path: * Direct (no upstream proxy): validate the resolved IP right before the TCP dial, no window between check and connect. * Upstream proxy: Resolves and checks before forwarding to the upstream dialer. A small rebinding window exists since the upstream proxy re-resolves independently. ## Changes * Add blocked IP denylist covering private, reserved, and special-purpose ranges * Add `AllowedPrivateCIDRs` option with CLI flag and env var * Wire IP checks into `proxy.ConnectDial` for both upstream and direct paths * Add tests for blocked/allowed cases across direct dial, upstream proxy, CIDR exemptions, and CoderAccessURL exemption Notes: documentation will be handled in a follow-up PR. Closes: https://github.com/coder/security/issues/124
This commit is contained in:
+6
@@ -170,6 +170,12 @@ AI BRIDGE OPTIONS:
|
|||||||
exporting these records to external SIEM or observability systems.
|
exporting these records to external SIEM or observability systems.
|
||||||
|
|
||||||
AI BRIDGE PROXY OPTIONS:
|
AI BRIDGE PROXY OPTIONS:
|
||||||
|
--aibridge-proxy-allowed-private-cidrs string-array, $CODER_AIBRIDGE_PROXY_ALLOWED_PRIVATE_CIDRS
|
||||||
|
Comma-separated list of CIDR ranges that are permitted even though
|
||||||
|
they fall within blocked private/reserved IP ranges. By default all
|
||||||
|
private ranges are blocked to prevent SSRF attacks. Use this to allow
|
||||||
|
access to specific internal networks.
|
||||||
|
|
||||||
--aibridge-proxy-enabled bool, $CODER_AIBRIDGE_PROXY_ENABLED (default: false)
|
--aibridge-proxy-enabled bool, $CODER_AIBRIDGE_PROXY_ENABLED (default: false)
|
||||||
Enable the AI Bridge MITM Proxy for intercepting and decrypting AI
|
Enable the AI Bridge MITM Proxy for intercepting and decrypting AI
|
||||||
provider requests.
|
provider requests.
|
||||||
|
|||||||
+6
@@ -873,6 +873,12 @@ aibridgeproxy:
|
|||||||
# by the system. If not provided, the system certificate pool is used.
|
# by the system. If not provided, the system certificate pool is used.
|
||||||
# (default: <unset>, type: string)
|
# (default: <unset>, type: string)
|
||||||
upstream_proxy_ca: ""
|
upstream_proxy_ca: ""
|
||||||
|
# Comma-separated list of CIDR ranges that are permitted even though they fall
|
||||||
|
# within blocked private/reserved IP ranges. By default all private ranges are
|
||||||
|
# blocked to prevent SSRF attacks. Use this to allow access to specific internal
|
||||||
|
# networks.
|
||||||
|
# (default: <unset>, type: string-array)
|
||||||
|
allowed_private_cidrs: []
|
||||||
# Configure data retention policies for various database tables. Retention
|
# Configure data retention policies for various database tables. Retention
|
||||||
# policies automatically purge old data to reduce database size and improve
|
# policies automatically purge old data to reduce database size and improve
|
||||||
# performance. Setting a retention duration to 0 disables automatic purging for
|
# performance. Setting a retention duration to 0 disables automatic purging for
|
||||||
|
|||||||
Generated
+6
@@ -12670,6 +12670,12 @@ const docTemplate = `{
|
|||||||
"codersdk.AIBridgeProxyConfig": {
|
"codersdk.AIBridgeProxyConfig": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
"allowed_private_cidrs": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
"cert_file": {
|
"cert_file": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
},
|
},
|
||||||
|
|||||||
Generated
+6
@@ -11266,6 +11266,12 @@
|
|||||||
"codersdk.AIBridgeProxyConfig": {
|
"codersdk.AIBridgeProxyConfig": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
"allowed_private_cidrs": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
"cert_file": {
|
"cert_file": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
},
|
},
|
||||||
|
|||||||
+20
-9
@@ -3953,6 +3953,16 @@ Write out the current server config as YAML to stdout.`,
|
|||||||
Group: &deploymentGroupAIBridgeProxy,
|
Group: &deploymentGroupAIBridgeProxy,
|
||||||
YAML: "upstream_proxy_ca",
|
YAML: "upstream_proxy_ca",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Name: "AI Bridge Proxy Allowed Private CIDRs",
|
||||||
|
Description: "Comma-separated list of CIDR ranges that are permitted even though they fall within blocked private/reserved IP ranges. By default all private ranges are blocked to prevent SSRF attacks. Use this to allow access to specific internal networks.",
|
||||||
|
Flag: "aibridge-proxy-allowed-private-cidrs",
|
||||||
|
Env: "CODER_AIBRIDGE_PROXY_ALLOWED_PRIVATE_CIDRS",
|
||||||
|
Value: &c.AI.BridgeProxyConfig.AllowedPrivateCIDRs,
|
||||||
|
Default: "",
|
||||||
|
Group: &deploymentGroupAIBridgeProxy,
|
||||||
|
YAML: "allowed_private_cidrs",
|
||||||
|
},
|
||||||
|
|
||||||
// Retention settings
|
// Retention settings
|
||||||
{
|
{
|
||||||
@@ -4058,15 +4068,16 @@ type AIBridgeBedrockConfig struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type AIBridgeProxyConfig struct {
|
type AIBridgeProxyConfig struct {
|
||||||
Enabled serpent.Bool `json:"enabled" typescript:",notnull"`
|
Enabled serpent.Bool `json:"enabled" typescript:",notnull"`
|
||||||
ListenAddr serpent.String `json:"listen_addr" typescript:",notnull"`
|
ListenAddr serpent.String `json:"listen_addr" typescript:",notnull"`
|
||||||
TLSCertFile serpent.String `json:"tls_cert_file" typescript:",notnull"`
|
TLSCertFile serpent.String `json:"tls_cert_file" typescript:",notnull"`
|
||||||
TLSKeyFile serpent.String `json:"tls_key_file" typescript:",notnull"`
|
TLSKeyFile serpent.String `json:"tls_key_file" typescript:",notnull"`
|
||||||
MITMCertFile serpent.String `json:"cert_file" typescript:",notnull"`
|
MITMCertFile serpent.String `json:"cert_file" typescript:",notnull"`
|
||||||
MITMKeyFile serpent.String `json:"key_file" typescript:",notnull"`
|
MITMKeyFile serpent.String `json:"key_file" typescript:",notnull"`
|
||||||
DomainAllowlist serpent.StringArray `json:"domain_allowlist" typescript:",notnull"`
|
DomainAllowlist serpent.StringArray `json:"domain_allowlist" typescript:",notnull"`
|
||||||
UpstreamProxy serpent.String `json:"upstream_proxy" typescript:",notnull"`
|
UpstreamProxy serpent.String `json:"upstream_proxy" typescript:",notnull"`
|
||||||
UpstreamProxyCA serpent.String `json:"upstream_proxy_ca" typescript:",notnull"`
|
UpstreamProxyCA serpent.String `json:"upstream_proxy_ca" typescript:",notnull"`
|
||||||
|
AllowedPrivateCIDRs serpent.StringArray `json:"allowed_private_cidrs" typescript:",notnull"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChatConfig struct {
|
type ChatConfig struct {
|
||||||
|
|||||||
Generated
+3
@@ -163,6 +163,9 @@ curl -X GET http://coder-server:8080/api/v2/deployment/config \
|
|||||||
"agent_stat_refresh_interval": 0,
|
"agent_stat_refresh_interval": 0,
|
||||||
"ai": {
|
"ai": {
|
||||||
"aibridge_proxy": {
|
"aibridge_proxy": {
|
||||||
|
"allowed_private_cidrs": [
|
||||||
|
"string"
|
||||||
|
],
|
||||||
"cert_file": "string",
|
"cert_file": "string",
|
||||||
"domain_allowlist": [
|
"domain_allowlist": [
|
||||||
"string"
|
"string"
|
||||||
|
|||||||
Generated
+24
-11
@@ -618,6 +618,9 @@
|
|||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
|
"allowed_private_cidrs": [
|
||||||
|
"string"
|
||||||
|
],
|
||||||
"cert_file": "string",
|
"cert_file": "string",
|
||||||
"domain_allowlist": [
|
"domain_allowlist": [
|
||||||
"string"
|
"string"
|
||||||
@@ -634,17 +637,18 @@
|
|||||||
|
|
||||||
### Properties
|
### Properties
|
||||||
|
|
||||||
| Name | Type | Required | Restrictions | Description |
|
| Name | Type | Required | Restrictions | Description |
|
||||||
|---------------------|-----------------|----------|--------------|-------------|
|
|-------------------------|-----------------|----------|--------------|-------------|
|
||||||
| `cert_file` | string | false | | |
|
| `allowed_private_cidrs` | array of string | false | | |
|
||||||
| `domain_allowlist` | array of string | false | | |
|
| `cert_file` | string | false | | |
|
||||||
| `enabled` | boolean | false | | |
|
| `domain_allowlist` | array of string | false | | |
|
||||||
| `key_file` | string | false | | |
|
| `enabled` | boolean | false | | |
|
||||||
| `listen_addr` | string | false | | |
|
| `key_file` | string | false | | |
|
||||||
| `tls_cert_file` | string | false | | |
|
| `listen_addr` | string | false | | |
|
||||||
| `tls_key_file` | string | false | | |
|
| `tls_cert_file` | string | false | | |
|
||||||
| `upstream_proxy` | string | false | | |
|
| `tls_key_file` | string | false | | |
|
||||||
| `upstream_proxy_ca` | string | false | | |
|
| `upstream_proxy` | string | false | | |
|
||||||
|
| `upstream_proxy_ca` | string | false | | |
|
||||||
|
|
||||||
## codersdk.AIBridgeTokenUsage
|
## codersdk.AIBridgeTokenUsage
|
||||||
|
|
||||||
@@ -745,6 +749,9 @@
|
|||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"aibridge_proxy": {
|
"aibridge_proxy": {
|
||||||
|
"allowed_private_cidrs": [
|
||||||
|
"string"
|
||||||
|
],
|
||||||
"cert_file": "string",
|
"cert_file": "string",
|
||||||
"domain_allowlist": [
|
"domain_allowlist": [
|
||||||
"string"
|
"string"
|
||||||
@@ -2697,6 +2704,9 @@ CreateWorkspaceRequest provides options for creating a new workspace. Only one o
|
|||||||
"agent_stat_refresh_interval": 0,
|
"agent_stat_refresh_interval": 0,
|
||||||
"ai": {
|
"ai": {
|
||||||
"aibridge_proxy": {
|
"aibridge_proxy": {
|
||||||
|
"allowed_private_cidrs": [
|
||||||
|
"string"
|
||||||
|
],
|
||||||
"cert_file": "string",
|
"cert_file": "string",
|
||||||
"domain_allowlist": [
|
"domain_allowlist": [
|
||||||
"string"
|
"string"
|
||||||
@@ -3272,6 +3282,9 @@ CreateWorkspaceRequest provides options for creating a new workspace. Only one o
|
|||||||
"agent_stat_refresh_interval": 0,
|
"agent_stat_refresh_interval": 0,
|
||||||
"ai": {
|
"ai": {
|
||||||
"aibridge_proxy": {
|
"aibridge_proxy": {
|
||||||
|
"allowed_private_cidrs": [
|
||||||
|
"string"
|
||||||
|
],
|
||||||
"cert_file": "string",
|
"cert_file": "string",
|
||||||
"domain_allowlist": [
|
"domain_allowlist": [
|
||||||
"string"
|
"string"
|
||||||
|
|||||||
Generated
+10
@@ -1961,6 +1961,16 @@ URL of an upstream HTTP proxy to chain tunneled (non-allowlisted) requests throu
|
|||||||
|
|
||||||
Path to a PEM-encoded CA certificate to trust for the upstream proxy's TLS connection. Only needed for HTTPS upstream proxies with certificates not trusted by the system. If not provided, the system certificate pool is used.
|
Path to a PEM-encoded CA certificate to trust for the upstream proxy's TLS connection. Only needed for HTTPS upstream proxies with certificates not trusted by the system. If not provided, the system certificate pool is used.
|
||||||
|
|
||||||
|
### --aibridge-proxy-allowed-private-cidrs
|
||||||
|
|
||||||
|
| | |
|
||||||
|
|-------------|----------------------------------------------------------|
|
||||||
|
| Type | <code>string-array</code> |
|
||||||
|
| Environment | <code>$CODER_AIBRIDGE_PROXY_ALLOWED_PRIVATE_CIDRS</code> |
|
||||||
|
| YAML | <code>aibridgeproxy.allowed_private_cidrs</code> |
|
||||||
|
|
||||||
|
Comma-separated list of CIDR ranges that are permitted even though they fall within blocked private/reserved IP ranges. By default all private ranges are blocked to prevent SSRF attacks. Use this to allow access to specific internal networks.
|
||||||
|
|
||||||
### --audit-logs-retention
|
### --audit-logs-retention
|
||||||
|
|
||||||
| | |
|
| | |
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -17,6 +18,7 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/elazarl/goproxy"
|
"github.com/elazarl/goproxy"
|
||||||
@@ -55,6 +57,45 @@ var proxyAuthRequiredMsg = []byte(http.StatusText(http.StatusProxyAuthRequired))
|
|||||||
// to GoproxyCa. In production, only one server runs, so this has no impact.
|
// to GoproxyCa. In production, only one server runs, so this has no impact.
|
||||||
var loadMITMOnce sync.Once
|
var loadMITMOnce sync.Once
|
||||||
|
|
||||||
|
// blockedIPRanges defines private, reserved, and special-purpose IP ranges
|
||||||
|
// that are blocked by default to prevent connections to internal networks.
|
||||||
|
// Operators can selectively allow specific ranges via AllowedPrivateCIDRs.
|
||||||
|
var blockedIPRanges = func() []net.IPNet {
|
||||||
|
cidrs := []string{
|
||||||
|
"0.0.0.0/8", // RFC 1122: "This" network
|
||||||
|
"10.0.0.0/8", // RFC 1918: Private-Use
|
||||||
|
"100.64.0.0/10", // RFC 6598: Shared Address Space (CGNAT / Tailscale)
|
||||||
|
"127.0.0.0/8", // RFC 1122: Loopback
|
||||||
|
"169.254.0.0/16", // RFC 3927: Link-Local (cloud IMDS: AWS, GCP, Azure)
|
||||||
|
"172.16.0.0/12", // RFC 1918: Private-Use
|
||||||
|
"192.0.0.0/24", // RFC 6890: IETF Protocol Assignments
|
||||||
|
"192.168.0.0/16", // RFC 1918: Private-Use
|
||||||
|
"198.18.0.0/15", // RFC 2544: Benchmarking
|
||||||
|
"240.0.0.0/4", // RFC 1112: Reserved for Future Use
|
||||||
|
"::1/128", // RFC 4291: Loopback
|
||||||
|
"64:ff9b::/96", // RFC 6052: NAT64 well-known prefix
|
||||||
|
"64:ff9b:1::/48", // RFC 8215: NAT64 local-use prefix
|
||||||
|
"2002::/16", // RFC 3056: 6to4
|
||||||
|
"fc00::/7", // RFC 4193: Unique-Local
|
||||||
|
"fe80::/10", // RFC 4291: Link-Local Unicast
|
||||||
|
|
||||||
|
// Note: intentionally excluded because Go's net.IPNet.Contains matches
|
||||||
|
// all IPv4 addresses against this range due to internal IPv4-to-IPv6 mapping.
|
||||||
|
// See https://github.com/golang/go/issues/51906
|
||||||
|
// "::ffff:0:0/96", // RFC 4291: IPv4-mapped IPv6
|
||||||
|
}
|
||||||
|
|
||||||
|
ranges := make([]net.IPNet, 0, len(cidrs))
|
||||||
|
for _, cidr := range cidrs {
|
||||||
|
_, ipNet, err := net.ParseCIDR(cidr)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("invalid blocked CIDR %q: %v", cidr, err))
|
||||||
|
}
|
||||||
|
ranges = append(ranges, *ipNet)
|
||||||
|
}
|
||||||
|
return ranges
|
||||||
|
}()
|
||||||
|
|
||||||
// Server is the AI MITM (Man-in-the-Middle) proxy server.
|
// Server is the AI MITM (Man-in-the-Middle) proxy server.
|
||||||
// It is responsible for:
|
// It is responsible for:
|
||||||
// - intercepting HTTPS requests to AI providers
|
// - intercepting HTTPS requests to AI providers
|
||||||
@@ -72,6 +113,8 @@ type Server struct {
|
|||||||
// caCert is the PEM-encoded MITM CA certificate loaded during initialization.
|
// caCert is the PEM-encoded MITM CA certificate loaded during initialization.
|
||||||
// This is served to clients who need to trust the proxy's generated certificates.
|
// This is served to clients who need to trust the proxy's generated certificates.
|
||||||
caCert []byte
|
caCert []byte
|
||||||
|
// allowedPrivateRanges are CIDR ranges exempt from the blocked IP denylist.
|
||||||
|
allowedPrivateRanges []net.IPNet
|
||||||
// Metrics is the Prometheus metrics for the proxy. If nil, metrics are disabled.
|
// Metrics is the Prometheus metrics for the proxy. If nil, metrics are disabled.
|
||||||
metrics *Metrics
|
metrics *Metrics
|
||||||
}
|
}
|
||||||
@@ -134,6 +177,11 @@ type Options struct {
|
|||||||
// proxies with certificates not trusted by the system. If empty, the system
|
// proxies with certificates not trusted by the system. If empty, the system
|
||||||
// certificate pool is used.
|
// certificate pool is used.
|
||||||
UpstreamProxyCA string
|
UpstreamProxyCA string
|
||||||
|
// AllowedPrivateCIDRs is a list of CIDR ranges that are permitted even
|
||||||
|
// though they fall within blocked private/reserved IP ranges. This allows
|
||||||
|
// access to specific internal networks while keeping all other private
|
||||||
|
// ranges blocked. If empty, all private ranges are blocked.
|
||||||
|
AllowedPrivateCIDRs []string
|
||||||
// Metrics is the prometheus metrics instance for recording proxy metrics.
|
// Metrics is the prometheus metrics instance for recording proxy metrics.
|
||||||
// If nil, metrics will not be recorded.
|
// If nil, metrics will not be recorded.
|
||||||
Metrics *Metrics
|
Metrics *Metrics
|
||||||
@@ -159,6 +207,17 @@ func New(ctx context.Context, logger slog.Logger, opts Options) (*Server, error)
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, xerrors.Errorf("invalid coder access URL %q: %w", opts.CoderAccessURL, err)
|
return nil, xerrors.Errorf("invalid coder access URL %q: %w", opts.CoderAccessURL, err)
|
||||||
}
|
}
|
||||||
|
// Resolve the default port when not explicitly specified in the URL.
|
||||||
|
coderAccessPort := coderAccessURL.Port()
|
||||||
|
if coderAccessPort == "" {
|
||||||
|
switch coderAccessURL.Scheme {
|
||||||
|
case "https":
|
||||||
|
coderAccessPort = "443"
|
||||||
|
default:
|
||||||
|
coderAccessPort = "80"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
coderAccessURL.Host = net.JoinHostPort(coderAccessURL.Hostname(), coderAccessPort)
|
||||||
|
|
||||||
// MITM cert and key are required to intercept and decrypt HTTPS traffic.
|
// MITM cert and key are required to intercept and decrypt HTTPS traffic.
|
||||||
if opts.MITMCertFile == "" || opts.MITMKeyFile == "" {
|
if opts.MITMCertFile == "" || opts.MITMKeyFile == "" {
|
||||||
@@ -194,6 +253,16 @@ func New(ctx context.Context, logger slog.Logger, opts Options) (*Server, error)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Parse configured exceptions to the blocked IP ranges.
|
||||||
|
allowedPrivateRanges := make([]net.IPNet, 0, len(opts.AllowedPrivateCIDRs))
|
||||||
|
for _, cidr := range opts.AllowedPrivateCIDRs {
|
||||||
|
_, ipNet, err := net.ParseCIDR(cidr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, xerrors.Errorf("invalid allowed private CIDR %q: %w", cidr, err)
|
||||||
|
}
|
||||||
|
allowedPrivateRanges = append(allowedPrivateRanges, *ipNet)
|
||||||
|
}
|
||||||
|
|
||||||
// Load the CA certificate for MITM.
|
// Load the CA certificate for MITM.
|
||||||
certPEM, err := loadMITMCertificate(opts.MITMCertFile, opts.MITMKeyFile)
|
certPEM, err := loadMITMCertificate(opts.MITMCertFile, opts.MITMKeyFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -219,12 +288,21 @@ func New(ctx context.Context, logger slog.Logger, opts Options) (*Server, error)
|
|||||||
return nil, xerrors.Errorf("failed to load system certificate pool: %w", err)
|
return nil, xerrors.Errorf("failed to load system certificate pool: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Configure upstream proxy for tunneled (non-allowlisted) requests.
|
srv := &Server{
|
||||||
// This only affects CONNECT requests to domains not in the allowlist.
|
ctx: ctx,
|
||||||
// MITM'd requests (allowlisted domains) are handled by aiproxy and forwarded
|
logger: logger,
|
||||||
// to aibridge directly, not through the upstream proxy. AI Bridge respects
|
proxy: proxy,
|
||||||
// proxy environment variables if set, so the upstream proxy is used at that
|
tlsEnabled: opts.TLSCertFile != "",
|
||||||
// layer instead.
|
coderAccessURL: coderAccessURL,
|
||||||
|
aibridgeProviderFromHost: aibridgeProviderFromHost,
|
||||||
|
caCert: certPEM,
|
||||||
|
allowedPrivateRanges: allowedPrivateRanges,
|
||||||
|
metrics: opts.Metrics,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configure upstream proxy for tunneled (non-allowlisted) CONNECT requests.
|
||||||
|
// Allowlisted domains are MITM'd and forwarded to aibridge directly,
|
||||||
|
// bypassing the upstream proxy.
|
||||||
if opts.UpstreamProxy != "" {
|
if opts.UpstreamProxy != "" {
|
||||||
upstreamURL, err := url.Parse(opts.UpstreamProxy)
|
upstreamURL, err := url.Parse(opts.UpstreamProxy)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -273,21 +351,29 @@ func New(ctx context.Context, logger slog.Logger, opts Options) (*Server, error)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Configure tunneled CONNECT requests to go through upstream proxy.
|
connectDialer := proxy.NewConnectDialToProxyWithHandler(opts.UpstreamProxy, connectReqHandler)
|
||||||
// This only affects non-allowlisted domains; allowlisted domains are
|
proxy.ConnectDial = func(network, addr string) (net.Conn, error) {
|
||||||
// MITM'd and forwarded to aibridge.
|
// Block CONNECT tunnels to private/reserved IP ranges.
|
||||||
proxy.ConnectDial = proxy.NewConnectDialToProxyWithHandler(opts.UpstreamProxy, connectReqHandler)
|
// addr is the CONNECT target, not the upstream proxy address.
|
||||||
|
if err := srv.checkBlockedIP(ctx, addr); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return connectDialer(network, addr)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
srv := &Server{
|
// No upstream proxy configured: check private/reserved IPs and dial to the destination.
|
||||||
ctx: ctx,
|
if proxy.ConnectDial == nil {
|
||||||
logger: logger,
|
proxy.ConnectDial = func(network, addr string) (net.Conn, error) {
|
||||||
proxy: proxy,
|
return srv.checkBlockedIPAndDial(srv.ctx, network, addr)
|
||||||
tlsEnabled: opts.TLSCertFile != "",
|
}
|
||||||
coderAccessURL: coderAccessURL,
|
}
|
||||||
aibridgeProviderFromHost: aibridgeProviderFromHost,
|
|
||||||
caCert: certPEM,
|
// Override goproxy's default CONNECT error handler to avoid leaking
|
||||||
metrics: opts.Metrics,
|
// internal error details to clients. Errors are still logged by the caller.
|
||||||
|
proxy.ConnectionErrHandler = func(w io.Writer, _ *goproxy.ProxyCtx, _ error) {
|
||||||
|
msg := "Bad Gateway"
|
||||||
|
_, _ = fmt.Fprintf(w, "HTTP/1.1 502 Bad Gateway\r\nContent-Type: text/plain\r\nContent-Length: %d\r\n\r\n%s", len(msg), msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reject CONNECT requests to non-standard ports.
|
// Reject CONNECT requests to non-standard ports.
|
||||||
@@ -348,6 +434,7 @@ func New(ctx context.Context, logger slog.Logger, opts Options) (*Server, error)
|
|||||||
slog.F("coder_access_url", coderAccessURL.String()),
|
slog.F("coder_access_url", coderAccessURL.String()),
|
||||||
slog.F("domain_allowlist", mitmHosts),
|
slog.F("domain_allowlist", mitmHosts),
|
||||||
slog.F("upstream_proxy", opts.UpstreamProxy),
|
slog.F("upstream_proxy", opts.UpstreamProxy),
|
||||||
|
slog.F("allowed_private_cidrs", opts.AllowedPrivateCIDRs),
|
||||||
)
|
)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
@@ -374,6 +461,11 @@ func (s *Server) IsTLSListener() bool {
|
|||||||
return s.tlsEnabled
|
return s.tlsEnabled
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CoderAccessURL returns the parsed Coder access URL with a normalized port.
|
||||||
|
func (s *Server) CoderAccessURL() *url.URL {
|
||||||
|
return s.coderAccessURL
|
||||||
|
}
|
||||||
|
|
||||||
// Close gracefully shuts down the proxy server.
|
// Close gracefully shuts down the proxy server.
|
||||||
func (s *Server) Close() error {
|
func (s *Server) Close() error {
|
||||||
if s.httpServer == nil {
|
if s.httpServer == nil {
|
||||||
@@ -674,6 +766,105 @@ func (s *Server) tunneledMiddleware(host string, _ *goproxy.ProxyCtx) (*goproxy.
|
|||||||
return goproxy.OkConnect, host
|
return goproxy.OkConnect, host
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isBlockedIP reports whether the given IP is in a blocked private/reserved range
|
||||||
|
// and not exempted by AllowedPrivateCIDRs or the Coder access URL hostname.
|
||||||
|
func (s *Server) isBlockedIP(ip net.IP, hostname string, port string) bool {
|
||||||
|
// Always allow the Coder access URL hostname+port so the proxy doesn't
|
||||||
|
// block connections to its own deployment. Hostname-based (not IP-based)
|
||||||
|
// to handle dynamic IPs (DNS changes, load balancers, k8s rescheduling).
|
||||||
|
// The port is normalized at startup to handle URLs without explicit ports.
|
||||||
|
if strings.EqualFold(hostname, s.coderAccessURL.Hostname()) && port == s.coderAccessURL.Port() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, blocked := range blockedIPRanges {
|
||||||
|
if blocked.Contains(ip) {
|
||||||
|
for _, allowed := range s.allowedPrivateRanges {
|
||||||
|
if allowed.Contains(ip) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkBlockedIP resolves the destination address and returns an error if any
|
||||||
|
// resolved IP falls within a blocked range. Used in the upstream proxy path,
|
||||||
|
// where the actual dial is delegated to the upstream proxy dialer.
|
||||||
|
//
|
||||||
|
// Note: this only prevents DNS rebinding on aibridgeproxyd, not on upstream proxies.
|
||||||
|
// The upstream proxy performs its own DNS resolution when dialing, so there is
|
||||||
|
// a window between this check and the actual connection.
|
||||||
|
func (s *Server) checkBlockedIP(ctx context.Context, addr string) error {
|
||||||
|
host, port, err := net.SplitHostPort(addr)
|
||||||
|
if err != nil {
|
||||||
|
return xerrors.Errorf("invalid address %q: %w", addr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DNS resolution relies on the OS resolver. We avoid application-level
|
||||||
|
// caching to keep the implementation simple. DNS caching behavior depends
|
||||||
|
// on the OS resolver.
|
||||||
|
ips, err := net.DefaultResolver.LookupIPAddr(ctx, host)
|
||||||
|
if err != nil {
|
||||||
|
return xerrors.Errorf("failed to resolve %q: %w", host, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ip := range ips {
|
||||||
|
if s.isBlockedIP(ip.IP, host, port) {
|
||||||
|
s.logger.Warn(ctx, "blocking connection to private/reserved IP",
|
||||||
|
slog.F("hostname", host),
|
||||||
|
slog.F("port", port),
|
||||||
|
slog.F("resolved_ip", ip.IP.String()),
|
||||||
|
)
|
||||||
|
return xerrors.Errorf("connection to %s (%s) blocked: destination is in a private/reserved IP range", host, ip.IP)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkBlockedIPAndDial dials the destination address, blocking connections to
|
||||||
|
// private/reserved IPs. Used for tunneled CONNECT requests when no upstream
|
||||||
|
// proxy is configured.
|
||||||
|
func (s *Server) checkBlockedIPAndDial(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
host, port, err := net.SplitHostPort(addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, xerrors.Errorf("invalid address %q: %w", addr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DNS resolution is handled by Go's DialContext using the OS resolver.
|
||||||
|
// We avoid application-level DNS caching to keep the implementation
|
||||||
|
// simple. DNS caching behavior depends on the OS resolver.
|
||||||
|
dialer := net.Dialer{
|
||||||
|
// ControlContext fires after DNS resolution and before each TCP dial,
|
||||||
|
// receiving the resolved IP:port. The resolved address is always an IP,
|
||||||
|
// so there is no risk of DNS rebinding between validation and the dial.
|
||||||
|
ControlContext: func(ctx context.Context, _, address string, _ syscall.RawConn) error {
|
||||||
|
resolvedIP, _, err := net.SplitHostPort(address)
|
||||||
|
if err != nil {
|
||||||
|
return xerrors.Errorf("invalid resolved address %q: %w", address, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ip := net.ParseIP(resolvedIP)
|
||||||
|
if ip == nil {
|
||||||
|
return xerrors.Errorf("invalid resolved IP %q", resolvedIP)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.isBlockedIP(ip, host, port) {
|
||||||
|
s.logger.Warn(ctx, "blocking connection to private/reserved IP",
|
||||||
|
slog.F("hostname", host),
|
||||||
|
slog.F("port", port),
|
||||||
|
slog.F("resolved_ip", ip.String()),
|
||||||
|
)
|
||||||
|
return xerrors.Errorf("CONNECT to private/reserved IP %s (%s) is blocked", ip, host)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return dialer.DialContext(ctx, network, addr)
|
||||||
|
}
|
||||||
|
|
||||||
// handleRequest intercepts HTTP requests after MITM decryption.
|
// handleRequest intercepts HTTP requests after MITM decryption.
|
||||||
// - Requests to known AI providers are rewritten to aibridged, with the Coder token
|
// - Requests to known AI providers are rewritten to aibridged, with the Coder token
|
||||||
// (from ctx.UserData, set during CONNECT) set in the X-Coder-Token header.
|
// (from ctx.UserData, set during CONNECT) set in the X-Coder-Token header.
|
||||||
|
|||||||
@@ -153,6 +153,7 @@ type testProxyConfig struct {
|
|||||||
aibridgeProviderFromHost func(string) string
|
aibridgeProviderFromHost func(string) string
|
||||||
upstreamProxy string
|
upstreamProxy string
|
||||||
upstreamProxyCA string
|
upstreamProxyCA string
|
||||||
|
allowedPrivateCIDRs []string
|
||||||
metrics *aibridgeproxyd.Metrics
|
metrics *aibridgeproxyd.Metrics
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -200,6 +201,12 @@ func withUpstreamProxyCA(upstreamProxyCA string) testProxyOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func withAllowedPrivateCIDRs(cidrs ...string) testProxyOption {
|
||||||
|
return func(cfg *testProxyConfig) {
|
||||||
|
cfg.allowedPrivateCIDRs = cidrs
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func withMetrics(metrics *aibridgeproxyd.Metrics) testProxyOption {
|
func withMetrics(metrics *aibridgeproxyd.Metrics) testProxyOption {
|
||||||
return func(cfg *testProxyConfig) {
|
return func(cfg *testProxyConfig) {
|
||||||
cfg.metrics = metrics
|
cfg.metrics = metrics
|
||||||
@@ -223,6 +230,9 @@ func newTestProxy(t *testing.T, opts ...testProxyOption) *aibridgeproxyd.Server
|
|||||||
listenAddr: "127.0.0.1:0",
|
listenAddr: "127.0.0.1:0",
|
||||||
coderAccessURL: "http://localhost:3000",
|
coderAccessURL: "http://localhost:3000",
|
||||||
domainAllowlist: []string{"127.0.0.1", "localhost"},
|
domainAllowlist: []string{"127.0.0.1", "localhost"},
|
||||||
|
// Allow 127.0.0.1 by default so test servers, which always listen on
|
||||||
|
// loopback, are reachable. Tests that verify IP blocking override this.
|
||||||
|
allowedPrivateCIDRs: []string{"127.0.0.1/32"},
|
||||||
aibridgeProviderFromHost: func(host string) string {
|
aibridgeProviderFromHost: func(host string) string {
|
||||||
return "test-provider"
|
return "test-provider"
|
||||||
},
|
},
|
||||||
@@ -246,6 +256,7 @@ func newTestProxy(t *testing.T, opts ...testProxyOption) *aibridgeproxyd.Server
|
|||||||
AIBridgeProviderFromHost: cfg.aibridgeProviderFromHost,
|
AIBridgeProviderFromHost: cfg.aibridgeProviderFromHost,
|
||||||
UpstreamProxy: cfg.upstreamProxy,
|
UpstreamProxy: cfg.upstreamProxy,
|
||||||
UpstreamProxyCA: cfg.upstreamProxyCA,
|
UpstreamProxyCA: cfg.upstreamProxyCA,
|
||||||
|
AllowedPrivateCIDRs: cfg.allowedPrivateCIDRs,
|
||||||
Metrics: cfg.metrics,
|
Metrics: cfg.metrics,
|
||||||
}
|
}
|
||||||
if cfg.certStore != nil {
|
if cfg.certStore != nil {
|
||||||
@@ -291,11 +302,12 @@ func getProxyCertPool(t *testing.T) *x509.CertPool {
|
|||||||
|
|
||||||
// newProxyClient creates an HTTP(S) client configured to use the proxy.
|
// newProxyClient creates an HTTP(S) client configured to use the proxy.
|
||||||
// It adds a Proxy-Authorization header with the provided token for authentication.
|
// It adds a Proxy-Authorization header with the provided token for authentication.
|
||||||
// The certPool parameter specifies which certificates the client should trust:
|
// The certPool and insecureSkipVerify parameters control TLS verification:
|
||||||
// - If the proxy listener is TLS, include the listener certificate.
|
// - If the proxy listener is TLS, include the listener certificate.
|
||||||
// - For MITM'd requests, include the proxy's MITM certificate.
|
// - For MITM'd requests, include the proxy's MITM certificate.
|
||||||
// - For tunneled requests, include the target server's certificate.
|
// - For tunneled requests, include the target server's certificate.
|
||||||
func newProxyClient(t *testing.T, srv *aibridgeproxyd.Server, proxyAuth string, certPool *x509.CertPool) *http.Client {
|
// - Set insecureSkipVerify when the target cert SANs do not match the hostname.
|
||||||
|
func newProxyClient(t *testing.T, srv *aibridgeproxyd.Server, proxyAuth string, certPool *x509.CertPool, insecureSkipVerify bool) *http.Client {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
// Create an HTTP(S) client configured to use the proxy.
|
// Create an HTTP(S) client configured to use the proxy.
|
||||||
@@ -309,8 +321,9 @@ func newProxyClient(t *testing.T, srv *aibridgeproxyd.Server, proxyAuth string,
|
|||||||
transport := &http.Transport{
|
transport := &http.Transport{
|
||||||
Proxy: http.ProxyURL(proxyURL),
|
Proxy: http.ProxyURL(proxyURL),
|
||||||
TLSClientConfig: &tls.Config{
|
TLSClientConfig: &tls.Config{
|
||||||
MinVersion: tls.VersionTLS12,
|
MinVersion: tls.VersionTLS12,
|
||||||
RootCAs: certPool,
|
RootCAs: certPool,
|
||||||
|
InsecureSkipVerify: insecureSkipVerify, //nolint:gosec
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -523,6 +536,60 @@ func TestNew(t *testing.T) {
|
|||||||
require.Contains(t, err.Error(), "invalid coder access URL")
|
require.Contains(t, err.Error(), "invalid coder access URL")
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("CoderAccessURLDefaultHTTPPort", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
mitmCertFile, mitmKeyFile := getSharedTestMITMCert(t)
|
||||||
|
logger := slogtest.Make(t, nil)
|
||||||
|
|
||||||
|
srv, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{
|
||||||
|
ListenAddr: "127.0.0.1:0",
|
||||||
|
CoderAccessURL: "http://localhost",
|
||||||
|
MITMCertFile: mitmCertFile,
|
||||||
|
MITMKeyFile: mitmKeyFile,
|
||||||
|
DomainAllowlist: []string{aibridgeproxyd.HostAnthropic},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "localhost", srv.CoderAccessURL().Hostname())
|
||||||
|
require.Equal(t, "80", srv.CoderAccessURL().Port())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("CoderAccessURLDefaultHTTPSPort", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
mitmCertFile, mitmKeyFile := getSharedTestMITMCert(t)
|
||||||
|
logger := slogtest.Make(t, nil)
|
||||||
|
|
||||||
|
srv, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{
|
||||||
|
ListenAddr: "127.0.0.1:0",
|
||||||
|
CoderAccessURL: "https://localhost",
|
||||||
|
MITMCertFile: mitmCertFile,
|
||||||
|
MITMKeyFile: mitmKeyFile,
|
||||||
|
DomainAllowlist: []string{aibridgeproxyd.HostAnthropic},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "localhost", srv.CoderAccessURL().Hostname())
|
||||||
|
require.Equal(t, "443", srv.CoderAccessURL().Port())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("CoderAccessURLExplicitPort", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
mitmCertFile, mitmKeyFile := getSharedTestMITMCert(t)
|
||||||
|
logger := slogtest.Make(t, nil)
|
||||||
|
|
||||||
|
srv, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{
|
||||||
|
ListenAddr: "127.0.0.1:0",
|
||||||
|
CoderAccessURL: "http://localhost:3000",
|
||||||
|
MITMCertFile: mitmCertFile,
|
||||||
|
MITMKeyFile: mitmKeyFile,
|
||||||
|
DomainAllowlist: []string{aibridgeproxyd.HostAnthropic},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "localhost", srv.CoderAccessURL().Hostname())
|
||||||
|
require.Equal(t, "3000", srv.CoderAccessURL().Port())
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("MissingCertFile", func(t *testing.T) {
|
t.Run("MissingCertFile", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
@@ -708,6 +775,24 @@ func TestNew(t *testing.T) {
|
|||||||
require.Contains(t, err.Error(), "invalid credentials: both username and password are empty")
|
require.Contains(t, err.Error(), "invalid credentials: both username and password are empty")
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("InvalidAllowedPrivateCIDR", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
mitmCertFile, mitmKeyFile := getSharedTestMITMCert(t)
|
||||||
|
logger := slogtest.Make(t, nil)
|
||||||
|
|
||||||
|
_, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{
|
||||||
|
ListenAddr: "127.0.0.1:0",
|
||||||
|
CoderAccessURL: "http://localhost:3000",
|
||||||
|
MITMCertFile: mitmCertFile,
|
||||||
|
MITMKeyFile: mitmKeyFile,
|
||||||
|
DomainAllowlist: []string{aibridgeproxyd.HostAnthropic, aibridgeproxyd.HostOpenAI},
|
||||||
|
AllowedPrivateCIDRs: []string{"not-a-cidr"},
|
||||||
|
})
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "invalid allowed private CIDR")
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("Success", func(t *testing.T) {
|
t.Run("Success", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
@@ -877,6 +962,24 @@ func TestNew(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, srv)
|
require.NotNil(t, srv)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("SuccessWithAllowedPrivateCIDRs", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
mitmCertFile, mitmKeyFile := getSharedTestMITMCert(t)
|
||||||
|
logger := slogtest.Make(t, nil)
|
||||||
|
|
||||||
|
srv, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{
|
||||||
|
ListenAddr: "127.0.0.1:0",
|
||||||
|
CoderAccessURL: "http://localhost:3000",
|
||||||
|
MITMCertFile: mitmCertFile,
|
||||||
|
MITMKeyFile: mitmKeyFile,
|
||||||
|
DomainAllowlist: []string{aibridgeproxyd.HostAnthropic, aibridgeproxyd.HostOpenAI},
|
||||||
|
AllowedPrivateCIDRs: []string{"127.0.0.1/32"},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, srv)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestClose(t *testing.T) {
|
func TestClose(t *testing.T) {
|
||||||
@@ -1006,7 +1109,7 @@ func TestProxy_CertCaching(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Make a request through the proxy to the target server.
|
// Make a request through the proxy to the target server.
|
||||||
client := newProxyClient(t, srv, makeProxyAuthHeader("test-token"), certPool)
|
client := newProxyClient(t, srv, makeProxyAuthHeader("test-token"), certPool, false)
|
||||||
req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, targetURL.String(), nil)
|
req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, targetURL.String(), nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
@@ -1082,7 +1185,7 @@ func TestProxy_PortValidation(t *testing.T) {
|
|||||||
)
|
)
|
||||||
|
|
||||||
// Make a request through the proxy to the target server.
|
// Make a request through the proxy to the target server.
|
||||||
client := newProxyClient(t, srv, makeProxyAuthHeader("test-token"), getProxyCertPool(t))
|
client := newProxyClient(t, srv, makeProxyAuthHeader("test-token"), getProxyCertPool(t), false)
|
||||||
req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, targetURL.String(), nil)
|
req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, targetURL.String(), nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@@ -1159,7 +1262,7 @@ func TestProxy_Authentication(t *testing.T) {
|
|||||||
|
|
||||||
if tt.expectSuccess {
|
if tt.expectSuccess {
|
||||||
// Use the standard HTTP client for successful requests.
|
// Use the standard HTTP client for successful requests.
|
||||||
client := newProxyClient(t, srv, tt.proxyAuth, getProxyCertPool(t))
|
client := newProxyClient(t, srv, tt.proxyAuth, getProxyCertPool(t), false)
|
||||||
req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, targetURL.String(), nil)
|
req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, targetURL.String(), nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
@@ -1325,7 +1428,7 @@ func TestProxy_MITM(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Make a request through the proxy to the target URL.
|
// Make a request through the proxy to the target URL.
|
||||||
client := newProxyClient(t, srv, makeProxyAuthHeader("test-token"), certPool)
|
client := newProxyClient(t, srv, makeProxyAuthHeader("test-token"), certPool, false)
|
||||||
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, targetURL, strings.NewReader(`{}`))
|
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, targetURL, strings.NewReader(`{}`))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
@@ -1448,7 +1551,7 @@ func TestListenerTLS(t *testing.T) {
|
|||||||
}
|
}
|
||||||
certPool.AppendCertsFromPEM(listenerCertPEM)
|
certPool.AppendCertsFromPEM(listenerCertPEM)
|
||||||
|
|
||||||
client := newProxyClient(t, srv, makeProxyAuthHeader("test-token"), certPool)
|
client := newProxyClient(t, srv, makeProxyAuthHeader("test-token"), certPool, false)
|
||||||
req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, targetURL.String(), nil)
|
req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, targetURL.String(), nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
@@ -1835,7 +1938,7 @@ func TestUpstreamProxy(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create HTTP client configured to use aiproxy.
|
// Create HTTP client configured to use aiproxy.
|
||||||
client := newProxyClient(t, srv, makeProxyAuthHeader("test-coder-token"), certPool)
|
client := newProxyClient(t, srv, makeProxyAuthHeader("test-coder-token"), certPool, false)
|
||||||
|
|
||||||
// Make request through aiproxy.
|
// Make request through aiproxy.
|
||||||
requestBody := `{"test": "data", "foo": "bar"}`
|
requestBody := `{"test": "data", "foo": "bar"}`
|
||||||
@@ -1889,3 +1992,144 @@ func TestUpstreamProxy(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestProxy_PrivateIPBlocking(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
targetHostname string
|
||||||
|
useUpstreamProxy bool
|
||||||
|
allowedCIDRs []string
|
||||||
|
coderAccessURLFn func(targetHostname, port string) string
|
||||||
|
expectBlocked bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
// Direct IP: by default, all private/reserved IPs are blocked.
|
||||||
|
name: "BlockedDirectDial",
|
||||||
|
targetHostname: "127.0.0.1",
|
||||||
|
expectBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Hostname: DNS resolves to 127.0.0.1, which is then blocked.
|
||||||
|
name: "BlockedDirectDialByHostname",
|
||||||
|
targetHostname: "localhost",
|
||||||
|
expectBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Direct IP: block applies even with an upstream proxy configured.
|
||||||
|
name: "BlockedViaUpstreamProxy",
|
||||||
|
targetHostname: "127.0.0.1",
|
||||||
|
useUpstreamProxy: true,
|
||||||
|
expectBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Hostname: DNS resolves to 127.0.0.1, which is then blocked.
|
||||||
|
name: "BlockedViaUpstreamProxyByHostname",
|
||||||
|
targetHostname: "localhost",
|
||||||
|
useUpstreamProxy: true,
|
||||||
|
expectBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Direct IP: a configured CIDR exception allows the range.
|
||||||
|
name: "AllowedByPrivateCIDR",
|
||||||
|
targetHostname: "127.0.0.1",
|
||||||
|
allowedCIDRs: []string{"127.0.0.1/32"},
|
||||||
|
expectBlocked: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Hostname: DNS resolves to 127.0.0.1, which is allowed by the CIDR exception.
|
||||||
|
name: "AllowedByPrivateCIDRByHostname",
|
||||||
|
targetHostname: "localhost",
|
||||||
|
allowedCIDRs: []string{"127.0.0.1/32"},
|
||||||
|
expectBlocked: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Direct IP: the Coder access URL host:port is always exempt.
|
||||||
|
name: "AllowedByCoderAccessURL",
|
||||||
|
targetHostname: "127.0.0.1",
|
||||||
|
coderAccessURLFn: func(targetHostname, port string) string {
|
||||||
|
return fmt.Sprintf("http://%s:%s", targetHostname, port)
|
||||||
|
},
|
||||||
|
expectBlocked: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Hostname: DNS resolves to 127.0.0.1, which is exempt as the Coder access URL.
|
||||||
|
name: "AllowedByCoderAccessURLByHostname",
|
||||||
|
targetHostname: "localhost",
|
||||||
|
coderAccessURLFn: func(targetHostname, port string) string {
|
||||||
|
return fmt.Sprintf("http://%s:%s", targetHostname, port)
|
||||||
|
},
|
||||||
|
expectBlocked: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// The target server always listens on 127.0.0.1. When targetHostname is
|
||||||
|
// "localhost", the proxy resolves it to 127.0.0.1 via DNS, exercising
|
||||||
|
// the hostname resolution path of the IP check.
|
||||||
|
targetServer, targetURL := newTargetServer(t, func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("hello from target"))
|
||||||
|
})
|
||||||
|
|
||||||
|
// Build the CONNECT target using the configured hostname.
|
||||||
|
connectTarget := fmt.Sprintf("%s:%s", tt.targetHostname, targetURL.Port())
|
||||||
|
|
||||||
|
// Use a domain allowlist that excludes the target so CONNECT requests
|
||||||
|
// go through the tunnel path rather than being MITM'd.
|
||||||
|
opts := []testProxyOption{
|
||||||
|
withDomainAllowlist(aibridgeproxyd.HostAnthropic),
|
||||||
|
withAllowedPorts(targetURL.Port()),
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.useUpstreamProxy {
|
||||||
|
// A minimal upstream proxy server is sufficient here: the IP check
|
||||||
|
// fires inside ConnectDial before any connection reaches it.
|
||||||
|
upstreamProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {}))
|
||||||
|
t.Cleanup(upstreamProxy.Close)
|
||||||
|
opts = append(opts, withUpstreamProxy(upstreamProxy.URL))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Always override the default allowedPrivateCIDRs so blocked cases
|
||||||
|
// are not accidentally exempted by the loopback default.
|
||||||
|
opts = append(opts, withAllowedPrivateCIDRs(tt.allowedCIDRs...))
|
||||||
|
if tt.coderAccessURLFn != nil {
|
||||||
|
opts = append(opts, withCoderAccessURL(tt.coderAccessURLFn(tt.targetHostname, targetURL.Port())))
|
||||||
|
}
|
||||||
|
|
||||||
|
srv := newTestProxy(t, opts...)
|
||||||
|
|
||||||
|
if tt.expectBlocked {
|
||||||
|
// Use a raw CONNECT to observe the 502 returned when ConnectDial fails.
|
||||||
|
// Go's HTTP client does not expose the response for non-2xx CONNECT results.
|
||||||
|
resp := sendConnect(t, srv.Addr(), connectTarget, makeProxyAuthHeader("test-token"))
|
||||||
|
defer resp.Body.Close()
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, http.StatusBadGateway, resp.StatusCode)
|
||||||
|
require.Equal(t, "Bad Gateway", string(body), "error details should not be leaked to the client")
|
||||||
|
} else {
|
||||||
|
certPool := x509.NewCertPool()
|
||||||
|
certPool.AddCert(targetServer.Certificate())
|
||||||
|
// InsecureSkipVerify is needed for "localhost": by default the cert SAN is 127.0.0.1.
|
||||||
|
client := newProxyClient(t, srv, makeProxyAuthHeader("test-token"), certPool, tt.targetHostname != "127.0.0.1")
|
||||||
|
|
||||||
|
reqURL := fmt.Sprintf("https://%s/", connectTarget)
|
||||||
|
req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, reqURL, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
require.Equal(t, "hello from target", string(body))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -22,16 +22,17 @@ func newAIBridgeProxyDaemon(coderAPI *coderd.API) (*aibridgeproxyd.Server, error
|
|||||||
metrics := aibridgeproxyd.NewMetrics(reg)
|
metrics := aibridgeproxyd.NewMetrics(reg)
|
||||||
|
|
||||||
srv, err := aibridgeproxyd.New(ctx, logger, aibridgeproxyd.Options{
|
srv, err := aibridgeproxyd.New(ctx, logger, aibridgeproxyd.Options{
|
||||||
ListenAddr: coderAPI.DeploymentValues.AI.BridgeProxyConfig.ListenAddr.String(),
|
ListenAddr: coderAPI.DeploymentValues.AI.BridgeProxyConfig.ListenAddr.String(),
|
||||||
TLSCertFile: coderAPI.DeploymentValues.AI.BridgeProxyConfig.TLSCertFile.String(),
|
TLSCertFile: coderAPI.DeploymentValues.AI.BridgeProxyConfig.TLSCertFile.String(),
|
||||||
TLSKeyFile: coderAPI.DeploymentValues.AI.BridgeProxyConfig.TLSKeyFile.String(),
|
TLSKeyFile: coderAPI.DeploymentValues.AI.BridgeProxyConfig.TLSKeyFile.String(),
|
||||||
CoderAccessURL: coderAPI.AccessURL.String(),
|
CoderAccessURL: coderAPI.AccessURL.String(),
|
||||||
MITMCertFile: coderAPI.DeploymentValues.AI.BridgeProxyConfig.MITMCertFile.String(),
|
MITMCertFile: coderAPI.DeploymentValues.AI.BridgeProxyConfig.MITMCertFile.String(),
|
||||||
MITMKeyFile: coderAPI.DeploymentValues.AI.BridgeProxyConfig.MITMKeyFile.String(),
|
MITMKeyFile: coderAPI.DeploymentValues.AI.BridgeProxyConfig.MITMKeyFile.String(),
|
||||||
DomainAllowlist: coderAPI.DeploymentValues.AI.BridgeProxyConfig.DomainAllowlist.Value(),
|
DomainAllowlist: coderAPI.DeploymentValues.AI.BridgeProxyConfig.DomainAllowlist.Value(),
|
||||||
UpstreamProxy: coderAPI.DeploymentValues.AI.BridgeProxyConfig.UpstreamProxy.String(),
|
UpstreamProxy: coderAPI.DeploymentValues.AI.BridgeProxyConfig.UpstreamProxy.String(),
|
||||||
UpstreamProxyCA: coderAPI.DeploymentValues.AI.BridgeProxyConfig.UpstreamProxyCA.String(),
|
UpstreamProxyCA: coderAPI.DeploymentValues.AI.BridgeProxyConfig.UpstreamProxyCA.String(),
|
||||||
Metrics: metrics,
|
AllowedPrivateCIDRs: coderAPI.DeploymentValues.AI.BridgeProxyConfig.AllowedPrivateCIDRs.Value(),
|
||||||
|
Metrics: metrics,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, xerrors.Errorf("failed to start in-memory aibridgeproxy daemon: %w", err)
|
return nil, xerrors.Errorf("failed to start in-memory aibridgeproxy daemon: %w", err)
|
||||||
|
|||||||
@@ -171,6 +171,12 @@ AI BRIDGE OPTIONS:
|
|||||||
exporting these records to external SIEM or observability systems.
|
exporting these records to external SIEM or observability systems.
|
||||||
|
|
||||||
AI BRIDGE PROXY OPTIONS:
|
AI BRIDGE PROXY OPTIONS:
|
||||||
|
--aibridge-proxy-allowed-private-cidrs string-array, $CODER_AIBRIDGE_PROXY_ALLOWED_PRIVATE_CIDRS
|
||||||
|
Comma-separated list of CIDR ranges that are permitted even though
|
||||||
|
they fall within blocked private/reserved IP ranges. By default all
|
||||||
|
private ranges are blocked to prevent SSRF attacks. Use this to allow
|
||||||
|
access to specific internal networks.
|
||||||
|
|
||||||
--aibridge-proxy-enabled bool, $CODER_AIBRIDGE_PROXY_ENABLED (default: false)
|
--aibridge-proxy-enabled bool, $CODER_AIBRIDGE_PROXY_ENABLED (default: false)
|
||||||
Enable the AI Bridge MITM Proxy for intercepting and decrypting AI
|
Enable the AI Bridge MITM Proxy for intercepting and decrypting AI
|
||||||
provider requests.
|
provider requests.
|
||||||
|
|||||||
Generated
+1
@@ -92,6 +92,7 @@ export interface AIBridgeProxyConfig {
|
|||||||
readonly domain_allowlist: string;
|
readonly domain_allowlist: string;
|
||||||
readonly upstream_proxy: string;
|
readonly upstream_proxy: string;
|
||||||
readonly upstream_proxy_ca: string;
|
readonly upstream_proxy_ca: string;
|
||||||
|
readonly allowed_private_cidrs: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
// From codersdk/aibridge.go
|
// From codersdk/aibridge.go
|
||||||
|
|||||||
Reference in New Issue
Block a user