fix(externalauth): prevent race condition in token refresh with optimistic locking (#22904) (backport to 2.29) (#24901)

Backport of https://github.com/coder/coder/pull/22904 to `release/2.29`.

Adds an optimistic lock to `UpdateExternalAuthLinkRefreshToken` so that
a concurrent caller that lost a token-refresh race cannot overwrite a
valid token stored by the winner. The SQL `WHERE` clause now includes
`AND oauth_refresh_token = @old_oauth_refresh_token`.

Original PR: #22904
Merge commit: 53e52aef78

Cherry-pick applied cleanly with no conflicts.

> Generated by Coder Agents

Co-authored-by: Kyle Carberry <kyle@coder.com>
This commit is contained in:
Garrett Delfosse
2026-05-01 14:37:11 -04:00
committed by GitHub
parent 308b3a0845
commit 63a9280a6f
8 changed files with 59 additions and 5 deletions
+1 -1
View File
@@ -1538,7 +1538,7 @@ func (s *MethodTestSuite) TestUser() {
}))
s.Run("UpdateExternalAuthLinkRefreshToken", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
link := testutil.Fake(s.T(), faker, database.ExternalAuthLink{})
arg := database.UpdateExternalAuthLinkRefreshTokenParams{OAuthRefreshToken: "", OAuthRefreshTokenKeyID: "", ProviderID: link.ProviderID, UserID: link.UserID, UpdatedAt: link.UpdatedAt}
arg := database.UpdateExternalAuthLinkRefreshTokenParams{OAuthRefreshToken: "", OAuthRefreshTokenKeyID: "", ProviderID: link.ProviderID, UserID: link.UserID, UpdatedAt: link.UpdatedAt, OldOauthRefreshToken: link.OAuthRefreshToken}
dbm.EXPECT().GetExternalAuthLink(gomock.Any(), database.GetExternalAuthLinkParams{ProviderID: link.ProviderID, UserID: link.UserID}).Return(link, nil).AnyTimes()
dbm.EXPECT().UpdateExternalAuthLinkRefreshToken(gomock.Any(), arg).Return(nil).AnyTimes()
check.Args(arg).Asserts(link, policy.ActionUpdatePersonal)
+4
View File
@@ -659,6 +659,10 @@ type sqlcQuerier interface {
UpdateCryptoKeyDeletesAt(ctx context.Context, arg UpdateCryptoKeyDeletesAtParams) (CryptoKey, error)
UpdateCustomRole(ctx context.Context, arg UpdateCustomRoleParams) (CustomRole, error)
UpdateExternalAuthLink(ctx context.Context, arg UpdateExternalAuthLinkParams) (ExternalAuthLink, error)
// Optimistic lock: only update the row if the refresh token in the database
// still matches the one we read before attempting the refresh. This prevents
// a concurrent caller that lost a token-refresh race from overwriting a valid
// token stored by the winner.
UpdateExternalAuthLinkRefreshToken(ctx context.Context, arg UpdateExternalAuthLinkRefreshTokenParams) error
UpdateGitSSHKey(ctx context.Context, arg UpdateGitSSHKeyParams) (GitSSHKey, error)
UpdateGroupByID(ctx context.Context, arg UpdateGroupByIDParams) (Group, error)
+9 -1
View File
@@ -2979,9 +2979,11 @@ WHERE
provider_id = $4
AND
user_id = $5
AND
oauth_refresh_token = $6
AND
-- Required for sqlc to generate a parameter for the oauth_refresh_token_key_id
$6 :: text = $6 :: text
$7 :: text = $7 :: text
`
type UpdateExternalAuthLinkRefreshTokenParams struct {
@@ -2990,9 +2992,14 @@ type UpdateExternalAuthLinkRefreshTokenParams struct {
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
ProviderID string `db:"provider_id" json:"provider_id"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
OldOauthRefreshToken string `db:"old_oauth_refresh_token" json:"old_oauth_refresh_token"`
OAuthRefreshTokenKeyID string `db:"oauth_refresh_token_key_id" json:"oauth_refresh_token_key_id"`
}
// Optimistic lock: only update the row if the refresh token in the database
// still matches the one we read before attempting the refresh. This prevents
// a concurrent caller that lost a token-refresh race from overwriting a valid
// token stored by the winner.
func (q *sqlQuerier) UpdateExternalAuthLinkRefreshToken(ctx context.Context, arg UpdateExternalAuthLinkRefreshTokenParams) error {
_, err := q.db.ExecContext(ctx, updateExternalAuthLinkRefreshToken,
arg.OauthRefreshFailureReason,
@@ -3000,6 +3007,7 @@ func (q *sqlQuerier) UpdateExternalAuthLinkRefreshToken(ctx context.Context, arg
arg.UpdatedAt,
arg.ProviderID,
arg.UserID,
arg.OldOauthRefreshToken,
arg.OAuthRefreshTokenKeyID,
)
return err
+6
View File
@@ -48,6 +48,10 @@ UPDATE external_auth_links SET
WHERE provider_id = $1 AND user_id = $2 RETURNING *;
-- name: UpdateExternalAuthLinkRefreshToken :exec
-- Optimistic lock: only update the row if the refresh token in the database
-- still matches the one we read before attempting the refresh. This prevents
-- a concurrent caller that lost a token-refresh race from overwriting a valid
-- token stored by the winner.
UPDATE
external_auth_links
SET
@@ -60,6 +64,8 @@ WHERE
provider_id = @provider_id
AND
user_id = @user_id
AND
oauth_refresh_token = @old_oauth_refresh_token
AND
-- Required for sqlc to generate a parameter for the oauth_refresh_token_key_id
@oauth_refresh_token_key_id :: text = @oauth_refresh_token_key_id :: text;
+3 -2
View File
@@ -138,8 +138,6 @@ func IsInvalidTokenError(err error) bool {
}
// RefreshToken automatically refreshes the token if expired and permitted.
// If an error is returned, the token is either invalid, or an error occurred.
// Use 'IsInvalidTokenError(err)' to determine the difference.
func (c *Config) RefreshToken(ctx context.Context, db database.Store, externalAuthLink database.ExternalAuthLink) (database.ExternalAuthLink, error) {
// If the token is expired and refresh is disabled, we prompt
// the user to authenticate again.
@@ -195,6 +193,9 @@ func (c *Config) RefreshToken(ctx context.Context, db database.Store, externalAu
UpdatedAt: dbtime.Now(),
ProviderID: externalAuthLink.ProviderID,
UserID: externalAuthLink.UserID,
// Optimistic lock: only clear the token if it hasn't been
// updated by a concurrent caller that won the refresh race.
OldOauthRefreshToken: externalAuthLink.OAuthRefreshToken,
})
if dbExecErr != nil {
// This error should be rare.
+2 -1
View File
@@ -93,6 +93,7 @@ func TestRefreshToken(t *testing.T) {
// Zero time used
link.OAuthExpiry = time.Time{}
_, err := config.RefreshToken(ctx, nil, link)
require.NoError(t, err)
require.True(t, validated, "token should have been validated")
@@ -107,6 +108,7 @@ func TestRefreshToken(t *testing.T) {
},
},
}
_, err := config.RefreshToken(context.Background(), nil, database.ExternalAuthLink{
OAuthExpiry: expired,
})
@@ -344,7 +346,6 @@ func TestRefreshToken(t *testing.T) {
require.NoError(t, err)
require.Equal(t, updated.OAuthAccessToken, dbLink.OAuthAccessToken, "token is updated in the DB")
})
t.Run("WithExtra", func(t *testing.T) {
t.Parallel()
+33
View File
@@ -262,6 +262,39 @@ func (db *dbCrypt) UpdateExternalAuthLink(ctx context.Context, params database.U
}
func (db *dbCrypt) UpdateExternalAuthLinkRefreshToken(ctx context.Context, params database.UpdateExternalAuthLinkRefreshTokenParams) error {
// The SQL query uses an optimistic lock:
// WHERE oauth_refresh_token = @old_oauth_refresh_token
// The caller supplies the plaintext old token (since dbcrypt
// decrypts on read), but the DB stores the encrypted value.
// Because AES-GCM is non-deterministic, we cannot simply
// re-encrypt the old token — the ciphertext would differ.
// Instead, read the current row from the inner (raw) store
// and use the actual encrypted value for the WHERE clause.
if params.OldOauthRefreshToken != "" && db.ciphers != nil && db.primaryCipherDigest != "" {
raw, err := db.Store.GetExternalAuthLink(ctx, database.GetExternalAuthLinkParams{
ProviderID: params.ProviderID,
UserID: params.UserID,
})
if err != nil {
return err
}
// Decrypt the stored token so we can compare with the
// caller-supplied plaintext.
decrypted := raw.OAuthRefreshToken
if err := db.decryptField(&decrypted, raw.OAuthRefreshTokenKeyID); err != nil {
return err
}
if decrypted != params.OldOauthRefreshToken {
// The token has changed since the caller read it;
// the optimistic lock should fail (no rows updated).
// Return nil to match the :exec semantics of the SQL
// query, which silently updates zero rows.
return nil
}
// Use the raw encrypted value so the WHERE clause matches.
params.OldOauthRefreshToken = raw.OAuthRefreshToken
}
// We would normally use a sql.NullString here, but sqlc does not want to make
// a params struct with a nullable string.
var digest sql.NullString
@@ -108,6 +108,7 @@ func TestUserLinks(t *testing.T) {
err := crypt.UpdateExternalAuthLinkRefreshToken(ctx, database.UpdateExternalAuthLinkRefreshTokenParams{
OAuthRefreshToken: "",
OAuthRefreshTokenKeyID: link.OAuthRefreshTokenKeyID.String,
OldOauthRefreshToken: link.OAuthRefreshToken,
UpdatedAt: dbtime.Now(),
ProviderID: link.ProviderID,
UserID: link.UserID,