package provisionerdserver import ( "context" "crypto/sha256" "database/sql" "encoding/hex" "encoding/json" "errors" "fmt" "net/http" "net/url" "reflect" "slices" "sort" "strconv" "strings" "sync/atomic" "time" "github.com/google/uuid" "github.com/sqlc-dev/pqtype" semconv "go.opentelemetry.io/otel/semconv/v1.14.0" "go.opentelemetry.io/otel/trace" "golang.org/x/exp/maps" "golang.org/x/oauth2" "golang.org/x/xerrors" protobuf "google.golang.org/protobuf/proto" "cdr.dev/slog/v3" "github.com/coder/coder/v2/coderd/aiseats" "github.com/coder/coder/v2/coderd/apikey" "github.com/coder/coder/v2/coderd/audit" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/database/pubsub" "github.com/coder/coder/v2/coderd/externalauth" "github.com/coder/coder/v2/coderd/notifications" "github.com/coder/coder/v2/coderd/prebuilds" "github.com/coder/coder/v2/coderd/promoauth" "github.com/coder/coder/v2/coderd/schedule" "github.com/coder/coder/v2/coderd/telemetry" "github.com/coder/coder/v2/coderd/tracing" "github.com/coder/coder/v2/coderd/usage" "github.com/coder/coder/v2/coderd/usage/usagetypes" "github.com/coder/coder/v2/coderd/util/slice" "github.com/coder/coder/v2/coderd/wspubsub" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/agentsdk" "github.com/coder/coder/v2/codersdk/drpcsdk" "github.com/coder/coder/v2/provisioner" "github.com/coder/coder/v2/provisionerd/proto" "github.com/coder/coder/v2/provisionersdk" sdkproto "github.com/coder/coder/v2/provisionersdk/proto" "github.com/coder/quartz" ) const ( tarMimeType = "application/x-tar" ) const ( // DefaultAcquireJobLongPollDur is the time the (deprecated) AcquireJob rpc waits to try to obtain a job before // canceling and returning an empty job. DefaultAcquireJobLongPollDur = time.Second * 5 // DefaultHeartbeatInterval is the interval at which the provisioner daemon // will update its last seen at timestamp in the database. DefaultHeartbeatInterval = time.Minute // StaleInterval is the amount of time after the last heartbeat for which // the provisioner will be reported as 'stale'. StaleInterval = 90 * time.Second ) type Options struct { OIDCConfig promoauth.OAuth2Config ExternalAuthConfigs []*externalauth.Config AISeatTracker aiseats.SeatTracker // Clock for testing Clock quartz.Clock // AcquireJobLongPollDur is used in tests AcquireJobLongPollDur time.Duration // HeartbeatInterval is the interval at which the provisioner daemon // will update its last seen at timestamp in the database. HeartbeatInterval time.Duration // HeartbeatFn is the function that will be called at the interval // specified by HeartbeatInterval. // The default function just calls UpdateProvisionerDaemonLastSeenAt. // This is mainly used for testing. HeartbeatFn func(context.Context) error } type server struct { apiVersion string // lifecycleCtx must be tied to the API server's lifecycle // as when the API server shuts down, we want to cancel any // long-running operations. lifecycleCtx context.Context AccessURL *url.URL ID uuid.UUID OrganizationID uuid.UUID Logger slog.Logger Provisioners []database.ProvisionerType ExternalAuthConfigs []*externalauth.Config Tags Tags Database database.Store Pubsub pubsub.Pubsub Acquirer *Acquirer Telemetry telemetry.Reporter Tracer trace.Tracer QuotaCommitter *atomic.Pointer[proto.QuotaCommitter] Auditor *atomic.Pointer[audit.Auditor] TemplateScheduleStore *atomic.Pointer[schedule.TemplateScheduleStore] UserQuietHoursScheduleStore *atomic.Pointer[schedule.UserQuietHoursScheduleStore] DeploymentValues *codersdk.DeploymentValues NotificationsEnqueuer notifications.Enqueuer PrebuildsOrchestrator *atomic.Pointer[prebuilds.ReconciliationOrchestrator] UsageInserter *atomic.Pointer[usage.Inserter] AISeatTracker aiseats.SeatTracker Experiments codersdk.Experiments OIDCConfig promoauth.OAuth2Config Clock quartz.Clock acquireJobLongPollDur time.Duration heartbeatInterval time.Duration heartbeatFn func(ctx context.Context) error metrics *Metrics } // We use the null byte (0x00) in generating a canonical map key for tags, so // it cannot be used in the tag keys or values. var ErrTagsContainNullByte = xerrors.New("tags cannot contain the null byte (0x00)") type Tags map[string]string func (t Tags) ToJSON() (json.RawMessage, error) { r, err := json.Marshal(t) if err != nil { return nil, err } return r, err } func (t Tags) Valid() error { for k, v := range t { if slices.Contains([]byte(k), 0x00) || slices.Contains([]byte(v), 0x00) { return ErrTagsContainNullByte } } return nil } func NewServer( lifecycleCtx context.Context, apiVersion string, accessURL *url.URL, id uuid.UUID, organizationID uuid.UUID, logger slog.Logger, provisioners []database.ProvisionerType, tags Tags, db database.Store, ps pubsub.Pubsub, acquirer *Acquirer, tel telemetry.Reporter, tracer trace.Tracer, quotaCommitter *atomic.Pointer[proto.QuotaCommitter], auditor *atomic.Pointer[audit.Auditor], templateScheduleStore *atomic.Pointer[schedule.TemplateScheduleStore], userQuietHoursScheduleStore *atomic.Pointer[schedule.UserQuietHoursScheduleStore], usageInserter *atomic.Pointer[usage.Inserter], deploymentValues *codersdk.DeploymentValues, options Options, enqueuer notifications.Enqueuer, prebuildsOrchestrator *atomic.Pointer[prebuilds.ReconciliationOrchestrator], metrics *Metrics, experiments codersdk.Experiments, ) (proto.DRPCProvisionerDaemonServer, error) { // Fail-fast if pointers are nil if lifecycleCtx == nil { return nil, xerrors.New("ctx is nil") } if quotaCommitter == nil { return nil, xerrors.New("quotaCommitter is nil") } if auditor == nil { return nil, xerrors.New("auditor is nil") } if templateScheduleStore == nil { return nil, xerrors.New("templateScheduleStore is nil") } if userQuietHoursScheduleStore == nil { return nil, xerrors.New("userQuietHoursScheduleStore is nil") } if usageInserter == nil { return nil, xerrors.New("usageCollector is nil") } if deploymentValues == nil { return nil, xerrors.New("deploymentValues is nil") } if acquirer == nil { return nil, xerrors.New("acquirer is nil") } if tags == nil { return nil, xerrors.Errorf("tags is nil") } if err := tags.Valid(); err != nil { return nil, xerrors.Errorf("invalid tags: %w", err) } if options.AISeatTracker == nil { options.AISeatTracker = aiseats.Noop{} } if options.AcquireJobLongPollDur == 0 { options.AcquireJobLongPollDur = DefaultAcquireJobLongPollDur } if options.HeartbeatInterval == 0 { options.HeartbeatInterval = DefaultHeartbeatInterval } if options.Clock == nil { options.Clock = quartz.NewReal() } s := &server{ lifecycleCtx: lifecycleCtx, apiVersion: apiVersion, AccessURL: accessURL, ID: id, OrganizationID: organizationID, Logger: logger, Provisioners: provisioners, ExternalAuthConfigs: options.ExternalAuthConfigs, Tags: tags, Database: db, Pubsub: ps, Acquirer: acquirer, NotificationsEnqueuer: enqueuer, Telemetry: tel, Tracer: tracer, QuotaCommitter: quotaCommitter, Auditor: auditor, TemplateScheduleStore: templateScheduleStore, UserQuietHoursScheduleStore: userQuietHoursScheduleStore, DeploymentValues: deploymentValues, OIDCConfig: options.OIDCConfig, Clock: options.Clock, acquireJobLongPollDur: options.AcquireJobLongPollDur, heartbeatInterval: options.HeartbeatInterval, heartbeatFn: options.HeartbeatFn, PrebuildsOrchestrator: prebuildsOrchestrator, UsageInserter: usageInserter, AISeatTracker: options.AISeatTracker, metrics: metrics, Experiments: experiments, } if s.heartbeatFn == nil { s.heartbeatFn = s.defaultHeartbeat } go s.heartbeatLoop() return s, nil } // timeNow should be used when trying to get the current time for math // calculations regarding workspace start and stop time. func (s *server) timeNow(tags ...string) time.Time { return dbtime.Time(s.Clock.Now(tags...)) } // heartbeatLoop runs heartbeatOnce at the interval specified by HeartbeatInterval // until the lifecycle context is canceled. func (s *server) heartbeatLoop() { tick := time.NewTicker(time.Nanosecond) defer tick.Stop() for { select { case <-s.lifecycleCtx.Done(): s.Logger.Debug(s.lifecycleCtx, "heartbeat loop canceled") return case <-tick.C: if s.lifecycleCtx.Err() != nil { return } start := s.timeNow() hbCtx, hbCancel := context.WithTimeout(s.lifecycleCtx, s.heartbeatInterval) if err := s.heartbeat(hbCtx); err != nil && !database.IsQueryCanceledError(err) { s.Logger.Warn(hbCtx, "heartbeat failed", slog.Error(err)) } hbCancel() elapsed := s.timeNow().Sub(start) nextBeat := s.heartbeatInterval - elapsed // avoid negative interval if nextBeat <= 0 { nextBeat = time.Nanosecond } tick.Reset(nextBeat) } } } // heartbeat updates the last seen at timestamp in the database. // If HeartbeatFn is set, it will be called instead. func (s *server) heartbeat(ctx context.Context) error { select { case <-ctx.Done(): return nil default: return s.heartbeatFn(ctx) } } func (s *server) defaultHeartbeat(ctx context.Context) error { //nolint:gocritic // This is specifically for updating the last seen at timestamp. return s.Database.UpdateProvisionerDaemonLastSeenAt(dbauthz.AsSystemRestricted(ctx), database.UpdateProvisionerDaemonLastSeenAtParams{ ID: s.ID, LastSeenAt: sql.NullTime{Time: s.timeNow(), Valid: true}, }) } // AcquireJob queries the database to lock a job. // // Deprecated: This method is only available for back-level provisioner daemons. func (s *server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) { //nolint:gocritic // Provisionerd has specific authz rules. ctx = dbauthz.AsProvisionerd(ctx) // Since AcquireJob blocks until a job is available, we set a long (5s by default) timeout. This allows back-level // provisioner daemons to gracefully shut down within a few seconds, but keeps them from rapidly polling the // database. acqCtx, acqCancel := context.WithTimeout(ctx, s.acquireJobLongPollDur) defer acqCancel() job, err := s.Acquirer.AcquireJob(acqCtx, s.OrganizationID, s.ID, s.Provisioners, s.Tags) if database.IsQueryCanceledError(err) { s.Logger.Debug(ctx, "successful cancel") return &proto.AcquiredJob{}, nil } if err != nil { return nil, xerrors.Errorf("acquire job: %w", err) } s.Logger.Debug(ctx, "locked job from database", slog.F("job_id", job.ID)) return s.acquireProtoJob(ctx, job) } type jobAndErr struct { job database.ProvisionerJob err error } // AcquireJobWithCancel queries the database to lock a job. func (s *server) AcquireJobWithCancel(stream proto.DRPCProvisionerDaemon_AcquireJobWithCancelStream) (retErr error) { //nolint:gocritic // Provisionerd has specific authz rules. streamCtx := dbauthz.AsProvisionerd(stream.Context()) defer func() { closeErr := stream.Close() s.Logger.Debug(streamCtx, "closed stream", slog.Error(closeErr)) if retErr == nil { retErr = closeErr } }() acqCtx, acqCancel := context.WithCancel(streamCtx) defer acqCancel() recvCh := make(chan error, 1) go func() { _, err := stream.Recv() // cancel is the only message recvCh <- err }() jec := make(chan jobAndErr, 1) go func() { job, err := s.Acquirer.AcquireJob(acqCtx, s.OrganizationID, s.ID, s.Provisioners, s.Tags) jec <- jobAndErr{job: job, err: err} }() var recvErr error var je jobAndErr select { case recvErr = <-recvCh: acqCancel() je = <-jec case je = <-jec: } if database.IsQueryCanceledError(je.err) { s.Logger.Debug(streamCtx, "successful cancel") err := stream.Send(&proto.AcquiredJob{}) if err != nil { // often this is just because the other side hangs up and doesn't wait for the cancel, so log at INFO s.Logger.Info(streamCtx, "failed to send empty job", slog.Error(err)) return err } return nil } if je.err != nil { return xerrors.Errorf("acquire job: %w", je.err) } logger := s.Logger.With(slog.F("job_id", je.job.ID)) logger.Debug(streamCtx, "locked job from database") if recvErr != nil { logger.Error(streamCtx, "recv error and failed to cancel acquire job", slog.Error(recvErr)) // Well, this is awkward. We hit an error receiving from the stream, but didn't cancel before we locked a job // in the database. We need to mark this job as failed so the end user can retry if they want to. now := s.timeNow() err := s.Database.UpdateProvisionerJobWithCompleteByID( //nolint:gocritic // Provisionerd has specific authz rules. dbauthz.AsProvisionerd(context.Background()), database.UpdateProvisionerJobWithCompleteByIDParams{ ID: je.job.ID, CompletedAt: sql.NullTime{ Time: now, Valid: true, }, UpdatedAt: now, Error: sql.NullString{ String: "connection to provisioner daemon broken", Valid: true, }, ErrorCode: sql.NullString{}, }) if err != nil { logger.Error(streamCtx, "error updating failed job", slog.Error(err)) } return recvErr } pj, err := s.acquireProtoJob(streamCtx, je.job) if err != nil { return err } err = stream.Send(pj) if err != nil { s.Logger.Error(streamCtx, "failed to send job", slog.Error(err)) return err } return nil } func (s *server) acquireProtoJob(ctx context.Context, job database.ProvisionerJob) (*proto.AcquiredJob, error) { // Marks the acquired job as failed with the error message provided. failJob := func(errorMessage string) error { err := s.Database.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{ ID: job.ID, CompletedAt: sql.NullTime{ Time: s.timeNow(), Valid: true, }, Error: sql.NullString{ String: errorMessage, Valid: true, }, ErrorCode: job.ErrorCode, UpdatedAt: s.timeNow(), }) if err != nil { return xerrors.Errorf("update provisioner job: %w", err) } return xerrors.Errorf("request job was invalidated: %s", errorMessage) } user, err := s.Database.GetUserByID(ctx, job.InitiatorID) if err != nil { return nil, failJob(fmt.Sprintf("get user: %s", err)) } jobTraceMetadata := map[string]string{} if job.TraceMetadata.Valid { err := json.Unmarshal(job.TraceMetadata.RawMessage, &jobTraceMetadata) if err != nil { return nil, failJob(fmt.Sprintf("unmarshal metadata: %s", err)) } } protoJob := &proto.AcquiredJob{ JobId: job.ID.String(), CreatedAt: job.CreatedAt.UnixMilli(), Provisioner: string(job.Provisioner), UserName: user.Username, TraceMetadata: jobTraceMetadata, } // jobTransition and jobBuildReason are used for metrics; only set for workspace builds. var jobTransition string var jobBuildReason string switch job.Type { case database.ProvisionerJobTypeWorkspaceBuild: var input WorkspaceProvisionJob err = json.Unmarshal(job.Input, &input) if err != nil { return nil, failJob(fmt.Sprintf("unmarshal job input %q: %s", job.Input, err)) } workspaceBuild, err := s.Database.GetWorkspaceBuildByID(ctx, input.WorkspaceBuildID) if err != nil { return nil, failJob(fmt.Sprintf("get workspace build: %s", err)) } workspace, err := s.Database.GetWorkspaceByID(ctx, workspaceBuild.WorkspaceID) if err != nil { return nil, failJob(fmt.Sprintf("get workspace: %s", err)) } templateVersion, err := s.Database.GetTemplateVersionByID(ctx, workspaceBuild.TemplateVersionID) if err != nil { return nil, failJob(fmt.Sprintf("get template version: %s", err)) } templateVariables, err := s.Database.GetTemplateVersionVariables(ctx, templateVersion.ID) if err != nil && !xerrors.Is(err, sql.ErrNoRows) { return nil, failJob(fmt.Sprintf("get template version variables: %s", err)) } template, err := s.Database.GetTemplateByID(ctx, templateVersion.TemplateID.UUID) if err != nil { return nil, failJob(fmt.Sprintf("get template: %s", err)) } owner, err := s.Database.GetUserByID(ctx, workspace.OwnerID) if err != nil { return nil, failJob(fmt.Sprintf("get owner: %s", err)) } // Fetch the file id of the cached module files if it exists. versionModulesFile := "" if !template.DisableModuleCache { tfvals, err := s.Database.GetTemplateVersionTerraformValues(ctx, templateVersion.ID) if err != nil && !xerrors.Is(err, sql.ErrNoRows) { // Older templates (before dynamic parameters) will not have cached module files. return nil, failJob(fmt.Sprintf("get template version terraform values: %s", err)) } if err == nil && tfvals.CachedModuleFiles.Valid { versionModulesFile = tfvals.CachedModuleFiles.UUID.String() } } var ownerSSHPublicKey, ownerSSHPrivateKey string if ownerSSHKey, err := s.Database.GetGitSSHKey(ctx, owner.ID); err != nil { if !xerrors.Is(err, sql.ErrNoRows) { return nil, failJob(fmt.Sprintf("get owner ssh key: %s", err)) } } else { ownerSSHPublicKey = ownerSSHKey.PublicKey ownerSSHPrivateKey = ownerSSHKey.PrivateKey } ownerGroups, err := s.Database.GetGroups(ctx, database.GetGroupsParams{ HasMemberID: owner.ID, OrganizationID: s.OrganizationID, }) if err != nil { return nil, failJob(fmt.Sprintf("get owner group names: %s", err)) } ownerGroupNames := []string{} for _, group := range ownerGroups { ownerGroupNames = append(ownerGroupNames, group.Group.Name) } msg, err := json.Marshal(wspubsub.WorkspaceEvent{ Kind: wspubsub.WorkspaceEventKindStateChange, WorkspaceID: workspace.ID, }) if err != nil { return nil, failJob(fmt.Sprintf("marshal workspace update event: %s", err)) } err = s.Pubsub.Publish(wspubsub.WorkspaceEventChannel(workspace.OwnerID), msg) if err != nil { return nil, failJob(fmt.Sprintf("publish workspace update: %s", err)) } var workspaceOwnerOIDCAccessToken string // The check `s.OIDCConfig != nil` is not as strict, since it can be an interface // pointing to a typed nil. if !reflect.ValueOf(s.OIDCConfig).IsNil() { workspaceOwnerOIDCAccessToken, err = ObtainOIDCAccessToken(ctx, s.Logger, s.Database, s.OIDCConfig, owner.ID) if err != nil { return nil, failJob(fmt.Sprintf("obtain OIDC access token: %s", err)) } } var sessionToken string switch workspaceBuild.Transition { case database.WorkspaceTransitionStart: sessionToken, err = s.regenerateSessionToken(ctx, owner, workspace) if err != nil { return nil, failJob(fmt.Sprintf("regenerate session token: %s", err)) } case database.WorkspaceTransitionStop, database.WorkspaceTransitionDelete: err = deleteSessionToken(ctx, s.Database, workspace) if err != nil { return nil, failJob(fmt.Sprintf("delete session token: %s", err)) } } transition, err := convertWorkspaceTransition(workspaceBuild.Transition) if err != nil { return nil, failJob(fmt.Sprintf("convert workspace transition: %s", err)) } jobTransition = string(workspaceBuild.Transition) // Prebuilds use BuildReasonInitiator in the database but we want to // track them separately in metrics. Check the initiator ID to detect // prebuild jobs. if job.InitiatorID == database.PrebuildsSystemUserID { jobBuildReason = BuildReasonPrebuild } else { jobBuildReason = string(workspaceBuild.Reason) } // A previous workspace build exists var lastWorkspaceBuildParameters []database.WorkspaceBuildParameter if workspaceBuild.BuildNumber > 1 { // TODO: Should we fetch the last build that succeeded? This fetches the // previous build regardless of the status of the build. buildNum := workspaceBuild.BuildNumber - 1 previous, err := s.Database.GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx, database.GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams{ WorkspaceID: workspaceBuild.WorkspaceID, BuildNumber: buildNum, }) // If the error is ErrNoRows, then assume previous values are empty. if err != nil && !xerrors.Is(err, sql.ErrNoRows) { return nil, xerrors.Errorf("get last build with number=%d: %w", buildNum, err) } if err == nil { lastWorkspaceBuildParameters, err = s.Database.GetWorkspaceBuildParameters(ctx, previous.ID) if err != nil { return nil, xerrors.Errorf("get last build parameters %q: %w", previous.ID, err) } } } workspaceBuildParameters, err := s.Database.GetWorkspaceBuildParameters(ctx, workspaceBuild.ID) if err != nil { return nil, failJob(fmt.Sprintf("get workspace build parameters: %s", err)) } task, err := s.Database.GetTaskByWorkspaceID(ctx, workspaceBuild.WorkspaceID) if err != nil && !errors.Is(err, sql.ErrNoRows) { return nil, xerrors.Errorf("get task by workspace id: %w", err) } dbExternalAuthProviders := []database.ExternalAuthProvider{} err = json.Unmarshal(templateVersion.ExternalAuthProviders, &dbExternalAuthProviders) if err != nil { return nil, xerrors.Errorf("failed to deserialize external_auth_providers value: %w", err) } externalAuthProviders := make([]*sdkproto.ExternalAuthProvider, 0, len(dbExternalAuthProviders)) for _, p := range dbExternalAuthProviders { link, err := s.Database.GetExternalAuthLink(ctx, database.GetExternalAuthLinkParams{ ProviderID: p.ID, UserID: owner.ID, }) if errors.Is(err, sql.ErrNoRows) { continue } if err != nil { return nil, failJob(fmt.Sprintf("acquire external auth link: %s", err)) } var config *externalauth.Config for _, c := range s.ExternalAuthConfigs { if c.ID != p.ID { continue } config = c break } // We weren't able to find a matching config for the ID! if config == nil { s.Logger.Warn(ctx, "workspace build job is missing external auth provider", slog.F("provider_id", p.ID), slog.F("template_version_id", templateVersion.ID), slog.F("workspace_id", workspaceBuild.WorkspaceID)) continue } refreshed, err := config.RefreshToken(ctx, s.Database, link) if err != nil && !externalauth.IsInvalidTokenError(err) { return nil, failJob(fmt.Sprintf("refresh external auth link %q: %s", p.ID, err)) } if err != nil { // Invalid tokens are skipped continue } externalAuthProviders = append(externalAuthProviders, &sdkproto.ExternalAuthProvider{ Id: p.ID, AccessToken: refreshed.OAuthAccessToken, }) } allUserRoles, err := s.Database.GetAuthorizationUserRoles(ctx, owner.ID) if err != nil { return nil, failJob(fmt.Sprintf("get owner authorization roles: %s", err)) } ownerRbacRoles := []*sdkproto.Role{} roles, err := allUserRoles.RoleNames() if err == nil { for _, role := range roles { if role.OrganizationID != uuid.Nil && role.OrganizationID != s.OrganizationID { continue // Only include site wide and org specific roles } orgID := role.OrganizationID.String() if role.OrganizationID == uuid.Nil { orgID = "" } ownerRbacRoles = append(ownerRbacRoles, &sdkproto.Role{Name: role.Name, OrgId: orgID}) } } runningAgentAuthTokens := []*sdkproto.RunningAgentAuthToken{} if input.PrebuiltWorkspaceBuildStage == sdkproto.PrebuiltWorkspaceBuildStage_CLAIM { // runningAgentAuthTokens are *only* used for prebuilds. We fetch them when we want to rebuild a prebuilt workspace // but not generate new agent tokens. The provisionerdserver will push them down to // the provisioner (and ultimately to the `coder_agent` resource in the Terraform provider) where they will be // reused. Context: the agent token is often used in immutable attributes of workspace resource (e.g. VM/container) // to initialize the agent, so if that value changes it will necessitate a replacement of that resource, thus // obviating the whole point of the prebuild. agents, err := s.Database.GetWorkspaceAgentsByWorkspaceAndBuildNumber(ctx, database.GetWorkspaceAgentsByWorkspaceAndBuildNumberParams{ WorkspaceID: workspace.ID, BuildNumber: 1, }) if err != nil { s.Logger.Error(ctx, "failed to retrieve running agents of claimed prebuilt workspace", slog.F("workspace_id", workspace.ID), slog.Error(err)) } for _, agent := range agents { runningAgentAuthTokens = append(runningAgentAuthTokens, &sdkproto.RunningAgentAuthToken{ AgentId: agent.ID.String(), Token: agent.AuthToken.String(), }) } } provisionerStateRow, err := s.Database.GetWorkspaceBuildProvisionerStateByID(ctx, workspaceBuild.ID) if err != nil { return nil, failJob(fmt.Sprintf("get workspace build provisioner state: %s", err)) } protoJob.Type = &proto.AcquiredJob_WorkspaceBuild_{ WorkspaceBuild: &proto.AcquiredJob_WorkspaceBuild{ WorkspaceBuildId: workspaceBuild.ID.String(), WorkspaceName: workspace.Name, State: provisionerStateRow.ProvisionerState, RichParameterValues: convertRichParameterValues(workspaceBuildParameters), PreviousParameterValues: convertRichParameterValues(lastWorkspaceBuildParameters), VariableValues: asVariableValues(templateVariables), ExternalAuthProviders: externalAuthProviders, Metadata: &sdkproto.Metadata{ CoderUrl: s.AccessURL.String(), WorkspaceTransition: transition, WorkspaceName: workspace.Name, WorkspaceOwner: owner.Username, WorkspaceOwnerEmail: owner.Email, WorkspaceOwnerName: owner.Name, WorkspaceOwnerGroups: ownerGroupNames, WorkspaceOwnerOidcAccessToken: workspaceOwnerOIDCAccessToken, WorkspaceId: workspace.ID.String(), WorkspaceOwnerId: owner.ID.String(), TemplateId: template.ID.String(), TemplateName: template.Name, TemplateVersionId: templateVersion.ID.String(), TemplateVersion: templateVersion.Name, WorkspaceOwnerSessionToken: sessionToken, WorkspaceOwnerSshPublicKey: ownerSSHPublicKey, WorkspaceOwnerSshPrivateKey: ownerSSHPrivateKey, WorkspaceBuildId: workspaceBuild.ID.String(), WorkspaceOwnerLoginType: string(owner.LoginType), WorkspaceOwnerRbacRoles: ownerRbacRoles, RunningAgentAuthTokens: runningAgentAuthTokens, PrebuiltWorkspaceBuildStage: input.PrebuiltWorkspaceBuildStage, TaskId: task.ID.String(), TaskPrompt: task.Prompt, TemplateVersionModulesFile: versionModulesFile, }, LogLevel: input.LogLevel, }, } case database.ProvisionerJobTypeTemplateVersionDryRun: var input TemplateVersionDryRunJob err = json.Unmarshal(job.Input, &input) if err != nil { return nil, failJob(fmt.Sprintf("unmarshal job input %q: %s", job.Input, err)) } templateVersion, err := s.Database.GetTemplateVersionByID(ctx, input.TemplateVersionID) if err != nil { return nil, failJob(fmt.Sprintf("get template version: %s", err)) } templateVariables, err := s.Database.GetTemplateVersionVariables(ctx, templateVersion.ID) if err != nil && !xerrors.Is(err, sql.ErrNoRows) { return nil, failJob(fmt.Sprintf("get template version variables: %s", err)) } protoJob.Type = &proto.AcquiredJob_TemplateDryRun_{ TemplateDryRun: &proto.AcquiredJob_TemplateDryRun{ RichParameterValues: convertRichParameterValues(input.RichParameterValues), VariableValues: asVariableValues(templateVariables), Metadata: &sdkproto.Metadata{ CoderUrl: s.AccessURL.String(), WorkspaceName: input.WorkspaceName, // There is no owner for a template import, but we can assume // the "Everyone" group as a placeholder. WorkspaceOwnerGroups: []string{database.EveryoneGroup}, }, }, } case database.ProvisionerJobTypeTemplateVersionImport: var input TemplateVersionImportJob err = json.Unmarshal(job.Input, &input) if err != nil { return nil, failJob(fmt.Sprintf("unmarshal job input %q: %s", job.Input, err)) } userVariableValues, err := s.includeLastVariableValues(ctx, input.TemplateVersionID, input.UserVariableValues) if err != nil { return nil, failJob(err.Error()) } templateID := "" if input.TemplateID.Valid { templateID = input.TemplateID.UUID.String() } protoJob.Type = &proto.AcquiredJob_TemplateImport_{ TemplateImport: &proto.AcquiredJob_TemplateImport{ UserVariableValues: convertVariableValues(userVariableValues), Metadata: &sdkproto.Metadata{ CoderUrl: s.AccessURL.String(), // There is no owner for a template import, but we can assume // the "Everyone" group as a placeholder. WorkspaceOwnerGroups: []string{database.EveryoneGroup}, TemplateId: templateID, TemplateVersionId: input.TemplateVersionID.String(), }, }, } } switch job.StorageMethod { case database.ProvisionerStorageMethodFile: file, err := s.Database.GetFileByID(ctx, job.FileID) if err != nil { return nil, failJob(fmt.Sprintf("get file by id: %s", err)) } protoJob.TemplateSourceArchive = file.Data default: return nil, failJob(fmt.Sprintf("unsupported storage method: %s", job.StorageMethod)) } if protobuf.Size(protoJob) > drpcsdk.MaxMessageSize { return nil, failJob(fmt.Sprintf("payload was too big: %d > %d", protobuf.Size(protoJob), drpcsdk.MaxMessageSize)) } // Record the time the job spent waiting in the queue. if s.metrics != nil && job.StartedAt.Valid && job.Provisioner.Valid() { // These timestamps lose their monotonic clock component after a Postgres // round-trip, so the subtraction is based purely on wall-clock time. Floor at // 1ms as a defensive measure against clock adjustments producing a negative // delta while acknowledging there's a non-zero queue time. queueWaitSeconds := max(job.StartedAt.Time.Sub(job.CreatedAt).Seconds(), 0.001) s.metrics.ObserveJobQueueWait(string(job.Provisioner), string(job.Type), jobTransition, jobBuildReason, queueWaitSeconds) } return protoJob, err } func (s *server) includeLastVariableValues(ctx context.Context, templateVersionID uuid.UUID, userVariableValues []codersdk.VariableValue) ([]codersdk.VariableValue, error) { var values []codersdk.VariableValue values = append(values, userVariableValues...) if templateVersionID == uuid.Nil { return values, nil } templateVersion, err := s.Database.GetTemplateVersionByID(ctx, templateVersionID) if err != nil { return nil, xerrors.Errorf("get template version: %w", err) } if templateVersion.TemplateID.UUID == uuid.Nil { return values, nil } template, err := s.Database.GetTemplateByID(ctx, templateVersion.TemplateID.UUID) if err != nil { return nil, xerrors.Errorf("get template: %w", err) } if template.ActiveVersionID == uuid.Nil { return values, nil } templateVariables, err := s.Database.GetTemplateVersionVariables(ctx, template.ActiveVersionID) if err != nil && !xerrors.Is(err, sql.ErrNoRows) { return nil, xerrors.Errorf("get template version variables: %w", err) } for _, templateVariable := range templateVariables { var alreadyAdded bool for _, uvv := range userVariableValues { if uvv.Name == templateVariable.Name { alreadyAdded = true break } } if alreadyAdded { continue } values = append(values, codersdk.VariableValue{ Name: templateVariable.Name, Value: templateVariable.Value, }) } return values, nil } func (s *server) CommitQuota(ctx context.Context, request *proto.CommitQuotaRequest) (*proto.CommitQuotaResponse, error) { ctx, span := s.startTrace(ctx, tracing.FuncName()) defer span.End() //nolint:gocritic // Provisionerd has specific authz rules. ctx = dbauthz.AsProvisionerd(ctx) jobID, err := uuid.Parse(request.JobId) if err != nil { return nil, xerrors.Errorf("parse job id: %w", err) } job, err := s.Database.GetProvisionerJobByID(ctx, jobID) if err != nil { return nil, xerrors.Errorf("get job: %w", err) } if !job.WorkerID.Valid { return nil, xerrors.New("job isn't running yet") } if job.WorkerID.UUID.String() != s.ID.String() { return nil, xerrors.New("you don't own this job") } q := s.QuotaCommitter.Load() if q == nil { // We're probably in community edition or a test. return &proto.CommitQuotaResponse{ Budget: -1, Ok: true, }, nil } return (*q).CommitQuota(ctx, request) } func (s *server) UpdateJob(ctx context.Context, request *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) { ctx, span := s.startTrace(ctx, tracing.FuncName()) defer span.End() //nolint:gocritic // Provisionerd has specific authz rules. ctx = dbauthz.AsProvisionerd(ctx) parsedID, err := uuid.Parse(request.JobId) if err != nil { return nil, xerrors.Errorf("parse job id: %w", err) } s.Logger.Debug(ctx, "stage UpdateJob starting", slog.F("job_id", parsedID)) job, err := s.Database.GetProvisionerJobByID(ctx, parsedID) if err != nil { return nil, xerrors.Errorf("get job: %w", err) } if !job.WorkerID.Valid { return nil, xerrors.New("job isn't running yet") } if job.WorkerID.UUID.String() != s.ID.String() { return nil, xerrors.New("you don't own this job") } err = s.Database.UpdateProvisionerJobByID(ctx, database.UpdateProvisionerJobByIDParams{ ID: parsedID, UpdatedAt: s.timeNow(), }) if err != nil { return nil, xerrors.Errorf("update job: %w", err) } if len(request.Logs) > 0 && !job.LogsOverflowed { //nolint:exhaustruct // We append to the additional fields below. insertParams := database.InsertProvisionerJobLogsParams{ JobID: parsedID, } newLogSize := 0 overflowedErrorMsg := "Provisioner logs exceeded the max size of 1MB. Will not continue to write provisioner logs for workspace build." lenErrMsg := len(overflowedErrorMsg) var ( createdAt time.Time level database.LogLevel stage string source database.LogSource output string ) for _, log := range request.Logs { // Build our log params level, err = convertLogLevel(log.Level) if err != nil { return nil, xerrors.Errorf("convert log level: %w", err) } source, err = convertLogSource(log.Source) if err != nil { return nil, xerrors.Errorf("convert log source: %w", err) } createdAt = time.UnixMilli(log.CreatedAt) stage = log.Stage output = log.Output // Check if we would overflow the job logs (not leaving enough room for the error message) willOverflow := int64(job.LogsLength)+int64(newLogSize)+int64(lenErrMsg)+int64(len(output)) > 1048576 if willOverflow { s.Logger.Debug(ctx, "provisioner job logs overflowed 1MB size limit in database", slog.F("job_id", parsedID)) err = s.Database.UpdateProvisionerJobLogsOverflowed(ctx, database.UpdateProvisionerJobLogsOverflowedParams{ ID: parsedID, LogsOverflowed: true, }) if err != nil { s.Logger.Error(ctx, "failed to set logs overflowed flag", slog.F("job_id", parsedID), slog.Error(err)) } level = database.LogLevelWarn output = overflowedErrorMsg } newLogSize += len(output) insertParams.CreatedAt = append(insertParams.CreatedAt, createdAt) insertParams.Level = append(insertParams.Level, level) insertParams.Stage = append(insertParams.Stage, stage) insertParams.Source = append(insertParams.Source, source) insertParams.Output = append(insertParams.Output, output) s.Logger.Debug(ctx, "job log", slog.F("job_id", parsedID), slog.F("stage", stage), slog.F("output", output)) // Don't write any more logs because there's no room. if willOverflow { break } } err = s.Database.UpdateProvisionerJobLogsLength(ctx, database.UpdateProvisionerJobLogsLengthParams{ ID: parsedID, LogsLength: int32(newLogSize), // #nosec G115 - Log output length is limited to 1MB (2^20) which fits in an int32. }) if err != nil { // Even though we do the runtime check for the overflow, we still check for the database error // as well. if database.IsProvisionerJobLogsLimitError(err) { err = s.Database.UpdateProvisionerJobLogsOverflowed(ctx, database.UpdateProvisionerJobLogsOverflowedParams{ ID: parsedID, LogsOverflowed: true, }) if err != nil { s.Logger.Error(ctx, "failed to set logs overflowed flag", slog.F("job_id", parsedID), slog.Error(err)) } return &proto.UpdateJobResponse{ Canceled: job.CanceledAt.Valid, }, nil } s.Logger.Error(ctx, "failed to update logs length", slog.F("job_id", parsedID), slog.Error(err)) return nil, xerrors.Errorf("update logs length: %w", err) } logs, err := s.Database.InsertProvisionerJobLogs(ctx, insertParams) if err != nil { s.Logger.Error(ctx, "failed to insert job logs", slog.F("job_id", parsedID), slog.Error(err)) return nil, xerrors.Errorf("insert job logs: %w", err) } // Publish by the lowest log ID inserted so the log stream will fetch // everything from that point. lowestID := logs[0].ID s.Logger.Debug(ctx, "inserted job logs", slog.F("job_id", parsedID)) data, err := json.Marshal(provisionersdk.ProvisionerJobLogsNotifyMessage{ CreatedAfter: lowestID - 1, }) if err != nil { return nil, xerrors.Errorf("marshal: %w", err) } err = s.Pubsub.Publish(provisionersdk.ProvisionerJobLogsNotifyChannel(parsedID), data) if err != nil { s.Logger.Error(ctx, "failed to publish job logs", slog.F("job_id", parsedID), slog.Error(err)) return nil, xerrors.Errorf("publish job logs: %w", err) } s.Logger.Debug(ctx, "published job logs", slog.F("job_id", parsedID)) } if len(request.WorkspaceTags) > 0 { templateVersion, err := s.Database.GetTemplateVersionByJobID(ctx, job.ID) if err != nil { s.Logger.Error(ctx, "failed to get the template version", slog.F("job_id", parsedID), slog.Error(err)) return nil, xerrors.Errorf("get template version by job id: %w", err) } for key, value := range request.WorkspaceTags { _, err := s.Database.InsertTemplateVersionWorkspaceTag(ctx, database.InsertTemplateVersionWorkspaceTagParams{ TemplateVersionID: templateVersion.ID, Key: key, Value: value, }) if err != nil { return nil, xerrors.Errorf("update template version workspace tags: %w", err) } } } if len(request.Readme) > 0 { err := s.Database.UpdateTemplateVersionDescriptionByJobID(ctx, database.UpdateTemplateVersionDescriptionByJobIDParams{ JobID: job.ID, Readme: string(request.Readme), UpdatedAt: s.timeNow(), }) if err != nil { return nil, xerrors.Errorf("update template version description: %w", err) } } if len(request.TemplateVariables) > 0 { templateVersion, err := s.Database.GetTemplateVersionByJobID(ctx, job.ID) if err != nil { s.Logger.Error(ctx, "failed to get the template version", slog.F("job_id", parsedID), slog.Error(err)) return nil, xerrors.Errorf("get template version by job id: %w", err) } var variableValues []*sdkproto.VariableValue var variablesWithMissingValues []string for _, templateVariable := range request.TemplateVariables { s.Logger.Debug(ctx, "insert template variable", slog.F("template_version_id", templateVersion.ID), slog.F("template_variable", redactTemplateVariable(templateVariable))) value := templateVariable.DefaultValue for _, v := range request.UserVariableValues { if v.Name == templateVariable.Name { value = v.Value break } } if templateVariable.Required && value == "" { variablesWithMissingValues = append(variablesWithMissingValues, templateVariable.Name) } variableValues = append(variableValues, &sdkproto.VariableValue{ Name: templateVariable.Name, Value: value, Sensitive: templateVariable.Sensitive, }) _, err = s.Database.InsertTemplateVersionVariable(ctx, database.InsertTemplateVersionVariableParams{ TemplateVersionID: templateVersion.ID, Name: templateVariable.Name, Description: templateVariable.Description, Type: templateVariable.Type, DefaultValue: templateVariable.DefaultValue, Required: templateVariable.Required, Sensitive: templateVariable.Sensitive, Value: value, }) if err != nil { return nil, xerrors.Errorf("insert parameter schema: %w", err) } } if len(variablesWithMissingValues) > 0 { return nil, xerrors.Errorf("required template variables need values: %s", strings.Join(variablesWithMissingValues, ", ")) } return &proto.UpdateJobResponse{ Canceled: job.CanceledAt.Valid, VariableValues: variableValues, }, nil } return &proto.UpdateJobResponse{ Canceled: job.CanceledAt.Valid, }, nil } func (s *server) FailJob(ctx context.Context, failJob *proto.FailedJob) (*proto.Empty, error) { ctx, span := s.startTrace(ctx, tracing.FuncName()) defer span.End() //nolint:gocritic // Provisionerd has specific authz rules. ctx = dbauthz.AsProvisionerd(ctx) jobID, err := uuid.Parse(failJob.JobId) if err != nil { return nil, xerrors.Errorf("parse job id: %w", err) } s.Logger.Debug(ctx, "stage FailJob starting", slog.F("job_id", jobID)) job, err := s.Database.GetProvisionerJobByID(ctx, jobID) if err != nil { return nil, xerrors.Errorf("get provisioner job: %w", err) } if job.WorkerID.UUID.String() != s.ID.String() { return nil, xerrors.New("you don't own this job") } if job.CompletedAt.Valid { return nil, xerrors.Errorf("job already completed") } job.CompletedAt = sql.NullTime{ Time: s.timeNow(), Valid: true, } job.Error = sql.NullString{ String: failJob.Error, Valid: failJob.Error != "", } job.ErrorCode = sql.NullString{ String: failJob.ErrorCode, Valid: failJob.ErrorCode != "", } err = s.Database.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{ ID: jobID, CompletedAt: job.CompletedAt, UpdatedAt: s.timeNow(), Error: job.Error, ErrorCode: job.ErrorCode, }) if err != nil { return nil, xerrors.Errorf("update provisioner job: %w", err) } s.Telemetry.Report(&telemetry.Snapshot{ ProvisionerJobs: []telemetry.ProvisionerJob{telemetry.ConvertProvisionerJob(job)}, }) switch jobType := failJob.Type.(type) { case *proto.FailedJob_WorkspaceBuild_: var input WorkspaceProvisionJob err = json.Unmarshal(job.Input, &input) if err != nil { return nil, xerrors.Errorf("unmarshal workspace provision input: %w", err) } var build database.WorkspaceBuild var workspace database.Workspace err = s.Database.InTx(func(db database.Store) error { build, err = db.GetWorkspaceBuildByID(ctx, input.WorkspaceBuildID) if err != nil { return xerrors.Errorf("get workspace build: %w", err) } workspace, err = db.GetWorkspaceByID(ctx, build.WorkspaceID) if err != nil { return xerrors.Errorf("get workspace: %w", err) } if jobType.WorkspaceBuild.State != nil { err = db.UpdateWorkspaceBuildProvisionerStateByID(ctx, database.UpdateWorkspaceBuildProvisionerStateByIDParams{ ID: input.WorkspaceBuildID, UpdatedAt: s.timeNow(), ProvisionerState: jobType.WorkspaceBuild.State, }) if err != nil { return xerrors.Errorf("update workspace build state: %w", err) } deadline := build.Deadline maxDeadline := build.MaxDeadline if workspace.IsPrebuild() { deadline = time.Time{} maxDeadline = time.Time{} } err = db.UpdateWorkspaceBuildDeadlineByID(ctx, database.UpdateWorkspaceBuildDeadlineByIDParams{ ID: input.WorkspaceBuildID, UpdatedAt: s.timeNow(), Deadline: deadline, MaxDeadline: maxDeadline, }) if err != nil { return xerrors.Errorf("update workspace build deadline: %w", err) } } return nil }, nil) if err != nil { return nil, err } s.notifyWorkspaceBuildFailed(ctx, workspace, build) msg, err := json.Marshal(wspubsub.WorkspaceEvent{ Kind: wspubsub.WorkspaceEventKindStateChange, WorkspaceID: workspace.ID, }) if err != nil { return nil, xerrors.Errorf("marshal workspace update event: %s", err) } err = s.Pubsub.Publish(wspubsub.WorkspaceEventChannel(workspace.OwnerID), msg) if err != nil { return nil, xerrors.Errorf("publish workspace update: %w", err) } // Publish workspace build update to the all builds channel if the experiment is enabled. if s.Experiments.Enabled(codersdk.ExperimentWorkspaceBuildUpdates) { err = wspubsub.PublishWorkspaceBuildUpdate(ctx, s.Pubsub, codersdk.WorkspaceBuildUpdate{ WorkspaceID: workspace.ID, WorkspaceName: workspace.Name, BuildID: build.ID, Transition: string(build.Transition), JobStatus: string(database.ProvisionerJobStatusFailed), BuildNumber: build.BuildNumber, }) if err != nil { s.Logger.Warn(ctx, "failed to publish workspace build update", slog.Error(err)) } } case *proto.FailedJob_TemplateImport_: } // if failed job is a workspace build, audit the outcome if job.Type == database.ProvisionerJobTypeWorkspaceBuild { auditor := s.Auditor.Load() build, err := s.Database.GetWorkspaceBuildByJobID(ctx, job.ID) if err != nil { s.Logger.Error(ctx, "audit log - get build", slog.Error(err)) } else { auditAction := auditActionFromTransition(build.Transition) workspace, err := s.Database.GetWorkspaceByID(ctx, build.WorkspaceID) if err != nil { s.Logger.Error(ctx, "audit log - get workspace", slog.Error(err)) } else { previousBuildNumber := build.BuildNumber - 1 previousBuild, prevBuildErr := s.Database.GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx, database.GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams{ WorkspaceID: workspace.ID, BuildNumber: previousBuildNumber, }) if prevBuildErr != nil { previousBuild = database.WorkspaceBuild{} } // We pass the below information to the Auditor so that it // can form a friendly string for the user to view in the UI. buildResourceInfo := audit.AdditionalFields{ WorkspaceName: workspace.Name, BuildNumber: strconv.FormatInt(int64(build.BuildNumber), 10), BuildReason: database.BuildReason(string(build.Reason)), WorkspaceID: workspace.ID, } wriBytes, err := json.Marshal(buildResourceInfo) if err != nil { s.Logger.Error(ctx, "marshal workspace resource info for failed job", slog.Error(err)) wriBytes = []byte("{}") } bag := audit.BaggageFromContext(ctx) audit.BackgroundAudit(ctx, &audit.BackgroundAuditParams[database.WorkspaceBuild]{ Audit: *auditor, Log: s.Logger, UserID: job.InitiatorID, OrganizationID: workspace.OrganizationID, RequestID: job.ID, IP: bag.IP, Action: auditAction, Old: previousBuild, New: build, Status: http.StatusInternalServerError, AdditionalFields: wriBytes, }) } } } data, err := json.Marshal(provisionersdk.ProvisionerJobLogsNotifyMessage{EndOfLogs: true}) if err != nil { return nil, xerrors.Errorf("marshal job log: %w", err) } err = s.Pubsub.Publish(provisionersdk.ProvisionerJobLogsNotifyChannel(jobID), data) if err != nil { s.Logger.Error(ctx, "failed to publish end of job logs", slog.F("job_id", jobID), slog.Error(err)) return nil, xerrors.Errorf("publish end of job logs: %w", err) } return &proto.Empty{}, nil } func (s *server) notifyWorkspaceBuildFailed(ctx context.Context, workspace database.Workspace, build database.WorkspaceBuild) { var reason string if build.Reason.Valid() && build.Reason == database.BuildReasonInitiator { s.notifyWorkspaceManualBuildFailed(ctx, workspace, build) return } reason = string(build.Reason) if _, err := s.NotificationsEnqueuer.Enqueue(ctx, workspace.OwnerID, notifications.TemplateWorkspaceAutobuildFailed, map[string]string{ "name": workspace.Name, "reason": reason, }, "provisionerdserver", // Associate this notification with all the related entities. workspace.ID, workspace.OwnerID, workspace.TemplateID, workspace.OrganizationID, ); err != nil { s.Logger.Warn(ctx, "failed to notify of failed workspace autobuild", slog.Error(err)) } } func (s *server) notifyWorkspaceManualBuildFailed(ctx context.Context, workspace database.Workspace, build database.WorkspaceBuild) { templateAdmins, template, templateVersion, workspaceOwner, err := s.prepareForNotifyWorkspaceManualBuildFailed(ctx, workspace, build) if err != nil { s.Logger.Error(ctx, "unable to collect data for manual build failed notification", slog.Error(err)) return } for _, templateAdmin := range templateAdmins { templateNameLabel := template.DisplayName if templateNameLabel == "" { templateNameLabel = template.Name } labels := map[string]string{ "name": workspace.Name, "template_name": templateNameLabel, "template_version_name": templateVersion.Name, "initiator": build.InitiatorByUsername, "workspace_owner_username": workspaceOwner.Username, "workspace_build_number": strconv.Itoa(int(build.BuildNumber)), } if _, err := s.NotificationsEnqueuer.Enqueue(ctx, templateAdmin.ID, notifications.TemplateWorkspaceManualBuildFailed, labels, "provisionerdserver", // Associate this notification with all the related entities. workspace.ID, workspace.OwnerID, workspace.TemplateID, workspace.OrganizationID, ); err != nil { s.Logger.Warn(ctx, "failed to notify of failed workspace manual build", slog.Error(err)) } } } // prepareForNotifyWorkspaceManualBuildFailed collects data required to build notifications for template admins. // The template `notifications.TemplateWorkspaceManualBuildFailed` is quite detailed as it requires information about the template, // template version, workspace, workspace build, etc. func (s *server) prepareForNotifyWorkspaceManualBuildFailed(ctx context.Context, workspace database.Workspace, build database.WorkspaceBuild) ([]database.GetUsersRow, database.Template, database.TemplateVersion, database.User, error, ) { users, err := s.Database.GetUsers(ctx, database.GetUsersParams{ RbacRole: []string{codersdk.RoleTemplateAdmin}, }) if err != nil { return nil, database.Template{}, database.TemplateVersion{}, database.User{}, xerrors.Errorf("unable to fetch template admins: %w", err) } usersByIDs := map[uuid.UUID]database.GetUsersRow{} var userIDs []uuid.UUID for _, user := range users { usersByIDs[user.ID] = user userIDs = append(userIDs, user.ID) } var templateAdmins []database.GetUsersRow if len(userIDs) > 0 { orgIDsByMemberIDs, err := s.Database.GetOrganizationIDsByMemberIDs(ctx, userIDs) if err != nil { return nil, database.Template{}, database.TemplateVersion{}, database.User{}, xerrors.Errorf("unable to fetch organization IDs by member IDs: %w", err) } for _, entry := range orgIDsByMemberIDs { if slices.Contains(entry.OrganizationIDs, workspace.OrganizationID) { templateAdmins = append(templateAdmins, usersByIDs[entry.UserID]) } } } sort.Slice(templateAdmins, func(i, j int) bool { return templateAdmins[i].Username < templateAdmins[j].Username }) template, err := s.Database.GetTemplateByID(ctx, workspace.TemplateID) if err != nil { return nil, database.Template{}, database.TemplateVersion{}, database.User{}, xerrors.Errorf("unable to fetch template: %w", err) } templateVersion, err := s.Database.GetTemplateVersionByID(ctx, build.TemplateVersionID) if err != nil { return nil, database.Template{}, database.TemplateVersion{}, database.User{}, xerrors.Errorf("unable to fetch template version: %w", err) } workspaceOwner, err := s.Database.GetUserByID(ctx, workspace.OwnerID) if err != nil { return nil, database.Template{}, database.TemplateVersion{}, database.User{}, xerrors.Errorf("unable to fetch workspace owner: %w", err) } return templateAdmins, template, templateVersion, workspaceOwner, nil } func (s *server) UploadFile(stream proto.DRPCProvisionerDaemon_UploadFileStream) error { var file *sdkproto.DataBuilder // Always terminate the stream with an empty response. //nolint:errcheck // We can't do much about send errors here. defer stream.SendAndClose(&proto.Empty{}) file, err := provisionersdk.HandleReceivingDataUpload(stream) if err != nil { return err } fileData, err := file.Complete() if err != nil { return xerrors.Errorf("complete file upload: %w", err) } // Just rehash the data to be sure it is correct. hashBytes := sha256.Sum256(fileData) hash := hex.EncodeToString(hashBytes[:]) var insert database.InsertFileParams switch file.Type { case sdkproto.DataUploadType_UPLOAD_TYPE_MODULE_FILES: insert = database.InsertFileParams{ ID: uuid.New(), Hash: hash, CreatedAt: dbtime.Now(), CreatedBy: uuid.Nil, Mimetype: tarMimeType, Data: fileData, } default: return xerrors.Errorf("unsupported file upload type: %s", file.Type) } //nolint:gocritic // Provisionerd actor _, err = s.Database.InsertFile(dbauthz.AsProvisionerd(s.lifecycleCtx), insert) if err != nil { // Duplicated files already exist in the database, so we can ignore this error. if !database.IsUniqueViolation(err, database.UniqueFilesHashCreatedByKey) { return xerrors.Errorf("insert file: %w", err) } } s.Logger.Info(s.lifecycleCtx, "file uploaded to database", slog.F("type", file.Type.String()), slog.F("hash", hash), slog.F("size", len(fileData)), // new_insert indicates whether the file was newly inserted or already existed. slog.F("new_insert", err == nil), ) return nil } // DownloadFile pulls the requested file from the database and sends it over the protobuf stream in chunks. func (s *server) DownloadFile(request *proto.FileRequest, stream proto.DRPCProvisionerDaemon_DownloadFileStream) error { //nolint:errcheck defer stream.CloseSend() //nolint:gocritic // Provisionerd is the actor here. ctx := dbauthz.AsProvisionerd(stream.Context()) // A graceful error message will help debugging. fail := func(err error) error { if sendErr := stream.Send(&sdkproto.FileUpload{ Type: &sdkproto.FileUpload_Error{ Error: &sdkproto.FailedFile{ Error: err.Error(), }, }, }); sendErr != nil { s.Logger.Warn(ctx, "failed to send error response on download stream", slog.Error(sendErr), slog.F("original_error", err.Error()), ) } return err } if request.FileId == "" || request.FileId == uuid.Nil.String() { return fail(xerrors.New("file id is required")) } fid, err := uuid.Parse(request.FileId) if err != nil { return fail(xerrors.Errorf("invalid file id: %w", err)) } file, err := s.Database.GetFileByID(ctx, fid) if err != nil { return fail(xerrors.Errorf("get file: %w", err)) } switch request.UploadType { case sdkproto.DataUploadType_UPLOAD_TYPE_MODULE_FILES: // This check is not perfect. If these conditions are not true, then the file is not a modules file. if file.CreatedBy != uuid.Nil || file.Mimetype != tarMimeType { return fail(xerrors.Errorf("file %s is not a modules file", fid)) } default: return fail(xerrors.Errorf("unsupported file upload type: %s", request.UploadType)) } upload, chunks := sdkproto.BytesToDataUpload(sdkproto.DataUploadType_UPLOAD_TYPE_MODULE_FILES, file.Data) err = stream.Send(&sdkproto.FileUpload{ Type: &sdkproto.FileUpload_DataUpload{DataUpload: upload}, }) if err != nil { return fail(xerrors.Errorf("send file upload: %w", err)) } for i, c := range chunks { if ctx.Err() != nil { return fail(ctx.Err()) } err = stream.Send(&sdkproto.FileUpload{ Type: &sdkproto.FileUpload_ChunkPiece{ChunkPiece: c}, }) if err != nil { return fail(xerrors.Errorf("send chunk piece %d: %w", i, err)) } } return nil } // CompleteJob is triggered by a provision daemon to mark a provisioner job as completed. func (s *server) CompleteJob(ctx context.Context, completed *proto.CompletedJob) (*proto.Empty, error) { ctx, span := s.startTrace(ctx, tracing.FuncName()) defer span.End() //nolint:gocritic // Provisionerd has specific authz rules. ctx = dbauthz.AsProvisionerd(ctx) jobID, err := uuid.Parse(completed.JobId) if err != nil { return nil, xerrors.Errorf("parse job id: %w", err) } s.Logger.Debug(ctx, "stage CompleteJob starting", slog.F("job_id", jobID)) job, err := s.Database.GetProvisionerJobByID(ctx, jobID) if err != nil { return nil, xerrors.Errorf("get job by id: %w", err) } if job.WorkerID.UUID.String() != s.ID.String() { return nil, xerrors.Errorf("you don't own this job") } telemetrySnapshot := &telemetry.Snapshot{} // Items are added to this snapshot as they complete! defer s.Telemetry.Report(telemetrySnapshot) switch jobType := completed.Type.(type) { case *proto.CompletedJob_TemplateImport_: err = s.completeTemplateImportJob(ctx, job, jobID, jobType, telemetrySnapshot) if err != nil { return nil, err } case *proto.CompletedJob_WorkspaceBuild_: err = s.completeWorkspaceBuildJob(ctx, job, jobID, jobType, telemetrySnapshot) if err != nil { return nil, err } case *proto.CompletedJob_TemplateDryRun_: err = s.completeTemplateDryRunJob(ctx, job, jobID, jobType, telemetrySnapshot) if err != nil { return nil, err } default: if completed.Type == nil { return nil, xerrors.Errorf("type payload must be provided") } return nil, xerrors.Errorf("unknown job type %q; ensure coderd and provisionerd versions match", reflect.TypeOf(completed.Type).String()) } data, err := json.Marshal(provisionersdk.ProvisionerJobLogsNotifyMessage{EndOfLogs: true}) if err != nil { return nil, xerrors.Errorf("marshal job log: %w", err) } err = s.Pubsub.Publish(provisionersdk.ProvisionerJobLogsNotifyChannel(jobID), data) if err != nil { s.Logger.Error(ctx, "failed to publish end of job logs", slog.F("job_id", jobID), slog.Error(err)) return nil, xerrors.Errorf("publish end of job logs: %w", err) } s.Logger.Debug(ctx, "stage CompleteJob done", slog.F("job_id", jobID)) return &proto.Empty{}, nil } // completeTemplateImportJob handles completion of a template import job. // All database operations are performed within a transaction. func (s *server) completeTemplateImportJob(ctx context.Context, job database.ProvisionerJob, jobID uuid.UUID, jobType *proto.CompletedJob_TemplateImport_, telemetrySnapshot *telemetry.Snapshot) error { var input TemplateVersionImportJob err := json.Unmarshal(job.Input, &input) if err != nil { return xerrors.Errorf("template version ID is expected: %w", err) } // Execute all database operations in a transaction return s.Database.InTx(func(db database.Store) error { now := s.timeNow() // Process resources for transition, resources := range map[database.WorkspaceTransition][]*sdkproto.Resource{ database.WorkspaceTransitionStart: jobType.TemplateImport.StartResources, database.WorkspaceTransitionStop: jobType.TemplateImport.StopResources, } { for _, resource := range resources { s.Logger.Info(ctx, "inserting template import job resource", slog.F("job_id", job.ID.String()), slog.F("resource_name", resource.Name), slog.F("resource_type", resource.Type), slog.F("transition", transition)) if err := InsertWorkspaceResource(ctx, db, jobID, transition, resource, telemetrySnapshot); err != nil { return xerrors.Errorf("insert resource: %w", err) } } } // Process modules for transition, modules := range map[database.WorkspaceTransition][]*sdkproto.Module{ database.WorkspaceTransitionStart: jobType.TemplateImport.StartModules, } { for _, module := range modules { s.Logger.Info(ctx, "inserting template import job module", slog.F("job_id", job.ID.String()), slog.F("module_source", module.Source), slog.F("module_version", module.Version), slog.F("module_key", module.Key), slog.F("transition", transition)) if err := InsertWorkspaceModule(ctx, db, jobID, transition, module, telemetrySnapshot); err != nil { return xerrors.Errorf("insert module: %w", err) } } } // Process rich parameters for _, richParameter := range jobType.TemplateImport.RichParameters { s.Logger.Info(ctx, "inserting template import job parameter", slog.F("job_id", job.ID.String()), slog.F("parameter_name", richParameter.Name), slog.F("type", richParameter.Type), slog.F("ephemeral", richParameter.Ephemeral), ) options, err := json.Marshal(richParameter.Options) if err != nil { return xerrors.Errorf("marshal parameter options: %w", err) } var validationMin, validationMax sql.NullInt32 if richParameter.ValidationMin != nil { validationMin = sql.NullInt32{ Int32: *richParameter.ValidationMin, Valid: true, } } if richParameter.ValidationMax != nil { validationMax = sql.NullInt32{ Int32: *richParameter.ValidationMax, Valid: true, } } pft, err := sdkproto.ProviderFormType(richParameter.FormType) if err != nil { return xerrors.Errorf("parameter %q: %w", richParameter.Name, err) } dft := database.ParameterFormType(pft) if !dft.Valid() { list := strings.Join(slice.ToStrings(database.AllParameterFormTypeValues()), ", ") return xerrors.Errorf("parameter %q field 'form_type' not valid, currently supported: %s", richParameter.Name, list) } _, err = db.InsertTemplateVersionParameter(ctx, database.InsertTemplateVersionParameterParams{ TemplateVersionID: input.TemplateVersionID, Name: richParameter.Name, DisplayName: richParameter.DisplayName, Description: richParameter.Description, Type: richParameter.Type, FormType: dft, Mutable: richParameter.Mutable, DefaultValue: richParameter.DefaultValue, Icon: richParameter.Icon, Options: options, ValidationRegex: richParameter.ValidationRegex, ValidationError: richParameter.ValidationError, ValidationMin: validationMin, ValidationMax: validationMax, ValidationMonotonic: richParameter.ValidationMonotonic, Required: richParameter.Required, DisplayOrder: richParameter.Order, Ephemeral: richParameter.Ephemeral, }) if err != nil { return xerrors.Errorf("insert parameter: %w", err) } } // Process presets and parameters err := InsertWorkspacePresetsAndParameters(ctx, s.Logger, db, jobID, input.TemplateVersionID, jobType.TemplateImport.Presets, now) if err != nil { return xerrors.Errorf("insert workspace presets and parameters: %w", err) } // Process external auth providers var completedError sql.NullString for _, externalAuthProvider := range jobType.TemplateImport.ExternalAuthProviders { contains := false for _, configuredProvider := range s.ExternalAuthConfigs { if configuredProvider.ID == externalAuthProvider.Id { contains = true break } } if !contains { completedError = sql.NullString{ String: fmt.Sprintf("external auth provider %q is not configured", externalAuthProvider.Id), Valid: true, } break } } // Fallback to `ExternalAuthProvidersNames` if it was specified and `ExternalAuthProviders` // was not. Gives us backwards compatibility with custom provisioners that haven't been // updated to use the new field yet. var externalAuthProviders []database.ExternalAuthProvider if providersLen := len(jobType.TemplateImport.ExternalAuthProviders); providersLen > 0 { externalAuthProviders = make([]database.ExternalAuthProvider, 0, providersLen) for _, provider := range jobType.TemplateImport.ExternalAuthProviders { externalAuthProviders = append(externalAuthProviders, database.ExternalAuthProvider{ ID: provider.Id, Optional: provider.Optional, }) } } else if namesLen := len(jobType.TemplateImport.ExternalAuthProvidersNames); namesLen > 0 { externalAuthProviders = make([]database.ExternalAuthProvider, 0, namesLen) for _, providerID := range jobType.TemplateImport.ExternalAuthProvidersNames { externalAuthProviders = append(externalAuthProviders, database.ExternalAuthProvider{ ID: providerID, }) } } externalAuthProvidersMessage, err := json.Marshal(externalAuthProviders) if err != nil { return xerrors.Errorf("failed to serialize external_auth_providers value: %w", err) } err = db.UpdateTemplateVersionExternalAuthProvidersByJobID(ctx, database.UpdateTemplateVersionExternalAuthProvidersByJobIDParams{ JobID: jobID, ExternalAuthProviders: externalAuthProvidersMessage, UpdatedAt: now, }) if err != nil { return xerrors.Errorf("update template version external auth providers: %w", err) } err = db.UpdateTemplateVersionFlagsByJobID(ctx, database.UpdateTemplateVersionFlagsByJobIDParams{ JobID: jobID, HasAITask: sql.NullBool{ Bool: jobType.TemplateImport.HasAiTasks, Valid: true, }, HasExternalAgent: sql.NullBool{ Bool: jobType.TemplateImport.HasExternalAgents, Valid: true, }, UpdatedAt: now, }) if err != nil { return xerrors.Errorf("update template version ai task and external agent: %w", err) } // Process terraform values plan := jobType.TemplateImport.Plan moduleFiles := jobType.TemplateImport.ModuleFiles // If there is a plan, or a module files archive we need to insert a // template_version_terraform_values row. if len(plan) > 0 || len(moduleFiles) > 0 { // ...but the plan and the module files archive are both optional! So // we need to fallback to a valid JSON object if the plan was omitted. if len(plan) == 0 { plan = []byte("{}") } // ...and we only want to insert a files row if an archive was provided. var fileID uuid.NullUUID if len(moduleFiles) > 0 { hashBytes := sha256.Sum256(moduleFiles) hash := hex.EncodeToString(hashBytes[:]) //nolint:gocritic // Acting as provisionerd file, err := db.GetFileByHashAndCreator(dbauthz.AsProvisionerd(ctx), database.GetFileByHashAndCreatorParams{Hash: hash, CreatedBy: uuid.Nil}) switch { case err == nil: // This set of modules is already cached, which means we can reuse them fileID = uuid.NullUUID{ Valid: true, UUID: file.ID, } case !xerrors.Is(err, sql.ErrNoRows): return xerrors.Errorf("check for cached modules: %w", err) default: //nolint:gocritic // Acting as provisionerd file, err = db.InsertFile(dbauthz.AsProvisionerd(ctx), database.InsertFileParams{ ID: uuid.New(), Hash: hash, CreatedBy: uuid.Nil, CreatedAt: dbtime.Now(), Mimetype: tarMimeType, Data: moduleFiles, }) if err != nil { return xerrors.Errorf("insert template version terraform modules: %w", err) } fileID = uuid.NullUUID{ Valid: true, UUID: file.ID, } } } if len(jobType.TemplateImport.ModuleFilesHash) > 0 { hashString := hex.EncodeToString(jobType.TemplateImport.ModuleFilesHash) //nolint:gocritic // Acting as provisioner file, err := db.GetFileByHashAndCreator(dbauthz.AsProvisionerd(ctx), database.GetFileByHashAndCreatorParams{Hash: hashString, CreatedBy: uuid.Nil}) if err != nil { return xerrors.Errorf("get file by hash, it should have been uploaded: %w", err) } fileID = uuid.NullUUID{ Valid: true, UUID: file.ID, } } err = db.InsertTemplateVersionTerraformValuesByJobID(ctx, database.InsertTemplateVersionTerraformValuesByJobIDParams{ JobID: jobID, UpdatedAt: now, CachedPlan: plan, CachedModuleFiles: fileID, ProvisionerdVersion: s.apiVersion, }) if err != nil { return xerrors.Errorf("insert template version terraform data: %w", err) } } // Mark job as completed err = db.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{ ID: jobID, UpdatedAt: now, CompletedAt: sql.NullTime{ Time: now, Valid: true, }, Error: completedError, ErrorCode: sql.NullString{}, }) if err != nil { return xerrors.Errorf("update provisioner job: %w", err) } s.Logger.Debug(ctx, "marked import job as completed", slog.F("job_id", jobID)) return nil }, nil) // End of transaction } // completeWorkspaceBuildJob handles completion of a workspace build job. // Most database operations are performed within a transaction. func (s *server) completeWorkspaceBuildJob(ctx context.Context, job database.ProvisionerJob, jobID uuid.UUID, jobType *proto.CompletedJob_WorkspaceBuild_, telemetrySnapshot *telemetry.Snapshot) error { var input WorkspaceProvisionJob err := json.Unmarshal(job.Input, &input) if err != nil { return xerrors.Errorf("unmarshal job data: %w", err) } workspaceBuild, err := s.Database.GetWorkspaceBuildByID(ctx, input.WorkspaceBuildID) if err != nil { return xerrors.Errorf("get workspace build: %w", err) } var workspace database.Workspace var getWorkspaceError error // Execute all database modifications in a transaction err = s.Database.InTx(func(db database.Store) error { // It's important we use s.timeNow() here because we want to be // able to customize the current time from within tests. now := s.timeNow() workspace, getWorkspaceError = db.GetWorkspaceByID(ctx, workspaceBuild.WorkspaceID) if getWorkspaceError != nil { s.Logger.Error(ctx, "fetch workspace for build", slog.F("workspace_build_id", workspaceBuild.ID), slog.F("workspace_id", workspaceBuild.WorkspaceID), ) return getWorkspaceError } // Prebuilt workspaces must not have Deadline or MaxDeadline set, // as they are managed by the prebuild reconciliation loop, not the lifecycle executor deadline := time.Time{} maxDeadline := time.Time{} if !workspace.IsPrebuild() { templateScheduleStore := *s.TemplateScheduleStore.Load() autoStop, err := schedule.CalculateAutostop(ctx, schedule.CalculateAutostopParams{ Database: db, TemplateScheduleStore: templateScheduleStore, UserQuietHoursScheduleStore: *s.UserQuietHoursScheduleStore.Load(), // `now` is used below to set the build completion time. WorkspaceBuildCompletedAt: now, Workspace: workspace.WorkspaceTable(), // Allowed to be the empty string. WorkspaceAutostart: workspace.AutostartSchedule.String, }) if err != nil { return xerrors.Errorf("calculate auto stop: %w", err) } if workspace.AutostartSchedule.Valid { templateScheduleOptions, err := templateScheduleStore.Get(ctx, db, workspace.TemplateID) if err != nil { return xerrors.Errorf("get template schedule options: %w", err) } nextStartAt, err := schedule.NextAllowedAutostart(now, workspace.AutostartSchedule.String, templateScheduleOptions) if err == nil { err = db.UpdateWorkspaceNextStartAt(ctx, database.UpdateWorkspaceNextStartAtParams{ ID: workspace.ID, NextStartAt: sql.NullTime{Valid: true, Time: nextStartAt.UTC()}, }) if err != nil { return xerrors.Errorf("update workspace next start at: %w", err) } } } deadline = autoStop.Deadline maxDeadline = autoStop.MaxDeadline } err = db.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{ ID: jobID, UpdatedAt: now, CompletedAt: sql.NullTime{ Time: now, Valid: true, }, Error: sql.NullString{}, ErrorCode: sql.NullString{}, }) if err != nil { return xerrors.Errorf("update provisioner job: %w", err) } err = db.UpdateWorkspaceBuildProvisionerStateByID(ctx, database.UpdateWorkspaceBuildProvisionerStateByIDParams{ ID: workspaceBuild.ID, ProvisionerState: jobType.WorkspaceBuild.State, UpdatedAt: now, }) if err != nil { return xerrors.Errorf("update workspace build provisioner state: %w", err) } err = db.UpdateWorkspaceBuildDeadlineByID(ctx, database.UpdateWorkspaceBuildDeadlineByIDParams{ ID: workspaceBuild.ID, Deadline: deadline, MaxDeadline: maxDeadline, UpdatedAt: now, }) if err != nil { return xerrors.Errorf("update workspace build deadline: %w", err) } appIDs := make([]string, 0) agentIDByAppID := make(map[string]uuid.UUID) agentTimeouts := make(map[time.Duration]bool) // A set of agent timeouts. // This could be a bulk insert to improve performance. for _, protoResource := range jobType.WorkspaceBuild.Resources { for _, protoAgent := range protoResource.GetAgents() { if protoAgent == nil { continue } // By default InsertWorkspaceResource ignores the protoAgent.Id // and generates a new one, but we will insert these using the // InsertWorkspaceResourceWithAgentIDsFromProto option so that // we can properly map agent IDs to app IDs. This is needed for // task linking. agentID := uuid.New() protoAgent.Id = agentID.String() dur := time.Duration(protoAgent.GetConnectionTimeoutSeconds()) * time.Second agentTimeouts[dur] = true for _, app := range protoAgent.GetApps() { appIDs = append(appIDs, app.GetId()) agentIDByAppID[app.GetId()] = agentID } // Subagents in devcontainers can also have apps that need // tracking for task linking, just like the parent agent's // apps above. for _, dc := range protoAgent.GetDevcontainers() { dc.Id = uuid.New().String() if dc.GetSubagentId() != "" { subAgentID := uuid.New() dc.SubagentId = subAgentID.String() for _, app := range dc.GetApps() { appIDs = append(appIDs, app.GetId()) agentIDByAppID[app.GetId()] = subAgentID } } } } err = InsertWorkspaceResource( ctx, db, job.ID, workspaceBuild.Transition, protoResource, telemetrySnapshot, // Ensure that the agent IDs we set previously // are written to the database. InsertWorkspaceResourceWithAgentIDsFromProto(), ) if err != nil { return xerrors.Errorf("insert provisioner job: %w", err) } } // Soft-delete agents from prior builds now that this build's // agents have been inserted. Waiting until completion (rather // than build creation) avoids bricking running workspaces // whose agents would otherwise be deleted while the new build // is still queued or provisioning. See #25155. err = db.SoftDeletePriorWorkspaceAgents(ctx, database.SoftDeletePriorWorkspaceAgentsParams{ WorkspaceID: workspaceBuild.WorkspaceID, CurrentBuildID: workspaceBuild.ID, }) if err != nil { return xerrors.Errorf("soft delete prior workspace agents: %w", err) } for _, module := range jobType.WorkspaceBuild.Modules { if err := InsertWorkspaceModule(ctx, db, job.ID, workspaceBuild.Transition, module, telemetrySnapshot); err != nil { return xerrors.Errorf("insert provisioner job module: %w", err) } } var ( unknownAppID string taskAppID uuid.NullUUID taskAgentID uuid.NullUUID ) if tasks := jobType.WorkspaceBuild.GetAiTasks(); len(tasks) > 0 { task := tasks[0] if task == nil { return xerrors.Errorf("update ai task: task is nil") } appID := task.GetAppId() if appID == "" && task.GetSidebarApp() != nil { appID = task.GetSidebarApp().GetId() } if appID == "" { return xerrors.Errorf("update ai task: app id is empty") } if !slices.Contains(appIDs, appID) { unknownAppID = appID } else { // Only parse for valid app and agent to avoid fk violation. id, err := uuid.Parse(appID) if err != nil { return xerrors.Errorf("parse app id: %w", err) } taskAppID = uuid.NullUUID{UUID: id, Valid: true} agentID, ok := agentIDByAppID[appID] taskAgentID = uuid.NullUUID{UUID: agentID, Valid: ok} } } if unknownAppID != "" && workspaceBuild.Transition == database.WorkspaceTransitionStart { // Ref: https://github.com/coder/coder/issues/18776 // This can happen for a number of reasons: // 1. Misconfigured template // 2. Count=0 on the agent due to stop transition, meaning the associated coder_app was not inserted. // Failing the build at this point is not ideal, so log a warning instead. s.Logger.Warn(ctx, "unknown ai_task_app_id", slog.F("ai_task_app_id", unknownAppID), slog.F("job_id", job.ID.String()), slog.F("workspace_id", workspace.ID), slog.F("workspace_build_id", workspaceBuild.ID), slog.F("transition", string(workspaceBuild.Transition)), ) // In order to surface this to the user, we will also insert a warning into the build logs. if _, err := db.InsertProvisionerJobLogs(ctx, database.InsertProvisionerJobLogsParams{ JobID: jobID, CreatedAt: []time.Time{now, now, now, now}, Source: []database.LogSource{database.LogSourceProvisionerDaemon, database.LogSourceProvisionerDaemon, database.LogSourceProvisionerDaemon, database.LogSourceProvisionerDaemon}, Level: []database.LogLevel{database.LogLevelWarn, database.LogLevelWarn, database.LogLevelWarn, database.LogLevelWarn}, Stage: []string{"Cleaning Up", "Cleaning Up", "Cleaning Up", "Cleaning Up"}, Output: []string{ fmt.Sprintf("Unknown ai_task_app_id %q. This workspace will be unable to run AI tasks. This may be due to a template configuration issue, please check with the template author.", unknownAppID), "Template author: double-check the following:", " - You have associated the coder_ai_task with a valid coder_app in your template (ref: https://registry.terraform.io/providers/coder/coder/latest/docs/resources/ai_task).", " - You have associated the coder_agent with at least one other compute resource. Agents with no other associated resources are not inserted into the database.", }, }); err != nil { s.Logger.Error(ctx, "insert provisioner job log for ai task app id warning", slog.F("job_id", jobID), slog.F("workspace_id", workspace.ID), slog.F("workspace_build_id", workspaceBuild.ID), slog.F("transition", string(workspaceBuild.Transition)), ) } } var hasAITask bool if task, err := db.GetTaskByWorkspaceID(ctx, workspace.ID); err == nil { hasAITask = true if workspaceBuild.Transition == database.WorkspaceTransitionStart { // Insert usage event for managed agents. usageInserter := s.UsageInserter.Load() if usageInserter != nil { event := usagetypes.DCManagedAgentsV1{ Count: 1, } err = (*usageInserter).InsertDiscreteUsageEvent(ctx, db, event) if err != nil { return xerrors.Errorf("insert %q event: %w", event.EventType(), err) } } } // Irrespective of whether the agent or sidebar app is present, // perform the upsert to ensure a link between the task and // workspace build. Linking the task to the build is typically // already established by wsbuilder. _, err = db.UpsertTaskWorkspaceApp( ctx, database.UpsertTaskWorkspaceAppParams{ TaskID: task.ID, WorkspaceBuildNumber: workspaceBuild.BuildNumber, WorkspaceAgentID: taskAgentID, WorkspaceAppID: taskAppID, }, ) if err != nil { return xerrors.Errorf("upsert task workspace app: %w", err) } } else if !errors.Is(err, sql.ErrNoRows) { return xerrors.Errorf("get task by workspace id: %w", err) } _, hasExternalAgent := slice.Find(jobType.WorkspaceBuild.Resources, func(resource *sdkproto.Resource) bool { return resource.Type == "coder_external_agent" }) if err := db.UpdateWorkspaceBuildFlagsByID(ctx, database.UpdateWorkspaceBuildFlagsByIDParams{ ID: workspaceBuild.ID, HasAITask: sql.NullBool{ Bool: hasAITask, Valid: true, }, HasExternalAgent: sql.NullBool{ Bool: hasExternalAgent, Valid: true, }, UpdatedAt: now, }); err != nil { return xerrors.Errorf("update workspace build ai tasks and external agent flag: %w", err) } // Insert timings inside the transaction now // nolint:exhaustruct // The other fields are set further down. params := database.InsertProvisionerJobTimingsParams{ JobID: jobID, } for _, t := range jobType.WorkspaceBuild.Timings { start := t.GetStart() if !start.IsValid() || start.AsTime().IsZero() { s.Logger.Warn(ctx, "timings entry has nil or zero start time", slog.F("job_id", job.ID.String()), slog.F("workspace_id", workspace.ID), slog.F("workspace_build_id", workspaceBuild.ID), slog.F("user_id", workspace.OwnerID)) continue } end := t.GetEnd() if !end.IsValid() || end.AsTime().IsZero() { s.Logger.Warn(ctx, "timings entry has nil or zero end time, skipping", slog.F("job_id", job.ID.String()), slog.F("workspace_id", workspace.ID), slog.F("workspace_build_id", workspaceBuild.ID), slog.F("user_id", workspace.OwnerID)) continue } var stg database.ProvisionerJobTimingStage if err := stg.Scan(t.Stage); err != nil { s.Logger.Warn(ctx, "failed to parse timings stage, skipping", slog.F("value", t.Stage)) continue } // Scan does not guarantee validity if !stg.Valid() { s.Logger.Warn(ctx, "invalid stage, will fail insert based one enum", slog.F("value", t.Stage)) continue } params.Stage = append(params.Stage, stg) params.Source = append(params.Source, t.Source) params.Resource = append(params.Resource, t.Resource) params.Action = append(params.Action, t.Action) params.StartedAt = append(params.StartedAt, t.Start.AsTime()) params.EndedAt = append(params.EndedAt, t.End.AsTime()) } _, err = db.InsertProvisionerJobTimings(ctx, params) if err != nil { // A database error here will "fail" this transaction. Making this error fatal. // If this error is seen, add checks above to validate the insert parameters. In // production, timings should not be a fatal error. s.Logger.Warn(ctx, "failed to update provisioner job timings", slog.F("job_id", jobID), slog.Error(err)) return xerrors.Errorf("update provisioner job timings: %w", err) } // On start, we want to ensure that workspace agents timeout statuses // are propagated. This method is simple and does not protect against // notifying in edge cases like when a workspace is stopped soon // after being started. // // Agent timeouts could be minutes apart, resulting in an unresponsive // experience, so we'll notify after every unique timeout seconds if !input.DryRun && workspaceBuild.Transition == database.WorkspaceTransitionStart && len(agentTimeouts) > 0 { timeouts := maps.Keys(agentTimeouts) slices.Sort(timeouts) var updates []<-chan time.Time for _, d := range timeouts { s.Logger.Debug(ctx, "triggering workspace notification after agent timeout", slog.F("workspace_build_id", workspaceBuild.ID), slog.F("timeout", d), ) // Agents are inserted with `dbtime.Now()`, this triggers a // workspace event approximately after created + timeout seconds. updates = append(updates, time.After(d)) } go func() { for _, wait := range updates { select { case <-s.lifecycleCtx.Done(): // If the server is shutting down, we don't want to wait around. s.Logger.Debug(ctx, "stopping notifications due to server shutdown", slog.F("workspace_build_id", workspaceBuild.ID), ) return case <-wait: // Wait for the next potential timeout to occur. msg, err := json.Marshal(wspubsub.WorkspaceEvent{ Kind: wspubsub.WorkspaceEventKindAgentTimeout, WorkspaceID: workspace.ID, }) if err != nil { s.Logger.Error(ctx, "marshal workspace update event", slog.Error(err)) break } if err := s.Pubsub.Publish(wspubsub.WorkspaceEventChannel(workspace.OwnerID), msg); err != nil { if s.lifecycleCtx.Err() != nil { // If the server is shutting down, we don't want to log this error, nor wait around. s.Logger.Debug(ctx, "stopping notifications due to server shutdown", slog.F("workspace_build_id", workspaceBuild.ID), ) return } s.Logger.Error(ctx, "workspace notification after agent timeout failed", slog.F("workspace_build_id", workspaceBuild.ID), slog.Error(err), ) } } } }() } if workspaceBuild.Transition != database.WorkspaceTransitionDelete { // This is for deleting a workspace! return nil } err = db.UpdateWorkspaceDeletedByID(ctx, database.UpdateWorkspaceDeletedByIDParams{ ID: workspaceBuild.WorkspaceID, Deleted: true, }) if err != nil { return xerrors.Errorf("update workspace deleted: %w", err) } // Soft-delete any agents tied to this workspace so the // aws-instance-identity handler (which filters on // workspace_agents.deleted) doesn't keep seeing orphaned rows // after the workspace itself is deleted. See #25155. if err := db.SoftDeleteWorkspaceAgentsByWorkspaceID(ctx, workspaceBuild.WorkspaceID); err != nil { return xerrors.Errorf("soft delete workspace agents: %w", err) } // A user might delete their task workspace directly, instead of // deleting the task. To avoid leaving the Task in a scenario where // it has no workspace, we also attempt to delete the task. // // Deleting the task may fail if it has already been deleted as part // of the typical task deletion workflow, so we explicitly allow that. if workspace.TaskID.Valid { if _, err := db.DeleteTask(ctx, database.DeleteTaskParams{ ID: workspace.TaskID.UUID, DeletedAt: dbtime.Now(), }); err != nil && !errors.Is(err, sql.ErrNoRows) { return xerrors.Errorf("delete task related to workspace: %w", err) } } return nil }, nil) if err != nil { return xerrors.Errorf("complete job: %w", err) } // Post-transaction operations (operations that do not require transactions or // are external to the database, like audit logging, notifications, etc.) // audit the outcome of the workspace build if getWorkspaceError == nil { // If the workspace has been deleted, notify the owner about it. if workspaceBuild.Transition == database.WorkspaceTransitionDelete { s.notifyWorkspaceDeleted(ctx, workspace, workspaceBuild) } auditor := s.Auditor.Load() auditAction := auditActionFromTransition(workspaceBuild.Transition) previousBuildNumber := workspaceBuild.BuildNumber - 1 previousBuild, prevBuildErr := s.Database.GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx, database.GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams{ WorkspaceID: workspace.ID, BuildNumber: previousBuildNumber, }) if prevBuildErr != nil { previousBuild = database.WorkspaceBuild{} } // We pass the below information to the Auditor so that it // can form a friendly string for the user to view in the UI. buildResourceInfo := audit.AdditionalFields{ WorkspaceName: workspace.Name, BuildNumber: strconv.FormatInt(int64(workspaceBuild.BuildNumber), 10), BuildReason: database.BuildReason(string(workspaceBuild.Reason)), WorkspaceID: workspace.ID, } wriBytes, err := json.Marshal(buildResourceInfo) if err != nil { s.Logger.Error(ctx, "marshal resource info for successful job", slog.Error(err)) } bag := audit.BaggageFromContext(ctx) audit.BackgroundAudit(ctx, &audit.BackgroundAuditParams[database.WorkspaceBuild]{ Audit: *auditor, Log: s.Logger, UserID: job.InitiatorID, OrganizationID: workspace.OrganizationID, RequestID: job.ID, IP: bag.IP, Action: auditAction, Old: previousBuild, New: workspaceBuild, Status: http.StatusOK, AdditionalFields: wriBytes, }) } // Record AI seat usage for successful task workspace builds. if workspaceBuild.Transition == database.WorkspaceTransitionStart && workspace.TaskID.Valid { s.AISeatTracker.RecordUsage(ctx, workspace.OwnerID, aiseats.ReasonTask("task workspace build succeeded")) } if s.PrebuildsOrchestrator != nil && input.PrebuiltWorkspaceBuildStage == sdkproto.PrebuiltWorkspaceBuildStage_CLAIM { // Track resource replacements, if there are any. orchestrator := s.PrebuildsOrchestrator.Load() if resourceReplacements := jobType.WorkspaceBuild.ResourceReplacements; orchestrator != nil && len(resourceReplacements) > 0 { // Fire and forget. Bind to the lifecycle of the server so shutdowns are handled gracefully. go (*orchestrator).TrackResourceReplacement(s.lifecycleCtx, workspace.ID, workspaceBuild.ID, resourceReplacements) } } // Update workspace (regular and prebuild) timing metrics // Only consider 'start' workspace builds if s.metrics != nil && workspaceBuild.Transition == database.WorkspaceTransitionStart { // Get the updated job to report the metrics with correct data updatedJob, err := s.Database.GetProvisionerJobByID(ctx, jobID) if err != nil { s.Logger.Error(ctx, "get updated job from database", slog.Error(err)) } else // Only consider 'succeeded' provisioner jobs if updatedJob.JobStatus == database.ProvisionerJobStatusSucceeded { presetName := "" if workspaceBuild.TemplateVersionPresetID.Valid { preset, err := s.Database.GetPresetByID(ctx, workspaceBuild.TemplateVersionPresetID.UUID) if err != nil { if !errors.Is(err, sql.ErrNoRows) { s.Logger.Error(ctx, "get preset by ID for workspace timing metrics", slog.Error(err)) } } else { presetName = preset.Name } } buildTime := updatedJob.CompletedAt.Time.Sub(updatedJob.StartedAt.Time).Seconds() flags := WorkspaceTimingFlags{ // Is a prebuilt workspace creation build IsPrebuild: input.PrebuiltWorkspaceBuildStage.IsPrebuild(), // Is a prebuilt workspace claim build IsClaim: input.PrebuiltWorkspaceBuildStage.IsPrebuiltWorkspaceClaim(), // Is a regular workspace creation build // Only consider the first build number for regular workspaces IsFirstBuild: workspaceBuild.BuildNumber == 1, } // Only track metrics for prebuild creation, prebuild claims and workspace creation if flags.IsTrackable() { s.metrics.UpdateWorkspaceTimingsMetrics( ctx, flags, workspace.OrganizationName, workspace.TemplateName, presetName, buildTime, ) } } } msg, err := json.Marshal(wspubsub.WorkspaceEvent{ Kind: wspubsub.WorkspaceEventKindStateChange, WorkspaceID: workspace.ID, }) if err != nil { return xerrors.Errorf("marshal workspace update event: %s", err) } err = s.Pubsub.Publish(wspubsub.WorkspaceEventChannel(workspace.OwnerID), msg) if err != nil { return xerrors.Errorf("update workspace: %w", err) } // Publish workspace build update to the all builds channel if the experiment is enabled. if s.Experiments.Enabled(codersdk.ExperimentWorkspaceBuildUpdates) { err = wspubsub.PublishWorkspaceBuildUpdate(ctx, s.Pubsub, codersdk.WorkspaceBuildUpdate{ WorkspaceID: workspace.ID, WorkspaceName: workspace.Name, BuildID: workspaceBuild.ID, Transition: string(workspaceBuild.Transition), JobStatus: string(database.ProvisionerJobStatusSucceeded), BuildNumber: workspaceBuild.BuildNumber, }) if err != nil { s.Logger.Warn(ctx, "failed to publish workspace build update", slog.Error(err)) } } if input.PrebuiltWorkspaceBuildStage == sdkproto.PrebuiltWorkspaceBuildStage_CLAIM { s.Logger.Info(ctx, "workspace prebuild successfully claimed by user", slog.F("workspace_id", workspace.ID)) err = prebuilds.NewPubsubWorkspaceClaimPublisher(s.Pubsub).PublishWorkspaceClaim(agentsdk.ReinitializationEvent{ WorkspaceID: workspace.ID, Reason: agentsdk.ReinitializeReasonPrebuildClaimed, OwnerID: workspace.OwnerID, }) if err != nil { s.Logger.Error(ctx, "failed to publish workspace claim event", slog.Error(err)) } } return nil } // completeTemplateDryRunJob handles completion of a template dry-run job. // All database operations are performed within a transaction. func (s *server) completeTemplateDryRunJob(ctx context.Context, job database.ProvisionerJob, jobID uuid.UUID, jobType *proto.CompletedJob_TemplateDryRun_, telemetrySnapshot *telemetry.Snapshot) error { // Execute all database operations in a transaction return s.Database.InTx(func(db database.Store) error { now := s.timeNow() // Process resources for _, resource := range jobType.TemplateDryRun.Resources { s.Logger.Info(ctx, "inserting template dry-run job resource", slog.F("job_id", job.ID.String()), slog.F("resource_name", resource.Name), slog.F("resource_type", resource.Type)) err := InsertWorkspaceResource(ctx, db, jobID, database.WorkspaceTransitionStart, resource, telemetrySnapshot) if err != nil { return xerrors.Errorf("insert resource: %w", err) } } // Process modules for _, module := range jobType.TemplateDryRun.Modules { s.Logger.Info(ctx, "inserting template dry-run job module", slog.F("job_id", job.ID.String()), slog.F("module_source", module.Source), ) if err := InsertWorkspaceModule(ctx, db, jobID, database.WorkspaceTransitionStart, module, telemetrySnapshot); err != nil { return xerrors.Errorf("insert module: %w", err) } } // Mark job as complete err := db.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{ ID: jobID, UpdatedAt: now, CompletedAt: sql.NullTime{ Time: now, Valid: true, }, Error: sql.NullString{}, ErrorCode: sql.NullString{}, }) if err != nil { return xerrors.Errorf("update provisioner job: %w", err) } s.Logger.Debug(ctx, "marked template dry-run job as completed", slog.F("job_id", jobID)) return nil }, nil) // End of transaction } func (s *server) notifyWorkspaceDeleted(ctx context.Context, workspace database.Workspace, build database.WorkspaceBuild) { var reason string initiator := build.InitiatorByUsername if build.Reason.Valid() { switch build.Reason { case database.BuildReasonInitiator: if build.InitiatorID == workspace.OwnerID { // Deletions initiated by self should not notify. return } reason = "initiated by user" case database.BuildReasonAutodelete: reason = "autodeleted due to dormancy" initiator = "autobuild" default: reason = string(build.Reason) } } else { reason = string(build.Reason) s.Logger.Warn(ctx, "invalid build reason when sending deletion notification", slog.F("reason", reason), slog.F("workspace_id", workspace.ID), slog.F("build_id", build.ID)) } if _, err := s.NotificationsEnqueuer.Enqueue(ctx, workspace.OwnerID, notifications.TemplateWorkspaceDeleted, map[string]string{ "name": workspace.Name, "reason": reason, "initiator": initiator, }, "provisionerdserver", // Associate this notification with all the related entities. workspace.ID, workspace.OwnerID, workspace.TemplateID, workspace.OrganizationID, ); err != nil { s.Logger.Warn(ctx, "failed to notify of workspace deletion", slog.Error(err)) } } func (s *server) startTrace(ctx context.Context, name string, opts ...trace.SpanStartOption) (context.Context, trace.Span) { return s.Tracer.Start(ctx, name, append(opts, trace.WithAttributes( semconv.ServiceNameKey.String("coderd.provisionerd"), ))...) } func InsertWorkspaceModule(ctx context.Context, db database.Store, jobID uuid.UUID, transition database.WorkspaceTransition, protoModule *sdkproto.Module, snapshot *telemetry.Snapshot) error { module, err := db.InsertWorkspaceModule(ctx, database.InsertWorkspaceModuleParams{ ID: uuid.New(), CreatedAt: dbtime.Now(), JobID: jobID, Transition: transition, Source: protoModule.Source, Version: protoModule.Version, Key: protoModule.Key, }) if err != nil { return xerrors.Errorf("insert provisioner job module %q: %w", protoModule.Source, err) } snapshot.WorkspaceModules = append(snapshot.WorkspaceModules, telemetry.ConvertWorkspaceModule(module)) return nil } func InsertWorkspacePresetsAndParameters(ctx context.Context, logger slog.Logger, db database.Store, jobID uuid.UUID, templateVersionID uuid.UUID, protoPresets []*sdkproto.Preset, t time.Time) error { for _, preset := range protoPresets { logger.Info(ctx, "inserting template import job preset", slog.F("job_id", jobID.String()), slog.F("preset_name", preset.Name), ) if err := InsertWorkspacePresetAndParameters(ctx, db, templateVersionID, preset, t); err != nil { return xerrors.Errorf("insert workspace preset: %w", err) } } return nil } func InsertWorkspacePresetAndParameters(ctx context.Context, db database.Store, templateVersionID uuid.UUID, protoPreset *sdkproto.Preset, t time.Time) error { err := db.InTx(func(tx database.Store) error { var ( desiredInstances sql.NullInt32 ttl sql.NullInt32 schedulingEnabled bool schedulingTimezone string prebuildSchedules []*sdkproto.Schedule ) if protoPreset != nil && protoPreset.Prebuild != nil { desiredInstances = sql.NullInt32{ Int32: protoPreset.Prebuild.Instances, Valid: true, } if protoPreset.Prebuild.ExpirationPolicy != nil { ttl = sql.NullInt32{ Int32: protoPreset.Prebuild.ExpirationPolicy.Ttl, Valid: true, } } if protoPreset.Prebuild.Scheduling != nil { schedulingEnabled = true schedulingTimezone = protoPreset.Prebuild.Scheduling.Timezone prebuildSchedules = protoPreset.Prebuild.Scheduling.Schedule } } dbPreset, err := tx.InsertPreset(ctx, database.InsertPresetParams{ ID: uuid.New(), TemplateVersionID: templateVersionID, Name: protoPreset.Name, CreatedAt: t, DesiredInstances: desiredInstances, InvalidateAfterSecs: ttl, SchedulingTimezone: schedulingTimezone, IsDefault: protoPreset.GetDefault(), Description: protoPreset.Description, Icon: protoPreset.Icon, LastInvalidatedAt: sql.NullTime{}, }) if err != nil { return xerrors.Errorf("insert preset: %w", err) } if schedulingEnabled { for _, schedule := range prebuildSchedules { _, err := tx.InsertPresetPrebuildSchedule(ctx, database.InsertPresetPrebuildScheduleParams{ PresetID: dbPreset.ID, CronExpression: schedule.Cron, DesiredInstances: schedule.Instances, }) if err != nil { return xerrors.Errorf("failed to insert preset prebuild schedule: %w", err) } } } var presetParameterNames []string var presetParameterValues []string for _, parameter := range protoPreset.Parameters { presetParameterNames = append(presetParameterNames, parameter.Name) presetParameterValues = append(presetParameterValues, parameter.Value) } _, err = tx.InsertPresetParameters(ctx, database.InsertPresetParametersParams{ TemplateVersionPresetID: dbPreset.ID, Names: presetParameterNames, Values: presetParameterValues, }) if err != nil { return xerrors.Errorf("insert preset parameters: %w", err) } return nil }, nil) if err != nil { return xerrors.Errorf("insert preset and parameters: %w", err) } return nil } type insertWorkspaceResourceOptions struct { useAgentIDsFromProto bool } // InsertWorkspaceResourceOption represents a functional option for // InsertWorkspaceResource. type InsertWorkspaceResourceOption func(*insertWorkspaceResourceOptions) // InsertWorkspaceResourceWithAgentIDsFromProto allows inserting agents into the // database using the agent IDs defined in the proto resource. func InsertWorkspaceResourceWithAgentIDsFromProto() InsertWorkspaceResourceOption { return func(opts *insertWorkspaceResourceOptions) { opts.useAgentIDsFromProto = true } } func InsertWorkspaceResource(ctx context.Context, db database.Store, jobID uuid.UUID, transition database.WorkspaceTransition, protoResource *sdkproto.Resource, snapshot *telemetry.Snapshot, opt ...InsertWorkspaceResourceOption) error { opts := &insertWorkspaceResourceOptions{} for _, o := range opt { o(opts) } resource, err := db.InsertWorkspaceResource(ctx, database.InsertWorkspaceResourceParams{ ID: uuid.New(), CreatedAt: dbtime.Now(), JobID: jobID, Transition: transition, Type: protoResource.Type, Name: protoResource.Name, Hide: protoResource.Hide, Icon: protoResource.Icon, DailyCost: protoResource.DailyCost, InstanceType: sql.NullString{ String: protoResource.InstanceType, Valid: protoResource.InstanceType != "", }, ModulePath: sql.NullString{ String: protoResource.ModulePath, // empty string is root module Valid: true, }, }) if err != nil { return xerrors.Errorf("insert provisioner job resource %q: %w", protoResource.Name, err) } snapshot.WorkspaceResources = append(snapshot.WorkspaceResources, telemetry.ConvertWorkspaceResource(resource)) var ( agentNames = make(map[string]struct{}) appSlugs = make(map[string]struct{}) ) for _, prAgent := range protoResource.Agents { // Similar logic is duplicated in terraform/resources.go. if prAgent.Name == "" { return xerrors.Errorf("agent name cannot be empty") } // In 2025-02 we removed support for underscores in agent names. To // provide a nicer error message, we check the regex first and check // for underscores if it fails. if !provisioner.AgentNameRegex.MatchString(prAgent.Name) { if strings.Contains(prAgent.Name, "_") { return xerrors.Errorf("agent name %q contains underscores which are no longer supported, please use hyphens instead (regex: %q)", prAgent.Name, provisioner.AgentNameRegex.String()) } return xerrors.Errorf("agent name %q does not match regex %q", prAgent.Name, provisioner.AgentNameRegex.String()) } // Agent names must be case-insensitive-unique, to be unambiguous in // `coder_app`s and CoderVPN DNS names. if _, ok := agentNames[strings.ToLower(prAgent.Name)]; ok { return xerrors.Errorf("duplicate agent name %q", prAgent.Name) } agentNames[strings.ToLower(prAgent.Name)] = struct{}{} var instanceID sql.NullString if prAgent.GetInstanceId() != "" { instanceID = sql.NullString{ String: prAgent.GetInstanceId(), Valid: true, } } env := make(map[string]string) // Apply extra envs with merge strategy support. // When multiple coder_env resources define the same name, // the merge_strategy controls how values are combined. if err := MergeExtraEnvs(env, prAgent.ExtraEnvs); err != nil { return err } // Allow the agent defined envs to override extra envs. for k, v := range prAgent.Env { env[k] = v } var envJSON pqtype.NullRawMessage if len(env) > 0 { data, err := json.Marshal(env) if err != nil { return xerrors.Errorf("marshal env: %w", err) } envJSON = pqtype.NullRawMessage{ RawMessage: data, Valid: true, } } authToken := uuid.New() if prAgent.GetToken() != "" { authToken, err = uuid.Parse(prAgent.GetToken()) if err != nil { return xerrors.Errorf("invalid auth token format; must be uuid: %w", err) } } apiKeyScope := database.AgentKeyScopeEnumAll if prAgent.ApiKeyScope == string(database.AgentKeyScopeEnumNoUserData) { apiKeyScope = database.AgentKeyScopeEnumNoUserData } agentID := uuid.New() if opts.useAgentIDsFromProto { agentID, err = uuid.Parse(prAgent.Id) if err != nil { return xerrors.Errorf("invalid agent ID format; must be uuid: %w", err) } } dbAgent, err := db.InsertWorkspaceAgent(ctx, database.InsertWorkspaceAgentParams{ ID: agentID, ParentID: uuid.NullUUID{}, CreatedAt: dbtime.Now(), UpdatedAt: dbtime.Now(), ResourceID: resource.ID, Name: prAgent.Name, AuthToken: authToken, AuthInstanceID: instanceID, Architecture: prAgent.Architecture, EnvironmentVariables: envJSON, Directory: prAgent.Directory, OperatingSystem: prAgent.OperatingSystem, ConnectionTimeoutSeconds: prAgent.GetConnectionTimeoutSeconds(), TroubleshootingURL: prAgent.GetTroubleshootingUrl(), MOTDFile: prAgent.GetMotdFile(), DisplayApps: convertDisplayApps(prAgent.GetDisplayApps()), InstanceMetadata: pqtype.NullRawMessage{}, ResourceMetadata: pqtype.NullRawMessage{}, // #nosec G115 - Order represents a display order value that's always small and fits in int32 DisplayOrder: int32(prAgent.Order), APIKeyScope: apiKeyScope, }) if err != nil { return xerrors.Errorf("insert agent: %w", err) } snapshot.WorkspaceAgents = append(snapshot.WorkspaceAgents, telemetry.ConvertWorkspaceAgent(dbAgent)) for _, md := range prAgent.Metadata { p := database.InsertWorkspaceAgentMetadataParams{ WorkspaceAgentID: agentID, DisplayName: md.DisplayName, Script: md.Script, Key: md.Key, Timeout: md.Timeout, Interval: md.Interval, // #nosec G115 - Order represents a display order value that's always small and fits in int32 DisplayOrder: int32(md.Order), } err := db.InsertWorkspaceAgentMetadata(ctx, p) if err != nil { return xerrors.Errorf("insert agent metadata: %w, params: %+v", err, p) } } if prAgent.ResourcesMonitoring != nil { if prAgent.ResourcesMonitoring.Memory != nil { _, err = db.InsertMemoryResourceMonitor(ctx, database.InsertMemoryResourceMonitorParams{ AgentID: agentID, Enabled: prAgent.ResourcesMonitoring.Memory.Enabled, Threshold: prAgent.ResourcesMonitoring.Memory.Threshold, State: database.WorkspaceAgentMonitorStateOK, CreatedAt: dbtime.Now(), UpdatedAt: dbtime.Now(), DebouncedUntil: time.Time{}, }) if err != nil { return xerrors.Errorf("failed to insert agent memory resource monitor into db: %w", err) } } for _, volume := range prAgent.ResourcesMonitoring.Volumes { _, err = db.InsertVolumeResourceMonitor(ctx, database.InsertVolumeResourceMonitorParams{ AgentID: agentID, Path: volume.Path, Enabled: volume.Enabled, Threshold: volume.Threshold, State: database.WorkspaceAgentMonitorStateOK, CreatedAt: dbtime.Now(), UpdatedAt: dbtime.Now(), DebouncedUntil: time.Time{}, }) if err != nil { return xerrors.Errorf("failed to insert agent volume resource monitor into db: %w", err) } } } scriptsParams := agentScriptsFromProto(prAgent.Scripts) // Dev Containers require a script and log/source, so we do this before // the logs insert below. if devcontainers := prAgent.GetDevcontainers(); len(devcontainers) > 0 { var ( devcontainerIDs = make([]uuid.UUID, 0, len(devcontainers)) devcontainerNames = make([]string, 0, len(devcontainers)) devcontainerWorkspaceFolders = make([]string, 0, len(devcontainers)) devcontainerConfigPaths = make([]string, 0, len(devcontainers)) devcontainerSubagentIDs = make([]uuid.UUID, 0, len(devcontainers)) ) for _, dc := range devcontainers { id := uuid.New() if opts.useAgentIDsFromProto { id, err = uuid.Parse(dc.GetId()) if err != nil { return xerrors.Errorf("invalid devcontainer ID format; must be uuid: %w", err) } } subAgentID, err := insertDevcontainerSubagent(ctx, db, dc, dbAgent, resource.ID, appSlugs, snapshot, opts) if err != nil { return xerrors.Errorf("insert devcontainer %q subagent: %w", dc.GetName(), err) } devcontainerIDs = append(devcontainerIDs, id) devcontainerNames = append(devcontainerNames, dc.GetName()) devcontainerWorkspaceFolders = append(devcontainerWorkspaceFolders, dc.GetWorkspaceFolder()) devcontainerConfigPaths = append(devcontainerConfigPaths, dc.GetConfigPath()) devcontainerSubagentIDs = append(devcontainerSubagentIDs, subAgentID) // Add a log source and script for each devcontainer so we can // track logs and timings for each devcontainer. displayName := fmt.Sprintf("Dev Container (%s)", dc.GetName()) scriptsParams.LogSourceIDs = append(scriptsParams.LogSourceIDs, uuid.New()) scriptsParams.LogSourceDisplayNames = append(scriptsParams.LogSourceDisplayNames, displayName) scriptsParams.LogSourceIcons = append(scriptsParams.LogSourceIcons, "/emojis/1f4e6.png") // Emoji package. Or perhaps /icon/container.svg? scriptsParams.ScriptIDs = append(scriptsParams.ScriptIDs, id) // Re-use the devcontainer ID as the script ID for identification. scriptsParams.ScriptDisplayNames = append(scriptsParams.ScriptDisplayNames, displayName) scriptsParams.ScriptLogPaths = append(scriptsParams.ScriptLogPaths, "") scriptsParams.ScriptSources = append(scriptsParams.ScriptSources, "") scriptsParams.ScriptCron = append(scriptsParams.ScriptCron, "") scriptsParams.ScriptTimeout = append(scriptsParams.ScriptTimeout, 0) scriptsParams.ScriptStartBlocksLogin = append(scriptsParams.ScriptStartBlocksLogin, false) scriptsParams.ScriptRunOnStart = append(scriptsParams.ScriptRunOnStart, false) scriptsParams.ScriptRunOnStop = append(scriptsParams.ScriptRunOnStop, false) } _, err = db.InsertWorkspaceAgentDevcontainers(ctx, database.InsertWorkspaceAgentDevcontainersParams{ WorkspaceAgentID: agentID, CreatedAt: dbtime.Now(), ID: devcontainerIDs, Name: devcontainerNames, WorkspaceFolder: devcontainerWorkspaceFolders, ConfigPath: devcontainerConfigPaths, SubagentID: devcontainerSubagentIDs, }) if err != nil { return xerrors.Errorf("insert agent devcontainer: %w", err) } } if err := insertAgentScriptsAndLogSources(ctx, db, agentID, scriptsParams); err != nil { return xerrors.Errorf("insert agent scripts and log sources: %w", err) } for _, app := range prAgent.Apps { if err := insertAgentApp(ctx, db, dbAgent.ID, app, appSlugs, snapshot); err != nil { return xerrors.Errorf("insert agent app: %w", err) } } } arg := database.InsertWorkspaceResourceMetadataParams{ WorkspaceResourceID: resource.ID, Key: []string{}, Value: []string{}, Sensitive: []bool{}, } for _, metadatum := range protoResource.Metadata { if metadatum.IsNull { continue } arg.Key = append(arg.Key, metadatum.Key) arg.Value = append(arg.Value, metadatum.Value) arg.Sensitive = append(arg.Sensitive, metadatum.Sensitive) } _, err = db.InsertWorkspaceResourceMetadata(ctx, arg) if err != nil { return xerrors.Errorf("insert workspace resource metadata: %w", err) } return nil } func WorkspaceSessionTokenName(ownerID, workspaceID uuid.UUID) string { return fmt.Sprintf("%s_%s_session_token", ownerID, workspaceID) } func (s *server) regenerateSessionToken(ctx context.Context, user database.User, workspace database.Workspace) (string, error) { // NOTE(Cian): Once a workspace is claimed, there's no reason for the session token to be valid any longer. // Not generating any session token at all for a system user may unintentionally break existing templates, // which we want to avoid. If there's no session token for the workspace belonging to the prebuilds user, // then there's nothing for us to worry about here. // TODO(Cian): Update this to handle _all_ system users. At the time of writing, only one system user exists. if err := deleteSessionTokenForUserAndWorkspace(ctx, s.Database, database.PrebuildsSystemUserID, workspace.ID); err != nil && !errors.Is(err, sql.ErrNoRows) { s.Logger.Error(ctx, "failed to delete prebuilds session token", slog.Error(err), slog.F("workspace_id", workspace.ID)) } newkey, sessionToken, err := apikey.Generate(apikey.CreateParams{ UserID: user.ID, LoginType: user.LoginType, TokenName: WorkspaceSessionTokenName(workspace.OwnerID, workspace.ID), DefaultLifetime: s.DeploymentValues.Sessions.DefaultTokenDuration.Value(), LifetimeSeconds: int64(s.DeploymentValues.Sessions.MaximumTokenDuration.Value().Seconds()), }) if err != nil { return "", xerrors.Errorf("generate API key: %w", err) } err = s.Database.InTx(func(tx database.Store) error { err := deleteSessionToken(ctx, tx, workspace) if err != nil { return xerrors.Errorf("delete session token: %w", err) } _, err = tx.InsertAPIKey(ctx, newkey) if err != nil { return xerrors.Errorf("insert API key: %w", err) } return nil }, nil) if err != nil { return "", xerrors.Errorf("create API key: %w", err) } return sessionToken, nil } func deleteSessionToken(ctx context.Context, db database.Store, workspace database.Workspace) error { return deleteSessionTokenForUserAndWorkspace(ctx, db, workspace.OwnerID, workspace.ID) } func deleteSessionTokenForUserAndWorkspace(ctx context.Context, db database.Store, userID, workspaceID uuid.UUID) error { err := db.InTx(func(tx database.Store) error { key, err := tx.GetAPIKeyByName(ctx, database.GetAPIKeyByNameParams{ UserID: userID, TokenName: WorkspaceSessionTokenName(userID, workspaceID), }) if err == nil { err = tx.DeleteAPIKeyByID(ctx, key.ID) } if err != nil && !xerrors.Is(err, sql.ErrNoRows) { return xerrors.Errorf("get api key by name: %w", err) } return nil }, nil) if err != nil { return xerrors.Errorf("in tx: %w", err) } return nil } func shouldRefreshOIDCToken(link database.UserLink) (bool, time.Time) { if link.OAuthRefreshToken == "" { // We cannot refresh even if we wanted to return false, link.OAuthExpiry } if link.OAuthExpiry.IsZero() { // 0 expire means the token never expires, so we shouldn't refresh return false, link.OAuthExpiry } // This handles an edge case where the token is about to expire. A workspace // build takes a non-trivial amount of time. If the token is to expire during the // build, then the build risks failure. To mitigate this, refresh the token // prematurely. // // If an OIDC provider issues short-lived tokens less than our defined period, // the token will always be refreshed on every workspace build. // // By setting the expiration backwards, we are effectively shortening the // time a token can be alive for by 10 minutes. // Note: This is how it is done in the oauth2 package's own token refreshing logic. expiresAt := link.OAuthExpiry.Add(-time.Minute * 10) // Return if the token is assumed to be expired. return expiresAt.Before(dbtime.Now()), expiresAt } // ObtainOIDCAccessToken returns a valid OpenID Connect access token // for the user if it's able to obtain one, otherwise it returns an empty string. func ObtainOIDCAccessToken(ctx context.Context, logger slog.Logger, db database.Store, oidcConfig promoauth.OAuth2Config, userID uuid.UUID) (string, error) { link, err := db.GetUserLinkByUserIDLoginType(ctx, database.GetUserLinkByUserIDLoginTypeParams{ UserID: userID, LoginType: database.LoginTypeOIDC, }) if errors.Is(err, sql.ErrNoRows) { return "", nil } if err != nil { return "", xerrors.Errorf("get owner oidc link: %w", err) } if shouldRefresh, expiresAt := shouldRefreshOIDCToken(link); shouldRefresh { token, err := oidcConfig.TokenSource(ctx, &oauth2.Token{ AccessToken: link.OAuthAccessToken, RefreshToken: link.OAuthRefreshToken, // Use the expiresAt returned by shouldRefreshOIDCToken. // It will force a refresh with an expired time. Expiry: expiresAt, }).Token() if err != nil { // If OIDC fails to refresh, we return an empty string and don't fail. // There isn't a way to hard-opt in to OIDC from a template, so we don't // want to fail builds if users haven't authenticated for a while or something. return "", nil } link.OAuthAccessToken = token.AccessToken link.OAuthRefreshToken = token.RefreshToken link.OAuthExpiry = token.Expiry link, err = db.UpdateUserLink(ctx, database.UpdateUserLinkParams{ UserID: userID, LoginType: database.LoginTypeOIDC, OAuthAccessToken: link.OAuthAccessToken, OAuthAccessTokenKeyID: sql.NullString{}, // set by dbcrypt if required OAuthRefreshToken: link.OAuthRefreshToken, OAuthRefreshTokenKeyID: sql.NullString{}, // set by dbcrypt if required OAuthExpiry: link.OAuthExpiry, Claims: link.Claims, }) if err != nil { return "", xerrors.Errorf("update user link: %w", err) } logger.Info(ctx, "refreshed expired OIDC token for user during workspace build", slog.F("user_id", userID)) } return link.OAuthAccessToken, nil } func convertLogLevel(logLevel sdkproto.LogLevel) (database.LogLevel, error) { switch logLevel { case sdkproto.LogLevel_TRACE: return database.LogLevelTrace, nil case sdkproto.LogLevel_DEBUG: return database.LogLevelDebug, nil case sdkproto.LogLevel_INFO: return database.LogLevelInfo, nil case sdkproto.LogLevel_WARN: return database.LogLevelWarn, nil case sdkproto.LogLevel_ERROR: return database.LogLevelError, nil default: return database.LogLevel(""), xerrors.Errorf("unknown log level: %d", logLevel) } } func convertLogSource(logSource proto.LogSource) (database.LogSource, error) { switch logSource { case proto.LogSource_PROVISIONER_DAEMON: return database.LogSourceProvisionerDaemon, nil case proto.LogSource_PROVISIONER: return database.LogSourceProvisioner, nil default: return database.LogSource(""), xerrors.Errorf("unknown log source: %d", logSource) } } func convertRichParameterValues(workspaceBuildParameters []database.WorkspaceBuildParameter) []*sdkproto.RichParameterValue { protoParameters := make([]*sdkproto.RichParameterValue, len(workspaceBuildParameters)) for i, buildParameter := range workspaceBuildParameters { protoParameters[i] = &sdkproto.RichParameterValue{ Name: buildParameter.Name, Value: buildParameter.Value, } } return protoParameters } func convertVariableValues(variableValues []codersdk.VariableValue) []*sdkproto.VariableValue { protoVariableValues := make([]*sdkproto.VariableValue, len(variableValues)) for i, variableValue := range variableValues { protoVariableValues[i] = &sdkproto.VariableValue{ Name: variableValue.Name, Value: variableValue.Value, Sensitive: true, // Without the template variable schema we have to assume that every variable may be sensitive. } } return protoVariableValues } func convertWorkspaceTransition(transition database.WorkspaceTransition) (sdkproto.WorkspaceTransition, error) { switch transition { case database.WorkspaceTransitionStart: return sdkproto.WorkspaceTransition_START, nil case database.WorkspaceTransitionStop: return sdkproto.WorkspaceTransition_STOP, nil case database.WorkspaceTransitionDelete: return sdkproto.WorkspaceTransition_DESTROY, nil default: return 0, xerrors.Errorf("unrecognized transition: %q", transition) } } func auditActionFromTransition(transition database.WorkspaceTransition) database.AuditAction { switch transition { case database.WorkspaceTransitionStart: return database.AuditActionStart case database.WorkspaceTransitionStop: return database.AuditActionStop case database.WorkspaceTransitionDelete: return database.AuditActionDelete default: return database.AuditActionWrite } } type TemplateVersionImportJob struct { // TemplateID is not guaranteed to be set. Template versions can be created // without being associated with a template. Resulting in a template id of // `uuid.Nil` TemplateID uuid.NullUUID `json:"template_id"` TemplateVersionID uuid.UUID `json:"template_version_id"` UserVariableValues []codersdk.VariableValue `json:"user_variable_values"` } // WorkspaceProvisionJob is the payload for the "workspace_provision" job type. type WorkspaceProvisionJob struct { WorkspaceBuildID uuid.UUID `json:"workspace_build_id"` DryRun bool `json:"dry_run"` LogLevel string `json:"log_level,omitempty"` PrebuiltWorkspaceBuildStage sdkproto.PrebuiltWorkspaceBuildStage `json:"prebuilt_workspace_stage,omitempty"` } // TemplateVersionDryRunJob is the payload for the "template_version_dry_run" job type. type TemplateVersionDryRunJob struct { TemplateVersionID uuid.UUID `json:"template_version_id"` WorkspaceName string `json:"workspace_name"` RichParameterValues []database.WorkspaceBuildParameter `json:"rich_parameter_values"` } func asVariableValues(templateVariables []database.TemplateVersionVariable) []*sdkproto.VariableValue { var apiVariableValues []*sdkproto.VariableValue for _, v := range templateVariables { value := v.Value if value == "" && v.DefaultValue != "" { value = v.DefaultValue } if value != "" || v.Required { apiVariableValues = append(apiVariableValues, &sdkproto.VariableValue{ Name: v.Name, Value: v.Value, Sensitive: v.Sensitive, }) } } return apiVariableValues } func redactTemplateVariable(templateVariable *sdkproto.TemplateVariable) *sdkproto.TemplateVariable { if templateVariable == nil { return nil } maybeRedacted := &sdkproto.TemplateVariable{ Name: templateVariable.Name, Description: templateVariable.Description, Type: templateVariable.Type, DefaultValue: templateVariable.DefaultValue, Required: templateVariable.Required, Sensitive: templateVariable.Sensitive, } if maybeRedacted.Sensitive { maybeRedacted.DefaultValue = "*redacted*" } return maybeRedacted } func convertDisplayApps(apps *sdkproto.DisplayApps) []database.DisplayApp { // This shouldn't happen but let's avoid panicking. It also makes // writing tests a bit easier. if apps == nil { return nil } dapps := make([]database.DisplayApp, 0, 5) if apps.Vscode { dapps = append(dapps, database.DisplayAppVscode) } if apps.VscodeInsiders { dapps = append(dapps, database.DisplayAppVscodeInsiders) } if apps.SshHelper { dapps = append(dapps, database.DisplayAppSSHHelper) } if apps.PortForwardingHelper { dapps = append(dapps, database.DisplayAppPortForwardingHelper) } if apps.WebTerminal { dapps = append(dapps, database.DisplayAppWebTerminal) } return dapps } // insertDevcontainerSubagent creates a workspace agent for a devcontainer's // subagent if one is defined. It returns the subagent ID (zero UUID if no // subagent is defined). func insertDevcontainerSubagent( ctx context.Context, db database.Store, dc *sdkproto.Devcontainer, parentAgent database.WorkspaceAgent, resourceID uuid.UUID, appSlugs map[string]struct{}, snapshot *telemetry.Snapshot, opts *insertWorkspaceResourceOptions, ) (uuid.UUID, error) { // If there are no attached resources, we don't need to pre-create the // subagent. This preserves backwards compatibility where devcontainers // without resources can have their agents recreated dynamically. if len(dc.GetApps()) == 0 && len(dc.GetScripts()) == 0 && len(dc.GetEnvs()) == 0 { return uuid.UUID{}, nil } subAgentID := uuid.New() if opts.useAgentIDsFromProto { var err error subAgentID, err = uuid.Parse(dc.GetSubagentId()) if err != nil { return uuid.UUID{}, xerrors.Errorf("parse subagent id: %w", err) } } envJSON, err := encodeSubagentEnvs(dc.GetEnvs()) if err != nil { return uuid.UUID{}, err } _, err = db.InsertWorkspaceAgent(ctx, database.InsertWorkspaceAgentParams{ ID: subAgentID, ParentID: uuid.NullUUID{Valid: true, UUID: parentAgent.ID}, CreatedAt: dbtime.Now(), UpdatedAt: dbtime.Now(), ResourceID: resourceID, Name: dc.GetName(), AuthToken: uuid.New(), AuthInstanceID: sql.NullString{}, Architecture: parentAgent.Architecture, EnvironmentVariables: envJSON, Directory: dc.GetWorkspaceFolder(), InstanceMetadata: pqtype.NullRawMessage{}, ResourceMetadata: pqtype.NullRawMessage{}, OperatingSystem: parentAgent.OperatingSystem, ConnectionTimeoutSeconds: parentAgent.ConnectionTimeoutSeconds, TroubleshootingURL: parentAgent.TroubleshootingURL, MOTDFile: "", DisplayApps: []database.DisplayApp{}, DisplayOrder: 0, APIKeyScope: parentAgent.APIKeyScope, }) if err != nil { return uuid.UUID{}, xerrors.Errorf("insert subagent: %w", err) } for _, app := range dc.GetApps() { if err := insertAgentApp(ctx, db, subAgentID, app, appSlugs, snapshot); err != nil { return uuid.UUID{}, xerrors.Errorf("insert agent app: %w", err) } } if err := insertAgentScriptsAndLogSources(ctx, db, subAgentID, agentScriptsFromProto(dc.GetScripts())); err != nil { return uuid.UUID{}, xerrors.Errorf("insert agent scripts and log sources: %w", err) } return subAgentID, nil } // MergeExtraEnvs applies extra environment variables to the given map, // respecting the merge_strategy field on each env. When merge_strategy // is empty or "replace", the value overwrites any existing entry. // "append" and "prepend" join values with a ":" separator (PATH-style). // "error" causes a failure if the key already exists. func MergeExtraEnvs(env map[string]string, extraEnvs []*sdkproto.Env) error { for _, e := range extraEnvs { strategy := e.GetMergeStrategy() if strategy == "" { strategy = "replace" } existing, exists := env[e.GetName()] switch strategy { case "error": if exists { return xerrors.Errorf( "duplicate env var %q: merge_strategy is %q but variable is already defined", e.GetName(), strategy, ) } env[e.GetName()] = e.GetValue() case "append": if exists && existing != "" { env[e.GetName()] = existing + ":" + e.GetValue() } else { env[e.GetName()] = e.GetValue() } case "prepend": if exists && existing != "" { env[e.GetName()] = e.GetValue() + ":" + existing } else { env[e.GetName()] = e.GetValue() } default: // "replace" env[e.GetName()] = e.GetValue() } } return nil } func encodeSubagentEnvs(envs []*sdkproto.Env) (pqtype.NullRawMessage, error) { if len(envs) == 0 { return pqtype.NullRawMessage{}, nil } subAgentEnvs := make(map[string]string, len(envs)) if err := MergeExtraEnvs(subAgentEnvs, envs); err != nil { return pqtype.NullRawMessage{}, err } data, err := json.Marshal(subAgentEnvs) if err != nil { return pqtype.NullRawMessage{}, xerrors.Errorf("marshal env: %w", err) } return pqtype.NullRawMessage{Valid: true, RawMessage: data}, nil } // agentScriptsParams holds the parameters for inserting agent scripts and // their associated log sources. type agentScriptsParams struct { LogSourceIDs []uuid.UUID LogSourceDisplayNames []string LogSourceIcons []string ScriptIDs []uuid.UUID ScriptDisplayNames []string ScriptLogPaths []string ScriptSources []string ScriptCron []string ScriptTimeout []int32 ScriptStartBlocksLogin []bool ScriptRunOnStart []bool ScriptRunOnStop []bool } // agentScriptsFromProto converts a slice of proto scripts into the // agentScriptsParams struct needed for database insertion. func agentScriptsFromProto(scripts []*sdkproto.Script) agentScriptsParams { params := agentScriptsParams{ LogSourceIDs: make([]uuid.UUID, 0, len(scripts)), LogSourceDisplayNames: make([]string, 0, len(scripts)), LogSourceIcons: make([]string, 0, len(scripts)), ScriptIDs: make([]uuid.UUID, 0, len(scripts)), ScriptDisplayNames: make([]string, 0, len(scripts)), ScriptLogPaths: make([]string, 0, len(scripts)), ScriptSources: make([]string, 0, len(scripts)), ScriptCron: make([]string, 0, len(scripts)), ScriptTimeout: make([]int32, 0, len(scripts)), ScriptStartBlocksLogin: make([]bool, 0, len(scripts)), ScriptRunOnStart: make([]bool, 0, len(scripts)), ScriptRunOnStop: make([]bool, 0, len(scripts)), } for _, script := range scripts { params.LogSourceIDs = append(params.LogSourceIDs, uuid.New()) params.LogSourceDisplayNames = append(params.LogSourceDisplayNames, script.GetDisplayName()) params.LogSourceIcons = append(params.LogSourceIcons, script.GetIcon()) params.ScriptIDs = append(params.ScriptIDs, uuid.New()) params.ScriptDisplayNames = append(params.ScriptDisplayNames, script.GetDisplayName()) params.ScriptLogPaths = append(params.ScriptLogPaths, script.GetLogPath()) params.ScriptSources = append(params.ScriptSources, script.GetScript()) params.ScriptCron = append(params.ScriptCron, script.GetCron()) params.ScriptTimeout = append(params.ScriptTimeout, script.GetTimeoutSeconds()) params.ScriptStartBlocksLogin = append(params.ScriptStartBlocksLogin, script.GetStartBlocksLogin()) params.ScriptRunOnStart = append(params.ScriptRunOnStart, script.GetRunOnStart()) params.ScriptRunOnStop = append(params.ScriptRunOnStop, script.GetRunOnStop()) } return params } // insertAgentScriptsAndLogSources inserts log sources and scripts for an agent (or // subagent). It expects the caller to have built the agentScriptsParams, // allowing for additional entries to be appended before insertion (e.g. for // devcontainers). Returns nil if there are no log sources to insert. func insertAgentScriptsAndLogSources(ctx context.Context, db database.Store, agentID uuid.UUID, params agentScriptsParams) error { if len(params.LogSourceIDs) == 0 { return nil } _, err := db.InsertWorkspaceAgentLogSources(ctx, database.InsertWorkspaceAgentLogSourcesParams{ WorkspaceAgentID: agentID, ID: params.LogSourceIDs, CreatedAt: dbtime.Now(), DisplayName: params.LogSourceDisplayNames, Icon: params.LogSourceIcons, }) if err != nil { return xerrors.Errorf("insert log sources: %w", err) } _, err = db.InsertWorkspaceAgentScripts(ctx, database.InsertWorkspaceAgentScriptsParams{ WorkspaceAgentID: agentID, LogSourceID: params.LogSourceIDs, ID: params.ScriptIDs, LogPath: params.ScriptLogPaths, CreatedAt: dbtime.Now(), Script: params.ScriptSources, Cron: params.ScriptCron, TimeoutSeconds: params.ScriptTimeout, StartBlocksLogin: params.ScriptStartBlocksLogin, RunOnStart: params.ScriptRunOnStart, RunOnStop: params.ScriptRunOnStop, DisplayName: params.ScriptDisplayNames, }) if err != nil { return xerrors.Errorf("insert scripts: %w", err) } return nil } func insertAgentApp(ctx context.Context, db database.Store, agentID uuid.UUID, app *sdkproto.App, appSlugs map[string]struct{}, snapshot *telemetry.Snapshot) error { // Similar logic is duplicated in terraform/resources.go. slug := app.Slug if slug == "" { return xerrors.Errorf("app must have a slug or name set") } // Unlike agent names, app slugs were never permitted to contain uppercase // letters or underscores. if !provisioner.AppSlugRegex.MatchString(slug) { return xerrors.Errorf("app slug %q does not match regex %q", slug, provisioner.AppSlugRegex.String()) } if _, exists := appSlugs[slug]; exists { return xerrors.Errorf("duplicate app slug, must be unique per template: %q", slug) } appSlugs[slug] = struct{}{} health := database.WorkspaceAppHealthDisabled healthcheck := app.GetHealthcheck() if healthcheck == nil { healthcheck = &sdkproto.Healthcheck{} } if healthcheck.Url != "" { health = database.WorkspaceAppHealthInitializing } sharingLevel := database.AppSharingLevelOwner switch app.SharingLevel { case sdkproto.AppSharingLevel_AUTHENTICATED: sharingLevel = database.AppSharingLevelAuthenticated case sdkproto.AppSharingLevel_PUBLIC: sharingLevel = database.AppSharingLevelPublic } displayGroup := sql.NullString{ Valid: app.Group != "", String: app.Group, } openIn := database.WorkspaceAppOpenInSlimWindow switch app.OpenIn { case sdkproto.AppOpenIn_TAB: openIn = database.WorkspaceAppOpenInTab case sdkproto.AppOpenIn_SLIM_WINDOW: openIn = database.WorkspaceAppOpenInSlimWindow } var appID string if app.Id == "" || app.Id == uuid.Nil.String() { appID = uuid.NewString() } else { appID = app.Id } id, err := uuid.Parse(appID) if err != nil { return xerrors.Errorf("parse app uuid: %w", err) } // If workspace apps are "persistent", the ID will not be regenerated across workspace builds, so we have to upsert. dbApp, err := db.UpsertWorkspaceApp(ctx, database.UpsertWorkspaceAppParams{ ID: id, CreatedAt: dbtime.Now(), AgentID: agentID, Slug: slug, DisplayName: app.DisplayName, Icon: app.Icon, Command: sql.NullString{ String: app.Command, Valid: app.Command != "", }, Url: sql.NullString{ String: app.Url, Valid: app.Url != "", }, External: app.External, Subdomain: app.Subdomain, SharingLevel: sharingLevel, HealthcheckUrl: healthcheck.Url, HealthcheckInterval: healthcheck.Interval, HealthcheckThreshold: healthcheck.Threshold, Health: health, // #nosec G115 - Order represents a display order value that's always small and fits in int32 DisplayOrder: int32(app.Order), DisplayGroup: displayGroup, Hidden: app.Hidden, OpenIn: openIn, Tooltip: app.Tooltip, }) if err != nil { return xerrors.Errorf("upsert app: %w", err) } snapshot.WorkspaceApps = append(snapshot.WorkspaceApps, telemetry.ConvertWorkspaceApp(dbApp)) return nil }