package reconnectingpty import ( "context" "encoding/binary" "encoding/json" "net" "sync" "sync/atomic" "time" "github.com/google/uuid" "github.com/prometheus/client_golang/prometheus" "golang.org/x/xerrors" "cdr.dev/slog/v3" "github.com/coder/coder/v2/agent/agentcontainers" "github.com/coder/coder/v2/agent/agentssh" "github.com/coder/coder/v2/agent/usershell" "github.com/coder/coder/v2/codersdk/workspacesdk" ) type reportConnectionFunc func(id uuid.UUID, ip string) (disconnected func(code int, reason string)) type Server struct { logger slog.Logger connectionsTotal prometheus.Counter errorsTotal *prometheus.CounterVec commandCreator *agentssh.Server reportConnection reportConnectionFunc connCount atomic.Int64 reconnectingPTYs sync.Map timeout time.Duration // Experimental: allow connecting to running containers via Docker exec. // Note that this is different from the devcontainers feature, which uses // subagents. ExperimentalContainers bool } // NewServer returns a new ReconnectingPTY server func NewServer(logger slog.Logger, commandCreator *agentssh.Server, reportConnection reportConnectionFunc, connectionsTotal prometheus.Counter, errorsTotal *prometheus.CounterVec, timeout time.Duration, opts ...func(*Server), ) *Server { if reportConnection == nil { reportConnection = func(uuid.UUID, string) func(int, string) { return func(int, string) {} } } s := &Server{ logger: logger, commandCreator: commandCreator, reportConnection: reportConnection, connectionsTotal: connectionsTotal, errorsTotal: errorsTotal, timeout: timeout, } for _, o := range opts { o(s) } return s } func (s *Server) Serve(ctx, hardCtx context.Context, l net.Listener) (retErr error) { var wg sync.WaitGroup for { if ctx.Err() != nil { break } conn, err := l.Accept() if err != nil { s.logger.Debug(ctx, "accept pty failed", slog.Error(err)) retErr = err break } clog := s.logger.With( slog.F("remote", conn.RemoteAddr()), slog.F("local", conn.LocalAddr())) clog.Info(ctx, "accepted conn") // It's not safe to assume RemoteAddr() returns a non-nil value. slog.F usage is fine because it correctly // handles nil. // c.f. https://github.com/coder/internal/issues/1143 remoteAddr := conn.RemoteAddr() remoteAddrString := "" if remoteAddr != nil { remoteAddrString = remoteAddr.String() } wg.Add(1) disconnected := s.reportConnection(uuid.New(), remoteAddrString) closed := make(chan struct{}) go func() { defer wg.Done() select { case <-closed: case <-hardCtx.Done(): disconnected(1, "server shut down") _ = conn.Close() } }() wg.Add(1) go func() { defer close(closed) defer wg.Done() err := s.handleConn(ctx, clog, conn) if err != nil { if ctx.Err() != nil { disconnected(1, "server shutting down") } else { disconnected(1, err.Error()) } } else { disconnected(0, "") } }() } wg.Wait() return retErr } func (s *Server) ConnCount() int64 { return s.connCount.Load() } func (s *Server) handleConn(ctx context.Context, logger slog.Logger, conn net.Conn) (retErr error) { defer conn.Close() s.connectionsTotal.Add(1) s.connCount.Add(1) defer s.connCount.Add(-1) // This cannot use a JSON decoder, since that can // buffer additional data that is required for the PTY. rawLen := make([]byte, 2) _, err := conn.Read(rawLen) if err != nil { // logging at info since a single incident isn't too worrying (the client could just have // hung up), but if we get a lot of these we'd want to investigate. logger.Info(ctx, "failed to read AgentReconnectingPTYInit length", slog.Error(err)) return nil } length := binary.LittleEndian.Uint16(rawLen) data := make([]byte, length) _, err = conn.Read(data) if err != nil { // logging at info since a single incident isn't too worrying (the client could just have // hung up), but if we get a lot of these we'd want to investigate. logger.Info(ctx, "failed to read AgentReconnectingPTYInit", slog.Error(err)) return nil } var msg workspacesdk.AgentReconnectingPTYInit err = json.Unmarshal(data, &msg) if err != nil { logger.Warn(ctx, "failed to unmarshal init", slog.F("raw", data)) return nil } connectionID := uuid.NewString() connLogger := logger.With(slog.F("message_id", msg.ID), slog.F("connection_id", connectionID), slog.F("container", msg.Container), slog.F("container_user", msg.ContainerUser)) connLogger.Debug(ctx, "starting handler") defer func() { if err := retErr; err != nil { // If the context is done, we don't want to log this as an error since it's expected. if ctx.Err() != nil { connLogger.Info(ctx, "reconnecting pty failed with attach error (agent closed)", slog.Error(err)) } else { connLogger.Error(ctx, "reconnecting pty failed with attach error", slog.Error(err)) } } connLogger.Info(ctx, "reconnecting pty connection closed") }() var rpty ReconnectingPTY sendConnected := make(chan ReconnectingPTY, 1) // On store, reserve this ID to prevent multiple concurrent new connections. waitReady, ok := s.reconnectingPTYs.LoadOrStore(msg.ID, sendConnected) if ok { close(sendConnected) // Unused. connLogger.Debug(ctx, "connecting to existing reconnecting pty") c, ok := waitReady.(chan ReconnectingPTY) if !ok { return xerrors.Errorf("found invalid type in reconnecting pty map: %T", waitReady) } rpty, ok = <-c if !ok || rpty == nil { return xerrors.Errorf("reconnecting pty closed before connection") } c <- rpty // Put it back for the next reconnect. } else { connLogger.Debug(ctx, "creating new reconnecting pty") connected := false defer func() { if !connected && retErr != nil { s.reconnectingPTYs.Delete(msg.ID) close(sendConnected) } }() var ei usershell.EnvInfoer if s.ExperimentalContainers && msg.Container != "" { dei, err := agentcontainers.EnvInfo(ctx, s.commandCreator.Execer, msg.Container, msg.ContainerUser) if err != nil { return xerrors.Errorf("get container env info: %w", err) } ei = dei s.logger.Info(ctx, "got container env info", slog.F("container", msg.Container)) } // Empty command will default to the users shell! cmd, err := s.commandCreator.CreateCommand(ctx, msg.Command, nil, ei) if err != nil { s.errorsTotal.WithLabelValues("create_command").Add(1) return xerrors.Errorf("create command: %w", err) } rpty = New(ctx, logger.With(slog.F("message_id", msg.ID)), s.commandCreator.Execer, cmd, &Options{ Timeout: s.timeout, Metrics: s.errorsTotal, BackendType: msg.BackendType, }, ) done := make(chan struct{}) go func() { select { case <-done: case <-ctx.Done(): rpty.Close(ctx.Err()) } }() go func() { rpty.Wait() s.reconnectingPTYs.Delete(msg.ID) }() connected = true sendConnected <- rpty } return rpty.Attach(ctx, connectionID, conn, msg.Height, msg.Width, connLogger) }