mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
55e525fc28
Replace the old `InTx` ruleguard rule in `scripts/rules.go` with a custom in-tree `go/analysis` analyzer under `scripts/intxcheck/`. The new analyzer catches the same direct and pass-through misuse classes as before, plus two new classes the pattern-matcher couldn't reach: - **Indirect same-package helper misuse** — flags `p.someHelper(ctx)` inside `InTx` when the helper body uses the outer store (the PR #24369 bug class). - **Nested dangerous closures** — descends into `go func() { ... }()`, `defer func() { ... }()`, and immediately-invoked function literals. The analyzer uses semantic `types.Object` identity instead of raw expression string comparison, which avoids false positives from closure-local shadowing and catches simple aliases like `outer := s.db` and `alias := s`. This PR also fixes three real outer-store-inside-transaction bugs the new analyzer surfaced: - `coderd/wsbuilder/wsbuilder.go`: `FindMatchingPresetID` and `getWorkspaceTask` now use the inner transaction store instead of `b.store`. - `enterprise/dbcrypt/dbcrypt.go`: `ensureEncrypted` now calls `s.InsertDBCryptKey` (the tx-wrapped store) instead of `db.InsertDBCryptKey`. The `dbCrypt.InTx` method wraps the raw tx in a new `*dbCrypt`, so `s.InsertDBCryptKey` still dispatches through the encryption layer. Two call sites need `// intxcheck:ignore` suppressions. Both are one-off patterns that only look like misuse because the analyzer doesn't track assignments — proving them safe would require full dataflow analysis, which is well beyond what a targeted lint like this should attempt: - `coderd/database/dbfake/dbfake.go` — `b.db` is reassigned to `tx` on the preceding line, so `b.doInTX()` actually uses the transaction. The analyzer sees the original `b.db` identity and flags it. - `coderd/database/db_test.go` — test intentionally passes the outer store to `require.Equal` to assert that nested `InTx` returns the same handle. Suppressions use `// intxcheck:ignore` instead of `//nolint:intxcheck` because `intxcheck` runs as a standalone `go/analysis` tool outside golangci-lint. golangci-lint's `nolintlint` checker flags `//nolint` directives for linters it doesn't control, so we use a custom comment prefix to avoid that conflict.
885 lines
29 KiB
Go
885 lines
29 KiB
Go
package dbcrypt
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/base64"
|
|
"strings"
|
|
|
|
"github.com/google/uuid"
|
|
"golang.org/x/xerrors"
|
|
|
|
"github.com/coder/coder/v2/coderd/database"
|
|
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
|
)
|
|
|
|
// testValue is the value that is stored in dbcrypt_keys.test.
|
|
// This is used to determine if the key is valid.
|
|
const testValue = "coder"
|
|
|
|
var (
|
|
b64encode = base64.StdEncoding.EncodeToString
|
|
b64decode = base64.StdEncoding.DecodeString
|
|
)
|
|
|
|
// DecryptFailedError is returned when decryption fails.
|
|
type DecryptFailedError struct {
|
|
Inner error
|
|
}
|
|
|
|
func (e *DecryptFailedError) Error() string {
|
|
return xerrors.Errorf("decrypt failed: %w", e.Inner).Error()
|
|
}
|
|
|
|
// New creates a database.Store wrapper that encrypts/decrypts values
|
|
// stored at rest in the database.
|
|
func New(ctx context.Context, db database.Store, ciphers ...Cipher) (database.Store, error) {
|
|
cm := make(map[string]Cipher)
|
|
for _, c := range ciphers {
|
|
cm[c.HexDigest()] = c
|
|
}
|
|
dbc := &dbCrypt{
|
|
ciphers: cm,
|
|
Store: db,
|
|
}
|
|
if len(ciphers) > 0 {
|
|
dbc.primaryCipherDigest = ciphers[0].HexDigest()
|
|
}
|
|
// nolint: gocritic // This is allowed.
|
|
authCtx := dbauthz.AsSystemRestricted(ctx)
|
|
if err := dbc.ensureEncryptedWithRetry(authCtx); err != nil {
|
|
return nil, xerrors.Errorf("ensure encrypted database fields: %w", err)
|
|
}
|
|
return dbc, nil
|
|
}
|
|
|
|
type dbCrypt struct {
|
|
// primaryCipherDigest is the digest of the primary cipher used for encrypting data.
|
|
primaryCipherDigest string
|
|
// ciphers is a map of cipher digests to ciphers.
|
|
ciphers map[string]Cipher
|
|
database.Store
|
|
}
|
|
|
|
func (db *dbCrypt) InTx(function func(database.Store) error, txOpts *database.TxOptions) error {
|
|
return db.Store.InTx(func(s database.Store) error {
|
|
return function(&dbCrypt{
|
|
primaryCipherDigest: db.primaryCipherDigest,
|
|
ciphers: db.ciphers,
|
|
Store: s,
|
|
})
|
|
}, txOpts)
|
|
}
|
|
|
|
func (db *dbCrypt) GetDBCryptKeys(ctx context.Context) ([]database.DBCryptKey, error) {
|
|
ks, err := db.Store.GetDBCryptKeys(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
// Decrypt the test field to ensure that the key is valid.
|
|
for i := range ks {
|
|
if !ks[i].ActiveKeyDigest.Valid {
|
|
// Key has been revoked. We can't decrypt the test field, but
|
|
// we need to return it so that the caller knows that the key
|
|
// has been revoked.
|
|
continue
|
|
}
|
|
if err := db.decryptField(&ks[i].Test, ks[i].ActiveKeyDigest); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
return ks, nil
|
|
}
|
|
|
|
// This does not need any special handling as it does not touch any encrypted fields.
|
|
// Explicitly defining this here to avoid confusion.
|
|
func (db *dbCrypt) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error {
|
|
return db.Store.RevokeDBCryptKey(ctx, activeKeyDigest)
|
|
}
|
|
|
|
func (db *dbCrypt) InsertDBCryptKey(ctx context.Context, arg database.InsertDBCryptKeyParams) error {
|
|
// It's nicer to be able to pass a *sql.NullString to encryptField, but we need to pass a string here.
|
|
var digest sql.NullString
|
|
err := db.encryptField(&arg.Test, &digest)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
arg.ActiveKeyDigest = digest.String
|
|
return db.Store.InsertDBCryptKey(ctx, arg)
|
|
}
|
|
|
|
func (db *dbCrypt) GetUserLinkByLinkedID(ctx context.Context, linkedID string) (database.UserLink, error) {
|
|
link, err := db.Store.GetUserLinkByLinkedID(ctx, linkedID)
|
|
if err != nil {
|
|
return database.UserLink{}, err
|
|
}
|
|
if err := db.decryptField(&link.OAuthAccessToken, link.OAuthAccessTokenKeyID); err != nil {
|
|
return database.UserLink{}, err
|
|
}
|
|
if err := db.decryptField(&link.OAuthRefreshToken, link.OAuthRefreshTokenKeyID); err != nil {
|
|
return database.UserLink{}, err
|
|
}
|
|
return link, nil
|
|
}
|
|
|
|
func (db *dbCrypt) GetUserLinksByUserID(ctx context.Context, userID uuid.UUID) ([]database.UserLink, error) {
|
|
links, err := db.Store.GetUserLinksByUserID(ctx, userID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for idx := range links {
|
|
if err := db.decryptField(&links[idx].OAuthAccessToken, links[idx].OAuthAccessTokenKeyID); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := db.decryptField(&links[idx].OAuthRefreshToken, links[idx].OAuthRefreshTokenKeyID); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
return links, nil
|
|
}
|
|
|
|
func (db *dbCrypt) GetUserLinkByUserIDLoginType(ctx context.Context, params database.GetUserLinkByUserIDLoginTypeParams) (database.UserLink, error) {
|
|
link, err := db.Store.GetUserLinkByUserIDLoginType(ctx, params)
|
|
if err != nil {
|
|
return database.UserLink{}, err
|
|
}
|
|
if err := db.decryptField(&link.OAuthAccessToken, link.OAuthAccessTokenKeyID); err != nil {
|
|
return database.UserLink{}, err
|
|
}
|
|
if err := db.decryptField(&link.OAuthRefreshToken, link.OAuthRefreshTokenKeyID); err != nil {
|
|
return database.UserLink{}, err
|
|
}
|
|
return link, nil
|
|
}
|
|
|
|
func (db *dbCrypt) InsertUserLink(ctx context.Context, params database.InsertUserLinkParams) (database.UserLink, error) {
|
|
if err := db.encryptField(¶ms.OAuthAccessToken, ¶ms.OAuthAccessTokenKeyID); err != nil {
|
|
return database.UserLink{}, err
|
|
}
|
|
if err := db.encryptField(¶ms.OAuthRefreshToken, ¶ms.OAuthRefreshTokenKeyID); err != nil {
|
|
return database.UserLink{}, err
|
|
}
|
|
link, err := db.Store.InsertUserLink(ctx, params)
|
|
if err != nil {
|
|
return database.UserLink{}, err
|
|
}
|
|
if err := db.decryptField(&link.OAuthAccessToken, link.OAuthAccessTokenKeyID); err != nil {
|
|
return database.UserLink{}, err
|
|
}
|
|
if err := db.decryptField(&link.OAuthRefreshToken, link.OAuthRefreshTokenKeyID); err != nil {
|
|
return database.UserLink{}, err
|
|
}
|
|
return link, nil
|
|
}
|
|
|
|
func (db *dbCrypt) UpdateUserLink(ctx context.Context, params database.UpdateUserLinkParams) (database.UserLink, error) {
|
|
if err := db.encryptField(¶ms.OAuthAccessToken, ¶ms.OAuthAccessTokenKeyID); err != nil {
|
|
return database.UserLink{}, err
|
|
}
|
|
if err := db.encryptField(¶ms.OAuthRefreshToken, ¶ms.OAuthRefreshTokenKeyID); err != nil {
|
|
return database.UserLink{}, err
|
|
}
|
|
link, err := db.Store.UpdateUserLink(ctx, params)
|
|
if err != nil {
|
|
return database.UserLink{}, err
|
|
}
|
|
if err := db.decryptField(&link.OAuthAccessToken, link.OAuthAccessTokenKeyID); err != nil {
|
|
return database.UserLink{}, err
|
|
}
|
|
if err := db.decryptField(&link.OAuthRefreshToken, link.OAuthRefreshTokenKeyID); err != nil {
|
|
return database.UserLink{}, err
|
|
}
|
|
return link, nil
|
|
}
|
|
|
|
func (db *dbCrypt) InsertExternalAuthLink(ctx context.Context, params database.InsertExternalAuthLinkParams) (database.ExternalAuthLink, error) {
|
|
if err := db.encryptField(¶ms.OAuthAccessToken, ¶ms.OAuthAccessTokenKeyID); err != nil {
|
|
return database.ExternalAuthLink{}, err
|
|
}
|
|
if err := db.encryptField(¶ms.OAuthRefreshToken, ¶ms.OAuthRefreshTokenKeyID); err != nil {
|
|
return database.ExternalAuthLink{}, err
|
|
}
|
|
link, err := db.Store.InsertExternalAuthLink(ctx, params)
|
|
if err != nil {
|
|
return database.ExternalAuthLink{}, err
|
|
}
|
|
if err := db.decryptField(&link.OAuthAccessToken, link.OAuthAccessTokenKeyID); err != nil {
|
|
return database.ExternalAuthLink{}, err
|
|
}
|
|
if err := db.decryptField(&link.OAuthRefreshToken, link.OAuthRefreshTokenKeyID); err != nil {
|
|
return database.ExternalAuthLink{}, err
|
|
}
|
|
return link, nil
|
|
}
|
|
|
|
func (db *dbCrypt) GetExternalAuthLink(ctx context.Context, params database.GetExternalAuthLinkParams) (database.ExternalAuthLink, error) {
|
|
link, err := db.Store.GetExternalAuthLink(ctx, params)
|
|
if err != nil {
|
|
return database.ExternalAuthLink{}, err
|
|
}
|
|
if err := db.decryptField(&link.OAuthAccessToken, link.OAuthAccessTokenKeyID); err != nil {
|
|
return database.ExternalAuthLink{}, err
|
|
}
|
|
if err := db.decryptField(&link.OAuthRefreshToken, link.OAuthRefreshTokenKeyID); err != nil {
|
|
return database.ExternalAuthLink{}, err
|
|
}
|
|
return link, nil
|
|
}
|
|
|
|
func (db *dbCrypt) GetExternalAuthLinksByUserID(ctx context.Context, userID uuid.UUID) ([]database.ExternalAuthLink, error) {
|
|
links, err := db.Store.GetExternalAuthLinksByUserID(ctx, userID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for idx := range links {
|
|
if err := db.decryptField(&links[idx].OAuthAccessToken, links[idx].OAuthAccessTokenKeyID); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := db.decryptField(&links[idx].OAuthRefreshToken, links[idx].OAuthRefreshTokenKeyID); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
return links, nil
|
|
}
|
|
|
|
func (db *dbCrypt) UpdateExternalAuthLink(ctx context.Context, params database.UpdateExternalAuthLinkParams) (database.ExternalAuthLink, error) {
|
|
if err := db.encryptField(¶ms.OAuthAccessToken, ¶ms.OAuthAccessTokenKeyID); err != nil {
|
|
return database.ExternalAuthLink{}, err
|
|
}
|
|
if err := db.encryptField(¶ms.OAuthRefreshToken, ¶ms.OAuthRefreshTokenKeyID); err != nil {
|
|
return database.ExternalAuthLink{}, err
|
|
}
|
|
link, err := db.Store.UpdateExternalAuthLink(ctx, params)
|
|
if err != nil {
|
|
return database.ExternalAuthLink{}, err
|
|
}
|
|
if err := db.decryptField(&link.OAuthAccessToken, link.OAuthAccessTokenKeyID); err != nil {
|
|
return database.ExternalAuthLink{}, err
|
|
}
|
|
if err := db.decryptField(&link.OAuthRefreshToken, link.OAuthRefreshTokenKeyID); err != nil {
|
|
return database.ExternalAuthLink{}, err
|
|
}
|
|
return link, nil
|
|
}
|
|
|
|
func (db *dbCrypt) UpdateExternalAuthLinkRefreshToken(ctx context.Context, params database.UpdateExternalAuthLinkRefreshTokenParams) error {
|
|
// The SQL query uses an optimistic lock:
|
|
// WHERE oauth_refresh_token = @old_oauth_refresh_token
|
|
// The caller supplies the plaintext old token (since dbcrypt
|
|
// decrypts on read), but the DB stores the encrypted value.
|
|
// Because AES-GCM is non-deterministic, we cannot simply
|
|
// re-encrypt the old token — the ciphertext would differ.
|
|
// Instead, read the current row from the inner (raw) store
|
|
// and use the actual encrypted value for the WHERE clause.
|
|
if params.OldOauthRefreshToken != "" && db.ciphers != nil && db.primaryCipherDigest != "" {
|
|
raw, err := db.Store.GetExternalAuthLink(ctx, database.GetExternalAuthLinkParams{
|
|
ProviderID: params.ProviderID,
|
|
UserID: params.UserID,
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
// Decrypt the stored token so we can compare with the
|
|
// caller-supplied plaintext.
|
|
decrypted := raw.OAuthRefreshToken
|
|
if err := db.decryptField(&decrypted, raw.OAuthRefreshTokenKeyID); err != nil {
|
|
return err
|
|
}
|
|
if decrypted != params.OldOauthRefreshToken {
|
|
// The token has changed since the caller read it;
|
|
// the optimistic lock should fail (no rows updated).
|
|
// Return nil to match the :exec semantics of the SQL
|
|
// query, which silently updates zero rows.
|
|
return nil
|
|
}
|
|
// Use the raw encrypted value so the WHERE clause matches.
|
|
params.OldOauthRefreshToken = raw.OAuthRefreshToken
|
|
}
|
|
|
|
// We would normally use a sql.NullString here, but sqlc does not want to make
|
|
// a params struct with a nullable string.
|
|
var digest sql.NullString
|
|
if params.OAuthRefreshTokenKeyID != "" {
|
|
digest.String = params.OAuthRefreshTokenKeyID
|
|
digest.Valid = true
|
|
}
|
|
if err := db.encryptField(¶ms.OAuthRefreshToken, &digest); err != nil {
|
|
return err
|
|
}
|
|
|
|
return db.Store.UpdateExternalAuthLinkRefreshToken(ctx, params)
|
|
}
|
|
|
|
func (db *dbCrypt) GetCryptoKeys(ctx context.Context) ([]database.CryptoKey, error) {
|
|
keys, err := db.Store.GetCryptoKeys(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for i := range keys {
|
|
if err := db.decryptField(&keys[i].Secret.String, keys[i].SecretKeyID); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
return keys, nil
|
|
}
|
|
|
|
func (db *dbCrypt) GetLatestCryptoKeyByFeature(ctx context.Context, feature database.CryptoKeyFeature) (database.CryptoKey, error) {
|
|
key, err := db.Store.GetLatestCryptoKeyByFeature(ctx, feature)
|
|
if err != nil {
|
|
return database.CryptoKey{}, err
|
|
}
|
|
if err := db.decryptField(&key.Secret.String, key.SecretKeyID); err != nil {
|
|
return database.CryptoKey{}, err
|
|
}
|
|
return key, nil
|
|
}
|
|
|
|
func (db *dbCrypt) GetCryptoKeyByFeatureAndSequence(ctx context.Context, params database.GetCryptoKeyByFeatureAndSequenceParams) (database.CryptoKey, error) {
|
|
key, err := db.Store.GetCryptoKeyByFeatureAndSequence(ctx, params)
|
|
if err != nil {
|
|
return database.CryptoKey{}, err
|
|
}
|
|
if err := db.decryptField(&key.Secret.String, key.SecretKeyID); err != nil {
|
|
return database.CryptoKey{}, err
|
|
}
|
|
return key, nil
|
|
}
|
|
|
|
func (db *dbCrypt) InsertCryptoKey(ctx context.Context, params database.InsertCryptoKeyParams) (database.CryptoKey, error) {
|
|
if err := db.encryptField(¶ms.Secret.String, ¶ms.SecretKeyID); err != nil {
|
|
return database.CryptoKey{}, err
|
|
}
|
|
key, err := db.Store.InsertCryptoKey(ctx, params)
|
|
if err != nil {
|
|
return database.CryptoKey{}, err
|
|
}
|
|
if err := db.decryptField(&key.Secret.String, key.SecretKeyID); err != nil {
|
|
return database.CryptoKey{}, err
|
|
}
|
|
return key, nil
|
|
}
|
|
|
|
func (db *dbCrypt) UpdateCryptoKeyDeletesAt(ctx context.Context, arg database.UpdateCryptoKeyDeletesAtParams) (database.CryptoKey, error) {
|
|
key, err := db.Store.UpdateCryptoKeyDeletesAt(ctx, arg)
|
|
if err != nil {
|
|
return database.CryptoKey{}, err
|
|
}
|
|
if err := db.decryptField(&key.Secret.String, key.SecretKeyID); err != nil {
|
|
return database.CryptoKey{}, err
|
|
}
|
|
return key, nil
|
|
}
|
|
|
|
func (db *dbCrypt) GetCryptoKeysByFeature(ctx context.Context, feature database.CryptoKeyFeature) ([]database.CryptoKey, error) {
|
|
keys, err := db.Store.GetCryptoKeysByFeature(ctx, feature)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for i := range keys {
|
|
if err := db.decryptField(&keys[i].Secret.String, keys[i].SecretKeyID); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return keys, nil
|
|
}
|
|
|
|
func (db *dbCrypt) GetChatProviderByID(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) {
|
|
provider, err := db.Store.GetChatProviderByID(ctx, id)
|
|
if err != nil {
|
|
return database.ChatProvider{}, err
|
|
}
|
|
if err := db.decryptField(&provider.APIKey, provider.ApiKeyKeyID); err != nil {
|
|
return database.ChatProvider{}, err
|
|
}
|
|
return provider, nil
|
|
}
|
|
|
|
func (db *dbCrypt) GetChatProviderByProvider(ctx context.Context, providerName string) (database.ChatProvider, error) {
|
|
provider, err := db.Store.GetChatProviderByProvider(ctx, providerName)
|
|
if err != nil {
|
|
return database.ChatProvider{}, err
|
|
}
|
|
if err := db.decryptField(&provider.APIKey, provider.ApiKeyKeyID); err != nil {
|
|
return database.ChatProvider{}, err
|
|
}
|
|
return provider, nil
|
|
}
|
|
|
|
func (db *dbCrypt) GetChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
|
|
providers, err := db.Store.GetChatProviders(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for i := range providers {
|
|
if err := db.decryptField(&providers[i].APIKey, providers[i].ApiKeyKeyID); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return providers, nil
|
|
}
|
|
|
|
func (db *dbCrypt) GetEnabledChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
|
|
providers, err := db.Store.GetEnabledChatProviders(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for i := range providers {
|
|
if err := db.decryptField(&providers[i].APIKey, providers[i].ApiKeyKeyID); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return providers, nil
|
|
}
|
|
|
|
func (db *dbCrypt) InsertChatProvider(ctx context.Context, params database.InsertChatProviderParams) (database.ChatProvider, error) {
|
|
if strings.TrimSpace(params.APIKey) == "" {
|
|
params.ApiKeyKeyID = sql.NullString{}
|
|
} else if err := db.encryptField(¶ms.APIKey, ¶ms.ApiKeyKeyID); err != nil {
|
|
return database.ChatProvider{}, err
|
|
}
|
|
|
|
provider, err := db.Store.InsertChatProvider(ctx, params)
|
|
if err != nil {
|
|
return database.ChatProvider{}, err
|
|
}
|
|
if err := db.decryptField(&provider.APIKey, provider.ApiKeyKeyID); err != nil {
|
|
return database.ChatProvider{}, err
|
|
}
|
|
return provider, nil
|
|
}
|
|
|
|
func (db *dbCrypt) UpdateChatProvider(ctx context.Context, params database.UpdateChatProviderParams) (database.ChatProvider, error) {
|
|
if strings.TrimSpace(params.APIKey) == "" {
|
|
params.ApiKeyKeyID = sql.NullString{}
|
|
} else if err := db.encryptField(¶ms.APIKey, ¶ms.ApiKeyKeyID); err != nil {
|
|
return database.ChatProvider{}, err
|
|
}
|
|
|
|
provider, err := db.Store.UpdateChatProvider(ctx, params)
|
|
if err != nil {
|
|
return database.ChatProvider{}, err
|
|
}
|
|
if err := db.decryptField(&provider.APIKey, provider.ApiKeyKeyID); err != nil {
|
|
return database.ChatProvider{}, err
|
|
}
|
|
return provider, nil
|
|
}
|
|
|
|
func (db *dbCrypt) decryptUserChatProviderKey(key *database.UserChatProviderKey) error {
|
|
return db.decryptField(&key.APIKey, key.ApiKeyKeyID)
|
|
}
|
|
|
|
func (db *dbCrypt) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]database.UserChatProviderKey, error) {
|
|
keys, err := db.Store.GetUserChatProviderKeys(ctx, userID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for i := range keys {
|
|
if err := db.decryptUserChatProviderKey(&keys[i]); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
return keys, nil
|
|
}
|
|
|
|
func (db *dbCrypt) UpsertUserChatProviderKey(ctx context.Context, params database.UpsertUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
|
|
if strings.TrimSpace(params.APIKey) == "" {
|
|
params.ApiKeyKeyID = sql.NullString{}
|
|
} else if err := db.encryptField(¶ms.APIKey, ¶ms.ApiKeyKeyID); err != nil {
|
|
return database.UserChatProviderKey{}, err
|
|
}
|
|
|
|
key, err := db.Store.UpsertUserChatProviderKey(ctx, params)
|
|
if err != nil {
|
|
return database.UserChatProviderKey{}, err
|
|
}
|
|
if err := db.decryptUserChatProviderKey(&key); err != nil {
|
|
return database.UserChatProviderKey{}, err
|
|
}
|
|
return key, nil
|
|
}
|
|
|
|
func (db *dbCrypt) UpdateUserChatProviderKey(ctx context.Context, params database.UpdateUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
|
|
if strings.TrimSpace(params.APIKey) == "" {
|
|
params.ApiKeyKeyID = sql.NullString{}
|
|
} else if err := db.encryptField(¶ms.APIKey, ¶ms.ApiKeyKeyID); err != nil {
|
|
return database.UserChatProviderKey{}, err
|
|
}
|
|
|
|
key, err := db.Store.UpdateUserChatProviderKey(ctx, params)
|
|
if err != nil {
|
|
return database.UserChatProviderKey{}, err
|
|
}
|
|
if err := db.decryptUserChatProviderKey(&key); err != nil {
|
|
return database.UserChatProviderKey{}, err
|
|
}
|
|
return key, nil
|
|
}
|
|
|
|
// decryptMCPServerConfig decrypts all encrypted fields on a
|
|
// single MCPServerConfig in place.
|
|
func (db *dbCrypt) decryptMCPServerConfig(cfg *database.MCPServerConfig) error {
|
|
if err := db.decryptField(&cfg.OAuth2ClientSecret, cfg.OAuth2ClientSecretKeyID); err != nil {
|
|
return err
|
|
}
|
|
if err := db.decryptField(&cfg.APIKeyValue, cfg.APIKeyValueKeyID); err != nil {
|
|
return err
|
|
}
|
|
return db.decryptField(&cfg.CustomHeaders, cfg.CustomHeadersKeyID)
|
|
}
|
|
|
|
// decryptMCPServerUserToken decrypts all encrypted fields on a
|
|
// single MCPServerUserToken in place.
|
|
func (db *dbCrypt) decryptMCPServerUserToken(tok *database.MCPServerUserToken) error {
|
|
if err := db.decryptField(&tok.AccessToken, tok.AccessTokenKeyID); err != nil {
|
|
return err
|
|
}
|
|
return db.decryptField(&tok.RefreshToken, tok.RefreshTokenKeyID)
|
|
}
|
|
|
|
func (db *dbCrypt) GetMCPServerConfigByID(ctx context.Context, id uuid.UUID) (database.MCPServerConfig, error) {
|
|
cfg, err := db.Store.GetMCPServerConfigByID(ctx, id)
|
|
if err != nil {
|
|
return database.MCPServerConfig{}, err
|
|
}
|
|
if err := db.decryptMCPServerConfig(&cfg); err != nil {
|
|
return database.MCPServerConfig{}, err
|
|
}
|
|
return cfg, nil
|
|
}
|
|
|
|
func (db *dbCrypt) GetMCPServerConfigBySlug(ctx context.Context, slug string) (database.MCPServerConfig, error) {
|
|
cfg, err := db.Store.GetMCPServerConfigBySlug(ctx, slug)
|
|
if err != nil {
|
|
return database.MCPServerConfig{}, err
|
|
}
|
|
if err := db.decryptMCPServerConfig(&cfg); err != nil {
|
|
return database.MCPServerConfig{}, err
|
|
}
|
|
return cfg, nil
|
|
}
|
|
|
|
func (db *dbCrypt) GetMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) {
|
|
cfgs, err := db.Store.GetMCPServerConfigs(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for i := range cfgs {
|
|
if err := db.decryptMCPServerConfig(&cfgs[i]); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
return cfgs, nil
|
|
}
|
|
|
|
func (db *dbCrypt) GetMCPServerConfigsByIDs(ctx context.Context, ids []uuid.UUID) ([]database.MCPServerConfig, error) {
|
|
cfgs, err := db.Store.GetMCPServerConfigsByIDs(ctx, ids)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for i := range cfgs {
|
|
if err := db.decryptMCPServerConfig(&cfgs[i]); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
return cfgs, nil
|
|
}
|
|
|
|
func (db *dbCrypt) GetEnabledMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) {
|
|
cfgs, err := db.Store.GetEnabledMCPServerConfigs(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for i := range cfgs {
|
|
if err := db.decryptMCPServerConfig(&cfgs[i]); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
return cfgs, nil
|
|
}
|
|
|
|
func (db *dbCrypt) GetForcedMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) {
|
|
cfgs, err := db.Store.GetForcedMCPServerConfigs(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for i := range cfgs {
|
|
if err := db.decryptMCPServerConfig(&cfgs[i]); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
return cfgs, nil
|
|
}
|
|
|
|
func (db *dbCrypt) GetMCPServerUserToken(ctx context.Context, arg database.GetMCPServerUserTokenParams) (database.MCPServerUserToken, error) {
|
|
tok, err := db.Store.GetMCPServerUserToken(ctx, arg)
|
|
if err != nil {
|
|
return database.MCPServerUserToken{}, err
|
|
}
|
|
if err := db.decryptMCPServerUserToken(&tok); err != nil {
|
|
return database.MCPServerUserToken{}, err
|
|
}
|
|
return tok, nil
|
|
}
|
|
|
|
func (db *dbCrypt) GetMCPServerUserTokensByUserID(ctx context.Context, userID uuid.UUID) ([]database.MCPServerUserToken, error) {
|
|
toks, err := db.Store.GetMCPServerUserTokensByUserID(ctx, userID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for i := range toks {
|
|
if err := db.decryptMCPServerUserToken(&toks[i]); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
return toks, nil
|
|
}
|
|
|
|
func (db *dbCrypt) InsertMCPServerConfig(ctx context.Context, params database.InsertMCPServerConfigParams) (database.MCPServerConfig, error) {
|
|
if strings.TrimSpace(params.OAuth2ClientSecret) == "" {
|
|
params.OAuth2ClientSecretKeyID = sql.NullString{}
|
|
} else if err := db.encryptField(¶ms.OAuth2ClientSecret, ¶ms.OAuth2ClientSecretKeyID); err != nil {
|
|
return database.MCPServerConfig{}, err
|
|
}
|
|
if strings.TrimSpace(params.APIKeyValue) == "" {
|
|
params.APIKeyValueKeyID = sql.NullString{}
|
|
} else if err := db.encryptField(¶ms.APIKeyValue, ¶ms.APIKeyValueKeyID); err != nil {
|
|
return database.MCPServerConfig{}, err
|
|
}
|
|
if strings.TrimSpace(params.CustomHeaders) == "" {
|
|
params.CustomHeadersKeyID = sql.NullString{}
|
|
} else if err := db.encryptField(¶ms.CustomHeaders, ¶ms.CustomHeadersKeyID); err != nil {
|
|
return database.MCPServerConfig{}, err
|
|
}
|
|
|
|
cfg, err := db.Store.InsertMCPServerConfig(ctx, params)
|
|
if err != nil {
|
|
return database.MCPServerConfig{}, err
|
|
}
|
|
if err := db.decryptMCPServerConfig(&cfg); err != nil {
|
|
return database.MCPServerConfig{}, err
|
|
}
|
|
return cfg, nil
|
|
}
|
|
|
|
func (db *dbCrypt) UpdateMCPServerConfig(ctx context.Context, params database.UpdateMCPServerConfigParams) (database.MCPServerConfig, error) {
|
|
if strings.TrimSpace(params.OAuth2ClientSecret) == "" {
|
|
params.OAuth2ClientSecretKeyID = sql.NullString{}
|
|
} else if err := db.encryptField(¶ms.OAuth2ClientSecret, ¶ms.OAuth2ClientSecretKeyID); err != nil {
|
|
return database.MCPServerConfig{}, err
|
|
}
|
|
if strings.TrimSpace(params.APIKeyValue) == "" {
|
|
params.APIKeyValueKeyID = sql.NullString{}
|
|
} else if err := db.encryptField(¶ms.APIKeyValue, ¶ms.APIKeyValueKeyID); err != nil {
|
|
return database.MCPServerConfig{}, err
|
|
}
|
|
if strings.TrimSpace(params.CustomHeaders) == "" {
|
|
params.CustomHeadersKeyID = sql.NullString{}
|
|
} else if err := db.encryptField(¶ms.CustomHeaders, ¶ms.CustomHeadersKeyID); err != nil {
|
|
return database.MCPServerConfig{}, err
|
|
}
|
|
|
|
cfg, err := db.Store.UpdateMCPServerConfig(ctx, params)
|
|
if err != nil {
|
|
return database.MCPServerConfig{}, err
|
|
}
|
|
if err := db.decryptMCPServerConfig(&cfg); err != nil {
|
|
return database.MCPServerConfig{}, err
|
|
}
|
|
return cfg, nil
|
|
}
|
|
|
|
func (db *dbCrypt) UpsertMCPServerUserToken(ctx context.Context, params database.UpsertMCPServerUserTokenParams) (database.MCPServerUserToken, error) {
|
|
if strings.TrimSpace(params.AccessToken) == "" {
|
|
params.AccessTokenKeyID = sql.NullString{}
|
|
} else if err := db.encryptField(¶ms.AccessToken, ¶ms.AccessTokenKeyID); err != nil {
|
|
return database.MCPServerUserToken{}, err
|
|
}
|
|
if strings.TrimSpace(params.RefreshToken) == "" {
|
|
params.RefreshTokenKeyID = sql.NullString{}
|
|
} else if err := db.encryptField(¶ms.RefreshToken, ¶ms.RefreshTokenKeyID); err != nil {
|
|
return database.MCPServerUserToken{}, err
|
|
}
|
|
|
|
tok, err := db.Store.UpsertMCPServerUserToken(ctx, params)
|
|
if err != nil {
|
|
return database.MCPServerUserToken{}, err
|
|
}
|
|
if err := db.decryptMCPServerUserToken(&tok); err != nil {
|
|
return database.MCPServerUserToken{}, err
|
|
}
|
|
return tok, nil
|
|
}
|
|
|
|
func (db *dbCrypt) CreateUserSecret(ctx context.Context, params database.CreateUserSecretParams) (database.UserSecret, error) {
|
|
if err := db.encryptField(¶ms.Value, ¶ms.ValueKeyID); err != nil {
|
|
return database.UserSecret{}, err
|
|
}
|
|
secret, err := db.Store.CreateUserSecret(ctx, params)
|
|
if err != nil {
|
|
return database.UserSecret{}, err
|
|
}
|
|
if err := db.decryptField(&secret.Value, secret.ValueKeyID); err != nil {
|
|
return database.UserSecret{}, err
|
|
}
|
|
return secret, nil
|
|
}
|
|
|
|
func (db *dbCrypt) GetUserSecretByUserIDAndName(ctx context.Context, arg database.GetUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
|
|
secret, err := db.Store.GetUserSecretByUserIDAndName(ctx, arg)
|
|
if err != nil {
|
|
return database.UserSecret{}, err
|
|
}
|
|
if err := db.decryptField(&secret.Value, secret.ValueKeyID); err != nil {
|
|
return database.UserSecret{}, err
|
|
}
|
|
return secret, nil
|
|
}
|
|
|
|
func (db *dbCrypt) ListUserSecretsWithValues(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
|
|
secrets, err := db.Store.ListUserSecretsWithValues(ctx, userID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for i := range secrets {
|
|
if err := db.decryptField(&secrets[i].Value, secrets[i].ValueKeyID); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
return secrets, nil
|
|
}
|
|
|
|
func (db *dbCrypt) UpdateUserSecretByUserIDAndName(ctx context.Context, arg database.UpdateUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
|
|
if arg.UpdateValue {
|
|
if err := db.encryptField(&arg.Value, &arg.ValueKeyID); err != nil {
|
|
return database.UserSecret{}, err
|
|
}
|
|
}
|
|
secret, err := db.Store.UpdateUserSecretByUserIDAndName(ctx, arg)
|
|
if err != nil {
|
|
return database.UserSecret{}, err
|
|
}
|
|
if err := db.decryptField(&secret.Value, secret.ValueKeyID); err != nil {
|
|
return database.UserSecret{}, err
|
|
}
|
|
return secret, nil
|
|
}
|
|
|
|
func (db *dbCrypt) encryptField(field *string, digest *sql.NullString) error {
|
|
// If no cipher is loaded, then we can't encrypt anything!
|
|
if db.ciphers == nil || db.primaryCipherDigest == "" {
|
|
return nil
|
|
}
|
|
|
|
if field == nil {
|
|
return xerrors.Errorf("developer error: encryptField called with nil field")
|
|
}
|
|
if digest == nil {
|
|
return xerrors.Errorf("developer error: encryptField called with nil digest")
|
|
}
|
|
|
|
encrypted, err := db.ciphers[db.primaryCipherDigest].Encrypt([]byte(*field))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
// Base64 is used to support UTF-8 encoding in PostgreSQL.
|
|
*field = b64encode(encrypted)
|
|
*digest = sql.NullString{String: db.primaryCipherDigest, Valid: true}
|
|
return nil
|
|
}
|
|
|
|
// decryptFields decrypts the given field using the key with the given digest.
|
|
// If the value fails to decrypt, sql.ErrNoRows will be returned.
|
|
func (db *dbCrypt) decryptField(field *string, digest sql.NullString) error {
|
|
if field == nil {
|
|
return xerrors.Errorf("developer error: decryptField called with nil field")
|
|
}
|
|
|
|
if !digest.Valid || digest.String == "" {
|
|
// This field is not encrypted.
|
|
return nil
|
|
}
|
|
|
|
key, ok := db.ciphers[digest.String]
|
|
if !ok {
|
|
return &DecryptFailedError{
|
|
Inner: xerrors.Errorf("no cipher with digest %q", digest.String),
|
|
}
|
|
}
|
|
|
|
data, err := b64decode(*field)
|
|
if err != nil {
|
|
// If it's not valid base64, we should complain loudly.
|
|
return &DecryptFailedError{
|
|
Inner: xerrors.Errorf("malformed encrypted field %q: %w", *field, err),
|
|
}
|
|
}
|
|
decrypted, err := key.Decrypt(data)
|
|
if err != nil {
|
|
return &DecryptFailedError{Inner: err}
|
|
}
|
|
*field = string(decrypted)
|
|
return nil
|
|
}
|
|
|
|
func (db *dbCrypt) ensureEncryptedWithRetry(ctx context.Context) error {
|
|
var err error
|
|
for i := 0; i < 3; i++ {
|
|
err = db.ensureEncrypted(ctx)
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
// If we get a serialization error, then we need to retry.
|
|
if !database.IsSerializedError(err) {
|
|
return err
|
|
}
|
|
// otherwise, retry
|
|
}
|
|
// If we get here, then we ran out of retries
|
|
return err
|
|
}
|
|
|
|
func (db *dbCrypt) ensureEncrypted(ctx context.Context) error {
|
|
return db.InTx(func(s database.Store) error {
|
|
// Attempt to read the encrypted test fields of the currently active keys.
|
|
ks, err := s.GetDBCryptKeys(ctx)
|
|
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
|
|
return err
|
|
}
|
|
|
|
var highestNumber int32
|
|
var activeCipherFound bool
|
|
for _, k := range ks {
|
|
// If our primary key has been revoked, then we can't do anything.
|
|
if k.RevokedKeyDigest.Valid && k.RevokedKeyDigest.String == db.primaryCipherDigest {
|
|
return xerrors.Errorf("primary encryption key %q has been revoked", db.primaryCipherDigest)
|
|
}
|
|
|
|
if k.ActiveKeyDigest.Valid && k.ActiveKeyDigest.String == db.primaryCipherDigest {
|
|
activeCipherFound = true
|
|
}
|
|
|
|
if k.Number > highestNumber {
|
|
highestNumber = k.Number
|
|
}
|
|
}
|
|
|
|
if activeCipherFound || len(db.ciphers) == 0 {
|
|
return nil
|
|
}
|
|
|
|
// If we get here, then we have a new key that we need to insert.
|
|
return s.InsertDBCryptKey(ctx, database.InsertDBCryptKeyParams{
|
|
Number: highestNumber + 1,
|
|
ActiveKeyDigest: db.primaryCipherDigest,
|
|
Test: testValue,
|
|
})
|
|
}, &database.TxOptions{Isolation: sql.LevelRepeatableRead})
|
|
}
|