feat(agent): add connection reporting for SSH and reconnecting PTY (#16652)

Updates #15139
This commit is contained in:
Mathias Fredriksson
2025-02-27 12:45:45 +02:00
committed by GitHub
parent 6dd51f92fb
commit 4ba5a8a2ba
7 changed files with 382 additions and 32 deletions
+158
View File
@@ -8,6 +8,7 @@ import (
"fmt"
"hash/fnv"
"io"
"net"
"net/http"
"net/netip"
"os"
@@ -28,6 +29,7 @@ import (
"golang.org/x/exp/slices"
"golang.org/x/sync/errgroup"
"golang.org/x/xerrors"
"google.golang.org/protobuf/types/known/timestamppb"
"tailscale.com/net/speedtest"
"tailscale.com/tailcfg"
"tailscale.com/types/netlogtype"
@@ -90,6 +92,7 @@ type Options struct {
ContainerLister agentcontainers.Lister
ExperimentalContainersEnabled bool
ExperimentalConnectionReports bool
}
type Client interface {
@@ -177,6 +180,7 @@ func New(options Options) Agent {
lifecycleUpdate: make(chan struct{}, 1),
lifecycleReported: make(chan codersdk.WorkspaceAgentLifecycle, 1),
lifecycleStates: []agentsdk.PostLifecycleRequest{{State: codersdk.WorkspaceAgentLifecycleCreated}},
reportConnectionsUpdate: make(chan struct{}, 1),
ignorePorts: options.IgnorePorts,
portCacheDuration: options.PortCacheDuration,
reportMetadataInterval: options.ReportMetadataInterval,
@@ -192,6 +196,7 @@ func New(options Options) Agent {
lister: options.ContainerLister,
experimentalDevcontainersEnabled: options.ExperimentalContainersEnabled,
experimentalConnectionReports: options.ExperimentalConnectionReports,
}
// Initially, we have a closed channel, reflecting the fact that we are not initially connected.
// Each time we connect we replace the channel (while holding the closeMutex) with a new one
@@ -252,6 +257,10 @@ type agent struct {
lifecycleStates []agentsdk.PostLifecycleRequest
lifecycleLastReportedIndex int // Keeps track of the last lifecycle state we successfully reported.
reportConnectionsUpdate chan struct{}
reportConnectionsMu sync.Mutex
reportConnections []*proto.ReportConnectionRequest
network *tailnet.Conn
statsReporter *statsReporter
logSender *agentsdk.LogSender
@@ -264,6 +273,7 @@ type agent struct {
lister agentcontainers.Lister
experimentalDevcontainersEnabled bool
experimentalConnectionReports bool
}
func (a *agent) TailnetConn() *tailnet.Conn {
@@ -279,6 +289,24 @@ func (a *agent) init() {
UpdateEnv: a.updateCommandEnv,
WorkingDirectory: func() string { return a.manifest.Load().Directory },
BlockFileTransfer: a.blockFileTransfer,
ReportConnection: func(id uuid.UUID, magicType agentssh.MagicSessionType, ip string) func(code int, reason string) {
var connectionType proto.Connection_Type
switch magicType {
case agentssh.MagicSessionTypeSSH:
connectionType = proto.Connection_SSH
case agentssh.MagicSessionTypeVSCode:
connectionType = proto.Connection_VSCODE
case agentssh.MagicSessionTypeJetBrains:
connectionType = proto.Connection_JETBRAINS
case agentssh.MagicSessionTypeUnknown:
connectionType = proto.Connection_TYPE_UNSPECIFIED
default:
a.logger.Error(a.hardCtx, "unhandled magic session type when reporting connection", slog.F("magic_type", magicType))
connectionType = proto.Connection_TYPE_UNSPECIFIED
}
return a.reportConnection(id, connectionType, ip)
},
})
if err != nil {
panic(err)
@@ -301,6 +329,9 @@ func (a *agent) init() {
a.reconnectingPTYServer = reconnectingpty.NewServer(
a.logger.Named("reconnecting-pty"),
a.sshServer,
func(id uuid.UUID, ip string) func(code int, reason string) {
return a.reportConnection(id, proto.Connection_RECONNECTING_PTY, ip)
},
a.metrics.connectionsTotal, a.metrics.reconnectingPTYErrors,
a.reconnectingPTYTimeout,
func(s *reconnectingpty.Server) {
@@ -713,6 +744,129 @@ func (a *agent) setLifecycle(state codersdk.WorkspaceAgentLifecycle) {
}
}
// reportConnectionsLoop reports connections to the agent for auditing.
func (a *agent) reportConnectionsLoop(ctx context.Context, aAPI proto.DRPCAgentClient24) error {
for {
select {
case <-a.reportConnectionsUpdate:
case <-ctx.Done():
return ctx.Err()
}
for {
a.reportConnectionsMu.Lock()
if len(a.reportConnections) == 0 {
a.reportConnectionsMu.Unlock()
break
}
payload := a.reportConnections[0]
// Release lock while we send the payload, this is safe
// since we only append to the slice.
a.reportConnectionsMu.Unlock()
logger := a.logger.With(slog.F("payload", payload))
logger.Debug(ctx, "reporting connection")
_, err := aAPI.ReportConnection(ctx, payload)
if err != nil {
return xerrors.Errorf("failed to report connection: %w", err)
}
logger.Debug(ctx, "successfully reported connection")
// Remove the payload we sent.
a.reportConnectionsMu.Lock()
a.reportConnections[0] = nil // Release the pointer from the underlying array.
a.reportConnections = a.reportConnections[1:]
a.reportConnectionsMu.Unlock()
}
}
}
const (
// reportConnectionBufferLimit limits the number of connection reports we
// buffer to avoid growing the buffer indefinitely. This should not happen
// unless the agent has lost connection to coderd for a long time or if
// the agent is being spammed with connections.
//
// If we assume ~150 byte per connection report, this would be around 300KB
// of memory which seems acceptable. We could reduce this if necessary by
// not using the proto struct directly.
reportConnectionBufferLimit = 2048
)
func (a *agent) reportConnection(id uuid.UUID, connectionType proto.Connection_Type, ip string) (disconnected func(code int, reason string)) {
// If the experiment hasn't been enabled, we don't report connections.
if !a.experimentalConnectionReports {
return func(int, string) {} // Noop.
}
// Remove the port from the IP because ports are not supported in coderd.
if host, _, err := net.SplitHostPort(ip); err != nil {
a.logger.Error(a.hardCtx, "split host and port for connection report failed", slog.F("ip", ip), slog.Error(err))
} else {
// Best effort.
ip = host
}
a.reportConnectionsMu.Lock()
defer a.reportConnectionsMu.Unlock()
if len(a.reportConnections) >= reportConnectionBufferLimit {
a.logger.Warn(a.hardCtx, "connection report buffer limit reached, dropping connect",
slog.F("limit", reportConnectionBufferLimit),
slog.F("connection_id", id),
slog.F("connection_type", connectionType),
slog.F("ip", ip),
)
} else {
a.reportConnections = append(a.reportConnections, &proto.ReportConnectionRequest{
Connection: &proto.Connection{
Id: id[:],
Action: proto.Connection_CONNECT,
Type: connectionType,
Timestamp: timestamppb.New(time.Now()),
Ip: ip,
StatusCode: 0,
Reason: nil,
},
})
select {
case a.reportConnectionsUpdate <- struct{}{}:
default:
}
}
return func(code int, reason string) {
a.reportConnectionsMu.Lock()
defer a.reportConnectionsMu.Unlock()
if len(a.reportConnections) >= reportConnectionBufferLimit {
a.logger.Warn(a.hardCtx, "connection report buffer limit reached, dropping disconnect",
slog.F("limit", reportConnectionBufferLimit),
slog.F("connection_id", id),
slog.F("connection_type", connectionType),
slog.F("ip", ip),
)
return
}
a.reportConnections = append(a.reportConnections, &proto.ReportConnectionRequest{
Connection: &proto.Connection{
Id: id[:],
Action: proto.Connection_DISCONNECT,
Type: connectionType,
Timestamp: timestamppb.New(time.Now()),
Ip: ip,
StatusCode: int32(code), //nolint:gosec
Reason: &reason,
},
})
select {
case a.reportConnectionsUpdate <- struct{}{}:
default:
}
}
}
// fetchServiceBannerLoop fetches the service banner on an interval. It will
// not be fetched immediately; the expectation is that it is primed elsewhere
// (and must be done before the session actually starts).
@@ -823,6 +977,10 @@ func (a *agent) run() (retErr error) {
return resourcesmonitor.Start(ctx)
})
// Connection reports are part of auditing, we should keep sending them via
// gracefulShutdownBehaviorRemain.
connMan.startAgentAPI("report connections", gracefulShutdownBehaviorRemain, a.reportConnectionsLoop)
// channels to sync goroutines below
// handle manifest
// |
+79 -8
View File
@@ -163,7 +163,9 @@ func TestAgent_Stats_Magic(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
//nolint:dogsled
conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
conn, agentClient, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
o.ExperimentalConnectionReports = true
})
sshClient, err := conn.SSHClient(ctx)
require.NoError(t, err)
defer sshClient.Close()
@@ -193,6 +195,8 @@ func TestAgent_Stats_Magic(t *testing.T) {
_ = stdin.Close()
err = session.Wait()
require.NoError(t, err)
assertConnectionReport(t, agentClient, proto.Connection_VSCODE, 0, "")
})
t.Run("TracksJetBrains", func(t *testing.T) {
@@ -229,7 +233,9 @@ func TestAgent_Stats_Magic(t *testing.T) {
remotePort := sc.Text()
//nolint:dogsled
conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
conn, agentClient, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
o.ExperimentalConnectionReports = true
})
sshClient, err := conn.SSHClient(ctx)
require.NoError(t, err)
@@ -265,6 +271,8 @@ func TestAgent_Stats_Magic(t *testing.T) {
}, testutil.WaitLong, testutil.IntervalFast,
"never saw stats after conn closes",
)
assertConnectionReport(t, agentClient, proto.Connection_JETBRAINS, 0, "")
})
}
@@ -922,7 +930,9 @@ func TestAgent_SFTP(t *testing.T) {
home = "/" + strings.ReplaceAll(home, "\\", "/")
}
//nolint:dogsled
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
o.ExperimentalConnectionReports = true
})
sshClient, err := conn.SSHClient(ctx)
require.NoError(t, err)
defer sshClient.Close()
@@ -945,6 +955,10 @@ func TestAgent_SFTP(t *testing.T) {
require.NoError(t, err)
_, err = os.Stat(tempFile)
require.NoError(t, err)
// Close the client to trigger disconnect event.
_ = client.Close()
assertConnectionReport(t, agentClient, proto.Connection_SSH, 0, "")
}
func TestAgent_SCP(t *testing.T) {
@@ -954,7 +968,9 @@ func TestAgent_SCP(t *testing.T) {
defer cancel()
//nolint:dogsled
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
o.ExperimentalConnectionReports = true
})
sshClient, err := conn.SSHClient(ctx)
require.NoError(t, err)
defer sshClient.Close()
@@ -967,6 +983,10 @@ func TestAgent_SCP(t *testing.T) {
require.NoError(t, err)
_, err = os.Stat(tempFile)
require.NoError(t, err)
// Close the client to trigger disconnect event.
scpClient.Close()
assertConnectionReport(t, agentClient, proto.Connection_SSH, 0, "")
}
func TestAgent_FileTransferBlocked(t *testing.T) {
@@ -991,8 +1011,9 @@ func TestAgent_FileTransferBlocked(t *testing.T) {
defer cancel()
//nolint:dogsled
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
o.BlockFileTransfer = true
o.ExperimentalConnectionReports = true
})
sshClient, err := conn.SSHClient(ctx)
require.NoError(t, err)
@@ -1000,6 +1021,8 @@ func TestAgent_FileTransferBlocked(t *testing.T) {
_, err = sftp.NewClient(sshClient)
require.Error(t, err)
assertFileTransferBlocked(t, err.Error())
assertConnectionReport(t, agentClient, proto.Connection_SSH, agentssh.BlockedFileTransferErrorCode, "")
})
t.Run("SCP with go-scp package", func(t *testing.T) {
@@ -1009,8 +1032,9 @@ func TestAgent_FileTransferBlocked(t *testing.T) {
defer cancel()
//nolint:dogsled
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
o.BlockFileTransfer = true
o.ExperimentalConnectionReports = true
})
sshClient, err := conn.SSHClient(ctx)
require.NoError(t, err)
@@ -1022,6 +1046,8 @@ func TestAgent_FileTransferBlocked(t *testing.T) {
err = scpClient.CopyFile(context.Background(), strings.NewReader("hello world"), tempFile, "0755")
require.Error(t, err)
assertFileTransferBlocked(t, err.Error())
assertConnectionReport(t, agentClient, proto.Connection_SSH, agentssh.BlockedFileTransferErrorCode, "")
})
t.Run("Forbidden commands", func(t *testing.T) {
@@ -1035,8 +1061,9 @@ func TestAgent_FileTransferBlocked(t *testing.T) {
defer cancel()
//nolint:dogsled
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
o.BlockFileTransfer = true
o.ExperimentalConnectionReports = true
})
sshClient, err := conn.SSHClient(ctx)
require.NoError(t, err)
@@ -1057,6 +1084,8 @@ func TestAgent_FileTransferBlocked(t *testing.T) {
msg, err := io.ReadAll(stdout)
require.NoError(t, err)
assertFileTransferBlocked(t, string(msg))
assertConnectionReport(t, agentClient, proto.Connection_SSH, agentssh.BlockedFileTransferErrorCode, "")
})
}
})
@@ -1665,8 +1694,18 @@ func TestAgent_ReconnectingPTY(t *testing.T) {
defer cancel()
//nolint:dogsled
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
o.ExperimentalConnectionReports = true
})
id := uuid.New()
// Test that the connection is reported. This must be tested in the
// first connection because we care about verifying all of these.
netConn0, err := conn.ReconnectingPTY(ctx, id, 80, 80, "bash --norc")
require.NoError(t, err)
_ = netConn0.Close()
assertConnectionReport(t, agentClient, proto.Connection_RECONNECTING_PTY, 0, "")
// --norc disables executing .bashrc, which is often used to customize the bash prompt
netConn1, err := conn.ReconnectingPTY(ctx, id, 80, 80, "bash --norc")
require.NoError(t, err)
@@ -2763,3 +2802,35 @@ func requireEcho(t *testing.T, conn net.Conn) {
require.NoError(t, err)
require.Equal(t, "test", string(b))
}
func assertConnectionReport(t testing.TB, agentClient *agenttest.Client, connectionType proto.Connection_Type, status int, reason string) {
t.Helper()
var reports []*proto.ReportConnectionRequest
if !assert.Eventually(t, func() bool {
reports = agentClient.GetConnectionReports()
return len(reports) >= 2
}, testutil.WaitMedium, testutil.IntervalFast, "waiting for 2 connection reports or more; got %d", len(reports)) {
return
}
assert.Len(t, reports, 2, "want 2 connection reports")
assert.Equal(t, proto.Connection_CONNECT, reports[0].GetConnection().GetAction(), "first report should be connect")
assert.Equal(t, proto.Connection_DISCONNECT, reports[1].GetConnection().GetAction(), "second report should be disconnect")
assert.Equal(t, connectionType, reports[0].GetConnection().GetType(), "connect type should be %s", connectionType)
assert.Equal(t, connectionType, reports[1].GetConnection().GetType(), "disconnect type should be %s", connectionType)
t1 := reports[0].GetConnection().GetTimestamp().AsTime()
t2 := reports[1].GetConnection().GetTimestamp().AsTime()
assert.True(t, t1.Before(t2) || t1.Equal(t2), "connect timestamp should be before or equal to disconnect timestamp")
assert.NotEmpty(t, reports[0].GetConnection().GetIp(), "connect ip should not be empty")
assert.NotEmpty(t, reports[1].GetConnection().GetIp(), "disconnect ip should not be empty")
assert.Equal(t, 0, int(reports[0].GetConnection().GetStatusCode()), "connect status code should be 0")
assert.Equal(t, status, int(reports[1].GetConnection().GetStatusCode()), "disconnect status code should be %d", status)
assert.Equal(t, "", reports[0].GetConnection().GetReason(), "connect reason should be empty")
if reason != "" {
assert.Contains(t, reports[1].GetConnection().GetReason(), reason, "disconnect reason should contain %s", reason)
} else {
t.Logf("connection report disconnect reason: %s", reports[1].GetConnection().GetReason())
}
}
+77 -10
View File
@@ -78,6 +78,8 @@ const (
// BlockedFileTransferCommands contains a list of restricted file transfer commands.
var BlockedFileTransferCommands = []string{"nc", "rsync", "scp", "sftp"}
type reportConnectionFunc func(id uuid.UUID, sessionType MagicSessionType, ip string) (disconnected func(code int, reason string))
// Config sets configuration parameters for the agent SSH server.
type Config struct {
// MaxTimeout sets the absolute connection timeout, none if empty. If set to
@@ -100,6 +102,8 @@ type Config struct {
X11DisplayOffset *int
// BlockFileTransfer restricts use of file transfer applications.
BlockFileTransfer bool
// ReportConnection.
ReportConnection reportConnectionFunc
}
type Server struct {
@@ -152,6 +156,9 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
return home
}
}
if config.ReportConnection == nil {
config.ReportConnection = func(uuid.UUID, MagicSessionType, string) func(int, string) { return func(int, string) {} }
}
forwardHandler := &ssh.ForwardedTCPHandler{}
unixForwardHandler := newForwardedUnixHandler(logger)
@@ -174,7 +181,7 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
ChannelHandlers: map[string]ssh.ChannelHandler{
"direct-tcpip": func(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) {
// Wrapper is designed to find and track JetBrains Gateway connections.
wrapped := NewJetbrainsChannelWatcher(ctx, s.logger, newChan, &s.connCountJetBrains)
wrapped := NewJetbrainsChannelWatcher(ctx, s.logger, s.config.ReportConnection, newChan, &s.connCountJetBrains)
ssh.DirectTCPIPHandler(srv, conn, wrapped, ctx)
},
"direct-streamlocal@openssh.com": directStreamLocalHandler,
@@ -288,6 +295,35 @@ func extractMagicSessionType(env []string) (magicType MagicSessionType, rawType
})
}
// sessionCloseTracker is a wrapper around Session that tracks the exit code.
type sessionCloseTracker struct {
ssh.Session
exitOnce sync.Once
code atomic.Int64
}
var _ ssh.Session = &sessionCloseTracker{}
func (s *sessionCloseTracker) track(code int) {
s.exitOnce.Do(func() {
s.code.Store(int64(code))
})
}
func (s *sessionCloseTracker) exitCode() int {
return int(s.code.Load())
}
func (s *sessionCloseTracker) Exit(code int) error {
s.track(code)
return s.Session.Exit(code)
}
func (s *sessionCloseTracker) Close() error {
s.track(1)
return s.Session.Close()
}
func (s *Server) sessionHandler(session ssh.Session) {
ctx := session.Context()
id := uuid.New()
@@ -300,17 +336,23 @@ func (s *Server) sessionHandler(session ssh.Session) {
)
logger.Info(ctx, "handling ssh session")
env := session.Environ()
magicType, magicTypeRaw, env := extractMagicSessionType(env)
if !s.trackSession(session, true) {
reason := "unable to accept new session, server is closing"
// Report connection attempt even if we couldn't accept it.
disconnected := s.config.ReportConnection(id, magicType, session.RemoteAddr().String())
defer disconnected(1, reason)
logger.Info(ctx, reason)
// See (*Server).Close() for why we call Close instead of Exit.
_ = session.Close()
logger.Info(ctx, "unable to accept new session, server is closing")
return
}
defer s.trackSession(session, false)
env := session.Environ()
magicType, magicTypeRaw, env := extractMagicSessionType(env)
reportSession := true
switch magicType {
case MagicSessionTypeVSCode:
s.connCountVSCode.Add(1)
@@ -318,6 +360,7 @@ func (s *Server) sessionHandler(session ssh.Session) {
case MagicSessionTypeJetBrains:
// Do nothing here because JetBrains launches hundreds of ssh sessions.
// We instead track JetBrains in the single persistent tcp forwarding channel.
reportSession = false
case MagicSessionTypeSSH:
s.connCountSSHSession.Add(1)
defer s.connCountSSHSession.Add(-1)
@@ -325,6 +368,20 @@ func (s *Server) sessionHandler(session ssh.Session) {
logger.Warn(ctx, "invalid magic ssh session type specified", slog.F("raw_type", magicTypeRaw))
}
closeCause := func(string) {}
if reportSession {
var reason string
closeCause = func(r string) { reason = r }
scr := &sessionCloseTracker{Session: session}
session = scr
disconnected := s.config.ReportConnection(id, magicType, session.RemoteAddr().String())
defer func() {
disconnected(scr.exitCode(), reason)
}()
}
if s.fileTransferBlocked(session) {
s.logger.Warn(ctx, "file transfer blocked", slog.F("session_subsystem", session.Subsystem()), slog.F("raw_command", session.RawCommand()))
@@ -333,6 +390,7 @@ func (s *Server) sessionHandler(session ssh.Session) {
errorMessage := fmt.Sprintf("\x02%s\n", BlockedFileTransferErrorMessage)
_, _ = session.Write([]byte(errorMessage))
}
closeCause("file transfer blocked")
_ = session.Exit(BlockedFileTransferErrorCode)
return
}
@@ -340,10 +398,14 @@ func (s *Server) sessionHandler(session ssh.Session) {
switch ss := session.Subsystem(); ss {
case "":
case "sftp":
s.sftpHandler(logger, session)
err := s.sftpHandler(logger, session)
if err != nil {
closeCause(err.Error())
}
return
default:
logger.Warn(ctx, "unsupported subsystem", slog.F("subsystem", ss))
closeCause(fmt.Sprintf("unsupported subsystem: %s", ss))
_ = session.Exit(1)
return
}
@@ -352,8 +414,9 @@ func (s *Server) sessionHandler(session ssh.Session) {
if hasX11 {
display, handled := s.x11Handler(session.Context(), x11)
if !handled {
_ = session.Exit(1)
logger.Error(ctx, "x11 handler failed")
closeCause("x11 handler failed")
_ = session.Exit(1)
return
}
env = append(env, fmt.Sprintf("DISPLAY=localhost:%d.%d", display, x11.ScreenNumber))
@@ -380,6 +443,8 @@ func (s *Server) sessionHandler(session ssh.Session) {
slog.F("exit_code", code),
)
closeCause(fmt.Sprintf("process exited with error status: %d", exitError.ExitCode()))
// TODO(mafredri): For signal exit, there's also an "exit-signal"
// request (session.Exit sends "exit-status"), however, since it's
// not implemented on the session interface and not used by
@@ -391,6 +456,7 @@ func (s *Server) sessionHandler(session ssh.Session) {
logger.Warn(ctx, "ssh session failed", slog.Error(err))
// This exit code is designed to be unlikely to be confused for a legit exit code
// from the process.
closeCause(err.Error())
_ = session.Exit(MagicSessionErrorCode)
return
}
@@ -650,7 +716,7 @@ func handleSignal(logger slog.Logger, ssig ssh.Signal, signaler interface{ Signa
}
}
func (s *Server) sftpHandler(logger slog.Logger, session ssh.Session) {
func (s *Server) sftpHandler(logger slog.Logger, session ssh.Session) error {
s.metrics.sftpConnectionsTotal.Add(1)
ctx := session.Context()
@@ -674,7 +740,7 @@ func (s *Server) sftpHandler(logger slog.Logger, session ssh.Session) {
server, err := sftp.NewServer(session, opts...)
if err != nil {
logger.Debug(ctx, "initialize sftp server", slog.Error(err))
return
return xerrors.Errorf("initialize sftp server: %w", err)
}
defer server.Close()
@@ -689,11 +755,12 @@ func (s *Server) sftpHandler(logger slog.Logger, session ssh.Session) {
// code but `scp` on macOS does (when using the default
// SFTP backend).
_ = session.Exit(0)
return
return nil
}
logger.Warn(ctx, "sftp server closed with error", slog.Error(err))
s.metrics.sftpServerErrors.Add(1)
_ = session.Exit(1)
return xerrors.Errorf("sftp server closed with error: %w", err)
}
// CreateCommand processes raw command input with OpenSSH-like behavior.
+10 -1
View File
@@ -6,6 +6,7 @@ import (
"sync"
"github.com/gliderlabs/ssh"
"github.com/google/uuid"
"go.uber.org/atomic"
gossh "golang.org/x/crypto/ssh"
@@ -28,9 +29,11 @@ type JetbrainsChannelWatcher struct {
gossh.NewChannel
jetbrainsCounter *atomic.Int64
logger slog.Logger
originAddr string
reportConnection reportConnectionFunc
}
func NewJetbrainsChannelWatcher(ctx ssh.Context, logger slog.Logger, newChannel gossh.NewChannel, counter *atomic.Int64) gossh.NewChannel {
func NewJetbrainsChannelWatcher(ctx ssh.Context, logger slog.Logger, reportConnection reportConnectionFunc, newChannel gossh.NewChannel, counter *atomic.Int64) gossh.NewChannel {
d := localForwardChannelData{}
if err := gossh.Unmarshal(newChannel.ExtraData(), &d); err != nil {
// If the data fails to unmarshal, do nothing.
@@ -61,12 +64,17 @@ func NewJetbrainsChannelWatcher(ctx ssh.Context, logger slog.Logger, newChannel
NewChannel: newChannel,
jetbrainsCounter: counter,
logger: logger.With(slog.F("destination_port", d.DestPort)),
originAddr: d.OriginAddr,
reportConnection: reportConnection,
}
}
func (w *JetbrainsChannelWatcher) Accept() (gossh.Channel, <-chan *gossh.Request, error) {
disconnected := w.reportConnection(uuid.New(), MagicSessionTypeJetBrains, w.originAddr)
c, r, err := w.NewChannel.Accept()
if err != nil {
disconnected(1, err.Error())
return c, r, err
}
w.jetbrainsCounter.Add(1)
@@ -77,6 +85,7 @@ func (w *JetbrainsChannelWatcher) Accept() (gossh.Channel, <-chan *gossh.Request
Channel: c,
done: func() {
w.jetbrainsCounter.Add(-1)
disconnected(0, "")
// nolint: gocritic // JetBrains is a proper noun and should be capitalized
w.logger.Debug(context.Background(), "JetBrains watcher channel closed")
},
+20 -10
View File
@@ -158,20 +158,24 @@ func (c *Client) SetLogsChannel(ch chan<- *agentproto.BatchCreateLogsRequest) {
c.fakeAgentAPI.SetLogsChannel(ch)
}
func (c *Client) GetConnectionReports() []*agentproto.ReportConnectionRequest {
return c.fakeAgentAPI.GetConnectionReports()
}
type FakeAgentAPI struct {
sync.Mutex
t testing.TB
logger slog.Logger
manifest *agentproto.Manifest
startupCh chan *agentproto.Startup
statsCh chan *agentproto.Stats
appHealthCh chan *agentproto.BatchUpdateAppHealthRequest
logsCh chan<- *agentproto.BatchCreateLogsRequest
lifecycleStates []codersdk.WorkspaceAgentLifecycle
metadata map[string]agentsdk.Metadata
timings []*agentproto.Timing
connections []*agentproto.Connection
manifest *agentproto.Manifest
startupCh chan *agentproto.Startup
statsCh chan *agentproto.Stats
appHealthCh chan *agentproto.BatchUpdateAppHealthRequest
logsCh chan<- *agentproto.BatchCreateLogsRequest
lifecycleStates []codersdk.WorkspaceAgentLifecycle
metadata map[string]agentsdk.Metadata
timings []*agentproto.Timing
connectionReports []*agentproto.ReportConnectionRequest
getAnnouncementBannersFunc func() ([]codersdk.BannerConfig, error)
getResourcesMonitoringConfigurationFunc func() (*agentproto.GetResourcesMonitoringConfigurationResponse, error)
@@ -348,12 +352,18 @@ func (f *FakeAgentAPI) ScriptCompleted(_ context.Context, req *agentproto.Worksp
func (f *FakeAgentAPI) ReportConnection(_ context.Context, req *agentproto.ReportConnectionRequest) (*emptypb.Empty, error) {
f.Lock()
f.connections = append(f.connections, req.GetConnection())
f.connectionReports = append(f.connectionReports, req)
f.Unlock()
return &emptypb.Empty{}, nil
}
func (f *FakeAgentAPI) GetConnectionReports() []*agentproto.ReportConnectionRequest {
f.Lock()
defer f.Unlock()
return slices.Clone(f.connectionReports)
}
func NewFakeAgentAPI(t testing.TB, logger slog.Logger, manifest *agentproto.Manifest, statsCh chan *agentproto.Stats) *FakeAgentAPI {
return &FakeAgentAPI{
t: t,
+23 -3
View File
@@ -20,11 +20,14 @@ import (
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
type reportConnectionFunc func(id uuid.UUID, ip string) (disconnected func(code int, reason string))
type Server struct {
logger slog.Logger
connectionsTotal prometheus.Counter
errorsTotal *prometheus.CounterVec
commandCreator *agentssh.Server
reportConnection reportConnectionFunc
connCount atomic.Int64
reconnectingPTYs sync.Map
timeout time.Duration
@@ -33,13 +36,19 @@ type Server struct {
}
// NewServer returns a new ReconnectingPTY server
func NewServer(logger slog.Logger, commandCreator *agentssh.Server,
func NewServer(logger slog.Logger, commandCreator *agentssh.Server, reportConnection reportConnectionFunc,
connectionsTotal prometheus.Counter, errorsTotal *prometheus.CounterVec,
timeout time.Duration, opts ...func(*Server),
) *Server {
if reportConnection == nil {
reportConnection = func(uuid.UUID, string) func(int, string) {
return func(int, string) {}
}
}
s := &Server{
logger: logger,
commandCreator: commandCreator,
reportConnection: reportConnection,
connectionsTotal: connectionsTotal,
errorsTotal: errorsTotal,
timeout: timeout,
@@ -67,20 +76,31 @@ func (s *Server) Serve(ctx, hardCtx context.Context, l net.Listener) (retErr err
slog.F("local", conn.LocalAddr().String()))
clog.Info(ctx, "accepted conn")
wg.Add(1)
disconnected := s.reportConnection(uuid.New(), conn.RemoteAddr().String())
closed := make(chan struct{})
go func() {
defer wg.Done()
select {
case <-closed:
case <-hardCtx.Done():
disconnected(1, "server shut down")
_ = conn.Close()
}
wg.Done()
}()
wg.Add(1)
go func() {
defer close(closed)
defer wg.Done()
_ = s.handleConn(ctx, clog, conn)
err := s.handleConn(ctx, clog, conn)
if err != nil {
if ctx.Err() != nil {
disconnected(1, "server shutting down")
} else {
disconnected(1, err.Error())
}
} else {
disconnected(0, "")
}
}()
}
wg.Wait()
+15
View File
@@ -54,6 +54,8 @@ func (r *RootCmd) workspaceAgent() *serpent.Command {
agentHeaderCommand string
agentHeader []string
devcontainersEnabled bool
experimentalConnectionReports bool
)
cmd := &serpent.Command{
Use: "agent",
@@ -325,6 +327,10 @@ func (r *RootCmd) workspaceAgent() *serpent.Command {
containerLister = agentcontainers.NewDocker(execer)
}
if experimentalConnectionReports {
logger.Info(ctx, "experimental connection reports enabled")
}
agnt := agent.New(agent.Options{
Client: client,
Logger: logger,
@@ -353,6 +359,7 @@ func (r *RootCmd) workspaceAgent() *serpent.Command {
ContainerLister: containerLister,
ExperimentalContainersEnabled: devcontainersEnabled,
ExperimentalConnectionReports: experimentalConnectionReports,
})
promHandler := agent.PrometheusMetricsHandler(prometheusRegistry, logger)
@@ -482,6 +489,14 @@ func (r *RootCmd) workspaceAgent() *serpent.Command {
Description: "Allow the agent to automatically detect running devcontainers.",
Value: serpent.BoolOf(&devcontainersEnabled),
},
{
Flag: "experimental-connection-reports-enable",
Hidden: true,
Default: "false",
Env: "CODER_AGENT_EXPERIMENTAL_CONNECTION_REPORTS_ENABLE",
Description: "Enable experimental connection reports.",
Value: serpent.BoolOf(&experimentalConnectionReports),
},
}
return cmd