mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
fix: fix panic while tearing down reconnecting PTY (#15615)
fixes https://github.com/coder/internal/issues/221 Fixes an issue where two goroutines were sharing the `err` variable, leading to a data race where we'd fail to process the error and then nil-pointer panic. I ended up refactoring reconnecting PTY stuff into the `reconnectingpty` package, instead of having it on the agent. That `createTailnet` routine had waaay too many deeply nested goroutines, which is I'm sure a big contributor to the bug appearing in the first place.
This commit is contained in:
+17
-138
@@ -3,12 +3,10 @@ package agent
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"os"
|
||||
@@ -216,8 +214,8 @@ type agent struct {
|
||||
portCacheDuration time.Duration
|
||||
subsystems []codersdk.AgentSubsystem
|
||||
|
||||
reconnectingPTYs sync.Map
|
||||
reconnectingPTYTimeout time.Duration
|
||||
reconnectingPTYServer *reconnectingpty.Server
|
||||
|
||||
// we track 2 contexts and associated cancel functions: "graceful" which is Done when it is time
|
||||
// to start gracefully shutting down and "hard" which is Done when it is time to close
|
||||
@@ -252,8 +250,6 @@ type agent struct {
|
||||
statsReporter *statsReporter
|
||||
logSender *agentsdk.LogSender
|
||||
|
||||
connCountReconnectingPTY atomic.Int64
|
||||
|
||||
prometheusRegistry *prometheus.Registry
|
||||
// metrics are prometheus registered metrics that will be collected and
|
||||
// labeled in Coder with the agent + workspace.
|
||||
@@ -297,6 +293,13 @@ func (a *agent) init() {
|
||||
// Register runner metrics. If the prom registry is nil, the metrics
|
||||
// will not report anywhere.
|
||||
a.scriptRunner.RegisterMetrics(a.prometheusRegistry)
|
||||
|
||||
a.reconnectingPTYServer = reconnectingpty.NewServer(
|
||||
a.logger.Named("reconnecting-pty"),
|
||||
a.sshServer,
|
||||
a.metrics.connectionsTotal, a.metrics.reconnectingPTYErrors,
|
||||
a.reconnectingPTYTimeout,
|
||||
)
|
||||
go a.runLoop()
|
||||
}
|
||||
|
||||
@@ -1181,55 +1184,12 @@ func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *t
|
||||
}
|
||||
}()
|
||||
if err = a.trackGoroutine(func() {
|
||||
logger := a.logger.Named("reconnecting-pty")
|
||||
var wg sync.WaitGroup
|
||||
for {
|
||||
conn, err := reconnectingPTYListener.Accept()
|
||||
if err != nil {
|
||||
if !a.isClosed() {
|
||||
logger.Debug(ctx, "accept pty failed", slog.Error(err))
|
||||
}
|
||||
break
|
||||
}
|
||||
clog := logger.With(
|
||||
slog.F("remote", conn.RemoteAddr().String()),
|
||||
slog.F("local", conn.LocalAddr().String()))
|
||||
clog.Info(ctx, "accepted conn")
|
||||
wg.Add(1)
|
||||
closed := make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-closed:
|
||||
case <-a.hardCtx.Done():
|
||||
_ = conn.Close()
|
||||
}
|
||||
wg.Done()
|
||||
}()
|
||||
go func() {
|
||||
defer close(closed)
|
||||
// This cannot use a JSON decoder, since that can
|
||||
// buffer additional data that is required for the PTY.
|
||||
rawLen := make([]byte, 2)
|
||||
_, err = conn.Read(rawLen)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
length := binary.LittleEndian.Uint16(rawLen)
|
||||
data := make([]byte, length)
|
||||
_, err = conn.Read(data)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var msg workspacesdk.AgentReconnectingPTYInit
|
||||
err = json.Unmarshal(data, &msg)
|
||||
if err != nil {
|
||||
logger.Warn(ctx, "failed to unmarshal init", slog.F("raw", data))
|
||||
return
|
||||
}
|
||||
_ = a.handleReconnectingPTY(ctx, clog, msg, conn)
|
||||
}()
|
||||
rPTYServeErr := a.reconnectingPTYServer.Serve(a.gracefulCtx, a.hardCtx, reconnectingPTYListener)
|
||||
if rPTYServeErr != nil &&
|
||||
a.gracefulCtx.Err() == nil &&
|
||||
!strings.Contains(rPTYServeErr.Error(), "use of closed network connection") {
|
||||
a.logger.Error(ctx, "error serving reconnecting PTY", slog.Error(err))
|
||||
}
|
||||
wg.Wait()
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1308,9 +1268,9 @@ func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *t
|
||||
_ = server.Close()
|
||||
}()
|
||||
|
||||
err := server.Serve(apiListener)
|
||||
if err != nil && !xerrors.Is(err, http.ErrServerClosed) && !strings.Contains(err.Error(), "use of closed network connection") {
|
||||
a.logger.Critical(ctx, "serve HTTP API server", slog.Error(err))
|
||||
apiServErr := server.Serve(apiListener)
|
||||
if apiServErr != nil && !xerrors.Is(apiServErr, http.ErrServerClosed) && !strings.Contains(apiServErr.Error(), "use of closed network connection") {
|
||||
a.logger.Critical(ctx, "serve HTTP API server", slog.Error(apiServErr))
|
||||
}
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
@@ -1394,87 +1354,6 @@ func (a *agent) runDERPMapSubscriber(ctx context.Context, tClient tailnetproto.D
|
||||
}
|
||||
}
|
||||
|
||||
func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, msg workspacesdk.AgentReconnectingPTYInit, conn net.Conn) (retErr error) {
|
||||
defer conn.Close()
|
||||
a.metrics.connectionsTotal.Add(1)
|
||||
|
||||
a.connCountReconnectingPTY.Add(1)
|
||||
defer a.connCountReconnectingPTY.Add(-1)
|
||||
|
||||
connectionID := uuid.NewString()
|
||||
connLogger := logger.With(slog.F("message_id", msg.ID), slog.F("connection_id", connectionID))
|
||||
connLogger.Debug(ctx, "starting handler")
|
||||
|
||||
defer func() {
|
||||
if err := retErr; err != nil {
|
||||
a.closeMutex.Lock()
|
||||
closed := a.isClosed()
|
||||
a.closeMutex.Unlock()
|
||||
|
||||
// If the agent is closed, we don't want to
|
||||
// log this as an error since it's expected.
|
||||
if closed {
|
||||
connLogger.Info(ctx, "reconnecting pty failed with attach error (agent closed)", slog.Error(err))
|
||||
} else {
|
||||
connLogger.Error(ctx, "reconnecting pty failed with attach error", slog.Error(err))
|
||||
}
|
||||
}
|
||||
connLogger.Info(ctx, "reconnecting pty connection closed")
|
||||
}()
|
||||
|
||||
var rpty reconnectingpty.ReconnectingPTY
|
||||
sendConnected := make(chan reconnectingpty.ReconnectingPTY, 1)
|
||||
// On store, reserve this ID to prevent multiple concurrent new connections.
|
||||
waitReady, ok := a.reconnectingPTYs.LoadOrStore(msg.ID, sendConnected)
|
||||
if ok {
|
||||
close(sendConnected) // Unused.
|
||||
connLogger.Debug(ctx, "connecting to existing reconnecting pty")
|
||||
c, ok := waitReady.(chan reconnectingpty.ReconnectingPTY)
|
||||
if !ok {
|
||||
return xerrors.Errorf("found invalid type in reconnecting pty map: %T", waitReady)
|
||||
}
|
||||
rpty, ok = <-c
|
||||
if !ok || rpty == nil {
|
||||
return xerrors.Errorf("reconnecting pty closed before connection")
|
||||
}
|
||||
c <- rpty // Put it back for the next reconnect.
|
||||
} else {
|
||||
connLogger.Debug(ctx, "creating new reconnecting pty")
|
||||
|
||||
connected := false
|
||||
defer func() {
|
||||
if !connected && retErr != nil {
|
||||
a.reconnectingPTYs.Delete(msg.ID)
|
||||
close(sendConnected)
|
||||
}
|
||||
}()
|
||||
|
||||
// Empty command will default to the users shell!
|
||||
cmd, err := a.sshServer.CreateCommand(ctx, msg.Command, nil)
|
||||
if err != nil {
|
||||
a.metrics.reconnectingPTYErrors.WithLabelValues("create_command").Add(1)
|
||||
return xerrors.Errorf("create command: %w", err)
|
||||
}
|
||||
|
||||
rpty = reconnectingpty.New(ctx, cmd, &reconnectingpty.Options{
|
||||
Timeout: a.reconnectingPTYTimeout,
|
||||
Metrics: a.metrics.reconnectingPTYErrors,
|
||||
}, logger.With(slog.F("message_id", msg.ID)))
|
||||
|
||||
if err = a.trackGoroutine(func() {
|
||||
rpty.Wait()
|
||||
a.reconnectingPTYs.Delete(msg.ID)
|
||||
}); err != nil {
|
||||
rpty.Close(err)
|
||||
return xerrors.Errorf("start routine: %w", err)
|
||||
}
|
||||
|
||||
connected = true
|
||||
sendConnected <- rpty
|
||||
}
|
||||
return rpty.Attach(ctx, connectionID, conn, msg.Height, msg.Width, connLogger)
|
||||
}
|
||||
|
||||
// Collect collects additional stats from the agent
|
||||
func (a *agent) Collect(ctx context.Context, networkStats map[netlogtype.Connection]netlogtype.Counts) *proto.Stats {
|
||||
a.logger.Debug(context.Background(), "computing stats report")
|
||||
@@ -1496,7 +1375,7 @@ func (a *agent) Collect(ctx context.Context, networkStats map[netlogtype.Connect
|
||||
stats.SessionCountVscode = sshStats.VSCode
|
||||
stats.SessionCountJetbrains = sshStats.JetBrains
|
||||
|
||||
stats.SessionCountReconnectingPty = a.connCountReconnectingPTY.Load()
|
||||
stats.SessionCountReconnectingPty = a.reconnectingPTYServer.ConnCount()
|
||||
|
||||
// Compute the median connection latency!
|
||||
a.logger.Debug(ctx, "starting peer latency measurement for stats")
|
||||
|
||||
@@ -0,0 +1,191 @@
|
||||
package reconnectingpty
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/v2/agent/agentssh"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
logger slog.Logger
|
||||
connectionsTotal prometheus.Counter
|
||||
errorsTotal *prometheus.CounterVec
|
||||
commandCreator *agentssh.Server
|
||||
connCount atomic.Int64
|
||||
reconnectingPTYs sync.Map
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
// NewServer returns a new ReconnectingPTY server
|
||||
func NewServer(logger slog.Logger, commandCreator *agentssh.Server,
|
||||
connectionsTotal prometheus.Counter, errorsTotal *prometheus.CounterVec,
|
||||
timeout time.Duration,
|
||||
) *Server {
|
||||
return &Server{
|
||||
logger: logger,
|
||||
commandCreator: commandCreator,
|
||||
connectionsTotal: connectionsTotal,
|
||||
errorsTotal: errorsTotal,
|
||||
timeout: timeout,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) Serve(ctx, hardCtx context.Context, l net.Listener) (retErr error) {
|
||||
var wg sync.WaitGroup
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
break
|
||||
}
|
||||
conn, err := l.Accept()
|
||||
if err != nil {
|
||||
s.logger.Debug(ctx, "accept pty failed", slog.Error(err))
|
||||
retErr = err
|
||||
break
|
||||
}
|
||||
clog := s.logger.With(
|
||||
slog.F("remote", conn.RemoteAddr().String()),
|
||||
slog.F("local", conn.LocalAddr().String()))
|
||||
clog.Info(ctx, "accepted conn")
|
||||
wg.Add(1)
|
||||
closed := make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-closed:
|
||||
case <-hardCtx.Done():
|
||||
_ = conn.Close()
|
||||
}
|
||||
wg.Done()
|
||||
}()
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer close(closed)
|
||||
defer wg.Done()
|
||||
_ = s.handleConn(ctx, clog, conn)
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
return retErr
|
||||
}
|
||||
|
||||
func (s *Server) ConnCount() int64 {
|
||||
return s.connCount.Load()
|
||||
}
|
||||
|
||||
func (s *Server) handleConn(ctx context.Context, logger slog.Logger, conn net.Conn) (retErr error) {
|
||||
defer conn.Close()
|
||||
s.connectionsTotal.Add(1)
|
||||
s.connCount.Add(1)
|
||||
defer s.connCount.Add(-1)
|
||||
|
||||
// This cannot use a JSON decoder, since that can
|
||||
// buffer additional data that is required for the PTY.
|
||||
rawLen := make([]byte, 2)
|
||||
_, err := conn.Read(rawLen)
|
||||
if err != nil {
|
||||
// logging at info since a single incident isn't too worrying (the client could just have
|
||||
// hung up), but if we get a lot of these we'd want to investigate.
|
||||
logger.Info(ctx, "failed to read AgentReconnectingPTYInit length", slog.Error(err))
|
||||
return nil
|
||||
}
|
||||
length := binary.LittleEndian.Uint16(rawLen)
|
||||
data := make([]byte, length)
|
||||
_, err = conn.Read(data)
|
||||
if err != nil {
|
||||
// logging at info since a single incident isn't too worrying (the client could just have
|
||||
// hung up), but if we get a lot of these we'd want to investigate.
|
||||
logger.Info(ctx, "failed to read AgentReconnectingPTYInit", slog.Error(err))
|
||||
return nil
|
||||
}
|
||||
var msg workspacesdk.AgentReconnectingPTYInit
|
||||
err = json.Unmarshal(data, &msg)
|
||||
if err != nil {
|
||||
logger.Warn(ctx, "failed to unmarshal init", slog.F("raw", data))
|
||||
return nil
|
||||
}
|
||||
|
||||
connectionID := uuid.NewString()
|
||||
connLogger := logger.With(slog.F("message_id", msg.ID), slog.F("connection_id", connectionID))
|
||||
connLogger.Debug(ctx, "starting handler")
|
||||
|
||||
defer func() {
|
||||
if err := retErr; err != nil {
|
||||
// If the context is done, we don't want to log this as an error since it's expected.
|
||||
if ctx.Err() != nil {
|
||||
connLogger.Info(ctx, "reconnecting pty failed with attach error (agent closed)", slog.Error(err))
|
||||
} else {
|
||||
connLogger.Error(ctx, "reconnecting pty failed with attach error", slog.Error(err))
|
||||
}
|
||||
}
|
||||
connLogger.Info(ctx, "reconnecting pty connection closed")
|
||||
}()
|
||||
|
||||
var rpty ReconnectingPTY
|
||||
sendConnected := make(chan ReconnectingPTY, 1)
|
||||
// On store, reserve this ID to prevent multiple concurrent new connections.
|
||||
waitReady, ok := s.reconnectingPTYs.LoadOrStore(msg.ID, sendConnected)
|
||||
if ok {
|
||||
close(sendConnected) // Unused.
|
||||
connLogger.Debug(ctx, "connecting to existing reconnecting pty")
|
||||
c, ok := waitReady.(chan ReconnectingPTY)
|
||||
if !ok {
|
||||
return xerrors.Errorf("found invalid type in reconnecting pty map: %T", waitReady)
|
||||
}
|
||||
rpty, ok = <-c
|
||||
if !ok || rpty == nil {
|
||||
return xerrors.Errorf("reconnecting pty closed before connection")
|
||||
}
|
||||
c <- rpty // Put it back for the next reconnect.
|
||||
} else {
|
||||
connLogger.Debug(ctx, "creating new reconnecting pty")
|
||||
|
||||
connected := false
|
||||
defer func() {
|
||||
if !connected && retErr != nil {
|
||||
s.reconnectingPTYs.Delete(msg.ID)
|
||||
close(sendConnected)
|
||||
}
|
||||
}()
|
||||
|
||||
// Empty command will default to the users shell!
|
||||
cmd, err := s.commandCreator.CreateCommand(ctx, msg.Command, nil)
|
||||
if err != nil {
|
||||
s.errorsTotal.WithLabelValues("create_command").Add(1)
|
||||
return xerrors.Errorf("create command: %w", err)
|
||||
}
|
||||
|
||||
rpty = New(ctx, cmd, &Options{
|
||||
Timeout: s.timeout,
|
||||
Metrics: s.errorsTotal,
|
||||
}, logger.With(slog.F("message_id", msg.ID)))
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-done:
|
||||
case <-ctx.Done():
|
||||
rpty.Close(ctx.Err())
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
rpty.Wait()
|
||||
s.reconnectingPTYs.Delete(msg.ID)
|
||||
}()
|
||||
|
||||
connected = true
|
||||
sendConnected <- rpty
|
||||
}
|
||||
return rpty.Attach(ctx, connectionID, conn, msg.Height, msg.Width, connLogger)
|
||||
}
|
||||
Reference in New Issue
Block a user