Files
Spike Curtis bddb808b25 chore: arrange imports in a standard way (#21452)
Fixes all our Go file imports to match the preferred spec that we've _mostly_ been using. For example:

```
import (
	"context"
	"time"

	"github.com/prometheus/client_golang/prometheus"
	"golang.org/x/xerrors"
	"gopkg.in/natefinch/lumberjack.v2"

	"cdr.dev/slog/v3"
	"github.com/coder/coder/v2/codersdk/agentsdk"
	"github.com/coder/serpent"
)
```

3 groups: standard library, 3rd partly libs, Coder libs.

This PR makes the change across the codebase. The PR in the stack above modifies our formatting to maintain this state of affairs, and is a separate PR so it's possible to review that one in detail.
2026-01-08 15:24:11 +04:00

405 lines
11 KiB
Go

package cryptokeys
import (
"context"
"encoding/hex"
"fmt"
"io"
"strconv"
"sync"
"time"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/quartz"
)
var (
ErrKeyNotFound = xerrors.New("key not found")
ErrKeyInvalid = xerrors.New("key is invalid for use")
ErrClosed = xerrors.New("closed")
ErrInvalidFeature = xerrors.New("invalid feature for this operation")
)
type Fetcher interface {
Fetch(ctx context.Context, feature codersdk.CryptoKeyFeature) ([]codersdk.CryptoKey, error)
}
type EncryptionKeycache interface {
// EncryptingKey returns the latest valid key for encrypting payloads. A valid
// key is one that is both past its start time and before its deletion time.
EncryptingKey(ctx context.Context) (id string, key interface{}, err error)
// DecryptingKey returns the key with the provided id which maps to its sequence
// number. The key is valid for decryption as long as it is not deleted or past
// its deletion date. We must allow for keys prior to their start time to
// account for clock skew between peers (one key may be past its start time on
// one machine while another is not).
DecryptingKey(ctx context.Context, id string) (key interface{}, err error)
io.Closer
}
type SigningKeycache interface {
// SigningKey returns the latest valid key for signing. A valid key is one
// that is both past its start time and before its deletion time.
SigningKey(ctx context.Context) (id string, key interface{}, err error)
// VerifyingKey returns the key with the provided id which should map to its
// sequence number. The key is valid for verifying as long as it is not deleted
// or past its deletion date. We must allow for keys prior to their start time
// to account for clock skew between peers (one key may be past its start time
// on one machine while another is not).
VerifyingKey(ctx context.Context, id string) (key interface{}, err error)
io.Closer
}
const (
// latestSequence is a special sequence number that represents the latest key.
latestSequence = -1
// refreshInterval is the interval at which the key cache will refresh.
refreshInterval = time.Minute * 10
)
type DBFetcher struct {
DB database.Store
}
func (d *DBFetcher) Fetch(ctx context.Context, feature codersdk.CryptoKeyFeature) ([]codersdk.CryptoKey, error) {
keys, err := d.DB.GetCryptoKeysByFeature(ctx, database.CryptoKeyFeature(feature))
if err != nil {
return nil, xerrors.Errorf("get crypto keys by feature: %w", err)
}
return toSDKKeys(keys), nil
}
// cache implements the caching functionality for both signing and encryption keys.
type cache struct {
ctx context.Context
cancel context.CancelFunc
clock quartz.Clock
fetcher Fetcher
logger slog.Logger
feature codersdk.CryptoKeyFeature
mu sync.Mutex
keys map[int32]codersdk.CryptoKey
lastFetch time.Time
refresher *quartz.Timer
fetching bool
closed bool
cond *sync.Cond
}
type CacheOption func(*cache)
func WithCacheClock(clock quartz.Clock) CacheOption {
return func(d *cache) {
d.clock = clock
}
}
// NewSigningCache instantiates a cache. Close should be called to release resources
// associated with its internal timer.
func NewSigningCache(ctx context.Context, logger slog.Logger, fetcher Fetcher,
feature codersdk.CryptoKeyFeature, opts ...func(*cache),
) (SigningKeycache, error) {
if !isSigningKeyFeature(feature) {
return nil, xerrors.Errorf("invalid feature: %s", feature)
}
logger = logger.Named(fmt.Sprintf("%s_signing_keycache", feature))
return newCache(ctx, logger, fetcher, feature, opts...), nil
}
func NewEncryptionCache(ctx context.Context, logger slog.Logger, fetcher Fetcher,
feature codersdk.CryptoKeyFeature, opts ...func(*cache),
) (EncryptionKeycache, error) {
if !isEncryptionKeyFeature(feature) {
return nil, xerrors.Errorf("invalid feature: %s", feature)
}
logger = logger.Named(fmt.Sprintf("%s_encryption_keycache", feature))
return newCache(ctx, logger, fetcher, feature, opts...), nil
}
func newCache(ctx context.Context, logger slog.Logger, fetcher Fetcher, feature codersdk.CryptoKeyFeature, opts ...func(*cache)) *cache {
cache := &cache{
clock: quartz.NewReal(),
logger: logger.With(slog.F("feature", feature)),
fetcher: fetcher,
feature: feature,
}
for _, opt := range opts {
opt(cache)
}
cache.logger.Debug(ctx, "created new key cache")
cache.cond = sync.NewCond(&cache.mu)
//nolint:gocritic // We need to be able to read the keys in order to cache them.
cache.ctx, cache.cancel = context.WithCancel(dbauthz.AsKeyReader(ctx))
cache.refresher = cache.clock.AfterFunc(refreshInterval, cache.refresh)
keys, err := cache.cryptoKeys(cache.ctx)
if err != nil {
cache.logger.Critical(cache.ctx, "failed initial fetch", slog.Error(err))
}
cache.keys = keys
return cache
}
func (c *cache) EncryptingKey(ctx context.Context) (string, interface{}, error) {
if !isEncryptionKeyFeature(c.feature) {
return "", nil, ErrInvalidFeature
}
//nolint:gocritic // cache can only read crypto keys.
ctx = dbauthz.AsKeyReader(ctx)
return c.cryptoKey(ctx, latestSequence)
}
func (c *cache) DecryptingKey(ctx context.Context, id string) (interface{}, error) {
if !isEncryptionKeyFeature(c.feature) {
return nil, ErrInvalidFeature
}
seq, err := strconv.ParseInt(id, 10, 32)
if err != nil {
return nil, xerrors.Errorf("parse id: %w", err)
}
//nolint:gocritic // cache can only read crypto keys.
ctx = dbauthz.AsKeyReader(ctx)
_, secret, err := c.cryptoKey(ctx, int32(seq))
if err != nil {
return nil, xerrors.Errorf("crypto key: %w", err)
}
return secret, nil
}
func (c *cache) SigningKey(ctx context.Context) (string, interface{}, error) {
if !isSigningKeyFeature(c.feature) {
return "", nil, ErrInvalidFeature
}
//nolint:gocritic // cache can only read crypto keys.
ctx = dbauthz.AsKeyReader(ctx)
return c.cryptoKey(ctx, latestSequence)
}
func (c *cache) VerifyingKey(ctx context.Context, id string) (interface{}, error) {
if !isSigningKeyFeature(c.feature) {
return nil, ErrInvalidFeature
}
seq, err := strconv.ParseInt(id, 10, 32)
if err != nil {
return nil, xerrors.Errorf("parse id: %w", err)
}
//nolint:gocritic // cache can only read crypto keys.
ctx = dbauthz.AsKeyReader(ctx)
_, secret, err := c.cryptoKey(ctx, int32(seq))
if err != nil {
return nil, xerrors.Errorf("crypto key: %w", err)
}
return secret, nil
}
func isEncryptionKeyFeature(feature codersdk.CryptoKeyFeature) bool {
return feature == codersdk.CryptoKeyFeatureWorkspaceAppsAPIKey
}
func isSigningKeyFeature(feature codersdk.CryptoKeyFeature) bool {
switch feature {
case codersdk.CryptoKeyFeatureTailnetResume, codersdk.CryptoKeyFeatureOIDCConvert, codersdk.CryptoKeyFeatureWorkspaceAppsToken:
return true
default:
return false
}
}
func idSecret(k codersdk.CryptoKey) (string, []byte, error) {
key, err := hex.DecodeString(k.Secret)
if err != nil {
return "", nil, xerrors.Errorf("decode key: %w", err)
}
return strconv.FormatInt(int64(k.Sequence), 10), key, nil
}
func (c *cache) cryptoKey(ctx context.Context, sequence int32) (string, []byte, error) {
c.logger.Debug(ctx, "request for key", slog.F("sequence", sequence))
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return "", nil, ErrClosed
}
var key codersdk.CryptoKey
var ok bool
for key, ok = c.key(sequence); !ok && c.fetching && !c.closed; {
c.cond.Wait()
}
if c.closed {
return "", nil, ErrClosed
}
if ok {
return checkKey(key, sequence, c.clock.Now())
}
c.fetching = true
c.mu.Unlock()
keys, err := c.cryptoKeys(ctx)
c.mu.Lock()
if err != nil {
return "", nil, xerrors.Errorf("get keys: %w", err)
}
c.lastFetch = c.clock.Now()
c.refresher.Reset(refreshInterval)
c.keys = keys
c.fetching = false
c.cond.Broadcast()
key, ok = c.key(sequence)
if !ok {
return "", nil, ErrKeyNotFound
}
return checkKey(key, sequence, c.clock.Now())
}
func (c *cache) key(sequence int32) (codersdk.CryptoKey, bool) {
if sequence == latestSequence {
return c.keys[latestSequence], c.keys[latestSequence].CanSign(c.clock.Now())
}
key, ok := c.keys[sequence]
return key, ok
}
func checkKey(key codersdk.CryptoKey, sequence int32, now time.Time) (string, []byte, error) {
if sequence == latestSequence {
if !key.CanSign(now) {
return "", nil, ErrKeyInvalid
}
return idSecret(key)
}
if !key.CanVerify(now) {
return "", nil, ErrKeyInvalid
}
return idSecret(key)
}
// refresh fetches the keys and updates the cache.
func (c *cache) refresh() {
now := c.clock.Now("CryptoKeyCache", "refresh")
c.mu.Lock()
if c.closed {
c.mu.Unlock()
return
}
// If something's already fetching, we don't need to do anything.
if c.fetching {
c.mu.Unlock()
return
}
// There's a window we must account for where the timer fires while a fetch
// is ongoing but prior to the timer getting reset. In this case we want to
// avoid double fetching.
if now.Sub(c.lastFetch) < refreshInterval {
c.mu.Unlock()
return
}
c.fetching = true
c.mu.Unlock()
keys, err := c.cryptoKeys(c.ctx)
if err != nil {
c.logger.Error(c.ctx, "fetch crypto keys", slog.Error(err))
return
}
c.mu.Lock()
defer c.mu.Unlock()
c.lastFetch = c.clock.Now()
c.refresher.Reset(refreshInterval)
c.keys = keys
c.fetching = false
c.cond.Broadcast()
}
// cryptoKeys queries the control plane for the crypto keys.
// Outside of initialization, this should only be called by fetch.
func (c *cache) cryptoKeys(ctx context.Context) (map[int32]codersdk.CryptoKey, error) {
c.logger.Debug(ctx, "fetching crypto keys")
keys, err := c.fetcher.Fetch(ctx, c.feature)
if err != nil {
return nil, xerrors.Errorf("fetch: %w", err)
}
cache := toKeyMap(keys, c.clock.Now())
c.logger.Debug(ctx, "crypto key fetch complete")
return cache, nil
}
func toKeyMap(keys []codersdk.CryptoKey, now time.Time) map[int32]codersdk.CryptoKey {
m := make(map[int32]codersdk.CryptoKey)
var latest codersdk.CryptoKey
for _, key := range keys {
m[key.Sequence] = key
if key.Sequence > latest.Sequence && key.CanSign(now) {
m[latestSequence] = key
latest = key
}
}
return m
}
func (c *cache) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return nil
}
c.closed = true
c.cancel()
c.refresher.Stop()
c.cond.Broadcast()
return nil
}
// We have to do this to avoid a circular dependency on db2sdk (cryptokeys -> db2sdk -> tailnet -> cryptokeys)
func toSDKKeys(keys []database.CryptoKey) []codersdk.CryptoKey {
into := make([]codersdk.CryptoKey, 0, len(keys))
for _, key := range keys {
into = append(into, toSDK(key))
}
return into
}
func toSDK(key database.CryptoKey) codersdk.CryptoKey {
return codersdk.CryptoKey{
Feature: codersdk.CryptoKeyFeature(key.Feature),
Sequence: key.Sequence,
StartsAt: key.StartsAt,
DeletesAt: key.DeletesAt.Time,
Secret: key.Secret.String,
}
}