chore(coderd/coderdtest/oidctest): protect mutable fields with rwmutex (#17151)

Protects mutable fields of `FakeIDP` to avoid data races.
This commit is contained in:
Cian Johnston
2025-04-02 13:36:26 +01:00
committed by GitHub
parent d6c034d2a3
commit 8cecc4f12d
+157 -61
View File
@@ -20,6 +20,7 @@ import (
"net/url"
"strconv"
"strings"
"sync"
"testing"
"time"
@@ -58,15 +59,107 @@ type deviceFlow struct {
granted bool
}
// fakeIDPLocked is a set of fields of FakeIDP that are protected
// behind a mutex.
type fakeIDPLocked struct {
mu sync.RWMutex
issuer string
issuerURL *url.URL
key *rsa.PrivateKey
provider ProviderJSON
handler http.Handler
cfg *oauth2.Config
fakeCoderd func(req *http.Request) (*http.Response, error)
}
func (f *fakeIDPLocked) Issuer() string {
f.mu.RLock()
defer f.mu.RUnlock()
return f.issuer
}
func (f *fakeIDPLocked) IssuerURL() *url.URL {
f.mu.RLock()
defer f.mu.RUnlock()
return f.issuerURL
}
func (f *fakeIDPLocked) PrivateKey() *rsa.PrivateKey {
f.mu.RLock()
defer f.mu.RUnlock()
return f.key
}
func (f *fakeIDPLocked) Provider() ProviderJSON {
f.mu.RLock()
defer f.mu.RUnlock()
return f.provider
}
func (f *fakeIDPLocked) Config() *oauth2.Config {
f.mu.RLock()
defer f.mu.RUnlock()
return f.cfg
}
func (f *fakeIDPLocked) Handler() http.Handler {
f.mu.RLock()
defer f.mu.RUnlock()
return f.handler
}
func (f *fakeIDPLocked) SetIssuer(issuer string) {
f.mu.Lock()
defer f.mu.Unlock()
f.issuer = issuer
}
func (f *fakeIDPLocked) SetIssuerURL(issuerURL *url.URL) {
f.mu.Lock()
defer f.mu.Unlock()
f.issuerURL = issuerURL
}
func (f *fakeIDPLocked) SetProvider(provider ProviderJSON) {
f.mu.Lock()
defer f.mu.Unlock()
f.provider = provider
}
// MutateConfig is a helper function to mutate the oauth2.Config.
// Beware of re-entrant locks!
func (f *fakeIDPLocked) MutateConfig(fn func(cfg *oauth2.Config)) {
f.mu.Lock()
if f.cfg == nil {
f.cfg = &oauth2.Config{}
}
fn(f.cfg)
f.mu.Unlock()
}
func (f *fakeIDPLocked) SetHandler(handler http.Handler) {
f.mu.Lock()
defer f.mu.Unlock()
f.handler = handler
}
func (f *fakeIDPLocked) SetFakeCoderd(fakeCoderd func(req *http.Request) (*http.Response, error)) {
f.mu.Lock()
defer f.mu.Unlock()
f.fakeCoderd = fakeCoderd
}
func (f *fakeIDPLocked) FakeCoderd() func(req *http.Request) (*http.Response, error) {
f.mu.RLock()
defer f.mu.RUnlock()
return f.fakeCoderd
}
// FakeIDP is a functional OIDC provider.
// It only supports 1 OIDC client.
type FakeIDP struct {
issuer string
issuerURL *url.URL
key *rsa.PrivateKey
provider ProviderJSON
handler http.Handler
cfg *oauth2.Config
locked fakeIDPLocked
// callbackPath allows changing where the callback path to coderd is expected.
// This only affects using the Login helper functions.
@@ -110,7 +203,6 @@ type FakeIDP struct {
// some claims.
defaultIDClaims jwt.MapClaims
hookMutateToken func(token map[string]interface{})
fakeCoderd func(req *http.Request) (*http.Response, error)
hookOnRefresh func(email string) error
// Custom authentication for the client. This is useful if you want
// to test something like PKI auth vs a client_secret.
@@ -256,7 +348,7 @@ func WithServing() func(*FakeIDP) {
func WithIssuer(issuer string) func(*FakeIDP) {
return func(f *FakeIDP) {
f.issuer = issuer
f.locked.SetIssuer(issuer)
}
}
@@ -327,7 +419,9 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP {
require.NoError(t, err)
idp := &FakeIDP{
key: pkey,
locked: fakeIDPLocked{
key: pkey,
},
clientID: uuid.NewString(),
clientSecret: uuid.NewString(),
logger: slog.Make(),
@@ -348,12 +442,12 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP {
opt(idp)
}
if idp.issuer == "" {
idp.issuer = "https://coder.com"
if idp.locked.Issuer() == "" {
idp.locked.SetIssuer("https://coder.com")
}
idp.handler = idp.httpHandler(t)
idp.updateIssuerURL(t, idp.issuer)
idp.locked.SetHandler(idp.httpHandler(t))
idp.updateIssuerURL(t, idp.locked.Issuer())
if idp.serve {
idp.realServer(t)
}
@@ -369,11 +463,11 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP {
}
func (f *FakeIDP) WellknownConfig() ProviderJSON {
return f.provider
return f.locked.Provider()
}
func (f *FakeIDP) IssuerURL() *url.URL {
return f.issuerURL
return f.locked.IssuerURL()
}
func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) {
@@ -382,11 +476,11 @@ func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) {
u, err := url.Parse(issuer)
require.NoError(t, err, "invalid issuer URL")
f.issuer = issuer
f.issuerURL = u
f.locked.SetIssuer(issuer)
f.locked.SetIssuerURL(u)
// ProviderJSON is the JSON representation of the OpenID Connect provider
// These are all the urls that the IDP will respond to.
f.provider = ProviderJSON{
f.locked.SetProvider(ProviderJSON{
Issuer: issuer,
AuthURL: u.ResolveReference(&url.URL{Path: authorizePath}).String(),
TokenURL: u.ResolveReference(&url.URL{Path: tokenPath}).String(),
@@ -397,7 +491,7 @@ func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) {
"RS256",
},
ExternalAuthURL: u.ResolveReference(&url.URL{Path: "/external-auth-validate/user"}).String(),
}
})
}
// realServer turns the FakeIDP into a real http server.
@@ -405,7 +499,7 @@ func (f *FakeIDP) realServer(t testing.TB) *httptest.Server {
t.Helper()
srvURL := "localhost:0"
issURL, err := url.Parse(f.issuer)
issURL, err := url.Parse(f.locked.Issuer())
if err == nil {
if issURL.Hostname() == "localhost" || issURL.Hostname() == "127.0.0.1" {
srvURL = issURL.Host
@@ -418,7 +512,7 @@ func (f *FakeIDP) realServer(t testing.TB) *httptest.Server {
ctx, cancel := context.WithCancel(context.Background())
srv := &httptest.Server{
Listener: l,
Config: &http.Server{Handler: f.handler, ReadHeaderTimeout: time.Second * 5},
Config: &http.Server{Handler: f.locked.Handler(), ReadHeaderTimeout: time.Second * 5},
}
srv.Config.BaseContext = func(_ net.Listener) context.Context {
@@ -439,7 +533,7 @@ func (f *FakeIDP) GenerateAuthenticatedToken(claims jwt.MapClaims) (*oauth2.Toke
state := uuid.NewString()
f.stateToIDTokenClaims.Store(state, claims)
code := f.newCode(state)
return f.cfg.Exchange(oidc.ClientContext(context.Background(), f.HTTPClient(nil)), code)
return f.locked.Config().Exchange(oidc.ClientContext(context.Background(), f.HTTPClient(nil)), code)
}
// Login does the full OIDC flow starting at the "LoginButton".
@@ -620,9 +714,9 @@ func (f *FakeIDP) CreateAuthCode(t testing.TB, state string) string {
// it expects some claims to be present.
f.stateToIDTokenClaims.Store(state, jwt.MapClaims{})
code, err := OAuth2GetCode(f.cfg.AuthCodeURL(state), func(req *http.Request) (*http.Response, error) {
code, err := OAuth2GetCode(f.locked.Config().AuthCodeURL(state), func(req *http.Request) (*http.Response, error) {
rw := httptest.NewRecorder()
f.handler.ServeHTTP(rw, req)
f.locked.Handler().ServeHTTP(rw, req)
resp := rw.Result()
return resp, nil
})
@@ -644,7 +738,7 @@ func (f *FakeIDP) OIDCCallback(t testing.TB, state string, idTokenClaims jwt.Map
f.stateToIDTokenClaims.Store(state, idTokenClaims)
cli := f.HTTPClient(nil)
u := f.cfg.AuthCodeURL(state)
u := f.locked.Config().AuthCodeURL(state)
req, err := http.NewRequest("GET", u, nil)
require.NoError(t, err)
@@ -762,10 +856,10 @@ func (f *FakeIDP) encodeClaims(t testing.TB, claims jwt.MapClaims) string {
}
if _, ok := claims["iss"]; !ok {
claims["iss"] = f.issuer
claims["iss"] = f.locked.Issuer()
}
signed, err := jwt.NewWithClaims(jwt.SigningMethodRS256, claims).SignedString(f.key)
signed, err := jwt.NewWithClaims(jwt.SigningMethodRS256, claims).SignedString(f.locked.PrivateKey())
require.NoError(t, err)
return signed
@@ -782,7 +876,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
mux.Get("/.well-known/openid-configuration", func(rw http.ResponseWriter, r *http.Request) {
f.logger.Info(r.Context(), "http OIDC config", slogRequestFields(r)...)
cpy := f.provider
cpy := f.locked.Provider()
if f.hookWellKnown != nil {
err := f.hookWellKnown(r, &cpy)
if err != nil {
@@ -1082,7 +1176,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
set := jose.JSONWebKeySet{
Keys: []jose.JSONWebKey{
{
Key: f.key.Public(),
Key: f.locked.PrivateKey().Public(),
KeyID: "test-key",
Algorithm: "RSA",
},
@@ -1181,7 +1275,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
exp: time.Now().Add(lifetime),
})
verifyURL := f.issuerURL.ResolveReference(&url.URL{
verifyURL := f.locked.IssuerURL().ResolveReference(&url.URL{
Path: deviceVerify,
RawQuery: url.Values{
"device_code": {deviceCode},
@@ -1240,10 +1334,10 @@ func (f *FakeIDP) HTTPClient(rest *http.Client) *http.Client {
Jar: jar,
Transport: fakeRoundTripper{
roundTrip: func(req *http.Request) (*http.Response, error) {
u, _ := url.Parse(f.issuer)
u, _ := url.Parse(f.locked.Issuer())
if req.URL.Host != u.Host {
if f.fakeCoderd != nil {
return f.fakeCoderd(req)
if fakeCoderd := f.locked.FakeCoderd(); fakeCoderd != nil {
return fakeCoderd(req)
}
if rest == nil || rest.Transport == nil {
return nil, xerrors.Errorf("unexpected network request to %q", req.URL.Host)
@@ -1251,7 +1345,7 @@ func (f *FakeIDP) HTTPClient(rest *http.Client) *http.Client {
return rest.Transport.RoundTrip(req)
}
resp := httptest.NewRecorder()
f.handler.ServeHTTP(resp, req)
f.locked.Handler().ServeHTTP(resp, req)
return resp.Result(), nil
},
},
@@ -1269,6 +1363,7 @@ func (f *FakeIDP) RefreshUsed(refreshToken string) bool {
// for a given refresh token. By default, all refreshes use the same claims as
// the original IDToken issuance.
func (f *FakeIDP) UpdateRefreshClaims(refreshToken string, claims jwt.MapClaims) {
// no mutex because it's a sync.Map
f.refreshIDTokenClaims.Store(refreshToken, claims)
}
@@ -1276,8 +1371,9 @@ func (f *FakeIDP) UpdateRefreshClaims(refreshToken string, claims jwt.MapClaims)
// Coderd.
func (f *FakeIDP) SetRedirect(t testing.TB, u string) {
t.Helper()
f.cfg.RedirectURL = u
f.locked.MutateConfig(func(cfg *oauth2.Config) {
cfg.RedirectURL = u
})
}
// SetCoderdCallback is optional and only works if not using the IsServing.
@@ -1287,7 +1383,7 @@ func (f *FakeIDP) SetCoderdCallback(callback func(req *http.Request) (*http.Resp
if f.serve {
panic("cannot set callback handler when using 'WithServing'. Must implement an actual 'Coderd'")
}
f.fakeCoderd = callback
f.locked.SetFakeCoderd(callback)
}
func (f *FakeIDP) SetCoderdCallbackHandler(handler http.HandlerFunc) {
@@ -1384,13 +1480,13 @@ func (f *FakeIDP) ExternalAuthConfig(t testing.TB, id string, custom *ExternalAu
DisplayIcon: f.WellknownConfig().UserInfoURL,
// Omit the /user for the validate so we can easily append to it when modifying
// the cfg for advanced tests.
ValidateURL: f.issuerURL.ResolveReference(&url.URL{Path: "/external-auth-validate/"}).String(),
ValidateURL: f.locked.IssuerURL().ResolveReference(&url.URL{Path: "/external-auth-validate/"}).String(),
DeviceAuth: &externalauth.DeviceAuth{
Config: oauthCfg,
ClientID: f.clientID,
TokenURL: f.provider.TokenURL,
TokenURL: f.locked.Provider().TokenURL,
Scopes: []string{},
CodeURL: f.provider.DeviceCodeURL,
CodeURL: f.locked.Provider().DeviceCodeURL,
},
}
@@ -1401,7 +1497,7 @@ func (f *FakeIDP) ExternalAuthConfig(t testing.TB, id string, custom *ExternalAu
for _, opt := range opts {
opt(cfg)
}
f.updateIssuerURL(t, f.issuer)
f.updateIssuerURL(t, f.locked.Issuer())
return cfg
}
@@ -1410,35 +1506,35 @@ func (f *FakeIDP) AppCredentials() (clientID string, clientSecret string) {
}
func (f *FakeIDP) PublicKey() crypto.PublicKey {
return f.key.Public()
return f.locked.PrivateKey().Public()
}
func (f *FakeIDP) OauthConfig(t testing.TB, scopes []string) *oauth2.Config {
t.Helper()
if len(scopes) == 0 {
scopes = []string{"openid", "email", "profile"}
}
oauthCfg := &oauth2.Config{
ClientID: f.clientID,
ClientSecret: f.clientSecret,
Endpoint: oauth2.Endpoint{
AuthURL: f.provider.AuthURL,
TokenURL: f.provider.TokenURL,
provider := f.locked.Provider()
f.locked.MutateConfig(func(cfg *oauth2.Config) {
if len(scopes) == 0 {
scopes = []string{"openid", "email", "profile"}
}
cfg.ClientID = f.clientID
cfg.ClientSecret = f.clientSecret
cfg.Endpoint = oauth2.Endpoint{
AuthURL: provider.AuthURL,
TokenURL: provider.TokenURL,
AuthStyle: oauth2.AuthStyleInParams,
},
}
// If the user is using a real network request, they will need to do
// 'fake.SetRedirect()'
RedirectURL: "https://redirect.com",
Scopes: scopes,
}
f.cfg = oauthCfg
cfg.RedirectURL = "https://redirect.com"
cfg.Scopes = scopes
})
return oauthCfg
return f.locked.Config()
}
func (f *FakeIDP) OIDCConfigSkipIssuerChecks(t testing.TB, scopes []string, opts ...func(cfg *coderd.OIDCConfig)) *coderd.OIDCConfig {
ctx := oidc.InsecureIssuerURLContext(context.Background(), f.issuer)
ctx := oidc.InsecureIssuerURLContext(context.Background(), f.locked.Issuer())
return f.internalOIDCConfig(ctx, t, scopes, func(config *oidc.Config) {
config.SkipIssuerCheck = true
@@ -1456,7 +1552,7 @@ func (f *FakeIDP) internalOIDCConfig(ctx context.Context, t testing.TB, scopes [
oauthCfg := f.OauthConfig(t, scopes)
ctx = oidc.ClientContext(ctx, f.HTTPClient(nil))
p, err := oidc.NewProvider(ctx, f.provider.Issuer)
p, err := oidc.NewProvider(ctx, f.locked.Issuer())
require.NoError(t, err, "failed to create OIDC provider")
verifierConfig := &oidc.Config{
@@ -1473,8 +1569,8 @@ func (f *FakeIDP) internalOIDCConfig(ctx context.Context, t testing.TB, scopes [
cfg := &coderd.OIDCConfig{
OAuth2Config: oauthCfg,
Provider: p,
Verifier: oidc.NewVerifier(f.provider.Issuer, &oidc.StaticKeySet{
PublicKeys: []crypto.PublicKey{f.key.Public()},
Verifier: oidc.NewVerifier(f.locked.Issuer(), &oidc.StaticKeySet{
PublicKeys: []crypto.PublicKey{f.locked.PrivateKey().Public()},
}, verifierConfig),
UsernameField: "preferred_username",
EmailField: "email",