mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
feat(agent): add connection reporting for SSH and reconnecting PTY (#16652)
Updates #15139
This commit is contained in:
committed by
GitHub
parent
6dd51f92fb
commit
4ba5a8a2ba
+158
@@ -8,6 +8,7 @@ import (
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"os"
|
||||
@@ -28,6 +29,7 @@ import (
|
||||
"golang.org/x/exp/slices"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"golang.org/x/xerrors"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
"tailscale.com/net/speedtest"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/netlogtype"
|
||||
@@ -90,6 +92,7 @@ type Options struct {
|
||||
ContainerLister agentcontainers.Lister
|
||||
|
||||
ExperimentalContainersEnabled bool
|
||||
ExperimentalConnectionReports bool
|
||||
}
|
||||
|
||||
type Client interface {
|
||||
@@ -177,6 +180,7 @@ func New(options Options) Agent {
|
||||
lifecycleUpdate: make(chan struct{}, 1),
|
||||
lifecycleReported: make(chan codersdk.WorkspaceAgentLifecycle, 1),
|
||||
lifecycleStates: []agentsdk.PostLifecycleRequest{{State: codersdk.WorkspaceAgentLifecycleCreated}},
|
||||
reportConnectionsUpdate: make(chan struct{}, 1),
|
||||
ignorePorts: options.IgnorePorts,
|
||||
portCacheDuration: options.PortCacheDuration,
|
||||
reportMetadataInterval: options.ReportMetadataInterval,
|
||||
@@ -192,6 +196,7 @@ func New(options Options) Agent {
|
||||
lister: options.ContainerLister,
|
||||
|
||||
experimentalDevcontainersEnabled: options.ExperimentalContainersEnabled,
|
||||
experimentalConnectionReports: options.ExperimentalConnectionReports,
|
||||
}
|
||||
// Initially, we have a closed channel, reflecting the fact that we are not initially connected.
|
||||
// Each time we connect we replace the channel (while holding the closeMutex) with a new one
|
||||
@@ -252,6 +257,10 @@ type agent struct {
|
||||
lifecycleStates []agentsdk.PostLifecycleRequest
|
||||
lifecycleLastReportedIndex int // Keeps track of the last lifecycle state we successfully reported.
|
||||
|
||||
reportConnectionsUpdate chan struct{}
|
||||
reportConnectionsMu sync.Mutex
|
||||
reportConnections []*proto.ReportConnectionRequest
|
||||
|
||||
network *tailnet.Conn
|
||||
statsReporter *statsReporter
|
||||
logSender *agentsdk.LogSender
|
||||
@@ -264,6 +273,7 @@ type agent struct {
|
||||
lister agentcontainers.Lister
|
||||
|
||||
experimentalDevcontainersEnabled bool
|
||||
experimentalConnectionReports bool
|
||||
}
|
||||
|
||||
func (a *agent) TailnetConn() *tailnet.Conn {
|
||||
@@ -279,6 +289,24 @@ func (a *agent) init() {
|
||||
UpdateEnv: a.updateCommandEnv,
|
||||
WorkingDirectory: func() string { return a.manifest.Load().Directory },
|
||||
BlockFileTransfer: a.blockFileTransfer,
|
||||
ReportConnection: func(id uuid.UUID, magicType agentssh.MagicSessionType, ip string) func(code int, reason string) {
|
||||
var connectionType proto.Connection_Type
|
||||
switch magicType {
|
||||
case agentssh.MagicSessionTypeSSH:
|
||||
connectionType = proto.Connection_SSH
|
||||
case agentssh.MagicSessionTypeVSCode:
|
||||
connectionType = proto.Connection_VSCODE
|
||||
case agentssh.MagicSessionTypeJetBrains:
|
||||
connectionType = proto.Connection_JETBRAINS
|
||||
case agentssh.MagicSessionTypeUnknown:
|
||||
connectionType = proto.Connection_TYPE_UNSPECIFIED
|
||||
default:
|
||||
a.logger.Error(a.hardCtx, "unhandled magic session type when reporting connection", slog.F("magic_type", magicType))
|
||||
connectionType = proto.Connection_TYPE_UNSPECIFIED
|
||||
}
|
||||
|
||||
return a.reportConnection(id, connectionType, ip)
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
@@ -301,6 +329,9 @@ func (a *agent) init() {
|
||||
a.reconnectingPTYServer = reconnectingpty.NewServer(
|
||||
a.logger.Named("reconnecting-pty"),
|
||||
a.sshServer,
|
||||
func(id uuid.UUID, ip string) func(code int, reason string) {
|
||||
return a.reportConnection(id, proto.Connection_RECONNECTING_PTY, ip)
|
||||
},
|
||||
a.metrics.connectionsTotal, a.metrics.reconnectingPTYErrors,
|
||||
a.reconnectingPTYTimeout,
|
||||
func(s *reconnectingpty.Server) {
|
||||
@@ -713,6 +744,129 @@ func (a *agent) setLifecycle(state codersdk.WorkspaceAgentLifecycle) {
|
||||
}
|
||||
}
|
||||
|
||||
// reportConnectionsLoop reports connections to the agent for auditing.
|
||||
func (a *agent) reportConnectionsLoop(ctx context.Context, aAPI proto.DRPCAgentClient24) error {
|
||||
for {
|
||||
select {
|
||||
case <-a.reportConnectionsUpdate:
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
for {
|
||||
a.reportConnectionsMu.Lock()
|
||||
if len(a.reportConnections) == 0 {
|
||||
a.reportConnectionsMu.Unlock()
|
||||
break
|
||||
}
|
||||
payload := a.reportConnections[0]
|
||||
// Release lock while we send the payload, this is safe
|
||||
// since we only append to the slice.
|
||||
a.reportConnectionsMu.Unlock()
|
||||
|
||||
logger := a.logger.With(slog.F("payload", payload))
|
||||
logger.Debug(ctx, "reporting connection")
|
||||
_, err := aAPI.ReportConnection(ctx, payload)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to report connection: %w", err)
|
||||
}
|
||||
|
||||
logger.Debug(ctx, "successfully reported connection")
|
||||
|
||||
// Remove the payload we sent.
|
||||
a.reportConnectionsMu.Lock()
|
||||
a.reportConnections[0] = nil // Release the pointer from the underlying array.
|
||||
a.reportConnections = a.reportConnections[1:]
|
||||
a.reportConnectionsMu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
// reportConnectionBufferLimit limits the number of connection reports we
|
||||
// buffer to avoid growing the buffer indefinitely. This should not happen
|
||||
// unless the agent has lost connection to coderd for a long time or if
|
||||
// the agent is being spammed with connections.
|
||||
//
|
||||
// If we assume ~150 byte per connection report, this would be around 300KB
|
||||
// of memory which seems acceptable. We could reduce this if necessary by
|
||||
// not using the proto struct directly.
|
||||
reportConnectionBufferLimit = 2048
|
||||
)
|
||||
|
||||
func (a *agent) reportConnection(id uuid.UUID, connectionType proto.Connection_Type, ip string) (disconnected func(code int, reason string)) {
|
||||
// If the experiment hasn't been enabled, we don't report connections.
|
||||
if !a.experimentalConnectionReports {
|
||||
return func(int, string) {} // Noop.
|
||||
}
|
||||
|
||||
// Remove the port from the IP because ports are not supported in coderd.
|
||||
if host, _, err := net.SplitHostPort(ip); err != nil {
|
||||
a.logger.Error(a.hardCtx, "split host and port for connection report failed", slog.F("ip", ip), slog.Error(err))
|
||||
} else {
|
||||
// Best effort.
|
||||
ip = host
|
||||
}
|
||||
|
||||
a.reportConnectionsMu.Lock()
|
||||
defer a.reportConnectionsMu.Unlock()
|
||||
|
||||
if len(a.reportConnections) >= reportConnectionBufferLimit {
|
||||
a.logger.Warn(a.hardCtx, "connection report buffer limit reached, dropping connect",
|
||||
slog.F("limit", reportConnectionBufferLimit),
|
||||
slog.F("connection_id", id),
|
||||
slog.F("connection_type", connectionType),
|
||||
slog.F("ip", ip),
|
||||
)
|
||||
} else {
|
||||
a.reportConnections = append(a.reportConnections, &proto.ReportConnectionRequest{
|
||||
Connection: &proto.Connection{
|
||||
Id: id[:],
|
||||
Action: proto.Connection_CONNECT,
|
||||
Type: connectionType,
|
||||
Timestamp: timestamppb.New(time.Now()),
|
||||
Ip: ip,
|
||||
StatusCode: 0,
|
||||
Reason: nil,
|
||||
},
|
||||
})
|
||||
select {
|
||||
case a.reportConnectionsUpdate <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
return func(code int, reason string) {
|
||||
a.reportConnectionsMu.Lock()
|
||||
defer a.reportConnectionsMu.Unlock()
|
||||
if len(a.reportConnections) >= reportConnectionBufferLimit {
|
||||
a.logger.Warn(a.hardCtx, "connection report buffer limit reached, dropping disconnect",
|
||||
slog.F("limit", reportConnectionBufferLimit),
|
||||
slog.F("connection_id", id),
|
||||
slog.F("connection_type", connectionType),
|
||||
slog.F("ip", ip),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
a.reportConnections = append(a.reportConnections, &proto.ReportConnectionRequest{
|
||||
Connection: &proto.Connection{
|
||||
Id: id[:],
|
||||
Action: proto.Connection_DISCONNECT,
|
||||
Type: connectionType,
|
||||
Timestamp: timestamppb.New(time.Now()),
|
||||
Ip: ip,
|
||||
StatusCode: int32(code), //nolint:gosec
|
||||
Reason: &reason,
|
||||
},
|
||||
})
|
||||
select {
|
||||
case a.reportConnectionsUpdate <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// fetchServiceBannerLoop fetches the service banner on an interval. It will
|
||||
// not be fetched immediately; the expectation is that it is primed elsewhere
|
||||
// (and must be done before the session actually starts).
|
||||
@@ -823,6 +977,10 @@ func (a *agent) run() (retErr error) {
|
||||
return resourcesmonitor.Start(ctx)
|
||||
})
|
||||
|
||||
// Connection reports are part of auditing, we should keep sending them via
|
||||
// gracefulShutdownBehaviorRemain.
|
||||
connMan.startAgentAPI("report connections", gracefulShutdownBehaviorRemain, a.reportConnectionsLoop)
|
||||
|
||||
// channels to sync goroutines below
|
||||
// handle manifest
|
||||
// |
|
||||
|
||||
+79
-8
@@ -163,7 +163,9 @@ func TestAgent_Stats_Magic(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
//nolint:dogsled
|
||||
conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
|
||||
conn, agentClient, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||||
o.ExperimentalConnectionReports = true
|
||||
})
|
||||
sshClient, err := conn.SSHClient(ctx)
|
||||
require.NoError(t, err)
|
||||
defer sshClient.Close()
|
||||
@@ -193,6 +195,8 @@ func TestAgent_Stats_Magic(t *testing.T) {
|
||||
_ = stdin.Close()
|
||||
err = session.Wait()
|
||||
require.NoError(t, err)
|
||||
|
||||
assertConnectionReport(t, agentClient, proto.Connection_VSCODE, 0, "")
|
||||
})
|
||||
|
||||
t.Run("TracksJetBrains", func(t *testing.T) {
|
||||
@@ -229,7 +233,9 @@ func TestAgent_Stats_Magic(t *testing.T) {
|
||||
remotePort := sc.Text()
|
||||
|
||||
//nolint:dogsled
|
||||
conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
|
||||
conn, agentClient, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||||
o.ExperimentalConnectionReports = true
|
||||
})
|
||||
sshClient, err := conn.SSHClient(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -265,6 +271,8 @@ func TestAgent_Stats_Magic(t *testing.T) {
|
||||
}, testutil.WaitLong, testutil.IntervalFast,
|
||||
"never saw stats after conn closes",
|
||||
)
|
||||
|
||||
assertConnectionReport(t, agentClient, proto.Connection_JETBRAINS, 0, "")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -922,7 +930,9 @@ func TestAgent_SFTP(t *testing.T) {
|
||||
home = "/" + strings.ReplaceAll(home, "\\", "/")
|
||||
}
|
||||
//nolint:dogsled
|
||||
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
|
||||
conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||||
o.ExperimentalConnectionReports = true
|
||||
})
|
||||
sshClient, err := conn.SSHClient(ctx)
|
||||
require.NoError(t, err)
|
||||
defer sshClient.Close()
|
||||
@@ -945,6 +955,10 @@ func TestAgent_SFTP(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
_, err = os.Stat(tempFile)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Close the client to trigger disconnect event.
|
||||
_ = client.Close()
|
||||
assertConnectionReport(t, agentClient, proto.Connection_SSH, 0, "")
|
||||
}
|
||||
|
||||
func TestAgent_SCP(t *testing.T) {
|
||||
@@ -954,7 +968,9 @@ func TestAgent_SCP(t *testing.T) {
|
||||
defer cancel()
|
||||
|
||||
//nolint:dogsled
|
||||
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
|
||||
conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||||
o.ExperimentalConnectionReports = true
|
||||
})
|
||||
sshClient, err := conn.SSHClient(ctx)
|
||||
require.NoError(t, err)
|
||||
defer sshClient.Close()
|
||||
@@ -967,6 +983,10 @@ func TestAgent_SCP(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
_, err = os.Stat(tempFile)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Close the client to trigger disconnect event.
|
||||
scpClient.Close()
|
||||
assertConnectionReport(t, agentClient, proto.Connection_SSH, 0, "")
|
||||
}
|
||||
|
||||
func TestAgent_FileTransferBlocked(t *testing.T) {
|
||||
@@ -991,8 +1011,9 @@ func TestAgent_FileTransferBlocked(t *testing.T) {
|
||||
defer cancel()
|
||||
|
||||
//nolint:dogsled
|
||||
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||||
conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||||
o.BlockFileTransfer = true
|
||||
o.ExperimentalConnectionReports = true
|
||||
})
|
||||
sshClient, err := conn.SSHClient(ctx)
|
||||
require.NoError(t, err)
|
||||
@@ -1000,6 +1021,8 @@ func TestAgent_FileTransferBlocked(t *testing.T) {
|
||||
_, err = sftp.NewClient(sshClient)
|
||||
require.Error(t, err)
|
||||
assertFileTransferBlocked(t, err.Error())
|
||||
|
||||
assertConnectionReport(t, agentClient, proto.Connection_SSH, agentssh.BlockedFileTransferErrorCode, "")
|
||||
})
|
||||
|
||||
t.Run("SCP with go-scp package", func(t *testing.T) {
|
||||
@@ -1009,8 +1032,9 @@ func TestAgent_FileTransferBlocked(t *testing.T) {
|
||||
defer cancel()
|
||||
|
||||
//nolint:dogsled
|
||||
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||||
conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||||
o.BlockFileTransfer = true
|
||||
o.ExperimentalConnectionReports = true
|
||||
})
|
||||
sshClient, err := conn.SSHClient(ctx)
|
||||
require.NoError(t, err)
|
||||
@@ -1022,6 +1046,8 @@ func TestAgent_FileTransferBlocked(t *testing.T) {
|
||||
err = scpClient.CopyFile(context.Background(), strings.NewReader("hello world"), tempFile, "0755")
|
||||
require.Error(t, err)
|
||||
assertFileTransferBlocked(t, err.Error())
|
||||
|
||||
assertConnectionReport(t, agentClient, proto.Connection_SSH, agentssh.BlockedFileTransferErrorCode, "")
|
||||
})
|
||||
|
||||
t.Run("Forbidden commands", func(t *testing.T) {
|
||||
@@ -1035,8 +1061,9 @@ func TestAgent_FileTransferBlocked(t *testing.T) {
|
||||
defer cancel()
|
||||
|
||||
//nolint:dogsled
|
||||
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||||
conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||||
o.BlockFileTransfer = true
|
||||
o.ExperimentalConnectionReports = true
|
||||
})
|
||||
sshClient, err := conn.SSHClient(ctx)
|
||||
require.NoError(t, err)
|
||||
@@ -1057,6 +1084,8 @@ func TestAgent_FileTransferBlocked(t *testing.T) {
|
||||
msg, err := io.ReadAll(stdout)
|
||||
require.NoError(t, err)
|
||||
assertFileTransferBlocked(t, string(msg))
|
||||
|
||||
assertConnectionReport(t, agentClient, proto.Connection_SSH, agentssh.BlockedFileTransferErrorCode, "")
|
||||
})
|
||||
}
|
||||
})
|
||||
@@ -1665,8 +1694,18 @@ func TestAgent_ReconnectingPTY(t *testing.T) {
|
||||
defer cancel()
|
||||
|
||||
//nolint:dogsled
|
||||
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
|
||||
conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||||
o.ExperimentalConnectionReports = true
|
||||
})
|
||||
id := uuid.New()
|
||||
|
||||
// Test that the connection is reported. This must be tested in the
|
||||
// first connection because we care about verifying all of these.
|
||||
netConn0, err := conn.ReconnectingPTY(ctx, id, 80, 80, "bash --norc")
|
||||
require.NoError(t, err)
|
||||
_ = netConn0.Close()
|
||||
assertConnectionReport(t, agentClient, proto.Connection_RECONNECTING_PTY, 0, "")
|
||||
|
||||
// --norc disables executing .bashrc, which is often used to customize the bash prompt
|
||||
netConn1, err := conn.ReconnectingPTY(ctx, id, 80, 80, "bash --norc")
|
||||
require.NoError(t, err)
|
||||
@@ -2763,3 +2802,35 @@ func requireEcho(t *testing.T, conn net.Conn) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "test", string(b))
|
||||
}
|
||||
|
||||
func assertConnectionReport(t testing.TB, agentClient *agenttest.Client, connectionType proto.Connection_Type, status int, reason string) {
|
||||
t.Helper()
|
||||
|
||||
var reports []*proto.ReportConnectionRequest
|
||||
if !assert.Eventually(t, func() bool {
|
||||
reports = agentClient.GetConnectionReports()
|
||||
return len(reports) >= 2
|
||||
}, testutil.WaitMedium, testutil.IntervalFast, "waiting for 2 connection reports or more; got %d", len(reports)) {
|
||||
return
|
||||
}
|
||||
|
||||
assert.Len(t, reports, 2, "want 2 connection reports")
|
||||
|
||||
assert.Equal(t, proto.Connection_CONNECT, reports[0].GetConnection().GetAction(), "first report should be connect")
|
||||
assert.Equal(t, proto.Connection_DISCONNECT, reports[1].GetConnection().GetAction(), "second report should be disconnect")
|
||||
assert.Equal(t, connectionType, reports[0].GetConnection().GetType(), "connect type should be %s", connectionType)
|
||||
assert.Equal(t, connectionType, reports[1].GetConnection().GetType(), "disconnect type should be %s", connectionType)
|
||||
t1 := reports[0].GetConnection().GetTimestamp().AsTime()
|
||||
t2 := reports[1].GetConnection().GetTimestamp().AsTime()
|
||||
assert.True(t, t1.Before(t2) || t1.Equal(t2), "connect timestamp should be before or equal to disconnect timestamp")
|
||||
assert.NotEmpty(t, reports[0].GetConnection().GetIp(), "connect ip should not be empty")
|
||||
assert.NotEmpty(t, reports[1].GetConnection().GetIp(), "disconnect ip should not be empty")
|
||||
assert.Equal(t, 0, int(reports[0].GetConnection().GetStatusCode()), "connect status code should be 0")
|
||||
assert.Equal(t, status, int(reports[1].GetConnection().GetStatusCode()), "disconnect status code should be %d", status)
|
||||
assert.Equal(t, "", reports[0].GetConnection().GetReason(), "connect reason should be empty")
|
||||
if reason != "" {
|
||||
assert.Contains(t, reports[1].GetConnection().GetReason(), reason, "disconnect reason should contain %s", reason)
|
||||
} else {
|
||||
t.Logf("connection report disconnect reason: %s", reports[1].GetConnection().GetReason())
|
||||
}
|
||||
}
|
||||
|
||||
+77
-10
@@ -78,6 +78,8 @@ const (
|
||||
// BlockedFileTransferCommands contains a list of restricted file transfer commands.
|
||||
var BlockedFileTransferCommands = []string{"nc", "rsync", "scp", "sftp"}
|
||||
|
||||
type reportConnectionFunc func(id uuid.UUID, sessionType MagicSessionType, ip string) (disconnected func(code int, reason string))
|
||||
|
||||
// Config sets configuration parameters for the agent SSH server.
|
||||
type Config struct {
|
||||
// MaxTimeout sets the absolute connection timeout, none if empty. If set to
|
||||
@@ -100,6 +102,8 @@ type Config struct {
|
||||
X11DisplayOffset *int
|
||||
// BlockFileTransfer restricts use of file transfer applications.
|
||||
BlockFileTransfer bool
|
||||
// ReportConnection.
|
||||
ReportConnection reportConnectionFunc
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
@@ -152,6 +156,9 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
|
||||
return home
|
||||
}
|
||||
}
|
||||
if config.ReportConnection == nil {
|
||||
config.ReportConnection = func(uuid.UUID, MagicSessionType, string) func(int, string) { return func(int, string) {} }
|
||||
}
|
||||
|
||||
forwardHandler := &ssh.ForwardedTCPHandler{}
|
||||
unixForwardHandler := newForwardedUnixHandler(logger)
|
||||
@@ -174,7 +181,7 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
|
||||
ChannelHandlers: map[string]ssh.ChannelHandler{
|
||||
"direct-tcpip": func(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) {
|
||||
// Wrapper is designed to find and track JetBrains Gateway connections.
|
||||
wrapped := NewJetbrainsChannelWatcher(ctx, s.logger, newChan, &s.connCountJetBrains)
|
||||
wrapped := NewJetbrainsChannelWatcher(ctx, s.logger, s.config.ReportConnection, newChan, &s.connCountJetBrains)
|
||||
ssh.DirectTCPIPHandler(srv, conn, wrapped, ctx)
|
||||
},
|
||||
"direct-streamlocal@openssh.com": directStreamLocalHandler,
|
||||
@@ -288,6 +295,35 @@ func extractMagicSessionType(env []string) (magicType MagicSessionType, rawType
|
||||
})
|
||||
}
|
||||
|
||||
// sessionCloseTracker is a wrapper around Session that tracks the exit code.
|
||||
type sessionCloseTracker struct {
|
||||
ssh.Session
|
||||
exitOnce sync.Once
|
||||
code atomic.Int64
|
||||
}
|
||||
|
||||
var _ ssh.Session = &sessionCloseTracker{}
|
||||
|
||||
func (s *sessionCloseTracker) track(code int) {
|
||||
s.exitOnce.Do(func() {
|
||||
s.code.Store(int64(code))
|
||||
})
|
||||
}
|
||||
|
||||
func (s *sessionCloseTracker) exitCode() int {
|
||||
return int(s.code.Load())
|
||||
}
|
||||
|
||||
func (s *sessionCloseTracker) Exit(code int) error {
|
||||
s.track(code)
|
||||
return s.Session.Exit(code)
|
||||
}
|
||||
|
||||
func (s *sessionCloseTracker) Close() error {
|
||||
s.track(1)
|
||||
return s.Session.Close()
|
||||
}
|
||||
|
||||
func (s *Server) sessionHandler(session ssh.Session) {
|
||||
ctx := session.Context()
|
||||
id := uuid.New()
|
||||
@@ -300,17 +336,23 @@ func (s *Server) sessionHandler(session ssh.Session) {
|
||||
)
|
||||
logger.Info(ctx, "handling ssh session")
|
||||
|
||||
env := session.Environ()
|
||||
magicType, magicTypeRaw, env := extractMagicSessionType(env)
|
||||
|
||||
if !s.trackSession(session, true) {
|
||||
reason := "unable to accept new session, server is closing"
|
||||
// Report connection attempt even if we couldn't accept it.
|
||||
disconnected := s.config.ReportConnection(id, magicType, session.RemoteAddr().String())
|
||||
defer disconnected(1, reason)
|
||||
|
||||
logger.Info(ctx, reason)
|
||||
// See (*Server).Close() for why we call Close instead of Exit.
|
||||
_ = session.Close()
|
||||
logger.Info(ctx, "unable to accept new session, server is closing")
|
||||
return
|
||||
}
|
||||
defer s.trackSession(session, false)
|
||||
|
||||
env := session.Environ()
|
||||
magicType, magicTypeRaw, env := extractMagicSessionType(env)
|
||||
|
||||
reportSession := true
|
||||
switch magicType {
|
||||
case MagicSessionTypeVSCode:
|
||||
s.connCountVSCode.Add(1)
|
||||
@@ -318,6 +360,7 @@ func (s *Server) sessionHandler(session ssh.Session) {
|
||||
case MagicSessionTypeJetBrains:
|
||||
// Do nothing here because JetBrains launches hundreds of ssh sessions.
|
||||
// We instead track JetBrains in the single persistent tcp forwarding channel.
|
||||
reportSession = false
|
||||
case MagicSessionTypeSSH:
|
||||
s.connCountSSHSession.Add(1)
|
||||
defer s.connCountSSHSession.Add(-1)
|
||||
@@ -325,6 +368,20 @@ func (s *Server) sessionHandler(session ssh.Session) {
|
||||
logger.Warn(ctx, "invalid magic ssh session type specified", slog.F("raw_type", magicTypeRaw))
|
||||
}
|
||||
|
||||
closeCause := func(string) {}
|
||||
if reportSession {
|
||||
var reason string
|
||||
closeCause = func(r string) { reason = r }
|
||||
|
||||
scr := &sessionCloseTracker{Session: session}
|
||||
session = scr
|
||||
|
||||
disconnected := s.config.ReportConnection(id, magicType, session.RemoteAddr().String())
|
||||
defer func() {
|
||||
disconnected(scr.exitCode(), reason)
|
||||
}()
|
||||
}
|
||||
|
||||
if s.fileTransferBlocked(session) {
|
||||
s.logger.Warn(ctx, "file transfer blocked", slog.F("session_subsystem", session.Subsystem()), slog.F("raw_command", session.RawCommand()))
|
||||
|
||||
@@ -333,6 +390,7 @@ func (s *Server) sessionHandler(session ssh.Session) {
|
||||
errorMessage := fmt.Sprintf("\x02%s\n", BlockedFileTransferErrorMessage)
|
||||
_, _ = session.Write([]byte(errorMessage))
|
||||
}
|
||||
closeCause("file transfer blocked")
|
||||
_ = session.Exit(BlockedFileTransferErrorCode)
|
||||
return
|
||||
}
|
||||
@@ -340,10 +398,14 @@ func (s *Server) sessionHandler(session ssh.Session) {
|
||||
switch ss := session.Subsystem(); ss {
|
||||
case "":
|
||||
case "sftp":
|
||||
s.sftpHandler(logger, session)
|
||||
err := s.sftpHandler(logger, session)
|
||||
if err != nil {
|
||||
closeCause(err.Error())
|
||||
}
|
||||
return
|
||||
default:
|
||||
logger.Warn(ctx, "unsupported subsystem", slog.F("subsystem", ss))
|
||||
closeCause(fmt.Sprintf("unsupported subsystem: %s", ss))
|
||||
_ = session.Exit(1)
|
||||
return
|
||||
}
|
||||
@@ -352,8 +414,9 @@ func (s *Server) sessionHandler(session ssh.Session) {
|
||||
if hasX11 {
|
||||
display, handled := s.x11Handler(session.Context(), x11)
|
||||
if !handled {
|
||||
_ = session.Exit(1)
|
||||
logger.Error(ctx, "x11 handler failed")
|
||||
closeCause("x11 handler failed")
|
||||
_ = session.Exit(1)
|
||||
return
|
||||
}
|
||||
env = append(env, fmt.Sprintf("DISPLAY=localhost:%d.%d", display, x11.ScreenNumber))
|
||||
@@ -380,6 +443,8 @@ func (s *Server) sessionHandler(session ssh.Session) {
|
||||
slog.F("exit_code", code),
|
||||
)
|
||||
|
||||
closeCause(fmt.Sprintf("process exited with error status: %d", exitError.ExitCode()))
|
||||
|
||||
// TODO(mafredri): For signal exit, there's also an "exit-signal"
|
||||
// request (session.Exit sends "exit-status"), however, since it's
|
||||
// not implemented on the session interface and not used by
|
||||
@@ -391,6 +456,7 @@ func (s *Server) sessionHandler(session ssh.Session) {
|
||||
logger.Warn(ctx, "ssh session failed", slog.Error(err))
|
||||
// This exit code is designed to be unlikely to be confused for a legit exit code
|
||||
// from the process.
|
||||
closeCause(err.Error())
|
||||
_ = session.Exit(MagicSessionErrorCode)
|
||||
return
|
||||
}
|
||||
@@ -650,7 +716,7 @@ func handleSignal(logger slog.Logger, ssig ssh.Signal, signaler interface{ Signa
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) sftpHandler(logger slog.Logger, session ssh.Session) {
|
||||
func (s *Server) sftpHandler(logger slog.Logger, session ssh.Session) error {
|
||||
s.metrics.sftpConnectionsTotal.Add(1)
|
||||
|
||||
ctx := session.Context()
|
||||
@@ -674,7 +740,7 @@ func (s *Server) sftpHandler(logger slog.Logger, session ssh.Session) {
|
||||
server, err := sftp.NewServer(session, opts...)
|
||||
if err != nil {
|
||||
logger.Debug(ctx, "initialize sftp server", slog.Error(err))
|
||||
return
|
||||
return xerrors.Errorf("initialize sftp server: %w", err)
|
||||
}
|
||||
defer server.Close()
|
||||
|
||||
@@ -689,11 +755,12 @@ func (s *Server) sftpHandler(logger slog.Logger, session ssh.Session) {
|
||||
// code but `scp` on macOS does (when using the default
|
||||
// SFTP backend).
|
||||
_ = session.Exit(0)
|
||||
return
|
||||
return nil
|
||||
}
|
||||
logger.Warn(ctx, "sftp server closed with error", slog.Error(err))
|
||||
s.metrics.sftpServerErrors.Add(1)
|
||||
_ = session.Exit(1)
|
||||
return xerrors.Errorf("sftp server closed with error: %w", err)
|
||||
}
|
||||
|
||||
// CreateCommand processes raw command input with OpenSSH-like behavior.
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/atomic"
|
||||
gossh "golang.org/x/crypto/ssh"
|
||||
|
||||
@@ -28,9 +29,11 @@ type JetbrainsChannelWatcher struct {
|
||||
gossh.NewChannel
|
||||
jetbrainsCounter *atomic.Int64
|
||||
logger slog.Logger
|
||||
originAddr string
|
||||
reportConnection reportConnectionFunc
|
||||
}
|
||||
|
||||
func NewJetbrainsChannelWatcher(ctx ssh.Context, logger slog.Logger, newChannel gossh.NewChannel, counter *atomic.Int64) gossh.NewChannel {
|
||||
func NewJetbrainsChannelWatcher(ctx ssh.Context, logger slog.Logger, reportConnection reportConnectionFunc, newChannel gossh.NewChannel, counter *atomic.Int64) gossh.NewChannel {
|
||||
d := localForwardChannelData{}
|
||||
if err := gossh.Unmarshal(newChannel.ExtraData(), &d); err != nil {
|
||||
// If the data fails to unmarshal, do nothing.
|
||||
@@ -61,12 +64,17 @@ func NewJetbrainsChannelWatcher(ctx ssh.Context, logger slog.Logger, newChannel
|
||||
NewChannel: newChannel,
|
||||
jetbrainsCounter: counter,
|
||||
logger: logger.With(slog.F("destination_port", d.DestPort)),
|
||||
originAddr: d.OriginAddr,
|
||||
reportConnection: reportConnection,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *JetbrainsChannelWatcher) Accept() (gossh.Channel, <-chan *gossh.Request, error) {
|
||||
disconnected := w.reportConnection(uuid.New(), MagicSessionTypeJetBrains, w.originAddr)
|
||||
|
||||
c, r, err := w.NewChannel.Accept()
|
||||
if err != nil {
|
||||
disconnected(1, err.Error())
|
||||
return c, r, err
|
||||
}
|
||||
w.jetbrainsCounter.Add(1)
|
||||
@@ -77,6 +85,7 @@ func (w *JetbrainsChannelWatcher) Accept() (gossh.Channel, <-chan *gossh.Request
|
||||
Channel: c,
|
||||
done: func() {
|
||||
w.jetbrainsCounter.Add(-1)
|
||||
disconnected(0, "")
|
||||
// nolint: gocritic // JetBrains is a proper noun and should be capitalized
|
||||
w.logger.Debug(context.Background(), "JetBrains watcher channel closed")
|
||||
},
|
||||
|
||||
+20
-10
@@ -158,20 +158,24 @@ func (c *Client) SetLogsChannel(ch chan<- *agentproto.BatchCreateLogsRequest) {
|
||||
c.fakeAgentAPI.SetLogsChannel(ch)
|
||||
}
|
||||
|
||||
func (c *Client) GetConnectionReports() []*agentproto.ReportConnectionRequest {
|
||||
return c.fakeAgentAPI.GetConnectionReports()
|
||||
}
|
||||
|
||||
type FakeAgentAPI struct {
|
||||
sync.Mutex
|
||||
t testing.TB
|
||||
logger slog.Logger
|
||||
|
||||
manifest *agentproto.Manifest
|
||||
startupCh chan *agentproto.Startup
|
||||
statsCh chan *agentproto.Stats
|
||||
appHealthCh chan *agentproto.BatchUpdateAppHealthRequest
|
||||
logsCh chan<- *agentproto.BatchCreateLogsRequest
|
||||
lifecycleStates []codersdk.WorkspaceAgentLifecycle
|
||||
metadata map[string]agentsdk.Metadata
|
||||
timings []*agentproto.Timing
|
||||
connections []*agentproto.Connection
|
||||
manifest *agentproto.Manifest
|
||||
startupCh chan *agentproto.Startup
|
||||
statsCh chan *agentproto.Stats
|
||||
appHealthCh chan *agentproto.BatchUpdateAppHealthRequest
|
||||
logsCh chan<- *agentproto.BatchCreateLogsRequest
|
||||
lifecycleStates []codersdk.WorkspaceAgentLifecycle
|
||||
metadata map[string]agentsdk.Metadata
|
||||
timings []*agentproto.Timing
|
||||
connectionReports []*agentproto.ReportConnectionRequest
|
||||
|
||||
getAnnouncementBannersFunc func() ([]codersdk.BannerConfig, error)
|
||||
getResourcesMonitoringConfigurationFunc func() (*agentproto.GetResourcesMonitoringConfigurationResponse, error)
|
||||
@@ -348,12 +352,18 @@ func (f *FakeAgentAPI) ScriptCompleted(_ context.Context, req *agentproto.Worksp
|
||||
|
||||
func (f *FakeAgentAPI) ReportConnection(_ context.Context, req *agentproto.ReportConnectionRequest) (*emptypb.Empty, error) {
|
||||
f.Lock()
|
||||
f.connections = append(f.connections, req.GetConnection())
|
||||
f.connectionReports = append(f.connectionReports, req)
|
||||
f.Unlock()
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (f *FakeAgentAPI) GetConnectionReports() []*agentproto.ReportConnectionRequest {
|
||||
f.Lock()
|
||||
defer f.Unlock()
|
||||
return slices.Clone(f.connectionReports)
|
||||
}
|
||||
|
||||
func NewFakeAgentAPI(t testing.TB, logger slog.Logger, manifest *agentproto.Manifest, statsCh chan *agentproto.Stats) *FakeAgentAPI {
|
||||
return &FakeAgentAPI{
|
||||
t: t,
|
||||
|
||||
@@ -20,11 +20,14 @@ import (
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
)
|
||||
|
||||
type reportConnectionFunc func(id uuid.UUID, ip string) (disconnected func(code int, reason string))
|
||||
|
||||
type Server struct {
|
||||
logger slog.Logger
|
||||
connectionsTotal prometheus.Counter
|
||||
errorsTotal *prometheus.CounterVec
|
||||
commandCreator *agentssh.Server
|
||||
reportConnection reportConnectionFunc
|
||||
connCount atomic.Int64
|
||||
reconnectingPTYs sync.Map
|
||||
timeout time.Duration
|
||||
@@ -33,13 +36,19 @@ type Server struct {
|
||||
}
|
||||
|
||||
// NewServer returns a new ReconnectingPTY server
|
||||
func NewServer(logger slog.Logger, commandCreator *agentssh.Server,
|
||||
func NewServer(logger slog.Logger, commandCreator *agentssh.Server, reportConnection reportConnectionFunc,
|
||||
connectionsTotal prometheus.Counter, errorsTotal *prometheus.CounterVec,
|
||||
timeout time.Duration, opts ...func(*Server),
|
||||
) *Server {
|
||||
if reportConnection == nil {
|
||||
reportConnection = func(uuid.UUID, string) func(int, string) {
|
||||
return func(int, string) {}
|
||||
}
|
||||
}
|
||||
s := &Server{
|
||||
logger: logger,
|
||||
commandCreator: commandCreator,
|
||||
reportConnection: reportConnection,
|
||||
connectionsTotal: connectionsTotal,
|
||||
errorsTotal: errorsTotal,
|
||||
timeout: timeout,
|
||||
@@ -67,20 +76,31 @@ func (s *Server) Serve(ctx, hardCtx context.Context, l net.Listener) (retErr err
|
||||
slog.F("local", conn.LocalAddr().String()))
|
||||
clog.Info(ctx, "accepted conn")
|
||||
wg.Add(1)
|
||||
disconnected := s.reportConnection(uuid.New(), conn.RemoteAddr().String())
|
||||
closed := make(chan struct{})
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
select {
|
||||
case <-closed:
|
||||
case <-hardCtx.Done():
|
||||
disconnected(1, "server shut down")
|
||||
_ = conn.Close()
|
||||
}
|
||||
wg.Done()
|
||||
}()
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer close(closed)
|
||||
defer wg.Done()
|
||||
_ = s.handleConn(ctx, clog, conn)
|
||||
err := s.handleConn(ctx, clog, conn)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
disconnected(1, "server shutting down")
|
||||
} else {
|
||||
disconnected(1, err.Error())
|
||||
}
|
||||
} else {
|
||||
disconnected(0, "")
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
@@ -54,6 +54,8 @@ func (r *RootCmd) workspaceAgent() *serpent.Command {
|
||||
agentHeaderCommand string
|
||||
agentHeader []string
|
||||
devcontainersEnabled bool
|
||||
|
||||
experimentalConnectionReports bool
|
||||
)
|
||||
cmd := &serpent.Command{
|
||||
Use: "agent",
|
||||
@@ -325,6 +327,10 @@ func (r *RootCmd) workspaceAgent() *serpent.Command {
|
||||
containerLister = agentcontainers.NewDocker(execer)
|
||||
}
|
||||
|
||||
if experimentalConnectionReports {
|
||||
logger.Info(ctx, "experimental connection reports enabled")
|
||||
}
|
||||
|
||||
agnt := agent.New(agent.Options{
|
||||
Client: client,
|
||||
Logger: logger,
|
||||
@@ -353,6 +359,7 @@ func (r *RootCmd) workspaceAgent() *serpent.Command {
|
||||
ContainerLister: containerLister,
|
||||
|
||||
ExperimentalContainersEnabled: devcontainersEnabled,
|
||||
ExperimentalConnectionReports: experimentalConnectionReports,
|
||||
})
|
||||
|
||||
promHandler := agent.PrometheusMetricsHandler(prometheusRegistry, logger)
|
||||
@@ -482,6 +489,14 @@ func (r *RootCmd) workspaceAgent() *serpent.Command {
|
||||
Description: "Allow the agent to automatically detect running devcontainers.",
|
||||
Value: serpent.BoolOf(&devcontainersEnabled),
|
||||
},
|
||||
{
|
||||
Flag: "experimental-connection-reports-enable",
|
||||
Hidden: true,
|
||||
Default: "false",
|
||||
Env: "CODER_AGENT_EXPERIMENTAL_CONNECTION_REPORTS_ENABLE",
|
||||
Description: "Enable experimental connection reports.",
|
||||
Value: serpent.BoolOf(&experimentalConnectionReports),
|
||||
},
|
||||
}
|
||||
|
||||
return cmd
|
||||
|
||||
Reference in New Issue
Block a user