mirror of
https://github.com/coder/coder.git
synced 2026-06-03 04:58:23 +00:00
chore(coderd/coderdtest/oidctest): protect mutable fields with rwmutex (#17151)
Protects mutable fields of `FakeIDP` to avoid data races.
This commit is contained in:
@@ -20,6 +20,7 @@ import (
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -58,15 +59,107 @@ type deviceFlow struct {
|
||||
granted bool
|
||||
}
|
||||
|
||||
// fakeIDPLocked is a set of fields of FakeIDP that are protected
|
||||
// behind a mutex.
|
||||
type fakeIDPLocked struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
issuer string
|
||||
issuerURL *url.URL
|
||||
key *rsa.PrivateKey
|
||||
provider ProviderJSON
|
||||
handler http.Handler
|
||||
cfg *oauth2.Config
|
||||
fakeCoderd func(req *http.Request) (*http.Response, error)
|
||||
}
|
||||
|
||||
func (f *fakeIDPLocked) Issuer() string {
|
||||
f.mu.RLock()
|
||||
defer f.mu.RUnlock()
|
||||
return f.issuer
|
||||
}
|
||||
|
||||
func (f *fakeIDPLocked) IssuerURL() *url.URL {
|
||||
f.mu.RLock()
|
||||
defer f.mu.RUnlock()
|
||||
return f.issuerURL
|
||||
}
|
||||
|
||||
func (f *fakeIDPLocked) PrivateKey() *rsa.PrivateKey {
|
||||
f.mu.RLock()
|
||||
defer f.mu.RUnlock()
|
||||
return f.key
|
||||
}
|
||||
|
||||
func (f *fakeIDPLocked) Provider() ProviderJSON {
|
||||
f.mu.RLock()
|
||||
defer f.mu.RUnlock()
|
||||
return f.provider
|
||||
}
|
||||
|
||||
func (f *fakeIDPLocked) Config() *oauth2.Config {
|
||||
f.mu.RLock()
|
||||
defer f.mu.RUnlock()
|
||||
return f.cfg
|
||||
}
|
||||
|
||||
func (f *fakeIDPLocked) Handler() http.Handler {
|
||||
f.mu.RLock()
|
||||
defer f.mu.RUnlock()
|
||||
return f.handler
|
||||
}
|
||||
|
||||
func (f *fakeIDPLocked) SetIssuer(issuer string) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.issuer = issuer
|
||||
}
|
||||
|
||||
func (f *fakeIDPLocked) SetIssuerURL(issuerURL *url.URL) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.issuerURL = issuerURL
|
||||
}
|
||||
|
||||
func (f *fakeIDPLocked) SetProvider(provider ProviderJSON) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.provider = provider
|
||||
}
|
||||
|
||||
// MutateConfig is a helper function to mutate the oauth2.Config.
|
||||
// Beware of re-entrant locks!
|
||||
func (f *fakeIDPLocked) MutateConfig(fn func(cfg *oauth2.Config)) {
|
||||
f.mu.Lock()
|
||||
if f.cfg == nil {
|
||||
f.cfg = &oauth2.Config{}
|
||||
}
|
||||
fn(f.cfg)
|
||||
f.mu.Unlock()
|
||||
}
|
||||
|
||||
func (f *fakeIDPLocked) SetHandler(handler http.Handler) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.handler = handler
|
||||
}
|
||||
|
||||
func (f *fakeIDPLocked) SetFakeCoderd(fakeCoderd func(req *http.Request) (*http.Response, error)) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.fakeCoderd = fakeCoderd
|
||||
}
|
||||
|
||||
func (f *fakeIDPLocked) FakeCoderd() func(req *http.Request) (*http.Response, error) {
|
||||
f.mu.RLock()
|
||||
defer f.mu.RUnlock()
|
||||
return f.fakeCoderd
|
||||
}
|
||||
|
||||
// FakeIDP is a functional OIDC provider.
|
||||
// It only supports 1 OIDC client.
|
||||
type FakeIDP struct {
|
||||
issuer string
|
||||
issuerURL *url.URL
|
||||
key *rsa.PrivateKey
|
||||
provider ProviderJSON
|
||||
handler http.Handler
|
||||
cfg *oauth2.Config
|
||||
locked fakeIDPLocked
|
||||
|
||||
// callbackPath allows changing where the callback path to coderd is expected.
|
||||
// This only affects using the Login helper functions.
|
||||
@@ -110,7 +203,6 @@ type FakeIDP struct {
|
||||
// some claims.
|
||||
defaultIDClaims jwt.MapClaims
|
||||
hookMutateToken func(token map[string]interface{})
|
||||
fakeCoderd func(req *http.Request) (*http.Response, error)
|
||||
hookOnRefresh func(email string) error
|
||||
// Custom authentication for the client. This is useful if you want
|
||||
// to test something like PKI auth vs a client_secret.
|
||||
@@ -256,7 +348,7 @@ func WithServing() func(*FakeIDP) {
|
||||
|
||||
func WithIssuer(issuer string) func(*FakeIDP) {
|
||||
return func(f *FakeIDP) {
|
||||
f.issuer = issuer
|
||||
f.locked.SetIssuer(issuer)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -327,7 +419,9 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP {
|
||||
require.NoError(t, err)
|
||||
|
||||
idp := &FakeIDP{
|
||||
key: pkey,
|
||||
locked: fakeIDPLocked{
|
||||
key: pkey,
|
||||
},
|
||||
clientID: uuid.NewString(),
|
||||
clientSecret: uuid.NewString(),
|
||||
logger: slog.Make(),
|
||||
@@ -348,12 +442,12 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP {
|
||||
opt(idp)
|
||||
}
|
||||
|
||||
if idp.issuer == "" {
|
||||
idp.issuer = "https://coder.com"
|
||||
if idp.locked.Issuer() == "" {
|
||||
idp.locked.SetIssuer("https://coder.com")
|
||||
}
|
||||
|
||||
idp.handler = idp.httpHandler(t)
|
||||
idp.updateIssuerURL(t, idp.issuer)
|
||||
idp.locked.SetHandler(idp.httpHandler(t))
|
||||
idp.updateIssuerURL(t, idp.locked.Issuer())
|
||||
if idp.serve {
|
||||
idp.realServer(t)
|
||||
}
|
||||
@@ -369,11 +463,11 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP {
|
||||
}
|
||||
|
||||
func (f *FakeIDP) WellknownConfig() ProviderJSON {
|
||||
return f.provider
|
||||
return f.locked.Provider()
|
||||
}
|
||||
|
||||
func (f *FakeIDP) IssuerURL() *url.URL {
|
||||
return f.issuerURL
|
||||
return f.locked.IssuerURL()
|
||||
}
|
||||
|
||||
func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) {
|
||||
@@ -382,11 +476,11 @@ func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) {
|
||||
u, err := url.Parse(issuer)
|
||||
require.NoError(t, err, "invalid issuer URL")
|
||||
|
||||
f.issuer = issuer
|
||||
f.issuerURL = u
|
||||
f.locked.SetIssuer(issuer)
|
||||
f.locked.SetIssuerURL(u)
|
||||
// ProviderJSON is the JSON representation of the OpenID Connect provider
|
||||
// These are all the urls that the IDP will respond to.
|
||||
f.provider = ProviderJSON{
|
||||
f.locked.SetProvider(ProviderJSON{
|
||||
Issuer: issuer,
|
||||
AuthURL: u.ResolveReference(&url.URL{Path: authorizePath}).String(),
|
||||
TokenURL: u.ResolveReference(&url.URL{Path: tokenPath}).String(),
|
||||
@@ -397,7 +491,7 @@ func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) {
|
||||
"RS256",
|
||||
},
|
||||
ExternalAuthURL: u.ResolveReference(&url.URL{Path: "/external-auth-validate/user"}).String(),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// realServer turns the FakeIDP into a real http server.
|
||||
@@ -405,7 +499,7 @@ func (f *FakeIDP) realServer(t testing.TB) *httptest.Server {
|
||||
t.Helper()
|
||||
|
||||
srvURL := "localhost:0"
|
||||
issURL, err := url.Parse(f.issuer)
|
||||
issURL, err := url.Parse(f.locked.Issuer())
|
||||
if err == nil {
|
||||
if issURL.Hostname() == "localhost" || issURL.Hostname() == "127.0.0.1" {
|
||||
srvURL = issURL.Host
|
||||
@@ -418,7 +512,7 @@ func (f *FakeIDP) realServer(t testing.TB) *httptest.Server {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
srv := &httptest.Server{
|
||||
Listener: l,
|
||||
Config: &http.Server{Handler: f.handler, ReadHeaderTimeout: time.Second * 5},
|
||||
Config: &http.Server{Handler: f.locked.Handler(), ReadHeaderTimeout: time.Second * 5},
|
||||
}
|
||||
|
||||
srv.Config.BaseContext = func(_ net.Listener) context.Context {
|
||||
@@ -439,7 +533,7 @@ func (f *FakeIDP) GenerateAuthenticatedToken(claims jwt.MapClaims) (*oauth2.Toke
|
||||
state := uuid.NewString()
|
||||
f.stateToIDTokenClaims.Store(state, claims)
|
||||
code := f.newCode(state)
|
||||
return f.cfg.Exchange(oidc.ClientContext(context.Background(), f.HTTPClient(nil)), code)
|
||||
return f.locked.Config().Exchange(oidc.ClientContext(context.Background(), f.HTTPClient(nil)), code)
|
||||
}
|
||||
|
||||
// Login does the full OIDC flow starting at the "LoginButton".
|
||||
@@ -620,9 +714,9 @@ func (f *FakeIDP) CreateAuthCode(t testing.TB, state string) string {
|
||||
// it expects some claims to be present.
|
||||
f.stateToIDTokenClaims.Store(state, jwt.MapClaims{})
|
||||
|
||||
code, err := OAuth2GetCode(f.cfg.AuthCodeURL(state), func(req *http.Request) (*http.Response, error) {
|
||||
code, err := OAuth2GetCode(f.locked.Config().AuthCodeURL(state), func(req *http.Request) (*http.Response, error) {
|
||||
rw := httptest.NewRecorder()
|
||||
f.handler.ServeHTTP(rw, req)
|
||||
f.locked.Handler().ServeHTTP(rw, req)
|
||||
resp := rw.Result()
|
||||
return resp, nil
|
||||
})
|
||||
@@ -644,7 +738,7 @@ func (f *FakeIDP) OIDCCallback(t testing.TB, state string, idTokenClaims jwt.Map
|
||||
f.stateToIDTokenClaims.Store(state, idTokenClaims)
|
||||
|
||||
cli := f.HTTPClient(nil)
|
||||
u := f.cfg.AuthCodeURL(state)
|
||||
u := f.locked.Config().AuthCodeURL(state)
|
||||
req, err := http.NewRequest("GET", u, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -762,10 +856,10 @@ func (f *FakeIDP) encodeClaims(t testing.TB, claims jwt.MapClaims) string {
|
||||
}
|
||||
|
||||
if _, ok := claims["iss"]; !ok {
|
||||
claims["iss"] = f.issuer
|
||||
claims["iss"] = f.locked.Issuer()
|
||||
}
|
||||
|
||||
signed, err := jwt.NewWithClaims(jwt.SigningMethodRS256, claims).SignedString(f.key)
|
||||
signed, err := jwt.NewWithClaims(jwt.SigningMethodRS256, claims).SignedString(f.locked.PrivateKey())
|
||||
require.NoError(t, err)
|
||||
|
||||
return signed
|
||||
@@ -782,7 +876,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
|
||||
mux.Get("/.well-known/openid-configuration", func(rw http.ResponseWriter, r *http.Request) {
|
||||
f.logger.Info(r.Context(), "http OIDC config", slogRequestFields(r)...)
|
||||
|
||||
cpy := f.provider
|
||||
cpy := f.locked.Provider()
|
||||
if f.hookWellKnown != nil {
|
||||
err := f.hookWellKnown(r, &cpy)
|
||||
if err != nil {
|
||||
@@ -1082,7 +1176,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
|
||||
set := jose.JSONWebKeySet{
|
||||
Keys: []jose.JSONWebKey{
|
||||
{
|
||||
Key: f.key.Public(),
|
||||
Key: f.locked.PrivateKey().Public(),
|
||||
KeyID: "test-key",
|
||||
Algorithm: "RSA",
|
||||
},
|
||||
@@ -1181,7 +1275,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
|
||||
exp: time.Now().Add(lifetime),
|
||||
})
|
||||
|
||||
verifyURL := f.issuerURL.ResolveReference(&url.URL{
|
||||
verifyURL := f.locked.IssuerURL().ResolveReference(&url.URL{
|
||||
Path: deviceVerify,
|
||||
RawQuery: url.Values{
|
||||
"device_code": {deviceCode},
|
||||
@@ -1240,10 +1334,10 @@ func (f *FakeIDP) HTTPClient(rest *http.Client) *http.Client {
|
||||
Jar: jar,
|
||||
Transport: fakeRoundTripper{
|
||||
roundTrip: func(req *http.Request) (*http.Response, error) {
|
||||
u, _ := url.Parse(f.issuer)
|
||||
u, _ := url.Parse(f.locked.Issuer())
|
||||
if req.URL.Host != u.Host {
|
||||
if f.fakeCoderd != nil {
|
||||
return f.fakeCoderd(req)
|
||||
if fakeCoderd := f.locked.FakeCoderd(); fakeCoderd != nil {
|
||||
return fakeCoderd(req)
|
||||
}
|
||||
if rest == nil || rest.Transport == nil {
|
||||
return nil, xerrors.Errorf("unexpected network request to %q", req.URL.Host)
|
||||
@@ -1251,7 +1345,7 @@ func (f *FakeIDP) HTTPClient(rest *http.Client) *http.Client {
|
||||
return rest.Transport.RoundTrip(req)
|
||||
}
|
||||
resp := httptest.NewRecorder()
|
||||
f.handler.ServeHTTP(resp, req)
|
||||
f.locked.Handler().ServeHTTP(resp, req)
|
||||
return resp.Result(), nil
|
||||
},
|
||||
},
|
||||
@@ -1269,6 +1363,7 @@ func (f *FakeIDP) RefreshUsed(refreshToken string) bool {
|
||||
// for a given refresh token. By default, all refreshes use the same claims as
|
||||
// the original IDToken issuance.
|
||||
func (f *FakeIDP) UpdateRefreshClaims(refreshToken string, claims jwt.MapClaims) {
|
||||
// no mutex because it's a sync.Map
|
||||
f.refreshIDTokenClaims.Store(refreshToken, claims)
|
||||
}
|
||||
|
||||
@@ -1276,8 +1371,9 @@ func (f *FakeIDP) UpdateRefreshClaims(refreshToken string, claims jwt.MapClaims)
|
||||
// Coderd.
|
||||
func (f *FakeIDP) SetRedirect(t testing.TB, u string) {
|
||||
t.Helper()
|
||||
|
||||
f.cfg.RedirectURL = u
|
||||
f.locked.MutateConfig(func(cfg *oauth2.Config) {
|
||||
cfg.RedirectURL = u
|
||||
})
|
||||
}
|
||||
|
||||
// SetCoderdCallback is optional and only works if not using the IsServing.
|
||||
@@ -1287,7 +1383,7 @@ func (f *FakeIDP) SetCoderdCallback(callback func(req *http.Request) (*http.Resp
|
||||
if f.serve {
|
||||
panic("cannot set callback handler when using 'WithServing'. Must implement an actual 'Coderd'")
|
||||
}
|
||||
f.fakeCoderd = callback
|
||||
f.locked.SetFakeCoderd(callback)
|
||||
}
|
||||
|
||||
func (f *FakeIDP) SetCoderdCallbackHandler(handler http.HandlerFunc) {
|
||||
@@ -1384,13 +1480,13 @@ func (f *FakeIDP) ExternalAuthConfig(t testing.TB, id string, custom *ExternalAu
|
||||
DisplayIcon: f.WellknownConfig().UserInfoURL,
|
||||
// Omit the /user for the validate so we can easily append to it when modifying
|
||||
// the cfg for advanced tests.
|
||||
ValidateURL: f.issuerURL.ResolveReference(&url.URL{Path: "/external-auth-validate/"}).String(),
|
||||
ValidateURL: f.locked.IssuerURL().ResolveReference(&url.URL{Path: "/external-auth-validate/"}).String(),
|
||||
DeviceAuth: &externalauth.DeviceAuth{
|
||||
Config: oauthCfg,
|
||||
ClientID: f.clientID,
|
||||
TokenURL: f.provider.TokenURL,
|
||||
TokenURL: f.locked.Provider().TokenURL,
|
||||
Scopes: []string{},
|
||||
CodeURL: f.provider.DeviceCodeURL,
|
||||
CodeURL: f.locked.Provider().DeviceCodeURL,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1401,7 +1497,7 @@ func (f *FakeIDP) ExternalAuthConfig(t testing.TB, id string, custom *ExternalAu
|
||||
for _, opt := range opts {
|
||||
opt(cfg)
|
||||
}
|
||||
f.updateIssuerURL(t, f.issuer)
|
||||
f.updateIssuerURL(t, f.locked.Issuer())
|
||||
return cfg
|
||||
}
|
||||
|
||||
@@ -1410,35 +1506,35 @@ func (f *FakeIDP) AppCredentials() (clientID string, clientSecret string) {
|
||||
}
|
||||
|
||||
func (f *FakeIDP) PublicKey() crypto.PublicKey {
|
||||
return f.key.Public()
|
||||
return f.locked.PrivateKey().Public()
|
||||
}
|
||||
|
||||
func (f *FakeIDP) OauthConfig(t testing.TB, scopes []string) *oauth2.Config {
|
||||
t.Helper()
|
||||
|
||||
if len(scopes) == 0 {
|
||||
scopes = []string{"openid", "email", "profile"}
|
||||
}
|
||||
oauthCfg := &oauth2.Config{
|
||||
ClientID: f.clientID,
|
||||
ClientSecret: f.clientSecret,
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: f.provider.AuthURL,
|
||||
TokenURL: f.provider.TokenURL,
|
||||
provider := f.locked.Provider()
|
||||
f.locked.MutateConfig(func(cfg *oauth2.Config) {
|
||||
if len(scopes) == 0 {
|
||||
scopes = []string{"openid", "email", "profile"}
|
||||
}
|
||||
cfg.ClientID = f.clientID
|
||||
cfg.ClientSecret = f.clientSecret
|
||||
cfg.Endpoint = oauth2.Endpoint{
|
||||
AuthURL: provider.AuthURL,
|
||||
TokenURL: provider.TokenURL,
|
||||
AuthStyle: oauth2.AuthStyleInParams,
|
||||
},
|
||||
}
|
||||
// If the user is using a real network request, they will need to do
|
||||
// 'fake.SetRedirect()'
|
||||
RedirectURL: "https://redirect.com",
|
||||
Scopes: scopes,
|
||||
}
|
||||
f.cfg = oauthCfg
|
||||
cfg.RedirectURL = "https://redirect.com"
|
||||
cfg.Scopes = scopes
|
||||
})
|
||||
|
||||
return oauthCfg
|
||||
return f.locked.Config()
|
||||
}
|
||||
|
||||
func (f *FakeIDP) OIDCConfigSkipIssuerChecks(t testing.TB, scopes []string, opts ...func(cfg *coderd.OIDCConfig)) *coderd.OIDCConfig {
|
||||
ctx := oidc.InsecureIssuerURLContext(context.Background(), f.issuer)
|
||||
ctx := oidc.InsecureIssuerURLContext(context.Background(), f.locked.Issuer())
|
||||
|
||||
return f.internalOIDCConfig(ctx, t, scopes, func(config *oidc.Config) {
|
||||
config.SkipIssuerCheck = true
|
||||
@@ -1456,7 +1552,7 @@ func (f *FakeIDP) internalOIDCConfig(ctx context.Context, t testing.TB, scopes [
|
||||
oauthCfg := f.OauthConfig(t, scopes)
|
||||
|
||||
ctx = oidc.ClientContext(ctx, f.HTTPClient(nil))
|
||||
p, err := oidc.NewProvider(ctx, f.provider.Issuer)
|
||||
p, err := oidc.NewProvider(ctx, f.locked.Issuer())
|
||||
require.NoError(t, err, "failed to create OIDC provider")
|
||||
|
||||
verifierConfig := &oidc.Config{
|
||||
@@ -1473,8 +1569,8 @@ func (f *FakeIDP) internalOIDCConfig(ctx context.Context, t testing.TB, scopes [
|
||||
cfg := &coderd.OIDCConfig{
|
||||
OAuth2Config: oauthCfg,
|
||||
Provider: p,
|
||||
Verifier: oidc.NewVerifier(f.provider.Issuer, &oidc.StaticKeySet{
|
||||
PublicKeys: []crypto.PublicKey{f.key.Public()},
|
||||
Verifier: oidc.NewVerifier(f.locked.Issuer(), &oidc.StaticKeySet{
|
||||
PublicKeys: []crypto.PublicKey{f.locked.PrivateKey().Public()},
|
||||
}, verifierConfig),
|
||||
UsernameField: "preferred_username",
|
||||
EmailField: "email",
|
||||
|
||||
Reference in New Issue
Block a user