Files
coder/coderd/httpmw/httpmw_internal_test.go
T
Thomas Kosiewski 071383bbe8 feat: add RFC 9728 OAuth2 resource metadata support (#18920)
# Enhanced OAuth2 and MCP Compliance for API Authentication

This PR improves OAuth2 and MCP (Microsoft Cloud for Sovereignty)
compliance by:

1. Adding RFC 9728 compliant `WWW-Authenticate` headers with resource
metadata URLs
2. Passing the configured `AccessURL` to API key middleware for proper
audience validation
3. Creating specialized CORS handling for OAuth2 and MCP endpoints with
appropriate headers
4. Making the `state` parameter optional in OAuth2 authorization
requests

These changes ensure proper OAuth2 token audience validation against the
configured access URL and improve interoperability with OAuth2 clients
by providing better error responses and metadata discovery.

Signed-off-by: Thomas Kosiewski <tk@coder.com>
2025-07-19 22:05:15 +02:00

266 lines
5.6 KiB
Go

package httpmw
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/go-chi/chi/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/codersdk"
)
const (
testParam = "workspaceagent"
testWorkspaceAgentID = "8a70c576-12dc-42bc-b791-112a32b5bd43"
)
func TestParseUUID_Valid(t *testing.T) {
t.Parallel()
rw := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/{workspaceagent}", nil)
ctx := chi.NewRouteContext()
ctx.URLParams.Add(testParam, testWorkspaceAgentID)
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, ctx))
parsed, ok := ParseUUIDParam(rw, r, "workspaceagent")
assert.True(t, ok, "UUID should be parsed")
assert.Equal(t, testWorkspaceAgentID, parsed.String())
}
func TestParseUUID_Invalid(t *testing.T) {
t.Parallel()
rw := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/{workspaceagent}", nil)
ctx := chi.NewRouteContext()
ctx.URLParams.Add(testParam, "wrong-id")
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, ctx))
_, ok := ParseUUIDParam(rw, r, "workspaceagent")
assert.False(t, ok, "UUID should not be parsed")
assert.Equal(t, http.StatusBadRequest, rw.Code)
var response codersdk.Response
err := json.Unmarshal(rw.Body.Bytes(), &response)
require.NoError(t, err)
assert.Contains(t, response.Message, `Invalid UUID "wrong-id"`)
}
// TestNormalizeAudienceURI tests URI normalization for OAuth2 audience validation
func TestNormalizeAudienceURI(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
input string
expected string
}{
{
name: "EmptyString",
input: "",
expected: "",
},
{
name: "SimpleHTTPWithoutTrailingSlash",
input: "http://example.com",
expected: "http://example.com/",
},
{
name: "SimpleHTTPWithTrailingSlash",
input: "http://example.com/",
expected: "http://example.com/",
},
{
name: "HTTPSWithPath",
input: "https://api.example.com/v1/",
expected: "https://api.example.com/v1",
},
{
name: "CaseNormalization",
input: "HTTPS://API.EXAMPLE.COM/V1/",
expected: "https://api.example.com/V1",
},
{
name: "DefaultHTTPPort",
input: "http://example.com:80/api/",
expected: "http://example.com/api",
},
{
name: "DefaultHTTPSPort",
input: "https://example.com:443/api/",
expected: "https://example.com/api",
},
{
name: "NonDefaultPort",
input: "http://example.com:8080/api/",
expected: "http://example.com:8080/api",
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
result := normalizeAudienceURI(tc.input)
assert.Equal(t, tc.expected, result)
})
}
}
// TestNormalizeHost tests host normalization including IDN support
func TestNormalizeHost(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
input string
expected string
}{
{
name: "EmptyString",
input: "",
expected: "",
},
{
name: "SimpleHost",
input: "example.com",
expected: "example.com",
},
{
name: "HostWithPort",
input: "example.com:8080",
expected: "example.com:8080",
},
{
name: "CaseNormalization",
input: "EXAMPLE.COM",
expected: "example.com",
},
{
name: "IPv4Address",
input: "192.168.1.1",
expected: "192.168.1.1",
},
{
name: "IPv6Address",
input: "[::1]:8080",
expected: "[::1]:8080",
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
result := normalizeHost(tc.input)
assert.Equal(t, tc.expected, result)
})
}
}
// TestNormalizePathSegments tests path normalization including dot-segment removal
func TestNormalizePathSegments(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
input string
expected string
}{
{
name: "EmptyString",
input: "",
expected: "/",
},
{
name: "SimplePath",
input: "/api/v1",
expected: "/api/v1",
},
{
name: "PathWithDotSegments",
input: "/api/../v1/./test",
expected: "/v1/test",
},
{
name: "TrailingSlash",
input: "/api/v1/",
expected: "/api/v1",
},
{
name: "MultipleSlashes",
input: "/api//v1///test",
expected: "/api//v1///test",
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
result := normalizePathSegments(tc.input)
assert.Equal(t, tc.expected, result)
})
}
}
// TestExtractExpectedAudience tests audience extraction from HTTP requests
func TestExtractExpectedAudience(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
scheme string
host string
path string
expected string
}{
{
name: "SimpleHTTP",
scheme: "http",
host: "example.com",
path: "/api/test",
expected: "http://example.com/",
},
{
name: "HTTPS",
scheme: "https",
host: "api.example.com",
path: "/v1/users",
expected: "https://api.example.com/",
},
{
name: "WithPort",
scheme: "http",
host: "localhost:8080",
path: "/api",
expected: "http://localhost:8080/",
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
var req *http.Request
if tc.scheme == "https" {
req = httptest.NewRequest("GET", "https://"+tc.host+tc.path, nil)
} else {
req = httptest.NewRequest("GET", "http://"+tc.host+tc.path, nil)
}
req.Host = tc.host
result := extractExpectedAudience(nil, req)
assert.Equal(t, tc.expected, result)
})
}
}