refactor(cli/cliui): extract agentWaiter struct from agent connection state machine (#21104)

The Agent function had complex nested control flow and cross-case state sharing
via the showStartupLogs flag. This made the code hard to follow and maintain.

This change extract an agentWaiter struct with self-contained methods:

- wait: main state machine loop
- waitForConnection: handles Connecting/Timeout states
- handleConnected: handles Connected state and startup scripts
- streamLogs: handles log streaming/fetching
- waitForReconnection: handles Disconnected state
- pollWhile: helper to consolidate polling loops

Each handler is now self-contained with no cross-method state sharing and the 
showStartupLogs flag is replaced by return values and the waitedForConnection
tracking variable.
This commit is contained in:
Mathias Fredriksson
2025-12-08 16:00:25 +02:00
committed by GitHub
parent 04d5ff88e4
commit 0c453d7f8e
+248 -161
View File
@@ -20,6 +20,12 @@ import (
var errAgentShuttingDown = xerrors.New("agent is shutting down")
// fetchAgentResult is used to pass agent fetch results through channels.
type fetchAgentResult struct {
agent codersdk.WorkspaceAgent
err error
}
type AgentOptions struct {
FetchInterval time.Duration
Fetch func(ctx context.Context, agentID uuid.UUID) (codersdk.WorkspaceAgent, error)
@@ -28,6 +34,14 @@ type AgentOptions struct {
DocsURL string
}
// agentWaiter encapsulates the state machine for waiting on a workspace agent.
type agentWaiter struct {
opts AgentOptions
sw *stageWriter
logSources map[uuid.UUID]codersdk.WorkspaceAgentLogSource
fetchAgent func(context.Context) (codersdk.WorkspaceAgent, error)
}
// Agent displays a spinning indicator that waits for a workspace agent to connect.
func Agent(ctx context.Context, writer io.Writer, agentID uuid.UUID, opts AgentOptions) error {
ctx, cancel := context.WithCancel(ctx)
@@ -44,11 +58,7 @@ func Agent(ctx context.Context, writer io.Writer, agentID uuid.UUID, opts AgentO
}
}
type fetchAgent struct {
agent codersdk.WorkspaceAgent
err error
}
fetchedAgent := make(chan fetchAgent, 1)
fetchedAgent := make(chan fetchAgentResult, 1)
go func() {
t := time.NewTimer(0)
defer t.Stop()
@@ -67,10 +77,10 @@ func Agent(ctx context.Context, writer io.Writer, agentID uuid.UUID, opts AgentO
default:
}
if err != nil {
fetchedAgent <- fetchAgent{err: xerrors.Errorf("fetch workspace agent: %w", err)}
fetchedAgent <- fetchAgentResult{err: xerrors.Errorf("fetch workspace agent: %w", err)}
return
}
fetchedAgent <- fetchAgent{agent: agent}
fetchedAgent <- fetchAgentResult{agent: agent}
// Adjust the interval based on how long we've been waiting.
elapsed := time.Since(startTime)
@@ -79,7 +89,7 @@ func Agent(ctx context.Context, writer io.Writer, agentID uuid.UUID, opts AgentO
}
}
}()
fetch := func() (codersdk.WorkspaceAgent, error) {
fetch := func(ctx context.Context) (codersdk.WorkspaceAgent, error) {
select {
case <-ctx.Done():
return codersdk.WorkspaceAgent{}, ctx.Err()
@@ -91,7 +101,7 @@ func Agent(ctx context.Context, writer io.Writer, agentID uuid.UUID, opts AgentO
}
}
agent, err := fetch()
agent, err := fetch(ctx)
if err != nil {
return xerrors.Errorf("fetch: %w", err)
}
@@ -100,9 +110,23 @@ func Agent(ctx context.Context, writer io.Writer, agentID uuid.UUID, opts AgentO
logSources[source.ID] = source
}
sw := &stageWriter{w: writer}
w := &agentWaiter{
opts: opts,
sw: &stageWriter{w: writer},
logSources: logSources,
fetchAgent: fetch,
}
return w.wait(ctx, agent, fetchedAgent)
}
// wait runs the main state machine loop.
func (aw *agentWaiter) wait(ctx context.Context, agent codersdk.WorkspaceAgent, fetchedAgent chan fetchAgentResult) error {
var err error
// Track whether we've gone through a wait state, which determines if we
// should show startup logs when connected.
waitedForConnection := false
showStartupLogs := false
for {
// It doesn't matter if we're connected or not, if the agent is
// shutting down, we don't know if it's coming back.
@@ -112,167 +136,230 @@ func Agent(ctx context.Context, writer io.Writer, agentID uuid.UUID, opts AgentO
switch agent.Status {
case codersdk.WorkspaceAgentConnecting, codersdk.WorkspaceAgentTimeout:
// Since we were waiting for the agent to connect, also show
// startup logs if applicable.
showStartupLogs = true
stage := "Waiting for the workspace agent to connect"
sw.Start(stage)
for agent.Status == codersdk.WorkspaceAgentConnecting {
if agent, err = fetch(); err != nil {
return xerrors.Errorf("fetch: %w", err)
}
}
if agent.Status == codersdk.WorkspaceAgentTimeout {
now := time.Now()
sw.Log(now, codersdk.LogLevelInfo, "The workspace agent is having trouble connecting, wait for it to connect or restart your workspace.")
sw.Log(now, codersdk.LogLevelInfo, troubleshootingMessage(agent, fmt.Sprintf("%s/admin/templates/troubleshooting#agent-connection-issues", opts.DocsURL)))
for agent.Status == codersdk.WorkspaceAgentTimeout {
if agent, err = fetch(); err != nil {
return xerrors.Errorf("fetch: %w", err)
}
}
}
sw.Complete(stage, agent.FirstConnectedAt.Sub(agent.CreatedAt))
case codersdk.WorkspaceAgentConnected:
if !showStartupLogs && agent.LifecycleState == codersdk.WorkspaceAgentLifecycleReady {
// The workspace is ready, there's nothing to do but connect.
return nil
}
stage := "Running workspace agent startup scripts"
follow := opts.Wait && agent.LifecycleState.Starting()
if !follow {
stage += " (non-blocking)"
}
sw.Start(stage)
if follow {
sw.Log(time.Time{}, codersdk.LogLevelInfo, "==> ︎ To connect immediately, reconnect with --wait=no or CODER_SSH_WAIT=no, see --help for more information.")
}
err = func() error { // Use func because of defer in for loop.
logStream, logsCloser, err := opts.FetchLogs(ctx, agent.ID, 0, follow)
if err != nil {
return xerrors.Errorf("fetch workspace agent startup logs: %w", err)
}
defer logsCloser.Close()
var lastLog codersdk.WorkspaceAgentLog
fetchedAgentWhileFollowing := fetchedAgent
if !follow {
fetchedAgentWhileFollowing = nil
}
for {
// This select is essentially and inline `fetch()`.
select {
case <-ctx.Done():
return ctx.Err()
case f := <-fetchedAgentWhileFollowing:
if f.err != nil {
return xerrors.Errorf("fetch: %w", f.err)
}
agent = f.agent
// If the agent is no longer starting, stop following
// logs because FetchLogs will keep streaming forever.
// We do one last non-follow request to ensure we have
// fetched all logs.
if !agent.LifecycleState.Starting() {
_ = logsCloser.Close()
fetchedAgentWhileFollowing = nil
logStream, logsCloser, err = opts.FetchLogs(ctx, agent.ID, lastLog.ID, false)
if err != nil {
return xerrors.Errorf("fetch workspace agent startup logs: %w", err)
}
// Logs are already primed, so we can call close.
_ = logsCloser.Close()
}
case logs, ok := <-logStream:
if !ok {
return nil
}
for _, log := range logs {
source, hasSource := logSources[log.SourceID]
output := log.Output
if hasSource && source.DisplayName != "" {
output = source.DisplayName + ": " + output
}
sw.Log(log.CreatedAt, log.Level, output)
lastLog = log
}
}
}
}()
agent, err = aw.waitForConnection(ctx, agent)
if err != nil {
return err
}
// Since we were waiting for the agent to connect, also show
// startup logs if applicable.
waitedForConnection = true
for follow && agent.LifecycleState.Starting() {
if agent, err = fetch(); err != nil {
return xerrors.Errorf("fetch: %w", err)
}
}
switch agent.LifecycleState {
case codersdk.WorkspaceAgentLifecycleReady:
sw.Complete(stage, safeDuration(sw, agent.ReadyAt, agent.StartedAt))
case codersdk.WorkspaceAgentLifecycleStartTimeout:
// Backwards compatibility: Avoid printing warning if
// coderd is old and doesn't set ReadyAt for timeouts.
if agent.ReadyAt == nil {
sw.Fail(stage, 0)
} else {
sw.Fail(stage, safeDuration(sw, agent.ReadyAt, agent.StartedAt))
}
sw.Log(time.Time{}, codersdk.LogLevelWarn, "Warning: A startup script timed out and your workspace may be incomplete.")
case codersdk.WorkspaceAgentLifecycleStartError:
sw.Fail(stage, safeDuration(sw, agent.ReadyAt, agent.StartedAt))
// Use zero time (omitted) to separate these from the startup logs.
sw.Log(time.Time{}, codersdk.LogLevelWarn, "Warning: A startup script exited with an error and your workspace may be incomplete.")
sw.Log(time.Time{}, codersdk.LogLevelWarn, troubleshootingMessage(agent, fmt.Sprintf("%s/admin/templates/troubleshooting#startup-script-exited-with-an-error", opts.DocsURL)))
default:
switch {
case agent.LifecycleState.Starting():
// Use zero time (omitted) to separate these from the startup logs.
sw.Log(time.Time{}, codersdk.LogLevelWarn, "Notice: The startup scripts are still running and your workspace may be incomplete.")
sw.Log(time.Time{}, codersdk.LogLevelWarn, troubleshootingMessage(agent, fmt.Sprintf("%s/admin/templates/troubleshooting#your-workspace-may-be-incomplete", opts.DocsURL)))
// Note: We don't complete or fail the stage here, it's
// intentionally left open to indicate this stage didn't
// complete.
case agent.LifecycleState.ShuttingDown():
// We no longer know if the startup script failed or not,
// but we need to tell the user something.
sw.Complete(stage, safeDuration(sw, agent.ReadyAt, agent.StartedAt))
return errAgentShuttingDown
}
}
return nil
case codersdk.WorkspaceAgentConnected:
return aw.handleConnected(ctx, agent, waitedForConnection, fetchedAgent)
case codersdk.WorkspaceAgentDisconnected:
// If the agent was still starting during disconnect, we'll
// show startup logs.
showStartupLogs = agent.LifecycleState.Starting()
stage := "The workspace agent lost connection"
sw.Start(stage)
sw.Log(time.Now(), codersdk.LogLevelWarn, "Wait for it to reconnect or restart your workspace.")
sw.Log(time.Now(), codersdk.LogLevelWarn, troubleshootingMessage(agent, fmt.Sprintf("%s/admin/templates/troubleshooting#agent-connection-issues", opts.DocsURL)))
disconnectedAt := agent.DisconnectedAt
for agent.Status == codersdk.WorkspaceAgentDisconnected {
if agent, err = fetch(); err != nil {
return xerrors.Errorf("fetch: %w", err)
}
agent, waitedForConnection, err = aw.waitForReconnection(ctx, agent)
if err != nil {
return err
}
sw.Complete(stage, safeDuration(sw, agent.LastConnectedAt, disconnectedAt))
}
}
}
// waitForConnection handles the Connecting/Timeout states.
// Returns when agent transitions to Connected or Disconnected.
func (aw *agentWaiter) waitForConnection(ctx context.Context, agent codersdk.WorkspaceAgent) (codersdk.WorkspaceAgent, error) {
stage := "Waiting for the workspace agent to connect"
aw.sw.Start(stage)
agent, err := aw.pollWhile(ctx, agent, func(agent codersdk.WorkspaceAgent) bool {
return agent.Status == codersdk.WorkspaceAgentConnecting
})
if err != nil {
return agent, err
}
if agent.Status == codersdk.WorkspaceAgentTimeout {
now := time.Now()
aw.sw.Log(now, codersdk.LogLevelInfo, "The workspace agent is having trouble connecting, wait for it to connect or restart your workspace.")
aw.sw.Log(now, codersdk.LogLevelInfo, troubleshootingMessage(agent, fmt.Sprintf("%s/admin/templates/troubleshooting#agent-connection-issues", aw.opts.DocsURL)))
agent, err = aw.pollWhile(ctx, agent, func(agent codersdk.WorkspaceAgent) bool {
return agent.Status == codersdk.WorkspaceAgentTimeout
})
if err != nil {
return agent, err
}
}
aw.sw.Complete(stage, agent.FirstConnectedAt.Sub(agent.CreatedAt))
return agent, nil
}
// handleConnected handles the Connected state and startup script logic.
// This is a terminal state, returns nil on success or error on failure.
//
//nolint:revive // Control flag is acceptable for internal method.
func (aw *agentWaiter) handleConnected(ctx context.Context, agent codersdk.WorkspaceAgent, showStartupLogs bool, fetchedAgent chan fetchAgentResult) error {
if !showStartupLogs && agent.LifecycleState == codersdk.WorkspaceAgentLifecycleReady {
// The workspace is ready, there's nothing to do but connect.
return nil
}
// Determine if we should follow/stream logs (blocking mode).
follow := aw.opts.Wait && agent.LifecycleState.Starting()
stage := "Running workspace agent startup scripts"
if !follow {
stage += " (non-blocking)"
}
aw.sw.Start(stage)
if follow {
aw.sw.Log(time.Time{}, codersdk.LogLevelInfo, "==> ︎ To connect immediately, reconnect with --wait=no or CODER_SSH_WAIT=no, see --help for more information.")
}
agent, err := aw.streamLogs(ctx, agent, follow, fetchedAgent)
if err != nil {
return err
}
// If we were following, wait until startup completes.
if follow {
agent, err = aw.pollWhile(ctx, agent, func(agent codersdk.WorkspaceAgent) bool {
return agent.LifecycleState.Starting()
})
if err != nil {
return err
}
}
// Handle final lifecycle state.
switch agent.LifecycleState {
case codersdk.WorkspaceAgentLifecycleReady:
aw.sw.Complete(stage, safeDuration(aw.sw, agent.ReadyAt, agent.StartedAt))
case codersdk.WorkspaceAgentLifecycleStartTimeout:
// Backwards compatibility: Avoid printing warning if
// coderd is old and doesn't set ReadyAt for timeouts.
if agent.ReadyAt == nil {
aw.sw.Fail(stage, 0)
} else {
aw.sw.Fail(stage, safeDuration(aw.sw, agent.ReadyAt, agent.StartedAt))
}
aw.sw.Log(time.Time{}, codersdk.LogLevelWarn, "Warning: A startup script timed out and your workspace may be incomplete.")
case codersdk.WorkspaceAgentLifecycleStartError:
aw.sw.Fail(stage, safeDuration(aw.sw, agent.ReadyAt, agent.StartedAt))
aw.sw.Log(time.Time{}, codersdk.LogLevelWarn, "Warning: A startup script exited with an error and your workspace may be incomplete.")
aw.sw.Log(time.Time{}, codersdk.LogLevelWarn, troubleshootingMessage(agent, fmt.Sprintf("%s/admin/templates/troubleshooting#startup-script-exited-with-an-error", aw.opts.DocsURL)))
default:
switch {
case agent.LifecycleState.Starting():
aw.sw.Log(time.Time{}, codersdk.LogLevelWarn, "Notice: The startup scripts are still running and your workspace may be incomplete.")
aw.sw.Log(time.Time{}, codersdk.LogLevelWarn, troubleshootingMessage(agent, fmt.Sprintf("%s/admin/templates/troubleshooting#your-workspace-may-be-incomplete", aw.opts.DocsURL)))
// Note: We don't complete or fail the stage here, it's
// intentionally left open to indicate this stage didn't
// complete.
case agent.LifecycleState.ShuttingDown():
// We no longer know if the startup script failed or not,
// but we need to tell the user something.
aw.sw.Complete(stage, safeDuration(aw.sw, agent.ReadyAt, agent.StartedAt))
return errAgentShuttingDown
}
}
return nil
}
// streamLogs handles streaming or fetching startup logs.
//
//nolint:revive // Control flag is acceptable for internal method.
func (aw *agentWaiter) streamLogs(ctx context.Context, agent codersdk.WorkspaceAgent, follow bool, fetchedAgent chan fetchAgentResult) (codersdk.WorkspaceAgent, error) {
logStream, logsCloser, err := aw.opts.FetchLogs(ctx, agent.ID, 0, follow)
if err != nil {
return agent, xerrors.Errorf("fetch workspace agent startup logs: %w", err)
}
defer logsCloser.Close()
var lastLog codersdk.WorkspaceAgentLog
// If not following, we don't need to watch for agent state changes.
var fetchedAgentWhileFollowing chan fetchAgentResult
if follow {
fetchedAgentWhileFollowing = fetchedAgent
}
for {
select {
case <-ctx.Done():
return agent, ctx.Err()
case f := <-fetchedAgentWhileFollowing:
if f.err != nil {
return agent, xerrors.Errorf("fetch: %w", f.err)
}
agent = f.agent
// If the agent is no longer starting, stop following
// logs because FetchLogs will keep streaming forever.
// We do one last non-follow request to ensure we have
// fetched all logs.
if !agent.LifecycleState.Starting() {
_ = logsCloser.Close()
fetchedAgentWhileFollowing = nil
logStream, logsCloser, err = aw.opts.FetchLogs(ctx, agent.ID, lastLog.ID, false)
if err != nil {
return agent, xerrors.Errorf("fetch workspace agent startup logs: %w", err)
}
// Logs are already primed, so we can call close.
_ = logsCloser.Close()
}
case logs, ok := <-logStream:
if !ok {
return agent, nil
}
for _, log := range logs {
source, hasSource := aw.logSources[log.SourceID]
output := log.Output
if hasSource && source.DisplayName != "" {
output = source.DisplayName + ": " + output
}
aw.sw.Log(log.CreatedAt, log.Level, output)
lastLog = log
}
}
}
}
// waitForReconnection handles the Disconnected state.
// Returns when agent reconnects along with whether to show startup logs.
func (aw *agentWaiter) waitForReconnection(ctx context.Context, agent codersdk.WorkspaceAgent) (codersdk.WorkspaceAgent, bool, error) {
// If the agent was still starting during disconnect, we'll
// show startup logs.
showStartupLogs := agent.LifecycleState.Starting()
stage := "The workspace agent lost connection"
aw.sw.Start(stage)
aw.sw.Log(time.Now(), codersdk.LogLevelWarn, "Wait for it to reconnect or restart your workspace.")
aw.sw.Log(time.Now(), codersdk.LogLevelWarn, troubleshootingMessage(agent, fmt.Sprintf("%s/admin/templates/troubleshooting#agent-connection-issues", aw.opts.DocsURL)))
disconnectedAt := agent.DisconnectedAt
agent, err := aw.pollWhile(ctx, agent, func(agent codersdk.WorkspaceAgent) bool {
return agent.Status == codersdk.WorkspaceAgentDisconnected
})
if err != nil {
return agent, showStartupLogs, err
}
aw.sw.Complete(stage, safeDuration(aw.sw, agent.LastConnectedAt, disconnectedAt))
return agent, showStartupLogs, nil
}
// pollWhile polls the agent while the condition is true. It fetches the agent
// on each iteration and returns the updated agent when the condition is false,
// the context is canceled, or an error occurs.
func (aw *agentWaiter) pollWhile(ctx context.Context, agent codersdk.WorkspaceAgent, cond func(agent codersdk.WorkspaceAgent) bool) (codersdk.WorkspaceAgent, error) {
var err error
for cond(agent) {
agent, err = aw.fetchAgent(ctx)
if err != nil {
return agent, xerrors.Errorf("fetch: %w", err)
}
}
if err = ctx.Err(); err != nil {
return agent, err
}
return agent, nil
}
func troubleshootingMessage(agent codersdk.WorkspaceAgent, url string) string {
m := "For more information and troubleshooting, see " + url
if agent.TroubleshootingURL != "" {