mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
fix(coderd/externalauth): save refreshed token before validation (#24332) (backport to 2.31) (#24899)
Backport of https://github.com/coder/coder/pull/24332 to `release/2.31`.
Moves the `UpdateExternalAuthLink` call to immediately after
`TokenSource.Token()` succeeds (before validation). GitHub rotates
refresh tokens on use, so if post-refresh validation fails (e.g.
rate-limited 403), the new token was previously silently discarded,
forcing manual re-authentication.
Original PR: #24332
Merge commit: 2a1984f0e8
**Note:** This branch includes the cherry-pick of #22904 (optimistic
locking) as a prerequisite since #24332's tests depend on it. The #22904
backport PR is #24902. Once that merges, the overlapping commit in this
PR will be a no-op.
Cherry-picks applied cleanly with no conflicts.
> Generated by Coder Agents
---------
Co-authored-by: Kyle Carberry <kyle@coder.com>
Co-authored-by: Mathias Fredriksson <mafredri@gmail.com>
This commit is contained in:
@@ -239,6 +239,37 @@ func (c *Config) RefreshToken(ctx context.Context, db database.Store, externalAu
|
||||
return externalAuthLink, xerrors.Errorf("generate token extra: %w", err)
|
||||
}
|
||||
|
||||
// Persist the refreshed token to the DB before validation. GitHub
|
||||
// rotates refresh tokens on every use, so the old refresh token is
|
||||
// already invalid on the IDP side. If we validated first and the
|
||||
// validation endpoint was unavailable (e.g. rate-limited 403), the
|
||||
// new token would be silently lost and the user would be forced to
|
||||
// re-authenticate manually.
|
||||
// Use a detached context for the DB write only. The IDP already
|
||||
// consumed the old refresh token, so if the caller's request
|
||||
// context is canceled mid-save, the new token would be lost.
|
||||
persistCtx, persistCancel := context.WithTimeout(context.WithoutCancel(ctx), 10*time.Second)
|
||||
defer persistCancel()
|
||||
|
||||
originalAccessToken := externalAuthLink.OAuthAccessToken
|
||||
if token.AccessToken != originalAccessToken {
|
||||
updatedAuthLink, err := db.UpdateExternalAuthLink(persistCtx, database.UpdateExternalAuthLinkParams{
|
||||
ProviderID: c.ID,
|
||||
UserID: externalAuthLink.UserID,
|
||||
UpdatedAt: dbtime.Now(),
|
||||
OAuthAccessToken: token.AccessToken,
|
||||
OAuthAccessTokenKeyID: sql.NullString{}, // dbcrypt will update as required
|
||||
OAuthRefreshToken: token.RefreshToken,
|
||||
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required
|
||||
OAuthExpiry: token.Expiry,
|
||||
OAuthExtra: extra,
|
||||
})
|
||||
if err != nil {
|
||||
return updatedAuthLink, xerrors.Errorf("persist refreshed token: %w", err)
|
||||
}
|
||||
externalAuthLink = updatedAuthLink
|
||||
}
|
||||
|
||||
r := retry.New(50*time.Millisecond, 200*time.Millisecond)
|
||||
// See the comment below why the retry and cancel is required.
|
||||
retryCtx, retryCtxCancel := context.WithTimeout(ctx, time.Second)
|
||||
@@ -263,35 +294,18 @@ validate:
|
||||
return externalAuthLink, InvalidTokenError("token failed to validate")
|
||||
}
|
||||
|
||||
if token.AccessToken != externalAuthLink.OAuthAccessToken {
|
||||
updatedAuthLink, err := db.UpdateExternalAuthLink(ctx, database.UpdateExternalAuthLinkParams{
|
||||
ProviderID: c.ID,
|
||||
UserID: externalAuthLink.UserID,
|
||||
UpdatedAt: dbtime.Now(),
|
||||
OAuthAccessToken: token.AccessToken,
|
||||
OAuthAccessTokenKeyID: sql.NullString{}, // dbcrypt will update as required
|
||||
OAuthRefreshToken: token.RefreshToken,
|
||||
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required
|
||||
OAuthExpiry: token.Expiry,
|
||||
OAuthExtra: extra,
|
||||
// Update the associated user's github.com user ID if the token
|
||||
// is for github.com and validation returned user info.
|
||||
if token.AccessToken != originalAccessToken && IsGithubDotComURL(c.AuthCodeURL("")) && user != nil {
|
||||
err = db.UpdateUserGithubComUserID(ctx, database.UpdateUserGithubComUserIDParams{
|
||||
ID: externalAuthLink.UserID,
|
||||
GithubComUserID: sql.NullInt64{
|
||||
Int64: user.ID,
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return updatedAuthLink, xerrors.Errorf("update external auth link: %w", err)
|
||||
}
|
||||
externalAuthLink = updatedAuthLink
|
||||
|
||||
// Update the associated users github.com username if the token is for github.com.
|
||||
if IsGithubDotComURL(c.AuthCodeURL("")) && user != nil {
|
||||
err = db.UpdateUserGithubComUserID(ctx, database.UpdateUserGithubComUserIDParams{
|
||||
ID: externalAuthLink.UserID,
|
||||
GithubComUserID: sql.NullInt64{
|
||||
Int64: user.ID,
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return externalAuthLink, xerrors.Errorf("update user github com user id: %w", err)
|
||||
}
|
||||
return externalAuthLink, xerrors.Errorf("update user github com user id: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -27,6 +28,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/coderd/externalauth"
|
||||
"github.com/coder/coder/v2/coderd/promoauth"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
@@ -120,6 +122,11 @@ func TestRefreshToken(t *testing.T) {
|
||||
t.Run("ValidateServerError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mDB := dbmock.NewMockStore(ctrl)
|
||||
mDB.EXPECT().UpdateExternalAuthLink(gomock.Any(), gomock.Any()).
|
||||
Return(database.ExternalAuthLink{}, nil).AnyTimes()
|
||||
|
||||
const staticError = "static error"
|
||||
validated := false
|
||||
fake, config, link := setupOauth2Test(t, testConfig{
|
||||
@@ -136,7 +143,7 @@ func TestRefreshToken(t *testing.T) {
|
||||
ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil))
|
||||
link.OAuthExpiry = expired
|
||||
|
||||
_, err := config.RefreshToken(ctx, nil, link)
|
||||
_, err := config.RefreshToken(ctx, mDB, link)
|
||||
require.ErrorContains(t, err, staticError)
|
||||
// Unsure if this should be the correct behavior. It's an invalid token because
|
||||
// 'ValidateToken()' failed with a runtime error. This was the previous behavior,
|
||||
@@ -223,6 +230,11 @@ func TestRefreshToken(t *testing.T) {
|
||||
t.Run("ValidateFailure", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mDB := dbmock.NewMockStore(ctrl)
|
||||
mDB.EXPECT().UpdateExternalAuthLink(gomock.Any(), gomock.Any()).
|
||||
Return(database.ExternalAuthLink{}, nil).AnyTimes()
|
||||
|
||||
const staticError = "static error"
|
||||
validated := false
|
||||
fake, config, link := setupOauth2Test(t, testConfig{
|
||||
@@ -239,7 +251,7 @@ func TestRefreshToken(t *testing.T) {
|
||||
ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil))
|
||||
link.OAuthExpiry = expired
|
||||
|
||||
_, err := config.RefreshToken(ctx, nil, link)
|
||||
_, err := config.RefreshToken(ctx, mDB, link)
|
||||
require.ErrorContains(t, err, "token failed to validate")
|
||||
require.True(t, externalauth.IsInvalidTokenError(err))
|
||||
require.True(t, validated, "token should have been attempted to be validated")
|
||||
@@ -380,6 +392,235 @@ func TestRefreshToken(t *testing.T) {
|
||||
require.True(t, ok)
|
||||
require.Equal(t, updated.OAuthAccessToken, mapping["access_token"])
|
||||
})
|
||||
|
||||
// SaveBeforeValidate tests that a successfully refreshed token is
|
||||
// persisted to the DB even when post-refresh validation fails. This
|
||||
// prevents the data-loss scenario where GitHub rotates the refresh
|
||||
// token on use but the new token is silently discarded because a
|
||||
// rate-limited validation endpoint returns 403.
|
||||
t.Run("SaveBeforeValidate", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
|
||||
// simulateRateLimit controls whether the validate endpoint
|
||||
// returns 403 (true) or 200 (false).
|
||||
var simulateRateLimit atomic.Bool
|
||||
simulateRateLimit.Store(true)
|
||||
|
||||
var refreshCalls atomic.Int64
|
||||
fake, config, link := setupOauth2Test(t, testConfig{
|
||||
FakeIDPOpts: []oidctest.FakeIDPOpt{
|
||||
oidctest.WithRefresh(func(_ string) error {
|
||||
refreshCalls.Add(1)
|
||||
return nil
|
||||
}),
|
||||
oidctest.WithDynamicUserInfo(func(_ string) (jwt.MapClaims, error) {
|
||||
if simulateRateLimit.Load() {
|
||||
return jwt.MapClaims{}, oidctest.StatusError(http.StatusForbidden, xerrors.New("rate limit exceeded"))
|
||||
}
|
||||
return jwt.MapClaims{}, nil
|
||||
}),
|
||||
},
|
||||
ExternalAuthOpt: func(cfg *externalauth.Config) {
|
||||
cfg.Type = codersdk.EnhancedExternalAuthProviderGitHub.String()
|
||||
},
|
||||
DB: db,
|
||||
})
|
||||
|
||||
ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil))
|
||||
|
||||
oldAccessToken := link.OAuthAccessToken
|
||||
oldRefreshToken := link.OAuthRefreshToken
|
||||
|
||||
// Expire the token to force a refresh.
|
||||
link.OAuthExpiry = expired
|
||||
|
||||
// First call: refresh succeeds, validation fails (403).
|
||||
_, err := config.RefreshToken(ctx, db, link)
|
||||
require.Error(t, err, "expected error because validation returned 403")
|
||||
require.True(t, externalauth.IsInvalidTokenError(err))
|
||||
require.Equal(t, int64(1), refreshCalls.Load(), "IDP refresh should have been called exactly once")
|
||||
|
||||
// Critical assertion: the DB must contain the NEW tokens from the
|
||||
// successful refresh, not the old (now-stale) ones.
|
||||
dbLink, err := db.GetExternalAuthLink(context.Background(), database.GetExternalAuthLinkParams{
|
||||
ProviderID: link.ProviderID,
|
||||
UserID: link.UserID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, oldAccessToken, dbLink.OAuthAccessToken,
|
||||
"DB should have the new access token from the successful refresh")
|
||||
require.NotEqual(t, oldRefreshToken, dbLink.OAuthRefreshToken,
|
||||
"DB should have the new refresh token (old one was rotated by the IDP)")
|
||||
|
||||
// Second call: uses the saved token from DB, no re-refresh.
|
||||
// The saved token has a future expiry, so TokenSource should return
|
||||
// it without contacting the IDP. Validation should succeed now.
|
||||
simulateRateLimit.Store(false)
|
||||
updated, err := config.RefreshToken(ctx, db, dbLink)
|
||||
require.NoError(t, err, "second call should succeed because rate limit lifted")
|
||||
require.Equal(t, int64(1), refreshCalls.Load(),
|
||||
"IDP refresh should NOT have been called again; the saved token is not expired")
|
||||
require.Equal(t, dbLink.OAuthAccessToken, updated.OAuthAccessToken,
|
||||
"returned token should match what was saved in the DB")
|
||||
})
|
||||
|
||||
// SaveBeforeValidate_ContextCanceled verifies the early DB save
|
||||
// uses a detached context. The parent context is canceled inside
|
||||
// the refresh hook (after TokenSource.Token() but before the DB
|
||||
// write), and the test asserts the new token is still persisted.
|
||||
t.Run("SaveBeforeValidate_ContextCanceled", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
|
||||
var refreshCalls atomic.Int64
|
||||
cancelOnRefresh, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
fake, config, link := setupOauth2Test(t, testConfig{
|
||||
FakeIDPOpts: []oidctest.FakeIDPOpt{
|
||||
oidctest.WithRefresh(func(_ string) error {
|
||||
refreshCalls.Add(1)
|
||||
// Cancel the parent context after refresh succeeds
|
||||
// but before the DB save and validation.
|
||||
cancel()
|
||||
return nil
|
||||
}),
|
||||
oidctest.WithDynamicUserInfo(func(_ string) (jwt.MapClaims, error) {
|
||||
return jwt.MapClaims{}, nil
|
||||
}),
|
||||
},
|
||||
ExternalAuthOpt: func(cfg *externalauth.Config) {
|
||||
cfg.Type = codersdk.EnhancedExternalAuthProviderGitHub.String()
|
||||
},
|
||||
DB: db,
|
||||
})
|
||||
|
||||
ctx := oidc.ClientContext(cancelOnRefresh, fake.HTTPClient(nil))
|
||||
|
||||
oldAccessToken := link.OAuthAccessToken
|
||||
oldRefreshToken := link.OAuthRefreshToken
|
||||
link.OAuthExpiry = expired
|
||||
|
||||
_, err := config.RefreshToken(ctx, db, link)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), refreshCalls.Load())
|
||||
|
||||
dbLink, err := db.GetExternalAuthLink(context.Background(), database.GetExternalAuthLinkParams{
|
||||
ProviderID: link.ProviderID,
|
||||
UserID: link.UserID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, oldAccessToken, dbLink.OAuthAccessToken,
|
||||
"DB should have the new access token despite context cancellation")
|
||||
require.NotEqual(t, oldRefreshToken, dbLink.OAuthRefreshToken,
|
||||
"DB should have the new refresh token despite context cancellation")
|
||||
})
|
||||
|
||||
// SaveBeforeValidate_DBError tests that when the early DB save
|
||||
// fails after a successful IDP refresh, the error is surfaced
|
||||
// as a non-InvalidTokenError. This is a degraded state (token
|
||||
// issued by IDP but not persisted), and callers should see a
|
||||
// real error, not a "please re-authenticate" prompt.
|
||||
t.Run("SaveBeforeValidate_DBError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mDB := dbmock.NewMockStore(ctrl)
|
||||
|
||||
fake, config, link := setupOauth2Test(t, testConfig{
|
||||
FakeIDPOpts: []oidctest.FakeIDPOpt{
|
||||
oidctest.WithRefresh(func(_ string) error {
|
||||
return nil
|
||||
}),
|
||||
},
|
||||
ExternalAuthOpt: func(cfg *externalauth.Config) {
|
||||
cfg.Type = codersdk.EnhancedExternalAuthProviderGitHub.String()
|
||||
},
|
||||
})
|
||||
|
||||
ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil))
|
||||
link.OAuthExpiry = expired
|
||||
|
||||
mDB.EXPECT().
|
||||
UpdateExternalAuthLink(gomock.Any(), gomock.Any()).
|
||||
Return(database.ExternalAuthLink{}, xerrors.New("db connection lost"))
|
||||
|
||||
_, err := config.RefreshToken(ctx, mDB, link)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "persist refreshed token")
|
||||
require.False(t, externalauth.IsInvalidTokenError(err),
|
||||
"DB errors should not be treated as invalid token")
|
||||
})
|
||||
|
||||
// OptimisticLockPreventsStaleOverwrite verifies that the
|
||||
// UpdateExternalAuthLinkRefreshToken WHERE clause prevents a
|
||||
// stale caller from overwriting a valid refresh token saved
|
||||
// by a concurrent winner.
|
||||
t.Run("OptimisticLockPreventsStaleOverwrite", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
|
||||
fake, config, link := setupOauth2Test(t, testConfig{
|
||||
FakeIDPOpts: []oidctest.FakeIDPOpt{
|
||||
oidctest.WithRefresh(func(_ string) error {
|
||||
return nil
|
||||
}),
|
||||
oidctest.WithDynamicUserInfo(func(_ string) (jwt.MapClaims, error) {
|
||||
return jwt.MapClaims{}, nil
|
||||
}),
|
||||
},
|
||||
ExternalAuthOpt: func(cfg *externalauth.Config) {
|
||||
cfg.Type = codersdk.EnhancedExternalAuthProviderGitHub.String()
|
||||
},
|
||||
DB: db,
|
||||
})
|
||||
|
||||
ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil))
|
||||
|
||||
// Snapshot the original tokens before any refresh.
|
||||
oldRefreshToken := link.OAuthRefreshToken
|
||||
|
||||
// Expire the token to force a refresh.
|
||||
link.OAuthExpiry = expired
|
||||
|
||||
// Caller A: refresh and save successfully.
|
||||
updated, err := config.RefreshToken(ctx, db, link)
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, oldRefreshToken, updated.OAuthRefreshToken,
|
||||
"caller A should have a new refresh token")
|
||||
|
||||
// Caller B had a stale read of the original link. It tries to
|
||||
// destroy the refresh token using the OLD refresh token in the
|
||||
// optimistic lock. Because caller A already wrote a different
|
||||
// refresh token, this WHERE clause matches nothing.
|
||||
err = db.UpdateExternalAuthLinkRefreshToken(ctx, database.UpdateExternalAuthLinkRefreshTokenParams{
|
||||
OauthRefreshFailureReason: "simulated failure from stale caller B",
|
||||
OAuthRefreshToken: "",
|
||||
OAuthRefreshTokenKeyID: "",
|
||||
UpdatedAt: dbtime.Now(),
|
||||
ProviderID: link.ProviderID,
|
||||
UserID: link.UserID,
|
||||
OldOauthRefreshToken: oldRefreshToken,
|
||||
})
|
||||
require.NoError(t, err, "optimistic lock write should not error, it is a no-op")
|
||||
|
||||
// Verify DB still has caller A's valid token.
|
||||
dbLink, err := db.GetExternalAuthLink(context.Background(), database.GetExternalAuthLinkParams{
|
||||
ProviderID: link.ProviderID,
|
||||
UserID: link.UserID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, updated.OAuthAccessToken, dbLink.OAuthAccessToken,
|
||||
"caller A's access token should still be in DB")
|
||||
require.Equal(t, updated.OAuthRefreshToken, dbLink.OAuthRefreshToken,
|
||||
"caller A's refresh token should still be in DB")
|
||||
require.Empty(t, dbLink.OauthRefreshFailureReason,
|
||||
"caller B's failure reason should not have been written")
|
||||
})
|
||||
}
|
||||
|
||||
func TestRevokeToken(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user