fix: fix panic while tearing down reconnecting PTY (#15615)

fixes https://github.com/coder/internal/issues/221

Fixes an issue where two goroutines were sharing the `err` variable, leading to a data race where we'd fail to process the error and then nil-pointer panic.

I ended up refactoring reconnecting PTY stuff into the `reconnectingpty` package, instead of having it on the agent.  That `createTailnet` routine had waaay too many deeply nested goroutines, which is I'm sure a big contributor to the bug appearing in the first place.
This commit is contained in:
Spike Curtis
2024-11-22 09:46:25 +04:00
committed by GitHub
parent 684e75e2a7
commit 103824f726
2 changed files with 208 additions and 138 deletions
+17 -138
View File
@@ -3,12 +3,10 @@ package agent
import (
"bytes"
"context"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/netip"
"os"
@@ -216,8 +214,8 @@ type agent struct {
portCacheDuration time.Duration
subsystems []codersdk.AgentSubsystem
reconnectingPTYs sync.Map
reconnectingPTYTimeout time.Duration
reconnectingPTYServer *reconnectingpty.Server
// we track 2 contexts and associated cancel functions: "graceful" which is Done when it is time
// to start gracefully shutting down and "hard" which is Done when it is time to close
@@ -252,8 +250,6 @@ type agent struct {
statsReporter *statsReporter
logSender *agentsdk.LogSender
connCountReconnectingPTY atomic.Int64
prometheusRegistry *prometheus.Registry
// metrics are prometheus registered metrics that will be collected and
// labeled in Coder with the agent + workspace.
@@ -297,6 +293,13 @@ func (a *agent) init() {
// Register runner metrics. If the prom registry is nil, the metrics
// will not report anywhere.
a.scriptRunner.RegisterMetrics(a.prometheusRegistry)
a.reconnectingPTYServer = reconnectingpty.NewServer(
a.logger.Named("reconnecting-pty"),
a.sshServer,
a.metrics.connectionsTotal, a.metrics.reconnectingPTYErrors,
a.reconnectingPTYTimeout,
)
go a.runLoop()
}
@@ -1181,55 +1184,12 @@ func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *t
}
}()
if err = a.trackGoroutine(func() {
logger := a.logger.Named("reconnecting-pty")
var wg sync.WaitGroup
for {
conn, err := reconnectingPTYListener.Accept()
if err != nil {
if !a.isClosed() {
logger.Debug(ctx, "accept pty failed", slog.Error(err))
}
break
}
clog := logger.With(
slog.F("remote", conn.RemoteAddr().String()),
slog.F("local", conn.LocalAddr().String()))
clog.Info(ctx, "accepted conn")
wg.Add(1)
closed := make(chan struct{})
go func() {
select {
case <-closed:
case <-a.hardCtx.Done():
_ = conn.Close()
}
wg.Done()
}()
go func() {
defer close(closed)
// This cannot use a JSON decoder, since that can
// buffer additional data that is required for the PTY.
rawLen := make([]byte, 2)
_, err = conn.Read(rawLen)
if err != nil {
return
}
length := binary.LittleEndian.Uint16(rawLen)
data := make([]byte, length)
_, err = conn.Read(data)
if err != nil {
return
}
var msg workspacesdk.AgentReconnectingPTYInit
err = json.Unmarshal(data, &msg)
if err != nil {
logger.Warn(ctx, "failed to unmarshal init", slog.F("raw", data))
return
}
_ = a.handleReconnectingPTY(ctx, clog, msg, conn)
}()
rPTYServeErr := a.reconnectingPTYServer.Serve(a.gracefulCtx, a.hardCtx, reconnectingPTYListener)
if rPTYServeErr != nil &&
a.gracefulCtx.Err() == nil &&
!strings.Contains(rPTYServeErr.Error(), "use of closed network connection") {
a.logger.Error(ctx, "error serving reconnecting PTY", slog.Error(err))
}
wg.Wait()
}); err != nil {
return nil, err
}
@@ -1308,9 +1268,9 @@ func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *t
_ = server.Close()
}()
err := server.Serve(apiListener)
if err != nil && !xerrors.Is(err, http.ErrServerClosed) && !strings.Contains(err.Error(), "use of closed network connection") {
a.logger.Critical(ctx, "serve HTTP API server", slog.Error(err))
apiServErr := server.Serve(apiListener)
if apiServErr != nil && !xerrors.Is(apiServErr, http.ErrServerClosed) && !strings.Contains(apiServErr.Error(), "use of closed network connection") {
a.logger.Critical(ctx, "serve HTTP API server", slog.Error(apiServErr))
}
}); err != nil {
return nil, err
@@ -1394,87 +1354,6 @@ func (a *agent) runDERPMapSubscriber(ctx context.Context, tClient tailnetproto.D
}
}
func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, msg workspacesdk.AgentReconnectingPTYInit, conn net.Conn) (retErr error) {
defer conn.Close()
a.metrics.connectionsTotal.Add(1)
a.connCountReconnectingPTY.Add(1)
defer a.connCountReconnectingPTY.Add(-1)
connectionID := uuid.NewString()
connLogger := logger.With(slog.F("message_id", msg.ID), slog.F("connection_id", connectionID))
connLogger.Debug(ctx, "starting handler")
defer func() {
if err := retErr; err != nil {
a.closeMutex.Lock()
closed := a.isClosed()
a.closeMutex.Unlock()
// If the agent is closed, we don't want to
// log this as an error since it's expected.
if closed {
connLogger.Info(ctx, "reconnecting pty failed with attach error (agent closed)", slog.Error(err))
} else {
connLogger.Error(ctx, "reconnecting pty failed with attach error", slog.Error(err))
}
}
connLogger.Info(ctx, "reconnecting pty connection closed")
}()
var rpty reconnectingpty.ReconnectingPTY
sendConnected := make(chan reconnectingpty.ReconnectingPTY, 1)
// On store, reserve this ID to prevent multiple concurrent new connections.
waitReady, ok := a.reconnectingPTYs.LoadOrStore(msg.ID, sendConnected)
if ok {
close(sendConnected) // Unused.
connLogger.Debug(ctx, "connecting to existing reconnecting pty")
c, ok := waitReady.(chan reconnectingpty.ReconnectingPTY)
if !ok {
return xerrors.Errorf("found invalid type in reconnecting pty map: %T", waitReady)
}
rpty, ok = <-c
if !ok || rpty == nil {
return xerrors.Errorf("reconnecting pty closed before connection")
}
c <- rpty // Put it back for the next reconnect.
} else {
connLogger.Debug(ctx, "creating new reconnecting pty")
connected := false
defer func() {
if !connected && retErr != nil {
a.reconnectingPTYs.Delete(msg.ID)
close(sendConnected)
}
}()
// Empty command will default to the users shell!
cmd, err := a.sshServer.CreateCommand(ctx, msg.Command, nil)
if err != nil {
a.metrics.reconnectingPTYErrors.WithLabelValues("create_command").Add(1)
return xerrors.Errorf("create command: %w", err)
}
rpty = reconnectingpty.New(ctx, cmd, &reconnectingpty.Options{
Timeout: a.reconnectingPTYTimeout,
Metrics: a.metrics.reconnectingPTYErrors,
}, logger.With(slog.F("message_id", msg.ID)))
if err = a.trackGoroutine(func() {
rpty.Wait()
a.reconnectingPTYs.Delete(msg.ID)
}); err != nil {
rpty.Close(err)
return xerrors.Errorf("start routine: %w", err)
}
connected = true
sendConnected <- rpty
}
return rpty.Attach(ctx, connectionID, conn, msg.Height, msg.Width, connLogger)
}
// Collect collects additional stats from the agent
func (a *agent) Collect(ctx context.Context, networkStats map[netlogtype.Connection]netlogtype.Counts) *proto.Stats {
a.logger.Debug(context.Background(), "computing stats report")
@@ -1496,7 +1375,7 @@ func (a *agent) Collect(ctx context.Context, networkStats map[netlogtype.Connect
stats.SessionCountVscode = sshStats.VSCode
stats.SessionCountJetbrains = sshStats.JetBrains
stats.SessionCountReconnectingPty = a.connCountReconnectingPTY.Load()
stats.SessionCountReconnectingPty = a.reconnectingPTYServer.ConnCount()
// Compute the median connection latency!
a.logger.Debug(ctx, "starting peer latency measurement for stats")
+191
View File
@@ -0,0 +1,191 @@
package reconnectingpty
import (
"context"
"encoding/binary"
"encoding/json"
"net"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
"github.com/prometheus/client_golang/prometheus"
"golang.org/x/xerrors"
"cdr.dev/slog"
"github.com/coder/coder/v2/agent/agentssh"
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
type Server struct {
logger slog.Logger
connectionsTotal prometheus.Counter
errorsTotal *prometheus.CounterVec
commandCreator *agentssh.Server
connCount atomic.Int64
reconnectingPTYs sync.Map
timeout time.Duration
}
// NewServer returns a new ReconnectingPTY server
func NewServer(logger slog.Logger, commandCreator *agentssh.Server,
connectionsTotal prometheus.Counter, errorsTotal *prometheus.CounterVec,
timeout time.Duration,
) *Server {
return &Server{
logger: logger,
commandCreator: commandCreator,
connectionsTotal: connectionsTotal,
errorsTotal: errorsTotal,
timeout: timeout,
}
}
func (s *Server) Serve(ctx, hardCtx context.Context, l net.Listener) (retErr error) {
var wg sync.WaitGroup
for {
if ctx.Err() != nil {
break
}
conn, err := l.Accept()
if err != nil {
s.logger.Debug(ctx, "accept pty failed", slog.Error(err))
retErr = err
break
}
clog := s.logger.With(
slog.F("remote", conn.RemoteAddr().String()),
slog.F("local", conn.LocalAddr().String()))
clog.Info(ctx, "accepted conn")
wg.Add(1)
closed := make(chan struct{})
go func() {
select {
case <-closed:
case <-hardCtx.Done():
_ = conn.Close()
}
wg.Done()
}()
wg.Add(1)
go func() {
defer close(closed)
defer wg.Done()
_ = s.handleConn(ctx, clog, conn)
}()
}
wg.Wait()
return retErr
}
func (s *Server) ConnCount() int64 {
return s.connCount.Load()
}
func (s *Server) handleConn(ctx context.Context, logger slog.Logger, conn net.Conn) (retErr error) {
defer conn.Close()
s.connectionsTotal.Add(1)
s.connCount.Add(1)
defer s.connCount.Add(-1)
// This cannot use a JSON decoder, since that can
// buffer additional data that is required for the PTY.
rawLen := make([]byte, 2)
_, err := conn.Read(rawLen)
if err != nil {
// logging at info since a single incident isn't too worrying (the client could just have
// hung up), but if we get a lot of these we'd want to investigate.
logger.Info(ctx, "failed to read AgentReconnectingPTYInit length", slog.Error(err))
return nil
}
length := binary.LittleEndian.Uint16(rawLen)
data := make([]byte, length)
_, err = conn.Read(data)
if err != nil {
// logging at info since a single incident isn't too worrying (the client could just have
// hung up), but if we get a lot of these we'd want to investigate.
logger.Info(ctx, "failed to read AgentReconnectingPTYInit", slog.Error(err))
return nil
}
var msg workspacesdk.AgentReconnectingPTYInit
err = json.Unmarshal(data, &msg)
if err != nil {
logger.Warn(ctx, "failed to unmarshal init", slog.F("raw", data))
return nil
}
connectionID := uuid.NewString()
connLogger := logger.With(slog.F("message_id", msg.ID), slog.F("connection_id", connectionID))
connLogger.Debug(ctx, "starting handler")
defer func() {
if err := retErr; err != nil {
// If the context is done, we don't want to log this as an error since it's expected.
if ctx.Err() != nil {
connLogger.Info(ctx, "reconnecting pty failed with attach error (agent closed)", slog.Error(err))
} else {
connLogger.Error(ctx, "reconnecting pty failed with attach error", slog.Error(err))
}
}
connLogger.Info(ctx, "reconnecting pty connection closed")
}()
var rpty ReconnectingPTY
sendConnected := make(chan ReconnectingPTY, 1)
// On store, reserve this ID to prevent multiple concurrent new connections.
waitReady, ok := s.reconnectingPTYs.LoadOrStore(msg.ID, sendConnected)
if ok {
close(sendConnected) // Unused.
connLogger.Debug(ctx, "connecting to existing reconnecting pty")
c, ok := waitReady.(chan ReconnectingPTY)
if !ok {
return xerrors.Errorf("found invalid type in reconnecting pty map: %T", waitReady)
}
rpty, ok = <-c
if !ok || rpty == nil {
return xerrors.Errorf("reconnecting pty closed before connection")
}
c <- rpty // Put it back for the next reconnect.
} else {
connLogger.Debug(ctx, "creating new reconnecting pty")
connected := false
defer func() {
if !connected && retErr != nil {
s.reconnectingPTYs.Delete(msg.ID)
close(sendConnected)
}
}()
// Empty command will default to the users shell!
cmd, err := s.commandCreator.CreateCommand(ctx, msg.Command, nil)
if err != nil {
s.errorsTotal.WithLabelValues("create_command").Add(1)
return xerrors.Errorf("create command: %w", err)
}
rpty = New(ctx, cmd, &Options{
Timeout: s.timeout,
Metrics: s.errorsTotal,
}, logger.With(slog.F("message_id", msg.ID)))
done := make(chan struct{})
go func() {
select {
case <-done:
case <-ctx.Done():
rpty.Close(ctx.Err())
}
}()
go func() {
rpty.Wait()
s.reconnectingPTYs.Delete(msg.ID)
}()
connected = true
sendConnected <- rpty
}
return rpty.Attach(ctx, connectionID, conn, msg.Height, msg.Width, connLogger)
}