chore: refactor keycache implementation to reduce duplication (#15100)

This commit is contained in:
Jon Ayers
2024-10-16 20:01:45 +01:00
committed by GitHub
parent 8e254cbb07
commit f537193682
10 changed files with 512 additions and 1339 deletions
-224
View File
@@ -1,224 +0,0 @@
package wsproxy
import (
"context"
"sync"
"time"
"golang.org/x/xerrors"
"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/cryptokeys"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/quartz"
)
const (
// latestSequence is a special sequence number that represents the latest key.
latestSequence = -1
// refreshInterval is the interval at which the key cache will refresh.
refreshInterval = time.Minute * 10
)
type Fetcher interface {
Fetch(ctx context.Context) ([]codersdk.CryptoKey, error)
}
type CryptoKeyCache struct {
Clock quartz.Clock
refreshCtx context.Context
refreshCancel context.CancelFunc
fetcher Fetcher
logger slog.Logger
mu sync.Mutex
keys map[int32]codersdk.CryptoKey
lastFetch time.Time
refresher *quartz.Timer
fetching bool
closed bool
cond *sync.Cond
}
func NewCryptoKeyCache(ctx context.Context, log slog.Logger, client Fetcher, opts ...func(*CryptoKeyCache)) (*CryptoKeyCache, error) {
cache := &CryptoKeyCache{
Clock: quartz.NewReal(),
logger: log,
fetcher: client,
}
for _, opt := range opts {
opt(cache)
}
cache.cond = sync.NewCond(&cache.mu)
cache.refreshCtx, cache.refreshCancel = context.WithCancel(ctx)
cache.refresher = cache.Clock.AfterFunc(refreshInterval, cache.refresh)
keys, err := cache.cryptoKeys(ctx)
if err != nil {
cache.refreshCancel()
return nil, xerrors.Errorf("initial fetch: %w", err)
}
cache.keys = keys
return cache, nil
}
func (k *CryptoKeyCache) Signing(ctx context.Context) (codersdk.CryptoKey, error) {
return k.cryptoKey(ctx, latestSequence)
}
func (k *CryptoKeyCache) Verifying(ctx context.Context, sequence int32) (codersdk.CryptoKey, error) {
return k.cryptoKey(ctx, sequence)
}
func (k *CryptoKeyCache) cryptoKey(ctx context.Context, sequence int32) (codersdk.CryptoKey, error) {
k.mu.Lock()
defer k.mu.Unlock()
if k.closed {
return codersdk.CryptoKey{}, cryptokeys.ErrClosed
}
var key codersdk.CryptoKey
var ok bool
for key, ok = k.key(sequence); !ok && k.fetching && !k.closed; {
k.cond.Wait()
}
if k.closed {
return codersdk.CryptoKey{}, cryptokeys.ErrClosed
}
if ok {
return checkKey(key, sequence, k.Clock.Now())
}
k.fetching = true
k.mu.Unlock()
keys, err := k.cryptoKeys(ctx)
if err != nil {
return codersdk.CryptoKey{}, xerrors.Errorf("get keys: %w", err)
}
k.mu.Lock()
k.lastFetch = k.Clock.Now()
k.refresher.Reset(refreshInterval)
k.keys = keys
k.fetching = false
k.cond.Broadcast()
key, ok = k.key(sequence)
if !ok {
return codersdk.CryptoKey{}, cryptokeys.ErrKeyNotFound
}
return checkKey(key, sequence, k.Clock.Now())
}
func (k *CryptoKeyCache) key(sequence int32) (codersdk.CryptoKey, bool) {
if sequence == latestSequence {
return k.keys[latestSequence], k.keys[latestSequence].CanSign(k.Clock.Now())
}
key, ok := k.keys[sequence]
return key, ok
}
func checkKey(key codersdk.CryptoKey, sequence int32, now time.Time) (codersdk.CryptoKey, error) {
if sequence == latestSequence {
if !key.CanSign(now) {
return codersdk.CryptoKey{}, cryptokeys.ErrKeyInvalid
}
return key, nil
}
if !key.CanVerify(now) {
return codersdk.CryptoKey{}, cryptokeys.ErrKeyInvalid
}
return key, nil
}
// refresh fetches the keys from the control plane and updates the cache.
func (k *CryptoKeyCache) refresh() {
now := k.Clock.Now("CryptoKeyCache", "refresh")
k.mu.Lock()
if k.closed {
k.mu.Unlock()
return
}
// If something's already fetching, we don't need to do anything.
if k.fetching {
k.mu.Unlock()
return
}
// There's a window we must account for where the timer fires while a fetch
// is ongoing but prior to the timer getting reset. In this case we want to
// avoid double fetching.
if now.Sub(k.lastFetch) < refreshInterval {
k.mu.Unlock()
return
}
k.fetching = true
k.mu.Unlock()
keys, err := k.cryptoKeys(k.refreshCtx)
if err != nil {
k.logger.Error(k.refreshCtx, "fetch crypto keys", slog.Error(err))
return
}
k.mu.Lock()
defer k.mu.Unlock()
k.lastFetch = k.Clock.Now()
k.refresher.Reset(refreshInterval)
k.keys = keys
k.fetching = false
k.cond.Broadcast()
}
// cryptoKeys queries the control plane for the crypto keys.
// Outside of initialization, this should only be called by fetch.
func (k *CryptoKeyCache) cryptoKeys(ctx context.Context) (map[int32]codersdk.CryptoKey, error) {
keys, err := k.fetcher.Fetch(ctx)
if err != nil {
return nil, xerrors.Errorf("crypto keys: %w", err)
}
cache := toKeyMap(keys, k.Clock.Now())
return cache, nil
}
func toKeyMap(keys []codersdk.CryptoKey, now time.Time) map[int32]codersdk.CryptoKey {
m := make(map[int32]codersdk.CryptoKey)
var latest codersdk.CryptoKey
for _, key := range keys {
m[key.Sequence] = key
if key.Sequence > latest.Sequence && key.CanSign(now) {
m[latestSequence] = key
}
}
return m
}
func (k *CryptoKeyCache) Close() {
k.mu.Lock()
defer k.mu.Unlock()
if k.closed {
return
}
k.closed = true
k.refreshCancel()
k.refresher.Stop()
k.cond.Broadcast()
}
-485
View File
@@ -1,485 +0,0 @@
package wsproxy_test
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/require"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/cryptokeys"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/enterprise/wsproxy"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
)
func TestCryptoKeyCache(t *testing.T) {
t.Parallel()
t.Run("Signing", func(t *testing.T) {
t.Parallel()
t.Run("HitsCache", func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitShort)
logger = slogtest.Make(t, nil)
clock = quartz.NewMock(t)
)
now := clock.Now().UTC()
expected := codersdk.CryptoKey{
Feature: codersdk.CryptoKeyFeatureWorkspaceApp,
Secret: "key2",
Sequence: 2,
StartsAt: now,
}
ff := &fakeFetcher{
keys: []codersdk.CryptoKey{expected},
}
cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock))
require.NoError(t, err)
got, err := cache.Signing(ctx)
require.NoError(t, err)
require.Equal(t, expected, got)
require.Equal(t, 1, ff.called)
})
t.Run("MissesCache", func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitShort)
logger = slogtest.Make(t, nil)
clock = quartz.NewMock(t)
)
ff := &fakeFetcher{
keys: []codersdk.CryptoKey{},
}
cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock))
require.NoError(t, err)
expected := codersdk.CryptoKey{
Feature: codersdk.CryptoKeyFeatureWorkspaceApp,
Secret: "key1",
Sequence: 12,
StartsAt: clock.Now().UTC(),
}
ff.keys = []codersdk.CryptoKey{expected}
got, err := cache.Signing(ctx)
require.NoError(t, err)
require.Equal(t, expected, got)
// 1 on startup + missing cache.
require.Equal(t, 2, ff.called)
// Ensure the cache gets hit this time.
got, err = cache.Signing(ctx)
require.NoError(t, err)
require.Equal(t, expected, got)
// 1 on startup + missing cache.
require.Equal(t, 2, ff.called)
})
t.Run("IgnoresInvalid", func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitShort)
logger = slogtest.Make(t, nil)
clock = quartz.NewMock(t)
)
now := clock.Now().UTC()
expected := codersdk.CryptoKey{
Feature: codersdk.CryptoKeyFeatureWorkspaceApp,
Secret: "key1",
Sequence: 1,
StartsAt: clock.Now().UTC(),
}
ff := &fakeFetcher{
keys: []codersdk.CryptoKey{
expected,
{
Feature: codersdk.CryptoKeyFeatureWorkspaceApp,
Secret: "key2",
Sequence: 2,
StartsAt: now.Add(-time.Second),
DeletesAt: now,
},
},
}
cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock))
require.NoError(t, err)
got, err := cache.Signing(ctx)
require.NoError(t, err)
require.Equal(t, expected, got)
require.Equal(t, 1, ff.called)
})
t.Run("KeyNotFound", func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitShort)
logger = slogtest.Make(t, nil)
clock = quartz.NewMock(t)
)
ff := &fakeFetcher{
keys: []codersdk.CryptoKey{},
}
cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock))
require.NoError(t, err)
_, err = cache.Signing(ctx)
require.ErrorIs(t, err, cryptokeys.ErrKeyNotFound)
})
})
t.Run("Verifying", func(t *testing.T) {
t.Parallel()
t.Run("HitsCache", func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitShort)
logger = slogtest.Make(t, nil)
clock = quartz.NewMock(t)
)
now := clock.Now().UTC()
expected := codersdk.CryptoKey{
Feature: codersdk.CryptoKeyFeatureWorkspaceApp,
Secret: "key1",
Sequence: 12,
StartsAt: now,
}
ff := &fakeFetcher{
keys: []codersdk.CryptoKey{
expected,
{
Feature: codersdk.CryptoKeyFeatureWorkspaceApp,
Secret: "key2",
Sequence: 13,
StartsAt: now,
},
},
}
cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock))
require.NoError(t, err)
got, err := cache.Verifying(ctx, expected.Sequence)
require.NoError(t, err)
require.Equal(t, expected, got)
require.Equal(t, 1, ff.called)
})
t.Run("MissesCache", func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitShort)
logger = slogtest.Make(t, nil)
clock = quartz.NewMock(t)
)
ff := &fakeFetcher{
keys: []codersdk.CryptoKey{},
}
cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock))
require.NoError(t, err)
expected := codersdk.CryptoKey{
Feature: codersdk.CryptoKeyFeatureWorkspaceApp,
Secret: "key1",
Sequence: 12,
StartsAt: clock.Now().UTC(),
}
ff.keys = []codersdk.CryptoKey{expected}
got, err := cache.Verifying(ctx, expected.Sequence)
require.NoError(t, err)
require.Equal(t, expected, got)
require.Equal(t, 2, ff.called)
// Ensure the cache gets hit this time.
got, err = cache.Verifying(ctx, expected.Sequence)
require.NoError(t, err)
require.Equal(t, expected, got)
require.Equal(t, 2, ff.called)
})
t.Run("AllowsBeforeStartsAt", func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitShort)
logger = slogtest.Make(t, nil)
clock = quartz.NewMock(t)
)
now := clock.Now().UTC()
expected := codersdk.CryptoKey{
Feature: codersdk.CryptoKeyFeatureWorkspaceApp,
Secret: "key1",
Sequence: 12,
StartsAt: now.Add(-time.Second),
}
ff := &fakeFetcher{
keys: []codersdk.CryptoKey{
expected,
},
}
cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock))
require.NoError(t, err)
got, err := cache.Verifying(ctx, expected.Sequence)
require.NoError(t, err)
require.Equal(t, expected, got)
require.Equal(t, 1, ff.called)
})
t.Run("KeyInvalid", func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitShort)
logger = slogtest.Make(t, nil)
clock = quartz.NewMock(t)
)
now := clock.Now().UTC()
expected := codersdk.CryptoKey{
Feature: codersdk.CryptoKeyFeatureWorkspaceApp,
Secret: "key1",
Sequence: 12,
StartsAt: now.Add(-time.Second),
DeletesAt: now,
}
ff := &fakeFetcher{
keys: []codersdk.CryptoKey{
expected,
},
}
cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock))
require.NoError(t, err)
_, err = cache.Verifying(ctx, expected.Sequence)
require.ErrorIs(t, err, cryptokeys.ErrKeyInvalid)
require.Equal(t, 1, ff.called)
})
t.Run("KeyNotFound", func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitShort)
logger = slogtest.Make(t, nil)
clock = quartz.NewMock(t)
)
ff := &fakeFetcher{
keys: []codersdk.CryptoKey{},
}
cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock))
require.NoError(t, err)
_, err = cache.Verifying(ctx, 1)
require.ErrorIs(t, err, cryptokeys.ErrKeyNotFound)
})
})
t.Run("CacheRefreshes", func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitShort)
logger = slogtest.Make(t, nil)
clock = quartz.NewMock(t)
)
now := clock.Now().UTC()
expected := codersdk.CryptoKey{
Feature: codersdk.CryptoKeyFeatureWorkspaceApp,
Secret: "key1",
Sequence: 12,
StartsAt: now,
DeletesAt: now.Add(time.Minute * 10),
}
ff := &fakeFetcher{
keys: []codersdk.CryptoKey{
expected,
},
}
cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock))
require.NoError(t, err)
got, err := cache.Signing(ctx)
require.NoError(t, err)
require.Equal(t, expected, got)
require.Equal(t, 1, ff.called)
newKey := codersdk.CryptoKey{
Feature: codersdk.CryptoKeyFeatureWorkspaceApp,
Secret: "key2",
Sequence: 13,
StartsAt: now,
}
ff.keys = []codersdk.CryptoKey{newKey}
// The ticker should fire and cause a request to coderd.
dur, advance := clock.AdvanceNext()
advance.MustWait(ctx)
require.Equal(t, 2, ff.called)
require.Equal(t, time.Minute*10, dur)
// Assert hits cache.
got, err = cache.Signing(ctx)
require.NoError(t, err)
require.Equal(t, newKey, got)
require.Equal(t, 2, ff.called)
// We check again to ensure the timer has been reset.
_, advance = clock.AdvanceNext()
advance.MustWait(ctx)
require.Equal(t, 3, ff.called)
require.Equal(t, time.Minute*10, dur)
})
// This test ensures that if the refresh timer races with an inflight request
// and loses that it doesn't cause a redundant fetch.
t.Run("RefreshNoDoubleFetch", func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitShort)
logger = slogtest.Make(t, nil)
clock = quartz.NewMock(t)
)
now := clock.Now().UTC()
expected := codersdk.CryptoKey{
Feature: codersdk.CryptoKeyFeatureWorkspaceApp,
Secret: "key1",
Sequence: 12,
StartsAt: now,
DeletesAt: now.Add(time.Minute * 10),
}
ff := &fakeFetcher{
keys: []codersdk.CryptoKey{
expected,
},
}
// Create a trap that blocks when the refresh timer fires.
trap := clock.Trap().Now("refresh")
cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock))
require.NoError(t, err)
_, wait := clock.AdvanceNext()
trapped := trap.MustWait(ctx)
newKey := codersdk.CryptoKey{
Feature: codersdk.CryptoKeyFeatureWorkspaceApp,
Secret: "key2",
Sequence: 13,
StartsAt: now,
}
ff.keys = []codersdk.CryptoKey{newKey}
_, err = cache.Verifying(ctx, newKey.Sequence)
require.NoError(t, err)
require.Equal(t, 2, ff.called)
trapped.Release()
wait.MustWait(ctx)
require.Equal(t, 2, ff.called)
trap.Close()
// The next timer should fire in 10 minutes.
dur, wait := clock.AdvanceNext()
wait.MustWait(ctx)
require.Equal(t, time.Minute*10, dur)
require.Equal(t, 3, ff.called)
})
t.Run("Closed", func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitShort)
logger = slogtest.Make(t, nil)
clock = quartz.NewMock(t)
)
now := clock.Now()
expected := codersdk.CryptoKey{
Feature: codersdk.CryptoKeyFeatureWorkspaceApp,
Secret: "key1",
Sequence: 12,
StartsAt: now,
}
ff := &fakeFetcher{
keys: []codersdk.CryptoKey{
expected,
},
}
cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock))
require.NoError(t, err)
got, err := cache.Signing(ctx)
require.NoError(t, err)
require.Equal(t, expected, got)
require.Equal(t, 1, ff.called)
got, err = cache.Verifying(ctx, expected.Sequence)
require.NoError(t, err)
require.Equal(t, expected, got)
require.Equal(t, 1, ff.called)
cache.Close()
_, err = cache.Signing(ctx)
require.ErrorIs(t, err, cryptokeys.ErrClosed)
_, err = cache.Verifying(ctx, expected.Sequence)
require.ErrorIs(t, err, cryptokeys.ErrClosed)
})
}
type fakeFetcher struct {
keys []codersdk.CryptoKey
called int
}
func (f *fakeFetcher) Fetch(_ context.Context) ([]codersdk.CryptoKey, error) {
f.called++
return f.keys, nil
}
func withClock(clock quartz.Clock) func(*wsproxy.CryptoKeyCache) {
return func(cache *wsproxy.CryptoKeyCache) {
cache.Clock = clock
}
}
+25
View File
@@ -0,0 +1,25 @@
package wsproxy
import (
"context"
"github.com/coder/coder/v2/coderd/cryptokeys"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk"
"golang.org/x/xerrors"
)
var _ cryptokeys.Fetcher = &ProxyFetcher{}
type ProxyFetcher struct {
Client *wsproxysdk.Client
Feature codersdk.CryptoKeyFeature
}
func (p *ProxyFetcher) Fetch(ctx context.Context) ([]codersdk.CryptoKey, error) {
keys, err := p.Client.CryptoKeys(ctx)
if err != nil {
return nil, xerrors.Errorf("crypto keys: %w", err)
}
return keys.CryptoKeys, nil
}