From 5e4647bb3ab151f712ffc4138f79e2cd9d9e28e6 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Wed, 6 May 2026 14:14:10 -0500 Subject: [PATCH] fix: synchronize access to drpc Send (#24600) --- coderd/tailnet.go | 10 ++--- tailnet/controllers.go | 98 +++++++++++++++++++++++++----------------- 2 files changed, 64 insertions(+), 44 deletions(-) diff --git a/coderd/tailnet.go b/coderd/tailnet.go index 6f591835d9..b69c687d3a 100644 --- a/coderd/tailnet.go +++ b/coderd/tailnet.go @@ -401,7 +401,7 @@ func (m *MultiAgentController) New(client tailnet.CoordinatorClient) tailnet.Clo defer m.mu.Unlock() m.coordination = b for agentID := range m.connectionTimes { - err := client.Send(&proto.CoordinateRequest{ + err := b.SendRequest(&proto.CoordinateRequest{ AddTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]}, }) if err != nil { @@ -426,13 +426,13 @@ func (m *MultiAgentController) ensureAgent(agentID uuid.UUID) error { m.logger.Debug(context.Background(), "subscribing to agent", slog.F("agent_id", agentID)) if m.coordination != nil { - err := m.coordination.Client.Send(&proto.CoordinateRequest{ + err := m.coordination.SendRequest(&proto.CoordinateRequest{ AddTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]}, }) if err != nil { err = xerrors.Errorf("subscribe agent: %w", err) m.coordination.SendErr(err) - _ = m.coordination.Client.Close() + _ = m.coordination.CloseClient() m.coordination = nil return err } @@ -494,7 +494,7 @@ func (m *MultiAgentController) doExpireOldAgents(ctx context.Context, cutoff tim // connections, remove the agent. if time.Since(lastConnection) > cutoff && len(m.tickets[agentID]) == 0 { if m.coordination != nil { - err := m.coordination.Client.Send(&proto.CoordinateRequest{ + err := m.coordination.SendRequest(&proto.CoordinateRequest{ RemoveTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]}, }) if err != nil { @@ -502,7 +502,7 @@ func (m *MultiAgentController) doExpireOldAgents(ctx context.Context, cutoff tim m.coordination.SendErr(xerrors.Errorf("unsubscribe expired agent: %w", err)) // close the client because we do not want to do a graceful disconnect by // closing the coordination. - _ = m.coordination.Client.Close() + _ = m.coordination.CloseClient() m.coordination = nil // Here we continue deleting any inactive agents: there is no point in // re-establishing tunnels to expired agents when we eventually reconnect. diff --git a/tailnet/controllers.go b/tailnet/controllers.go index b99016e80b..35ef075878 100644 --- a/tailnet/controllers.go +++ b/tailnet/controllers.go @@ -167,7 +167,7 @@ func (c *BasicCoordinationController) NewCoordination(client CoordinatorClient) logger: c.Logger, errChan: make(chan error, 1), coordinatee: c.Coordinatee, - Client: client, + client: client, respLoopDone: make(chan struct{}), sendAcks: c.SendAcks, } @@ -185,7 +185,7 @@ func (c *BasicCoordinationController) NewCoordination(client CoordinatorClient) b.logger.Debug(context.Background(), "ignored node update because coordination is closed") return } - err = b.Client.Send(&proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: pn}}) + err = b.client.Send(&proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: pn}}) if err != nil { b.SendErr(xerrors.Errorf("write: %w", err)) } @@ -208,46 +208,66 @@ type BasicCoordination struct { errChan chan error coordinatee Coordinatee logger slog.Logger - Client CoordinatorClient + client CoordinatorClient respLoopDone chan struct{} sendAcks bool } +// CloseClient forcibly closes the underlying coordinator client connection +// without sending a graceful Disconnect message. Use this when you need to +// tear down the connection immediately, for example after a send error. +func (c *BasicCoordination) CloseClient() error { + return c.client.Close() +} + +// SendRequest sends a coordinate request on the client connection, holding +// the coordination lock to prevent concurrent writes on the dRPC stream. +func (c *BasicCoordination) SendRequest(req *proto.CoordinateRequest) error { + c.Lock() + defer c.Unlock() + if c.closed { + return xerrors.New("coordination is closed") + } + return c.client.Send(req) +} + // Close the coordination gracefully. If the context expires before the remote API server has hung // up on us, we forcibly close the Client connection. func (c *BasicCoordination) Close(ctx context.Context) (retErr error) { c.Lock() - defer c.Unlock() if c.closed { + c.Unlock() return nil } c.closed = true - defer func() { - // We shouldn't just close the protocol right away, because the way dRPC streams work is - // that if you close them, that could take effect immediately, even before the Disconnect - // message is processed. Coordinators are supposed to hang up on us once they get a - // Disconnect message, so we should wait around for that until the context expires. - select { - case <-c.respLoopDone: - c.logger.Debug(ctx, "responses closed after disconnect") - return - case <-ctx.Done(): - c.logger.Warn(ctx, "context expired while waiting for coordinate responses to close") - } - // forcefully close the stream - protoErr := c.Client.Close() - <-c.respLoopDone - if retErr == nil { - retErr = protoErr - } - }() - err := c.Client.Send(&proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}}) + err := c.client.Send(&proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}}) + c.Unlock() if err != nil && !xerrors.Is(err, io.EOF) { - // Coordinator RPC hangs up when it gets disconnect, so EOF is expected. - return xerrors.Errorf("send disconnect: %w", err) + // Log but don't return early; we must still clean up below. + c.logger.Warn(context.Background(), "failed to send disconnect", slog.Error(err)) + retErr = xerrors.Errorf("send disconnect: %w", err) + } else { + c.logger.Debug(context.Background(), "sent disconnect") } - c.logger.Debug(context.Background(), "sent disconnect") - return nil + + // We shouldn't just close the protocol right away, because the way dRPC streams work is + // that if you close them, that could take effect immediately, even before the Disconnect + // message is processed. Coordinators are supposed to hang up on us once they get a + // Disconnect message, so we should wait around for that until the context expires. + select { + case <-c.respLoopDone: + c.logger.Debug(ctx, "responses closed after disconnect") + return retErr + case <-ctx.Done(): + c.logger.Warn(ctx, "context expired while waiting for coordinate responses to close") + } + // forcefully close the stream + protoErr := c.client.Close() + <-c.respLoopDone + if retErr == nil { + retErr = protoErr + } + return retErr } // Wait for the Coordination to complete @@ -267,7 +287,7 @@ func (c *BasicCoordination) SendErr(err error) { func (c *BasicCoordination) respLoop() { defer func() { - cErr := c.Client.Close() + cErr := c.client.Close() if cErr != nil { c.logger.Debug(context.Background(), "failed to close coordinate client after respLoop exit", slog.Error(cErr)) @@ -276,7 +296,7 @@ func (c *BasicCoordination) respLoop() { close(c.respLoopDone) }() for { - resp, err := c.Client.Recv() + resp, err := c.client.Recv() if err != nil { c.logger.Debug(context.Background(), "failed to read from protocol", slog.Error(err)) @@ -317,7 +337,7 @@ func (c *BasicCoordination) respLoop() { rfh = append(rfh, &proto.CoordinateRequest_ReadyForHandshake{Id: peer.Id}) } if len(rfh) > 0 { - err := c.Client.Send(&proto.CoordinateRequest{ + err := c.SendRequest(&proto.CoordinateRequest{ ReadyForHandshake: rfh, }) if err != nil { @@ -361,7 +381,7 @@ func (c *TunnelSrcCoordController) New(client CoordinatorClient) CloserWaiter { c.coordination = b // resync destinations on reconnect for dest := range c.dests { - err := client.Send(&proto.CoordinateRequest{ + err := b.SendRequest(&proto.CoordinateRequest{ AddTunnel: &proto.CoordinateRequest_Tunnel{Id: UUIDToByteSlice(dest)}, }) if err != nil { @@ -389,13 +409,13 @@ func (c *TunnelSrcCoordController) AddDestination(dest uuid.UUID) { if c.coordination == nil { return } - err := c.coordination.Client.Send( + err := c.coordination.SendRequest( &proto.CoordinateRequest{ AddTunnel: &proto.CoordinateRequest_Tunnel{Id: UUIDToByteSlice(dest)}, }) if err != nil { c.coordination.SendErr(err) - cErr := c.coordination.Client.Close() // close the client so we don't gracefully disconnect + cErr := c.coordination.client.Close() // close the client so we don't gracefully disconnect if cErr != nil { c.Logger.Debug(context.Background(), "failed to close coordinator client after add tunnel failure", @@ -412,13 +432,13 @@ func (c *TunnelSrcCoordController) RemoveDestination(dest uuid.UUID) { if c.coordination == nil { return } - err := c.coordination.Client.Send( + err := c.coordination.SendRequest( &proto.CoordinateRequest{ RemoveTunnel: &proto.CoordinateRequest_Tunnel{Id: UUIDToByteSlice(dest)}, }) if err != nil { c.coordination.SendErr(err) - cErr := c.coordination.Client.Close() // close the client so we don't gracefully disconnect + cErr := c.coordination.client.Close() // close the client so we don't gracefully disconnect if cErr != nil { c.Logger.Debug(context.Background(), "failed to close coordinator client after remove tunnel failure", @@ -449,7 +469,7 @@ func (c *TunnelSrcCoordController) SyncDestinations(destinations []uuid.UUID) { defer func() { if err != nil { c.coordination.SendErr(err) - cErr := c.coordination.Client.Close() // don't gracefully disconnect + cErr := c.coordination.client.Close() // don't gracefully disconnect if cErr != nil { c.Logger.Debug(context.Background(), "failed to close coordinator client during sync destinations", @@ -460,7 +480,7 @@ func (c *TunnelSrcCoordController) SyncDestinations(destinations []uuid.UUID) { }() for dest := range toAdd { c.Coordinatee.SetTunnelDestination(dest) - err = c.coordination.Client.Send( + err = c.coordination.SendRequest( &proto.CoordinateRequest{ AddTunnel: &proto.CoordinateRequest_Tunnel{Id: UUIDToByteSlice(dest)}, }) @@ -469,7 +489,7 @@ func (c *TunnelSrcCoordController) SyncDestinations(destinations []uuid.UUID) { } } for dest := range toRemove { - err = c.coordination.Client.Send( + err = c.coordination.SendRequest( &proto.CoordinateRequest{ RemoveTunnel: &proto.CoordinateRequest_Tunnel{Id: UUIDToByteSlice(dest)}, })