feat: enable key rotation (#15066)

This PR contains the remaining logic necessary to hook up key rotation
to the product.
This commit is contained in:
Jon Ayers
2024-10-25 17:14:35 +01:00
committed by GitHub
parent ccfffc6911
commit cd890aa3a0
54 changed files with 1412 additions and 1129 deletions
+21 -81
View File
@@ -10,7 +10,6 @@ import (
"crypto/tls"
"crypto/x509"
"database/sql"
"encoding/hex"
"errors"
"flag"
"fmt"
@@ -62,6 +61,7 @@ import (
"github.com/coder/serpent"
"github.com/coder/wgtunnel/tunnelsdk"
"github.com/coder/coder/v2/coderd/cryptokeys"
"github.com/coder/coder/v2/coderd/entitlements"
"github.com/coder/coder/v2/coderd/notifications/reports"
"github.com/coder/coder/v2/coderd/runtimeconfig"
@@ -97,7 +97,6 @@ import (
"github.com/coder/coder/v2/coderd/updatecheck"
"github.com/coder/coder/v2/coderd/util/slice"
stringutil "github.com/coder/coder/v2/coderd/util/strings"
"github.com/coder/coder/v2/coderd/workspaceapps"
"github.com/coder/coder/v2/coderd/workspaceapps/appurl"
"github.com/coder/coder/v2/coderd/workspacestats"
"github.com/coder/coder/v2/codersdk"
@@ -743,90 +742,31 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
return xerrors.Errorf("set deployment id: %w", err)
}
}
// Read the app signing key from the DB. We store it hex encoded
// since the config table uses strings for the value and we
// don't want to deal with automatic encoding issues.
appSecurityKeyStr, err := tx.GetAppSecurityKey(ctx)
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
return xerrors.Errorf("get app signing key: %w", err)
}
// If the string in the DB is an invalid hex string or the
// length is not equal to the current key length, generate a new
// one.
//
// If the key is regenerated, old signed tokens and encrypted
// strings will become invalid. New signed app tokens will be
// generated automatically on failure. Any workspace app token
// smuggling operations in progress may fail, although with a
// helpful error.
if decoded, err := hex.DecodeString(appSecurityKeyStr); err != nil || len(decoded) != len(workspaceapps.SecurityKey{}) {
b := make([]byte, len(workspaceapps.SecurityKey{}))
_, err := rand.Read(b)
if err != nil {
return xerrors.Errorf("generate fresh app signing key: %w", err)
}
appSecurityKeyStr = hex.EncodeToString(b)
err = tx.UpsertAppSecurityKey(ctx, appSecurityKeyStr)
if err != nil {
return xerrors.Errorf("insert freshly generated app signing key to database: %w", err)
}
}
appSecurityKey, err := workspaceapps.KeyFromString(appSecurityKeyStr)
if err != nil {
return xerrors.Errorf("decode app signing key from database: %w", err)
}
options.AppSecurityKey = appSecurityKey
// Read the oauth signing key from the database. Like the app security, generate a new one
// if it is invalid for any reason.
oauthSigningKeyStr, err := tx.GetOAuthSigningKey(ctx)
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
return xerrors.Errorf("get app oauth signing key: %w", err)
}
if decoded, err := hex.DecodeString(oauthSigningKeyStr); err != nil || len(decoded) != len(options.OAuthSigningKey) {
b := make([]byte, len(options.OAuthSigningKey))
_, err := rand.Read(b)
if err != nil {
return xerrors.Errorf("generate fresh oauth signing key: %w", err)
}
oauthSigningKeyStr = hex.EncodeToString(b)
err = tx.UpsertOAuthSigningKey(ctx, oauthSigningKeyStr)
if err != nil {
return xerrors.Errorf("insert freshly generated oauth signing key to database: %w", err)
}
}
oauthKeyBytes, err := hex.DecodeString(oauthSigningKeyStr)
if err != nil {
return xerrors.Errorf("decode oauth signing key from database: %w", err)
}
if len(oauthKeyBytes) != len(options.OAuthSigningKey) {
return xerrors.Errorf("oauth signing key in database is not the correct length, expect %d got %d", len(options.OAuthSigningKey), len(oauthKeyBytes))
}
copy(options.OAuthSigningKey[:], oauthKeyBytes)
if options.OAuthSigningKey == [32]byte{} {
return xerrors.Errorf("oauth signing key in database is empty")
}
// Read the coordinator resume token signing key from the
// database.
resumeTokenKey, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, tx)
if err != nil {
return xerrors.Errorf("get coordinator resume token key from database: %w", err)
}
options.CoordinatorResumeTokenProvider = tailnet.NewResumeTokenKeyProvider(resumeTokenKey, quartz.NewReal(), tailnet.DefaultResumeTokenExpiry)
return nil
}, nil)
if err != nil {
return err
return xerrors.Errorf("set deployment id: %w", err)
}
fetcher := &cryptokeys.DBFetcher{
DB: options.Database,
}
resumeKeycache, err := cryptokeys.NewSigningCache(ctx,
logger,
fetcher,
codersdk.CryptoKeyFeatureTailnetResume,
)
if err != nil {
logger.Critical(ctx, "failed to properly instantiate tailnet resume signing cache", slog.Error(err))
}
options.CoordinatorResumeTokenProvider = tailnet.NewResumeTokenKeyProvider(
resumeKeycache,
quartz.NewReal(),
tailnet.DefaultResumeTokenExpiry,
)
options.RuntimeConfig = runtimeconfig.NewManager()
// This should be output before the logs start streaming.
+13 -5
View File
@@ -7646,6 +7646,15 @@ const docTemplate = `{
],
"summary": "Get workspace proxy crypto keys",
"operationId": "get-workspace-proxy-crypto-keys",
"parameters": [
{
"type": "string",
"description": "Feature key",
"name": "feature",
"in": "query",
"required": true
}
],
"responses": {
"200": {
"description": "OK",
@@ -10011,12 +10020,14 @@ const docTemplate = `{
"codersdk.CryptoKeyFeature": {
"type": "string",
"enum": [
"workspace_apps",
"workspace_apps_api_key",
"workspace_apps_token",
"oidc_convert",
"tailnet_resume"
],
"x-enum-varnames": [
"CryptoKeyFeatureWorkspaceApp",
"CryptoKeyFeatureWorkspaceAppsAPIKey",
"CryptoKeyFeatureWorkspaceAppsToken",
"CryptoKeyFeatureOIDCConvert",
"CryptoKeyFeatureTailnetResume"
]
@@ -16244,9 +16255,6 @@ const docTemplate = `{
"wsproxysdk.RegisterWorkspaceProxyResponse": {
"type": "object",
"properties": {
"app_security_key": {
"type": "string"
},
"derp_force_websockets": {
"type": "boolean"
},
+17 -5
View File
@@ -6758,6 +6758,15 @@
"tags": ["Enterprise"],
"summary": "Get workspace proxy crypto keys",
"operationId": "get-workspace-proxy-crypto-keys",
"parameters": [
{
"type": "string",
"description": "Feature key",
"name": "feature",
"in": "query",
"required": true
}
],
"responses": {
"200": {
"description": "OK",
@@ -8914,9 +8923,15 @@
},
"codersdk.CryptoKeyFeature": {
"type": "string",
"enum": ["workspace_apps", "oidc_convert", "tailnet_resume"],
"enum": [
"workspace_apps_api_key",
"workspace_apps_token",
"oidc_convert",
"tailnet_resume"
],
"x-enum-varnames": [
"CryptoKeyFeatureWorkspaceApp",
"CryptoKeyFeatureWorkspaceAppsAPIKey",
"CryptoKeyFeatureWorkspaceAppsToken",
"CryptoKeyFeatureOIDCConvert",
"CryptoKeyFeatureTailnetResume"
]
@@ -14853,9 +14868,6 @@
"wsproxysdk.RegisterWorkspaceProxyResponse": {
"type": "object",
"properties": {
"app_security_key": {
"type": "string"
},
"derp_force_websockets": {
"type": "boolean"
},
+61 -11
View File
@@ -40,6 +40,7 @@ import (
"github.com/coder/quartz"
"github.com/coder/serpent"
"github.com/coder/coder/v2/coderd/cryptokeys"
"github.com/coder/coder/v2/coderd/entitlements"
"github.com/coder/coder/v2/coderd/idpsync"
"github.com/coder/coder/v2/coderd/runtimeconfig"
@@ -185,9 +186,6 @@ type Options struct {
TemplateScheduleStore *atomic.Pointer[schedule.TemplateScheduleStore]
UserQuietHoursScheduleStore *atomic.Pointer[schedule.UserQuietHoursScheduleStore]
AccessControlStore *atomic.Pointer[dbauthz.AccessControlStore]
// AppSecurityKey is the crypto key used to sign and encrypt tokens related to
// workspace applications. It consists of both a signing and encryption key.
AppSecurityKey workspaceapps.SecurityKey
// CoordinatorResumeTokenProvider is used to provide and validate resume
// tokens issued by and passed to the coordinator DRPC API.
CoordinatorResumeTokenProvider tailnet.ResumeTokenProvider
@@ -251,6 +249,12 @@ type Options struct {
// OneTimePasscodeValidityPeriod specifies how long a one time passcode should be valid for.
OneTimePasscodeValidityPeriod time.Duration
// Keycaches
AppSigningKeyCache cryptokeys.SigningKeycache
AppEncryptionKeyCache cryptokeys.EncryptionKeycache
OIDCConvertKeyCache cryptokeys.SigningKeycache
Clock quartz.Clock
}
// @title Coder API
@@ -352,6 +356,9 @@ func New(options *Options) *API {
if options.PrometheusRegistry == nil {
options.PrometheusRegistry = prometheus.NewRegistry()
}
if options.Clock == nil {
options.Clock = quartz.NewReal()
}
if options.DERPServer == nil && options.DeploymentValues.DERP.Server.Enable {
options.DERPServer = derp.NewServer(key.NewNode(), tailnet.Logger(options.Logger.Named("derp")))
}
@@ -444,6 +451,49 @@ func New(options *Options) *API {
if err != nil {
panic(xerrors.Errorf("get deployment ID: %w", err))
}
fetcher := &cryptokeys.DBFetcher{
DB: options.Database,
}
if options.OIDCConvertKeyCache == nil {
options.OIDCConvertKeyCache, err = cryptokeys.NewSigningCache(ctx,
options.Logger.Named("oidc_convert_keycache"),
fetcher,
codersdk.CryptoKeyFeatureOIDCConvert,
)
if err != nil {
options.Logger.Critical(ctx, "failed to properly instantiate oidc convert signing cache", slog.Error(err))
}
}
if options.AppSigningKeyCache == nil {
options.AppSigningKeyCache, err = cryptokeys.NewSigningCache(ctx,
options.Logger.Named("app_signing_keycache"),
fetcher,
codersdk.CryptoKeyFeatureWorkspaceAppsToken,
)
if err != nil {
options.Logger.Critical(ctx, "failed to properly instantiate app signing key cache", slog.Error(err))
}
}
if options.AppEncryptionKeyCache == nil {
options.AppEncryptionKeyCache, err = cryptokeys.NewEncryptionCache(ctx,
options.Logger,
fetcher,
codersdk.CryptoKeyFeatureWorkspaceAppsAPIKey,
)
if err != nil {
options.Logger.Critical(ctx, "failed to properly instantiate app encryption key cache", slog.Error(err))
}
}
// Start a background process that rotates keys. We intentionally start this after the caches
// are created to force initial requests for a key to populate the caches. This helps catch
// bugs that may only occur when a key isn't precached in tests and the latency cost is minimal.
cryptokeys.StartRotator(ctx, options.Logger, options.Database)
api := &API{
ctx: ctx,
cancel: cancel,
@@ -464,7 +514,7 @@ func New(options *Options) *API {
options.DeploymentValues,
oauthConfigs,
options.AgentInactiveDisconnectTimeout,
options.AppSecurityKey,
options.AppSigningKeyCache,
),
metricsCache: metricsCache,
Auditor: atomic.Pointer[audit.Auditor]{},
@@ -606,7 +656,7 @@ func New(options *Options) *API {
ResumeTokenProvider: api.Options.CoordinatorResumeTokenProvider,
})
if err != nil {
api.Logger.Fatal(api.ctx, "failed to initialize tailnet client service", slog.Error(err))
api.Logger.Fatal(context.Background(), "failed to initialize tailnet client service", slog.Error(err))
}
api.statsReporter = workspacestats.NewReporter(workspacestats.ReporterOptions{
@@ -628,9 +678,6 @@ func New(options *Options) *API {
options.WorkspaceAppsStatsCollectorOptions.Reporter = api.statsReporter
}
if options.AppSecurityKey.IsZero() {
api.Logger.Fatal(api.ctx, "app security key cannot be zero")
}
api.workspaceAppServer = &workspaceapps.Server{
Logger: workspaceAppsLogger,
@@ -642,11 +689,11 @@ func New(options *Options) *API {
SignedTokenProvider: api.WorkspaceAppsProvider,
AgentProvider: api.agentProvider,
AppSecurityKey: options.AppSecurityKey,
StatsCollector: workspaceapps.NewStatsCollector(options.WorkspaceAppsStatsCollectorOptions),
DisablePathApps: options.DeploymentValues.DisablePathApps.Value(),
SecureAuthCookie: options.DeploymentValues.SecureAuthCookie.Value(),
DisablePathApps: options.DeploymentValues.DisablePathApps.Value(),
SecureAuthCookie: options.DeploymentValues.SecureAuthCookie.Value(),
APIKeyEncryptionKeycache: options.AppEncryptionKeyCache,
}
apiKeyMiddleware := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
@@ -1434,6 +1481,9 @@ func (api *API) Close() error {
_ = api.agentProvider.Close()
_ = api.statsReporter.Close()
_ = api.NetworkTelemetryBatcher.Close()
_ = api.OIDCConvertKeyCache.Close()
_ = api.AppSigningKeyCache.Close()
_ = api.AppEncryptionKeyCache.Close()
return nil
}
+9 -7
View File
@@ -55,6 +55,7 @@ import (
"github.com/coder/coder/v2/coderd/audit"
"github.com/coder/coder/v2/coderd/autobuild"
"github.com/coder/coder/v2/coderd/awsidentity"
"github.com/coder/coder/v2/coderd/cryptokeys"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/coderd/database/dbauthz"
@@ -88,12 +89,9 @@ import (
sdkproto "github.com/coder/coder/v2/provisionersdk/proto"
"github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
)
// AppSecurityKey is a 96-byte key used to sign JWTs and encrypt JWEs for
// workspace app tokens in tests.
var AppSecurityKey = must(workspaceapps.KeyFromString("6465616e207761732068657265206465616e207761732068657265206465616e207761732068657265206465616e207761732068657265206465616e207761732068657265206465616e207761732068657265206465616e2077617320686572"))
type Options struct {
// AccessURL denotes a custom access URL. By default we use the httptest
// server's URL. Setting this may result in unexpected behavior (especially
@@ -161,8 +159,10 @@ type Options struct {
DatabaseRolluper *dbrollup.Rolluper
WorkspaceUsageTrackerFlush chan int
WorkspaceUsageTrackerTick chan time.Time
NotificationsEnqueuer notifications.Enqueuer
NotificationsEnqueuer notifications.Enqueuer
APIKeyEncryptionCache cryptokeys.EncryptionKeycache
OIDCConvertKeyCache cryptokeys.SigningKeycache
Clock quartz.Clock
}
// New constructs a codersdk client connected to an in-memory API instance.
@@ -525,7 +525,6 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can
DeploymentOptions: codersdk.DeploymentOptionsWithoutSecrets(options.DeploymentValues.Options()),
UpdateCheckOptions: options.UpdateCheckOptions,
SwaggerEndpoint: options.SwaggerEndpoint,
AppSecurityKey: AppSecurityKey,
SSHConfig: options.ConfigSSH,
HealthcheckFunc: options.HealthcheckFunc,
HealthcheckTimeout: options.HealthcheckTimeout,
@@ -538,6 +537,9 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can
WorkspaceUsageTracker: wuTracker,
NotificationsEnqueuer: options.NotificationsEnqueuer,
OneTimePasscodeValidityPeriod: options.OneTimePasscodeValidityPeriod,
Clock: options.Clock,
AppEncryptionKeyCache: options.APIKeyEncryptionCache,
OIDCConvertKeyCache: options.OIDCConvertKeyCache,
}
}
+61 -31
View File
@@ -3,6 +3,7 @@ package cryptokeys
import (
"context"
"encoding/hex"
"fmt"
"io"
"strconv"
"sync"
@@ -12,7 +13,7 @@ import (
"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/quartz"
)
@@ -25,7 +26,7 @@ var (
)
type Fetcher interface {
Fetch(ctx context.Context) ([]codersdk.CryptoKey, error)
Fetch(ctx context.Context, feature codersdk.CryptoKeyFeature) ([]codersdk.CryptoKey, error)
}
type EncryptionKeycache interface {
@@ -62,27 +63,26 @@ const (
)
type DBFetcher struct {
DB database.Store
Feature database.CryptoKeyFeature
DB database.Store
}
func (d *DBFetcher) Fetch(ctx context.Context) ([]codersdk.CryptoKey, error) {
keys, err := d.DB.GetCryptoKeysByFeature(ctx, d.Feature)
func (d *DBFetcher) Fetch(ctx context.Context, feature codersdk.CryptoKeyFeature) ([]codersdk.CryptoKey, error) {
keys, err := d.DB.GetCryptoKeysByFeature(ctx, database.CryptoKeyFeature(feature))
if err != nil {
return nil, xerrors.Errorf("get crypto keys by feature: %w", err)
}
return db2sdk.CryptoKeys(keys), nil
return toSDKKeys(keys), nil
}
// cache implements the caching functionality for both signing and encryption keys.
type cache struct {
clock quartz.Clock
refreshCtx context.Context
refreshCancel context.CancelFunc
fetcher Fetcher
logger slog.Logger
feature codersdk.CryptoKeyFeature
ctx context.Context
cancel context.CancelFunc
clock quartz.Clock
fetcher Fetcher
logger slog.Logger
feature codersdk.CryptoKeyFeature
mu sync.Mutex
keys map[int32]codersdk.CryptoKey
@@ -109,7 +109,8 @@ func NewSigningCache(ctx context.Context, logger slog.Logger, fetcher Fetcher,
if !isSigningKeyFeature(feature) {
return nil, xerrors.Errorf("invalid feature: %s", feature)
}
return newCache(ctx, logger, fetcher, feature, opts...)
logger = logger.Named(fmt.Sprintf("%s_signing_keycache", feature))
return newCache(ctx, logger, fetcher, feature, opts...), nil
}
func NewEncryptionCache(ctx context.Context, logger slog.Logger, fetcher Fetcher,
@@ -118,10 +119,11 @@ func NewEncryptionCache(ctx context.Context, logger slog.Logger, fetcher Fetcher
if !isEncryptionKeyFeature(feature) {
return nil, xerrors.Errorf("invalid feature: %s", feature)
}
return newCache(ctx, logger, fetcher, feature, opts...)
logger = logger.Named(fmt.Sprintf("%s_encryption_keycache", feature))
return newCache(ctx, logger, fetcher, feature, opts...), nil
}
func newCache(ctx context.Context, logger slog.Logger, fetcher Fetcher, feature codersdk.CryptoKeyFeature, opts ...func(*cache)) (*cache, error) {
func newCache(ctx context.Context, logger slog.Logger, fetcher Fetcher, feature codersdk.CryptoKeyFeature, opts ...func(*cache)) *cache {
cache := &cache{
clock: quartz.NewReal(),
logger: logger,
@@ -134,16 +136,16 @@ func newCache(ctx context.Context, logger slog.Logger, fetcher Fetcher, feature
}
cache.cond = sync.NewCond(&cache.mu)
cache.refreshCtx, cache.refreshCancel = context.WithCancel(ctx)
//nolint:gocritic // We need to be able to read the keys in order to cache them.
cache.ctx, cache.cancel = context.WithCancel(dbauthz.AsKeyReader(ctx))
cache.refresher = cache.clock.AfterFunc(refreshInterval, cache.refresh)
keys, err := cache.cryptoKeys(ctx)
keys, err := cache.cryptoKeys(cache.ctx)
if err != nil {
cache.refreshCancel()
return nil, xerrors.Errorf("initial fetch: %w", err)
cache.logger.Critical(cache.ctx, "failed initial fetch", slog.Error(err))
}
cache.keys = keys
return cache, nil
return cache
}
func (c *cache) EncryptingKey(ctx context.Context) (string, interface{}, error) {
@@ -151,6 +153,8 @@ func (c *cache) EncryptingKey(ctx context.Context) (string, interface{}, error)
return "", nil, ErrInvalidFeature
}
//nolint:gocritic // cache can only read crypto keys.
ctx = dbauthz.AsKeyReader(ctx)
return c.cryptoKey(ctx, latestSequence)
}
@@ -164,6 +168,8 @@ func (c *cache) DecryptingKey(ctx context.Context, id string) (interface{}, erro
return nil, xerrors.Errorf("parse id: %w", err)
}
//nolint:gocritic // cache can only read crypto keys.
ctx = dbauthz.AsKeyReader(ctx)
_, secret, err := c.cryptoKey(ctx, int32(seq))
if err != nil {
return nil, xerrors.Errorf("crypto key: %w", err)
@@ -176,6 +182,8 @@ func (c *cache) SigningKey(ctx context.Context) (string, interface{}, error) {
return "", nil, ErrInvalidFeature
}
//nolint:gocritic // cache can only read crypto keys.
ctx = dbauthz.AsKeyReader(ctx)
return c.cryptoKey(ctx, latestSequence)
}
@@ -188,7 +196,8 @@ func (c *cache) VerifyingKey(ctx context.Context, id string) (interface{}, error
if err != nil {
return nil, xerrors.Errorf("parse id: %w", err)
}
//nolint:gocritic // cache can only read crypto keys.
ctx = dbauthz.AsKeyReader(ctx)
_, secret, err := c.cryptoKey(ctx, int32(seq))
if err != nil {
return nil, xerrors.Errorf("crypto key: %w", err)
@@ -198,12 +207,12 @@ func (c *cache) VerifyingKey(ctx context.Context, id string) (interface{}, error
}
func isEncryptionKeyFeature(feature codersdk.CryptoKeyFeature) bool {
return feature == codersdk.CryptoKeyFeatureWorkspaceApp
return feature == codersdk.CryptoKeyFeatureWorkspaceAppsAPIKey
}
func isSigningKeyFeature(feature codersdk.CryptoKeyFeature) bool {
switch feature {
case codersdk.CryptoKeyFeatureTailnetResume, codersdk.CryptoKeyFeatureOIDCConvert:
case codersdk.CryptoKeyFeatureTailnetResume, codersdk.CryptoKeyFeatureOIDCConvert, codersdk.CryptoKeyFeatureWorkspaceAppsToken:
return true
default:
return false
@@ -292,14 +301,15 @@ func checkKey(key codersdk.CryptoKey, sequence int32, now time.Time) (string, []
func (c *cache) refresh() {
now := c.clock.Now("CryptoKeyCache", "refresh")
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
c.mu.Unlock()
return
}
// If something's already fetching, we don't need to do anything.
if c.fetching {
c.mu.Unlock()
return
}
@@ -307,20 +317,21 @@ func (c *cache) refresh() {
// is ongoing but prior to the timer getting reset. In this case we want to
// avoid double fetching.
if now.Sub(c.lastFetch) < refreshInterval {
c.mu.Unlock()
return
}
c.fetching = true
c.mu.Unlock()
keys, err := c.cryptoKeys(c.refreshCtx)
keys, err := c.cryptoKeys(c.ctx)
if err != nil {
c.logger.Error(c.refreshCtx, "fetch crypto keys", slog.Error(err))
c.logger.Error(c.ctx, "fetch crypto keys", slog.Error(err))
return
}
// We don't defer an unlock here due to the deferred unlock at the top of the function.
c.mu.Lock()
defer c.mu.Unlock()
c.lastFetch = c.clock.Now()
c.refresher.Reset(refreshInterval)
@@ -332,9 +343,9 @@ func (c *cache) refresh() {
// cryptoKeys queries the control plane for the crypto keys.
// Outside of initialization, this should only be called by fetch.
func (c *cache) cryptoKeys(ctx context.Context) (map[int32]codersdk.CryptoKey, error) {
keys, err := c.fetcher.Fetch(ctx)
keys, err := c.fetcher.Fetch(ctx, c.feature)
if err != nil {
return nil, xerrors.Errorf("crypto keys: %w", err)
return nil, xerrors.Errorf("fetch: %w", err)
}
cache := toKeyMap(keys, c.clock.Now())
return cache, nil
@@ -361,9 +372,28 @@ func (c *cache) Close() error {
}
c.closed = true
c.refreshCancel()
c.cancel()
c.refresher.Stop()
c.cond.Broadcast()
return nil
}
// We have to do this to avoid a circular dependency on db2sdk (cryptokeys -> db2sdk -> tailnet -> cryptokeys)
func toSDKKeys(keys []database.CryptoKey) []codersdk.CryptoKey {
into := make([]codersdk.CryptoKey, 0, len(keys))
for _, key := range keys {
into = append(into, toSDK(key))
}
return into
}
func toSDK(key database.CryptoKey) codersdk.CryptoKey {
return codersdk.CryptoKey{
Feature: codersdk.CryptoKeyFeature(key.Feature),
Sequence: key.Sequence,
StartsAt: key.StartsAt,
DeletesAt: key.DeletesAt.Time,
Secret: key.Secret.String,
}
}
+1 -1
View File
@@ -488,7 +488,7 @@ type fakeFetcher struct {
called int
}
func (f *fakeFetcher) Fetch(_ context.Context) ([]codersdk.CryptoKey, error) {
func (f *fakeFetcher) Fetch(_ context.Context, _ codersdk.CryptoKeyFeature) ([]codersdk.CryptoKey, error) {
f.called++
return f.keys, nil
}
+14 -9
View File
@@ -11,6 +11,7 @@ import (
"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/quartz"
)
@@ -53,10 +54,12 @@ func WithKeyDuration(keyDuration time.Duration) RotatorOption {
// StartRotator starts a background process that rotates keys in the database.
// It ensures there's at least one valid key per feature prior to returning.
// Canceling the provided context will stop the background process.
func StartRotator(ctx context.Context, logger slog.Logger, db database.Store, opts ...RotatorOption) error {
func StartRotator(ctx context.Context, logger slog.Logger, db database.Store, opts ...RotatorOption) {
//nolint:gocritic // KeyRotator can only rotate crypto keys.
ctx = dbauthz.AsKeyRotator(ctx)
kr := &rotator{
db: db,
logger: logger,
logger: logger.Named("keyrotator"),
clock: quartz.NewReal(),
keyDuration: DefaultKeyDuration,
features: database.AllCryptoKeyFeatureValues(),
@@ -68,12 +71,10 @@ func StartRotator(ctx context.Context, logger slog.Logger, db database.Store, op
err := kr.rotateKeys(ctx)
if err != nil {
return xerrors.Errorf("rotate keys: %w", err)
kr.logger.Critical(ctx, "failed to rotate keys", slog.Error(err))
}
go kr.start(ctx)
return nil
}
// start begins the process of rotating keys.
@@ -227,9 +228,11 @@ func (k *rotator) rotateKey(ctx context.Context, tx database.Store, key database
func generateNewSecret(feature database.CryptoKeyFeature) (string, error) {
switch feature {
case database.CryptoKeyFeatureWorkspaceApps:
case database.CryptoKeyFeatureWorkspaceAppsAPIKey:
return generateKey(32)
case database.CryptoKeyFeatureOidcConvert:
case database.CryptoKeyFeatureWorkspaceAppsToken:
return generateKey(64)
case database.CryptoKeyFeatureOIDCConvert:
return generateKey(64)
case database.CryptoKeyFeatureTailnetResume:
return generateKey(64)
@@ -248,9 +251,11 @@ func generateKey(length int) (string, error) {
func tokenDuration(feature database.CryptoKeyFeature) time.Duration {
switch feature {
case database.CryptoKeyFeatureWorkspaceApps:
case database.CryptoKeyFeatureWorkspaceAppsAPIKey:
return WorkspaceAppsTokenDuration
case database.CryptoKeyFeatureOidcConvert:
case database.CryptoKeyFeatureWorkspaceAppsToken:
return WorkspaceAppsTokenDuration
case database.CryptoKeyFeatureOIDCConvert:
return OIDCConvertTokenDuration
case database.CryptoKeyFeatureTailnetResume:
return TailnetResumeTokenDuration
+44 -37
View File
@@ -38,7 +38,7 @@ func Test_rotateKeys(t *testing.T) {
clock: clock,
logger: logger,
features: []database.CryptoKeyFeature{
database.CryptoKeyFeatureWorkspaceApps,
database.CryptoKeyFeatureWorkspaceAppsAPIKey,
},
}
@@ -46,7 +46,7 @@ func Test_rotateKeys(t *testing.T) {
// Seed the database with an existing key.
oldKey := dbgen.CryptoKey(t, db, database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceApps,
Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey,
StartsAt: now,
Sequence: 15,
})
@@ -69,11 +69,11 @@ func Test_rotateKeys(t *testing.T) {
// The new key should be created and have a starts_at of the old key's expires_at.
newKey, err := db.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{
Feature: database.CryptoKeyFeatureWorkspaceApps,
Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey,
Sequence: oldKey.Sequence + 1,
})
require.NoError(t, err)
requireKey(t, newKey, database.CryptoKeyFeatureWorkspaceApps, oldKey.ExpiresAt(keyDuration), nullTime, oldKey.Sequence+1)
requireKey(t, newKey, database.CryptoKeyFeatureWorkspaceAppsAPIKey, oldKey.ExpiresAt(keyDuration), nullTime, oldKey.Sequence+1)
// Advance the clock just before the keys delete time.
clock.Advance(oldKey.DeletesAt.Time.UTC().Sub(now) - time.Second)
@@ -123,7 +123,7 @@ func Test_rotateKeys(t *testing.T) {
clock: clock,
logger: logger,
features: []database.CryptoKeyFeature{
database.CryptoKeyFeatureWorkspaceApps,
database.CryptoKeyFeatureWorkspaceAppsAPIKey,
},
}
@@ -131,7 +131,7 @@ func Test_rotateKeys(t *testing.T) {
// Seed the database with an existing key
existingKey := dbgen.CryptoKey(t, db, database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceApps,
Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey,
StartsAt: now,
Sequence: 123,
})
@@ -179,7 +179,7 @@ func Test_rotateKeys(t *testing.T) {
clock: clock,
logger: logger,
features: []database.CryptoKeyFeature{
database.CryptoKeyFeatureWorkspaceApps,
database.CryptoKeyFeatureWorkspaceAppsAPIKey,
},
}
@@ -187,7 +187,7 @@ func Test_rotateKeys(t *testing.T) {
// Seed the database with an existing key
deletingKey := dbgen.CryptoKey(t, db, database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceApps,
Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey,
StartsAt: now.Add(-keyDuration),
Sequence: 789,
DeletesAt: sql.NullTime{
@@ -232,7 +232,7 @@ func Test_rotateKeys(t *testing.T) {
clock: clock,
logger: logger,
features: []database.CryptoKeyFeature{
database.CryptoKeyFeatureWorkspaceApps,
database.CryptoKeyFeatureWorkspaceAppsAPIKey,
},
}
@@ -240,7 +240,7 @@ func Test_rotateKeys(t *testing.T) {
// Seed the database with an existing key
deletingKey := dbgen.CryptoKey(t, db, database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceApps,
Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey,
StartsAt: now,
Sequence: 456,
DeletesAt: sql.NullTime{
@@ -281,7 +281,7 @@ func Test_rotateKeys(t *testing.T) {
clock: clock,
logger: logger,
features: []database.CryptoKeyFeature{
database.CryptoKeyFeatureWorkspaceApps,
database.CryptoKeyFeatureWorkspaceAppsAPIKey,
},
}
@@ -291,7 +291,7 @@ func Test_rotateKeys(t *testing.T) {
keys, err := db.GetCryptoKeys(ctx)
require.NoError(t, err)
require.Len(t, keys, 1)
requireKey(t, keys[0], database.CryptoKeyFeatureWorkspaceApps, clock.Now().UTC(), nullTime, 1)
requireKey(t, keys[0], database.CryptoKeyFeatureWorkspaceAppsAPIKey, clock.Now().UTC(), nullTime, 1)
})
// Assert we insert a new key when the only key was manually deleted.
@@ -312,14 +312,14 @@ func Test_rotateKeys(t *testing.T) {
clock: clock,
logger: logger,
features: []database.CryptoKeyFeature{
database.CryptoKeyFeatureWorkspaceApps,
database.CryptoKeyFeatureWorkspaceAppsAPIKey,
},
}
now := dbnow(clock)
deletedkey := dbgen.CryptoKey(t, db, database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceApps,
Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey,
StartsAt: now,
Sequence: 19,
DeletesAt: sql.NullTime{
@@ -338,7 +338,7 @@ func Test_rotateKeys(t *testing.T) {
keys, err := db.GetCryptoKeys(ctx)
require.NoError(t, err)
require.Len(t, keys, 1)
requireKey(t, keys[0], database.CryptoKeyFeatureWorkspaceApps, now, nullTime, deletedkey.Sequence+1)
requireKey(t, keys[0], database.CryptoKeyFeatureWorkspaceAppsAPIKey, now, nullTime, deletedkey.Sequence+1)
})
// This tests ensures that rotation works with multiple
@@ -365,9 +365,11 @@ func Test_rotateKeys(t *testing.T) {
now := dbnow(clock)
// We'll test a scenario where one feature has no valid keys.
// Another has a key that should be rotate. And one that
// has a valid key that shouldn't trigger an action.
// We'll test a scenario where:
// - One feature has no valid keys.
// - One has a key that should be rotated.
// - One has a valid key that shouldn't trigger an action.
// - One has no keys at all.
_ = dbgen.CryptoKey(t, db, database.CryptoKey{
Feature: database.CryptoKeyFeatureTailnetResume,
StartsAt: now.Add(-keyDuration),
@@ -377,6 +379,7 @@ func Test_rotateKeys(t *testing.T) {
Valid: false,
},
})
// Generate another deleted key to ensure we insert after the latest sequence.
deletedKey := dbgen.CryptoKey(t, db, database.CryptoKey{
Feature: database.CryptoKeyFeatureTailnetResume,
StartsAt: now.Add(-keyDuration),
@@ -389,14 +392,14 @@ func Test_rotateKeys(t *testing.T) {
// Insert a key that should be rotated.
rotatedKey := dbgen.CryptoKey(t, db, database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceApps,
Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey,
StartsAt: now.Add(-keyDuration + time.Hour),
Sequence: 42,
})
// Insert a key that should not trigger an action.
validKey := dbgen.CryptoKey(t, db, database.CryptoKey{
Feature: database.CryptoKeyFeatureOidcConvert,
Feature: database.CryptoKeyFeatureOIDCConvert,
StartsAt: now,
Sequence: 17,
})
@@ -406,26 +409,28 @@ func Test_rotateKeys(t *testing.T) {
keys, err := db.GetCryptoKeys(ctx)
require.NoError(t, err)
require.Len(t, keys, 4)
require.Len(t, keys, 5)
kbf, err := keysByFeature(keys, database.AllCryptoKeyFeatureValues())
require.NoError(t, err)
// No actions on OIDC convert.
require.Len(t, kbf[database.CryptoKeyFeatureOidcConvert], 1)
require.Len(t, kbf[database.CryptoKeyFeatureOIDCConvert], 1)
// Workspace apps should have been rotated.
require.Len(t, kbf[database.CryptoKeyFeatureWorkspaceApps], 2)
require.Len(t, kbf[database.CryptoKeyFeatureWorkspaceAppsAPIKey], 2)
// No existing key for tailnet resume should've
// caused a key to be inserted.
require.Len(t, kbf[database.CryptoKeyFeatureTailnetResume], 1)
require.Len(t, kbf[database.CryptoKeyFeatureWorkspaceAppsToken], 1)
oidcKey := kbf[database.CryptoKeyFeatureOidcConvert][0]
oidcKey := kbf[database.CryptoKeyFeatureOIDCConvert][0]
tailnetKey := kbf[database.CryptoKeyFeatureTailnetResume][0]
requireKey(t, oidcKey, database.CryptoKeyFeatureOidcConvert, now, nullTime, validKey.Sequence)
appTokenKey := kbf[database.CryptoKeyFeatureWorkspaceAppsToken][0]
requireKey(t, oidcKey, database.CryptoKeyFeatureOIDCConvert, now, nullTime, validKey.Sequence)
requireKey(t, tailnetKey, database.CryptoKeyFeatureTailnetResume, now, nullTime, deletedKey.Sequence+1)
newKey := kbf[database.CryptoKeyFeatureWorkspaceApps][0]
oldKey := kbf[database.CryptoKeyFeatureWorkspaceApps][1]
requireKey(t, appTokenKey, database.CryptoKeyFeatureWorkspaceAppsToken, now, nullTime, 1)
newKey := kbf[database.CryptoKeyFeatureWorkspaceAppsAPIKey][0]
oldKey := kbf[database.CryptoKeyFeatureWorkspaceAppsAPIKey][1]
if newKey.Sequence == rotatedKey.Sequence {
oldKey, newKey = newKey, oldKey
}
@@ -433,8 +438,8 @@ func Test_rotateKeys(t *testing.T) {
Time: rotatedKey.ExpiresAt(keyDuration).Add(WorkspaceAppsTokenDuration + time.Hour),
Valid: true,
}
requireKey(t, oldKey, database.CryptoKeyFeatureWorkspaceApps, rotatedKey.StartsAt.UTC(), deletesAt, rotatedKey.Sequence)
requireKey(t, newKey, database.CryptoKeyFeatureWorkspaceApps, rotatedKey.ExpiresAt(keyDuration), nullTime, rotatedKey.Sequence+1)
requireKey(t, oldKey, database.CryptoKeyFeatureWorkspaceAppsAPIKey, rotatedKey.StartsAt.UTC(), deletesAt, rotatedKey.Sequence)
requireKey(t, newKey, database.CryptoKeyFeatureWorkspaceAppsAPIKey, rotatedKey.ExpiresAt(keyDuration), nullTime, rotatedKey.Sequence+1)
})
t.Run("UnknownFeature", func(t *testing.T) {
@@ -478,11 +483,11 @@ func Test_rotateKeys(t *testing.T) {
keyDuration: keyDuration,
clock: clock,
logger: logger,
features: []database.CryptoKeyFeature{database.CryptoKeyFeatureWorkspaceApps},
features: []database.CryptoKeyFeature{database.CryptoKeyFeatureWorkspaceAppsAPIKey},
}
expiringKey := dbgen.CryptoKey(t, db, database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceApps,
Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey,
StartsAt: now.Add(-keyDuration),
Sequence: 345,
})
@@ -522,19 +527,19 @@ func Test_rotateKeys(t *testing.T) {
keyDuration: keyDuration,
clock: clock,
logger: logger,
features: []database.CryptoKeyFeature{database.CryptoKeyFeatureWorkspaceApps},
features: []database.CryptoKeyFeature{database.CryptoKeyFeatureWorkspaceAppsAPIKey},
}
now := dbnow(clock)
expiredKey := dbgen.CryptoKey(t, db, database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceApps,
Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey,
StartsAt: now.Add(-keyDuration - 2*time.Hour),
Sequence: 19,
})
deletedKey := dbgen.CryptoKey(t, db, database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceApps,
Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey,
StartsAt: now,
Sequence: 20,
Secret: sql.NullString{
@@ -587,9 +592,11 @@ func requireKey(t *testing.T, key database.CryptoKey, feature database.CryptoKey
require.NoError(t, err)
switch key.Feature {
case database.CryptoKeyFeatureOidcConvert:
case database.CryptoKeyFeatureOIDCConvert:
require.Len(t, secret, 64)
case database.CryptoKeyFeatureWorkspaceApps:
case database.CryptoKeyFeatureWorkspaceAppsToken:
require.Len(t, secret, 64)
case database.CryptoKeyFeatureWorkspaceAppsAPIKey:
require.Len(t, secret, 32)
case database.CryptoKeyFeatureTailnetResume:
require.Len(t, secret, 64)
+4 -6
View File
@@ -34,8 +34,7 @@ func TestRotator(t *testing.T) {
require.NoError(t, err)
require.Len(t, dbkeys, 0)
err = cryptokeys.StartRotator(ctx, logger, db, cryptokeys.WithClock(clock))
require.NoError(t, err)
cryptokeys.StartRotator(ctx, logger, db, cryptokeys.WithClock(clock))
// Fetch the keys from the database and ensure they
// are as expected.
@@ -58,7 +57,7 @@ func TestRotator(t *testing.T) {
now := clock.Now().UTC()
rotatingKey := dbgen.CryptoKey(t, db, database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceApps,
Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey,
StartsAt: now.Add(-cryptokeys.DefaultKeyDuration + time.Hour + time.Minute),
Sequence: 12345,
})
@@ -66,8 +65,7 @@ func TestRotator(t *testing.T) {
trap := clock.Trap().TickerFunc()
t.Cleanup(trap.Close)
err := cryptokeys.StartRotator(ctx, logger, db, cryptokeys.WithClock(clock))
require.NoError(t, err)
cryptokeys.StartRotator(ctx, logger, db, cryptokeys.WithClock(clock))
initialKeyLen := len(database.AllCryptoKeyFeatureValues())
// Fetch the keys from the database and ensure they
@@ -85,7 +83,7 @@ func TestRotator(t *testing.T) {
require.NoError(t, err)
require.Len(t, keys, initialKeyLen+1)
newKey, err := db.GetLatestCryptoKeyByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps)
newKey, err := db.GetLatestCryptoKeyByFeature(ctx, database.CryptoKeyFeatureWorkspaceAppsAPIKey)
require.NoError(t, err)
require.Equal(t, rotatingKey.Sequence+1, newKey.Sequence)
require.Equal(t, rotatingKey.ExpiresAt(cryptokeys.DefaultKeyDuration), newKey.StartsAt.UTC())
+46
View File
@@ -228,6 +228,42 @@ var (
Scope: rbac.ScopeAll,
}.WithCachedASTValue()
// See cryptokeys package.
subjectCryptoKeyRotator = rbac.Subject{
FriendlyName: "Crypto Key Rotator",
ID: uuid.Nil.String(),
Roles: rbac.Roles([]rbac.Role{
{
Identifier: rbac.RoleIdentifier{Name: "keyrotator"},
DisplayName: "Key Rotator",
Site: rbac.Permissions(map[string][]policy.Action{
rbac.ResourceCryptoKey.Type: {policy.WildcardSymbol},
}),
Org: map[string][]rbac.Permission{},
User: []rbac.Permission{},
},
}),
Scope: rbac.ScopeAll,
}.WithCachedASTValue()
// See cryptokeys package.
subjectCryptoKeyReader = rbac.Subject{
FriendlyName: "Crypto Key Reader",
ID: uuid.Nil.String(),
Roles: rbac.Roles([]rbac.Role{
{
Identifier: rbac.RoleIdentifier{Name: "keyrotator"},
DisplayName: "Key Rotator",
Site: rbac.Permissions(map[string][]policy.Action{
rbac.ResourceCryptoKey.Type: {policy.WildcardSymbol},
}),
Org: map[string][]rbac.Permission{},
User: []rbac.Permission{},
},
}),
Scope: rbac.ScopeAll,
}.WithCachedASTValue()
subjectSystemRestricted = rbac.Subject{
FriendlyName: "System",
ID: uuid.Nil.String(),
@@ -281,6 +317,16 @@ func AsHangDetector(ctx context.Context) context.Context {
return context.WithValue(ctx, authContextKey{}, subjectHangDetector)
}
// AsKeyRotator returns a context with an actor that has permissions required for rotating crypto keys.
func AsKeyRotator(ctx context.Context) context.Context {
return context.WithValue(ctx, authContextKey{}, subjectCryptoKeyRotator)
}
// AsKeyReader returns a context with an actor that has permissions required for reading crypto keys.
func AsKeyReader(ctx context.Context) context.Context {
return context.WithValue(ctx, authContextKey{}, subjectCryptoKeyReader)
}
// AsSystemRestricted returns a context with an actor that has permissions
// required for various system operations (login, logout, metrics cache).
func AsSystemRestricted(ctx context.Context) context.Context {
+7 -7
View File
@@ -2243,13 +2243,13 @@ func (s *MethodTestSuite) TestCryptoKeys() {
}))
s.Run("InsertCryptoKey", s.Subtest(func(db database.Store, check *expects) {
check.Args(database.InsertCryptoKeyParams{
Feature: database.CryptoKeyFeatureWorkspaceApps,
Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey,
}).
Asserts(rbac.ResourceCryptoKey, policy.ActionCreate)
}))
s.Run("DeleteCryptoKey", s.Subtest(func(db database.Store, check *expects) {
key := dbgen.CryptoKey(s.T(), db, database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceApps,
Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey,
Sequence: 4,
})
check.Args(database.DeleteCryptoKeyParams{
@@ -2259,7 +2259,7 @@ func (s *MethodTestSuite) TestCryptoKeys() {
}))
s.Run("GetCryptoKeyByFeatureAndSequence", s.Subtest(func(db database.Store, check *expects) {
key := dbgen.CryptoKey(s.T(), db, database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceApps,
Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey,
Sequence: 4,
})
check.Args(database.GetCryptoKeyByFeatureAndSequenceParams{
@@ -2269,14 +2269,14 @@ func (s *MethodTestSuite) TestCryptoKeys() {
}))
s.Run("GetLatestCryptoKeyByFeature", s.Subtest(func(db database.Store, check *expects) {
dbgen.CryptoKey(s.T(), db, database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceApps,
Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey,
Sequence: 4,
})
check.Args(database.CryptoKeyFeatureWorkspaceApps).Asserts(rbac.ResourceCryptoKey, policy.ActionRead)
check.Args(database.CryptoKeyFeatureWorkspaceAppsAPIKey).Asserts(rbac.ResourceCryptoKey, policy.ActionRead)
}))
s.Run("UpdateCryptoKeyDeletesAt", s.Subtest(func(db database.Store, check *expects) {
key := dbgen.CryptoKey(s.T(), db, database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceApps,
Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey,
Sequence: 4,
})
check.Args(database.UpdateCryptoKeyDeletesAtParams{
@@ -2286,7 +2286,7 @@ func (s *MethodTestSuite) TestCryptoKeys() {
}).Asserts(rbac.ResourceCryptoKey, policy.ActionUpdate)
}))
s.Run("GetCryptoKeysByFeature", s.Subtest(func(db database.Store, check *expects) {
check.Args(database.CryptoKeyFeatureWorkspaceApps).
check.Args(database.CryptoKeyFeatureWorkspaceAppsAPIKey).
Asserts(rbac.ResourceCryptoKey, policy.ActionRead)
}))
}
+5 -3
View File
@@ -943,7 +943,7 @@ func CustomRole(t testing.TB, db database.Store, seed database.CustomRole) datab
func CryptoKey(t testing.TB, db database.Store, seed database.CryptoKey) database.CryptoKey {
t.Helper()
seed.Feature = takeFirst(seed.Feature, database.CryptoKeyFeatureWorkspaceApps)
seed.Feature = takeFirst(seed.Feature, database.CryptoKeyFeatureWorkspaceAppsAPIKey)
// An empty string for the secret is interpreted as
// a caller wanting a new secret to be generated.
@@ -1048,9 +1048,11 @@ func takeFirst[Value comparable](values ...Value) Value {
func newCryptoKeySecret(feature database.CryptoKeyFeature) (string, error) {
switch feature {
case database.CryptoKeyFeatureWorkspaceApps:
case database.CryptoKeyFeatureWorkspaceAppsAPIKey:
return generateCryptoKey(32)
case database.CryptoKeyFeatureOidcConvert:
case database.CryptoKeyFeatureWorkspaceAppsToken:
return generateCryptoKey(64)
case database.CryptoKeyFeatureOIDCConvert:
return generateCryptoKey(64)
case database.CryptoKeyFeatureTailnetResume:
return generateCryptoKey(64)
+2 -1
View File
@@ -38,7 +38,8 @@ CREATE TYPE build_reason AS ENUM (
);
CREATE TYPE crypto_key_feature AS ENUM (
'workspace_apps',
'workspace_apps_token',
'workspace_apps_api_key',
'oidc_convert',
'tailnet_resume'
);
@@ -0,0 +1,18 @@
-- Step 1: Remove the new entries from crypto_keys table
DELETE FROM crypto_keys
WHERE feature IN ('workspace_apps_token', 'workspace_apps_api_key');
CREATE TYPE old_crypto_key_feature AS ENUM (
'workspace_apps',
'oidc_convert',
'tailnet_resume'
);
ALTER TABLE crypto_keys
ALTER COLUMN feature TYPE old_crypto_key_feature
USING (feature::text::old_crypto_key_feature);
DROP TYPE crypto_key_feature;
ALTER TYPE old_crypto_key_feature RENAME TO crypto_key_feature;
@@ -0,0 +1,18 @@
-- Create a new enum type with the desired values
CREATE TYPE new_crypto_key_feature AS ENUM (
'workspace_apps_token',
'workspace_apps_api_key',
'oidc_convert',
'tailnet_resume'
);
DELETE FROM crypto_keys WHERE feature = 'workspace_apps';
-- Drop the old type and rename the new one
ALTER TABLE crypto_keys
ALTER COLUMN feature TYPE new_crypto_key_feature
USING (feature::text::new_crypto_key_feature);
DROP TYPE crypto_key_feature;
ALTER TYPE new_crypto_key_feature RENAME TO crypto_key_feature;
@@ -0,0 +1,40 @@
INSERT INTO crypto_keys (feature, sequence, secret, secret_key_id, starts_at, deletes_at)
VALUES (
'workspace_apps_token',
1,
'abc',
NULL,
'1970-01-01 00:00:00 UTC'::timestamptz,
'2100-01-01 00:00:00 UTC'::timestamptz
);
INSERT INTO crypto_keys (feature, sequence, secret, secret_key_id, starts_at, deletes_at)
VALUES (
'workspace_apps_api_key',
1,
'def',
NULL,
'1970-01-01 00:00:00 UTC'::timestamptz,
'2100-01-01 00:00:00 UTC'::timestamptz
);
INSERT INTO crypto_keys (feature, sequence, secret, secret_key_id, starts_at, deletes_at)
VALUES (
'oidc_convert',
2,
'ghi',
NULL,
'1970-01-01 00:00:00 UTC'::timestamptz,
'2100-01-01 00:00:00 UTC'::timestamptz
);
INSERT INTO crypto_keys (feature, sequence, secret, secret_key_id, starts_at, deletes_at)
VALUES (
'tailnet_resume',
2,
'jkl',
NULL,
'1970-01-01 00:00:00 UTC'::timestamptz,
'2100-01-01 00:00:00 UTC'::timestamptz
);
+10 -7
View File
@@ -345,9 +345,10 @@ func AllBuildReasonValues() []BuildReason {
type CryptoKeyFeature string
const (
CryptoKeyFeatureWorkspaceApps CryptoKeyFeature = "workspace_apps"
CryptoKeyFeatureOidcConvert CryptoKeyFeature = "oidc_convert"
CryptoKeyFeatureTailnetResume CryptoKeyFeature = "tailnet_resume"
CryptoKeyFeatureWorkspaceAppsToken CryptoKeyFeature = "workspace_apps_token"
CryptoKeyFeatureWorkspaceAppsAPIKey CryptoKeyFeature = "workspace_apps_api_key"
CryptoKeyFeatureOIDCConvert CryptoKeyFeature = "oidc_convert"
CryptoKeyFeatureTailnetResume CryptoKeyFeature = "tailnet_resume"
)
func (e *CryptoKeyFeature) Scan(src interface{}) error {
@@ -387,8 +388,9 @@ func (ns NullCryptoKeyFeature) Value() (driver.Value, error) {
func (e CryptoKeyFeature) Valid() bool {
switch e {
case CryptoKeyFeatureWorkspaceApps,
CryptoKeyFeatureOidcConvert,
case CryptoKeyFeatureWorkspaceAppsToken,
CryptoKeyFeatureWorkspaceAppsAPIKey,
CryptoKeyFeatureOIDCConvert,
CryptoKeyFeatureTailnetResume:
return true
}
@@ -397,8 +399,9 @@ func (e CryptoKeyFeature) Valid() bool {
func AllCryptoKeyFeatureValues() []CryptoKeyFeature {
return []CryptoKeyFeature{
CryptoKeyFeatureWorkspaceApps,
CryptoKeyFeatureOidcConvert,
CryptoKeyFeatureWorkspaceAppsToken,
CryptoKeyFeatureWorkspaceAppsAPIKey,
CryptoKeyFeatureOIDCConvert,
CryptoKeyFeatureTailnetResume,
}
}
+2
View File
@@ -135,6 +135,8 @@ sql:
api_key_id: APIKeyID
callback_url: CallbackURL
login_type_oauth2_provider_app: LoginTypeOAuth2ProviderApp
crypto_key_feature_workspace_apps_api_key: CryptoKeyFeatureWorkspaceAppsAPIKey
crypto_key_feature_oidc_convert: CryptoKeyFeatureOIDCConvert
rules:
- name: do-not-use-public-schema-in-queries
message: "do not use public schema in queries"
+7 -1
View File
@@ -65,6 +65,12 @@ func Encrypt(ctx context.Context, e EncryptKeyProvider, claims Claims) (string,
return compact, nil
}
func WithDecryptExpected(expected jwt.Expected) func(*DecryptOptions) {
return func(opts *DecryptOptions) {
opts.RegisteredClaims = expected
}
}
// DecryptOptions are options for decrypting a JWE.
type DecryptOptions struct {
RegisteredClaims jwt.Expected
@@ -100,7 +106,7 @@ func Decrypt(ctx context.Context, d DecryptKeyProvider, token string, claims Cla
kid := object.Header.KeyID
if kid == "" {
return xerrors.Errorf("expected %q header to be a string", keyIDHeaderKey)
return ErrMissingKeyID
}
key, err := d.DecryptingKey(ctx, kid)
+61 -1
View File
@@ -10,10 +10,27 @@ import (
"golang.org/x/xerrors"
)
var ErrMissingKeyID = xerrors.New("missing key ID")
const (
keyIDHeaderKey = "kid"
)
// RegisteredClaims is a convenience type for embedding jwt.Claims. It should be
// preferred over embedding jwt.Claims directly since it will ensure that certain fields are set.
type RegisteredClaims jwt.Claims
func (r RegisteredClaims) Validate(e jwt.Expected) error {
if r.Expiry == nil {
return xerrors.Errorf("expiry is required")
}
if e.Time.IsZero() {
return xerrors.Errorf("expected time is required")
}
return (jwt.Claims(r)).Validate(e)
}
// Claims defines the payload for a JWT. Most callers
// should embed jwt.Claims
type Claims interface {
@@ -24,6 +41,11 @@ const (
signingAlgo = jose.HS512
)
type SigningKeyManager interface {
SigningKeyProvider
VerifyKeyProvider
}
type SigningKeyProvider interface {
SigningKey(ctx context.Context) (id string, key interface{}, err error)
}
@@ -75,6 +97,12 @@ type VerifyOptions struct {
SignatureAlgorithm jose.SignatureAlgorithm
}
func WithVerifyExpected(expected jwt.Expected) func(*VerifyOptions) {
return func(opts *VerifyOptions) {
opts.RegisteredClaims = expected
}
}
// Verify verifies that a token was signed by the provided key. It unmarshals into the provided claims.
func Verify(ctx context.Context, v VerifyKeyProvider, token string, claims Claims, opts ...func(*VerifyOptions)) error {
options := VerifyOptions{
@@ -105,7 +133,7 @@ func Verify(ctx context.Context, v VerifyKeyProvider, token string, claims Claim
kid := signature.Header.KeyID
if kid == "" {
return xerrors.Errorf("expected %q header to be a string", keyIDHeaderKey)
return ErrMissingKeyID
}
key, err := v.VerifyingKey(ctx, kid)
@@ -125,3 +153,35 @@ func Verify(ctx context.Context, v VerifyKeyProvider, token string, claims Claim
return claims.Validate(options.RegisteredClaims)
}
// StaticKey fulfills the SigningKeycache and EncryptionKeycache interfaces. Useful for testing.
type StaticKey struct {
ID string
Key interface{}
}
func (s StaticKey) SigningKey(_ context.Context) (string, interface{}, error) {
return s.ID, s.Key, nil
}
func (s StaticKey) VerifyingKey(_ context.Context, id string) (interface{}, error) {
if id != s.ID {
return nil, xerrors.Errorf("invalid id %q", id)
}
return s.Key, nil
}
func (s StaticKey) EncryptingKey(_ context.Context) (string, interface{}, error) {
return s.ID, s.Key, nil
}
func (s StaticKey) DecryptingKey(_ context.Context, id string) (interface{}, error) {
if id != s.ID {
return nil, xerrors.Errorf("invalid id %q", id)
}
return s.Key, nil
}
func (StaticKey) Close() error {
return nil
}
+5 -5
View File
@@ -236,11 +236,11 @@ func TestJWS(t *testing.T) {
ctx = testutil.Context(t, testutil.WaitShort)
db, _ = dbtestutil.NewDB(t)
_ = dbgen.CryptoKey(t, db, database.CryptoKey{
Feature: database.CryptoKeyFeatureOidcConvert,
Feature: database.CryptoKeyFeatureOIDCConvert,
StartsAt: time.Now(),
})
log = slogtest.Make(t, nil)
fetcher = &cryptokeys.DBFetcher{DB: db, Feature: database.CryptoKeyFeatureOidcConvert}
fetcher = &cryptokeys.DBFetcher{DB: db}
)
cache, err := cryptokeys.NewSigningCache(ctx, log, fetcher, codersdk.CryptoKeyFeatureOIDCConvert)
@@ -326,15 +326,15 @@ func TestJWE(t *testing.T) {
ctx = testutil.Context(t, testutil.WaitShort)
db, _ = dbtestutil.NewDB(t)
_ = dbgen.CryptoKey(t, db, database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceApps,
Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey,
StartsAt: time.Now(),
})
log = slogtest.Make(t, nil)
fetcher = &cryptokeys.DBFetcher{DB: db, Feature: database.CryptoKeyFeatureWorkspaceApps}
fetcher = &cryptokeys.DBFetcher{DB: db}
)
cache, err := cryptokeys.NewEncryptionCache(ctx, log, fetcher, codersdk.CryptoKeyFeatureWorkspaceApp)
cache, err := cryptokeys.NewEncryptionCache(ctx, log, fetcher, codersdk.CryptoKeyFeatureWorkspaceAppsAPIKey)
require.NoError(t, err)
claims := testClaims{
+22 -18
View File
@@ -15,7 +15,8 @@ import (
"time"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/golang-jwt/jwt/v4"
"github.com/go-jose/go-jose/v4"
"github.com/go-jose/go-jose/v4/jwt"
"github.com/google/go-github/v43/github"
"github.com/google/uuid"
"github.com/moby/moby/pkg/namesgenerator"
@@ -23,6 +24,9 @@ import (
"golang.org/x/xerrors"
"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/cryptokeys"
"github.com/coder/coder/v2/coderd/idpsync"
"github.com/coder/coder/v2/coderd/jwtutils"
"github.com/coder/coder/v2/coderd/apikey"
"github.com/coder/coder/v2/coderd/audit"
@@ -32,7 +36,6 @@ import (
"github.com/coder/coder/v2/coderd/externalauth"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/coderd/idpsync"
"github.com/coder/coder/v2/coderd/notifications"
"github.com/coder/coder/v2/coderd/promoauth"
"github.com/coder/coder/v2/coderd/rbac"
@@ -49,7 +52,7 @@ const (
)
type OAuthConvertStateClaims struct {
jwt.RegisteredClaims
jwtutils.RegisteredClaims
UserID uuid.UUID `json:"user_id"`
State string `json:"state"`
@@ -57,6 +60,10 @@ type OAuthConvertStateClaims struct {
ToLoginType codersdk.LoginType `json:"to_login_type"`
}
func (o *OAuthConvertStateClaims) Validate(e jwt.Expected) error {
return o.RegisteredClaims.Validate(e)
}
// postConvertLoginType replies with an oauth state token capable of converting
// the user to an oauth user.
//
@@ -149,11 +156,11 @@ func (api *API) postConvertLoginType(rw http.ResponseWriter, r *http.Request) {
// Eg: Developers with more than 1 deployment.
now := time.Now()
claims := &OAuthConvertStateClaims{
RegisteredClaims: jwt.RegisteredClaims{
RegisteredClaims: jwtutils.RegisteredClaims{
Issuer: api.DeploymentID,
Subject: stateString,
Audience: []string{user.ID.String()},
ExpiresAt: jwt.NewNumericDate(now.Add(time.Minute * 5)),
Expiry: jwt.NewNumericDate(now.Add(time.Minute * 5)),
NotBefore: jwt.NewNumericDate(now.Add(time.Second * -1)),
IssuedAt: jwt.NewNumericDate(now),
ID: uuid.NewString(),
@@ -164,9 +171,7 @@ func (api *API) postConvertLoginType(rw http.ResponseWriter, r *http.Request) {
ToLoginType: req.ToType,
}
token := jwt.NewWithClaims(jwt.SigningMethodHS512, claims)
// Key must be a byte slice, not an array. So make sure to include the [:]
tokenString, err := token.SignedString(api.OAuthSigningKey[:])
token, err := jwtutils.Sign(ctx, api.OIDCConvertKeyCache, claims)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error signing state jwt.",
@@ -176,8 +181,8 @@ func (api *API) postConvertLoginType(rw http.ResponseWriter, r *http.Request) {
}
aReq.New = database.AuditOAuthConvertState{
CreatedAt: claims.IssuedAt.Time,
ExpiresAt: claims.ExpiresAt.Time,
CreatedAt: claims.IssuedAt.Time(),
ExpiresAt: claims.Expiry.Time(),
FromLoginType: database.LoginType(claims.FromLoginType),
ToLoginType: database.LoginType(claims.ToLoginType),
UserID: claims.UserID,
@@ -186,8 +191,8 @@ func (api *API) postConvertLoginType(rw http.ResponseWriter, r *http.Request) {
http.SetCookie(rw, &http.Cookie{
Name: OAuthConvertCookieValue,
Path: "/",
Value: tokenString,
Expires: claims.ExpiresAt.Time,
Value: token,
Expires: claims.Expiry.Time(),
Secure: api.SecureAuthCookie,
HttpOnly: true,
// Must be SameSite to work on the redirected auth flow from the
@@ -196,7 +201,7 @@ func (api *API) postConvertLoginType(rw http.ResponseWriter, r *http.Request) {
})
httpapi.Write(ctx, rw, http.StatusCreated, codersdk.OAuthConversionResponse{
StateString: stateString,
ExpiresAt: claims.ExpiresAt.Time,
ExpiresAt: claims.Expiry.Time(),
ToType: claims.ToLoginType,
UserID: claims.UserID,
})
@@ -1677,10 +1682,9 @@ func (api *API) convertUserToOauth(ctx context.Context, r *http.Request, db data
}
}
var claims OAuthConvertStateClaims
token, err := jwt.ParseWithClaims(jwtCookie.Value, &claims, func(_ *jwt.Token) (interface{}, error) {
return api.OAuthSigningKey[:], nil
})
if xerrors.Is(err, jwt.ErrSignatureInvalid) || !token.Valid {
err = jwtutils.Verify(ctx, api.OIDCConvertKeyCache, jwtCookie.Value, &claims)
if xerrors.Is(err, cryptokeys.ErrKeyNotFound) || xerrors.Is(err, cryptokeys.ErrKeyInvalid) || xerrors.Is(err, jose.ErrCryptoFailure) || xerrors.Is(err, jwtutils.ErrMissingKeyID) {
// These errors are probably because the user is mixing 2 coder deployments.
return database.User{}, idpsync.HTTPError{
Code: http.StatusBadRequest,
@@ -1709,7 +1713,7 @@ func (api *API) convertUserToOauth(ctx context.Context, r *http.Request, db data
oauthConvertAudit.UserID = claims.UserID
oauthConvertAudit.Old = user
if claims.RegisteredClaims.Issuer != api.DeploymentID {
if claims.Issuer != api.DeploymentID {
return database.User{}, idpsync.HTTPError{
Code: http.StatusForbidden,
Msg: "Request to convert login type failed. Issuer mismatch. Found a cookie from another coder deployment, please try again.",
+123 -1
View File
@@ -3,6 +3,8 @@ package coderd_test
import (
"context"
"crypto"
"crypto/rand"
"encoding/json"
"fmt"
"io"
"net/http"
@@ -13,6 +15,7 @@ import (
"time"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/go-jose/go-jose/v4"
"github.com/golang-jwt/jwt/v4"
"github.com/google/go-github/v43/github"
"github.com/google/uuid"
@@ -27,10 +30,12 @@ import (
"github.com/coder/coder/v2/coderd/audit"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
"github.com/coder/coder/v2/coderd/cryptokeys"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/jwtutils"
"github.com/coder/coder/v2/coderd/notifications"
"github.com/coder/coder/v2/coderd/promoauth"
"github.com/coder/coder/v2/codersdk"
@@ -1316,6 +1321,7 @@ func TestUserOIDC(t *testing.T) {
owner := coderdtest.CreateFirstUser(t, client)
user, userData := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
require.Equal(t, codersdk.LoginTypePassword, userData.LoginType)
claims := jwt.MapClaims{
"email": userData.Email,
@@ -1323,15 +1329,17 @@ func TestUserOIDC(t *testing.T) {
var err error
user.HTTPClient.Jar, err = cookiejar.New(nil)
require.NoError(t, err)
user.HTTPClient.Transport = http.DefaultTransport.(*http.Transport).Clone()
ctx := testutil.Context(t, testutil.WaitShort)
convertResponse, err := user.ConvertLoginType(ctx, codersdk.ConvertLoginRequest{
ToType: codersdk.LoginTypeOIDC,
Password: "SomeSecurePassword!",
})
require.NoError(t, err)
fake.LoginWithClient(t, user, claims, func(r *http.Request) {
_, _ = fake.LoginWithClient(t, user, claims, func(r *http.Request) {
r.URL.RawQuery = url.Values{
"oidc_merge_state": {convertResponse.StateString},
}.Encode()
@@ -1341,6 +1349,99 @@ func TestUserOIDC(t *testing.T) {
r.AddCookie(cookie)
}
})
info, err := client.User(ctx, userData.ID.String())
require.NoError(t, err)
require.Equal(t, codersdk.LoginTypeOIDC, info.LoginType)
})
t.Run("BadJWT", func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitMedium)
logger = slogtest.Make(t, nil)
)
auditor := audit.NewMock()
fake := oidctest.NewFakeIDP(t,
oidctest.WithRefresh(func(_ string) error {
return xerrors.New("refreshing token should never occur")
}),
oidctest.WithServing(),
)
cfg := fake.OIDCConfig(t, nil, func(cfg *coderd.OIDCConfig) {
cfg.AllowSignups = true
})
db, ps := dbtestutil.NewDB(t)
fetcher := &cryptokeys.DBFetcher{
DB: db,
}
kc, err := cryptokeys.NewSigningCache(ctx, logger, fetcher, codersdk.CryptoKeyFeatureOIDCConvert)
require.NoError(t, err)
client := coderdtest.New(t, &coderdtest.Options{
Auditor: auditor,
OIDCConfig: cfg,
Database: db,
Pubsub: ps,
OIDCConvertKeyCache: kc,
})
owner := coderdtest.CreateFirstUser(t, client)
user, userData := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
claims := jwt.MapClaims{
"email": userData.Email,
}
user.HTTPClient.Jar, err = cookiejar.New(nil)
require.NoError(t, err)
user.HTTPClient.Transport = http.DefaultTransport.(*http.Transport).Clone()
convertResponse, err := user.ConvertLoginType(ctx, codersdk.ConvertLoginRequest{
ToType: codersdk.LoginTypeOIDC,
Password: "SomeSecurePassword!",
})
require.NoError(t, err)
// Update the cookie to use a bad signing key. We're asserting the behavior of the scenario
// where a JWT gets minted on an old version of Coder but gets verified on a new version.
_, resp := fake.AttemptLogin(t, user, claims, func(r *http.Request) {
r.URL.RawQuery = url.Values{
"oidc_merge_state": {convertResponse.StateString},
}.Encode()
r.Header.Set(codersdk.SessionTokenHeader, user.SessionToken())
cookies := user.HTTPClient.Jar.Cookies(user.URL)
for i, cookie := range cookies {
if cookie.Name != coderd.OAuthConvertCookieValue {
continue
}
jwt := cookie.Value
var claims coderd.OAuthConvertStateClaims
err := jwtutils.Verify(ctx, kc, jwt, &claims)
require.NoError(t, err)
badJWT := generateBadJWT(t, claims)
cookie.Value = badJWT
cookies[i] = cookie
}
user.HTTPClient.Jar.SetCookies(user.URL, cookies)
for _, cookie := range cookies {
fmt.Printf("cookie: %+v\n", cookie)
r.AddCookie(cookie)
}
})
defer resp.Body.Close()
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
var respErr codersdk.Response
err = json.NewDecoder(resp.Body).Decode(&respErr)
require.NoError(t, err)
require.Contains(t, respErr.Message, "Using an invalid jwt to authorize this action.")
})
t.Run("AlternateUsername", func(t *testing.T) {
@@ -2022,3 +2123,24 @@ func inflateClaims(t testing.TB, seed jwt.MapClaims, size int) jwt.MapClaims {
seed["random_data"] = junk
return seed
}
// generateBadJWT generates a JWT with a random key. It's intended to emulate the old-style JWT's we generated.
func generateBadJWT(t *testing.T, claims interface{}) string {
t.Helper()
var buf [64]byte
_, err := rand.Read(buf[:])
require.NoError(t, err)
signer, err := jose.NewSigner(jose.SigningKey{
Algorithm: jose.HS512,
Key: buf[:],
}, nil)
require.NoError(t, err)
payload, err := json.Marshal(claims)
require.NoError(t, err)
signed, err := signer.Sign(payload)
require.NoError(t, err)
compact, err := signed.CompactSerialize()
require.NoError(t, err)
return compact
}
+10 -4
View File
@@ -32,6 +32,7 @@ import (
"github.com/coder/coder/v2/coderd/externalauth"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/coderd/jwtutils"
"github.com/coder/coder/v2/coderd/rbac/policy"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
@@ -852,8 +853,12 @@ func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.R
)
if resumeToken != "" {
var err error
peerID, err = api.Options.CoordinatorResumeTokenProvider.VerifyResumeToken(resumeToken)
if err != nil {
peerID, err = api.Options.CoordinatorResumeTokenProvider.VerifyResumeToken(ctx, resumeToken)
// If the token is missing the key ID, it's probably an old token in which
// case we just want to generate a new peer ID.
if xerrors.Is(err, jwtutils.ErrMissingKeyID) {
peerID = uuid.New()
} else if err != nil {
httpapi.Write(ctx, rw, http.StatusUnauthorized, codersdk.Response{
Message: workspacesdk.CoordinateAPIInvalidResumeToken,
Detail: err.Error(),
@@ -862,9 +867,10 @@ func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.R
},
})
return
} else {
api.Logger.Debug(ctx, "accepted coordinate resume token for peer",
slog.F("peer_id", peerID.String()))
}
api.Logger.Debug(ctx, "accepted coordinate resume token for peer",
slog.F("peer_id", peerID.String()))
}
api.WebsocketWaitMutex.Lock()
+130 -61
View File
@@ -13,6 +13,7 @@ import (
"testing"
"time"
"github.com/go-jose/go-jose/v4/jwt"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -36,6 +37,7 @@ import (
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/database/pubsub"
"github.com/coder/coder/v2/coderd/externalauth"
"github.com/coder/coder/v2/coderd/jwtutils"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
@@ -531,20 +533,20 @@ func newResumeTokenRecordingProvider(t testing.TB, underlying tailnet.ResumeToke
}
}
func (r *resumeTokenRecordingProvider) GenerateResumeToken(peerID uuid.UUID) (*tailnetproto.RefreshResumeTokenResponse, error) {
func (r *resumeTokenRecordingProvider) GenerateResumeToken(ctx context.Context, peerID uuid.UUID) (*tailnetproto.RefreshResumeTokenResponse, error) {
select {
case r.generateCalls <- peerID:
return r.ResumeTokenProvider.GenerateResumeToken(peerID)
return r.ResumeTokenProvider.GenerateResumeToken(ctx, peerID)
default:
r.t.Error("generateCalls full")
return nil, xerrors.New("generateCalls full")
}
}
func (r *resumeTokenRecordingProvider) VerifyResumeToken(token string) (uuid.UUID, error) {
func (r *resumeTokenRecordingProvider) VerifyResumeToken(ctx context.Context, token string) (uuid.UUID, error) {
select {
case r.verifyCalls <- token:
return r.ResumeTokenProvider.VerifyResumeToken(token)
return r.ResumeTokenProvider.VerifyResumeToken(ctx, token)
default:
r.t.Error("verifyCalls full")
return uuid.Nil, xerrors.New("verifyCalls full")
@@ -554,69 +556,136 @@ func (r *resumeTokenRecordingProvider) VerifyResumeToken(token string) (uuid.UUI
func TestWorkspaceAgentClientCoordinate_ResumeToken(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
clock := quartz.NewMock(t)
resumeTokenSigningKey, err := tailnet.GenerateResumeTokenSigningKey()
require.NoError(t, err)
resumeTokenProvider := newResumeTokenRecordingProvider(
t,
tailnet.NewResumeTokenKeyProvider(resumeTokenSigningKey, clock, time.Hour),
)
client, closer, api := coderdtest.NewWithAPI(t, &coderdtest.Options{
Coordinator: tailnet.NewCoordinator(logger),
CoordinatorResumeTokenProvider: resumeTokenProvider,
t.Run("OK", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
clock := quartz.NewMock(t)
resumeTokenSigningKey, err := tailnet.GenerateResumeTokenSigningKey()
mgr := jwtutils.StaticKey{
ID: uuid.New().String(),
Key: resumeTokenSigningKey[:],
}
require.NoError(t, err)
resumeTokenProvider := newResumeTokenRecordingProvider(
t,
tailnet.NewResumeTokenKeyProvider(mgr, clock, time.Hour),
)
client, closer, api := coderdtest.NewWithAPI(t, &coderdtest.Options{
Coordinator: tailnet.NewCoordinator(logger),
CoordinatorResumeTokenProvider: resumeTokenProvider,
})
defer closer.Close()
user := coderdtest.CreateFirstUser(t, client)
// Create a workspace with an agent. No need to connect it since clients can
// still connect to the coordinator while the agent isn't connected.
r := dbfake.WorkspaceBuild(t, api.Database, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
}).WithAgent().Do()
agentTokenUUID, err := uuid.Parse(r.AgentToken)
require.NoError(t, err)
ctx := testutil.Context(t, testutil.WaitLong)
agentAndBuild, err := api.Database.GetWorkspaceAgentAndLatestBuildByAuthToken(dbauthz.AsSystemRestricted(ctx), agentTokenUUID) //nolint
require.NoError(t, err)
// Connect with no resume token, and ensure that the peer ID is set to a
// random value.
originalResumeToken, err := connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, "")
require.NoError(t, err)
originalPeerID := testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.generateCalls)
require.NotEqual(t, originalPeerID, uuid.Nil)
// Connect with a valid resume token, and ensure that the peer ID is set to
// the stored value.
clock.Advance(time.Second)
newResumeToken, err := connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, originalResumeToken)
require.NoError(t, err)
verifiedToken := testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.verifyCalls)
require.Equal(t, originalResumeToken, verifiedToken)
newPeerID := testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.generateCalls)
require.Equal(t, originalPeerID, newPeerID)
require.NotEqual(t, originalResumeToken, newResumeToken)
// Connect with an invalid resume token, and ensure that the request is
// rejected.
clock.Advance(time.Second)
_, err = connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, "invalid")
require.Error(t, err)
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusUnauthorized, sdkErr.StatusCode())
require.Len(t, sdkErr.Validations, 1)
require.Equal(t, "resume_token", sdkErr.Validations[0].Field)
verifiedToken = testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.verifyCalls)
require.Equal(t, "invalid", verifiedToken)
select {
case <-resumeTokenProvider.generateCalls:
t.Fatal("unexpected peer ID in channel")
default:
}
})
defer closer.Close()
user := coderdtest.CreateFirstUser(t, client)
// Create a workspace with an agent. No need to connect it since clients can
// still connect to the coordinator while the agent isn't connected.
r := dbfake.WorkspaceBuild(t, api.Database, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
}).WithAgent().Do()
agentTokenUUID, err := uuid.Parse(r.AgentToken)
require.NoError(t, err)
ctx := testutil.Context(t, testutil.WaitLong)
agentAndBuild, err := api.Database.GetWorkspaceAgentAndLatestBuildByAuthToken(dbauthz.AsSystemRestricted(ctx), agentTokenUUID) //nolint
require.NoError(t, err)
t.Run("BadJWT", func(t *testing.T) {
t.Parallel()
// Connect with no resume token, and ensure that the peer ID is set to a
// random value.
originalResumeToken, err := connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, "")
require.NoError(t, err)
originalPeerID := testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.generateCalls)
require.NotEqual(t, originalPeerID, uuid.Nil)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
clock := quartz.NewMock(t)
resumeTokenSigningKey, err := tailnet.GenerateResumeTokenSigningKey()
mgr := jwtutils.StaticKey{
ID: uuid.New().String(),
Key: resumeTokenSigningKey[:],
}
require.NoError(t, err)
resumeTokenProvider := newResumeTokenRecordingProvider(
t,
tailnet.NewResumeTokenKeyProvider(mgr, clock, time.Hour),
)
client, closer, api := coderdtest.NewWithAPI(t, &coderdtest.Options{
Coordinator: tailnet.NewCoordinator(logger),
CoordinatorResumeTokenProvider: resumeTokenProvider,
})
defer closer.Close()
user := coderdtest.CreateFirstUser(t, client)
// Connect with a valid resume token, and ensure that the peer ID is set to
// the stored value.
clock.Advance(time.Second)
newResumeToken, err := connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, originalResumeToken)
require.NoError(t, err)
verifiedToken := testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.verifyCalls)
require.Equal(t, originalResumeToken, verifiedToken)
newPeerID := testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.generateCalls)
require.Equal(t, originalPeerID, newPeerID)
require.NotEqual(t, originalResumeToken, newResumeToken)
// Create a workspace with an agent. No need to connect it since clients can
// still connect to the coordinator while the agent isn't connected.
r := dbfake.WorkspaceBuild(t, api.Database, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
}).WithAgent().Do()
agentTokenUUID, err := uuid.Parse(r.AgentToken)
require.NoError(t, err)
ctx := testutil.Context(t, testutil.WaitLong)
agentAndBuild, err := api.Database.GetWorkspaceAgentAndLatestBuildByAuthToken(dbauthz.AsSystemRestricted(ctx), agentTokenUUID) //nolint
require.NoError(t, err)
// Connect with an invalid resume token, and ensure that the request is
// rejected.
clock.Advance(time.Second)
_, err = connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, "invalid")
require.Error(t, err)
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusUnauthorized, sdkErr.StatusCode())
require.Len(t, sdkErr.Validations, 1)
require.Equal(t, "resume_token", sdkErr.Validations[0].Field)
verifiedToken = testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.verifyCalls)
require.Equal(t, "invalid", verifiedToken)
// Connect with no resume token, and ensure that the peer ID is set to a
// random value.
originalResumeToken, err := connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, "")
require.NoError(t, err)
originalPeerID := testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.generateCalls)
require.NotEqual(t, originalPeerID, uuid.Nil)
select {
case <-resumeTokenProvider.generateCalls:
t.Fatal("unexpected peer ID in channel")
default:
}
// Connect with an outdated token, and ensure that the peer ID is set to a
// random value. We don't want to fail requests just because
// a user got unlucky during a deployment upgrade.
outdatedToken := generateBadJWT(t, jwtutils.RegisteredClaims{
Subject: originalPeerID.String(),
Expiry: jwt.NewNumericDate(clock.Now().Add(time.Minute)),
})
clock.Advance(time.Second)
newResumeToken, err := connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, outdatedToken)
require.NoError(t, err)
verifiedToken := testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.verifyCalls)
require.Equal(t, outdatedToken, verifiedToken)
newPeerID := testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.generateCalls)
require.NotEqual(t, originalPeerID, newPeerID)
require.NotEqual(t, originalResumeToken, newResumeToken)
})
}
// connectToCoordinatorAndFetchResumeToken connects to the tailnet coordinator
+5 -3
View File
@@ -16,6 +16,7 @@ import (
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/coderd/jwtutils"
"github.com/coder/coder/v2/coderd/rbac/policy"
"github.com/coder/coder/v2/coderd/workspaceapps"
"github.com/coder/coder/v2/coderd/workspaceapps/appurl"
@@ -122,10 +123,11 @@ func (api *API) workspaceApplicationAuth(rw http.ResponseWriter, r *http.Request
return
}
// Encrypt the API key.
encryptedAPIKey, err := api.AppSecurityKey.EncryptAPIKey(workspaceapps.EncryptedAPIKeyPayload{
payload := workspaceapps.EncryptedAPIKeyPayload{
APIKey: cookie.Value,
})
}
payload.Fill(api.Clock.Now())
encryptedAPIKey, err := jwtutils.Encrypt(ctx, api.AppEncryptionKeyCache, payload)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to encrypt API key.",
+212 -2
View File
@@ -3,6 +3,7 @@ package apptest
import (
"bufio"
"context"
"crypto/rand"
"encoding/json"
"fmt"
"io"
@@ -408,6 +409,67 @@ func Run(t *testing.T, appHostIsPrimary bool, factory DeploymentFactory) {
require.Equal(t, http.StatusInternalServerError, resp.StatusCode)
assertWorkspaceLastUsedAtNotUpdated(t, appDetails)
})
t.Run("BadJWT", func(t *testing.T) {
t.Parallel()
appDetails := setupProxyTest(t, nil)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
u := appDetails.PathAppURL(appDetails.Apps.Owner)
resp, err := requestWithRetries(ctx, t, appDetails.AppClient(t), http.MethodGet, u.String(), nil)
require.NoError(t, err)
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, proxyTestAppBody, string(body))
require.Equal(t, http.StatusOK, resp.StatusCode)
appTokenCookie := findCookie(resp.Cookies(), codersdk.SignedAppTokenCookie)
require.NotNil(t, appTokenCookie, "no signed app token cookie in response")
require.Equal(t, appTokenCookie.Path, u.Path, "incorrect path on app token cookie")
object, err := jose.ParseSigned(appTokenCookie.Value)
require.NoError(t, err)
require.Len(t, object.Signatures, 1)
// Parse the payload.
var tok workspaceapps.SignedToken
//nolint:gosec
err = json.Unmarshal(object.UnsafePayloadWithoutVerification(), &tok)
require.NoError(t, err)
appTokenClient := appDetails.AppClient(t)
apiKey := appTokenClient.SessionToken()
appTokenClient.SetSessionToken("")
appTokenClient.HTTPClient.Jar, err = cookiejar.New(nil)
require.NoError(t, err)
// Sign the token with an old-style key.
appTokenCookie.Value = generateBadJWT(t, tok)
appTokenClient.HTTPClient.Jar.SetCookies(u,
[]*http.Cookie{
appTokenCookie,
{
Name: codersdk.PathAppSessionTokenCookie,
Value: apiKey,
},
},
)
resp, err = requestWithRetries(ctx, t, appTokenClient, http.MethodGet, u.String(), nil)
require.NoError(t, err)
defer resp.Body.Close()
body, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, proxyTestAppBody, string(body))
require.Equal(t, http.StatusOK, resp.StatusCode)
assertWorkspaceLastUsedAtUpdated(t, appDetails)
// Since the old token is invalid, the signed app token cookie should have a new value.
newTokenCookie := findCookie(resp.Cookies(), codersdk.SignedAppTokenCookie)
require.NotEqual(t, appTokenCookie.Value, newTokenCookie.Value)
})
})
t.Run("WorkspaceApplicationAuth", func(t *testing.T) {
@@ -463,7 +525,7 @@ func Run(t *testing.T, appHostIsPrimary bool, factory DeploymentFactory) {
appClient.SetSessionToken("")
// Try to load the application without authentication.
u := c.appURL
u := *c.appURL
u.Path = path.Join(u.Path, "/test")
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
require.NoError(t, err)
@@ -500,7 +562,7 @@ func Run(t *testing.T, appHostIsPrimary bool, factory DeploymentFactory) {
// Copy the query parameters and then check equality.
u.RawQuery = gotLocation.RawQuery
require.Equal(t, u, gotLocation)
require.Equal(t, u, *gotLocation)
// Verify the API key is set.
encryptedAPIKey := gotLocation.Query().Get(workspaceapps.SubdomainProxyAPIKeyParam)
@@ -580,6 +642,38 @@ func Run(t *testing.T, appHostIsPrimary bool, factory DeploymentFactory) {
resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)
})
t.Run("BadJWE", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
currentKeyStr := appDetails.SDKClient.SessionToken()
appClient := appDetails.AppClient(t)
appClient.SetSessionToken("")
u := *c.appURL
u.Path = path.Join(u.Path, "/test")
badToken := generateBadJWE(t, workspaceapps.EncryptedAPIKeyPayload{
APIKey: currentKeyStr,
})
u.RawQuery = (url.Values{
workspaceapps.SubdomainProxyAPIKeyParam: {badToken},
}).Encode()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
require.NoError(t, err)
var resp *http.Response
resp, err = doWithRetries(t, appClient, req)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Contains(t, string(body), "Could not decrypt API key. Please remove the query parameter and try again.")
})
}
})
})
@@ -1077,6 +1171,68 @@ func Run(t *testing.T, appHostIsPrimary bool, factory DeploymentFactory) {
assertWorkspaceLastUsedAtNotUpdated(t, appDetails)
})
})
t.Run("BadJWT", func(t *testing.T) {
t.Parallel()
appDetails := setupProxyTest(t, nil)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
u := appDetails.SubdomainAppURL(appDetails.Apps.Owner)
resp, err := requestWithRetries(ctx, t, appDetails.AppClient(t), http.MethodGet, u.String(), nil)
require.NoError(t, err)
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, proxyTestAppBody, string(body))
require.Equal(t, http.StatusOK, resp.StatusCode)
appTokenCookie := findCookie(resp.Cookies(), codersdk.SignedAppTokenCookie)
require.NotNil(t, appTokenCookie, "no signed token cookie in response")
require.Equal(t, appTokenCookie.Path, "/", "incorrect path on signed token cookie")
object, err := jose.ParseSigned(appTokenCookie.Value)
require.NoError(t, err)
require.Len(t, object.Signatures, 1)
// Parse the payload.
var tok workspaceapps.SignedToken
//nolint:gosec
err = json.Unmarshal(object.UnsafePayloadWithoutVerification(), &tok)
require.NoError(t, err)
appTokenClient := appDetails.AppClient(t)
apiKey := appTokenClient.SessionToken()
appTokenClient.SetSessionToken("")
appTokenClient.HTTPClient.Jar, err = cookiejar.New(nil)
require.NoError(t, err)
// Sign the token with an old-style key.
appTokenCookie.Value = generateBadJWT(t, tok)
appTokenClient.HTTPClient.Jar.SetCookies(u,
[]*http.Cookie{
appTokenCookie,
{
Name: codersdk.SubdomainAppSessionTokenCookie,
Value: apiKey,
},
},
)
// We should still be able to successfully proxy.
resp, err = requestWithRetries(ctx, t, appTokenClient, http.MethodGet, u.String(), nil)
require.NoError(t, err)
defer resp.Body.Close()
body, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, proxyTestAppBody, string(body))
require.Equal(t, http.StatusOK, resp.StatusCode)
assertWorkspaceLastUsedAtUpdated(t, appDetails)
// Since the old token is invalid, the signed app token cookie should have a new value.
newTokenCookie := findCookie(resp.Cookies(), codersdk.SignedAppTokenCookie)
require.NotEqual(t, appTokenCookie.Value, newTokenCookie.Value)
})
})
t.Run("PortSharing", func(t *testing.T) {
@@ -1789,3 +1945,57 @@ func assertWorkspaceLastUsedAtNotUpdated(t testing.TB, details *Details) {
require.NoError(t, err)
require.Equal(t, before.LastUsedAt, after.LastUsedAt, "workspace LastUsedAt updated when it should not have been")
}
func generateBadJWE(t *testing.T, claims interface{}) string {
t.Helper()
var buf [32]byte
_, err := rand.Read(buf[:])
require.NoError(t, err)
encrypt, err := jose.NewEncrypter(
jose.A256GCM,
jose.Recipient{
Algorithm: jose.A256GCMKW,
Key: buf[:],
}, &jose.EncrypterOptions{
Compression: jose.DEFLATE,
},
)
require.NoError(t, err)
payload, err := json.Marshal(claims)
require.NoError(t, err)
signed, err := encrypt.Encrypt(payload)
require.NoError(t, err)
compact, err := signed.CompactSerialize()
require.NoError(t, err)
return compact
}
// generateBadJWT generates a JWT with a random key. It's intended to emulate the old-style JWT's we generated.
func generateBadJWT(t *testing.T, claims interface{}) string {
t.Helper()
var buf [64]byte
_, err := rand.Read(buf[:])
require.NoError(t, err)
signer, err := jose.NewSigner(jose.SigningKey{
Algorithm: jose.HS512,
Key: buf[:],
}, nil)
require.NoError(t, err)
payload, err := json.Marshal(claims)
require.NoError(t, err)
signed, err := signer.Sign(payload)
require.NoError(t, err)
compact, err := signed.CompactSerialize()
require.NoError(t, err)
return compact
}
func findCookie(cookies []*http.Cookie, name string) *http.Cookie {
for _, cookie := range cookies {
if cookie.Name == name {
return cookie
}
}
return nil
}
+21 -7
View File
@@ -13,11 +13,15 @@ import (
"golang.org/x/exp/slices"
"golang.org/x/xerrors"
"github.com/go-jose/go-jose/v4/jwt"
"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/cryptokeys"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/coderd/jwtutils"
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/coderd/rbac/policy"
"github.com/coder/coder/v2/codersdk"
@@ -35,12 +39,20 @@ type DBTokenProvider struct {
DeploymentValues *codersdk.DeploymentValues
OAuth2Configs *httpmw.OAuth2Configs
WorkspaceAgentInactiveTimeout time.Duration
SigningKey SecurityKey
Keycache cryptokeys.SigningKeycache
}
var _ SignedTokenProvider = &DBTokenProvider{}
func NewDBTokenProvider(log slog.Logger, accessURL *url.URL, authz rbac.Authorizer, db database.Store, cfg *codersdk.DeploymentValues, oauth2Cfgs *httpmw.OAuth2Configs, workspaceAgentInactiveTimeout time.Duration, signingKey SecurityKey) SignedTokenProvider {
func NewDBTokenProvider(log slog.Logger,
accessURL *url.URL,
authz rbac.Authorizer,
db database.Store,
cfg *codersdk.DeploymentValues,
oauth2Cfgs *httpmw.OAuth2Configs,
workspaceAgentInactiveTimeout time.Duration,
signer cryptokeys.SigningKeycache,
) SignedTokenProvider {
if workspaceAgentInactiveTimeout == 0 {
workspaceAgentInactiveTimeout = 1 * time.Minute
}
@@ -53,12 +65,12 @@ func NewDBTokenProvider(log slog.Logger, accessURL *url.URL, authz rbac.Authoriz
DeploymentValues: cfg,
OAuth2Configs: oauth2Cfgs,
WorkspaceAgentInactiveTimeout: workspaceAgentInactiveTimeout,
SigningKey: signingKey,
Keycache: signer,
}
}
func (p *DBTokenProvider) FromRequest(r *http.Request) (*SignedToken, bool) {
return FromRequest(r, p.SigningKey)
return FromRequest(r, p.Keycache)
}
func (p *DBTokenProvider) Issue(ctx context.Context, rw http.ResponseWriter, r *http.Request, issueReq IssueTokenRequest) (*SignedToken, string, bool) {
@@ -70,7 +82,7 @@ func (p *DBTokenProvider) Issue(ctx context.Context, rw http.ResponseWriter, r *
dangerousSystemCtx := dbauthz.AsSystemRestricted(ctx)
appReq := issueReq.AppRequest.Normalize()
err := appReq.Validate()
err := appReq.Check()
if err != nil {
WriteWorkspaceApp500(p.Logger, p.DashboardURL, rw, r, &appReq, err, "invalid app request")
return nil, "", false
@@ -210,9 +222,11 @@ func (p *DBTokenProvider) Issue(ctx context.Context, rw http.ResponseWriter, r *
return nil, "", false
}
token.RegisteredClaims = jwtutils.RegisteredClaims{
Expiry: jwt.NewNumericDate(time.Now().Add(DefaultTokenExpiry)),
}
// Sign the token.
token.Expiry = time.Now().Add(DefaultTokenExpiry)
tokenStr, err := p.SigningKey.SignToken(token)
tokenStr, err := jwtutils.Sign(ctx, p.Keycache, token)
if err != nil {
WriteWorkspaceApp500(p.Logger, p.DashboardURL, rw, r, &appReq, err, "generate token")
return nil, "", false
+18 -10
View File
@@ -13,6 +13,7 @@ import (
"testing"
"time"
"github.com/go-jose/go-jose/v4/jwt"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -20,6 +21,7 @@ import (
"github.com/coder/coder/v2/agent/agenttest"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/coderd/jwtutils"
"github.com/coder/coder/v2/coderd/workspaceapps"
"github.com/coder/coder/v2/coderd/workspaceapps/appurl"
"github.com/coder/coder/v2/codersdk"
@@ -94,8 +96,7 @@ func Test_ResolveRequest(t *testing.T) {
_ = closer.Close()
})
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
t.Cleanup(cancel)
ctx := testutil.Context(t, testutil.WaitMedium)
firstUser := coderdtest.CreateFirstUser(t, client)
me, err := client.User(ctx, codersdk.Me)
@@ -276,15 +277,17 @@ func Test_ResolveRequest(t *testing.T) {
_ = w.Body.Close()
require.Equal(t, &workspaceapps.SignedToken{
RegisteredClaims: jwtutils.RegisteredClaims{
Expiry: jwt.NewNumericDate(token.Expiry.Time()),
},
Request: req,
Expiry: token.Expiry, // ignored to avoid flakiness
UserID: me.ID,
WorkspaceID: workspace.ID,
AgentID: agentID,
AppURL: appURL,
}, token)
require.NotZero(t, token.Expiry)
require.WithinDuration(t, time.Now().Add(workspaceapps.DefaultTokenExpiry), token.Expiry, time.Minute)
require.WithinDuration(t, time.Now().Add(workspaceapps.DefaultTokenExpiry), token.Expiry.Time(), time.Minute)
// Check that the token was set in the response and is valid.
require.Len(t, w.Cookies(), 1)
@@ -292,10 +295,11 @@ func Test_ResolveRequest(t *testing.T) {
require.Equal(t, codersdk.SignedAppTokenCookie, cookie.Name)
require.Equal(t, req.BasePath, cookie.Path)
parsedToken, err := api.AppSecurityKey.VerifySignedToken(cookie.Value)
var parsedToken workspaceapps.SignedToken
err := jwtutils.Verify(ctx, api.AppSigningKeyCache, cookie.Value, &parsedToken)
require.NoError(t, err)
// normalize expiry
require.WithinDuration(t, token.Expiry, parsedToken.Expiry, 2*time.Second)
require.WithinDuration(t, token.Expiry.Time(), parsedToken.Expiry.Time(), 2*time.Second)
parsedToken.Expiry = token.Expiry
require.Equal(t, token, &parsedToken)
@@ -314,7 +318,7 @@ func Test_ResolveRequest(t *testing.T) {
})
require.True(t, ok)
// normalize expiry
require.WithinDuration(t, token.Expiry, secondToken.Expiry, 2*time.Second)
require.WithinDuration(t, token.Expiry.Time(), secondToken.Expiry.Time(), 2*time.Second)
secondToken.Expiry = token.Expiry
require.Equal(t, token, secondToken)
}
@@ -540,13 +544,16 @@ func Test_ResolveRequest(t *testing.T) {
// App name differs
AppSlugOrPort: appNamePublic,
}).Normalize(),
Expiry: time.Now().Add(time.Minute),
RegisteredClaims: jwtutils.RegisteredClaims{
Expiry: jwt.NewNumericDate(time.Now().Add(time.Minute)),
},
UserID: me.ID,
WorkspaceID: workspace.ID,
AgentID: agentID,
AppURL: appURL,
}
badTokenStr, err := api.AppSecurityKey.SignToken(badToken)
badTokenStr, err := jwtutils.Sign(ctx, api.AppSigningKeyCache, badToken)
require.NoError(t, err)
req := (workspaceapps.Request{
@@ -589,7 +596,8 @@ func Test_ResolveRequest(t *testing.T) {
require.Len(t, cookies, 1)
require.Equal(t, cookies[0].Name, codersdk.SignedAppTokenCookie)
require.NotEqual(t, cookies[0].Value, badTokenStr)
parsedToken, err := api.AppSecurityKey.VerifySignedToken(cookies[0].Value)
var parsedToken workspaceapps.SignedToken
err = jwtutils.Verify(ctx, api.AppSigningKeyCache, cookies[0].Value, &parsedToken)
require.NoError(t, err)
require.Equal(t, appNameOwner, parsedToken.AppSlugOrPort)
})
+2 -2
View File
@@ -38,7 +38,7 @@ type ResolveRequestOptions struct {
func ResolveRequest(rw http.ResponseWriter, r *http.Request, opts ResolveRequestOptions) (*SignedToken, bool) {
appReq := opts.AppRequest.Normalize()
err := appReq.Validate()
err := appReq.Check()
if err != nil {
// This is a 500 since it's a coder server or proxy that's making this
// request struct based on details from the request. The values should
@@ -79,7 +79,7 @@ func ResolveRequest(rw http.ResponseWriter, r *http.Request, opts ResolveRequest
Name: codersdk.SignedAppTokenCookie,
Value: tokenStr,
Path: appReq.BasePath,
Expires: token.Expiry,
Expires: token.Expiry.Time(),
})
return token, true
+11 -4
View File
@@ -11,17 +11,21 @@ import (
"strconv"
"strings"
"sync"
"time"
"github.com/go-chi/chi/v5"
"github.com/go-jose/go-jose/v4/jwt"
"github.com/google/uuid"
"go.opentelemetry.io/otel/trace"
"nhooyr.io/websocket"
"cdr.dev/slog"
"github.com/coder/coder/v2/agent/agentssh"
"github.com/coder/coder/v2/coderd/cryptokeys"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/coderd/jwtutils"
"github.com/coder/coder/v2/coderd/tracing"
"github.com/coder/coder/v2/coderd/util/slice"
"github.com/coder/coder/v2/coderd/workspaceapps/appurl"
@@ -97,8 +101,8 @@ type Server struct {
HostnameRegex *regexp.Regexp
RealIPConfig *httpmw.RealIPConfig
SignedTokenProvider SignedTokenProvider
AppSecurityKey SecurityKey
SignedTokenProvider SignedTokenProvider
APIKeyEncryptionKeycache cryptokeys.EncryptionKeycache
// DisablePathApps disables path-based apps. This is a security feature as path
// based apps share the same cookie as the dashboard, and are susceptible to XSS
@@ -176,7 +180,10 @@ func (s *Server) handleAPIKeySmuggling(rw http.ResponseWriter, r *http.Request,
}
// Exchange the encoded API key for a real one.
token, err := s.AppSecurityKey.DecryptAPIKey(encryptedAPIKey)
var payload EncryptedAPIKeyPayload
err := jwtutils.Decrypt(ctx, s.APIKeyEncryptionKeycache, encryptedAPIKey, &payload, jwtutils.WithDecryptExpected(jwt.Expected{
Time: time.Now(),
}))
if err != nil {
s.Logger.Debug(ctx, "could not decrypt smuggled workspace app API key", slog.Error(err))
site.RenderStaticErrorPage(rw, r, site.ErrorPageData{
@@ -225,7 +232,7 @@ func (s *Server) handleAPIKeySmuggling(rw http.ResponseWriter, r *http.Request,
// server using the wrong value.
http.SetCookie(rw, &http.Cookie{
Name: AppConnectSessionTokenCookieName(accessMethod),
Value: token,
Value: payload.APIKey,
Domain: domain,
Path: "/",
MaxAge: 0,
+2 -2
View File
@@ -124,9 +124,9 @@ func (r Request) Normalize() Request {
return req
}
// Validate ensures the request is correct and contains the necessary
// Check ensures the request is correct and contains the necessary
// parameters.
func (r Request) Validate() error {
func (r Request) Check() error {
switch r.AccessMethod {
case AccessMethodPath, AccessMethodSubdomain, AccessMethodTerminal:
default:
+1 -1
View File
@@ -279,7 +279,7 @@ func Test_RequestValidate(t *testing.T) {
if !c.noNormalize {
req = c.req.Normalize()
}
err := req.Validate()
err := req.Check()
if c.errContains == "" {
require.NoError(t, err)
} else {
+24 -187
View File
@@ -1,35 +1,27 @@
package workspaceapps
import (
"encoding/base64"
"encoding/hex"
"encoding/json"
"net/http"
"strings"
"time"
"github.com/go-jose/go-jose/v3"
"github.com/go-jose/go-jose/v4/jwt"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/cryptokeys"
"github.com/coder/coder/v2/coderd/jwtutils"
"github.com/coder/coder/v2/codersdk"
)
const (
tokenSigningAlgorithm = jose.HS512
apiKeyEncryptionAlgorithm = jose.A256GCMKW
)
// SignedToken is the struct data contained inside a workspace app JWE. It
// contains the details of the workspace app that the token is valid for to
// avoid database queries.
type SignedToken struct {
jwtutils.RegisteredClaims
// Request details.
Request `json:"request"`
// Trusted resolved details.
Expiry time.Time `json:"expiry"` // set by GenerateToken if unset
UserID uuid.UUID `json:"user_id"`
WorkspaceID uuid.UUID `json:"workspace_id"`
AgentID uuid.UUID `json:"agent_id"`
@@ -57,191 +49,32 @@ func (t SignedToken) MatchesRequest(req Request) bool {
t.AppSlugOrPort == req.AppSlugOrPort
}
// SecurityKey is used for signing and encrypting app tokens and API keys.
//
// The first 64 bytes of the key are used for signing tokens with HMAC-SHA256,
// and the last 32 bytes are used for encrypting API keys with AES-256-GCM.
// We use a single key for both operations to avoid having to store and manage
// two keys.
type SecurityKey [96]byte
func (k SecurityKey) IsZero() bool {
return k == SecurityKey{}
}
func (k SecurityKey) String() string {
return hex.EncodeToString(k[:])
}
func (k SecurityKey) signingKey() []byte {
return k[:64]
}
func (k SecurityKey) encryptionKey() []byte {
return k[64:]
}
func KeyFromString(str string) (SecurityKey, error) {
var key SecurityKey
decoded, err := hex.DecodeString(str)
if err != nil {
return key, xerrors.Errorf("decode key: %w", err)
}
if len(decoded) != len(key) {
return key, xerrors.Errorf("expected key to be %d bytes, got %d", len(key), len(decoded))
}
copy(key[:], decoded)
return key, nil
}
// SignToken generates a signed workspace app token with the given payload. If
// the payload doesn't have an expiry, it will be set to the current time plus
// the default expiry.
func (k SecurityKey) SignToken(payload SignedToken) (string, error) {
if payload.Expiry.IsZero() {
payload.Expiry = time.Now().Add(DefaultTokenExpiry)
}
payloadBytes, err := json.Marshal(payload)
if err != nil {
return "", xerrors.Errorf("marshal payload to JSON: %w", err)
}
signer, err := jose.NewSigner(jose.SigningKey{
Algorithm: tokenSigningAlgorithm,
Key: k.signingKey(),
}, nil)
if err != nil {
return "", xerrors.Errorf("create signer: %w", err)
}
signedObject, err := signer.Sign(payloadBytes)
if err != nil {
return "", xerrors.Errorf("sign payload: %w", err)
}
serialized, err := signedObject.CompactSerialize()
if err != nil {
return "", xerrors.Errorf("serialize JWS: %w", err)
}
return serialized, nil
}
// VerifySignedToken parses a signed workspace app token with the given key and
// returns the payload. If the token is invalid or expired, an error is
// returned.
func (k SecurityKey) VerifySignedToken(str string) (SignedToken, error) {
object, err := jose.ParseSigned(str)
if err != nil {
return SignedToken{}, xerrors.Errorf("parse JWS: %w", err)
}
if len(object.Signatures) != 1 {
return SignedToken{}, xerrors.New("expected 1 signature")
}
if object.Signatures[0].Header.Algorithm != string(tokenSigningAlgorithm) {
return SignedToken{}, xerrors.Errorf("expected token signing algorithm to be %q, got %q", tokenSigningAlgorithm, object.Signatures[0].Header.Algorithm)
}
output, err := object.Verify(k.signingKey())
if err != nil {
return SignedToken{}, xerrors.Errorf("verify JWS: %w", err)
}
var tok SignedToken
err = json.Unmarshal(output, &tok)
if err != nil {
return SignedToken{}, xerrors.Errorf("unmarshal payload: %w", err)
}
if tok.Expiry.Before(time.Now()) {
return SignedToken{}, xerrors.New("signed app token expired")
}
return tok, nil
}
type EncryptedAPIKeyPayload struct {
APIKey string `json:"api_key"`
ExpiresAt time.Time `json:"expires_at"`
jwtutils.RegisteredClaims
APIKey string `json:"api_key"`
}
// EncryptAPIKey encrypts an API key for subdomain token smuggling.
func (k SecurityKey) EncryptAPIKey(payload EncryptedAPIKeyPayload) (string, error) {
if payload.APIKey == "" {
return "", xerrors.New("API key is empty")
}
if payload.ExpiresAt.IsZero() {
// Very short expiry as these keys are only used once as part of an
// automatic redirection flow.
payload.ExpiresAt = dbtime.Now().Add(time.Minute)
}
payloadBytes, err := json.Marshal(payload)
if err != nil {
return "", xerrors.Errorf("marshal payload: %w", err)
}
// JWEs seem to apply a nonce themselves.
encrypter, err := jose.NewEncrypter(
jose.A256GCM,
jose.Recipient{
Algorithm: apiKeyEncryptionAlgorithm,
Key: k.encryptionKey(),
},
&jose.EncrypterOptions{
Compression: jose.DEFLATE,
},
)
if err != nil {
return "", xerrors.Errorf("initializer jose encrypter: %w", err)
}
encryptedObject, err := encrypter.Encrypt(payloadBytes)
if err != nil {
return "", xerrors.Errorf("encrypt jwe: %w", err)
}
encrypted := encryptedObject.FullSerialize()
return base64.RawURLEncoding.EncodeToString([]byte(encrypted)), nil
func (e *EncryptedAPIKeyPayload) Fill(now time.Time) {
e.Issuer = "coderd"
e.Audience = jwt.Audience{"wsproxy"}
e.Expiry = jwt.NewNumericDate(now.Add(time.Minute))
e.NotBefore = jwt.NewNumericDate(now.Add(-time.Minute))
}
// DecryptAPIKey undoes EncryptAPIKey and is used in the subdomain app handler.
func (k SecurityKey) DecryptAPIKey(encryptedAPIKey string) (string, error) {
encrypted, err := base64.RawURLEncoding.DecodeString(encryptedAPIKey)
if err != nil {
return "", xerrors.Errorf("base64 decode encrypted API key: %w", err)
func (e EncryptedAPIKeyPayload) Validate(ex jwt.Expected) error {
if e.NotBefore == nil {
return xerrors.Errorf("not before is required")
}
object, err := jose.ParseEncrypted(string(encrypted))
if err != nil {
return "", xerrors.Errorf("parse encrypted API key: %w", err)
}
if object.Header.Algorithm != string(apiKeyEncryptionAlgorithm) {
return "", xerrors.Errorf("expected API key encryption algorithm to be %q, got %q", apiKeyEncryptionAlgorithm, object.Header.Algorithm)
}
ex.Issuer = "coderd"
ex.AnyAudience = jwt.Audience{"wsproxy"}
// Decrypt using the hashed secret.
decrypted, err := object.Decrypt(k.encryptionKey())
if err != nil {
return "", xerrors.Errorf("decrypt API key: %w", err)
}
// Unmarshal the payload.
var payload EncryptedAPIKeyPayload
if err := json.Unmarshal(decrypted, &payload); err != nil {
return "", xerrors.Errorf("unmarshal decrypted payload: %w", err)
}
// Validate expiry.
if payload.ExpiresAt.Before(dbtime.Now()) {
return "", xerrors.New("encrypted API key expired")
}
return payload.APIKey, nil
return e.RegisteredClaims.Validate(ex)
}
// FromRequest returns the signed token from the request, if it exists and is
// valid. The caller must check that the token matches the request.
func FromRequest(r *http.Request, key SecurityKey) (*SignedToken, bool) {
func FromRequest(r *http.Request, mgr cryptokeys.SigningKeycache) (*SignedToken, bool) {
// Get all signed app tokens from the request. This includes the query
// parameter and all matching cookies sent with the request. If there are
// somehow multiple signed app token cookies, we want to try all of them
@@ -270,8 +103,12 @@ func FromRequest(r *http.Request, key SecurityKey) (*SignedToken, bool) {
tokens = tokens[:4]
}
ctx := r.Context()
for _, tokenStr := range tokens {
token, err := key.VerifySignedToken(tokenStr)
var token SignedToken
err := jwtutils.Verify(ctx, mgr, tokenStr, &token, jwtutils.WithVerifyExpected(jwt.Expected{
Time: time.Now(),
}))
if err == nil {
req := token.Request.Normalize()
if hasQueryParam && req.AccessMethod != AccessMethodTerminal {
@@ -280,7 +117,7 @@ func FromRequest(r *http.Request, key SecurityKey) (*SignedToken, bool) {
return nil, false
}
err := req.Validate()
err := req.Check()
if err == nil {
// The request has a valid signed app token, which is a valid
// token signed by us. The caller must check that it matches
+30 -269
View File
@@ -1,22 +1,22 @@
package workspaceapps_test
import (
"fmt"
"crypto/rand"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/coder/coder/v2/codersdk"
"github.com/go-jose/go-jose/v4/jwt"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
"github.com/go-jose/go-jose/v3"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/jwtutils"
"github.com/coder/coder/v2/coderd/workspaceapps"
"github.com/coder/coder/v2/cryptorand"
)
func Test_TokenMatchesRequest(t *testing.T) {
@@ -283,129 +283,6 @@ func Test_TokenMatchesRequest(t *testing.T) {
}
}
func Test_GenerateToken(t *testing.T) {
t.Parallel()
t.Run("SetExpiry", func(t *testing.T) {
t.Parallel()
tokenStr, err := coderdtest.AppSecurityKey.SignToken(workspaceapps.SignedToken{
Request: workspaceapps.Request{
AccessMethod: workspaceapps.AccessMethodPath,
BasePath: "/app",
UsernameOrID: "foo",
WorkspaceNameOrID: "bar",
AgentNameOrID: "baz",
AppSlugOrPort: "qux",
},
Expiry: time.Time{},
UserID: uuid.MustParse("b1530ba9-76f3-415e-b597-4ddd7cd466a4"),
WorkspaceID: uuid.MustParse("1e6802d3-963e-45ac-9d8c-bf997016ffed"),
AgentID: uuid.MustParse("9ec18681-d2c9-4c9e-9186-f136efb4edbe"),
AppURL: "http://127.0.0.1:8080",
})
require.NoError(t, err)
token, err := coderdtest.AppSecurityKey.VerifySignedToken(tokenStr)
require.NoError(t, err)
require.WithinDuration(t, time.Now().Add(time.Minute), token.Expiry, 15*time.Second)
})
future := time.Now().Add(time.Hour)
cases := []struct {
name string
token workspaceapps.SignedToken
parseErrContains string
}{
{
name: "OK1",
token: workspaceapps.SignedToken{
Request: workspaceapps.Request{
AccessMethod: workspaceapps.AccessMethodPath,
BasePath: "/app",
UsernameOrID: "foo",
WorkspaceNameOrID: "bar",
AgentNameOrID: "baz",
AppSlugOrPort: "qux",
},
Expiry: future,
UserID: uuid.MustParse("b1530ba9-76f3-415e-b597-4ddd7cd466a4"),
WorkspaceID: uuid.MustParse("1e6802d3-963e-45ac-9d8c-bf997016ffed"),
AgentID: uuid.MustParse("9ec18681-d2c9-4c9e-9186-f136efb4edbe"),
AppURL: "http://127.0.0.1:8080",
},
},
{
name: "OK2",
token: workspaceapps.SignedToken{
Request: workspaceapps.Request{
AccessMethod: workspaceapps.AccessMethodSubdomain,
BasePath: "/",
UsernameOrID: "oof",
WorkspaceNameOrID: "rab",
AgentNameOrID: "zab",
AppSlugOrPort: "xuq",
},
Expiry: future,
UserID: uuid.MustParse("6fa684a3-11aa-49fd-8512-ab527bd9b900"),
WorkspaceID: uuid.MustParse("b2d816cc-505c-441d-afdf-dae01781bc0b"),
AgentID: uuid.MustParse("6c4396e1-af88-4a8a-91a3-13ea54fc29fb"),
AppURL: "http://localhost:9090",
},
},
{
name: "Expired",
token: workspaceapps.SignedToken{
Request: workspaceapps.Request{
AccessMethod: workspaceapps.AccessMethodSubdomain,
BasePath: "/",
UsernameOrID: "foo",
WorkspaceNameOrID: "bar",
AgentNameOrID: "baz",
AppSlugOrPort: "qux",
},
Expiry: time.Now().Add(-time.Hour),
UserID: uuid.MustParse("b1530ba9-76f3-415e-b597-4ddd7cd466a4"),
WorkspaceID: uuid.MustParse("1e6802d3-963e-45ac-9d8c-bf997016ffed"),
AgentID: uuid.MustParse("9ec18681-d2c9-4c9e-9186-f136efb4edbe"),
AppURL: "http://127.0.0.1:8080",
},
parseErrContains: "token expired",
},
}
for _, c := range cases {
c := c
t.Run(c.name, func(t *testing.T) {
t.Parallel()
str, err := coderdtest.AppSecurityKey.SignToken(c.token)
require.NoError(t, err)
// Tokens aren't deterministic as they have a random nonce, so we
// can't compare them directly.
token, err := coderdtest.AppSecurityKey.VerifySignedToken(str)
if c.parseErrContains != "" {
require.Error(t, err)
require.ErrorContains(t, err, c.parseErrContains)
} else {
require.NoError(t, err)
// normalize the expiry
require.WithinDuration(t, c.token.Expiry, token.Expiry, 10*time.Second)
c.token.Expiry = token.Expiry
require.Equal(t, c.token, token)
}
})
}
}
func Test_FromRequest(t *testing.T) {
t.Parallel()
@@ -419,7 +296,13 @@ func Test_FromRequest(t *testing.T) {
Value: "invalid",
})
ctx := testutil.Context(t, testutil.WaitShort)
signer := newSigner(t)
token := workspaceapps.SignedToken{
RegisteredClaims: jwtutils.RegisteredClaims{
Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)),
},
Request: workspaceapps.Request{
AccessMethod: workspaceapps.AccessMethodSubdomain,
BasePath: "/",
@@ -429,7 +312,6 @@ func Test_FromRequest(t *testing.T) {
AgentNameOrID: "agent",
AppSlugOrPort: "app",
},
Expiry: time.Now().Add(time.Hour),
UserID: uuid.New(),
WorkspaceID: uuid.New(),
AgentID: uuid.New(),
@@ -438,16 +320,15 @@ func Test_FromRequest(t *testing.T) {
// Add an expired cookie
expired := token
expired.Expiry = time.Now().Add(time.Hour * -1)
expiredStr, err := coderdtest.AppSecurityKey.SignToken(token)
expired.RegisteredClaims.Expiry = jwt.NewNumericDate(time.Now().Add(time.Hour * -1))
expiredStr, err := jwtutils.Sign(ctx, signer, expired)
require.NoError(t, err)
r.AddCookie(&http.Cookie{
Name: codersdk.SignedAppTokenCookie,
Value: expiredStr,
})
// Add a valid token
validStr, err := coderdtest.AppSecurityKey.SignToken(token)
validStr, err := jwtutils.Sign(ctx, signer, token)
require.NoError(t, err)
r.AddCookie(&http.Cookie{
@@ -455,147 +336,27 @@ func Test_FromRequest(t *testing.T) {
Value: validStr,
})
signed, ok := workspaceapps.FromRequest(r, coderdtest.AppSecurityKey)
signed, ok := workspaceapps.FromRequest(r, signer)
require.True(t, ok, "expected a token to be found")
// Confirm it is the correct token.
require.Equal(t, signed.UserID, token.UserID)
})
}
// The ParseToken fn is tested quite thoroughly in the GenerateToken test as
// well.
func Test_ParseToken(t *testing.T) {
t.Parallel()
func newSigner(t *testing.T) jwtutils.StaticKey {
t.Helper()
t.Run("InvalidJWS", func(t *testing.T) {
t.Parallel()
token, err := coderdtest.AppSecurityKey.VerifySignedToken("invalid")
require.Error(t, err)
require.ErrorContains(t, err, "parse JWS")
require.Equal(t, workspaceapps.SignedToken{}, token)
})
t.Run("VerifySignature", func(t *testing.T) {
t.Parallel()
// Create a valid token using a different key.
var otherKey workspaceapps.SecurityKey
copy(otherKey[:], coderdtest.AppSecurityKey[:])
for i := range otherKey {
otherKey[i] ^= 0xff
}
require.NotEqual(t, coderdtest.AppSecurityKey, otherKey)
tokenStr, err := otherKey.SignToken(workspaceapps.SignedToken{
Request: workspaceapps.Request{
AccessMethod: workspaceapps.AccessMethodPath,
BasePath: "/app",
UsernameOrID: "foo",
WorkspaceNameOrID: "bar",
AgentNameOrID: "baz",
AppSlugOrPort: "qux",
},
Expiry: time.Now().Add(time.Hour),
UserID: uuid.MustParse("b1530ba9-76f3-415e-b597-4ddd7cd466a4"),
WorkspaceID: uuid.MustParse("1e6802d3-963e-45ac-9d8c-bf997016ffed"),
AgentID: uuid.MustParse("9ec18681-d2c9-4c9e-9186-f136efb4edbe"),
AppURL: "http://127.0.0.1:8080",
})
require.NoError(t, err)
// Verify the token is invalid.
token, err := coderdtest.AppSecurityKey.VerifySignedToken(tokenStr)
require.Error(t, err)
require.ErrorContains(t, err, "verify JWS")
require.Equal(t, workspaceapps.SignedToken{}, token)
})
t.Run("InvalidBody", func(t *testing.T) {
t.Parallel()
// Create a signature for an invalid body.
signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS512, Key: coderdtest.AppSecurityKey[:64]}, nil)
require.NoError(t, err)
signedObject, err := signer.Sign([]byte("hi"))
require.NoError(t, err)
serialized, err := signedObject.CompactSerialize()
require.NoError(t, err)
token, err := coderdtest.AppSecurityKey.VerifySignedToken(serialized)
require.Error(t, err)
require.ErrorContains(t, err, "unmarshal payload")
require.Equal(t, workspaceapps.SignedToken{}, token)
})
}
func TestAPIKeyEncryption(t *testing.T) {
t.Parallel()
genAPIKey := func(t *testing.T) string {
id, _ := cryptorand.String(10)
secret, _ := cryptorand.String(22)
return fmt.Sprintf("%s-%s", id, secret)
return jwtutils.StaticKey{
ID: "test",
Key: generateSecret(t, 64),
}
t.Run("OK", func(t *testing.T) {
t.Parallel()
key := genAPIKey(t)
encrypted, err := coderdtest.AppSecurityKey.EncryptAPIKey(workspaceapps.EncryptedAPIKeyPayload{
APIKey: key,
})
require.NoError(t, err)
decryptedKey, err := coderdtest.AppSecurityKey.DecryptAPIKey(encrypted)
require.NoError(t, err)
require.Equal(t, key, decryptedKey)
})
t.Run("Verifies", func(t *testing.T) {
t.Parallel()
t.Run("Expiry", func(t *testing.T) {
t.Parallel()
key := genAPIKey(t)
encrypted, err := coderdtest.AppSecurityKey.EncryptAPIKey(workspaceapps.EncryptedAPIKeyPayload{
APIKey: key,
ExpiresAt: dbtime.Now().Add(-1 * time.Hour),
})
require.NoError(t, err)
decryptedKey, err := coderdtest.AppSecurityKey.DecryptAPIKey(encrypted)
require.Error(t, err)
require.ErrorContains(t, err, "expired")
require.Empty(t, decryptedKey)
})
t.Run("EncryptionKey", func(t *testing.T) {
t.Parallel()
// Create a valid token using a different key.
var otherKey workspaceapps.SecurityKey
copy(otherKey[:], coderdtest.AppSecurityKey[:])
for i := range otherKey {
otherKey[i] ^= 0xff
}
require.NotEqual(t, coderdtest.AppSecurityKey, otherKey)
// Encrypt with the other key.
key := genAPIKey(t)
encrypted, err := otherKey.EncryptAPIKey(workspaceapps.EncryptedAPIKeyPayload{
APIKey: key,
})
require.NoError(t, err)
// Decrypt with the original key.
decryptedKey, err := coderdtest.AppSecurityKey.DecryptAPIKey(encrypted)
require.Error(t, err)
require.ErrorContains(t, err, "decrypt API key")
require.Empty(t, decryptedKey)
})
})
}
func generateSecret(t *testing.T, size int) []byte {
t.Helper()
secret := make([]byte, size)
_, err := rand.Read(secret)
require.NoError(t, err)
return secret
}
+34 -7
View File
@@ -5,16 +5,23 @@ import (
"net/http"
"net/url"
"testing"
"time"
"github.com/go-jose/go-jose/v4/jwt"
"github.com/stretchr/testify/require"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/cryptokeys"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/jwtutils"
"github.com/coder/coder/v2/coderd/workspaceapps"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
)
func TestGetAppHost(t *testing.T) {
@@ -181,16 +188,28 @@ func TestWorkspaceApplicationAuth(t *testing.T) {
t.Run(c.name, func(t *testing.T) {
t.Parallel()
db, pubsub := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitMedium)
logger := slogtest.Make(t, nil)
accessURL, err := url.Parse(c.accessURL)
require.NoError(t, err)
db, ps := dbtestutil.NewDB(t)
fetcher := &cryptokeys.DBFetcher{
DB: db,
}
kc, err := cryptokeys.NewEncryptionCache(ctx, logger, fetcher, codersdk.CryptoKeyFeatureWorkspaceAppsAPIKey)
require.NoError(t, err)
clock := quartz.NewMock(t)
client := coderdtest.New(t, &coderdtest.Options{
Database: db,
Pubsub: pubsub,
AccessURL: accessURL,
AppHostname: c.appHostname,
AccessURL: accessURL,
AppHostname: c.appHostname,
Database: db,
Pubsub: ps,
APIKeyEncryptionCache: kc,
Clock: clock,
})
_ = coderdtest.CreateFirstUser(t, client)
@@ -240,7 +259,15 @@ func TestWorkspaceApplicationAuth(t *testing.T) {
loc.RawQuery = q.Encode()
require.Equal(t, c.expectRedirect, loc.String())
// The decrypted key is verified in the apptest test suite.
var token workspaceapps.EncryptedAPIKeyPayload
err = jwtutils.Decrypt(ctx, kc, encryptedAPIKey, &token, jwtutils.WithDecryptExpected(jwt.Expected{
Time: clock.Now(),
AnyAudience: jwt.Audience{"wsproxy"},
Issuer: "coderd",
}))
require.NoError(t, err)
require.Equal(t, jwt.NewNumericDate(clock.Now().Add(time.Minute)), token.Expiry)
require.Equal(t, jwt.NewNumericDate(clock.Now().Add(-time.Minute)), token.NotBefore)
})
}
}
+5 -3
View File
@@ -3113,9 +3113,11 @@ func (c *Client) SSHConfiguration(ctx context.Context) (SSHConfigResponse, error
type CryptoKeyFeature string
const (
CryptoKeyFeatureWorkspaceApp CryptoKeyFeature = "workspace_apps"
CryptoKeyFeatureOIDCConvert CryptoKeyFeature = "oidc_convert"
CryptoKeyFeatureTailnetResume CryptoKeyFeature = "tailnet_resume"
CryptoKeyFeatureWorkspaceAppsAPIKey CryptoKeyFeature = "workspace_apps_api_key"
//nolint:gosec // This denotes a type of key, not a literal.
CryptoKeyFeatureWorkspaceAppsToken CryptoKeyFeature = "workspace_apps_token"
CryptoKeyFeatureOIDCConvert CryptoKeyFeature = "oidc_convert"
CryptoKeyFeatureTailnetResume CryptoKeyFeature = "tailnet_resume"
)
type CryptoKey struct {
@@ -25,6 +25,7 @@ import (
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/apiversion"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/jwtutils"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto"
@@ -61,7 +62,7 @@ func TestTailnetAPIConnector_Disconnects(t *testing.T) {
CoordPtr: &coordPtr,
DERPMapUpdateFrequency: time.Millisecond,
DERPMapFn: func() *tailcfg.DERPMap { return <-derpMapCh },
NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) {},
NetworkTelemetryHandler: func([]*proto.TelemetryEvent) {},
ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(),
})
require.NoError(t, err)
@@ -165,13 +166,17 @@ func TestTailnetAPIConnector_ResumeToken(t *testing.T) {
clock := quartz.NewMock(t)
resumeTokenSigningKey, err := tailnet.GenerateResumeTokenSigningKey()
require.NoError(t, err)
resumeTokenProvider := tailnet.NewResumeTokenKeyProvider(resumeTokenSigningKey, clock, time.Hour)
mgr := jwtutils.StaticKey{
ID: "123",
Key: resumeTokenSigningKey[:],
}
resumeTokenProvider := tailnet.NewResumeTokenKeyProvider(mgr, clock, time.Hour)
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
Logger: logger,
CoordPtr: &coordPtr,
DERPMapUpdateFrequency: time.Millisecond,
DERPMapFn: func() *tailcfg.DERPMap { return <-derpMapCh },
NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) {},
NetworkTelemetryHandler: func([]*proto.TelemetryEvent) {},
ResumeTokenProvider: resumeTokenProvider,
})
require.NoError(t, err)
@@ -190,7 +195,7 @@ func TestTailnetAPIConnector_ResumeToken(t *testing.T) {
t.Logf("received resume token: %s", resumeToken)
assert.Equal(t, expectResumeToken, resumeToken)
if resumeToken != "" {
peerID, err = resumeTokenProvider.VerifyResumeToken(resumeToken)
peerID, err = resumeTokenProvider.VerifyResumeToken(ctx, resumeToken)
assert.NoError(t, err, "failed to parse resume token")
if err != nil {
httpapi.Write(ctx, w, http.StatusUnauthorized, codersdk.Response{
@@ -280,13 +285,17 @@ func TestTailnetAPIConnector_ResumeTokenFailure(t *testing.T) {
clock := quartz.NewMock(t)
resumeTokenSigningKey, err := tailnet.GenerateResumeTokenSigningKey()
require.NoError(t, err)
resumeTokenProvider := tailnet.NewResumeTokenKeyProvider(resumeTokenSigningKey, clock, time.Hour)
mgr := jwtutils.StaticKey{
ID: uuid.New().String(),
Key: resumeTokenSigningKey[:],
}
resumeTokenProvider := tailnet.NewResumeTokenKeyProvider(mgr, clock, time.Hour)
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
Logger: logger,
CoordPtr: &coordPtr,
DERPMapUpdateFrequency: time.Millisecond,
DERPMapFn: func() *tailcfg.DERPMap { return <-derpMapCh },
NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) {},
NetworkTelemetryHandler: func(_ []*proto.TelemetryEvent) {},
ResumeTokenProvider: resumeTokenProvider,
})
require.NoError(t, err)
+9 -10
View File
@@ -1454,7 +1454,7 @@ CreateWorkspaceRequest provides options for creating a new workspace. Only one o
```json
{
"deletes_at": "2019-08-24T14:15:22Z",
"feature": "workspace_apps",
"feature": "workspace_apps_api_key",
"secret": "string",
"sequence": 0,
"starts_at": "2019-08-24T14:15:22Z"
@@ -1474,18 +1474,19 @@ CreateWorkspaceRequest provides options for creating a new workspace. Only one o
## codersdk.CryptoKeyFeature
```json
"workspace_apps"
"workspace_apps_api_key"
```
### Properties
#### Enumerated Values
| Value |
| ---------------- |
| `workspace_apps` |
| `oidc_convert` |
| `tailnet_resume` |
| Value |
| ------------------------ |
| `workspace_apps_api_key` |
| `workspace_apps_token` |
| `oidc_convert` |
| `tailnet_resume` |
## codersdk.CustomRoleRequest
@@ -9893,7 +9894,7 @@ _None_
"crypto_keys": [
{
"deletes_at": "2019-08-24T14:15:22Z",
"feature": "workspace_apps",
"feature": "workspace_apps_api_key",
"secret": "string",
"sequence": 0,
"starts_at": "2019-08-24T14:15:22Z"
@@ -9971,7 +9972,6 @@ _None_
```json
{
"app_security_key": "string",
"derp_force_websockets": true,
"derp_map": {
"homeParams": {
@@ -10052,7 +10052,6 @@ _None_
| Name | Type | Required | Restrictions | Description |
| ----------------------- | --------------------------------------------- | -------- | ------------ | -------------------------------------------------------------------------------------- |
| `app_security_key` | string | false | | |
| `derp_force_websockets` | boolean | false | | |
| `derp_map` | [tailcfg.DERPMap](#tailcfgderpmap) | false | | |
| `derp_mesh_key` | string | false | | |
+5 -1
View File
@@ -65,6 +65,8 @@ type WorkspaceProxy struct {
// owner client. If a token is provided, the proxy will become a replica of the
// existing proxy region.
func NewWorkspaceProxyReplica(t *testing.T, coderdAPI *coderd.API, owner *codersdk.Client, options *ProxyOptions) WorkspaceProxy {
t.Helper()
ctx, cancelFunc := context.WithCancel(context.Background())
t.Cleanup(cancelFunc)
@@ -142,8 +144,10 @@ func NewWorkspaceProxyReplica(t *testing.T, coderdAPI *coderd.API, owner *coders
statsCollectorOptions.Flush = options.FlushStats
}
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug).With(slog.F("server_url", serverURL.String()))
wssrv, err := wsproxy.New(ctx, &wsproxy.Options{
Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug).With(slog.F("server_url", serverURL.String())),
Logger: logger,
Experiments: options.Experiments,
DashboardURL: coderdAPI.AccessURL,
AccessURL: accessURL,
+25 -2
View File
@@ -7,6 +7,7 @@ import (
"fmt"
"net/http"
"net/url"
"slices"
"strings"
"time"
@@ -33,6 +34,13 @@ import (
"github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk"
)
// whitelistedCryptoKeyFeatures is a list of crypto key features that are
// allowed to be queried with workspace proxies.
var whitelistedCryptoKeyFeatures = []database.CryptoKeyFeature{
database.CryptoKeyFeatureWorkspaceAppsToken,
database.CryptoKeyFeatureWorkspaceAppsAPIKey,
}
// forceWorkspaceProxyHealthUpdate forces an update of the proxy health.
// This is useful when a proxy is created or deleted. Errors will be logged.
func (api *API) forceWorkspaceProxyHealthUpdate(ctx context.Context) {
@@ -700,7 +708,6 @@ func (api *API) workspaceProxyRegister(rw http.ResponseWriter, r *http.Request)
}
httpapi.Write(ctx, rw, http.StatusCreated, wsproxysdk.RegisterWorkspaceProxyResponse{
AppSecurityKey: api.AppSecurityKey.String(),
DERPMeshKey: api.DERPServer.MeshKey(),
DERPRegionID: regionID,
DERPMap: api.AGPL.DERPMap(),
@@ -721,13 +728,29 @@ func (api *API) workspaceProxyRegister(rw http.ResponseWriter, r *http.Request)
// @Security CoderSessionToken
// @Produce json
// @Tags Enterprise
// @Param feature query string true "Feature key"
// @Success 200 {object} wsproxysdk.CryptoKeysResponse
// @Router /workspaceproxies/me/crypto-keys [get]
// @x-apidocgen {"skip": true}
func (api *API) workspaceProxyCryptoKeys(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
keys, err := api.Database.GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps)
feature := database.CryptoKeyFeature(r.URL.Query().Get("feature"))
if feature == "" {
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
Message: "Missing feature query parameter.",
})
return
}
if !slices.Contains(whitelistedCryptoKeyFeatures, feature) {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: fmt.Sprintf("Invalid feature: %q", feature),
})
return
}
keys, err := api.Database.GetCryptoKeysByFeature(ctx, feature)
if err != nil {
httpapi.InternalServerError(rw, err)
return
+72 -25
View File
@@ -320,7 +320,6 @@ func TestProxyRegisterDeregister(t *testing.T) {
}
registerRes1, err := proxyClient.RegisterWorkspaceProxy(ctx, req)
require.NoError(t, err)
require.NotEmpty(t, registerRes1.AppSecurityKey)
require.NotEmpty(t, registerRes1.DERPMeshKey)
require.EqualValues(t, 10001, registerRes1.DERPRegionID)
require.Empty(t, registerRes1.SiblingReplicas)
@@ -609,11 +608,8 @@ func TestProxyRegisterDeregister(t *testing.T) {
func TestIssueSignedAppToken(t *testing.T) {
t.Parallel()
db, pubsub := dbtestutil.NewDB(t)
client, user := coderdenttest.New(t, &coderdenttest.Options{
Options: &coderdtest.Options{
Database: db,
Pubsub: pubsub,
IncludeProvisionerDaemon: true,
},
LicenseOptions: &coderdenttest.LicenseOptions{
@@ -716,6 +712,10 @@ func TestReconnectingPTYSignedToken(t *testing.T) {
closer.Close()
})
_ = dbgen.CryptoKey(t, db, database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceAppsToken,
})
// Create a workspace + apps
authToken := uuid.NewString()
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
@@ -915,51 +915,86 @@ func TestGetCryptoKeys(t *testing.T) {
now := time.Now()
expectedKey1 := dbgen.CryptoKey(t, db, database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceApps,
Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey,
StartsAt: now.Add(-time.Hour),
Sequence: 2,
})
key1 := db2sdk.CryptoKey(expectedKey1)
encryptionKey := db2sdk.CryptoKey(expectedKey1)
expectedKey2 := dbgen.CryptoKey(t, db, database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceApps,
Feature: database.CryptoKeyFeatureWorkspaceAppsToken,
StartsAt: now,
Sequence: 3,
})
key2 := db2sdk.CryptoKey(expectedKey2)
signingKey := db2sdk.CryptoKey(expectedKey2)
// Create a deleted key.
_ = dbgen.CryptoKey(t, db, database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceApps,
Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey,
StartsAt: now.Add(-time.Hour),
Secret: sql.NullString{
String: "secret1",
Valid: false,
},
Sequence: 1,
})
// Create a key with different features.
_ = dbgen.CryptoKey(t, db, database.CryptoKey{
Feature: database.CryptoKeyFeatureTailnetResume,
StartsAt: now.Add(-time.Hour),
Sequence: 1,
})
_ = dbgen.CryptoKey(t, db, database.CryptoKey{
Feature: database.CryptoKeyFeatureOidcConvert,
StartsAt: now.Add(-time.Hour),
Sequence: 1,
Sequence: 4,
})
proxy := coderdenttest.NewWorkspaceProxyReplica(t, api, cclient, &coderdenttest.ProxyOptions{
Name: testutil.GetRandomName(t),
})
keys, err := proxy.SDKClient.CryptoKeys(ctx)
keys, err := proxy.SDKClient.CryptoKeys(ctx, codersdk.CryptoKeyFeatureWorkspaceAppsAPIKey)
require.NoError(t, err)
require.NotEmpty(t, keys)
// 1 key is generated on startup, the other we manually generated.
require.Equal(t, 2, len(keys.CryptoKeys))
requireContainsKeys(t, keys.CryptoKeys, key1, key2)
requireContainsKeys(t, keys.CryptoKeys, encryptionKey)
requireNotContainsKeys(t, keys.CryptoKeys, signingKey)
keys, err = proxy.SDKClient.CryptoKeys(ctx, codersdk.CryptoKeyFeatureWorkspaceAppsToken)
require.NoError(t, err)
require.NotEmpty(t, keys)
// 1 key is generated on startup, the other we manually generated.
require.Equal(t, 2, len(keys.CryptoKeys))
requireContainsKeys(t, keys.CryptoKeys, signingKey)
requireNotContainsKeys(t, keys.CryptoKeys, encryptionKey)
})
t.Run("InvalidFeature", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
db, pubsub := dbtestutil.NewDB(t)
cclient, _, api, _ := coderdenttest.NewWithAPI(t, &coderdenttest.Options{
Options: &coderdtest.Options{
Database: db,
Pubsub: pubsub,
IncludeProvisionerDaemon: true,
},
LicenseOptions: &coderdenttest.LicenseOptions{
Features: license.Features{
codersdk.FeatureWorkspaceProxy: 1,
},
},
})
proxy := coderdenttest.NewWorkspaceProxyReplica(t, api, cclient, &coderdenttest.ProxyOptions{
Name: testutil.GetRandomName(t),
})
_, err := proxy.SDKClient.CryptoKeys(ctx, codersdk.CryptoKeyFeatureOIDCConvert)
require.Error(t, err)
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
_, err = proxy.SDKClient.CryptoKeys(ctx, codersdk.CryptoKeyFeatureTailnetResume)
require.Error(t, err)
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
_, err = proxy.SDKClient.CryptoKeys(ctx, "invalid")
require.Error(t, err)
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
})
t.Run("Unauthorized", func(t *testing.T) {
@@ -987,7 +1022,7 @@ func TestGetCryptoKeys(t *testing.T) {
client := wsproxysdk.New(cclient.URL)
client.SetSessionToken(cclient.SessionToken())
_, err := client.CryptoKeys(ctx)
_, err := client.CryptoKeys(ctx, codersdk.CryptoKeyFeatureWorkspaceAppsAPIKey)
require.Error(t, err)
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
@@ -995,6 +1030,18 @@ func TestGetCryptoKeys(t *testing.T) {
})
}
func requireNotContainsKeys(t *testing.T, keys []codersdk.CryptoKey, unexpected ...codersdk.CryptoKey) {
t.Helper()
for _, unexpectedKey := range unexpected {
for _, key := range keys {
if key.Feature == unexpectedKey.Feature && key.Sequence == unexpectedKey.Sequence {
t.Fatalf("unexpected key %+v found", unexpectedKey)
}
}
}
}
func requireContainsKeys(t *testing.T, keys []codersdk.CryptoKey, expected ...codersdk.CryptoKey) {
t.Helper()
+5 -5
View File
@@ -397,12 +397,12 @@ func TestCryptoKeys(t *testing.T) {
_ = dbgen.CryptoKey(t, crypt, database.CryptoKey{
Secret: sql.NullString{String: "test", Valid: true},
})
key, err := crypt.GetLatestCryptoKeyByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps)
key, err := crypt.GetLatestCryptoKeyByFeature(ctx, database.CryptoKeyFeatureWorkspaceAppsAPIKey)
require.NoError(t, err)
require.Equal(t, "test", key.Secret.String)
require.Equal(t, ciphers[0].HexDigest(), key.SecretKeyID.String)
key, err = db.GetLatestCryptoKeyByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps)
key, err = db.GetLatestCryptoKeyByFeature(ctx, database.CryptoKeyFeatureWorkspaceAppsAPIKey)
require.NoError(t, err)
requireEncryptedEquals(t, ciphers[0], key.Secret.String, "test")
require.Equal(t, ciphers[0].HexDigest(), key.SecretKeyID.String)
@@ -415,7 +415,7 @@ func TestCryptoKeys(t *testing.T) {
Secret: sql.NullString{String: "test", Valid: true},
})
key, err := crypt.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{
Feature: database.CryptoKeyFeatureWorkspaceApps,
Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey,
Sequence: key.Sequence,
})
require.NoError(t, err)
@@ -423,7 +423,7 @@ func TestCryptoKeys(t *testing.T) {
require.Equal(t, ciphers[0].HexDigest(), key.SecretKeyID.String)
key, err = db.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{
Feature: database.CryptoKeyFeatureWorkspaceApps,
Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey,
Sequence: key.Sequence,
})
require.NoError(t, err)
@@ -459,7 +459,7 @@ func TestCryptoKeys(t *testing.T) {
Secret: sql.NullString{String: "test", Valid: true},
})
_ = dbgen.CryptoKey(t, crypt, database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceApps,
Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey,
Sequence: 43,
})
keys, err := crypt.GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureTailnetResume)
+6
View File
@@ -5,6 +5,7 @@ import (
"testing"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/coderd/workspaceapps/apptest"
"github.com/coder/coder/v2/codersdk"
@@ -36,6 +37,9 @@ func TestWorkspaceApps(t *testing.T) {
flushStatsCollectorCh <- flushStatsCollectorDone
<-flushStatsCollectorDone
}
db, pubsub := dbtestutil.NewDB(t)
client, _, _, user := coderdenttest.NewWithAPI(t, &coderdenttest.Options{
Options: &coderdtest.Options{
DeploymentValues: deploymentValues,
@@ -51,6 +55,8 @@ func TestWorkspaceApps(t *testing.T) {
},
},
WorkspaceAppsStatsCollectorOptions: opts.StatsCollectorOptions,
Database: db,
Pubsub: pubsub,
},
LicenseOptions: &coderdenttest.LicenseOptions{
Features: license.Features{
+3 -4
View File
@@ -13,12 +13,11 @@ import (
var _ cryptokeys.Fetcher = &ProxyFetcher{}
type ProxyFetcher struct {
Client *wsproxysdk.Client
Feature codersdk.CryptoKeyFeature
Client *wsproxysdk.Client
}
func (p *ProxyFetcher) Fetch(ctx context.Context) ([]codersdk.CryptoKey, error) {
keys, err := p.Client.CryptoKeys(ctx)
func (p *ProxyFetcher) Fetch(ctx context.Context, feature codersdk.CryptoKeyFeature) ([]codersdk.CryptoKey, error) {
keys, err := p.Client.CryptoKeys(ctx, feature)
if err != nil {
return nil, xerrors.Errorf("crypto keys: %w", err)
}
+10 -6
View File
@@ -7,6 +7,8 @@ import (
"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/cryptokeys"
"github.com/coder/coder/v2/coderd/jwtutils"
"github.com/coder/coder/v2/coderd/workspaceapps"
"github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk"
)
@@ -18,18 +20,19 @@ type TokenProvider struct {
AccessURL *url.URL
AppHostname string
Client *wsproxysdk.Client
SecurityKey workspaceapps.SecurityKey
Logger slog.Logger
Client *wsproxysdk.Client
TokenSigningKeycache cryptokeys.SigningKeycache
APIKeyEncryptionKeycache cryptokeys.EncryptionKeycache
Logger slog.Logger
}
func (p *TokenProvider) FromRequest(r *http.Request) (*workspaceapps.SignedToken, bool) {
return workspaceapps.FromRequest(r, p.SecurityKey)
return workspaceapps.FromRequest(r, p.TokenSigningKeycache)
}
func (p *TokenProvider) Issue(ctx context.Context, rw http.ResponseWriter, r *http.Request, issueReq workspaceapps.IssueTokenRequest) (*workspaceapps.SignedToken, string, bool) {
appReq := issueReq.AppRequest.Normalize()
err := appReq.Validate()
err := appReq.Check()
if err != nil {
workspaceapps.WriteWorkspaceApp500(p.Logger, p.DashboardURL, rw, r, &appReq, err, "invalid app request")
return nil, "", false
@@ -42,7 +45,8 @@ func (p *TokenProvider) Issue(ctx context.Context, rw http.ResponseWriter, r *ht
}
// Check that it verifies properly and matches the string.
token, err := p.SecurityKey.VerifySignedToken(resp.SignedTokenStr)
var token workspaceapps.SignedToken
err = jwtutils.Verify(ctx, p.TokenSigningKeycache, resp.SignedTokenStr, &token)
if err != nil {
workspaceapps.WriteWorkspaceApp500(p.Logger, p.DashboardURL, rw, r, &appReq, err, "failed to verify newly generated signed token")
return nil, "", false
+54 -25
View File
@@ -31,6 +31,7 @@ import (
"github.com/coder/coder/v2/buildinfo"
"github.com/coder/coder/v2/cli/cliutil"
"github.com/coder/coder/v2/coderd"
"github.com/coder/coder/v2/coderd/cryptokeys"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/coderd/tracing"
@@ -130,6 +131,13 @@ type Server struct {
// the moon's token.
SDKClient *wsproxysdk.Client
// apiKeyEncryptionKeycache manages the encryption keys for smuggling API
// tokens to the alternate domain when using workspace apps.
apiKeyEncryptionKeycache cryptokeys.EncryptionKeycache
// appTokenSigningKeycache manages the signing keys for signing the app
// tokens we use for workspace apps.
appTokenSigningKeycache cryptokeys.SigningKeycache
// DERP
derpMesh *derpmesh.Mesh
derpMeshTLSConfig *tls.Config
@@ -195,19 +203,42 @@ func New(ctx context.Context, opts *Options) (*Server, error) {
derpServer := derp.NewServer(key.NewNode(), tailnet.Logger(opts.Logger.Named("net.derp")))
ctx, cancel := context.WithCancel(context.Background())
encryptionCache, err := cryptokeys.NewEncryptionCache(ctx,
opts.Logger,
&ProxyFetcher{Client: client},
codersdk.CryptoKeyFeatureWorkspaceAppsAPIKey,
)
if err != nil {
cancel()
return nil, xerrors.Errorf("create api key encryption cache: %w", err)
}
signingCache, err := cryptokeys.NewSigningCache(ctx,
opts.Logger,
&ProxyFetcher{Client: client},
codersdk.CryptoKeyFeatureWorkspaceAppsToken,
)
if err != nil {
cancel()
return nil, xerrors.Errorf("create api token signing cache: %w", err)
}
r := chi.NewRouter()
s := &Server{
Options: opts,
Handler: r,
DashboardURL: opts.DashboardURL,
Logger: opts.Logger.Named("net.workspace-proxy"),
TracerProvider: opts.Tracing,
PrometheusRegistry: opts.PrometheusRegistry,
SDKClient: client,
derpMesh: derpmesh.New(opts.Logger.Named("net.derpmesh"), derpServer, meshTLSConfig),
derpMeshTLSConfig: meshTLSConfig,
ctx: ctx,
cancel: cancel,
ctx: ctx,
cancel: cancel,
Options: opts,
Handler: r,
DashboardURL: opts.DashboardURL,
Logger: opts.Logger.Named("net.workspace-proxy"),
TracerProvider: opts.Tracing,
PrometheusRegistry: opts.PrometheusRegistry,
SDKClient: client,
derpMesh: derpmesh.New(opts.Logger.Named("net.derpmesh"), derpServer, meshTLSConfig),
derpMeshTLSConfig: meshTLSConfig,
apiKeyEncryptionKeycache: encryptionCache,
appTokenSigningKeycache: signingCache,
}
// Register the workspace proxy with the primary coderd instance and start a
@@ -240,11 +271,6 @@ func New(ctx context.Context, opts *Options) (*Server, error) {
return nil, xerrors.Errorf("handle register: %w", err)
}
secKey, err := workspaceapps.KeyFromString(regResp.AppSecurityKey)
if err != nil {
return nil, xerrors.Errorf("parse app security key: %w", err)
}
agentProvider, err := coderd.NewServerTailnet(ctx,
s.Logger,
nil,
@@ -277,20 +303,21 @@ func New(ctx context.Context, opts *Options) (*Server, error) {
HostnameRegex: opts.AppHostnameRegex,
RealIPConfig: opts.RealIPConfig,
SignedTokenProvider: &TokenProvider{
DashboardURL: opts.DashboardURL,
AccessURL: opts.AccessURL,
AppHostname: opts.AppHostname,
Client: client,
SecurityKey: secKey,
Logger: s.Logger.Named("proxy_token_provider"),
DashboardURL: opts.DashboardURL,
AccessURL: opts.AccessURL,
AppHostname: opts.AppHostname,
Client: client,
TokenSigningKeycache: signingCache,
APIKeyEncryptionKeycache: encryptionCache,
Logger: s.Logger.Named("proxy_token_provider"),
},
AppSecurityKey: secKey,
DisablePathApps: opts.DisablePathApps,
SecureAuthCookie: opts.SecureAuthCookie,
AgentProvider: agentProvider,
StatsCollector: workspaceapps.NewStatsCollector(opts.StatsCollectorOptions),
AgentProvider: agentProvider,
StatsCollector: workspaceapps.NewStatsCollector(opts.StatsCollectorOptions),
APIKeyEncryptionKeycache: encryptionCache,
}
derpHandler := derphttp.Handler(derpServer)
@@ -435,6 +462,8 @@ func (s *Server) Close() error {
err = multierror.Append(err, agentProviderErr)
}
s.SDKClient.SDKClient.HTTPClient.CloseIdleConnections()
_ = s.appTokenSigningKeycache.Close()
_ = s.apiKeyEncryptionKeycache.Close()
return err
}
+26
View File
@@ -25,6 +25,9 @@ import (
"github.com/coder/coder/v2/agent/agenttest"
"github.com/coder/coder/v2/buildinfo"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/healthcheck/derphealth"
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/coderd/workspaceapps/apptest"
@@ -932,6 +935,9 @@ func TestWorkspaceProxyWorkspaceApps(t *testing.T) {
if opts.PrimaryAppHost == "" {
opts.PrimaryAppHost = "*.primary.test.coder.com"
}
db, pubsub := dbtestutil.NewDB(t)
client, closer, api, user := coderdenttest.NewWithAPI(t, &coderdenttest.Options{
Options: &coderdtest.Options{
DeploymentValues: deploymentValues,
@@ -947,6 +953,8 @@ func TestWorkspaceProxyWorkspaceApps(t *testing.T) {
},
},
WorkspaceAppsStatsCollectorOptions: opts.StatsCollectorOptions,
Database: db,
Pubsub: pubsub,
},
LicenseOptions: &coderdenttest.LicenseOptions{
Features: license.Features{
@@ -959,6 +967,13 @@ func TestWorkspaceProxyWorkspaceApps(t *testing.T) {
_ = closer.Close()
})
_ = dbgen.CryptoKey(t, db, database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceAppsToken,
})
_ = dbgen.CryptoKey(t, db, database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey,
})
// Create the external proxy
if opts.DisableSubdomainApps {
opts.AppHost = ""
@@ -1002,6 +1017,8 @@ func TestWorkspaceProxyWorkspaceApps_BlockDirect(t *testing.T) {
if opts.PrimaryAppHost == "" {
opts.PrimaryAppHost = "*.primary.test.coder.com"
}
db, pubsub := dbtestutil.NewDB(t)
client, closer, api, user := coderdenttest.NewWithAPI(t, &coderdenttest.Options{
Options: &coderdtest.Options{
DeploymentValues: deploymentValues,
@@ -1017,6 +1034,8 @@ func TestWorkspaceProxyWorkspaceApps_BlockDirect(t *testing.T) {
},
},
WorkspaceAppsStatsCollectorOptions: opts.StatsCollectorOptions,
Database: db,
Pubsub: pubsub,
},
LicenseOptions: &coderdenttest.LicenseOptions{
Features: license.Features{
@@ -1029,6 +1048,13 @@ func TestWorkspaceProxyWorkspaceApps_BlockDirect(t *testing.T) {
_ = closer.Close()
})
_ = dbgen.CryptoKey(t, db, database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceAppsToken,
})
_ = dbgen.CryptoKey(t, db, database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey,
})
// Create the external proxy
if opts.DisableSubdomainApps {
opts.AppHost = ""
+3 -10
View File
@@ -205,7 +205,6 @@ type RegisterWorkspaceProxyRequest struct {
}
type RegisterWorkspaceProxyResponse struct {
AppSecurityKey string `json:"app_security_key"`
DERPMeshKey string `json:"derp_mesh_key"`
DERPRegionID int32 `json:"derp_region_id"`
DERPMap *tailcfg.DERPMap `json:"derp_map"`
@@ -372,12 +371,6 @@ func (l *RegisterWorkspaceProxyLoop) Start(ctx context.Context) (RegisterWorkspa
}
failedAttempts = 0
// Check for consistency.
if originalRes.AppSecurityKey != resp.AppSecurityKey {
l.failureFn(xerrors.New("app security key has changed, proxy must be restarted"))
return
}
if originalRes.DERPMeshKey != resp.DERPMeshKey {
l.failureFn(xerrors.New("DERP mesh key has changed, proxy must be restarted"))
return
@@ -586,10 +579,10 @@ type CryptoKeysResponse struct {
CryptoKeys []codersdk.CryptoKey `json:"crypto_keys"`
}
func (c *Client) CryptoKeys(ctx context.Context) (CryptoKeysResponse, error) {
func (c *Client) CryptoKeys(ctx context.Context, feature codersdk.CryptoKeyFeature) (CryptoKeysResponse, error) {
res, err := c.Request(ctx, http.MethodGet,
"/api/v2/workspaceproxies/me/crypto-keys",
nil,
"/api/v2/workspaceproxies/me/crypto-keys", nil,
codersdk.WithQueryParam("feature", string(feature)),
)
if err != nil {
return CryptoKeysResponse{}, xerrors.Errorf("make request: %w", err)
+2 -2
View File
@@ -2110,8 +2110,8 @@ export type BuildReason = "autostart" | "autostop" | "initiator"
export const BuildReasons: BuildReason[] = ["autostart", "autostop", "initiator"]
// From codersdk/deployment.go
export type CryptoKeyFeature = "oidc_convert" | "tailnet_resume" | "workspace_apps"
export const CryptoKeyFeatures: CryptoKeyFeature[] = ["oidc_convert", "tailnet_resume", "workspace_apps"]
export type CryptoKeyFeature = "oidc_convert" | "tailnet_resume" | "workspace_apps_api_key" | "workspace_apps_token"
export const CryptoKeyFeatures: CryptoKeyFeature[] = ["oidc_convert", "tailnet_resume", "workspace_apps_api_key", "workspace_apps_token"]
// From codersdk/workspaceagents.go
export type DisplayApp = "port_forwarding_helper" | "ssh_helper" | "vscode" | "vscode_insiders" | "web_terminal"
+26 -117
View File
@@ -3,32 +3,23 @@ package tailnet
import (
"context"
"crypto/rand"
"database/sql"
"encoding/hex"
"encoding/json"
"time"
"github.com/go-jose/go-jose/v3"
"github.com/go-jose/go-jose/v4/jwt"
"github.com/google/uuid"
"golang.org/x/xerrors"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/coder/coder/v2/coderd/jwtutils"
"github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/quartz"
)
const (
DefaultResumeTokenExpiry = 24 * time.Hour
resumeTokenSigningAlgorithm = jose.HS512
)
// resumeTokenSigningKeyID is a fixed key ID for the resume token signing key.
// If/when we add support for multiple keys (e.g. key rotation), this will move
// to the database instead.
var resumeTokenSigningKeyID = uuid.MustParse("97166747-9309-4d7f-9071-a230e257c2a4")
// NewInsecureTestResumeTokenProvider returns a ResumeTokenProvider that uses a
// random key with short expiry for testing purposes. If any errors occur while
// generating the key, the function panics.
@@ -37,12 +28,15 @@ func NewInsecureTestResumeTokenProvider() ResumeTokenProvider {
if err != nil {
panic(err)
}
return NewResumeTokenKeyProvider(key, quartz.NewReal(), time.Hour)
return NewResumeTokenKeyProvider(jwtutils.StaticKey{
ID: uuid.New().String(),
Key: key[:],
}, quartz.NewReal(), time.Hour)
}
type ResumeTokenProvider interface {
GenerateResumeToken(peerID uuid.UUID) (*proto.RefreshResumeTokenResponse, error)
VerifyResumeToken(token string) (uuid.UUID, error)
GenerateResumeToken(ctx context.Context, peerID uuid.UUID) (*proto.RefreshResumeTokenResponse, error)
VerifyResumeToken(ctx context.Context, token string) (uuid.UUID, error)
}
type ResumeTokenSigningKey [64]byte
@@ -56,104 +50,37 @@ func GenerateResumeTokenSigningKey() (ResumeTokenSigningKey, error) {
return key, nil
}
type ResumeTokenSigningKeyDatabaseStore interface {
GetCoordinatorResumeTokenSigningKey(ctx context.Context) (string, error)
UpsertCoordinatorResumeTokenSigningKey(ctx context.Context, key string) error
}
// ResumeTokenSigningKeyFromDatabase retrieves the coordinator resume token
// signing key from the database. If the key is not found, a new key is
// generated and inserted into the database.
func ResumeTokenSigningKeyFromDatabase(ctx context.Context, db ResumeTokenSigningKeyDatabaseStore) (ResumeTokenSigningKey, error) {
var resumeTokenKey ResumeTokenSigningKey
resumeTokenKeyStr, err := db.GetCoordinatorResumeTokenSigningKey(ctx)
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
return resumeTokenKey, xerrors.Errorf("get coordinator resume token key: %w", err)
}
if decoded, err := hex.DecodeString(resumeTokenKeyStr); err != nil || len(decoded) != len(resumeTokenKey) {
newKey, err := GenerateResumeTokenSigningKey()
if err != nil {
return resumeTokenKey, xerrors.Errorf("generate fresh coordinator resume token key: %w", err)
}
resumeTokenKeyStr = hex.EncodeToString(newKey[:])
err = db.UpsertCoordinatorResumeTokenSigningKey(ctx, resumeTokenKeyStr)
if err != nil {
return resumeTokenKey, xerrors.Errorf("insert freshly generated coordinator resume token key to database: %w", err)
}
}
resumeTokenKeyBytes, err := hex.DecodeString(resumeTokenKeyStr)
if err != nil {
return resumeTokenKey, xerrors.Errorf("decode coordinator resume token key from database: %w", err)
}
if len(resumeTokenKeyBytes) != len(resumeTokenKey) {
return resumeTokenKey, xerrors.Errorf("coordinator resume token key in database is not the correct length, expect %d got %d", len(resumeTokenKey), len(resumeTokenKeyBytes))
}
copy(resumeTokenKey[:], resumeTokenKeyBytes)
if resumeTokenKey == [64]byte{} {
return resumeTokenKey, xerrors.Errorf("coordinator resume token key in database is empty")
}
return resumeTokenKey, nil
}
type ResumeTokenKeyProvider struct {
key ResumeTokenSigningKey
key jwtutils.SigningKeyManager
clock quartz.Clock
expiry time.Duration
}
func NewResumeTokenKeyProvider(key ResumeTokenSigningKey, clock quartz.Clock, expiry time.Duration) ResumeTokenProvider {
func NewResumeTokenKeyProvider(key jwtutils.SigningKeyManager, clock quartz.Clock, expiry time.Duration) ResumeTokenProvider {
if expiry <= 0 {
expiry = DefaultResumeTokenExpiry
}
return ResumeTokenKeyProvider{
key: key,
clock: clock,
expiry: DefaultResumeTokenExpiry,
expiry: expiry,
}
}
type resumeTokenPayload struct {
PeerID uuid.UUID `json:"sub"`
Expiry int64 `json:"exp"`
}
func (p ResumeTokenKeyProvider) GenerateResumeToken(peerID uuid.UUID) (*proto.RefreshResumeTokenResponse, error) {
func (p ResumeTokenKeyProvider) GenerateResumeToken(ctx context.Context, peerID uuid.UUID) (*proto.RefreshResumeTokenResponse, error) {
exp := p.clock.Now().Add(p.expiry)
payload := resumeTokenPayload{
PeerID: peerID,
Expiry: exp.Unix(),
}
payloadBytes, err := json.Marshal(payload)
if err != nil {
return nil, xerrors.Errorf("marshal payload to JSON: %w", err)
payload := jwtutils.RegisteredClaims{
Subject: peerID.String(),
Expiry: jwt.NewNumericDate(exp),
}
signer, err := jose.NewSigner(jose.SigningKey{
Algorithm: resumeTokenSigningAlgorithm,
Key: p.key[:],
}, &jose.SignerOptions{
ExtraHeaders: map[jose.HeaderKey]interface{}{
"kid": resumeTokenSigningKeyID.String(),
},
})
if err != nil {
return nil, xerrors.Errorf("create signer: %w", err)
}
signedObject, err := signer.Sign(payloadBytes)
token, err := jwtutils.Sign(ctx, p.key, payload)
if err != nil {
return nil, xerrors.Errorf("sign payload: %w", err)
}
serialized, err := signedObject.CompactSerialize()
if err != nil {
return nil, xerrors.Errorf("serialize JWS: %w", err)
}
return &proto.RefreshResumeTokenResponse{
Token: serialized,
Token: token,
RefreshIn: durationpb.New(p.expiry / 2),
ExpiresAt: timestamppb.New(exp),
}, nil
@@ -162,35 +89,17 @@ func (p ResumeTokenKeyProvider) GenerateResumeToken(peerID uuid.UUID) (*proto.Re
// VerifyResumeToken parses a signed tailnet resume token with the given key and
// returns the payload. If the token is invalid or expired, an error is
// returned.
func (p ResumeTokenKeyProvider) VerifyResumeToken(str string) (uuid.UUID, error) {
object, err := jose.ParseSigned(str)
func (p ResumeTokenKeyProvider) VerifyResumeToken(ctx context.Context, str string) (uuid.UUID, error) {
var tok jwt.Claims
err := jwtutils.Verify(ctx, p.key, str, &tok, jwtutils.WithVerifyExpected(jwt.Expected{
Time: p.clock.Now(),
}))
if err != nil {
return uuid.Nil, xerrors.Errorf("parse JWS: %w", err)
return uuid.Nil, xerrors.Errorf("verify payload: %w", err)
}
if len(object.Signatures) != 1 {
return uuid.Nil, xerrors.New("expected 1 signature")
}
if object.Signatures[0].Header.Algorithm != string(resumeTokenSigningAlgorithm) {
return uuid.Nil, xerrors.Errorf("expected token signing algorithm to be %q, got %q", resumeTokenSigningAlgorithm, object.Signatures[0].Header.Algorithm)
}
if object.Signatures[0].Header.KeyID != resumeTokenSigningKeyID.String() {
return uuid.Nil, xerrors.Errorf("expected token key ID to be %q, got %q", resumeTokenSigningKeyID, object.Signatures[0].Header.KeyID)
}
output, err := object.Verify(p.key[:])
parsed, err := uuid.Parse(tok.Subject)
if err != nil {
return uuid.Nil, xerrors.Errorf("verify JWS: %w", err)
return uuid.Nil, xerrors.Errorf("parse peerID from token: %w", err)
}
var tok resumeTokenPayload
err = json.Unmarshal(output, &tok)
if err != nil {
return uuid.Nil, xerrors.Errorf("unmarshal payload: %w", err)
}
exp := time.Unix(tok.Expiry, 0)
if exp.Before(p.clock.Now()) {
return uuid.Nil, xerrors.New("signed resume token expired")
}
return tok.PeerID, nil
return parsed, nil
}
+34 -116
View File
@@ -1,117 +1,20 @@
package tailnet_test
import (
"context"
"encoding/hex"
"testing"
"time"
"github.com/go-jose/go-jose/v4"
"github.com/go-jose/go-jose/v4/jwt"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/jwtutils"
"github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
)
func TestResumeTokenSigningKeyFromDatabase(t *testing.T) {
t.Parallel()
assertRandomKey := func(t *testing.T, key tailnet.ResumeTokenSigningKey) {
t.Helper()
assert.NotEqual(t, tailnet.ResumeTokenSigningKey{}, key, "key should not be empty")
assert.NotEqualValues(t, [64]byte{1}, key, "key should not be all 1s")
}
t.Run("GenerateRetrieve", func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitShort)
key1, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db)
require.NoError(t, err)
assertRandomKey(t, key1)
key2, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db)
require.NoError(t, err)
require.Equal(t, key1, key2, "keys should not be different")
})
t.Run("GetError", func(t *testing.T) {
t.Parallel()
db := dbmock.NewMockStore(gomock.NewController(t))
db.EXPECT().GetCoordinatorResumeTokenSigningKey(gomock.Any()).Return("", assert.AnError)
ctx := testutil.Context(t, testutil.WaitShort)
_, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db)
require.ErrorIs(t, err, assert.AnError)
})
t.Run("UpsertError", func(t *testing.T) {
t.Parallel()
db := dbmock.NewMockStore(gomock.NewController(t))
db.EXPECT().GetCoordinatorResumeTokenSigningKey(gomock.Any()).Return("", nil)
db.EXPECT().UpsertCoordinatorResumeTokenSigningKey(gomock.Any(), gomock.Any()).Return(assert.AnError)
ctx := testutil.Context(t, testutil.WaitShort)
_, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db)
require.ErrorIs(t, err, assert.AnError)
})
t.Run("DecodeErrorShouldRegenerate", func(t *testing.T) {
t.Parallel()
db := dbmock.NewMockStore(gomock.NewController(t))
db.EXPECT().GetCoordinatorResumeTokenSigningKey(gomock.Any()).Return("invalid", nil)
var storedKey tailnet.ResumeTokenSigningKey
db.EXPECT().UpsertCoordinatorResumeTokenSigningKey(gomock.Any(), gomock.Any()).Do(func(_ context.Context, value string) error {
keyBytes, err := hex.DecodeString(value)
require.NoError(t, err)
require.Len(t, keyBytes, len(storedKey))
copy(storedKey[:], keyBytes)
return nil
})
ctx := testutil.Context(t, testutil.WaitShort)
key, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db)
require.NoError(t, err)
assertRandomKey(t, key)
require.Equal(t, storedKey, key, "key should match stored value")
})
t.Run("LengthErrorShouldRegenerate", func(t *testing.T) {
t.Parallel()
db := dbmock.NewMockStore(gomock.NewController(t))
db.EXPECT().GetCoordinatorResumeTokenSigningKey(gomock.Any()).Return("deadbeef", nil)
db.EXPECT().UpsertCoordinatorResumeTokenSigningKey(gomock.Any(), gomock.Any()).Return(nil)
ctx := testutil.Context(t, testutil.WaitShort)
key, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db)
require.NoError(t, err)
assertRandomKey(t, key)
})
t.Run("EmptyError", func(t *testing.T) {
t.Parallel()
db := dbmock.NewMockStore(gomock.NewController(t))
emptyKey := hex.EncodeToString(make([]byte, 64))
db.EXPECT().GetCoordinatorResumeTokenSigningKey(gomock.Any()).Return(emptyKey, nil)
ctx := testutil.Context(t, testutil.WaitShort)
_, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db)
require.ErrorContains(t, err, "is empty")
})
}
func TestResumeTokenKeyProvider(t *testing.T) {
t.Parallel()
@@ -121,17 +24,18 @@ func TestResumeTokenKeyProvider(t *testing.T) {
t.Run("OK", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
id := uuid.New()
clock := quartz.NewMock(t)
provider := tailnet.NewResumeTokenKeyProvider(key, clock, tailnet.DefaultResumeTokenExpiry)
token, err := provider.GenerateResumeToken(id)
provider := tailnet.NewResumeTokenKeyProvider(newKeySigner(key), clock, tailnet.DefaultResumeTokenExpiry)
token, err := provider.GenerateResumeToken(ctx, id)
require.NoError(t, err)
require.NotNil(t, token)
require.NotEmpty(t, token.Token)
require.Equal(t, tailnet.DefaultResumeTokenExpiry/2, token.RefreshIn.AsDuration())
require.WithinDuration(t, clock.Now().Add(tailnet.DefaultResumeTokenExpiry), token.ExpiresAt.AsTime(), time.Second)
gotID, err := provider.VerifyResumeToken(token.Token)
gotID, err := provider.VerifyResumeToken(ctx, token.Token)
require.NoError(t, err)
require.Equal(t, id, gotID)
})
@@ -139,43 +43,57 @@ func TestResumeTokenKeyProvider(t *testing.T) {
t.Run("Expired", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
id := uuid.New()
clock := quartz.NewMock(t)
provider := tailnet.NewResumeTokenKeyProvider(key, clock, tailnet.DefaultResumeTokenExpiry)
token, err := provider.GenerateResumeToken(id)
provider := tailnet.NewResumeTokenKeyProvider(newKeySigner(key), clock, tailnet.DefaultResumeTokenExpiry)
token, err := provider.GenerateResumeToken(ctx, id)
require.NoError(t, err)
require.NotNil(t, token)
require.NotEmpty(t, token.Token)
require.Equal(t, tailnet.DefaultResumeTokenExpiry/2, token.RefreshIn.AsDuration())
require.WithinDuration(t, clock.Now().Add(tailnet.DefaultResumeTokenExpiry), token.ExpiresAt.AsTime(), time.Second)
// Advance time past expiry
_ = clock.Advance(tailnet.DefaultResumeTokenExpiry + time.Second)
// Advance time past expiry. Account for leeway.
_ = clock.Advance(tailnet.DefaultResumeTokenExpiry + time.Second*61)
_, err = provider.VerifyResumeToken(token.Token)
require.ErrorContains(t, err, "expired")
_, err = provider.VerifyResumeToken(ctx, token.Token)
require.Error(t, err)
require.ErrorIs(t, err, jwt.ErrExpired)
})
t.Run("InvalidToken", func(t *testing.T) {
t.Parallel()
provider := tailnet.NewResumeTokenKeyProvider(key, quartz.NewMock(t), tailnet.DefaultResumeTokenExpiry)
_, err := provider.VerifyResumeToken("invalid")
ctx := testutil.Context(t, testutil.WaitShort)
provider := tailnet.NewResumeTokenKeyProvider(newKeySigner(key), quartz.NewMock(t), tailnet.DefaultResumeTokenExpiry)
_, err := provider.VerifyResumeToken(ctx, "invalid")
require.ErrorContains(t, err, "parse JWS")
})
t.Run("VerifyError", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
// Generate a resume token with a different key
otherKey, err := tailnet.GenerateResumeTokenSigningKey()
require.NoError(t, err)
otherProvider := tailnet.NewResumeTokenKeyProvider(otherKey, quartz.NewMock(t), tailnet.DefaultResumeTokenExpiry)
token, err := otherProvider.GenerateResumeToken(uuid.New())
otherSigner := newKeySigner(otherKey)
otherProvider := tailnet.NewResumeTokenKeyProvider(otherSigner, quartz.NewMock(t), tailnet.DefaultResumeTokenExpiry)
token, err := otherProvider.GenerateResumeToken(ctx, uuid.New())
require.NoError(t, err)
provider := tailnet.NewResumeTokenKeyProvider(key, quartz.NewMock(t), tailnet.DefaultResumeTokenExpiry)
_, err = provider.VerifyResumeToken(token.Token)
require.ErrorContains(t, err, "verify JWS")
signer := newKeySigner(key)
signer.ID = otherSigner.ID
provider := tailnet.NewResumeTokenKeyProvider(signer, quartz.NewMock(t), tailnet.DefaultResumeTokenExpiry)
_, err = provider.VerifyResumeToken(ctx, token.Token)
require.ErrorIs(t, err, jose.ErrCryptoFailure)
})
}
func newKeySigner(key tailnet.ResumeTokenSigningKey) jwtutils.StaticKey {
return jwtutils.StaticKey{
ID: "123",
Key: key[:],
}
}
+1 -1
View File
@@ -177,7 +177,7 @@ func (s *DRPCService) RefreshResumeToken(ctx context.Context, _ *proto.RefreshRe
return nil, xerrors.New("no Stream ID")
}
res, err := s.ResumeTokenProvider.GenerateResumeToken(streamID.ID)
res, err := s.ResumeTokenProvider.GenerateResumeToken(ctx, streamID.ID)
if err != nil {
return nil, xerrors.Errorf("generate resume token: %w", err)
}