mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
fix(externalauth): prevent race condition in token refresh with optimistic locking (#22904) (backport to 2.29) (#24901)
Backport of https://github.com/coder/coder/pull/22904 to `release/2.29`.
Adds an optimistic lock to `UpdateExternalAuthLinkRefreshToken` so that
a concurrent caller that lost a token-refresh race cannot overwrite a
valid token stored by the winner. The SQL `WHERE` clause now includes
`AND oauth_refresh_token = @old_oauth_refresh_token`.
Original PR: #22904
Merge commit: 53e52aef78
Cherry-pick applied cleanly with no conflicts.
> Generated by Coder Agents
Co-authored-by: Kyle Carberry <kyle@coder.com>
This commit is contained in:
@@ -1538,7 +1538,7 @@ func (s *MethodTestSuite) TestUser() {
|
||||
}))
|
||||
s.Run("UpdateExternalAuthLinkRefreshToken", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
link := testutil.Fake(s.T(), faker, database.ExternalAuthLink{})
|
||||
arg := database.UpdateExternalAuthLinkRefreshTokenParams{OAuthRefreshToken: "", OAuthRefreshTokenKeyID: "", ProviderID: link.ProviderID, UserID: link.UserID, UpdatedAt: link.UpdatedAt}
|
||||
arg := database.UpdateExternalAuthLinkRefreshTokenParams{OAuthRefreshToken: "", OAuthRefreshTokenKeyID: "", ProviderID: link.ProviderID, UserID: link.UserID, UpdatedAt: link.UpdatedAt, OldOauthRefreshToken: link.OAuthRefreshToken}
|
||||
dbm.EXPECT().GetExternalAuthLink(gomock.Any(), database.GetExternalAuthLinkParams{ProviderID: link.ProviderID, UserID: link.UserID}).Return(link, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateExternalAuthLinkRefreshToken(gomock.Any(), arg).Return(nil).AnyTimes()
|
||||
check.Args(arg).Asserts(link, policy.ActionUpdatePersonal)
|
||||
|
||||
@@ -659,6 +659,10 @@ type sqlcQuerier interface {
|
||||
UpdateCryptoKeyDeletesAt(ctx context.Context, arg UpdateCryptoKeyDeletesAtParams) (CryptoKey, error)
|
||||
UpdateCustomRole(ctx context.Context, arg UpdateCustomRoleParams) (CustomRole, error)
|
||||
UpdateExternalAuthLink(ctx context.Context, arg UpdateExternalAuthLinkParams) (ExternalAuthLink, error)
|
||||
// Optimistic lock: only update the row if the refresh token in the database
|
||||
// still matches the one we read before attempting the refresh. This prevents
|
||||
// a concurrent caller that lost a token-refresh race from overwriting a valid
|
||||
// token stored by the winner.
|
||||
UpdateExternalAuthLinkRefreshToken(ctx context.Context, arg UpdateExternalAuthLinkRefreshTokenParams) error
|
||||
UpdateGitSSHKey(ctx context.Context, arg UpdateGitSSHKeyParams) (GitSSHKey, error)
|
||||
UpdateGroupByID(ctx context.Context, arg UpdateGroupByIDParams) (Group, error)
|
||||
|
||||
@@ -2979,9 +2979,11 @@ WHERE
|
||||
provider_id = $4
|
||||
AND
|
||||
user_id = $5
|
||||
AND
|
||||
oauth_refresh_token = $6
|
||||
AND
|
||||
-- Required for sqlc to generate a parameter for the oauth_refresh_token_key_id
|
||||
$6 :: text = $6 :: text
|
||||
$7 :: text = $7 :: text
|
||||
`
|
||||
|
||||
type UpdateExternalAuthLinkRefreshTokenParams struct {
|
||||
@@ -2990,9 +2992,14 @@ type UpdateExternalAuthLinkRefreshTokenParams struct {
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
ProviderID string `db:"provider_id" json:"provider_id"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
OldOauthRefreshToken string `db:"old_oauth_refresh_token" json:"old_oauth_refresh_token"`
|
||||
OAuthRefreshTokenKeyID string `db:"oauth_refresh_token_key_id" json:"oauth_refresh_token_key_id"`
|
||||
}
|
||||
|
||||
// Optimistic lock: only update the row if the refresh token in the database
|
||||
// still matches the one we read before attempting the refresh. This prevents
|
||||
// a concurrent caller that lost a token-refresh race from overwriting a valid
|
||||
// token stored by the winner.
|
||||
func (q *sqlQuerier) UpdateExternalAuthLinkRefreshToken(ctx context.Context, arg UpdateExternalAuthLinkRefreshTokenParams) error {
|
||||
_, err := q.db.ExecContext(ctx, updateExternalAuthLinkRefreshToken,
|
||||
arg.OauthRefreshFailureReason,
|
||||
@@ -3000,6 +3007,7 @@ func (q *sqlQuerier) UpdateExternalAuthLinkRefreshToken(ctx context.Context, arg
|
||||
arg.UpdatedAt,
|
||||
arg.ProviderID,
|
||||
arg.UserID,
|
||||
arg.OldOauthRefreshToken,
|
||||
arg.OAuthRefreshTokenKeyID,
|
||||
)
|
||||
return err
|
||||
|
||||
@@ -48,6 +48,10 @@ UPDATE external_auth_links SET
|
||||
WHERE provider_id = $1 AND user_id = $2 RETURNING *;
|
||||
|
||||
-- name: UpdateExternalAuthLinkRefreshToken :exec
|
||||
-- Optimistic lock: only update the row if the refresh token in the database
|
||||
-- still matches the one we read before attempting the refresh. This prevents
|
||||
-- a concurrent caller that lost a token-refresh race from overwriting a valid
|
||||
-- token stored by the winner.
|
||||
UPDATE
|
||||
external_auth_links
|
||||
SET
|
||||
@@ -60,6 +64,8 @@ WHERE
|
||||
provider_id = @provider_id
|
||||
AND
|
||||
user_id = @user_id
|
||||
AND
|
||||
oauth_refresh_token = @old_oauth_refresh_token
|
||||
AND
|
||||
-- Required for sqlc to generate a parameter for the oauth_refresh_token_key_id
|
||||
@oauth_refresh_token_key_id :: text = @oauth_refresh_token_key_id :: text;
|
||||
|
||||
@@ -138,8 +138,6 @@ func IsInvalidTokenError(err error) bool {
|
||||
}
|
||||
|
||||
// RefreshToken automatically refreshes the token if expired and permitted.
|
||||
// If an error is returned, the token is either invalid, or an error occurred.
|
||||
// Use 'IsInvalidTokenError(err)' to determine the difference.
|
||||
func (c *Config) RefreshToken(ctx context.Context, db database.Store, externalAuthLink database.ExternalAuthLink) (database.ExternalAuthLink, error) {
|
||||
// If the token is expired and refresh is disabled, we prompt
|
||||
// the user to authenticate again.
|
||||
@@ -195,6 +193,9 @@ func (c *Config) RefreshToken(ctx context.Context, db database.Store, externalAu
|
||||
UpdatedAt: dbtime.Now(),
|
||||
ProviderID: externalAuthLink.ProviderID,
|
||||
UserID: externalAuthLink.UserID,
|
||||
// Optimistic lock: only clear the token if it hasn't been
|
||||
// updated by a concurrent caller that won the refresh race.
|
||||
OldOauthRefreshToken: externalAuthLink.OAuthRefreshToken,
|
||||
})
|
||||
if dbExecErr != nil {
|
||||
// This error should be rare.
|
||||
|
||||
@@ -93,6 +93,7 @@ func TestRefreshToken(t *testing.T) {
|
||||
|
||||
// Zero time used
|
||||
link.OAuthExpiry = time.Time{}
|
||||
|
||||
_, err := config.RefreshToken(ctx, nil, link)
|
||||
require.NoError(t, err)
|
||||
require.True(t, validated, "token should have been validated")
|
||||
@@ -107,6 +108,7 @@ func TestRefreshToken(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := config.RefreshToken(context.Background(), nil, database.ExternalAuthLink{
|
||||
OAuthExpiry: expired,
|
||||
})
|
||||
@@ -344,7 +346,6 @@ func TestRefreshToken(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, updated.OAuthAccessToken, dbLink.OAuthAccessToken, "token is updated in the DB")
|
||||
})
|
||||
|
||||
t.Run("WithExtra", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -262,6 +262,39 @@ func (db *dbCrypt) UpdateExternalAuthLink(ctx context.Context, params database.U
|
||||
}
|
||||
|
||||
func (db *dbCrypt) UpdateExternalAuthLinkRefreshToken(ctx context.Context, params database.UpdateExternalAuthLinkRefreshTokenParams) error {
|
||||
// The SQL query uses an optimistic lock:
|
||||
// WHERE oauth_refresh_token = @old_oauth_refresh_token
|
||||
// The caller supplies the plaintext old token (since dbcrypt
|
||||
// decrypts on read), but the DB stores the encrypted value.
|
||||
// Because AES-GCM is non-deterministic, we cannot simply
|
||||
// re-encrypt the old token — the ciphertext would differ.
|
||||
// Instead, read the current row from the inner (raw) store
|
||||
// and use the actual encrypted value for the WHERE clause.
|
||||
if params.OldOauthRefreshToken != "" && db.ciphers != nil && db.primaryCipherDigest != "" {
|
||||
raw, err := db.Store.GetExternalAuthLink(ctx, database.GetExternalAuthLinkParams{
|
||||
ProviderID: params.ProviderID,
|
||||
UserID: params.UserID,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Decrypt the stored token so we can compare with the
|
||||
// caller-supplied plaintext.
|
||||
decrypted := raw.OAuthRefreshToken
|
||||
if err := db.decryptField(&decrypted, raw.OAuthRefreshTokenKeyID); err != nil {
|
||||
return err
|
||||
}
|
||||
if decrypted != params.OldOauthRefreshToken {
|
||||
// The token has changed since the caller read it;
|
||||
// the optimistic lock should fail (no rows updated).
|
||||
// Return nil to match the :exec semantics of the SQL
|
||||
// query, which silently updates zero rows.
|
||||
return nil
|
||||
}
|
||||
// Use the raw encrypted value so the WHERE clause matches.
|
||||
params.OldOauthRefreshToken = raw.OAuthRefreshToken
|
||||
}
|
||||
|
||||
// We would normally use a sql.NullString here, but sqlc does not want to make
|
||||
// a params struct with a nullable string.
|
||||
var digest sql.NullString
|
||||
|
||||
@@ -108,6 +108,7 @@ func TestUserLinks(t *testing.T) {
|
||||
err := crypt.UpdateExternalAuthLinkRefreshToken(ctx, database.UpdateExternalAuthLinkRefreshTokenParams{
|
||||
OAuthRefreshToken: "",
|
||||
OAuthRefreshTokenKeyID: link.OAuthRefreshTokenKeyID.String,
|
||||
OldOauthRefreshToken: link.OAuthRefreshToken,
|
||||
UpdatedAt: dbtime.Now(),
|
||||
ProviderID: link.ProviderID,
|
||||
UserID: link.UserID,
|
||||
|
||||
Reference in New Issue
Block a user