mirror of
https://github.com/coder/coder.git
synced 2026-06-03 21:18:24 +00:00
d6e9344e03
The Azure instance-identity authentication endpoint parsed the PKCS7 envelope and verified the certificate chain, but never verified the PKCS7 signature itself. An attacker could forge a PKCS7 envelope with a legitimate, publicly obtainable Azure certificate and arbitrary vmId content to obtain any agent auth token. Add verifyPKCS7Signature(), a custom PKCS7 signature verification that handles Azure non-standard use of sha256WithRSAEncryption (OID 1.2.840.113549.1.1.11) as the DigestAlgorithm. The upstream go.mozilla.org/pkcs7 library Verify() rejects this combination. The verification checks: 1. Content digest matches the signed message-digest attribute 2. Signature over the authenticated attributes is valid Tests added: - TestValidate_TamperedContent: forges a PKCS7 with modified vmId, confirms rejection - TestValidate_UntrustedCertWithValidSignature: valid PKCS7 signature with untrusted cert chain, confirms rejection Co-authored-by: Jakub Domeracki <jakub@coder.com>
1858 lines
68 KiB
Go
1858 lines
68 KiB
Go
package coderdtest
|
||
|
||
import (
|
||
"archive/tar"
|
||
"bytes"
|
||
"context"
|
||
"crypto"
|
||
"crypto/rand"
|
||
"crypto/rsa"
|
||
"crypto/sha256"
|
||
"crypto/tls"
|
||
"crypto/x509"
|
||
"crypto/x509/pkix"
|
||
"database/sql"
|
||
"encoding/base64"
|
||
"encoding/json"
|
||
"encoding/pem"
|
||
"errors"
|
||
"fmt"
|
||
"io"
|
||
"math/big"
|
||
"net"
|
||
"net/http"
|
||
"net/http/httptest"
|
||
"net/url"
|
||
"regexp"
|
||
"strconv"
|
||
"strings"
|
||
"sync"
|
||
"sync/atomic"
|
||
"testing"
|
||
"time"
|
||
|
||
"cloud.google.com/go/compute/metadata"
|
||
"github.com/go-chi/chi/v5"
|
||
"github.com/golang-jwt/jwt/v4"
|
||
"github.com/google/uuid"
|
||
"github.com/prometheus/client_golang/prometheus"
|
||
"github.com/smallstep/pkcs7"
|
||
"github.com/stretchr/testify/assert"
|
||
"github.com/stretchr/testify/require"
|
||
"golang.org/x/text/cases"
|
||
"golang.org/x/text/language"
|
||
"golang.org/x/xerrors"
|
||
"google.golang.org/api/idtoken"
|
||
"google.golang.org/api/option"
|
||
"tailscale.com/derp"
|
||
"tailscale.com/net/stun/stuntest"
|
||
"tailscale.com/tailcfg"
|
||
"tailscale.com/types/key"
|
||
"tailscale.com/types/nettype"
|
||
|
||
"cdr.dev/slog/v3"
|
||
"cdr.dev/slog/v3/sloggers/sloghuman"
|
||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||
"github.com/coder/coder/v2/archive"
|
||
"github.com/coder/coder/v2/coderd"
|
||
"github.com/coder/coder/v2/coderd/agentapi/metadatabatcher"
|
||
"github.com/coder/coder/v2/coderd/audit"
|
||
"github.com/coder/coder/v2/coderd/autobuild"
|
||
"github.com/coder/coder/v2/coderd/awsidentity"
|
||
"github.com/coder/coder/v2/coderd/azureidentity"
|
||
"github.com/coder/coder/v2/coderd/connectionlog"
|
||
"github.com/coder/coder/v2/coderd/cryptokeys"
|
||
"github.com/coder/coder/v2/coderd/database"
|
||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||
"github.com/coder/coder/v2/coderd/database/dbrollup"
|
||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||
"github.com/coder/coder/v2/coderd/externalauth"
|
||
"github.com/coder/coder/v2/coderd/files"
|
||
"github.com/coder/coder/v2/coderd/gitsshkey"
|
||
"github.com/coder/coder/v2/coderd/healthcheck"
|
||
"github.com/coder/coder/v2/coderd/httpmw"
|
||
"github.com/coder/coder/v2/coderd/jobreaper"
|
||
"github.com/coder/coder/v2/coderd/notifications"
|
||
"github.com/coder/coder/v2/coderd/notifications/notificationstest"
|
||
"github.com/coder/coder/v2/coderd/provisionerdserver"
|
||
"github.com/coder/coder/v2/coderd/rbac"
|
||
"github.com/coder/coder/v2/coderd/rbac/policy"
|
||
"github.com/coder/coder/v2/coderd/rbac/rolestore"
|
||
"github.com/coder/coder/v2/coderd/runtimeconfig"
|
||
"github.com/coder/coder/v2/coderd/schedule"
|
||
"github.com/coder/coder/v2/coderd/telemetry"
|
||
"github.com/coder/coder/v2/coderd/updatecheck"
|
||
"github.com/coder/coder/v2/coderd/usage"
|
||
"github.com/coder/coder/v2/coderd/util/namesgenerator"
|
||
"github.com/coder/coder/v2/coderd/util/ptr"
|
||
"github.com/coder/coder/v2/coderd/util/slice"
|
||
"github.com/coder/coder/v2/coderd/webpush"
|
||
"github.com/coder/coder/v2/coderd/workspaceapps"
|
||
"github.com/coder/coder/v2/coderd/workspaceapps/appurl"
|
||
"github.com/coder/coder/v2/coderd/workspacestats"
|
||
"github.com/coder/coder/v2/coderd/wsbuilder"
|
||
"github.com/coder/coder/v2/codersdk"
|
||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||
"github.com/coder/coder/v2/codersdk/drpcsdk"
|
||
"github.com/coder/coder/v2/codersdk/healthsdk"
|
||
"github.com/coder/coder/v2/cryptorand"
|
||
"github.com/coder/coder/v2/provisioner/echo"
|
||
"github.com/coder/coder/v2/provisionerd"
|
||
provisionerdproto "github.com/coder/coder/v2/provisionerd/proto"
|
||
"github.com/coder/coder/v2/provisionersdk"
|
||
sdkproto "github.com/coder/coder/v2/provisionersdk/proto"
|
||
"github.com/coder/coder/v2/tailnet"
|
||
"github.com/coder/coder/v2/testutil"
|
||
"github.com/coder/quartz"
|
||
)
|
||
|
||
const DefaultDERPMeshKey = "test-key"
|
||
|
||
const defaultTestDaemonName = "test-daemon"
|
||
|
||
type Options struct {
|
||
// AccessURL denotes a custom access URL. By default we use the httptest
|
||
// server's URL. Setting this may result in unexpected behavior (especially
|
||
// with running agents).
|
||
AccessURL *url.URL
|
||
AppHostname string
|
||
AWSCertificates awsidentity.Certificates
|
||
Authorizer rbac.Authorizer
|
||
AzureCertificates azureidentity.Options
|
||
GithubOAuth2Config *coderd.GithubOAuth2Config
|
||
RealIPConfig *httpmw.RealIPConfig
|
||
OIDCConfig *coderd.OIDCConfig
|
||
GoogleTokenValidator *idtoken.Validator
|
||
SSHKeygenAlgorithm gitsshkey.Algorithm
|
||
AutobuildTicker <-chan time.Time
|
||
AutobuildStats chan<- autobuild.Stats
|
||
Auditor audit.Auditor
|
||
TLSCertificates []tls.Certificate
|
||
ExternalAuthConfigs []*externalauth.Config
|
||
TrialGenerator func(ctx context.Context, body codersdk.LicensorTrialRequest) error
|
||
RefreshEntitlements func(ctx context.Context) error
|
||
TemplateScheduleStore schedule.TemplateScheduleStore
|
||
Coordinator tailnet.Coordinator
|
||
CoordinatorResumeTokenProvider tailnet.ResumeTokenProvider
|
||
ConnectionLogger connectionlog.ConnectionLogger
|
||
|
||
HealthcheckFunc func(ctx context.Context, apiKey string, progress *healthcheck.Progress) *healthsdk.HealthcheckReport
|
||
HealthcheckTimeout time.Duration
|
||
HealthcheckRefresh time.Duration
|
||
|
||
// All rate limits default to -1 (unlimited) in tests if not set.
|
||
APIRateLimit int
|
||
LoginRateLimit int
|
||
FilesRateLimit int
|
||
|
||
// OneTimePasscodeValidityPeriod specifies how long a one time passcode should be valid for.
|
||
OneTimePasscodeValidityPeriod time.Duration
|
||
|
||
// IncludeProvisionerDaemon when true means to start an in-memory provisionerD
|
||
IncludeProvisionerDaemon bool
|
||
ChatdInstructionLookupTimeout time.Duration
|
||
ProvisionerDaemonVersion string
|
||
ProvisionerDaemonTags map[string]string
|
||
MetricsCacheRefreshInterval time.Duration
|
||
AgentStatsRefreshInterval time.Duration
|
||
DeploymentValues *codersdk.DeploymentValues
|
||
|
||
// Set update check options to enable update check.
|
||
UpdateCheckOptions *updatecheck.Options
|
||
|
||
// Overriding the database is heavily discouraged.
|
||
// It should only be used in cases where multiple Coder
|
||
// test instances are running against the same database.
|
||
Database database.Store
|
||
Pubsub pubsub.Pubsub
|
||
|
||
// APIMiddleware inserts middleware before api.RootHandler, this can be
|
||
// useful in certain tests where you want to intercept requests before
|
||
// passing them on to the API, e.g. for synchronization of execution.
|
||
APIMiddleware func(http.Handler) http.Handler
|
||
|
||
ConfigSSH codersdk.SSHConfigResponse
|
||
|
||
SwaggerEndpoint bool
|
||
// Logger should only be overridden if you expect errors
|
||
// as part of your test.
|
||
Logger *slog.Logger
|
||
StatsBatcher workspacestats.Batcher
|
||
MetadataBatcherOptions []metadatabatcher.Option
|
||
|
||
WebpushDispatcher webpush.Dispatcher
|
||
WorkspaceAppsStatsCollectorOptions workspaceapps.StatsCollectorOptions
|
||
AllowWorkspaceRenames bool
|
||
NewTicker func(duration time.Duration) (<-chan time.Time, func())
|
||
DatabaseRolluper *dbrollup.Rolluper
|
||
WorkspaceUsageTrackerFlush chan int
|
||
WorkspaceUsageTrackerTick chan time.Time
|
||
NotificationsEnqueuer notifications.Enqueuer
|
||
APIKeyEncryptionCache cryptokeys.EncryptionKeycache
|
||
OIDCConvertKeyCache cryptokeys.SigningKeycache
|
||
Clock quartz.Clock
|
||
TelemetryReporter telemetry.Reporter
|
||
|
||
ProvisionerdServerMetrics *provisionerdserver.Metrics
|
||
WorkspaceBuilderMetrics *wsbuilder.Metrics
|
||
UsageInserter usage.Inserter
|
||
}
|
||
|
||
// New constructs a codersdk client connected to an in-memory API instance.
|
||
func New(t testing.TB, options *Options) *codersdk.Client {
|
||
client, _ := newWithCloser(t, options)
|
||
return client
|
||
}
|
||
|
||
// NewWithDatabase constructs a codersdk client connected to an in-memory API instance.
|
||
// The database is returned to provide direct data manipulation for tests.
|
||
func NewWithDatabase(t testing.TB, options *Options) (*codersdk.Client, database.Store) {
|
||
client, _, api := NewWithAPI(t, options)
|
||
return client, api.Database
|
||
}
|
||
|
||
// NewWithProvisionerCloser returns a client as well as a handle to close
|
||
// the provisioner. This is a temporary function while work is done to
|
||
// standardize how provisioners are registered with coderd. The option
|
||
// to include a provisioner is set to true for convenience.
|
||
func NewWithProvisionerCloser(t testing.TB, options *Options) (*codersdk.Client, io.Closer) {
|
||
if options == nil {
|
||
options = &Options{}
|
||
}
|
||
options.IncludeProvisionerDaemon = true
|
||
client, closer := newWithCloser(t, options)
|
||
return client, closer
|
||
}
|
||
|
||
// newWithCloser constructs a codersdk client connected to an in-memory API instance.
|
||
// The returned closer closes a provisioner if it was provided
|
||
// The API is intentionally not returned here because coderd tests should not
|
||
// require a handle to the API. Do not expose the API or wrath shall descend
|
||
// upon thee. Even the io.Closer that is exposed here shouldn't be exposed
|
||
// and is a temporary measure while the API to register provisioners is ironed
|
||
// out.
|
||
func newWithCloser(t testing.TB, options *Options) (*codersdk.Client, io.Closer) {
|
||
client, closer, _ := NewWithAPI(t, options)
|
||
return client, closer
|
||
}
|
||
|
||
func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.CancelFunc, *url.URL, *coderd.Options) {
|
||
t.Helper()
|
||
|
||
if options == nil {
|
||
options = &Options{}
|
||
}
|
||
if options.Logger == nil {
|
||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug).Named("coderd")
|
||
options.Logger = &logger
|
||
}
|
||
if options.GoogleTokenValidator == nil {
|
||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||
t.Cleanup(cancelFunc)
|
||
var err error
|
||
options.GoogleTokenValidator, err = idtoken.NewValidator(ctx, option.WithoutAuthentication())
|
||
require.NoError(t, err)
|
||
}
|
||
if options.AutobuildTicker == nil {
|
||
ticker := make(chan time.Time)
|
||
options.AutobuildTicker = ticker
|
||
t.Cleanup(func() { close(ticker) })
|
||
}
|
||
if options.AutobuildStats != nil {
|
||
t.Cleanup(func() {
|
||
close(options.AutobuildStats)
|
||
})
|
||
}
|
||
|
||
if options.Authorizer == nil {
|
||
defAuth := rbac.NewStrictCachingAuthorizer(prometheus.NewRegistry())
|
||
if _, ok := t.(*testing.T); ok {
|
||
options.Authorizer = &RecordingAuthorizer{
|
||
Wrapped: defAuth,
|
||
}
|
||
} else {
|
||
// In benchmarks, the recording authorizer greatly skews results.
|
||
options.Authorizer = defAuth
|
||
}
|
||
}
|
||
|
||
var usageInserter *atomic.Pointer[usage.Inserter]
|
||
if options.UsageInserter != nil {
|
||
usageInserter = &atomic.Pointer[usage.Inserter]{}
|
||
usageInserter.Store(&options.UsageInserter)
|
||
}
|
||
if options.Database == nil {
|
||
options.Database, options.Pubsub = dbtestutil.NewDB(t)
|
||
}
|
||
if options.CoordinatorResumeTokenProvider == nil {
|
||
options.CoordinatorResumeTokenProvider = tailnet.NewInsecureTestResumeTokenProvider()
|
||
}
|
||
|
||
if options.NotificationsEnqueuer == nil {
|
||
options.NotificationsEnqueuer = ¬ificationstest.FakeEnqueuer{}
|
||
}
|
||
|
||
accessControlStore := &atomic.Pointer[dbauthz.AccessControlStore]{}
|
||
var acs dbauthz.AccessControlStore = dbauthz.AGPLTemplateAccessControlStore{}
|
||
accessControlStore.Store(&acs)
|
||
|
||
runtimeManager := runtimeconfig.NewManager()
|
||
options.Database = dbauthz.New(options.Database, options.Authorizer, *options.Logger, accessControlStore)
|
||
|
||
// Some routes expect a deployment ID, so just make sure one exists.
|
||
// Check first incase the caller already set up this database.
|
||
// nolint:gocritic // Setting up unit test data inside test helper
|
||
depID, err := options.Database.GetDeploymentID(dbauthz.AsSystemRestricted(context.Background()))
|
||
if xerrors.Is(err, sql.ErrNoRows) || depID == "" {
|
||
// nolint:gocritic // Setting up unit test data inside test helper
|
||
err := options.Database.InsertDeploymentID(dbauthz.AsSystemRestricted(context.Background()), uuid.NewString())
|
||
require.NoError(t, err, "insert a deployment id")
|
||
}
|
||
|
||
if options.WebpushDispatcher == nil {
|
||
// nolint:gocritic // Gets/sets VAPID keys.
|
||
pushNotifier, err := webpush.New(dbauthz.AsNotifier(context.Background()), options.Logger, options.Database, "http://example.com")
|
||
if err != nil {
|
||
panic(xerrors.Errorf("failed to create web push notifier: %w", err))
|
||
}
|
||
options.WebpushDispatcher = pushNotifier
|
||
}
|
||
|
||
if options.DeploymentValues == nil {
|
||
options.DeploymentValues = DeploymentValues(t)
|
||
}
|
||
// DisableOwnerWorkspaceExec modifies the 'global' RBAC roles. Fast-fail tests if we detect this.
|
||
if !options.DeploymentValues.DisableOwnerWorkspaceExec.Value() {
|
||
ownerSubj := rbac.Subject{
|
||
Roles: rbac.RoleIdentifiers{rbac.RoleOwner()},
|
||
Scope: rbac.ScopeAll,
|
||
}
|
||
if err := options.Authorizer.Authorize(context.Background(), ownerSubj, policy.ActionSSH, rbac.ResourceWorkspace); err != nil {
|
||
if rbac.IsUnauthorizedError(err) {
|
||
t.Fatal("Side-effect of DisableOwnerWorkspaceExec detected in unrelated test. Please move the test that requires DisableOwnerWorkspaceExec to its own package so that it does not impact other tests!")
|
||
}
|
||
require.NoError(t, err)
|
||
}
|
||
}
|
||
|
||
// If no ratelimits are set, disable all rate limiting for tests.
|
||
if options.APIRateLimit == 0 {
|
||
options.APIRateLimit = -1
|
||
}
|
||
if options.LoginRateLimit == 0 {
|
||
options.LoginRateLimit = -1
|
||
}
|
||
if options.FilesRateLimit == 0 {
|
||
options.FilesRateLimit = -1
|
||
}
|
||
if options.StatsBatcher == nil {
|
||
ctx, cancel := context.WithCancel(context.Background())
|
||
t.Cleanup(cancel)
|
||
batcher, closeBatcher, err := workspacestats.NewBatcher(ctx,
|
||
workspacestats.BatcherWithStore(options.Database),
|
||
// Avoid cluttering up test output.
|
||
workspacestats.BatcherWithLogger(slog.Make(sloghuman.Sink(io.Discard))),
|
||
)
|
||
require.NoError(t, err, "create stats batcher")
|
||
options.StatsBatcher = batcher
|
||
t.Cleanup(closeBatcher)
|
||
}
|
||
if options.NotificationsEnqueuer == nil {
|
||
options.NotificationsEnqueuer = ¬ificationstest.FakeEnqueuer{}
|
||
}
|
||
|
||
if options.OneTimePasscodeValidityPeriod == 0 {
|
||
options.OneTimePasscodeValidityPeriod = testutil.WaitLong
|
||
}
|
||
|
||
var templateScheduleStore atomic.Pointer[schedule.TemplateScheduleStore]
|
||
if options.TemplateScheduleStore == nil {
|
||
options.TemplateScheduleStore = schedule.NewAGPLTemplateScheduleStore()
|
||
}
|
||
templateScheduleStore.Store(&options.TemplateScheduleStore)
|
||
|
||
var auditor atomic.Pointer[audit.Auditor]
|
||
if options.Auditor == nil {
|
||
options.Auditor = audit.NewNop()
|
||
}
|
||
auditor.Store(&options.Auditor)
|
||
|
||
var connectionLogger atomic.Pointer[connectionlog.ConnectionLogger]
|
||
if options.ConnectionLogger == nil {
|
||
options.ConnectionLogger = connectionlog.NewNop()
|
||
}
|
||
connectionLogger.Store(&options.ConnectionLogger)
|
||
|
||
var buildUsageChecker atomic.Pointer[wsbuilder.UsageChecker]
|
||
var noopUsageChecker wsbuilder.UsageChecker = wsbuilder.NoopUsageChecker{}
|
||
buildUsageChecker.Store(&noopUsageChecker)
|
||
|
||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||
experiments := coderd.ReadExperiments(*options.Logger, options.DeploymentValues.Experiments)
|
||
lifecycleExecutor := autobuild.NewExecutor(
|
||
ctx,
|
||
options.Database,
|
||
options.Pubsub,
|
||
files.New(prometheus.NewRegistry(), options.Authorizer),
|
||
prometheus.NewRegistry(),
|
||
&templateScheduleStore,
|
||
&auditor,
|
||
accessControlStore,
|
||
&buildUsageChecker,
|
||
*options.Logger,
|
||
options.AutobuildTicker,
|
||
options.NotificationsEnqueuer,
|
||
experiments,
|
||
options.WorkspaceBuilderMetrics,
|
||
).WithStatsChannel(options.AutobuildStats)
|
||
|
||
lifecycleExecutor.Run()
|
||
|
||
jobReaperTicker := time.NewTicker(options.DeploymentValues.JobReaperDetectorInterval.Value())
|
||
defer jobReaperTicker.Stop()
|
||
jobReaper := jobreaper.New(ctx, options.Database, options.Pubsub, options.Logger.Named("reaper.detector"), jobReaperTicker.C)
|
||
jobReaper.Start()
|
||
t.Cleanup(jobReaper.Close)
|
||
|
||
if options.TelemetryReporter == nil {
|
||
options.TelemetryReporter = telemetry.NewNoop()
|
||
}
|
||
|
||
// Did last_used_at not update? Scratching your noggin? Here's why.
|
||
// Workspace usage tracking must be triggered manually in tests.
|
||
// The vast majority of existing tests do not depend on last_used_at
|
||
// and adding an extra time-based background goroutine to all existing
|
||
// tests may lead to future flakes and goleak complaints.
|
||
// Instead, pass in your own flush and ticker like so:
|
||
//
|
||
// tickCh = make(chan time.Time)
|
||
// flushCh = make(chan int, 1)
|
||
// client = coderdtest.New(t, &coderdtest.Options{
|
||
// WorkspaceUsageTrackerFlush: flushCh,
|
||
// WorkspaceUsageTrackerTick: tickCh
|
||
// })
|
||
//
|
||
// Now to trigger a tick, just write to `tickCh`.
|
||
// Reading from `flushCh` will ensure that workspaceusage.Tracker flushed.
|
||
// See TestPortForward or TestTracker_MultipleInstances for how this works in practice.
|
||
if options.WorkspaceUsageTrackerFlush == nil {
|
||
options.WorkspaceUsageTrackerFlush = make(chan int, 1) // buffering just in case
|
||
}
|
||
if options.WorkspaceUsageTrackerTick == nil {
|
||
options.WorkspaceUsageTrackerTick = make(chan time.Time, 1) // buffering just in case
|
||
}
|
||
// Close is called by API.Close()
|
||
wuTracker := workspacestats.NewTracker(
|
||
options.Database,
|
||
workspacestats.TrackerWithLogger(options.Logger.Named("workspace_usage_tracker")),
|
||
workspacestats.TrackerWithTickFlush(options.WorkspaceUsageTrackerTick, options.WorkspaceUsageTrackerFlush),
|
||
)
|
||
|
||
// create the TempDir for the HTTP file cache BEFORE we start the server and set a t.Cleanup to close it. TempDir()
|
||
// registers a Cleanup function that deletes the directory, and Cleanup functions are called in reverse order. If
|
||
// we don't do this, then we could try to delete the directory before the HTTP server is done with all files in it,
|
||
// which on Windows will fail (can't delete files until all programs have closed handles to them).
|
||
cacheDir := t.TempDir()
|
||
|
||
var mutex sync.RWMutex
|
||
var handler http.Handler
|
||
srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
mutex.RLock()
|
||
handler := handler
|
||
mutex.RUnlock()
|
||
if handler != nil {
|
||
handler.ServeHTTP(w, r)
|
||
}
|
||
}))
|
||
t.Logf("coderdtest server listening on %s", srv.Listener.Addr().String())
|
||
srv.Config.BaseContext = func(_ net.Listener) context.Context {
|
||
return ctx
|
||
}
|
||
if options.TLSCertificates != nil {
|
||
srv.TLS = &tls.Config{
|
||
Certificates: options.TLSCertificates,
|
||
MinVersion: tls.VersionTLS12,
|
||
}
|
||
srv.StartTLS()
|
||
} else {
|
||
srv.Start()
|
||
}
|
||
t.Logf("coderdtest server started on %s", srv.URL)
|
||
t.Cleanup(func() {
|
||
t.Logf("closing coderdtest server on %s", srv.Listener.Addr().String())
|
||
srv.Close()
|
||
t.Logf("closed coderdtest server on %s", srv.Listener.Addr().String())
|
||
})
|
||
|
||
tcpAddr, ok := srv.Listener.Addr().(*net.TCPAddr)
|
||
require.True(t, ok)
|
||
|
||
serverURL, err := url.Parse(srv.URL)
|
||
require.NoError(t, err)
|
||
serverURL.Host = fmt.Sprintf("127.0.0.1:%d", tcpAddr.Port)
|
||
|
||
derpPort, err := strconv.Atoi(serverURL.Port())
|
||
require.NoError(t, err)
|
||
|
||
accessURL := options.AccessURL
|
||
if accessURL == nil {
|
||
accessURL = serverURL
|
||
}
|
||
|
||
// If the STUNAddresses setting is empty or the default, start a STUN
|
||
// server. Otherwise, use the value as is.
|
||
var (
|
||
stunAddresses []string
|
||
dvStunAddresses = options.DeploymentValues.DERP.Server.STUNAddresses.Value()
|
||
)
|
||
if len(dvStunAddresses) == 0 || dvStunAddresses[0] == "stun.l.google.com:19302" {
|
||
stunAddr, stunCleanup := stuntest.ServeWithPacketListener(t, nettype.Std{})
|
||
stunAddr.IP = net.ParseIP("127.0.0.1")
|
||
t.Cleanup(stunCleanup)
|
||
stunAddresses = []string{stunAddr.String()}
|
||
options.DeploymentValues.DERP.Server.STUNAddresses = stunAddresses
|
||
} else if dvStunAddresses[0] != tailnet.DisableSTUN {
|
||
stunAddresses = options.DeploymentValues.DERP.Server.STUNAddresses.Value()
|
||
}
|
||
|
||
const derpMeshKey = "test-key"
|
||
// Technically AGPL coderd servers don't set this value, but it doesn't
|
||
// change any behavior. It's useful for enterprise tests.
|
||
err = options.Database.InsertDERPMeshKey(dbauthz.AsSystemRestricted(ctx), derpMeshKey) //nolint:gocritic // test
|
||
if !database.IsUniqueViolation(err, database.UniqueSiteConfigsKeyKey) {
|
||
require.NoError(t, err, "insert DERP mesh key")
|
||
}
|
||
var derpServer *derp.Server
|
||
if options.DeploymentValues.DERP.Server.Enable.Value() {
|
||
derpServer = derp.NewServer(key.NewNode(), tailnet.Logger(options.Logger.Named("derp").Leveled(slog.LevelDebug)))
|
||
derpServer.SetMeshKey(derpMeshKey)
|
||
}
|
||
|
||
// match default with cli default
|
||
if options.SSHKeygenAlgorithm == "" {
|
||
options.SSHKeygenAlgorithm = gitsshkey.AlgorithmEd25519
|
||
}
|
||
|
||
var appHostnameRegex *regexp.Regexp
|
||
if options.AppHostname != "" {
|
||
var err error
|
||
appHostnameRegex, err = appurl.CompileHostnamePattern(options.AppHostname)
|
||
require.NoError(t, err)
|
||
}
|
||
|
||
region := &tailcfg.DERPRegion{
|
||
EmbeddedRelay: true,
|
||
RegionID: int(options.DeploymentValues.DERP.Server.RegionID.Value()),
|
||
RegionCode: options.DeploymentValues.DERP.Server.RegionCode.String(),
|
||
RegionName: options.DeploymentValues.DERP.Server.RegionName.String(),
|
||
Nodes: []*tailcfg.DERPNode{{
|
||
Name: fmt.Sprintf("%db", options.DeploymentValues.DERP.Server.RegionID),
|
||
RegionID: int(options.DeploymentValues.DERP.Server.RegionID.Value()),
|
||
IPv4: "127.0.0.1",
|
||
DERPPort: derpPort,
|
||
// STUN port is added as a separate node by tailnet.NewDERPMap() if
|
||
// direct connections are enabled.
|
||
STUNPort: -1,
|
||
InsecureForTests: true,
|
||
ForceHTTP: options.TLSCertificates == nil,
|
||
}},
|
||
}
|
||
if !options.DeploymentValues.DERP.Server.Enable.Value() {
|
||
region = nil
|
||
}
|
||
derpMap, err := tailnet.NewDERPMap(ctx, region, stunAddresses,
|
||
options.DeploymentValues.DERP.Config.URL.Value(),
|
||
options.DeploymentValues.DERP.Config.Path.Value(),
|
||
options.DeploymentValues.DERP.Config.BlockDirect.Value(),
|
||
)
|
||
require.NoError(t, err)
|
||
|
||
return func(h http.Handler) {
|
||
mutex.Lock()
|
||
defer mutex.Unlock()
|
||
handler = h
|
||
}, cancelFunc, serverURL, &coderd.Options{
|
||
AgentConnectionUpdateFrequency: 150 * time.Millisecond,
|
||
// Force a long disconnection timeout to ensure
|
||
// agents are not marked as disconnected during slow tests.
|
||
AgentInactiveDisconnectTimeout: testutil.WaitShort,
|
||
ChatdInstructionLookupTimeout: options.ChatdInstructionLookupTimeout,
|
||
AccessURL: accessURL,
|
||
AppHostname: options.AppHostname,
|
||
AppHostnameRegex: appHostnameRegex,
|
||
Logger: *options.Logger,
|
||
CacheDir: cacheDir,
|
||
RuntimeConfig: runtimeManager,
|
||
Database: options.Database,
|
||
Pubsub: options.Pubsub,
|
||
ExternalAuthConfigs: options.ExternalAuthConfigs,
|
||
UsageInserter: usageInserter,
|
||
|
||
Auditor: options.Auditor,
|
||
ConnectionLogger: options.ConnectionLogger,
|
||
AWSCertificates: options.AWSCertificates,
|
||
AzureCertificates: options.AzureCertificates,
|
||
GithubOAuth2Config: options.GithubOAuth2Config,
|
||
RealIPConfig: options.RealIPConfig,
|
||
OIDCConfig: options.OIDCConfig,
|
||
GoogleTokenValidator: options.GoogleTokenValidator,
|
||
SSHKeygenAlgorithm: options.SSHKeygenAlgorithm,
|
||
DERPServer: derpServer,
|
||
APIRateLimit: options.APIRateLimit,
|
||
LoginRateLimit: options.LoginRateLimit,
|
||
FilesRateLimit: options.FilesRateLimit,
|
||
Authorizer: options.Authorizer,
|
||
Telemetry: options.TelemetryReporter,
|
||
TemplateScheduleStore: &templateScheduleStore,
|
||
AccessControlStore: accessControlStore,
|
||
TLSCertificates: options.TLSCertificates,
|
||
TrialGenerator: options.TrialGenerator,
|
||
RefreshEntitlements: options.RefreshEntitlements,
|
||
TailnetCoordinator: options.Coordinator,
|
||
WebPushDispatcher: options.WebpushDispatcher,
|
||
BaseDERPMap: derpMap,
|
||
DERPMapUpdateFrequency: 150 * time.Millisecond,
|
||
CoordinatorResumeTokenProvider: options.CoordinatorResumeTokenProvider,
|
||
MetricsCacheRefreshInterval: options.MetricsCacheRefreshInterval,
|
||
AgentStatsRefreshInterval: options.AgentStatsRefreshInterval,
|
||
DeploymentValues: options.DeploymentValues,
|
||
DeploymentOptions: codersdk.DeploymentOptionsWithoutSecrets(options.DeploymentValues.Options()),
|
||
UpdateCheckOptions: options.UpdateCheckOptions,
|
||
SwaggerEndpoint: options.SwaggerEndpoint,
|
||
SSHConfig: options.ConfigSSH,
|
||
HealthcheckFunc: options.HealthcheckFunc,
|
||
HealthcheckTimeout: options.HealthcheckTimeout,
|
||
HealthcheckRefresh: options.HealthcheckRefresh,
|
||
StatsBatcher: options.StatsBatcher,
|
||
MetadataBatcherOptions: options.MetadataBatcherOptions,
|
||
WorkspaceAppsStatsCollectorOptions: options.WorkspaceAppsStatsCollectorOptions,
|
||
AllowWorkspaceRenames: options.AllowWorkspaceRenames,
|
||
NewTicker: options.NewTicker,
|
||
DatabaseRolluper: options.DatabaseRolluper,
|
||
WorkspaceUsageTracker: wuTracker,
|
||
NotificationsEnqueuer: options.NotificationsEnqueuer,
|
||
OneTimePasscodeValidityPeriod: options.OneTimePasscodeValidityPeriod,
|
||
Clock: options.Clock,
|
||
AppEncryptionKeyCache: options.APIKeyEncryptionCache,
|
||
OIDCConvertKeyCache: options.OIDCConvertKeyCache,
|
||
ProvisionerdServerMetrics: options.ProvisionerdServerMetrics,
|
||
WorkspaceBuilderMetrics: options.WorkspaceBuilderMetrics,
|
||
}
|
||
}
|
||
|
||
// NewWithAPI constructs an in-memory API instance and returns a client to talk to it.
|
||
// Most tests never need a reference to the API, but AuthorizationTest in this module uses it.
|
||
// Do not expose the API or wrath shall descend upon thee.
|
||
func NewWithAPI(t testing.TB, options *Options) (*codersdk.Client, io.Closer, *coderd.API) {
|
||
if options == nil {
|
||
options = &Options{}
|
||
}
|
||
setHandler, cancelFunc, serverURL, newOptions := NewOptions(t, options)
|
||
// We set the handler after server creation for the access URL.
|
||
coderAPI := coderd.New(newOptions)
|
||
rootHandler := coderAPI.RootHandler
|
||
if options.APIMiddleware != nil {
|
||
r := chi.NewRouter()
|
||
r.Use(options.APIMiddleware)
|
||
r.Mount("/", rootHandler)
|
||
rootHandler = r
|
||
}
|
||
setHandler(rootHandler)
|
||
var provisionerCloser io.Closer = nopcloser{}
|
||
if options.IncludeProvisionerDaemon {
|
||
provisionerCloser = NewTaggedProvisionerDaemon(t, coderAPI, defaultTestDaemonName, options.ProvisionerDaemonTags, coderd.MemoryProvisionerWithVersionOverride(options.ProvisionerDaemonVersion))
|
||
}
|
||
client := codersdk.New(serverURL)
|
||
t.Cleanup(func() {
|
||
cancelFunc()
|
||
_ = provisionerCloser.Close()
|
||
_ = coderAPI.Close()
|
||
client.HTTPClient.CloseIdleConnections()
|
||
})
|
||
return client, provisionerCloser, coderAPI
|
||
}
|
||
|
||
// ProvisionerdCloser wraps a provisioner daemon as an io.Closer that can be called multiple times
|
||
type ProvisionerdCloser struct {
|
||
mu sync.Mutex
|
||
closed bool
|
||
d *provisionerd.Server
|
||
}
|
||
|
||
func NewProvisionerDaemonCloser(d *provisionerd.Server) *ProvisionerdCloser {
|
||
return &ProvisionerdCloser{d: d}
|
||
}
|
||
|
||
func (c *ProvisionerdCloser) Close() error {
|
||
c.mu.Lock()
|
||
defer c.mu.Unlock()
|
||
if c.closed {
|
||
return nil
|
||
}
|
||
c.closed = true
|
||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||
defer cancel()
|
||
shutdownErr := c.d.Shutdown(ctx, true)
|
||
closeErr := c.d.Close()
|
||
if shutdownErr != nil {
|
||
return shutdownErr
|
||
}
|
||
return closeErr
|
||
}
|
||
|
||
// NewProvisionerDaemon launches a provisionerd instance configured to work
|
||
// well with coderd testing. It registers the "echo" provisioner for
|
||
// quick testing.
|
||
func NewProvisionerDaemon(t testing.TB, coderAPI *coderd.API) io.Closer {
|
||
return NewTaggedProvisionerDaemon(t, coderAPI, defaultTestDaemonName, nil)
|
||
}
|
||
|
||
func NewTaggedProvisionerDaemon(t testing.TB, coderAPI *coderd.API, name string, provisionerTags map[string]string, opts ...coderd.MemoryProvisionerDaemonOption) io.Closer {
|
||
t.Helper()
|
||
|
||
// t.Cleanup runs in last added, first called order. t.TempDir() will delete
|
||
// the directory on cleanup, so we want to make sure the echoServer is closed
|
||
// before we go ahead an attempt to delete it's work directory.
|
||
// seems t.TempDir() is not safe to call from a different goroutine
|
||
workDir := t.TempDir()
|
||
|
||
echoClient, echoServer := drpcsdk.MemTransportPipe()
|
||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||
t.Cleanup(func() {
|
||
_ = echoClient.Close()
|
||
_ = echoServer.Close()
|
||
cancelFunc()
|
||
})
|
||
|
||
go func() {
|
||
err := echo.Serve(ctx, &provisionersdk.ServeOptions{
|
||
Listener: echoServer,
|
||
WorkDirectory: workDir,
|
||
Logger: coderAPI.Logger.Named("echo").Leveled(slog.LevelDebug),
|
||
})
|
||
assert.NoError(t, err)
|
||
}()
|
||
|
||
connectedCh := make(chan struct{})
|
||
daemon := provisionerd.New(func(dialCtx context.Context) (provisionerdproto.DRPCProvisionerDaemonClient, error) {
|
||
return coderAPI.CreateInMemoryTaggedProvisionerDaemon(dialCtx, name, []codersdk.ProvisionerType{codersdk.ProvisionerTypeEcho}, provisionerTags, opts...)
|
||
}, &provisionerd.Options{
|
||
Logger: coderAPI.Logger.Named("provisionerd").Leveled(slog.LevelDebug),
|
||
UpdateInterval: 250 * time.Millisecond,
|
||
ForceCancelInterval: 5 * time.Second,
|
||
Connector: provisionerd.LocalProvisioners{
|
||
string(database.ProvisionerTypeEcho): sdkproto.NewDRPCProvisionerClient(echoClient),
|
||
},
|
||
InitConnectionCh: connectedCh,
|
||
})
|
||
// Wait for the provisioner daemon to connect before continuing.
|
||
// Users of this function tend to assume that the provisioner is connected
|
||
// and ready to use when that may not strictly be the case.
|
||
<-connectedCh
|
||
closer := NewProvisionerDaemonCloser(daemon)
|
||
t.Cleanup(func() {
|
||
_ = closer.Close()
|
||
})
|
||
return closer
|
||
}
|
||
|
||
var FirstUserParams = codersdk.CreateFirstUserRequest{
|
||
Email: "testuser@coder.com",
|
||
Username: "testuser",
|
||
Password: "SomeSecurePassword!",
|
||
Name: "Test User",
|
||
}
|
||
|
||
var TrialUserParams = codersdk.CreateFirstUserTrialInfo{
|
||
FirstName: "John",
|
||
LastName: "Doe",
|
||
PhoneNumber: "9999999999",
|
||
JobTitle: "Engineer",
|
||
CompanyName: "Acme Inc",
|
||
Country: "United States",
|
||
Developers: "10-50",
|
||
}
|
||
|
||
// CreateFirstUser creates a user with preset credentials and authenticates
|
||
// with the passed in codersdk client.
|
||
func CreateFirstUser(t testing.TB, client *codersdk.Client) codersdk.CreateFirstUserResponse {
|
||
resp, err := client.CreateFirstUser(context.Background(), FirstUserParams)
|
||
require.NoError(t, err)
|
||
|
||
login, err := client.LoginWithPassword(context.Background(), codersdk.LoginWithPasswordRequest{
|
||
Email: FirstUserParams.Email,
|
||
Password: FirstUserParams.Password,
|
||
})
|
||
require.NoError(t, err)
|
||
client.SetSessionToken(login.SessionToken)
|
||
return resp
|
||
}
|
||
|
||
// CreateAnotherUser creates and authenticates a new user.
|
||
// Roles can include org scoped roles with 'roleName:<organization_id>'
|
||
func CreateAnotherUser(t testing.TB, client *codersdk.Client, organizationID uuid.UUID, roles ...rbac.RoleIdentifier) (*codersdk.Client, codersdk.User) {
|
||
return createAnotherUserRetry(t, client, []uuid.UUID{organizationID}, 5, roles)
|
||
}
|
||
|
||
func CreateAnotherUserMutators(t testing.TB, client *codersdk.Client, organizationID uuid.UUID, roles []rbac.RoleIdentifier, mutators ...func(r *codersdk.CreateUserRequestWithOrgs)) (*codersdk.Client, codersdk.User) {
|
||
return createAnotherUserRetry(t, client, []uuid.UUID{organizationID}, 5, roles, mutators...)
|
||
}
|
||
|
||
// AuthzUserSubject does not include the user's groups or the org-member role
|
||
// (which is a db-backed system role).
|
||
func AuthzUserSubject(user codersdk.User) rbac.Subject {
|
||
roles := make(rbac.RoleIdentifiers, 0, len(user.Roles))
|
||
// Member role is always implied
|
||
roles = append(roles, rbac.RoleMember())
|
||
for _, r := range user.Roles {
|
||
orgID, _ := uuid.Parse(r.OrganizationID) // defaults to nil
|
||
roles = append(roles, rbac.RoleIdentifier{
|
||
Name: r.Name,
|
||
OrganizationID: orgID,
|
||
})
|
||
}
|
||
|
||
return rbac.Subject{
|
||
ID: user.ID.String(),
|
||
Roles: roles,
|
||
Groups: []string{},
|
||
Scope: rbac.ScopeAll,
|
||
}
|
||
}
|
||
|
||
// AuthzUserSubjectWithDB is like AuthzUserSubject but adds db-backed roles
|
||
// (like organization-member).
|
||
func AuthzUserSubjectWithDB(ctx context.Context, t testing.TB, db database.Store, user codersdk.User) rbac.Subject {
|
||
t.Helper()
|
||
|
||
roles := make(rbac.RoleIdentifiers, 0, len(user.Roles)+2)
|
||
// Member role is always implied
|
||
roles = append(roles, rbac.RoleMember())
|
||
for _, r := range user.Roles {
|
||
parsedOrgID, _ := uuid.Parse(r.OrganizationID) // defaults to nil
|
||
roles = append(roles, rbac.RoleIdentifier{
|
||
Name: r.Name,
|
||
OrganizationID: parsedOrgID,
|
||
})
|
||
}
|
||
|
||
//nolint:gocritic // We’re constructing the subject. The incoming ctx
|
||
// typically has no dbauthz actor yet, and using AuthzUserSubject(user)
|
||
// here would be circular (it lacks DB-backed org-member roles needed for
|
||
// organization:read). Use system-restricted ctx for the membership lookup.
|
||
orgs, err := db.GetOrganizationsByUserID(dbauthz.AsSystemRestricted(ctx), database.GetOrganizationsByUserIDParams{
|
||
UserID: user.ID,
|
||
Deleted: sql.NullBool{
|
||
Valid: true,
|
||
Bool: false,
|
||
},
|
||
})
|
||
require.NoError(t, err)
|
||
for _, org := range orgs {
|
||
roles = append(roles, rbac.ScopedRoleOrgMember(org.ID))
|
||
}
|
||
|
||
//nolint:gocritic // We need to expand DB-backed/system roles. The caller
|
||
// ctx may not have permission to read system roles, so use system-restricted
|
||
// context for the internal role lookup.
|
||
rbacRoles, err := rolestore.Expand(dbauthz.AsSystemRestricted(ctx), db, roles)
|
||
require.NoError(t, err)
|
||
|
||
return rbac.Subject{
|
||
ID: user.ID.String(),
|
||
Roles: rbacRoles,
|
||
Groups: []string{},
|
||
Scope: rbac.ScopeAll,
|
||
}.WithCachedASTValue()
|
||
}
|
||
|
||
func createAnotherUserRetry(t testing.TB, client *codersdk.Client, organizationIDs []uuid.UUID, retries int, roles []rbac.RoleIdentifier, mutators ...func(r *codersdk.CreateUserRequestWithOrgs)) (*codersdk.Client, codersdk.User) {
|
||
req := codersdk.CreateUserRequestWithOrgs{
|
||
Email: namesgenerator.UniqueName() + "@coder.com",
|
||
Username: RandomUsername(t),
|
||
Name: RandomName(t),
|
||
Password: "SomeSecurePassword!",
|
||
OrganizationIDs: organizationIDs,
|
||
// Always create users as active in tests to ignore an extra audit log
|
||
// when logging in.
|
||
UserStatus: ptr.Ref(codersdk.UserStatusActive),
|
||
}
|
||
for _, m := range mutators {
|
||
m(&req)
|
||
}
|
||
|
||
// Service accounts cannot have a password or email and must
|
||
// use login_type=none. Enforce this after mutators so callers
|
||
// only need to set ServiceAccount=true.
|
||
if req.ServiceAccount {
|
||
req.Password = ""
|
||
req.Email = ""
|
||
req.UserLoginType = codersdk.LoginTypeNone
|
||
}
|
||
|
||
user, err := client.CreateUserWithOrgs(context.Background(), req)
|
||
var apiError *codersdk.Error
|
||
// If the user already exists by username or email conflict, try again up to "retries" times.
|
||
if err != nil && retries >= 0 && xerrors.As(err, &apiError) {
|
||
if apiError.StatusCode() == http.StatusConflict {
|
||
retries--
|
||
return createAnotherUserRetry(t, client, organizationIDs, retries, roles)
|
||
}
|
||
}
|
||
require.NoError(t, err)
|
||
|
||
var sessionToken string
|
||
switch req.UserLoginType {
|
||
case codersdk.LoginTypeNone, codersdk.LoginTypeGithub, codersdk.LoginTypeOIDC:
|
||
// Cannot log in with a non-password user. So make it an api key from the
|
||
// client making this user.
|
||
token, err := client.CreateToken(context.Background(), user.ID.String(), codersdk.CreateTokenRequest{
|
||
Lifetime: time.Hour * 24,
|
||
Scope: codersdk.APIKeyScopeAll,
|
||
TokenName: "no-password-user-token",
|
||
})
|
||
require.NoError(t, err)
|
||
sessionToken = token.Key
|
||
default:
|
||
login, err := client.LoginWithPassword(context.Background(), codersdk.LoginWithPasswordRequest{
|
||
Email: req.Email,
|
||
Password: req.Password,
|
||
})
|
||
require.NoError(t, err)
|
||
sessionToken = login.SessionToken
|
||
}
|
||
|
||
if user.Status == codersdk.UserStatusDormant {
|
||
// Use admin client so that user's LastSeenAt is not updated.
|
||
// In general we need to refresh the user status, which should
|
||
// transition from "dormant" to "active".
|
||
user, err = client.User(context.Background(), user.Username)
|
||
require.NoError(t, err)
|
||
}
|
||
|
||
other := codersdk.New(client.URL, codersdk.WithSessionToken(sessionToken))
|
||
t.Cleanup(func() {
|
||
other.HTTPClient.CloseIdleConnections()
|
||
})
|
||
|
||
if len(roles) > 0 {
|
||
// Find the roles for the org vs the site wide roles
|
||
orgRoles := make(map[uuid.UUID][]rbac.RoleIdentifier)
|
||
var siteRoles []rbac.RoleIdentifier
|
||
|
||
for _, roleName := range roles {
|
||
ok := roleName.IsOrgRole()
|
||
if ok {
|
||
orgRoles[roleName.OrganizationID] = append(orgRoles[roleName.OrganizationID], roleName)
|
||
} else {
|
||
siteRoles = append(siteRoles, roleName)
|
||
}
|
||
}
|
||
// Update the roles
|
||
for _, r := range user.Roles {
|
||
orgID, _ := uuid.Parse(r.OrganizationID)
|
||
siteRoles = append(siteRoles, rbac.RoleIdentifier{
|
||
Name: r.Name,
|
||
OrganizationID: orgID,
|
||
})
|
||
}
|
||
|
||
onlyName := func(role rbac.RoleIdentifier) string {
|
||
return role.Name
|
||
}
|
||
|
||
user, err = client.UpdateUserRoles(context.Background(), user.ID.String(), codersdk.UpdateRoles{Roles: slice.List(siteRoles, onlyName)})
|
||
require.NoError(t, err, "update site roles")
|
||
|
||
// isMember keeps track of which orgs the user was added to as a member
|
||
isMember := make(map[uuid.UUID]bool)
|
||
for _, orgID := range organizationIDs {
|
||
isMember[orgID] = true
|
||
}
|
||
|
||
// Update org roles
|
||
for orgID, roles := range orgRoles {
|
||
// The user must be an organization of any orgRoles, so insert
|
||
// the organization member, then assign the roles.
|
||
if !isMember[orgID] {
|
||
_, err = client.PostOrganizationMember(context.Background(), orgID, user.ID.String())
|
||
require.NoError(t, err, "add user to organization as member")
|
||
}
|
||
|
||
_, err = client.UpdateOrganizationMemberRoles(context.Background(), orgID, user.ID.String(),
|
||
codersdk.UpdateRoles{Roles: slice.List(roles, onlyName)})
|
||
require.NoError(t, err, "update org membership roles")
|
||
isMember[orgID] = true
|
||
}
|
||
}
|
||
|
||
user, err = client.User(context.Background(), user.Username)
|
||
require.NoError(t, err, "update final user")
|
||
|
||
return other, user
|
||
}
|
||
|
||
func CreateTemplateVersionMimeType(t testing.TB, client *codersdk.Client, mimeType string, organizationID uuid.UUID, res *echo.Responses, mutators ...func(*codersdk.CreateTemplateVersionRequest)) codersdk.TemplateVersion {
|
||
t.Helper()
|
||
data, err := echo.TarWithOptions(context.Background(), client.Logger(), res)
|
||
require.NoError(t, err)
|
||
|
||
switch mimeType {
|
||
case codersdk.ContentTypeTar:
|
||
// do nothing
|
||
case codersdk.ContentTypeZip:
|
||
data, err = archive.CreateZipFromTar(tar.NewReader(bytes.NewBuffer(data)), int64(len(data)))
|
||
require.NoError(t, err, "creating zip")
|
||
default:
|
||
t.Fatal("unexpected mime type", mimeType)
|
||
}
|
||
|
||
file, err := client.Upload(context.Background(), mimeType, bytes.NewReader(data))
|
||
require.NoError(t, err)
|
||
|
||
req := codersdk.CreateTemplateVersionRequest{
|
||
FileID: file.ID,
|
||
StorageMethod: codersdk.ProvisionerStorageMethodFile,
|
||
Provisioner: codersdk.ProvisionerTypeEcho,
|
||
}
|
||
for _, mut := range mutators {
|
||
mut(&req)
|
||
}
|
||
|
||
templateVersion, err := client.CreateTemplateVersion(context.Background(), organizationID, req)
|
||
require.NoError(t, err)
|
||
return templateVersion
|
||
}
|
||
|
||
// CreateTemplateVersion creates a template import provisioner job
|
||
// with the responses provided. It uses the "echo" provisioner for compatibility
|
||
// with testing.
|
||
func CreateTemplateVersion(t testing.TB, client *codersdk.Client, organizationID uuid.UUID, res *echo.Responses, mutators ...func(*codersdk.CreateTemplateVersionRequest)) codersdk.TemplateVersion {
|
||
return CreateTemplateVersionMimeType(t, client, codersdk.ContentTypeTar, organizationID, res, mutators...)
|
||
}
|
||
|
||
// CreateWorkspaceBuild creates a workspace build for the given workspace and transition.
|
||
func CreateWorkspaceBuild(
|
||
t *testing.T,
|
||
client *codersdk.Client,
|
||
workspace codersdk.Workspace,
|
||
transition database.WorkspaceTransition,
|
||
mutators ...func(*codersdk.CreateWorkspaceBuildRequest),
|
||
) codersdk.WorkspaceBuild {
|
||
t.Helper()
|
||
|
||
req := codersdk.CreateWorkspaceBuildRequest{
|
||
Transition: codersdk.WorkspaceTransition(transition),
|
||
}
|
||
for _, mut := range mutators {
|
||
mut(&req)
|
||
}
|
||
build, err := client.CreateWorkspaceBuild(context.Background(), workspace.ID, req)
|
||
require.NoError(t, err)
|
||
return build
|
||
}
|
||
|
||
// CreateTemplate creates a template with the "echo" provisioner for
|
||
// compatibility with testing. The name assigned is randomly generated.
|
||
func CreateTemplate(t testing.TB, client *codersdk.Client, organization uuid.UUID, version uuid.UUID, mutators ...func(*codersdk.CreateTemplateRequest)) codersdk.Template {
|
||
req := codersdk.CreateTemplateRequest{
|
||
Name: RandomUsername(t),
|
||
VersionID: version,
|
||
}
|
||
for _, mut := range mutators {
|
||
mut(&req)
|
||
}
|
||
template, err := client.CreateTemplate(context.Background(), organization, req)
|
||
require.NoError(t, err)
|
||
return template
|
||
}
|
||
|
||
// CreateGroup creates a group with the given name and members.
|
||
func CreateGroup(t testing.TB, client *codersdk.Client, organizationID uuid.UUID, name string, members ...codersdk.User) codersdk.Group {
|
||
t.Helper()
|
||
group, err := client.CreateGroup(context.Background(), organizationID, codersdk.CreateGroupRequest{
|
||
Name: name,
|
||
})
|
||
require.NoError(t, err, "failed to create group")
|
||
memberIDs := make([]string, 0)
|
||
for _, member := range members {
|
||
memberIDs = append(memberIDs, member.ID.String())
|
||
}
|
||
group, err = client.PatchGroup(context.Background(), group.ID, codersdk.PatchGroupRequest{
|
||
AddUsers: memberIDs,
|
||
})
|
||
|
||
require.NoError(t, err, "failed to add members to group")
|
||
return group
|
||
}
|
||
|
||
// UpdateTemplateVersion creates a new template version with the "echo" provisioner
|
||
// and associates it with the given templateID.
|
||
func UpdateTemplateVersion(t testing.TB, client *codersdk.Client, organizationID uuid.UUID, res *echo.Responses, templateID uuid.UUID) codersdk.TemplateVersion {
|
||
ctx := context.Background()
|
||
data, err := echo.Tar(res)
|
||
require.NoError(t, err)
|
||
file, err := client.Upload(ctx, codersdk.ContentTypeTar, bytes.NewReader(data))
|
||
require.NoError(t, err)
|
||
templateVersion, err := client.CreateTemplateVersion(ctx, organizationID, codersdk.CreateTemplateVersionRequest{
|
||
TemplateID: templateID,
|
||
FileID: file.ID,
|
||
StorageMethod: codersdk.ProvisionerStorageMethodFile,
|
||
Provisioner: codersdk.ProvisionerTypeEcho,
|
||
})
|
||
require.NoError(t, err)
|
||
return templateVersion
|
||
}
|
||
|
||
func UpdateActiveTemplateVersion(t testing.TB, client *codersdk.Client, templateID, versionID uuid.UUID) {
|
||
err := client.UpdateActiveTemplateVersion(context.Background(), templateID, codersdk.UpdateActiveTemplateVersion{
|
||
ID: versionID,
|
||
})
|
||
require.NoError(t, err)
|
||
}
|
||
|
||
// UpdateTemplateMeta updates the template meta for the given template.
|
||
func UpdateTemplateMeta(t testing.TB, client *codersdk.Client, templateID uuid.UUID, meta codersdk.UpdateTemplateMeta) codersdk.Template {
|
||
t.Helper()
|
||
updated, err := client.UpdateTemplateMeta(context.Background(), templateID, meta)
|
||
require.NoError(t, err)
|
||
return updated
|
||
}
|
||
|
||
// AwaitTemplateVersionJobRunning waits for the build to be picked up by a provisioner.
|
||
func AwaitTemplateVersionJobRunning(t testing.TB, client *codersdk.Client, version uuid.UUID) codersdk.TemplateVersion {
|
||
t.Helper()
|
||
|
||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||
defer cancel()
|
||
|
||
t.Logf("waiting for template version %s build job to start", version)
|
||
var templateVersion codersdk.TemplateVersion
|
||
require.Eventually(t, func() bool {
|
||
var err error
|
||
templateVersion, err = client.TemplateVersion(ctx, version)
|
||
if err != nil {
|
||
return false
|
||
}
|
||
t.Logf("template version job status: %s", templateVersion.Job.Status)
|
||
switch templateVersion.Job.Status {
|
||
case codersdk.ProvisionerJobPending:
|
||
return false
|
||
case codersdk.ProvisionerJobRunning:
|
||
return true
|
||
default:
|
||
t.FailNow()
|
||
return false
|
||
}
|
||
}, testutil.WaitShort, testutil.IntervalFast, "make sure you set `IncludeProvisionerDaemon`!")
|
||
t.Logf("template version %s job has started", version)
|
||
return templateVersion
|
||
}
|
||
|
||
// AwaitTemplateVersionJobCompleted waits for the build to be completed. This may result
|
||
// from cancelation, an error, or from completing successfully.
|
||
func AwaitTemplateVersionJobCompleted(t testing.TB, client *codersdk.Client, version uuid.UUID) codersdk.TemplateVersion {
|
||
t.Helper()
|
||
|
||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||
defer cancel()
|
||
|
||
t.Logf("waiting for template version %s build job to complete", version)
|
||
var templateVersion codersdk.TemplateVersion
|
||
require.Eventually(t, func() bool {
|
||
var err error
|
||
templateVersion, err = client.TemplateVersion(ctx, version)
|
||
t.Logf("template version job status: %s", templateVersion.Job.Status)
|
||
return assert.NoError(t, err) && templateVersion.Job.CompletedAt != nil
|
||
}, testutil.WaitLong, testutil.IntervalFast, "make sure you set `IncludeProvisionerDaemon`!")
|
||
t.Logf("template version %s job has completed", version)
|
||
return templateVersion
|
||
}
|
||
|
||
// AwaitWorkspaceBuildJobCompleted waits for a workspace provision job to reach completed status.
|
||
func AwaitWorkspaceBuildJobCompleted(t testing.TB, client *codersdk.Client, build uuid.UUID) codersdk.WorkspaceBuild {
|
||
t.Helper()
|
||
|
||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||
defer cancel()
|
||
|
||
t.Logf("waiting for workspace build job %s", build)
|
||
var workspaceBuild codersdk.WorkspaceBuild
|
||
require.Eventually(t, func() bool {
|
||
var err error
|
||
workspaceBuild, err = client.WorkspaceBuild(ctx, build)
|
||
if err != nil {
|
||
t.Logf("failed to get workspace build %s: %v", build, err)
|
||
return false
|
||
}
|
||
if workspaceBuild.Job.CompletedAt == nil {
|
||
t.Logf("workspace build job %s still running (status: %s)", build, workspaceBuild.Job.Status)
|
||
return false
|
||
}
|
||
return true
|
||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||
t.Logf("got workspace build job %s (status: %s)", build, workspaceBuild.Job.Status)
|
||
return workspaceBuild
|
||
}
|
||
|
||
// AwaitWorkspaceAgents waits for all resources with agents to be connected. If
|
||
// specific agents are provided, it will wait for those agents to be connected
|
||
// but will not fail if other agents are not connected.
|
||
//
|
||
// Deprecated: Use NewWorkspaceAgentWaiter
|
||
func AwaitWorkspaceAgents(t testing.TB, client *codersdk.Client, workspaceID uuid.UUID, agentNames ...string) []codersdk.WorkspaceResource {
|
||
return NewWorkspaceAgentWaiter(t, client, workspaceID).AgentNames(agentNames).Wait()
|
||
}
|
||
|
||
// WorkspaceAgentWaiter waits for all resources with agents to be connected. If
|
||
// specific agents are provided using AgentNames(), it will wait for those agents
|
||
// to be connected but will not fail if other agents are not connected.
|
||
type WorkspaceAgentWaiter struct {
|
||
t testing.TB
|
||
client *codersdk.Client
|
||
workspaceID uuid.UUID
|
||
agentNames []string
|
||
resourcesMatcher func([]codersdk.WorkspaceResource) bool
|
||
ctx context.Context
|
||
}
|
||
|
||
// NewWorkspaceAgentWaiter returns an object that waits for agents to connect when
|
||
// you call Wait() on it.
|
||
func NewWorkspaceAgentWaiter(t testing.TB, client *codersdk.Client, workspaceID uuid.UUID) WorkspaceAgentWaiter {
|
||
return WorkspaceAgentWaiter{
|
||
t: t,
|
||
client: client,
|
||
workspaceID: workspaceID,
|
||
}
|
||
}
|
||
|
||
// AgentNames instructs the waiter to wait for the given, named agents to be connected and will
|
||
// return even if other agents are not connected.
|
||
func (w WorkspaceAgentWaiter) AgentNames(names []string) WorkspaceAgentWaiter {
|
||
//nolint: revive // returns modified struct
|
||
w.agentNames = names
|
||
return w
|
||
}
|
||
|
||
// MatchResources instructs the waiter to wait until the workspace has resources that cause the
|
||
// provided matcher function to return true.
|
||
func (w WorkspaceAgentWaiter) MatchResources(m func([]codersdk.WorkspaceResource) bool) WorkspaceAgentWaiter {
|
||
//nolint: revive // returns modified struct
|
||
w.resourcesMatcher = m
|
||
return w
|
||
}
|
||
|
||
// WithContext instructs the waiter to use the provided context for all operations.
|
||
// If not specified, the waiter will create its own context with testutil.WaitLong timeout.
|
||
func (w WorkspaceAgentWaiter) WithContext(ctx context.Context) WorkspaceAgentWaiter {
|
||
//nolint: revive // returns modified struct
|
||
w.ctx = ctx
|
||
return w
|
||
}
|
||
|
||
// WaitForAgentFn represents a boolean assertion to be made against each agent
|
||
// that a given WorkspaceAgentWaited knows about. Each WaitForAgentFn should apply
|
||
// the check to a single agent, but it should be named for plural, because `func (w WorkspaceAgentWaiter) WaitFor`
|
||
// applies the check to all agents that it is aware of. This ensures that the public API of the waiter
|
||
// reads correctly. For example:
|
||
//
|
||
// waiter := coderdtest.NewWorkspaceAgentWaiter(t, client, r.Workspace.ID)
|
||
// waiter.WaitFor(coderdtest.AgentsReady)
|
||
type WaitForAgentFn func(agent codersdk.WorkspaceAgent) bool
|
||
|
||
// AgentsReady checks that the latest lifecycle state of an agent is "Ready".
|
||
func AgentsReady(agent codersdk.WorkspaceAgent) bool {
|
||
return agent.LifecycleState == codersdk.WorkspaceAgentLifecycleReady
|
||
}
|
||
|
||
// AgentsNotReady checks that the latest lifecycle state of an agent is anything except "Ready".
|
||
func AgentsNotReady(agent codersdk.WorkspaceAgent) bool {
|
||
return !AgentsReady(agent)
|
||
}
|
||
|
||
// WaitFor waits for the given criteria and fails the test if they are not met before the
|
||
// waiter's context is canceled.
|
||
func (w WorkspaceAgentWaiter) WaitFor(criteria ...WaitForAgentFn) {
|
||
w.t.Helper()
|
||
|
||
agentNamesMap := make(map[string]struct{}, len(w.agentNames))
|
||
for _, name := range w.agentNames {
|
||
agentNamesMap[name] = struct{}{}
|
||
}
|
||
|
||
ctx := w.ctx
|
||
if w.ctx == nil {
|
||
ctx = testutil.Context(w.t, testutil.WaitLong)
|
||
}
|
||
|
||
w.t.Logf("waiting for workspace agents (workspace %s)", w.workspaceID)
|
||
testutil.Eventually(ctx, w.t, func(ctx context.Context) bool {
|
||
var err error
|
||
workspace, err := w.client.Workspace(ctx, w.workspaceID)
|
||
if err != nil {
|
||
return false
|
||
}
|
||
if workspace.LatestBuild.Job.CompletedAt == nil {
|
||
return false
|
||
}
|
||
if workspace.LatestBuild.Job.CompletedAt.IsZero() {
|
||
return false
|
||
}
|
||
|
||
for _, resource := range workspace.LatestBuild.Resources {
|
||
for _, agent := range resource.Agents {
|
||
if len(w.agentNames) > 0 {
|
||
if _, ok := agentNamesMap[agent.Name]; !ok {
|
||
continue
|
||
}
|
||
}
|
||
for _, criterium := range criteria {
|
||
if !criterium(agent) {
|
||
return false
|
||
}
|
||
}
|
||
}
|
||
}
|
||
return true
|
||
}, testutil.IntervalFast)
|
||
}
|
||
|
||
// Wait waits for the agent(s) to connect and fails the test if they do not connect before the
|
||
// waiter's context is canceled.
|
||
func (w WorkspaceAgentWaiter) Wait() []codersdk.WorkspaceResource {
|
||
w.t.Helper()
|
||
|
||
agentNamesMap := make(map[string]struct{}, len(w.agentNames))
|
||
for _, name := range w.agentNames {
|
||
agentNamesMap[name] = struct{}{}
|
||
}
|
||
|
||
ctx := w.ctx
|
||
if w.ctx == nil {
|
||
ctx = testutil.Context(w.t, testutil.WaitLong)
|
||
}
|
||
|
||
w.t.Logf("waiting for workspace agents (workspace %s)", w.workspaceID)
|
||
var resources []codersdk.WorkspaceResource
|
||
testutil.Eventually(ctx, w.t, func(ctx context.Context) bool {
|
||
var err error
|
||
workspace, err := w.client.Workspace(ctx, w.workspaceID)
|
||
if err != nil {
|
||
return false
|
||
}
|
||
if workspace.LatestBuild.Job.CompletedAt == nil {
|
||
return false
|
||
}
|
||
if workspace.LatestBuild.Job.CompletedAt.IsZero() {
|
||
return false
|
||
}
|
||
|
||
for _, resource := range workspace.LatestBuild.Resources {
|
||
for _, agent := range resource.Agents {
|
||
if len(w.agentNames) > 0 {
|
||
if _, ok := agentNamesMap[agent.Name]; !ok {
|
||
continue
|
||
}
|
||
}
|
||
|
||
if agent.Status != codersdk.WorkspaceAgentConnected {
|
||
w.t.Logf("agent %s not connected yet", agent.Name)
|
||
return false
|
||
}
|
||
}
|
||
}
|
||
resources = workspace.LatestBuild.Resources
|
||
if w.resourcesMatcher == nil {
|
||
return true
|
||
}
|
||
return w.resourcesMatcher(resources)
|
||
}, testutil.IntervalFast)
|
||
w.t.Logf("got workspace agents (workspace %s)", w.workspaceID)
|
||
return resources
|
||
}
|
||
|
||
// CreateWorkspace creates a workspace for the user and template provided.
|
||
// A random name is generated for it.
|
||
// To customize the defaults, pass a mutator func.
|
||
func CreateWorkspace(t testing.TB, client *codersdk.Client, templateID uuid.UUID, mutators ...func(*codersdk.CreateWorkspaceRequest)) codersdk.Workspace {
|
||
t.Helper()
|
||
req := codersdk.CreateWorkspaceRequest{
|
||
TemplateID: templateID,
|
||
Name: RandomUsername(t),
|
||
AutostartSchedule: ptr.Ref("CRON_TZ=US/Central 30 9 * * 1-5"),
|
||
TTLMillis: ptr.Ref((8 * time.Hour).Milliseconds()),
|
||
AutomaticUpdates: codersdk.AutomaticUpdatesNever,
|
||
}
|
||
for _, mutator := range mutators {
|
||
mutator(&req)
|
||
}
|
||
workspace, err := client.CreateUserWorkspace(context.Background(), codersdk.Me, req)
|
||
require.NoError(t, err)
|
||
return workspace
|
||
}
|
||
|
||
// TransitionWorkspace is a convenience method for transitioning a workspace from one state to another.
|
||
func MustTransitionWorkspace(t testing.TB, client *codersdk.Client, workspaceID uuid.UUID, from, to codersdk.WorkspaceTransition, muts ...func(req *codersdk.CreateWorkspaceBuildRequest)) codersdk.Workspace {
|
||
t.Helper()
|
||
ctx := context.Background()
|
||
workspace, err := client.Workspace(ctx, workspaceID)
|
||
require.NoError(t, err, "unexpected error fetching workspace")
|
||
require.Equal(t, workspace.LatestBuild.Transition, from, "expected workspace state: %s got: %s", from, workspace.LatestBuild.Transition)
|
||
|
||
req := codersdk.CreateWorkspaceBuildRequest{
|
||
TemplateVersionID: workspace.LatestBuild.TemplateVersionID,
|
||
Transition: to,
|
||
}
|
||
|
||
for _, mut := range muts {
|
||
mut(&req)
|
||
}
|
||
|
||
build, err := client.CreateWorkspaceBuild(ctx, workspace.ID, req)
|
||
require.NoError(t, err, "unexpected error transitioning workspace to %s", to)
|
||
|
||
_ = AwaitWorkspaceBuildJobCompleted(t, client, build.ID)
|
||
|
||
updated := MustWorkspace(t, client, workspace.ID)
|
||
require.Equal(t, to, updated.LatestBuild.Transition, "expected workspace to be in state %s but got %s", to, updated.LatestBuild.Transition)
|
||
return updated
|
||
}
|
||
|
||
// MustWorkspace is a convenience method for fetching a workspace that should exist.
|
||
func MustWorkspace(t testing.TB, client *codersdk.Client, workspaceID uuid.UUID) codersdk.Workspace {
|
||
t.Helper()
|
||
ctx := context.Background()
|
||
ws, err := client.Workspace(ctx, workspaceID)
|
||
if err != nil && strings.Contains(err.Error(), "status code 410") {
|
||
ws, err = client.DeletedWorkspace(ctx, workspaceID)
|
||
}
|
||
require.NoError(t, err, "no workspace found with id %s", workspaceID)
|
||
return ws
|
||
}
|
||
|
||
// RequestExternalAuthCallback makes a request with the proper OAuth2 state cookie
|
||
// to the external auth callback endpoint.
|
||
func RequestExternalAuthCallback(t testing.TB, providerID string, client *codersdk.Client, opts ...func(*http.Request)) *http.Response {
|
||
client.HTTPClient.CheckRedirect = func(_ *http.Request, _ []*http.Request) error {
|
||
return http.ErrUseLastResponse
|
||
}
|
||
state := "somestate"
|
||
oauthURL, err := client.URL.Parse(fmt.Sprintf("/external-auth/%s/callback?code=asd&state=%s", providerID, state))
|
||
require.NoError(t, err)
|
||
req, err := http.NewRequestWithContext(context.Background(), "GET", oauthURL.String(), nil)
|
||
require.NoError(t, err)
|
||
req.AddCookie(&http.Cookie{
|
||
Name: codersdk.OAuth2StateCookie,
|
||
Value: state,
|
||
})
|
||
req.AddCookie(&http.Cookie{
|
||
Name: codersdk.SessionTokenCookie,
|
||
Value: client.SessionToken(),
|
||
})
|
||
for _, opt := range opts {
|
||
opt(req)
|
||
}
|
||
res, err := client.HTTPClient.Do(req)
|
||
require.NoError(t, err)
|
||
t.Cleanup(func() {
|
||
_ = res.Body.Close()
|
||
})
|
||
return res
|
||
}
|
||
|
||
// NewGoogleInstanceIdentity returns a metadata client and ID token validator for faking
|
||
// instance authentication for Google Cloud.
|
||
// nolint:revive
|
||
func NewGoogleInstanceIdentity(t testing.TB, instanceID string, expired bool) (*idtoken.Validator, *metadata.Client) {
|
||
keyID, err := cryptorand.String(12)
|
||
require.NoError(t, err)
|
||
claims := jwt.MapClaims{
|
||
"google": map[string]interface{}{
|
||
"compute_engine": map[string]string{
|
||
"instance_id": instanceID,
|
||
},
|
||
},
|
||
}
|
||
if !expired {
|
||
claims["exp"] = time.Now().AddDate(1, 0, 0).Unix()
|
||
}
|
||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
||
token.Header["kid"] = keyID
|
||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||
require.NoError(t, err)
|
||
signedKey, err := token.SignedString(privateKey)
|
||
require.NoError(t, err)
|
||
|
||
// Taken from: https://github.com/googleapis/google-api-go-client/blob/4bb729045d611fa77bdbeb971f6a1204ba23161d/idtoken/validate.go#L57-L75
|
||
type jwk struct {
|
||
Kid string `json:"kid"`
|
||
N string `json:"n"`
|
||
E string `json:"e"`
|
||
}
|
||
type certResponse struct {
|
||
Keys []jwk `json:"keys"`
|
||
}
|
||
|
||
validator, err := idtoken.NewValidator(context.Background(), option.WithHTTPClient(&http.Client{
|
||
Transport: roundTripper(func(r *http.Request) (*http.Response, error) {
|
||
data, err := json.Marshal(certResponse{
|
||
Keys: []jwk{{
|
||
Kid: keyID,
|
||
N: base64.RawURLEncoding.EncodeToString(privateKey.N.Bytes()),
|
||
E: base64.RawURLEncoding.EncodeToString(new(big.Int).SetInt64(int64(privateKey.E)).Bytes()),
|
||
}},
|
||
})
|
||
require.NoError(t, err)
|
||
return &http.Response{
|
||
StatusCode: http.StatusOK,
|
||
Body: io.NopCloser(bytes.NewReader(data)),
|
||
Header: make(http.Header),
|
||
}, nil
|
||
}),
|
||
}))
|
||
require.NoError(t, err)
|
||
|
||
return validator, metadata.NewClient(&http.Client{
|
||
Transport: roundTripper(func(r *http.Request) (*http.Response, error) {
|
||
return &http.Response{
|
||
StatusCode: http.StatusOK,
|
||
Body: io.NopCloser(bytes.NewReader([]byte(signedKey))),
|
||
Header: make(http.Header),
|
||
}, nil
|
||
}),
|
||
})
|
||
}
|
||
|
||
// NewAWSInstanceIdentity returns a metadata client and ID token validator for faking
|
||
// instance authentication for AWS.
|
||
func NewAWSInstanceIdentity(t testing.TB, instanceID string) (awsidentity.Certificates, *http.Client) {
|
||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||
require.NoError(t, err)
|
||
|
||
document := []byte(`{"instanceId":"` + instanceID + `"}`)
|
||
hashedDocument := sha256.Sum256(document)
|
||
|
||
signatureRaw, err := rsa.SignPKCS1v15(rand.Reader, privateKey, crypto.SHA256, hashedDocument[:])
|
||
require.NoError(t, err)
|
||
signature := make([]byte, base64.StdEncoding.EncodedLen(len(signatureRaw)))
|
||
base64.StdEncoding.Encode(signature, signatureRaw)
|
||
|
||
certificate, err := x509.CreateCertificate(rand.Reader, &x509.Certificate{
|
||
SerialNumber: big.NewInt(2022),
|
||
}, &x509.Certificate{}, &privateKey.PublicKey, privateKey)
|
||
require.NoError(t, err)
|
||
|
||
certificatePEM := bytes.Buffer{}
|
||
err = pem.Encode(&certificatePEM, &pem.Block{
|
||
Type: "CERTIFICATE",
|
||
Bytes: certificate,
|
||
})
|
||
require.NoError(t, err)
|
||
|
||
return awsidentity.Certificates{
|
||
awsidentity.Other: certificatePEM.String(),
|
||
}, &http.Client{
|
||
Transport: roundTripper(func(r *http.Request) (*http.Response, error) {
|
||
// Only handle metadata server requests.
|
||
if r.URL.Host != "169.254.169.254" {
|
||
return http.DefaultTransport.RoundTrip(r)
|
||
}
|
||
switch r.URL.Path {
|
||
case "/latest/api/token":
|
||
return &http.Response{
|
||
StatusCode: http.StatusOK,
|
||
Body: io.NopCloser(bytes.NewReader([]byte("faketoken"))),
|
||
Header: make(http.Header),
|
||
}, nil
|
||
case "/latest/dynamic/instance-identity/signature":
|
||
return &http.Response{
|
||
StatusCode: http.StatusOK,
|
||
Body: io.NopCloser(bytes.NewReader(signature)),
|
||
Header: make(http.Header),
|
||
}, nil
|
||
case "/latest/dynamic/instance-identity/document":
|
||
return &http.Response{
|
||
StatusCode: http.StatusOK,
|
||
Body: io.NopCloser(bytes.NewReader(document)),
|
||
Header: make(http.Header),
|
||
}, nil
|
||
default:
|
||
panic("unhandled route: " + r.URL.Path)
|
||
}
|
||
}),
|
||
}
|
||
}
|
||
|
||
// NewAzureInstanceIdentity returns a metadata client and ID token
|
||
// validator for faking instance authentication for Azure. It builds
|
||
// a realistic 3-level certificate chain (Root CA -> Intermediate ->
|
||
// Signing Cert) to match the real Azure trust hierarchy.
|
||
func NewAzureInstanceIdentity(t testing.TB, instanceID string) (azureidentity.Options, *http.Client) {
|
||
// Root CA (self-signed, trusted).
|
||
rootKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||
require.NoError(t, err)
|
||
rootTmpl := &x509.Certificate{
|
||
SerialNumber: big.NewInt(1),
|
||
Subject: pkix.Name{CommonName: "Test Root CA"},
|
||
NotBefore: time.Now().Add(-time.Hour),
|
||
NotAfter: time.Now().AddDate(10, 0, 0),
|
||
IsCA: true,
|
||
BasicConstraintsValid: true,
|
||
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
|
||
}
|
||
rootDER, err := x509.CreateCertificate(rand.Reader, rootTmpl, rootTmpl, &rootKey.PublicKey, rootKey)
|
||
require.NoError(t, err)
|
||
rootCert, err := x509.ParseCertificate(rootDER)
|
||
require.NoError(t, err)
|
||
|
||
// Intermediate CA (signed by root).
|
||
interKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||
require.NoError(t, err)
|
||
interTmpl := &x509.Certificate{
|
||
SerialNumber: big.NewInt(2),
|
||
Subject: pkix.Name{CommonName: "Test Intermediate CA"},
|
||
NotBefore: time.Now().Add(-time.Hour),
|
||
NotAfter: time.Now().AddDate(5, 0, 0),
|
||
IsCA: true,
|
||
BasicConstraintsValid: true,
|
||
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
|
||
}
|
||
interDER, err := x509.CreateCertificate(rand.Reader, interTmpl, rootCert, &interKey.PublicKey, rootKey)
|
||
require.NoError(t, err)
|
||
interCert, err := x509.ParseCertificate(interDER)
|
||
require.NoError(t, err)
|
||
|
||
// Signing cert (leaf, signed by intermediate).
|
||
signKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||
require.NoError(t, err)
|
||
signTmpl := &x509.Certificate{
|
||
SerialNumber: big.NewInt(3),
|
||
Subject: pkix.Name{CommonName: "metadata.azure.com"},
|
||
NotBefore: time.Now().Add(-time.Hour),
|
||
NotAfter: time.Now().AddDate(1, 0, 0),
|
||
}
|
||
signDER, err := x509.CreateCertificate(rand.Reader, signTmpl, interCert, &signKey.PublicKey, interKey)
|
||
require.NoError(t, err)
|
||
signCert, err := x509.ParseCertificate(signDER)
|
||
require.NoError(t, err)
|
||
|
||
// Build PKCS7 signed data with only the signing cert.
|
||
signed, err := pkcs7.NewSignedData([]byte(`{"vmId":"` + instanceID + `"}`))
|
||
require.NoError(t, err)
|
||
err = signed.AddSigner(signCert, signKey, pkcs7.SignerInfoConfig{})
|
||
require.NoError(t, err)
|
||
signatureRaw, err := signed.Finish()
|
||
require.NoError(t, err)
|
||
signature := make([]byte, base64.StdEncoding.EncodedLen(len(signatureRaw)))
|
||
base64.StdEncoding.Encode(signature, signatureRaw)
|
||
|
||
payload, err := json.Marshal(agentsdk.AzureInstanceIdentityToken{
|
||
Signature: string(signature),
|
||
Encoding: "pkcs7",
|
||
})
|
||
require.NoError(t, err)
|
||
|
||
roots := x509.NewCertPool()
|
||
roots.AddCert(rootCert)
|
||
|
||
return azureidentity.Options{
|
||
Roots: roots,
|
||
Intermediates: []*x509.Certificate{interCert},
|
||
}, &http.Client{
|
||
Transport: roundTripper(func(r *http.Request) (*http.Response, error) {
|
||
// Only handle metadata server requests.
|
||
if r.URL.Host != "169.254.169.254" {
|
||
return http.DefaultTransport.RoundTrip(r)
|
||
}
|
||
switch r.URL.Path {
|
||
case "/metadata/attested/document":
|
||
return &http.Response{
|
||
StatusCode: http.StatusOK,
|
||
Body: io.NopCloser(bytes.NewReader(payload)),
|
||
Header: make(http.Header),
|
||
}, nil
|
||
default:
|
||
panic("unhandled route: " + r.URL.Path)
|
||
}
|
||
}),
|
||
}
|
||
}
|
||
|
||
func RandomUsername(_ testing.TB) string {
|
||
return namesgenerator.UniqueNameWith("-")
|
||
}
|
||
|
||
// RandomName returns a random name in title case (e.g. "Happy Einstein").
|
||
func RandomName(_ testing.TB) string {
|
||
return cases.Title(language.English).String(
|
||
namesgenerator.NameWith(" "),
|
||
)
|
||
}
|
||
|
||
// Used to easily create an HTTP transport!
|
||
type roundTripper func(req *http.Request) (*http.Response, error)
|
||
|
||
func (r roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||
return r(req)
|
||
}
|
||
|
||
type nopcloser struct{}
|
||
|
||
func (nopcloser) Close() error { return nil }
|
||
|
||
// SDKError coerces err into an SDK error.
|
||
func SDKError(t testing.TB, err error) *codersdk.Error {
|
||
var cerr *codersdk.Error
|
||
require.True(t, errors.As(err, &cerr), "should be SDK error, got %s", err)
|
||
return cerr
|
||
}
|
||
|
||
func DeploymentValues(t testing.TB, mut ...func(*codersdk.DeploymentValues)) *codersdk.DeploymentValues {
|
||
cfg := &codersdk.DeploymentValues{}
|
||
opts := cfg.Options()
|
||
err := opts.SetDefaults()
|
||
require.NoError(t, err)
|
||
for _, fn := range mut {
|
||
fn(cfg)
|
||
}
|
||
return cfg
|
||
}
|
||
|
||
// GetProvisionerForTags returns the first valid provisioner for a workspace + template tags.
|
||
func GetProvisionerForTags(tx database.Store, curTime time.Time, orgID uuid.UUID, tags map[string]string) (database.ProvisionerDaemon, error) {
|
||
if tags == nil {
|
||
tags = map[string]string{}
|
||
}
|
||
queryParams := database.GetProvisionerDaemonsByOrganizationParams{
|
||
OrganizationID: orgID,
|
||
WantTags: tags,
|
||
}
|
||
|
||
// nolint: gocritic // The user (in this case, the user/context for autostart builds) may not have the full
|
||
// permissions to read provisioner daemons, but we need to check if there's any for the job prior to the
|
||
// execution of the job via autostart to fix: https://github.com/coder/coder/issues/17941
|
||
provisionerDaemons, err := tx.GetProvisionerDaemonsByOrganization(dbauthz.AsSystemReadProvisionerDaemons(context.Background()), queryParams)
|
||
if err != nil {
|
||
return database.ProvisionerDaemon{}, xerrors.Errorf("get provisioner daemons: %w", err)
|
||
}
|
||
|
||
// Check if any provisioners are active (not stale)
|
||
for _, pd := range provisionerDaemons {
|
||
if pd.LastSeenAt.Valid {
|
||
age := curTime.Sub(pd.LastSeenAt.Time)
|
||
if age <= provisionerdserver.StaleInterval {
|
||
return pd, nil
|
||
}
|
||
}
|
||
}
|
||
return database.ProvisionerDaemon{}, xerrors.New("no available provisioners found")
|
||
}
|
||
|
||
func ctxWithProvisionerPermissions(ctx context.Context) context.Context {
|
||
// Use system restricted context which has permissions to update provisioner daemons
|
||
//nolint: gocritic // We need system context to modify this.
|
||
return dbauthz.AsSystemRestricted(ctx)
|
||
}
|
||
|
||
// UpdateProvisionerLastSeenAt updates the provisioner daemon's LastSeenAt timestamp
|
||
// to the specified time to prevent it from appearing stale during autobuild operations
|
||
func UpdateProvisionerLastSeenAt(t *testing.T, db database.Store, id uuid.UUID, tickTime time.Time) {
|
||
t.Helper()
|
||
ctx := ctxWithProvisionerPermissions(context.Background())
|
||
t.Logf("Updating provisioner %s LastSeenAt to %v", id, tickTime)
|
||
err := db.UpdateProvisionerDaemonLastSeenAt(ctx, database.UpdateProvisionerDaemonLastSeenAtParams{
|
||
ID: id,
|
||
LastSeenAt: sql.NullTime{Time: tickTime, Valid: true},
|
||
})
|
||
require.NoError(t, err)
|
||
t.Logf("Successfully updated provisioner LastSeenAt")
|
||
}
|
||
|
||
func MustWaitForAnyProvisioner(t *testing.T, db database.Store) {
|
||
t.Helper()
|
||
ctx := ctxWithProvisionerPermissions(testutil.Context(t, testutil.WaitShort))
|
||
// testutil.Eventually(t, func)
|
||
testutil.Eventually(ctx, t, func(ctx context.Context) (done bool) {
|
||
daemons, err := db.GetProvisionerDaemons(ctx)
|
||
return err == nil && len(daemons) > 0
|
||
}, testutil.IntervalFast, "no provisioner daemons found")
|
||
}
|
||
|
||
// MustWaitForProvisionersUnavailable waits for provisioners to become unavailable for a specific workspace
|
||
func MustWaitForProvisionersUnavailable(t *testing.T, db database.Store, workspace codersdk.Workspace, tags map[string]string, checkTime time.Time) {
|
||
t.Helper()
|
||
ctx := ctxWithProvisionerPermissions(testutil.Context(t, testutil.WaitMedium))
|
||
|
||
testutil.Eventually(ctx, t, func(ctx context.Context) (done bool) {
|
||
// Use the same logic as hasValidProvisioner but expect false
|
||
provisionerDaemons, err := db.GetProvisionerDaemonsByOrganization(ctx, database.GetProvisionerDaemonsByOrganizationParams{
|
||
OrganizationID: workspace.OrganizationID,
|
||
WantTags: tags,
|
||
})
|
||
if err != nil {
|
||
return false
|
||
}
|
||
|
||
// Check if NO provisioners are active (all are stale or gone)
|
||
for _, pd := range provisionerDaemons {
|
||
if pd.LastSeenAt.Valid {
|
||
age := checkTime.Sub(pd.LastSeenAt.Time)
|
||
if age <= provisionerdserver.StaleInterval {
|
||
return false // Found an active provisioner, keep waiting
|
||
}
|
||
}
|
||
}
|
||
return true // No active provisioners found
|
||
}, testutil.IntervalFast, "there are still provisioners available for workspace, expected none")
|
||
}
|
||
|
||
// MustWaitForProvisionersAvailable waits for provisioners to be available for a specific workspace.
|
||
func MustWaitForProvisionersAvailable(t *testing.T, db database.Store, workspace codersdk.Workspace, ts time.Time) uuid.UUID {
|
||
t.Helper()
|
||
ctx := ctxWithProvisionerPermissions(testutil.Context(t, testutil.WaitLong))
|
||
id := uuid.UUID{}
|
||
// Get the workspace from the database
|
||
testutil.Eventually(ctx, t, func(ctx context.Context) (done bool) {
|
||
ws, err := db.GetWorkspaceByID(ctx, workspace.ID)
|
||
if err != nil {
|
||
return false
|
||
}
|
||
|
||
// Get the latest build
|
||
latestBuild, err := db.GetWorkspaceBuildByID(ctx, workspace.LatestBuild.ID)
|
||
if err != nil {
|
||
return false
|
||
}
|
||
|
||
// Get the template version job
|
||
templateVersionJob, err := db.GetProvisionerJobByID(ctx, latestBuild.JobID)
|
||
if err != nil {
|
||
return false
|
||
}
|
||
|
||
// Check if provisioners are available using the same logic as hasAvailableProvisioners
|
||
provisionerDaemons, err := db.GetProvisionerDaemonsByOrganization(ctx, database.GetProvisionerDaemonsByOrganizationParams{
|
||
OrganizationID: ws.OrganizationID,
|
||
WantTags: templateVersionJob.Tags,
|
||
})
|
||
if err != nil {
|
||
return false
|
||
}
|
||
|
||
// Check if any provisioners are active (not stale)
|
||
for _, pd := range provisionerDaemons {
|
||
if pd.LastSeenAt.Valid {
|
||
age := ts.Sub(pd.LastSeenAt.Time)
|
||
if age <= provisionerdserver.StaleInterval {
|
||
id = pd.ID
|
||
return true // Found an active provisioner
|
||
}
|
||
}
|
||
}
|
||
return false // No active provisioners found
|
||
}, testutil.IntervalFast, "no active provisioners available for workspace, expected at least one (non-stale)")
|
||
|
||
return id
|
||
}
|