From 2d5c06852516deb35287bb6862db20c6f62b1fc1 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Thu, 19 Sep 2024 19:12:44 +0100 Subject: [PATCH] feat: implement key rotation system (#14710) --- coderd/database/dbgen/dbgen.go | 6 +- coderd/database/lock.go | 1 + coderd/keyrotate/rotate.go | 298 +++++++++++ coderd/keyrotate/rotate_internal_test.go | 601 +++++++++++++++++++++++ coderd/keyrotate/rotate_test.go | 124 +++++ 5 files changed, 1029 insertions(+), 1 deletion(-) create mode 100644 coderd/keyrotate/rotate.go create mode 100644 coderd/keyrotate/rotate_internal_test.go create mode 100644 coderd/keyrotate/rotate_test.go diff --git a/coderd/database/dbgen/dbgen.go b/coderd/database/dbgen/dbgen.go index b7c7ef54d8..d18da855be 100644 --- a/coderd/database/dbgen/dbgen.go +++ b/coderd/database/dbgen/dbgen.go @@ -902,7 +902,11 @@ func CryptoKey(t testing.TB, db database.Store, seed database.CryptoKey) databas seed.Feature = takeFirst(seed.Feature, database.CryptoKeyFeatureWorkspaceApps) - if !seed.Secret.Valid { + // An empty string for the secret is interpreted as + // a caller wanting a new secret to be generated. + // To generate a key with a NULL secret set Valid=false + // and String to a non-empty string. + if seed.Secret.String == "" { secret, err := newCryptoKeySecret(seed.Feature) require.NoError(t, err, "generate secret") seed.Secret = sql.NullString{ diff --git a/coderd/database/lock.go b/coderd/database/lock.go index 0ebf6b0f14..0bc8b2a75d 100644 --- a/coderd/database/lock.go +++ b/coderd/database/lock.go @@ -11,6 +11,7 @@ const ( LockIDDBRollup LockIDDBPurge LockIDNotificationsReportGenerator + LockIDCryptoKeyRotation ) // GenLockID generates a unique and consistent lock ID from a given string. diff --git a/coderd/keyrotate/rotate.go b/coderd/keyrotate/rotate.go new file mode 100644 index 0000000000..b3046161aa --- /dev/null +++ b/coderd/keyrotate/rotate.go @@ -0,0 +1,298 @@ +package keyrotate + +import ( + "context" + "crypto/rand" + "database/sql" + "encoding/hex" + "time" + + "golang.org/x/xerrors" + + "cdr.dev/slog" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbtime" + "github.com/coder/quartz" +) + +const ( + WorkspaceAppsTokenDuration = time.Minute + OIDCConvertTokenDuration = time.Minute * 5 + TailnetResumeTokenDuration = time.Hour * 24 + + // defaultRotationInterval is the default interval at which keys are checked for rotation. + defaultRotationInterval = time.Minute * 10 + // DefaultKeyDuration is the default duration for which a key is valid. It applies to all features. + DefaultKeyDuration = time.Hour * 24 * 30 +) + +// rotator is responsible for rotating keys in the database. +type rotator struct { + db database.Store + logger slog.Logger + clock quartz.Clock + keyDuration time.Duration + + features []database.CryptoKeyFeature +} + +type Option func(*rotator) + +func WithClock(clock quartz.Clock) Option { + return func(r *rotator) { + r.clock = clock + } +} + +func WithKeyDuration(keyDuration time.Duration) Option { + return func(r *rotator) { + r.keyDuration = keyDuration + } +} + +// StartRotator starts a background process that rotates keys in the database. +// It ensures there's at least one valid key per feature prior to returning. +// Canceling the provided context will stop the background process. +func StartRotator(ctx context.Context, logger slog.Logger, db database.Store, opts ...Option) error { + kr := &rotator{ + db: db, + logger: logger, + clock: quartz.NewReal(), + keyDuration: DefaultKeyDuration, + features: database.AllCryptoKeyFeatureValues(), + } + + for _, opt := range opts { + opt(kr) + } + + err := kr.rotateKeys(ctx) + if err != nil { + return xerrors.Errorf("rotate keys: %w", err) + } + + go kr.start(ctx) + + return nil +} + +// start begins the process of rotating keys. +// Canceling the context will stop the rotation process. +func (k *rotator) start(ctx context.Context) { + k.clock.TickerFunc(ctx, defaultRotationInterval, func() error { + err := k.rotateKeys(ctx) + if err != nil { + k.logger.Error(ctx, "failed to rotate keys", slog.Error(err)) + } + return nil + }) + k.logger.Debug(ctx, "ctx canceled, stopping key rotation") +} + +// rotateKeys checks for any keys needing rotation or deletion and +// may insert a new key if it detects that a valid one does +// not exist for a feature. +func (k *rotator) rotateKeys(ctx context.Context) error { + return k.db.InTx( + func(tx database.Store) error { + err := tx.AcquireLock(ctx, database.LockIDCryptoKeyRotation) + if err != nil { + return xerrors.Errorf("acquire lock: %w", err) + } + + cryptokeys, err := tx.GetCryptoKeys(ctx) + if err != nil { + return xerrors.Errorf("get keys: %w", err) + } + + featureKeys, err := keysByFeature(cryptokeys, k.features) + if err != nil { + return xerrors.Errorf("keys by feature: %w", err) + } + + now := dbtime.Time(k.clock.Now().UTC()) + for feature, keys := range featureKeys { + // We'll use a counter to determine if we should insert a new key. We should always have at least one key for a feature. + var validKeys int + for _, key := range keys { + switch { + case shouldDeleteKey(key, now): + _, err := tx.DeleteCryptoKey(ctx, database.DeleteCryptoKeyParams{ + Feature: key.Feature, + Sequence: key.Sequence, + }) + if err != nil { + return xerrors.Errorf("delete key: %w", err) + } + k.logger.Debug(ctx, "deleted key", + slog.F("key", key.Sequence), + slog.F("feature", key.Feature), + ) + case shouldRotateKey(key, k.keyDuration, now): + _, err := k.rotateKey(ctx, tx, key, now) + if err != nil { + return xerrors.Errorf("rotate key: %w", err) + } + k.logger.Debug(ctx, "rotated key", + slog.F("key", key.Sequence), + slog.F("feature", key.Feature), + ) + validKeys++ + default: + // We only consider keys without a populated deletes_at field as valid. + // This is because under normal circumstances the deletes_at field + // is set during rotation (meaning a new key was generated) + // but it's possible if the database was manually altered to + // delete the new key we may be in a situation where there + // isn't a key to replace the one scheduled for deletion. + if !key.DeletesAt.Valid { + validKeys++ + } + } + } + if validKeys == 0 { + k.logger.Info(ctx, "no valid keys detected, inserting new key", + slog.F("feature", feature), + ) + _, err := k.insertNewKey(ctx, tx, feature, now) + if err != nil { + return xerrors.Errorf("insert new key: %w", err) + } + } + } + return nil + }, &sql.TxOptions{ + Isolation: sql.LevelRepeatableRead, + }) +} + +func (k *rotator) insertNewKey(ctx context.Context, tx database.Store, feature database.CryptoKeyFeature, startsAt time.Time) (database.CryptoKey, error) { + secret, err := generateNewSecret(feature) + if err != nil { + return database.CryptoKey{}, xerrors.Errorf("generate new secret: %w", err) + } + + latestKey, err := tx.GetLatestCryptoKeyByFeature(ctx, feature) + if err != nil && !xerrors.Is(err, sql.ErrNoRows) { + return database.CryptoKey{}, xerrors.Errorf("get latest key: %w", err) + } + + newKey, err := tx.InsertCryptoKey(ctx, database.InsertCryptoKeyParams{ + Feature: feature, + Sequence: latestKey.Sequence + 1, + Secret: sql.NullString{ + String: secret, + Valid: true, + }, + // Set by dbcrypt if it's required. + SecretKeyID: sql.NullString{}, + StartsAt: startsAt.UTC(), + }) + if err != nil { + return database.CryptoKey{}, xerrors.Errorf("inserting new key: %w", err) + } + + k.logger.Info(ctx, "inserted new key for feature", slog.F("feature", feature)) + return newKey, nil +} + +func (k *rotator) rotateKey(ctx context.Context, tx database.Store, key database.CryptoKey, now time.Time) ([]database.CryptoKey, error) { + startsAt := minStartsAt(key, now, k.keyDuration) + newKey, err := k.insertNewKey(ctx, tx, key.Feature, startsAt) + if err != nil { + return nil, xerrors.Errorf("insert new key: %w", err) + } + + // Set old key's deletes_at to an hour + however long the token + // for this feature is expected to be valid for. This should + // allow for sufficient time for the new key to propagate to + // dependent services (i.e. Workspace Proxies). + deletesAt := startsAt.Add(time.Hour).Add(tokenDuration(key.Feature)) + + updatedKey, err := tx.UpdateCryptoKeyDeletesAt(ctx, database.UpdateCryptoKeyDeletesAtParams{ + Feature: key.Feature, + Sequence: key.Sequence, + DeletesAt: sql.NullTime{ + Time: deletesAt.UTC(), + Valid: true, + }, + }) + if err != nil { + return nil, xerrors.Errorf("update old key's deletes_at: %w", err) + } + + return []database.CryptoKey{updatedKey, newKey}, nil +} + +func generateNewSecret(feature database.CryptoKeyFeature) (string, error) { + switch feature { + case database.CryptoKeyFeatureWorkspaceApps: + return generateKey(96) + case database.CryptoKeyFeatureOidcConvert: + return generateKey(32) + case database.CryptoKeyFeatureTailnetResume: + return generateKey(64) + } + return "", xerrors.Errorf("unknown feature: %s", feature) +} + +func generateKey(length int) (string, error) { + b := make([]byte, length) + _, err := rand.Read(b) + if err != nil { + return "", xerrors.Errorf("rand read: %w", err) + } + return hex.EncodeToString(b), nil +} + +func tokenDuration(feature database.CryptoKeyFeature) time.Duration { + switch feature { + case database.CryptoKeyFeatureWorkspaceApps: + return WorkspaceAppsTokenDuration + case database.CryptoKeyFeatureOidcConvert: + return OIDCConvertTokenDuration + case database.CryptoKeyFeatureTailnetResume: + return TailnetResumeTokenDuration + default: + return 0 + } +} + +func shouldDeleteKey(key database.CryptoKey, now time.Time) bool { + return key.DeletesAt.Valid && !now.Before(key.DeletesAt.Time.UTC()) +} + +func shouldRotateKey(key database.CryptoKey, keyDuration time.Duration, now time.Time) bool { + // If deletes_at is set, we've already inserted a key. + if key.DeletesAt.Valid { + return false + } + expirationTime := key.ExpiresAt(keyDuration) + return !now.Add(time.Hour).UTC().Before(expirationTime) +} + +func keysByFeature(keys []database.CryptoKey, features []database.CryptoKeyFeature) (map[database.CryptoKeyFeature][]database.CryptoKey, error) { + m := map[database.CryptoKeyFeature][]database.CryptoKey{} + for _, feature := range features { + m[feature] = []database.CryptoKey{} + } + for _, key := range keys { + if _, ok := m[key.Feature]; !ok { + return nil, xerrors.Errorf("unknown feature: %s", key.Feature) + } + + m[key.Feature] = append(m[key.Feature], key) + } + return m, nil +} + +// minStartsAt ensures the minimum starts_at time we use for a new +// key is no less than 3*the default rotation interval. +func minStartsAt(key database.CryptoKey, now time.Time, keyDuration time.Duration) time.Time { + expiresAt := key.ExpiresAt(keyDuration) + minStartsAt := now.Add(3 * defaultRotationInterval) + if expiresAt.Before(minStartsAt) { + return minStartsAt + } + return expiresAt +} diff --git a/coderd/keyrotate/rotate_internal_test.go b/coderd/keyrotate/rotate_internal_test.go new file mode 100644 index 0000000000..94160a947b --- /dev/null +++ b/coderd/keyrotate/rotate_internal_test.go @@ -0,0 +1,601 @@ +package keyrotate + +import ( + "database/sql" + "encoding/hex" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbgen" + "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/coderd/database/dbtime" + "github.com/coder/coder/v2/testutil" + "github.com/coder/quartz" +) + +func Test_rotateKeys(t *testing.T) { + t.Parallel() + + t.Run("RotatesKeysNearExpiration", func(t *testing.T) { + t.Parallel() + + var ( + db, _ = dbtestutil.NewDB(t) + clock = quartz.NewMock(t) + keyDuration = time.Hour * 24 * 7 + logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug) + ctx = testutil.Context(t, testutil.WaitShort) + ) + + kr := &rotator{ + db: db, + keyDuration: keyDuration, + clock: clock, + logger: logger, + features: []database.CryptoKeyFeature{ + database.CryptoKeyFeatureWorkspaceApps, + }, + } + + now := dbnow(clock) + + // Seed the database with an existing key. + oldKey := dbgen.CryptoKey(t, db, database.CryptoKey{ + Feature: database.CryptoKeyFeatureWorkspaceApps, + StartsAt: now, + Sequence: 15, + }) + + // Advance the window to just inside rotation time. + _ = clock.Advance(keyDuration - time.Minute*59) + err := kr.rotateKeys(ctx) + require.NoError(t, err) + + now = dbnow(clock) + expectedDeletesAt := oldKey.ExpiresAt(keyDuration).Add(WorkspaceAppsTokenDuration + time.Hour) + + // Fetch the old key, it should have an deletes_at now. + oldKey, err = db.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{ + Feature: oldKey.Feature, + Sequence: oldKey.Sequence, + }) + require.NoError(t, err) + require.Equal(t, oldKey.DeletesAt.Time.UTC(), expectedDeletesAt) + + // The new key should be created and have a starts_at of the old key's expires_at. + newKey, err := db.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{ + Feature: database.CryptoKeyFeatureWorkspaceApps, + Sequence: oldKey.Sequence + 1, + }) + require.NoError(t, err) + requireKey(t, newKey, database.CryptoKeyFeatureWorkspaceApps, oldKey.ExpiresAt(keyDuration), nullTime, oldKey.Sequence+1) + + // Advance the clock just before the keys delete time. + clock.Advance(oldKey.DeletesAt.Time.UTC().Sub(now) - time.Second) + + // No action should be taken. + err = kr.rotateKeys(ctx) + require.NoError(t, err) + + keys, err := db.GetCryptoKeys(ctx) + require.NoError(t, err) + require.Len(t, keys, 2) + + // Advance the clock just past the keys delete time. + clock.Advance(oldKey.DeletesAt.Time.UTC().Sub(now) + time.Second) + + // We should have deleted the old key. + err = kr.rotateKeys(ctx) + require.NoError(t, err) + + // The old key should be "deleted". + _, err = db.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{ + Feature: oldKey.Feature, + Sequence: oldKey.Sequence, + }) + require.ErrorIs(t, err, sql.ErrNoRows) + + keys, err = db.GetCryptoKeys(ctx) + require.NoError(t, err) + require.Len(t, keys, 1) + require.Equal(t, newKey, keys[0]) + }) + + t.Run("DoesNotRotateValidKeys", func(t *testing.T) { + t.Parallel() + + var ( + db, _ = dbtestutil.NewDB(t) + clock = quartz.NewMock(t) + keyDuration = time.Hour * 24 * 7 + logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug) + ctx = testutil.Context(t, testutil.WaitShort) + ) + + kr := &rotator{ + db: db, + keyDuration: keyDuration, + clock: clock, + logger: logger, + features: []database.CryptoKeyFeature{ + database.CryptoKeyFeatureWorkspaceApps, + }, + } + + now := dbnow(clock) + + // Seed the database with an existing key + existingKey := dbgen.CryptoKey(t, db, database.CryptoKey{ + Feature: database.CryptoKeyFeatureWorkspaceApps, + StartsAt: now, + Sequence: 123, + }) + + // Advance the clock by 6 days, 22 hours. Once we + // breach the last hour we will insert a new key. + clock.Advance(keyDuration - 2*time.Hour) + + err := kr.rotateKeys(ctx) + require.NoError(t, err) + + keys, err := db.GetCryptoKeys(ctx) + require.NoError(t, err) + require.Len(t, keys, 1) + require.Equal(t, existingKey, keys[0]) + + // Advance it again to just before the key is scheduled to be rotated for sanity purposes. + clock.Advance(time.Hour - time.Second) + + err = kr.rotateKeys(ctx) + require.NoError(t, err) + + // Verify that the existing key is still the only key in the database + keys, err = db.GetCryptoKeys(ctx) + require.NoError(t, err) + require.Len(t, keys, 1) + requireKey(t, keys[0], existingKey.Feature, existingKey.StartsAt.UTC(), nullTime, existingKey.Sequence) + }) + + // Simulate a situation where the database was manually altered such that we only have a key that is scheduled to be deleted and assert we insert a new key. + t.Run("DeletesExpiredKeys", func(t *testing.T) { + t.Parallel() + + var ( + db, _ = dbtestutil.NewDB(t) + clock = quartz.NewMock(t) + keyDuration = time.Hour * 24 * 7 + logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug) + ctx = testutil.Context(t, testutil.WaitShort) + ) + + kr := &rotator{ + db: db, + keyDuration: keyDuration, + clock: clock, + logger: logger, + features: []database.CryptoKeyFeature{ + database.CryptoKeyFeatureWorkspaceApps, + }, + } + + now := dbnow(clock) + + // Seed the database with an existing key + deletingKey := dbgen.CryptoKey(t, db, database.CryptoKey{ + Feature: database.CryptoKeyFeatureWorkspaceApps, + StartsAt: now.Add(-keyDuration), + Sequence: 789, + DeletesAt: sql.NullTime{ + Time: now, + Valid: true, + }, + }) + + err := kr.rotateKeys(ctx) + require.NoError(t, err) + + // We should only get one key since the old key + // should be deleted. + keys, err := db.GetCryptoKeys(ctx) + require.NoError(t, err) + require.Len(t, keys, 1) + requireKey(t, keys[0], deletingKey.Feature, deletingKey.DeletesAt.Time.UTC(), nullTime, deletingKey.Sequence+1) + // The old key should be "deleted". + _, err = db.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{ + Feature: deletingKey.Feature, + Sequence: deletingKey.Sequence, + }) + require.ErrorIs(t, err, sql.ErrNoRows) + }) + + // This tests a situation where we have a key scheduled for deletion but it's still valid for use. + // If no other key is detected we should insert a new key. + t.Run("AddsKeyForDeletingKey", func(t *testing.T) { + t.Parallel() + + var ( + db, _ = dbtestutil.NewDB(t) + clock = quartz.NewMock(t) + keyDuration = time.Hour * 24 * 7 + logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug) + ctx = testutil.Context(t, testutil.WaitShort) + ) + + kr := &rotator{ + db: db, + keyDuration: keyDuration, + clock: clock, + logger: logger, + features: []database.CryptoKeyFeature{ + database.CryptoKeyFeatureWorkspaceApps, + }, + } + + now := dbnow(clock) + + // Seed the database with an existing key + deletingKey := dbgen.CryptoKey(t, db, database.CryptoKey{ + Feature: database.CryptoKeyFeatureWorkspaceApps, + StartsAt: now, + Sequence: 456, + DeletesAt: sql.NullTime{ + Time: now.Add(time.Hour), + Valid: true, + }, + }) + + // We should only have inserted a key. + err := kr.rotateKeys(ctx) + require.NoError(t, err) + + keys, err := db.GetCryptoKeys(ctx) + require.NoError(t, err) + require.Len(t, keys, 2) + oldKey, newKey := keys[0], keys[1] + if oldKey.Sequence != deletingKey.Sequence { + oldKey, newKey = newKey, oldKey + } + requireKey(t, oldKey, deletingKey.Feature, deletingKey.StartsAt.UTC(), deletingKey.DeletesAt, deletingKey.Sequence) + requireKey(t, newKey, deletingKey.Feature, now, nullTime, deletingKey.Sequence+1) + }) + + t.Run("NoKeys", func(t *testing.T) { + t.Parallel() + + var ( + db, _ = dbtestutil.NewDB(t) + clock = quartz.NewMock(t) + keyDuration = time.Hour * 24 * 7 + logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug) + ctx = testutil.Context(t, testutil.WaitShort) + ) + + kr := &rotator{ + db: db, + keyDuration: keyDuration, + clock: clock, + logger: logger, + features: []database.CryptoKeyFeature{ + database.CryptoKeyFeatureWorkspaceApps, + }, + } + + err := kr.rotateKeys(ctx) + require.NoError(t, err) + + keys, err := db.GetCryptoKeys(ctx) + require.NoError(t, err) + require.Len(t, keys, 1) + requireKey(t, keys[0], database.CryptoKeyFeatureWorkspaceApps, clock.Now().UTC(), nullTime, 1) + }) + + // Assert we insert a new key when the only key was manually deleted. + t.Run("OnlyDeletedKeys", func(t *testing.T) { + t.Parallel() + + var ( + db, _ = dbtestutil.NewDB(t) + clock = quartz.NewMock(t) + keyDuration = time.Hour * 24 * 7 + logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug) + ctx = testutil.Context(t, testutil.WaitShort) + ) + + kr := &rotator{ + db: db, + keyDuration: keyDuration, + clock: clock, + logger: logger, + features: []database.CryptoKeyFeature{ + database.CryptoKeyFeatureWorkspaceApps, + }, + } + + now := dbnow(clock) + + deletedkey := dbgen.CryptoKey(t, db, database.CryptoKey{ + Feature: database.CryptoKeyFeatureWorkspaceApps, + StartsAt: now, + Sequence: 19, + DeletesAt: sql.NullTime{ + Time: now.Add(time.Hour), + Valid: true, + }, + Secret: sql.NullString{ + String: "deleted", + Valid: false, + }, + }) + + err := kr.rotateKeys(ctx) + require.NoError(t, err) + + keys, err := db.GetCryptoKeys(ctx) + require.NoError(t, err) + require.Len(t, keys, 1) + requireKey(t, keys[0], database.CryptoKeyFeatureWorkspaceApps, now, nullTime, deletedkey.Sequence+1) + }) + + // This tests ensures that rotation works with multiple + // features. It's mainly a sanity test since some bugs + // are not unveiled in the simple n=1 case. + t.Run("AllFeatures", func(t *testing.T) { + t.Parallel() + + var ( + db, _ = dbtestutil.NewDB(t) + clock = quartz.NewMock(t) + keyDuration = time.Hour * 24 * 30 + logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug) + ctx = testutil.Context(t, testutil.WaitShort) + ) + + kr := &rotator{ + db: db, + keyDuration: keyDuration, + clock: clock, + logger: logger, + features: database.AllCryptoKeyFeatureValues(), + } + + now := dbnow(clock) + + // We'll test a scenario where one feature has no valid keys. + // Another has a key that should be rotate. And one that + // has a valid key that shouldn't trigger an action. + _ = dbgen.CryptoKey(t, db, database.CryptoKey{ + Feature: database.CryptoKeyFeatureTailnetResume, + StartsAt: now.Add(-keyDuration), + Sequence: 5, + Secret: sql.NullString{ + String: "older key", + Valid: false, + }, + }) + deletedKey := dbgen.CryptoKey(t, db, database.CryptoKey{ + Feature: database.CryptoKeyFeatureTailnetResume, + StartsAt: now.Add(-keyDuration), + Sequence: 19, + Secret: sql.NullString{ + String: "old key", + Valid: false, + }, + }) + + // Insert a key that should be rotated. + rotatedKey := dbgen.CryptoKey(t, db, database.CryptoKey{ + Feature: database.CryptoKeyFeatureWorkspaceApps, + StartsAt: now.Add(-keyDuration + time.Hour), + Sequence: 42, + }) + + // Insert a key that should not trigger an action. + validKey := dbgen.CryptoKey(t, db, database.CryptoKey{ + Feature: database.CryptoKeyFeatureOidcConvert, + StartsAt: now, + Sequence: 17, + }) + + err := kr.rotateKeys(ctx) + require.NoError(t, err) + + keys, err := db.GetCryptoKeys(ctx) + require.NoError(t, err) + require.Len(t, keys, 4) + + kbf, err := keysByFeature(keys, database.AllCryptoKeyFeatureValues()) + require.NoError(t, err) + + // No actions on OIDC convert. + require.Len(t, kbf[database.CryptoKeyFeatureOidcConvert], 1) + // Workspace apps should have been rotated. + require.Len(t, kbf[database.CryptoKeyFeatureWorkspaceApps], 2) + // No existing key for tailnet resume should've + // caused a key to be inserted. + require.Len(t, kbf[database.CryptoKeyFeatureTailnetResume], 1) + + oidcKey := kbf[database.CryptoKeyFeatureOidcConvert][0] + tailnetKey := kbf[database.CryptoKeyFeatureTailnetResume][0] + requireKey(t, oidcKey, database.CryptoKeyFeatureOidcConvert, now, nullTime, validKey.Sequence) + requireKey(t, tailnetKey, database.CryptoKeyFeatureTailnetResume, now, nullTime, deletedKey.Sequence+1) + + newKey := kbf[database.CryptoKeyFeatureWorkspaceApps][0] + oldKey := kbf[database.CryptoKeyFeatureWorkspaceApps][1] + if newKey.Sequence == rotatedKey.Sequence { + oldKey, newKey = newKey, oldKey + } + deletesAt := sql.NullTime{ + Time: rotatedKey.ExpiresAt(keyDuration).Add(WorkspaceAppsTokenDuration + time.Hour), + Valid: true, + } + requireKey(t, oldKey, database.CryptoKeyFeatureWorkspaceApps, rotatedKey.StartsAt.UTC(), deletesAt, rotatedKey.Sequence) + requireKey(t, newKey, database.CryptoKeyFeatureWorkspaceApps, rotatedKey.ExpiresAt(keyDuration), nullTime, rotatedKey.Sequence+1) + }) + + t.Run("UnknownFeature", func(t *testing.T) { + t.Parallel() + + var ( + db, _ = dbtestutil.NewDB(t) + clock = quartz.NewMock(t) + keyDuration = time.Hour * 24 * 7 + logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug) + ctx = testutil.Context(t, testutil.WaitShort) + ) + + kr := &rotator{ + db: db, + keyDuration: keyDuration, + clock: clock, + logger: logger, + features: []database.CryptoKeyFeature{database.CryptoKeyFeature("unknown")}, + } + + err := kr.rotateKeys(ctx) + require.Error(t, err) + }) + + t.Run("MinStartsAt", func(t *testing.T) { + t.Parallel() + + var ( + db, _ = dbtestutil.NewDB(t) + clock = quartz.NewMock(t) + keyDuration = time.Hour * 24 * 5 + logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug) + ctx = testutil.Context(t, testutil.WaitShort) + ) + + now := dbnow(clock) + + kr := &rotator{ + db: db, + keyDuration: keyDuration, + clock: clock, + logger: logger, + features: []database.CryptoKeyFeature{database.CryptoKeyFeatureWorkspaceApps}, + } + + expiringKey := dbgen.CryptoKey(t, db, database.CryptoKey{ + Feature: database.CryptoKeyFeatureWorkspaceApps, + StartsAt: now.Add(-keyDuration), + Sequence: 345, + }) + + err := kr.rotateKeys(ctx) + require.NoError(t, err) + + keys, err := db.GetCryptoKeys(ctx) + require.NoError(t, err) + require.Len(t, keys, 2) + + rotatedKey, err := db.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{ + Feature: expiringKey.Feature, + Sequence: expiringKey.Sequence + 1, + }) + require.NoError(t, err) + require.Equal(t, now.Add(defaultRotationInterval*3), rotatedKey.StartsAt.UTC()) + }) + + // Test that the the deletes_at of a key that is well past its expiration + // Has its deletes_at field set to value that is relative + // to the current time to afford propagation time for the + // new key. + t.Run("ExtensivelyExpiredKey", func(t *testing.T) { + t.Parallel() + + var ( + db, _ = dbtestutil.NewDB(t) + clock = quartz.NewMock(t) + keyDuration = time.Hour * 24 * 3 + logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug) + ctx = testutil.Context(t, testutil.WaitShort) + ) + + kr := &rotator{ + db: db, + keyDuration: keyDuration, + clock: clock, + logger: logger, + features: []database.CryptoKeyFeature{database.CryptoKeyFeatureWorkspaceApps}, + } + + now := dbnow(clock) + + expiredKey := dbgen.CryptoKey(t, db, database.CryptoKey{ + Feature: database.CryptoKeyFeatureWorkspaceApps, + StartsAt: now.Add(-keyDuration - 2*time.Hour), + Sequence: 19, + }) + + deletedKey := dbgen.CryptoKey(t, db, database.CryptoKey{ + Feature: database.CryptoKeyFeatureWorkspaceApps, + StartsAt: now, + Sequence: 20, + Secret: sql.NullString{ + String: "deleted", + Valid: false, + }, + }) + + err := kr.rotateKeys(ctx) + require.NoError(t, err) + + keys, err := db.GetCryptoKeys(ctx) + require.NoError(t, err) + require.Len(t, keys, 2) + + deletesAtKey, err := db.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{ + Feature: expiredKey.Feature, + Sequence: expiredKey.Sequence, + }) + + deletesAt := sql.NullTime{ + Time: now.Add(defaultRotationInterval * 3).Add(WorkspaceAppsTokenDuration + time.Hour), + Valid: true, + } + require.NoError(t, err) + requireKey(t, deletesAtKey, expiredKey.Feature, expiredKey.StartsAt.UTC(), deletesAt, expiredKey.Sequence) + + newKey, err := db.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{ + Feature: expiredKey.Feature, + Sequence: deletedKey.Sequence + 1, + }) + require.NoError(t, err) + requireKey(t, newKey, expiredKey.Feature, now.Add(defaultRotationInterval*3), nullTime, deletedKey.Sequence+1) + }) +} + +func dbnow(c quartz.Clock) time.Time { + return dbtime.Time(c.Now().UTC()) +} + +func requireKey(t *testing.T, key database.CryptoKey, feature database.CryptoKeyFeature, startsAt time.Time, deletesAt sql.NullTime, sequence int32) { + t.Helper() + require.Equal(t, feature, key.Feature) + require.Equal(t, startsAt, key.StartsAt.UTC()) + require.Equal(t, deletesAt.Valid, key.DeletesAt.Valid) + require.Equal(t, deletesAt.Time.UTC(), key.DeletesAt.Time.UTC()) + require.Equal(t, sequence, key.Sequence) + + secret, err := hex.DecodeString(key.Secret.String) + require.NoError(t, err) + + switch key.Feature { + case database.CryptoKeyFeatureOidcConvert: + require.Len(t, secret, 32) + case database.CryptoKeyFeatureWorkspaceApps: + require.Len(t, secret, 96) + case database.CryptoKeyFeatureTailnetResume: + require.Len(t, secret, 64) + default: + t.Fatalf("unknown key feature: %s", key.Feature) + } +} + +var nullTime = sql.NullTime{Time: time.Time{}, Valid: false} diff --git a/coderd/keyrotate/rotate_test.go b/coderd/keyrotate/rotate_test.go new file mode 100644 index 0000000000..43a62ac451 --- /dev/null +++ b/coderd/keyrotate/rotate_test.go @@ -0,0 +1,124 @@ +package keyrotate_test + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbgen" + "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/coderd/keyrotate" + "github.com/coder/coder/v2/testutil" + "github.com/coder/quartz" +) + +func TestRotator(t *testing.T) { + t.Parallel() + + t.Run("NoKeysOnInit", func(t *testing.T) { + t.Parallel() + + var ( + db, _ = dbtestutil.NewDB(t) + clock = quartz.NewMock(t) + logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug) + ctx = testutil.Context(t, testutil.WaitShort) + ) + + dbkeys, err := db.GetCryptoKeys(ctx) + require.NoError(t, err) + require.Len(t, dbkeys, 0) + + err = keyrotate.StartRotator(ctx, logger, db, keyrotate.WithClock(clock)) + require.NoError(t, err) + + // Fetch the keys from the database and ensure they + // are as expected. + dbkeys, err = db.GetCryptoKeys(ctx) + require.NoError(t, err) + require.Len(t, dbkeys, len(database.AllCryptoKeyFeatureValues())) + requireContainsAllFeatures(t, dbkeys) + }) + + t.Run("RotateKeys", func(t *testing.T) { + t.Parallel() + + var ( + db, _ = dbtestutil.NewDB(t) + clock = quartz.NewMock(t) + logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug) + ctx = testutil.Context(t, testutil.WaitShort) + ) + + now := clock.Now().UTC() + + rotatingKey := dbgen.CryptoKey(t, db, database.CryptoKey{ + Feature: database.CryptoKeyFeatureWorkspaceApps, + StartsAt: now.Add(-keyrotate.DefaultKeyDuration + time.Hour + time.Minute), + Sequence: 12345, + }) + + trap := clock.Trap().TickerFunc() + t.Cleanup(trap.Close) + + err := keyrotate.StartRotator(ctx, logger, db, keyrotate.WithClock(clock)) + require.NoError(t, err) + + initialKeyLen := len(database.AllCryptoKeyFeatureValues()) + // Fetch the keys from the database and ensure they + // are as expected. + dbkeys, err := db.GetCryptoKeys(ctx) + require.NoError(t, err) + require.Len(t, dbkeys, initialKeyLen) + requireContainsAllFeatures(t, dbkeys) + + trap.MustWait(ctx).Release() + _, wait := clock.AdvanceNext() + wait.MustWait(ctx) + + keys, err := db.GetCryptoKeys(ctx) + require.NoError(t, err) + require.Len(t, keys, initialKeyLen+1) + + newKey, err := db.GetLatestCryptoKeyByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps) + require.NoError(t, err) + require.Equal(t, rotatingKey.Sequence+1, newKey.Sequence) + require.Equal(t, rotatingKey.ExpiresAt(keyrotate.DefaultKeyDuration), newKey.StartsAt.UTC()) + require.False(t, newKey.DeletesAt.Valid) + + oldKey, err := db.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{ + Feature: rotatingKey.Feature, + Sequence: rotatingKey.Sequence, + }) + expectedDeletesAt := rotatingKey.StartsAt.Add(keyrotate.DefaultKeyDuration + time.Hour + keyrotate.WorkspaceAppsTokenDuration) + require.NoError(t, err) + require.Equal(t, rotatingKey.StartsAt, oldKey.StartsAt) + require.True(t, oldKey.DeletesAt.Valid) + require.Equal(t, expectedDeletesAt, oldKey.DeletesAt.Time) + + // Try rotating again and ensure no keys are rotated. + _, wait = clock.AdvanceNext() + wait.MustWait(ctx) + + keys, err = db.GetCryptoKeys(ctx) + require.NoError(t, err) + require.Len(t, keys, initialKeyLen+1) + }) +} + +func requireContainsAllFeatures(t *testing.T, keys []database.CryptoKey) { + t.Helper() + + features := make(map[database.CryptoKeyFeature]bool) + for _, key := range keys { + features[key.Feature] = true + } + for _, feature := range database.AllCryptoKeyFeatureValues() { + require.True(t, features[feature]) + } +}