diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 5215955901..2707ca3d84 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -1564,7 +1564,7 @@ func (s *MethodTestSuite) TestUser() { })) s.Run("UpdateExternalAuthLinkRefreshToken", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { link := testutil.Fake(s.T(), faker, database.ExternalAuthLink{}) - arg := database.UpdateExternalAuthLinkRefreshTokenParams{OAuthRefreshToken: "", OAuthRefreshTokenKeyID: "", ProviderID: link.ProviderID, UserID: link.UserID, UpdatedAt: link.UpdatedAt} + arg := database.UpdateExternalAuthLinkRefreshTokenParams{OAuthRefreshToken: "", OAuthRefreshTokenKeyID: "", ProviderID: link.ProviderID, UserID: link.UserID, UpdatedAt: link.UpdatedAt, OldOauthRefreshToken: link.OAuthRefreshToken} dbm.EXPECT().GetExternalAuthLink(gomock.Any(), database.GetExternalAuthLinkParams{ProviderID: link.ProviderID, UserID: link.UserID}).Return(link, nil).AnyTimes() dbm.EXPECT().UpdateExternalAuthLinkRefreshToken(gomock.Any(), arg).Return(nil).AnyTimes() check.Args(arg).Asserts(link, policy.ActionUpdatePersonal) diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 3b6c07dd05..477b759c91 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -695,6 +695,10 @@ type sqlcQuerier interface { UpdateCryptoKeyDeletesAt(ctx context.Context, arg UpdateCryptoKeyDeletesAtParams) (CryptoKey, error) UpdateCustomRole(ctx context.Context, arg UpdateCustomRoleParams) (CustomRole, error) UpdateExternalAuthLink(ctx context.Context, arg UpdateExternalAuthLinkParams) (ExternalAuthLink, error) + // Optimistic lock: only update the row if the refresh token in the database + // still matches the one we read before attempting the refresh. This prevents + // a concurrent caller that lost a token-refresh race from overwriting a valid + // token stored by the winner. UpdateExternalAuthLinkRefreshToken(ctx context.Context, arg UpdateExternalAuthLinkRefreshTokenParams) error UpdateGitSSHKey(ctx context.Context, arg UpdateGitSSHKeyParams) (GitSSHKey, error) UpdateGroupByID(ctx context.Context, arg UpdateGroupByIDParams) (Group, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 20efe1c988..f6dc9ce957 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -3160,9 +3160,11 @@ WHERE provider_id = $4 AND user_id = $5 +AND + oauth_refresh_token = $6 AND -- Required for sqlc to generate a parameter for the oauth_refresh_token_key_id - $6 :: text = $6 :: text + $7 :: text = $7 :: text ` type UpdateExternalAuthLinkRefreshTokenParams struct { @@ -3171,9 +3173,14 @@ type UpdateExternalAuthLinkRefreshTokenParams struct { UpdatedAt time.Time `db:"updated_at" json:"updated_at"` ProviderID string `db:"provider_id" json:"provider_id"` UserID uuid.UUID `db:"user_id" json:"user_id"` + OldOauthRefreshToken string `db:"old_oauth_refresh_token" json:"old_oauth_refresh_token"` OAuthRefreshTokenKeyID string `db:"oauth_refresh_token_key_id" json:"oauth_refresh_token_key_id"` } +// Optimistic lock: only update the row if the refresh token in the database +// still matches the one we read before attempting the refresh. This prevents +// a concurrent caller that lost a token-refresh race from overwriting a valid +// token stored by the winner. func (q *sqlQuerier) UpdateExternalAuthLinkRefreshToken(ctx context.Context, arg UpdateExternalAuthLinkRefreshTokenParams) error { _, err := q.db.ExecContext(ctx, updateExternalAuthLinkRefreshToken, arg.OauthRefreshFailureReason, @@ -3181,6 +3188,7 @@ func (q *sqlQuerier) UpdateExternalAuthLinkRefreshToken(ctx context.Context, arg arg.UpdatedAt, arg.ProviderID, arg.UserID, + arg.OldOauthRefreshToken, arg.OAuthRefreshTokenKeyID, ) return err diff --git a/coderd/database/queries/externalauth.sql b/coderd/database/queries/externalauth.sql index 9ca5cf6f87..e5d0ec548b 100644 --- a/coderd/database/queries/externalauth.sql +++ b/coderd/database/queries/externalauth.sql @@ -48,6 +48,10 @@ UPDATE external_auth_links SET WHERE provider_id = $1 AND user_id = $2 RETURNING *; -- name: UpdateExternalAuthLinkRefreshToken :exec +-- Optimistic lock: only update the row if the refresh token in the database +-- still matches the one we read before attempting the refresh. This prevents +-- a concurrent caller that lost a token-refresh race from overwriting a valid +-- token stored by the winner. UPDATE external_auth_links SET @@ -60,6 +64,8 @@ WHERE provider_id = @provider_id AND user_id = @user_id +AND + oauth_refresh_token = @old_oauth_refresh_token AND -- Required for sqlc to generate a parameter for the oauth_refresh_token_key_id @oauth_refresh_token_key_id :: text = @oauth_refresh_token_key_id :: text; diff --git a/coderd/externalauth/externalauth.go b/coderd/externalauth/externalauth.go index c4e2aa8241..67923d18c2 100644 --- a/coderd/externalauth/externalauth.go +++ b/coderd/externalauth/externalauth.go @@ -139,8 +139,6 @@ func IsInvalidTokenError(err error) bool { } // RefreshToken automatically refreshes the token if expired and permitted. -// If an error is returned, the token is either invalid, or an error occurred. -// Use 'IsInvalidTokenError(err)' to determine the difference. func (c *Config) RefreshToken(ctx context.Context, db database.Store, externalAuthLink database.ExternalAuthLink) (database.ExternalAuthLink, error) { // If the token is expired and refresh is disabled, we prompt // the user to authenticate again. @@ -196,6 +194,9 @@ func (c *Config) RefreshToken(ctx context.Context, db database.Store, externalAu UpdatedAt: dbtime.Now(), ProviderID: externalAuthLink.ProviderID, UserID: externalAuthLink.UserID, + // Optimistic lock: only clear the token if it hasn't been + // updated by a concurrent caller that won the refresh race. + OldOauthRefreshToken: externalAuthLink.OAuthRefreshToken, }) if dbExecErr != nil { // This error should be rare. diff --git a/coderd/externalauth/externalauth_test.go b/coderd/externalauth/externalauth_test.go index 61fdbb2de5..827fafd966 100644 --- a/coderd/externalauth/externalauth_test.go +++ b/coderd/externalauth/externalauth_test.go @@ -93,6 +93,7 @@ func TestRefreshToken(t *testing.T) { // Zero time used link.OAuthExpiry = time.Time{} + _, err := config.RefreshToken(ctx, nil, link) require.NoError(t, err) require.True(t, validated, "token should have been validated") @@ -107,6 +108,7 @@ func TestRefreshToken(t *testing.T) { }, }, } + _, err := config.RefreshToken(context.Background(), nil, database.ExternalAuthLink{ OAuthExpiry: expired, }) @@ -344,7 +346,6 @@ func TestRefreshToken(t *testing.T) { require.NoError(t, err) require.Equal(t, updated.OAuthAccessToken, dbLink.OAuthAccessToken, "token is updated in the DB") }) - t.Run("WithExtra", func(t *testing.T) { t.Parallel() diff --git a/enterprise/dbcrypt/dbcrypt.go b/enterprise/dbcrypt/dbcrypt.go index 08136122ad..4d7223f75d 100644 --- a/enterprise/dbcrypt/dbcrypt.go +++ b/enterprise/dbcrypt/dbcrypt.go @@ -262,6 +262,39 @@ func (db *dbCrypt) UpdateExternalAuthLink(ctx context.Context, params database.U } func (db *dbCrypt) UpdateExternalAuthLinkRefreshToken(ctx context.Context, params database.UpdateExternalAuthLinkRefreshTokenParams) error { + // The SQL query uses an optimistic lock: + // WHERE oauth_refresh_token = @old_oauth_refresh_token + // The caller supplies the plaintext old token (since dbcrypt + // decrypts on read), but the DB stores the encrypted value. + // Because AES-GCM is non-deterministic, we cannot simply + // re-encrypt the old token — the ciphertext would differ. + // Instead, read the current row from the inner (raw) store + // and use the actual encrypted value for the WHERE clause. + if params.OldOauthRefreshToken != "" && db.ciphers != nil && db.primaryCipherDigest != "" { + raw, err := db.Store.GetExternalAuthLink(ctx, database.GetExternalAuthLinkParams{ + ProviderID: params.ProviderID, + UserID: params.UserID, + }) + if err != nil { + return err + } + // Decrypt the stored token so we can compare with the + // caller-supplied plaintext. + decrypted := raw.OAuthRefreshToken + if err := db.decryptField(&decrypted, raw.OAuthRefreshTokenKeyID); err != nil { + return err + } + if decrypted != params.OldOauthRefreshToken { + // The token has changed since the caller read it; + // the optimistic lock should fail (no rows updated). + // Return nil to match the :exec semantics of the SQL + // query, which silently updates zero rows. + return nil + } + // Use the raw encrypted value so the WHERE clause matches. + params.OldOauthRefreshToken = raw.OAuthRefreshToken + } + // We would normally use a sql.NullString here, but sqlc does not want to make // a params struct with a nullable string. var digest sql.NullString diff --git a/enterprise/dbcrypt/dbcrypt_internal_test.go b/enterprise/dbcrypt/dbcrypt_internal_test.go index e73c3eee85..fcf9eae2de 100644 --- a/enterprise/dbcrypt/dbcrypt_internal_test.go +++ b/enterprise/dbcrypt/dbcrypt_internal_test.go @@ -108,6 +108,7 @@ func TestUserLinks(t *testing.T) { err := crypt.UpdateExternalAuthLinkRefreshToken(ctx, database.UpdateExternalAuthLinkRefreshTokenParams{ OAuthRefreshToken: "", OAuthRefreshTokenKeyID: link.OAuthRefreshTokenKeyID.String, + OldOauthRefreshToken: link.OAuthRefreshToken, UpdatedAt: dbtime.Now(), ProviderID: link.ProviderID, UserID: link.UserID,