mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
fix(agent): install connstats callback at statsReporter creation
The stats reporter only installed the connstats callback on the TUN device after the report loop negotiated an interval with the server. Traffic that flowed before that point (e.g. an SSH handshake) was silently dropped from tracking because the TUN wrapper's stats.Load() returned nil. On macOS CI under scheduler pressure, this window was wide enough for the entire SSH handshake to complete untracked. Install the connstats callback immediately in newStatsReporter using the interval from agent.Options.StatsReportInterval. The default (5 minutes) matches the server-side default, so the report loop does not replace the tracker on the common path. Tests set it to the agenttest.StatsInterval (500ms) to match the fake server, also avoiding replacement. If the server returns a different interval, the existing reconciliation in reportLocked handles it. Add a StatsCallbackRace subtest to TestAgent_Stats_SSH that uses a channel barrier to prove the fix is load-bearing: on unfixed code the barrier blocks reportLoop while SSH connects (traffic lost), on fixed code the callback is already installed at creation (traffic captured). Fixes flaky TestAgent_Stats_SSH, TestAgent_Stats_ReconnectingPTY, and TestAgent_Stats_Magic by ensuring the connstats callback is always installed before network traffic can flow. Closes https://github.com/coder/internal/issues/505
This commit is contained in:
+10
-1
@@ -117,6 +117,9 @@ type Options struct {
|
||||
ContextConfig agentcontextconfig.Config
|
||||
// DERPTLSConfig is an optional TLS config for DERP connections.
|
||||
DERPTLSConfig *tls.Config
|
||||
// StatsReportInterval is the interval for the connstats callback
|
||||
// installed at statsReporter creation.
|
||||
StatsReportInterval time.Duration
|
||||
}
|
||||
|
||||
type Client interface {
|
||||
@@ -183,6 +186,10 @@ func New(options Options) Agent {
|
||||
options.Execer = agentexec.DefaultExecer
|
||||
}
|
||||
|
||||
if options.StatsReportInterval == 0 {
|
||||
options.StatsReportInterval = DefaultStatsReportInterval
|
||||
}
|
||||
|
||||
if options.ListeningPortsGetter == nil {
|
||||
options.ListeningPortsGetter = &osListeningPortsGetter{
|
||||
cacheDuration: 1 * time.Second,
|
||||
@@ -216,6 +223,7 @@ func New(options Options) Agent {
|
||||
ignorePorts: maps.Clone(options.IgnorePorts),
|
||||
},
|
||||
reportMetadataInterval: options.ReportMetadataInterval,
|
||||
statsReportInterval: options.StatsReportInterval,
|
||||
announcementBannersRefreshInterval: options.ServiceBannerRefreshInterval,
|
||||
sshMaxTimeout: options.SSHMaxTimeout,
|
||||
subsystems: options.Subsystems,
|
||||
@@ -289,6 +297,7 @@ type agent struct {
|
||||
// values. Callers that need secrets must explicitly load this.
|
||||
secrets atomic.Pointer[[]agentsdk.WorkspaceSecret]
|
||||
reportMetadataInterval time.Duration
|
||||
statsReportInterval time.Duration
|
||||
scriptRunner *agentscripts.Runner
|
||||
announcementBanners atomic.Pointer[[]codersdk.BannerConfig] // announcementBanners is atomic because it is periodically updated.
|
||||
announcementBannersRefreshInterval time.Duration
|
||||
@@ -1500,7 +1509,7 @@ func (a *agent) createOrUpdateNetwork(manifestOK, networkOK *checkpoint) func(co
|
||||
closing := a.closing
|
||||
if !closing {
|
||||
a.network = network
|
||||
a.statsReporter = newStatsReporter(a.logger, network, a)
|
||||
a.statsReporter = newStatsReporter(a.logger, network, a, a.statsReportInterval)
|
||||
}
|
||||
a.closeMutex.Unlock()
|
||||
if closing {
|
||||
|
||||
+79
-27
@@ -148,33 +148,11 @@ func TestAgent_Stats_SSH(t *testing.T) {
|
||||
err = session.Shell()
|
||||
require.NoError(t, err)
|
||||
|
||||
var s *proto.Stats
|
||||
// We are looking for four different stats to be reported. They might not all
|
||||
// arrive at the same time, so we loop until we've seen them all.
|
||||
var connectionCountSeen, rxBytesSeen, txBytesSeen, sessionCountSSHSeen bool
|
||||
require.Eventuallyf(t, func() bool {
|
||||
var ok bool
|
||||
s, ok = <-stats
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if s.ConnectionCount > 0 {
|
||||
connectionCountSeen = true
|
||||
}
|
||||
if s.RxBytes > 0 {
|
||||
rxBytesSeen = true
|
||||
}
|
||||
if s.TxBytes > 0 {
|
||||
txBytesSeen = true
|
||||
}
|
||||
if s.SessionCountSsh == 1 {
|
||||
sessionCountSSHSeen = true
|
||||
}
|
||||
return connectionCountSeen && rxBytesSeen && txBytesSeen && sessionCountSSHSeen
|
||||
}, testutil.WaitLong, testutil.IntervalFast,
|
||||
"never saw all stats: %+v, saw connectionCount: %t, rxBytes: %t, txBytes: %t, sessionCountSsh: %t",
|
||||
s, connectionCountSeen, rxBytesSeen, txBytesSeen, sessionCountSSHSeen,
|
||||
)
|
||||
// Generate SSH traffic so the connstats window sees the session.
|
||||
_, err = stdin.Write([]byte("echo test\n"))
|
||||
require.NoError(t, err)
|
||||
|
||||
assertSSHStats(t, stats)
|
||||
_, err = stdin.Write([]byte("exit 0\n"))
|
||||
require.NoError(t, err, "writing exit to stdin")
|
||||
_ = stdin.Close()
|
||||
@@ -182,6 +160,79 @@ func TestAgent_Stats_SSH(t *testing.T) {
|
||||
require.NoError(t, err, "waiting for session to exit")
|
||||
})
|
||||
}
|
||||
|
||||
// Regression test for CODAGT-517: the barrier blocks reportLoop's
|
||||
// initial UpdateStats, so on unfixed code the connstats callback is
|
||||
// never installed and handshake traffic is lost. On fixed code the
|
||||
// callback is installed at creation, so traffic is captured.
|
||||
t.Run("StatsCallbackRace", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
barrier := make(chan struct{})
|
||||
|
||||
//nolint:dogsled
|
||||
conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0,
|
||||
func(c *agenttest.Client, _ *agent.Options) {
|
||||
c.SetInitialUpdateStatsBarrier(barrier)
|
||||
},
|
||||
)
|
||||
|
||||
// Connect SSH while the barrier holds reportLoop blocked.
|
||||
sshClient, err := conn.SSHClientOnPort(ctx, workspacesdk.AgentStandardSSHPort)
|
||||
require.NoError(t, err)
|
||||
defer sshClient.Close()
|
||||
session, err := sshClient.NewSession()
|
||||
require.NoError(t, err)
|
||||
defer session.Close()
|
||||
stdin, err := session.StdinPipe()
|
||||
require.NoError(t, err)
|
||||
err = session.Shell()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Shell must be idle so the only traffic is the SSH handshake.
|
||||
|
||||
close(barrier)
|
||||
|
||||
assertSSHStats(t, stats)
|
||||
_, err = stdin.Write([]byte("exit 0\n"))
|
||||
require.NoError(t, err, "writing exit to stdin")
|
||||
_ = stdin.Close()
|
||||
err = session.Wait()
|
||||
require.NoError(t, err, "waiting for session to exit")
|
||||
})
|
||||
}
|
||||
|
||||
// assertSSHStats waits for ConnectionCount, RxBytes, TxBytes, and
|
||||
// SessionCountSsh to be nonzero on the stats channel.
|
||||
func assertSSHStats(t *testing.T, stats <-chan *proto.Stats) {
|
||||
t.Helper()
|
||||
var connectionCountSeen, rxBytesSeen, txBytesSeen, sessionCountSSHSeen bool
|
||||
require.Eventuallyf(t, func() bool {
|
||||
s, ok := <-stats
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
t.Logf("got stats: ConnectionCount=%d, RxBytes=%d, TxBytes=%d, SessionCountSsh=%d",
|
||||
s.ConnectionCount, s.RxBytes, s.TxBytes, s.SessionCountSsh)
|
||||
if s.ConnectionCount > 0 {
|
||||
connectionCountSeen = true
|
||||
}
|
||||
if s.RxBytes > 0 {
|
||||
rxBytesSeen = true
|
||||
}
|
||||
if s.TxBytes > 0 {
|
||||
txBytesSeen = true
|
||||
}
|
||||
if s.SessionCountSsh == 1 {
|
||||
sessionCountSSHSeen = true
|
||||
}
|
||||
return connectionCountSeen && rxBytesSeen && txBytesSeen && sessionCountSSHSeen
|
||||
}, testutil.WaitLong, testutil.IntervalFast,
|
||||
"never saw all SSH stats",
|
||||
)
|
||||
}
|
||||
|
||||
func TestAgent_Stats_ReconnectingPTY(t *testing.T) {
|
||||
@@ -3851,6 +3902,7 @@ func setupAgentWithSecrets(t testing.TB, metadata agentsdk.Manifest, secrets []a
|
||||
Logger: logger.Named("agent"),
|
||||
ReconnectingPTYTimeout: ptyTimeout,
|
||||
EnvironmentVariables: map[string]string{},
|
||||
StatsReportInterval: agenttest.StatsInterval,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
|
||||
+29
-15
@@ -32,7 +32,8 @@ import (
|
||||
"github.com/coder/websocket"
|
||||
)
|
||||
|
||||
const statsInterval = 500 * time.Millisecond
|
||||
// StatsInterval is the report interval returned by FakeAgentAPI.UpdateStats.
|
||||
const StatsInterval = 500 * time.Millisecond
|
||||
|
||||
func NewClient(t testing.TB,
|
||||
logger slog.Logger,
|
||||
@@ -128,6 +129,12 @@ func (c *Client) RefreshToken(context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetInitialUpdateStatsBarrier blocks the initial (empty) UpdateStats
|
||||
// RPC until the channel is closed. Must be called before the agent starts.
|
||||
func (c *Client) SetInitialUpdateStatsBarrier(barrier <-chan struct{}) {
|
||||
c.fakeAgentAPI.initialUpdateStatsBarrier = barrier
|
||||
}
|
||||
|
||||
func (c *Client) GetNumRefreshTokenCalls() int {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
@@ -232,19 +239,20 @@ type FakeAgentAPI struct {
|
||||
t testing.TB
|
||||
logger slog.Logger
|
||||
|
||||
manifest *agentproto.Manifest
|
||||
startupCh chan *agentproto.Startup
|
||||
statsCh chan *agentproto.Stats
|
||||
appHealthCh chan *agentproto.BatchUpdateAppHealthRequest
|
||||
logsCh chan<- *agentproto.BatchCreateLogsRequest
|
||||
lifecycleStates []codersdk.WorkspaceAgentLifecycle
|
||||
metadata map[string]agentsdk.Metadata
|
||||
timings []*agentproto.Timing
|
||||
connectionReports []*agentproto.ReportConnectionRequest
|
||||
subAgents map[uuid.UUID]*agentproto.SubAgent
|
||||
subAgentDirs map[uuid.UUID]string
|
||||
subAgentDisplayApps map[uuid.UUID][]agentproto.CreateSubAgentRequest_DisplayApp
|
||||
subAgentApps map[uuid.UUID][]*agentproto.CreateSubAgentRequest_App
|
||||
initialUpdateStatsBarrier <-chan struct{} // blocks initial (empty) UpdateStats until closed
|
||||
manifest *agentproto.Manifest
|
||||
startupCh chan *agentproto.Startup
|
||||
statsCh chan *agentproto.Stats
|
||||
appHealthCh chan *agentproto.BatchUpdateAppHealthRequest
|
||||
logsCh chan<- *agentproto.BatchCreateLogsRequest
|
||||
lifecycleStates []codersdk.WorkspaceAgentLifecycle
|
||||
metadata map[string]agentsdk.Metadata
|
||||
timings []*agentproto.Timing
|
||||
connectionReports []*agentproto.ReportConnectionRequest
|
||||
subAgents map[uuid.UUID]*agentproto.SubAgent
|
||||
subAgentDirs map[uuid.UUID]string
|
||||
subAgentDisplayApps map[uuid.UUID][]agentproto.CreateSubAgentRequest_DisplayApp
|
||||
subAgentApps map[uuid.UUID][]*agentproto.CreateSubAgentRequest_App
|
||||
|
||||
getAnnouncementBannersFunc func() ([]codersdk.BannerConfig, error)
|
||||
getResourcesMonitoringConfigurationFunc func() (*agentproto.GetResourcesMonitoringConfigurationResponse, error)
|
||||
@@ -330,8 +338,14 @@ func (f *FakeAgentAPI) UpdateStats(ctx context.Context, req *agentproto.UpdateSt
|
||||
case f.statsCh <- req.Stats:
|
||||
// OK!
|
||||
}
|
||||
} else if barrier := f.initialUpdateStatsBarrier; barrier != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-barrier:
|
||||
}
|
||||
}
|
||||
return &agentproto.UpdateStatsResponse{ReportInterval: durationpb.New(statsInterval)}, nil
|
||||
return &agentproto.UpdateStatsResponse{ReportInterval: durationpb.New(StatsInterval)}, nil
|
||||
}
|
||||
|
||||
func (f *FakeAgentAPI) GetLifecycleStates() []codersdk.WorkspaceAgentLifecycle {
|
||||
|
||||
+24
-10
@@ -42,13 +42,22 @@ type statsReporter struct {
|
||||
logger slog.Logger
|
||||
}
|
||||
|
||||
func newStatsReporter(logger slog.Logger, source networkStatsSource, collector statsCollector) *statsReporter {
|
||||
return &statsReporter{
|
||||
Cond: sync.NewCond(&sync.Mutex{}),
|
||||
logger: logger,
|
||||
source: source,
|
||||
collector: collector,
|
||||
// DefaultStatsReportInterval matches coderd.Options.AgentStatsRefreshInterval.
|
||||
const DefaultStatsReportInterval = 5 * time.Minute
|
||||
|
||||
func newStatsReporter(logger slog.Logger, source networkStatsSource, collector statsCollector, interval time.Duration) *statsReporter {
|
||||
s := &statsReporter{
|
||||
Cond: sync.NewCond(&sync.Mutex{}),
|
||||
logger: logger,
|
||||
source: source,
|
||||
collector: collector,
|
||||
lastInterval: interval,
|
||||
}
|
||||
// Install the callback immediately so traffic is tracked before
|
||||
// reportLoop starts. reportLoop replaces it only if the
|
||||
// server-negotiated interval differs.
|
||||
source.SetConnStatsCallback(interval, maxConns, s.callback)
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *statsReporter) callback(_, _ time.Time, virtual, _ map[netlogtype.Connection]netlogtype.Counts) {
|
||||
@@ -67,8 +76,10 @@ func (s *statsReporter) callback(_, _ time.Time, virtual, _ map[netlogtype.Conne
|
||||
s.Broadcast()
|
||||
}
|
||||
|
||||
// reportLoop programs the source (tailnet.Conn) to send it stats via the
|
||||
// callback, then reports them to the dest.
|
||||
// reportLoop reports collected stats to the server.
|
||||
//
|
||||
// The connstats callback is already installed by newStatsReporter;
|
||||
// reportLoop only replaces it if the server returns a different interval.
|
||||
//
|
||||
// It's intended to be called within the larger retry loop that establishes a
|
||||
// connection to the agent API, then passes that connection to go routines like
|
||||
@@ -80,8 +91,11 @@ func (s *statsReporter) reportLoop(ctx context.Context, dest statsDest) error {
|
||||
if err != nil {
|
||||
return xerrors.Errorf("initial update: %w", err)
|
||||
}
|
||||
s.lastInterval = resp.ReportInterval.AsDuration()
|
||||
s.source.SetConnStatsCallback(s.lastInterval, maxConns, s.callback)
|
||||
interval := resp.ReportInterval.AsDuration()
|
||||
if interval != s.lastInterval {
|
||||
s.lastInterval = interval
|
||||
s.source.SetConnStatsCallback(s.lastInterval, maxConns, s.callback)
|
||||
}
|
||||
|
||||
// use a separate goroutine to monitor the context so that we notice immediately, rather than
|
||||
// waiting for the next callback (which might never come if we are closing!)
|
||||
|
||||
@@ -23,7 +23,9 @@ func TestStatsReporter(t *testing.T) {
|
||||
fSource := newFakeNetworkStatsSource(ctx, t)
|
||||
fCollector := newFakeCollector(t)
|
||||
fDest := newFakeStatsDest()
|
||||
uut := newStatsReporter(logger, fSource, fCollector)
|
||||
uut := newStatsReporter(logger, fSource, fCollector, DefaultStatsReportInterval)
|
||||
|
||||
_ = testutil.TryReceive(ctx, t, fSource.period) // drain construction-time install
|
||||
|
||||
loopErr := make(chan error, 1)
|
||||
loopCtx, loopCancel := context.WithCancel(ctx)
|
||||
@@ -157,7 +159,7 @@ func newFakeNetworkStatsSource(ctx context.Context, t testing.TB) *fakeNetworkSt
|
||||
f := &fakeNetworkStatsSource{
|
||||
ctx: ctx,
|
||||
t: t,
|
||||
period: make(chan time.Duration),
|
||||
period: make(chan time.Duration, 1),
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user