diff --git a/coderd/provisionerdserver/provisionerdserver.go b/coderd/provisionerdserver/provisionerdserver.go index c0ccee4145..16109fab9a 100644 --- a/coderd/provisionerdserver/provisionerdserver.go +++ b/coderd/provisionerdserver/provisionerdserver.go @@ -564,7 +564,7 @@ func (s *server) acquireProtoJob(ctx context.Context, job database.ProvisionerJo // The check `s.OIDCConfig != nil` is not as strict, since it can be an interface // pointing to a typed nil. if !reflect.ValueOf(s.OIDCConfig).IsNil() { - workspaceOwnerOIDCAccessToken, err = obtainOIDCAccessToken(ctx, s.Logger, s.Database, s.OIDCConfig, owner.ID) + workspaceOwnerOIDCAccessToken, err = ObtainOIDCAccessToken(ctx, s.Logger, s.Database, s.OIDCConfig, owner.ID) if err != nil { return nil, failJob(fmt.Sprintf("obtain OIDC access token: %s", err)) } @@ -3075,15 +3075,15 @@ func deleteSessionTokenForUserAndWorkspace(ctx context.Context, db database.Stor return nil } -func shouldRefreshOIDCToken(link database.UserLink) bool { +func shouldRefreshOIDCToken(link database.UserLink) (bool, time.Time) { if link.OAuthRefreshToken == "" { // We cannot refresh even if we wanted to - return false + return false, link.OAuthExpiry } if link.OAuthExpiry.IsZero() { // 0 expire means the token never expires, so we shouldn't refresh - return false + return false, link.OAuthExpiry } // This handles an edge case where the token is about to expire. A workspace @@ -3094,17 +3094,18 @@ func shouldRefreshOIDCToken(link database.UserLink) bool { // If an OIDC provider issues short-lived tokens less than our defined period, // the token will always be refreshed on every workspace build. // - // By shifting the time forward, we are asking - // "Will this token be valid in 10 minutes" - expiryCheckTime := dbtime.Now().Add(time.Minute * 10) + // By setting the expiration backwards, we are effectively shortening the + // time a token can be alive for by 10 minutes. + // Note: This is how it is done in the oauth2 package's own token refreshing logic. + expiresAt := link.OAuthExpiry.Add(-time.Minute * 10) // Return if the token is assumed to be expired. - return link.OAuthExpiry.Before(expiryCheckTime) + return expiresAt.Before(dbtime.Now()), expiresAt } -// obtainOIDCAccessToken returns a valid OpenID Connect access token +// ObtainOIDCAccessToken returns a valid OpenID Connect access token // for the user if it's able to obtain one, otherwise it returns an empty string. -func obtainOIDCAccessToken(ctx context.Context, logger slog.Logger, db database.Store, oidcConfig promoauth.OAuth2Config, userID uuid.UUID) (string, error) { +func ObtainOIDCAccessToken(ctx context.Context, logger slog.Logger, db database.Store, oidcConfig promoauth.OAuth2Config, userID uuid.UUID) (string, error) { link, err := db.GetUserLinkByUserIDLoginType(ctx, database.GetUserLinkByUserIDLoginTypeParams{ UserID: userID, LoginType: database.LoginTypeOIDC, @@ -3116,11 +3117,13 @@ func obtainOIDCAccessToken(ctx context.Context, logger slog.Logger, db database. return "", xerrors.Errorf("get owner oidc link: %w", err) } - if shouldRefreshOIDCToken(link) { + if shouldRefresh, expiresAt := shouldRefreshOIDCToken(link); shouldRefresh { token, err := oidcConfig.TokenSource(ctx, &oauth2.Token{ AccessToken: link.OAuthAccessToken, RefreshToken: link.OAuthRefreshToken, - Expiry: link.OAuthExpiry, + // Use the expiresAt returned by shouldRefreshOIDCToken. + // It will force a refresh with an expired time. + Expiry: expiresAt, }).Token() if err != nil { // If OIDC fails to refresh, we return an empty string and don't fail. diff --git a/coderd/provisionerdserver/provisionerdserver_internal_test.go b/coderd/provisionerdserver/provisionerdserver_internal_test.go index cf18d502aa..7e6aa80f9b 100644 --- a/coderd/provisionerdserver/provisionerdserver_internal_test.go +++ b/coderd/provisionerdserver/provisionerdserver_internal_test.go @@ -106,7 +106,8 @@ func TestShouldRefreshOIDCToken(t *testing.T) { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() - require.Equal(t, tc.want, shouldRefreshOIDCToken(tc.link)) + shouldRefresh, _ := shouldRefreshOIDCToken(tc.link) + require.Equal(t, tc.want, shouldRefresh) }) } } @@ -117,7 +118,7 @@ func TestObtainOIDCAccessToken(t *testing.T) { t.Run("NoToken", func(t *testing.T) { t.Parallel() db, _ := dbtestutil.NewDB(t) - _, err := obtainOIDCAccessToken(ctx, testutil.Logger(t), db, nil, uuid.Nil) + _, err := ObtainOIDCAccessToken(ctx, testutil.Logger(t), db, nil, uuid.Nil) require.NoError(t, err) }) t.Run("InvalidConfig", func(t *testing.T) { @@ -130,7 +131,7 @@ func TestObtainOIDCAccessToken(t *testing.T) { LoginType: database.LoginTypeOIDC, OAuthExpiry: dbtime.Now().Add(-time.Hour), }) - _, err := obtainOIDCAccessToken(ctx, testutil.Logger(t), db, &oauth2.Config{}, user.ID) + _, err := ObtainOIDCAccessToken(ctx, testutil.Logger(t), db, &oauth2.Config{}, user.ID) require.NoError(t, err) }) t.Run("MissingLink", func(t *testing.T) { @@ -139,7 +140,7 @@ func TestObtainOIDCAccessToken(t *testing.T) { user := dbgen.User(t, db, database.User{ LoginType: database.LoginTypeOIDC, }) - tok, err := obtainOIDCAccessToken(ctx, testutil.Logger(t), db, &oauth2.Config{}, user.ID) + tok, err := ObtainOIDCAccessToken(ctx, testutil.Logger(t), db, &oauth2.Config{}, user.ID) require.Empty(t, tok) require.NoError(t, err) }) @@ -152,7 +153,7 @@ func TestObtainOIDCAccessToken(t *testing.T) { LoginType: database.LoginTypeOIDC, OAuthExpiry: dbtime.Now().Add(-time.Hour), }) - _, err := obtainOIDCAccessToken(ctx, testutil.Logger(t), db, &testutil.OAuth2Config{ + _, err := ObtainOIDCAccessToken(ctx, testutil.Logger(t), db, &testutil.OAuth2Config{ Token: &oauth2.Token{ AccessToken: "token", }, diff --git a/coderd/provisionerdserver/provisionerdserver_test.go b/coderd/provisionerdserver/provisionerdserver_test.go index dbbfa94505..336d91903a 100644 --- a/coderd/provisionerdserver/provisionerdserver_test.go +++ b/coderd/provisionerdserver/provisionerdserver_test.go @@ -15,6 +15,7 @@ import ( "testing" "time" + "github.com/golang-jwt/jwt/v4" "github.com/google/uuid" "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/assert" @@ -30,6 +31,7 @@ import ( "github.com/coder/coder/v2/coderd" "github.com/coder/coder/v2/coderd/audit" "github.com/coder/coder/v2/coderd/coderdtest" + "github.com/coder/coder/v2/coderd/coderdtest/oidctest" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/database/dbgen" @@ -58,6 +60,175 @@ import ( "github.com/coder/serpent" ) +// TestTokenIsRefreshedEarly creates a fake OIDC IDP that sets expiration times +// of the token to values that are "near expiration". Expiration being 10minutes +// earlier than it needs to be. The `ObtainOIDCAccessToken` should refresh these +// tokens early. +func TestTokenIsRefreshedEarly(t *testing.T) { + t.Parallel() + + t.Run("WithCoderd", func(t *testing.T) { + t.Parallel() + tokenRefreshCount := 0 + fake := oidctest.NewFakeIDP(t, + oidctest.WithServing(), + oidctest.WithDefaultExpire(time.Minute*8), + oidctest.WithRefresh(func(email string) error { + tokenRefreshCount++ + return nil + }), + ) + cfg := fake.OIDCConfig(t, nil, func(cfg *coderd.OIDCConfig) { + cfg.AllowSignups = true + }) + db, ps := dbtestutil.NewDB(t) + owner := coderdtest.New(t, &coderdtest.Options{ + OIDCConfig: cfg, + IncludeProvisionerDaemon: true, + Database: db, + Pubsub: ps, + }) + first := coderdtest.CreateFirstUser(t, owner) + version := coderdtest.CreateTemplateVersion(t, owner, first.OrganizationID, nil) + coderdtest.AwaitTemplateVersionJobCompleted(t, owner, version.ID) + template := coderdtest.CreateTemplate(t, owner, first.OrganizationID, version.ID) + + // Setup an OIDC user. + client, _ := fake.Login(t, owner, jwt.MapClaims{ + "email": "user@unauthorized.com", + "email_verified": true, + "sub": uuid.NewString(), + }) + + // Creating a workspace should refresh the oidc early. + tokenRefreshCount = 0 + wrk := coderdtest.CreateWorkspace(t, client, template.ID) + coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, wrk.LatestBuild.ID) + require.Equal(t, 1, tokenRefreshCount) + }) +} + +//nolint:tparallel,paralleltest // Sub tests need to run sequentially. +func TestTokenIsRefreshedEarlyWithoutCoderd(t *testing.T) { + t.Parallel() + tokenRefreshCount := 0 + fake := oidctest.NewFakeIDP(t, + oidctest.WithServing(), + oidctest.WithDefaultExpire(time.Minute*8), + oidctest.WithRefresh(func(email string) error { + tokenRefreshCount++ + return nil + }), + ) + cfg := fake.OIDCConfig(t, nil) + + // Fetch a valid token from the fake OIDC provider + token, err := fake.GenerateAuthenticatedToken(jwt.MapClaims{ + "email": "user@unauthorized.com", + "email_verified": true, + "sub": uuid.NewString(), + }) + require.NoError(t, err) + + db, _ := dbtestutil.NewDB(t) + user := dbgen.User(t, db, database.User{}) + dbgen.UserLink(t, db, database.UserLink{ + UserID: user.ID, + LoginType: database.LoginTypeOIDC, + LinkedID: "foo", + OAuthAccessToken: token.AccessToken, + OAuthRefreshToken: token.RefreshToken, + // The oauth expiry does not really matter, since each test will manually control + // this value. + OAuthExpiry: dbtime.Now().Add(time.Hour), + }) + + setLinkExpiration := func(t *testing.T, exp time.Time) database.UserLink { + ctx := testutil.Context(t, testutil.WaitShort) + links, err := db.GetUserLinksByUserID(ctx, user.ID) + require.NoError(t, err) + require.Len(t, links, 1) + link := links[0] + + newLink, err := db.UpdateUserLink(ctx, database.UpdateUserLinkParams{ + OAuthAccessToken: link.OAuthAccessToken, + OAuthAccessTokenKeyID: link.OAuthAccessTokenKeyID, + OAuthRefreshToken: link.OAuthRefreshToken, + OAuthRefreshTokenKeyID: link.OAuthRefreshTokenKeyID, + OAuthExpiry: exp, + Claims: link.Claims, + UserID: link.UserID, + LoginType: link.LoginType, + }) + require.NoError(t, err) + return newLink + } + + for _, c := range []struct { + name string + // expires is a function to return a more up to date "now". + // Because the oauth library is calling `time.Now()`, we cannot use + // mocked clocks. + expires func() time.Time + refreshExpected bool + }{ + { + name: "ZeroExpiry", + expires: func() time.Time { return time.Time{} }, + refreshExpected: false, + }, + { + name: "LongExpired", + expires: func() time.Time { return dbtime.Now().Add(-time.Hour) }, + refreshExpected: true, + }, + { + name: "EdgeExpired", + expires: func() time.Time { return dbtime.Now().Add(-time.Minute * 10) }, + refreshExpected: true, + }, + { + name: "RecentExpired", + expires: func() time.Time { return dbtime.Now().Add(-time.Second * -1) }, + refreshExpected: true, + }, + + { + name: "Future", + expires: func() time.Time { return dbtime.Now().Add(time.Hour) }, + refreshExpected: false, + }, + { + name: "FutureWithinRefreshWindow", + expires: func() time.Time { return dbtime.Now().Add(time.Minute * 8) }, + refreshExpected: true, + }, + } { + t.Run(c.name, func(t *testing.T) { + ctx := testutil.Context(t, testutil.WaitShort) + oldLink := setLinkExpiration(t, c.expires()) + tokenRefreshCount = 0 + _, err := provisionerdserver.ObtainOIDCAccessToken(ctx, testutil.Logger(t), db, cfg, user.ID) + require.NoError(t, err) + links, err := db.GetUserLinksByUserID(ctx, user.ID) + require.NoError(t, err) + require.Len(t, links, 1) + newLink := links[0] + + if c.refreshExpected { + require.Equal(t, 1, tokenRefreshCount) + + require.NotEqual(t, oldLink.OAuthAccessToken, newLink.OAuthAccessToken) + require.NotEqual(t, oldLink.OAuthRefreshToken, newLink.OAuthRefreshToken) + } else { + require.Equal(t, 0, tokenRefreshCount) + require.Equal(t, oldLink.OAuthAccessToken, newLink.OAuthAccessToken) + require.Equal(t, oldLink.OAuthRefreshToken, newLink.OAuthRefreshToken) + } + }) + } +} + func testTemplateScheduleStore() *atomic.Pointer[schedule.TemplateScheduleStore] { poitr := &atomic.Pointer[schedule.TemplateScheduleStore]{} store := schedule.NewAGPLTemplateScheduleStore()