mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
fix(coderd): harden Azure identity certificate fetch (cherry-pick v2.31) (#25278)
Cherry-pick of https://github.com/coder/coder/commit/57b11d405f17492aa789d4b9ff33366f961a37f8 to `release/2.31`. Backport of #25274. > [!NOTE] > This PR was created by Coder Agents on behalf of a human.
This commit is contained in:
@@ -8,7 +8,9 @@ import (
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -25,6 +27,158 @@ var allowedSigners = regexp.MustCompile(`^(.*\.)?metadata\.(azure\.(com|us|cn)|m
|
||||
// each time a parse occurs.
|
||||
var pkcs7Mutex sync.Mutex
|
||||
|
||||
// allowedCertHosts contains the hosts Azure intermediate
|
||||
// certificates are served from. Only these hosts are permitted
|
||||
// when fetching issuing certificates referenced in the signer
|
||||
// certificate. This prevents SSRF via crafted
|
||||
// IssuingCertificateURL values.
|
||||
//
|
||||
// Source: https://learn.microsoft.com/en-us/azure/security/fundamentals/azure-ca-details
|
||||
var allowedCertHosts = map[string]bool{
|
||||
"www.microsoft.com": true,
|
||||
"cacerts.digicert.com": true,
|
||||
}
|
||||
|
||||
// maxCertResponseBytes is the maximum size of a certificate
|
||||
// response body we will read. Azure intermediate certificates
|
||||
// are typically under 4 KiB; 1 MiB is a generous upper bound
|
||||
// that prevents memory exhaustion from malicious responses.
|
||||
const maxCertResponseBytes = 1 << 20 // 1 MiB
|
||||
|
||||
// extraBlockedNetworks lists special-use CIDR ranges that the
|
||||
// stdlib classification methods (IsLoopback, IsPrivate, etc.) do
|
||||
// not cover. Blocking these prevents SSRF against carrier-grade
|
||||
// NAT, network-benchmarking, documentation, discard-only, and
|
||||
// the all-zeros "this network" range.
|
||||
//
|
||||
// IPv6 ranges already handled by stdlib:
|
||||
// - ::1/128 (IsLoopback)
|
||||
// - fc00::/7 (IsPrivate, ULA)
|
||||
// - fe80::/10 (IsLinkLocalUnicast)
|
||||
// - ff00::/8 (IsMulticast)
|
||||
// - ::/128 (IsUnspecified)
|
||||
var extraBlockedNetworks []*net.IPNet
|
||||
|
||||
func init() {
|
||||
for _, cidr := range []string{
|
||||
// IPv4 special-use ranges.
|
||||
"0.0.0.0/8", // RFC 1122 "this network".
|
||||
"100.64.0.0/10", // RFC 6598 carrier-grade NAT.
|
||||
"198.18.0.0/15", // RFC 2544 benchmarking.
|
||||
|
||||
// IPv6 special-use ranges not covered by stdlib.
|
||||
"64:ff9b:1::/48", // RFC 8215 IPv4/IPv6 translation.
|
||||
"100::/64", // RFC 6666 discard-only.
|
||||
"2001:2::/48", // RFC 5180 benchmarking.
|
||||
"2001:db8::/32", // RFC 3849 documentation.
|
||||
} {
|
||||
_, network, _ := net.ParseCIDR(cidr)
|
||||
extraBlockedNetworks = append(extraBlockedNetworks, network)
|
||||
}
|
||||
}
|
||||
|
||||
// isPrivateIP reports whether the IP is on a network that must
|
||||
// not be reachable when fetching certificates. IPv4-mapped IPv6
|
||||
// addresses are canonicalized to IPv4 first so a literal like
|
||||
// ::ffff:169.254.169.254 cannot bypass the IPv4 ranges.
|
||||
func isPrivateIP(ip net.IP) bool {
|
||||
if v4 := ip.To4(); v4 != nil {
|
||||
ip = v4
|
||||
}
|
||||
if ip.IsLoopback() ||
|
||||
ip.IsPrivate() ||
|
||||
ip.IsLinkLocalUnicast() ||
|
||||
ip.IsLinkLocalMulticast() ||
|
||||
ip.IsMulticast() ||
|
||||
ip.IsUnspecified() ||
|
||||
ip.IsInterfaceLocalMulticast() {
|
||||
return true
|
||||
}
|
||||
for _, network := range extraBlockedNetworks {
|
||||
if network.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// certFetchClient is an HTTP client that refuses to connect
|
||||
// to private or link-local IP addresses. This provides
|
||||
// defense-in-depth against SSRF even if the host allowlist is
|
||||
// somehow bypassed (e.g. via DNS rebinding).
|
||||
var certFetchClient = &http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("split host/port: %w", err)
|
||||
}
|
||||
ips, err := net.DefaultResolver.LookupIPAddr(ctx, host)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("resolve host: %w", err)
|
||||
}
|
||||
if len(ips) == 0 {
|
||||
return nil, xerrors.Errorf("no addresses for %q", host)
|
||||
}
|
||||
// Reject up front so a single tainted answer
|
||||
// short-circuits the dial rather than racing it.
|
||||
for _, ip := range ips {
|
||||
if isPrivateIP(ip.IP) {
|
||||
return nil, xerrors.Errorf(
|
||||
"certificate fetch blocked: %q resolved to private IP %s",
|
||||
host, ip.IP,
|
||||
)
|
||||
}
|
||||
}
|
||||
// Dial the validated IP directly. If we dialed by
|
||||
// hostname here, Go's stdlib would re-resolve and a
|
||||
// hostile resolver could swap in a private IP after
|
||||
// validation (DNS rebinding). TLS verification still
|
||||
// uses the URL host via the Transport's TLS config.
|
||||
var d net.Dialer
|
||||
var firstErr error
|
||||
for _, ip := range ips {
|
||||
conn, derr := d.DialContext(ctx, network, net.JoinHostPort(ip.IP.String(), port))
|
||||
if derr == nil {
|
||||
return conn, nil
|
||||
}
|
||||
if firstErr == nil {
|
||||
firstErr = derr
|
||||
}
|
||||
}
|
||||
return nil, firstErr
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// IsAllowedCertificateURL reports whether rawURL points to a
|
||||
// host on the allowlist, uses http or https, and targets a
|
||||
// standard PKI distribution port. Microsoft and DigiCert serve
|
||||
// these artifacts on 80/443 only; any other port is rejected to
|
||||
// keep the SSRF surface as narrow as the hostname itself.
|
||||
func IsAllowedCertificateURL(rawURL string) bool {
|
||||
if rawURL == "" {
|
||||
return false
|
||||
}
|
||||
u, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if u.Scheme != "http" && u.Scheme != "https" {
|
||||
return false
|
||||
}
|
||||
if !allowedCertHosts[u.Hostname()] {
|
||||
return false
|
||||
}
|
||||
switch u.Port() {
|
||||
case "", "80", "443":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
type metadata struct {
|
||||
VMID string `json:"vmId"`
|
||||
}
|
||||
@@ -81,29 +235,42 @@ func Validate(ctx context.Context, signature string, options Options) (string, e
|
||||
ctx, cancelFunc := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancelFunc()
|
||||
for _, certURL := range signer.IssuingCertificateURL {
|
||||
if !IsAllowedCertificateURL(certURL) {
|
||||
return "", xerrors.New("issuing certificate URL not on allowlist")
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", certURL, nil)
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("new request %q: %w", certURL, err)
|
||||
return "", xerrors.New("construct certificate request")
|
||||
}
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
res, err := certFetchClient.Do(req)
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("no cached certificate for %q found. error fetching: %w", certURL, err)
|
||||
}
|
||||
data, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
_ = res.Body.Close()
|
||||
return "", xerrors.Errorf("read body %q: %w", certURL, err)
|
||||
return "", xerrors.New("certificate fetch unsuccessful")
|
||||
}
|
||||
limited := io.LimitReader(res.Body, maxCertResponseBytes+1)
|
||||
data, err := io.ReadAll(limited)
|
||||
_ = res.Body.Close()
|
||||
if err != nil {
|
||||
return "", xerrors.New("read certificate response body")
|
||||
}
|
||||
if int64(len(data)) > maxCertResponseBytes {
|
||||
return "", xerrors.New(
|
||||
"certificate response exceeds maximum size",
|
||||
)
|
||||
}
|
||||
cert, err := x509.ParseCertificate(data)
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("parse certificate %q: %w", certURL, err)
|
||||
// Do not wrap the parse error; it may contain
|
||||
// fragments of the HTTP response body, which
|
||||
// could leak internal data to the caller.
|
||||
return "", xerrors.New(
|
||||
"fetched data is not a valid certificate",
|
||||
)
|
||||
}
|
||||
options.Intermediates.AddCert(cert)
|
||||
}
|
||||
_, err = signer.Verify(options.VerifyOptions)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return "", xerrors.New("signature verification failed after fetching issuing certificates")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,76 @@
|
||||
package azureidentity
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIsPrivateIP(t *testing.T) {
|
||||
t.Parallel()
|
||||
cases := []struct {
|
||||
name string
|
||||
ip string
|
||||
blocked bool
|
||||
}{
|
||||
{"loopback v4", "127.0.0.1", true},
|
||||
{"loopback v6", "::1", true},
|
||||
{"link local v4 (azure metadata)", "169.254.169.254", true},
|
||||
{"link local v6", "fe80::1", true},
|
||||
{"rfc1918 10/8", "10.0.0.1", true},
|
||||
{"rfc1918 172.16/12", "172.16.0.1", true},
|
||||
{"rfc1918 192.168/16", "192.168.0.1", true},
|
||||
{"ipv6 ula", "fc00::1", true},
|
||||
{"unspecified v4", "0.0.0.0", true},
|
||||
{"unspecified v6", "::", true},
|
||||
{"this-network 0.0.0.0/8", "0.1.2.3", true},
|
||||
{"cgnat 100.64/10", "100.64.0.1", true},
|
||||
{"benchmarking 198.18/15", "198.18.0.1", true},
|
||||
{"multicast v4", "224.0.0.1", true},
|
||||
{"ipv6 nat64 well-known", "64:ff9b:1::1", true},
|
||||
{"ipv6 discard-only", "100::1", true},
|
||||
{"ipv6 benchmarking", "2001:2::1", true},
|
||||
{"ipv6 documentation", "2001:db8::1", true},
|
||||
// IPv4-mapped IPv6: must canonicalize to v4 before
|
||||
// classification, otherwise an attacker could bypass
|
||||
// the metadata block via ::ffff:169.254.169.254.
|
||||
{"ipv4-mapped metadata", "::ffff:169.254.169.254", true},
|
||||
{"ipv4-mapped rfc1918", "::ffff:10.0.0.1", true},
|
||||
|
||||
{"public v4", "8.8.8.8", false},
|
||||
{"public v6", "2606:4700:4700::1111", false},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ip := net.ParseIP(tc.ip)
|
||||
require.NotNil(t, ip, "parse %q", tc.ip)
|
||||
require.Equal(t, tc.blocked, isPrivateIP(ip))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestCertFetchClientRejectsLoopback proves the dialer refuses
|
||||
// to connect even when the URL itself would have passed an
|
||||
// allowlist (httptest.Server always binds to 127.0.0.1, so a
|
||||
// successful fetch here would mean the SSRF guard had failed).
|
||||
func TestCertFetchClientRejectsLoopback(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = w.Write([]byte("should never be reached"))
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, srv.URL, nil)
|
||||
require.NoError(t, err)
|
||||
resp, err := certFetchClient.Do(req)
|
||||
if resp != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "private IP")
|
||||
}
|
||||
@@ -87,3 +87,37 @@ func TestExpiresSoon(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsAllowedCertificateURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
allowed bool
|
||||
}{
|
||||
{"microsoft http", "http://www.microsoft.com/pki/mscorp/cert.crt", true},
|
||||
{"microsoft https", "https://www.microsoft.com/pkiops/certs/cert.crt", true},
|
||||
{"digicert http", "http://cacerts.digicert.com/DigiCertGlobalRootG2.crt", true},
|
||||
{"digicert https", "https://cacerts.digicert.com/DigiCertGlobalRootG3.crt", true},
|
||||
{"evil domain", "http://evil.example.com/cert.crt", false},
|
||||
{"metadata endpoint", "http://169.254.169.254/latest/meta-data/", false},
|
||||
{"localhost", "http://localhost/secret", false},
|
||||
{"subdomain trick", "http://www.microsoft.com.evil.com/cert.crt", false},
|
||||
{"empty string", "", false},
|
||||
{"ftp scheme", "ftp://www.microsoft.com/cert.crt", false},
|
||||
{"no scheme", "www.microsoft.com/cert.crt", false},
|
||||
{"javascript scheme", "javascript:alert(1)", false},
|
||||
{"microsoft with path", "http://www.microsoft.com/pkiops/certs/cert.crt", true},
|
||||
{"microsoft explicit port 80", "http://www.microsoft.com:80/cert.crt", true},
|
||||
{"microsoft explicit port 443", "https://www.microsoft.com:443/cert.crt", true},
|
||||
{"microsoft non-standard port", "http://www.microsoft.com:8080/cert.crt", false},
|
||||
{"microsoft port 22", "http://www.microsoft.com:22/cert.crt", false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result := azureidentity.IsAllowedCertificateURL(tc.url)
|
||||
require.Equal(t, tc.allowed, result, "URL: %s", tc.url)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
|
||||
"github.com/mitchellh/mapstructure"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/awsidentity"
|
||||
"github.com/coder/coder/v2/coderd/azureidentity"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
@@ -39,9 +40,17 @@ func (api *API) postWorkspaceAuthAzureInstanceIdentity(rw http.ResponseWriter, r
|
||||
VerifyOptions: api.AzureCertificates,
|
||||
})
|
||||
if err != nil {
|
||||
// Log the full error for operators but return only a
|
||||
// generic message to the caller. Errors from the
|
||||
// certificate fetch path may contain fragments of
|
||||
// internal HTTP responses, so exposing them would be
|
||||
// an information disclosure risk.
|
||||
api.Logger.Warn(ctx, "azure identity validation failed",
|
||||
slog.Error(err),
|
||||
)
|
||||
httpapi.Write(ctx, rw, http.StatusUnauthorized, codersdk.Response{
|
||||
Message: "Invalid Azure identity.",
|
||||
Detail: err.Error(),
|
||||
Detail: "Signature verification failed.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user