Files
coder/enterprise/dbcrypt/dbcrypt_internal_test.go
T
Cian Johnston e744cde86f fix(coderd): ensure that clearing invalid oauth refresh tokens works with dbcrypt (#15721)
https://github.com/coder/coder/pull/15608 introduced a buggy behaviour
with dbcrypt enabled.
When clearing an oauth refresh token, we had been setting the value to
the empty string.
The database encryption package considers decrypting an empty string to
be an error, as an empty encrypted string value will still have a nonce
associated with it and thus not actually be empty when stored at rest.

Instead of 'deleting' the refresh token, 'update' it to be the empty
string.
This plays nicely with dbcrypt.

It also adds a 'utility test' in the dbcrypt package to help encrypt a
value. This was useful when manually fixing users affected by this bug
on our dogfood instance.
2024-12-03 13:26:31 -06:00

880 lines
30 KiB
Go

package dbcrypt
import (
"context"
"crypto/rand"
"database/sql"
"encoding/base64"
"io"
"testing"
"time"
"github.com/lib/pq"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/database/dbtime"
)
func TestUserLinks(t *testing.T) {
t.Parallel()
ctx := context.Background()
t.Run("InsertUserLink", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
user := dbgen.User(t, crypt, database.User{})
link := dbgen.UserLink(t, crypt, database.UserLink{
UserID: user.ID,
OAuthAccessToken: "access",
OAuthRefreshToken: "refresh",
})
require.Equal(t, "access", link.OAuthAccessToken)
require.Equal(t, "refresh", link.OAuthRefreshToken)
require.Equal(t, ciphers[0].HexDigest(), link.OAuthAccessTokenKeyID.String)
require.Equal(t, ciphers[0].HexDigest(), link.OAuthRefreshTokenKeyID.String)
rawLink, err := db.GetUserLinkByLinkedID(ctx, link.LinkedID)
require.NoError(t, err)
requireEncryptedEquals(t, ciphers[0], rawLink.OAuthAccessToken, "access")
requireEncryptedEquals(t, ciphers[0], rawLink.OAuthRefreshToken, "refresh")
})
t.Run("UpdateUserLink", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
user := dbgen.User(t, crypt, database.User{})
link := dbgen.UserLink(t, crypt, database.UserLink{
UserID: user.ID,
})
expectedClaims := database.UserLinkClaims{
IDTokenClaims: map[string]interface{}{
"sub": "123",
"groups": []interface{}{
"foo", "bar",
},
},
UserInfoClaims: map[string]interface{}{
"number": float64(2),
"struct": map[string]interface{}{
"number": float64(2),
},
},
MergedClaims: map[string]interface{}{
"sub": "123",
"groups": []interface{}{
"foo", "bar",
},
"number": float64(2),
"struct": map[string]interface{}{
"number": float64(2),
},
},
}
updated, err := crypt.UpdateUserLink(ctx, database.UpdateUserLinkParams{
OAuthAccessToken: "access",
OAuthRefreshToken: "refresh",
UserID: link.UserID,
LoginType: link.LoginType,
Claims: expectedClaims,
})
require.NoError(t, err)
require.Equal(t, "access", updated.OAuthAccessToken)
require.Equal(t, "refresh", updated.OAuthRefreshToken)
require.Equal(t, ciphers[0].HexDigest(), link.OAuthAccessTokenKeyID.String)
require.Equal(t, ciphers[0].HexDigest(), link.OAuthRefreshTokenKeyID.String)
rawLink, err := db.GetUserLinkByLinkedID(ctx, link.LinkedID)
require.NoError(t, err)
requireEncryptedEquals(t, ciphers[0], rawLink.OAuthAccessToken, "access")
requireEncryptedEquals(t, ciphers[0], rawLink.OAuthRefreshToken, "refresh")
require.EqualValues(t, expectedClaims, rawLink.Claims)
})
t.Run("UpdateExternalAuthLinkRefreshToken", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
user := dbgen.User(t, crypt, database.User{})
link := dbgen.ExternalAuthLink(t, crypt, database.ExternalAuthLink{
UserID: user.ID,
})
err := crypt.UpdateExternalAuthLinkRefreshToken(ctx, database.UpdateExternalAuthLinkRefreshTokenParams{
OAuthRefreshToken: "",
OAuthRefreshTokenKeyID: link.OAuthRefreshTokenKeyID.String,
UpdatedAt: dbtime.Now(),
ProviderID: link.ProviderID,
UserID: link.UserID,
})
require.NoError(t, err)
rawLink, err := db.GetExternalAuthLink(ctx, database.GetExternalAuthLinkParams{
ProviderID: link.ProviderID,
UserID: link.UserID,
})
require.NoError(t, err)
requireEncryptedEquals(t, ciphers[0], rawLink.OAuthRefreshToken, "")
})
t.Run("GetUserLinkByLinkedID", func(t *testing.T) {
t.Parallel()
t.Run("OK", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
user := dbgen.User(t, crypt, database.User{})
link := dbgen.UserLink(t, crypt, database.UserLink{
UserID: user.ID,
OAuthAccessToken: "access",
OAuthRefreshToken: "refresh",
})
link, err := crypt.GetUserLinkByLinkedID(ctx, link.LinkedID)
require.NoError(t, err)
require.Equal(t, "access", link.OAuthAccessToken)
require.Equal(t, "refresh", link.OAuthRefreshToken)
require.Equal(t, ciphers[0].HexDigest(), link.OAuthAccessTokenKeyID.String)
require.Equal(t, ciphers[0].HexDigest(), link.OAuthRefreshTokenKeyID.String)
rawLink, err := db.GetUserLinkByLinkedID(ctx, link.LinkedID)
require.NoError(t, err)
requireEncryptedEquals(t, ciphers[0], rawLink.OAuthAccessToken, "access")
requireEncryptedEquals(t, ciphers[0], rawLink.OAuthRefreshToken, "refresh")
})
t.Run("DecryptErr", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
user := dbgen.User(t, db, database.User{})
link := dbgen.UserLink(t, db, database.UserLink{
UserID: user.ID,
OAuthAccessToken: fakeBase64RandomData(t, 32),
OAuthRefreshToken: fakeBase64RandomData(t, 32),
OAuthAccessTokenKeyID: sql.NullString{String: ciphers[0].HexDigest(), Valid: true},
OAuthRefreshTokenKeyID: sql.NullString{String: ciphers[0].HexDigest(), Valid: true},
})
_, err := crypt.GetUserLinkByLinkedID(ctx, link.LinkedID)
require.Error(t, err, "expected an error")
var derr *DecryptFailedError
require.ErrorAs(t, err, &derr, "expected a decrypt error")
})
})
t.Run("GetUserLinksByUserID", func(t *testing.T) {
t.Parallel()
t.Run("OK", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
user := dbgen.User(t, crypt, database.User{})
link := dbgen.UserLink(t, crypt, database.UserLink{
UserID: user.ID,
OAuthAccessToken: "access",
OAuthRefreshToken: "refresh",
})
links, err := crypt.GetUserLinksByUserID(ctx, link.UserID)
require.NoError(t, err)
require.Len(t, links, 1)
require.Equal(t, "access", links[0].OAuthAccessToken)
require.Equal(t, "refresh", links[0].OAuthRefreshToken)
require.Equal(t, ciphers[0].HexDigest(), links[0].OAuthAccessTokenKeyID.String)
require.Equal(t, ciphers[0].HexDigest(), links[0].OAuthRefreshTokenKeyID.String)
rawLinks, err := db.GetUserLinksByUserID(ctx, link.UserID)
require.NoError(t, err)
require.Len(t, rawLinks, 1)
requireEncryptedEquals(t, ciphers[0], rawLinks[0].OAuthAccessToken, "access")
requireEncryptedEquals(t, ciphers[0], rawLinks[0].OAuthRefreshToken, "refresh")
})
t.Run("Empty", func(t *testing.T) {
t.Parallel()
_, crypt, _ := setup(t)
user := dbgen.User(t, crypt, database.User{})
links, err := crypt.GetUserLinksByUserID(ctx, user.ID)
require.NoError(t, err)
require.Empty(t, links)
})
t.Run("DecryptErr", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
user := dbgen.User(t, db, database.User{})
_ = dbgen.UserLink(t, db, database.UserLink{
UserID: user.ID,
OAuthAccessToken: fakeBase64RandomData(t, 32),
OAuthRefreshToken: fakeBase64RandomData(t, 32),
OAuthAccessTokenKeyID: sql.NullString{String: ciphers[0].HexDigest(), Valid: true},
OAuthRefreshTokenKeyID: sql.NullString{String: ciphers[0].HexDigest(), Valid: true},
})
_, err := crypt.GetUserLinksByUserID(ctx, user.ID)
require.Error(t, err, "expected an error")
var derr *DecryptFailedError
require.ErrorAs(t, err, &derr, "expected a decrypt error")
})
})
t.Run("GetUserLinkByUserIDLoginType", func(t *testing.T) {
t.Parallel()
t.Run("OK", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
user := dbgen.User(t, crypt, database.User{})
link := dbgen.UserLink(t, crypt, database.UserLink{
UserID: user.ID,
OAuthAccessToken: "access",
OAuthRefreshToken: "refresh",
})
link, err := crypt.GetUserLinkByUserIDLoginType(ctx, database.GetUserLinkByUserIDLoginTypeParams{
UserID: link.UserID,
LoginType: link.LoginType,
})
require.NoError(t, err)
require.Equal(t, "access", link.OAuthAccessToken)
require.Equal(t, "refresh", link.OAuthRefreshToken)
require.Equal(t, ciphers[0].HexDigest(), link.OAuthAccessTokenKeyID.String)
require.Equal(t, ciphers[0].HexDigest(), link.OAuthRefreshTokenKeyID.String)
rawLink, err := db.GetUserLinkByUserIDLoginType(ctx, database.GetUserLinkByUserIDLoginTypeParams{
UserID: link.UserID,
LoginType: link.LoginType,
})
require.NoError(t, err)
requireEncryptedEquals(t, ciphers[0], rawLink.OAuthAccessToken, "access")
requireEncryptedEquals(t, ciphers[0], rawLink.OAuthRefreshToken, "refresh")
})
t.Run("DecryptErr", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
user := dbgen.User(t, db, database.User{})
link := dbgen.UserLink(t, db, database.UserLink{
UserID: user.ID,
OAuthAccessToken: fakeBase64RandomData(t, 32),
OAuthRefreshToken: fakeBase64RandomData(t, 32),
OAuthAccessTokenKeyID: sql.NullString{String: ciphers[0].HexDigest(), Valid: true},
OAuthRefreshTokenKeyID: sql.NullString{String: ciphers[0].HexDigest(), Valid: true},
})
_, err := crypt.GetUserLinkByUserIDLoginType(ctx, database.GetUserLinkByUserIDLoginTypeParams{
UserID: link.UserID,
LoginType: link.LoginType,
})
require.Error(t, err, "expected an error")
var derr *DecryptFailedError
require.ErrorAs(t, err, &derr, "expected a decrypt error")
})
})
}
func TestExternalAuthLinks(t *testing.T) {
t.Parallel()
ctx := context.Background()
t.Run("InsertExternalAuthLink", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
link := dbgen.ExternalAuthLink(t, crypt, database.ExternalAuthLink{
OAuthAccessToken: "access",
OAuthRefreshToken: "refresh",
})
require.Equal(t, "access", link.OAuthAccessToken)
require.Equal(t, "refresh", link.OAuthRefreshToken)
link, err := db.GetExternalAuthLink(ctx, database.GetExternalAuthLinkParams{
ProviderID: link.ProviderID,
UserID: link.UserID,
})
require.NoError(t, err)
requireEncryptedEquals(t, ciphers[0], link.OAuthAccessToken, "access")
requireEncryptedEquals(t, ciphers[0], link.OAuthRefreshToken, "refresh")
})
t.Run("UpdateExternalAuthLink", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
link := dbgen.ExternalAuthLink(t, crypt, database.ExternalAuthLink{})
updated, err := crypt.UpdateExternalAuthLink(ctx, database.UpdateExternalAuthLinkParams{
ProviderID: link.ProviderID,
UserID: link.UserID,
OAuthAccessToken: "access",
OAuthRefreshToken: "refresh",
})
require.NoError(t, err)
require.Equal(t, "access", updated.OAuthAccessToken)
require.Equal(t, "refresh", updated.OAuthRefreshToken)
link, err = db.GetExternalAuthLink(ctx, database.GetExternalAuthLinkParams{
ProviderID: link.ProviderID,
UserID: link.UserID,
})
require.NoError(t, err)
requireEncryptedEquals(t, ciphers[0], link.OAuthAccessToken, "access")
requireEncryptedEquals(t, ciphers[0], link.OAuthRefreshToken, "refresh")
})
t.Run("GetExternalAuthLink", func(t *testing.T) {
t.Run("OK", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
link := dbgen.ExternalAuthLink(t, crypt, database.ExternalAuthLink{
OAuthAccessToken: "access",
OAuthRefreshToken: "refresh",
})
link, err := db.GetExternalAuthLink(ctx, database.GetExternalAuthLinkParams{
UserID: link.UserID,
ProviderID: link.ProviderID,
})
require.NoError(t, err)
requireEncryptedEquals(t, ciphers[0], link.OAuthAccessToken, "access")
requireEncryptedEquals(t, ciphers[0], link.OAuthRefreshToken, "refresh")
})
t.Run("DecryptErr", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
link := dbgen.ExternalAuthLink(t, db, database.ExternalAuthLink{
OAuthAccessToken: fakeBase64RandomData(t, 32),
OAuthRefreshToken: fakeBase64RandomData(t, 32),
OAuthAccessTokenKeyID: sql.NullString{String: ciphers[0].HexDigest(), Valid: true},
OAuthRefreshTokenKeyID: sql.NullString{String: ciphers[0].HexDigest(), Valid: true},
})
_, err := crypt.GetExternalAuthLink(ctx, database.GetExternalAuthLinkParams{
UserID: link.UserID,
ProviderID: link.ProviderID,
})
require.Error(t, err, "expected an error")
var derr *DecryptFailedError
require.ErrorAs(t, err, &derr, "expected a decrypt error")
})
})
t.Run("GetExternalAuthLinksByUserID", func(t *testing.T) {
t.Parallel()
t.Run("OK", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
user := dbgen.User(t, crypt, database.User{})
link := dbgen.ExternalAuthLink(t, crypt, database.ExternalAuthLink{
UserID: user.ID,
OAuthAccessToken: "access",
OAuthRefreshToken: "refresh",
})
links, err := crypt.GetExternalAuthLinksByUserID(ctx, link.UserID)
require.NoError(t, err)
require.Len(t, links, 1)
require.Equal(t, "access", links[0].OAuthAccessToken)
require.Equal(t, "refresh", links[0].OAuthRefreshToken)
require.Equal(t, ciphers[0].HexDigest(), links[0].OAuthAccessTokenKeyID.String)
require.Equal(t, ciphers[0].HexDigest(), links[0].OAuthRefreshTokenKeyID.String)
rawLinks, err := db.GetExternalAuthLinksByUserID(ctx, link.UserID)
require.NoError(t, err)
require.Len(t, rawLinks, 1)
requireEncryptedEquals(t, ciphers[0], rawLinks[0].OAuthAccessToken, "access")
requireEncryptedEquals(t, ciphers[0], rawLinks[0].OAuthRefreshToken, "refresh")
})
t.Run("DecryptErr", func(t *testing.T) {
db, crypt, ciphers := setup(t)
user := dbgen.User(t, db, database.User{})
link := dbgen.ExternalAuthLink(t, db, database.ExternalAuthLink{
UserID: user.ID,
OAuthAccessToken: fakeBase64RandomData(t, 32),
OAuthRefreshToken: fakeBase64RandomData(t, 32),
OAuthAccessTokenKeyID: sql.NullString{String: ciphers[0].HexDigest(), Valid: true},
OAuthRefreshTokenKeyID: sql.NullString{String: ciphers[0].HexDigest(), Valid: true},
})
_, err := crypt.GetExternalAuthLinksByUserID(ctx, link.UserID)
require.Error(t, err, "expected an error")
var derr *DecryptFailedError
require.ErrorAs(t, err, &derr, "expected a decrypt error")
})
})
}
func TestCryptoKeys(t *testing.T) {
t.Parallel()
ctx := context.Background()
t.Run("InsertCryptoKey", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
key := dbgen.CryptoKey(t, crypt, database.CryptoKey{
Secret: sql.NullString{String: "test", Valid: true},
})
require.Equal(t, "test", key.Secret.String)
key, err := db.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{
Feature: key.Feature,
Sequence: key.Sequence,
})
require.NoError(t, err)
require.Equal(t, ciphers[0].HexDigest(), key.SecretKeyID.String)
requireEncryptedEquals(t, ciphers[0], key.Secret.String, "test")
})
t.Run("GetCryptoKeys", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
_ = dbgen.CryptoKey(t, crypt, database.CryptoKey{
Secret: sql.NullString{String: "test", Valid: true},
})
keys, err := crypt.GetCryptoKeys(ctx)
require.NoError(t, err)
require.Len(t, keys, 1)
require.Equal(t, "test", keys[0].Secret.String)
require.Equal(t, ciphers[0].HexDigest(), keys[0].SecretKeyID.String)
keys, err = db.GetCryptoKeys(ctx)
require.NoError(t, err)
require.Len(t, keys, 1)
requireEncryptedEquals(t, ciphers[0], keys[0].Secret.String, "test")
require.Equal(t, ciphers[0].HexDigest(), keys[0].SecretKeyID.String)
})
t.Run("GetLatestCryptoKeyByFeature", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
_ = dbgen.CryptoKey(t, crypt, database.CryptoKey{
Secret: sql.NullString{String: "test", Valid: true},
})
key, err := crypt.GetLatestCryptoKeyByFeature(ctx, database.CryptoKeyFeatureWorkspaceAppsAPIKey)
require.NoError(t, err)
require.Equal(t, "test", key.Secret.String)
require.Equal(t, ciphers[0].HexDigest(), key.SecretKeyID.String)
key, err = db.GetLatestCryptoKeyByFeature(ctx, database.CryptoKeyFeatureWorkspaceAppsAPIKey)
require.NoError(t, err)
requireEncryptedEquals(t, ciphers[0], key.Secret.String, "test")
require.Equal(t, ciphers[0].HexDigest(), key.SecretKeyID.String)
})
t.Run("GetCryptoKeyByFeatureAndSequence", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
key := dbgen.CryptoKey(t, crypt, database.CryptoKey{
Secret: sql.NullString{String: "test", Valid: true},
})
key, err := crypt.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{
Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey,
Sequence: key.Sequence,
})
require.NoError(t, err)
require.Equal(t, "test", key.Secret.String)
require.Equal(t, ciphers[0].HexDigest(), key.SecretKeyID.String)
key, err = db.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{
Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey,
Sequence: key.Sequence,
})
require.NoError(t, err)
requireEncryptedEquals(t, ciphers[0], key.Secret.String, "test")
require.Equal(t, ciphers[0].HexDigest(), key.SecretKeyID.String)
})
t.Run("UpdateCryptoKeyDeletesAt", func(t *testing.T) {
t.Parallel()
_, crypt, ciphers := setup(t)
key := dbgen.CryptoKey(t, crypt, database.CryptoKey{
Secret: sql.NullString{String: "test", Valid: true},
})
key, err := crypt.UpdateCryptoKeyDeletesAt(ctx, database.UpdateCryptoKeyDeletesAtParams{
Feature: key.Feature,
Sequence: key.Sequence,
DeletesAt: sql.NullTime{
Time: time.Now().Add(time.Hour),
Valid: true,
},
})
require.NoError(t, err)
require.Equal(t, "test", key.Secret.String)
require.Equal(t, ciphers[0].HexDigest(), key.SecretKeyID.String)
})
t.Run("GetCryptoKeysByFeature", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
expected := dbgen.CryptoKey(t, crypt, database.CryptoKey{
Sequence: 2,
Feature: database.CryptoKeyFeatureTailnetResume,
Secret: sql.NullString{String: "test", Valid: true},
})
_ = dbgen.CryptoKey(t, crypt, database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey,
Sequence: 43,
})
keys, err := crypt.GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureTailnetResume)
require.NoError(t, err)
require.Len(t, keys, 1)
require.Equal(t, "test", keys[0].Secret.String)
require.Equal(t, ciphers[0].HexDigest(), keys[0].SecretKeyID.String)
require.Equal(t, expected.Sequence, keys[0].Sequence)
require.Equal(t, expected.Feature, keys[0].Feature)
keys, err = db.GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureTailnetResume)
require.NoError(t, err)
require.Len(t, keys, 1)
requireEncryptedEquals(t, ciphers[0], keys[0].Secret.String, "test")
require.Equal(t, ciphers[0].HexDigest(), keys[0].SecretKeyID.String)
require.Equal(t, expected.Sequence, keys[0].Sequence)
require.Equal(t, expected.Feature, keys[0].Feature)
})
t.Run("DecryptErr", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
key := dbgen.CryptoKey(t, db, database.CryptoKey{
Secret: sql.NullString{
String: fakeBase64RandomData(t, 32),
Valid: true,
},
SecretKeyID: sql.NullString{
String: ciphers[0].HexDigest(),
Valid: true,
},
})
_, err := crypt.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{
Feature: key.Feature,
Sequence: key.Sequence,
})
require.Error(t, err, "expected an error")
var derr *DecryptFailedError
require.ErrorAs(t, err, &derr, "expected a decrypt error")
})
}
func TestNew(t *testing.T) {
t.Parallel()
t.Run("OK", func(t *testing.T) {
t.Parallel()
// Given: a cipher is loaded
cipher := initCipher(t)
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
rawDB, _ := dbtestutil.NewDB(t)
// Before: no keys should be present
keys, err := rawDB.GetDBCryptKeys(ctx)
require.NoError(t, err, "no error should be returned")
require.Empty(t, keys, "no keys should be present")
// When: we init the crypt db
_, err = New(ctx, rawDB, cipher)
require.NoError(t, err)
// Then: a new key is inserted
keys, err = rawDB.GetDBCryptKeys(ctx)
require.NoError(t, err)
require.Len(t, keys, 1, "one key should be present")
require.Equal(t, cipher.HexDigest(), keys[0].ActiveKeyDigest.String, "key digest mismatch")
require.Empty(t, keys[0].RevokedKeyDigest.String, "key should not be revoked")
requireEncryptedEquals(t, cipher, keys[0].Test, "coder")
})
t.Run("MissingKey", func(t *testing.T) {
t.Parallel()
// Given: there exist two valid encryption keys
cipher1 := initCipher(t)
cipher2 := initCipher(t)
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
rawDB, _ := dbtestutil.NewDB(t)
// Given: key 1 is already present in the database
err := rawDB.InsertDBCryptKey(ctx, database.InsertDBCryptKeyParams{
Number: 1,
ActiveKeyDigest: cipher1.HexDigest(),
Test: fakeBase64RandomData(t, 32),
})
require.NoError(t, err, "no error should be returned")
keys, err := rawDB.GetDBCryptKeys(ctx)
require.NoError(t, err, "no error should be returned")
require.Len(t, keys, 1, "one key should be present")
// When: we init the crypt db with no keys
_, err = New(ctx, rawDB)
// Then: we error because we don't know how to decrypt the existing key
require.Error(t, err, "expected an error")
var derr *DecryptFailedError
require.ErrorAs(t, err, &derr, "expected a decrypt error")
// When: we init the crypt db with key 2
_, err = New(ctx, rawDB, cipher2)
// Then: we error because the key is not revoked and we don't know how to decrypt it
require.Error(t, err, "expected an error")
require.ErrorAs(t, err, &derr, "expected a decrypt error")
// When: the existing key is marked as having been revoked
err = rawDB.RevokeDBCryptKey(ctx, cipher1.HexDigest())
require.NoError(t, err, "no error should be returned")
// And: we init the crypt db with key 2
_, err = New(ctx, rawDB, cipher2)
// Then: we succeed
require.NoError(t, err)
// And: key 2 should now be the active key
keys, err = rawDB.GetDBCryptKeys(ctx)
require.NoError(t, err)
require.Len(t, keys, 2, "two keys should be present")
require.EqualValues(t, keys[0].Number, 1, "key number mismatch")
require.Empty(t, keys[0].ActiveKeyDigest.String, "key should not be active")
require.Equal(t, cipher1.HexDigest(), keys[0].RevokedKeyDigest.String, "key should be revoked")
require.EqualValues(t, keys[1].Number, 2, "key number mismatch")
require.Equal(t, cipher2.HexDigest(), keys[1].ActiveKeyDigest.String, "key digest mismatch")
require.Empty(t, keys[1].RevokedKeyDigest.String, "key should not be revoked")
requireEncryptedEquals(t, cipher2, keys[1].Test, "coder")
})
t.Run("NoKeys", func(t *testing.T) {
t.Parallel()
// Given: no cipher is loaded
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
rawDB, _ := dbtestutil.NewDB(t)
keys, err := rawDB.GetDBCryptKeys(ctx)
require.NoError(t, err, "no error should be returned")
require.Empty(t, keys, "no keys should be present")
// When: we init the crypt db with no ciphers
_, err = New(ctx, rawDB)
// Then: it should succeed.
require.NoError(t, err, "dbcrypt.New should work with no keys against an unencrypted database")
// Assert invariant: no keys are inserted
keys, err = rawDB.GetDBCryptKeys(ctx)
require.NoError(t, err, "no error should be returned")
require.Empty(t, keys, "no keys should be present")
// Insert a key
require.NoError(t, rawDB.InsertDBCryptKey(ctx, database.InsertDBCryptKeyParams{
Number: 1,
ActiveKeyDigest: "whatever",
Test: fakeBase64RandomData(t, 32),
}))
// This should fail as we do not know how to decrypt the key:
_, err = New(ctx, rawDB)
require.Error(t, err)
// Until we revoke the key:
require.NoError(t, rawDB.RevokeDBCryptKey(ctx, "whatever"))
_, err = New(ctx, rawDB)
require.NoError(t, err, "the above should still hold if the key is revoked")
})
t.Run("PrimaryRevoked", func(t *testing.T) {
t.Parallel()
// Given: a cipher is loaded
cipher := initCipher(t)
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
rawDB, _ := dbtestutil.NewDB(t)
// And: the cipher is revoked before we init the crypt db
err := rawDB.InsertDBCryptKey(ctx, database.InsertDBCryptKeyParams{
Number: 1,
ActiveKeyDigest: cipher.HexDigest(),
Test: fakeBase64RandomData(t, 32),
})
require.NoError(t, err, "no error should be returned")
err = rawDB.RevokeDBCryptKey(ctx, cipher.HexDigest())
require.NoError(t, err, "no error should be returned")
// Then: when we init the crypt db, we error because the key is revoked
_, err = New(ctx, rawDB, cipher)
require.Error(t, err)
require.ErrorContains(t, err, "has been revoked")
})
t.Run("Retry", func(t *testing.T) {
t.Parallel()
// Given: a cipher is loaded
cipher := initCipher(t)
ctx, cancel := context.WithCancel(context.Background())
testVal, err := cipher.Encrypt([]byte("coder"))
key := database.DBCryptKey{
Number: 1,
ActiveKeyDigest: sql.NullString{String: cipher.HexDigest(), Valid: true},
Test: b64encode(testVal),
}
require.NoError(t, err)
t.Cleanup(cancel)
// And: a database that returns an error once when we try to serialize a key
ctrl := gomock.NewController(t)
mockDB := dbmock.NewMockStore(ctrl)
gomock.InOrder(
// First try: we get a serialization error.
expectInTx(mockDB),
mockDB.EXPECT().GetDBCryptKeys(gomock.Any()).Times(1).Return([]database.DBCryptKey{}, nil),
mockDB.EXPECT().InsertDBCryptKey(gomock.Any(), gomock.Any()).Times(1).Return(&pq.Error{Code: "40001"}),
// Second try: we get the key we wanted to insert initially.
expectInTx(mockDB),
mockDB.EXPECT().GetDBCryptKeys(gomock.Any()).Times(1).Return([]database.DBCryptKey{key}, nil),
)
_, err = New(ctx, mockDB, cipher)
require.NoError(t, err)
})
}
func TestEncryptDecryptField(t *testing.T) {
t.Parallel()
t.Run("OK", func(t *testing.T) {
t.Parallel()
_, cryptDB, ciphers := setup(t)
field := "coder"
digest := sql.NullString{}
require.NoError(t, cryptDB.encryptField(&field, &digest))
require.Equal(t, ciphers[0].HexDigest(), digest.String)
requireEncryptedEquals(t, ciphers[0], field, "coder")
require.NoError(t, cryptDB.decryptField(&field, digest))
require.Equal(t, "coder", field)
})
t.Run("NoKeys", func(t *testing.T) {
t.Parallel()
// With no keys, encryption and decryption are both no-ops.
_, cryptDB := setupNoCiphers(t)
field := "coder"
digest := sql.NullString{}
require.NoError(t, cryptDB.encryptField(&field, &digest))
require.Empty(t, digest.String)
require.Equal(t, "coder", field)
require.NoError(t, cryptDB.decryptField(&field, digest))
require.Equal(t, "coder", field)
})
t.Run("MissingKey", func(t *testing.T) {
t.Parallel()
_, cryptDB, ciphers := setup(t)
field := "coder"
digest := sql.NullString{}
err := cryptDB.encryptField(&field, &digest)
require.NoError(t, err)
requireEncryptedEquals(t, ciphers[0], field, "coder")
require.Equal(t, ciphers[0].HexDigest(), digest.String)
require.True(t, digest.Valid)
digest = sql.NullString{String: "missing", Valid: true}
var derr *DecryptFailedError
err = cryptDB.decryptField(&field, digest)
require.Error(t, err)
require.ErrorAs(t, err, &derr)
})
t.Run("CantEncryptOrDecryptNil", func(t *testing.T) {
t.Parallel()
_, cryptDB, _ := setup(t)
require.ErrorContains(t, cryptDB.encryptField(nil, nil), "developer error")
require.ErrorContains(t, cryptDB.decryptField(nil, sql.NullString{}), "developer error")
})
t.Run("EncryptEmptyString", func(t *testing.T) {
t.Parallel()
_, cryptDB, ciphers := setup(t)
field := ""
digest := sql.NullString{}
require.NoError(t, cryptDB.encryptField(&field, &digest))
requireEncryptedEquals(t, ciphers[0], field, "")
require.Equal(t, ciphers[0].HexDigest(), digest.String)
require.NoError(t, cryptDB.decryptField(&field, digest))
require.Empty(t, field)
})
t.Run("DecryptEmptyString", func(t *testing.T) {
t.Parallel()
_, cryptDB, ciphers := setup(t)
field := ""
digest := sql.NullString{String: ciphers[0].HexDigest(), Valid: true}
err := cryptDB.decryptField(&field, digest)
// Currently this has to fail because the ciphertext must at least
// have a nonce. This may need to be changed depending on future
// ciphers.
require.ErrorContains(t, err, "ciphertext too short")
})
t.Run("InvalidBase64", func(t *testing.T) {
t.Parallel()
_, cryptDB, ciphers := setup(t)
field := "not valid base64"
digest := sql.NullString{String: ciphers[0].HexDigest(), Valid: true}
err := cryptDB.decryptField(&field, digest)
require.ErrorContains(t, err, "illegal base64 data")
})
}
func expectInTx(mdb *dbmock.MockStore) *gomock.Call {
return mdb.EXPECT().InTx(gomock.Any(), gomock.Any()).Times(1).DoAndReturn(
func(f func(store database.Store) error, _ *database.TxOptions) error {
return f(mdb)
},
)
}
func requireEncryptedEquals(t *testing.T, c Cipher, value, expected string) {
t.Helper()
data, err := base64.StdEncoding.DecodeString(value)
require.NoError(t, err, "invalid base64")
got, err := c.Decrypt(data)
require.NoError(t, err, "failed to decrypt data")
require.Equal(t, expected, string(got), "decrypted data does not match")
}
func initCipher(t *testing.T) *aes256 {
t.Helper()
key := make([]byte, 32) // AES-256 key size is 32 bytes
_, err := io.ReadFull(rand.Reader, key)
require.NoError(t, err)
c, err := cipherAES256(key)
require.NoError(t, err)
return c
}
func setup(t *testing.T) (db database.Store, cryptDB *dbCrypt, cs []Cipher) {
t.Helper()
rawDB, _ := dbtestutil.NewDB(t)
cs = append(cs, initCipher(t))
cdb, err := New(context.Background(), rawDB, cs...)
require.NoError(t, err)
cryptDB, ok := cdb.(*dbCrypt)
require.True(t, ok)
return rawDB, cryptDB, cs
}
func setupNoCiphers(t *testing.T) (db database.Store, cryptodb *dbCrypt) {
t.Helper()
rawDB, _ := dbtestutil.NewDB(t)
cdb, err := New(context.Background(), rawDB)
require.NoError(t, err)
cryptDB, ok := cdb.(*dbCrypt)
require.True(t, ok)
return rawDB, cryptDB
}
func fakeBase64RandomData(t *testing.T, n int) string {
t.Helper()
b := make([]byte, n)
_, err := io.ReadFull(rand.Reader, b)
require.NoError(t, err)
return base64.StdEncoding.EncodeToString(b)
}