mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
feat: reinitialize agents when a prebuilt workspace is claimed (#17475)
This pull request allows coder workspace agents to be reinitialized when a prebuilt workspace is claimed by a user. This facilitates the transfer of ownership between the anonymous prebuilds system user and the new owner of the workspace. Only a single agent per prebuilt workspace is supported for now, but plumbing has already been done to facilitate the seamless transition to multi-agent support. --------- Signed-off-by: Danny Kopping <dannykopping@gmail.com> Co-authored-by: Danny Kopping <dannykopping@gmail.com>
This commit is contained in:
+7
-1
@@ -368,9 +368,11 @@ func (a *agent) runLoop() {
|
||||
if ctx.Err() != nil {
|
||||
// Context canceled errors may come from websocket pings, so we
|
||||
// don't want to use `errors.Is(err, context.Canceled)` here.
|
||||
a.logger.Warn(ctx, "runLoop exited with error", slog.Error(ctx.Err()))
|
||||
return
|
||||
}
|
||||
if a.isClosed() {
|
||||
a.logger.Warn(ctx, "runLoop exited because agent is closed")
|
||||
return
|
||||
}
|
||||
if errors.Is(err, io.EOF) {
|
||||
@@ -1051,7 +1053,11 @@ func (a *agent) run() (retErr error) {
|
||||
return a.statsReporter.reportLoop(ctx, aAPI)
|
||||
})
|
||||
|
||||
return connMan.wait()
|
||||
err = connMan.wait()
|
||||
if err != nil {
|
||||
a.logger.Info(context.Background(), "connection manager errored", slog.Error(err))
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// handleManifest returns a function that fetches and processes the manifest
|
||||
|
||||
+65
-39
@@ -25,6 +25,8 @@ import (
|
||||
"cdr.dev/slog/sloggers/sloghuman"
|
||||
"cdr.dev/slog/sloggers/slogjson"
|
||||
"cdr.dev/slog/sloggers/slogstackdriver"
|
||||
"github.com/coder/serpent"
|
||||
|
||||
"github.com/coder/coder/v2/agent"
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
"github.com/coder/coder/v2/agent/agentssh"
|
||||
@@ -33,7 +35,6 @@ import (
|
||||
"github.com/coder/coder/v2/cli/clilog"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
func (r *RootCmd) workspaceAgent() *serpent.Command {
|
||||
@@ -63,8 +64,10 @@ func (r *RootCmd) workspaceAgent() *serpent.Command {
|
||||
// This command isn't useful to manually execute.
|
||||
Hidden: true,
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
ctx, cancel := context.WithCancel(inv.Context())
|
||||
defer cancel()
|
||||
ctx, cancel := context.WithCancelCause(inv.Context())
|
||||
defer func() {
|
||||
cancel(xerrors.New("agent exited"))
|
||||
}()
|
||||
|
||||
var (
|
||||
ignorePorts = map[int]string{}
|
||||
@@ -281,7 +284,6 @@ func (r *RootCmd) workspaceAgent() *serpent.Command {
|
||||
return xerrors.Errorf("add executable to $PATH: %w", err)
|
||||
}
|
||||
|
||||
prometheusRegistry := prometheus.NewRegistry()
|
||||
subsystemsRaw := inv.Environ.Get(agent.EnvAgentSubsystem)
|
||||
subsystems := []codersdk.AgentSubsystem{}
|
||||
for _, s := range strings.Split(subsystemsRaw, ",") {
|
||||
@@ -325,46 +327,70 @@ func (r *RootCmd) workspaceAgent() *serpent.Command {
|
||||
logger.Info(ctx, "agent devcontainer detection not enabled")
|
||||
}
|
||||
|
||||
agnt := agent.New(agent.Options{
|
||||
Client: client,
|
||||
Logger: logger,
|
||||
LogDir: logDir,
|
||||
ScriptDataDir: scriptDataDir,
|
||||
// #nosec G115 - Safe conversion as tailnet listen port is within uint16 range (0-65535)
|
||||
TailnetListenPort: uint16(tailnetListenPort),
|
||||
ExchangeToken: func(ctx context.Context) (string, error) {
|
||||
if exchangeToken == nil {
|
||||
return client.SDK.SessionToken(), nil
|
||||
}
|
||||
resp, err := exchangeToken(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
client.SetSessionToken(resp.SessionToken)
|
||||
return resp.SessionToken, nil
|
||||
},
|
||||
EnvironmentVariables: environmentVariables,
|
||||
IgnorePorts: ignorePorts,
|
||||
SSHMaxTimeout: sshMaxTimeout,
|
||||
Subsystems: subsystems,
|
||||
reinitEvents := agentsdk.WaitForReinitLoop(ctx, logger, client)
|
||||
|
||||
PrometheusRegistry: prometheusRegistry,
|
||||
BlockFileTransfer: blockFileTransfer,
|
||||
Execer: execer,
|
||||
SubAgent: subAgent,
|
||||
var (
|
||||
lastErr error
|
||||
mustExit bool
|
||||
)
|
||||
for {
|
||||
prometheusRegistry := prometheus.NewRegistry()
|
||||
|
||||
ExperimentalDevcontainersEnabled: experimentalDevcontainersEnabled,
|
||||
})
|
||||
agnt := agent.New(agent.Options{
|
||||
Client: client,
|
||||
Logger: logger,
|
||||
LogDir: logDir,
|
||||
ScriptDataDir: scriptDataDir,
|
||||
// #nosec G115 - Safe conversion as tailnet listen port is within uint16 range (0-65535)
|
||||
TailnetListenPort: uint16(tailnetListenPort),
|
||||
ExchangeToken: func(ctx context.Context) (string, error) {
|
||||
if exchangeToken == nil {
|
||||
return client.SDK.SessionToken(), nil
|
||||
}
|
||||
resp, err := exchangeToken(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
client.SetSessionToken(resp.SessionToken)
|
||||
return resp.SessionToken, nil
|
||||
},
|
||||
EnvironmentVariables: environmentVariables,
|
||||
IgnorePorts: ignorePorts,
|
||||
SSHMaxTimeout: sshMaxTimeout,
|
||||
Subsystems: subsystems,
|
||||
|
||||
promHandler := agent.PrometheusMetricsHandler(prometheusRegistry, logger)
|
||||
prometheusSrvClose := ServeHandler(ctx, logger, promHandler, prometheusAddress, "prometheus")
|
||||
defer prometheusSrvClose()
|
||||
PrometheusRegistry: prometheusRegistry,
|
||||
BlockFileTransfer: blockFileTransfer,
|
||||
Execer: execer,
|
||||
SubAgent: subAgent,
|
||||
ExperimentalDevcontainersEnabled: experimentalDevcontainersEnabled,
|
||||
})
|
||||
|
||||
debugSrvClose := ServeHandler(ctx, logger, agnt.HTTPDebug(), debugAddress, "debug")
|
||||
defer debugSrvClose()
|
||||
promHandler := agent.PrometheusMetricsHandler(prometheusRegistry, logger)
|
||||
prometheusSrvClose := ServeHandler(ctx, logger, promHandler, prometheusAddress, "prometheus")
|
||||
|
||||
<-ctx.Done()
|
||||
return agnt.Close()
|
||||
debugSrvClose := ServeHandler(ctx, logger, agnt.HTTPDebug(), debugAddress, "debug")
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
logger.Info(ctx, "agent shutting down", slog.Error(context.Cause(ctx)))
|
||||
mustExit = true
|
||||
case event := <-reinitEvents:
|
||||
logger.Info(ctx, "agent received instruction to reinitialize",
|
||||
slog.F("workspace_id", event.WorkspaceID), slog.F("reason", event.Reason))
|
||||
}
|
||||
|
||||
lastErr = agnt.Close()
|
||||
debugSrvClose()
|
||||
prometheusSrvClose()
|
||||
|
||||
if mustExit {
|
||||
break
|
||||
}
|
||||
|
||||
logger.Info(ctx, "agent reinitializing")
|
||||
}
|
||||
return lastErr
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
Generated
+45
@@ -8446,6 +8446,31 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"/workspaceagents/me/reinit": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"produces": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"Agents"
|
||||
],
|
||||
"summary": "Get workspace agent reinitialization",
|
||||
"operationId": "get-workspace-agent-reinitialization",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/agentsdk.ReinitializationEvent"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/workspaceagents/me/rpc": {
|
||||
"get": {
|
||||
"security": [
|
||||
@@ -10491,6 +10516,26 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"agentsdk.ReinitializationEvent": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"reason": {
|
||||
"$ref": "#/definitions/agentsdk.ReinitializationReason"
|
||||
},
|
||||
"workspaceID": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"agentsdk.ReinitializationReason": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"prebuild_claimed"
|
||||
],
|
||||
"x-enum-varnames": [
|
||||
"ReinitializeReasonPrebuildClaimed"
|
||||
]
|
||||
},
|
||||
"aisdk.Attachment": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
||||
Generated
+37
@@ -7463,6 +7463,27 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/workspaceagents/me/reinit": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"produces": ["application/json"],
|
||||
"tags": ["Agents"],
|
||||
"summary": "Get workspace agent reinitialization",
|
||||
"operationId": "get-workspace-agent-reinitialization",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/agentsdk.ReinitializationEvent"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/workspaceagents/me/rpc": {
|
||||
"get": {
|
||||
"security": [
|
||||
@@ -9302,6 +9323,22 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"agentsdk.ReinitializationEvent": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"reason": {
|
||||
"$ref": "#/definitions/agentsdk.ReinitializationReason"
|
||||
},
|
||||
"workspaceID": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"agentsdk.ReinitializationReason": {
|
||||
"type": "string",
|
||||
"enum": ["prebuild_claimed"],
|
||||
"x-enum-varnames": ["ReinitializeReasonPrebuildClaimed"]
|
||||
},
|
||||
"aisdk.Attachment": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
||||
+3
-1
@@ -19,6 +19,8 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/prebuilds"
|
||||
|
||||
"github.com/andybalholm/brotli"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
@@ -47,7 +49,6 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/entitlements"
|
||||
"github.com/coder/coder/v2/coderd/files"
|
||||
"github.com/coder/coder/v2/coderd/idpsync"
|
||||
"github.com/coder/coder/v2/coderd/prebuilds"
|
||||
"github.com/coder/coder/v2/coderd/runtimeconfig"
|
||||
"github.com/coder/coder/v2/coderd/webpush"
|
||||
|
||||
@@ -1299,6 +1300,7 @@ func New(options *Options) *API {
|
||||
r.Get("/external-auth", api.workspaceAgentsExternalAuth)
|
||||
r.Get("/gitsshkey", api.agentGitSSHKey)
|
||||
r.Post("/log-source", api.workspaceAgentPostLogSource)
|
||||
r.Get("/reinit", api.workspaceAgentReinit)
|
||||
})
|
||||
r.Route("/{workspaceagent}", func(r chi.Router) {
|
||||
r.Use(
|
||||
|
||||
@@ -1105,6 +1105,69 @@ func (w WorkspaceAgentWaiter) MatchResources(m func([]codersdk.WorkspaceResource
|
||||
return w
|
||||
}
|
||||
|
||||
// WaitForAgentFn represents a boolean assertion to be made against each agent
|
||||
// that a given WorkspaceAgentWaited knows about. Each WaitForAgentFn should apply
|
||||
// the check to a single agent, but it should be named for plural, because `func (w WorkspaceAgentWaiter) WaitFor`
|
||||
// applies the check to all agents that it is aware of. This ensures that the public API of the waiter
|
||||
// reads correctly. For example:
|
||||
//
|
||||
// waiter := coderdtest.NewWorkspaceAgentWaiter(t, client, r.Workspace.ID)
|
||||
// waiter.WaitFor(coderdtest.AgentsReady)
|
||||
type WaitForAgentFn func(agent codersdk.WorkspaceAgent) bool
|
||||
|
||||
// AgentsReady checks that the latest lifecycle state of an agent is "Ready".
|
||||
func AgentsReady(agent codersdk.WorkspaceAgent) bool {
|
||||
return agent.LifecycleState == codersdk.WorkspaceAgentLifecycleReady
|
||||
}
|
||||
|
||||
// AgentsNotReady checks that the latest lifecycle state of an agent is anything except "Ready".
|
||||
func AgentsNotReady(agent codersdk.WorkspaceAgent) bool {
|
||||
return !AgentsReady(agent)
|
||||
}
|
||||
|
||||
func (w WorkspaceAgentWaiter) WaitFor(criteria ...WaitForAgentFn) {
|
||||
w.t.Helper()
|
||||
|
||||
agentNamesMap := make(map[string]struct{}, len(w.agentNames))
|
||||
for _, name := range w.agentNames {
|
||||
agentNamesMap[name] = struct{}{}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
w.t.Logf("waiting for workspace agents (workspace %s)", w.workspaceID)
|
||||
require.Eventually(w.t, func() bool {
|
||||
var err error
|
||||
workspace, err := w.client.Workspace(ctx, w.workspaceID)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if workspace.LatestBuild.Job.CompletedAt == nil {
|
||||
return false
|
||||
}
|
||||
if workspace.LatestBuild.Job.CompletedAt.IsZero() {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, resource := range workspace.LatestBuild.Resources {
|
||||
for _, agent := range resource.Agents {
|
||||
if len(w.agentNames) > 0 {
|
||||
if _, ok := agentNamesMap[agent.Name]; !ok {
|
||||
continue
|
||||
}
|
||||
}
|
||||
for _, criterium := range criteria {
|
||||
if !criterium(agent) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}, testutil.WaitLong, testutil.IntervalMedium)
|
||||
}
|
||||
|
||||
// Wait waits for the agent(s) to connect and fails the test if they do not within testutil.WaitLong
|
||||
func (w WorkspaceAgentWaiter) Wait() []codersdk.WorkspaceResource {
|
||||
w.t.Helper()
|
||||
|
||||
@@ -3020,6 +3020,15 @@ func (q *querier) GetWorkspaceAgentsByResourceIDs(ctx context.Context, ids []uui
|
||||
return q.db.GetWorkspaceAgentsByResourceIDs(ctx, ids)
|
||||
}
|
||||
|
||||
func (q *querier) GetWorkspaceAgentsByWorkspaceAndBuildNumber(ctx context.Context, arg database.GetWorkspaceAgentsByWorkspaceAndBuildNumberParams) ([]database.WorkspaceAgent, error) {
|
||||
_, err := q.GetWorkspaceByID(ctx, arg.WorkspaceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return q.db.GetWorkspaceAgentsByWorkspaceAndBuildNumber(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetWorkspaceAgentsCreatedAfter(ctx context.Context, createdAt time.Time) ([]database.WorkspaceAgent, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -2009,6 +2009,38 @@ func (s *MethodTestSuite) TestWorkspace() {
|
||||
agt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: res.ID})
|
||||
check.Args(agt.ID).Asserts(w, policy.ActionRead).Returns(agt)
|
||||
}))
|
||||
s.Run("GetWorkspaceAgentsByWorkspaceAndBuildNumber", s.Subtest(func(db database.Store, check *expects) {
|
||||
u := dbgen.User(s.T(), db, database.User{})
|
||||
o := dbgen.Organization(s.T(), db, database.Organization{})
|
||||
tpl := dbgen.Template(s.T(), db, database.Template{
|
||||
OrganizationID: o.ID,
|
||||
CreatedBy: u.ID,
|
||||
})
|
||||
tv := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{
|
||||
TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true},
|
||||
OrganizationID: o.ID,
|
||||
CreatedBy: u.ID,
|
||||
})
|
||||
w := dbgen.Workspace(s.T(), db, database.WorkspaceTable{
|
||||
TemplateID: tpl.ID,
|
||||
OrganizationID: o.ID,
|
||||
OwnerID: u.ID,
|
||||
})
|
||||
j := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{
|
||||
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
||||
})
|
||||
b := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{
|
||||
JobID: j.ID,
|
||||
WorkspaceID: w.ID,
|
||||
TemplateVersionID: tv.ID,
|
||||
})
|
||||
res := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: b.JobID})
|
||||
agt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{ResourceID: res.ID})
|
||||
check.Args(database.GetWorkspaceAgentsByWorkspaceAndBuildNumberParams{
|
||||
WorkspaceID: w.ID,
|
||||
BuildNumber: 1,
|
||||
}).Asserts(w, policy.ActionRead).Returns([]database.WorkspaceAgent{agt})
|
||||
}))
|
||||
s.Run("GetWorkspaceAgentLifecycleStateByID", s.Subtest(func(db database.Store, check *expects) {
|
||||
u := dbgen.User(s.T(), db, database.User{})
|
||||
o := dbgen.Organization(s.T(), db, database.Organization{})
|
||||
|
||||
@@ -294,6 +294,8 @@ type TemplateVersionBuilder struct {
|
||||
ps pubsub.Pubsub
|
||||
resources []*sdkproto.Resource
|
||||
params []database.TemplateVersionParameter
|
||||
presets []database.TemplateVersionPreset
|
||||
presetParams []database.TemplateVersionPresetParameter
|
||||
promote bool
|
||||
autoCreateTemplate bool
|
||||
}
|
||||
@@ -339,6 +341,13 @@ func (t TemplateVersionBuilder) Params(ps ...database.TemplateVersionParameter)
|
||||
return t
|
||||
}
|
||||
|
||||
func (t TemplateVersionBuilder) Preset(preset database.TemplateVersionPreset, params ...database.TemplateVersionPresetParameter) TemplateVersionBuilder {
|
||||
// nolint: revive // returns modified struct
|
||||
t.presets = append(t.presets, preset)
|
||||
t.presetParams = append(t.presetParams, params...)
|
||||
return t
|
||||
}
|
||||
|
||||
func (t TemplateVersionBuilder) SkipCreateTemplate() TemplateVersionBuilder {
|
||||
// nolint: revive // returns modified struct
|
||||
t.autoCreateTemplate = false
|
||||
@@ -378,6 +387,25 @@ func (t TemplateVersionBuilder) Do() TemplateVersionResponse {
|
||||
require.NoError(t.t, err)
|
||||
}
|
||||
|
||||
for _, preset := range t.presets {
|
||||
dbgen.Preset(t.t, t.db, database.InsertPresetParams{
|
||||
ID: preset.ID,
|
||||
TemplateVersionID: version.ID,
|
||||
Name: preset.Name,
|
||||
CreatedAt: version.CreatedAt,
|
||||
DesiredInstances: preset.DesiredInstances,
|
||||
InvalidateAfterSecs: preset.InvalidateAfterSecs,
|
||||
})
|
||||
}
|
||||
|
||||
for _, presetParam := range t.presetParams {
|
||||
dbgen.PresetParameter(t.t, t.db, database.InsertPresetParametersParams{
|
||||
TemplateVersionPresetID: presetParam.TemplateVersionPresetID,
|
||||
Names: []string{presetParam.Name},
|
||||
Values: []string{presetParam.Value},
|
||||
})
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(provisionerdserver.TemplateVersionImportJob{
|
||||
TemplateVersionID: t.seed.ID,
|
||||
})
|
||||
|
||||
@@ -1224,6 +1224,7 @@ func TelemetryItem(t testing.TB, db database.Store, seed database.TelemetryItem)
|
||||
|
||||
func Preset(t testing.TB, db database.Store, seed database.InsertPresetParams) database.TemplateVersionPreset {
|
||||
preset, err := db.InsertPreset(genCtx, database.InsertPresetParams{
|
||||
ID: takeFirst(seed.ID, uuid.New()),
|
||||
TemplateVersionID: takeFirst(seed.TemplateVersionID, uuid.New()),
|
||||
Name: takeFirst(seed.Name, testutil.GetRandomName(t)),
|
||||
CreatedAt: takeFirst(seed.CreatedAt, dbtime.Now()),
|
||||
|
||||
@@ -7654,6 +7654,30 @@ func (q *FakeQuerier) GetWorkspaceAgentsByResourceIDs(ctx context.Context, resou
|
||||
return q.getWorkspaceAgentsByResourceIDsNoLock(ctx, resourceIDs)
|
||||
}
|
||||
|
||||
func (q *FakeQuerier) GetWorkspaceAgentsByWorkspaceAndBuildNumber(ctx context.Context, arg database.GetWorkspaceAgentsByWorkspaceAndBuildNumberParams) ([]database.WorkspaceAgent, error) {
|
||||
err := validateDatabaseType(arg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
build, err := q.GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx, database.GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams(arg))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resources, err := q.getWorkspaceResourcesByJobIDNoLock(ctx, build.JobID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var resourceIDs []uuid.UUID
|
||||
for _, resource := range resources {
|
||||
resourceIDs = append(resourceIDs, resource.ID)
|
||||
}
|
||||
|
||||
return q.GetWorkspaceAgentsByResourceIDs(ctx, resourceIDs)
|
||||
}
|
||||
|
||||
func (q *FakeQuerier) GetWorkspaceAgentsCreatedAfter(_ context.Context, after time.Time) ([]database.WorkspaceAgent, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
@@ -1754,6 +1754,13 @@ func (m queryMetricsStore) GetWorkspaceAgentsByResourceIDs(ctx context.Context,
|
||||
return agents, err
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetWorkspaceAgentsByWorkspaceAndBuildNumber(ctx context.Context, arg database.GetWorkspaceAgentsByWorkspaceAndBuildNumberParams) ([]database.WorkspaceAgent, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetWorkspaceAgentsByWorkspaceAndBuildNumber(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("GetWorkspaceAgentsByWorkspaceAndBuildNumber").Observe(time.Since(start).Seconds())
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetWorkspaceAgentsCreatedAfter(ctx context.Context, createdAt time.Time) ([]database.WorkspaceAgent, error) {
|
||||
start := time.Now()
|
||||
agents, err := m.s.GetWorkspaceAgentsCreatedAfter(ctx, createdAt)
|
||||
|
||||
@@ -3678,6 +3678,21 @@ func (mr *MockStoreMockRecorder) GetWorkspaceAgentsByResourceIDs(ctx, ids any) *
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWorkspaceAgentsByResourceIDs", reflect.TypeOf((*MockStore)(nil).GetWorkspaceAgentsByResourceIDs), ctx, ids)
|
||||
}
|
||||
|
||||
// GetWorkspaceAgentsByWorkspaceAndBuildNumber mocks base method.
|
||||
func (m *MockStore) GetWorkspaceAgentsByWorkspaceAndBuildNumber(ctx context.Context, arg database.GetWorkspaceAgentsByWorkspaceAndBuildNumberParams) ([]database.WorkspaceAgent, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetWorkspaceAgentsByWorkspaceAndBuildNumber", ctx, arg)
|
||||
ret0, _ := ret[0].([]database.WorkspaceAgent)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetWorkspaceAgentsByWorkspaceAndBuildNumber indicates an expected call of GetWorkspaceAgentsByWorkspaceAndBuildNumber.
|
||||
func (mr *MockStoreMockRecorder) GetWorkspaceAgentsByWorkspaceAndBuildNumber(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWorkspaceAgentsByWorkspaceAndBuildNumber", reflect.TypeOf((*MockStore)(nil).GetWorkspaceAgentsByWorkspaceAndBuildNumber), ctx, arg)
|
||||
}
|
||||
|
||||
// GetWorkspaceAgentsCreatedAfter mocks base method.
|
||||
func (m *MockStore) GetWorkspaceAgentsCreatedAfter(ctx context.Context, createdAt time.Time) ([]database.WorkspaceAgent, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -400,6 +400,7 @@ type sqlcQuerier interface {
|
||||
GetWorkspaceAgentUsageStats(ctx context.Context, createdAt time.Time) ([]GetWorkspaceAgentUsageStatsRow, error)
|
||||
GetWorkspaceAgentUsageStatsAndLabels(ctx context.Context, createdAt time.Time) ([]GetWorkspaceAgentUsageStatsAndLabelsRow, error)
|
||||
GetWorkspaceAgentsByResourceIDs(ctx context.Context, ids []uuid.UUID) ([]WorkspaceAgent, error)
|
||||
GetWorkspaceAgentsByWorkspaceAndBuildNumber(ctx context.Context, arg GetWorkspaceAgentsByWorkspaceAndBuildNumberParams) ([]WorkspaceAgent, error)
|
||||
GetWorkspaceAgentsCreatedAfter(ctx context.Context, createdAt time.Time) ([]WorkspaceAgent, error)
|
||||
GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) ([]WorkspaceAgent, error)
|
||||
GetWorkspaceAppByAgentIDAndSlug(ctx context.Context, arg GetWorkspaceAppByAgentIDAndSlugParams) (WorkspaceApp, error)
|
||||
|
||||
@@ -6678,6 +6678,7 @@ func (q *sqlQuerier) GetPresetsByTemplateVersionID(ctx context.Context, template
|
||||
|
||||
const insertPreset = `-- name: InsertPreset :one
|
||||
INSERT INTO template_version_presets (
|
||||
id,
|
||||
template_version_id,
|
||||
name,
|
||||
created_at,
|
||||
@@ -6689,11 +6690,13 @@ VALUES (
|
||||
$2,
|
||||
$3,
|
||||
$4,
|
||||
$5
|
||||
$5,
|
||||
$6
|
||||
) RETURNING id, template_version_id, name, created_at, desired_instances, invalidate_after_secs
|
||||
`
|
||||
|
||||
type InsertPresetParams struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
TemplateVersionID uuid.UUID `db:"template_version_id" json:"template_version_id"`
|
||||
Name string `db:"name" json:"name"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
@@ -6703,6 +6706,7 @@ type InsertPresetParams struct {
|
||||
|
||||
func (q *sqlQuerier) InsertPreset(ctx context.Context, arg InsertPresetParams) (TemplateVersionPreset, error) {
|
||||
row := q.db.QueryRowContext(ctx, insertPreset,
|
||||
arg.ID,
|
||||
arg.TemplateVersionID,
|
||||
arg.Name,
|
||||
arg.CreatedAt,
|
||||
@@ -14416,6 +14420,81 @@ func (q *sqlQuerier) GetWorkspaceAgentsByResourceIDs(ctx context.Context, ids []
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const getWorkspaceAgentsByWorkspaceAndBuildNumber = `-- name: GetWorkspaceAgentsByWorkspaceAndBuildNumber :many
|
||||
SELECT
|
||||
workspace_agents.id, workspace_agents.created_at, workspace_agents.updated_at, workspace_agents.name, workspace_agents.first_connected_at, workspace_agents.last_connected_at, workspace_agents.disconnected_at, workspace_agents.resource_id, workspace_agents.auth_token, workspace_agents.auth_instance_id, workspace_agents.architecture, workspace_agents.environment_variables, workspace_agents.operating_system, workspace_agents.instance_metadata, workspace_agents.resource_metadata, workspace_agents.directory, workspace_agents.version, workspace_agents.last_connected_replica_id, workspace_agents.connection_timeout_seconds, workspace_agents.troubleshooting_url, workspace_agents.motd_file, workspace_agents.lifecycle_state, workspace_agents.expanded_directory, workspace_agents.logs_length, workspace_agents.logs_overflowed, workspace_agents.started_at, workspace_agents.ready_at, workspace_agents.subsystems, workspace_agents.display_apps, workspace_agents.api_version, workspace_agents.display_order, workspace_agents.parent_id
|
||||
FROM
|
||||
workspace_agents
|
||||
JOIN
|
||||
workspace_resources ON workspace_agents.resource_id = workspace_resources.id
|
||||
JOIN
|
||||
workspace_builds ON workspace_resources.job_id = workspace_builds.job_id
|
||||
WHERE
|
||||
workspace_builds.workspace_id = $1 :: uuid AND
|
||||
workspace_builds.build_number = $2 :: int
|
||||
`
|
||||
|
||||
type GetWorkspaceAgentsByWorkspaceAndBuildNumberParams struct {
|
||||
WorkspaceID uuid.UUID `db:"workspace_id" json:"workspace_id"`
|
||||
BuildNumber int32 `db:"build_number" json:"build_number"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) GetWorkspaceAgentsByWorkspaceAndBuildNumber(ctx context.Context, arg GetWorkspaceAgentsByWorkspaceAndBuildNumberParams) ([]WorkspaceAgent, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getWorkspaceAgentsByWorkspaceAndBuildNumber, arg.WorkspaceID, arg.BuildNumber)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []WorkspaceAgent
|
||||
for rows.Next() {
|
||||
var i WorkspaceAgent
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.Name,
|
||||
&i.FirstConnectedAt,
|
||||
&i.LastConnectedAt,
|
||||
&i.DisconnectedAt,
|
||||
&i.ResourceID,
|
||||
&i.AuthToken,
|
||||
&i.AuthInstanceID,
|
||||
&i.Architecture,
|
||||
&i.EnvironmentVariables,
|
||||
&i.OperatingSystem,
|
||||
&i.InstanceMetadata,
|
||||
&i.ResourceMetadata,
|
||||
&i.Directory,
|
||||
&i.Version,
|
||||
&i.LastConnectedReplicaID,
|
||||
&i.ConnectionTimeoutSeconds,
|
||||
&i.TroubleshootingURL,
|
||||
&i.MOTDFile,
|
||||
&i.LifecycleState,
|
||||
&i.ExpandedDirectory,
|
||||
&i.LogsLength,
|
||||
&i.LogsOverflowed,
|
||||
&i.StartedAt,
|
||||
&i.ReadyAt,
|
||||
pq.Array(&i.Subsystems),
|
||||
pq.Array(&i.DisplayApps),
|
||||
&i.APIVersion,
|
||||
&i.DisplayOrder,
|
||||
&i.ParentID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const getWorkspaceAgentsCreatedAfter = `-- name: GetWorkspaceAgentsCreatedAfter :many
|
||||
SELECT id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, instance_metadata, resource_metadata, directory, version, last_connected_replica_id, connection_timeout_seconds, troubleshooting_url, motd_file, lifecycle_state, expanded_directory, logs_length, logs_overflowed, started_at, ready_at, subsystems, display_apps, api_version, display_order, parent_id FROM workspace_agents WHERE created_at > $1
|
||||
`
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
-- name: InsertPreset :one
|
||||
INSERT INTO template_version_presets (
|
||||
id,
|
||||
template_version_id,
|
||||
name,
|
||||
created_at,
|
||||
@@ -7,6 +8,7 @@ INSERT INTO template_version_presets (
|
||||
invalidate_after_secs
|
||||
)
|
||||
VALUES (
|
||||
@id,
|
||||
@template_version_id,
|
||||
@name,
|
||||
@created_at,
|
||||
|
||||
@@ -253,6 +253,19 @@ WHERE
|
||||
wb.workspace_id = @workspace_id :: uuid
|
||||
);
|
||||
|
||||
-- name: GetWorkspaceAgentsByWorkspaceAndBuildNumber :many
|
||||
SELECT
|
||||
workspace_agents.*
|
||||
FROM
|
||||
workspace_agents
|
||||
JOIN
|
||||
workspace_resources ON workspace_agents.resource_id = workspace_resources.id
|
||||
JOIN
|
||||
workspace_builds ON workspace_resources.job_id = workspace_builds.job_id
|
||||
WHERE
|
||||
workspace_builds.workspace_id = @workspace_id :: uuid AND
|
||||
workspace_builds.build_number = @build_number :: int;
|
||||
|
||||
-- name: GetWorkspaceAgentAndLatestBuildByAuthToken :one
|
||||
SELECT
|
||||
sqlc.embed(workspaces),
|
||||
|
||||
@@ -0,0 +1,82 @@
|
||||
package prebuilds
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
)
|
||||
|
||||
func NewPubsubWorkspaceClaimPublisher(ps pubsub.Pubsub) *PubsubWorkspaceClaimPublisher {
|
||||
return &PubsubWorkspaceClaimPublisher{ps: ps}
|
||||
}
|
||||
|
||||
type PubsubWorkspaceClaimPublisher struct {
|
||||
ps pubsub.Pubsub
|
||||
}
|
||||
|
||||
func (p PubsubWorkspaceClaimPublisher) PublishWorkspaceClaim(claim agentsdk.ReinitializationEvent) error {
|
||||
channel := agentsdk.PrebuildClaimedChannel(claim.WorkspaceID)
|
||||
if err := p.ps.Publish(channel, []byte(claim.Reason)); err != nil {
|
||||
return xerrors.Errorf("failed to trigger prebuilt workspace agent reinitialization: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewPubsubWorkspaceClaimListener(ps pubsub.Pubsub, logger slog.Logger) *PubsubWorkspaceClaimListener {
|
||||
return &PubsubWorkspaceClaimListener{ps: ps, logger: logger}
|
||||
}
|
||||
|
||||
type PubsubWorkspaceClaimListener struct {
|
||||
logger slog.Logger
|
||||
ps pubsub.Pubsub
|
||||
}
|
||||
|
||||
// ListenForWorkspaceClaims subscribes to a pubsub channel and sends any received events on the chan that it returns.
|
||||
// pubsub.Pubsub does not communicate when its last callback has been called after it has been closed. As such the chan
|
||||
// returned by this method is never closed. Call the returned cancel() function to close the subscription when it is no longer needed.
|
||||
// cancel() will be called if ctx expires or is canceled.
|
||||
func (p PubsubWorkspaceClaimListener) ListenForWorkspaceClaims(ctx context.Context, workspaceID uuid.UUID, reinitEvents chan<- agentsdk.ReinitializationEvent) (func(), error) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return func() {}, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
cancelSub, err := p.ps.Subscribe(agentsdk.PrebuildClaimedChannel(workspaceID), func(inner context.Context, reason []byte) {
|
||||
claim := agentsdk.ReinitializationEvent{
|
||||
WorkspaceID: workspaceID,
|
||||
Reason: agentsdk.ReinitializationReason(reason),
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-inner.Done():
|
||||
return
|
||||
case reinitEvents <- claim:
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return func() {}, xerrors.Errorf("failed to subscribe to prebuild claimed channel: %w", err)
|
||||
}
|
||||
|
||||
var once sync.Once
|
||||
cancel := func() {
|
||||
once.Do(func() {
|
||||
cancelSub()
|
||||
})
|
||||
}
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
cancel()
|
||||
}()
|
||||
|
||||
return cancel, nil
|
||||
}
|
||||
@@ -0,0 +1,141 @@
|
||||
package prebuilds_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
"github.com/coder/coder/v2/coderd/prebuilds"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestPubsubWorkspaceClaimPublisher(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("published claim is received by a listener for the same workspace", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
logger := testutil.Logger(t)
|
||||
ps := pubsub.NewInMemory()
|
||||
workspaceID := uuid.New()
|
||||
reinitEvents := make(chan agentsdk.ReinitializationEvent, 1)
|
||||
publisher := prebuilds.NewPubsubWorkspaceClaimPublisher(ps)
|
||||
listener := prebuilds.NewPubsubWorkspaceClaimListener(ps, logger)
|
||||
|
||||
cancel, err := listener.ListenForWorkspaceClaims(ctx, workspaceID, reinitEvents)
|
||||
require.NoError(t, err)
|
||||
defer cancel()
|
||||
|
||||
claim := agentsdk.ReinitializationEvent{
|
||||
WorkspaceID: workspaceID,
|
||||
Reason: agentsdk.ReinitializeReasonPrebuildClaimed,
|
||||
}
|
||||
err = publisher.PublishWorkspaceClaim(claim)
|
||||
require.NoError(t, err)
|
||||
|
||||
gotEvent := testutil.RequireReceive(ctx, t, reinitEvents)
|
||||
require.Equal(t, workspaceID, gotEvent.WorkspaceID)
|
||||
require.Equal(t, claim.Reason, gotEvent.Reason)
|
||||
})
|
||||
|
||||
t.Run("fail to publish claim", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ps := &brokenPubsub{}
|
||||
|
||||
publisher := prebuilds.NewPubsubWorkspaceClaimPublisher(ps)
|
||||
claim := agentsdk.ReinitializationEvent{
|
||||
WorkspaceID: uuid.New(),
|
||||
Reason: agentsdk.ReinitializeReasonPrebuildClaimed,
|
||||
}
|
||||
|
||||
err := publisher.PublishWorkspaceClaim(claim)
|
||||
require.ErrorContains(t, err, "failed to trigger prebuilt workspace agent reinitialization")
|
||||
})
|
||||
}
|
||||
|
||||
func TestPubsubWorkspaceClaimListener(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("finds claim events for its workspace", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ps := pubsub.NewInMemory()
|
||||
listener := prebuilds.NewPubsubWorkspaceClaimListener(ps, slogtest.Make(t, nil))
|
||||
|
||||
claims := make(chan agentsdk.ReinitializationEvent, 1) // Buffer to avoid messing with goroutines in the rest of the test
|
||||
|
||||
workspaceID := uuid.New()
|
||||
cancelFunc, err := listener.ListenForWorkspaceClaims(context.Background(), workspaceID, claims)
|
||||
require.NoError(t, err)
|
||||
defer cancelFunc()
|
||||
|
||||
// Publish a claim
|
||||
channel := agentsdk.PrebuildClaimedChannel(workspaceID)
|
||||
reason := agentsdk.ReinitializeReasonPrebuildClaimed
|
||||
err = ps.Publish(channel, []byte(reason))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify we receive the claim
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
claim := testutil.RequireReceive(ctx, t, claims)
|
||||
require.Equal(t, workspaceID, claim.WorkspaceID)
|
||||
require.Equal(t, reason, claim.Reason)
|
||||
})
|
||||
|
||||
t.Run("ignores claim events for other workspaces", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ps := pubsub.NewInMemory()
|
||||
listener := prebuilds.NewPubsubWorkspaceClaimListener(ps, slogtest.Make(t, nil))
|
||||
|
||||
claims := make(chan agentsdk.ReinitializationEvent)
|
||||
workspaceID := uuid.New()
|
||||
otherWorkspaceID := uuid.New()
|
||||
cancelFunc, err := listener.ListenForWorkspaceClaims(context.Background(), workspaceID, claims)
|
||||
require.NoError(t, err)
|
||||
defer cancelFunc()
|
||||
|
||||
// Publish a claim for a different workspace
|
||||
channel := agentsdk.PrebuildClaimedChannel(otherWorkspaceID)
|
||||
err = ps.Publish(channel, []byte(agentsdk.ReinitializeReasonPrebuildClaimed))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify we don't receive the claim
|
||||
select {
|
||||
case <-claims:
|
||||
t.Fatal("received claim for wrong workspace")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
// Expected - no claim received
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("communicates the error if it can't subscribe", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
claims := make(chan agentsdk.ReinitializationEvent)
|
||||
ps := &brokenPubsub{}
|
||||
listener := prebuilds.NewPubsubWorkspaceClaimListener(ps, slogtest.Make(t, nil))
|
||||
|
||||
_, err := listener.ListenForWorkspaceClaims(context.Background(), uuid.New(), claims)
|
||||
require.ErrorContains(t, err, "failed to subscribe to prebuild claimed channel")
|
||||
})
|
||||
}
|
||||
|
||||
type brokenPubsub struct {
|
||||
pubsub.Pubsub
|
||||
}
|
||||
|
||||
func (brokenPubsub) Subscribe(_ string, _ pubsub.Listener) (func(), error) {
|
||||
return nil, xerrors.New("broken")
|
||||
}
|
||||
|
||||
func (brokenPubsub) Publish(_ string, _ []byte) error {
|
||||
return xerrors.New("broken")
|
||||
}
|
||||
@@ -40,12 +40,14 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
"github.com/coder/coder/v2/coderd/externalauth"
|
||||
"github.com/coder/coder/v2/coderd/notifications"
|
||||
"github.com/coder/coder/v2/coderd/prebuilds"
|
||||
"github.com/coder/coder/v2/coderd/promoauth"
|
||||
"github.com/coder/coder/v2/coderd/schedule"
|
||||
"github.com/coder/coder/v2/coderd/telemetry"
|
||||
"github.com/coder/coder/v2/coderd/tracing"
|
||||
"github.com/coder/coder/v2/coderd/wspubsub"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
"github.com/coder/coder/v2/provisioner"
|
||||
"github.com/coder/coder/v2/provisionerd/proto"
|
||||
"github.com/coder/coder/v2/provisionersdk"
|
||||
@@ -647,6 +649,30 @@ func (s *server) acquireProtoJob(ctx context.Context, job database.ProvisionerJo
|
||||
}
|
||||
}
|
||||
|
||||
runningAgentAuthTokens := []*sdkproto.RunningAgentAuthToken{}
|
||||
if input.PrebuiltWorkspaceBuildStage == sdkproto.PrebuiltWorkspaceBuildStage_CLAIM {
|
||||
// runningAgentAuthTokens are *only* used for prebuilds. We fetch them when we want to rebuild a prebuilt workspace
|
||||
// but not generate new agent tokens. The provisionerdserver will push them down to
|
||||
// the provisioner (and ultimately to the `coder_agent` resource in the Terraform provider) where they will be
|
||||
// reused. Context: the agent token is often used in immutable attributes of workspace resource (e.g. VM/container)
|
||||
// to initialize the agent, so if that value changes it will necessitate a replacement of that resource, thus
|
||||
// obviating the whole point of the prebuild.
|
||||
agents, err := s.Database.GetWorkspaceAgentsByWorkspaceAndBuildNumber(ctx, database.GetWorkspaceAgentsByWorkspaceAndBuildNumberParams{
|
||||
WorkspaceID: workspace.ID,
|
||||
BuildNumber: 1,
|
||||
})
|
||||
if err != nil {
|
||||
s.Logger.Error(ctx, "failed to retrieve running agents of claimed prebuilt workspace",
|
||||
slog.F("workspace_id", workspace.ID), slog.Error(err))
|
||||
}
|
||||
for _, agent := range agents {
|
||||
runningAgentAuthTokens = append(runningAgentAuthTokens, &sdkproto.RunningAgentAuthToken{
|
||||
AgentId: agent.ID.String(),
|
||||
Token: agent.AuthToken.String(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
protoJob.Type = &proto.AcquiredJob_WorkspaceBuild_{
|
||||
WorkspaceBuild: &proto.AcquiredJob_WorkspaceBuild{
|
||||
WorkspaceBuildId: workspaceBuild.ID.String(),
|
||||
@@ -676,6 +702,7 @@ func (s *server) acquireProtoJob(ctx context.Context, job database.ProvisionerJo
|
||||
WorkspaceBuildId: workspaceBuild.ID.String(),
|
||||
WorkspaceOwnerLoginType: string(owner.LoginType),
|
||||
WorkspaceOwnerRbacRoles: ownerRbacRoles,
|
||||
RunningAgentAuthTokens: runningAgentAuthTokens,
|
||||
PrebuiltWorkspaceBuildStage: input.PrebuiltWorkspaceBuildStage,
|
||||
},
|
||||
LogLevel: input.LogLevel,
|
||||
@@ -1812,6 +1839,19 @@ func (s *server) CompleteJob(ctx context.Context, completed *proto.CompletedJob)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("update workspace: %w", err)
|
||||
}
|
||||
|
||||
if input.PrebuiltWorkspaceBuildStage == sdkproto.PrebuiltWorkspaceBuildStage_CLAIM {
|
||||
s.Logger.Info(ctx, "workspace prebuild successfully claimed by user",
|
||||
slog.F("workspace_id", workspace.ID))
|
||||
|
||||
err = prebuilds.NewPubsubWorkspaceClaimPublisher(s.Pubsub).PublishWorkspaceClaim(agentsdk.ReinitializationEvent{
|
||||
WorkspaceID: workspace.ID,
|
||||
Reason: agentsdk.ReinitializeReasonPrebuildClaimed,
|
||||
})
|
||||
if err != nil {
|
||||
s.Logger.Error(ctx, "failed to publish workspace claim event", slog.Error(err))
|
||||
}
|
||||
}
|
||||
case *proto.CompletedJob_TemplateDryRun_:
|
||||
for _, resource := range jobType.TemplateDryRun.Resources {
|
||||
s.Logger.Info(ctx, "inserting template dry-run job resource",
|
||||
@@ -1955,6 +1995,7 @@ func InsertWorkspacePresetAndParameters(ctx context.Context, db database.Store,
|
||||
}
|
||||
}
|
||||
dbPreset, err := tx.InsertPreset(ctx, database.InsertPresetParams{
|
||||
ID: uuid.New(),
|
||||
TemplateVersionID: templateVersionID,
|
||||
Name: protoPreset.Name,
|
||||
CreatedAt: t,
|
||||
|
||||
@@ -26,7 +26,10 @@ import (
|
||||
"github.com/coder/quartz"
|
||||
"github.com/coder/serpent"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/prebuilds"
|
||||
"github.com/coder/coder/v2/coderd/provisionerdserver"
|
||||
"github.com/coder/coder/v2/coderd/rbac"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
|
||||
"github.com/coder/coder/v2/buildinfo"
|
||||
"github.com/coder/coder/v2/coderd/audit"
|
||||
@@ -39,7 +42,6 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/externalauth"
|
||||
"github.com/coder/coder/v2/coderd/notifications"
|
||||
"github.com/coder/coder/v2/coderd/notifications/notificationstest"
|
||||
"github.com/coder/coder/v2/coderd/provisionerdserver"
|
||||
"github.com/coder/coder/v2/coderd/schedule"
|
||||
"github.com/coder/coder/v2/coderd/schedule/cron"
|
||||
"github.com/coder/coder/v2/coderd/telemetry"
|
||||
@@ -167,8 +169,12 @@ func TestAcquireJob(t *testing.T) {
|
||||
_, err = tc.acquire(ctx, srv)
|
||||
require.ErrorContains(t, err, "sql: no rows in result set")
|
||||
})
|
||||
for _, prebuiltWorkspace := range []bool{false, true} {
|
||||
prebuiltWorkspace := prebuiltWorkspace
|
||||
for _, prebuiltWorkspaceBuildStage := range []sdkproto.PrebuiltWorkspaceBuildStage{
|
||||
sdkproto.PrebuiltWorkspaceBuildStage_NONE,
|
||||
sdkproto.PrebuiltWorkspaceBuildStage_CREATE,
|
||||
sdkproto.PrebuiltWorkspaceBuildStage_CLAIM,
|
||||
} {
|
||||
prebuiltWorkspaceBuildStage := prebuiltWorkspaceBuildStage
|
||||
t.Run(tc.name+"_WorkspaceBuildJob", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Set the max session token lifetime so we can assert we
|
||||
@@ -212,7 +218,7 @@ func TestAcquireJob(t *testing.T) {
|
||||
Roles: []string{rbac.RoleOrgAuditor()},
|
||||
})
|
||||
|
||||
// Add extra erronous roles
|
||||
// Add extra erroneous roles
|
||||
secondOrg := dbgen.Organization(t, db, database.Organization{})
|
||||
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
||||
UserID: user.ID,
|
||||
@@ -287,36 +293,74 @@ func TestAcquireJob(t *testing.T) {
|
||||
Required: true,
|
||||
Sensitive: false,
|
||||
})
|
||||
workspace := dbgen.Workspace(t, db, database.WorkspaceTable{
|
||||
workspace := database.WorkspaceTable{
|
||||
TemplateID: template.ID,
|
||||
OwnerID: user.ID,
|
||||
OrganizationID: pd.OrganizationID,
|
||||
})
|
||||
build := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
||||
}
|
||||
workspace = dbgen.Workspace(t, db, workspace)
|
||||
build := database.WorkspaceBuild{
|
||||
WorkspaceID: workspace.ID,
|
||||
BuildNumber: 1,
|
||||
JobID: uuid.New(),
|
||||
TemplateVersionID: version.ID,
|
||||
Transition: database.WorkspaceTransitionStart,
|
||||
Reason: database.BuildReasonInitiator,
|
||||
})
|
||||
var buildState sdkproto.PrebuiltWorkspaceBuildStage
|
||||
if prebuiltWorkspace {
|
||||
buildState = sdkproto.PrebuiltWorkspaceBuildStage_CREATE
|
||||
}
|
||||
_ = dbgen.ProvisionerJob(t, db, ps, database.ProvisionerJob{
|
||||
ID: build.ID,
|
||||
build = dbgen.WorkspaceBuild(t, db, build)
|
||||
input := provisionerdserver.WorkspaceProvisionJob{
|
||||
WorkspaceBuildID: build.ID,
|
||||
}
|
||||
dbJob := database.ProvisionerJob{
|
||||
ID: build.JobID,
|
||||
OrganizationID: pd.OrganizationID,
|
||||
InitiatorID: user.ID,
|
||||
Provisioner: database.ProvisionerTypeEcho,
|
||||
StorageMethod: database.ProvisionerStorageMethodFile,
|
||||
FileID: file.ID,
|
||||
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
||||
Input: must(json.Marshal(provisionerdserver.WorkspaceProvisionJob{
|
||||
Input: must(json.Marshal(input)),
|
||||
}
|
||||
dbJob = dbgen.ProvisionerJob(t, db, ps, dbJob)
|
||||
|
||||
var agent database.WorkspaceAgent
|
||||
if prebuiltWorkspaceBuildStage == sdkproto.PrebuiltWorkspaceBuildStage_CLAIM {
|
||||
resource := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
||||
JobID: dbJob.ID,
|
||||
})
|
||||
agent = dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
||||
ResourceID: resource.ID,
|
||||
AuthToken: uuid.New(),
|
||||
})
|
||||
// At this point we have an unclaimed workspace and build, now we need to setup the claim
|
||||
// build
|
||||
build = database.WorkspaceBuild{
|
||||
WorkspaceID: workspace.ID,
|
||||
BuildNumber: 2,
|
||||
JobID: uuid.New(),
|
||||
TemplateVersionID: version.ID,
|
||||
Transition: database.WorkspaceTransitionStart,
|
||||
Reason: database.BuildReasonInitiator,
|
||||
InitiatorID: user.ID,
|
||||
}
|
||||
build = dbgen.WorkspaceBuild(t, db, build)
|
||||
|
||||
input = provisionerdserver.WorkspaceProvisionJob{
|
||||
WorkspaceBuildID: build.ID,
|
||||
PrebuiltWorkspaceBuildStage: buildState,
|
||||
})),
|
||||
})
|
||||
PrebuiltWorkspaceBuildStage: prebuiltWorkspaceBuildStage,
|
||||
}
|
||||
dbJob = database.ProvisionerJob{
|
||||
ID: build.JobID,
|
||||
OrganizationID: pd.OrganizationID,
|
||||
InitiatorID: user.ID,
|
||||
Provisioner: database.ProvisionerTypeEcho,
|
||||
StorageMethod: database.ProvisionerStorageMethodFile,
|
||||
FileID: file.ID,
|
||||
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
||||
Input: must(json.Marshal(input)),
|
||||
}
|
||||
dbJob = dbgen.ProvisionerJob(t, db, ps, dbJob)
|
||||
}
|
||||
|
||||
startPublished := make(chan struct{})
|
||||
var closed bool
|
||||
@@ -350,6 +394,19 @@ func TestAcquireJob(t *testing.T) {
|
||||
|
||||
<-startPublished
|
||||
|
||||
if prebuiltWorkspaceBuildStage == sdkproto.PrebuiltWorkspaceBuildStage_CLAIM {
|
||||
for {
|
||||
// In the case of a prebuild claim, there is a second build, which is the
|
||||
// one that we're interested in.
|
||||
job, err = tc.acquire(ctx, srv)
|
||||
require.NoError(t, err)
|
||||
if _, ok := job.Type.(*proto.AcquiredJob_WorkspaceBuild_); ok {
|
||||
break
|
||||
}
|
||||
}
|
||||
<-startPublished
|
||||
}
|
||||
|
||||
got, err := json.Marshal(job.Type)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -384,8 +441,14 @@ func TestAcquireJob(t *testing.T) {
|
||||
WorkspaceOwnerLoginType: string(user.LoginType),
|
||||
WorkspaceOwnerRbacRoles: []*sdkproto.Role{{Name: rbac.RoleOrgMember(), OrgId: pd.OrganizationID.String()}, {Name: "member", OrgId: ""}, {Name: rbac.RoleOrgAuditor(), OrgId: pd.OrganizationID.String()}},
|
||||
}
|
||||
if prebuiltWorkspace {
|
||||
wantedMetadata.PrebuiltWorkspaceBuildStage = sdkproto.PrebuiltWorkspaceBuildStage_CREATE
|
||||
if prebuiltWorkspaceBuildStage == sdkproto.PrebuiltWorkspaceBuildStage_CLAIM {
|
||||
// For claimed prebuilds, we expect the prebuild state to be set to CLAIM
|
||||
// and we expect tokens from the first build to be set for reuse
|
||||
wantedMetadata.PrebuiltWorkspaceBuildStage = prebuiltWorkspaceBuildStage
|
||||
wantedMetadata.RunningAgentAuthTokens = append(wantedMetadata.RunningAgentAuthTokens, &sdkproto.RunningAgentAuthToken{
|
||||
AgentId: agent.ID.String(),
|
||||
Token: agent.AuthToken.String(),
|
||||
})
|
||||
}
|
||||
|
||||
slices.SortFunc(wantedMetadata.WorkspaceOwnerRbacRoles, func(a, b *sdkproto.Role) int {
|
||||
@@ -1750,6 +1813,110 @@ func TestCompleteJob(t *testing.T) {
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ReinitializePrebuiltAgents", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
type testcase struct {
|
||||
name string
|
||||
shouldReinitializeAgent bool
|
||||
}
|
||||
|
||||
for _, tc := range []testcase{
|
||||
// Whether or not there are presets and those presets define prebuilds, etc
|
||||
// are all irrelevant at this level. Those factors are useful earlier in the process.
|
||||
// Everything relevant to this test is determined by the value of `PrebuildClaimedByUser`
|
||||
// on the provisioner job. As such, there are only two significant test cases:
|
||||
{
|
||||
name: "claimed prebuild",
|
||||
shouldReinitializeAgent: true,
|
||||
},
|
||||
{
|
||||
name: "not a claimed prebuild",
|
||||
shouldReinitializeAgent: false,
|
||||
},
|
||||
} {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// GIVEN an enqueued provisioner job and its dependencies:
|
||||
|
||||
srv, db, ps, pd := setup(t, false, &overrides{})
|
||||
|
||||
buildID := uuid.New()
|
||||
jobInput := provisionerdserver.WorkspaceProvisionJob{
|
||||
WorkspaceBuildID: buildID,
|
||||
}
|
||||
if tc.shouldReinitializeAgent { // This is the key lever in the test
|
||||
// GIVEN the enqueued provisioner job is for a workspace being claimed by a user:
|
||||
jobInput.PrebuiltWorkspaceBuildStage = sdkproto.PrebuiltWorkspaceBuildStage_CLAIM
|
||||
}
|
||||
input, err := json.Marshal(jobInput)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
|
||||
Input: input,
|
||||
Provisioner: database.ProvisionerTypeEcho,
|
||||
StorageMethod: database.ProvisionerStorageMethodFile,
|
||||
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
tpl := dbgen.Template(t, db, database.Template{
|
||||
OrganizationID: pd.OrganizationID,
|
||||
})
|
||||
tv := dbgen.TemplateVersion(t, db, database.TemplateVersion{
|
||||
TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true},
|
||||
JobID: job.ID,
|
||||
})
|
||||
workspace := dbgen.Workspace(t, db, database.WorkspaceTable{
|
||||
TemplateID: tpl.ID,
|
||||
})
|
||||
_ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
||||
ID: buildID,
|
||||
JobID: job.ID,
|
||||
WorkspaceID: workspace.ID,
|
||||
TemplateVersionID: tv.ID,
|
||||
})
|
||||
_, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
|
||||
WorkerID: uuid.NullUUID{
|
||||
UUID: pd.ID,
|
||||
Valid: true,
|
||||
},
|
||||
Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// GIVEN something is listening to process workspace reinitialization:
|
||||
reinitChan := make(chan agentsdk.ReinitializationEvent, 1) // Buffered to simplify test structure
|
||||
cancel, err := prebuilds.NewPubsubWorkspaceClaimListener(ps, testutil.Logger(t)).ListenForWorkspaceClaims(ctx, workspace.ID, reinitChan)
|
||||
require.NoError(t, err)
|
||||
defer cancel()
|
||||
|
||||
// WHEN the job is completed
|
||||
completedJob := proto.CompletedJob{
|
||||
JobId: job.ID.String(),
|
||||
Type: &proto.CompletedJob_WorkspaceBuild_{
|
||||
WorkspaceBuild: &proto.CompletedJob_WorkspaceBuild{},
|
||||
},
|
||||
}
|
||||
_, err = srv.CompleteJob(ctx, &completedJob)
|
||||
require.NoError(t, err)
|
||||
|
||||
if tc.shouldReinitializeAgent {
|
||||
event := testutil.RequireReceive(ctx, t, reinitChan)
|
||||
require.Equal(t, workspace.ID, event.WorkspaceID)
|
||||
} else {
|
||||
select {
|
||||
case <-reinitChan:
|
||||
t.Fatal("unexpected reinitialization event published")
|
||||
default:
|
||||
// OK
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestInsertWorkspacePresetsAndParameters(t *testing.T) {
|
||||
|
||||
@@ -35,6 +35,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
"github.com/coder/coder/v2/coderd/httpmw/loggermw"
|
||||
"github.com/coder/coder/v2/coderd/jwtutils"
|
||||
"github.com/coder/coder/v2/coderd/prebuilds"
|
||||
"github.com/coder/coder/v2/coderd/rbac"
|
||||
"github.com/coder/coder/v2/coderd/rbac/policy"
|
||||
"github.com/coder/coder/v2/coderd/telemetry"
|
||||
@@ -1183,6 +1184,60 @@ func (api *API) workspaceAgentPostLogSource(rw http.ResponseWriter, r *http.Requ
|
||||
httpapi.Write(ctx, rw, http.StatusCreated, apiSource)
|
||||
}
|
||||
|
||||
// @Summary Get workspace agent reinitialization
|
||||
// @ID get-workspace-agent-reinitialization
|
||||
// @Security CoderSessionToken
|
||||
// @Produce json
|
||||
// @Tags Agents
|
||||
// @Success 200 {object} agentsdk.ReinitializationEvent
|
||||
// @Router /workspaceagents/me/reinit [get]
|
||||
func (api *API) workspaceAgentReinit(rw http.ResponseWriter, r *http.Request) {
|
||||
// Allow us to interrupt watch via cancel.
|
||||
ctx, cancel := context.WithCancel(r.Context())
|
||||
defer cancel()
|
||||
r = r.WithContext(ctx) // Rewire context for SSE cancellation.
|
||||
|
||||
workspaceAgent := httpmw.WorkspaceAgent(r)
|
||||
log := api.Logger.Named("workspace_agent_reinit_watcher").With(
|
||||
slog.F("workspace_agent_id", workspaceAgent.ID),
|
||||
)
|
||||
|
||||
workspace, err := api.Database.GetWorkspaceByAgentID(ctx, workspaceAgent.ID)
|
||||
if err != nil {
|
||||
log.Error(ctx, "failed to retrieve workspace from agent token", slog.Error(err))
|
||||
httpapi.InternalServerError(rw, xerrors.New("failed to determine workspace from agent token"))
|
||||
}
|
||||
|
||||
log.Info(ctx, "agent waiting for reinit instruction")
|
||||
|
||||
reinitEvents := make(chan agentsdk.ReinitializationEvent)
|
||||
cancel, err = prebuilds.NewPubsubWorkspaceClaimListener(api.Pubsub, log).ListenForWorkspaceClaims(ctx, workspace.ID, reinitEvents)
|
||||
if err != nil {
|
||||
log.Error(ctx, "subscribe to prebuild claimed channel", slog.Error(err))
|
||||
httpapi.InternalServerError(rw, xerrors.New("failed to subscribe to prebuild claimed channel"))
|
||||
return
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
transmitter := agentsdk.NewSSEAgentReinitTransmitter(log, rw, r)
|
||||
|
||||
err = transmitter.Transmit(ctx, reinitEvents)
|
||||
switch {
|
||||
case errors.Is(err, agentsdk.ErrTransmissionSourceClosed):
|
||||
log.Info(ctx, "agent reinitialization subscription closed", slog.F("workspace_agent_id", workspaceAgent.ID))
|
||||
case errors.Is(err, agentsdk.ErrTransmissionTargetClosed):
|
||||
log.Info(ctx, "agent connection closed", slog.F("workspace_agent_id", workspaceAgent.ID))
|
||||
case errors.Is(err, context.Canceled):
|
||||
log.Info(ctx, "agent reinitialization", slog.Error(err))
|
||||
case err != nil:
|
||||
log.Error(ctx, "failed to stream agent reinit events", slog.Error(err))
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error streaming agent reinitialization events.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// convertProvisionedApps converts applications that are in the middle of provisioning process.
|
||||
// It means that they may not have an agent or workspace assigned (dry-run job).
|
||||
func convertProvisionedApps(dbApps []database.WorkspaceApp) []codersdk.WorkspaceApp {
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -44,10 +45,12 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/database/dbfake"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbmem"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
"github.com/coder/coder/v2/coderd/externalauth"
|
||||
"github.com/coder/coder/v2/coderd/jwtutils"
|
||||
"github.com/coder/coder/v2/coderd/prebuilds"
|
||||
"github.com/coder/coder/v2/coderd/rbac"
|
||||
"github.com/coder/coder/v2/coderd/telemetry"
|
||||
"github.com/coder/coder/v2/coderd/util/ptr"
|
||||
@@ -2641,3 +2644,70 @@ func TestAgentConnectionInfo(t *testing.T) {
|
||||
require.True(t, info.DisableDirectConnections)
|
||||
require.True(t, info.DERPForceWebSockets)
|
||||
}
|
||||
|
||||
func TestReinit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
pubsubSpy := pubsubReinitSpy{
|
||||
Pubsub: ps,
|
||||
subscribed: make(chan string),
|
||||
}
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
Database: db,
|
||||
Pubsub: &pubsubSpy,
|
||||
})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OrganizationID: user.OrganizationID,
|
||||
OwnerID: user.UserID,
|
||||
}).WithAgent().Do()
|
||||
|
||||
pubsubSpy.Mutex.Lock()
|
||||
pubsubSpy.expectedEvent = agentsdk.PrebuildClaimedChannel(r.Workspace.ID)
|
||||
pubsubSpy.Mutex.Unlock()
|
||||
|
||||
agentCtx := testutil.Context(t, testutil.WaitShort)
|
||||
agentClient := agentsdk.New(client.URL)
|
||||
agentClient.SetSessionToken(r.AgentToken)
|
||||
|
||||
agentReinitializedCh := make(chan *agentsdk.ReinitializationEvent)
|
||||
go func() {
|
||||
reinitEvent, err := agentClient.WaitForReinit(agentCtx)
|
||||
assert.NoError(t, err)
|
||||
agentReinitializedCh <- reinitEvent
|
||||
}()
|
||||
|
||||
// We need to subscribe before we publish, lest we miss the event
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
testutil.TryReceive(ctx, t, pubsubSpy.subscribed) // Wait for the appropriate subscription
|
||||
|
||||
// Now that we're subscribed, publish the event
|
||||
err := prebuilds.NewPubsubWorkspaceClaimPublisher(ps).PublishWorkspaceClaim(agentsdk.ReinitializationEvent{
|
||||
WorkspaceID: r.Workspace.ID,
|
||||
Reason: agentsdk.ReinitializeReasonPrebuildClaimed,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
reinitEvent := testutil.TryReceive(ctx, t, agentReinitializedCh)
|
||||
require.NotNil(t, reinitEvent)
|
||||
require.Equal(t, r.Workspace.ID, reinitEvent.WorkspaceID)
|
||||
}
|
||||
|
||||
type pubsubReinitSpy struct {
|
||||
pubsub.Pubsub
|
||||
sync.Mutex
|
||||
subscribed chan string
|
||||
expectedEvent string
|
||||
}
|
||||
|
||||
func (p *pubsubReinitSpy) Subscribe(event string, listener pubsub.Listener) (cancel func(), err error) {
|
||||
p.Lock()
|
||||
if p.expectedEvent != "" && event == p.expectedEvent {
|
||||
close(p.subscribed)
|
||||
}
|
||||
p.Unlock()
|
||||
return p.Pubsub.Subscribe(event, listener)
|
||||
}
|
||||
|
||||
@@ -628,9 +628,9 @@ func createWorkspace(
|
||||
|
||||
err = api.Database.InTx(func(db database.Store) error {
|
||||
var (
|
||||
prebuildsClaimer = *api.PrebuildsClaimer.Load()
|
||||
workspaceID uuid.UUID
|
||||
claimedWorkspace *database.Workspace
|
||||
prebuildsClaimer = *api.PrebuildsClaimer.Load()
|
||||
)
|
||||
|
||||
// If a template preset was chosen, try claim a prebuilt workspace.
|
||||
@@ -704,8 +704,7 @@ func createWorkspace(
|
||||
Reason(database.BuildReasonInitiator).
|
||||
Initiator(initiatorID).
|
||||
ActiveVersion().
|
||||
RichParameterValues(req.RichParameterValues).
|
||||
TemplateVersionPresetID(req.TemplateVersionPresetID)
|
||||
RichParameterValues(req.RichParameterValues)
|
||||
if req.TemplateVersionID != uuid.Nil {
|
||||
builder = builder.VersionID(req.TemplateVersionID)
|
||||
}
|
||||
|
||||
@@ -77,8 +77,7 @@ type Builder struct {
|
||||
parameterValues *[]string
|
||||
templateVersionPresetParameterValues []database.TemplateVersionPresetParameter
|
||||
|
||||
prebuiltWorkspaceBuildStage sdkproto.PrebuiltWorkspaceBuildStage
|
||||
|
||||
prebuiltWorkspaceBuildStage sdkproto.PrebuiltWorkspaceBuildStage
|
||||
verifyNoLegacyParametersOnce bool
|
||||
}
|
||||
|
||||
|
||||
@@ -19,12 +19,15 @@ import (
|
||||
"tailscale.com/tailcfg"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/retry"
|
||||
"github.com/coder/websocket"
|
||||
|
||||
"github.com/coder/coder/v2/agent/proto"
|
||||
"github.com/coder/coder/v2/apiversion"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/drpcsdk"
|
||||
tailnetproto "github.com/coder/coder/v2/tailnet/proto"
|
||||
"github.com/coder/websocket"
|
||||
)
|
||||
|
||||
// ExternalLogSourceID is the statically-defined ID of a log-source that
|
||||
@@ -686,3 +689,188 @@ func LogsNotifyChannel(agentID uuid.UUID) string {
|
||||
type LogsNotifyMessage struct {
|
||||
CreatedAfter int64 `json:"created_after"`
|
||||
}
|
||||
|
||||
type ReinitializationReason string
|
||||
|
||||
const (
|
||||
ReinitializeReasonPrebuildClaimed ReinitializationReason = "prebuild_claimed"
|
||||
)
|
||||
|
||||
type ReinitializationEvent struct {
|
||||
WorkspaceID uuid.UUID
|
||||
Reason ReinitializationReason `json:"reason"`
|
||||
}
|
||||
|
||||
func PrebuildClaimedChannel(id uuid.UUID) string {
|
||||
return fmt.Sprintf("prebuild_claimed_%s", id)
|
||||
}
|
||||
|
||||
// WaitForReinit polls a SSE endpoint, and receives an event back under the following conditions:
|
||||
// - ping: ignored, keepalive
|
||||
// - prebuild claimed: a prebuilt workspace is claimed, so the agent must reinitialize.
|
||||
func (c *Client) WaitForReinit(ctx context.Context) (*ReinitializationEvent, error) {
|
||||
rpcURL, err := c.SDK.URL.Parse("/api/v2/workspaceagents/me/reinit")
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse url: %w", err)
|
||||
}
|
||||
|
||||
jar, err := cookiejar.New(nil)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("create cookie jar: %w", err)
|
||||
}
|
||||
jar.SetCookies(rpcURL, []*http.Cookie{{
|
||||
Name: codersdk.SessionTokenCookie,
|
||||
Value: c.SDK.SessionToken(),
|
||||
}})
|
||||
httpClient := &http.Client{
|
||||
Jar: jar,
|
||||
Transport: c.SDK.HTTPClient.Transport,
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, rpcURL.String(), nil)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("build request: %w", err)
|
||||
}
|
||||
|
||||
res, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("execute request: %w", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return nil, codersdk.ReadBodyAsError(res)
|
||||
}
|
||||
|
||||
reinitEvent, err := NewSSEAgentReinitReceiver(res.Body).Receive(ctx)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("listening for reinitialization events: %w", err)
|
||||
}
|
||||
return reinitEvent, nil
|
||||
}
|
||||
|
||||
func WaitForReinitLoop(ctx context.Context, logger slog.Logger, client *Client) <-chan ReinitializationEvent {
|
||||
reinitEvents := make(chan ReinitializationEvent)
|
||||
|
||||
go func() {
|
||||
for retrier := retry.New(100*time.Millisecond, 10*time.Second); retrier.Wait(ctx); {
|
||||
logger.Debug(ctx, "waiting for agent reinitialization instructions")
|
||||
reinitEvent, err := client.WaitForReinit(ctx)
|
||||
if err != nil {
|
||||
logger.Error(ctx, "failed to wait for agent reinitialization instructions", slog.Error(err))
|
||||
continue
|
||||
}
|
||||
retrier.Reset()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
close(reinitEvents)
|
||||
return
|
||||
case reinitEvents <- *reinitEvent:
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return reinitEvents
|
||||
}
|
||||
|
||||
func NewSSEAgentReinitTransmitter(logger slog.Logger, rw http.ResponseWriter, r *http.Request) *SSEAgentReinitTransmitter {
|
||||
return &SSEAgentReinitTransmitter{logger: logger, rw: rw, r: r}
|
||||
}
|
||||
|
||||
type SSEAgentReinitTransmitter struct {
|
||||
rw http.ResponseWriter
|
||||
r *http.Request
|
||||
logger slog.Logger
|
||||
}
|
||||
|
||||
var (
|
||||
ErrTransmissionSourceClosed = xerrors.New("transmission source closed")
|
||||
ErrTransmissionTargetClosed = xerrors.New("transmission target closed")
|
||||
)
|
||||
|
||||
// Transmit will read from the given chan and send events for as long as:
|
||||
// * the chan remains open
|
||||
// * the context has not been canceled
|
||||
// * not timed out
|
||||
// * the connection to the receiver remains open
|
||||
func (s *SSEAgentReinitTransmitter) Transmit(ctx context.Context, reinitEvents <-chan ReinitializationEvent) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
sseSendEvent, sseSenderClosed, err := httpapi.ServerSentEventSender(s.rw, s.r)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to create sse transmitter: %w", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
// Block returning until the ServerSentEventSender is closed
|
||||
// to avoid a race condition where we might write or flush to rw after the handler returns.
|
||||
<-sseSenderClosed
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-sseSenderClosed:
|
||||
return ErrTransmissionTargetClosed
|
||||
case reinitEvent, ok := <-reinitEvents:
|
||||
if !ok {
|
||||
return ErrTransmissionSourceClosed
|
||||
}
|
||||
err := sseSendEvent(codersdk.ServerSentEvent{
|
||||
Type: codersdk.ServerSentEventTypeData,
|
||||
Data: reinitEvent,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func NewSSEAgentReinitReceiver(r io.ReadCloser) *SSEAgentReinitReceiver {
|
||||
return &SSEAgentReinitReceiver{r: r}
|
||||
}
|
||||
|
||||
type SSEAgentReinitReceiver struct {
|
||||
r io.ReadCloser
|
||||
}
|
||||
|
||||
func (s *SSEAgentReinitReceiver) Receive(ctx context.Context) (*ReinitializationEvent, error) {
|
||||
nextEvent := codersdk.ServerSentEventReader(ctx, s.r)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
sse, err := nextEvent()
|
||||
switch {
|
||||
case err != nil:
|
||||
return nil, xerrors.Errorf("failed to read server-sent event: %w", err)
|
||||
case sse.Type == codersdk.ServerSentEventTypeError:
|
||||
return nil, xerrors.Errorf("unexpected server sent event type error")
|
||||
case sse.Type == codersdk.ServerSentEventTypePing:
|
||||
continue
|
||||
case sse.Type != codersdk.ServerSentEventTypeData:
|
||||
return nil, xerrors.Errorf("unexpected server sent event type: %s", sse.Type)
|
||||
}
|
||||
|
||||
// At this point we know that the sent event is of type codersdk.ServerSentEventTypeData
|
||||
var reinitEvent ReinitializationEvent
|
||||
b, ok := sse.Data.([]byte)
|
||||
if !ok {
|
||||
return nil, xerrors.Errorf("expected data as []byte, got %T", sse.Data)
|
||||
}
|
||||
err = json.Unmarshal(b, &reinitEvent)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("unmarshal reinit response: %w", err)
|
||||
}
|
||||
return &reinitEvent, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,122 @@
|
||||
package agentsdk_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestStreamAgentReinitEvents(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("transmitted events are received", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
eventToSend := agentsdk.ReinitializationEvent{
|
||||
WorkspaceID: uuid.New(),
|
||||
Reason: agentsdk.ReinitializeReasonPrebuildClaimed,
|
||||
}
|
||||
|
||||
events := make(chan agentsdk.ReinitializationEvent, 1)
|
||||
events <- eventToSend
|
||||
|
||||
transmitCtx := testutil.Context(t, testutil.WaitShort)
|
||||
transmitErrCh := make(chan error, 1)
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
transmitter := agentsdk.NewSSEAgentReinitTransmitter(slogtest.Make(t, nil), w, r)
|
||||
transmitErrCh <- transmitter.Transmit(transmitCtx, events)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
requestCtx := testutil.Context(t, testutil.WaitShort)
|
||||
req, err := http.NewRequestWithContext(requestCtx, "GET", srv.URL, nil)
|
||||
require.NoError(t, err)
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
receiveCtx := testutil.Context(t, testutil.WaitShort)
|
||||
receiver := agentsdk.NewSSEAgentReinitReceiver(resp.Body)
|
||||
sentEvent, receiveErr := receiver.Receive(receiveCtx)
|
||||
require.Nil(t, receiveErr)
|
||||
require.Equal(t, eventToSend, *sentEvent)
|
||||
})
|
||||
|
||||
t.Run("doesn't transmit events if the transmitter context is canceled", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
eventToSend := agentsdk.ReinitializationEvent{
|
||||
WorkspaceID: uuid.New(),
|
||||
Reason: agentsdk.ReinitializeReasonPrebuildClaimed,
|
||||
}
|
||||
|
||||
events := make(chan agentsdk.ReinitializationEvent, 1)
|
||||
events <- eventToSend
|
||||
|
||||
transmitCtx, cancelTransmit := context.WithCancel(testutil.Context(t, testutil.WaitShort))
|
||||
cancelTransmit()
|
||||
transmitErrCh := make(chan error, 1)
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
transmitter := agentsdk.NewSSEAgentReinitTransmitter(slogtest.Make(t, nil), w, r)
|
||||
transmitErrCh <- transmitter.Transmit(transmitCtx, events)
|
||||
}))
|
||||
|
||||
defer srv.Close()
|
||||
|
||||
requestCtx := testutil.Context(t, testutil.WaitShort)
|
||||
req, err := http.NewRequestWithContext(requestCtx, "GET", srv.URL, nil)
|
||||
require.NoError(t, err)
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
receiveCtx := testutil.Context(t, testutil.WaitShort)
|
||||
receiver := agentsdk.NewSSEAgentReinitReceiver(resp.Body)
|
||||
sentEvent, receiveErr := receiver.Receive(receiveCtx)
|
||||
require.Nil(t, sentEvent)
|
||||
require.ErrorIs(t, receiveErr, io.EOF)
|
||||
})
|
||||
|
||||
t.Run("does not receive events if the receiver context is canceled", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
eventToSend := agentsdk.ReinitializationEvent{
|
||||
WorkspaceID: uuid.New(),
|
||||
Reason: agentsdk.ReinitializeReasonPrebuildClaimed,
|
||||
}
|
||||
|
||||
events := make(chan agentsdk.ReinitializationEvent, 1)
|
||||
events <- eventToSend
|
||||
|
||||
transmitCtx := testutil.Context(t, testutil.WaitShort)
|
||||
transmitErrCh := make(chan error, 1)
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
transmitter := agentsdk.NewSSEAgentReinitTransmitter(slogtest.Make(t, nil), w, r)
|
||||
transmitErrCh <- transmitter.Transmit(transmitCtx, events)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
requestCtx := testutil.Context(t, testutil.WaitShort)
|
||||
req, err := http.NewRequestWithContext(requestCtx, "GET", srv.URL, nil)
|
||||
require.NoError(t, err)
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
receiveCtx, cancelReceive := context.WithCancel(context.Background())
|
||||
cancelReceive()
|
||||
receiver := agentsdk.NewSSEAgentReinitReceiver(resp.Body)
|
||||
sentEvent, receiveErr := receiver.Receive(receiveCtx)
|
||||
require.Nil(t, sentEvent)
|
||||
require.ErrorIs(t, receiveErr, context.Canceled)
|
||||
})
|
||||
}
|
||||
+1
-1
@@ -631,7 +631,7 @@ func (h *HeaderTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
}
|
||||
}
|
||||
if h.Transport == nil {
|
||||
h.Transport = http.DefaultTransport
|
||||
return http.DefaultTransport.RoundTrip(req)
|
||||
}
|
||||
return h.Transport.RoundTrip(req)
|
||||
}
|
||||
|
||||
Generated
+32
@@ -470,6 +470,38 @@ curl -X PATCH http://coder-server:8080/api/v2/workspaceagents/me/logs \
|
||||
|
||||
To perform this operation, you must be authenticated. [Learn more](authentication.md).
|
||||
|
||||
## Get workspace agent reinitialization
|
||||
|
||||
### Code samples
|
||||
|
||||
```shell
|
||||
# Example request using curl
|
||||
curl -X GET http://coder-server:8080/api/v2/workspaceagents/me/reinit \
|
||||
-H 'Accept: application/json' \
|
||||
-H 'Coder-Session-Token: API_KEY'
|
||||
```
|
||||
|
||||
`GET /workspaceagents/me/reinit`
|
||||
|
||||
### Example responses
|
||||
|
||||
> 200 Response
|
||||
|
||||
```json
|
||||
{
|
||||
"reason": "prebuild_claimed",
|
||||
"workspaceID": "string"
|
||||
}
|
||||
```
|
||||
|
||||
### Responses
|
||||
|
||||
| Status | Meaning | Description | Schema |
|
||||
|--------|---------------------------------------------------------|-------------|----------------------------------------------------------------------------|
|
||||
| 200 | [OK](https://tools.ietf.org/html/rfc7231#section-6.3.1) | OK | [agentsdk.ReinitializationEvent](schemas.md#agentsdkreinitializationevent) |
|
||||
|
||||
To perform this operation, you must be authenticated. [Learn more](authentication.md).
|
||||
|
||||
## Get workspace agent by ID
|
||||
|
||||
### Code samples
|
||||
|
||||
Generated
+30
@@ -182,6 +182,36 @@
|
||||
| `icon` | string | false | | |
|
||||
| `id` | string | false | | ID is a unique identifier for the log source. It is scoped to a workspace agent, and can be statically defined inside code to prevent duplicate sources from being created for the same agent. |
|
||||
|
||||
## agentsdk.ReinitializationEvent
|
||||
|
||||
```json
|
||||
{
|
||||
"reason": "prebuild_claimed",
|
||||
"workspaceID": "string"
|
||||
}
|
||||
```
|
||||
|
||||
### Properties
|
||||
|
||||
| Name | Type | Required | Restrictions | Description |
|
||||
|---------------|--------------------------------------------------------------------|----------|--------------|-------------|
|
||||
| `reason` | [agentsdk.ReinitializationReason](#agentsdkreinitializationreason) | false | | |
|
||||
| `workspaceID` | string | false | | |
|
||||
|
||||
## agentsdk.ReinitializationReason
|
||||
|
||||
```json
|
||||
"prebuild_claimed"
|
||||
```
|
||||
|
||||
### Properties
|
||||
|
||||
#### Enumerated Values
|
||||
|
||||
| Value |
|
||||
|--------------------|
|
||||
| `prebuild_claimed` |
|
||||
|
||||
## aisdk.Attachment
|
||||
|
||||
```json
|
||||
|
||||
@@ -5,12 +5,19 @@ import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"regexp"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/serpent"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/agent"
|
||||
"github.com/coder/coder/v2/cli/clitest"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
@@ -73,6 +80,168 @@ func TestBlockNonBrowser(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestReinitializeAgent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempAgentLog := testutil.CreateTemp(t, "", "testReinitializeAgent")
|
||||
|
||||
if !dbtestutil.WillUsePostgres() {
|
||||
t.Skip("dbmem cannot currently claim a workspace")
|
||||
}
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
// GIVEN a live enterprise API with the prebuilds feature enabled
|
||||
client, user := coderdenttest.New(t, &coderdenttest.Options{
|
||||
Options: &coderdtest.Options{
|
||||
Database: db,
|
||||
Pubsub: ps,
|
||||
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
|
||||
dv.Prebuilds.ReconciliationInterval = serpent.Duration(time.Second)
|
||||
dv.Experiments.Append(string(codersdk.ExperimentWorkspacePrebuilds))
|
||||
}),
|
||||
IncludeProvisionerDaemon: true,
|
||||
},
|
||||
LicenseOptions: &coderdenttest.LicenseOptions{
|
||||
Features: license.Features{
|
||||
codersdk.FeatureWorkspacePrebuilds: 1,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
// GIVEN a template, template version, preset and a prebuilt workspace that uses them all
|
||||
agentToken := uuid.UUID{3}
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
|
||||
Parse: echo.ParseComplete,
|
||||
ProvisionPlan: []*proto.Response{
|
||||
{
|
||||
Type: &proto.Response_Plan{
|
||||
Plan: &proto.PlanComplete{
|
||||
Presets: []*proto.Preset{
|
||||
{
|
||||
Name: "test-preset",
|
||||
Prebuild: &proto.Prebuild{
|
||||
Instances: 1,
|
||||
},
|
||||
},
|
||||
},
|
||||
Resources: []*proto.Resource{
|
||||
{
|
||||
Agents: []*proto.Agent{
|
||||
{
|
||||
Name: "smith",
|
||||
OperatingSystem: "linux",
|
||||
Architecture: "i386",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
ProvisionApply: []*proto.Response{
|
||||
{
|
||||
Type: &proto.Response_Apply{
|
||||
Apply: &proto.ApplyComplete{
|
||||
Resources: []*proto.Resource{
|
||||
{
|
||||
Type: "compute",
|
||||
Name: "main",
|
||||
Agents: []*proto.Agent{
|
||||
{
|
||||
Name: "smith",
|
||||
OperatingSystem: "linux",
|
||||
Architecture: "i386",
|
||||
Scripts: []*proto.Script{
|
||||
{
|
||||
RunOnStart: true,
|
||||
Script: fmt.Sprintf("printenv >> %s; echo '---\n' >> %s", tempAgentLog.Name(), tempAgentLog.Name()), // Make reinitialization take long enough to assert that it happened
|
||||
},
|
||||
},
|
||||
Auth: &proto.Agent_Token{
|
||||
Token: agentToken.String(),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
|
||||
|
||||
coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
|
||||
|
||||
// Wait for prebuilds to create a prebuilt workspace
|
||||
ctx := context.Background()
|
||||
// ctx := testutil.Context(t, testutil.WaitLong)
|
||||
var (
|
||||
prebuildID uuid.UUID
|
||||
)
|
||||
require.Eventually(t, func() bool {
|
||||
agentAndBuild, err := db.GetWorkspaceAgentAndLatestBuildByAuthToken(ctx, agentToken)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
prebuildID = agentAndBuild.WorkspaceBuild.ID
|
||||
return true
|
||||
}, testutil.WaitLong, testutil.IntervalFast)
|
||||
|
||||
prebuild := coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, prebuildID)
|
||||
|
||||
preset, err := db.GetPresetByWorkspaceBuildID(ctx, prebuildID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// GIVEN a running agent
|
||||
logDir := t.TempDir()
|
||||
inv, _ := clitest.New(t,
|
||||
"agent",
|
||||
"--auth", "token",
|
||||
"--agent-token", agentToken.String(),
|
||||
"--agent-url", client.URL.String(),
|
||||
"--log-dir", logDir,
|
||||
)
|
||||
clitest.Start(t, inv)
|
||||
|
||||
// GIVEN the agent is in a happy steady state
|
||||
waiter := coderdtest.NewWorkspaceAgentWaiter(t, client, prebuild.WorkspaceID)
|
||||
waiter.WaitFor(coderdtest.AgentsReady)
|
||||
|
||||
// WHEN a workspace is created that can benefit from prebuilds
|
||||
anotherClient, anotherUser := coderdtest.CreateAnotherUser(t, client, user.OrganizationID)
|
||||
workspace, err := anotherClient.CreateUserWorkspace(ctx, anotherUser.ID.String(), codersdk.CreateWorkspaceRequest{
|
||||
TemplateVersionID: version.ID,
|
||||
TemplateVersionPresetID: preset.ID,
|
||||
Name: "claimed-workspace",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
|
||||
|
||||
// THEN reinitialization completes
|
||||
waiter.WaitFor(coderdtest.AgentsReady)
|
||||
|
||||
var matches [][]byte
|
||||
require.Eventually(t, func() bool {
|
||||
// THEN the agent script ran again and reused the same agent token
|
||||
contents, err := os.ReadFile(tempAgentLog.Name())
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
// UUID regex pattern (matches UUID v4-like strings)
|
||||
uuidRegex := regexp.MustCompile(`\bCODER_AGENT_TOKEN=(.+)\b`)
|
||||
|
||||
matches = uuidRegex.FindAll(contents, -1)
|
||||
// When an agent reinitializes, we expect it to run startup scripts again.
|
||||
// As such, we expect to have written the agent environment to the temp file twice.
|
||||
// Once on initial startup and then once on reinitialization.
|
||||
return len(matches) == 2
|
||||
}, testutil.WaitLong, testutil.IntervalMedium)
|
||||
require.Equal(t, matches[0], matches[1])
|
||||
}
|
||||
|
||||
type setupResp struct {
|
||||
workspace codersdk.Workspace
|
||||
sdkAgent codersdk.WorkspaceAgent
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
@@ -13,6 +14,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
@@ -30,6 +32,8 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
"github.com/coder/coder/v2/coderd/notifications"
|
||||
"github.com/coder/coder/v2/coderd/prebuilds"
|
||||
"github.com/coder/coder/v2/coderd/provisionerdserver"
|
||||
"github.com/coder/coder/v2/coderd/rbac"
|
||||
"github.com/coder/coder/v2/coderd/rbac/policy"
|
||||
agplschedule "github.com/coder/coder/v2/coderd/schedule"
|
||||
@@ -43,6 +47,7 @@ import (
|
||||
"github.com/coder/coder/v2/enterprise/coderd/schedule"
|
||||
"github.com/coder/coder/v2/provisioner/echo"
|
||||
"github.com/coder/coder/v2/provisionersdk"
|
||||
"github.com/coder/coder/v2/provisionersdk/proto"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
@@ -459,6 +464,79 @@ func TestCreateUserWorkspace(t *testing.T) {
|
||||
_, err = client1.CreateUserWorkspace(ctx, user1.ID.String(), req)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("ClaimPrebuild", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if !dbtestutil.WillUsePostgres() {
|
||||
t.Skip("dbmem cannot currently claim a workspace")
|
||||
}
|
||||
|
||||
client, db, user := coderdenttest.NewWithDatabase(t, &coderdenttest.Options{
|
||||
Options: &coderdtest.Options{
|
||||
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
|
||||
err := dv.Experiments.Append(string(codersdk.ExperimentWorkspacePrebuilds))
|
||||
require.NoError(t, err)
|
||||
}),
|
||||
},
|
||||
LicenseOptions: &coderdenttest.LicenseOptions{
|
||||
Features: license.Features{
|
||||
codersdk.FeatureWorkspacePrebuilds: 1,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
// GIVEN a template, template version, preset and a prebuilt workspace that uses them all
|
||||
presetID := uuid.New()
|
||||
tv := dbfake.TemplateVersion(t, db).Seed(database.TemplateVersion{
|
||||
OrganizationID: user.OrganizationID,
|
||||
CreatedBy: user.UserID,
|
||||
}).Preset(database.TemplateVersionPreset{
|
||||
ID: presetID,
|
||||
}).Do()
|
||||
|
||||
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OwnerID: prebuilds.SystemUserID,
|
||||
TemplateID: tv.Template.ID,
|
||||
}).Seed(database.WorkspaceBuild{
|
||||
TemplateVersionID: tv.TemplateVersion.ID,
|
||||
TemplateVersionPresetID: uuid.NullUUID{
|
||||
UUID: presetID,
|
||||
Valid: true,
|
||||
},
|
||||
}).WithAgent(func(a []*proto.Agent) []*proto.Agent {
|
||||
return a
|
||||
}).Do()
|
||||
|
||||
// nolint:gocritic // this is a test
|
||||
ctx := dbauthz.AsSystemRestricted(testutil.Context(t, testutil.WaitLong))
|
||||
agent, err := db.GetWorkspaceAgentAndLatestBuildByAuthToken(ctx, uuid.MustParse(r.AgentToken))
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.UpdateWorkspaceAgentLifecycleStateByID(ctx, database.UpdateWorkspaceAgentLifecycleStateByIDParams{
|
||||
ID: agent.WorkspaceAgent.ID,
|
||||
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// WHEN a workspace is created that matches the available prebuilt workspace
|
||||
_, err = client.CreateUserWorkspace(ctx, user.UserID.String(), codersdk.CreateWorkspaceRequest{
|
||||
TemplateVersionID: tv.TemplateVersion.ID,
|
||||
TemplateVersionPresetID: presetID,
|
||||
Name: "claimed-workspace",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// THEN a new build is scheduled with the build stage specified
|
||||
build, err := db.GetLatestWorkspaceBuildByWorkspaceID(ctx, r.Workspace.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, build.ID, r.Build.ID)
|
||||
job, err := db.GetProvisionerJobByID(ctx, build.JobID)
|
||||
require.NoError(t, err)
|
||||
var metadata provisionerdserver.WorkspaceProvisionJob
|
||||
require.NoError(t, json.Unmarshal(job.Input, &metadata))
|
||||
require.Equal(t, metadata.PrebuiltWorkspaceBuildStage, proto.PrebuiltWorkspaceBuildStage_CLAIM)
|
||||
})
|
||||
}
|
||||
|
||||
func TestWorkspaceAutobuild(t *testing.T) {
|
||||
|
||||
@@ -350,6 +350,68 @@ func onlyDataResources(sm tfjson.StateModule) tfjson.StateModule {
|
||||
return filtered
|
||||
}
|
||||
|
||||
func (e *executor) logResourceReplacements(ctx context.Context, plan *tfjson.Plan) {
|
||||
if plan == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if len(plan.ResourceChanges) == 0 {
|
||||
return
|
||||
}
|
||||
var (
|
||||
count int
|
||||
replacements = make(map[string][]string, len(plan.ResourceChanges))
|
||||
)
|
||||
|
||||
for _, ch := range plan.ResourceChanges {
|
||||
// No change, no problem!
|
||||
if ch.Change == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// No-op change, no problem!
|
||||
if ch.Change.Actions.NoOp() {
|
||||
continue
|
||||
}
|
||||
|
||||
// No replacements, no problem!
|
||||
if len(ch.Change.ReplacePaths) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Replacing our resources, no problem!
|
||||
if strings.Index(ch.Type, "coder_") == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, p := range ch.Change.ReplacePaths {
|
||||
var path string
|
||||
switch p := p.(type) {
|
||||
case []interface{}:
|
||||
segs := p
|
||||
list := make([]string, 0, len(segs))
|
||||
for _, s := range segs {
|
||||
list = append(list, fmt.Sprintf("%v", s))
|
||||
}
|
||||
path = strings.Join(list, ".")
|
||||
default:
|
||||
path = fmt.Sprintf("%v", p)
|
||||
}
|
||||
|
||||
replacements[ch.Address] = append(replacements[ch.Address], path)
|
||||
}
|
||||
|
||||
count++
|
||||
}
|
||||
|
||||
if count > 0 {
|
||||
e.server.logger.Warn(ctx, "plan introduces resource changes", slog.F("count", count))
|
||||
for n, p := range replacements {
|
||||
e.server.logger.Warn(ctx, "resource will be replaced", slog.F("name", n), slog.F("replacement_paths", strings.Join(p, ",")))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// planResources must only be called while the lock is held.
|
||||
func (e *executor) planResources(ctx, killCtx context.Context, planfilePath string) (*State, json.RawMessage, error) {
|
||||
ctx, span := e.server.startTrace(ctx, tracing.FuncName())
|
||||
@@ -360,6 +422,8 @@ func (e *executor) planResources(ctx, killCtx context.Context, planfilePath stri
|
||||
return nil, nil, xerrors.Errorf("show terraform plan file: %w", err)
|
||||
}
|
||||
|
||||
e.logResourceReplacements(ctx, plan)
|
||||
|
||||
rawGraph, err := e.graph(ctx, killCtx)
|
||||
if err != nil {
|
||||
return nil, nil, xerrors.Errorf("graph: %w", err)
|
||||
|
||||
@@ -273,6 +273,17 @@ func provisionEnv(
|
||||
if metadata.GetPrebuiltWorkspaceBuildStage().IsPrebuild() {
|
||||
env = append(env, provider.IsPrebuildEnvironmentVariable()+"=true")
|
||||
}
|
||||
tokens := metadata.GetRunningAgentAuthTokens()
|
||||
if len(tokens) == 1 {
|
||||
env = append(env, provider.RunningAgentTokenEnvironmentVariable("")+"="+tokens[0].Token)
|
||||
} else {
|
||||
// Not currently supported, but added for forward-compatibility
|
||||
for _, t := range tokens {
|
||||
// If there are multiple agents, provide all the tokens to terraform so that it can
|
||||
// choose the correct one for each agent ID.
|
||||
env = append(env, provider.RunningAgentTokenEnvironmentVariable(t.AgentId)+"="+t.Token)
|
||||
}
|
||||
}
|
||||
if metadata.GetPrebuiltWorkspaceBuildStage().IsPrebuiltWorkspaceClaim() {
|
||||
env = append(env, provider.IsPrebuildClaimEnvironmentVariable()+"=true")
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ import "github.com/coder/coder/v2/apiversion"
|
||||
// - Add previous parameter values to 'WorkspaceBuild' jobs. Provisioner passes
|
||||
// the previous values for the `terraform apply` to enforce monotonicity
|
||||
// in the terraform provider.
|
||||
// - Add new field named `running_agent_auth_tokens` to provisioner job metadata
|
||||
const (
|
||||
CurrentMajor = 1
|
||||
CurrentMinor = 5
|
||||
|
||||
Generated
+452
-376
File diff suppressed because it is too large
Load Diff
@@ -273,6 +273,10 @@ message Role {
|
||||
string org_id = 2;
|
||||
}
|
||||
|
||||
message RunningAgentAuthToken {
|
||||
string agent_id = 1;
|
||||
string token = 2;
|
||||
}
|
||||
enum PrebuiltWorkspaceBuildStage {
|
||||
NONE = 0; // Default value for builds unrelated to prebuilds.
|
||||
CREATE = 1; // A prebuilt workspace is being provisioned.
|
||||
@@ -301,7 +305,7 @@ message Metadata {
|
||||
string workspace_owner_login_type = 18;
|
||||
repeated Role workspace_owner_rbac_roles = 19;
|
||||
PrebuiltWorkspaceBuildStage prebuilt_workspace_build_stage = 20; // Indicates that a prebuilt workspace is being built.
|
||||
string running_workspace_agent_token = 21; // Preserves the running agent token of a prebuilt workspace so it can reinitialize.
|
||||
repeated RunningAgentAuthToken running_agent_auth_tokens = 21;
|
||||
}
|
||||
|
||||
// Config represents execution configuration shared by all subsequent requests in the Session
|
||||
|
||||
Generated
+20
-4
@@ -297,6 +297,11 @@ export interface Role {
|
||||
orgId: string;
|
||||
}
|
||||
|
||||
export interface RunningAgentAuthToken {
|
||||
agentId: string;
|
||||
token: string;
|
||||
}
|
||||
|
||||
/** Metadata is information about a workspace used in the execution of a build */
|
||||
export interface Metadata {
|
||||
coderUrl: string;
|
||||
@@ -320,8 +325,7 @@ export interface Metadata {
|
||||
workspaceOwnerRbacRoles: Role[];
|
||||
/** Indicates that a prebuilt workspace is being built. */
|
||||
prebuiltWorkspaceBuildStage: PrebuiltWorkspaceBuildStage;
|
||||
/** Preserves the running agent token of a prebuilt workspace so it can reinitialize. */
|
||||
runningWorkspaceAgentToken: string;
|
||||
runningAgentAuthTokens: RunningAgentAuthToken[];
|
||||
}
|
||||
|
||||
/** Config represents execution configuration shared by all subsequent requests in the Session */
|
||||
@@ -986,6 +990,18 @@ export const Role = {
|
||||
},
|
||||
};
|
||||
|
||||
export const RunningAgentAuthToken = {
|
||||
encode(message: RunningAgentAuthToken, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer {
|
||||
if (message.agentId !== "") {
|
||||
writer.uint32(10).string(message.agentId);
|
||||
}
|
||||
if (message.token !== "") {
|
||||
writer.uint32(18).string(message.token);
|
||||
}
|
||||
return writer;
|
||||
},
|
||||
};
|
||||
|
||||
export const Metadata = {
|
||||
encode(message: Metadata, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer {
|
||||
if (message.coderUrl !== "") {
|
||||
@@ -1048,8 +1064,8 @@ export const Metadata = {
|
||||
if (message.prebuiltWorkspaceBuildStage !== 0) {
|
||||
writer.uint32(160).int32(message.prebuiltWorkspaceBuildStage);
|
||||
}
|
||||
if (message.runningWorkspaceAgentToken !== "") {
|
||||
writer.uint32(170).string(message.runningWorkspaceAgentToken);
|
||||
for (const v of message.runningAgentAuthTokens) {
|
||||
RunningAgentAuthToken.encode(v!, writer.uint32(170).fork()).ldelim();
|
||||
}
|
||||
return writer;
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user