Files
coder/enterprise/dbcrypt/dbcrypt.go
T
Kyle Carberry 53e52aef78 fix(externalauth): prevent race condition in token refresh with optimistic locking (#22904)
## Problem

When multiple concurrent callers (e.g., parallel workspace builds) read
the same single-use OAuth2 refresh token from the database and race to
exchange it with the provider, the first caller succeeds but subsequent
callers get `bad_refresh_token`. The losing caller then **clears the
valid new token** from the database, permanently breaking the auth link
until the user manually re-authenticates.

This is reliably reproducible when launching multiple workspaces
simultaneously with GitHub App external auth and user-to-server token
expiration enabled.

## Solution

Two layers of protection:

### 1. Singleflight deduplication (`Config.RefreshToken` +
`ObtainOIDCAccessToken`)

Concurrent callers for the same user/provider share a single refresh
call via `golang.org/x/sync/singleflight`, keyed by `userID`. The
singleflight callback re-reads the link from the database to pick up any
token already refreshed by a prior in-flight call, avoiding redundant
IDP round-trips entirely.

### 2. Optimistic locking on `UpdateExternalAuthLinkRefreshToken`

The SQL `WHERE` clause now includes `AND oauth_refresh_token =
@old_oauth_refresh_token`, so if two replicas (HA) race past
singleflight, the loser's destructive UPDATE is a harmless no-op rather
than overwriting the winner's valid token.

## Changes

| File | Change |
|------|--------|
| `coderd/externalauth/externalauth.go` | Added `singleflight.Group` to
`Config`; split `RefreshToken` into public wrapper +
`refreshTokenInner`; pass `OldOauthRefreshToken` to DB update |
| `coderd/provisionerdserver/provisionerdserver.go` | Wrapped OIDC
refresh in `ObtainOIDCAccessToken` with package-level singleflight |
| `coderd/database/queries/externalauth.sql` | Added optimistic lock
(`WHERE ... AND oauth_refresh_token = @old_oauth_refresh_token`) |
| `coderd/database/queries.sql.go` | Regenerated |
| `coderd/database/querier.go` | Regenerated |
| `coderd/database/dbauthz/dbauthz_test.go` | Updated test params for
new field |
| `coderd/externalauth/externalauth_test.go` | Added
`ConcurrentRefreshDedup` test; updated existing tests for singleflight
DB re-read |

## Testing

- **New test `ConcurrentRefreshDedup`**: 5 goroutines call
`RefreshToken` concurrently, asserts IDP refresh called exactly once,
all callers get same token.
- All existing `TestRefreshToken/*` subtests updated and passing.
- `TestObtainOIDCAccessToken` passing.
- `dbauthz` tests passing.
2026-03-10 13:52:55 -04:00

585 lines
19 KiB
Go

