From b275be2e7a1731a315d0ffb9f21c86e26d9d5399 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Thu, 5 Feb 2026 16:09:41 -0600 Subject: [PATCH] chore: backport fixes (#21957) --- agent/agentsocket/client.go | 5 +- agent/reaper/reaper.go | 9 ++++ agent/reaper/reaper_stub.go | 4 +- agent/reaper/reaper_test.go | 39 +++++++++++++- agent/reaper/reaper_unix.go | 40 +++++++++++--- cli/agent.go | 64 ++++++++++++++++------- coderd/prebuilds/api.go | 1 + coderd/prebuilds/noop.go | 2 +- coderd/workspaces.go | 2 +- enterprise/cli/create_test.go | 4 +- enterprise/coderd/coderd.go | 15 ++++-- enterprise/coderd/prebuilds/claim.go | 17 +++--- enterprise/coderd/prebuilds/claim_test.go | 10 +++- enterprise/coderd/workspaces_test.go | 12 ++--- 14 files changed, 170 insertions(+), 54 deletions(-) diff --git a/agent/agentsocket/client.go b/agent/agentsocket/client.go index cc8810c987..3ed629e555 100644 --- a/agent/agentsocket/client.go +++ b/agent/agentsocket/client.go @@ -99,7 +99,10 @@ func (c *Client) SyncReady(ctx context.Context, unitName unit.ID) (bool, error) resp, err := c.client.SyncReady(ctx, &proto.SyncReadyRequest{ Unit: string(unitName), }) - return resp.Ready, err + if err != nil { + return false, xerrors.Errorf("sync ready: %w", err) + } + return resp.Ready, nil } // SyncStatus gets the status of a unit and its dependencies. diff --git a/agent/reaper/reaper.go b/agent/reaper/reaper.go index 94f5190d11..d968937a3a 100644 --- a/agent/reaper/reaper.go +++ b/agent/reaper/reaper.go @@ -4,6 +4,8 @@ import ( "os" "github.com/hashicorp/go-reap" + + "cdr.dev/slog/v3" ) type Option func(o *options) @@ -34,8 +36,15 @@ func WithCatchSignals(sigs ...os.Signal) Option { } } +func WithLogger(logger slog.Logger) Option { + return func(o *options) { + o.Logger = logger + } +} + type options struct { ExecArgs []string PIDs reap.PidCh CatchSignals []os.Signal + Logger slog.Logger } diff --git a/agent/reaper/reaper_stub.go b/agent/reaper/reaper_stub.go index 8cd87ab0bf..da4d871fc5 100644 --- a/agent/reaper/reaper_stub.go +++ b/agent/reaper/reaper_stub.go @@ -7,6 +7,6 @@ func IsInitProcess() bool { return false } -func ForkReap(_ ...Option) error { - return nil +func ForkReap(_ ...Option) (int, error) { + return 0, nil } diff --git a/agent/reaper/reaper_test.go b/agent/reaper/reaper_test.go index 84246fba06..7ef3f0a50b 100644 --- a/agent/reaper/reaper_test.go +++ b/agent/reaper/reaper_test.go @@ -32,12 +32,13 @@ func TestReap(t *testing.T) { } pids := make(reap.PidCh, 1) - err := reaper.ForkReap( + exitCode, err := reaper.ForkReap( reaper.WithPIDCallback(pids), // Provide some argument that immediately exits. reaper.WithExecArgs("/bin/sh", "-c", "exit 0"), ) require.NoError(t, err) + require.Equal(t, 0, exitCode) cmd := exec.Command("tail", "-f", "/dev/null") err = cmd.Start() @@ -65,6 +66,36 @@ func TestReap(t *testing.T) { } } +//nolint:paralleltest +func TestForkReapExitCodes(t *testing.T) { + if testutil.InCI() { + t.Skip("Detected CI, skipping reaper tests") + } + + tests := []struct { + name string + command string + expectedCode int + }{ + {"exit 0", "exit 0", 0}, + {"exit 1", "exit 1", 1}, + {"exit 42", "exit 42", 42}, + {"exit 255", "exit 255", 255}, + {"SIGKILL", "kill -9 $$", 128 + 9}, + {"SIGTERM", "kill -15 $$", 128 + 15}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + exitCode, err := reaper.ForkReap( + reaper.WithExecArgs("/bin/sh", "-c", tt.command), + ) + require.NoError(t, err) + require.Equal(t, tt.expectedCode, exitCode, "exit code mismatch for %q", tt.command) + }) + } +} + //nolint:paralleltest // Signal handling. func TestReapInterrupt(t *testing.T) { // Don't run the reaper test in CI. It does weird @@ -84,13 +115,17 @@ func TestReapInterrupt(t *testing.T) { defer signal.Stop(usrSig) go func() { - errC <- reaper.ForkReap( + exitCode, err := reaper.ForkReap( reaper.WithPIDCallback(pids), reaper.WithCatchSignals(os.Interrupt), // Signal propagation does not extend to children of children, so // we create a little bash script to ensure sleep is interrupted. reaper.WithExecArgs("/bin/sh", "-c", fmt.Sprintf("pid=0; trap 'kill -USR2 %d; kill -TERM $pid' INT; sleep 10 &\npid=$!; kill -USR1 %d; wait", os.Getpid(), os.Getpid())), ) + // The child exits with 128 + SIGTERM (15) = 143, but the trap catches + // SIGINT and sends SIGTERM to the sleep process, so exit code varies. + _ = exitCode + errC <- err }() require.Equal(t, <-usrSig, syscall.SIGUSR1) diff --git a/agent/reaper/reaper_unix.go b/agent/reaper/reaper_unix.go index 35ce9bfaa1..b095c5a7f9 100644 --- a/agent/reaper/reaper_unix.go +++ b/agent/reaper/reaper_unix.go @@ -3,12 +3,15 @@ package reaper import ( + "context" "os" "os/signal" "syscall" "github.com/hashicorp/go-reap" "golang.org/x/xerrors" + + "cdr.dev/slog/v3" ) // IsInitProcess returns true if the current process's PID is 1. @@ -16,7 +19,7 @@ func IsInitProcess() bool { return os.Getpid() == 1 } -func catchSignals(pid int, sigs []os.Signal) { +func catchSignals(logger slog.Logger, pid int, sigs []os.Signal) { if len(sigs) == 0 { return } @@ -25,10 +28,19 @@ func catchSignals(pid int, sigs []os.Signal) { signal.Notify(sc, sigs...) defer signal.Stop(sc) + logger.Info(context.Background(), "reaper catching signals", + slog.F("signals", sigs), + slog.F("child_pid", pid), + ) + for { s := <-sc sig, ok := s.(syscall.Signal) if ok { + logger.Info(context.Background(), "reaper caught signal, killing child process", + slog.F("signal", sig.String()), + slog.F("child_pid", pid), + ) _ = syscall.Kill(pid, sig) } } @@ -40,7 +52,10 @@ func catchSignals(pid int, sigs []os.Signal) { // the reaper and an exec.Command waiting for its process to complete. // The provided 'pids' channel may be nil if the caller does not care about the // reaped children PIDs. -func ForkReap(opt ...Option) error { +// +// Returns the child's exit code (using 128+signal for signal termination) +// and any error from Wait4. +func ForkReap(opt ...Option) (int, error) { opts := &options{ ExecArgs: os.Args, } @@ -53,7 +68,7 @@ func ForkReap(opt ...Option) error { pwd, err := os.Getwd() if err != nil { - return xerrors.Errorf("get wd: %w", err) + return 1, xerrors.Errorf("get wd: %w", err) } pattrs := &syscall.ProcAttr{ @@ -72,15 +87,28 @@ func ForkReap(opt ...Option) error { //#nosec G204 pid, err := syscall.ForkExec(opts.ExecArgs[0], opts.ExecArgs, pattrs) if err != nil { - return xerrors.Errorf("fork exec: %w", err) + return 1, xerrors.Errorf("fork exec: %w", err) } - go catchSignals(pid, opts.CatchSignals) + go catchSignals(opts.Logger, pid, opts.CatchSignals) var wstatus syscall.WaitStatus _, err = syscall.Wait4(pid, &wstatus, 0, nil) for xerrors.Is(err, syscall.EINTR) { _, err = syscall.Wait4(pid, &wstatus, 0, nil) } - return err + + // Convert wait status to exit code using standard Unix conventions: + // - Normal exit: use the exit code + // - Signal termination: use 128 + signal number + var exitCode int + switch { + case wstatus.Exited(): + exitCode = wstatus.ExitStatus() + case wstatus.Signaled(): + exitCode = 128 + int(wstatus.Signal()) + default: + exitCode = 1 + } + return exitCode, err } diff --git a/cli/agent.go b/cli/agent.go index 56a8720a41..2e396206c9 100644 --- a/cli/agent.go +++ b/cli/agent.go @@ -9,6 +9,7 @@ import ( "net/http/pprof" "net/url" "os" + "os/signal" "path/filepath" "runtime" "slices" @@ -130,40 +131,29 @@ func workspaceAgent() *serpent.Command { sinks = append(sinks, sloghuman.Sink(logWriter)) logger := inv.Logger.AppendSinks(sinks...).Leveled(slog.LevelDebug) + logger = logger.Named("reaper") logger.Info(ctx, "spawning reaper process") // Do not start a reaper on the child process. It's important // to do this else we fork bomb ourselves. //nolint:gocritic args := append(os.Args, "--no-reap") - err := reaper.ForkReap( + exitCode, err := reaper.ForkReap( reaper.WithExecArgs(args...), reaper.WithCatchSignals(StopSignals...), + reaper.WithLogger(logger), ) if err != nil { logger.Error(ctx, "agent process reaper unable to fork", slog.Error(err)) return xerrors.Errorf("fork reap: %w", err) } - logger.Info(ctx, "reaper process exiting") - return nil + logger.Info(ctx, "child process exited, propagating exit code", + slog.F("exit_code", exitCode), + ) + return ExitError(exitCode, nil) } - // Handle interrupt signals to allow for graceful shutdown, - // note that calling stopNotify disables the signal handler - // and the next interrupt will terminate the program (you - // probably want cancel instead). - // - // Note that we don't want to handle these signals in the - // process that runs as PID 1, that's why we do this after - // the reaper forked. - ctx, stopNotify := inv.SignalNotifyContext(ctx, StopSignals...) - defer stopNotify() - - // DumpHandler does signal handling, so we call it after the - // reaper. - go DumpHandler(ctx, "agent") - logWriter := &clilog.LumberjackWriteCloseFixer{Writer: &lumberjack.Logger{ Filename: filepath.Join(logDir, "coder-agent.log"), MaxSize: 5, // MB @@ -176,6 +166,21 @@ func workspaceAgent() *serpent.Command { sinks = append(sinks, sloghuman.Sink(logWriter)) logger := inv.Logger.AppendSinks(sinks...).Leveled(slog.LevelDebug) + // Handle interrupt signals to allow for graceful shutdown, + // note that calling stopNotify disables the signal handler + // and the next interrupt will terminate the program (you + // probably want cancel instead). + // + // Note that we also handle these signals in the + // process that runs as PID 1, mainly to forward it to the agent child + // so that it can shutdown gracefully. + ctx, stopNotify := logSignalNotifyContext(ctx, logger, StopSignals...) + defer stopNotify() + + // DumpHandler does signal handling, so we call it after the + // reaper. + go DumpHandler(ctx, "agent") + version := buildinfo.Version() logger.Info(ctx, "agent is starting now", slog.F("url", agentAuth.agentURL), @@ -557,3 +562,26 @@ func urlPort(u string) (int, error) { } return -1, xerrors.Errorf("invalid port: %s", u) } + +// logSignalNotifyContext is like signal.NotifyContext but logs the received +// signal before canceling the context. +func logSignalNotifyContext(parent context.Context, logger slog.Logger, signals ...os.Signal) (context.Context, context.CancelFunc) { + ctx, cancel := context.WithCancelCause(parent) + c := make(chan os.Signal, 1) + signal.Notify(c, signals...) + + go func() { + select { + case sig := <-c: + logger.Info(ctx, "agent received signal", slog.F("signal", sig.String())) + cancel(xerrors.Errorf("signal: %s", sig.String())) + case <-ctx.Done(): + logger.Info(ctx, "ctx canceled, stopping signal handler") + } + }() + + return ctx, func() { + cancel(context.Canceled) + signal.Stop(c) + } +} diff --git a/coderd/prebuilds/api.go b/coderd/prebuilds/api.go index 0deab99416..dc5092a06f 100644 --- a/coderd/prebuilds/api.go +++ b/coderd/prebuilds/api.go @@ -63,6 +63,7 @@ type StateSnapshotter interface { type Claimer interface { Claim( ctx context.Context, + store database.Store, now time.Time, userID uuid.UUID, name string, diff --git a/coderd/prebuilds/noop.go b/coderd/prebuilds/noop.go index 0859d428b4..1dda74c1dd 100644 --- a/coderd/prebuilds/noop.go +++ b/coderd/prebuilds/noop.go @@ -34,7 +34,7 @@ var DefaultReconciler ReconciliationOrchestrator = NoopReconciler{} type NoopClaimer struct{} -func (NoopClaimer) Claim(context.Context, time.Time, uuid.UUID, string, uuid.UUID, sql.NullString, sql.NullTime, sql.NullInt64) (*uuid.UUID, error) { +func (NoopClaimer) Claim(context.Context, database.Store, time.Time, uuid.UUID, string, uuid.UUID, sql.NullString, sql.NullTime, sql.NullInt64) (*uuid.UUID, error) { // Not entitled to claim prebuilds in AGPL version. return nil, ErrAGPLDoesNotSupportPrebuiltWorkspaces } diff --git a/coderd/workspaces.go b/coderd/workspaces.go index a02c16ec89..bb6983cbfe 100644 --- a/coderd/workspaces.go +++ b/coderd/workspaces.go @@ -937,7 +937,7 @@ func claimPrebuild( nextStartAt sql.NullTime, ttl sql.NullInt64, ) (*database.Workspace, error) { - claimedID, err := claimer.Claim(ctx, now, owner.ID, name, templateVersionPresetID, autostartSchedule, nextStartAt, ttl) + claimedID, err := claimer.Claim(ctx, db, now, owner.ID, name, templateVersionPresetID, autostartSchedule, nextStartAt, ttl) if err != nil { // TODO: enhance this by clarifying whether this *specific* prebuild failed or whether there are none to claim. return nil, xerrors.Errorf("claim prebuild: %w", err) diff --git a/enterprise/cli/create_test.go b/enterprise/cli/create_test.go index 44218abb5a..941908d17b 100644 --- a/enterprise/cli/create_test.go +++ b/enterprise/cli/create_test.go @@ -371,7 +371,7 @@ func TestEnterpriseCreateWithPreset(t *testing.T) { notifications.NewNoopEnqueuer(), newNoopUsageCheckerPtr(), ) - var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer(db) + var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer() api.AGPL.PrebuildsClaimer.Store(&claimer) // Given: a template and a template version where the preset defines values for all required parameters, @@ -482,7 +482,7 @@ func TestEnterpriseCreateWithPreset(t *testing.T) { notifications.NewNoopEnqueuer(), newNoopUsageCheckerPtr(), ) - var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer(db) + var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer() api.AGPL.PrebuildsClaimer.Store(&claimer) // Given: a template and a template version where the preset defines values for all required parameters, diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index 2dc10de7c3..2b5c9d8d47 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -1307,7 +1307,16 @@ func (api *API) setupPrebuilds(featureEnabled bool) (agplprebuilds.Reconciliatio return agplprebuilds.DefaultReconciler, agplprebuilds.DefaultClaimer } - reconciler := prebuilds.NewStoreReconciler(api.Database, api.Pubsub, api.AGPL.FileCache, api.DeploymentValues.Prebuilds, - api.Logger.Named("prebuilds"), quartz.NewReal(), api.PrometheusRegistry, api.NotificationsEnqueuer, api.AGPL.BuildUsageChecker) - return reconciler, prebuilds.NewEnterpriseClaimer(api.Database) + reconciler := prebuilds.NewStoreReconciler( + api.Database, + api.Pubsub, + api.AGPL.FileCache, + api.DeploymentValues.Prebuilds, + api.Logger.Named("prebuilds"), + quartz.NewReal(), + api.PrometheusRegistry, + api.NotificationsEnqueuer, + api.AGPL.BuildUsageChecker, + ) + return reconciler, prebuilds.NewEnterpriseClaimer() } diff --git a/enterprise/coderd/prebuilds/claim.go b/enterprise/coderd/prebuilds/claim.go index 743513cedb..e057fb03d6 100644 --- a/enterprise/coderd/prebuilds/claim.go +++ b/enterprise/coderd/prebuilds/claim.go @@ -13,18 +13,15 @@ import ( "github.com/coder/coder/v2/coderd/prebuilds" ) -type EnterpriseClaimer struct { - store database.Store +type EnterpriseClaimer struct{} + +func NewEnterpriseClaimer() *EnterpriseClaimer { + return &EnterpriseClaimer{} } -func NewEnterpriseClaimer(store database.Store) *EnterpriseClaimer { - return &EnterpriseClaimer{ - store: store, - } -} - -func (c EnterpriseClaimer) Claim( +func (EnterpriseClaimer) Claim( ctx context.Context, + store database.Store, now time.Time, userID uuid.UUID, name string, @@ -33,7 +30,7 @@ func (c EnterpriseClaimer) Claim( nextStartAt sql.NullTime, ttl sql.NullInt64, ) (*uuid.UUID, error) { - result, err := c.store.ClaimPrebuiltWorkspace(ctx, database.ClaimPrebuiltWorkspaceParams{ + result, err := store.ClaimPrebuiltWorkspace(ctx, database.ClaimPrebuiltWorkspaceParams{ NewUserID: userID, NewName: name, Now: now, diff --git a/enterprise/coderd/prebuilds/claim_test.go b/enterprise/coderd/prebuilds/claim_test.go index 217a9ff096..79262401c8 100644 --- a/enterprise/coderd/prebuilds/claim_test.go +++ b/enterprise/coderd/prebuilds/claim_test.go @@ -167,8 +167,14 @@ func TestClaimPrebuild(t *testing.T) { defer provisionerCloser.Close() cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) - reconciler := prebuilds.NewStoreReconciler(spy, pubsub, cache, codersdk.PrebuildsConfig{}, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer(), newNoopUsageCheckerPtr()) - var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer(spy) + reconciler := prebuilds.NewStoreReconciler( + spy, pubsub, cache, codersdk.PrebuildsConfig{}, logger, + quartz.NewMock(t), + prometheus.NewRegistry(), + newNoopEnqueuer(), + newNoopUsageCheckerPtr(), + ) + var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer() api.AGPL.PrebuildsClaimer.Store(&claimer) version := coderdtest.CreateTemplateVersion(t, client, orgID, templateWithAgentAndPresetsWithPrebuilds(desiredInstances)) diff --git a/enterprise/coderd/workspaces_test.go b/enterprise/coderd/workspaces_test.go index 6ffb8b4c30..562ad35eb1 100644 --- a/enterprise/coderd/workspaces_test.go +++ b/enterprise/coderd/workspaces_test.go @@ -1978,7 +1978,7 @@ func TestPrebuildsAutobuild(t *testing.T) { notificationsNoop, api.AGPL.BuildUsageChecker, ) - var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer(db) + var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer() api.AGPL.PrebuildsClaimer.Store(&claimer) // Setup user, template and template version with a preset with 1 prebuild instance @@ -2100,7 +2100,7 @@ func TestPrebuildsAutobuild(t *testing.T) { notificationsNoop, api.AGPL.BuildUsageChecker, ) - var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer(db) + var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer() api.AGPL.PrebuildsClaimer.Store(&claimer) // Setup user, template and template version with a preset with 1 prebuild instance @@ -2222,7 +2222,7 @@ func TestPrebuildsAutobuild(t *testing.T) { notificationsNoop, api.AGPL.BuildUsageChecker, ) - var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer(db) + var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer() api.AGPL.PrebuildsClaimer.Store(&claimer) // Setup user, template and template version with a preset with 1 prebuild instance @@ -2366,7 +2366,7 @@ func TestPrebuildsAutobuild(t *testing.T) { notificationsNoop, api.AGPL.BuildUsageChecker, ) - var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer(db) + var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer() api.AGPL.PrebuildsClaimer.Store(&claimer) // Setup user, template and template version with a preset with 1 prebuild instance @@ -2511,7 +2511,7 @@ func TestPrebuildsAutobuild(t *testing.T) { notificationsNoop, api.AGPL.BuildUsageChecker, ) - var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer(db) + var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer() api.AGPL.PrebuildsClaimer.Store(&claimer) // Setup user, template and template version with a preset with 1 prebuild instance @@ -2957,7 +2957,7 @@ func TestWorkspaceProvisionerdServerMetrics(t *testing.T) { notifications.NewNoopEnqueuer(), api.AGPL.BuildUsageChecker, ) - var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer(db) + var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer() api.AGPL.PrebuildsClaimer.Store(&claimer) organizationName, err := client.Organization(ctx, owner.OrganizationID)