fix(agent): install connstats callback at statsReporter creation

The stats reporter only installed the connstats callback on the TUN
device after the report loop negotiated an interval with the server.
Traffic that flowed before that point (e.g. an SSH handshake) was
silently dropped from tracking because the TUN wrapper's stats.Load()
returned nil. On macOS CI under scheduler pressure, this window was
wide enough for the entire SSH handshake to complete untracked.

Install the connstats callback immediately in newStatsReporter using
the interval from agent.Options.StatsReportInterval. The default
(5 minutes) matches the server-side default, so the report loop does
not replace the tracker on the common path. Tests set it to the
agenttest.StatsInterval (500ms) to match the fake server, also
avoiding replacement. If the server returns a different interval, the
existing reconciliation in reportLocked handles it.

Add a StatsCallbackRace subtest to TestAgent_Stats_SSH that uses a
channel barrier to prove the fix is load-bearing: on unfixed code the
barrier blocks reportLoop while SSH connects (traffic lost), on fixed
code the callback is already installed at creation (traffic captured).

Fixes flaky TestAgent_Stats_SSH, TestAgent_Stats_ReconnectingPTY,
and TestAgent_Stats_Magic by ensuring the connstats callback is
always installed before network traffic can flow.

Closes https://github.com/coder/internal/issues/505
This commit is contained in:
Mathias Fredriksson
2026-05-29 08:49:13 +00:00
parent c248dfb437
commit d0e71d20fd
5 changed files with 146 additions and 55 deletions
+10 -1
View File
@@ -117,6 +117,9 @@ type Options struct {
ContextConfig agentcontextconfig.Config
// DERPTLSConfig is an optional TLS config for DERP connections.
DERPTLSConfig *tls.Config
// StatsReportInterval is the interval for the connstats callback
// installed at statsReporter creation.
StatsReportInterval time.Duration
}
type Client interface {
@@ -183,6 +186,10 @@ func New(options Options) Agent {
options.Execer = agentexec.DefaultExecer
}
if options.StatsReportInterval == 0 {
options.StatsReportInterval = DefaultStatsReportInterval
}
if options.ListeningPortsGetter == nil {
options.ListeningPortsGetter = &osListeningPortsGetter{
cacheDuration: 1 * time.Second,
@@ -216,6 +223,7 @@ func New(options Options) Agent {
ignorePorts: maps.Clone(options.IgnorePorts),
},
reportMetadataInterval: options.ReportMetadataInterval,
statsReportInterval: options.StatsReportInterval,
announcementBannersRefreshInterval: options.ServiceBannerRefreshInterval,
sshMaxTimeout: options.SSHMaxTimeout,
subsystems: options.Subsystems,
@@ -289,6 +297,7 @@ type agent struct {
// values. Callers that need secrets must explicitly load this.
secrets atomic.Pointer[[]agentsdk.WorkspaceSecret]
reportMetadataInterval time.Duration
statsReportInterval time.Duration
scriptRunner *agentscripts.Runner
announcementBanners atomic.Pointer[[]codersdk.BannerConfig] // announcementBanners is atomic because it is periodically updated.
announcementBannersRefreshInterval time.Duration
@@ -1500,7 +1509,7 @@ func (a *agent) createOrUpdateNetwork(manifestOK, networkOK *checkpoint) func(co
closing := a.closing
if !closing {
a.network = network
a.statsReporter = newStatsReporter(a.logger, network, a)
a.statsReporter = newStatsReporter(a.logger, network, a, a.statsReportInterval)
}
a.closeMutex.Unlock()
if closing {
+79 -27
View File
@@ -148,33 +148,11 @@ func TestAgent_Stats_SSH(t *testing.T) {
err = session.Shell()
require.NoError(t, err)
var s *proto.Stats
// We are looking for four different stats to be reported. They might not all
// arrive at the same time, so we loop until we've seen them all.
var connectionCountSeen, rxBytesSeen, txBytesSeen, sessionCountSSHSeen bool
require.Eventuallyf(t, func() bool {
var ok bool
s, ok = <-stats
if !ok {
return false
}
if s.ConnectionCount > 0 {
connectionCountSeen = true
}
if s.RxBytes > 0 {
rxBytesSeen = true
}
if s.TxBytes > 0 {
txBytesSeen = true
}
if s.SessionCountSsh == 1 {
sessionCountSSHSeen = true
}
return connectionCountSeen && rxBytesSeen && txBytesSeen && sessionCountSSHSeen
}, testutil.WaitLong, testutil.IntervalFast,
"never saw all stats: %+v, saw connectionCount: %t, rxBytes: %t, txBytes: %t, sessionCountSsh: %t",
s, connectionCountSeen, rxBytesSeen, txBytesSeen, sessionCountSSHSeen,
)
// Generate SSH traffic so the connstats window sees the session.
_, err = stdin.Write([]byte("echo test\n"))
require.NoError(t, err)
assertSSHStats(t, stats)
_, err = stdin.Write([]byte("exit 0\n"))
require.NoError(t, err, "writing exit to stdin")
_ = stdin.Close()
@@ -182,6 +160,79 @@ func TestAgent_Stats_SSH(t *testing.T) {
require.NoError(t, err, "waiting for session to exit")
})
}
// Regression test for CODAGT-517: the barrier blocks reportLoop's
// initial UpdateStats, so on unfixed code the connstats callback is
// never installed and handshake traffic is lost. On fixed code the
// callback is installed at creation, so traffic is captured.
t.Run("StatsCallbackRace", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
barrier := make(chan struct{})
//nolint:dogsled
conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0,
func(c *agenttest.Client, _ *agent.Options) {
c.SetInitialUpdateStatsBarrier(barrier)
},
)
// Connect SSH while the barrier holds reportLoop blocked.
sshClient, err := conn.SSHClientOnPort(ctx, workspacesdk.AgentStandardSSHPort)
require.NoError(t, err)
defer sshClient.Close()
session, err := sshClient.NewSession()
require.NoError(t, err)
defer session.Close()
stdin, err := session.StdinPipe()
require.NoError(t, err)
err = session.Shell()
require.NoError(t, err)
// Shell must be idle so the only traffic is the SSH handshake.
close(barrier)
assertSSHStats(t, stats)
_, err = stdin.Write([]byte("exit 0\n"))
require.NoError(t, err, "writing exit to stdin")
_ = stdin.Close()
err = session.Wait()
require.NoError(t, err, "waiting for session to exit")
})
}
// assertSSHStats waits for ConnectionCount, RxBytes, TxBytes, and
// SessionCountSsh to be nonzero on the stats channel.
func assertSSHStats(t *testing.T, stats <-chan *proto.Stats) {
t.Helper()
var connectionCountSeen, rxBytesSeen, txBytesSeen, sessionCountSSHSeen bool
require.Eventuallyf(t, func() bool {
s, ok := <-stats
if !ok {
return false
}
t.Logf("got stats: ConnectionCount=%d, RxBytes=%d, TxBytes=%d, SessionCountSsh=%d",
s.ConnectionCount, s.RxBytes, s.TxBytes, s.SessionCountSsh)
if s.ConnectionCount > 0 {
connectionCountSeen = true
}
if s.RxBytes > 0 {
rxBytesSeen = true
}
if s.TxBytes > 0 {
txBytesSeen = true
}
if s.SessionCountSsh == 1 {
sessionCountSSHSeen = true
}
return connectionCountSeen && rxBytesSeen && txBytesSeen && sessionCountSSHSeen
}, testutil.WaitLong, testutil.IntervalFast,
"never saw all SSH stats",
)
}
func TestAgent_Stats_ReconnectingPTY(t *testing.T) {
@@ -3851,6 +3902,7 @@ func setupAgentWithSecrets(t testing.TB, metadata agentsdk.Manifest, secrets []a
Logger: logger.Named("agent"),
ReconnectingPTYTimeout: ptyTimeout,
EnvironmentVariables: map[string]string{},
StatsReportInterval: agenttest.StatsInterval,
}
for _, opt := range opts {
+29 -15
View File
@@ -32,7 +32,8 @@ import (
"github.com/coder/websocket"
)
const statsInterval = 500 * time.Millisecond
// StatsInterval is the report interval returned by FakeAgentAPI.UpdateStats.
const StatsInterval = 500 * time.Millisecond
func NewClient(t testing.TB,
logger slog.Logger,
@@ -128,6 +129,12 @@ func (c *Client) RefreshToken(context.Context) error {
return nil
}
// SetInitialUpdateStatsBarrier blocks the initial (empty) UpdateStats
// RPC until the channel is closed. Must be called before the agent starts.
func (c *Client) SetInitialUpdateStatsBarrier(barrier <-chan struct{}) {
c.fakeAgentAPI.initialUpdateStatsBarrier = barrier
}
func (c *Client) GetNumRefreshTokenCalls() int {
c.mu.Lock()
defer c.mu.Unlock()
@@ -232,19 +239,20 @@ type FakeAgentAPI struct {
t testing.TB
logger slog.Logger
manifest *agentproto.Manifest
startupCh chan *agentproto.Startup
statsCh chan *agentproto.Stats
appHealthCh chan *agentproto.BatchUpdateAppHealthRequest
logsCh chan<- *agentproto.BatchCreateLogsRequest
lifecycleStates []codersdk.WorkspaceAgentLifecycle
metadata map[string]agentsdk.Metadata
timings []*agentproto.Timing
connectionReports []*agentproto.ReportConnectionRequest
subAgents map[uuid.UUID]*agentproto.SubAgent
subAgentDirs map[uuid.UUID]string
subAgentDisplayApps map[uuid.UUID][]agentproto.CreateSubAgentRequest_DisplayApp
subAgentApps map[uuid.UUID][]*agentproto.CreateSubAgentRequest_App
initialUpdateStatsBarrier <-chan struct{} // blocks initial (empty) UpdateStats until closed
manifest *agentproto.Manifest
startupCh chan *agentproto.Startup
statsCh chan *agentproto.Stats
appHealthCh chan *agentproto.BatchUpdateAppHealthRequest
logsCh chan<- *agentproto.BatchCreateLogsRequest
lifecycleStates []codersdk.WorkspaceAgentLifecycle
metadata map[string]agentsdk.Metadata
timings []*agentproto.Timing
connectionReports []*agentproto.ReportConnectionRequest
subAgents map[uuid.UUID]*agentproto.SubAgent
subAgentDirs map[uuid.UUID]string
subAgentDisplayApps map[uuid.UUID][]agentproto.CreateSubAgentRequest_DisplayApp
subAgentApps map[uuid.UUID][]*agentproto.CreateSubAgentRequest_App
getAnnouncementBannersFunc func() ([]codersdk.BannerConfig, error)
getResourcesMonitoringConfigurationFunc func() (*agentproto.GetResourcesMonitoringConfigurationResponse, error)
@@ -330,8 +338,14 @@ func (f *FakeAgentAPI) UpdateStats(ctx context.Context, req *agentproto.UpdateSt
case f.statsCh <- req.Stats:
// OK!
}
} else if barrier := f.initialUpdateStatsBarrier; barrier != nil {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-barrier:
}
}
return &agentproto.UpdateStatsResponse{ReportInterval: durationpb.New(statsInterval)}, nil
return &agentproto.UpdateStatsResponse{ReportInterval: durationpb.New(StatsInterval)}, nil
}
func (f *FakeAgentAPI) GetLifecycleStates() []codersdk.WorkspaceAgentLifecycle {
+24 -10
View File
@@ -42,13 +42,22 @@ type statsReporter struct {
logger slog.Logger
}
func newStatsReporter(logger slog.Logger, source networkStatsSource, collector statsCollector) *statsReporter {
return &statsReporter{
Cond: sync.NewCond(&sync.Mutex{}),
logger: logger,
source: source,
collector: collector,
// DefaultStatsReportInterval matches coderd.Options.AgentStatsRefreshInterval.
const DefaultStatsReportInterval = 5 * time.Minute
func newStatsReporter(logger slog.Logger, source networkStatsSource, collector statsCollector, interval time.Duration) *statsReporter {
s := &statsReporter{
Cond: sync.NewCond(&sync.Mutex{}),
logger: logger,
source: source,
collector: collector,
lastInterval: interval,
}
// Install the callback immediately so traffic is tracked before
// reportLoop starts. reportLoop replaces it only if the
// server-negotiated interval differs.
source.SetConnStatsCallback(interval, maxConns, s.callback)
return s
}
func (s *statsReporter) callback(_, _ time.Time, virtual, _ map[netlogtype.Connection]netlogtype.Counts) {
@@ -67,8 +76,10 @@ func (s *statsReporter) callback(_, _ time.Time, virtual, _ map[netlogtype.Conne
s.Broadcast()
}
// reportLoop programs the source (tailnet.Conn) to send it stats via the
// callback, then reports them to the dest.
// reportLoop reports collected stats to the server.
//
// The connstats callback is already installed by newStatsReporter;
// reportLoop only replaces it if the server returns a different interval.
//
// It's intended to be called within the larger retry loop that establishes a
// connection to the agent API, then passes that connection to go routines like
@@ -80,8 +91,11 @@ func (s *statsReporter) reportLoop(ctx context.Context, dest statsDest) error {
if err != nil {
return xerrors.Errorf("initial update: %w", err)
}
s.lastInterval = resp.ReportInterval.AsDuration()
s.source.SetConnStatsCallback(s.lastInterval, maxConns, s.callback)
interval := resp.ReportInterval.AsDuration()
if interval != s.lastInterval {
s.lastInterval = interval
s.source.SetConnStatsCallback(s.lastInterval, maxConns, s.callback)
}
// use a separate goroutine to monitor the context so that we notice immediately, rather than
// waiting for the next callback (which might never come if we are closing!)
+4 -2
View File
@@ -23,7 +23,9 @@ func TestStatsReporter(t *testing.T) {
fSource := newFakeNetworkStatsSource(ctx, t)
fCollector := newFakeCollector(t)
fDest := newFakeStatsDest()
uut := newStatsReporter(logger, fSource, fCollector)
uut := newStatsReporter(logger, fSource, fCollector, DefaultStatsReportInterval)
_ = testutil.TryReceive(ctx, t, fSource.period) // drain construction-time install
loopErr := make(chan error, 1)
loopCtx, loopCancel := context.WithCancel(ctx)
@@ -157,7 +159,7 @@ func newFakeNetworkStatsSource(ctx context.Context, t testing.TB) *fakeNetworkSt
f := &fakeNetworkStatsSource{
ctx: ctx,
t: t,
period: make(chan time.Duration),
period: make(chan time.Duration, 1),
}
return f
}