chore: backport fixes (#21957)

This commit is contained in:
Jon Ayers
2026-02-05 16:09:41 -06:00
committed by GitHub
parent 72afd3677c
commit b275be2e7a
14 changed files with 170 additions and 54 deletions
+4 -1
View File
@@ -99,7 +99,10 @@ func (c *Client) SyncReady(ctx context.Context, unitName unit.ID) (bool, error)
resp, err := c.client.SyncReady(ctx, &proto.SyncReadyRequest{
Unit: string(unitName),
})
return resp.Ready, err
if err != nil {
return false, xerrors.Errorf("sync ready: %w", err)
}
return resp.Ready, nil
}
// SyncStatus gets the status of a unit and its dependencies.
+9
View File
@@ -4,6 +4,8 @@ import (
"os"
"github.com/hashicorp/go-reap"
"cdr.dev/slog/v3"
)
type Option func(o *options)
@@ -34,8 +36,15 @@ func WithCatchSignals(sigs ...os.Signal) Option {
}
}
func WithLogger(logger slog.Logger) Option {
return func(o *options) {
o.Logger = logger
}
}
type options struct {
ExecArgs []string
PIDs reap.PidCh
CatchSignals []os.Signal
Logger slog.Logger
}
+2 -2
View File
@@ -7,6 +7,6 @@ func IsInitProcess() bool {
return false
}
func ForkReap(_ ...Option) error {
return nil
func ForkReap(_ ...Option) (int, error) {
return 0, nil
}
+37 -2
View File
@@ -32,12 +32,13 @@ func TestReap(t *testing.T) {
}
pids := make(reap.PidCh, 1)
err := reaper.ForkReap(
exitCode, err := reaper.ForkReap(
reaper.WithPIDCallback(pids),
// Provide some argument that immediately exits.
reaper.WithExecArgs("/bin/sh", "-c", "exit 0"),
)
require.NoError(t, err)
require.Equal(t, 0, exitCode)
cmd := exec.Command("tail", "-f", "/dev/null")
err = cmd.Start()
@@ -65,6 +66,36 @@ func TestReap(t *testing.T) {
}
}
//nolint:paralleltest
func TestForkReapExitCodes(t *testing.T) {
if testutil.InCI() {
t.Skip("Detected CI, skipping reaper tests")
}
tests := []struct {
name string
command string
expectedCode int
}{
{"exit 0", "exit 0", 0},
{"exit 1", "exit 1", 1},
{"exit 42", "exit 42", 42},
{"exit 255", "exit 255", 255},
{"SIGKILL", "kill -9 $$", 128 + 9},
{"SIGTERM", "kill -15 $$", 128 + 15},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
exitCode, err := reaper.ForkReap(
reaper.WithExecArgs("/bin/sh", "-c", tt.command),
)
require.NoError(t, err)
require.Equal(t, tt.expectedCode, exitCode, "exit code mismatch for %q", tt.command)
})
}
}
//nolint:paralleltest // Signal handling.
func TestReapInterrupt(t *testing.T) {
// Don't run the reaper test in CI. It does weird
@@ -84,13 +115,17 @@ func TestReapInterrupt(t *testing.T) {
defer signal.Stop(usrSig)
go func() {
errC <- reaper.ForkReap(
exitCode, err := reaper.ForkReap(
reaper.WithPIDCallback(pids),
reaper.WithCatchSignals(os.Interrupt),
// Signal propagation does not extend to children of children, so
// we create a little bash script to ensure sleep is interrupted.
reaper.WithExecArgs("/bin/sh", "-c", fmt.Sprintf("pid=0; trap 'kill -USR2 %d; kill -TERM $pid' INT; sleep 10 &\npid=$!; kill -USR1 %d; wait", os.Getpid(), os.Getpid())),
)
// The child exits with 128 + SIGTERM (15) = 143, but the trap catches
// SIGINT and sends SIGTERM to the sleep process, so exit code varies.
_ = exitCode
errC <- err
}()
require.Equal(t, <-usrSig, syscall.SIGUSR1)
+34 -6
View File
@@ -3,12 +3,15 @@
package reaper
import (
"context"
"os"
"os/signal"
"syscall"
"github.com/hashicorp/go-reap"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
)
// IsInitProcess returns true if the current process's PID is 1.
@@ -16,7 +19,7 @@ func IsInitProcess() bool {
return os.Getpid() == 1
}
func catchSignals(pid int, sigs []os.Signal) {
func catchSignals(logger slog.Logger, pid int, sigs []os.Signal) {
if len(sigs) == 0 {
return
}
@@ -25,10 +28,19 @@ func catchSignals(pid int, sigs []os.Signal) {
signal.Notify(sc, sigs...)
defer signal.Stop(sc)
logger.Info(context.Background(), "reaper catching signals",
slog.F("signals", sigs),
slog.F("child_pid", pid),
)
for {
s := <-sc
sig, ok := s.(syscall.Signal)
if ok {
logger.Info(context.Background(), "reaper caught signal, killing child process",
slog.F("signal", sig.String()),
slog.F("child_pid", pid),
)
_ = syscall.Kill(pid, sig)
}
}
@@ -40,7 +52,10 @@ func catchSignals(pid int, sigs []os.Signal) {
// the reaper and an exec.Command waiting for its process to complete.
// The provided 'pids' channel may be nil if the caller does not care about the
// reaped children PIDs.
func ForkReap(opt ...Option) error {
//
// Returns the child's exit code (using 128+signal for signal termination)
// and any error from Wait4.
func ForkReap(opt ...Option) (int, error) {
opts := &options{
ExecArgs: os.Args,
}
@@ -53,7 +68,7 @@ func ForkReap(opt ...Option) error {
pwd, err := os.Getwd()
if err != nil {
return xerrors.Errorf("get wd: %w", err)
return 1, xerrors.Errorf("get wd: %w", err)
}
pattrs := &syscall.ProcAttr{
@@ -72,15 +87,28 @@ func ForkReap(opt ...Option) error {
//#nosec G204
pid, err := syscall.ForkExec(opts.ExecArgs[0], opts.ExecArgs, pattrs)
if err != nil {
return xerrors.Errorf("fork exec: %w", err)
return 1, xerrors.Errorf("fork exec: %w", err)
}
go catchSignals(pid, opts.CatchSignals)
go catchSignals(opts.Logger, pid, opts.CatchSignals)
var wstatus syscall.WaitStatus
_, err = syscall.Wait4(pid, &wstatus, 0, nil)
for xerrors.Is(err, syscall.EINTR) {
_, err = syscall.Wait4(pid, &wstatus, 0, nil)
}
return err
// Convert wait status to exit code using standard Unix conventions:
// - Normal exit: use the exit code
// - Signal termination: use 128 + signal number
var exitCode int
switch {
case wstatus.Exited():
exitCode = wstatus.ExitStatus()
case wstatus.Signaled():
exitCode = 128 + int(wstatus.Signal())
default:
exitCode = 1
}
return exitCode, err
}
+46 -18
View File
@@ -9,6 +9,7 @@ import (
"net/http/pprof"
"net/url"
"os"
"os/signal"
"path/filepath"
"runtime"
"slices"
@@ -130,40 +131,29 @@ func workspaceAgent() *serpent.Command {
sinks = append(sinks, sloghuman.Sink(logWriter))
logger := inv.Logger.AppendSinks(sinks...).Leveled(slog.LevelDebug)
logger = logger.Named("reaper")
logger.Info(ctx, "spawning reaper process")
// Do not start a reaper on the child process. It's important
// to do this else we fork bomb ourselves.
//nolint:gocritic
args := append(os.Args, "--no-reap")
err := reaper.ForkReap(
exitCode, err := reaper.ForkReap(
reaper.WithExecArgs(args...),
reaper.WithCatchSignals(StopSignals...),
reaper.WithLogger(logger),
)
if err != nil {
logger.Error(ctx, "agent process reaper unable to fork", slog.Error(err))
return xerrors.Errorf("fork reap: %w", err)
}
logger.Info(ctx, "reaper process exiting")
return nil
logger.Info(ctx, "child process exited, propagating exit code",
slog.F("exit_code", exitCode),
)
return ExitError(exitCode, nil)
}
// Handle interrupt signals to allow for graceful shutdown,
// note that calling stopNotify disables the signal handler
// and the next interrupt will terminate the program (you
// probably want cancel instead).
//
// Note that we don't want to handle these signals in the
// process that runs as PID 1, that's why we do this after
// the reaper forked.
ctx, stopNotify := inv.SignalNotifyContext(ctx, StopSignals...)
defer stopNotify()
// DumpHandler does signal handling, so we call it after the
// reaper.
go DumpHandler(ctx, "agent")
logWriter := &clilog.LumberjackWriteCloseFixer{Writer: &lumberjack.Logger{
Filename: filepath.Join(logDir, "coder-agent.log"),
MaxSize: 5, // MB
@@ -176,6 +166,21 @@ func workspaceAgent() *serpent.Command {
sinks = append(sinks, sloghuman.Sink(logWriter))
logger := inv.Logger.AppendSinks(sinks...).Leveled(slog.LevelDebug)
// Handle interrupt signals to allow for graceful shutdown,
// note that calling stopNotify disables the signal handler
// and the next interrupt will terminate the program (you
// probably want cancel instead).
//
// Note that we also handle these signals in the
// process that runs as PID 1, mainly to forward it to the agent child
// so that it can shutdown gracefully.
ctx, stopNotify := logSignalNotifyContext(ctx, logger, StopSignals...)
defer stopNotify()
// DumpHandler does signal handling, so we call it after the
// reaper.
go DumpHandler(ctx, "agent")
version := buildinfo.Version()
logger.Info(ctx, "agent is starting now",
slog.F("url", agentAuth.agentURL),
@@ -557,3 +562,26 @@ func urlPort(u string) (int, error) {
}
return -1, xerrors.Errorf("invalid port: %s", u)
}
// logSignalNotifyContext is like signal.NotifyContext but logs the received
// signal before canceling the context.
func logSignalNotifyContext(parent context.Context, logger slog.Logger, signals ...os.Signal) (context.Context, context.CancelFunc) {
ctx, cancel := context.WithCancelCause(parent)
c := make(chan os.Signal, 1)
signal.Notify(c, signals...)
go func() {
select {
case sig := <-c:
logger.Info(ctx, "agent received signal", slog.F("signal", sig.String()))
cancel(xerrors.Errorf("signal: %s", sig.String()))
case <-ctx.Done():
logger.Info(ctx, "ctx canceled, stopping signal handler")
}
}()
return ctx, func() {
cancel(context.Canceled)
signal.Stop(c)
}
}
+1
View File
@@ -63,6 +63,7 @@ type StateSnapshotter interface {
type Claimer interface {
Claim(
ctx context.Context,
store database.Store,
now time.Time,
userID uuid.UUID,
name string,
+1 -1
View File
@@ -34,7 +34,7 @@ var DefaultReconciler ReconciliationOrchestrator = NoopReconciler{}
type NoopClaimer struct{}
func (NoopClaimer) Claim(context.Context, time.Time, uuid.UUID, string, uuid.UUID, sql.NullString, sql.NullTime, sql.NullInt64) (*uuid.UUID, error) {
func (NoopClaimer) Claim(context.Context, database.Store, time.Time, uuid.UUID, string, uuid.UUID, sql.NullString, sql.NullTime, sql.NullInt64) (*uuid.UUID, error) {
// Not entitled to claim prebuilds in AGPL version.
return nil, ErrAGPLDoesNotSupportPrebuiltWorkspaces
}
+1 -1
View File
@@ -937,7 +937,7 @@ func claimPrebuild(
nextStartAt sql.NullTime,
ttl sql.NullInt64,
) (*database.Workspace, error) {
claimedID, err := claimer.Claim(ctx, now, owner.ID, name, templateVersionPresetID, autostartSchedule, nextStartAt, ttl)
claimedID, err := claimer.Claim(ctx, db, now, owner.ID, name, templateVersionPresetID, autostartSchedule, nextStartAt, ttl)
if err != nil {
// TODO: enhance this by clarifying whether this *specific* prebuild failed or whether there are none to claim.
return nil, xerrors.Errorf("claim prebuild: %w", err)
+2 -2
View File
@@ -371,7 +371,7 @@ func TestEnterpriseCreateWithPreset(t *testing.T) {
notifications.NewNoopEnqueuer(),
newNoopUsageCheckerPtr(),
)
var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer(db)
var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer()
api.AGPL.PrebuildsClaimer.Store(&claimer)
// Given: a template and a template version where the preset defines values for all required parameters,
@@ -482,7 +482,7 @@ func TestEnterpriseCreateWithPreset(t *testing.T) {
notifications.NewNoopEnqueuer(),
newNoopUsageCheckerPtr(),
)
var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer(db)
var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer()
api.AGPL.PrebuildsClaimer.Store(&claimer)
// Given: a template and a template version where the preset defines values for all required parameters,
+12 -3
View File
@@ -1307,7 +1307,16 @@ func (api *API) setupPrebuilds(featureEnabled bool) (agplprebuilds.Reconciliatio
return agplprebuilds.DefaultReconciler, agplprebuilds.DefaultClaimer
}
reconciler := prebuilds.NewStoreReconciler(api.Database, api.Pubsub, api.AGPL.FileCache, api.DeploymentValues.Prebuilds,
api.Logger.Named("prebuilds"), quartz.NewReal(), api.PrometheusRegistry, api.NotificationsEnqueuer, api.AGPL.BuildUsageChecker)
return reconciler, prebuilds.NewEnterpriseClaimer(api.Database)
reconciler := prebuilds.NewStoreReconciler(
api.Database,
api.Pubsub,
api.AGPL.FileCache,
api.DeploymentValues.Prebuilds,
api.Logger.Named("prebuilds"),
quartz.NewReal(),
api.PrometheusRegistry,
api.NotificationsEnqueuer,
api.AGPL.BuildUsageChecker,
)
return reconciler, prebuilds.NewEnterpriseClaimer()
}
+7 -10
View File
@@ -13,18 +13,15 @@ import (
"github.com/coder/coder/v2/coderd/prebuilds"
)
type EnterpriseClaimer struct {
store database.Store
type EnterpriseClaimer struct{}
func NewEnterpriseClaimer() *EnterpriseClaimer {
return &EnterpriseClaimer{}
}
func NewEnterpriseClaimer(store database.Store) *EnterpriseClaimer {
return &EnterpriseClaimer{
store: store,
}
}
func (c EnterpriseClaimer) Claim(
func (EnterpriseClaimer) Claim(
ctx context.Context,
store database.Store,
now time.Time,
userID uuid.UUID,
name string,
@@ -33,7 +30,7 @@ func (c EnterpriseClaimer) Claim(
nextStartAt sql.NullTime,
ttl sql.NullInt64,
) (*uuid.UUID, error) {
result, err := c.store.ClaimPrebuiltWorkspace(ctx, database.ClaimPrebuiltWorkspaceParams{
result, err := store.ClaimPrebuiltWorkspace(ctx, database.ClaimPrebuiltWorkspaceParams{
NewUserID: userID,
NewName: name,
Now: now,
+8 -2
View File
@@ -167,8 +167,14 @@ func TestClaimPrebuild(t *testing.T) {
defer provisionerCloser.Close()
cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{})
reconciler := prebuilds.NewStoreReconciler(spy, pubsub, cache, codersdk.PrebuildsConfig{}, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer(), newNoopUsageCheckerPtr())
var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer(spy)
reconciler := prebuilds.NewStoreReconciler(
spy, pubsub, cache, codersdk.PrebuildsConfig{}, logger,
quartz.NewMock(t),
prometheus.NewRegistry(),
newNoopEnqueuer(),
newNoopUsageCheckerPtr(),
)
var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer()
api.AGPL.PrebuildsClaimer.Store(&claimer)
version := coderdtest.CreateTemplateVersion(t, client, orgID, templateWithAgentAndPresetsWithPrebuilds(desiredInstances))
+6 -6
View File
@@ -1978,7 +1978,7 @@ func TestPrebuildsAutobuild(t *testing.T) {
notificationsNoop,
api.AGPL.BuildUsageChecker,
)
var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer(db)
var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer()
api.AGPL.PrebuildsClaimer.Store(&claimer)
// Setup user, template and template version with a preset with 1 prebuild instance
@@ -2100,7 +2100,7 @@ func TestPrebuildsAutobuild(t *testing.T) {
notificationsNoop,
api.AGPL.BuildUsageChecker,
)
var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer(db)
var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer()
api.AGPL.PrebuildsClaimer.Store(&claimer)
// Setup user, template and template version with a preset with 1 prebuild instance
@@ -2222,7 +2222,7 @@ func TestPrebuildsAutobuild(t *testing.T) {
notificationsNoop,
api.AGPL.BuildUsageChecker,
)
var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer(db)
var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer()
api.AGPL.PrebuildsClaimer.Store(&claimer)
// Setup user, template and template version with a preset with 1 prebuild instance
@@ -2366,7 +2366,7 @@ func TestPrebuildsAutobuild(t *testing.T) {
notificationsNoop,
api.AGPL.BuildUsageChecker,
)
var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer(db)
var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer()
api.AGPL.PrebuildsClaimer.Store(&claimer)
// Setup user, template and template version with a preset with 1 prebuild instance
@@ -2511,7 +2511,7 @@ func TestPrebuildsAutobuild(t *testing.T) {
notificationsNoop,
api.AGPL.BuildUsageChecker,
)
var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer(db)
var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer()
api.AGPL.PrebuildsClaimer.Store(&claimer)
// Setup user, template and template version with a preset with 1 prebuild instance
@@ -2957,7 +2957,7 @@ func TestWorkspaceProvisionerdServerMetrics(t *testing.T) {
notifications.NewNoopEnqueuer(),
api.AGPL.BuildUsageChecker,
)
var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer(db)
var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer()
api.AGPL.PrebuildsClaimer.Store(&claimer)
organizationName, err := client.Organization(ctx, owner.OrganizationID)