From 71c6dc404361b950b3c8fa25fcce89c7e2366cb7 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Mon, 15 Dec 2025 12:04:01 +0400 Subject: [PATCH] fix: stop disconnecting from coderd early and record disconnect correctly (#21250) fixes https://github.com/coder/internal/issues/1196 The above issue exposes two different bugs in Coder. In the agent, there is a race where if the agent is closed while starting up networking, it will erroneously disconnect from Coderd, which delays or breaks writing final status and logs. In Coderd, there is a bug where we don't properly record the latest agent disconnection time if the agent had previously disconnected. This causes us to report the agent status as "Connected" even after it has disconnected up until the inactivity timeout fires. This PR fixes both issues. It also slightly reworks when we send workspace updates based on connection and disconnection. Previously we would send two updates when the agent connected in certain circumstances, even though the status would be the same in both (only times changed). Now we universally only send one on connect, and then another on disconnect. --- agent/agent.go | 62 +++--- coderd/workspaceagentsrpc.go | 35 ++-- coderd/workspaceagentsrpc_internal_test.go | 211 +++++++++++++-------- 3 files changed, 180 insertions(+), 128 deletions(-) diff --git a/agent/agent.go b/agent/agent.go index 9a23ec5210..115735bc69 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -71,6 +71,8 @@ const ( EnvProcOOMScore = "CODER_PROC_OOM_SCORE" ) +var ErrAgentClosing = xerrors.New("agent is closing") + type Options struct { Filesystem afero.Fs LogDir string @@ -401,6 +403,7 @@ func (a *agent) runLoop() { // need to keep retrying up to the hardCtx so that we can send graceful shutdown-related // messages. ctx := a.hardCtx + defer a.logger.Info(ctx, "agent main loop exited") for retrier := retry.New(100*time.Millisecond, 10*time.Second); retrier.Wait(ctx); { a.logger.Info(ctx, "connecting to coderd") err := a.run() @@ -1348,7 +1351,7 @@ func (a *agent) createOrUpdateNetwork(manifestOK, networkOK *checkpoint) func(co a.closeMutex.Unlock() if closing { _ = network.Close() - return xerrors.New("agent is closing") + return xerrors.Errorf("agent closed while creating tailnet: %w", ErrAgentClosing) } } else { // Update the wireguard IPs if the agent ID changed. @@ -1471,7 +1474,7 @@ func (a *agent) trackGoroutine(fn func()) error { a.closeMutex.Lock() defer a.closeMutex.Unlock() if a.closing { - return xerrors.New("track conn goroutine: agent is closing") + return xerrors.Errorf("track conn goroutine: %w", ErrAgentClosing) } a.closeWaitGroup.Add(1) go func() { @@ -2152,16 +2155,7 @@ func (a *apiConnRoutineManager) startAgentAPI( a.eg.Go(func() error { logger.Debug(ctx, "starting agent routine") err := f(ctx, a.aAPI) - if xerrors.Is(err, context.Canceled) && ctx.Err() != nil { - logger.Debug(ctx, "swallowing context canceled") - // Don't propagate context canceled errors to the error group, because we don't want the - // graceful context being canceled to halt the work of routines with - // gracefulShutdownBehaviorRemain. Note that we check both that the error is - // context.Canceled and that *our* context is currently canceled, because when Coderd - // unilaterally closes the API connection (for example if the build is outdated), it can - // sometimes show up as context.Canceled in our RPC calls. - return nil - } + err = shouldPropagateError(ctx, logger, err) logger.Debug(ctx, "routine exited", slog.Error(err)) if err != nil { return xerrors.Errorf("error in routine %s: %w", name, err) @@ -2189,21 +2183,7 @@ func (a *apiConnRoutineManager) startTailnetAPI( a.eg.Go(func() error { logger.Debug(ctx, "starting tailnet routine") err := f(ctx, a.tAPI) - if (xerrors.Is(err, context.Canceled) || - xerrors.Is(err, io.EOF)) && - ctx.Err() != nil { - logger.Debug(ctx, "swallowing error because context is canceled", slog.Error(err)) - // Don't propagate context canceled errors to the error group, because we don't want the - // graceful context being canceled to halt the work of routines with - // gracefulShutdownBehaviorRemain. Unfortunately, the dRPC library closes the stream - // when context is canceled on an RPC, so canceling the context can also show up as - // io.EOF. Also, when Coderd unilaterally closes the API connection (for example if the - // build is outdated), it can sometimes show up as context.Canceled in our RPC calls. - // We can't reliably distinguish between a context cancelation and a legit EOF, so we - // also check that *our* context is currently canceled. If it is, we can safely ignore - // the error. - return nil - } + err = shouldPropagateError(ctx, logger, err) logger.Debug(ctx, "routine exited", slog.Error(err)) if err != nil { return xerrors.Errorf("error in routine %s: %w", name, err) @@ -2212,6 +2192,34 @@ func (a *apiConnRoutineManager) startTailnetAPI( }) } +// shouldPropagateError decides whether an error from an API connection routine should be propagated to the +// apiConnRoutineManager. Its purpose is to prevent errors related to shutting down from propagating to the manager's +// error group, which will tear down the API connection and potentially stop graceful shutdown from succeeding. +func shouldPropagateError(ctx context.Context, logger slog.Logger, err error) error { + if (xerrors.Is(err, context.Canceled) || + xerrors.Is(err, io.EOF)) && + ctx.Err() != nil { + logger.Debug(ctx, "swallowing error because context is canceled", slog.Error(err)) + // Don't propagate context canceled errors to the error group, because we don't want the + // graceful context being canceled to halt the work of routines with + // gracefulShutdownBehaviorRemain. Unfortunately, the dRPC library closes the stream + // when context is canceled on an RPC, so canceling the context can also show up as + // io.EOF. Also, when Coderd unilaterally closes the API connection (for example if the + // build is outdated), it can sometimes show up as context.Canceled in our RPC calls. + // We can't reliably distinguish between a context cancelation and a legit EOF, so we + // also check that *our* context is currently canceled. If it is, we can safely ignore + // the error. + return nil + } + if xerrors.Is(err, ErrAgentClosing) { + logger.Debug(ctx, "swallowing error because agent is closing") + // This can only be generated when the agent is closing, so we never want it to propagate to other routines. + // (They are signaled to exit via canceled contexts.) + return nil + } + return err +} + func (a *apiConnRoutineManager) wait() error { return a.eg.Wait() } diff --git a/coderd/workspaceagentsrpc.go b/coderd/workspaceagentsrpc.go index 37d5e6d3b7..3046a22d89 100644 --- a/coderd/workspaceagentsrpc.go +++ b/coderd/workspaceagentsrpc.go @@ -359,7 +359,16 @@ func (m *agentConnectionMonitor) start(ctx context.Context) { } func (m *agentConnectionMonitor) monitor(ctx context.Context) { + reason := "disconnect" defer func() { + m.logger.Debug(ctx, "agent connection monitor is closing connection", + slog.F("reason", reason)) + _ = m.conn.Close(websocket.StatusGoingAway, reason) + m.disconnectedAt = sql.NullTime{ + Time: dbtime.Now(), + Valid: true, + } + // If connection closed then context will be canceled, try to // ensure our final update is sent. By waiting at most the agent // inactive disconnect timeout we ensure that we don't block but @@ -372,13 +381,6 @@ func (m *agentConnectionMonitor) monitor(ctx context.Context) { finalCtx, cancel := context.WithTimeout(dbauthz.AsSystemRestricted(m.apiCtx), m.disconnectTimeout) defer cancel() - // Only update timestamp if the disconnect is new. - if !m.disconnectedAt.Valid { - m.disconnectedAt = sql.NullTime{ - Time: dbtime.Now(), - Valid: true, - } - } err := m.updateConnectionTimes(finalCtx) if err != nil { // This is a bug with unit tests that cancel the app context and @@ -398,12 +400,6 @@ func (m *agentConnectionMonitor) monitor(ctx context.Context) { AgentID: &m.workspaceAgent.ID, }) }() - reason := "disconnect" - defer func() { - m.logger.Debug(ctx, "agent connection monitor is closing connection", - slog.F("reason", reason)) - _ = m.conn.Close(websocket.StatusGoingAway, reason) - }() err := m.updateConnectionTimes(ctx) if err != nil { @@ -432,8 +428,7 @@ func (m *agentConnectionMonitor) monitor(ctx context.Context) { m.logger.Warn(ctx, "connection to agent timed out") return } - connectionStatusChanged := m.disconnectedAt.Valid - m.disconnectedAt = sql.NullTime{} + m.lastConnectedAt = sql.NullTime{ Time: dbtime.Now(), Valid: true, @@ -447,13 +442,9 @@ func (m *agentConnectionMonitor) monitor(ctx context.Context) { } return } - if connectionStatusChanged { - m.updater.publishWorkspaceUpdate(ctx, m.workspace.OwnerID, wspubsub.WorkspaceEvent{ - Kind: wspubsub.WorkspaceEventKindAgentConnectionUpdate, - WorkspaceID: m.workspaceBuild.WorkspaceID, - AgentID: &m.workspaceAgent.ID, - }) - } + // we don't need to publish a workspace update here because we published an update when the workspace first + // connected. Since all we've done is updated lastConnectedAt, the workspace is still connected and hasn't + // changed status. We don't expect to get updates just for the times changing. ctx, err := dbauthz.WithWorkspaceRBAC(ctx, m.workspace.RBACObject()) if err != nil { diff --git a/coderd/workspaceagentsrpc_internal_test.go b/coderd/workspaceagentsrpc_internal_test.go index 5c254b41fe..88d08bc4e3 100644 --- a/coderd/workspaceagentsrpc_internal_test.go +++ b/coderd/workspaceagentsrpc_internal_test.go @@ -23,76 +23,107 @@ import ( func TestAgentConnectionMonitor_ContextCancel(t *testing.T) { t.Parallel() - ctx := testutil.Context(t, testutil.WaitShort) now := dbtime.Now() - fConn := &fakePingerCloser{} - ctrl := gomock.NewController(t) - mDB := dbmock.NewMockStore(ctrl) - fUpdater := &fakeUpdater{} - logger := testutil.Logger(t) - agent := database.WorkspaceAgent{ - ID: uuid.New(), - FirstConnectedAt: sql.NullTime{ - Time: now.Add(-time.Minute), - Valid: true, + agentID := uuid.UUID{1} + replicaID := uuid.UUID{2} + testCases := []struct { + name string + agent database.WorkspaceAgent + initialMatcher connectionUpdateMatcher + }{ + { + name: "no disconnected at", + agent: database.WorkspaceAgent{ + ID: agentID, + FirstConnectedAt: sql.NullTime{ + Time: now.Add(-time.Minute), + Valid: true, + }, + }, + initialMatcher: connectionUpdate(agentID, replicaID), + }, + { + name: "disconnected at", + agent: database.WorkspaceAgent{ + ID: agentID, + FirstConnectedAt: sql.NullTime{ + Time: now.Add(-time.Minute), + Valid: true, + }, + DisconnectedAt: sql.NullTime{ + Time: now.Add(-2 * time.Minute), + Valid: true, + }, + }, + initialMatcher: connectionUpdate(agentID, replicaID, withDisconnectedAt(now.Add(-2*time.Minute))), }, } - build := database.WorkspaceBuild{ - ID: uuid.New(), - WorkspaceID: uuid.New(), + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + fConn := &fakePingerCloser{} + ctrl := gomock.NewController(t) + mDB := dbmock.NewMockStore(ctrl) + fUpdater := &fakeUpdater{} + logger := testutil.Logger(t) + build := database.WorkspaceBuild{ + ID: uuid.New(), + WorkspaceID: uuid.New(), + } + + uut := &agentConnectionMonitor{ + apiCtx: ctx, + workspaceAgent: tc.agent, + workspaceBuild: build, + conn: fConn, + db: mDB, + replicaID: replicaID, + updater: fUpdater, + logger: logger, + pingPeriod: testutil.IntervalFast, + disconnectTimeout: testutil.WaitShort, + } + uut.init() + + connected := mDB.EXPECT().UpdateWorkspaceAgentConnectionByID( + gomock.Any(), + tc.initialMatcher, + ). + AnyTimes(). + Return(nil) + mDB.EXPECT().UpdateWorkspaceAgentConnectionByID( + gomock.Any(), + connectionUpdate(agentID, replicaID, withDisconnectedAfter(now)), + ). + After(connected). + Times(1). + Return(nil) + mDB.EXPECT().GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), build.WorkspaceID). + AnyTimes(). + Return(database.WorkspaceBuild{ID: build.ID}, nil) + + closeCtx, cancel := context.WithCancel(ctx) + defer cancel() + done := make(chan struct{}) + go func() { + uut.monitor(closeCtx) + close(done) + }() + // wait a couple intervals, but not long enough for a disconnect + time.Sleep(3 * testutil.IntervalFast) + fConn.requireNotClosed(t) + fUpdater.requireEventuallySomeUpdates(t, build.WorkspaceID) + n := fUpdater.getUpdates() + cancel() + fConn.requireEventuallyClosed(t, websocket.StatusGoingAway, "canceled") + + // make sure we got at least one additional update on close + _ = testutil.TryReceive(ctx, t, done) + m := fUpdater.getUpdates() + require.Greater(t, m, n) + }) } - replicaID := uuid.New() - - uut := &agentConnectionMonitor{ - apiCtx: ctx, - workspaceAgent: agent, - workspaceBuild: build, - conn: fConn, - db: mDB, - replicaID: replicaID, - updater: fUpdater, - logger: logger, - pingPeriod: testutil.IntervalFast, - disconnectTimeout: testutil.WaitShort, - } - uut.init() - - connected := mDB.EXPECT().UpdateWorkspaceAgentConnectionByID( - gomock.Any(), - connectionUpdate(agent.ID, replicaID), - ). - AnyTimes(). - Return(nil) - mDB.EXPECT().UpdateWorkspaceAgentConnectionByID( - gomock.Any(), - connectionUpdate(agent.ID, replicaID, withDisconnected()), - ). - After(connected). - Times(1). - Return(nil) - mDB.EXPECT().GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), build.WorkspaceID). - AnyTimes(). - Return(database.WorkspaceBuild{ID: build.ID}, nil) - - closeCtx, cancel := context.WithCancel(ctx) - defer cancel() - done := make(chan struct{}) - go func() { - uut.monitor(closeCtx) - close(done) - }() - // wait a couple intervals, but not long enough for a disconnect - time.Sleep(3 * testutil.IntervalFast) - fConn.requireNotClosed(t) - fUpdater.requireEventuallySomeUpdates(t, build.WorkspaceID) - n := fUpdater.getUpdates() - cancel() - fConn.requireEventuallyClosed(t, websocket.StatusGoingAway, "canceled") - - // make sure we got at least one additional update on close - _ = testutil.TryReceive(ctx, t, done) - m := fUpdater.getUpdates() - require.Greater(t, m, n) } func TestAgentConnectionMonitor_PingTimeout(t *testing.T) { @@ -141,7 +172,7 @@ func TestAgentConnectionMonitor_PingTimeout(t *testing.T) { Return(nil) mDB.EXPECT().UpdateWorkspaceAgentConnectionByID( gomock.Any(), - connectionUpdate(agent.ID, replicaID, withDisconnected()), + connectionUpdate(agent.ID, replicaID, withDisconnectedAfter(now)), ). After(connected). Times(1). @@ -204,7 +235,7 @@ func TestAgentConnectionMonitor_BuildOutdated(t *testing.T) { Return(nil) mDB.EXPECT().UpdateWorkspaceAgentConnectionByID( gomock.Any(), - connectionUpdate(agent.ID, replicaID, withDisconnected()), + connectionUpdate(agent.ID, replicaID, withDisconnectedAfter(now)), ). After(connected). Times(1). @@ -289,7 +320,7 @@ func TestAgentConnectionMonitor_StartClose(t *testing.T) { Return(nil) mDB.EXPECT().UpdateWorkspaceAgentConnectionByID( gomock.Any(), - connectionUpdate(agent.ID, replicaID, withDisconnected()), + connectionUpdate(agent.ID, replicaID, withDisconnectedAfter(now)), ). After(connected). Times(1). @@ -392,9 +423,10 @@ func (f *fakeUpdater) getUpdates() int { } type connectionUpdateMatcher struct { - agentID uuid.UUID - replicaID uuid.UUID - disconnected bool + agentID uuid.UUID + replicaID uuid.UUID + disconnectedAt sql.NullTime + disconnectedAfter sql.NullTime } type connectionUpdateMatcherOption func(m connectionUpdateMatcher) connectionUpdateMatcher @@ -410,9 +442,22 @@ func connectionUpdate(id, replica uuid.UUID, opts ...connectionUpdateMatcherOpti return m } -func withDisconnected() connectionUpdateMatcherOption { +func withDisconnectedAfter(t time.Time) connectionUpdateMatcherOption { return func(m connectionUpdateMatcher) connectionUpdateMatcher { - m.disconnected = true + m.disconnectedAfter = sql.NullTime{ + Valid: true, + Time: t, + } + return m + } +} + +func withDisconnectedAt(t time.Time) connectionUpdateMatcherOption { + return func(m connectionUpdateMatcher) connectionUpdateMatcher { + m.disconnectedAt = sql.NullTime{ + Valid: true, + Time: t, + } return m } } @@ -431,15 +476,23 @@ func (m connectionUpdateMatcher) Matches(x interface{}) bool { if args.LastConnectedReplicaID.UUID != m.replicaID { return false } - if args.DisconnectedAt.Valid != m.disconnected { + if m.disconnectedAfter.Valid { + if !args.DisconnectedAt.Valid { + return false + } + if !args.DisconnectedAt.Time.After(m.disconnectedAfter.Time) { + return false + } + // disconnectedAfter takes precedence over disconnectedAt + } else if args.DisconnectedAt != m.disconnectedAt { return false } return true } func (m connectionUpdateMatcher) String() string { - return fmt.Sprintf("{agent=%s, replica=%s, disconnected=%t}", - m.agentID.String(), m.replicaID.String(), m.disconnected) + return fmt.Sprintf("{agent=%s, replica=%s, disconnectedAt=%v, disconnectedAfter=%v}", + m.agentID.String(), m.replicaID.String(), m.disconnectedAt, m.disconnectedAfter) } func (connectionUpdateMatcher) Got(x interface{}) string { @@ -447,6 +500,6 @@ func (connectionUpdateMatcher) Got(x interface{}) string { if !ok { return fmt.Sprintf("type=%T", x) } - return fmt.Sprintf("{agent=%s, replica=%s, disconnected=%t}", - args.ID, args.LastConnectedReplicaID.UUID, args.DisconnectedAt.Valid) + return fmt.Sprintf("{agent=%s, replica=%s, disconnectedAt=%v}", + args.ID, args.LastConnectedReplicaID.UUID, args.DisconnectedAt) }