From af3fdc68c301c2b06deeab9acaa2b60ef2eb09d3 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Fri, 23 Feb 2024 11:04:23 +0400 Subject: [PATCH] chore: refactor agent routines that use the v2 API (#12223) In anticipation of needing the `LogSender` to run on a context that doesn't get immediately canceled when you `Close()` the agent, I've undertaken a little refactor to manage the goroutines that get run against the Tailnet and Agent API connection. This handles controlling two contexts, one that gets canceled right away at the start of graceful shutdown, and another that stays up to allow graceful shutdown to complete. --- agent/agent.go | 692 ++++++++++++------- cli/ssh_test.go | 10 +- coderd/coderdtest/coderdtest.go | 2 +- codersdk/workspaceagents.go | 2 +- enterprise/coderd/coderd_test.go | 7 +- enterprise/coderd/replicas_test.go | 8 +- enterprise/coderd/workspaceagents_test.go | 24 +- enterprise/coderd/workspaceportshare_test.go | 12 +- tailnet/coordinator.go | 3 +- 9 files changed, 501 insertions(+), 259 deletions(-) diff --git a/agent/agent.go b/agent/agent.go index c0a61fa97f..e0256d2e22 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -150,13 +150,17 @@ func New(options Options) Agent { options.Syscaller = agentproc.NewSyscaller() } - ctx, cancelFunc := context.WithCancel(context.Background()) + hardCtx, hardCancel := context.WithCancel(context.Background()) + gracefulCtx, gracefulCancel := context.WithCancel(hardCtx) a := &agent{ tailnetListenPort: options.TailnetListenPort, reconnectingPTYTimeout: options.ReconnectingPTYTimeout, logger: options.Logger, - closeCancel: cancelFunc, - closed: make(chan struct{}), + gracefulCtx: gracefulCtx, + gracefulCancel: gracefulCancel, + hardCtx: hardCtx, + hardCancel: hardCancel, + coordDisconnected: make(chan struct{}), environmentVariables: options.EnvironmentVariables, client: options.Client, exchangeToken: options.ExchangeToken, @@ -181,9 +185,14 @@ func New(options Options) Agent { prometheusRegistry: prometheusRegistry, metrics: newAgentMetrics(prometheusRegistry), } + // Initially, we have a closed channel, reflecting the fact that we are not initially connected. + // Each time we connect we replace the channel (while holding the closeMutex) with a new one + // that gets closed on disconnection. This is used to wait for graceful disconnection from the + // coordinator during shut down. + close(a.coordDisconnected) a.serviceBanner.Store(new(codersdk.ServiceBannerConfig)) a.sessionToken.Store(new(string)) - a.init(ctx) + a.init() return a } @@ -206,10 +215,16 @@ type agent struct { reconnectingPTYs sync.Map reconnectingPTYTimeout time.Duration - connCloseWait sync.WaitGroup - closeCancel context.CancelFunc - closeMutex sync.Mutex - closed chan struct{} + // we track 2 contexts and associated cancel functions: "graceful" which is Done when it is time + // to start gracefully shutting down and "hard" which is Done when it is time to close + // everything down (regardless of whether graceful shutdown completed). + gracefulCtx context.Context + gracefulCancel context.CancelFunc + hardCtx context.Context + hardCancel context.CancelFunc + closeWaitGroup sync.WaitGroup + closeMutex sync.Mutex + coordDisconnected chan struct{} environmentVariables map[string]string @@ -249,8 +264,9 @@ func (a *agent) TailnetConn() *tailnet.Conn { return a.network } -func (a *agent) init(ctx context.Context) { - sshSrv, err := agentssh.NewServer(ctx, a.logger.Named("ssh-server"), a.prometheusRegistry, a.filesystem, &agentssh.Config{ +func (a *agent) init() { + // pass the "hard" context because we explicitly close the SSH server as part of graceful shutdown. + sshSrv, err := agentssh.NewServer(a.hardCtx, a.logger.Named("ssh-server"), a.prometheusRegistry, a.filesystem, &agentssh.Config{ MaxTimeout: a.sshMaxTimeout, MOTDFile: func() string { return a.manifest.Load().MOTDFile }, ServiceBanner: func() *codersdk.ServiceBannerConfig { return a.serviceBanner.Load() }, @@ -272,22 +288,24 @@ func (a *agent) init(ctx context.Context) { // Register runner metrics. If the prom registry is nil, the metrics // will not report anywhere. a.scriptRunner.RegisterMetrics(a.prometheusRegistry) - go a.runLoop(ctx) + go a.runLoop() } // runLoop attempts to start the agent in a retry loop. // Coder may be offline temporarily, a connection issue // may be happening, but regardless after the intermittent // failure, you'll want the agent to reconnect. -func (a *agent) runLoop(ctx context.Context) { - go a.reportLifecycleLoop(ctx) - go a.reportMetadataLoop(ctx) - go a.manageProcessPriorityLoop(ctx) +func (a *agent) runLoop() { + go a.reportLifecycleUntilClose() + go a.reportMetadataUntilGracefulShutdown() + go a.manageProcessPriorityUntilGracefulShutdown() + // need to keep retrying up to the hardCtx so that we can send graceful shutdown-related + // messages. + ctx := a.hardCtx for retrier := retry.New(100*time.Millisecond, 10*time.Second); retrier.Wait(ctx); { a.logger.Info(ctx, "connecting to coderd") - err := a.run(ctx) - // Cancel after the run is complete to clean up any leaked resources! + err := a.run() if err == nil { continue } @@ -386,7 +404,9 @@ func (t *trySingleflight) Do(key string, fn func()) { fn() } -func (a *agent) reportMetadataLoop(ctx context.Context) { +func (a *agent) reportMetadataUntilGracefulShutdown() { + // metadata reporting can cease as soon as we start gracefully shutting down. + ctx := a.gracefulCtx tickerDone := make(chan struct{}) collectDone := make(chan struct{}) ctx, cancel := context.WithCancel(ctx) @@ -595,9 +615,12 @@ func (a *agent) reportMetadataLoop(ctx context.Context) { } } -// reportLifecycleLoop reports the current lifecycle state once. All state +// reportLifecycleUntilClose reports the current lifecycle state once. All state // changes are reported in order. -func (a *agent) reportLifecycleLoop(ctx context.Context) { +func (a *agent) reportLifecycleUntilClose() { + // part of graceful shut down is reporting the final lifecycle states, e.g "ShuttingDown" so the + // lifecycle reporting has to be via the "hard" context. + ctx := a.hardCtx lastReportedIndex := 0 // Start off with the created state without reporting it. for { select { @@ -623,6 +646,8 @@ func (a *agent) reportLifecycleLoop(ctx context.Context) { err := a.client.PostLifecycle(ctx, report) if err == nil { + a.logger.Debug(ctx, "successfully reported lifecycle state", slog.F("payload", report)) + r.Reset() // don't back off when we are successful lastReportedIndex++ select { case a.lifecycleReported <- report.State: @@ -638,6 +663,7 @@ func (a *agent) reportLifecycleLoop(ctx context.Context) { break } if xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) { + a.logger.Debug(ctx, "canceled reporting lifecycle state", slog.F("payload", report)) return } // If we fail to report the state we probably shouldn't exit, log only. @@ -648,7 +674,7 @@ func (a *agent) reportLifecycleLoop(ctx context.Context) { // setLifecycle sets the lifecycle state and notifies the lifecycle loop. // The state is only updated if it's a valid state transition. -func (a *agent) setLifecycle(ctx context.Context, state codersdk.WorkspaceAgentLifecycle) { +func (a *agent) setLifecycle(state codersdk.WorkspaceAgentLifecycle) { report := agentsdk.PostLifecycleRequest{ State: state, ChangedAt: dbtime.Now(), @@ -657,12 +683,12 @@ func (a *agent) setLifecycle(ctx context.Context, state codersdk.WorkspaceAgentL a.lifecycleMu.Lock() lastReport := a.lifecycleStates[len(a.lifecycleStates)-1] if slices.Index(codersdk.WorkspaceAgentLifecycleOrder, lastReport.State) >= slices.Index(codersdk.WorkspaceAgentLifecycleOrder, report.State) { - a.logger.Warn(ctx, "attempted to set lifecycle state to a previous state", slog.F("last", lastReport), slog.F("current", report)) + a.logger.Warn(context.Background(), "attempted to set lifecycle state to a previous state", slog.F("last", lastReport), slog.F("current", report)) a.lifecycleMu.Unlock() return } a.lifecycleStates = append(a.lifecycleStates, report) - a.logger.Debug(ctx, "set lifecycle state", slog.F("current", report), slog.F("last", lastReport)) + a.logger.Debug(context.Background(), "set lifecycle state", slog.F("current", report), slog.F("last", lastReport)) a.lifecycleMu.Unlock() select { @@ -674,7 +700,8 @@ func (a *agent) setLifecycle(ctx context.Context, state codersdk.WorkspaceAgentL // fetchServiceBannerLoop fetches the service banner on an interval. It will // not be fetched immediately; the expectation is that it is primed elsewhere // (and must be done before the session actually starts). -func (a *agent) fetchServiceBannerLoop(ctx context.Context, aAPI proto.DRPCAgentClient) error { +func (a *agent) fetchServiceBannerLoop(ctx context.Context, conn drpc.Conn) error { + aAPI := proto.NewDRPCAgentClient(conn) ticker := time.NewTicker(a.serviceBannerRefreshInterval) defer ticker.Stop() for { @@ -696,205 +723,272 @@ func (a *agent) fetchServiceBannerLoop(ctx context.Context, aAPI proto.DRPCAgent } } -func (a *agent) run(ctx context.Context) error { +func (a *agent) run() (retErr error) { // This allows the agent to refresh it's token if necessary. // For instance identity this is required, since the instance // may not have re-provisioned, but a new agent ID was created. - sessionToken, err := a.exchangeToken(ctx) + sessionToken, err := a.exchangeToken(a.hardCtx) if err != nil { return xerrors.Errorf("exchange token: %w", err) } a.sessionToken.Store(&sessionToken) // ConnectRPC returns the dRPC connection we use for the Agent and Tailnet v2+ APIs - conn, err := a.client.ConnectRPC(ctx) + conn, err := a.client.ConnectRPC(a.hardCtx) if err != nil { return err } defer func() { cErr := conn.Close() if cErr != nil { - a.logger.Debug(ctx, "error closing drpc connection", slog.Error(err)) + a.logger.Debug(a.hardCtx, "error closing drpc connection", slog.Error(err)) } }() - aAPI := proto.NewDRPCAgentClient(conn) - sbp, err := aAPI.GetServiceBanner(ctx, &proto.GetServiceBannerRequest{}) - if err != nil { - return xerrors.Errorf("fetch service banner: %w", err) - } - serviceBanner := agentsdk.ServiceBannerFromProto(sbp) - a.serviceBanner.Store(&serviceBanner) + // A lot of routines need the agent API / tailnet API connection. We run them in their own + // goroutines in parallel, but errors in any routine will cause them all to exit so we can + // redial the coder server and retry. + connMan := newAPIConnRoutineManager(a.gracefulCtx, a.hardCtx, a.logger, conn) - mp, err := aAPI.GetManifest(ctx, &proto.GetManifestRequest{}) - if err != nil { - return xerrors.Errorf("fetch metadata: %w", err) - } - a.logger.Info(ctx, "fetched manifest", slog.F("manifest", mp)) - manifest, err := agentsdk.ManifestFromProto(mp) - if err != nil { - a.logger.Critical(ctx, "failed to convert manifest", slog.F("manifest", mp), slog.Error(err)) - return xerrors.Errorf("convert manifest: %w", err) - } - if manifest.AgentID == uuid.Nil { - return xerrors.New("nil agentID returned by manifest") - } - a.client.RewriteDERPMap(manifest.DERPMap) - - // Expand the directory and send it back to coderd so external - // applications that rely on the directory can use it. - // - // An example is VS Code Remote, which must know the directory - // before initializing a connection. - manifest.Directory, err = expandDirectory(manifest.Directory) - if err != nil { - return xerrors.Errorf("expand directory: %w", err) - } - subsys, err := agentsdk.ProtoFromSubsystems(a.subsystems) - if err != nil { - a.logger.Critical(ctx, "failed to convert subsystems", slog.Error(err)) - return xerrors.Errorf("failed to convert subsystems: %w", err) - } - _, err = aAPI.UpdateStartup(ctx, &proto.UpdateStartupRequest{Startup: &proto.Startup{ - Version: buildinfo.Version(), - ExpandedDirectory: manifest.Directory, - Subsystems: subsys, - }}) - if err != nil { - return xerrors.Errorf("update workspace agent startup: %w", err) - } - - oldManifest := a.manifest.Swap(&manifest) - - // The startup script should only execute on the first run! - if oldManifest == nil { - a.setLifecycle(ctx, codersdk.WorkspaceAgentLifecycleStarting) - - // Perform overrides early so that Git auth can work even if users - // connect to a workspace that is not yet ready. We don't run this - // concurrently with the startup script to avoid conflicts between - // them. - if manifest.GitAuthConfigs > 0 { - // If this fails, we should consider surfacing the error in the - // startup log and setting the lifecycle state to be "start_error" - // (after startup script completion), but for now we'll just log it. - err := gitauth.OverrideVSCodeConfigs(a.filesystem) + connMan.start("init service banner", gracefulShutdownBehaviorStop, + func(ctx context.Context, conn drpc.Conn) error { + aAPI := proto.NewDRPCAgentClient(conn) + sbp, err := aAPI.GetServiceBanner(ctx, &proto.GetServiceBannerRequest{}) if err != nil { - a.logger.Warn(ctx, "failed to override vscode git auth configs", slog.Error(err)) + return xerrors.Errorf("fetch service banner: %w", err) } - } + serviceBanner := agentsdk.ServiceBannerFromProto(sbp) + a.serviceBanner.Store(&serviceBanner) + return nil + }, + ) - err = a.scriptRunner.Init(manifest.Scripts) - if err != nil { - return xerrors.Errorf("init script runner: %w", err) - } - err = a.trackConnGoroutine(func() { - start := time.Now() - err := a.scriptRunner.Execute(ctx, func(script codersdk.WorkspaceAgentScript) bool { - return script.RunOnStart - }) - // Measure the time immediately after the script has finished - dur := time.Since(start).Seconds() - if err != nil { - a.logger.Warn(ctx, "startup script(s) failed", slog.Error(err)) - if errors.Is(err, agentscripts.ErrTimeout) { - a.setLifecycle(ctx, codersdk.WorkspaceAgentLifecycleStartTimeout) - } else { - a.setLifecycle(ctx, codersdk.WorkspaceAgentLifecycleStartError) - } - } else { - a.setLifecycle(ctx, codersdk.WorkspaceAgentLifecycleReady) - } + // channels to sync goroutines below + // handle manifest + // | + // manifestOK + // | | + // | +----------------------+ + // V | + // app health reporter | + // V + // create or update network + // | + // networkOK + // | + // coordination <--------------------------+ + // derp map subscriber <----------------+ + // stats report loop <---------------+ + networkOK := make(chan struct{}) + manifestOK := make(chan struct{}) - label := "false" - if err == nil { - label = "true" + connMan.start("handle manifest", gracefulShutdownBehaviorStop, a.handleManifest(manifestOK)) + + connMan.start("app health reporter", gracefulShutdownBehaviorStop, + func(ctx context.Context, conn drpc.Conn) error { + select { + case <-ctx.Done(): + return nil + case <-manifestOK: + manifest := a.manifest.Load() + NewWorkspaceAppHealthReporter( + a.logger, manifest.Apps, agentsdk.AppHealthPoster(proto.NewDRPCAgentClient(conn)), + )(ctx) + return nil } - a.metrics.startupScriptSeconds.WithLabelValues(label).Set(dur) - a.scriptRunner.StartCron() }) - if err != nil { - return xerrors.Errorf("track conn goroutine: %w", err) + + connMan.start("create or update network", gracefulShutdownBehaviorStop, + a.createOrUpdateNetwork(manifestOK, networkOK)) + + connMan.start("coordination", gracefulShutdownBehaviorStop, + func(ctx context.Context, conn drpc.Conn) error { + select { + case <-ctx.Done(): + return nil + case <-networkOK: + } + return a.runCoordinator(ctx, conn, a.network) + }, + ) + + connMan.start("derp map subscriber", gracefulShutdownBehaviorStop, + func(ctx context.Context, conn drpc.Conn) error { + select { + case <-ctx.Done(): + return nil + case <-networkOK: + } + return a.runDERPMapSubscriber(ctx, conn, a.network) + }) + + connMan.start("fetch service banner loop", gracefulShutdownBehaviorStop, a.fetchServiceBannerLoop) + + connMan.start("stats report loop", gracefulShutdownBehaviorStop, func(ctx context.Context, conn drpc.Conn) error { + select { + case <-ctx.Done(): + return nil + case <-networkOK: } + return a.statsReporter.reportLoop(ctx, proto.NewDRPCAgentClient(conn)) + }) + + return connMan.wait() +} + +// handleManifest returns a function that fetches and processes the manifest +func (a *agent) handleManifest(manifestOK chan<- struct{}) func(ctx context.Context, conn drpc.Conn) error { + return func(ctx context.Context, conn drpc.Conn) error { + aAPI := proto.NewDRPCAgentClient(conn) + mp, err := aAPI.GetManifest(ctx, &proto.GetManifestRequest{}) + if err != nil { + return xerrors.Errorf("fetch metadata: %w", err) + } + a.logger.Info(ctx, "fetched manifest", slog.F("manifest", mp)) + manifest, err := agentsdk.ManifestFromProto(mp) + if err != nil { + a.logger.Critical(ctx, "failed to convert manifest", slog.F("manifest", mp), slog.Error(err)) + return xerrors.Errorf("convert manifest: %w", err) + } + if manifest.AgentID == uuid.Nil { + return xerrors.New("nil agentID returned by manifest") + } + a.client.RewriteDERPMap(manifest.DERPMap) + + // Expand the directory and send it back to coderd so external + // applications that rely on the directory can use it. + // + // An example is VS Code Remote, which must know the directory + // before initializing a connection. + manifest.Directory, err = expandDirectory(manifest.Directory) + if err != nil { + return xerrors.Errorf("expand directory: %w", err) + } + subsys, err := agentsdk.ProtoFromSubsystems(a.subsystems) + if err != nil { + a.logger.Critical(ctx, "failed to convert subsystems", slog.Error(err)) + return xerrors.Errorf("failed to convert subsystems: %w", err) + } + _, err = aAPI.UpdateStartup(ctx, &proto.UpdateStartupRequest{Startup: &proto.Startup{ + Version: buildinfo.Version(), + ExpandedDirectory: manifest.Directory, + Subsystems: subsys, + }}) + if err != nil { + if xerrors.Is(err, context.Canceled) { + return nil + } + return xerrors.Errorf("update workspace agent startup: %w", err) + } + + oldManifest := a.manifest.Swap(&manifest) + close(manifestOK) + + // The startup script should only execute on the first run! + if oldManifest == nil { + a.setLifecycle(codersdk.WorkspaceAgentLifecycleStarting) + + // Perform overrides early so that Git auth can work even if users + // connect to a workspace that is not yet ready. We don't run this + // concurrently with the startup script to avoid conflicts between + // them. + if manifest.GitAuthConfigs > 0 { + // If this fails, we should consider surfacing the error in the + // startup log and setting the lifecycle state to be "start_error" + // (after startup script completion), but for now we'll just log it. + err := gitauth.OverrideVSCodeConfigs(a.filesystem) + if err != nil { + a.logger.Warn(ctx, "failed to override vscode git auth configs", slog.Error(err)) + } + } + + err = a.scriptRunner.Init(manifest.Scripts) + if err != nil { + return xerrors.Errorf("init script runner: %w", err) + } + err = a.trackGoroutine(func() { + start := time.Now() + // here we use the graceful context because the script runner is not directly tied + // to the agent API. + err := a.scriptRunner.Execute(a.gracefulCtx, func(script codersdk.WorkspaceAgentScript) bool { + return script.RunOnStart + }) + // Measure the time immediately after the script has finished + dur := time.Since(start).Seconds() + if err != nil { + a.logger.Warn(ctx, "startup script(s) failed", slog.Error(err)) + if errors.Is(err, agentscripts.ErrTimeout) { + a.setLifecycle(codersdk.WorkspaceAgentLifecycleStartTimeout) + } else { + a.setLifecycle(codersdk.WorkspaceAgentLifecycleStartError) + } + } else { + a.setLifecycle(codersdk.WorkspaceAgentLifecycleReady) + } + + label := "false" + if err == nil { + label = "true" + } + a.metrics.startupScriptSeconds.WithLabelValues(label).Set(dur) + a.scriptRunner.StartCron() + }) + if err != nil { + return xerrors.Errorf("track conn goroutine: %w", err) + } + } + return nil } +} - // This automatically closes when the context ends! - appReporterCtx, appReporterCtxCancel := context.WithCancel(ctx) - defer appReporterCtxCancel() - go NewWorkspaceAppHealthReporter( - a.logger, manifest.Apps, agentsdk.AppHealthPoster(aAPI))(appReporterCtx) - - a.closeMutex.Lock() - network := a.network - a.closeMutex.Unlock() - if network == nil { - network, err = a.createTailnet(ctx, manifest.AgentID, manifest.DERPMap, manifest.DERPForceWebSockets, manifest.DisableDirectConnections) - if err != nil { - return xerrors.Errorf("create tailnet: %w", err) +// createOrUpdateNetwork waits for the manifest to be set using manifestOK, then creates or updates +// the tailnet using the information in the manifest +func (a *agent) createOrUpdateNetwork(manifestOK <-chan struct{}, networkOK chan<- struct{}) func(context.Context, drpc.Conn) error { + return func(ctx context.Context, _ drpc.Conn) error { + select { + case <-ctx.Done(): + return nil + case <-manifestOK: } + var err error + manifest := a.manifest.Load() a.closeMutex.Lock() - // Re-check if agent was closed while initializing the network. - closed := a.isClosed() - if !closed { - a.network = network - a.statsReporter = newStatsReporter(a.logger, network, a) - } + network := a.network a.closeMutex.Unlock() - if closed { - _ = network.Close() - return xerrors.New("agent is closed") + if network == nil { + // use the graceful context here, because creating the tailnet is not itself tied to the + // agent API. + network, err = a.createTailnet(a.gracefulCtx, manifest.AgentID, manifest.DERPMap, manifest.DERPForceWebSockets, manifest.DisableDirectConnections) + if err != nil { + return xerrors.Errorf("create tailnet: %w", err) + } + a.closeMutex.Lock() + // Re-check if agent was closed while initializing the network. + closed := a.isClosed() + if !closed { + a.network = network + a.statsReporter = newStatsReporter(a.logger, network, a) + } + a.closeMutex.Unlock() + if closed { + _ = network.Close() + return xerrors.New("agent is closed") + } + } else { + // Update the wireguard IPs if the agent ID changed. + err := network.SetAddresses(a.wireguardAddresses(manifest.AgentID)) + if err != nil { + a.logger.Error(a.gracefulCtx, "update tailnet addresses", slog.Error(err)) + } + // Update the DERP map, force WebSocket setting and allow/disallow + // direct connections. + network.SetDERPMap(manifest.DERPMap) + network.SetDERPForceWebSockets(manifest.DERPForceWebSockets) + network.SetBlockEndpoints(manifest.DisableDirectConnections) } - } else { - // Update the wireguard IPs if the agent ID changed. - err := network.SetAddresses(a.wireguardAddresses(manifest.AgentID)) - if err != nil { - a.logger.Error(ctx, "update tailnet addresses", slog.Error(err)) - } - // Update the DERP map, force WebSocket setting and allow/disallow - // direct connections. - network.SetDERPMap(manifest.DERPMap) - network.SetDERPForceWebSockets(manifest.DERPForceWebSockets) - network.SetBlockEndpoints(manifest.DisableDirectConnections) + close(networkOK) + return nil } - - eg, egCtx := errgroup.WithContext(ctx) - eg.Go(func() error { - a.logger.Debug(egCtx, "running tailnet connection coordinator") - err := a.runCoordinator(egCtx, conn, network) - if err != nil { - return xerrors.Errorf("run coordinator: %w", err) - } - return nil - }) - - eg.Go(func() error { - a.logger.Debug(egCtx, "running derp map subscriber") - err := a.runDERPMapSubscriber(egCtx, conn, network) - if err != nil { - return xerrors.Errorf("run derp map subscriber: %w", err) - } - return nil - }) - - eg.Go(func() error { - a.logger.Debug(egCtx, "running fetch server banner loop") - err := a.fetchServiceBannerLoop(egCtx, aAPI) - if err != nil { - return xerrors.Errorf("fetch server banner loop: %w", err) - } - return nil - }) - - eg.Go(func() error { - a.logger.Debug(egCtx, "running stats report loop") - err := a.statsReporter.reportLoop(egCtx, aAPI) - if err != nil { - return xerrors.Errorf("report stats loop: %w", err) - } - return nil - }) - - return eg.Wait() } // updateCommandEnv updates the provided command environment with the @@ -995,15 +1089,15 @@ func (a *agent) wireguardAddresses(agentID uuid.UUID) []netip.Prefix { return a.addresses } -func (a *agent) trackConnGoroutine(fn func()) error { +func (a *agent) trackGoroutine(fn func()) error { a.closeMutex.Lock() defer a.closeMutex.Unlock() if a.isClosed() { return xerrors.New("track conn goroutine: agent is closed") } - a.connCloseWait.Add(1) + a.closeWaitGroup.Add(1) go func() { - defer a.connCloseWait.Done() + defer a.closeWaitGroup.Done() fn() }() return nil @@ -1037,7 +1131,7 @@ func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *t _ = sshListener.Close() } }() - if err = a.trackConnGoroutine(func() { + if err = a.trackGoroutine(func() { _ = a.sshServer.Serve(sshListener) }); err != nil { return nil, err @@ -1052,7 +1146,7 @@ func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *t _ = reconnectingPTYListener.Close() } }() - if err = a.trackConnGoroutine(func() { + if err = a.trackGoroutine(func() { logger := a.logger.Named("reconnecting-pty") var wg sync.WaitGroup for { @@ -1072,7 +1166,7 @@ func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *t go func() { select { case <-closed: - case <-a.closed: + case <-a.hardCtx.Done(): _ = conn.Close() } wg.Done() @@ -1115,7 +1209,7 @@ func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *t _ = speedtestListener.Close() } }() - if err = a.trackConnGoroutine(func() { + if err = a.trackGoroutine(func() { var wg sync.WaitGroup for { conn, err := speedtestListener.Accept() @@ -1134,7 +1228,7 @@ func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *t go func() { select { case <-closed: - case <-a.closed: + case <-a.hardCtx.Done(): _ = conn.Close() } wg.Done() @@ -1163,7 +1257,7 @@ func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *t _ = apiListener.Close() } }() - if err = a.trackConnGoroutine(func() { + if err = a.trackGoroutine(func() { defer apiListener.Close() server := &http.Server{ Handler: a.apiHandler(), @@ -1175,7 +1269,7 @@ func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *t go func() { select { case <-ctx.Done(): - case <-a.closed: + case <-a.hardCtx.Done(): } _ = server.Close() }() @@ -1196,7 +1290,9 @@ func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *t func (a *agent) runCoordinator(ctx context.Context, conn drpc.Conn, network *tailnet.Conn) error { defer a.logger.Debug(ctx, "disconnected from coordination RPC") tClient := tailnetproto.NewDRPCTailnetClient(conn) - coordinate, err := tClient.Coordinate(ctx) + // we run the RPC on the hardCtx so that we have a chance to send the disconnect message if we + // gracefully shut down. + coordinate, err := tClient.Coordinate(a.hardCtx) if err != nil { return xerrors.Errorf("failed to connect to the coordinate endpoint: %w", err) } @@ -1207,13 +1303,34 @@ func (a *agent) runCoordinator(ctx context.Context, conn drpc.Conn, network *tai } }() a.logger.Info(ctx, "connected to coordination RPC") - coordination := tailnet.NewRemoteCoordination(a.logger, coordinate, network, uuid.Nil) - select { - case <-ctx.Done(): - return ctx.Err() - case err := <-coordination.Error(): - return err + + // This allows the Close() routine to wait for the coordinator to gracefully disconnect. + a.closeMutex.Lock() + if a.isClosed() { + return nil } + disconnected := make(chan struct{}) + a.coordDisconnected = disconnected + defer close(disconnected) + a.closeMutex.Unlock() + + coordination := tailnet.NewRemoteCoordination(a.logger, coordinate, network, uuid.Nil) + + errCh := make(chan error, 1) + go func() { + defer close(errCh) + select { + case <-ctx.Done(): + err := coordination.Close() + if err != nil { + a.logger.Warn(ctx, "failed to close remote coordination", slog.Error(err)) + } + return + case err := <-coordination.Error(): + errCh <- err + } + }() + return <-errCh } // runDERPMapSubscriber runs a coordinator and returns if a reconnect should occur. @@ -1311,7 +1428,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m Metrics: a.metrics.reconnectingPTYErrors, }, logger.With(slog.F("message_id", msg.ID))) - if err = a.trackConnGoroutine(func() { + if err = a.trackGoroutine(func() { rpty.Wait() a.reconnectingPTYs.Delete(msg.ID) }); err != nil { @@ -1406,7 +1523,9 @@ func (a *agent) Collect(ctx context.Context, networkStats map[netlogtype.Connect var prioritizedProcs = []string{"coder agent"} -func (a *agent) manageProcessPriorityLoop(ctx context.Context) { +func (a *agent) manageProcessPriorityUntilGracefulShutdown() { + // process priority can stop as soon as we are gracefully shutting down + ctx := a.gracefulCtx defer func() { if r := recover(); r != nil { a.logger.Critical(ctx, "recovered from panic", @@ -1515,12 +1634,7 @@ func (a *agent) manageProcessPriority(ctx context.Context) ([]*agentproc.Process // isClosed returns whether the API is closed or not. func (a *agent) isClosed() bool { - select { - case <-a.closed: - return true - default: - return false - } + return a.hardCtx.Err() != nil } func (a *agent) HTTPDebug() http.Handler { @@ -1584,59 +1698,82 @@ func (a *agent) Close() error { return nil } - ctx := context.Background() - a.logger.Info(ctx, "shutting down agent") - a.setLifecycle(ctx, codersdk.WorkspaceAgentLifecycleShuttingDown) + a.logger.Info(a.hardCtx, "shutting down agent") + a.setLifecycle(codersdk.WorkspaceAgentLifecycleShuttingDown) // Attempt to gracefully shut down all active SSH connections and // stop accepting new ones. - err := a.sshServer.Shutdown(ctx) + err := a.sshServer.Shutdown(a.hardCtx) if err != nil { - a.logger.Error(ctx, "ssh server shutdown", slog.Error(err)) + a.logger.Error(a.hardCtx, "ssh server shutdown", slog.Error(err)) } + err = a.sshServer.Close() + if err != nil { + a.logger.Error(a.hardCtx, "ssh server close", slog.Error(err)) + } + // wait for SSH to shut down before the general graceful cancel, because + // this triggers a disconnect in the tailnet layer, telling all clients to + // shut down their wireguard tunnels to us. If SSH sessions are still up, + // they might hang instead of being closed. + a.gracefulCancel() lifecycleState := codersdk.WorkspaceAgentLifecycleOff - err = a.scriptRunner.Execute(ctx, func(script codersdk.WorkspaceAgentScript) bool { + err = a.scriptRunner.Execute(a.hardCtx, func(script codersdk.WorkspaceAgentScript) bool { return script.RunOnStop }) if err != nil { - a.logger.Warn(ctx, "shutdown script(s) failed", slog.Error(err)) + a.logger.Warn(a.hardCtx, "shutdown script(s) failed", slog.Error(err)) if errors.Is(err, agentscripts.ErrTimeout) { lifecycleState = codersdk.WorkspaceAgentLifecycleShutdownTimeout } else { lifecycleState = codersdk.WorkspaceAgentLifecycleShutdownError } } - a.setLifecycle(ctx, lifecycleState) + a.setLifecycle(lifecycleState) err = a.scriptRunner.Close() if err != nil { - a.logger.Error(ctx, "script runner close", slog.Error(err)) + a.logger.Error(a.hardCtx, "script runner close", slog.Error(err)) } - // Wait for the lifecycle to be reported, but don't wait forever so + // Wait for the graceful shutdown to complete, but don't wait forever so // that we don't break user expectations. - ctx, cancel := context.WithTimeout(ctx, 5*time.Second) - defer cancel() + go func() { + defer a.hardCancel() + select { + case <-a.hardCtx.Done(): + case <-time.After(5 * time.Second): + } + }() + + // Wait for lifecycle to be reported lifecycleWaitLoop: for { select { - case <-ctx.Done(): + case <-a.hardCtx.Done(): + a.logger.Warn(context.Background(), "failed to report final lifecycle state") break lifecycleWaitLoop case s := <-a.lifecycleReported: if s == lifecycleState { + a.logger.Debug(context.Background(), "reported final lifecycle state") break lifecycleWaitLoop } } } - close(a.closed) - a.closeCancel() - _ = a.sshServer.Close() + // Wait for graceful disconnect from the Coordinator RPC + select { + case <-a.hardCtx.Done(): + a.logger.Warn(context.Background(), "timed out waiting for Coordinator RPC disconnect") + case <-a.coordDisconnected: + a.logger.Debug(context.Background(), "coordinator RPC disconnected") + } + + a.hardCancel() if a.network != nil { _ = a.network.Close() } - a.connCloseWait.Wait() + a.closeWaitGroup.Wait() return nil } @@ -1688,3 +1825,94 @@ func expandDirectory(dir string) (string, error) { // specialized environment in which the agent is running // (e.g. envbox, envbuilder). const EnvAgentSubsystem = "CODER_AGENT_SUBSYSTEM" + +// eitherContext returns a context that is canceled when either context ends. +func eitherContext(a, b context.Context) context.Context { + ctx, cancel := context.WithCancel(a) + go func() { + defer cancel() + select { + case <-a.Done(): + case <-b.Done(): + } + }() + return ctx +} + +type gracefulShutdownBehavior int + +const ( + gracefulShutdownBehaviorStop gracefulShutdownBehavior = iota + gracefulShutdownBehaviorRemain +) + +type apiConnRoutineManager struct { + logger slog.Logger + conn drpc.Conn + eg *errgroup.Group + stopCtx context.Context + remainCtx context.Context +} + +func newAPIConnRoutineManager(gracefulCtx, hardCtx context.Context, logger slog.Logger, conn drpc.Conn) *apiConnRoutineManager { + // routines that remain in operation during graceful shutdown use the remainCtx. They'll still + // exit if the errgroup hits an error, which usually means a problem with the conn. + eg, remainCtx := errgroup.WithContext(hardCtx) + + // routines that stop operation during graceful shutdown use the stopCtx, which ends when the + // first of remainCtx or gracefulContext ends (an error or start of graceful shutdown). + // + // +------------------------------------------+ + // | hardCtx | + // | +------------------------------------+ | + // | | stopCtx | | + // | | +--------------+ +--------------+ | | + // | | | remainCtx | | gracefulCtx | | | + // | | +--------------+ +--------------+ | | + // | +------------------------------------+ | + // +------------------------------------------+ + stopCtx := eitherContext(remainCtx, gracefulCtx) + return &apiConnRoutineManager{ + logger: logger, + conn: conn, + eg: eg, + stopCtx: stopCtx, + remainCtx: remainCtx, + } +} + +func (a *apiConnRoutineManager) start(name string, b gracefulShutdownBehavior, f func(context.Context, drpc.Conn) error) { + logger := a.logger.With(slog.F("name", name)) + var ctx context.Context + switch b { + case gracefulShutdownBehaviorStop: + ctx = a.stopCtx + case gracefulShutdownBehaviorRemain: + ctx = a.remainCtx + default: + panic("unknown behavior") + } + a.eg.Go(func() error { + logger.Debug(ctx, "starting routine") + err := f(ctx, a.conn) + if xerrors.Is(err, context.Canceled) && ctx.Err() != nil { + logger.Debug(ctx, "swallowing context canceled") + // Don't propagate context canceled errors to the error group, because we don't want the + // graceful context being canceled to halt the work of routines with + // gracefulShutdownBehaviorRemain. Note that we check both that the error is + // context.Canceled and that *our* context is currently canceled, because when Coderd + // unilaterally closes the API connection (for example if the build is outdated), it can + // sometimes show up as context.Canceled in our RPC calls. + return nil + } + logger.Debug(ctx, "routine exited", slog.Error(err)) + if err != nil { + return xerrors.Errorf("error in routine %s: %w", name, err) + } + return nil + }) +} + +func (a *apiConnRoutineManager) wait() error { + return a.eg.Wait() +} diff --git a/cli/ssh_test.go b/cli/ssh_test.go index 81019788c7..cc88bc52d7 100644 --- a/cli/ssh_test.go +++ b/cli/ssh_test.go @@ -162,7 +162,13 @@ func TestSSH(t *testing.T) { coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspaceBuild.ID) // Update template version - version = coderdtest.UpdateTemplateVersion(t, ownerClient, owner.OrganizationID, echoResponses, template.ID) + authToken2 := uuid.NewString() + echoResponses2 := &echo.Responses{ + Parse: echo.ParseComplete, + ProvisionPlan: echo.PlanComplete, + ProvisionApply: echo.ProvisionApplyWithAgent(authToken2), + } + version = coderdtest.UpdateTemplateVersion(t, ownerClient, owner.OrganizationID, echoResponses2, template.ID) coderdtest.AwaitTemplateVersionJobCompleted(t, ownerClient, version.ID) err := ownerClient.UpdateActiveTemplateVersion(context.Background(), template.ID, codersdk.UpdateActiveTemplateVersion{ ID: version.ID, @@ -184,7 +190,7 @@ func TestSSH(t *testing.T) { // When the agent connects, the workspace was started, and we should // have access to the shell. - _ = agenttest.New(t, client.URL, authToken) + _ = agenttest.New(t, client.URL, authToken2) coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID) // Shells on Mac, Windows, and Linux all exit shells with the "exit" command. diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index b1c496e4ba..e75c32f9b0 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -193,7 +193,7 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can options = &Options{} } if options.Logger == nil { - logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug).Named("coderd") options.Logger = &logger } if options.GoogleTokenValidator == nil { diff --git a/codersdk/workspaceagents.go b/codersdk/workspaceagents.go index 20d1e221e2..7a8b2e5d3b 100644 --- a/codersdk/workspaceagents.go +++ b/codersdk/workspaceagents.go @@ -534,7 +534,7 @@ func (tac *tailnetAPIConnector) coordinate(client proto.DRPCTailnetClient) { tac.logger.Debug(tac.ctx, "main context canceled; do graceful disconnect") crdErr := coordination.Close() if crdErr != nil { - tac.logger.Error(tac.ctx, "failed to close remote coordination", slog.Error(err)) + tac.logger.Warn(tac.ctx, "failed to close remote coordination", slog.Error(err)) } case err = <-coordination.Error(): if err != nil && diff --git a/enterprise/coderd/coderd_test.go b/enterprise/coderd/coderd_test.go index f69fbff8d4..1564b6587e 100644 --- a/enterprise/coderd/coderd_test.go +++ b/enterprise/coderd/coderd_test.go @@ -231,13 +231,14 @@ func TestAuditLogging(t *testing.T) { }, DontAddLicense: true, }) - workspace, agent := setupWorkspaceAgent(t, client, user, 0) - conn, err := client.DialWorkspaceAgent(ctx, agent.ID, nil) //nolint:gocritic // RBAC is not the purpose of this test + r := setupWorkspaceAgent(t, client, user, 0) + conn, err := client.DialWorkspaceAgent(ctx, r.sdkAgent.ID, nil) //nolint:gocritic // RBAC is not the purpose of this test require.NoError(t, err) defer conn.Close() connected := conn.AwaitReachable(ctx) require.True(t, connected) - build := coderdtest.CreateWorkspaceBuild(t, client, workspace, database.WorkspaceTransitionStop) + _ = r.agent.Close() // close first so we don't drop error logs from outdated build + build := coderdtest.CreateWorkspaceBuild(t, client, r.workspace, database.WorkspaceTransitionStop) coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, build.ID) }) } diff --git a/enterprise/coderd/replicas_test.go b/enterprise/coderd/replicas_test.go index 1081ec81e3..6d348db782 100644 --- a/enterprise/coderd/replicas_test.go +++ b/enterprise/coderd/replicas_test.go @@ -81,8 +81,8 @@ func TestReplicas(t *testing.T) { require.NoError(t, err) require.Len(t, replicas, 2) - _, agent := setupWorkspaceAgent(t, firstClient, firstUser, 0) - conn, err := secondClient.DialWorkspaceAgent(context.Background(), agent.ID, &codersdk.DialWorkspaceAgentOptions{ + r := setupWorkspaceAgent(t, firstClient, firstUser, 0) + conn, err := secondClient.DialWorkspaceAgent(context.Background(), r.sdkAgent.ID, &codersdk.DialWorkspaceAgentOptions{ BlockEndpoints: true, Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug), }) @@ -127,8 +127,8 @@ func TestReplicas(t *testing.T) { require.NoError(t, err) require.Len(t, replicas, 2) - _, agent := setupWorkspaceAgent(t, firstClient, firstUser, 0) - conn, err := secondClient.DialWorkspaceAgent(context.Background(), agent.ID, &codersdk.DialWorkspaceAgentOptions{ + r := setupWorkspaceAgent(t, firstClient, firstUser, 0) + conn, err := secondClient.DialWorkspaceAgent(context.Background(), r.sdkAgent.ID, &codersdk.DialWorkspaceAgentOptions{ BlockEndpoints: true, Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug), }) diff --git a/enterprise/coderd/workspaceagents_test.go b/enterprise/coderd/workspaceagents_test.go index 7745eb7289..a6cf84a594 100644 --- a/enterprise/coderd/workspaceagents_test.go +++ b/enterprise/coderd/workspaceagents_test.go @@ -44,9 +44,9 @@ func TestBlockNonBrowser(t *testing.T) { }, }, }) - _, agent := setupWorkspaceAgent(t, client, user, 0) + r := setupWorkspaceAgent(t, client, user, 0) //nolint:gocritic // Testing that even the owner gets blocked. - _, err := client.DialWorkspaceAgent(context.Background(), agent.ID, nil) + _, err := client.DialWorkspaceAgent(context.Background(), r.sdkAgent.ID, nil) var apiErr *codersdk.Error require.ErrorAs(t, err, &apiErr) require.Equal(t, http.StatusConflict, apiErr.StatusCode()) @@ -63,15 +63,21 @@ func TestBlockNonBrowser(t *testing.T) { }, }, }) - _, agent := setupWorkspaceAgent(t, client, user, 0) + r := setupWorkspaceAgent(t, client, user, 0) //nolint:gocritic // Testing RBAC is not the point of this test. - conn, err := client.DialWorkspaceAgent(context.Background(), agent.ID, nil) + conn, err := client.DialWorkspaceAgent(context.Background(), r.sdkAgent.ID, nil) require.NoError(t, err) _ = conn.Close() }) } -func setupWorkspaceAgent(t *testing.T, client *codersdk.Client, user codersdk.CreateFirstUserResponse, appPort uint16) (codersdk.Workspace, codersdk.WorkspaceAgent) { +type setupResp struct { + workspace codersdk.Workspace + sdkAgent codersdk.WorkspaceAgent + agent agent.Agent +} + +func setupWorkspaceAgent(t *testing.T, client *codersdk.Client, user codersdk.CreateFirstUserResponse, appPort uint16) setupResp { authToken := uuid.NewString() version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{ Parse: echo.ParseComplete, @@ -127,20 +133,20 @@ func setupWorkspaceAgent(t *testing.T, client *codersdk.Client, user codersdk.Cr }, } agentClient.SetSessionToken(authToken) - agentCloser := agent.New(agent.Options{ + agnt := agent.New(agent.Options{ Client: agentClient, Logger: slogtest.Make(t, nil).Named("agent"), }) t.Cleanup(func() { - _ = agentCloser.Close() + _ = agnt.Close() }) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID) - agnt, err := client.WorkspaceAgent(ctx, resources[0].Agents[0].ID) + sdkAgent, err := client.WorkspaceAgent(ctx, resources[0].Agents[0].ID) require.NoError(t, err) - return workspace, agnt + return setupResp{workspace, sdkAgent, agnt} } diff --git a/enterprise/coderd/workspaceportshare_test.go b/enterprise/coderd/workspaceportshare_test.go index 04d2d83967..1a8543db68 100644 --- a/enterprise/coderd/workspaceportshare_test.go +++ b/enterprise/coderd/workspaceportshare_test.go @@ -31,7 +31,7 @@ func TestWorkspacePortShare(t *testing.T) { }, }) client, user := coderdtest.CreateAnotherUser(t, ownerClient, owner.OrganizationID, rbac.RoleTemplateAdmin()) - workspace, agent := setupWorkspaceAgent(t, client, codersdk.CreateFirstUserResponse{ + r := setupWorkspaceAgent(t, client, codersdk.CreateFirstUserResponse{ UserID: user.ID, OrganizationID: owner.OrganizationID, }, 0) @@ -39,8 +39,8 @@ func TestWorkspacePortShare(t *testing.T) { defer cancel() // try to update port share with template max port share level owner - _, err := client.UpsertWorkspaceAgentPortShare(ctx, workspace.ID, codersdk.UpsertWorkspaceAgentPortShareRequest{ - AgentName: agent.Name, + _, err := client.UpsertWorkspaceAgentPortShare(ctx, r.workspace.ID, codersdk.UpsertWorkspaceAgentPortShareRequest{ + AgentName: r.sdkAgent.Name, Port: 8080, ShareLevel: codersdk.WorkspaceAgentPortShareLevelPublic, }) @@ -48,13 +48,13 @@ func TestWorkspacePortShare(t *testing.T) { // update the template max port share level to public var level codersdk.WorkspaceAgentPortShareLevel = codersdk.WorkspaceAgentPortShareLevelPublic - client.UpdateTemplateMeta(ctx, workspace.TemplateID, codersdk.UpdateTemplateMeta{ + client.UpdateTemplateMeta(ctx, r.workspace.TemplateID, codersdk.UpdateTemplateMeta{ MaxPortShareLevel: &level, }) // OK - ps, err := client.UpsertWorkspaceAgentPortShare(ctx, workspace.ID, codersdk.UpsertWorkspaceAgentPortShareRequest{ - AgentName: agent.Name, + ps, err := client.UpsertWorkspaceAgentPortShare(ctx, r.workspace.ID, codersdk.UpsertWorkspaceAgentPortShareRequest{ + AgentName: r.sdkAgent.Name, Port: 8080, ShareLevel: codersdk.WorkspaceAgentPortShareLevelPublic, }) diff --git a/tailnet/coordinator.go b/tailnet/coordinator.go index 530b42aea3..842a6bcbfa 100644 --- a/tailnet/coordinator.go +++ b/tailnet/coordinator.go @@ -131,7 +131,8 @@ func (c *remoteCoordination) Close() (retErr error) { } }() err := c.protocol.Send(&proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}}) - if err != nil { + if err != nil && !xerrors.Is(err, io.EOF) { + // Coordinator RPC hangs up when it gets disconnect, so EOF is expected. return xerrors.Errorf("send disconnect: %w", err) } c.logger.Debug(context.Background(), "sent disconnect")