mirror of
https://github.com/coder/coder.git
synced 2026-06-03 13:08:25 +00:00
bddb808b25
Fixes all our Go file imports to match the preferred spec that we've _mostly_ been using. For example: ``` import ( "context" "time" "github.com/prometheus/client_golang/prometheus" "golang.org/x/xerrors" "gopkg.in/natefinch/lumberjack.v2" "cdr.dev/slog/v3" "github.com/coder/coder/v2/codersdk/agentsdk" "github.com/coder/serpent" ) ``` 3 groups: standard library, 3rd partly libs, Coder libs. This PR makes the change across the codebase. The PR in the stack above modifies our formatting to maintain this state of affairs, and is a separate PR so it's possible to review that one in detail.
306 lines
9.3 KiB
Go
306 lines
9.3 KiB
Go
package cryptokeys
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"database/sql"
|
|
"encoding/hex"
|
|
"time"
|
|
|
|
"golang.org/x/xerrors"
|
|
|
|
"cdr.dev/slog/v3"
|
|
"github.com/coder/coder/v2/coderd/database"
|
|
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
|
"github.com/coder/coder/v2/coderd/database/dbtime"
|
|
"github.com/coder/quartz"
|
|
)
|
|
|
|
const (
|
|
WorkspaceAppsTokenDuration = time.Minute
|
|
OIDCConvertTokenDuration = time.Minute * 5
|
|
TailnetResumeTokenDuration = time.Hour * 24
|
|
|
|
// defaultRotationInterval is the default interval at which keys are checked for rotation.
|
|
defaultRotationInterval = time.Minute * 10
|
|
// DefaultKeyDuration is the default duration for which a key is valid. It applies to all features.
|
|
DefaultKeyDuration = time.Hour * 24 * 30
|
|
)
|
|
|
|
// rotator is responsible for rotating keys in the database.
|
|
type rotator struct {
|
|
db database.Store
|
|
logger slog.Logger
|
|
clock quartz.Clock
|
|
keyDuration time.Duration
|
|
|
|
features []database.CryptoKeyFeature
|
|
}
|
|
|
|
type RotatorOption func(*rotator)
|
|
|
|
func WithClock(clock quartz.Clock) RotatorOption {
|
|
return func(r *rotator) {
|
|
r.clock = clock
|
|
}
|
|
}
|
|
|
|
func WithKeyDuration(keyDuration time.Duration) RotatorOption {
|
|
return func(r *rotator) {
|
|
r.keyDuration = keyDuration
|
|
}
|
|
}
|
|
|
|
// StartRotator starts a background process that rotates keys in the database.
|
|
// It ensures there's at least one valid key per feature prior to returning.
|
|
// Canceling the provided context will stop the background process.
|
|
func StartRotator(ctx context.Context, logger slog.Logger, db database.Store, opts ...RotatorOption) {
|
|
//nolint:gocritic // KeyRotator can only rotate crypto keys.
|
|
ctx = dbauthz.AsKeyRotator(ctx)
|
|
kr := &rotator{
|
|
db: db,
|
|
logger: logger.Named("keyrotator"),
|
|
clock: quartz.NewReal(),
|
|
keyDuration: DefaultKeyDuration,
|
|
features: database.AllCryptoKeyFeatureValues(),
|
|
}
|
|
|
|
for _, opt := range opts {
|
|
opt(kr)
|
|
}
|
|
|
|
err := kr.rotateKeys(ctx)
|
|
if err != nil {
|
|
kr.logger.Critical(ctx, "failed to rotate keys", slog.Error(err))
|
|
}
|
|
|
|
go kr.start(ctx)
|
|
}
|
|
|
|
// start begins the process of rotating keys.
|
|
// Canceling the context will stop the rotation process.
|
|
func (k *rotator) start(ctx context.Context) {
|
|
w := k.clock.TickerFunc(ctx, defaultRotationInterval, func() error {
|
|
err := k.rotateKeys(ctx)
|
|
if err != nil {
|
|
k.logger.Error(ctx, "failed to rotate keys", slog.Error(err))
|
|
}
|
|
return nil
|
|
})
|
|
err := w.Wait()
|
|
k.logger.Debug(ctx, "stopping key rotation", slog.Error(err))
|
|
}
|
|
|
|
// rotateKeys checks for any keys needing rotation or deletion and
|
|
// may insert a new key if it detects that a valid one does
|
|
// not exist for a feature.
|
|
func (k *rotator) rotateKeys(ctx context.Context) error {
|
|
return k.db.InTx(
|
|
func(tx database.Store) error {
|
|
err := tx.AcquireLock(ctx, database.LockIDCryptoKeyRotation)
|
|
if err != nil {
|
|
return xerrors.Errorf("acquire lock: %w", err)
|
|
}
|
|
|
|
cryptokeys, err := tx.GetCryptoKeys(ctx)
|
|
if err != nil {
|
|
return xerrors.Errorf("get keys: %w", err)
|
|
}
|
|
|
|
featureKeys, err := keysByFeature(cryptokeys, k.features)
|
|
if err != nil {
|
|
return xerrors.Errorf("keys by feature: %w", err)
|
|
}
|
|
|
|
now := dbtime.Time(k.clock.Now().UTC())
|
|
for feature, keys := range featureKeys {
|
|
// We'll use a counter to determine if we should insert a new key. We should always have at least one key for a feature.
|
|
var validKeys int
|
|
for _, key := range keys {
|
|
switch {
|
|
case shouldDeleteKey(key, now):
|
|
_, err := tx.DeleteCryptoKey(ctx, database.DeleteCryptoKeyParams{
|
|
Feature: key.Feature,
|
|
Sequence: key.Sequence,
|
|
})
|
|
if err != nil {
|
|
return xerrors.Errorf("delete key: %w", err)
|
|
}
|
|
k.logger.Debug(ctx, "deleted key",
|
|
slog.F("key", key.Sequence),
|
|
slog.F("feature", key.Feature),
|
|
)
|
|
case shouldRotateKey(key, k.keyDuration, now):
|
|
_, err := k.rotateKey(ctx, tx, key, now)
|
|
if err != nil {
|
|
return xerrors.Errorf("rotate key: %w", err)
|
|
}
|
|
k.logger.Debug(ctx, "rotated key",
|
|
slog.F("key", key.Sequence),
|
|
slog.F("feature", key.Feature),
|
|
)
|
|
validKeys++
|
|
default:
|
|
// We only consider keys without a populated deletes_at field as valid.
|
|
// This is because under normal circumstances the deletes_at field
|
|
// is set during rotation (meaning a new key was generated)
|
|
// but it's possible if the database was manually altered to
|
|
// delete the new key we may be in a situation where there
|
|
// isn't a key to replace the one scheduled for deletion.
|
|
if !key.DeletesAt.Valid {
|
|
validKeys++
|
|
}
|
|
}
|
|
}
|
|
if validKeys == 0 {
|
|
k.logger.Debug(ctx, "no valid keys detected, inserting new key",
|
|
slog.F("feature", feature),
|
|
)
|
|
_, err := k.insertNewKey(ctx, tx, feature, now)
|
|
if err != nil {
|
|
return xerrors.Errorf("insert new key: %w", err)
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}, &database.TxOptions{
|
|
Isolation: sql.LevelRepeatableRead,
|
|
TxIdentifier: "rotate_keys",
|
|
})
|
|
}
|
|
|
|
func (k *rotator) insertNewKey(ctx context.Context, tx database.Store, feature database.CryptoKeyFeature, startsAt time.Time) (database.CryptoKey, error) {
|
|
secret, err := generateNewSecret(feature)
|
|
if err != nil {
|
|
return database.CryptoKey{}, xerrors.Errorf("generate new secret: %w", err)
|
|
}
|
|
|
|
latestKey, err := tx.GetLatestCryptoKeyByFeature(ctx, feature)
|
|
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
|
|
return database.CryptoKey{}, xerrors.Errorf("get latest key: %w", err)
|
|
}
|
|
|
|
newKey, err := tx.InsertCryptoKey(ctx, database.InsertCryptoKeyParams{
|
|
Feature: feature,
|
|
Sequence: latestKey.Sequence + 1,
|
|
Secret: sql.NullString{
|
|
String: secret,
|
|
Valid: true,
|
|
},
|
|
// Set by dbcrypt if it's required.
|
|
SecretKeyID: sql.NullString{},
|
|
StartsAt: startsAt.UTC(),
|
|
})
|
|
if err != nil {
|
|
return database.CryptoKey{}, xerrors.Errorf("inserting new key: %w", err)
|
|
}
|
|
|
|
k.logger.Debug(ctx, "inserted new key for feature", slog.F("feature", feature), slog.F("sequence", newKey.Sequence))
|
|
return newKey, nil
|
|
}
|
|
|
|
func (k *rotator) rotateKey(ctx context.Context, tx database.Store, key database.CryptoKey, now time.Time) ([]database.CryptoKey, error) {
|
|
startsAt := minStartsAt(key, now, k.keyDuration)
|
|
newKey, err := k.insertNewKey(ctx, tx, key.Feature, startsAt)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("insert new key: %w", err)
|
|
}
|
|
|
|
// Set old key's deletes_at to an hour + however long the token
|
|
// for this feature is expected to be valid for. This should
|
|
// allow for sufficient time for the new key to propagate to
|
|
// dependent services (i.e. Workspace Proxies).
|
|
deletesAt := startsAt.Add(time.Hour).Add(tokenDuration(key.Feature))
|
|
|
|
updatedKey, err := tx.UpdateCryptoKeyDeletesAt(ctx, database.UpdateCryptoKeyDeletesAtParams{
|
|
Feature: key.Feature,
|
|
Sequence: key.Sequence,
|
|
DeletesAt: sql.NullTime{
|
|
Time: deletesAt.UTC(),
|
|
Valid: true,
|
|
},
|
|
})
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("update old key's deletes_at: %w", err)
|
|
}
|
|
|
|
return []database.CryptoKey{updatedKey, newKey}, nil
|
|
}
|
|
|
|
func generateNewSecret(feature database.CryptoKeyFeature) (string, error) {
|
|
switch feature {
|
|
case database.CryptoKeyFeatureWorkspaceAppsAPIKey:
|
|
return generateKey(32)
|
|
case database.CryptoKeyFeatureWorkspaceAppsToken:
|
|
return generateKey(64)
|
|
case database.CryptoKeyFeatureOIDCConvert:
|
|
return generateKey(64)
|
|
case database.CryptoKeyFeatureTailnetResume:
|
|
return generateKey(64)
|
|
}
|
|
return "", xerrors.Errorf("unknown feature: %s", feature)
|
|
}
|
|
|
|
func generateKey(length int) (string, error) {
|
|
b := make([]byte, length)
|
|
_, err := rand.Read(b)
|
|
if err != nil {
|
|
return "", xerrors.Errorf("rand read: %w", err)
|
|
}
|
|
return hex.EncodeToString(b), nil
|
|
}
|
|
|
|
func tokenDuration(feature database.CryptoKeyFeature) time.Duration {
|
|
switch feature {
|
|
case database.CryptoKeyFeatureWorkspaceAppsAPIKey:
|
|
return WorkspaceAppsTokenDuration
|
|
case database.CryptoKeyFeatureWorkspaceAppsToken:
|
|
return WorkspaceAppsTokenDuration
|
|
case database.CryptoKeyFeatureOIDCConvert:
|
|
return OIDCConvertTokenDuration
|
|
case database.CryptoKeyFeatureTailnetResume:
|
|
return TailnetResumeTokenDuration
|
|
default:
|
|
return 0
|
|
}
|
|
}
|
|
|
|
func shouldDeleteKey(key database.CryptoKey, now time.Time) bool {
|
|
return key.DeletesAt.Valid && !now.Before(key.DeletesAt.Time.UTC())
|
|
}
|
|
|
|
func shouldRotateKey(key database.CryptoKey, keyDuration time.Duration, now time.Time) bool {
|
|
// If deletes_at is set, we've already inserted a key.
|
|
if key.DeletesAt.Valid {
|
|
return false
|
|
}
|
|
expirationTime := key.ExpiresAt(keyDuration)
|
|
return !now.Add(time.Hour).UTC().Before(expirationTime)
|
|
}
|
|
|
|
func keysByFeature(keys []database.CryptoKey, features []database.CryptoKeyFeature) (map[database.CryptoKeyFeature][]database.CryptoKey, error) {
|
|
m := map[database.CryptoKeyFeature][]database.CryptoKey{}
|
|
for _, feature := range features {
|
|
m[feature] = []database.CryptoKey{}
|
|
}
|
|
for _, key := range keys {
|
|
if _, ok := m[key.Feature]; !ok {
|
|
return nil, xerrors.Errorf("unknown feature: %s", key.Feature)
|
|
}
|
|
|
|
m[key.Feature] = append(m[key.Feature], key)
|
|
}
|
|
return m, nil
|
|
}
|
|
|
|
// minStartsAt ensures the minimum starts_at time we use for a new
|
|
// key is no less than 3*the default rotation interval.
|
|
func minStartsAt(key database.CryptoKey, now time.Time, keyDuration time.Duration) time.Time {
|
|
expiresAt := key.ExpiresAt(keyDuration)
|
|
minStartsAt := now.Add(3 * defaultRotationInterval)
|
|
if expiresAt.Before(minStartsAt) {
|
|
return minStartsAt
|
|
}
|
|
return expiresAt
|
|
}
|