feat: add RFC 9728 OAuth2 resource metadata support (#18920)

# Enhanced OAuth2 and MCP Compliance for API Authentication

This PR improves OAuth2 and MCP (Microsoft Cloud for Sovereignty)
compliance by:

1. Adding RFC 9728 compliant `WWW-Authenticate` headers with resource
metadata URLs
2. Passing the configured `AccessURL` to API key middleware for proper
audience validation
3. Creating specialized CORS handling for OAuth2 and MCP endpoints with
appropriate headers
4. Making the `state` parameter optional in OAuth2 authorization
requests

These changes ensure proper OAuth2 token audience validation against the
configured access URL and improve interoperability with OAuth2 clients
by providing better error responses and metadata discovery.

Signed-off-by: Thomas Kosiewski <tk@coder.com>
This commit is contained in:
Thomas Kosiewski
2025-07-19 22:05:15 +02:00
committed by GitHub
parent f47efc62ee
commit 071383bbe8
6 changed files with 119 additions and 42 deletions
+3
View File
@@ -790,6 +790,7 @@ func New(options *Options) *API {
SessionTokenFunc: nil, // Default behavior
PostAuthAdditionalHeadersFunc: options.PostAuthAdditionalHeadersFunc,
Logger: options.Logger,
AccessURL: options.AccessURL,
})
// Same as above but it redirects to the login page.
apiKeyMiddlewareRedirect := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
@@ -801,6 +802,7 @@ func New(options *Options) *API {
SessionTokenFunc: nil, // Default behavior
PostAuthAdditionalHeadersFunc: options.PostAuthAdditionalHeadersFunc,
Logger: options.Logger,
AccessURL: options.AccessURL,
})
// Same as the first but it's optional.
apiKeyMiddlewareOptional := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
@@ -812,6 +814,7 @@ func New(options *Options) *API {
SessionTokenFunc: nil, // Default behavior
PostAuthAdditionalHeadersFunc: options.PostAuthAdditionalHeadersFunc,
Logger: options.Logger,
AccessURL: options.AccessURL,
})
workspaceAgentInfo := httpmw.ExtractWorkspaceAgentAndLatestBuild(httpmw.ExtractWorkspaceAgentAndLatestBuildConfig{
+61 -36
View File
@@ -113,6 +113,10 @@ type ExtractAPIKeyConfig struct {
// a user is authenticated to prevent additional CLI invocations.
PostAuthAdditionalHeadersFunc func(a rbac.Subject, header http.Header)
// AccessURL is the configured access URL for this Coder deployment.
// Used for generating OAuth2 resource metadata URLs in WWW-Authenticate headers.
AccessURL *url.URL
// Logger is used for logging middleware operations.
Logger slog.Logger
}
@@ -214,29 +218,9 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
return nil, nil, false
}
// Add WWW-Authenticate header for 401/403 responses (RFC 6750)
// Add WWW-Authenticate header for 401/403 responses (RFC 6750 + RFC 9728)
if code == http.StatusUnauthorized || code == http.StatusForbidden {
var wwwAuth string
switch code {
case http.StatusUnauthorized:
// Map 401 to invalid_token with specific error descriptions
switch {
case strings.Contains(response.Message, "expired") || strings.Contains(response.Detail, "expired"):
wwwAuth = `Bearer realm="coder", error="invalid_token", error_description="The access token has expired"`
case strings.Contains(response.Message, "audience") || strings.Contains(response.Message, "mismatch"):
wwwAuth = `Bearer realm="coder", error="invalid_token", error_description="The access token audience does not match this resource"`
default:
wwwAuth = `Bearer realm="coder", error="invalid_token", error_description="The access token is invalid"`
}
case http.StatusForbidden:
// Map 403 to insufficient_scope per RFC 6750
wwwAuth = `Bearer realm="coder", error="insufficient_scope", error_description="The request requires higher privileges than provided by the access token"`
default:
wwwAuth = `Bearer realm="coder"`
}
rw.Header().Set("WWW-Authenticate", wwwAuth)
rw.Header().Set("WWW-Authenticate", buildWWWAuthenticateHeader(cfg.AccessURL, r, code, response))
}
httpapi.Write(ctx, rw, code, response)
@@ -272,7 +256,7 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
// Validate OAuth2 provider app token audience (RFC 8707) if applicable
if key.LoginType == database.LoginTypeOAuth2ProviderApp {
if err := validateOAuth2ProviderAppTokenAudience(ctx, cfg.DB, *key, r); err != nil {
if err := validateOAuth2ProviderAppTokenAudience(ctx, cfg.DB, *key, cfg.AccessURL, r); err != nil {
// Log the detailed error for debugging but don't expose it to the client
cfg.Logger.Debug(ctx, "oauth2 token audience validation failed", slog.Error(err))
return optionalWrite(http.StatusForbidden, codersdk.Response{
@@ -489,7 +473,7 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
// validateOAuth2ProviderAppTokenAudience validates that an OAuth2 provider app token
// is being used with the correct audience/resource server (RFC 8707).
func validateOAuth2ProviderAppTokenAudience(ctx context.Context, db database.Store, key database.APIKey, r *http.Request) error {
func validateOAuth2ProviderAppTokenAudience(ctx context.Context, db database.Store, key database.APIKey, accessURL *url.URL, r *http.Request) error {
// Get the OAuth2 provider app token to check its audience
//nolint:gocritic // System needs to access token for audience validation
token, err := db.GetOAuth2ProviderAppTokenByAPIKeyID(dbauthz.AsSystemRestricted(ctx), key.ID)
@@ -502,8 +486,8 @@ func validateOAuth2ProviderAppTokenAudience(ctx context.Context, db database.Sto
return nil
}
// Extract the expected audience from the request
expectedAudience := extractExpectedAudience(r)
// Extract the expected audience from the access URL
expectedAudience := extractExpectedAudience(accessURL, r)
// Normalize both audience values for RFC 3986 compliant comparison
normalizedTokenAudience := normalizeAudienceURI(token.Audience.String)
@@ -624,18 +608,59 @@ func normalizePathSegments(path string) string {
// Test export functions for testing package access
// extractExpectedAudience determines the expected audience for the current request.
// This should match the resource parameter used during authorization.
func extractExpectedAudience(r *http.Request) string {
// For MCP compliance, the audience should be the canonical URI of the resource server
// This typically matches the access URL of the Coder deployment
scheme := "https"
if r.TLS == nil {
scheme = "http"
// buildWWWAuthenticateHeader constructs RFC 6750 + RFC 9728 compliant WWW-Authenticate header
func buildWWWAuthenticateHeader(accessURL *url.URL, r *http.Request, code int, response codersdk.Response) string {
// Use the configured access URL for resource metadata
if accessURL == nil {
scheme := "https"
if r.TLS == nil {
scheme = "http"
}
// Use the Host header to construct the canonical audience URI
accessURL = &url.URL{
Scheme: scheme,
Host: r.Host,
}
}
// Use the Host header to construct the canonical audience URI
audience := fmt.Sprintf("%s://%s", scheme, r.Host)
resourceMetadata := accessURL.JoinPath("/.well-known/oauth-protected-resource").String()
switch code {
case http.StatusUnauthorized:
switch {
case strings.Contains(response.Message, "expired") || strings.Contains(response.Detail, "expired"):
return fmt.Sprintf(`Bearer realm="coder", error="invalid_token", error_description="The access token has expired", resource_metadata=%q`, resourceMetadata)
case strings.Contains(response.Message, "audience") || strings.Contains(response.Message, "mismatch"):
return fmt.Sprintf(`Bearer realm="coder", error="invalid_token", error_description="The access token audience does not match this resource", resource_metadata=%q`, resourceMetadata)
default:
return fmt.Sprintf(`Bearer realm="coder", error="invalid_token", error_description="The access token is invalid", resource_metadata=%q`, resourceMetadata)
}
case http.StatusForbidden:
return fmt.Sprintf(`Bearer realm="coder", error="insufficient_scope", error_description="The request requires higher privileges than provided by the access token", resource_metadata=%q`, resourceMetadata)
default:
return fmt.Sprintf(`Bearer realm="coder", resource_metadata=%q`, resourceMetadata)
}
}
// extractExpectedAudience determines the expected audience for the current request.
// This should match the resource parameter used during authorization.
func extractExpectedAudience(accessURL *url.URL, r *http.Request) string {
// For MCP compliance, the audience should be the canonical URI of the resource server
// This typically matches the access URL of the Coder deployment
var audience string
if accessURL != nil {
audience = accessURL.String()
} else {
scheme := "https"
if r.TLS == nil {
scheme = "http"
}
// Use the Host header to construct the canonical audience URI
audience = fmt.Sprintf("%s://%s", scheme, r.Host)
}
// Normalize the URI according to RFC 3986 for consistent comparison
return normalizeAudienceURI(audience)
+49 -2
View File
@@ -4,6 +4,7 @@ import (
"net/http"
"net/url"
"regexp"
"strings"
"github.com/go-chi/cors"
@@ -28,13 +29,15 @@ const (
func Cors(allowAll bool, origins ...string) func(next http.Handler) http.Handler {
if len(origins) == 0 {
// The default behavior is '*', so putting the empty string defaults to
// the secure behavior of blocking CORs requests.
// the secure behavior of blocking CORS requests.
origins = []string{""}
}
if allowAll {
origins = []string{"*"}
}
return cors.Handler(cors.Options{
// Standard CORS for most endpoints
standardCors := cors.Handler(cors.Options{
AllowedOrigins: origins,
// We only need GET for latency requests
AllowedMethods: []string{http.MethodOptions, http.MethodGet},
@@ -42,6 +45,50 @@ func Cors(allowAll bool, origins ...string) func(next http.Handler) http.Handler
// Do not send any cookies
AllowCredentials: false,
})
// Permissive CORS for OAuth2 and MCP endpoints
permissiveCors := cors.Handler(cors.Options{
AllowedOrigins: []string{"*"},
AllowedMethods: []string{
http.MethodGet,
http.MethodPost,
http.MethodDelete,
http.MethodOptions,
},
AllowedHeaders: []string{
"Content-Type",
"Accept",
"Authorization",
"x-api-key",
"Mcp-Session-Id",
"MCP-Protocol-Version",
"Last-Event-ID",
},
ExposedHeaders: []string{
"Content-Type",
"Authorization",
"x-api-key",
"Mcp-Session-Id",
"MCP-Protocol-Version",
},
MaxAge: 86400, // 24 hours in seconds
AllowCredentials: false,
})
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Use permissive CORS for OAuth2, MCP, and well-known endpoints
if strings.HasPrefix(r.URL.Path, "/oauth2/") ||
strings.HasPrefix(r.URL.Path, "/api/experimental/mcp/") ||
strings.HasPrefix(r.URL.Path, "/.well-known/oauth-") {
permissiveCors(next).ServeHTTP(w, r)
return
}
// Use standard CORS for all other endpoints
standardCors(next).ServeHTTP(w, r)
})
}
}
func WorkspaceAppCors(regex *regexp.Regexp, app appurl.ApplicationURL) func(next http.Handler) http.Handler {
+1 -1
View File
@@ -34,7 +34,7 @@ func TestCSP(t *testing.T) {
expected := []string{
"frame-src 'self' *.test.com *.coder.com *.coder2.com",
"media-src 'self' media.com media2.com",
"media-src 'self' " + strings.Join(expectedMedia, " "),
strings.Join([]string{
"connect-src", "'self'",
// Added from host header.
+1 -1
View File
@@ -258,7 +258,7 @@ func TestExtractExpectedAudience(t *testing.T) {
}
req.Host = tc.host
result := extractExpectedAudience(req)
result := extractExpectedAudience(nil, req)
assert.Equal(t, tc.expected, result)
})
}
+4 -2
View File
@@ -33,7 +33,7 @@ func extractAuthorizeParams(r *http.Request, callbackURL *url.URL) (authorizePar
p := httpapi.NewQueryParamParser()
vals := r.URL.Query()
p.RequiredNotEmpty("state", "response_type", "client_id")
p.RequiredNotEmpty("response_type", "client_id")
params := authorizeParams{
clientID: p.String(vals, "", "client_id"),
@@ -154,7 +154,9 @@ func ProcessAuthorize(db database.Store, accessURL *url.URL) http.HandlerFunc {
newQuery := params.redirectURL.Query()
newQuery.Add("code", code.Formatted)
newQuery.Add("state", params.state)
if params.state != "" {
newQuery.Add("state", params.state)
}
params.redirectURL.RawQuery = newQuery.Encode()
http.Redirect(rw, r, params.redirectURL.String(), http.StatusTemporaryRedirect)