fix(coderd): harden Azure identity certificate fetch (cherry-pick v2.31) (#25278)

Cherry-pick of
https://github.com/coder/coder/commit/57b11d405f17492aa789d4b9ff33366f961a37f8
to `release/2.31`.

Backport of #25274.

> [!NOTE]
> This PR was created by Coder Agents on behalf of a human.
This commit is contained in:
Jakub Domeracki
2026-05-13 17:35:00 +02:00
committed by GitHub
parent bddd73d5d2
commit eb461163c7
4 changed files with 297 additions and 11 deletions
+177 -10
View File
@@ -8,7 +8,9 @@ import (
"encoding/pem"
"errors"
"io"
"net"
"net/http"
"net/url"
"regexp"
"sync"
"time"
@@ -25,6 +27,158 @@ var allowedSigners = regexp.MustCompile(`^(.*\.)?metadata\.(azure\.(com|us|cn)|m
// each time a parse occurs.
var pkcs7Mutex sync.Mutex
// allowedCertHosts contains the hosts Azure intermediate
// certificates are served from. Only these hosts are permitted
// when fetching issuing certificates referenced in the signer
// certificate. This prevents SSRF via crafted
// IssuingCertificateURL values.
//
// Source: https://learn.microsoft.com/en-us/azure/security/fundamentals/azure-ca-details
var allowedCertHosts = map[string]bool{
"www.microsoft.com": true,
"cacerts.digicert.com": true,
}
// maxCertResponseBytes is the maximum size of a certificate
// response body we will read. Azure intermediate certificates
// are typically under 4 KiB; 1 MiB is a generous upper bound
// that prevents memory exhaustion from malicious responses.
const maxCertResponseBytes = 1 << 20 // 1 MiB
// extraBlockedNetworks lists special-use CIDR ranges that the
// stdlib classification methods (IsLoopback, IsPrivate, etc.) do
// not cover. Blocking these prevents SSRF against carrier-grade
// NAT, network-benchmarking, documentation, discard-only, and
// the all-zeros "this network" range.
//
// IPv6 ranges already handled by stdlib:
// - ::1/128 (IsLoopback)
// - fc00::/7 (IsPrivate, ULA)
// - fe80::/10 (IsLinkLocalUnicast)
// - ff00::/8 (IsMulticast)
// - ::/128 (IsUnspecified)
var extraBlockedNetworks []*net.IPNet
func init() {
for _, cidr := range []string{
// IPv4 special-use ranges.
"0.0.0.0/8", // RFC 1122 "this network".
"100.64.0.0/10", // RFC 6598 carrier-grade NAT.
"198.18.0.0/15", // RFC 2544 benchmarking.
// IPv6 special-use ranges not covered by stdlib.
"64:ff9b:1::/48", // RFC 8215 IPv4/IPv6 translation.
"100::/64", // RFC 6666 discard-only.
"2001:2::/48", // RFC 5180 benchmarking.
"2001:db8::/32", // RFC 3849 documentation.
} {
_, network, _ := net.ParseCIDR(cidr)
extraBlockedNetworks = append(extraBlockedNetworks, network)
}
}
// isPrivateIP reports whether the IP is on a network that must
// not be reachable when fetching certificates. IPv4-mapped IPv6
// addresses are canonicalized to IPv4 first so a literal like
// ::ffff:169.254.169.254 cannot bypass the IPv4 ranges.
func isPrivateIP(ip net.IP) bool {
if v4 := ip.To4(); v4 != nil {
ip = v4
}
if ip.IsLoopback() ||
ip.IsPrivate() ||
ip.IsLinkLocalUnicast() ||
ip.IsLinkLocalMulticast() ||
ip.IsMulticast() ||
ip.IsUnspecified() ||
ip.IsInterfaceLocalMulticast() {
return true
}
for _, network := range extraBlockedNetworks {
if network.Contains(ip) {
return true
}
}
return false
}
// certFetchClient is an HTTP client that refuses to connect
// to private or link-local IP addresses. This provides
// defense-in-depth against SSRF even if the host allowlist is
// somehow bypassed (e.g. via DNS rebinding).
var certFetchClient = &http.Client{
Timeout: 5 * time.Second,
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, xerrors.Errorf("split host/port: %w", err)
}
ips, err := net.DefaultResolver.LookupIPAddr(ctx, host)
if err != nil {
return nil, xerrors.Errorf("resolve host: %w", err)
}
if len(ips) == 0 {
return nil, xerrors.Errorf("no addresses for %q", host)
}
// Reject up front so a single tainted answer
// short-circuits the dial rather than racing it.
for _, ip := range ips {
if isPrivateIP(ip.IP) {
return nil, xerrors.Errorf(
"certificate fetch blocked: %q resolved to private IP %s",
host, ip.IP,
)
}
}
// Dial the validated IP directly. If we dialed by
// hostname here, Go's stdlib would re-resolve and a
// hostile resolver could swap in a private IP after
// validation (DNS rebinding). TLS verification still
// uses the URL host via the Transport's TLS config.
var d net.Dialer
var firstErr error
for _, ip := range ips {
conn, derr := d.DialContext(ctx, network, net.JoinHostPort(ip.IP.String(), port))
if derr == nil {
return conn, nil
}
if firstErr == nil {
firstErr = derr
}
}
return nil, firstErr
},
},
}
// IsAllowedCertificateURL reports whether rawURL points to a
// host on the allowlist, uses http or https, and targets a
// standard PKI distribution port. Microsoft and DigiCert serve
// these artifacts on 80/443 only; any other port is rejected to
// keep the SSRF surface as narrow as the hostname itself.
func IsAllowedCertificateURL(rawURL string) bool {
if rawURL == "" {
return false
}
u, err := url.Parse(rawURL)
if err != nil {
return false
}
if u.Scheme != "http" && u.Scheme != "https" {
return false
}
if !allowedCertHosts[u.Hostname()] {
return false
}
switch u.Port() {
case "", "80", "443":
return true
default:
return false
}
}
type metadata struct {
VMID string `json:"vmId"`
}
@@ -81,29 +235,42 @@ func Validate(ctx context.Context, signature string, options Options) (string, e
ctx, cancelFunc := context.WithTimeout(ctx, 5*time.Second)
defer cancelFunc()
for _, certURL := range signer.IssuingCertificateURL {
if !IsAllowedCertificateURL(certURL) {
return "", xerrors.New("issuing certificate URL not on allowlist")
}
req, err := http.NewRequestWithContext(ctx, "GET", certURL, nil)
if err != nil {
return "", xerrors.Errorf("new request %q: %w", certURL, err)
return "", xerrors.New("construct certificate request")
}
res, err := http.DefaultClient.Do(req)
res, err := certFetchClient.Do(req)
if err != nil {
return "", xerrors.Errorf("no cached certificate for %q found. error fetching: %w", certURL, err)
}
data, err := io.ReadAll(res.Body)
if err != nil {
_ = res.Body.Close()
return "", xerrors.Errorf("read body %q: %w", certURL, err)
return "", xerrors.New("certificate fetch unsuccessful")
}
limited := io.LimitReader(res.Body, maxCertResponseBytes+1)
data, err := io.ReadAll(limited)
_ = res.Body.Close()
if err != nil {
return "", xerrors.New("read certificate response body")
}
if int64(len(data)) > maxCertResponseBytes {
return "", xerrors.New(
"certificate response exceeds maximum size",
)
}
cert, err := x509.ParseCertificate(data)
if err != nil {
return "", xerrors.Errorf("parse certificate %q: %w", certURL, err)
// Do not wrap the parse error; it may contain
// fragments of the HTTP response body, which
// could leak internal data to the caller.
return "", xerrors.New(
"fetched data is not a valid certificate",
)
}
options.Intermediates.AddCert(cert)
}
_, err = signer.Verify(options.VerifyOptions)
if err != nil {
return "", err
return "", xerrors.New("signature verification failed after fetching issuing certificates")
}
}
@@ -0,0 +1,76 @@
package azureidentity
import (
"context"
"net"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
)
func TestIsPrivateIP(t *testing.T) {
t.Parallel()
cases := []struct {
name string
ip string
blocked bool
}{
{"loopback v4", "127.0.0.1", true},
{"loopback v6", "::1", true},
{"link local v4 (azure metadata)", "169.254.169.254", true},
{"link local v6", "fe80::1", true},
{"rfc1918 10/8", "10.0.0.1", true},
{"rfc1918 172.16/12", "172.16.0.1", true},
{"rfc1918 192.168/16", "192.168.0.1", true},
{"ipv6 ula", "fc00::1", true},
{"unspecified v4", "0.0.0.0", true},
{"unspecified v6", "::", true},
{"this-network 0.0.0.0/8", "0.1.2.3", true},
{"cgnat 100.64/10", "100.64.0.1", true},
{"benchmarking 198.18/15", "198.18.0.1", true},
{"multicast v4", "224.0.0.1", true},
{"ipv6 nat64 well-known", "64:ff9b:1::1", true},
{"ipv6 discard-only", "100::1", true},
{"ipv6 benchmarking", "2001:2::1", true},
{"ipv6 documentation", "2001:db8::1", true},
// IPv4-mapped IPv6: must canonicalize to v4 before
// classification, otherwise an attacker could bypass
// the metadata block via ::ffff:169.254.169.254.
{"ipv4-mapped metadata", "::ffff:169.254.169.254", true},
{"ipv4-mapped rfc1918", "::ffff:10.0.0.1", true},
{"public v4", "8.8.8.8", false},
{"public v6", "2606:4700:4700::1111", false},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ip := net.ParseIP(tc.ip)
require.NotNil(t, ip, "parse %q", tc.ip)
require.Equal(t, tc.blocked, isPrivateIP(ip))
})
}
}
// TestCertFetchClientRejectsLoopback proves the dialer refuses
// to connect even when the URL itself would have passed an
// allowlist (httptest.Server always binds to 127.0.0.1, so a
// successful fetch here would mean the SSRF guard had failed).
func TestCertFetchClientRejectsLoopback(t *testing.T) {
t.Parallel()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte("should never be reached"))
}))
t.Cleanup(srv.Close)
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, srv.URL, nil)
require.NoError(t, err)
resp, err := certFetchClient.Do(req)
if resp != nil {
defer resp.Body.Close()
}
require.Error(t, err)
require.Contains(t, err.Error(), "private IP")
}
@@ -87,3 +87,37 @@ func TestExpiresSoon(t *testing.T) {
}
}
}
func TestIsAllowedCertificateURL(t *testing.T) {
t.Parallel()
tests := []struct {
name string
url string
allowed bool
}{
{"microsoft http", "http://www.microsoft.com/pki/mscorp/cert.crt", true},
{"microsoft https", "https://www.microsoft.com/pkiops/certs/cert.crt", true},
{"digicert http", "http://cacerts.digicert.com/DigiCertGlobalRootG2.crt", true},
{"digicert https", "https://cacerts.digicert.com/DigiCertGlobalRootG3.crt", true},
{"evil domain", "http://evil.example.com/cert.crt", false},
{"metadata endpoint", "http://169.254.169.254/latest/meta-data/", false},
{"localhost", "http://localhost/secret", false},
{"subdomain trick", "http://www.microsoft.com.evil.com/cert.crt", false},
{"empty string", "", false},
{"ftp scheme", "ftp://www.microsoft.com/cert.crt", false},
{"no scheme", "www.microsoft.com/cert.crt", false},
{"javascript scheme", "javascript:alert(1)", false},
{"microsoft with path", "http://www.microsoft.com/pkiops/certs/cert.crt", true},
{"microsoft explicit port 80", "http://www.microsoft.com:80/cert.crt", true},
{"microsoft explicit port 443", "https://www.microsoft.com:443/cert.crt", true},
{"microsoft non-standard port", "http://www.microsoft.com:8080/cert.crt", false},
{"microsoft port 22", "http://www.microsoft.com:22/cert.crt", false},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
result := azureidentity.IsAllowedCertificateURL(tc.url)
require.Equal(t, tc.allowed, result, "URL: %s", tc.url)
})
}
}
+10 -1
View File
@@ -7,6 +7,7 @@ import (
"github.com/mitchellh/mapstructure"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/awsidentity"
"github.com/coder/coder/v2/coderd/azureidentity"
"github.com/coder/coder/v2/coderd/database"
@@ -39,9 +40,17 @@ func (api *API) postWorkspaceAuthAzureInstanceIdentity(rw http.ResponseWriter, r
VerifyOptions: api.AzureCertificates,
})
if err != nil {
// Log the full error for operators but return only a
// generic message to the caller. Errors from the
// certificate fetch path may contain fragments of
// internal HTTP responses, so exposing them would be
// an information disclosure risk.
api.Logger.Warn(ctx, "azure identity validation failed",
slog.Error(err),
)
httpapi.Write(ctx, rw, http.StatusUnauthorized, codersdk.Response{
Message: "Invalid Azure identity.",
Detail: err.Error(),
Detail: "Signature verification failed.",
})
return
}