package dbcrypt
import (
"context"
"database/sql"
"encoding/base64"
"strings"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
)
// testValue is the value that is stored in dbcrypt_keys.test.
// This is used to determine if the key is valid.
const testValue = "coder"
var (
b64encode = base64.StdEncoding.EncodeToString
b64decode = base64.StdEncoding.DecodeString
)
// DecryptFailedError is returned when decryption fails.
type DecryptFailedError struct {
Inner error
}
func (e *DecryptFailedError) Error() string {
return xerrors.Errorf("decrypt failed: %w", e.Inner).Error()
}
// New creates a database.Store wrapper that encrypts/decrypts values
// stored at rest in the database.
func New(ctx context.Context, db database.Store, ciphers ...Cipher) (database.Store, error) {
cm := make(map[string]Cipher)
for _, c := range ciphers {
cm[c.HexDigest()] = c
}
dbc := &dbCrypt{
ciphers: cm,
Store: db,
}
if len(ciphers) > 0 {
dbc.primaryCipherDigest = ciphers[0].HexDigest()
}
// nolint: gocritic // This is allowed.
authCtx := dbauthz.AsSystemRestricted(ctx)
if err := dbc.ensureEncryptedWithRetry(authCtx); err != nil {
return nil, xerrors.Errorf("ensure encrypted database fields: %w", err)
}
return dbc, nil
}
type dbCrypt struct {
// primaryCipherDigest is the digest of the primary cipher used for encrypting data.
primaryCipherDigest string
// ciphers is a map of cipher digests to ciphers.
ciphers map[string]Cipher
database.Store
}
func (db *dbCrypt) InTx(function func(database.Store) error, txOpts *database.TxOptions) error {
return db.Store.InTx(func(s database.Store) error {
return function(&dbCrypt{
primaryCipherDigest: db.primaryCipherDigest,
ciphers: db.ciphers,
Store: s,
})
}, txOpts)
}
func (db *dbCrypt) GetDBCryptKeys(ctx context.Context) ([]database.DBCryptKey, error) {
ks, err := db.Store.GetDBCryptKeys(ctx)
if err != nil {
return nil, err
}
// Decrypt the test field to ensure that the key is valid.
for i := range ks {
if !ks[i].ActiveKeyDigest.Valid {
// Key has been revoked. We can't decrypt the test field, but
// we need to return it so that the caller knows that the key
// has been revoked.
continue
}
if err := db.decryptField(&ks[i].Test, ks[i].ActiveKeyDigest); err != nil {
return nil, err
}
}
return ks, nil
}
// This does not need any special handling as it does not touch any encrypted fields.
// Explicitly defining this here to avoid confusion.
func (db *dbCrypt) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error {
return db.Store.RevokeDBCryptKey(ctx, activeKeyDigest)
}
func (db *dbCrypt) InsertDBCryptKey(ctx context.Context, arg database.InsertDBCryptKeyParams) error {
// It's nicer to be able to pass a *sql.NullString to encryptField, but we need to pass a string here.
var digest sql.NullString
err := db.encryptField(&arg.Test, &digest)
if err != nil {
return err
}
arg.ActiveKeyDigest = digest.String
return db.Store.InsertDBCryptKey(ctx, arg)
}
func (db *dbCrypt) GetUserLinkByLinkedID(ctx context.Context, linkedID string) (database.UserLink, error) {
link, err := db.Store.GetUserLinkByLinkedID(ctx, linkedID)
if err != nil {
return database.UserLink{}, err
}
if err := db.decryptField(&link.OAuthAccessToken, link.OAuthAccessTokenKeyID); err != nil {
return database.UserLink{}, err
}
if err := db.decryptField(&link.OAuthRefreshToken, link.OAuthRefreshTokenKeyID); err != nil {
return database.UserLink{}, err
}
return link, nil
}
func (db *dbCrypt) GetUserLinksByUserID(ctx context.Context, userID uuid.UUID) ([]database.UserLink, error) {
links, err := db.Store.GetUserLinksByUserID(ctx, userID)
if err != nil {
return nil, err
}
for idx := range links {
if err := db.decryptField(&links[idx].OAuthAccessToken, links[idx].OAuthAccessTokenKeyID); err != nil {
return nil, err
}
if err := db.decryptField(&links[idx].OAuthRefreshToken, links[idx].OAuthRefreshTokenKeyID); err != nil {
return nil, err
}
}
return links, nil
}
func (db *dbCrypt) GetUserLinkByUserIDLoginType(ctx context.Context, params database.GetUserLinkByUserIDLoginTypeParams) (database.UserLink, error) {
link, err := db.Store.GetUserLinkByUserIDLoginType(ctx, params)
if err != nil {
return database.UserLink{}, err
}
if err := db.decryptField(&link.OAuthAccessToken, link.OAuthAccessTokenKeyID); err != nil {
return database.UserLink{}, err
}
if err := db.decryptField(&link.OAuthRefreshToken, link.OAuthRefreshTokenKeyID); err != nil {
return database.UserLink{}, err
}
return link, nil
}
func (db *dbCrypt) InsertUserLink(ctx context.Context, params database.InsertUserLinkParams) (database.UserLink, error) {
if err := db.encryptField(&params.OAuthAccessToken, &params.OAuthAccessTokenKeyID); err != nil {
return database.UserLink{}, err
}
if err := db.encryptField(&params.OAuthRefreshToken, &params.OAuthRefreshTokenKeyID); err != nil {
return database.UserLink{}, err
}
link, err := db.Store.InsertUserLink(ctx, params)
if err != nil {
return database.UserLink{}, err
}
if err := db.decryptField(&link.OAuthAccessToken, link.OAuthAccessTokenKeyID); err != nil {
return database.UserLink{}, err
}
if err := db.decryptField(&link.OAuthRefreshToken, link.OAuthRefreshTokenKeyID); err != nil {
return database.UserLink{}, err
}
return link, nil
}
func (db *dbCrypt) UpdateUserLink(ctx context.Context, params database.UpdateUserLinkParams) (database.UserLink, error) {
if err := db.encryptField(&params.OAuthAccessToken, &params.OAuthAccessTokenKeyID); err != nil {
return database.UserLink{}, err
}
if err := db.encryptField(&params.OAuthRefreshToken, &params.OAuthRefreshTokenKeyID); err != nil {
return database.UserLink{}, err
}
link, err := db.Store.UpdateUserLink(ctx, params)
if err != nil {
return database.UserLink{}, err
}
if err := db.decryptField(&link.OAuthAccessToken, link.OAuthAccessTokenKeyID); err != nil {
return database.UserLink{}, err
}
if err := db.decryptField(&link.OAuthRefreshToken, link.OAuthRefreshTokenKeyID); err != nil {
return database.UserLink{}, err
}
return link, nil
}
func (db *dbCrypt) InsertExternalAuthLink(ctx context.Context, params database.InsertExternalAuthLinkParams) (database.ExternalAuthLink, error) {
if err := db.encryptField(&params.OAuthAccessToken, &params.OAuthAccessTokenKeyID); err != nil {
return database.ExternalAuthLink{}, err
}
if err := db.encryptField(&params.OAuthRefreshToken, &params.OAuthRefreshTokenKeyID); err != nil {
return database.ExternalAuthLink{}, err
}
link, err := db.Store.InsertExternalAuthLink(ctx, params)
if err != nil {
return database.ExternalAuthLink{}, err
}
if err := db.decryptField(&link.OAuthAccessToken, link.OAuthAccessTokenKeyID); err != nil {
return database.ExternalAuthLink{}, err
}
if err := db.decryptField(&link.OAuthRefreshToken, link.OAuthRefreshTokenKeyID); err != nil {
return database.ExternalAuthLink{}, err
}
return link, nil
}
func (db *dbCrypt) GetExternalAuthLink(ctx context.Context, params database.GetExternalAuthLinkParams) (database.ExternalAuthLink, error) {
link, err := db.Store.GetExternalAuthLink(ctx, params)
if err != nil {
return database.ExternalAuthLink{}, err
}
if err := db.decryptField(&link.OAuthAccessToken, link.OAuthAccessTokenKeyID); err != nil {
return database.ExternalAuthLink{}, err
}
if err := db.decryptField(&link.OAuthRefreshToken, link.OAuthRefreshTokenKeyID); err != nil {
return database.ExternalAuthLink{}, err
}
return link, nil
}
func (db *dbCrypt) GetExternalAuthLinksByUserID(ctx context.Context, userID uuid.UUID) ([]database.ExternalAuthLink, error) {
links, err := db.Store.GetExternalAuthLinksByUserID(ctx, userID)
if err != nil {
return nil, err
}
for idx := range links {
if err := db.decryptField(&links[idx].OAuthAccessToken, links[idx].OAuthAccessTokenKeyID); err != nil {
return nil, err
}
if err := db.decryptField(&links[idx].OAuthRefreshToken, links[idx].OAuthRefreshTokenKeyID); err != nil {
return nil, err
}
}
return links, nil
}
func (db *dbCrypt) UpdateExternalAuthLink(ctx context.Context, params database.UpdateExternalAuthLinkParams) (database.ExternalAuthLink, error) {
if err := db.encryptField(&params.OAuthAccessToken, &params.OAuthAccessTokenKeyID); err != nil {
return database.ExternalAuthLink{}, err
}
if err := db.encryptField(&params.OAuthRefreshToken, &params.OAuthRefreshTokenKeyID); err != nil {
return database.ExternalAuthLink{}, err
}
link, err := db.Store.UpdateExternalAuthLink(ctx, params)
if err != nil {
return database.ExternalAuthLink{}, err
}
if err := db.decryptField(&link.OAuthAccessToken, link.OAuthAccessTokenKeyID); err != nil {
return database.ExternalAuthLink{}, err
}
if err := db.decryptField(&link.OAuthRefreshToken, link.OAuthRefreshTokenKeyID); err != nil {
return database.ExternalAuthLink{}, err
}
return link, nil
}
func (db *dbCrypt) UpdateExternalAuthLinkRefreshToken(ctx context.Context, params database.UpdateExternalAuthLinkRefreshTokenParams) error {
// The SQL query uses an optimistic lock:
// WHERE oauth_refresh_token = @old_oauth_refresh_token
// The caller supplies the plaintext old token (since dbcrypt
// decrypts on read), but the DB stores the encrypted value.
// Because AES-GCM is non-deterministic, we cannot simply
// re-encrypt the old token — the ciphertext would differ.
// Instead, read the current row from the inner (raw) store
// and use the actual encrypted value for the WHERE clause.
if params.OldOauthRefreshToken != "" && db.ciphers != nil && db.primaryCipherDigest != "" {
raw, err := db.Store.GetExternalAuthLink(ctx, database.GetExternalAuthLinkParams{
ProviderID: params.ProviderID,
UserID: params.UserID,
})
if err != nil {
return err
}
// Decrypt the stored token so we can compare with the
// caller-supplied plaintext.
decrypted := raw.OAuthRefreshToken
if err := db.decryptField(&decrypted, raw.OAuthRefreshTokenKeyID); err != nil {
return err
}
if decrypted != params.OldOauthRefreshToken {
// The token has changed since the caller read it;
// the optimistic lock should fail (no rows updated).
// Return nil to match the :exec semantics of the SQL
// query, which silently updates zero rows.
return nil
}
// Use the raw encrypted value so the WHERE clause matches.
params.OldOauthRefreshToken = raw.OAuthRefreshToken
}
// We would normally use a sql.NullString here, but sqlc does not want to make
// a params struct with a nullable string.
var digest sql.NullString
if params.OAuthRefreshTokenKeyID != "" {
digest.String = params.OAuthRefreshTokenKeyID
digest.Valid = true
}
if err := db.encryptField(&params.OAuthRefreshToken, &digest); err != nil {
return err
}
return db.Store.UpdateExternalAuthLinkRefreshToken(ctx, params)
}
func (db *dbCrypt) GetCryptoKeys(ctx context.Context) ([]database.CryptoKey, error) {
keys, err := db.Store.GetCryptoKeys(ctx)
if err != nil {
return nil, err
}
for i := range keys {
if err := db.decryptField(&keys[i].Secret.String, keys[i].SecretKeyID); err != nil {
return nil, err
}
}
return keys, nil
}
func (db *dbCrypt) GetLatestCryptoKeyByFeature(ctx context.Context, feature database.CryptoKeyFeature) (database.CryptoKey, error) {
key, err := db.Store.GetLatestCryptoKeyByFeature(ctx, feature)
if err != nil {
return database.CryptoKey{}, err
}
if err := db.decryptField(&key.Secret.String, key.SecretKeyID); err != nil {
return database.CryptoKey{}, err
}
return key, nil
}
func (db *dbCrypt) GetCryptoKeyByFeatureAndSequence(ctx context.Context, params database.GetCryptoKeyByFeatureAndSequenceParams) (database.CryptoKey, error) {
key, err := db.Store.GetCryptoKeyByFeatureAndSequence(ctx, params)
if err != nil {
return database.CryptoKey{}, err
}
if err := db.decryptField(&key.Secret.String, key.SecretKeyID); err != nil {
return database.CryptoKey{}, err
}
return key, nil
}
func (db *dbCrypt) InsertCryptoKey(ctx context.Context, params database.InsertCryptoKeyParams) (database.CryptoKey, error) {
if err := db.encryptField(&params.Secret.String, &params.SecretKeyID); err != nil {
return database.CryptoKey{}, err
}
key, err := db.Store.InsertCryptoKey(ctx, params)
if err != nil {
return database.CryptoKey{}, err
}
if err := db.decryptField(&key.Secret.String, key.SecretKeyID); err != nil {
return database.CryptoKey{}, err
}
return key, nil
}
func (db *dbCrypt) UpdateCryptoKeyDeletesAt(ctx context.Context, arg database.UpdateCryptoKeyDeletesAtParams) (database.CryptoKey, error) {
key, err := db.Store.UpdateCryptoKeyDeletesAt(ctx, arg)
if err != nil {
return database.CryptoKey{}, err
}
if err := db.decryptField(&key.Secret.String, key.SecretKeyID); err != nil {
return database.CryptoKey{}, err
}
return key, nil
}
func (db *dbCrypt) GetCryptoKeysByFeature(ctx context.Context, feature database.CryptoKeyFeature) ([]database.CryptoKey, error) {
keys, err := db.Store.GetCryptoKeysByFeature(ctx, feature)
if err != nil {
return nil, err
}
for i := range keys {
if err := db.decryptField(&keys[i].Secret.String, keys[i].SecretKeyID); err != nil {
return nil, err
}
}
return keys, nil
}
func (db *dbCrypt) GetChatProviderByID(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) {
provider, err := db.Store.GetChatProviderByID(ctx, id)
if err != nil {
return database.ChatProvider{}, err
}
if err := db.decryptField(&provider.APIKey, provider.ApiKeyKeyID); err != nil {
return database.ChatProvider{}, err
}
return provider, nil
}
func (db *dbCrypt) GetChatProviderByProvider(ctx context.Context, providerName string) (database.ChatProvider, error) {
provider, err := db.Store.GetChatProviderByProvider(ctx, providerName)
if err != nil {
return database.ChatProvider{}, err
}
if err := db.decryptField(&provider.APIKey, provider.ApiKeyKeyID); err != nil {
return database.ChatProvider{}, err
}
return provider, nil
}
func (db *dbCrypt) GetChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
providers, err := db.Store.GetChatProviders(ctx)
if err != nil {
return nil, err
}
for i := range providers {
if err := db.decryptField(&providers[i].APIKey, providers[i].ApiKeyKeyID); err != nil {
return nil, err
}
}
return providers, nil
}
func (db *dbCrypt) GetEnabledChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
providers, err := db.Store.GetEnabledChatProviders(ctx)
if err != nil {
return nil, err
}
for i := range providers {
if err := db.decryptField(&providers[i].APIKey, providers[i].ApiKeyKeyID); err != nil {
return nil, err
}
}
return providers, nil
}
func (db *dbCrypt) InsertChatProvider(ctx context.Context, params database.InsertChatProviderParams) (database.ChatProvider, error) {
if strings.TrimSpace(params.APIKey) == "" {
params.ApiKeyKeyID = sql.NullString{}
} else if err := db.encryptField(&params.APIKey, &params.ApiKeyKeyID); err != nil {
return database.ChatProvider{}, err
}
provider, err := db.Store.InsertChatProvider(ctx, params)
if err != nil {
return database.ChatProvider{}, err
}
if err := db.decryptField(&provider.APIKey, provider.ApiKeyKeyID); err != nil {
return database.ChatProvider{}, err
}
return provider, nil
}
func (db *dbCrypt) UpdateChatProvider(ctx context.Context, params database.UpdateChatProviderParams) (database.ChatProvider, error) {
if strings.TrimSpace(params.APIKey) == "" {
params.ApiKeyKeyID = sql.NullString{}
} else if err := db.encryptField(&params.APIKey, &params.ApiKeyKeyID); err != nil {
return database.ChatProvider{}, err
}
provider, err := db.Store.UpdateChatProvider(ctx, params)
if err != nil {
return database.ChatProvider{}, err
}
if err := db.decryptField(&provider.APIKey, provider.ApiKeyKeyID); err != nil {
return database.ChatProvider{}, err
}
return provider, nil
}
func (db *dbCrypt) encryptField(field *string, digest *sql.NullString) error {
// If no cipher is loaded, then we can't encrypt anything!
if db.ciphers == nil || db.primaryCipherDigest == "" {
return nil
}
if field == nil {
return xerrors.Errorf("developer error: encryptField called with nil field")
}
if digest == nil {
return xerrors.Errorf("developer error: encryptField called with nil digest")
}
encrypted, err := db.ciphers[db.primaryCipherDigest].Encrypt([]byte(*field))
if err != nil {
return err
}
// Base64 is used to support UTF-8 encoding in PostgreSQL.
*field = b64encode(encrypted)
*digest = sql.NullString{String: db.primaryCipherDigest, Valid: true}
return nil
}
// decryptFields decrypts the given field using the key with the given digest.
// If the value fails to decrypt, sql.ErrNoRows will be returned.
func (db *dbCrypt) decryptField(field *string, digest sql.NullString) error {
if field == nil {
return xerrors.Errorf("developer error: decryptField called with nil field")
}
if !digest.Valid || digest.String == "" {
// This field is not encrypted.
return nil
}
key, ok := db.ciphers[digest.String]
if !ok {
return &DecryptFailedError{
Inner: xerrors.Errorf("no cipher with digest %q", digest.String),
}
}
data, err := b64decode(*field)
if err != nil {
// If it's not valid base64, we should complain loudly.
return &DecryptFailedError{
Inner: xerrors.Errorf("malformed encrypted field %q: %w", *field, err),
}
}
decrypted, err := key.Decrypt(data)
if err != nil {
return &DecryptFailedError{Inner: err}
}
*field = string(decrypted)
return nil
}
func (db *dbCrypt) ensureEncryptedWithRetry(ctx context.Context) error {
var err error
for i := 0; i < 3; i++ {
err = db.ensureEncrypted(ctx)
if err == nil {
return nil
}
// If we get a serialization error, then we need to retry.
if !database.IsSerializedError(err) {
return err
}
// otherwise, retry
}
// If we get here, then we ran out of retries
return err
}
func (db *dbCrypt) ensureEncrypted(ctx context.Context) error {
return db.InTx(func(s database.Store) error {
// Attempt to read the encrypted test fields of the currently active keys.
ks, err := s.GetDBCryptKeys(ctx)
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
return err
}
var highestNumber int32
var activeCipherFound bool
for _, k := range ks {
// If our primary key has been revoked, then we can't do anything.
if k.RevokedKeyDigest.Valid && k.RevokedKeyDigest.String == db.primaryCipherDigest {
return xerrors.Errorf("primary encryption key %q has been revoked", db.primaryCipherDigest)
}
if k.ActiveKeyDigest.Valid && k.ActiveKeyDigest.String == db.primaryCipherDigest {
activeCipherFound = true
}
if k.Number > highestNumber {
highestNumber = k.Number
}
}
if activeCipherFound || len(db.ciphers) == 0 {
return nil
}
// If we get here, then we have a new key that we need to insert.
return db.InsertDBCryptKey(ctx, database.InsertDBCryptKeyParams{
Number: highestNumber + 1,
ActiveKeyDigest: db.primaryCipherDigest,
Test: testValue,
})
}, &database.TxOptions{Isolation: sql.LevelRepeatableRead})
}