fix(coderd/externalauth): save refreshed token before validation (#24332) (backport to 2.31) (#24899)

Backport of https://github.com/coder/coder/pull/24332 to `release/2.31`.

Moves the `UpdateExternalAuthLink` call to immediately after
`TokenSource.Token()` succeeds (before validation). GitHub rotates
refresh tokens on use, so if post-refresh validation fails (e.g.
rate-limited 403), the new token was previously silently discarded,
forcing manual re-authentication.

Original PR: #24332
Merge commit: 2a1984f0e8

**Note:** This branch includes the cherry-pick of #22904 (optimistic
locking) as a prerequisite since #24332's tests depend on it. The #22904
backport PR is #24902. Once that merges, the overlapping commit in this
PR will be a no-op.

Cherry-picks applied cleanly with no conflicts.

> Generated by Coder Agents

---------

Co-authored-by: Kyle Carberry <kyle@coder.com>
Co-authored-by: Mathias Fredriksson <mafredri@gmail.com>
This commit is contained in:
Garrett Delfosse
2026-05-01 14:51:22 -04:00
committed by GitHub
parent bd06fc5d84
commit 1a078790b1
2 changed files with 284 additions and 29 deletions
+41 -27
View File
@@ -239,6 +239,37 @@ func (c *Config) RefreshToken(ctx context.Context, db database.Store, externalAu
return externalAuthLink, xerrors.Errorf("generate token extra: %w", err)
}
// Persist the refreshed token to the DB before validation. GitHub
// rotates refresh tokens on every use, so the old refresh token is
// already invalid on the IDP side. If we validated first and the
// validation endpoint was unavailable (e.g. rate-limited 403), the
// new token would be silently lost and the user would be forced to
// re-authenticate manually.
// Use a detached context for the DB write only. The IDP already
// consumed the old refresh token, so if the caller's request
// context is canceled mid-save, the new token would be lost.
persistCtx, persistCancel := context.WithTimeout(context.WithoutCancel(ctx), 10*time.Second)
defer persistCancel()
originalAccessToken := externalAuthLink.OAuthAccessToken
if token.AccessToken != originalAccessToken {
updatedAuthLink, err := db.UpdateExternalAuthLink(persistCtx, database.UpdateExternalAuthLinkParams{
ProviderID: c.ID,
UserID: externalAuthLink.UserID,
UpdatedAt: dbtime.Now(),
OAuthAccessToken: token.AccessToken,
OAuthAccessTokenKeyID: sql.NullString{}, // dbcrypt will update as required
OAuthRefreshToken: token.RefreshToken,
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required
OAuthExpiry: token.Expiry,
OAuthExtra: extra,
})
if err != nil {
return updatedAuthLink, xerrors.Errorf("persist refreshed token: %w", err)
}
externalAuthLink = updatedAuthLink
}
r := retry.New(50*time.Millisecond, 200*time.Millisecond)
// See the comment below why the retry and cancel is required.
retryCtx, retryCtxCancel := context.WithTimeout(ctx, time.Second)
@@ -263,35 +294,18 @@ validate:
return externalAuthLink, InvalidTokenError("token failed to validate")
}
if token.AccessToken != externalAuthLink.OAuthAccessToken {
updatedAuthLink, err := db.UpdateExternalAuthLink(ctx, database.UpdateExternalAuthLinkParams{
ProviderID: c.ID,
UserID: externalAuthLink.UserID,
UpdatedAt: dbtime.Now(),
OAuthAccessToken: token.AccessToken,
OAuthAccessTokenKeyID: sql.NullString{}, // dbcrypt will update as required
OAuthRefreshToken: token.RefreshToken,
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required
OAuthExpiry: token.Expiry,
OAuthExtra: extra,
// Update the associated user's github.com user ID if the token
// is for github.com and validation returned user info.
if token.AccessToken != originalAccessToken && IsGithubDotComURL(c.AuthCodeURL("")) && user != nil {
err = db.UpdateUserGithubComUserID(ctx, database.UpdateUserGithubComUserIDParams{
ID: externalAuthLink.UserID,
GithubComUserID: sql.NullInt64{
Int64: user.ID,
Valid: true,
},
})
if err != nil {
return updatedAuthLink, xerrors.Errorf("update external auth link: %w", err)
}
externalAuthLink = updatedAuthLink
// Update the associated users github.com username if the token is for github.com.
if IsGithubDotComURL(c.AuthCodeURL("")) && user != nil {
err = db.UpdateUserGithubComUserID(ctx, database.UpdateUserGithubComUserIDParams{
ID: externalAuthLink.UserID,
GithubComUserID: sql.NullInt64{
Int64: user.ID,
Valid: true,
},
})
if err != nil {
return externalAuthLink, xerrors.Errorf("update user github com user id: %w", err)
}
return externalAuthLink, xerrors.Errorf("update user github com user id: %w", err)
}
}
+243 -2
View File
@@ -8,6 +8,7 @@ import (
"net/http/httptest"
"net/url"
"strings"
"sync/atomic"
"testing"
"time"
@@ -27,6 +28,7 @@ import (
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/externalauth"
"github.com/coder/coder/v2/coderd/promoauth"
"github.com/coder/coder/v2/codersdk"
@@ -120,6 +122,11 @@ func TestRefreshToken(t *testing.T) {
t.Run("ValidateServerError", func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
mDB := dbmock.NewMockStore(ctrl)
mDB.EXPECT().UpdateExternalAuthLink(gomock.Any(), gomock.Any()).
Return(database.ExternalAuthLink{}, nil).AnyTimes()
const staticError = "static error"
validated := false
fake, config, link := setupOauth2Test(t, testConfig{
@@ -136,7 +143,7 @@ func TestRefreshToken(t *testing.T) {
ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil))
link.OAuthExpiry = expired
_, err := config.RefreshToken(ctx, nil, link)
_, err := config.RefreshToken(ctx, mDB, link)
require.ErrorContains(t, err, staticError)
// Unsure if this should be the correct behavior. It's an invalid token because
// 'ValidateToken()' failed with a runtime error. This was the previous behavior,
@@ -223,6 +230,11 @@ func TestRefreshToken(t *testing.T) {
t.Run("ValidateFailure", func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
mDB := dbmock.NewMockStore(ctrl)
mDB.EXPECT().UpdateExternalAuthLink(gomock.Any(), gomock.Any()).
Return(database.ExternalAuthLink{}, nil).AnyTimes()
const staticError = "static error"
validated := false
fake, config, link := setupOauth2Test(t, testConfig{
@@ -239,7 +251,7 @@ func TestRefreshToken(t *testing.T) {
ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil))
link.OAuthExpiry = expired
_, err := config.RefreshToken(ctx, nil, link)
_, err := config.RefreshToken(ctx, mDB, link)
require.ErrorContains(t, err, "token failed to validate")
require.True(t, externalauth.IsInvalidTokenError(err))
require.True(t, validated, "token should have been attempted to be validated")
@@ -380,6 +392,235 @@ func TestRefreshToken(t *testing.T) {
require.True(t, ok)
require.Equal(t, updated.OAuthAccessToken, mapping["access_token"])
})
// SaveBeforeValidate tests that a successfully refreshed token is
// persisted to the DB even when post-refresh validation fails. This
// prevents the data-loss scenario where GitHub rotates the refresh
// token on use but the new token is silently discarded because a
// rate-limited validation endpoint returns 403.
t.Run("SaveBeforeValidate", func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
// simulateRateLimit controls whether the validate endpoint
// returns 403 (true) or 200 (false).
var simulateRateLimit atomic.Bool
simulateRateLimit.Store(true)
var refreshCalls atomic.Int64
fake, config, link := setupOauth2Test(t, testConfig{
FakeIDPOpts: []oidctest.FakeIDPOpt{
oidctest.WithRefresh(func(_ string) error {
refreshCalls.Add(1)
return nil
}),
oidctest.WithDynamicUserInfo(func(_ string) (jwt.MapClaims, error) {
if simulateRateLimit.Load() {
return jwt.MapClaims{}, oidctest.StatusError(http.StatusForbidden, xerrors.New("rate limit exceeded"))
}
return jwt.MapClaims{}, nil
}),
},
ExternalAuthOpt: func(cfg *externalauth.Config) {
cfg.Type = codersdk.EnhancedExternalAuthProviderGitHub.String()
},
DB: db,
})
ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil))
oldAccessToken := link.OAuthAccessToken
oldRefreshToken := link.OAuthRefreshToken
// Expire the token to force a refresh.
link.OAuthExpiry = expired
// First call: refresh succeeds, validation fails (403).
_, err := config.RefreshToken(ctx, db, link)
require.Error(t, err, "expected error because validation returned 403")
require.True(t, externalauth.IsInvalidTokenError(err))
require.Equal(t, int64(1), refreshCalls.Load(), "IDP refresh should have been called exactly once")
// Critical assertion: the DB must contain the NEW tokens from the
// successful refresh, not the old (now-stale) ones.
dbLink, err := db.GetExternalAuthLink(context.Background(), database.GetExternalAuthLinkParams{
ProviderID: link.ProviderID,
UserID: link.UserID,
})
require.NoError(t, err)
require.NotEqual(t, oldAccessToken, dbLink.OAuthAccessToken,
"DB should have the new access token from the successful refresh")
require.NotEqual(t, oldRefreshToken, dbLink.OAuthRefreshToken,
"DB should have the new refresh token (old one was rotated by the IDP)")
// Second call: uses the saved token from DB, no re-refresh.
// The saved token has a future expiry, so TokenSource should return
// it without contacting the IDP. Validation should succeed now.
simulateRateLimit.Store(false)
updated, err := config.RefreshToken(ctx, db, dbLink)
require.NoError(t, err, "second call should succeed because rate limit lifted")
require.Equal(t, int64(1), refreshCalls.Load(),
"IDP refresh should NOT have been called again; the saved token is not expired")
require.Equal(t, dbLink.OAuthAccessToken, updated.OAuthAccessToken,
"returned token should match what was saved in the DB")
})
// SaveBeforeValidate_ContextCanceled verifies the early DB save
// uses a detached context. The parent context is canceled inside
// the refresh hook (after TokenSource.Token() but before the DB
// write), and the test asserts the new token is still persisted.
t.Run("SaveBeforeValidate_ContextCanceled", func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
var refreshCalls atomic.Int64
cancelOnRefresh, cancel := context.WithCancel(context.Background())
defer cancel()
fake, config, link := setupOauth2Test(t, testConfig{
FakeIDPOpts: []oidctest.FakeIDPOpt{
oidctest.WithRefresh(func(_ string) error {
refreshCalls.Add(1)
// Cancel the parent context after refresh succeeds
// but before the DB save and validation.
cancel()
return nil
}),
oidctest.WithDynamicUserInfo(func(_ string) (jwt.MapClaims, error) {
return jwt.MapClaims{}, nil
}),
},
ExternalAuthOpt: func(cfg *externalauth.Config) {
cfg.Type = codersdk.EnhancedExternalAuthProviderGitHub.String()
},
DB: db,
})
ctx := oidc.ClientContext(cancelOnRefresh, fake.HTTPClient(nil))
oldAccessToken := link.OAuthAccessToken
oldRefreshToken := link.OAuthRefreshToken
link.OAuthExpiry = expired
_, err := config.RefreshToken(ctx, db, link)
require.NoError(t, err)
require.Equal(t, int64(1), refreshCalls.Load())
dbLink, err := db.GetExternalAuthLink(context.Background(), database.GetExternalAuthLinkParams{
ProviderID: link.ProviderID,
UserID: link.UserID,
})
require.NoError(t, err)
require.NotEqual(t, oldAccessToken, dbLink.OAuthAccessToken,
"DB should have the new access token despite context cancellation")
require.NotEqual(t, oldRefreshToken, dbLink.OAuthRefreshToken,
"DB should have the new refresh token despite context cancellation")
})
// SaveBeforeValidate_DBError tests that when the early DB save
// fails after a successful IDP refresh, the error is surfaced
// as a non-InvalidTokenError. This is a degraded state (token
// issued by IDP but not persisted), and callers should see a
// real error, not a "please re-authenticate" prompt.
t.Run("SaveBeforeValidate_DBError", func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
mDB := dbmock.NewMockStore(ctrl)
fake, config, link := setupOauth2Test(t, testConfig{
FakeIDPOpts: []oidctest.FakeIDPOpt{
oidctest.WithRefresh(func(_ string) error {
return nil
}),
},
ExternalAuthOpt: func(cfg *externalauth.Config) {
cfg.Type = codersdk.EnhancedExternalAuthProviderGitHub.String()
},
})
ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil))
link.OAuthExpiry = expired
mDB.EXPECT().
UpdateExternalAuthLink(gomock.Any(), gomock.Any()).
Return(database.ExternalAuthLink{}, xerrors.New("db connection lost"))
_, err := config.RefreshToken(ctx, mDB, link)
require.Error(t, err)
require.Contains(t, err.Error(), "persist refreshed token")
require.False(t, externalauth.IsInvalidTokenError(err),
"DB errors should not be treated as invalid token")
})
// OptimisticLockPreventsStaleOverwrite verifies that the
// UpdateExternalAuthLinkRefreshToken WHERE clause prevents a
// stale caller from overwriting a valid refresh token saved
// by a concurrent winner.
t.Run("OptimisticLockPreventsStaleOverwrite", func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
fake, config, link := setupOauth2Test(t, testConfig{
FakeIDPOpts: []oidctest.FakeIDPOpt{
oidctest.WithRefresh(func(_ string) error {
return nil
}),
oidctest.WithDynamicUserInfo(func(_ string) (jwt.MapClaims, error) {
return jwt.MapClaims{}, nil
}),
},
ExternalAuthOpt: func(cfg *externalauth.Config) {
cfg.Type = codersdk.EnhancedExternalAuthProviderGitHub.String()
},
DB: db,
})
ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil))
// Snapshot the original tokens before any refresh.
oldRefreshToken := link.OAuthRefreshToken
// Expire the token to force a refresh.
link.OAuthExpiry = expired
// Caller A: refresh and save successfully.
updated, err := config.RefreshToken(ctx, db, link)
require.NoError(t, err)
require.NotEqual(t, oldRefreshToken, updated.OAuthRefreshToken,
"caller A should have a new refresh token")
// Caller B had a stale read of the original link. It tries to
// destroy the refresh token using the OLD refresh token in the
// optimistic lock. Because caller A already wrote a different
// refresh token, this WHERE clause matches nothing.
err = db.UpdateExternalAuthLinkRefreshToken(ctx, database.UpdateExternalAuthLinkRefreshTokenParams{
OauthRefreshFailureReason: "simulated failure from stale caller B",
OAuthRefreshToken: "",
OAuthRefreshTokenKeyID: "",
UpdatedAt: dbtime.Now(),
ProviderID: link.ProviderID,
UserID: link.UserID,
OldOauthRefreshToken: oldRefreshToken,
})
require.NoError(t, err, "optimistic lock write should not error, it is a no-op")
// Verify DB still has caller A's valid token.
dbLink, err := db.GetExternalAuthLink(context.Background(), database.GetExternalAuthLinkParams{
ProviderID: link.ProviderID,
UserID: link.UserID,
})
require.NoError(t, err)
require.Equal(t, updated.OAuthAccessToken, dbLink.OAuthAccessToken,
"caller A's access token should still be in DB")
require.Equal(t, updated.OAuthRefreshToken, dbLink.OAuthRefreshToken,
"caller A's refresh token should still be in DB")
require.Empty(t, dbLink.OauthRefreshFailureReason,
"caller B's failure reason should not have been written")
})
}
func TestRevokeToken(t *testing.T) {