mirror of
https://github.com/coder/coder.git
synced 2026-06-03 13:08:25 +00:00
071383bbe8
# Enhanced OAuth2 and MCP Compliance for API Authentication This PR improves OAuth2 and MCP (Microsoft Cloud for Sovereignty) compliance by: 1. Adding RFC 9728 compliant `WWW-Authenticate` headers with resource metadata URLs 2. Passing the configured `AccessURL` to API key middleware for proper audience validation 3. Creating specialized CORS handling for OAuth2 and MCP endpoints with appropriate headers 4. Making the `state` parameter optional in OAuth2 authorization requests These changes ensure proper OAuth2 token audience validation against the configured access URL and improve interoperability with OAuth2 clients by providing better error responses and metadata discovery. Signed-off-by: Thomas Kosiewski <tk@coder.com>
266 lines
5.6 KiB
Go
266 lines
5.6 KiB
Go
package httpmw
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"github.com/coder/coder/v2/codersdk"
|
|
)
|
|
|
|
const (
|
|
testParam = "workspaceagent"
|
|
testWorkspaceAgentID = "8a70c576-12dc-42bc-b791-112a32b5bd43"
|
|
)
|
|
|
|
func TestParseUUID_Valid(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
rw := httptest.NewRecorder()
|
|
r := httptest.NewRequest("GET", "/{workspaceagent}", nil)
|
|
|
|
ctx := chi.NewRouteContext()
|
|
ctx.URLParams.Add(testParam, testWorkspaceAgentID)
|
|
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, ctx))
|
|
|
|
parsed, ok := ParseUUIDParam(rw, r, "workspaceagent")
|
|
assert.True(t, ok, "UUID should be parsed")
|
|
assert.Equal(t, testWorkspaceAgentID, parsed.String())
|
|
}
|
|
|
|
func TestParseUUID_Invalid(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
rw := httptest.NewRecorder()
|
|
r := httptest.NewRequest("GET", "/{workspaceagent}", nil)
|
|
|
|
ctx := chi.NewRouteContext()
|
|
ctx.URLParams.Add(testParam, "wrong-id")
|
|
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, ctx))
|
|
|
|
_, ok := ParseUUIDParam(rw, r, "workspaceagent")
|
|
assert.False(t, ok, "UUID should not be parsed")
|
|
assert.Equal(t, http.StatusBadRequest, rw.Code)
|
|
|
|
var response codersdk.Response
|
|
err := json.Unmarshal(rw.Body.Bytes(), &response)
|
|
require.NoError(t, err)
|
|
assert.Contains(t, response.Message, `Invalid UUID "wrong-id"`)
|
|
}
|
|
|
|
// TestNormalizeAudienceURI tests URI normalization for OAuth2 audience validation
|
|
func TestNormalizeAudienceURI(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
testCases := []struct {
|
|
name string
|
|
input string
|
|
expected string
|
|
}{
|
|
{
|
|
name: "EmptyString",
|
|
input: "",
|
|
expected: "",
|
|
},
|
|
{
|
|
name: "SimpleHTTPWithoutTrailingSlash",
|
|
input: "http://example.com",
|
|
expected: "http://example.com/",
|
|
},
|
|
{
|
|
name: "SimpleHTTPWithTrailingSlash",
|
|
input: "http://example.com/",
|
|
expected: "http://example.com/",
|
|
},
|
|
{
|
|
name: "HTTPSWithPath",
|
|
input: "https://api.example.com/v1/",
|
|
expected: "https://api.example.com/v1",
|
|
},
|
|
{
|
|
name: "CaseNormalization",
|
|
input: "HTTPS://API.EXAMPLE.COM/V1/",
|
|
expected: "https://api.example.com/V1",
|
|
},
|
|
{
|
|
name: "DefaultHTTPPort",
|
|
input: "http://example.com:80/api/",
|
|
expected: "http://example.com/api",
|
|
},
|
|
{
|
|
name: "DefaultHTTPSPort",
|
|
input: "https://example.com:443/api/",
|
|
expected: "https://example.com/api",
|
|
},
|
|
{
|
|
name: "NonDefaultPort",
|
|
input: "http://example.com:8080/api/",
|
|
expected: "http://example.com:8080/api",
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
result := normalizeAudienceURI(tc.input)
|
|
assert.Equal(t, tc.expected, result)
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestNormalizeHost tests host normalization including IDN support
|
|
func TestNormalizeHost(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
testCases := []struct {
|
|
name string
|
|
input string
|
|
expected string
|
|
}{
|
|
{
|
|
name: "EmptyString",
|
|
input: "",
|
|
expected: "",
|
|
},
|
|
{
|
|
name: "SimpleHost",
|
|
input: "example.com",
|
|
expected: "example.com",
|
|
},
|
|
{
|
|
name: "HostWithPort",
|
|
input: "example.com:8080",
|
|
expected: "example.com:8080",
|
|
},
|
|
{
|
|
name: "CaseNormalization",
|
|
input: "EXAMPLE.COM",
|
|
expected: "example.com",
|
|
},
|
|
{
|
|
name: "IPv4Address",
|
|
input: "192.168.1.1",
|
|
expected: "192.168.1.1",
|
|
},
|
|
{
|
|
name: "IPv6Address",
|
|
input: "[::1]:8080",
|
|
expected: "[::1]:8080",
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
result := normalizeHost(tc.input)
|
|
assert.Equal(t, tc.expected, result)
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestNormalizePathSegments tests path normalization including dot-segment removal
|
|
func TestNormalizePathSegments(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
testCases := []struct {
|
|
name string
|
|
input string
|
|
expected string
|
|
}{
|
|
{
|
|
name: "EmptyString",
|
|
input: "",
|
|
expected: "/",
|
|
},
|
|
{
|
|
name: "SimplePath",
|
|
input: "/api/v1",
|
|
expected: "/api/v1",
|
|
},
|
|
{
|
|
name: "PathWithDotSegments",
|
|
input: "/api/../v1/./test",
|
|
expected: "/v1/test",
|
|
},
|
|
{
|
|
name: "TrailingSlash",
|
|
input: "/api/v1/",
|
|
expected: "/api/v1",
|
|
},
|
|
{
|
|
name: "MultipleSlashes",
|
|
input: "/api//v1///test",
|
|
expected: "/api//v1///test",
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
result := normalizePathSegments(tc.input)
|
|
assert.Equal(t, tc.expected, result)
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestExtractExpectedAudience tests audience extraction from HTTP requests
|
|
func TestExtractExpectedAudience(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
testCases := []struct {
|
|
name string
|
|
scheme string
|
|
host string
|
|
path string
|
|
expected string
|
|
}{
|
|
{
|
|
name: "SimpleHTTP",
|
|
scheme: "http",
|
|
host: "example.com",
|
|
path: "/api/test",
|
|
expected: "http://example.com/",
|
|
},
|
|
{
|
|
name: "HTTPS",
|
|
scheme: "https",
|
|
host: "api.example.com",
|
|
path: "/v1/users",
|
|
expected: "https://api.example.com/",
|
|
},
|
|
{
|
|
name: "WithPort",
|
|
scheme: "http",
|
|
host: "localhost:8080",
|
|
path: "/api",
|
|
expected: "http://localhost:8080/",
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
var req *http.Request
|
|
if tc.scheme == "https" {
|
|
req = httptest.NewRequest("GET", "https://"+tc.host+tc.path, nil)
|
|
} else {
|
|
req = httptest.NewRequest("GET", "http://"+tc.host+tc.path, nil)
|
|
}
|
|
req.Host = tc.host
|
|
|
|
result := extractExpectedAudience(nil, req)
|
|
assert.Equal(t, tc.expected, result)
|
|
})
|
|
}
|
|
}
|