feat: reinitialize agents when a prebuilt workspace is claimed (#17475)

This pull request allows coder workspace agents to be reinitialized when
a prebuilt workspace is claimed by a user. This facilitates the transfer
of ownership between the anonymous prebuilds system user and the new
owner of the workspace.

Only a single agent per prebuilt workspace is supported for now, but
plumbing has already been done to facilitate the seamless transition to
multi-agent support.

---------

Signed-off-by: Danny Kopping <dannykopping@gmail.com>
Co-authored-by: Danny Kopping <dannykopping@gmail.com>
This commit is contained in:
Sas Swart
2025-05-14 14:15:36 +02:00
committed by GitHub
parent fcbdd1a28e
commit 425ee6fa55
38 changed files with 2184 additions and 449 deletions
+7 -1
View File
@@ -368,9 +368,11 @@ func (a *agent) runLoop() {
if ctx.Err() != nil {
// Context canceled errors may come from websocket pings, so we
// don't want to use `errors.Is(err, context.Canceled)` here.
a.logger.Warn(ctx, "runLoop exited with error", slog.Error(ctx.Err()))
return
}
if a.isClosed() {
a.logger.Warn(ctx, "runLoop exited because agent is closed")
return
}
if errors.Is(err, io.EOF) {
@@ -1051,7 +1053,11 @@ func (a *agent) run() (retErr error) {
return a.statsReporter.reportLoop(ctx, aAPI)
})
return connMan.wait()
err = connMan.wait()
if err != nil {
a.logger.Info(context.Background(), "connection manager errored", slog.Error(err))
}
return err
}
// handleManifest returns a function that fetches and processes the manifest
+65 -39
View File
@@ -25,6 +25,8 @@ import (
"cdr.dev/slog/sloggers/sloghuman"
"cdr.dev/slog/sloggers/slogjson"
"cdr.dev/slog/sloggers/slogstackdriver"
"github.com/coder/serpent"
"github.com/coder/coder/v2/agent"
"github.com/coder/coder/v2/agent/agentexec"
"github.com/coder/coder/v2/agent/agentssh"
@@ -33,7 +35,6 @@ import (
"github.com/coder/coder/v2/cli/clilog"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/serpent"
)
func (r *RootCmd) workspaceAgent() *serpent.Command {
@@ -63,8 +64,10 @@ func (r *RootCmd) workspaceAgent() *serpent.Command {
// This command isn't useful to manually execute.
Hidden: true,
Handler: func(inv *serpent.Invocation) error {
ctx, cancel := context.WithCancel(inv.Context())
defer cancel()
ctx, cancel := context.WithCancelCause(inv.Context())
defer func() {
cancel(xerrors.New("agent exited"))
}()
var (
ignorePorts = map[int]string{}
@@ -281,7 +284,6 @@ func (r *RootCmd) workspaceAgent() *serpent.Command {
return xerrors.Errorf("add executable to $PATH: %w", err)
}
prometheusRegistry := prometheus.NewRegistry()
subsystemsRaw := inv.Environ.Get(agent.EnvAgentSubsystem)
subsystems := []codersdk.AgentSubsystem{}
for _, s := range strings.Split(subsystemsRaw, ",") {
@@ -325,46 +327,70 @@ func (r *RootCmd) workspaceAgent() *serpent.Command {
logger.Info(ctx, "agent devcontainer detection not enabled")
}
agnt := agent.New(agent.Options{
Client: client,
Logger: logger,
LogDir: logDir,
ScriptDataDir: scriptDataDir,
// #nosec G115 - Safe conversion as tailnet listen port is within uint16 range (0-65535)
TailnetListenPort: uint16(tailnetListenPort),
ExchangeToken: func(ctx context.Context) (string, error) {
if exchangeToken == nil {
return client.SDK.SessionToken(), nil
}
resp, err := exchangeToken(ctx)
if err != nil {
return "", err
}
client.SetSessionToken(resp.SessionToken)
return resp.SessionToken, nil
},
EnvironmentVariables: environmentVariables,
IgnorePorts: ignorePorts,
SSHMaxTimeout: sshMaxTimeout,
Subsystems: subsystems,
reinitEvents := agentsdk.WaitForReinitLoop(ctx, logger, client)
PrometheusRegistry: prometheusRegistry,
BlockFileTransfer: blockFileTransfer,
Execer: execer,
SubAgent: subAgent,
var (
lastErr error
mustExit bool
)
for {
prometheusRegistry := prometheus.NewRegistry()
ExperimentalDevcontainersEnabled: experimentalDevcontainersEnabled,
})
agnt := agent.New(agent.Options{
Client: client,
Logger: logger,
LogDir: logDir,
ScriptDataDir: scriptDataDir,
// #nosec G115 - Safe conversion as tailnet listen port is within uint16 range (0-65535)
TailnetListenPort: uint16(tailnetListenPort),
ExchangeToken: func(ctx context.Context) (string, error) {
if exchangeToken == nil {
return client.SDK.SessionToken(), nil
}
resp, err := exchangeToken(ctx)
if err != nil {
return "", err
}
client.SetSessionToken(resp.SessionToken)
return resp.SessionToken, nil
},
EnvironmentVariables: environmentVariables,
IgnorePorts: ignorePorts,
SSHMaxTimeout: sshMaxTimeout,
Subsystems: subsystems,
promHandler := agent.PrometheusMetricsHandler(prometheusRegistry, logger)
prometheusSrvClose := ServeHandler(ctx, logger, promHandler, prometheusAddress, "prometheus")
defer prometheusSrvClose()
PrometheusRegistry: prometheusRegistry,
BlockFileTransfer: blockFileTransfer,
Execer: execer,
SubAgent: subAgent,
ExperimentalDevcontainersEnabled: experimentalDevcontainersEnabled,
})
debugSrvClose := ServeHandler(ctx, logger, agnt.HTTPDebug(), debugAddress, "debug")
defer debugSrvClose()
promHandler := agent.PrometheusMetricsHandler(prometheusRegistry, logger)
prometheusSrvClose := ServeHandler(ctx, logger, promHandler, prometheusAddress, "prometheus")
<-ctx.Done()
return agnt.Close()
debugSrvClose := ServeHandler(ctx, logger, agnt.HTTPDebug(), debugAddress, "debug")
select {
case <-ctx.Done():
logger.Info(ctx, "agent shutting down", slog.Error(context.Cause(ctx)))
mustExit = true
case event := <-reinitEvents:
logger.Info(ctx, "agent received instruction to reinitialize",
slog.F("workspace_id", event.WorkspaceID), slog.F("reason", event.Reason))
}
lastErr = agnt.Close()
debugSrvClose()
prometheusSrvClose()
if mustExit {
break
}
logger.Info(ctx, "agent reinitializing")
}
return lastErr
},
}
+45
View File
@@ -8446,6 +8446,31 @@ const docTemplate = `{
}
}
},
"/workspaceagents/me/reinit": {
"get": {
"security": [
{
"CoderSessionToken": []
}
],
"produces": [
"application/json"
],
"tags": [
"Agents"
],
"summary": "Get workspace agent reinitialization",
"operationId": "get-workspace-agent-reinitialization",
"responses": {
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/agentsdk.ReinitializationEvent"
}
}
}
}
},
"/workspaceagents/me/rpc": {
"get": {
"security": [
@@ -10491,6 +10516,26 @@ const docTemplate = `{
}
}
},
"agentsdk.ReinitializationEvent": {
"type": "object",
"properties": {
"reason": {
"$ref": "#/definitions/agentsdk.ReinitializationReason"
},
"workspaceID": {
"type": "string"
}
}
},
"agentsdk.ReinitializationReason": {
"type": "string",
"enum": [
"prebuild_claimed"
],
"x-enum-varnames": [
"ReinitializeReasonPrebuildClaimed"
]
},
"aisdk.Attachment": {
"type": "object",
"properties": {
+37
View File
@@ -7463,6 +7463,27 @@
}
}
},
"/workspaceagents/me/reinit": {
"get": {
"security": [
{
"CoderSessionToken": []
}
],
"produces": ["application/json"],
"tags": ["Agents"],
"summary": "Get workspace agent reinitialization",
"operationId": "get-workspace-agent-reinitialization",
"responses": {
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/agentsdk.ReinitializationEvent"
}
}
}
}
},
"/workspaceagents/me/rpc": {
"get": {
"security": [
@@ -9302,6 +9323,22 @@
}
}
},
"agentsdk.ReinitializationEvent": {
"type": "object",
"properties": {
"reason": {
"$ref": "#/definitions/agentsdk.ReinitializationReason"
},
"workspaceID": {
"type": "string"
}
}
},
"agentsdk.ReinitializationReason": {
"type": "string",
"enum": ["prebuild_claimed"],
"x-enum-varnames": ["ReinitializeReasonPrebuildClaimed"]
},
"aisdk.Attachment": {
"type": "object",
"properties": {
+3 -1
View File
@@ -19,6 +19,8 @@ import (
"sync/atomic"
"time"
"github.com/coder/coder/v2/coderd/prebuilds"
"github.com/andybalholm/brotli"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
@@ -47,7 +49,6 @@ import (
"github.com/coder/coder/v2/coderd/entitlements"
"github.com/coder/coder/v2/coderd/files"
"github.com/coder/coder/v2/coderd/idpsync"
"github.com/coder/coder/v2/coderd/prebuilds"
"github.com/coder/coder/v2/coderd/runtimeconfig"
"github.com/coder/coder/v2/coderd/webpush"
@@ -1299,6 +1300,7 @@ func New(options *Options) *API {
r.Get("/external-auth", api.workspaceAgentsExternalAuth)
r.Get("/gitsshkey", api.agentGitSSHKey)
r.Post("/log-source", api.workspaceAgentPostLogSource)
r.Get("/reinit", api.workspaceAgentReinit)
})
r.Route("/{workspaceagent}", func(r chi.Router) {
r.Use(
+63
View File
@@ -1105,6 +1105,69 @@ func (w WorkspaceAgentWaiter) MatchResources(m func([]codersdk.WorkspaceResource
return w
}
// WaitForAgentFn represents a boolean assertion to be made against each agent
// that a given WorkspaceAgentWaited knows about. Each WaitForAgentFn should apply
// the check to a single agent, but it should be named for plural, because `func (w WorkspaceAgentWaiter) WaitFor`
// applies the check to all agents that it is aware of. This ensures that the public API of the waiter
// reads correctly. For example:
//
// waiter := coderdtest.NewWorkspaceAgentWaiter(t, client, r.Workspace.ID)
// waiter.WaitFor(coderdtest.AgentsReady)
type WaitForAgentFn func(agent codersdk.WorkspaceAgent) bool
// AgentsReady checks that the latest lifecycle state of an agent is "Ready".
func AgentsReady(agent codersdk.WorkspaceAgent) bool {
return agent.LifecycleState == codersdk.WorkspaceAgentLifecycleReady
}
// AgentsNotReady checks that the latest lifecycle state of an agent is anything except "Ready".
func AgentsNotReady(agent codersdk.WorkspaceAgent) bool {
return !AgentsReady(agent)
}
func (w WorkspaceAgentWaiter) WaitFor(criteria ...WaitForAgentFn) {
w.t.Helper()
agentNamesMap := make(map[string]struct{}, len(w.agentNames))
for _, name := range w.agentNames {
agentNamesMap[name] = struct{}{}
}
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
w.t.Logf("waiting for workspace agents (workspace %s)", w.workspaceID)
require.Eventually(w.t, func() bool {
var err error
workspace, err := w.client.Workspace(ctx, w.workspaceID)
if err != nil {
return false
}
if workspace.LatestBuild.Job.CompletedAt == nil {
return false
}
if workspace.LatestBuild.Job.CompletedAt.IsZero() {
return false
}
for _, resource := range workspace.LatestBuild.Resources {
for _, agent := range resource.Agents {
if len(w.agentNames) > 0 {
if _, ok := agentNamesMap[agent.Name]; !ok {
continue
}
}
for _, criterium := range criteria {
if !criterium(agent) {
return false
}
}
}
}
return true
}, testutil.WaitLong, testutil.IntervalMedium)
}
// Wait waits for the agent(s) to connect and fails the test if they do not within testutil.WaitLong
func (w WorkspaceAgentWaiter) Wait() []codersdk.WorkspaceResource {
w.t.Helper()
+9
View File
@@ -3020,6 +3020,15 @@ func (q *querier) GetWorkspaceAgentsByResourceIDs(ctx context.Context, ids []uui
return q.db.GetWorkspaceAgentsByResourceIDs(ctx, ids)
}
func (q *querier) GetWorkspaceAgentsByWorkspaceAndBuildNumber(ctx context.Context, arg database.GetWorkspaceAgentsByWorkspaceAndBuildNumberParams) ([]database.WorkspaceAgent, error) {
_, err := q.GetWorkspaceByID(ctx, arg.WorkspaceID)
if err != nil {
return nil, err
}
return q.db.GetWorkspaceAgentsByWorkspaceAndBuildNumber(ctx, arg)
}
func (q *querier) GetWorkspaceAgentsCreatedAfter(ctx context.Context, createdAt time.Time) ([]database.WorkspaceAgent, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
return nil, err
+32
View File
@@ -2009,6 +2009,38 @@ func (s *MethodTestSuite) TestWorkspace() {
agt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: res.ID})
check.Args(agt.ID).Asserts(w, policy.ActionRead).Returns(agt)
}))
s.Run("GetWorkspaceAgentsByWorkspaceAndBuildNumber", s.Subtest(func(db database.Store, check *expects) {
u := dbgen.User(s.T(), db, database.User{})
o := dbgen.Organization(s.T(), db, database.Organization{})
tpl := dbgen.Template(s.T(), db, database.Template{
OrganizationID: o.ID,
CreatedBy: u.ID,
})
tv := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{
TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true},
OrganizationID: o.ID,
CreatedBy: u.ID,
})
w := dbgen.Workspace(s.T(), db, database.WorkspaceTable{
TemplateID: tpl.ID,
OrganizationID: o.ID,
OwnerID: u.ID,
})
j := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{
Type: database.ProvisionerJobTypeWorkspaceBuild,
})
b := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{
JobID: j.ID,
WorkspaceID: w.ID,
TemplateVersionID: tv.ID,
})
res := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: b.JobID})
agt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: res.ID})
check.Args(database.GetWorkspaceAgentsByWorkspaceAndBuildNumberParams{
WorkspaceID: w.ID,
BuildNumber: 1,
}).Asserts(w, policy.ActionRead).Returns([]database.WorkspaceAgent{agt})
}))
s.Run("GetWorkspaceAgentLifecycleStateByID", s.Subtest(func(db database.Store, check *expects) {
u := dbgen.User(s.T(), db, database.User{})
o := dbgen.Organization(s.T(), db, database.Organization{})
+28
View File
@@ -294,6 +294,8 @@ type TemplateVersionBuilder struct {
ps pubsub.Pubsub
resources []*sdkproto.Resource
params []database.TemplateVersionParameter
presets []database.TemplateVersionPreset
presetParams []database.TemplateVersionPresetParameter
promote bool
autoCreateTemplate bool
}
@@ -339,6 +341,13 @@ func (t TemplateVersionBuilder) Params(ps ...database.TemplateVersionParameter)
return t
}
func (t TemplateVersionBuilder) Preset(preset database.TemplateVersionPreset, params ...database.TemplateVersionPresetParameter) TemplateVersionBuilder {
// nolint: revive // returns modified struct
t.presets = append(t.presets, preset)
t.presetParams = append(t.presetParams, params...)
return t
}
func (t TemplateVersionBuilder) SkipCreateTemplate() TemplateVersionBuilder {
// nolint: revive // returns modified struct
t.autoCreateTemplate = false
@@ -378,6 +387,25 @@ func (t TemplateVersionBuilder) Do() TemplateVersionResponse {
require.NoError(t.t, err)
}
for _, preset := range t.presets {
dbgen.Preset(t.t, t.db, database.InsertPresetParams{
ID: preset.ID,
TemplateVersionID: version.ID,
Name: preset.Name,
CreatedAt: version.CreatedAt,
DesiredInstances: preset.DesiredInstances,
InvalidateAfterSecs: preset.InvalidateAfterSecs,
})
}
for _, presetParam := range t.presetParams {
dbgen.PresetParameter(t.t, t.db, database.InsertPresetParametersParams{
TemplateVersionPresetID: presetParam.TemplateVersionPresetID,
Names: []string{presetParam.Name},
Values: []string{presetParam.Value},
})
}
payload, err := json.Marshal(provisionerdserver.TemplateVersionImportJob{
TemplateVersionID: t.seed.ID,
})
+1
View File
@@ -1224,6 +1224,7 @@ func TelemetryItem(t testing.TB, db database.Store, seed database.TelemetryItem)
func Preset(t testing.TB, db database.Store, seed database.InsertPresetParams) database.TemplateVersionPreset {
preset, err := db.InsertPreset(genCtx, database.InsertPresetParams{
ID: takeFirst(seed.ID, uuid.New()),
TemplateVersionID: takeFirst(seed.TemplateVersionID, uuid.New()),
Name: takeFirst(seed.Name, testutil.GetRandomName(t)),
CreatedAt: takeFirst(seed.CreatedAt, dbtime.Now()),
+24
View File
@@ -7654,6 +7654,30 @@ func (q *FakeQuerier) GetWorkspaceAgentsByResourceIDs(ctx context.Context, resou
return q.getWorkspaceAgentsByResourceIDsNoLock(ctx, resourceIDs)
}
func (q *FakeQuerier) GetWorkspaceAgentsByWorkspaceAndBuildNumber(ctx context.Context, arg database.GetWorkspaceAgentsByWorkspaceAndBuildNumberParams) ([]database.WorkspaceAgent, error) {
err := validateDatabaseType(arg)
if err != nil {
return nil, err
}
build, err := q.GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx, database.GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams(arg))
if err != nil {
return nil, err
}
resources, err := q.getWorkspaceResourcesByJobIDNoLock(ctx, build.JobID)
if err != nil {
return nil, err
}
var resourceIDs []uuid.UUID
for _, resource := range resources {
resourceIDs = append(resourceIDs, resource.ID)
}
return q.GetWorkspaceAgentsByResourceIDs(ctx, resourceIDs)
}
func (q *FakeQuerier) GetWorkspaceAgentsCreatedAfter(_ context.Context, after time.Time) ([]database.WorkspaceAgent, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
@@ -1754,6 +1754,13 @@ func (m queryMetricsStore) GetWorkspaceAgentsByResourceIDs(ctx context.Context,
return agents, err
}
func (m queryMetricsStore) GetWorkspaceAgentsByWorkspaceAndBuildNumber(ctx context.Context, arg database.GetWorkspaceAgentsByWorkspaceAndBuildNumberParams) ([]database.WorkspaceAgent, error) {
start := time.Now()
r0, r1 := m.s.GetWorkspaceAgentsByWorkspaceAndBuildNumber(ctx, arg)
m.queryLatencies.WithLabelValues("GetWorkspaceAgentsByWorkspaceAndBuildNumber").Observe(time.Since(start).Seconds())
return r0, r1
}
func (m queryMetricsStore) GetWorkspaceAgentsCreatedAfter(ctx context.Context, createdAt time.Time) ([]database.WorkspaceAgent, error) {
start := time.Now()
agents, err := m.s.GetWorkspaceAgentsCreatedAfter(ctx, createdAt)
+15
View File
@@ -3678,6 +3678,21 @@ func (mr *MockStoreMockRecorder) GetWorkspaceAgentsByResourceIDs(ctx, ids any) *
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWorkspaceAgentsByResourceIDs", reflect.TypeOf((*MockStore)(nil).GetWorkspaceAgentsByResourceIDs), ctx, ids)
}
// GetWorkspaceAgentsByWorkspaceAndBuildNumber mocks base method.
func (m *MockStore) GetWorkspaceAgentsByWorkspaceAndBuildNumber(ctx context.Context, arg database.GetWorkspaceAgentsByWorkspaceAndBuildNumberParams) ([]database.WorkspaceAgent, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetWorkspaceAgentsByWorkspaceAndBuildNumber", ctx, arg)
ret0, _ := ret[0].([]database.WorkspaceAgent)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetWorkspaceAgentsByWorkspaceAndBuildNumber indicates an expected call of GetWorkspaceAgentsByWorkspaceAndBuildNumber.
func (mr *MockStoreMockRecorder) GetWorkspaceAgentsByWorkspaceAndBuildNumber(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWorkspaceAgentsByWorkspaceAndBuildNumber", reflect.TypeOf((*MockStore)(nil).GetWorkspaceAgentsByWorkspaceAndBuildNumber), ctx, arg)
}
// GetWorkspaceAgentsCreatedAfter mocks base method.
func (m *MockStore) GetWorkspaceAgentsCreatedAfter(ctx context.Context, createdAt time.Time) ([]database.WorkspaceAgent, error) {
m.ctrl.T.Helper()
+1
View File
@@ -400,6 +400,7 @@ type sqlcQuerier interface {
GetWorkspaceAgentUsageStats(ctx context.Context, createdAt time.Time) ([]GetWorkspaceAgentUsageStatsRow, error)
GetWorkspaceAgentUsageStatsAndLabels(ctx context.Context, createdAt time.Time) ([]GetWorkspaceAgentUsageStatsAndLabelsRow, error)
GetWorkspaceAgentsByResourceIDs(ctx context.Context, ids []uuid.UUID) ([]WorkspaceAgent, error)
GetWorkspaceAgentsByWorkspaceAndBuildNumber(ctx context.Context, arg GetWorkspaceAgentsByWorkspaceAndBuildNumberParams) ([]WorkspaceAgent, error)
GetWorkspaceAgentsCreatedAfter(ctx context.Context, createdAt time.Time) ([]WorkspaceAgent, error)
GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) ([]WorkspaceAgent, error)
GetWorkspaceAppByAgentIDAndSlug(ctx context.Context, arg GetWorkspaceAppByAgentIDAndSlugParams) (WorkspaceApp, error)
+80 -1
View File
@@ -6678,6 +6678,7 @@ func (q *sqlQuerier) GetPresetsByTemplateVersionID(ctx context.Context, template
const insertPreset = `-- name: InsertPreset :one
INSERT INTO template_version_presets (
id,
template_version_id,
name,
created_at,
@@ -6689,11 +6690,13 @@ VALUES (
$2,
$3,
$4,
$5
$5,
$6
) RETURNING id, template_version_id, name, created_at, desired_instances, invalidate_after_secs
`
type InsertPresetParams struct {
ID uuid.UUID `db:"id" json:"id"`
TemplateVersionID uuid.UUID `db:"template_version_id" json:"template_version_id"`
Name string `db:"name" json:"name"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
@@ -6703,6 +6706,7 @@ type InsertPresetParams struct {
func (q *sqlQuerier) InsertPreset(ctx context.Context, arg InsertPresetParams) (TemplateVersionPreset, error) {
row := q.db.QueryRowContext(ctx, insertPreset,
arg.ID,
arg.TemplateVersionID,
arg.Name,
arg.CreatedAt,
@@ -14416,6 +14420,81 @@ func (q *sqlQuerier) GetWorkspaceAgentsByResourceIDs(ctx context.Context, ids []
return items, nil
}
const getWorkspaceAgentsByWorkspaceAndBuildNumber = `-- name: GetWorkspaceAgentsByWorkspaceAndBuildNumber :many
SELECT
workspace_agents.id, workspace_agents.created_at, workspace_agents.updated_at, workspace_agents.name, workspace_agents.first_connected_at, workspace_agents.last_connected_at, workspace_agents.disconnected_at, workspace_agents.resource_id, workspace_agents.auth_token, workspace_agents.auth_instance_id, workspace_agents.architecture, workspace_agents.environment_variables, workspace_agents.operating_system, workspace_agents.instance_metadata, workspace_agents.resource_metadata, workspace_agents.directory, workspace_agents.version, workspace_agents.last_connected_replica_id, workspace_agents.connection_timeout_seconds, workspace_agents.troubleshooting_url, workspace_agents.motd_file, workspace_agents.lifecycle_state, workspace_agents.expanded_directory, workspace_agents.logs_length, workspace_agents.logs_overflowed, workspace_agents.started_at, workspace_agents.ready_at, workspace_agents.subsystems, workspace_agents.display_apps, workspace_agents.api_version, workspace_agents.display_order, workspace_agents.parent_id
FROM
workspace_agents
JOIN
workspace_resources ON workspace_agents.resource_id = workspace_resources.id
JOIN
workspace_builds ON workspace_resources.job_id = workspace_builds.job_id
WHERE
workspace_builds.workspace_id = $1 :: uuid AND
workspace_builds.build_number = $2 :: int
`
type GetWorkspaceAgentsByWorkspaceAndBuildNumberParams struct {
WorkspaceID uuid.UUID `db:"workspace_id" json:"workspace_id"`
BuildNumber int32 `db:"build_number" json:"build_number"`
}
func (q *sqlQuerier) GetWorkspaceAgentsByWorkspaceAndBuildNumber(ctx context.Context, arg GetWorkspaceAgentsByWorkspaceAndBuildNumberParams) ([]WorkspaceAgent, error) {
rows, err := q.db.QueryContext(ctx, getWorkspaceAgentsByWorkspaceAndBuildNumber, arg.WorkspaceID, arg.BuildNumber)
if err != nil {
return nil, err
}
defer rows.Close()
var items []WorkspaceAgent
for rows.Next() {
var i WorkspaceAgent
if err := rows.Scan(
&i.ID,
&i.CreatedAt,
&i.UpdatedAt,
&i.Name,
&i.FirstConnectedAt,
&i.LastConnectedAt,
&i.DisconnectedAt,
&i.ResourceID,
&i.AuthToken,
&i.AuthInstanceID,
&i.Architecture,
&i.EnvironmentVariables,
&i.OperatingSystem,
&i.InstanceMetadata,
&i.ResourceMetadata,
&i.Directory,
&i.Version,
&i.LastConnectedReplicaID,
&i.ConnectionTimeoutSeconds,
&i.TroubleshootingURL,
&i.MOTDFile,
&i.LifecycleState,
&i.ExpandedDirectory,
&i.LogsLength,
&i.LogsOverflowed,
&i.StartedAt,
&i.ReadyAt,
pq.Array(&i.Subsystems),
pq.Array(&i.DisplayApps),
&i.APIVersion,
&i.DisplayOrder,
&i.ParentID,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getWorkspaceAgentsCreatedAfter = `-- name: GetWorkspaceAgentsCreatedAfter :many
SELECT id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, instance_metadata, resource_metadata, directory, version, last_connected_replica_id, connection_timeout_seconds, troubleshooting_url, motd_file, lifecycle_state, expanded_directory, logs_length, logs_overflowed, started_at, ready_at, subsystems, display_apps, api_version, display_order, parent_id FROM workspace_agents WHERE created_at > $1
`
+2
View File
@@ -1,5 +1,6 @@
-- name: InsertPreset :one
INSERT INTO template_version_presets (
id,
template_version_id,
name,
created_at,
@@ -7,6 +8,7 @@ INSERT INTO template_version_presets (
invalidate_after_secs
)
VALUES (
@id,
@template_version_id,
@name,
@created_at,
@@ -253,6 +253,19 @@ WHERE
wb.workspace_id = @workspace_id :: uuid
);
-- name: GetWorkspaceAgentsByWorkspaceAndBuildNumber :many
SELECT
workspace_agents.*
FROM
workspace_agents
JOIN
workspace_resources ON workspace_agents.resource_id = workspace_resources.id
JOIN
workspace_builds ON workspace_resources.job_id = workspace_builds.job_id
WHERE
workspace_builds.workspace_id = @workspace_id :: uuid AND
workspace_builds.build_number = @build_number :: int;
-- name: GetWorkspaceAgentAndLatestBuildByAuthToken :one
SELECT
sqlc.embed(workspaces),
+82
View File
@@ -0,0 +1,82 @@
package prebuilds
import (
"context"
"sync"
"github.com/google/uuid"
"golang.org/x/xerrors"
"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/database/pubsub"
"github.com/coder/coder/v2/codersdk/agentsdk"
)
func NewPubsubWorkspaceClaimPublisher(ps pubsub.Pubsub) *PubsubWorkspaceClaimPublisher {
return &PubsubWorkspaceClaimPublisher{ps: ps}
}
type PubsubWorkspaceClaimPublisher struct {
ps pubsub.Pubsub
}
func (p PubsubWorkspaceClaimPublisher) PublishWorkspaceClaim(claim agentsdk.ReinitializationEvent) error {
channel := agentsdk.PrebuildClaimedChannel(claim.WorkspaceID)
if err := p.ps.Publish(channel, []byte(claim.Reason)); err != nil {
return xerrors.Errorf("failed to trigger prebuilt workspace agent reinitialization: %w", err)
}
return nil
}
func NewPubsubWorkspaceClaimListener(ps pubsub.Pubsub, logger slog.Logger) *PubsubWorkspaceClaimListener {
return &PubsubWorkspaceClaimListener{ps: ps, logger: logger}
}
type PubsubWorkspaceClaimListener struct {
logger slog.Logger
ps pubsub.Pubsub
}
// ListenForWorkspaceClaims subscribes to a pubsub channel and sends any received events on the chan that it returns.
// pubsub.Pubsub does not communicate when its last callback has been called after it has been closed. As such the chan
// returned by this method is never closed. Call the returned cancel() function to close the subscription when it is no longer needed.
// cancel() will be called if ctx expires or is canceled.
func (p PubsubWorkspaceClaimListener) ListenForWorkspaceClaims(ctx context.Context, workspaceID uuid.UUID, reinitEvents chan<- agentsdk.ReinitializationEvent) (func(), error) {
select {
case <-ctx.Done():
return func() {}, ctx.Err()
default:
}
cancelSub, err := p.ps.Subscribe(agentsdk.PrebuildClaimedChannel(workspaceID), func(inner context.Context, reason []byte) {
claim := agentsdk.ReinitializationEvent{
WorkspaceID: workspaceID,
Reason: agentsdk.ReinitializationReason(reason),
}
select {
case <-ctx.Done():
return
case <-inner.Done():
return
case reinitEvents <- claim:
}
})
if err != nil {
return func() {}, xerrors.Errorf("failed to subscribe to prebuild claimed channel: %w", err)
}
var once sync.Once
cancel := func() {
once.Do(func() {
cancelSub()
})
}
go func() {
<-ctx.Done()
cancel()
}()
return cancel, nil
}
+141
View File
@@ -0,0 +1,141 @@
package prebuilds_test
import (
"context"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/database/pubsub"
"github.com/coder/coder/v2/coderd/prebuilds"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/testutil"
)
func TestPubsubWorkspaceClaimPublisher(t *testing.T) {
t.Parallel()
t.Run("published claim is received by a listener for the same workspace", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := testutil.Logger(t)
ps := pubsub.NewInMemory()
workspaceID := uuid.New()
reinitEvents := make(chan agentsdk.ReinitializationEvent, 1)
publisher := prebuilds.NewPubsubWorkspaceClaimPublisher(ps)
listener := prebuilds.NewPubsubWorkspaceClaimListener(ps, logger)
cancel, err := listener.ListenForWorkspaceClaims(ctx, workspaceID, reinitEvents)
require.NoError(t, err)
defer cancel()
claim := agentsdk.ReinitializationEvent{
WorkspaceID: workspaceID,
Reason: agentsdk.ReinitializeReasonPrebuildClaimed,
}
err = publisher.PublishWorkspaceClaim(claim)
require.NoError(t, err)
gotEvent := testutil.RequireReceive(ctx, t, reinitEvents)
require.Equal(t, workspaceID, gotEvent.WorkspaceID)
require.Equal(t, claim.Reason, gotEvent.Reason)
})
t.Run("fail to publish claim", func(t *testing.T) {
t.Parallel()
ps := &brokenPubsub{}
publisher := prebuilds.NewPubsubWorkspaceClaimPublisher(ps)
claim := agentsdk.ReinitializationEvent{
WorkspaceID: uuid.New(),
Reason: agentsdk.ReinitializeReasonPrebuildClaimed,
}
err := publisher.PublishWorkspaceClaim(claim)
require.ErrorContains(t, err, "failed to trigger prebuilt workspace agent reinitialization")
})
}
func TestPubsubWorkspaceClaimListener(t *testing.T) {
t.Parallel()
t.Run("finds claim events for its workspace", func(t *testing.T) {
t.Parallel()
ps := pubsub.NewInMemory()
listener := prebuilds.NewPubsubWorkspaceClaimListener(ps, slogtest.Make(t, nil))
claims := make(chan agentsdk.ReinitializationEvent, 1) // Buffer to avoid messing with goroutines in the rest of the test
workspaceID := uuid.New()
cancelFunc, err := listener.ListenForWorkspaceClaims(context.Background(), workspaceID, claims)
require.NoError(t, err)
defer cancelFunc()
// Publish a claim
channel := agentsdk.PrebuildClaimedChannel(workspaceID)
reason := agentsdk.ReinitializeReasonPrebuildClaimed
err = ps.Publish(channel, []byte(reason))
require.NoError(t, err)
// Verify we receive the claim
ctx := testutil.Context(t, testutil.WaitShort)
claim := testutil.RequireReceive(ctx, t, claims)
require.Equal(t, workspaceID, claim.WorkspaceID)
require.Equal(t, reason, claim.Reason)
})
t.Run("ignores claim events for other workspaces", func(t *testing.T) {
t.Parallel()
ps := pubsub.NewInMemory()
listener := prebuilds.NewPubsubWorkspaceClaimListener(ps, slogtest.Make(t, nil))
claims := make(chan agentsdk.ReinitializationEvent)
workspaceID := uuid.New()
otherWorkspaceID := uuid.New()
cancelFunc, err := listener.ListenForWorkspaceClaims(context.Background(), workspaceID, claims)
require.NoError(t, err)
defer cancelFunc()
// Publish a claim for a different workspace
channel := agentsdk.PrebuildClaimedChannel(otherWorkspaceID)
err = ps.Publish(channel, []byte(agentsdk.ReinitializeReasonPrebuildClaimed))
require.NoError(t, err)
// Verify we don't receive the claim
select {
case <-claims:
t.Fatal("received claim for wrong workspace")
case <-time.After(100 * time.Millisecond):
// Expected - no claim received
}
})
t.Run("communicates the error if it can't subscribe", func(t *testing.T) {
t.Parallel()
claims := make(chan agentsdk.ReinitializationEvent)
ps := &brokenPubsub{}
listener := prebuilds.NewPubsubWorkspaceClaimListener(ps, slogtest.Make(t, nil))
_, err := listener.ListenForWorkspaceClaims(context.Background(), uuid.New(), claims)
require.ErrorContains(t, err, "failed to subscribe to prebuild claimed channel")
})
}
type brokenPubsub struct {
pubsub.Pubsub
}
func (brokenPubsub) Subscribe(_ string, _ pubsub.Listener) (func(), error) {
return nil, xerrors.New("broken")
}
func (brokenPubsub) Publish(_ string, _ []byte) error {
return xerrors.New("broken")
}
@@ -40,12 +40,14 @@ import (
"github.com/coder/coder/v2/coderd/database/pubsub"
"github.com/coder/coder/v2/coderd/externalauth"
"github.com/coder/coder/v2/coderd/notifications"
"github.com/coder/coder/v2/coderd/prebuilds"
"github.com/coder/coder/v2/coderd/promoauth"
"github.com/coder/coder/v2/coderd/schedule"
"github.com/coder/coder/v2/coderd/telemetry"
"github.com/coder/coder/v2/coderd/tracing"
"github.com/coder/coder/v2/coderd/wspubsub"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/provisioner"
"github.com/coder/coder/v2/provisionerd/proto"
"github.com/coder/coder/v2/provisionersdk"
@@ -647,6 +649,30 @@ func (s *server) acquireProtoJob(ctx context.Context, job database.ProvisionerJo
}
}
runningAgentAuthTokens := []*sdkproto.RunningAgentAuthToken{}
if input.PrebuiltWorkspaceBuildStage == sdkproto.PrebuiltWorkspaceBuildStage_CLAIM {
// runningAgentAuthTokens are *only* used for prebuilds. We fetch them when we want to rebuild a prebuilt workspace
// but not generate new agent tokens. The provisionerdserver will push them down to
// the provisioner (and ultimately to the `coder_agent` resource in the Terraform provider) where they will be
// reused. Context: the agent token is often used in immutable attributes of workspace resource (e.g. VM/container)
// to initialize the agent, so if that value changes it will necessitate a replacement of that resource, thus
// obviating the whole point of the prebuild.
agents, err := s.Database.GetWorkspaceAgentsByWorkspaceAndBuildNumber(ctx, database.GetWorkspaceAgentsByWorkspaceAndBuildNumberParams{
WorkspaceID: workspace.ID,
BuildNumber: 1,
})
if err != nil {
s.Logger.Error(ctx, "failed to retrieve running agents of claimed prebuilt workspace",
slog.F("workspace_id", workspace.ID), slog.Error(err))
}
for _, agent := range agents {
runningAgentAuthTokens = append(runningAgentAuthTokens, &sdkproto.RunningAgentAuthToken{
AgentId: agent.ID.String(),
Token: agent.AuthToken.String(),
})
}
}
protoJob.Type = &proto.AcquiredJob_WorkspaceBuild_{
WorkspaceBuild: &proto.AcquiredJob_WorkspaceBuild{
WorkspaceBuildId: workspaceBuild.ID.String(),
@@ -676,6 +702,7 @@ func (s *server) acquireProtoJob(ctx context.Context, job database.ProvisionerJo
WorkspaceBuildId: workspaceBuild.ID.String(),
WorkspaceOwnerLoginType: string(owner.LoginType),
WorkspaceOwnerRbacRoles: ownerRbacRoles,
RunningAgentAuthTokens: runningAgentAuthTokens,
PrebuiltWorkspaceBuildStage: input.PrebuiltWorkspaceBuildStage,
},
LogLevel: input.LogLevel,
@@ -1812,6 +1839,19 @@ func (s *server) CompleteJob(ctx context.Context, completed *proto.CompletedJob)
if err != nil {
return nil, xerrors.Errorf("update workspace: %w", err)
}
if input.PrebuiltWorkspaceBuildStage == sdkproto.PrebuiltWorkspaceBuildStage_CLAIM {
s.Logger.Info(ctx, "workspace prebuild successfully claimed by user",
slog.F("workspace_id", workspace.ID))
err = prebuilds.NewPubsubWorkspaceClaimPublisher(s.Pubsub).PublishWorkspaceClaim(agentsdk.ReinitializationEvent{
WorkspaceID: workspace.ID,
Reason: agentsdk.ReinitializeReasonPrebuildClaimed,
})
if err != nil {
s.Logger.Error(ctx, "failed to publish workspace claim event", slog.Error(err))
}
}
case *proto.CompletedJob_TemplateDryRun_:
for _, resource := range jobType.TemplateDryRun.Resources {
s.Logger.Info(ctx, "inserting template dry-run job resource",
@@ -1955,6 +1995,7 @@ func InsertWorkspacePresetAndParameters(ctx context.Context, db database.Store,
}
}
dbPreset, err := tx.InsertPreset(ctx, database.InsertPresetParams{
ID: uuid.New(),
TemplateVersionID: templateVersionID,
Name: protoPreset.Name,
CreatedAt: t,
@@ -26,7 +26,10 @@ import (
"github.com/coder/quartz"
"github.com/coder/serpent"
"github.com/coder/coder/v2/coderd/prebuilds"
"github.com/coder/coder/v2/coderd/provisionerdserver"
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/buildinfo"
"github.com/coder/coder/v2/coderd/audit"
@@ -39,7 +42,6 @@ import (
"github.com/coder/coder/v2/coderd/externalauth"
"github.com/coder/coder/v2/coderd/notifications"
"github.com/coder/coder/v2/coderd/notifications/notificationstest"
"github.com/coder/coder/v2/coderd/provisionerdserver"
"github.com/coder/coder/v2/coderd/schedule"
"github.com/coder/coder/v2/coderd/schedule/cron"
"github.com/coder/coder/v2/coderd/telemetry"
@@ -167,8 +169,12 @@ func TestAcquireJob(t *testing.T) {
_, err = tc.acquire(ctx, srv)
require.ErrorContains(t, err, "sql: no rows in result set")
})
for _, prebuiltWorkspace := range []bool{false, true} {
prebuiltWorkspace := prebuiltWorkspace
for _, prebuiltWorkspaceBuildStage := range []sdkproto.PrebuiltWorkspaceBuildStage{
sdkproto.PrebuiltWorkspaceBuildStage_NONE,
sdkproto.PrebuiltWorkspaceBuildStage_CREATE,
sdkproto.PrebuiltWorkspaceBuildStage_CLAIM,
} {
prebuiltWorkspaceBuildStage := prebuiltWorkspaceBuildStage
t.Run(tc.name+"_WorkspaceBuildJob", func(t *testing.T) {
t.Parallel()
// Set the max session token lifetime so we can assert we
@@ -212,7 +218,7 @@ func TestAcquireJob(t *testing.T) {
Roles: []string{rbac.RoleOrgAuditor()},
})
// Add extra erronous roles
// Add extra erroneous roles
secondOrg := dbgen.Organization(t, db, database.Organization{})
dbgen.OrganizationMember(t, db, database.OrganizationMember{
UserID: user.ID,
@@ -287,36 +293,74 @@ func TestAcquireJob(t *testing.T) {
Required: true,
Sensitive: false,
})
workspace := dbgen.Workspace(t, db, database.WorkspaceTable{
workspace := database.WorkspaceTable{
TemplateID: template.ID,
OwnerID: user.ID,
OrganizationID: pd.OrganizationID,
})
build := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
}
workspace = dbgen.Workspace(t, db, workspace)
build := database.WorkspaceBuild{
WorkspaceID: workspace.ID,
BuildNumber: 1,
JobID: uuid.New(),
TemplateVersionID: version.ID,
Transition: database.WorkspaceTransitionStart,
Reason: database.BuildReasonInitiator,
})
var buildState sdkproto.PrebuiltWorkspaceBuildStage
if prebuiltWorkspace {
buildState = sdkproto.PrebuiltWorkspaceBuildStage_CREATE
}
_ = dbgen.ProvisionerJob(t, db, ps, database.ProvisionerJob{
ID: build.ID,
build = dbgen.WorkspaceBuild(t, db, build)
input := provisionerdserver.WorkspaceProvisionJob{
WorkspaceBuildID: build.ID,
}
dbJob := database.ProvisionerJob{
ID: build.JobID,
OrganizationID: pd.OrganizationID,
InitiatorID: user.ID,
Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile,
FileID: file.ID,
Type: database.ProvisionerJobTypeWorkspaceBuild,
Input: must(json.Marshal(provisionerdserver.WorkspaceProvisionJob{
Input: must(json.Marshal(input)),
}
dbJob = dbgen.ProvisionerJob(t, db, ps, dbJob)
var agent database.WorkspaceAgent
if prebuiltWorkspaceBuildStage == sdkproto.PrebuiltWorkspaceBuildStage_CLAIM {
resource := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
JobID: dbJob.ID,
})
agent = dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
ResourceID: resource.ID,
AuthToken: uuid.New(),
})
// At this point we have an unclaimed workspace and build, now we need to setup the claim
// build
build = database.WorkspaceBuild{
WorkspaceID: workspace.ID,
BuildNumber: 2,
JobID: uuid.New(),
TemplateVersionID: version.ID,
Transition: database.WorkspaceTransitionStart,
Reason: database.BuildReasonInitiator,
InitiatorID: user.ID,
}
build = dbgen.WorkspaceBuild(t, db, build)
input = provisionerdserver.WorkspaceProvisionJob{
WorkspaceBuildID: build.ID,
PrebuiltWorkspaceBuildStage: buildState,
})),
})
PrebuiltWorkspaceBuildStage: prebuiltWorkspaceBuildStage,
}
dbJob = database.ProvisionerJob{
ID: build.JobID,
OrganizationID: pd.OrganizationID,
InitiatorID: user.ID,
Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile,
FileID: file.ID,
Type: database.ProvisionerJobTypeWorkspaceBuild,
Input: must(json.Marshal(input)),
}
dbJob = dbgen.ProvisionerJob(t, db, ps, dbJob)
}
startPublished := make(chan struct{})
var closed bool
@@ -350,6 +394,19 @@ func TestAcquireJob(t *testing.T) {
<-startPublished
if prebuiltWorkspaceBuildStage == sdkproto.PrebuiltWorkspaceBuildStage_CLAIM {
for {
// In the case of a prebuild claim, there is a second build, which is the
// one that we're interested in.
job, err = tc.acquire(ctx, srv)
require.NoError(t, err)
if _, ok := job.Type.(*proto.AcquiredJob_WorkspaceBuild_); ok {
break
}
}
<-startPublished
}
got, err := json.Marshal(job.Type)
require.NoError(t, err)
@@ -384,8 +441,14 @@ func TestAcquireJob(t *testing.T) {
WorkspaceOwnerLoginType: string(user.LoginType),
WorkspaceOwnerRbacRoles: []*sdkproto.Role{{Name: rbac.RoleOrgMember(), OrgId: pd.OrganizationID.String()}, {Name: "member", OrgId: ""}, {Name: rbac.RoleOrgAuditor(), OrgId: pd.OrganizationID.String()}},
}
if prebuiltWorkspace {
wantedMetadata.PrebuiltWorkspaceBuildStage = sdkproto.PrebuiltWorkspaceBuildStage_CREATE
if prebuiltWorkspaceBuildStage == sdkproto.PrebuiltWorkspaceBuildStage_CLAIM {
// For claimed prebuilds, we expect the prebuild state to be set to CLAIM
// and we expect tokens from the first build to be set for reuse
wantedMetadata.PrebuiltWorkspaceBuildStage = prebuiltWorkspaceBuildStage
wantedMetadata.RunningAgentAuthTokens = append(wantedMetadata.RunningAgentAuthTokens, &sdkproto.RunningAgentAuthToken{
AgentId: agent.ID.String(),
Token: agent.AuthToken.String(),
})
}
slices.SortFunc(wantedMetadata.WorkspaceOwnerRbacRoles, func(a, b *sdkproto.Role) int {
@@ -1750,6 +1813,110 @@ func TestCompleteJob(t *testing.T) {
})
}
})
t.Run("ReinitializePrebuiltAgents", func(t *testing.T) {
t.Parallel()
type testcase struct {
name string
shouldReinitializeAgent bool
}
for _, tc := range []testcase{
// Whether or not there are presets and those presets define prebuilds, etc
// are all irrelevant at this level. Those factors are useful earlier in the process.
// Everything relevant to this test is determined by the value of `PrebuildClaimedByUser`
// on the provisioner job. As such, there are only two significant test cases:
{
name: "claimed prebuild",
shouldReinitializeAgent: true,
},
{
name: "not a claimed prebuild",
shouldReinitializeAgent: false,
},
} {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
// GIVEN an enqueued provisioner job and its dependencies:
srv, db, ps, pd := setup(t, false, &overrides{})
buildID := uuid.New()
jobInput := provisionerdserver.WorkspaceProvisionJob{
WorkspaceBuildID: buildID,
}
if tc.shouldReinitializeAgent { // This is the key lever in the test
// GIVEN the enqueued provisioner job is for a workspace being claimed by a user:
jobInput.PrebuiltWorkspaceBuildStage = sdkproto.PrebuiltWorkspaceBuildStage_CLAIM
}
input, err := json.Marshal(jobInput)
require.NoError(t, err)
ctx := testutil.Context(t, testutil.WaitShort)
job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
Input: input,
Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile,
Type: database.ProvisionerJobTypeWorkspaceBuild,
})
require.NoError(t, err)
tpl := dbgen.Template(t, db, database.Template{
OrganizationID: pd.OrganizationID,
})
tv := dbgen.TemplateVersion(t, db, database.TemplateVersion{
TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true},
JobID: job.ID,
})
workspace := dbgen.Workspace(t, db, database.WorkspaceTable{
TemplateID: tpl.ID,
})
_ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
ID: buildID,
JobID: job.ID,
WorkspaceID: workspace.ID,
TemplateVersionID: tv.ID,
})
_, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
WorkerID: uuid.NullUUID{
UUID: pd.ID,
Valid: true,
},
Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
})
require.NoError(t, err)
// GIVEN something is listening to process workspace reinitialization:
reinitChan := make(chan agentsdk.ReinitializationEvent, 1) // Buffered to simplify test structure
cancel, err := prebuilds.NewPubsubWorkspaceClaimListener(ps, testutil.Logger(t)).ListenForWorkspaceClaims(ctx, workspace.ID, reinitChan)
require.NoError(t, err)
defer cancel()
// WHEN the job is completed
completedJob := proto.CompletedJob{
JobId: job.ID.String(),
Type: &proto.CompletedJob_WorkspaceBuild_{
WorkspaceBuild: &proto.CompletedJob_WorkspaceBuild{},
},
}
_, err = srv.CompleteJob(ctx, &completedJob)
require.NoError(t, err)
if tc.shouldReinitializeAgent {
event := testutil.RequireReceive(ctx, t, reinitChan)
require.Equal(t, workspace.ID, event.WorkspaceID)
} else {
select {
case <-reinitChan:
t.Fatal("unexpected reinitialization event published")
default:
// OK
}
}
})
}
})
}
func TestInsertWorkspacePresetsAndParameters(t *testing.T) {
+55
View File
@@ -35,6 +35,7 @@ import (
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/coderd/httpmw/loggermw"
"github.com/coder/coder/v2/coderd/jwtutils"
"github.com/coder/coder/v2/coderd/prebuilds"
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/coderd/rbac/policy"
"github.com/coder/coder/v2/coderd/telemetry"
@@ -1183,6 +1184,60 @@ func (api *API) workspaceAgentPostLogSource(rw http.ResponseWriter, r *http.Requ
httpapi.Write(ctx, rw, http.StatusCreated, apiSource)
}
// @Summary Get workspace agent reinitialization
// @ID get-workspace-agent-reinitialization
// @Security CoderSessionToken
// @Produce json
// @Tags Agents
// @Success 200 {object} agentsdk.ReinitializationEvent
// @Router /workspaceagents/me/reinit [get]
func (api *API) workspaceAgentReinit(rw http.ResponseWriter, r *http.Request) {
// Allow us to interrupt watch via cancel.
ctx, cancel := context.WithCancel(r.Context())
defer cancel()
r = r.WithContext(ctx) // Rewire context for SSE cancellation.
workspaceAgent := httpmw.WorkspaceAgent(r)
log := api.Logger.Named("workspace_agent_reinit_watcher").With(
slog.F("workspace_agent_id", workspaceAgent.ID),
)
workspace, err := api.Database.GetWorkspaceByAgentID(ctx, workspaceAgent.ID)
if err != nil {
log.Error(ctx, "failed to retrieve workspace from agent token", slog.Error(err))
httpapi.InternalServerError(rw, xerrors.New("failed to determine workspace from agent token"))
}
log.Info(ctx, "agent waiting for reinit instruction")
reinitEvents := make(chan agentsdk.ReinitializationEvent)
cancel, err = prebuilds.NewPubsubWorkspaceClaimListener(api.Pubsub, log).ListenForWorkspaceClaims(ctx, workspace.ID, reinitEvents)
if err != nil {
log.Error(ctx, "subscribe to prebuild claimed channel", slog.Error(err))
httpapi.InternalServerError(rw, xerrors.New("failed to subscribe to prebuild claimed channel"))
return
}
defer cancel()
transmitter := agentsdk.NewSSEAgentReinitTransmitter(log, rw, r)
err = transmitter.Transmit(ctx, reinitEvents)
switch {
case errors.Is(err, agentsdk.ErrTransmissionSourceClosed):
log.Info(ctx, "agent reinitialization subscription closed", slog.F("workspace_agent_id", workspaceAgent.ID))
case errors.Is(err, agentsdk.ErrTransmissionTargetClosed):
log.Info(ctx, "agent connection closed", slog.F("workspace_agent_id", workspaceAgent.ID))
case errors.Is(err, context.Canceled):
log.Info(ctx, "agent reinitialization", slog.Error(err))
case err != nil:
log.Error(ctx, "failed to stream agent reinit events", slog.Error(err))
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error streaming agent reinitialization events.",
Detail: err.Error(),
})
}
}
// convertProvisionedApps converts applications that are in the middle of provisioning process.
// It means that they may not have an agent or workspace assigned (dry-run job).
func convertProvisionedApps(dbApps []database.WorkspaceApp) []codersdk.WorkspaceApp {
+70
View File
@@ -11,6 +11,7 @@ import (
"runtime"
"strconv"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
@@ -44,10 +45,12 @@ import (
"github.com/coder/coder/v2/coderd/database/dbfake"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbmem"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/database/pubsub"
"github.com/coder/coder/v2/coderd/externalauth"
"github.com/coder/coder/v2/coderd/jwtutils"
"github.com/coder/coder/v2/coderd/prebuilds"
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/coderd/telemetry"
"github.com/coder/coder/v2/coderd/util/ptr"
@@ -2641,3 +2644,70 @@ func TestAgentConnectionInfo(t *testing.T) {
require.True(t, info.DisableDirectConnections)
require.True(t, info.DERPForceWebSockets)
}
func TestReinit(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
pubsubSpy := pubsubReinitSpy{
Pubsub: ps,
subscribed: make(chan string),
}
client := coderdtest.New(t, &coderdtest.Options{
Database: db,
Pubsub: &pubsubSpy,
})
user := coderdtest.CreateFirstUser(t, client)
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
}).WithAgent().Do()
pubsubSpy.Mutex.Lock()
pubsubSpy.expectedEvent = agentsdk.PrebuildClaimedChannel(r.Workspace.ID)
pubsubSpy.Mutex.Unlock()
agentCtx := testutil.Context(t, testutil.WaitShort)
agentClient := agentsdk.New(client.URL)
agentClient.SetSessionToken(r.AgentToken)
agentReinitializedCh := make(chan *agentsdk.ReinitializationEvent)
go func() {
reinitEvent, err := agentClient.WaitForReinit(agentCtx)
assert.NoError(t, err)
agentReinitializedCh <- reinitEvent
}()
// We need to subscribe before we publish, lest we miss the event
ctx := testutil.Context(t, testutil.WaitShort)
testutil.TryReceive(ctx, t, pubsubSpy.subscribed) // Wait for the appropriate subscription
// Now that we're subscribed, publish the event
err := prebuilds.NewPubsubWorkspaceClaimPublisher(ps).PublishWorkspaceClaim(agentsdk.ReinitializationEvent{
WorkspaceID: r.Workspace.ID,
Reason: agentsdk.ReinitializeReasonPrebuildClaimed,
})
require.NoError(t, err)
ctx = testutil.Context(t, testutil.WaitShort)
reinitEvent := testutil.TryReceive(ctx, t, agentReinitializedCh)
require.NotNil(t, reinitEvent)
require.Equal(t, r.Workspace.ID, reinitEvent.WorkspaceID)
}
type pubsubReinitSpy struct {
pubsub.Pubsub
sync.Mutex
subscribed chan string
expectedEvent string
}
func (p *pubsubReinitSpy) Subscribe(event string, listener pubsub.Listener) (cancel func(), err error) {
p.Lock()
if p.expectedEvent != "" && event == p.expectedEvent {
close(p.subscribed)
}
p.Unlock()
return p.Pubsub.Subscribe(event, listener)
}
+2 -3
View File
@@ -628,9 +628,9 @@ func createWorkspace(
err = api.Database.InTx(func(db database.Store) error {
var (
prebuildsClaimer = *api.PrebuildsClaimer.Load()
workspaceID uuid.UUID
claimedWorkspace *database.Workspace
prebuildsClaimer = *api.PrebuildsClaimer.Load()
)
// If a template preset was chosen, try claim a prebuilt workspace.
@@ -704,8 +704,7 @@ func createWorkspace(
Reason(database.BuildReasonInitiator).
Initiator(initiatorID).
ActiveVersion().
RichParameterValues(req.RichParameterValues).
TemplateVersionPresetID(req.TemplateVersionPresetID)
RichParameterValues(req.RichParameterValues)
if req.TemplateVersionID != uuid.Nil {
builder = builder.VersionID(req.TemplateVersionID)
}
+1 -2
View File
@@ -77,8 +77,7 @@ type Builder struct {
parameterValues *[]string
templateVersionPresetParameterValues []database.TemplateVersionPresetParameter
prebuiltWorkspaceBuildStage sdkproto.PrebuiltWorkspaceBuildStage
prebuiltWorkspaceBuildStage sdkproto.PrebuiltWorkspaceBuildStage
verifyNoLegacyParametersOnce bool
}
+189 -1
View File
@@ -19,12 +19,15 @@ import (
"tailscale.com/tailcfg"
"cdr.dev/slog"
"github.com/coder/retry"
"github.com/coder/websocket"
"github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/apiversion"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/drpcsdk"
tailnetproto "github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/websocket"
)
// ExternalLogSourceID is the statically-defined ID of a log-source that
@@ -686,3 +689,188 @@ func LogsNotifyChannel(agentID uuid.UUID) string {
type LogsNotifyMessage struct {
CreatedAfter int64 `json:"created_after"`
}
type ReinitializationReason string
const (
ReinitializeReasonPrebuildClaimed ReinitializationReason = "prebuild_claimed"
)
type ReinitializationEvent struct {
WorkspaceID uuid.UUID
Reason ReinitializationReason `json:"reason"`
}
func PrebuildClaimedChannel(id uuid.UUID) string {
return fmt.Sprintf("prebuild_claimed_%s", id)
}
// WaitForReinit polls a SSE endpoint, and receives an event back under the following conditions:
// - ping: ignored, keepalive
// - prebuild claimed: a prebuilt workspace is claimed, so the agent must reinitialize.
func (c *Client) WaitForReinit(ctx context.Context) (*ReinitializationEvent, error) {
rpcURL, err := c.SDK.URL.Parse("/api/v2/workspaceagents/me/reinit")
if err != nil {
return nil, xerrors.Errorf("parse url: %w", err)
}
jar, err := cookiejar.New(nil)
if err != nil {
return nil, xerrors.Errorf("create cookie jar: %w", err)
}
jar.SetCookies(rpcURL, []*http.Cookie{{
Name: codersdk.SessionTokenCookie,
Value: c.SDK.SessionToken(),
}})
httpClient := &http.Client{
Jar: jar,
Transport: c.SDK.HTTPClient.Transport,
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, rpcURL.String(), nil)
if err != nil {
return nil, xerrors.Errorf("build request: %w", err)
}
res, err := httpClient.Do(req)
if err != nil {
return nil, xerrors.Errorf("execute request: %w", err)
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return nil, codersdk.ReadBodyAsError(res)
}
reinitEvent, err := NewSSEAgentReinitReceiver(res.Body).Receive(ctx)
if err != nil {
return nil, xerrors.Errorf("listening for reinitialization events: %w", err)
}
return reinitEvent, nil
}
func WaitForReinitLoop(ctx context.Context, logger slog.Logger, client *Client) <-chan ReinitializationEvent {
reinitEvents := make(chan ReinitializationEvent)
go func() {
for retrier := retry.New(100*time.Millisecond, 10*time.Second); retrier.Wait(ctx); {
logger.Debug(ctx, "waiting for agent reinitialization instructions")
reinitEvent, err := client.WaitForReinit(ctx)
if err != nil {
logger.Error(ctx, "failed to wait for agent reinitialization instructions", slog.Error(err))
continue
}
retrier.Reset()
select {
case <-ctx.Done():
close(reinitEvents)
return
case reinitEvents <- *reinitEvent:
}
}
}()
return reinitEvents
}
func NewSSEAgentReinitTransmitter(logger slog.Logger, rw http.ResponseWriter, r *http.Request) *SSEAgentReinitTransmitter {
return &SSEAgentReinitTransmitter{logger: logger, rw: rw, r: r}
}
type SSEAgentReinitTransmitter struct {
rw http.ResponseWriter
r *http.Request
logger slog.Logger
}
var (
ErrTransmissionSourceClosed = xerrors.New("transmission source closed")
ErrTransmissionTargetClosed = xerrors.New("transmission target closed")
)
// Transmit will read from the given chan and send events for as long as:
// * the chan remains open
// * the context has not been canceled
// * not timed out
// * the connection to the receiver remains open
func (s *SSEAgentReinitTransmitter) Transmit(ctx context.Context, reinitEvents <-chan ReinitializationEvent) error {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
sseSendEvent, sseSenderClosed, err := httpapi.ServerSentEventSender(s.rw, s.r)
if err != nil {
return xerrors.Errorf("failed to create sse transmitter: %w", err)
}
defer func() {
// Block returning until the ServerSentEventSender is closed
// to avoid a race condition where we might write or flush to rw after the handler returns.
<-sseSenderClosed
}()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-sseSenderClosed:
return ErrTransmissionTargetClosed
case reinitEvent, ok := <-reinitEvents:
if !ok {
return ErrTransmissionSourceClosed
}
err := sseSendEvent(codersdk.ServerSentEvent{
Type: codersdk.ServerSentEventTypeData,
Data: reinitEvent,
})
if err != nil {
return err
}
}
}
}
func NewSSEAgentReinitReceiver(r io.ReadCloser) *SSEAgentReinitReceiver {
return &SSEAgentReinitReceiver{r: r}
}
type SSEAgentReinitReceiver struct {
r io.ReadCloser
}
func (s *SSEAgentReinitReceiver) Receive(ctx context.Context) (*ReinitializationEvent, error) {
nextEvent := codersdk.ServerSentEventReader(ctx, s.r)
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
sse, err := nextEvent()
switch {
case err != nil:
return nil, xerrors.Errorf("failed to read server-sent event: %w", err)
case sse.Type == codersdk.ServerSentEventTypeError:
return nil, xerrors.Errorf("unexpected server sent event type error")
case sse.Type == codersdk.ServerSentEventTypePing:
continue
case sse.Type != codersdk.ServerSentEventTypeData:
return nil, xerrors.Errorf("unexpected server sent event type: %s", sse.Type)
}
// At this point we know that the sent event is of type codersdk.ServerSentEventTypeData
var reinitEvent ReinitializationEvent
b, ok := sse.Data.([]byte)
if !ok {
return nil, xerrors.Errorf("expected data as []byte, got %T", sse.Data)
}
err = json.Unmarshal(b, &reinitEvent)
if err != nil {
return nil, xerrors.Errorf("unmarshal reinit response: %w", err)
}
return &reinitEvent, nil
}
}
+122
View File
@@ -0,0 +1,122 @@
package agentsdk_test
import (
"context"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/testutil"
)
func TestStreamAgentReinitEvents(t *testing.T) {
t.Parallel()
t.Run("transmitted events are received", func(t *testing.T) {
t.Parallel()
eventToSend := agentsdk.ReinitializationEvent{
WorkspaceID: uuid.New(),
Reason: agentsdk.ReinitializeReasonPrebuildClaimed,
}
events := make(chan agentsdk.ReinitializationEvent, 1)
events <- eventToSend
transmitCtx := testutil.Context(t, testutil.WaitShort)
transmitErrCh := make(chan error, 1)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
transmitter := agentsdk.NewSSEAgentReinitTransmitter(slogtest.Make(t, nil), w, r)
transmitErrCh <- transmitter.Transmit(transmitCtx, events)
}))
defer srv.Close()
requestCtx := testutil.Context(t, testutil.WaitShort)
req, err := http.NewRequestWithContext(requestCtx, "GET", srv.URL, nil)
require.NoError(t, err)
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
receiveCtx := testutil.Context(t, testutil.WaitShort)
receiver := agentsdk.NewSSEAgentReinitReceiver(resp.Body)
sentEvent, receiveErr := receiver.Receive(receiveCtx)
require.Nil(t, receiveErr)
require.Equal(t, eventToSend, *sentEvent)
})
t.Run("doesn't transmit events if the transmitter context is canceled", func(t *testing.T) {
t.Parallel()
eventToSend := agentsdk.ReinitializationEvent{
WorkspaceID: uuid.New(),
Reason: agentsdk.ReinitializeReasonPrebuildClaimed,
}
events := make(chan agentsdk.ReinitializationEvent, 1)
events <- eventToSend
transmitCtx, cancelTransmit := context.WithCancel(testutil.Context(t, testutil.WaitShort))
cancelTransmit()
transmitErrCh := make(chan error, 1)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
transmitter := agentsdk.NewSSEAgentReinitTransmitter(slogtest.Make(t, nil), w, r)
transmitErrCh <- transmitter.Transmit(transmitCtx, events)
}))
defer srv.Close()
requestCtx := testutil.Context(t, testutil.WaitShort)
req, err := http.NewRequestWithContext(requestCtx, "GET", srv.URL, nil)
require.NoError(t, err)
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
receiveCtx := testutil.Context(t, testutil.WaitShort)
receiver := agentsdk.NewSSEAgentReinitReceiver(resp.Body)
sentEvent, receiveErr := receiver.Receive(receiveCtx)
require.Nil(t, sentEvent)
require.ErrorIs(t, receiveErr, io.EOF)
})
t.Run("does not receive events if the receiver context is canceled", func(t *testing.T) {
t.Parallel()
eventToSend := agentsdk.ReinitializationEvent{
WorkspaceID: uuid.New(),
Reason: agentsdk.ReinitializeReasonPrebuildClaimed,
}
events := make(chan agentsdk.ReinitializationEvent, 1)
events <- eventToSend
transmitCtx := testutil.Context(t, testutil.WaitShort)
transmitErrCh := make(chan error, 1)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
transmitter := agentsdk.NewSSEAgentReinitTransmitter(slogtest.Make(t, nil), w, r)
transmitErrCh <- transmitter.Transmit(transmitCtx, events)
}))
defer srv.Close()
requestCtx := testutil.Context(t, testutil.WaitShort)
req, err := http.NewRequestWithContext(requestCtx, "GET", srv.URL, nil)
require.NoError(t, err)
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
receiveCtx, cancelReceive := context.WithCancel(context.Background())
cancelReceive()
receiver := agentsdk.NewSSEAgentReinitReceiver(resp.Body)
sentEvent, receiveErr := receiver.Receive(receiveCtx)
require.Nil(t, sentEvent)
require.ErrorIs(t, receiveErr, context.Canceled)
})
}
+1 -1
View File
@@ -631,7 +631,7 @@ func (h *HeaderTransport) RoundTrip(req *http.Request) (*http.Response, error) {
}
}
if h.Transport == nil {
h.Transport = http.DefaultTransport
return http.DefaultTransport.RoundTrip(req)
}
return h.Transport.RoundTrip(req)
}
+32
View File
@@ -470,6 +470,38 @@ curl -X PATCH http://coder-server:8080/api/v2/workspaceagents/me/logs \
To perform this operation, you must be authenticated. [Learn more](authentication.md).
## Get workspace agent reinitialization
### Code samples
```shell
# Example request using curl
curl -X GET http://coder-server:8080/api/v2/workspaceagents/me/reinit \
-H 'Accept: application/json' \
-H 'Coder-Session-Token: API_KEY'
```
`GET /workspaceagents/me/reinit`
### Example responses
> 200 Response
```json
{
"reason": "prebuild_claimed",
"workspaceID": "string"
}
```
### Responses
| Status | Meaning | Description | Schema |
|--------|---------------------------------------------------------|-------------|----------------------------------------------------------------------------|
| 200 | [OK](https://tools.ietf.org/html/rfc7231#section-6.3.1) | OK | [agentsdk.ReinitializationEvent](schemas.md#agentsdkreinitializationevent) |
To perform this operation, you must be authenticated. [Learn more](authentication.md).
## Get workspace agent by ID
### Code samples
+30
View File
@@ -182,6 +182,36 @@
| `icon` | string | false | | |
| `id` | string | false | | ID is a unique identifier for the log source. It is scoped to a workspace agent, and can be statically defined inside code to prevent duplicate sources from being created for the same agent. |
## agentsdk.ReinitializationEvent
```json
{
"reason": "prebuild_claimed",
"workspaceID": "string"
}
```
### Properties
| Name | Type | Required | Restrictions | Description |
|---------------|--------------------------------------------------------------------|----------|--------------|-------------|
| `reason` | [agentsdk.ReinitializationReason](#agentsdkreinitializationreason) | false | | |
| `workspaceID` | string | false | | |
## agentsdk.ReinitializationReason
```json
"prebuild_claimed"
```
### Properties
#### Enumerated Values
| Value |
|--------------------|
| `prebuild_claimed` |
## aisdk.Attachment
```json
+169
View File
@@ -5,12 +5,19 @@ import (
"crypto/tls"
"fmt"
"net/http"
"os"
"regexp"
"testing"
"time"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/serpent"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/agent"
"github.com/coder/coder/v2/cli/clitest"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
@@ -73,6 +80,168 @@ func TestBlockNonBrowser(t *testing.T) {
})
}
func TestReinitializeAgent(t *testing.T) {
t.Parallel()
tempAgentLog := testutil.CreateTemp(t, "", "testReinitializeAgent")
if !dbtestutil.WillUsePostgres() {
t.Skip("dbmem cannot currently claim a workspace")
}
db, ps := dbtestutil.NewDB(t)
// GIVEN a live enterprise API with the prebuilds feature enabled
client, user := coderdenttest.New(t, &coderdenttest.Options{
Options: &coderdtest.Options{
Database: db,
Pubsub: ps,
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
dv.Prebuilds.ReconciliationInterval = serpent.Duration(time.Second)
dv.Experiments.Append(string(codersdk.ExperimentWorkspacePrebuilds))
}),
IncludeProvisionerDaemon: true,
},
LicenseOptions: &coderdenttest.LicenseOptions{
Features: license.Features{
codersdk.FeatureWorkspacePrebuilds: 1,
},
},
})
// GIVEN a template, template version, preset and a prebuilt workspace that uses them all
agentToken := uuid.UUID{3}
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
Parse: echo.ParseComplete,
ProvisionPlan: []*proto.Response{
{
Type: &proto.Response_Plan{
Plan: &proto.PlanComplete{
Presets: []*proto.Preset{
{
Name: "test-preset",
Prebuild: &proto.Prebuild{
Instances: 1,
},
},
},
Resources: []*proto.Resource{
{
Agents: []*proto.Agent{
{
Name: "smith",
OperatingSystem: "linux",
Architecture: "i386",
},
},
},
},
},
},
},
},
ProvisionApply: []*proto.Response{
{
Type: &proto.Response_Apply{
Apply: &proto.ApplyComplete{
Resources: []*proto.Resource{
{
Type: "compute",
Name: "main",
Agents: []*proto.Agent{
{
Name: "smith",
OperatingSystem: "linux",
Architecture: "i386",
Scripts: []*proto.Script{
{
RunOnStart: true,
Script: fmt.Sprintf("printenv >> %s; echo '---\n' >> %s", tempAgentLog.Name(), tempAgentLog.Name()), // Make reinitialization take long enough to assert that it happened
},
},
Auth: &proto.Agent_Token{
Token: agentToken.String(),
},
},
},
},
},
},
},
},
},
})
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
// Wait for prebuilds to create a prebuilt workspace
ctx := context.Background()
// ctx := testutil.Context(t, testutil.WaitLong)
var (
prebuildID uuid.UUID
)
require.Eventually(t, func() bool {
agentAndBuild, err := db.GetWorkspaceAgentAndLatestBuildByAuthToken(ctx, agentToken)
if err != nil {
return false
}
prebuildID = agentAndBuild.WorkspaceBuild.ID
return true
}, testutil.WaitLong, testutil.IntervalFast)
prebuild := coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, prebuildID)
preset, err := db.GetPresetByWorkspaceBuildID(ctx, prebuildID)
require.NoError(t, err)
// GIVEN a running agent
logDir := t.TempDir()
inv, _ := clitest.New(t,
"agent",
"--auth", "token",
"--agent-token", agentToken.String(),
"--agent-url", client.URL.String(),
"--log-dir", logDir,
)
clitest.Start(t, inv)
// GIVEN the agent is in a happy steady state
waiter := coderdtest.NewWorkspaceAgentWaiter(t, client, prebuild.WorkspaceID)
waiter.WaitFor(coderdtest.AgentsReady)
// WHEN a workspace is created that can benefit from prebuilds
anotherClient, anotherUser := coderdtest.CreateAnotherUser(t, client, user.OrganizationID)
workspace, err := anotherClient.CreateUserWorkspace(ctx, anotherUser.ID.String(), codersdk.CreateWorkspaceRequest{
TemplateVersionID: version.ID,
TemplateVersionPresetID: preset.ID,
Name: "claimed-workspace",
})
require.NoError(t, err)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
// THEN reinitialization completes
waiter.WaitFor(coderdtest.AgentsReady)
var matches [][]byte
require.Eventually(t, func() bool {
// THEN the agent script ran again and reused the same agent token
contents, err := os.ReadFile(tempAgentLog.Name())
if err != nil {
return false
}
// UUID regex pattern (matches UUID v4-like strings)
uuidRegex := regexp.MustCompile(`\bCODER_AGENT_TOKEN=(.+)\b`)
matches = uuidRegex.FindAll(contents, -1)
// When an agent reinitializes, we expect it to run startup scripts again.
// As such, we expect to have written the agent environment to the temp file twice.
// Once on initial startup and then once on reinitialization.
return len(matches) == 2
}, testutil.WaitLong, testutil.IntervalMedium)
require.Equal(t, matches[0], matches[1])
}
type setupResp struct {
workspace codersdk.Workspace
sdkAgent codersdk.WorkspaceAgent
+78
View File
@@ -4,6 +4,7 @@ import (
"bytes"
"context"
"database/sql"
"encoding/json"
"fmt"
"net/http"
"os"
@@ -13,6 +14,7 @@ import (
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -30,6 +32,8 @@ import (
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/coderd/notifications"
"github.com/coder/coder/v2/coderd/prebuilds"
"github.com/coder/coder/v2/coderd/provisionerdserver"
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/coderd/rbac/policy"
agplschedule "github.com/coder/coder/v2/coderd/schedule"
@@ -43,6 +47,7 @@ import (
"github.com/coder/coder/v2/enterprise/coderd/schedule"
"github.com/coder/coder/v2/provisioner/echo"
"github.com/coder/coder/v2/provisionersdk"
"github.com/coder/coder/v2/provisionersdk/proto"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
)
@@ -459,6 +464,79 @@ func TestCreateUserWorkspace(t *testing.T) {
_, err = client1.CreateUserWorkspace(ctx, user1.ID.String(), req)
require.Error(t, err)
})
t.Run("ClaimPrebuild", func(t *testing.T) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("dbmem cannot currently claim a workspace")
}
client, db, user := coderdenttest.NewWithDatabase(t, &coderdenttest.Options{
Options: &coderdtest.Options{
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
err := dv.Experiments.Append(string(codersdk.ExperimentWorkspacePrebuilds))
require.NoError(t, err)
}),
},
LicenseOptions: &coderdenttest.LicenseOptions{
Features: license.Features{
codersdk.FeatureWorkspacePrebuilds: 1,
},
},
})
// GIVEN a template, template version, preset and a prebuilt workspace that uses them all
presetID := uuid.New()
tv := dbfake.TemplateVersion(t, db).Seed(database.TemplateVersion{
OrganizationID: user.OrganizationID,
CreatedBy: user.UserID,
}).Preset(database.TemplateVersionPreset{
ID: presetID,
}).Do()
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OwnerID: prebuilds.SystemUserID,
TemplateID: tv.Template.ID,
}).Seed(database.WorkspaceBuild{
TemplateVersionID: tv.TemplateVersion.ID,
TemplateVersionPresetID: uuid.NullUUID{
UUID: presetID,
Valid: true,
},
}).WithAgent(func(a []*proto.Agent) []*proto.Agent {
return a
}).Do()
// nolint:gocritic // this is a test
ctx := dbauthz.AsSystemRestricted(testutil.Context(t, testutil.WaitLong))
agent, err := db.GetWorkspaceAgentAndLatestBuildByAuthToken(ctx, uuid.MustParse(r.AgentToken))
require.NoError(t, err)
err = db.UpdateWorkspaceAgentLifecycleStateByID(ctx, database.UpdateWorkspaceAgentLifecycleStateByIDParams{
ID: agent.WorkspaceAgent.ID,
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
})
require.NoError(t, err)
// WHEN a workspace is created that matches the available prebuilt workspace
_, err = client.CreateUserWorkspace(ctx, user.UserID.String(), codersdk.CreateWorkspaceRequest{
TemplateVersionID: tv.TemplateVersion.ID,
TemplateVersionPresetID: presetID,
Name: "claimed-workspace",
})
require.NoError(t, err)
// THEN a new build is scheduled with the build stage specified
build, err := db.GetLatestWorkspaceBuildByWorkspaceID(ctx, r.Workspace.ID)
require.NoError(t, err)
require.NotEqual(t, build.ID, r.Build.ID)
job, err := db.GetProvisionerJobByID(ctx, build.JobID)
require.NoError(t, err)
var metadata provisionerdserver.WorkspaceProvisionJob
require.NoError(t, json.Unmarshal(job.Input, &metadata))
require.Equal(t, metadata.PrebuiltWorkspaceBuildStage, proto.PrebuiltWorkspaceBuildStage_CLAIM)
})
}
func TestWorkspaceAutobuild(t *testing.T) {
+64
View File
@@ -350,6 +350,68 @@ func onlyDataResources(sm tfjson.StateModule) tfjson.StateModule {
return filtered
}
func (e *executor) logResourceReplacements(ctx context.Context, plan *tfjson.Plan) {
if plan == nil {
return
}
if len(plan.ResourceChanges) == 0 {
return
}
var (
count int
replacements = make(map[string][]string, len(plan.ResourceChanges))
)
for _, ch := range plan.ResourceChanges {
// No change, no problem!
if ch.Change == nil {
continue
}
// No-op change, no problem!
if ch.Change.Actions.NoOp() {
continue
}
// No replacements, no problem!
if len(ch.Change.ReplacePaths) == 0 {
continue
}
// Replacing our resources, no problem!
if strings.Index(ch.Type, "coder_") == 0 {
continue
}
for _, p := range ch.Change.ReplacePaths {
var path string
switch p := p.(type) {
case []interface{}:
segs := p
list := make([]string, 0, len(segs))
for _, s := range segs {
list = append(list, fmt.Sprintf("%v", s))
}
path = strings.Join(list, ".")
default:
path = fmt.Sprintf("%v", p)
}
replacements[ch.Address] = append(replacements[ch.Address], path)
}
count++
}
if count > 0 {
e.server.logger.Warn(ctx, "plan introduces resource changes", slog.F("count", count))
for n, p := range replacements {
e.server.logger.Warn(ctx, "resource will be replaced", slog.F("name", n), slog.F("replacement_paths", strings.Join(p, ",")))
}
}
}
// planResources must only be called while the lock is held.
func (e *executor) planResources(ctx, killCtx context.Context, planfilePath string) (*State, json.RawMessage, error) {
ctx, span := e.server.startTrace(ctx, tracing.FuncName())
@@ -360,6 +422,8 @@ func (e *executor) planResources(ctx, killCtx context.Context, planfilePath stri
return nil, nil, xerrors.Errorf("show terraform plan file: %w", err)
}
e.logResourceReplacements(ctx, plan)
rawGraph, err := e.graph(ctx, killCtx)
if err != nil {
return nil, nil, xerrors.Errorf("graph: %w", err)
+11
View File
@@ -273,6 +273,17 @@ func provisionEnv(
if metadata.GetPrebuiltWorkspaceBuildStage().IsPrebuild() {
env = append(env, provider.IsPrebuildEnvironmentVariable()+"=true")
}
tokens := metadata.GetRunningAgentAuthTokens()
if len(tokens) == 1 {
env = append(env, provider.RunningAgentTokenEnvironmentVariable("")+"="+tokens[0].Token)
} else {
// Not currently supported, but added for forward-compatibility
for _, t := range tokens {
// If there are multiple agents, provide all the tokens to terraform so that it can
// choose the correct one for each agent ID.
env = append(env, provider.RunningAgentTokenEnvironmentVariable(t.AgentId)+"="+t.Token)
}
}
if metadata.GetPrebuiltWorkspaceBuildStage().IsPrebuiltWorkspaceClaim() {
env = append(env, provider.IsPrebuildClaimEnvironmentVariable()+"=true")
}
+1
View File
@@ -19,6 +19,7 @@ import "github.com/coder/coder/v2/apiversion"
// - Add previous parameter values to 'WorkspaceBuild' jobs. Provisioner passes
// the previous values for the `terraform apply` to enforce monotonicity
// in the terraform provider.
// - Add new field named `running_agent_auth_tokens` to provisioner job metadata
const (
CurrentMajor = 1
CurrentMinor = 5
+452 -376
View File
File diff suppressed because it is too large Load Diff
+5 -1
View File
@@ -273,6 +273,10 @@ message Role {
string org_id = 2;
}
message RunningAgentAuthToken {
string agent_id = 1;
string token = 2;
}
enum PrebuiltWorkspaceBuildStage {
NONE = 0; // Default value for builds unrelated to prebuilds.
CREATE = 1; // A prebuilt workspace is being provisioned.
@@ -301,7 +305,7 @@ message Metadata {
string workspace_owner_login_type = 18;
repeated Role workspace_owner_rbac_roles = 19;
PrebuiltWorkspaceBuildStage prebuilt_workspace_build_stage = 20; // Indicates that a prebuilt workspace is being built.
string running_workspace_agent_token = 21; // Preserves the running agent token of a prebuilt workspace so it can reinitialize.
repeated RunningAgentAuthToken running_agent_auth_tokens = 21;
}
// Config represents execution configuration shared by all subsequent requests in the Session
+20 -4
View File
@@ -297,6 +297,11 @@ export interface Role {
orgId: string;
}
export interface RunningAgentAuthToken {
agentId: string;
token: string;
}
/** Metadata is information about a workspace used in the execution of a build */
export interface Metadata {
coderUrl: string;
@@ -320,8 +325,7 @@ export interface Metadata {
workspaceOwnerRbacRoles: Role[];
/** Indicates that a prebuilt workspace is being built. */
prebuiltWorkspaceBuildStage: PrebuiltWorkspaceBuildStage;
/** Preserves the running agent token of a prebuilt workspace so it can reinitialize. */
runningWorkspaceAgentToken: string;
runningAgentAuthTokens: RunningAgentAuthToken[];
}
/** Config represents execution configuration shared by all subsequent requests in the Session */
@@ -986,6 +990,18 @@ export const Role = {
},
};
export const RunningAgentAuthToken = {
encode(message: RunningAgentAuthToken, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer {
if (message.agentId !== "") {
writer.uint32(10).string(message.agentId);
}
if (message.token !== "") {
writer.uint32(18).string(message.token);
}
return writer;
},
};
export const Metadata = {
encode(message: Metadata, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer {
if (message.coderUrl !== "") {
@@ -1048,8 +1064,8 @@ export const Metadata = {
if (message.prebuiltWorkspaceBuildStage !== 0) {
writer.uint32(160).int32(message.prebuiltWorkspaceBuildStage);
}
if (message.runningWorkspaceAgentToken !== "") {
writer.uint32(170).string(message.runningWorkspaceAgentToken);
for (const v of message.runningAgentAuthTokens) {
RunningAgentAuthToken.encode(v!, writer.uint32(170).fork()).ldelim();
}
return writer;
},