fix: early oidc refresh with fake idp tests (#22712)

Wrote unit tests that implement a fake idp to verify the oauth package
actually refreshes the token
This commit is contained in:
Steven Masley
2026-03-06 10:51:27 -06:00
committed by GitHub
parent ec48636ba8
commit 537260aa22
3 changed files with 192 additions and 17 deletions
+15 -12
View File
@@ -564,7 +564,7 @@ func (s *server) acquireProtoJob(ctx context.Context, job database.ProvisionerJo
// The check `s.OIDCConfig != nil` is not as strict, since it can be an interface
// pointing to a typed nil.
if !reflect.ValueOf(s.OIDCConfig).IsNil() {
workspaceOwnerOIDCAccessToken, err = obtainOIDCAccessToken(ctx, s.Logger, s.Database, s.OIDCConfig, owner.ID)
workspaceOwnerOIDCAccessToken, err = ObtainOIDCAccessToken(ctx, s.Logger, s.Database, s.OIDCConfig, owner.ID)
if err != nil {
return nil, failJob(fmt.Sprintf("obtain OIDC access token: %s", err))
}
@@ -3075,15 +3075,15 @@ func deleteSessionTokenForUserAndWorkspace(ctx context.Context, db database.Stor
return nil
}
func shouldRefreshOIDCToken(link database.UserLink) bool {
func shouldRefreshOIDCToken(link database.UserLink) (bool, time.Time) {
if link.OAuthRefreshToken == "" {
// We cannot refresh even if we wanted to
return false
return false, link.OAuthExpiry
}
if link.OAuthExpiry.IsZero() {
// 0 expire means the token never expires, so we shouldn't refresh
return false
return false, link.OAuthExpiry
}
// This handles an edge case where the token is about to expire. A workspace
@@ -3094,17 +3094,18 @@ func shouldRefreshOIDCToken(link database.UserLink) bool {
// If an OIDC provider issues short-lived tokens less than our defined period,
// the token will always be refreshed on every workspace build.
//
// By shifting the time forward, we are asking
// "Will this token be valid in 10 minutes"
expiryCheckTime := dbtime.Now().Add(time.Minute * 10)
// By setting the expiration backwards, we are effectively shortening the
// time a token can be alive for by 10 minutes.
// Note: This is how it is done in the oauth2 package's own token refreshing logic.
expiresAt := link.OAuthExpiry.Add(-time.Minute * 10)
// Return if the token is assumed to be expired.
return link.OAuthExpiry.Before(expiryCheckTime)
return expiresAt.Before(dbtime.Now()), expiresAt
}
// obtainOIDCAccessToken returns a valid OpenID Connect access token
// ObtainOIDCAccessToken returns a valid OpenID Connect access token
// for the user if it's able to obtain one, otherwise it returns an empty string.
func obtainOIDCAccessToken(ctx context.Context, logger slog.Logger, db database.Store, oidcConfig promoauth.OAuth2Config, userID uuid.UUID) (string, error) {
func ObtainOIDCAccessToken(ctx context.Context, logger slog.Logger, db database.Store, oidcConfig promoauth.OAuth2Config, userID uuid.UUID) (string, error) {
link, err := db.GetUserLinkByUserIDLoginType(ctx, database.GetUserLinkByUserIDLoginTypeParams{
UserID: userID,
LoginType: database.LoginTypeOIDC,
@@ -3116,11 +3117,13 @@ func obtainOIDCAccessToken(ctx context.Context, logger slog.Logger, db database.
return "", xerrors.Errorf("get owner oidc link: %w", err)
}
if shouldRefreshOIDCToken(link) {
if shouldRefresh, expiresAt := shouldRefreshOIDCToken(link); shouldRefresh {
token, err := oidcConfig.TokenSource(ctx, &oauth2.Token{
AccessToken: link.OAuthAccessToken,
RefreshToken: link.OAuthRefreshToken,
Expiry: link.OAuthExpiry,
// Use the expiresAt returned by shouldRefreshOIDCToken.
// It will force a refresh with an expired time.
Expiry: expiresAt,
}).Token()
if err != nil {
// If OIDC fails to refresh, we return an empty string and don't fail.
@@ -106,7 +106,8 @@ func TestShouldRefreshOIDCToken(t *testing.T) {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
require.Equal(t, tc.want, shouldRefreshOIDCToken(tc.link))
shouldRefresh, _ := shouldRefreshOIDCToken(tc.link)
require.Equal(t, tc.want, shouldRefresh)
})
}
}
@@ -117,7 +118,7 @@ func TestObtainOIDCAccessToken(t *testing.T) {
t.Run("NoToken", func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
_, err := obtainOIDCAccessToken(ctx, testutil.Logger(t), db, nil, uuid.Nil)
_, err := ObtainOIDCAccessToken(ctx, testutil.Logger(t), db, nil, uuid.Nil)
require.NoError(t, err)
})
t.Run("InvalidConfig", func(t *testing.T) {
@@ -130,7 +131,7 @@ func TestObtainOIDCAccessToken(t *testing.T) {
LoginType: database.LoginTypeOIDC,
OAuthExpiry: dbtime.Now().Add(-time.Hour),
})
_, err := obtainOIDCAccessToken(ctx, testutil.Logger(t), db, &oauth2.Config{}, user.ID)
_, err := ObtainOIDCAccessToken(ctx, testutil.Logger(t), db, &oauth2.Config{}, user.ID)
require.NoError(t, err)
})
t.Run("MissingLink", func(t *testing.T) {
@@ -139,7 +140,7 @@ func TestObtainOIDCAccessToken(t *testing.T) {
user := dbgen.User(t, db, database.User{
LoginType: database.LoginTypeOIDC,
})
tok, err := obtainOIDCAccessToken(ctx, testutil.Logger(t), db, &oauth2.Config{}, user.ID)
tok, err := ObtainOIDCAccessToken(ctx, testutil.Logger(t), db, &oauth2.Config{}, user.ID)
require.Empty(t, tok)
require.NoError(t, err)
})
@@ -152,7 +153,7 @@ func TestObtainOIDCAccessToken(t *testing.T) {
LoginType: database.LoginTypeOIDC,
OAuthExpiry: dbtime.Now().Add(-time.Hour),
})
_, err := obtainOIDCAccessToken(ctx, testutil.Logger(t), db, &testutil.OAuth2Config{
_, err := ObtainOIDCAccessToken(ctx, testutil.Logger(t), db, &testutil.OAuth2Config{
Token: &oauth2.Token{
AccessToken: "token",
},
@@ -15,6 +15,7 @@ import (
"testing"
"time"
"github.com/golang-jwt/jwt/v4"
"github.com/google/uuid"
"github.com/prometheus/client_golang/prometheus"
"github.com/stretchr/testify/assert"
@@ -30,6 +31,7 @@ import (
"github.com/coder/coder/v2/coderd"
"github.com/coder/coder/v2/coderd/audit"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbgen"
@@ -58,6 +60,175 @@ import (
"github.com/coder/serpent"
)
// TestTokenIsRefreshedEarly creates a fake OIDC IDP that sets expiration times
// of the token to values that are "near expiration". Expiration being 10minutes
// earlier than it needs to be. The `ObtainOIDCAccessToken` should refresh these
// tokens early.
func TestTokenIsRefreshedEarly(t *testing.T) {
t.Parallel()
t.Run("WithCoderd", func(t *testing.T) {
t.Parallel()
tokenRefreshCount := 0
fake := oidctest.NewFakeIDP(t,
oidctest.WithServing(),
oidctest.WithDefaultExpire(time.Minute*8),
oidctest.WithRefresh(func(email string) error {
tokenRefreshCount++
return nil
}),
)
cfg := fake.OIDCConfig(t, nil, func(cfg *coderd.OIDCConfig) {
cfg.AllowSignups = true
})
db, ps := dbtestutil.NewDB(t)
owner := coderdtest.New(t, &coderdtest.Options{
OIDCConfig: cfg,
IncludeProvisionerDaemon: true,
Database: db,
Pubsub: ps,
})
first := coderdtest.CreateFirstUser(t, owner)
version := coderdtest.CreateTemplateVersion(t, owner, first.OrganizationID, nil)
coderdtest.AwaitTemplateVersionJobCompleted(t, owner, version.ID)
template := coderdtest.CreateTemplate(t, owner, first.OrganizationID, version.ID)
// Setup an OIDC user.
client, _ := fake.Login(t, owner, jwt.MapClaims{
"email": "user@unauthorized.com",
"email_verified": true,
"sub": uuid.NewString(),
})
// Creating a workspace should refresh the oidc early.
tokenRefreshCount = 0
wrk := coderdtest.CreateWorkspace(t, client, template.ID)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, wrk.LatestBuild.ID)
require.Equal(t, 1, tokenRefreshCount)
})
}
//nolint:tparallel,paralleltest // Sub tests need to run sequentially.
func TestTokenIsRefreshedEarlyWithoutCoderd(t *testing.T) {
t.Parallel()
tokenRefreshCount := 0
fake := oidctest.NewFakeIDP(t,
oidctest.WithServing(),
oidctest.WithDefaultExpire(time.Minute*8),
oidctest.WithRefresh(func(email string) error {
tokenRefreshCount++
return nil
}),
)
cfg := fake.OIDCConfig(t, nil)
// Fetch a valid token from the fake OIDC provider
token, err := fake.GenerateAuthenticatedToken(jwt.MapClaims{
"email": "user@unauthorized.com",
"email_verified": true,
"sub": uuid.NewString(),
})
require.NoError(t, err)
db, _ := dbtestutil.NewDB(t)
user := dbgen.User(t, db, database.User{})
dbgen.UserLink(t, db, database.UserLink{
UserID: user.ID,
LoginType: database.LoginTypeOIDC,
LinkedID: "foo",
OAuthAccessToken: token.AccessToken,
OAuthRefreshToken: token.RefreshToken,
// The oauth expiry does not really matter, since each test will manually control
// this value.
OAuthExpiry: dbtime.Now().Add(time.Hour),
})
setLinkExpiration := func(t *testing.T, exp time.Time) database.UserLink {
ctx := testutil.Context(t, testutil.WaitShort)
links, err := db.GetUserLinksByUserID(ctx, user.ID)
require.NoError(t, err)
require.Len(t, links, 1)
link := links[0]
newLink, err := db.UpdateUserLink(ctx, database.UpdateUserLinkParams{
OAuthAccessToken: link.OAuthAccessToken,
OAuthAccessTokenKeyID: link.OAuthAccessTokenKeyID,
OAuthRefreshToken: link.OAuthRefreshToken,
OAuthRefreshTokenKeyID: link.OAuthRefreshTokenKeyID,
OAuthExpiry: exp,
Claims: link.Claims,
UserID: link.UserID,
LoginType: link.LoginType,
})
require.NoError(t, err)
return newLink
}
for _, c := range []struct {
name string
// expires is a function to return a more up to date "now".
// Because the oauth library is calling `time.Now()`, we cannot use
// mocked clocks.
expires func() time.Time
refreshExpected bool
}{
{
name: "ZeroExpiry",
expires: func() time.Time { return time.Time{} },
refreshExpected: false,
},
{
name: "LongExpired",
expires: func() time.Time { return dbtime.Now().Add(-time.Hour) },
refreshExpected: true,
},
{
name: "EdgeExpired",
expires: func() time.Time { return dbtime.Now().Add(-time.Minute * 10) },
refreshExpected: true,
},
{
name: "RecentExpired",
expires: func() time.Time { return dbtime.Now().Add(-time.Second * -1) },
refreshExpected: true,
},
{
name: "Future",
expires: func() time.Time { return dbtime.Now().Add(time.Hour) },
refreshExpected: false,
},
{
name: "FutureWithinRefreshWindow",
expires: func() time.Time { return dbtime.Now().Add(time.Minute * 8) },
refreshExpected: true,
},
} {
t.Run(c.name, func(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitShort)
oldLink := setLinkExpiration(t, c.expires())
tokenRefreshCount = 0
_, err := provisionerdserver.ObtainOIDCAccessToken(ctx, testutil.Logger(t), db, cfg, user.ID)
require.NoError(t, err)
links, err := db.GetUserLinksByUserID(ctx, user.ID)
require.NoError(t, err)
require.Len(t, links, 1)
newLink := links[0]
if c.refreshExpected {
require.Equal(t, 1, tokenRefreshCount)
require.NotEqual(t, oldLink.OAuthAccessToken, newLink.OAuthAccessToken)
require.NotEqual(t, oldLink.OAuthRefreshToken, newLink.OAuthRefreshToken)
} else {
require.Equal(t, 0, tokenRefreshCount)
require.Equal(t, oldLink.OAuthAccessToken, newLink.OAuthAccessToken)
require.Equal(t, oldLink.OAuthRefreshToken, newLink.OAuthRefreshToken)
}
})
}
}
func testTemplateScheduleStore() *atomic.Pointer[schedule.TemplateScheduleStore] {
poitr := &atomic.Pointer[schedule.TemplateScheduleStore]{}
store := schedule.NewAGPLTemplateScheduleStore()