mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
feat: enable key rotation (#15066)
This PR contains the remaining logic necessary to hook up key rotation to the product.
This commit is contained in:
+21
-81
@@ -10,7 +10,6 @@ import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
@@ -62,6 +61,7 @@ import (
|
||||
"github.com/coder/serpent"
|
||||
"github.com/coder/wgtunnel/tunnelsdk"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/cryptokeys"
|
||||
"github.com/coder/coder/v2/coderd/entitlements"
|
||||
"github.com/coder/coder/v2/coderd/notifications/reports"
|
||||
"github.com/coder/coder/v2/coderd/runtimeconfig"
|
||||
@@ -97,7 +97,6 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/updatecheck"
|
||||
"github.com/coder/coder/v2/coderd/util/slice"
|
||||
stringutil "github.com/coder/coder/v2/coderd/util/strings"
|
||||
"github.com/coder/coder/v2/coderd/workspaceapps"
|
||||
"github.com/coder/coder/v2/coderd/workspaceapps/appurl"
|
||||
"github.com/coder/coder/v2/coderd/workspacestats"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
@@ -743,90 +742,31 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
|
||||
return xerrors.Errorf("set deployment id: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Read the app signing key from the DB. We store it hex encoded
|
||||
// since the config table uses strings for the value and we
|
||||
// don't want to deal with automatic encoding issues.
|
||||
appSecurityKeyStr, err := tx.GetAppSecurityKey(ctx)
|
||||
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
|
||||
return xerrors.Errorf("get app signing key: %w", err)
|
||||
}
|
||||
// If the string in the DB is an invalid hex string or the
|
||||
// length is not equal to the current key length, generate a new
|
||||
// one.
|
||||
//
|
||||
// If the key is regenerated, old signed tokens and encrypted
|
||||
// strings will become invalid. New signed app tokens will be
|
||||
// generated automatically on failure. Any workspace app token
|
||||
// smuggling operations in progress may fail, although with a
|
||||
// helpful error.
|
||||
if decoded, err := hex.DecodeString(appSecurityKeyStr); err != nil || len(decoded) != len(workspaceapps.SecurityKey{}) {
|
||||
b := make([]byte, len(workspaceapps.SecurityKey{}))
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("generate fresh app signing key: %w", err)
|
||||
}
|
||||
|
||||
appSecurityKeyStr = hex.EncodeToString(b)
|
||||
err = tx.UpsertAppSecurityKey(ctx, appSecurityKeyStr)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("insert freshly generated app signing key to database: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
appSecurityKey, err := workspaceapps.KeyFromString(appSecurityKeyStr)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("decode app signing key from database: %w", err)
|
||||
}
|
||||
|
||||
options.AppSecurityKey = appSecurityKey
|
||||
|
||||
// Read the oauth signing key from the database. Like the app security, generate a new one
|
||||
// if it is invalid for any reason.
|
||||
oauthSigningKeyStr, err := tx.GetOAuthSigningKey(ctx)
|
||||
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
|
||||
return xerrors.Errorf("get app oauth signing key: %w", err)
|
||||
}
|
||||
if decoded, err := hex.DecodeString(oauthSigningKeyStr); err != nil || len(decoded) != len(options.OAuthSigningKey) {
|
||||
b := make([]byte, len(options.OAuthSigningKey))
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("generate fresh oauth signing key: %w", err)
|
||||
}
|
||||
|
||||
oauthSigningKeyStr = hex.EncodeToString(b)
|
||||
err = tx.UpsertOAuthSigningKey(ctx, oauthSigningKeyStr)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("insert freshly generated oauth signing key to database: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
oauthKeyBytes, err := hex.DecodeString(oauthSigningKeyStr)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("decode oauth signing key from database: %w", err)
|
||||
}
|
||||
if len(oauthKeyBytes) != len(options.OAuthSigningKey) {
|
||||
return xerrors.Errorf("oauth signing key in database is not the correct length, expect %d got %d", len(options.OAuthSigningKey), len(oauthKeyBytes))
|
||||
}
|
||||
copy(options.OAuthSigningKey[:], oauthKeyBytes)
|
||||
if options.OAuthSigningKey == [32]byte{} {
|
||||
return xerrors.Errorf("oauth signing key in database is empty")
|
||||
}
|
||||
|
||||
// Read the coordinator resume token signing key from the
|
||||
// database.
|
||||
resumeTokenKey, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, tx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get coordinator resume token key from database: %w", err)
|
||||
}
|
||||
options.CoordinatorResumeTokenProvider = tailnet.NewResumeTokenKeyProvider(resumeTokenKey, quartz.NewReal(), tailnet.DefaultResumeTokenExpiry)
|
||||
|
||||
return nil
|
||||
}, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
return xerrors.Errorf("set deployment id: %w", err)
|
||||
}
|
||||
|
||||
fetcher := &cryptokeys.DBFetcher{
|
||||
DB: options.Database,
|
||||
}
|
||||
|
||||
resumeKeycache, err := cryptokeys.NewSigningCache(ctx,
|
||||
logger,
|
||||
fetcher,
|
||||
codersdk.CryptoKeyFeatureTailnetResume,
|
||||
)
|
||||
if err != nil {
|
||||
logger.Critical(ctx, "failed to properly instantiate tailnet resume signing cache", slog.Error(err))
|
||||
}
|
||||
|
||||
options.CoordinatorResumeTokenProvider = tailnet.NewResumeTokenKeyProvider(
|
||||
resumeKeycache,
|
||||
quartz.NewReal(),
|
||||
tailnet.DefaultResumeTokenExpiry,
|
||||
)
|
||||
|
||||
options.RuntimeConfig = runtimeconfig.NewManager()
|
||||
|
||||
// This should be output before the logs start streaming.
|
||||
|
||||
Generated
+13
-5
@@ -7646,6 +7646,15 @@ const docTemplate = `{
|
||||
],
|
||||
"summary": "Get workspace proxy crypto keys",
|
||||
"operationId": "get-workspace-proxy-crypto-keys",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Feature key",
|
||||
"name": "feature",
|
||||
"in": "query",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
@@ -10011,12 +10020,14 @@ const docTemplate = `{
|
||||
"codersdk.CryptoKeyFeature": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"workspace_apps",
|
||||
"workspace_apps_api_key",
|
||||
"workspace_apps_token",
|
||||
"oidc_convert",
|
||||
"tailnet_resume"
|
||||
],
|
||||
"x-enum-varnames": [
|
||||
"CryptoKeyFeatureWorkspaceApp",
|
||||
"CryptoKeyFeatureWorkspaceAppsAPIKey",
|
||||
"CryptoKeyFeatureWorkspaceAppsToken",
|
||||
"CryptoKeyFeatureOIDCConvert",
|
||||
"CryptoKeyFeatureTailnetResume"
|
||||
]
|
||||
@@ -16244,9 +16255,6 @@ const docTemplate = `{
|
||||
"wsproxysdk.RegisterWorkspaceProxyResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"app_security_key": {
|
||||
"type": "string"
|
||||
},
|
||||
"derp_force_websockets": {
|
||||
"type": "boolean"
|
||||
},
|
||||
|
||||
Generated
+17
-5
@@ -6758,6 +6758,15 @@
|
||||
"tags": ["Enterprise"],
|
||||
"summary": "Get workspace proxy crypto keys",
|
||||
"operationId": "get-workspace-proxy-crypto-keys",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Feature key",
|
||||
"name": "feature",
|
||||
"in": "query",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
@@ -8914,9 +8923,15 @@
|
||||
},
|
||||
"codersdk.CryptoKeyFeature": {
|
||||
"type": "string",
|
||||
"enum": ["workspace_apps", "oidc_convert", "tailnet_resume"],
|
||||
"enum": [
|
||||
"workspace_apps_api_key",
|
||||
"workspace_apps_token",
|
||||
"oidc_convert",
|
||||
"tailnet_resume"
|
||||
],
|
||||
"x-enum-varnames": [
|
||||
"CryptoKeyFeatureWorkspaceApp",
|
||||
"CryptoKeyFeatureWorkspaceAppsAPIKey",
|
||||
"CryptoKeyFeatureWorkspaceAppsToken",
|
||||
"CryptoKeyFeatureOIDCConvert",
|
||||
"CryptoKeyFeatureTailnetResume"
|
||||
]
|
||||
@@ -14853,9 +14868,6 @@
|
||||
"wsproxysdk.RegisterWorkspaceProxyResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"app_security_key": {
|
||||
"type": "string"
|
||||
},
|
||||
"derp_force_websockets": {
|
||||
"type": "boolean"
|
||||
},
|
||||
|
||||
+61
-11
@@ -40,6 +40,7 @@ import (
|
||||
"github.com/coder/quartz"
|
||||
"github.com/coder/serpent"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/cryptokeys"
|
||||
"github.com/coder/coder/v2/coderd/entitlements"
|
||||
"github.com/coder/coder/v2/coderd/idpsync"
|
||||
"github.com/coder/coder/v2/coderd/runtimeconfig"
|
||||
@@ -185,9 +186,6 @@ type Options struct {
|
||||
TemplateScheduleStore *atomic.Pointer[schedule.TemplateScheduleStore]
|
||||
UserQuietHoursScheduleStore *atomic.Pointer[schedule.UserQuietHoursScheduleStore]
|
||||
AccessControlStore *atomic.Pointer[dbauthz.AccessControlStore]
|
||||
// AppSecurityKey is the crypto key used to sign and encrypt tokens related to
|
||||
// workspace applications. It consists of both a signing and encryption key.
|
||||
AppSecurityKey workspaceapps.SecurityKey
|
||||
// CoordinatorResumeTokenProvider is used to provide and validate resume
|
||||
// tokens issued by and passed to the coordinator DRPC API.
|
||||
CoordinatorResumeTokenProvider tailnet.ResumeTokenProvider
|
||||
@@ -251,6 +249,12 @@ type Options struct {
|
||||
|
||||
// OneTimePasscodeValidityPeriod specifies how long a one time passcode should be valid for.
|
||||
OneTimePasscodeValidityPeriod time.Duration
|
||||
|
||||
// Keycaches
|
||||
AppSigningKeyCache cryptokeys.SigningKeycache
|
||||
AppEncryptionKeyCache cryptokeys.EncryptionKeycache
|
||||
OIDCConvertKeyCache cryptokeys.SigningKeycache
|
||||
Clock quartz.Clock
|
||||
}
|
||||
|
||||
// @title Coder API
|
||||
@@ -352,6 +356,9 @@ func New(options *Options) *API {
|
||||
if options.PrometheusRegistry == nil {
|
||||
options.PrometheusRegistry = prometheus.NewRegistry()
|
||||
}
|
||||
if options.Clock == nil {
|
||||
options.Clock = quartz.NewReal()
|
||||
}
|
||||
if options.DERPServer == nil && options.DeploymentValues.DERP.Server.Enable {
|
||||
options.DERPServer = derp.NewServer(key.NewNode(), tailnet.Logger(options.Logger.Named("derp")))
|
||||
}
|
||||
@@ -444,6 +451,49 @@ func New(options *Options) *API {
|
||||
if err != nil {
|
||||
panic(xerrors.Errorf("get deployment ID: %w", err))
|
||||
}
|
||||
|
||||
fetcher := &cryptokeys.DBFetcher{
|
||||
DB: options.Database,
|
||||
}
|
||||
|
||||
if options.OIDCConvertKeyCache == nil {
|
||||
options.OIDCConvertKeyCache, err = cryptokeys.NewSigningCache(ctx,
|
||||
options.Logger.Named("oidc_convert_keycache"),
|
||||
fetcher,
|
||||
codersdk.CryptoKeyFeatureOIDCConvert,
|
||||
)
|
||||
if err != nil {
|
||||
options.Logger.Critical(ctx, "failed to properly instantiate oidc convert signing cache", slog.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
if options.AppSigningKeyCache == nil {
|
||||
options.AppSigningKeyCache, err = cryptokeys.NewSigningCache(ctx,
|
||||
options.Logger.Named("app_signing_keycache"),
|
||||
fetcher,
|
||||
codersdk.CryptoKeyFeatureWorkspaceAppsToken,
|
||||
)
|
||||
if err != nil {
|
||||
options.Logger.Critical(ctx, "failed to properly instantiate app signing key cache", slog.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
if options.AppEncryptionKeyCache == nil {
|
||||
options.AppEncryptionKeyCache, err = cryptokeys.NewEncryptionCache(ctx,
|
||||
options.Logger,
|
||||
fetcher,
|
||||
codersdk.CryptoKeyFeatureWorkspaceAppsAPIKey,
|
||||
)
|
||||
if err != nil {
|
||||
options.Logger.Critical(ctx, "failed to properly instantiate app encryption key cache", slog.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// Start a background process that rotates keys. We intentionally start this after the caches
|
||||
// are created to force initial requests for a key to populate the caches. This helps catch
|
||||
// bugs that may only occur when a key isn't precached in tests and the latency cost is minimal.
|
||||
cryptokeys.StartRotator(ctx, options.Logger, options.Database)
|
||||
|
||||
api := &API{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
@@ -464,7 +514,7 @@ func New(options *Options) *API {
|
||||
options.DeploymentValues,
|
||||
oauthConfigs,
|
||||
options.AgentInactiveDisconnectTimeout,
|
||||
options.AppSecurityKey,
|
||||
options.AppSigningKeyCache,
|
||||
),
|
||||
metricsCache: metricsCache,
|
||||
Auditor: atomic.Pointer[audit.Auditor]{},
|
||||
@@ -606,7 +656,7 @@ func New(options *Options) *API {
|
||||
ResumeTokenProvider: api.Options.CoordinatorResumeTokenProvider,
|
||||
})
|
||||
if err != nil {
|
||||
api.Logger.Fatal(api.ctx, "failed to initialize tailnet client service", slog.Error(err))
|
||||
api.Logger.Fatal(context.Background(), "failed to initialize tailnet client service", slog.Error(err))
|
||||
}
|
||||
|
||||
api.statsReporter = workspacestats.NewReporter(workspacestats.ReporterOptions{
|
||||
@@ -628,9 +678,6 @@ func New(options *Options) *API {
|
||||
options.WorkspaceAppsStatsCollectorOptions.Reporter = api.statsReporter
|
||||
}
|
||||
|
||||
if options.AppSecurityKey.IsZero() {
|
||||
api.Logger.Fatal(api.ctx, "app security key cannot be zero")
|
||||
}
|
||||
api.workspaceAppServer = &workspaceapps.Server{
|
||||
Logger: workspaceAppsLogger,
|
||||
|
||||
@@ -642,11 +689,11 @@ func New(options *Options) *API {
|
||||
|
||||
SignedTokenProvider: api.WorkspaceAppsProvider,
|
||||
AgentProvider: api.agentProvider,
|
||||
AppSecurityKey: options.AppSecurityKey,
|
||||
StatsCollector: workspaceapps.NewStatsCollector(options.WorkspaceAppsStatsCollectorOptions),
|
||||
|
||||
DisablePathApps: options.DeploymentValues.DisablePathApps.Value(),
|
||||
SecureAuthCookie: options.DeploymentValues.SecureAuthCookie.Value(),
|
||||
DisablePathApps: options.DeploymentValues.DisablePathApps.Value(),
|
||||
SecureAuthCookie: options.DeploymentValues.SecureAuthCookie.Value(),
|
||||
APIKeyEncryptionKeycache: options.AppEncryptionKeyCache,
|
||||
}
|
||||
|
||||
apiKeyMiddleware := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
|
||||
@@ -1434,6 +1481,9 @@ func (api *API) Close() error {
|
||||
_ = api.agentProvider.Close()
|
||||
_ = api.statsReporter.Close()
|
||||
_ = api.NetworkTelemetryBatcher.Close()
|
||||
_ = api.OIDCConvertKeyCache.Close()
|
||||
_ = api.AppSigningKeyCache.Close()
|
||||
_ = api.AppEncryptionKeyCache.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -55,6 +55,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/audit"
|
||||
"github.com/coder/coder/v2/coderd/autobuild"
|
||||
"github.com/coder/coder/v2/coderd/awsidentity"
|
||||
"github.com/coder/coder/v2/coderd/cryptokeys"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/db2sdk"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
@@ -88,12 +89,9 @@ import (
|
||||
sdkproto "github.com/coder/coder/v2/provisionersdk/proto"
|
||||
"github.com/coder/coder/v2/tailnet"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
// AppSecurityKey is a 96-byte key used to sign JWTs and encrypt JWEs for
|
||||
// workspace app tokens in tests.
|
||||
var AppSecurityKey = must(workspaceapps.KeyFromString("6465616e207761732068657265206465616e207761732068657265206465616e207761732068657265206465616e207761732068657265206465616e207761732068657265206465616e207761732068657265206465616e2077617320686572"))
|
||||
|
||||
type Options struct {
|
||||
// AccessURL denotes a custom access URL. By default we use the httptest
|
||||
// server's URL. Setting this may result in unexpected behavior (especially
|
||||
@@ -161,8 +159,10 @@ type Options struct {
|
||||
DatabaseRolluper *dbrollup.Rolluper
|
||||
WorkspaceUsageTrackerFlush chan int
|
||||
WorkspaceUsageTrackerTick chan time.Time
|
||||
|
||||
NotificationsEnqueuer notifications.Enqueuer
|
||||
NotificationsEnqueuer notifications.Enqueuer
|
||||
APIKeyEncryptionCache cryptokeys.EncryptionKeycache
|
||||
OIDCConvertKeyCache cryptokeys.SigningKeycache
|
||||
Clock quartz.Clock
|
||||
}
|
||||
|
||||
// New constructs a codersdk client connected to an in-memory API instance.
|
||||
@@ -525,7 +525,6 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can
|
||||
DeploymentOptions: codersdk.DeploymentOptionsWithoutSecrets(options.DeploymentValues.Options()),
|
||||
UpdateCheckOptions: options.UpdateCheckOptions,
|
||||
SwaggerEndpoint: options.SwaggerEndpoint,
|
||||
AppSecurityKey: AppSecurityKey,
|
||||
SSHConfig: options.ConfigSSH,
|
||||
HealthcheckFunc: options.HealthcheckFunc,
|
||||
HealthcheckTimeout: options.HealthcheckTimeout,
|
||||
@@ -538,6 +537,9 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can
|
||||
WorkspaceUsageTracker: wuTracker,
|
||||
NotificationsEnqueuer: options.NotificationsEnqueuer,
|
||||
OneTimePasscodeValidityPeriod: options.OneTimePasscodeValidityPeriod,
|
||||
Clock: options.Clock,
|
||||
AppEncryptionKeyCache: options.APIKeyEncryptionCache,
|
||||
OIDCConvertKeyCache: options.OIDCConvertKeyCache,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+61
-31
@@ -3,6 +3,7 @@ package cryptokeys
|
||||
import (
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
"sync"
|
||||
@@ -12,7 +13,7 @@ import (
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/db2sdk"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
@@ -25,7 +26,7 @@ var (
|
||||
)
|
||||
|
||||
type Fetcher interface {
|
||||
Fetch(ctx context.Context) ([]codersdk.CryptoKey, error)
|
||||
Fetch(ctx context.Context, feature codersdk.CryptoKeyFeature) ([]codersdk.CryptoKey, error)
|
||||
}
|
||||
|
||||
type EncryptionKeycache interface {
|
||||
@@ -62,27 +63,26 @@ const (
|
||||
)
|
||||
|
||||
type DBFetcher struct {
|
||||
DB database.Store
|
||||
Feature database.CryptoKeyFeature
|
||||
DB database.Store
|
||||
}
|
||||
|
||||
func (d *DBFetcher) Fetch(ctx context.Context) ([]codersdk.CryptoKey, error) {
|
||||
keys, err := d.DB.GetCryptoKeysByFeature(ctx, d.Feature)
|
||||
func (d *DBFetcher) Fetch(ctx context.Context, feature codersdk.CryptoKeyFeature) ([]codersdk.CryptoKey, error) {
|
||||
keys, err := d.DB.GetCryptoKeysByFeature(ctx, database.CryptoKeyFeature(feature))
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get crypto keys by feature: %w", err)
|
||||
}
|
||||
|
||||
return db2sdk.CryptoKeys(keys), nil
|
||||
return toSDKKeys(keys), nil
|
||||
}
|
||||
|
||||
// cache implements the caching functionality for both signing and encryption keys.
|
||||
type cache struct {
|
||||
clock quartz.Clock
|
||||
refreshCtx context.Context
|
||||
refreshCancel context.CancelFunc
|
||||
fetcher Fetcher
|
||||
logger slog.Logger
|
||||
feature codersdk.CryptoKeyFeature
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
clock quartz.Clock
|
||||
fetcher Fetcher
|
||||
logger slog.Logger
|
||||
feature codersdk.CryptoKeyFeature
|
||||
|
||||
mu sync.Mutex
|
||||
keys map[int32]codersdk.CryptoKey
|
||||
@@ -109,7 +109,8 @@ func NewSigningCache(ctx context.Context, logger slog.Logger, fetcher Fetcher,
|
||||
if !isSigningKeyFeature(feature) {
|
||||
return nil, xerrors.Errorf("invalid feature: %s", feature)
|
||||
}
|
||||
return newCache(ctx, logger, fetcher, feature, opts...)
|
||||
logger = logger.Named(fmt.Sprintf("%s_signing_keycache", feature))
|
||||
return newCache(ctx, logger, fetcher, feature, opts...), nil
|
||||
}
|
||||
|
||||
func NewEncryptionCache(ctx context.Context, logger slog.Logger, fetcher Fetcher,
|
||||
@@ -118,10 +119,11 @@ func NewEncryptionCache(ctx context.Context, logger slog.Logger, fetcher Fetcher
|
||||
if !isEncryptionKeyFeature(feature) {
|
||||
return nil, xerrors.Errorf("invalid feature: %s", feature)
|
||||
}
|
||||
return newCache(ctx, logger, fetcher, feature, opts...)
|
||||
logger = logger.Named(fmt.Sprintf("%s_encryption_keycache", feature))
|
||||
return newCache(ctx, logger, fetcher, feature, opts...), nil
|
||||
}
|
||||
|
||||
func newCache(ctx context.Context, logger slog.Logger, fetcher Fetcher, feature codersdk.CryptoKeyFeature, opts ...func(*cache)) (*cache, error) {
|
||||
func newCache(ctx context.Context, logger slog.Logger, fetcher Fetcher, feature codersdk.CryptoKeyFeature, opts ...func(*cache)) *cache {
|
||||
cache := &cache{
|
||||
clock: quartz.NewReal(),
|
||||
logger: logger,
|
||||
@@ -134,16 +136,16 @@ func newCache(ctx context.Context, logger slog.Logger, fetcher Fetcher, feature
|
||||
}
|
||||
|
||||
cache.cond = sync.NewCond(&cache.mu)
|
||||
cache.refreshCtx, cache.refreshCancel = context.WithCancel(ctx)
|
||||
//nolint:gocritic // We need to be able to read the keys in order to cache them.
|
||||
cache.ctx, cache.cancel = context.WithCancel(dbauthz.AsKeyReader(ctx))
|
||||
cache.refresher = cache.clock.AfterFunc(refreshInterval, cache.refresh)
|
||||
|
||||
keys, err := cache.cryptoKeys(ctx)
|
||||
keys, err := cache.cryptoKeys(cache.ctx)
|
||||
if err != nil {
|
||||
cache.refreshCancel()
|
||||
return nil, xerrors.Errorf("initial fetch: %w", err)
|
||||
cache.logger.Critical(cache.ctx, "failed initial fetch", slog.Error(err))
|
||||
}
|
||||
cache.keys = keys
|
||||
return cache, nil
|
||||
return cache
|
||||
}
|
||||
|
||||
func (c *cache) EncryptingKey(ctx context.Context) (string, interface{}, error) {
|
||||
@@ -151,6 +153,8 @@ func (c *cache) EncryptingKey(ctx context.Context) (string, interface{}, error)
|
||||
return "", nil, ErrInvalidFeature
|
||||
}
|
||||
|
||||
//nolint:gocritic // cache can only read crypto keys.
|
||||
ctx = dbauthz.AsKeyReader(ctx)
|
||||
return c.cryptoKey(ctx, latestSequence)
|
||||
}
|
||||
|
||||
@@ -164,6 +168,8 @@ func (c *cache) DecryptingKey(ctx context.Context, id string) (interface{}, erro
|
||||
return nil, xerrors.Errorf("parse id: %w", err)
|
||||
}
|
||||
|
||||
//nolint:gocritic // cache can only read crypto keys.
|
||||
ctx = dbauthz.AsKeyReader(ctx)
|
||||
_, secret, err := c.cryptoKey(ctx, int32(seq))
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("crypto key: %w", err)
|
||||
@@ -176,6 +182,8 @@ func (c *cache) SigningKey(ctx context.Context) (string, interface{}, error) {
|
||||
return "", nil, ErrInvalidFeature
|
||||
}
|
||||
|
||||
//nolint:gocritic // cache can only read crypto keys.
|
||||
ctx = dbauthz.AsKeyReader(ctx)
|
||||
return c.cryptoKey(ctx, latestSequence)
|
||||
}
|
||||
|
||||
@@ -188,7 +196,8 @@ func (c *cache) VerifyingKey(ctx context.Context, id string) (interface{}, error
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse id: %w", err)
|
||||
}
|
||||
|
||||
//nolint:gocritic // cache can only read crypto keys.
|
||||
ctx = dbauthz.AsKeyReader(ctx)
|
||||
_, secret, err := c.cryptoKey(ctx, int32(seq))
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("crypto key: %w", err)
|
||||
@@ -198,12 +207,12 @@ func (c *cache) VerifyingKey(ctx context.Context, id string) (interface{}, error
|
||||
}
|
||||
|
||||
func isEncryptionKeyFeature(feature codersdk.CryptoKeyFeature) bool {
|
||||
return feature == codersdk.CryptoKeyFeatureWorkspaceApp
|
||||
return feature == codersdk.CryptoKeyFeatureWorkspaceAppsAPIKey
|
||||
}
|
||||
|
||||
func isSigningKeyFeature(feature codersdk.CryptoKeyFeature) bool {
|
||||
switch feature {
|
||||
case codersdk.CryptoKeyFeatureTailnetResume, codersdk.CryptoKeyFeatureOIDCConvert:
|
||||
case codersdk.CryptoKeyFeatureTailnetResume, codersdk.CryptoKeyFeatureOIDCConvert, codersdk.CryptoKeyFeatureWorkspaceAppsToken:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
@@ -292,14 +301,15 @@ func checkKey(key codersdk.CryptoKey, sequence int32, now time.Time) (string, []
|
||||
func (c *cache) refresh() {
|
||||
now := c.clock.Now("CryptoKeyCache", "refresh")
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.closed {
|
||||
c.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// If something's already fetching, we don't need to do anything.
|
||||
if c.fetching {
|
||||
c.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -307,20 +317,21 @@ func (c *cache) refresh() {
|
||||
// is ongoing but prior to the timer getting reset. In this case we want to
|
||||
// avoid double fetching.
|
||||
if now.Sub(c.lastFetch) < refreshInterval {
|
||||
c.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
c.fetching = true
|
||||
|
||||
c.mu.Unlock()
|
||||
keys, err := c.cryptoKeys(c.refreshCtx)
|
||||
keys, err := c.cryptoKeys(c.ctx)
|
||||
if err != nil {
|
||||
c.logger.Error(c.refreshCtx, "fetch crypto keys", slog.Error(err))
|
||||
c.logger.Error(c.ctx, "fetch crypto keys", slog.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
// We don't defer an unlock here due to the deferred unlock at the top of the function.
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.lastFetch = c.clock.Now()
|
||||
c.refresher.Reset(refreshInterval)
|
||||
@@ -332,9 +343,9 @@ func (c *cache) refresh() {
|
||||
// cryptoKeys queries the control plane for the crypto keys.
|
||||
// Outside of initialization, this should only be called by fetch.
|
||||
func (c *cache) cryptoKeys(ctx context.Context) (map[int32]codersdk.CryptoKey, error) {
|
||||
keys, err := c.fetcher.Fetch(ctx)
|
||||
keys, err := c.fetcher.Fetch(ctx, c.feature)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("crypto keys: %w", err)
|
||||
return nil, xerrors.Errorf("fetch: %w", err)
|
||||
}
|
||||
cache := toKeyMap(keys, c.clock.Now())
|
||||
return cache, nil
|
||||
@@ -361,9 +372,28 @@ func (c *cache) Close() error {
|
||||
}
|
||||
|
||||
c.closed = true
|
||||
c.refreshCancel()
|
||||
c.cancel()
|
||||
c.refresher.Stop()
|
||||
c.cond.Broadcast()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// We have to do this to avoid a circular dependency on db2sdk (cryptokeys -> db2sdk -> tailnet -> cryptokeys)
|
||||
func toSDKKeys(keys []database.CryptoKey) []codersdk.CryptoKey {
|
||||
into := make([]codersdk.CryptoKey, 0, len(keys))
|
||||
for _, key := range keys {
|
||||
into = append(into, toSDK(key))
|
||||
}
|
||||
return into
|
||||
}
|
||||
|
||||
func toSDK(key database.CryptoKey) codersdk.CryptoKey {
|
||||
return codersdk.CryptoKey{
|
||||
Feature: codersdk.CryptoKeyFeature(key.Feature),
|
||||
Sequence: key.Sequence,
|
||||
StartsAt: key.StartsAt,
|
||||
DeletesAt: key.DeletesAt.Time,
|
||||
Secret: key.Secret.String,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -488,7 +488,7 @@ type fakeFetcher struct {
|
||||
called int
|
||||
}
|
||||
|
||||
func (f *fakeFetcher) Fetch(_ context.Context) ([]codersdk.CryptoKey, error) {
|
||||
func (f *fakeFetcher) Fetch(_ context.Context, _ codersdk.CryptoKeyFeature) ([]codersdk.CryptoKey, error) {
|
||||
f.called++
|
||||
return f.keys, nil
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
@@ -53,10 +54,12 @@ func WithKeyDuration(keyDuration time.Duration) RotatorOption {
|
||||
// StartRotator starts a background process that rotates keys in the database.
|
||||
// It ensures there's at least one valid key per feature prior to returning.
|
||||
// Canceling the provided context will stop the background process.
|
||||
func StartRotator(ctx context.Context, logger slog.Logger, db database.Store, opts ...RotatorOption) error {
|
||||
func StartRotator(ctx context.Context, logger slog.Logger, db database.Store, opts ...RotatorOption) {
|
||||
//nolint:gocritic // KeyRotator can only rotate crypto keys.
|
||||
ctx = dbauthz.AsKeyRotator(ctx)
|
||||
kr := &rotator{
|
||||
db: db,
|
||||
logger: logger,
|
||||
logger: logger.Named("keyrotator"),
|
||||
clock: quartz.NewReal(),
|
||||
keyDuration: DefaultKeyDuration,
|
||||
features: database.AllCryptoKeyFeatureValues(),
|
||||
@@ -68,12 +71,10 @@ func StartRotator(ctx context.Context, logger slog.Logger, db database.Store, op
|
||||
|
||||
err := kr.rotateKeys(ctx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("rotate keys: %w", err)
|
||||
kr.logger.Critical(ctx, "failed to rotate keys", slog.Error(err))
|
||||
}
|
||||
|
||||
go kr.start(ctx)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// start begins the process of rotating keys.
|
||||
@@ -227,9 +228,11 @@ func (k *rotator) rotateKey(ctx context.Context, tx database.Store, key database
|
||||
|
||||
func generateNewSecret(feature database.CryptoKeyFeature) (string, error) {
|
||||
switch feature {
|
||||
case database.CryptoKeyFeatureWorkspaceApps:
|
||||
case database.CryptoKeyFeatureWorkspaceAppsAPIKey:
|
||||
return generateKey(32)
|
||||
case database.CryptoKeyFeatureOidcConvert:
|
||||
case database.CryptoKeyFeatureWorkspaceAppsToken:
|
||||
return generateKey(64)
|
||||
case database.CryptoKeyFeatureOIDCConvert:
|
||||
return generateKey(64)
|
||||
case database.CryptoKeyFeatureTailnetResume:
|
||||
return generateKey(64)
|
||||
@@ -248,9 +251,11 @@ func generateKey(length int) (string, error) {
|
||||
|
||||
func tokenDuration(feature database.CryptoKeyFeature) time.Duration {
|
||||
switch feature {
|
||||
case database.CryptoKeyFeatureWorkspaceApps:
|
||||
case database.CryptoKeyFeatureWorkspaceAppsAPIKey:
|
||||
return WorkspaceAppsTokenDuration
|
||||
case database.CryptoKeyFeatureOidcConvert:
|
||||
case database.CryptoKeyFeatureWorkspaceAppsToken:
|
||||
return WorkspaceAppsTokenDuration
|
||||
case database.CryptoKeyFeatureOIDCConvert:
|
||||
return OIDCConvertTokenDuration
|
||||
case database.CryptoKeyFeatureTailnetResume:
|
||||
return TailnetResumeTokenDuration
|
||||
|
||||
@@ -38,7 +38,7 @@ func Test_rotateKeys(t *testing.T) {
|
||||
clock: clock,
|
||||
logger: logger,
|
||||
features: []database.CryptoKeyFeature{
|
||||
database.CryptoKeyFeatureWorkspaceApps,
|
||||
database.CryptoKeyFeatureWorkspaceAppsAPIKey,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -46,7 +46,7 @@ func Test_rotateKeys(t *testing.T) {
|
||||
|
||||
// Seed the database with an existing key.
|
||||
oldKey := dbgen.CryptoKey(t, db, database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureWorkspaceApps,
|
||||
Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey,
|
||||
StartsAt: now,
|
||||
Sequence: 15,
|
||||
})
|
||||
@@ -69,11 +69,11 @@ func Test_rotateKeys(t *testing.T) {
|
||||
|
||||
// The new key should be created and have a starts_at of the old key's expires_at.
|
||||
newKey, err := db.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{
|
||||
Feature: database.CryptoKeyFeatureWorkspaceApps,
|
||||
Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey,
|
||||
Sequence: oldKey.Sequence + 1,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
requireKey(t, newKey, database.CryptoKeyFeatureWorkspaceApps, oldKey.ExpiresAt(keyDuration), nullTime, oldKey.Sequence+1)
|
||||
requireKey(t, newKey, database.CryptoKeyFeatureWorkspaceAppsAPIKey, oldKey.ExpiresAt(keyDuration), nullTime, oldKey.Sequence+1)
|
||||
|
||||
// Advance the clock just before the keys delete time.
|
||||
clock.Advance(oldKey.DeletesAt.Time.UTC().Sub(now) - time.Second)
|
||||
@@ -123,7 +123,7 @@ func Test_rotateKeys(t *testing.T) {
|
||||
clock: clock,
|
||||
logger: logger,
|
||||
features: []database.CryptoKeyFeature{
|
||||
database.CryptoKeyFeatureWorkspaceApps,
|
||||
database.CryptoKeyFeatureWorkspaceAppsAPIKey,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -131,7 +131,7 @@ func Test_rotateKeys(t *testing.T) {
|
||||
|
||||
// Seed the database with an existing key
|
||||
existingKey := dbgen.CryptoKey(t, db, database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureWorkspaceApps,
|
||||
Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey,
|
||||
StartsAt: now,
|
||||
Sequence: 123,
|
||||
})
|
||||
@@ -179,7 +179,7 @@ func Test_rotateKeys(t *testing.T) {
|
||||
clock: clock,
|
||||
logger: logger,
|
||||
features: []database.CryptoKeyFeature{
|
||||
database.CryptoKeyFeatureWorkspaceApps,
|
||||
database.CryptoKeyFeatureWorkspaceAppsAPIKey,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -187,7 +187,7 @@ func Test_rotateKeys(t *testing.T) {
|
||||
|
||||
// Seed the database with an existing key
|
||||
deletingKey := dbgen.CryptoKey(t, db, database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureWorkspaceApps,
|
||||
Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey,
|
||||
StartsAt: now.Add(-keyDuration),
|
||||
Sequence: 789,
|
||||
DeletesAt: sql.NullTime{
|
||||
@@ -232,7 +232,7 @@ func Test_rotateKeys(t *testing.T) {
|
||||
clock: clock,
|
||||
logger: logger,
|
||||
features: []database.CryptoKeyFeature{
|
||||
database.CryptoKeyFeatureWorkspaceApps,
|
||||
database.CryptoKeyFeatureWorkspaceAppsAPIKey,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -240,7 +240,7 @@ func Test_rotateKeys(t *testing.T) {
|
||||
|
||||
// Seed the database with an existing key
|
||||
deletingKey := dbgen.CryptoKey(t, db, database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureWorkspaceApps,
|
||||
Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey,
|
||||
StartsAt: now,
|
||||
Sequence: 456,
|
||||
DeletesAt: sql.NullTime{
|
||||
@@ -281,7 +281,7 @@ func Test_rotateKeys(t *testing.T) {
|
||||
clock: clock,
|
||||
logger: logger,
|
||||
features: []database.CryptoKeyFeature{
|
||||
database.CryptoKeyFeatureWorkspaceApps,
|
||||
database.CryptoKeyFeatureWorkspaceAppsAPIKey,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -291,7 +291,7 @@ func Test_rotateKeys(t *testing.T) {
|
||||
keys, err := db.GetCryptoKeys(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, keys, 1)
|
||||
requireKey(t, keys[0], database.CryptoKeyFeatureWorkspaceApps, clock.Now().UTC(), nullTime, 1)
|
||||
requireKey(t, keys[0], database.CryptoKeyFeatureWorkspaceAppsAPIKey, clock.Now().UTC(), nullTime, 1)
|
||||
})
|
||||
|
||||
// Assert we insert a new key when the only key was manually deleted.
|
||||
@@ -312,14 +312,14 @@ func Test_rotateKeys(t *testing.T) {
|
||||
clock: clock,
|
||||
logger: logger,
|
||||
features: []database.CryptoKeyFeature{
|
||||
database.CryptoKeyFeatureWorkspaceApps,
|
||||
database.CryptoKeyFeatureWorkspaceAppsAPIKey,
|
||||
},
|
||||
}
|
||||
|
||||
now := dbnow(clock)
|
||||
|
||||
deletedkey := dbgen.CryptoKey(t, db, database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureWorkspaceApps,
|
||||
Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey,
|
||||
StartsAt: now,
|
||||
Sequence: 19,
|
||||
DeletesAt: sql.NullTime{
|
||||
@@ -338,7 +338,7 @@ func Test_rotateKeys(t *testing.T) {
|
||||
keys, err := db.GetCryptoKeys(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, keys, 1)
|
||||
requireKey(t, keys[0], database.CryptoKeyFeatureWorkspaceApps, now, nullTime, deletedkey.Sequence+1)
|
||||
requireKey(t, keys[0], database.CryptoKeyFeatureWorkspaceAppsAPIKey, now, nullTime, deletedkey.Sequence+1)
|
||||
})
|
||||
|
||||
// This tests ensures that rotation works with multiple
|
||||
@@ -365,9 +365,11 @@ func Test_rotateKeys(t *testing.T) {
|
||||
|
||||
now := dbnow(clock)
|
||||
|
||||
// We'll test a scenario where one feature has no valid keys.
|
||||
// Another has a key that should be rotate. And one that
|
||||
// has a valid key that shouldn't trigger an action.
|
||||
// We'll test a scenario where:
|
||||
// - One feature has no valid keys.
|
||||
// - One has a key that should be rotated.
|
||||
// - One has a valid key that shouldn't trigger an action.
|
||||
// - One has no keys at all.
|
||||
_ = dbgen.CryptoKey(t, db, database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureTailnetResume,
|
||||
StartsAt: now.Add(-keyDuration),
|
||||
@@ -377,6 +379,7 @@ func Test_rotateKeys(t *testing.T) {
|
||||
Valid: false,
|
||||
},
|
||||
})
|
||||
// Generate another deleted key to ensure we insert after the latest sequence.
|
||||
deletedKey := dbgen.CryptoKey(t, db, database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureTailnetResume,
|
||||
StartsAt: now.Add(-keyDuration),
|
||||
@@ -389,14 +392,14 @@ func Test_rotateKeys(t *testing.T) {
|
||||
|
||||
// Insert a key that should be rotated.
|
||||
rotatedKey := dbgen.CryptoKey(t, db, database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureWorkspaceApps,
|
||||
Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey,
|
||||
StartsAt: now.Add(-keyDuration + time.Hour),
|
||||
Sequence: 42,
|
||||
})
|
||||
|
||||
// Insert a key that should not trigger an action.
|
||||
validKey := dbgen.CryptoKey(t, db, database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureOidcConvert,
|
||||
Feature: database.CryptoKeyFeatureOIDCConvert,
|
||||
StartsAt: now,
|
||||
Sequence: 17,
|
||||
})
|
||||
@@ -406,26 +409,28 @@ func Test_rotateKeys(t *testing.T) {
|
||||
|
||||
keys, err := db.GetCryptoKeys(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, keys, 4)
|
||||
require.Len(t, keys, 5)
|
||||
|
||||
kbf, err := keysByFeature(keys, database.AllCryptoKeyFeatureValues())
|
||||
require.NoError(t, err)
|
||||
|
||||
// No actions on OIDC convert.
|
||||
require.Len(t, kbf[database.CryptoKeyFeatureOidcConvert], 1)
|
||||
require.Len(t, kbf[database.CryptoKeyFeatureOIDCConvert], 1)
|
||||
// Workspace apps should have been rotated.
|
||||
require.Len(t, kbf[database.CryptoKeyFeatureWorkspaceApps], 2)
|
||||
require.Len(t, kbf[database.CryptoKeyFeatureWorkspaceAppsAPIKey], 2)
|
||||
// No existing key for tailnet resume should've
|
||||
// caused a key to be inserted.
|
||||
require.Len(t, kbf[database.CryptoKeyFeatureTailnetResume], 1)
|
||||
require.Len(t, kbf[database.CryptoKeyFeatureWorkspaceAppsToken], 1)
|
||||
|
||||
oidcKey := kbf[database.CryptoKeyFeatureOidcConvert][0]
|
||||
oidcKey := kbf[database.CryptoKeyFeatureOIDCConvert][0]
|
||||
tailnetKey := kbf[database.CryptoKeyFeatureTailnetResume][0]
|
||||
requireKey(t, oidcKey, database.CryptoKeyFeatureOidcConvert, now, nullTime, validKey.Sequence)
|
||||
appTokenKey := kbf[database.CryptoKeyFeatureWorkspaceAppsToken][0]
|
||||
requireKey(t, oidcKey, database.CryptoKeyFeatureOIDCConvert, now, nullTime, validKey.Sequence)
|
||||
requireKey(t, tailnetKey, database.CryptoKeyFeatureTailnetResume, now, nullTime, deletedKey.Sequence+1)
|
||||
|
||||
newKey := kbf[database.CryptoKeyFeatureWorkspaceApps][0]
|
||||
oldKey := kbf[database.CryptoKeyFeatureWorkspaceApps][1]
|
||||
requireKey(t, appTokenKey, database.CryptoKeyFeatureWorkspaceAppsToken, now, nullTime, 1)
|
||||
newKey := kbf[database.CryptoKeyFeatureWorkspaceAppsAPIKey][0]
|
||||
oldKey := kbf[database.CryptoKeyFeatureWorkspaceAppsAPIKey][1]
|
||||
if newKey.Sequence == rotatedKey.Sequence {
|
||||
oldKey, newKey = newKey, oldKey
|
||||
}
|
||||
@@ -433,8 +438,8 @@ func Test_rotateKeys(t *testing.T) {
|
||||
Time: rotatedKey.ExpiresAt(keyDuration).Add(WorkspaceAppsTokenDuration + time.Hour),
|
||||
Valid: true,
|
||||
}
|
||||
requireKey(t, oldKey, database.CryptoKeyFeatureWorkspaceApps, rotatedKey.StartsAt.UTC(), deletesAt, rotatedKey.Sequence)
|
||||
requireKey(t, newKey, database.CryptoKeyFeatureWorkspaceApps, rotatedKey.ExpiresAt(keyDuration), nullTime, rotatedKey.Sequence+1)
|
||||
requireKey(t, oldKey, database.CryptoKeyFeatureWorkspaceAppsAPIKey, rotatedKey.StartsAt.UTC(), deletesAt, rotatedKey.Sequence)
|
||||
requireKey(t, newKey, database.CryptoKeyFeatureWorkspaceAppsAPIKey, rotatedKey.ExpiresAt(keyDuration), nullTime, rotatedKey.Sequence+1)
|
||||
})
|
||||
|
||||
t.Run("UnknownFeature", func(t *testing.T) {
|
||||
@@ -478,11 +483,11 @@ func Test_rotateKeys(t *testing.T) {
|
||||
keyDuration: keyDuration,
|
||||
clock: clock,
|
||||
logger: logger,
|
||||
features: []database.CryptoKeyFeature{database.CryptoKeyFeatureWorkspaceApps},
|
||||
features: []database.CryptoKeyFeature{database.CryptoKeyFeatureWorkspaceAppsAPIKey},
|
||||
}
|
||||
|
||||
expiringKey := dbgen.CryptoKey(t, db, database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureWorkspaceApps,
|
||||
Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey,
|
||||
StartsAt: now.Add(-keyDuration),
|
||||
Sequence: 345,
|
||||
})
|
||||
@@ -522,19 +527,19 @@ func Test_rotateKeys(t *testing.T) {
|
||||
keyDuration: keyDuration,
|
||||
clock: clock,
|
||||
logger: logger,
|
||||
features: []database.CryptoKeyFeature{database.CryptoKeyFeatureWorkspaceApps},
|
||||
features: []database.CryptoKeyFeature{database.CryptoKeyFeatureWorkspaceAppsAPIKey},
|
||||
}
|
||||
|
||||
now := dbnow(clock)
|
||||
|
||||
expiredKey := dbgen.CryptoKey(t, db, database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureWorkspaceApps,
|
||||
Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey,
|
||||
StartsAt: now.Add(-keyDuration - 2*time.Hour),
|
||||
Sequence: 19,
|
||||
})
|
||||
|
||||
deletedKey := dbgen.CryptoKey(t, db, database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureWorkspaceApps,
|
||||
Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey,
|
||||
StartsAt: now,
|
||||
Sequence: 20,
|
||||
Secret: sql.NullString{
|
||||
@@ -587,9 +592,11 @@ func requireKey(t *testing.T, key database.CryptoKey, feature database.CryptoKey
|
||||
require.NoError(t, err)
|
||||
|
||||
switch key.Feature {
|
||||
case database.CryptoKeyFeatureOidcConvert:
|
||||
case database.CryptoKeyFeatureOIDCConvert:
|
||||
require.Len(t, secret, 64)
|
||||
case database.CryptoKeyFeatureWorkspaceApps:
|
||||
case database.CryptoKeyFeatureWorkspaceAppsToken:
|
||||
require.Len(t, secret, 64)
|
||||
case database.CryptoKeyFeatureWorkspaceAppsAPIKey:
|
||||
require.Len(t, secret, 32)
|
||||
case database.CryptoKeyFeatureTailnetResume:
|
||||
require.Len(t, secret, 64)
|
||||
|
||||
@@ -34,8 +34,7 @@ func TestRotator(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Len(t, dbkeys, 0)
|
||||
|
||||
err = cryptokeys.StartRotator(ctx, logger, db, cryptokeys.WithClock(clock))
|
||||
require.NoError(t, err)
|
||||
cryptokeys.StartRotator(ctx, logger, db, cryptokeys.WithClock(clock))
|
||||
|
||||
// Fetch the keys from the database and ensure they
|
||||
// are as expected.
|
||||
@@ -58,7 +57,7 @@ func TestRotator(t *testing.T) {
|
||||
now := clock.Now().UTC()
|
||||
|
||||
rotatingKey := dbgen.CryptoKey(t, db, database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureWorkspaceApps,
|
||||
Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey,
|
||||
StartsAt: now.Add(-cryptokeys.DefaultKeyDuration + time.Hour + time.Minute),
|
||||
Sequence: 12345,
|
||||
})
|
||||
@@ -66,8 +65,7 @@ func TestRotator(t *testing.T) {
|
||||
trap := clock.Trap().TickerFunc()
|
||||
t.Cleanup(trap.Close)
|
||||
|
||||
err := cryptokeys.StartRotator(ctx, logger, db, cryptokeys.WithClock(clock))
|
||||
require.NoError(t, err)
|
||||
cryptokeys.StartRotator(ctx, logger, db, cryptokeys.WithClock(clock))
|
||||
|
||||
initialKeyLen := len(database.AllCryptoKeyFeatureValues())
|
||||
// Fetch the keys from the database and ensure they
|
||||
@@ -85,7 +83,7 @@ func TestRotator(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Len(t, keys, initialKeyLen+1)
|
||||
|
||||
newKey, err := db.GetLatestCryptoKeyByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps)
|
||||
newKey, err := db.GetLatestCryptoKeyByFeature(ctx, database.CryptoKeyFeatureWorkspaceAppsAPIKey)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, rotatingKey.Sequence+1, newKey.Sequence)
|
||||
require.Equal(t, rotatingKey.ExpiresAt(cryptokeys.DefaultKeyDuration), newKey.StartsAt.UTC())
|
||||
|
||||
@@ -228,6 +228,42 @@ var (
|
||||
Scope: rbac.ScopeAll,
|
||||
}.WithCachedASTValue()
|
||||
|
||||
// See cryptokeys package.
|
||||
subjectCryptoKeyRotator = rbac.Subject{
|
||||
FriendlyName: "Crypto Key Rotator",
|
||||
ID: uuid.Nil.String(),
|
||||
Roles: rbac.Roles([]rbac.Role{
|
||||
{
|
||||
Identifier: rbac.RoleIdentifier{Name: "keyrotator"},
|
||||
DisplayName: "Key Rotator",
|
||||
Site: rbac.Permissions(map[string][]policy.Action{
|
||||
rbac.ResourceCryptoKey.Type: {policy.WildcardSymbol},
|
||||
}),
|
||||
Org: map[string][]rbac.Permission{},
|
||||
User: []rbac.Permission{},
|
||||
},
|
||||
}),
|
||||
Scope: rbac.ScopeAll,
|
||||
}.WithCachedASTValue()
|
||||
|
||||
// See cryptokeys package.
|
||||
subjectCryptoKeyReader = rbac.Subject{
|
||||
FriendlyName: "Crypto Key Reader",
|
||||
ID: uuid.Nil.String(),
|
||||
Roles: rbac.Roles([]rbac.Role{
|
||||
{
|
||||
Identifier: rbac.RoleIdentifier{Name: "keyrotator"},
|
||||
DisplayName: "Key Rotator",
|
||||
Site: rbac.Permissions(map[string][]policy.Action{
|
||||
rbac.ResourceCryptoKey.Type: {policy.WildcardSymbol},
|
||||
}),
|
||||
Org: map[string][]rbac.Permission{},
|
||||
User: []rbac.Permission{},
|
||||
},
|
||||
}),
|
||||
Scope: rbac.ScopeAll,
|
||||
}.WithCachedASTValue()
|
||||
|
||||
subjectSystemRestricted = rbac.Subject{
|
||||
FriendlyName: "System",
|
||||
ID: uuid.Nil.String(),
|
||||
@@ -281,6 +317,16 @@ func AsHangDetector(ctx context.Context) context.Context {
|
||||
return context.WithValue(ctx, authContextKey{}, subjectHangDetector)
|
||||
}
|
||||
|
||||
// AsKeyRotator returns a context with an actor that has permissions required for rotating crypto keys.
|
||||
func AsKeyRotator(ctx context.Context) context.Context {
|
||||
return context.WithValue(ctx, authContextKey{}, subjectCryptoKeyRotator)
|
||||
}
|
||||
|
||||
// AsKeyReader returns a context with an actor that has permissions required for reading crypto keys.
|
||||
func AsKeyReader(ctx context.Context) context.Context {
|
||||
return context.WithValue(ctx, authContextKey{}, subjectCryptoKeyReader)
|
||||
}
|
||||
|
||||
// AsSystemRestricted returns a context with an actor that has permissions
|
||||
// required for various system operations (login, logout, metrics cache).
|
||||
func AsSystemRestricted(ctx context.Context) context.Context {
|
||||
|
||||
@@ -2243,13 +2243,13 @@ func (s *MethodTestSuite) TestCryptoKeys() {
|
||||
}))
|
||||
s.Run("InsertCryptoKey", s.Subtest(func(db database.Store, check *expects) {
|
||||
check.Args(database.InsertCryptoKeyParams{
|
||||
Feature: database.CryptoKeyFeatureWorkspaceApps,
|
||||
Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey,
|
||||
}).
|
||||
Asserts(rbac.ResourceCryptoKey, policy.ActionCreate)
|
||||
}))
|
||||
s.Run("DeleteCryptoKey", s.Subtest(func(db database.Store, check *expects) {
|
||||
key := dbgen.CryptoKey(s.T(), db, database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureWorkspaceApps,
|
||||
Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey,
|
||||
Sequence: 4,
|
||||
})
|
||||
check.Args(database.DeleteCryptoKeyParams{
|
||||
@@ -2259,7 +2259,7 @@ func (s *MethodTestSuite) TestCryptoKeys() {
|
||||
}))
|
||||
s.Run("GetCryptoKeyByFeatureAndSequence", s.Subtest(func(db database.Store, check *expects) {
|
||||
key := dbgen.CryptoKey(s.T(), db, database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureWorkspaceApps,
|
||||
Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey,
|
||||
Sequence: 4,
|
||||
})
|
||||
check.Args(database.GetCryptoKeyByFeatureAndSequenceParams{
|
||||
@@ -2269,14 +2269,14 @@ func (s *MethodTestSuite) TestCryptoKeys() {
|
||||
}))
|
||||
s.Run("GetLatestCryptoKeyByFeature", s.Subtest(func(db database.Store, check *expects) {
|
||||
dbgen.CryptoKey(s.T(), db, database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureWorkspaceApps,
|
||||
Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey,
|
||||
Sequence: 4,
|
||||
})
|
||||
check.Args(database.CryptoKeyFeatureWorkspaceApps).Asserts(rbac.ResourceCryptoKey, policy.ActionRead)
|
||||
check.Args(database.CryptoKeyFeatureWorkspaceAppsAPIKey).Asserts(rbac.ResourceCryptoKey, policy.ActionRead)
|
||||
}))
|
||||
s.Run("UpdateCryptoKeyDeletesAt", s.Subtest(func(db database.Store, check *expects) {
|
||||
key := dbgen.CryptoKey(s.T(), db, database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureWorkspaceApps,
|
||||
Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey,
|
||||
Sequence: 4,
|
||||
})
|
||||
check.Args(database.UpdateCryptoKeyDeletesAtParams{
|
||||
@@ -2286,7 +2286,7 @@ func (s *MethodTestSuite) TestCryptoKeys() {
|
||||
}).Asserts(rbac.ResourceCryptoKey, policy.ActionUpdate)
|
||||
}))
|
||||
s.Run("GetCryptoKeysByFeature", s.Subtest(func(db database.Store, check *expects) {
|
||||
check.Args(database.CryptoKeyFeatureWorkspaceApps).
|
||||
check.Args(database.CryptoKeyFeatureWorkspaceAppsAPIKey).
|
||||
Asserts(rbac.ResourceCryptoKey, policy.ActionRead)
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -943,7 +943,7 @@ func CustomRole(t testing.TB, db database.Store, seed database.CustomRole) datab
|
||||
func CryptoKey(t testing.TB, db database.Store, seed database.CryptoKey) database.CryptoKey {
|
||||
t.Helper()
|
||||
|
||||
seed.Feature = takeFirst(seed.Feature, database.CryptoKeyFeatureWorkspaceApps)
|
||||
seed.Feature = takeFirst(seed.Feature, database.CryptoKeyFeatureWorkspaceAppsAPIKey)
|
||||
|
||||
// An empty string for the secret is interpreted as
|
||||
// a caller wanting a new secret to be generated.
|
||||
@@ -1048,9 +1048,11 @@ func takeFirst[Value comparable](values ...Value) Value {
|
||||
|
||||
func newCryptoKeySecret(feature database.CryptoKeyFeature) (string, error) {
|
||||
switch feature {
|
||||
case database.CryptoKeyFeatureWorkspaceApps:
|
||||
case database.CryptoKeyFeatureWorkspaceAppsAPIKey:
|
||||
return generateCryptoKey(32)
|
||||
case database.CryptoKeyFeatureOidcConvert:
|
||||
case database.CryptoKeyFeatureWorkspaceAppsToken:
|
||||
return generateCryptoKey(64)
|
||||
case database.CryptoKeyFeatureOIDCConvert:
|
||||
return generateCryptoKey(64)
|
||||
case database.CryptoKeyFeatureTailnetResume:
|
||||
return generateCryptoKey(64)
|
||||
|
||||
Generated
+2
-1
@@ -38,7 +38,8 @@ CREATE TYPE build_reason AS ENUM (
|
||||
);
|
||||
|
||||
CREATE TYPE crypto_key_feature AS ENUM (
|
||||
'workspace_apps',
|
||||
'workspace_apps_token',
|
||||
'workspace_apps_api_key',
|
||||
'oidc_convert',
|
||||
'tailnet_resume'
|
||||
);
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
-- Step 1: Remove the new entries from crypto_keys table
|
||||
DELETE FROM crypto_keys
|
||||
WHERE feature IN ('workspace_apps_token', 'workspace_apps_api_key');
|
||||
|
||||
CREATE TYPE old_crypto_key_feature AS ENUM (
|
||||
'workspace_apps',
|
||||
'oidc_convert',
|
||||
'tailnet_resume'
|
||||
);
|
||||
|
||||
ALTER TABLE crypto_keys
|
||||
ALTER COLUMN feature TYPE old_crypto_key_feature
|
||||
USING (feature::text::old_crypto_key_feature);
|
||||
|
||||
DROP TYPE crypto_key_feature;
|
||||
|
||||
ALTER TYPE old_crypto_key_feature RENAME TO crypto_key_feature;
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
-- Create a new enum type with the desired values
|
||||
CREATE TYPE new_crypto_key_feature AS ENUM (
|
||||
'workspace_apps_token',
|
||||
'workspace_apps_api_key',
|
||||
'oidc_convert',
|
||||
'tailnet_resume'
|
||||
);
|
||||
|
||||
DELETE FROM crypto_keys WHERE feature = 'workspace_apps';
|
||||
|
||||
-- Drop the old type and rename the new one
|
||||
ALTER TABLE crypto_keys
|
||||
ALTER COLUMN feature TYPE new_crypto_key_feature
|
||||
USING (feature::text::new_crypto_key_feature);
|
||||
|
||||
DROP TYPE crypto_key_feature;
|
||||
|
||||
ALTER TYPE new_crypto_key_feature RENAME TO crypto_key_feature;
|
||||
+40
@@ -0,0 +1,40 @@
|
||||
INSERT INTO crypto_keys (feature, sequence, secret, secret_key_id, starts_at, deletes_at)
|
||||
VALUES (
|
||||
'workspace_apps_token',
|
||||
1,
|
||||
'abc',
|
||||
NULL,
|
||||
'1970-01-01 00:00:00 UTC'::timestamptz,
|
||||
'2100-01-01 00:00:00 UTC'::timestamptz
|
||||
);
|
||||
|
||||
INSERT INTO crypto_keys (feature, sequence, secret, secret_key_id, starts_at, deletes_at)
|
||||
VALUES (
|
||||
'workspace_apps_api_key',
|
||||
1,
|
||||
'def',
|
||||
NULL,
|
||||
'1970-01-01 00:00:00 UTC'::timestamptz,
|
||||
'2100-01-01 00:00:00 UTC'::timestamptz
|
||||
);
|
||||
|
||||
INSERT INTO crypto_keys (feature, sequence, secret, secret_key_id, starts_at, deletes_at)
|
||||
VALUES (
|
||||
'oidc_convert',
|
||||
2,
|
||||
'ghi',
|
||||
NULL,
|
||||
'1970-01-01 00:00:00 UTC'::timestamptz,
|
||||
'2100-01-01 00:00:00 UTC'::timestamptz
|
||||
);
|
||||
|
||||
INSERT INTO crypto_keys (feature, sequence, secret, secret_key_id, starts_at, deletes_at)
|
||||
VALUES (
|
||||
'tailnet_resume',
|
||||
2,
|
||||
'jkl',
|
||||
NULL,
|
||||
'1970-01-01 00:00:00 UTC'::timestamptz,
|
||||
'2100-01-01 00:00:00 UTC'::timestamptz
|
||||
);
|
||||
|
||||
@@ -345,9 +345,10 @@ func AllBuildReasonValues() []BuildReason {
|
||||
type CryptoKeyFeature string
|
||||
|
||||
const (
|
||||
CryptoKeyFeatureWorkspaceApps CryptoKeyFeature = "workspace_apps"
|
||||
CryptoKeyFeatureOidcConvert CryptoKeyFeature = "oidc_convert"
|
||||
CryptoKeyFeatureTailnetResume CryptoKeyFeature = "tailnet_resume"
|
||||
CryptoKeyFeatureWorkspaceAppsToken CryptoKeyFeature = "workspace_apps_token"
|
||||
CryptoKeyFeatureWorkspaceAppsAPIKey CryptoKeyFeature = "workspace_apps_api_key"
|
||||
CryptoKeyFeatureOIDCConvert CryptoKeyFeature = "oidc_convert"
|
||||
CryptoKeyFeatureTailnetResume CryptoKeyFeature = "tailnet_resume"
|
||||
)
|
||||
|
||||
func (e *CryptoKeyFeature) Scan(src interface{}) error {
|
||||
@@ -387,8 +388,9 @@ func (ns NullCryptoKeyFeature) Value() (driver.Value, error) {
|
||||
|
||||
func (e CryptoKeyFeature) Valid() bool {
|
||||
switch e {
|
||||
case CryptoKeyFeatureWorkspaceApps,
|
||||
CryptoKeyFeatureOidcConvert,
|
||||
case CryptoKeyFeatureWorkspaceAppsToken,
|
||||
CryptoKeyFeatureWorkspaceAppsAPIKey,
|
||||
CryptoKeyFeatureOIDCConvert,
|
||||
CryptoKeyFeatureTailnetResume:
|
||||
return true
|
||||
}
|
||||
@@ -397,8 +399,9 @@ func (e CryptoKeyFeature) Valid() bool {
|
||||
|
||||
func AllCryptoKeyFeatureValues() []CryptoKeyFeature {
|
||||
return []CryptoKeyFeature{
|
||||
CryptoKeyFeatureWorkspaceApps,
|
||||
CryptoKeyFeatureOidcConvert,
|
||||
CryptoKeyFeatureWorkspaceAppsToken,
|
||||
CryptoKeyFeatureWorkspaceAppsAPIKey,
|
||||
CryptoKeyFeatureOIDCConvert,
|
||||
CryptoKeyFeatureTailnetResume,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -135,6 +135,8 @@ sql:
|
||||
api_key_id: APIKeyID
|
||||
callback_url: CallbackURL
|
||||
login_type_oauth2_provider_app: LoginTypeOAuth2ProviderApp
|
||||
crypto_key_feature_workspace_apps_api_key: CryptoKeyFeatureWorkspaceAppsAPIKey
|
||||
crypto_key_feature_oidc_convert: CryptoKeyFeatureOIDCConvert
|
||||
rules:
|
||||
- name: do-not-use-public-schema-in-queries
|
||||
message: "do not use public schema in queries"
|
||||
|
||||
@@ -65,6 +65,12 @@ func Encrypt(ctx context.Context, e EncryptKeyProvider, claims Claims) (string,
|
||||
return compact, nil
|
||||
}
|
||||
|
||||
func WithDecryptExpected(expected jwt.Expected) func(*DecryptOptions) {
|
||||
return func(opts *DecryptOptions) {
|
||||
opts.RegisteredClaims = expected
|
||||
}
|
||||
}
|
||||
|
||||
// DecryptOptions are options for decrypting a JWE.
|
||||
type DecryptOptions struct {
|
||||
RegisteredClaims jwt.Expected
|
||||
@@ -100,7 +106,7 @@ func Decrypt(ctx context.Context, d DecryptKeyProvider, token string, claims Cla
|
||||
|
||||
kid := object.Header.KeyID
|
||||
if kid == "" {
|
||||
return xerrors.Errorf("expected %q header to be a string", keyIDHeaderKey)
|
||||
return ErrMissingKeyID
|
||||
}
|
||||
|
||||
key, err := d.DecryptingKey(ctx, kid)
|
||||
|
||||
+61
-1
@@ -10,10 +10,27 @@ import (
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
var ErrMissingKeyID = xerrors.New("missing key ID")
|
||||
|
||||
const (
|
||||
keyIDHeaderKey = "kid"
|
||||
)
|
||||
|
||||
// RegisteredClaims is a convenience type for embedding jwt.Claims. It should be
|
||||
// preferred over embedding jwt.Claims directly since it will ensure that certain fields are set.
|
||||
type RegisteredClaims jwt.Claims
|
||||
|
||||
func (r RegisteredClaims) Validate(e jwt.Expected) error {
|
||||
if r.Expiry == nil {
|
||||
return xerrors.Errorf("expiry is required")
|
||||
}
|
||||
if e.Time.IsZero() {
|
||||
return xerrors.Errorf("expected time is required")
|
||||
}
|
||||
|
||||
return (jwt.Claims(r)).Validate(e)
|
||||
}
|
||||
|
||||
// Claims defines the payload for a JWT. Most callers
|
||||
// should embed jwt.Claims
|
||||
type Claims interface {
|
||||
@@ -24,6 +41,11 @@ const (
|
||||
signingAlgo = jose.HS512
|
||||
)
|
||||
|
||||
type SigningKeyManager interface {
|
||||
SigningKeyProvider
|
||||
VerifyKeyProvider
|
||||
}
|
||||
|
||||
type SigningKeyProvider interface {
|
||||
SigningKey(ctx context.Context) (id string, key interface{}, err error)
|
||||
}
|
||||
@@ -75,6 +97,12 @@ type VerifyOptions struct {
|
||||
SignatureAlgorithm jose.SignatureAlgorithm
|
||||
}
|
||||
|
||||
func WithVerifyExpected(expected jwt.Expected) func(*VerifyOptions) {
|
||||
return func(opts *VerifyOptions) {
|
||||
opts.RegisteredClaims = expected
|
||||
}
|
||||
}
|
||||
|
||||
// Verify verifies that a token was signed by the provided key. It unmarshals into the provided claims.
|
||||
func Verify(ctx context.Context, v VerifyKeyProvider, token string, claims Claims, opts ...func(*VerifyOptions)) error {
|
||||
options := VerifyOptions{
|
||||
@@ -105,7 +133,7 @@ func Verify(ctx context.Context, v VerifyKeyProvider, token string, claims Claim
|
||||
|
||||
kid := signature.Header.KeyID
|
||||
if kid == "" {
|
||||
return xerrors.Errorf("expected %q header to be a string", keyIDHeaderKey)
|
||||
return ErrMissingKeyID
|
||||
}
|
||||
|
||||
key, err := v.VerifyingKey(ctx, kid)
|
||||
@@ -125,3 +153,35 @@ func Verify(ctx context.Context, v VerifyKeyProvider, token string, claims Claim
|
||||
|
||||
return claims.Validate(options.RegisteredClaims)
|
||||
}
|
||||
|
||||
// StaticKey fulfills the SigningKeycache and EncryptionKeycache interfaces. Useful for testing.
|
||||
type StaticKey struct {
|
||||
ID string
|
||||
Key interface{}
|
||||
}
|
||||
|
||||
func (s StaticKey) SigningKey(_ context.Context) (string, interface{}, error) {
|
||||
return s.ID, s.Key, nil
|
||||
}
|
||||
|
||||
func (s StaticKey) VerifyingKey(_ context.Context, id string) (interface{}, error) {
|
||||
if id != s.ID {
|
||||
return nil, xerrors.Errorf("invalid id %q", id)
|
||||
}
|
||||
return s.Key, nil
|
||||
}
|
||||
|
||||
func (s StaticKey) EncryptingKey(_ context.Context) (string, interface{}, error) {
|
||||
return s.ID, s.Key, nil
|
||||
}
|
||||
|
||||
func (s StaticKey) DecryptingKey(_ context.Context, id string) (interface{}, error) {
|
||||
if id != s.ID {
|
||||
return nil, xerrors.Errorf("invalid id %q", id)
|
||||
}
|
||||
return s.Key, nil
|
||||
}
|
||||
|
||||
func (StaticKey) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -236,11 +236,11 @@ func TestJWS(t *testing.T) {
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
db, _ = dbtestutil.NewDB(t)
|
||||
_ = dbgen.CryptoKey(t, db, database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureOidcConvert,
|
||||
Feature: database.CryptoKeyFeatureOIDCConvert,
|
||||
StartsAt: time.Now(),
|
||||
})
|
||||
log = slogtest.Make(t, nil)
|
||||
fetcher = &cryptokeys.DBFetcher{DB: db, Feature: database.CryptoKeyFeatureOidcConvert}
|
||||
fetcher = &cryptokeys.DBFetcher{DB: db}
|
||||
)
|
||||
|
||||
cache, err := cryptokeys.NewSigningCache(ctx, log, fetcher, codersdk.CryptoKeyFeatureOIDCConvert)
|
||||
@@ -326,15 +326,15 @@ func TestJWE(t *testing.T) {
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
db, _ = dbtestutil.NewDB(t)
|
||||
_ = dbgen.CryptoKey(t, db, database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureWorkspaceApps,
|
||||
Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey,
|
||||
StartsAt: time.Now(),
|
||||
})
|
||||
log = slogtest.Make(t, nil)
|
||||
|
||||
fetcher = &cryptokeys.DBFetcher{DB: db, Feature: database.CryptoKeyFeatureWorkspaceApps}
|
||||
fetcher = &cryptokeys.DBFetcher{DB: db}
|
||||
)
|
||||
|
||||
cache, err := cryptokeys.NewEncryptionCache(ctx, log, fetcher, codersdk.CryptoKeyFeatureWorkspaceApp)
|
||||
cache, err := cryptokeys.NewEncryptionCache(ctx, log, fetcher, codersdk.CryptoKeyFeatureWorkspaceAppsAPIKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
claims := testClaims{
|
||||
|
||||
+22
-18
@@ -15,7 +15,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/go-jose/go-jose/v4"
|
||||
"github.com/go-jose/go-jose/v4/jwt"
|
||||
"github.com/google/go-github/v43/github"
|
||||
"github.com/google/uuid"
|
||||
"github.com/moby/moby/pkg/namesgenerator"
|
||||
@@ -23,6 +24,9 @@ import (
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/v2/coderd/cryptokeys"
|
||||
"github.com/coder/coder/v2/coderd/idpsync"
|
||||
"github.com/coder/coder/v2/coderd/jwtutils"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/apikey"
|
||||
"github.com/coder/coder/v2/coderd/audit"
|
||||
@@ -32,7 +36,6 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/externalauth"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
"github.com/coder/coder/v2/coderd/idpsync"
|
||||
"github.com/coder/coder/v2/coderd/notifications"
|
||||
"github.com/coder/coder/v2/coderd/promoauth"
|
||||
"github.com/coder/coder/v2/coderd/rbac"
|
||||
@@ -49,7 +52,7 @@ const (
|
||||
)
|
||||
|
||||
type OAuthConvertStateClaims struct {
|
||||
jwt.RegisteredClaims
|
||||
jwtutils.RegisteredClaims
|
||||
|
||||
UserID uuid.UUID `json:"user_id"`
|
||||
State string `json:"state"`
|
||||
@@ -57,6 +60,10 @@ type OAuthConvertStateClaims struct {
|
||||
ToLoginType codersdk.LoginType `json:"to_login_type"`
|
||||
}
|
||||
|
||||
func (o *OAuthConvertStateClaims) Validate(e jwt.Expected) error {
|
||||
return o.RegisteredClaims.Validate(e)
|
||||
}
|
||||
|
||||
// postConvertLoginType replies with an oauth state token capable of converting
|
||||
// the user to an oauth user.
|
||||
//
|
||||
@@ -149,11 +156,11 @@ func (api *API) postConvertLoginType(rw http.ResponseWriter, r *http.Request) {
|
||||
// Eg: Developers with more than 1 deployment.
|
||||
now := time.Now()
|
||||
claims := &OAuthConvertStateClaims{
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
RegisteredClaims: jwtutils.RegisteredClaims{
|
||||
Issuer: api.DeploymentID,
|
||||
Subject: stateString,
|
||||
Audience: []string{user.ID.String()},
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(time.Minute * 5)),
|
||||
Expiry: jwt.NewNumericDate(now.Add(time.Minute * 5)),
|
||||
NotBefore: jwt.NewNumericDate(now.Add(time.Second * -1)),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
ID: uuid.NewString(),
|
||||
@@ -164,9 +171,7 @@ func (api *API) postConvertLoginType(rw http.ResponseWriter, r *http.Request) {
|
||||
ToLoginType: req.ToType,
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS512, claims)
|
||||
// Key must be a byte slice, not an array. So make sure to include the [:]
|
||||
tokenString, err := token.SignedString(api.OAuthSigningKey[:])
|
||||
token, err := jwtutils.Sign(ctx, api.OIDCConvertKeyCache, claims)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error signing state jwt.",
|
||||
@@ -176,8 +181,8 @@ func (api *API) postConvertLoginType(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
aReq.New = database.AuditOAuthConvertState{
|
||||
CreatedAt: claims.IssuedAt.Time,
|
||||
ExpiresAt: claims.ExpiresAt.Time,
|
||||
CreatedAt: claims.IssuedAt.Time(),
|
||||
ExpiresAt: claims.Expiry.Time(),
|
||||
FromLoginType: database.LoginType(claims.FromLoginType),
|
||||
ToLoginType: database.LoginType(claims.ToLoginType),
|
||||
UserID: claims.UserID,
|
||||
@@ -186,8 +191,8 @@ func (api *API) postConvertLoginType(rw http.ResponseWriter, r *http.Request) {
|
||||
http.SetCookie(rw, &http.Cookie{
|
||||
Name: OAuthConvertCookieValue,
|
||||
Path: "/",
|
||||
Value: tokenString,
|
||||
Expires: claims.ExpiresAt.Time,
|
||||
Value: token,
|
||||
Expires: claims.Expiry.Time(),
|
||||
Secure: api.SecureAuthCookie,
|
||||
HttpOnly: true,
|
||||
// Must be SameSite to work on the redirected auth flow from the
|
||||
@@ -196,7 +201,7 @@ func (api *API) postConvertLoginType(rw http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
httpapi.Write(ctx, rw, http.StatusCreated, codersdk.OAuthConversionResponse{
|
||||
StateString: stateString,
|
||||
ExpiresAt: claims.ExpiresAt.Time,
|
||||
ExpiresAt: claims.Expiry.Time(),
|
||||
ToType: claims.ToLoginType,
|
||||
UserID: claims.UserID,
|
||||
})
|
||||
@@ -1677,10 +1682,9 @@ func (api *API) convertUserToOauth(ctx context.Context, r *http.Request, db data
|
||||
}
|
||||
}
|
||||
var claims OAuthConvertStateClaims
|
||||
token, err := jwt.ParseWithClaims(jwtCookie.Value, &claims, func(_ *jwt.Token) (interface{}, error) {
|
||||
return api.OAuthSigningKey[:], nil
|
||||
})
|
||||
if xerrors.Is(err, jwt.ErrSignatureInvalid) || !token.Valid {
|
||||
|
||||
err = jwtutils.Verify(ctx, api.OIDCConvertKeyCache, jwtCookie.Value, &claims)
|
||||
if xerrors.Is(err, cryptokeys.ErrKeyNotFound) || xerrors.Is(err, cryptokeys.ErrKeyInvalid) || xerrors.Is(err, jose.ErrCryptoFailure) || xerrors.Is(err, jwtutils.ErrMissingKeyID) {
|
||||
// These errors are probably because the user is mixing 2 coder deployments.
|
||||
return database.User{}, idpsync.HTTPError{
|
||||
Code: http.StatusBadRequest,
|
||||
@@ -1709,7 +1713,7 @@ func (api *API) convertUserToOauth(ctx context.Context, r *http.Request, db data
|
||||
oauthConvertAudit.UserID = claims.UserID
|
||||
oauthConvertAudit.Old = user
|
||||
|
||||
if claims.RegisteredClaims.Issuer != api.DeploymentID {
|
||||
if claims.Issuer != api.DeploymentID {
|
||||
return database.User{}, idpsync.HTTPError{
|
||||
Code: http.StatusForbidden,
|
||||
Msg: "Request to convert login type failed. Issuer mismatch. Found a cookie from another coder deployment, please try again.",
|
||||
|
||||
+123
-1
@@ -3,6 +3,8 @@ package coderd_test
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -13,6 +15,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/go-jose/go-jose/v4"
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/google/go-github/v43/github"
|
||||
"github.com/google/uuid"
|
||||
@@ -27,10 +30,12 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/audit"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
|
||||
"github.com/coder/coder/v2/coderd/cryptokeys"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/coderd/jwtutils"
|
||||
"github.com/coder/coder/v2/coderd/notifications"
|
||||
"github.com/coder/coder/v2/coderd/promoauth"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
@@ -1316,6 +1321,7 @@ func TestUserOIDC(t *testing.T) {
|
||||
|
||||
owner := coderdtest.CreateFirstUser(t, client)
|
||||
user, userData := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
|
||||
require.Equal(t, codersdk.LoginTypePassword, userData.LoginType)
|
||||
|
||||
claims := jwt.MapClaims{
|
||||
"email": userData.Email,
|
||||
@@ -1323,15 +1329,17 @@ func TestUserOIDC(t *testing.T) {
|
||||
var err error
|
||||
user.HTTPClient.Jar, err = cookiejar.New(nil)
|
||||
require.NoError(t, err)
|
||||
user.HTTPClient.Transport = http.DefaultTransport.(*http.Transport).Clone()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
convertResponse, err := user.ConvertLoginType(ctx, codersdk.ConvertLoginRequest{
|
||||
ToType: codersdk.LoginTypeOIDC,
|
||||
Password: "SomeSecurePassword!",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
fake.LoginWithClient(t, user, claims, func(r *http.Request) {
|
||||
_, _ = fake.LoginWithClient(t, user, claims, func(r *http.Request) {
|
||||
r.URL.RawQuery = url.Values{
|
||||
"oidc_merge_state": {convertResponse.StateString},
|
||||
}.Encode()
|
||||
@@ -1341,6 +1349,99 @@ func TestUserOIDC(t *testing.T) {
|
||||
r.AddCookie(cookie)
|
||||
}
|
||||
})
|
||||
|
||||
info, err := client.User(ctx, userData.ID.String())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.LoginTypeOIDC, info.LoginType)
|
||||
})
|
||||
|
||||
t.Run("BadJWT", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
ctx = testutil.Context(t, testutil.WaitMedium)
|
||||
logger = slogtest.Make(t, nil)
|
||||
)
|
||||
|
||||
auditor := audit.NewMock()
|
||||
fake := oidctest.NewFakeIDP(t,
|
||||
oidctest.WithRefresh(func(_ string) error {
|
||||
return xerrors.New("refreshing token should never occur")
|
||||
}),
|
||||
oidctest.WithServing(),
|
||||
)
|
||||
cfg := fake.OIDCConfig(t, nil, func(cfg *coderd.OIDCConfig) {
|
||||
cfg.AllowSignups = true
|
||||
})
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
fetcher := &cryptokeys.DBFetcher{
|
||||
DB: db,
|
||||
}
|
||||
|
||||
kc, err := cryptokeys.NewSigningCache(ctx, logger, fetcher, codersdk.CryptoKeyFeatureOIDCConvert)
|
||||
require.NoError(t, err)
|
||||
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
Auditor: auditor,
|
||||
OIDCConfig: cfg,
|
||||
Database: db,
|
||||
Pubsub: ps,
|
||||
OIDCConvertKeyCache: kc,
|
||||
})
|
||||
|
||||
owner := coderdtest.CreateFirstUser(t, client)
|
||||
user, userData := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
|
||||
|
||||
claims := jwt.MapClaims{
|
||||
"email": userData.Email,
|
||||
}
|
||||
user.HTTPClient.Jar, err = cookiejar.New(nil)
|
||||
require.NoError(t, err)
|
||||
user.HTTPClient.Transport = http.DefaultTransport.(*http.Transport).Clone()
|
||||
|
||||
convertResponse, err := user.ConvertLoginType(ctx, codersdk.ConvertLoginRequest{
|
||||
ToType: codersdk.LoginTypeOIDC,
|
||||
Password: "SomeSecurePassword!",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update the cookie to use a bad signing key. We're asserting the behavior of the scenario
|
||||
// where a JWT gets minted on an old version of Coder but gets verified on a new version.
|
||||
_, resp := fake.AttemptLogin(t, user, claims, func(r *http.Request) {
|
||||
r.URL.RawQuery = url.Values{
|
||||
"oidc_merge_state": {convertResponse.StateString},
|
||||
}.Encode()
|
||||
r.Header.Set(codersdk.SessionTokenHeader, user.SessionToken())
|
||||
|
||||
cookies := user.HTTPClient.Jar.Cookies(user.URL)
|
||||
for i, cookie := range cookies {
|
||||
if cookie.Name != coderd.OAuthConvertCookieValue {
|
||||
continue
|
||||
}
|
||||
|
||||
jwt := cookie.Value
|
||||
var claims coderd.OAuthConvertStateClaims
|
||||
err := jwtutils.Verify(ctx, kc, jwt, &claims)
|
||||
require.NoError(t, err)
|
||||
badJWT := generateBadJWT(t, claims)
|
||||
cookie.Value = badJWT
|
||||
cookies[i] = cookie
|
||||
}
|
||||
|
||||
user.HTTPClient.Jar.SetCookies(user.URL, cookies)
|
||||
|
||||
for _, cookie := range cookies {
|
||||
fmt.Printf("cookie: %+v\n", cookie)
|
||||
r.AddCookie(cookie)
|
||||
}
|
||||
})
|
||||
defer resp.Body.Close()
|
||||
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
||||
var respErr codersdk.Response
|
||||
err = json.NewDecoder(resp.Body).Decode(&respErr)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, respErr.Message, "Using an invalid jwt to authorize this action.")
|
||||
})
|
||||
|
||||
t.Run("AlternateUsername", func(t *testing.T) {
|
||||
@@ -2022,3 +2123,24 @@ func inflateClaims(t testing.TB, seed jwt.MapClaims, size int) jwt.MapClaims {
|
||||
seed["random_data"] = junk
|
||||
return seed
|
||||
}
|
||||
|
||||
// generateBadJWT generates a JWT with a random key. It's intended to emulate the old-style JWT's we generated.
|
||||
func generateBadJWT(t *testing.T, claims interface{}) string {
|
||||
t.Helper()
|
||||
|
||||
var buf [64]byte
|
||||
_, err := rand.Read(buf[:])
|
||||
require.NoError(t, err)
|
||||
signer, err := jose.NewSigner(jose.SigningKey{
|
||||
Algorithm: jose.HS512,
|
||||
Key: buf[:],
|
||||
}, nil)
|
||||
require.NoError(t, err)
|
||||
payload, err := json.Marshal(claims)
|
||||
require.NoError(t, err)
|
||||
signed, err := signer.Sign(payload)
|
||||
require.NoError(t, err)
|
||||
compact, err := signed.CompactSerialize()
|
||||
require.NoError(t, err)
|
||||
return compact
|
||||
}
|
||||
|
||||
@@ -32,6 +32,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/externalauth"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
"github.com/coder/coder/v2/coderd/jwtutils"
|
||||
"github.com/coder/coder/v2/coderd/rbac/policy"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
@@ -852,8 +853,12 @@ func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.R
|
||||
)
|
||||
if resumeToken != "" {
|
||||
var err error
|
||||
peerID, err = api.Options.CoordinatorResumeTokenProvider.VerifyResumeToken(resumeToken)
|
||||
if err != nil {
|
||||
peerID, err = api.Options.CoordinatorResumeTokenProvider.VerifyResumeToken(ctx, resumeToken)
|
||||
// If the token is missing the key ID, it's probably an old token in which
|
||||
// case we just want to generate a new peer ID.
|
||||
if xerrors.Is(err, jwtutils.ErrMissingKeyID) {
|
||||
peerID = uuid.New()
|
||||
} else if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusUnauthorized, codersdk.Response{
|
||||
Message: workspacesdk.CoordinateAPIInvalidResumeToken,
|
||||
Detail: err.Error(),
|
||||
@@ -862,9 +867,10 @@ func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.R
|
||||
},
|
||||
})
|
||||
return
|
||||
} else {
|
||||
api.Logger.Debug(ctx, "accepted coordinate resume token for peer",
|
||||
slog.F("peer_id", peerID.String()))
|
||||
}
|
||||
api.Logger.Debug(ctx, "accepted coordinate resume token for peer",
|
||||
slog.F("peer_id", peerID.String()))
|
||||
}
|
||||
|
||||
api.WebsocketWaitMutex.Lock()
|
||||
|
||||
+130
-61
@@ -13,6 +13,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-jose/go-jose/v4/jwt"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -36,6 +37,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
"github.com/coder/coder/v2/coderd/externalauth"
|
||||
"github.com/coder/coder/v2/coderd/jwtutils"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
@@ -531,20 +533,20 @@ func newResumeTokenRecordingProvider(t testing.TB, underlying tailnet.ResumeToke
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resumeTokenRecordingProvider) GenerateResumeToken(peerID uuid.UUID) (*tailnetproto.RefreshResumeTokenResponse, error) {
|
||||
func (r *resumeTokenRecordingProvider) GenerateResumeToken(ctx context.Context, peerID uuid.UUID) (*tailnetproto.RefreshResumeTokenResponse, error) {
|
||||
select {
|
||||
case r.generateCalls <- peerID:
|
||||
return r.ResumeTokenProvider.GenerateResumeToken(peerID)
|
||||
return r.ResumeTokenProvider.GenerateResumeToken(ctx, peerID)
|
||||
default:
|
||||
r.t.Error("generateCalls full")
|
||||
return nil, xerrors.New("generateCalls full")
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resumeTokenRecordingProvider) VerifyResumeToken(token string) (uuid.UUID, error) {
|
||||
func (r *resumeTokenRecordingProvider) VerifyResumeToken(ctx context.Context, token string) (uuid.UUID, error) {
|
||||
select {
|
||||
case r.verifyCalls <- token:
|
||||
return r.ResumeTokenProvider.VerifyResumeToken(token)
|
||||
return r.ResumeTokenProvider.VerifyResumeToken(ctx, token)
|
||||
default:
|
||||
r.t.Error("verifyCalls full")
|
||||
return uuid.Nil, xerrors.New("verifyCalls full")
|
||||
@@ -554,69 +556,136 @@ func (r *resumeTokenRecordingProvider) VerifyResumeToken(token string) (uuid.UUI
|
||||
func TestWorkspaceAgentClientCoordinate_ResumeToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
clock := quartz.NewMock(t)
|
||||
resumeTokenSigningKey, err := tailnet.GenerateResumeTokenSigningKey()
|
||||
require.NoError(t, err)
|
||||
resumeTokenProvider := newResumeTokenRecordingProvider(
|
||||
t,
|
||||
tailnet.NewResumeTokenKeyProvider(resumeTokenSigningKey, clock, time.Hour),
|
||||
)
|
||||
client, closer, api := coderdtest.NewWithAPI(t, &coderdtest.Options{
|
||||
Coordinator: tailnet.NewCoordinator(logger),
|
||||
CoordinatorResumeTokenProvider: resumeTokenProvider,
|
||||
t.Run("OK", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
clock := quartz.NewMock(t)
|
||||
resumeTokenSigningKey, err := tailnet.GenerateResumeTokenSigningKey()
|
||||
mgr := jwtutils.StaticKey{
|
||||
ID: uuid.New().String(),
|
||||
Key: resumeTokenSigningKey[:],
|
||||
}
|
||||
require.NoError(t, err)
|
||||
resumeTokenProvider := newResumeTokenRecordingProvider(
|
||||
t,
|
||||
tailnet.NewResumeTokenKeyProvider(mgr, clock, time.Hour),
|
||||
)
|
||||
client, closer, api := coderdtest.NewWithAPI(t, &coderdtest.Options{
|
||||
Coordinator: tailnet.NewCoordinator(logger),
|
||||
CoordinatorResumeTokenProvider: resumeTokenProvider,
|
||||
})
|
||||
defer closer.Close()
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
// Create a workspace with an agent. No need to connect it since clients can
|
||||
// still connect to the coordinator while the agent isn't connected.
|
||||
r := dbfake.WorkspaceBuild(t, api.Database, database.WorkspaceTable{
|
||||
OrganizationID: user.OrganizationID,
|
||||
OwnerID: user.UserID,
|
||||
}).WithAgent().Do()
|
||||
agentTokenUUID, err := uuid.Parse(r.AgentToken)
|
||||
require.NoError(t, err)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
agentAndBuild, err := api.Database.GetWorkspaceAgentAndLatestBuildByAuthToken(dbauthz.AsSystemRestricted(ctx), agentTokenUUID) //nolint
|
||||
require.NoError(t, err)
|
||||
|
||||
// Connect with no resume token, and ensure that the peer ID is set to a
|
||||
// random value.
|
||||
originalResumeToken, err := connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, "")
|
||||
require.NoError(t, err)
|
||||
originalPeerID := testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.generateCalls)
|
||||
require.NotEqual(t, originalPeerID, uuid.Nil)
|
||||
|
||||
// Connect with a valid resume token, and ensure that the peer ID is set to
|
||||
// the stored value.
|
||||
clock.Advance(time.Second)
|
||||
newResumeToken, err := connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, originalResumeToken)
|
||||
require.NoError(t, err)
|
||||
verifiedToken := testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.verifyCalls)
|
||||
require.Equal(t, originalResumeToken, verifiedToken)
|
||||
newPeerID := testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.generateCalls)
|
||||
require.Equal(t, originalPeerID, newPeerID)
|
||||
require.NotEqual(t, originalResumeToken, newResumeToken)
|
||||
|
||||
// Connect with an invalid resume token, and ensure that the request is
|
||||
// rejected.
|
||||
clock.Advance(time.Second)
|
||||
_, err = connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, "invalid")
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
require.Equal(t, http.StatusUnauthorized, sdkErr.StatusCode())
|
||||
require.Len(t, sdkErr.Validations, 1)
|
||||
require.Equal(t, "resume_token", sdkErr.Validations[0].Field)
|
||||
verifiedToken = testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.verifyCalls)
|
||||
require.Equal(t, "invalid", verifiedToken)
|
||||
|
||||
select {
|
||||
case <-resumeTokenProvider.generateCalls:
|
||||
t.Fatal("unexpected peer ID in channel")
|
||||
default:
|
||||
}
|
||||
})
|
||||
defer closer.Close()
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
// Create a workspace with an agent. No need to connect it since clients can
|
||||
// still connect to the coordinator while the agent isn't connected.
|
||||
r := dbfake.WorkspaceBuild(t, api.Database, database.WorkspaceTable{
|
||||
OrganizationID: user.OrganizationID,
|
||||
OwnerID: user.UserID,
|
||||
}).WithAgent().Do()
|
||||
agentTokenUUID, err := uuid.Parse(r.AgentToken)
|
||||
require.NoError(t, err)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
agentAndBuild, err := api.Database.GetWorkspaceAgentAndLatestBuildByAuthToken(dbauthz.AsSystemRestricted(ctx), agentTokenUUID) //nolint
|
||||
require.NoError(t, err)
|
||||
t.Run("BadJWT", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Connect with no resume token, and ensure that the peer ID is set to a
|
||||
// random value.
|
||||
originalResumeToken, err := connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, "")
|
||||
require.NoError(t, err)
|
||||
originalPeerID := testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.generateCalls)
|
||||
require.NotEqual(t, originalPeerID, uuid.Nil)
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
clock := quartz.NewMock(t)
|
||||
resumeTokenSigningKey, err := tailnet.GenerateResumeTokenSigningKey()
|
||||
mgr := jwtutils.StaticKey{
|
||||
ID: uuid.New().String(),
|
||||
Key: resumeTokenSigningKey[:],
|
||||
}
|
||||
require.NoError(t, err)
|
||||
resumeTokenProvider := newResumeTokenRecordingProvider(
|
||||
t,
|
||||
tailnet.NewResumeTokenKeyProvider(mgr, clock, time.Hour),
|
||||
)
|
||||
client, closer, api := coderdtest.NewWithAPI(t, &coderdtest.Options{
|
||||
Coordinator: tailnet.NewCoordinator(logger),
|
||||
CoordinatorResumeTokenProvider: resumeTokenProvider,
|
||||
})
|
||||
defer closer.Close()
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
// Connect with a valid resume token, and ensure that the peer ID is set to
|
||||
// the stored value.
|
||||
clock.Advance(time.Second)
|
||||
newResumeToken, err := connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, originalResumeToken)
|
||||
require.NoError(t, err)
|
||||
verifiedToken := testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.verifyCalls)
|
||||
require.Equal(t, originalResumeToken, verifiedToken)
|
||||
newPeerID := testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.generateCalls)
|
||||
require.Equal(t, originalPeerID, newPeerID)
|
||||
require.NotEqual(t, originalResumeToken, newResumeToken)
|
||||
// Create a workspace with an agent. No need to connect it since clients can
|
||||
// still connect to the coordinator while the agent isn't connected.
|
||||
r := dbfake.WorkspaceBuild(t, api.Database, database.WorkspaceTable{
|
||||
OrganizationID: user.OrganizationID,
|
||||
OwnerID: user.UserID,
|
||||
}).WithAgent().Do()
|
||||
agentTokenUUID, err := uuid.Parse(r.AgentToken)
|
||||
require.NoError(t, err)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
agentAndBuild, err := api.Database.GetWorkspaceAgentAndLatestBuildByAuthToken(dbauthz.AsSystemRestricted(ctx), agentTokenUUID) //nolint
|
||||
require.NoError(t, err)
|
||||
|
||||
// Connect with an invalid resume token, and ensure that the request is
|
||||
// rejected.
|
||||
clock.Advance(time.Second)
|
||||
_, err = connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, "invalid")
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
require.Equal(t, http.StatusUnauthorized, sdkErr.StatusCode())
|
||||
require.Len(t, sdkErr.Validations, 1)
|
||||
require.Equal(t, "resume_token", sdkErr.Validations[0].Field)
|
||||
verifiedToken = testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.verifyCalls)
|
||||
require.Equal(t, "invalid", verifiedToken)
|
||||
// Connect with no resume token, and ensure that the peer ID is set to a
|
||||
// random value.
|
||||
originalResumeToken, err := connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, "")
|
||||
require.NoError(t, err)
|
||||
originalPeerID := testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.generateCalls)
|
||||
require.NotEqual(t, originalPeerID, uuid.Nil)
|
||||
|
||||
select {
|
||||
case <-resumeTokenProvider.generateCalls:
|
||||
t.Fatal("unexpected peer ID in channel")
|
||||
default:
|
||||
}
|
||||
// Connect with an outdated token, and ensure that the peer ID is set to a
|
||||
// random value. We don't want to fail requests just because
|
||||
// a user got unlucky during a deployment upgrade.
|
||||
outdatedToken := generateBadJWT(t, jwtutils.RegisteredClaims{
|
||||
Subject: originalPeerID.String(),
|
||||
Expiry: jwt.NewNumericDate(clock.Now().Add(time.Minute)),
|
||||
})
|
||||
|
||||
clock.Advance(time.Second)
|
||||
newResumeToken, err := connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, outdatedToken)
|
||||
require.NoError(t, err)
|
||||
verifiedToken := testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.verifyCalls)
|
||||
require.Equal(t, outdatedToken, verifiedToken)
|
||||
newPeerID := testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.generateCalls)
|
||||
require.NotEqual(t, originalPeerID, newPeerID)
|
||||
require.NotEqual(t, originalResumeToken, newResumeToken)
|
||||
})
|
||||
}
|
||||
|
||||
// connectToCoordinatorAndFetchResumeToken connects to the tailnet coordinator
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
"github.com/coder/coder/v2/coderd/jwtutils"
|
||||
"github.com/coder/coder/v2/coderd/rbac/policy"
|
||||
"github.com/coder/coder/v2/coderd/workspaceapps"
|
||||
"github.com/coder/coder/v2/coderd/workspaceapps/appurl"
|
||||
@@ -122,10 +123,11 @@ func (api *API) workspaceApplicationAuth(rw http.ResponseWriter, r *http.Request
|
||||
return
|
||||
}
|
||||
|
||||
// Encrypt the API key.
|
||||
encryptedAPIKey, err := api.AppSecurityKey.EncryptAPIKey(workspaceapps.EncryptedAPIKeyPayload{
|
||||
payload := workspaceapps.EncryptedAPIKeyPayload{
|
||||
APIKey: cookie.Value,
|
||||
})
|
||||
}
|
||||
payload.Fill(api.Clock.Now())
|
||||
encryptedAPIKey, err := jwtutils.Encrypt(ctx, api.AppEncryptionKeyCache, payload)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to encrypt API key.",
|
||||
|
||||
@@ -3,6 +3,7 @@ package apptest
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -408,6 +409,67 @@ func Run(t *testing.T, appHostIsPrimary bool, factory DeploymentFactory) {
|
||||
require.Equal(t, http.StatusInternalServerError, resp.StatusCode)
|
||||
assertWorkspaceLastUsedAtNotUpdated(t, appDetails)
|
||||
})
|
||||
|
||||
t.Run("BadJWT", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
appDetails := setupProxyTest(t, nil)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
u := appDetails.PathAppURL(appDetails.Apps.Owner)
|
||||
resp, err := requestWithRetries(ctx, t, appDetails.AppClient(t), http.MethodGet, u.String(), nil)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, proxyTestAppBody, string(body))
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
appTokenCookie := findCookie(resp.Cookies(), codersdk.SignedAppTokenCookie)
|
||||
require.NotNil(t, appTokenCookie, "no signed app token cookie in response")
|
||||
require.Equal(t, appTokenCookie.Path, u.Path, "incorrect path on app token cookie")
|
||||
|
||||
object, err := jose.ParseSigned(appTokenCookie.Value)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, object.Signatures, 1)
|
||||
|
||||
// Parse the payload.
|
||||
var tok workspaceapps.SignedToken
|
||||
//nolint:gosec
|
||||
err = json.Unmarshal(object.UnsafePayloadWithoutVerification(), &tok)
|
||||
require.NoError(t, err)
|
||||
|
||||
appTokenClient := appDetails.AppClient(t)
|
||||
apiKey := appTokenClient.SessionToken()
|
||||
appTokenClient.SetSessionToken("")
|
||||
appTokenClient.HTTPClient.Jar, err = cookiejar.New(nil)
|
||||
require.NoError(t, err)
|
||||
// Sign the token with an old-style key.
|
||||
appTokenCookie.Value = generateBadJWT(t, tok)
|
||||
appTokenClient.HTTPClient.Jar.SetCookies(u,
|
||||
[]*http.Cookie{
|
||||
appTokenCookie,
|
||||
{
|
||||
Name: codersdk.PathAppSessionTokenCookie,
|
||||
Value: apiKey,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
resp, err = requestWithRetries(ctx, t, appTokenClient, http.MethodGet, u.String(), nil)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
body, err = io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, proxyTestAppBody, string(body))
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
assertWorkspaceLastUsedAtUpdated(t, appDetails)
|
||||
|
||||
// Since the old token is invalid, the signed app token cookie should have a new value.
|
||||
newTokenCookie := findCookie(resp.Cookies(), codersdk.SignedAppTokenCookie)
|
||||
require.NotEqual(t, appTokenCookie.Value, newTokenCookie.Value)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("WorkspaceApplicationAuth", func(t *testing.T) {
|
||||
@@ -463,7 +525,7 @@ func Run(t *testing.T, appHostIsPrimary bool, factory DeploymentFactory) {
|
||||
appClient.SetSessionToken("")
|
||||
|
||||
// Try to load the application without authentication.
|
||||
u := c.appURL
|
||||
u := *c.appURL
|
||||
u.Path = path.Join(u.Path, "/test")
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
|
||||
require.NoError(t, err)
|
||||
@@ -500,7 +562,7 @@ func Run(t *testing.T, appHostIsPrimary bool, factory DeploymentFactory) {
|
||||
|
||||
// Copy the query parameters and then check equality.
|
||||
u.RawQuery = gotLocation.RawQuery
|
||||
require.Equal(t, u, gotLocation)
|
||||
require.Equal(t, u, *gotLocation)
|
||||
|
||||
// Verify the API key is set.
|
||||
encryptedAPIKey := gotLocation.Query().Get(workspaceapps.SubdomainProxyAPIKeyParam)
|
||||
@@ -580,6 +642,38 @@ func Run(t *testing.T, appHostIsPrimary bool, factory DeploymentFactory) {
|
||||
resp.Body.Close()
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("BadJWE", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
currentKeyStr := appDetails.SDKClient.SessionToken()
|
||||
appClient := appDetails.AppClient(t)
|
||||
appClient.SetSessionToken("")
|
||||
u := *c.appURL
|
||||
u.Path = path.Join(u.Path, "/test")
|
||||
badToken := generateBadJWE(t, workspaceapps.EncryptedAPIKeyPayload{
|
||||
APIKey: currentKeyStr,
|
||||
})
|
||||
|
||||
u.RawQuery = (url.Values{
|
||||
workspaceapps.SubdomainProxyAPIKeyParam: {badToken},
|
||||
}).Encode()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
var resp *http.Response
|
||||
resp, err = doWithRetries(t, appClient, req)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, string(body), "Could not decrypt API key. Please remove the query parameter and try again.")
|
||||
})
|
||||
}
|
||||
})
|
||||
})
|
||||
@@ -1077,6 +1171,68 @@ func Run(t *testing.T, appHostIsPrimary bool, factory DeploymentFactory) {
|
||||
assertWorkspaceLastUsedAtNotUpdated(t, appDetails)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("BadJWT", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
appDetails := setupProxyTest(t, nil)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
u := appDetails.SubdomainAppURL(appDetails.Apps.Owner)
|
||||
resp, err := requestWithRetries(ctx, t, appDetails.AppClient(t), http.MethodGet, u.String(), nil)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, proxyTestAppBody, string(body))
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
appTokenCookie := findCookie(resp.Cookies(), codersdk.SignedAppTokenCookie)
|
||||
require.NotNil(t, appTokenCookie, "no signed token cookie in response")
|
||||
require.Equal(t, appTokenCookie.Path, "/", "incorrect path on signed token cookie")
|
||||
|
||||
object, err := jose.ParseSigned(appTokenCookie.Value)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, object.Signatures, 1)
|
||||
|
||||
// Parse the payload.
|
||||
var tok workspaceapps.SignedToken
|
||||
//nolint:gosec
|
||||
err = json.Unmarshal(object.UnsafePayloadWithoutVerification(), &tok)
|
||||
require.NoError(t, err)
|
||||
|
||||
appTokenClient := appDetails.AppClient(t)
|
||||
apiKey := appTokenClient.SessionToken()
|
||||
appTokenClient.SetSessionToken("")
|
||||
appTokenClient.HTTPClient.Jar, err = cookiejar.New(nil)
|
||||
require.NoError(t, err)
|
||||
// Sign the token with an old-style key.
|
||||
appTokenCookie.Value = generateBadJWT(t, tok)
|
||||
appTokenClient.HTTPClient.Jar.SetCookies(u,
|
||||
[]*http.Cookie{
|
||||
appTokenCookie,
|
||||
{
|
||||
Name: codersdk.SubdomainAppSessionTokenCookie,
|
||||
Value: apiKey,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
// We should still be able to successfully proxy.
|
||||
resp, err = requestWithRetries(ctx, t, appTokenClient, http.MethodGet, u.String(), nil)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
body, err = io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, proxyTestAppBody, string(body))
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
assertWorkspaceLastUsedAtUpdated(t, appDetails)
|
||||
|
||||
// Since the old token is invalid, the signed app token cookie should have a new value.
|
||||
newTokenCookie := findCookie(resp.Cookies(), codersdk.SignedAppTokenCookie)
|
||||
require.NotEqual(t, appTokenCookie.Value, newTokenCookie.Value)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("PortSharing", func(t *testing.T) {
|
||||
@@ -1789,3 +1945,57 @@ func assertWorkspaceLastUsedAtNotUpdated(t testing.TB, details *Details) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, before.LastUsedAt, after.LastUsedAt, "workspace LastUsedAt updated when it should not have been")
|
||||
}
|
||||
|
||||
func generateBadJWE(t *testing.T, claims interface{}) string {
|
||||
t.Helper()
|
||||
var buf [32]byte
|
||||
_, err := rand.Read(buf[:])
|
||||
require.NoError(t, err)
|
||||
encrypt, err := jose.NewEncrypter(
|
||||
jose.A256GCM,
|
||||
jose.Recipient{
|
||||
Algorithm: jose.A256GCMKW,
|
||||
Key: buf[:],
|
||||
}, &jose.EncrypterOptions{
|
||||
Compression: jose.DEFLATE,
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
payload, err := json.Marshal(claims)
|
||||
require.NoError(t, err)
|
||||
signed, err := encrypt.Encrypt(payload)
|
||||
require.NoError(t, err)
|
||||
compact, err := signed.CompactSerialize()
|
||||
require.NoError(t, err)
|
||||
return compact
|
||||
}
|
||||
|
||||
// generateBadJWT generates a JWT with a random key. It's intended to emulate the old-style JWT's we generated.
|
||||
func generateBadJWT(t *testing.T, claims interface{}) string {
|
||||
t.Helper()
|
||||
|
||||
var buf [64]byte
|
||||
_, err := rand.Read(buf[:])
|
||||
require.NoError(t, err)
|
||||
signer, err := jose.NewSigner(jose.SigningKey{
|
||||
Algorithm: jose.HS512,
|
||||
Key: buf[:],
|
||||
}, nil)
|
||||
require.NoError(t, err)
|
||||
payload, err := json.Marshal(claims)
|
||||
require.NoError(t, err)
|
||||
signed, err := signer.Sign(payload)
|
||||
require.NoError(t, err)
|
||||
compact, err := signed.CompactSerialize()
|
||||
require.NoError(t, err)
|
||||
return compact
|
||||
}
|
||||
|
||||
func findCookie(cookies []*http.Cookie, name string) *http.Cookie {
|
||||
for _, cookie := range cookies {
|
||||
if cookie.Name == name {
|
||||
return cookie
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -13,11 +13,15 @@ import (
|
||||
"golang.org/x/exp/slices"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/go-jose/go-jose/v4/jwt"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/v2/coderd/cryptokeys"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
"github.com/coder/coder/v2/coderd/jwtutils"
|
||||
"github.com/coder/coder/v2/coderd/rbac"
|
||||
"github.com/coder/coder/v2/coderd/rbac/policy"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
@@ -35,12 +39,20 @@ type DBTokenProvider struct {
|
||||
DeploymentValues *codersdk.DeploymentValues
|
||||
OAuth2Configs *httpmw.OAuth2Configs
|
||||
WorkspaceAgentInactiveTimeout time.Duration
|
||||
SigningKey SecurityKey
|
||||
Keycache cryptokeys.SigningKeycache
|
||||
}
|
||||
|
||||
var _ SignedTokenProvider = &DBTokenProvider{}
|
||||
|
||||
func NewDBTokenProvider(log slog.Logger, accessURL *url.URL, authz rbac.Authorizer, db database.Store, cfg *codersdk.DeploymentValues, oauth2Cfgs *httpmw.OAuth2Configs, workspaceAgentInactiveTimeout time.Duration, signingKey SecurityKey) SignedTokenProvider {
|
||||
func NewDBTokenProvider(log slog.Logger,
|
||||
accessURL *url.URL,
|
||||
authz rbac.Authorizer,
|
||||
db database.Store,
|
||||
cfg *codersdk.DeploymentValues,
|
||||
oauth2Cfgs *httpmw.OAuth2Configs,
|
||||
workspaceAgentInactiveTimeout time.Duration,
|
||||
signer cryptokeys.SigningKeycache,
|
||||
) SignedTokenProvider {
|
||||
if workspaceAgentInactiveTimeout == 0 {
|
||||
workspaceAgentInactiveTimeout = 1 * time.Minute
|
||||
}
|
||||
@@ -53,12 +65,12 @@ func NewDBTokenProvider(log slog.Logger, accessURL *url.URL, authz rbac.Authoriz
|
||||
DeploymentValues: cfg,
|
||||
OAuth2Configs: oauth2Cfgs,
|
||||
WorkspaceAgentInactiveTimeout: workspaceAgentInactiveTimeout,
|
||||
SigningKey: signingKey,
|
||||
Keycache: signer,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *DBTokenProvider) FromRequest(r *http.Request) (*SignedToken, bool) {
|
||||
return FromRequest(r, p.SigningKey)
|
||||
return FromRequest(r, p.Keycache)
|
||||
}
|
||||
|
||||
func (p *DBTokenProvider) Issue(ctx context.Context, rw http.ResponseWriter, r *http.Request, issueReq IssueTokenRequest) (*SignedToken, string, bool) {
|
||||
@@ -70,7 +82,7 @@ func (p *DBTokenProvider) Issue(ctx context.Context, rw http.ResponseWriter, r *
|
||||
dangerousSystemCtx := dbauthz.AsSystemRestricted(ctx)
|
||||
|
||||
appReq := issueReq.AppRequest.Normalize()
|
||||
err := appReq.Validate()
|
||||
err := appReq.Check()
|
||||
if err != nil {
|
||||
WriteWorkspaceApp500(p.Logger, p.DashboardURL, rw, r, &appReq, err, "invalid app request")
|
||||
return nil, "", false
|
||||
@@ -210,9 +222,11 @@ func (p *DBTokenProvider) Issue(ctx context.Context, rw http.ResponseWriter, r *
|
||||
return nil, "", false
|
||||
}
|
||||
|
||||
token.RegisteredClaims = jwtutils.RegisteredClaims{
|
||||
Expiry: jwt.NewNumericDate(time.Now().Add(DefaultTokenExpiry)),
|
||||
}
|
||||
// Sign the token.
|
||||
token.Expiry = time.Now().Add(DefaultTokenExpiry)
|
||||
tokenStr, err := p.SigningKey.SignToken(token)
|
||||
tokenStr, err := jwtutils.Sign(ctx, p.Keycache, token)
|
||||
if err != nil {
|
||||
WriteWorkspaceApp500(p.Logger, p.DashboardURL, rw, r, &appReq, err, "generate token")
|
||||
return nil, "", false
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-jose/go-jose/v4/jwt"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -20,6 +21,7 @@ import (
|
||||
"github.com/coder/coder/v2/agent/agenttest"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
"github.com/coder/coder/v2/coderd/jwtutils"
|
||||
"github.com/coder/coder/v2/coderd/workspaceapps"
|
||||
"github.com/coder/coder/v2/coderd/workspaceapps/appurl"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
@@ -94,8 +96,7 @@ func Test_ResolveRequest(t *testing.T) {
|
||||
_ = closer.Close()
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
|
||||
t.Cleanup(cancel)
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
firstUser := coderdtest.CreateFirstUser(t, client)
|
||||
me, err := client.User(ctx, codersdk.Me)
|
||||
@@ -276,15 +277,17 @@ func Test_ResolveRequest(t *testing.T) {
|
||||
_ = w.Body.Close()
|
||||
|
||||
require.Equal(t, &workspaceapps.SignedToken{
|
||||
RegisteredClaims: jwtutils.RegisteredClaims{
|
||||
Expiry: jwt.NewNumericDate(token.Expiry.Time()),
|
||||
},
|
||||
Request: req,
|
||||
Expiry: token.Expiry, // ignored to avoid flakiness
|
||||
UserID: me.ID,
|
||||
WorkspaceID: workspace.ID,
|
||||
AgentID: agentID,
|
||||
AppURL: appURL,
|
||||
}, token)
|
||||
require.NotZero(t, token.Expiry)
|
||||
require.WithinDuration(t, time.Now().Add(workspaceapps.DefaultTokenExpiry), token.Expiry, time.Minute)
|
||||
require.WithinDuration(t, time.Now().Add(workspaceapps.DefaultTokenExpiry), token.Expiry.Time(), time.Minute)
|
||||
|
||||
// Check that the token was set in the response and is valid.
|
||||
require.Len(t, w.Cookies(), 1)
|
||||
@@ -292,10 +295,11 @@ func Test_ResolveRequest(t *testing.T) {
|
||||
require.Equal(t, codersdk.SignedAppTokenCookie, cookie.Name)
|
||||
require.Equal(t, req.BasePath, cookie.Path)
|
||||
|
||||
parsedToken, err := api.AppSecurityKey.VerifySignedToken(cookie.Value)
|
||||
var parsedToken workspaceapps.SignedToken
|
||||
err := jwtutils.Verify(ctx, api.AppSigningKeyCache, cookie.Value, &parsedToken)
|
||||
require.NoError(t, err)
|
||||
// normalize expiry
|
||||
require.WithinDuration(t, token.Expiry, parsedToken.Expiry, 2*time.Second)
|
||||
require.WithinDuration(t, token.Expiry.Time(), parsedToken.Expiry.Time(), 2*time.Second)
|
||||
parsedToken.Expiry = token.Expiry
|
||||
require.Equal(t, token, &parsedToken)
|
||||
|
||||
@@ -314,7 +318,7 @@ func Test_ResolveRequest(t *testing.T) {
|
||||
})
|
||||
require.True(t, ok)
|
||||
// normalize expiry
|
||||
require.WithinDuration(t, token.Expiry, secondToken.Expiry, 2*time.Second)
|
||||
require.WithinDuration(t, token.Expiry.Time(), secondToken.Expiry.Time(), 2*time.Second)
|
||||
secondToken.Expiry = token.Expiry
|
||||
require.Equal(t, token, secondToken)
|
||||
}
|
||||
@@ -540,13 +544,16 @@ func Test_ResolveRequest(t *testing.T) {
|
||||
// App name differs
|
||||
AppSlugOrPort: appNamePublic,
|
||||
}).Normalize(),
|
||||
Expiry: time.Now().Add(time.Minute),
|
||||
RegisteredClaims: jwtutils.RegisteredClaims{
|
||||
Expiry: jwt.NewNumericDate(time.Now().Add(time.Minute)),
|
||||
},
|
||||
UserID: me.ID,
|
||||
WorkspaceID: workspace.ID,
|
||||
AgentID: agentID,
|
||||
AppURL: appURL,
|
||||
}
|
||||
badTokenStr, err := api.AppSecurityKey.SignToken(badToken)
|
||||
|
||||
badTokenStr, err := jwtutils.Sign(ctx, api.AppSigningKeyCache, badToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := (workspaceapps.Request{
|
||||
@@ -589,7 +596,8 @@ func Test_ResolveRequest(t *testing.T) {
|
||||
require.Len(t, cookies, 1)
|
||||
require.Equal(t, cookies[0].Name, codersdk.SignedAppTokenCookie)
|
||||
require.NotEqual(t, cookies[0].Value, badTokenStr)
|
||||
parsedToken, err := api.AppSecurityKey.VerifySignedToken(cookies[0].Value)
|
||||
var parsedToken workspaceapps.SignedToken
|
||||
err = jwtutils.Verify(ctx, api.AppSigningKeyCache, cookies[0].Value, &parsedToken)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, appNameOwner, parsedToken.AppSlugOrPort)
|
||||
})
|
||||
|
||||
@@ -38,7 +38,7 @@ type ResolveRequestOptions struct {
|
||||
|
||||
func ResolveRequest(rw http.ResponseWriter, r *http.Request, opts ResolveRequestOptions) (*SignedToken, bool) {
|
||||
appReq := opts.AppRequest.Normalize()
|
||||
err := appReq.Validate()
|
||||
err := appReq.Check()
|
||||
if err != nil {
|
||||
// This is a 500 since it's a coder server or proxy that's making this
|
||||
// request struct based on details from the request. The values should
|
||||
@@ -79,7 +79,7 @@ func ResolveRequest(rw http.ResponseWriter, r *http.Request, opts ResolveRequest
|
||||
Name: codersdk.SignedAppTokenCookie,
|
||||
Value: tokenStr,
|
||||
Path: appReq.BasePath,
|
||||
Expires: token.Expiry,
|
||||
Expires: token.Expiry.Time(),
|
||||
})
|
||||
|
||||
return token, true
|
||||
|
||||
@@ -11,17 +11,21 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-jose/go-jose/v4/jwt"
|
||||
"github.com/google/uuid"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"nhooyr.io/websocket"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/v2/agent/agentssh"
|
||||
"github.com/coder/coder/v2/coderd/cryptokeys"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
"github.com/coder/coder/v2/coderd/jwtutils"
|
||||
"github.com/coder/coder/v2/coderd/tracing"
|
||||
"github.com/coder/coder/v2/coderd/util/slice"
|
||||
"github.com/coder/coder/v2/coderd/workspaceapps/appurl"
|
||||
@@ -97,8 +101,8 @@ type Server struct {
|
||||
HostnameRegex *regexp.Regexp
|
||||
RealIPConfig *httpmw.RealIPConfig
|
||||
|
||||
SignedTokenProvider SignedTokenProvider
|
||||
AppSecurityKey SecurityKey
|
||||
SignedTokenProvider SignedTokenProvider
|
||||
APIKeyEncryptionKeycache cryptokeys.EncryptionKeycache
|
||||
|
||||
// DisablePathApps disables path-based apps. This is a security feature as path
|
||||
// based apps share the same cookie as the dashboard, and are susceptible to XSS
|
||||
@@ -176,7 +180,10 @@ func (s *Server) handleAPIKeySmuggling(rw http.ResponseWriter, r *http.Request,
|
||||
}
|
||||
|
||||
// Exchange the encoded API key for a real one.
|
||||
token, err := s.AppSecurityKey.DecryptAPIKey(encryptedAPIKey)
|
||||
var payload EncryptedAPIKeyPayload
|
||||
err := jwtutils.Decrypt(ctx, s.APIKeyEncryptionKeycache, encryptedAPIKey, &payload, jwtutils.WithDecryptExpected(jwt.Expected{
|
||||
Time: time.Now(),
|
||||
}))
|
||||
if err != nil {
|
||||
s.Logger.Debug(ctx, "could not decrypt smuggled workspace app API key", slog.Error(err))
|
||||
site.RenderStaticErrorPage(rw, r, site.ErrorPageData{
|
||||
@@ -225,7 +232,7 @@ func (s *Server) handleAPIKeySmuggling(rw http.ResponseWriter, r *http.Request,
|
||||
// server using the wrong value.
|
||||
http.SetCookie(rw, &http.Cookie{
|
||||
Name: AppConnectSessionTokenCookieName(accessMethod),
|
||||
Value: token,
|
||||
Value: payload.APIKey,
|
||||
Domain: domain,
|
||||
Path: "/",
|
||||
MaxAge: 0,
|
||||
|
||||
@@ -124,9 +124,9 @@ func (r Request) Normalize() Request {
|
||||
return req
|
||||
}
|
||||
|
||||
// Validate ensures the request is correct and contains the necessary
|
||||
// Check ensures the request is correct and contains the necessary
|
||||
// parameters.
|
||||
func (r Request) Validate() error {
|
||||
func (r Request) Check() error {
|
||||
switch r.AccessMethod {
|
||||
case AccessMethodPath, AccessMethodSubdomain, AccessMethodTerminal:
|
||||
default:
|
||||
|
||||
@@ -279,7 +279,7 @@ func Test_RequestValidate(t *testing.T) {
|
||||
if !c.noNormalize {
|
||||
req = c.req.Normalize()
|
||||
}
|
||||
err := req.Validate()
|
||||
err := req.Check()
|
||||
if c.errContains == "" {
|
||||
require.NoError(t, err)
|
||||
} else {
|
||||
|
||||
+24
-187
@@ -1,35 +1,27 @@
|
||||
package workspaceapps
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-jose/go-jose/v3"
|
||||
"github.com/go-jose/go-jose/v4/jwt"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/coderd/cryptokeys"
|
||||
"github.com/coder/coder/v2/coderd/jwtutils"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
const (
|
||||
tokenSigningAlgorithm = jose.HS512
|
||||
apiKeyEncryptionAlgorithm = jose.A256GCMKW
|
||||
)
|
||||
|
||||
// SignedToken is the struct data contained inside a workspace app JWE. It
|
||||
// contains the details of the workspace app that the token is valid for to
|
||||
// avoid database queries.
|
||||
type SignedToken struct {
|
||||
jwtutils.RegisteredClaims
|
||||
// Request details.
|
||||
Request `json:"request"`
|
||||
|
||||
// Trusted resolved details.
|
||||
Expiry time.Time `json:"expiry"` // set by GenerateToken if unset
|
||||
UserID uuid.UUID `json:"user_id"`
|
||||
WorkspaceID uuid.UUID `json:"workspace_id"`
|
||||
AgentID uuid.UUID `json:"agent_id"`
|
||||
@@ -57,191 +49,32 @@ func (t SignedToken) MatchesRequest(req Request) bool {
|
||||
t.AppSlugOrPort == req.AppSlugOrPort
|
||||
}
|
||||
|
||||
// SecurityKey is used for signing and encrypting app tokens and API keys.
|
||||
//
|
||||
// The first 64 bytes of the key are used for signing tokens with HMAC-SHA256,
|
||||
// and the last 32 bytes are used for encrypting API keys with AES-256-GCM.
|
||||
// We use a single key for both operations to avoid having to store and manage
|
||||
// two keys.
|
||||
type SecurityKey [96]byte
|
||||
|
||||
func (k SecurityKey) IsZero() bool {
|
||||
return k == SecurityKey{}
|
||||
}
|
||||
|
||||
func (k SecurityKey) String() string {
|
||||
return hex.EncodeToString(k[:])
|
||||
}
|
||||
|
||||
func (k SecurityKey) signingKey() []byte {
|
||||
return k[:64]
|
||||
}
|
||||
|
||||
func (k SecurityKey) encryptionKey() []byte {
|
||||
return k[64:]
|
||||
}
|
||||
|
||||
func KeyFromString(str string) (SecurityKey, error) {
|
||||
var key SecurityKey
|
||||
decoded, err := hex.DecodeString(str)
|
||||
if err != nil {
|
||||
return key, xerrors.Errorf("decode key: %w", err)
|
||||
}
|
||||
if len(decoded) != len(key) {
|
||||
return key, xerrors.Errorf("expected key to be %d bytes, got %d", len(key), len(decoded))
|
||||
}
|
||||
copy(key[:], decoded)
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// SignToken generates a signed workspace app token with the given payload. If
|
||||
// the payload doesn't have an expiry, it will be set to the current time plus
|
||||
// the default expiry.
|
||||
func (k SecurityKey) SignToken(payload SignedToken) (string, error) {
|
||||
if payload.Expiry.IsZero() {
|
||||
payload.Expiry = time.Now().Add(DefaultTokenExpiry)
|
||||
}
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("marshal payload to JSON: %w", err)
|
||||
}
|
||||
|
||||
signer, err := jose.NewSigner(jose.SigningKey{
|
||||
Algorithm: tokenSigningAlgorithm,
|
||||
Key: k.signingKey(),
|
||||
}, nil)
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("create signer: %w", err)
|
||||
}
|
||||
|
||||
signedObject, err := signer.Sign(payloadBytes)
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("sign payload: %w", err)
|
||||
}
|
||||
|
||||
serialized, err := signedObject.CompactSerialize()
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("serialize JWS: %w", err)
|
||||
}
|
||||
|
||||
return serialized, nil
|
||||
}
|
||||
|
||||
// VerifySignedToken parses a signed workspace app token with the given key and
|
||||
// returns the payload. If the token is invalid or expired, an error is
|
||||
// returned.
|
||||
func (k SecurityKey) VerifySignedToken(str string) (SignedToken, error) {
|
||||
object, err := jose.ParseSigned(str)
|
||||
if err != nil {
|
||||
return SignedToken{}, xerrors.Errorf("parse JWS: %w", err)
|
||||
}
|
||||
if len(object.Signatures) != 1 {
|
||||
return SignedToken{}, xerrors.New("expected 1 signature")
|
||||
}
|
||||
if object.Signatures[0].Header.Algorithm != string(tokenSigningAlgorithm) {
|
||||
return SignedToken{}, xerrors.Errorf("expected token signing algorithm to be %q, got %q", tokenSigningAlgorithm, object.Signatures[0].Header.Algorithm)
|
||||
}
|
||||
|
||||
output, err := object.Verify(k.signingKey())
|
||||
if err != nil {
|
||||
return SignedToken{}, xerrors.Errorf("verify JWS: %w", err)
|
||||
}
|
||||
|
||||
var tok SignedToken
|
||||
err = json.Unmarshal(output, &tok)
|
||||
if err != nil {
|
||||
return SignedToken{}, xerrors.Errorf("unmarshal payload: %w", err)
|
||||
}
|
||||
if tok.Expiry.Before(time.Now()) {
|
||||
return SignedToken{}, xerrors.New("signed app token expired")
|
||||
}
|
||||
|
||||
return tok, nil
|
||||
}
|
||||
|
||||
type EncryptedAPIKeyPayload struct {
|
||||
APIKey string `json:"api_key"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
jwtutils.RegisteredClaims
|
||||
APIKey string `json:"api_key"`
|
||||
}
|
||||
|
||||
// EncryptAPIKey encrypts an API key for subdomain token smuggling.
|
||||
func (k SecurityKey) EncryptAPIKey(payload EncryptedAPIKeyPayload) (string, error) {
|
||||
if payload.APIKey == "" {
|
||||
return "", xerrors.New("API key is empty")
|
||||
}
|
||||
if payload.ExpiresAt.IsZero() {
|
||||
// Very short expiry as these keys are only used once as part of an
|
||||
// automatic redirection flow.
|
||||
payload.ExpiresAt = dbtime.Now().Add(time.Minute)
|
||||
}
|
||||
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("marshal payload: %w", err)
|
||||
}
|
||||
|
||||
// JWEs seem to apply a nonce themselves.
|
||||
encrypter, err := jose.NewEncrypter(
|
||||
jose.A256GCM,
|
||||
jose.Recipient{
|
||||
Algorithm: apiKeyEncryptionAlgorithm,
|
||||
Key: k.encryptionKey(),
|
||||
},
|
||||
&jose.EncrypterOptions{
|
||||
Compression: jose.DEFLATE,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("initializer jose encrypter: %w", err)
|
||||
}
|
||||
encryptedObject, err := encrypter.Encrypt(payloadBytes)
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("encrypt jwe: %w", err)
|
||||
}
|
||||
|
||||
encrypted := encryptedObject.FullSerialize()
|
||||
return base64.RawURLEncoding.EncodeToString([]byte(encrypted)), nil
|
||||
func (e *EncryptedAPIKeyPayload) Fill(now time.Time) {
|
||||
e.Issuer = "coderd"
|
||||
e.Audience = jwt.Audience{"wsproxy"}
|
||||
e.Expiry = jwt.NewNumericDate(now.Add(time.Minute))
|
||||
e.NotBefore = jwt.NewNumericDate(now.Add(-time.Minute))
|
||||
}
|
||||
|
||||
// DecryptAPIKey undoes EncryptAPIKey and is used in the subdomain app handler.
|
||||
func (k SecurityKey) DecryptAPIKey(encryptedAPIKey string) (string, error) {
|
||||
encrypted, err := base64.RawURLEncoding.DecodeString(encryptedAPIKey)
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("base64 decode encrypted API key: %w", err)
|
||||
func (e EncryptedAPIKeyPayload) Validate(ex jwt.Expected) error {
|
||||
if e.NotBefore == nil {
|
||||
return xerrors.Errorf("not before is required")
|
||||
}
|
||||
|
||||
object, err := jose.ParseEncrypted(string(encrypted))
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("parse encrypted API key: %w", err)
|
||||
}
|
||||
if object.Header.Algorithm != string(apiKeyEncryptionAlgorithm) {
|
||||
return "", xerrors.Errorf("expected API key encryption algorithm to be %q, got %q", apiKeyEncryptionAlgorithm, object.Header.Algorithm)
|
||||
}
|
||||
ex.Issuer = "coderd"
|
||||
ex.AnyAudience = jwt.Audience{"wsproxy"}
|
||||
|
||||
// Decrypt using the hashed secret.
|
||||
decrypted, err := object.Decrypt(k.encryptionKey())
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("decrypt API key: %w", err)
|
||||
}
|
||||
|
||||
// Unmarshal the payload.
|
||||
var payload EncryptedAPIKeyPayload
|
||||
if err := json.Unmarshal(decrypted, &payload); err != nil {
|
||||
return "", xerrors.Errorf("unmarshal decrypted payload: %w", err)
|
||||
}
|
||||
|
||||
// Validate expiry.
|
||||
if payload.ExpiresAt.Before(dbtime.Now()) {
|
||||
return "", xerrors.New("encrypted API key expired")
|
||||
}
|
||||
|
||||
return payload.APIKey, nil
|
||||
return e.RegisteredClaims.Validate(ex)
|
||||
}
|
||||
|
||||
// FromRequest returns the signed token from the request, if it exists and is
|
||||
// valid. The caller must check that the token matches the request.
|
||||
func FromRequest(r *http.Request, key SecurityKey) (*SignedToken, bool) {
|
||||
func FromRequest(r *http.Request, mgr cryptokeys.SigningKeycache) (*SignedToken, bool) {
|
||||
// Get all signed app tokens from the request. This includes the query
|
||||
// parameter and all matching cookies sent with the request. If there are
|
||||
// somehow multiple signed app token cookies, we want to try all of them
|
||||
@@ -270,8 +103,12 @@ func FromRequest(r *http.Request, key SecurityKey) (*SignedToken, bool) {
|
||||
tokens = tokens[:4]
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
for _, tokenStr := range tokens {
|
||||
token, err := key.VerifySignedToken(tokenStr)
|
||||
var token SignedToken
|
||||
err := jwtutils.Verify(ctx, mgr, tokenStr, &token, jwtutils.WithVerifyExpected(jwt.Expected{
|
||||
Time: time.Now(),
|
||||
}))
|
||||
if err == nil {
|
||||
req := token.Request.Normalize()
|
||||
if hasQueryParam && req.AccessMethod != AccessMethodTerminal {
|
||||
@@ -280,7 +117,7 @@ func FromRequest(r *http.Request, key SecurityKey) (*SignedToken, bool) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
err := req.Validate()
|
||||
err := req.Check()
|
||||
if err == nil {
|
||||
// The request has a valid signed app token, which is a valid
|
||||
// token signed by us. The caller must check that it matches
|
||||
|
||||
@@ -1,22 +1,22 @@
|
||||
package workspaceapps_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"crypto/rand"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/go-jose/go-jose/v4/jwt"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
|
||||
"github.com/go-jose/go-jose/v3"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/coderd/jwtutils"
|
||||
"github.com/coder/coder/v2/coderd/workspaceapps"
|
||||
"github.com/coder/coder/v2/cryptorand"
|
||||
)
|
||||
|
||||
func Test_TokenMatchesRequest(t *testing.T) {
|
||||
@@ -283,129 +283,6 @@ func Test_TokenMatchesRequest(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_GenerateToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("SetExpiry", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tokenStr, err := coderdtest.AppSecurityKey.SignToken(workspaceapps.SignedToken{
|
||||
Request: workspaceapps.Request{
|
||||
AccessMethod: workspaceapps.AccessMethodPath,
|
||||
BasePath: "/app",
|
||||
UsernameOrID: "foo",
|
||||
WorkspaceNameOrID: "bar",
|
||||
AgentNameOrID: "baz",
|
||||
AppSlugOrPort: "qux",
|
||||
},
|
||||
|
||||
Expiry: time.Time{},
|
||||
UserID: uuid.MustParse("b1530ba9-76f3-415e-b597-4ddd7cd466a4"),
|
||||
WorkspaceID: uuid.MustParse("1e6802d3-963e-45ac-9d8c-bf997016ffed"),
|
||||
AgentID: uuid.MustParse("9ec18681-d2c9-4c9e-9186-f136efb4edbe"),
|
||||
AppURL: "http://127.0.0.1:8080",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
token, err := coderdtest.AppSecurityKey.VerifySignedToken(tokenStr)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.WithinDuration(t, time.Now().Add(time.Minute), token.Expiry, 15*time.Second)
|
||||
})
|
||||
|
||||
future := time.Now().Add(time.Hour)
|
||||
cases := []struct {
|
||||
name string
|
||||
token workspaceapps.SignedToken
|
||||
parseErrContains string
|
||||
}{
|
||||
{
|
||||
name: "OK1",
|
||||
token: workspaceapps.SignedToken{
|
||||
Request: workspaceapps.Request{
|
||||
AccessMethod: workspaceapps.AccessMethodPath,
|
||||
BasePath: "/app",
|
||||
UsernameOrID: "foo",
|
||||
WorkspaceNameOrID: "bar",
|
||||
AgentNameOrID: "baz",
|
||||
AppSlugOrPort: "qux",
|
||||
},
|
||||
|
||||
Expiry: future,
|
||||
UserID: uuid.MustParse("b1530ba9-76f3-415e-b597-4ddd7cd466a4"),
|
||||
WorkspaceID: uuid.MustParse("1e6802d3-963e-45ac-9d8c-bf997016ffed"),
|
||||
AgentID: uuid.MustParse("9ec18681-d2c9-4c9e-9186-f136efb4edbe"),
|
||||
AppURL: "http://127.0.0.1:8080",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "OK2",
|
||||
token: workspaceapps.SignedToken{
|
||||
Request: workspaceapps.Request{
|
||||
AccessMethod: workspaceapps.AccessMethodSubdomain,
|
||||
BasePath: "/",
|
||||
UsernameOrID: "oof",
|
||||
WorkspaceNameOrID: "rab",
|
||||
AgentNameOrID: "zab",
|
||||
AppSlugOrPort: "xuq",
|
||||
},
|
||||
|
||||
Expiry: future,
|
||||
UserID: uuid.MustParse("6fa684a3-11aa-49fd-8512-ab527bd9b900"),
|
||||
WorkspaceID: uuid.MustParse("b2d816cc-505c-441d-afdf-dae01781bc0b"),
|
||||
AgentID: uuid.MustParse("6c4396e1-af88-4a8a-91a3-13ea54fc29fb"),
|
||||
AppURL: "http://localhost:9090",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Expired",
|
||||
token: workspaceapps.SignedToken{
|
||||
Request: workspaceapps.Request{
|
||||
AccessMethod: workspaceapps.AccessMethodSubdomain,
|
||||
BasePath: "/",
|
||||
UsernameOrID: "foo",
|
||||
WorkspaceNameOrID: "bar",
|
||||
AgentNameOrID: "baz",
|
||||
AppSlugOrPort: "qux",
|
||||
},
|
||||
|
||||
Expiry: time.Now().Add(-time.Hour),
|
||||
UserID: uuid.MustParse("b1530ba9-76f3-415e-b597-4ddd7cd466a4"),
|
||||
WorkspaceID: uuid.MustParse("1e6802d3-963e-45ac-9d8c-bf997016ffed"),
|
||||
AgentID: uuid.MustParse("9ec18681-d2c9-4c9e-9186-f136efb4edbe"),
|
||||
AppURL: "http://127.0.0.1:8080",
|
||||
},
|
||||
parseErrContains: "token expired",
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
c := c
|
||||
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
str, err := coderdtest.AppSecurityKey.SignToken(c.token)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Tokens aren't deterministic as they have a random nonce, so we
|
||||
// can't compare them directly.
|
||||
|
||||
token, err := coderdtest.AppSecurityKey.VerifySignedToken(str)
|
||||
if c.parseErrContains != "" {
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, c.parseErrContains)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
// normalize the expiry
|
||||
require.WithinDuration(t, c.token.Expiry, token.Expiry, 10*time.Second)
|
||||
c.token.Expiry = token.Expiry
|
||||
require.Equal(t, c.token, token)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_FromRequest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -419,7 +296,13 @@ func Test_FromRequest(t *testing.T) {
|
||||
Value: "invalid",
|
||||
})
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
signer := newSigner(t)
|
||||
|
||||
token := workspaceapps.SignedToken{
|
||||
RegisteredClaims: jwtutils.RegisteredClaims{
|
||||
Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
||||
},
|
||||
Request: workspaceapps.Request{
|
||||
AccessMethod: workspaceapps.AccessMethodSubdomain,
|
||||
BasePath: "/",
|
||||
@@ -429,7 +312,6 @@ func Test_FromRequest(t *testing.T) {
|
||||
AgentNameOrID: "agent",
|
||||
AppSlugOrPort: "app",
|
||||
},
|
||||
Expiry: time.Now().Add(time.Hour),
|
||||
UserID: uuid.New(),
|
||||
WorkspaceID: uuid.New(),
|
||||
AgentID: uuid.New(),
|
||||
@@ -438,16 +320,15 @@ func Test_FromRequest(t *testing.T) {
|
||||
|
||||
// Add an expired cookie
|
||||
expired := token
|
||||
expired.Expiry = time.Now().Add(time.Hour * -1)
|
||||
expiredStr, err := coderdtest.AppSecurityKey.SignToken(token)
|
||||
expired.RegisteredClaims.Expiry = jwt.NewNumericDate(time.Now().Add(time.Hour * -1))
|
||||
expiredStr, err := jwtutils.Sign(ctx, signer, expired)
|
||||
require.NoError(t, err)
|
||||
r.AddCookie(&http.Cookie{
|
||||
Name: codersdk.SignedAppTokenCookie,
|
||||
Value: expiredStr,
|
||||
})
|
||||
|
||||
// Add a valid token
|
||||
validStr, err := coderdtest.AppSecurityKey.SignToken(token)
|
||||
validStr, err := jwtutils.Sign(ctx, signer, token)
|
||||
require.NoError(t, err)
|
||||
|
||||
r.AddCookie(&http.Cookie{
|
||||
@@ -455,147 +336,27 @@ func Test_FromRequest(t *testing.T) {
|
||||
Value: validStr,
|
||||
})
|
||||
|
||||
signed, ok := workspaceapps.FromRequest(r, coderdtest.AppSecurityKey)
|
||||
signed, ok := workspaceapps.FromRequest(r, signer)
|
||||
require.True(t, ok, "expected a token to be found")
|
||||
// Confirm it is the correct token.
|
||||
require.Equal(t, signed.UserID, token.UserID)
|
||||
})
|
||||
}
|
||||
|
||||
// The ParseToken fn is tested quite thoroughly in the GenerateToken test as
|
||||
// well.
|
||||
func Test_ParseToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
func newSigner(t *testing.T) jwtutils.StaticKey {
|
||||
t.Helper()
|
||||
|
||||
t.Run("InvalidJWS", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
token, err := coderdtest.AppSecurityKey.VerifySignedToken("invalid")
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "parse JWS")
|
||||
require.Equal(t, workspaceapps.SignedToken{}, token)
|
||||
})
|
||||
|
||||
t.Run("VerifySignature", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a valid token using a different key.
|
||||
var otherKey workspaceapps.SecurityKey
|
||||
copy(otherKey[:], coderdtest.AppSecurityKey[:])
|
||||
for i := range otherKey {
|
||||
otherKey[i] ^= 0xff
|
||||
}
|
||||
require.NotEqual(t, coderdtest.AppSecurityKey, otherKey)
|
||||
|
||||
tokenStr, err := otherKey.SignToken(workspaceapps.SignedToken{
|
||||
Request: workspaceapps.Request{
|
||||
AccessMethod: workspaceapps.AccessMethodPath,
|
||||
BasePath: "/app",
|
||||
UsernameOrID: "foo",
|
||||
WorkspaceNameOrID: "bar",
|
||||
AgentNameOrID: "baz",
|
||||
AppSlugOrPort: "qux",
|
||||
},
|
||||
|
||||
Expiry: time.Now().Add(time.Hour),
|
||||
UserID: uuid.MustParse("b1530ba9-76f3-415e-b597-4ddd7cd466a4"),
|
||||
WorkspaceID: uuid.MustParse("1e6802d3-963e-45ac-9d8c-bf997016ffed"),
|
||||
AgentID: uuid.MustParse("9ec18681-d2c9-4c9e-9186-f136efb4edbe"),
|
||||
AppURL: "http://127.0.0.1:8080",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the token is invalid.
|
||||
token, err := coderdtest.AppSecurityKey.VerifySignedToken(tokenStr)
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "verify JWS")
|
||||
require.Equal(t, workspaceapps.SignedToken{}, token)
|
||||
})
|
||||
|
||||
t.Run("InvalidBody", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a signature for an invalid body.
|
||||
signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS512, Key: coderdtest.AppSecurityKey[:64]}, nil)
|
||||
require.NoError(t, err)
|
||||
signedObject, err := signer.Sign([]byte("hi"))
|
||||
require.NoError(t, err)
|
||||
serialized, err := signedObject.CompactSerialize()
|
||||
require.NoError(t, err)
|
||||
|
||||
token, err := coderdtest.AppSecurityKey.VerifySignedToken(serialized)
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "unmarshal payload")
|
||||
require.Equal(t, workspaceapps.SignedToken{}, token)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAPIKeyEncryption(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
genAPIKey := func(t *testing.T) string {
|
||||
id, _ := cryptorand.String(10)
|
||||
secret, _ := cryptorand.String(22)
|
||||
|
||||
return fmt.Sprintf("%s-%s", id, secret)
|
||||
return jwtutils.StaticKey{
|
||||
ID: "test",
|
||||
Key: generateSecret(t, 64),
|
||||
}
|
||||
|
||||
t.Run("OK", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
key := genAPIKey(t)
|
||||
encrypted, err := coderdtest.AppSecurityKey.EncryptAPIKey(workspaceapps.EncryptedAPIKeyPayload{
|
||||
APIKey: key,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
decryptedKey, err := coderdtest.AppSecurityKey.DecryptAPIKey(encrypted)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, key, decryptedKey)
|
||||
})
|
||||
|
||||
t.Run("Verifies", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("Expiry", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
key := genAPIKey(t)
|
||||
encrypted, err := coderdtest.AppSecurityKey.EncryptAPIKey(workspaceapps.EncryptedAPIKeyPayload{
|
||||
APIKey: key,
|
||||
ExpiresAt: dbtime.Now().Add(-1 * time.Hour),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
decryptedKey, err := coderdtest.AppSecurityKey.DecryptAPIKey(encrypted)
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "expired")
|
||||
require.Empty(t, decryptedKey)
|
||||
})
|
||||
|
||||
t.Run("EncryptionKey", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a valid token using a different key.
|
||||
var otherKey workspaceapps.SecurityKey
|
||||
copy(otherKey[:], coderdtest.AppSecurityKey[:])
|
||||
for i := range otherKey {
|
||||
otherKey[i] ^= 0xff
|
||||
}
|
||||
require.NotEqual(t, coderdtest.AppSecurityKey, otherKey)
|
||||
|
||||
// Encrypt with the other key.
|
||||
key := genAPIKey(t)
|
||||
encrypted, err := otherKey.EncryptAPIKey(workspaceapps.EncryptedAPIKeyPayload{
|
||||
APIKey: key,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Decrypt with the original key.
|
||||
decryptedKey, err := coderdtest.AppSecurityKey.DecryptAPIKey(encrypted)
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "decrypt API key")
|
||||
require.Empty(t, decryptedKey)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func generateSecret(t *testing.T, size int) []byte {
|
||||
t.Helper()
|
||||
|
||||
secret := make([]byte, size)
|
||||
_, err := rand.Read(secret)
|
||||
require.NoError(t, err)
|
||||
return secret
|
||||
}
|
||||
|
||||
@@ -5,16 +5,23 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-jose/go-jose/v4/jwt"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/cryptokeys"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/coderd/jwtutils"
|
||||
"github.com/coder/coder/v2/coderd/workspaceapps"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
func TestGetAppHost(t *testing.T) {
|
||||
@@ -181,16 +188,28 @@ func TestWorkspaceApplicationAuth(t *testing.T) {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, pubsub := dbtestutil.NewDB(t)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
logger := slogtest.Make(t, nil)
|
||||
accessURL, err := url.Parse(c.accessURL)
|
||||
require.NoError(t, err)
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
fetcher := &cryptokeys.DBFetcher{
|
||||
DB: db,
|
||||
}
|
||||
|
||||
kc, err := cryptokeys.NewEncryptionCache(ctx, logger, fetcher, codersdk.CryptoKeyFeatureWorkspaceAppsAPIKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
clock := quartz.NewMock(t)
|
||||
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
Database: db,
|
||||
Pubsub: pubsub,
|
||||
AccessURL: accessURL,
|
||||
AppHostname: c.appHostname,
|
||||
AccessURL: accessURL,
|
||||
AppHostname: c.appHostname,
|
||||
Database: db,
|
||||
Pubsub: ps,
|
||||
APIKeyEncryptionCache: kc,
|
||||
Clock: clock,
|
||||
})
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
@@ -240,7 +259,15 @@ func TestWorkspaceApplicationAuth(t *testing.T) {
|
||||
loc.RawQuery = q.Encode()
|
||||
require.Equal(t, c.expectRedirect, loc.String())
|
||||
|
||||
// The decrypted key is verified in the apptest test suite.
|
||||
var token workspaceapps.EncryptedAPIKeyPayload
|
||||
err = jwtutils.Decrypt(ctx, kc, encryptedAPIKey, &token, jwtutils.WithDecryptExpected(jwt.Expected{
|
||||
Time: clock.Now(),
|
||||
AnyAudience: jwt.Audience{"wsproxy"},
|
||||
Issuer: "coderd",
|
||||
}))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, jwt.NewNumericDate(clock.Now().Add(time.Minute)), token.Expiry)
|
||||
require.Equal(t, jwt.NewNumericDate(clock.Now().Add(-time.Minute)), token.NotBefore)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3113,9 +3113,11 @@ func (c *Client) SSHConfiguration(ctx context.Context) (SSHConfigResponse, error
|
||||
type CryptoKeyFeature string
|
||||
|
||||
const (
|
||||
CryptoKeyFeatureWorkspaceApp CryptoKeyFeature = "workspace_apps"
|
||||
CryptoKeyFeatureOIDCConvert CryptoKeyFeature = "oidc_convert"
|
||||
CryptoKeyFeatureTailnetResume CryptoKeyFeature = "tailnet_resume"
|
||||
CryptoKeyFeatureWorkspaceAppsAPIKey CryptoKeyFeature = "workspace_apps_api_key"
|
||||
//nolint:gosec // This denotes a type of key, not a literal.
|
||||
CryptoKeyFeatureWorkspaceAppsToken CryptoKeyFeature = "workspace_apps_token"
|
||||
CryptoKeyFeatureOIDCConvert CryptoKeyFeature = "oidc_convert"
|
||||
CryptoKeyFeatureTailnetResume CryptoKeyFeature = "tailnet_resume"
|
||||
)
|
||||
|
||||
type CryptoKey struct {
|
||||
|
||||
@@ -25,6 +25,7 @@ import (
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/apiversion"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/coderd/jwtutils"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/tailnet"
|
||||
"github.com/coder/coder/v2/tailnet/proto"
|
||||
@@ -61,7 +62,7 @@ func TestTailnetAPIConnector_Disconnects(t *testing.T) {
|
||||
CoordPtr: &coordPtr,
|
||||
DERPMapUpdateFrequency: time.Millisecond,
|
||||
DERPMapFn: func() *tailcfg.DERPMap { return <-derpMapCh },
|
||||
NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) {},
|
||||
NetworkTelemetryHandler: func([]*proto.TelemetryEvent) {},
|
||||
ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
@@ -165,13 +166,17 @@ func TestTailnetAPIConnector_ResumeToken(t *testing.T) {
|
||||
clock := quartz.NewMock(t)
|
||||
resumeTokenSigningKey, err := tailnet.GenerateResumeTokenSigningKey()
|
||||
require.NoError(t, err)
|
||||
resumeTokenProvider := tailnet.NewResumeTokenKeyProvider(resumeTokenSigningKey, clock, time.Hour)
|
||||
mgr := jwtutils.StaticKey{
|
||||
ID: "123",
|
||||
Key: resumeTokenSigningKey[:],
|
||||
}
|
||||
resumeTokenProvider := tailnet.NewResumeTokenKeyProvider(mgr, clock, time.Hour)
|
||||
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
|
||||
Logger: logger,
|
||||
CoordPtr: &coordPtr,
|
||||
DERPMapUpdateFrequency: time.Millisecond,
|
||||
DERPMapFn: func() *tailcfg.DERPMap { return <-derpMapCh },
|
||||
NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) {},
|
||||
NetworkTelemetryHandler: func([]*proto.TelemetryEvent) {},
|
||||
ResumeTokenProvider: resumeTokenProvider,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
@@ -190,7 +195,7 @@ func TestTailnetAPIConnector_ResumeToken(t *testing.T) {
|
||||
t.Logf("received resume token: %s", resumeToken)
|
||||
assert.Equal(t, expectResumeToken, resumeToken)
|
||||
if resumeToken != "" {
|
||||
peerID, err = resumeTokenProvider.VerifyResumeToken(resumeToken)
|
||||
peerID, err = resumeTokenProvider.VerifyResumeToken(ctx, resumeToken)
|
||||
assert.NoError(t, err, "failed to parse resume token")
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, w, http.StatusUnauthorized, codersdk.Response{
|
||||
@@ -280,13 +285,17 @@ func TestTailnetAPIConnector_ResumeTokenFailure(t *testing.T) {
|
||||
clock := quartz.NewMock(t)
|
||||
resumeTokenSigningKey, err := tailnet.GenerateResumeTokenSigningKey()
|
||||
require.NoError(t, err)
|
||||
resumeTokenProvider := tailnet.NewResumeTokenKeyProvider(resumeTokenSigningKey, clock, time.Hour)
|
||||
mgr := jwtutils.StaticKey{
|
||||
ID: uuid.New().String(),
|
||||
Key: resumeTokenSigningKey[:],
|
||||
}
|
||||
resumeTokenProvider := tailnet.NewResumeTokenKeyProvider(mgr, clock, time.Hour)
|
||||
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
|
||||
Logger: logger,
|
||||
CoordPtr: &coordPtr,
|
||||
DERPMapUpdateFrequency: time.Millisecond,
|
||||
DERPMapFn: func() *tailcfg.DERPMap { return <-derpMapCh },
|
||||
NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) {},
|
||||
NetworkTelemetryHandler: func(_ []*proto.TelemetryEvent) {},
|
||||
ResumeTokenProvider: resumeTokenProvider,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
Generated
+9
-10
@@ -1454,7 +1454,7 @@ CreateWorkspaceRequest provides options for creating a new workspace. Only one o
|
||||
```json
|
||||
{
|
||||
"deletes_at": "2019-08-24T14:15:22Z",
|
||||
"feature": "workspace_apps",
|
||||
"feature": "workspace_apps_api_key",
|
||||
"secret": "string",
|
||||
"sequence": 0,
|
||||
"starts_at": "2019-08-24T14:15:22Z"
|
||||
@@ -1474,18 +1474,19 @@ CreateWorkspaceRequest provides options for creating a new workspace. Only one o
|
||||
## codersdk.CryptoKeyFeature
|
||||
|
||||
```json
|
||||
"workspace_apps"
|
||||
"workspace_apps_api_key"
|
||||
```
|
||||
|
||||
### Properties
|
||||
|
||||
#### Enumerated Values
|
||||
|
||||
| Value |
|
||||
| ---------------- |
|
||||
| `workspace_apps` |
|
||||
| `oidc_convert` |
|
||||
| `tailnet_resume` |
|
||||
| Value |
|
||||
| ------------------------ |
|
||||
| `workspace_apps_api_key` |
|
||||
| `workspace_apps_token` |
|
||||
| `oidc_convert` |
|
||||
| `tailnet_resume` |
|
||||
|
||||
## codersdk.CustomRoleRequest
|
||||
|
||||
@@ -9893,7 +9894,7 @@ _None_
|
||||
"crypto_keys": [
|
||||
{
|
||||
"deletes_at": "2019-08-24T14:15:22Z",
|
||||
"feature": "workspace_apps",
|
||||
"feature": "workspace_apps_api_key",
|
||||
"secret": "string",
|
||||
"sequence": 0,
|
||||
"starts_at": "2019-08-24T14:15:22Z"
|
||||
@@ -9971,7 +9972,6 @@ _None_
|
||||
|
||||
```json
|
||||
{
|
||||
"app_security_key": "string",
|
||||
"derp_force_websockets": true,
|
||||
"derp_map": {
|
||||
"homeParams": {
|
||||
@@ -10052,7 +10052,6 @@ _None_
|
||||
|
||||
| Name | Type | Required | Restrictions | Description |
|
||||
| ----------------------- | --------------------------------------------- | -------- | ------------ | -------------------------------------------------------------------------------------- |
|
||||
| `app_security_key` | string | false | | |
|
||||
| `derp_force_websockets` | boolean | false | | |
|
||||
| `derp_map` | [tailcfg.DERPMap](#tailcfgderpmap) | false | | |
|
||||
| `derp_mesh_key` | string | false | | |
|
||||
|
||||
@@ -65,6 +65,8 @@ type WorkspaceProxy struct {
|
||||
// owner client. If a token is provided, the proxy will become a replica of the
|
||||
// existing proxy region.
|
||||
func NewWorkspaceProxyReplica(t *testing.T, coderdAPI *coderd.API, owner *codersdk.Client, options *ProxyOptions) WorkspaceProxy {
|
||||
t.Helper()
|
||||
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
t.Cleanup(cancelFunc)
|
||||
|
||||
@@ -142,8 +144,10 @@ func NewWorkspaceProxyReplica(t *testing.T, coderdAPI *coderd.API, owner *coders
|
||||
statsCollectorOptions.Flush = options.FlushStats
|
||||
}
|
||||
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug).With(slog.F("server_url", serverURL.String()))
|
||||
|
||||
wssrv, err := wsproxy.New(ctx, &wsproxy.Options{
|
||||
Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug).With(slog.F("server_url", serverURL.String())),
|
||||
Logger: logger,
|
||||
Experiments: options.Experiments,
|
||||
DashboardURL: coderdAPI.AccessURL,
|
||||
AccessURL: accessURL,
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -33,6 +34,13 @@ import (
|
||||
"github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk"
|
||||
)
|
||||
|
||||
// whitelistedCryptoKeyFeatures is a list of crypto key features that are
|
||||
// allowed to be queried with workspace proxies.
|
||||
var whitelistedCryptoKeyFeatures = []database.CryptoKeyFeature{
|
||||
database.CryptoKeyFeatureWorkspaceAppsToken,
|
||||
database.CryptoKeyFeatureWorkspaceAppsAPIKey,
|
||||
}
|
||||
|
||||
// forceWorkspaceProxyHealthUpdate forces an update of the proxy health.
|
||||
// This is useful when a proxy is created or deleted. Errors will be logged.
|
||||
func (api *API) forceWorkspaceProxyHealthUpdate(ctx context.Context) {
|
||||
@@ -700,7 +708,6 @@ func (api *API) workspaceProxyRegister(rw http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusCreated, wsproxysdk.RegisterWorkspaceProxyResponse{
|
||||
AppSecurityKey: api.AppSecurityKey.String(),
|
||||
DERPMeshKey: api.DERPServer.MeshKey(),
|
||||
DERPRegionID: regionID,
|
||||
DERPMap: api.AGPL.DERPMap(),
|
||||
@@ -721,13 +728,29 @@ func (api *API) workspaceProxyRegister(rw http.ResponseWriter, r *http.Request)
|
||||
// @Security CoderSessionToken
|
||||
// @Produce json
|
||||
// @Tags Enterprise
|
||||
// @Param feature query string true "Feature key"
|
||||
// @Success 200 {object} wsproxysdk.CryptoKeysResponse
|
||||
// @Router /workspaceproxies/me/crypto-keys [get]
|
||||
// @x-apidocgen {"skip": true}
|
||||
func (api *API) workspaceProxyCryptoKeys(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
keys, err := api.Database.GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps)
|
||||
feature := database.CryptoKeyFeature(r.URL.Query().Get("feature"))
|
||||
if feature == "" {
|
||||
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Missing feature query parameter.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if !slices.Contains(whitelistedCryptoKeyFeatures, feature) {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: fmt.Sprintf("Invalid feature: %q", feature),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
keys, err := api.Database.GetCryptoKeysByFeature(ctx, feature)
|
||||
if err != nil {
|
||||
httpapi.InternalServerError(rw, err)
|
||||
return
|
||||
|
||||
@@ -320,7 +320,6 @@ func TestProxyRegisterDeregister(t *testing.T) {
|
||||
}
|
||||
registerRes1, err := proxyClient.RegisterWorkspaceProxy(ctx, req)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, registerRes1.AppSecurityKey)
|
||||
require.NotEmpty(t, registerRes1.DERPMeshKey)
|
||||
require.EqualValues(t, 10001, registerRes1.DERPRegionID)
|
||||
require.Empty(t, registerRes1.SiblingReplicas)
|
||||
@@ -609,11 +608,8 @@ func TestProxyRegisterDeregister(t *testing.T) {
|
||||
func TestIssueSignedAppToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, pubsub := dbtestutil.NewDB(t)
|
||||
client, user := coderdenttest.New(t, &coderdenttest.Options{
|
||||
Options: &coderdtest.Options{
|
||||
Database: db,
|
||||
Pubsub: pubsub,
|
||||
IncludeProvisionerDaemon: true,
|
||||
},
|
||||
LicenseOptions: &coderdenttest.LicenseOptions{
|
||||
@@ -716,6 +712,10 @@ func TestReconnectingPTYSignedToken(t *testing.T) {
|
||||
closer.Close()
|
||||
})
|
||||
|
||||
_ = dbgen.CryptoKey(t, db, database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureWorkspaceAppsToken,
|
||||
})
|
||||
|
||||
// Create a workspace + apps
|
||||
authToken := uuid.NewString()
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
|
||||
@@ -915,51 +915,86 @@ func TestGetCryptoKeys(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
expectedKey1 := dbgen.CryptoKey(t, db, database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureWorkspaceApps,
|
||||
Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey,
|
||||
StartsAt: now.Add(-time.Hour),
|
||||
Sequence: 2,
|
||||
})
|
||||
key1 := db2sdk.CryptoKey(expectedKey1)
|
||||
encryptionKey := db2sdk.CryptoKey(expectedKey1)
|
||||
|
||||
expectedKey2 := dbgen.CryptoKey(t, db, database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureWorkspaceApps,
|
||||
Feature: database.CryptoKeyFeatureWorkspaceAppsToken,
|
||||
StartsAt: now,
|
||||
Sequence: 3,
|
||||
})
|
||||
key2 := db2sdk.CryptoKey(expectedKey2)
|
||||
signingKey := db2sdk.CryptoKey(expectedKey2)
|
||||
|
||||
// Create a deleted key.
|
||||
_ = dbgen.CryptoKey(t, db, database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureWorkspaceApps,
|
||||
Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey,
|
||||
StartsAt: now.Add(-time.Hour),
|
||||
Secret: sql.NullString{
|
||||
String: "secret1",
|
||||
Valid: false,
|
||||
},
|
||||
Sequence: 1,
|
||||
})
|
||||
|
||||
// Create a key with different features.
|
||||
_ = dbgen.CryptoKey(t, db, database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureTailnetResume,
|
||||
StartsAt: now.Add(-time.Hour),
|
||||
Sequence: 1,
|
||||
})
|
||||
_ = dbgen.CryptoKey(t, db, database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureOidcConvert,
|
||||
StartsAt: now.Add(-time.Hour),
|
||||
Sequence: 1,
|
||||
Sequence: 4,
|
||||
})
|
||||
|
||||
proxy := coderdenttest.NewWorkspaceProxyReplica(t, api, cclient, &coderdenttest.ProxyOptions{
|
||||
Name: testutil.GetRandomName(t),
|
||||
})
|
||||
|
||||
keys, err := proxy.SDKClient.CryptoKeys(ctx)
|
||||
keys, err := proxy.SDKClient.CryptoKeys(ctx, codersdk.CryptoKeyFeatureWorkspaceAppsAPIKey)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, keys)
|
||||
// 1 key is generated on startup, the other we manually generated.
|
||||
require.Equal(t, 2, len(keys.CryptoKeys))
|
||||
requireContainsKeys(t, keys.CryptoKeys, key1, key2)
|
||||
requireContainsKeys(t, keys.CryptoKeys, encryptionKey)
|
||||
requireNotContainsKeys(t, keys.CryptoKeys, signingKey)
|
||||
|
||||
keys, err = proxy.SDKClient.CryptoKeys(ctx, codersdk.CryptoKeyFeatureWorkspaceAppsToken)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, keys)
|
||||
// 1 key is generated on startup, the other we manually generated.
|
||||
require.Equal(t, 2, len(keys.CryptoKeys))
|
||||
requireContainsKeys(t, keys.CryptoKeys, signingKey)
|
||||
requireNotContainsKeys(t, keys.CryptoKeys, encryptionKey)
|
||||
})
|
||||
|
||||
t.Run("InvalidFeature", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
db, pubsub := dbtestutil.NewDB(t)
|
||||
cclient, _, api, _ := coderdenttest.NewWithAPI(t, &coderdenttest.Options{
|
||||
Options: &coderdtest.Options{
|
||||
Database: db,
|
||||
Pubsub: pubsub,
|
||||
IncludeProvisionerDaemon: true,
|
||||
},
|
||||
LicenseOptions: &coderdenttest.LicenseOptions{
|
||||
Features: license.Features{
|
||||
codersdk.FeatureWorkspaceProxy: 1,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
proxy := coderdenttest.NewWorkspaceProxyReplica(t, api, cclient, &coderdenttest.ProxyOptions{
|
||||
Name: testutil.GetRandomName(t),
|
||||
})
|
||||
|
||||
_, err := proxy.SDKClient.CryptoKeys(ctx, codersdk.CryptoKeyFeatureOIDCConvert)
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
_, err = proxy.SDKClient.CryptoKeys(ctx, codersdk.CryptoKeyFeatureTailnetResume)
|
||||
require.Error(t, err)
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
_, err = proxy.SDKClient.CryptoKeys(ctx, "invalid")
|
||||
require.Error(t, err)
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("Unauthorized", func(t *testing.T) {
|
||||
@@ -987,7 +1022,7 @@ func TestGetCryptoKeys(t *testing.T) {
|
||||
client := wsproxysdk.New(cclient.URL)
|
||||
client.SetSessionToken(cclient.SessionToken())
|
||||
|
||||
_, err := client.CryptoKeys(ctx)
|
||||
_, err := client.CryptoKeys(ctx, codersdk.CryptoKeyFeatureWorkspaceAppsAPIKey)
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
@@ -995,6 +1030,18 @@ func TestGetCryptoKeys(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func requireNotContainsKeys(t *testing.T, keys []codersdk.CryptoKey, unexpected ...codersdk.CryptoKey) {
|
||||
t.Helper()
|
||||
|
||||
for _, unexpectedKey := range unexpected {
|
||||
for _, key := range keys {
|
||||
if key.Feature == unexpectedKey.Feature && key.Sequence == unexpectedKey.Sequence {
|
||||
t.Fatalf("unexpected key %+v found", unexpectedKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func requireContainsKeys(t *testing.T, keys []codersdk.CryptoKey, expected ...codersdk.CryptoKey) {
|
||||
t.Helper()
|
||||
|
||||
|
||||
@@ -397,12 +397,12 @@ func TestCryptoKeys(t *testing.T) {
|
||||
_ = dbgen.CryptoKey(t, crypt, database.CryptoKey{
|
||||
Secret: sql.NullString{String: "test", Valid: true},
|
||||
})
|
||||
key, err := crypt.GetLatestCryptoKeyByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps)
|
||||
key, err := crypt.GetLatestCryptoKeyByFeature(ctx, database.CryptoKeyFeatureWorkspaceAppsAPIKey)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "test", key.Secret.String)
|
||||
require.Equal(t, ciphers[0].HexDigest(), key.SecretKeyID.String)
|
||||
|
||||
key, err = db.GetLatestCryptoKeyByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps)
|
||||
key, err = db.GetLatestCryptoKeyByFeature(ctx, database.CryptoKeyFeatureWorkspaceAppsAPIKey)
|
||||
require.NoError(t, err)
|
||||
requireEncryptedEquals(t, ciphers[0], key.Secret.String, "test")
|
||||
require.Equal(t, ciphers[0].HexDigest(), key.SecretKeyID.String)
|
||||
@@ -415,7 +415,7 @@ func TestCryptoKeys(t *testing.T) {
|
||||
Secret: sql.NullString{String: "test", Valid: true},
|
||||
})
|
||||
key, err := crypt.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{
|
||||
Feature: database.CryptoKeyFeatureWorkspaceApps,
|
||||
Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey,
|
||||
Sequence: key.Sequence,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
@@ -423,7 +423,7 @@ func TestCryptoKeys(t *testing.T) {
|
||||
require.Equal(t, ciphers[0].HexDigest(), key.SecretKeyID.String)
|
||||
|
||||
key, err = db.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{
|
||||
Feature: database.CryptoKeyFeatureWorkspaceApps,
|
||||
Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey,
|
||||
Sequence: key.Sequence,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
@@ -459,7 +459,7 @@ func TestCryptoKeys(t *testing.T) {
|
||||
Secret: sql.NullString{String: "test", Valid: true},
|
||||
})
|
||||
_ = dbgen.CryptoKey(t, crypt, database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureWorkspaceApps,
|
||||
Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey,
|
||||
Sequence: 43,
|
||||
})
|
||||
keys, err := crypt.GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureTailnetResume)
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
"github.com/coder/coder/v2/coderd/workspaceapps/apptest"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
@@ -36,6 +37,9 @@ func TestWorkspaceApps(t *testing.T) {
|
||||
flushStatsCollectorCh <- flushStatsCollectorDone
|
||||
<-flushStatsCollectorDone
|
||||
}
|
||||
|
||||
db, pubsub := dbtestutil.NewDB(t)
|
||||
|
||||
client, _, _, user := coderdenttest.NewWithAPI(t, &coderdenttest.Options{
|
||||
Options: &coderdtest.Options{
|
||||
DeploymentValues: deploymentValues,
|
||||
@@ -51,6 +55,8 @@ func TestWorkspaceApps(t *testing.T) {
|
||||
},
|
||||
},
|
||||
WorkspaceAppsStatsCollectorOptions: opts.StatsCollectorOptions,
|
||||
Database: db,
|
||||
Pubsub: pubsub,
|
||||
},
|
||||
LicenseOptions: &coderdenttest.LicenseOptions{
|
||||
Features: license.Features{
|
||||
|
||||
@@ -13,12 +13,11 @@ import (
|
||||
var _ cryptokeys.Fetcher = &ProxyFetcher{}
|
||||
|
||||
type ProxyFetcher struct {
|
||||
Client *wsproxysdk.Client
|
||||
Feature codersdk.CryptoKeyFeature
|
||||
Client *wsproxysdk.Client
|
||||
}
|
||||
|
||||
func (p *ProxyFetcher) Fetch(ctx context.Context) ([]codersdk.CryptoKey, error) {
|
||||
keys, err := p.Client.CryptoKeys(ctx)
|
||||
func (p *ProxyFetcher) Fetch(ctx context.Context, feature codersdk.CryptoKeyFeature) ([]codersdk.CryptoKey, error) {
|
||||
keys, err := p.Client.CryptoKeys(ctx, feature)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("crypto keys: %w", err)
|
||||
}
|
||||
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
|
||||
"cdr.dev/slog"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/cryptokeys"
|
||||
"github.com/coder/coder/v2/coderd/jwtutils"
|
||||
"github.com/coder/coder/v2/coderd/workspaceapps"
|
||||
"github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk"
|
||||
)
|
||||
@@ -18,18 +20,19 @@ type TokenProvider struct {
|
||||
AccessURL *url.URL
|
||||
AppHostname string
|
||||
|
||||
Client *wsproxysdk.Client
|
||||
SecurityKey workspaceapps.SecurityKey
|
||||
Logger slog.Logger
|
||||
Client *wsproxysdk.Client
|
||||
TokenSigningKeycache cryptokeys.SigningKeycache
|
||||
APIKeyEncryptionKeycache cryptokeys.EncryptionKeycache
|
||||
Logger slog.Logger
|
||||
}
|
||||
|
||||
func (p *TokenProvider) FromRequest(r *http.Request) (*workspaceapps.SignedToken, bool) {
|
||||
return workspaceapps.FromRequest(r, p.SecurityKey)
|
||||
return workspaceapps.FromRequest(r, p.TokenSigningKeycache)
|
||||
}
|
||||
|
||||
func (p *TokenProvider) Issue(ctx context.Context, rw http.ResponseWriter, r *http.Request, issueReq workspaceapps.IssueTokenRequest) (*workspaceapps.SignedToken, string, bool) {
|
||||
appReq := issueReq.AppRequest.Normalize()
|
||||
err := appReq.Validate()
|
||||
err := appReq.Check()
|
||||
if err != nil {
|
||||
workspaceapps.WriteWorkspaceApp500(p.Logger, p.DashboardURL, rw, r, &appReq, err, "invalid app request")
|
||||
return nil, "", false
|
||||
@@ -42,7 +45,8 @@ func (p *TokenProvider) Issue(ctx context.Context, rw http.ResponseWriter, r *ht
|
||||
}
|
||||
|
||||
// Check that it verifies properly and matches the string.
|
||||
token, err := p.SecurityKey.VerifySignedToken(resp.SignedTokenStr)
|
||||
var token workspaceapps.SignedToken
|
||||
err = jwtutils.Verify(ctx, p.TokenSigningKeycache, resp.SignedTokenStr, &token)
|
||||
if err != nil {
|
||||
workspaceapps.WriteWorkspaceApp500(p.Logger, p.DashboardURL, rw, r, &appReq, err, "failed to verify newly generated signed token")
|
||||
return nil, "", false
|
||||
|
||||
@@ -31,6 +31,7 @@ import (
|
||||
"github.com/coder/coder/v2/buildinfo"
|
||||
"github.com/coder/coder/v2/cli/cliutil"
|
||||
"github.com/coder/coder/v2/coderd"
|
||||
"github.com/coder/coder/v2/coderd/cryptokeys"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
"github.com/coder/coder/v2/coderd/tracing"
|
||||
@@ -130,6 +131,13 @@ type Server struct {
|
||||
// the moon's token.
|
||||
SDKClient *wsproxysdk.Client
|
||||
|
||||
// apiKeyEncryptionKeycache manages the encryption keys for smuggling API
|
||||
// tokens to the alternate domain when using workspace apps.
|
||||
apiKeyEncryptionKeycache cryptokeys.EncryptionKeycache
|
||||
// appTokenSigningKeycache manages the signing keys for signing the app
|
||||
// tokens we use for workspace apps.
|
||||
appTokenSigningKeycache cryptokeys.SigningKeycache
|
||||
|
||||
// DERP
|
||||
derpMesh *derpmesh.Mesh
|
||||
derpMeshTLSConfig *tls.Config
|
||||
@@ -195,19 +203,42 @@ func New(ctx context.Context, opts *Options) (*Server, error) {
|
||||
derpServer := derp.NewServer(key.NewNode(), tailnet.Logger(opts.Logger.Named("net.derp")))
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
encryptionCache, err := cryptokeys.NewEncryptionCache(ctx,
|
||||
opts.Logger,
|
||||
&ProxyFetcher{Client: client},
|
||||
codersdk.CryptoKeyFeatureWorkspaceAppsAPIKey,
|
||||
)
|
||||
if err != nil {
|
||||
cancel()
|
||||
return nil, xerrors.Errorf("create api key encryption cache: %w", err)
|
||||
}
|
||||
signingCache, err := cryptokeys.NewSigningCache(ctx,
|
||||
opts.Logger,
|
||||
&ProxyFetcher{Client: client},
|
||||
codersdk.CryptoKeyFeatureWorkspaceAppsToken,
|
||||
)
|
||||
if err != nil {
|
||||
cancel()
|
||||
return nil, xerrors.Errorf("create api token signing cache: %w", err)
|
||||
}
|
||||
|
||||
r := chi.NewRouter()
|
||||
s := &Server{
|
||||
Options: opts,
|
||||
Handler: r,
|
||||
DashboardURL: opts.DashboardURL,
|
||||
Logger: opts.Logger.Named("net.workspace-proxy"),
|
||||
TracerProvider: opts.Tracing,
|
||||
PrometheusRegistry: opts.PrometheusRegistry,
|
||||
SDKClient: client,
|
||||
derpMesh: derpmesh.New(opts.Logger.Named("net.derpmesh"), derpServer, meshTLSConfig),
|
||||
derpMeshTLSConfig: meshTLSConfig,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
|
||||
Options: opts,
|
||||
Handler: r,
|
||||
DashboardURL: opts.DashboardURL,
|
||||
Logger: opts.Logger.Named("net.workspace-proxy"),
|
||||
TracerProvider: opts.Tracing,
|
||||
PrometheusRegistry: opts.PrometheusRegistry,
|
||||
SDKClient: client,
|
||||
derpMesh: derpmesh.New(opts.Logger.Named("net.derpmesh"), derpServer, meshTLSConfig),
|
||||
derpMeshTLSConfig: meshTLSConfig,
|
||||
apiKeyEncryptionKeycache: encryptionCache,
|
||||
appTokenSigningKeycache: signingCache,
|
||||
}
|
||||
|
||||
// Register the workspace proxy with the primary coderd instance and start a
|
||||
@@ -240,11 +271,6 @@ func New(ctx context.Context, opts *Options) (*Server, error) {
|
||||
return nil, xerrors.Errorf("handle register: %w", err)
|
||||
}
|
||||
|
||||
secKey, err := workspaceapps.KeyFromString(regResp.AppSecurityKey)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse app security key: %w", err)
|
||||
}
|
||||
|
||||
agentProvider, err := coderd.NewServerTailnet(ctx,
|
||||
s.Logger,
|
||||
nil,
|
||||
@@ -277,20 +303,21 @@ func New(ctx context.Context, opts *Options) (*Server, error) {
|
||||
HostnameRegex: opts.AppHostnameRegex,
|
||||
RealIPConfig: opts.RealIPConfig,
|
||||
SignedTokenProvider: &TokenProvider{
|
||||
DashboardURL: opts.DashboardURL,
|
||||
AccessURL: opts.AccessURL,
|
||||
AppHostname: opts.AppHostname,
|
||||
Client: client,
|
||||
SecurityKey: secKey,
|
||||
Logger: s.Logger.Named("proxy_token_provider"),
|
||||
DashboardURL: opts.DashboardURL,
|
||||
AccessURL: opts.AccessURL,
|
||||
AppHostname: opts.AppHostname,
|
||||
Client: client,
|
||||
TokenSigningKeycache: signingCache,
|
||||
APIKeyEncryptionKeycache: encryptionCache,
|
||||
Logger: s.Logger.Named("proxy_token_provider"),
|
||||
},
|
||||
AppSecurityKey: secKey,
|
||||
|
||||
DisablePathApps: opts.DisablePathApps,
|
||||
SecureAuthCookie: opts.SecureAuthCookie,
|
||||
|
||||
AgentProvider: agentProvider,
|
||||
StatsCollector: workspaceapps.NewStatsCollector(opts.StatsCollectorOptions),
|
||||
AgentProvider: agentProvider,
|
||||
StatsCollector: workspaceapps.NewStatsCollector(opts.StatsCollectorOptions),
|
||||
APIKeyEncryptionKeycache: encryptionCache,
|
||||
}
|
||||
|
||||
derpHandler := derphttp.Handler(derpServer)
|
||||
@@ -435,6 +462,8 @@ func (s *Server) Close() error {
|
||||
err = multierror.Append(err, agentProviderErr)
|
||||
}
|
||||
s.SDKClient.SDKClient.HTTPClient.CloseIdleConnections()
|
||||
_ = s.appTokenSigningKeycache.Close()
|
||||
_ = s.apiKeyEncryptionKeycache.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@@ -25,6 +25,9 @@ import (
|
||||
"github.com/coder/coder/v2/agent/agenttest"
|
||||
"github.com/coder/coder/v2/buildinfo"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/coderd/healthcheck/derphealth"
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
"github.com/coder/coder/v2/coderd/workspaceapps/apptest"
|
||||
@@ -932,6 +935,9 @@ func TestWorkspaceProxyWorkspaceApps(t *testing.T) {
|
||||
if opts.PrimaryAppHost == "" {
|
||||
opts.PrimaryAppHost = "*.primary.test.coder.com"
|
||||
}
|
||||
|
||||
db, pubsub := dbtestutil.NewDB(t)
|
||||
|
||||
client, closer, api, user := coderdenttest.NewWithAPI(t, &coderdenttest.Options{
|
||||
Options: &coderdtest.Options{
|
||||
DeploymentValues: deploymentValues,
|
||||
@@ -947,6 +953,8 @@ func TestWorkspaceProxyWorkspaceApps(t *testing.T) {
|
||||
},
|
||||
},
|
||||
WorkspaceAppsStatsCollectorOptions: opts.StatsCollectorOptions,
|
||||
Database: db,
|
||||
Pubsub: pubsub,
|
||||
},
|
||||
LicenseOptions: &coderdenttest.LicenseOptions{
|
||||
Features: license.Features{
|
||||
@@ -959,6 +967,13 @@ func TestWorkspaceProxyWorkspaceApps(t *testing.T) {
|
||||
_ = closer.Close()
|
||||
})
|
||||
|
||||
_ = dbgen.CryptoKey(t, db, database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureWorkspaceAppsToken,
|
||||
})
|
||||
_ = dbgen.CryptoKey(t, db, database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey,
|
||||
})
|
||||
|
||||
// Create the external proxy
|
||||
if opts.DisableSubdomainApps {
|
||||
opts.AppHost = ""
|
||||
@@ -1002,6 +1017,8 @@ func TestWorkspaceProxyWorkspaceApps_BlockDirect(t *testing.T) {
|
||||
if opts.PrimaryAppHost == "" {
|
||||
opts.PrimaryAppHost = "*.primary.test.coder.com"
|
||||
}
|
||||
|
||||
db, pubsub := dbtestutil.NewDB(t)
|
||||
client, closer, api, user := coderdenttest.NewWithAPI(t, &coderdenttest.Options{
|
||||
Options: &coderdtest.Options{
|
||||
DeploymentValues: deploymentValues,
|
||||
@@ -1017,6 +1034,8 @@ func TestWorkspaceProxyWorkspaceApps_BlockDirect(t *testing.T) {
|
||||
},
|
||||
},
|
||||
WorkspaceAppsStatsCollectorOptions: opts.StatsCollectorOptions,
|
||||
Database: db,
|
||||
Pubsub: pubsub,
|
||||
},
|
||||
LicenseOptions: &coderdenttest.LicenseOptions{
|
||||
Features: license.Features{
|
||||
@@ -1029,6 +1048,13 @@ func TestWorkspaceProxyWorkspaceApps_BlockDirect(t *testing.T) {
|
||||
_ = closer.Close()
|
||||
})
|
||||
|
||||
_ = dbgen.CryptoKey(t, db, database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureWorkspaceAppsToken,
|
||||
})
|
||||
_ = dbgen.CryptoKey(t, db, database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureWorkspaceAppsAPIKey,
|
||||
})
|
||||
|
||||
// Create the external proxy
|
||||
if opts.DisableSubdomainApps {
|
||||
opts.AppHost = ""
|
||||
|
||||
@@ -205,7 +205,6 @@ type RegisterWorkspaceProxyRequest struct {
|
||||
}
|
||||
|
||||
type RegisterWorkspaceProxyResponse struct {
|
||||
AppSecurityKey string `json:"app_security_key"`
|
||||
DERPMeshKey string `json:"derp_mesh_key"`
|
||||
DERPRegionID int32 `json:"derp_region_id"`
|
||||
DERPMap *tailcfg.DERPMap `json:"derp_map"`
|
||||
@@ -372,12 +371,6 @@ func (l *RegisterWorkspaceProxyLoop) Start(ctx context.Context) (RegisterWorkspa
|
||||
}
|
||||
failedAttempts = 0
|
||||
|
||||
// Check for consistency.
|
||||
if originalRes.AppSecurityKey != resp.AppSecurityKey {
|
||||
l.failureFn(xerrors.New("app security key has changed, proxy must be restarted"))
|
||||
return
|
||||
}
|
||||
|
||||
if originalRes.DERPMeshKey != resp.DERPMeshKey {
|
||||
l.failureFn(xerrors.New("DERP mesh key has changed, proxy must be restarted"))
|
||||
return
|
||||
@@ -586,10 +579,10 @@ type CryptoKeysResponse struct {
|
||||
CryptoKeys []codersdk.CryptoKey `json:"crypto_keys"`
|
||||
}
|
||||
|
||||
func (c *Client) CryptoKeys(ctx context.Context) (CryptoKeysResponse, error) {
|
||||
func (c *Client) CryptoKeys(ctx context.Context, feature codersdk.CryptoKeyFeature) (CryptoKeysResponse, error) {
|
||||
res, err := c.Request(ctx, http.MethodGet,
|
||||
"/api/v2/workspaceproxies/me/crypto-keys",
|
||||
nil,
|
||||
"/api/v2/workspaceproxies/me/crypto-keys", nil,
|
||||
codersdk.WithQueryParam("feature", string(feature)),
|
||||
)
|
||||
if err != nil {
|
||||
return CryptoKeysResponse{}, xerrors.Errorf("make request: %w", err)
|
||||
|
||||
Generated
+2
-2
@@ -2110,8 +2110,8 @@ export type BuildReason = "autostart" | "autostop" | "initiator"
|
||||
export const BuildReasons: BuildReason[] = ["autostart", "autostop", "initiator"]
|
||||
|
||||
// From codersdk/deployment.go
|
||||
export type CryptoKeyFeature = "oidc_convert" | "tailnet_resume" | "workspace_apps"
|
||||
export const CryptoKeyFeatures: CryptoKeyFeature[] = ["oidc_convert", "tailnet_resume", "workspace_apps"]
|
||||
export type CryptoKeyFeature = "oidc_convert" | "tailnet_resume" | "workspace_apps_api_key" | "workspace_apps_token"
|
||||
export const CryptoKeyFeatures: CryptoKeyFeature[] = ["oidc_convert", "tailnet_resume", "workspace_apps_api_key", "workspace_apps_token"]
|
||||
|
||||
// From codersdk/workspaceagents.go
|
||||
export type DisplayApp = "port_forwarding_helper" | "ssh_helper" | "vscode" | "vscode_insiders" | "web_terminal"
|
||||
|
||||
+26
-117
@@ -3,32 +3,23 @@ package tailnet
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/go-jose/go-jose/v3"
|
||||
"github.com/go-jose/go-jose/v4/jwt"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/jwtutils"
|
||||
"github.com/coder/coder/v2/tailnet/proto"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultResumeTokenExpiry = 24 * time.Hour
|
||||
|
||||
resumeTokenSigningAlgorithm = jose.HS512
|
||||
)
|
||||
|
||||
// resumeTokenSigningKeyID is a fixed key ID for the resume token signing key.
|
||||
// If/when we add support for multiple keys (e.g. key rotation), this will move
|
||||
// to the database instead.
|
||||
var resumeTokenSigningKeyID = uuid.MustParse("97166747-9309-4d7f-9071-a230e257c2a4")
|
||||
|
||||
// NewInsecureTestResumeTokenProvider returns a ResumeTokenProvider that uses a
|
||||
// random key with short expiry for testing purposes. If any errors occur while
|
||||
// generating the key, the function panics.
|
||||
@@ -37,12 +28,15 @@ func NewInsecureTestResumeTokenProvider() ResumeTokenProvider {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return NewResumeTokenKeyProvider(key, quartz.NewReal(), time.Hour)
|
||||
return NewResumeTokenKeyProvider(jwtutils.StaticKey{
|
||||
ID: uuid.New().String(),
|
||||
Key: key[:],
|
||||
}, quartz.NewReal(), time.Hour)
|
||||
}
|
||||
|
||||
type ResumeTokenProvider interface {
|
||||
GenerateResumeToken(peerID uuid.UUID) (*proto.RefreshResumeTokenResponse, error)
|
||||
VerifyResumeToken(token string) (uuid.UUID, error)
|
||||
GenerateResumeToken(ctx context.Context, peerID uuid.UUID) (*proto.RefreshResumeTokenResponse, error)
|
||||
VerifyResumeToken(ctx context.Context, token string) (uuid.UUID, error)
|
||||
}
|
||||
|
||||
type ResumeTokenSigningKey [64]byte
|
||||
@@ -56,104 +50,37 @@ func GenerateResumeTokenSigningKey() (ResumeTokenSigningKey, error) {
|
||||
return key, nil
|
||||
}
|
||||
|
||||
type ResumeTokenSigningKeyDatabaseStore interface {
|
||||
GetCoordinatorResumeTokenSigningKey(ctx context.Context) (string, error)
|
||||
UpsertCoordinatorResumeTokenSigningKey(ctx context.Context, key string) error
|
||||
}
|
||||
|
||||
// ResumeTokenSigningKeyFromDatabase retrieves the coordinator resume token
|
||||
// signing key from the database. If the key is not found, a new key is
|
||||
// generated and inserted into the database.
|
||||
func ResumeTokenSigningKeyFromDatabase(ctx context.Context, db ResumeTokenSigningKeyDatabaseStore) (ResumeTokenSigningKey, error) {
|
||||
var resumeTokenKey ResumeTokenSigningKey
|
||||
resumeTokenKeyStr, err := db.GetCoordinatorResumeTokenSigningKey(ctx)
|
||||
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
|
||||
return resumeTokenKey, xerrors.Errorf("get coordinator resume token key: %w", err)
|
||||
}
|
||||
if decoded, err := hex.DecodeString(resumeTokenKeyStr); err != nil || len(decoded) != len(resumeTokenKey) {
|
||||
newKey, err := GenerateResumeTokenSigningKey()
|
||||
if err != nil {
|
||||
return resumeTokenKey, xerrors.Errorf("generate fresh coordinator resume token key: %w", err)
|
||||
}
|
||||
|
||||
resumeTokenKeyStr = hex.EncodeToString(newKey[:])
|
||||
err = db.UpsertCoordinatorResumeTokenSigningKey(ctx, resumeTokenKeyStr)
|
||||
if err != nil {
|
||||
return resumeTokenKey, xerrors.Errorf("insert freshly generated coordinator resume token key to database: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
resumeTokenKeyBytes, err := hex.DecodeString(resumeTokenKeyStr)
|
||||
if err != nil {
|
||||
return resumeTokenKey, xerrors.Errorf("decode coordinator resume token key from database: %w", err)
|
||||
}
|
||||
if len(resumeTokenKeyBytes) != len(resumeTokenKey) {
|
||||
return resumeTokenKey, xerrors.Errorf("coordinator resume token key in database is not the correct length, expect %d got %d", len(resumeTokenKey), len(resumeTokenKeyBytes))
|
||||
}
|
||||
copy(resumeTokenKey[:], resumeTokenKeyBytes)
|
||||
if resumeTokenKey == [64]byte{} {
|
||||
return resumeTokenKey, xerrors.Errorf("coordinator resume token key in database is empty")
|
||||
}
|
||||
return resumeTokenKey, nil
|
||||
}
|
||||
|
||||
type ResumeTokenKeyProvider struct {
|
||||
key ResumeTokenSigningKey
|
||||
key jwtutils.SigningKeyManager
|
||||
clock quartz.Clock
|
||||
expiry time.Duration
|
||||
}
|
||||
|
||||
func NewResumeTokenKeyProvider(key ResumeTokenSigningKey, clock quartz.Clock, expiry time.Duration) ResumeTokenProvider {
|
||||
func NewResumeTokenKeyProvider(key jwtutils.SigningKeyManager, clock quartz.Clock, expiry time.Duration) ResumeTokenProvider {
|
||||
if expiry <= 0 {
|
||||
expiry = DefaultResumeTokenExpiry
|
||||
}
|
||||
return ResumeTokenKeyProvider{
|
||||
key: key,
|
||||
clock: clock,
|
||||
expiry: DefaultResumeTokenExpiry,
|
||||
expiry: expiry,
|
||||
}
|
||||
}
|
||||
|
||||
type resumeTokenPayload struct {
|
||||
PeerID uuid.UUID `json:"sub"`
|
||||
Expiry int64 `json:"exp"`
|
||||
}
|
||||
|
||||
func (p ResumeTokenKeyProvider) GenerateResumeToken(peerID uuid.UUID) (*proto.RefreshResumeTokenResponse, error) {
|
||||
func (p ResumeTokenKeyProvider) GenerateResumeToken(ctx context.Context, peerID uuid.UUID) (*proto.RefreshResumeTokenResponse, error) {
|
||||
exp := p.clock.Now().Add(p.expiry)
|
||||
payload := resumeTokenPayload{
|
||||
PeerID: peerID,
|
||||
Expiry: exp.Unix(),
|
||||
}
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("marshal payload to JSON: %w", err)
|
||||
payload := jwtutils.RegisteredClaims{
|
||||
Subject: peerID.String(),
|
||||
Expiry: jwt.NewNumericDate(exp),
|
||||
}
|
||||
|
||||
signer, err := jose.NewSigner(jose.SigningKey{
|
||||
Algorithm: resumeTokenSigningAlgorithm,
|
||||
Key: p.key[:],
|
||||
}, &jose.SignerOptions{
|
||||
ExtraHeaders: map[jose.HeaderKey]interface{}{
|
||||
"kid": resumeTokenSigningKeyID.String(),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("create signer: %w", err)
|
||||
}
|
||||
|
||||
signedObject, err := signer.Sign(payloadBytes)
|
||||
token, err := jwtutils.Sign(ctx, p.key, payload)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("sign payload: %w", err)
|
||||
}
|
||||
|
||||
serialized, err := signedObject.CompactSerialize()
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("serialize JWS: %w", err)
|
||||
}
|
||||
|
||||
return &proto.RefreshResumeTokenResponse{
|
||||
Token: serialized,
|
||||
Token: token,
|
||||
RefreshIn: durationpb.New(p.expiry / 2),
|
||||
ExpiresAt: timestamppb.New(exp),
|
||||
}, nil
|
||||
@@ -162,35 +89,17 @@ func (p ResumeTokenKeyProvider) GenerateResumeToken(peerID uuid.UUID) (*proto.Re
|
||||
// VerifyResumeToken parses a signed tailnet resume token with the given key and
|
||||
// returns the payload. If the token is invalid or expired, an error is
|
||||
// returned.
|
||||
func (p ResumeTokenKeyProvider) VerifyResumeToken(str string) (uuid.UUID, error) {
|
||||
object, err := jose.ParseSigned(str)
|
||||
func (p ResumeTokenKeyProvider) VerifyResumeToken(ctx context.Context, str string) (uuid.UUID, error) {
|
||||
var tok jwt.Claims
|
||||
err := jwtutils.Verify(ctx, p.key, str, &tok, jwtutils.WithVerifyExpected(jwt.Expected{
|
||||
Time: p.clock.Now(),
|
||||
}))
|
||||
if err != nil {
|
||||
return uuid.Nil, xerrors.Errorf("parse JWS: %w", err)
|
||||
return uuid.Nil, xerrors.Errorf("verify payload: %w", err)
|
||||
}
|
||||
if len(object.Signatures) != 1 {
|
||||
return uuid.Nil, xerrors.New("expected 1 signature")
|
||||
}
|
||||
if object.Signatures[0].Header.Algorithm != string(resumeTokenSigningAlgorithm) {
|
||||
return uuid.Nil, xerrors.Errorf("expected token signing algorithm to be %q, got %q", resumeTokenSigningAlgorithm, object.Signatures[0].Header.Algorithm)
|
||||
}
|
||||
if object.Signatures[0].Header.KeyID != resumeTokenSigningKeyID.String() {
|
||||
return uuid.Nil, xerrors.Errorf("expected token key ID to be %q, got %q", resumeTokenSigningKeyID, object.Signatures[0].Header.KeyID)
|
||||
}
|
||||
|
||||
output, err := object.Verify(p.key[:])
|
||||
parsed, err := uuid.Parse(tok.Subject)
|
||||
if err != nil {
|
||||
return uuid.Nil, xerrors.Errorf("verify JWS: %w", err)
|
||||
return uuid.Nil, xerrors.Errorf("parse peerID from token: %w", err)
|
||||
}
|
||||
|
||||
var tok resumeTokenPayload
|
||||
err = json.Unmarshal(output, &tok)
|
||||
if err != nil {
|
||||
return uuid.Nil, xerrors.Errorf("unmarshal payload: %w", err)
|
||||
}
|
||||
exp := time.Unix(tok.Expiry, 0)
|
||||
if exp.Before(p.clock.Now()) {
|
||||
return uuid.Nil, xerrors.New("signed resume token expired")
|
||||
}
|
||||
|
||||
return tok.PeerID, nil
|
||||
return parsed, nil
|
||||
}
|
||||
|
||||
+34
-116
@@ -1,117 +1,20 @@
|
||||
package tailnet_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-jose/go-jose/v4"
|
||||
"github.com/go-jose/go-jose/v4/jwt"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/coderd/jwtutils"
|
||||
"github.com/coder/coder/v2/tailnet"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
func TestResumeTokenSigningKeyFromDatabase(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assertRandomKey := func(t *testing.T, key tailnet.ResumeTokenSigningKey) {
|
||||
t.Helper()
|
||||
assert.NotEqual(t, tailnet.ResumeTokenSigningKey{}, key, "key should not be empty")
|
||||
assert.NotEqualValues(t, [64]byte{1}, key, "key should not be all 1s")
|
||||
}
|
||||
|
||||
t.Run("GenerateRetrieve", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
key1, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db)
|
||||
require.NoError(t, err)
|
||||
assertRandomKey(t, key1)
|
||||
|
||||
key2, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, key1, key2, "keys should not be different")
|
||||
})
|
||||
|
||||
t.Run("GetError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db := dbmock.NewMockStore(gomock.NewController(t))
|
||||
db.EXPECT().GetCoordinatorResumeTokenSigningKey(gomock.Any()).Return("", assert.AnError)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
_, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db)
|
||||
require.ErrorIs(t, err, assert.AnError)
|
||||
})
|
||||
|
||||
t.Run("UpsertError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db := dbmock.NewMockStore(gomock.NewController(t))
|
||||
db.EXPECT().GetCoordinatorResumeTokenSigningKey(gomock.Any()).Return("", nil)
|
||||
db.EXPECT().UpsertCoordinatorResumeTokenSigningKey(gomock.Any(), gomock.Any()).Return(assert.AnError)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
_, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db)
|
||||
require.ErrorIs(t, err, assert.AnError)
|
||||
})
|
||||
|
||||
t.Run("DecodeErrorShouldRegenerate", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db := dbmock.NewMockStore(gomock.NewController(t))
|
||||
db.EXPECT().GetCoordinatorResumeTokenSigningKey(gomock.Any()).Return("invalid", nil)
|
||||
|
||||
var storedKey tailnet.ResumeTokenSigningKey
|
||||
db.EXPECT().UpsertCoordinatorResumeTokenSigningKey(gomock.Any(), gomock.Any()).Do(func(_ context.Context, value string) error {
|
||||
keyBytes, err := hex.DecodeString(value)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, keyBytes, len(storedKey))
|
||||
copy(storedKey[:], keyBytes)
|
||||
return nil
|
||||
})
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
key, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db)
|
||||
require.NoError(t, err)
|
||||
assertRandomKey(t, key)
|
||||
require.Equal(t, storedKey, key, "key should match stored value")
|
||||
})
|
||||
|
||||
t.Run("LengthErrorShouldRegenerate", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db := dbmock.NewMockStore(gomock.NewController(t))
|
||||
db.EXPECT().GetCoordinatorResumeTokenSigningKey(gomock.Any()).Return("deadbeef", nil)
|
||||
db.EXPECT().UpsertCoordinatorResumeTokenSigningKey(gomock.Any(), gomock.Any()).Return(nil)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
key, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db)
|
||||
require.NoError(t, err)
|
||||
assertRandomKey(t, key)
|
||||
})
|
||||
|
||||
t.Run("EmptyError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db := dbmock.NewMockStore(gomock.NewController(t))
|
||||
emptyKey := hex.EncodeToString(make([]byte, 64))
|
||||
db.EXPECT().GetCoordinatorResumeTokenSigningKey(gomock.Any()).Return(emptyKey, nil)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
_, err := tailnet.ResumeTokenSigningKeyFromDatabase(ctx, db)
|
||||
require.ErrorContains(t, err, "is empty")
|
||||
})
|
||||
}
|
||||
|
||||
func TestResumeTokenKeyProvider(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -121,17 +24,18 @@ func TestResumeTokenKeyProvider(t *testing.T) {
|
||||
t.Run("OK", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
id := uuid.New()
|
||||
clock := quartz.NewMock(t)
|
||||
provider := tailnet.NewResumeTokenKeyProvider(key, clock, tailnet.DefaultResumeTokenExpiry)
|
||||
token, err := provider.GenerateResumeToken(id)
|
||||
provider := tailnet.NewResumeTokenKeyProvider(newKeySigner(key), clock, tailnet.DefaultResumeTokenExpiry)
|
||||
token, err := provider.GenerateResumeToken(ctx, id)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, token)
|
||||
require.NotEmpty(t, token.Token)
|
||||
require.Equal(t, tailnet.DefaultResumeTokenExpiry/2, token.RefreshIn.AsDuration())
|
||||
require.WithinDuration(t, clock.Now().Add(tailnet.DefaultResumeTokenExpiry), token.ExpiresAt.AsTime(), time.Second)
|
||||
|
||||
gotID, err := provider.VerifyResumeToken(token.Token)
|
||||
gotID, err := provider.VerifyResumeToken(ctx, token.Token)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, id, gotID)
|
||||
})
|
||||
@@ -139,43 +43,57 @@ func TestResumeTokenKeyProvider(t *testing.T) {
|
||||
t.Run("Expired", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
id := uuid.New()
|
||||
clock := quartz.NewMock(t)
|
||||
provider := tailnet.NewResumeTokenKeyProvider(key, clock, tailnet.DefaultResumeTokenExpiry)
|
||||
token, err := provider.GenerateResumeToken(id)
|
||||
provider := tailnet.NewResumeTokenKeyProvider(newKeySigner(key), clock, tailnet.DefaultResumeTokenExpiry)
|
||||
token, err := provider.GenerateResumeToken(ctx, id)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, token)
|
||||
require.NotEmpty(t, token.Token)
|
||||
require.Equal(t, tailnet.DefaultResumeTokenExpiry/2, token.RefreshIn.AsDuration())
|
||||
require.WithinDuration(t, clock.Now().Add(tailnet.DefaultResumeTokenExpiry), token.ExpiresAt.AsTime(), time.Second)
|
||||
|
||||
// Advance time past expiry
|
||||
_ = clock.Advance(tailnet.DefaultResumeTokenExpiry + time.Second)
|
||||
// Advance time past expiry. Account for leeway.
|
||||
_ = clock.Advance(tailnet.DefaultResumeTokenExpiry + time.Second*61)
|
||||
|
||||
_, err = provider.VerifyResumeToken(token.Token)
|
||||
require.ErrorContains(t, err, "expired")
|
||||
_, err = provider.VerifyResumeToken(ctx, token.Token)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, jwt.ErrExpired)
|
||||
})
|
||||
|
||||
t.Run("InvalidToken", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
provider := tailnet.NewResumeTokenKeyProvider(key, quartz.NewMock(t), tailnet.DefaultResumeTokenExpiry)
|
||||
_, err := provider.VerifyResumeToken("invalid")
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
provider := tailnet.NewResumeTokenKeyProvider(newKeySigner(key), quartz.NewMock(t), tailnet.DefaultResumeTokenExpiry)
|
||||
_, err := provider.VerifyResumeToken(ctx, "invalid")
|
||||
require.ErrorContains(t, err, "parse JWS")
|
||||
})
|
||||
|
||||
t.Run("VerifyError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
// Generate a resume token with a different key
|
||||
otherKey, err := tailnet.GenerateResumeTokenSigningKey()
|
||||
require.NoError(t, err)
|
||||
otherProvider := tailnet.NewResumeTokenKeyProvider(otherKey, quartz.NewMock(t), tailnet.DefaultResumeTokenExpiry)
|
||||
token, err := otherProvider.GenerateResumeToken(uuid.New())
|
||||
otherSigner := newKeySigner(otherKey)
|
||||
otherProvider := tailnet.NewResumeTokenKeyProvider(otherSigner, quartz.NewMock(t), tailnet.DefaultResumeTokenExpiry)
|
||||
token, err := otherProvider.GenerateResumeToken(ctx, uuid.New())
|
||||
require.NoError(t, err)
|
||||
|
||||
provider := tailnet.NewResumeTokenKeyProvider(key, quartz.NewMock(t), tailnet.DefaultResumeTokenExpiry)
|
||||
_, err = provider.VerifyResumeToken(token.Token)
|
||||
require.ErrorContains(t, err, "verify JWS")
|
||||
signer := newKeySigner(key)
|
||||
signer.ID = otherSigner.ID
|
||||
provider := tailnet.NewResumeTokenKeyProvider(signer, quartz.NewMock(t), tailnet.DefaultResumeTokenExpiry)
|
||||
_, err = provider.VerifyResumeToken(ctx, token.Token)
|
||||
require.ErrorIs(t, err, jose.ErrCryptoFailure)
|
||||
})
|
||||
}
|
||||
|
||||
func newKeySigner(key tailnet.ResumeTokenSigningKey) jwtutils.StaticKey {
|
||||
return jwtutils.StaticKey{
|
||||
ID: "123",
|
||||
Key: key[:],
|
||||
}
|
||||
}
|
||||
|
||||
+1
-1
@@ -177,7 +177,7 @@ func (s *DRPCService) RefreshResumeToken(ctx context.Context, _ *proto.RefreshRe
|
||||
return nil, xerrors.New("no Stream ID")
|
||||
}
|
||||
|
||||
res, err := s.ResumeTokenProvider.GenerateResumeToken(streamID.ID)
|
||||
res, err := s.ResumeTokenProvider.GenerateResumeToken(ctx, streamID.ID)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("generate resume token: %w", err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user