fix: synchronize access to drpc Send (#24600)

This commit is contained in:
Jon Ayers
2026-05-06 14:14:10 -05:00
committed by GitHub
parent 30a0e2aebd
commit 5e4647bb3a
2 changed files with 64 additions and 44 deletions
+59 -39
View File
@@ -167,7 +167,7 @@ func (c *BasicCoordinationController) NewCoordination(client CoordinatorClient)
logger: c.Logger,
errChan: make(chan error, 1),
coordinatee: c.Coordinatee,
Client: client,
client: client,
respLoopDone: make(chan struct{}),
sendAcks: c.SendAcks,
}
@@ -185,7 +185,7 @@ func (c *BasicCoordinationController) NewCoordination(client CoordinatorClient)
b.logger.Debug(context.Background(), "ignored node update because coordination is closed")
return
}
err = b.Client.Send(&proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: pn}})
err = b.client.Send(&proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: pn}})
if err != nil {
b.SendErr(xerrors.Errorf("write: %w", err))
}
@@ -208,46 +208,66 @@ type BasicCoordination struct {
errChan chan error
coordinatee Coordinatee
logger slog.Logger
Client CoordinatorClient
client CoordinatorClient
respLoopDone chan struct{}
sendAcks bool
}
// CloseClient forcibly closes the underlying coordinator client connection
// without sending a graceful Disconnect message. Use this when you need to
// tear down the connection immediately, for example after a send error.
func (c *BasicCoordination) CloseClient() error {
return c.client.Close()
}
// SendRequest sends a coordinate request on the client connection, holding
// the coordination lock to prevent concurrent writes on the dRPC stream.
func (c *BasicCoordination) SendRequest(req *proto.CoordinateRequest) error {
c.Lock()
defer c.Unlock()
if c.closed {
return xerrors.New("coordination is closed")
}
return c.client.Send(req)
}
// Close the coordination gracefully. If the context expires before the remote API server has hung
// up on us, we forcibly close the Client connection.
func (c *BasicCoordination) Close(ctx context.Context) (retErr error) {
c.Lock()
defer c.Unlock()
if c.closed {
c.Unlock()
return nil
}
c.closed = true
defer func() {
// We shouldn't just close the protocol right away, because the way dRPC streams work is
// that if you close them, that could take effect immediately, even before the Disconnect
// message is processed. Coordinators are supposed to hang up on us once they get a
// Disconnect message, so we should wait around for that until the context expires.
select {
case <-c.respLoopDone:
c.logger.Debug(ctx, "responses closed after disconnect")
return
case <-ctx.Done():
c.logger.Warn(ctx, "context expired while waiting for coordinate responses to close")
}
// forcefully close the stream
protoErr := c.Client.Close()
<-c.respLoopDone
if retErr == nil {
retErr = protoErr
}
}()
err := c.Client.Send(&proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}})
err := c.client.Send(&proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}})
c.Unlock()
if err != nil && !xerrors.Is(err, io.EOF) {
// Coordinator RPC hangs up when it gets disconnect, so EOF is expected.
return xerrors.Errorf("send disconnect: %w", err)
// Log but don't return early; we must still clean up below.
c.logger.Warn(context.Background(), "failed to send disconnect", slog.Error(err))
retErr = xerrors.Errorf("send disconnect: %w", err)
} else {
c.logger.Debug(context.Background(), "sent disconnect")
}
c.logger.Debug(context.Background(), "sent disconnect")
return nil
// We shouldn't just close the protocol right away, because the way dRPC streams work is
// that if you close them, that could take effect immediately, even before the Disconnect
// message is processed. Coordinators are supposed to hang up on us once they get a
// Disconnect message, so we should wait around for that until the context expires.
select {
case <-c.respLoopDone:
c.logger.Debug(ctx, "responses closed after disconnect")
return retErr
case <-ctx.Done():
c.logger.Warn(ctx, "context expired while waiting for coordinate responses to close")
}
// forcefully close the stream
protoErr := c.client.Close()
<-c.respLoopDone
if retErr == nil {
retErr = protoErr
}
return retErr
}
// Wait for the Coordination to complete
@@ -267,7 +287,7 @@ func (c *BasicCoordination) SendErr(err error) {
func (c *BasicCoordination) respLoop() {
defer func() {
cErr := c.Client.Close()
cErr := c.client.Close()
if cErr != nil {
c.logger.Debug(context.Background(),
"failed to close coordinate client after respLoop exit", slog.Error(cErr))
@@ -276,7 +296,7 @@ func (c *BasicCoordination) respLoop() {
close(c.respLoopDone)
}()
for {
resp, err := c.Client.Recv()
resp, err := c.client.Recv()
if err != nil {
c.logger.Debug(context.Background(),
"failed to read from protocol", slog.Error(err))
@@ -317,7 +337,7 @@ func (c *BasicCoordination) respLoop() {
rfh = append(rfh, &proto.CoordinateRequest_ReadyForHandshake{Id: peer.Id})
}
if len(rfh) > 0 {
err := c.Client.Send(&proto.CoordinateRequest{
err := c.SendRequest(&proto.CoordinateRequest{
ReadyForHandshake: rfh,
})
if err != nil {
@@ -361,7 +381,7 @@ func (c *TunnelSrcCoordController) New(client CoordinatorClient) CloserWaiter {
c.coordination = b
// resync destinations on reconnect
for dest := range c.dests {
err := client.Send(&proto.CoordinateRequest{
err := b.SendRequest(&proto.CoordinateRequest{
AddTunnel: &proto.CoordinateRequest_Tunnel{Id: UUIDToByteSlice(dest)},
})
if err != nil {
@@ -389,13 +409,13 @@ func (c *TunnelSrcCoordController) AddDestination(dest uuid.UUID) {
if c.coordination == nil {
return
}
err := c.coordination.Client.Send(
err := c.coordination.SendRequest(
&proto.CoordinateRequest{
AddTunnel: &proto.CoordinateRequest_Tunnel{Id: UUIDToByteSlice(dest)},
})
if err != nil {
c.coordination.SendErr(err)
cErr := c.coordination.Client.Close() // close the client so we don't gracefully disconnect
cErr := c.coordination.client.Close() // close the client so we don't gracefully disconnect
if cErr != nil {
c.Logger.Debug(context.Background(),
"failed to close coordinator client after add tunnel failure",
@@ -412,13 +432,13 @@ func (c *TunnelSrcCoordController) RemoveDestination(dest uuid.UUID) {
if c.coordination == nil {
return
}
err := c.coordination.Client.Send(
err := c.coordination.SendRequest(
&proto.CoordinateRequest{
RemoveTunnel: &proto.CoordinateRequest_Tunnel{Id: UUIDToByteSlice(dest)},
})
if err != nil {
c.coordination.SendErr(err)
cErr := c.coordination.Client.Close() // close the client so we don't gracefully disconnect
cErr := c.coordination.client.Close() // close the client so we don't gracefully disconnect
if cErr != nil {
c.Logger.Debug(context.Background(),
"failed to close coordinator client after remove tunnel failure",
@@ -449,7 +469,7 @@ func (c *TunnelSrcCoordController) SyncDestinations(destinations []uuid.UUID) {
defer func() {
if err != nil {
c.coordination.SendErr(err)
cErr := c.coordination.Client.Close() // don't gracefully disconnect
cErr := c.coordination.client.Close() // don't gracefully disconnect
if cErr != nil {
c.Logger.Debug(context.Background(),
"failed to close coordinator client during sync destinations",
@@ -460,7 +480,7 @@ func (c *TunnelSrcCoordController) SyncDestinations(destinations []uuid.UUID) {
}()
for dest := range toAdd {
c.Coordinatee.SetTunnelDestination(dest)
err = c.coordination.Client.Send(
err = c.coordination.SendRequest(
&proto.CoordinateRequest{
AddTunnel: &proto.CoordinateRequest_Tunnel{Id: UUIDToByteSlice(dest)},
})
@@ -469,7 +489,7 @@ func (c *TunnelSrcCoordController) SyncDestinations(destinations []uuid.UUID) {
}
}
for dest := range toRemove {
err = c.coordination.Client.Send(
err = c.coordination.SendRequest(
&proto.CoordinateRequest{
RemoveTunnel: &proto.CoordinateRequest_Tunnel{Id: UUIDToByteSlice(dest)},
})