diff --git a/coderd/externalauth/externalauth.go b/coderd/externalauth/externalauth.go index 67923d18c2..3f018a15a0 100644 --- a/coderd/externalauth/externalauth.go +++ b/coderd/externalauth/externalauth.go @@ -239,6 +239,37 @@ func (c *Config) RefreshToken(ctx context.Context, db database.Store, externalAu return externalAuthLink, xerrors.Errorf("generate token extra: %w", err) } + // Persist the refreshed token to the DB before validation. GitHub + // rotates refresh tokens on every use, so the old refresh token is + // already invalid on the IDP side. If we validated first and the + // validation endpoint was unavailable (e.g. rate-limited 403), the + // new token would be silently lost and the user would be forced to + // re-authenticate manually. + // Use a detached context for the DB write only. The IDP already + // consumed the old refresh token, so if the caller's request + // context is canceled mid-save, the new token would be lost. + persistCtx, persistCancel := context.WithTimeout(context.WithoutCancel(ctx), 10*time.Second) + defer persistCancel() + + originalAccessToken := externalAuthLink.OAuthAccessToken + if token.AccessToken != originalAccessToken { + updatedAuthLink, err := db.UpdateExternalAuthLink(persistCtx, database.UpdateExternalAuthLinkParams{ + ProviderID: c.ID, + UserID: externalAuthLink.UserID, + UpdatedAt: dbtime.Now(), + OAuthAccessToken: token.AccessToken, + OAuthAccessTokenKeyID: sql.NullString{}, // dbcrypt will update as required + OAuthRefreshToken: token.RefreshToken, + OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required + OAuthExpiry: token.Expiry, + OAuthExtra: extra, + }) + if err != nil { + return updatedAuthLink, xerrors.Errorf("persist refreshed token: %w", err) + } + externalAuthLink = updatedAuthLink + } + r := retry.New(50*time.Millisecond, 200*time.Millisecond) // See the comment below why the retry and cancel is required. retryCtx, retryCtxCancel := context.WithTimeout(ctx, time.Second) @@ -263,35 +294,18 @@ validate: return externalAuthLink, InvalidTokenError("token failed to validate") } - if token.AccessToken != externalAuthLink.OAuthAccessToken { - updatedAuthLink, err := db.UpdateExternalAuthLink(ctx, database.UpdateExternalAuthLinkParams{ - ProviderID: c.ID, - UserID: externalAuthLink.UserID, - UpdatedAt: dbtime.Now(), - OAuthAccessToken: token.AccessToken, - OAuthAccessTokenKeyID: sql.NullString{}, // dbcrypt will update as required - OAuthRefreshToken: token.RefreshToken, - OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required - OAuthExpiry: token.Expiry, - OAuthExtra: extra, + // Update the associated user's github.com user ID if the token + // is for github.com and validation returned user info. + if token.AccessToken != originalAccessToken && IsGithubDotComURL(c.AuthCodeURL("")) && user != nil { + err = db.UpdateUserGithubComUserID(ctx, database.UpdateUserGithubComUserIDParams{ + ID: externalAuthLink.UserID, + GithubComUserID: sql.NullInt64{ + Int64: user.ID, + Valid: true, + }, }) if err != nil { - return updatedAuthLink, xerrors.Errorf("update external auth link: %w", err) - } - externalAuthLink = updatedAuthLink - - // Update the associated users github.com username if the token is for github.com. - if IsGithubDotComURL(c.AuthCodeURL("")) && user != nil { - err = db.UpdateUserGithubComUserID(ctx, database.UpdateUserGithubComUserIDParams{ - ID: externalAuthLink.UserID, - GithubComUserID: sql.NullInt64{ - Int64: user.ID, - Valid: true, - }, - }) - if err != nil { - return externalAuthLink, xerrors.Errorf("update user github com user id: %w", err) - } + return externalAuthLink, xerrors.Errorf("update user github com user id: %w", err) } } diff --git a/coderd/externalauth/externalauth_test.go b/coderd/externalauth/externalauth_test.go index 827fafd966..50b68aef88 100644 --- a/coderd/externalauth/externalauth_test.go +++ b/coderd/externalauth/externalauth_test.go @@ -8,6 +8,7 @@ import ( "net/http/httptest" "net/url" "strings" + "sync/atomic" "testing" "time" @@ -27,6 +28,7 @@ import ( "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/database/dbmock" "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/externalauth" "github.com/coder/coder/v2/coderd/promoauth" "github.com/coder/coder/v2/codersdk" @@ -120,6 +122,11 @@ func TestRefreshToken(t *testing.T) { t.Run("ValidateServerError", func(t *testing.T) { t.Parallel() + ctrl := gomock.NewController(t) + mDB := dbmock.NewMockStore(ctrl) + mDB.EXPECT().UpdateExternalAuthLink(gomock.Any(), gomock.Any()). + Return(database.ExternalAuthLink{}, nil).AnyTimes() + const staticError = "static error" validated := false fake, config, link := setupOauth2Test(t, testConfig{ @@ -136,7 +143,7 @@ func TestRefreshToken(t *testing.T) { ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil)) link.OAuthExpiry = expired - _, err := config.RefreshToken(ctx, nil, link) + _, err := config.RefreshToken(ctx, mDB, link) require.ErrorContains(t, err, staticError) // Unsure if this should be the correct behavior. It's an invalid token because // 'ValidateToken()' failed with a runtime error. This was the previous behavior, @@ -223,6 +230,11 @@ func TestRefreshToken(t *testing.T) { t.Run("ValidateFailure", func(t *testing.T) { t.Parallel() + ctrl := gomock.NewController(t) + mDB := dbmock.NewMockStore(ctrl) + mDB.EXPECT().UpdateExternalAuthLink(gomock.Any(), gomock.Any()). + Return(database.ExternalAuthLink{}, nil).AnyTimes() + const staticError = "static error" validated := false fake, config, link := setupOauth2Test(t, testConfig{ @@ -239,7 +251,7 @@ func TestRefreshToken(t *testing.T) { ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil)) link.OAuthExpiry = expired - _, err := config.RefreshToken(ctx, nil, link) + _, err := config.RefreshToken(ctx, mDB, link) require.ErrorContains(t, err, "token failed to validate") require.True(t, externalauth.IsInvalidTokenError(err)) require.True(t, validated, "token should have been attempted to be validated") @@ -380,6 +392,235 @@ func TestRefreshToken(t *testing.T) { require.True(t, ok) require.Equal(t, updated.OAuthAccessToken, mapping["access_token"]) }) + + // SaveBeforeValidate tests that a successfully refreshed token is + // persisted to the DB even when post-refresh validation fails. This + // prevents the data-loss scenario where GitHub rotates the refresh + // token on use but the new token is silently discarded because a + // rate-limited validation endpoint returns 403. + t.Run("SaveBeforeValidate", func(t *testing.T) { + t.Parallel() + + db, _ := dbtestutil.NewDB(t) + + // simulateRateLimit controls whether the validate endpoint + // returns 403 (true) or 200 (false). + var simulateRateLimit atomic.Bool + simulateRateLimit.Store(true) + + var refreshCalls atomic.Int64 + fake, config, link := setupOauth2Test(t, testConfig{ + FakeIDPOpts: []oidctest.FakeIDPOpt{ + oidctest.WithRefresh(func(_ string) error { + refreshCalls.Add(1) + return nil + }), + oidctest.WithDynamicUserInfo(func(_ string) (jwt.MapClaims, error) { + if simulateRateLimit.Load() { + return jwt.MapClaims{}, oidctest.StatusError(http.StatusForbidden, xerrors.New("rate limit exceeded")) + } + return jwt.MapClaims{}, nil + }), + }, + ExternalAuthOpt: func(cfg *externalauth.Config) { + cfg.Type = codersdk.EnhancedExternalAuthProviderGitHub.String() + }, + DB: db, + }) + + ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil)) + + oldAccessToken := link.OAuthAccessToken + oldRefreshToken := link.OAuthRefreshToken + + // Expire the token to force a refresh. + link.OAuthExpiry = expired + + // First call: refresh succeeds, validation fails (403). + _, err := config.RefreshToken(ctx, db, link) + require.Error(t, err, "expected error because validation returned 403") + require.True(t, externalauth.IsInvalidTokenError(err)) + require.Equal(t, int64(1), refreshCalls.Load(), "IDP refresh should have been called exactly once") + + // Critical assertion: the DB must contain the NEW tokens from the + // successful refresh, not the old (now-stale) ones. + dbLink, err := db.GetExternalAuthLink(context.Background(), database.GetExternalAuthLinkParams{ + ProviderID: link.ProviderID, + UserID: link.UserID, + }) + require.NoError(t, err) + require.NotEqual(t, oldAccessToken, dbLink.OAuthAccessToken, + "DB should have the new access token from the successful refresh") + require.NotEqual(t, oldRefreshToken, dbLink.OAuthRefreshToken, + "DB should have the new refresh token (old one was rotated by the IDP)") + + // Second call: uses the saved token from DB, no re-refresh. + // The saved token has a future expiry, so TokenSource should return + // it without contacting the IDP. Validation should succeed now. + simulateRateLimit.Store(false) + updated, err := config.RefreshToken(ctx, db, dbLink) + require.NoError(t, err, "second call should succeed because rate limit lifted") + require.Equal(t, int64(1), refreshCalls.Load(), + "IDP refresh should NOT have been called again; the saved token is not expired") + require.Equal(t, dbLink.OAuthAccessToken, updated.OAuthAccessToken, + "returned token should match what was saved in the DB") + }) + + // SaveBeforeValidate_ContextCanceled verifies the early DB save + // uses a detached context. The parent context is canceled inside + // the refresh hook (after TokenSource.Token() but before the DB + // write), and the test asserts the new token is still persisted. + t.Run("SaveBeforeValidate_ContextCanceled", func(t *testing.T) { + t.Parallel() + + db, _ := dbtestutil.NewDB(t) + + var refreshCalls atomic.Int64 + cancelOnRefresh, cancel := context.WithCancel(context.Background()) + defer cancel() + + fake, config, link := setupOauth2Test(t, testConfig{ + FakeIDPOpts: []oidctest.FakeIDPOpt{ + oidctest.WithRefresh(func(_ string) error { + refreshCalls.Add(1) + // Cancel the parent context after refresh succeeds + // but before the DB save and validation. + cancel() + return nil + }), + oidctest.WithDynamicUserInfo(func(_ string) (jwt.MapClaims, error) { + return jwt.MapClaims{}, nil + }), + }, + ExternalAuthOpt: func(cfg *externalauth.Config) { + cfg.Type = codersdk.EnhancedExternalAuthProviderGitHub.String() + }, + DB: db, + }) + + ctx := oidc.ClientContext(cancelOnRefresh, fake.HTTPClient(nil)) + + oldAccessToken := link.OAuthAccessToken + oldRefreshToken := link.OAuthRefreshToken + link.OAuthExpiry = expired + + _, err := config.RefreshToken(ctx, db, link) + require.NoError(t, err) + require.Equal(t, int64(1), refreshCalls.Load()) + + dbLink, err := db.GetExternalAuthLink(context.Background(), database.GetExternalAuthLinkParams{ + ProviderID: link.ProviderID, + UserID: link.UserID, + }) + require.NoError(t, err) + require.NotEqual(t, oldAccessToken, dbLink.OAuthAccessToken, + "DB should have the new access token despite context cancellation") + require.NotEqual(t, oldRefreshToken, dbLink.OAuthRefreshToken, + "DB should have the new refresh token despite context cancellation") + }) + + // SaveBeforeValidate_DBError tests that when the early DB save + // fails after a successful IDP refresh, the error is surfaced + // as a non-InvalidTokenError. This is a degraded state (token + // issued by IDP but not persisted), and callers should see a + // real error, not a "please re-authenticate" prompt. + t.Run("SaveBeforeValidate_DBError", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + mDB := dbmock.NewMockStore(ctrl) + + fake, config, link := setupOauth2Test(t, testConfig{ + FakeIDPOpts: []oidctest.FakeIDPOpt{ + oidctest.WithRefresh(func(_ string) error { + return nil + }), + }, + ExternalAuthOpt: func(cfg *externalauth.Config) { + cfg.Type = codersdk.EnhancedExternalAuthProviderGitHub.String() + }, + }) + + ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil)) + link.OAuthExpiry = expired + + mDB.EXPECT(). + UpdateExternalAuthLink(gomock.Any(), gomock.Any()). + Return(database.ExternalAuthLink{}, xerrors.New("db connection lost")) + + _, err := config.RefreshToken(ctx, mDB, link) + require.Error(t, err) + require.Contains(t, err.Error(), "persist refreshed token") + require.False(t, externalauth.IsInvalidTokenError(err), + "DB errors should not be treated as invalid token") + }) + + // OptimisticLockPreventsStaleOverwrite verifies that the + // UpdateExternalAuthLinkRefreshToken WHERE clause prevents a + // stale caller from overwriting a valid refresh token saved + // by a concurrent winner. + t.Run("OptimisticLockPreventsStaleOverwrite", func(t *testing.T) { + t.Parallel() + + db, _ := dbtestutil.NewDB(t) + + fake, config, link := setupOauth2Test(t, testConfig{ + FakeIDPOpts: []oidctest.FakeIDPOpt{ + oidctest.WithRefresh(func(_ string) error { + return nil + }), + oidctest.WithDynamicUserInfo(func(_ string) (jwt.MapClaims, error) { + return jwt.MapClaims{}, nil + }), + }, + ExternalAuthOpt: func(cfg *externalauth.Config) { + cfg.Type = codersdk.EnhancedExternalAuthProviderGitHub.String() + }, + DB: db, + }) + + ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil)) + + // Snapshot the original tokens before any refresh. + oldRefreshToken := link.OAuthRefreshToken + + // Expire the token to force a refresh. + link.OAuthExpiry = expired + + // Caller A: refresh and save successfully. + updated, err := config.RefreshToken(ctx, db, link) + require.NoError(t, err) + require.NotEqual(t, oldRefreshToken, updated.OAuthRefreshToken, + "caller A should have a new refresh token") + + // Caller B had a stale read of the original link. It tries to + // destroy the refresh token using the OLD refresh token in the + // optimistic lock. Because caller A already wrote a different + // refresh token, this WHERE clause matches nothing. + err = db.UpdateExternalAuthLinkRefreshToken(ctx, database.UpdateExternalAuthLinkRefreshTokenParams{ + OauthRefreshFailureReason: "simulated failure from stale caller B", + OAuthRefreshToken: "", + OAuthRefreshTokenKeyID: "", + UpdatedAt: dbtime.Now(), + ProviderID: link.ProviderID, + UserID: link.UserID, + OldOauthRefreshToken: oldRefreshToken, + }) + require.NoError(t, err, "optimistic lock write should not error, it is a no-op") + + // Verify DB still has caller A's valid token. + dbLink, err := db.GetExternalAuthLink(context.Background(), database.GetExternalAuthLinkParams{ + ProviderID: link.ProviderID, + UserID: link.UserID, + }) + require.NoError(t, err) + require.Equal(t, updated.OAuthAccessToken, dbLink.OAuthAccessToken, + "caller A's access token should still be in DB") + require.Equal(t, updated.OAuthRefreshToken, dbLink.OAuthRefreshToken, + "caller A's refresh token should still be in DB") + require.Empty(t, dbLink.OauthRefreshFailureReason, + "caller B's failure reason should not have been written") + }) } func TestRevokeToken(t *testing.T) {