diff --git a/agent/agent.go b/agent/agent.go index 33ad83c804..61b116de6b 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -117,6 +117,9 @@ type Options struct { ContextConfig agentcontextconfig.Config // DERPTLSConfig is an optional TLS config for DERP connections. DERPTLSConfig *tls.Config + // StatsReportInterval is the interval for the connstats callback + // installed at statsReporter creation. + StatsReportInterval time.Duration } type Client interface { @@ -183,6 +186,10 @@ func New(options Options) Agent { options.Execer = agentexec.DefaultExecer } + if options.StatsReportInterval == 0 { + options.StatsReportInterval = DefaultStatsReportInterval + } + if options.ListeningPortsGetter == nil { options.ListeningPortsGetter = &osListeningPortsGetter{ cacheDuration: 1 * time.Second, @@ -216,6 +223,7 @@ func New(options Options) Agent { ignorePorts: maps.Clone(options.IgnorePorts), }, reportMetadataInterval: options.ReportMetadataInterval, + statsReportInterval: options.StatsReportInterval, announcementBannersRefreshInterval: options.ServiceBannerRefreshInterval, sshMaxTimeout: options.SSHMaxTimeout, subsystems: options.Subsystems, @@ -289,6 +297,7 @@ type agent struct { // values. Callers that need secrets must explicitly load this. secrets atomic.Pointer[[]agentsdk.WorkspaceSecret] reportMetadataInterval time.Duration + statsReportInterval time.Duration scriptRunner *agentscripts.Runner announcementBanners atomic.Pointer[[]codersdk.BannerConfig] // announcementBanners is atomic because it is periodically updated. announcementBannersRefreshInterval time.Duration @@ -1500,7 +1509,7 @@ func (a *agent) createOrUpdateNetwork(manifestOK, networkOK *checkpoint) func(co closing := a.closing if !closing { a.network = network - a.statsReporter = newStatsReporter(a.logger, network, a) + a.statsReporter = newStatsReporter(a.logger, network, a, a.statsReportInterval) } a.closeMutex.Unlock() if closing { diff --git a/agent/agent_test.go b/agent/agent_test.go index 1fe8ad2725..f20386b5b3 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -148,33 +148,11 @@ func TestAgent_Stats_SSH(t *testing.T) { err = session.Shell() require.NoError(t, err) - var s *proto.Stats - // We are looking for four different stats to be reported. They might not all - // arrive at the same time, so we loop until we've seen them all. - var connectionCountSeen, rxBytesSeen, txBytesSeen, sessionCountSSHSeen bool - require.Eventuallyf(t, func() bool { - var ok bool - s, ok = <-stats - if !ok { - return false - } - if s.ConnectionCount > 0 { - connectionCountSeen = true - } - if s.RxBytes > 0 { - rxBytesSeen = true - } - if s.TxBytes > 0 { - txBytesSeen = true - } - if s.SessionCountSsh == 1 { - sessionCountSSHSeen = true - } - return connectionCountSeen && rxBytesSeen && txBytesSeen && sessionCountSSHSeen - }, testutil.WaitLong, testutil.IntervalFast, - "never saw all stats: %+v, saw connectionCount: %t, rxBytes: %t, txBytes: %t, sessionCountSsh: %t", - s, connectionCountSeen, rxBytesSeen, txBytesSeen, sessionCountSSHSeen, - ) + // Generate SSH traffic so the connstats window sees the session. + _, err = stdin.Write([]byte("echo test\n")) + require.NoError(t, err) + + assertSSHStats(t, stats) _, err = stdin.Write([]byte("exit 0\n")) require.NoError(t, err, "writing exit to stdin") _ = stdin.Close() @@ -182,6 +160,79 @@ func TestAgent_Stats_SSH(t *testing.T) { require.NoError(t, err, "waiting for session to exit") }) } + + // Regression test for CODAGT-517: the barrier blocks reportLoop's + // initial UpdateStats, so on unfixed code the connstats callback is + // never installed and handshake traffic is lost. On fixed code the + // callback is installed at creation, so traffic is captured. + t.Run("StatsCallbackRace", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + barrier := make(chan struct{}) + + //nolint:dogsled + conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, + func(c *agenttest.Client, _ *agent.Options) { + c.SetInitialUpdateStatsBarrier(barrier) + }, + ) + + // Connect SSH while the barrier holds reportLoop blocked. + sshClient, err := conn.SSHClientOnPort(ctx, workspacesdk.AgentStandardSSHPort) + require.NoError(t, err) + defer sshClient.Close() + session, err := sshClient.NewSession() + require.NoError(t, err) + defer session.Close() + stdin, err := session.StdinPipe() + require.NoError(t, err) + err = session.Shell() + require.NoError(t, err) + + // Shell must be idle so the only traffic is the SSH handshake. + + close(barrier) + + assertSSHStats(t, stats) + _, err = stdin.Write([]byte("exit 0\n")) + require.NoError(t, err, "writing exit to stdin") + _ = stdin.Close() + err = session.Wait() + require.NoError(t, err, "waiting for session to exit") + }) +} + +// assertSSHStats waits for ConnectionCount, RxBytes, TxBytes, and +// SessionCountSsh to be nonzero on the stats channel. +func assertSSHStats(t *testing.T, stats <-chan *proto.Stats) { + t.Helper() + var connectionCountSeen, rxBytesSeen, txBytesSeen, sessionCountSSHSeen bool + require.Eventuallyf(t, func() bool { + s, ok := <-stats + if !ok { + return false + } + t.Logf("got stats: ConnectionCount=%d, RxBytes=%d, TxBytes=%d, SessionCountSsh=%d", + s.ConnectionCount, s.RxBytes, s.TxBytes, s.SessionCountSsh) + if s.ConnectionCount > 0 { + connectionCountSeen = true + } + if s.RxBytes > 0 { + rxBytesSeen = true + } + if s.TxBytes > 0 { + txBytesSeen = true + } + if s.SessionCountSsh == 1 { + sessionCountSSHSeen = true + } + return connectionCountSeen && rxBytesSeen && txBytesSeen && sessionCountSSHSeen + }, testutil.WaitLong, testutil.IntervalFast, + "never saw all SSH stats", + ) } func TestAgent_Stats_ReconnectingPTY(t *testing.T) { @@ -3851,6 +3902,7 @@ func setupAgentWithSecrets(t testing.TB, metadata agentsdk.Manifest, secrets []a Logger: logger.Named("agent"), ReconnectingPTYTimeout: ptyTimeout, EnvironmentVariables: map[string]string{}, + StatsReportInterval: agenttest.StatsInterval, } for _, opt := range opts { diff --git a/agent/agenttest/client.go b/agent/agenttest/client.go index 474469d7ff..1f0ae547a9 100644 --- a/agent/agenttest/client.go +++ b/agent/agenttest/client.go @@ -32,7 +32,8 @@ import ( "github.com/coder/websocket" ) -const statsInterval = 500 * time.Millisecond +// StatsInterval is the report interval returned by FakeAgentAPI.UpdateStats. +const StatsInterval = 500 * time.Millisecond func NewClient(t testing.TB, logger slog.Logger, @@ -128,6 +129,12 @@ func (c *Client) RefreshToken(context.Context) error { return nil } +// SetInitialUpdateStatsBarrier blocks the initial (empty) UpdateStats +// RPC until the channel is closed. Must be called before the agent starts. +func (c *Client) SetInitialUpdateStatsBarrier(barrier <-chan struct{}) { + c.fakeAgentAPI.initialUpdateStatsBarrier = barrier +} + func (c *Client) GetNumRefreshTokenCalls() int { c.mu.Lock() defer c.mu.Unlock() @@ -232,19 +239,20 @@ type FakeAgentAPI struct { t testing.TB logger slog.Logger - manifest *agentproto.Manifest - startupCh chan *agentproto.Startup - statsCh chan *agentproto.Stats - appHealthCh chan *agentproto.BatchUpdateAppHealthRequest - logsCh chan<- *agentproto.BatchCreateLogsRequest - lifecycleStates []codersdk.WorkspaceAgentLifecycle - metadata map[string]agentsdk.Metadata - timings []*agentproto.Timing - connectionReports []*agentproto.ReportConnectionRequest - subAgents map[uuid.UUID]*agentproto.SubAgent - subAgentDirs map[uuid.UUID]string - subAgentDisplayApps map[uuid.UUID][]agentproto.CreateSubAgentRequest_DisplayApp - subAgentApps map[uuid.UUID][]*agentproto.CreateSubAgentRequest_App + initialUpdateStatsBarrier <-chan struct{} // blocks initial (empty) UpdateStats until closed + manifest *agentproto.Manifest + startupCh chan *agentproto.Startup + statsCh chan *agentproto.Stats + appHealthCh chan *agentproto.BatchUpdateAppHealthRequest + logsCh chan<- *agentproto.BatchCreateLogsRequest + lifecycleStates []codersdk.WorkspaceAgentLifecycle + metadata map[string]agentsdk.Metadata + timings []*agentproto.Timing + connectionReports []*agentproto.ReportConnectionRequest + subAgents map[uuid.UUID]*agentproto.SubAgent + subAgentDirs map[uuid.UUID]string + subAgentDisplayApps map[uuid.UUID][]agentproto.CreateSubAgentRequest_DisplayApp + subAgentApps map[uuid.UUID][]*agentproto.CreateSubAgentRequest_App getAnnouncementBannersFunc func() ([]codersdk.BannerConfig, error) getResourcesMonitoringConfigurationFunc func() (*agentproto.GetResourcesMonitoringConfigurationResponse, error) @@ -330,8 +338,14 @@ func (f *FakeAgentAPI) UpdateStats(ctx context.Context, req *agentproto.UpdateSt case f.statsCh <- req.Stats: // OK! } + } else if barrier := f.initialUpdateStatsBarrier; barrier != nil { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-barrier: + } } - return &agentproto.UpdateStatsResponse{ReportInterval: durationpb.New(statsInterval)}, nil + return &agentproto.UpdateStatsResponse{ReportInterval: durationpb.New(StatsInterval)}, nil } func (f *FakeAgentAPI) GetLifecycleStates() []codersdk.WorkspaceAgentLifecycle { diff --git a/agent/stats.go b/agent/stats.go index 3df0fd44df..1989ff4fed 100644 --- a/agent/stats.go +++ b/agent/stats.go @@ -42,13 +42,22 @@ type statsReporter struct { logger slog.Logger } -func newStatsReporter(logger slog.Logger, source networkStatsSource, collector statsCollector) *statsReporter { - return &statsReporter{ - Cond: sync.NewCond(&sync.Mutex{}), - logger: logger, - source: source, - collector: collector, +// DefaultStatsReportInterval matches coderd.Options.AgentStatsRefreshInterval. +const DefaultStatsReportInterval = 5 * time.Minute + +func newStatsReporter(logger slog.Logger, source networkStatsSource, collector statsCollector, interval time.Duration) *statsReporter { + s := &statsReporter{ + Cond: sync.NewCond(&sync.Mutex{}), + logger: logger, + source: source, + collector: collector, + lastInterval: interval, } + // Install the callback immediately so traffic is tracked before + // reportLoop starts. reportLoop replaces it only if the + // server-negotiated interval differs. + source.SetConnStatsCallback(interval, maxConns, s.callback) + return s } func (s *statsReporter) callback(_, _ time.Time, virtual, _ map[netlogtype.Connection]netlogtype.Counts) { @@ -67,8 +76,10 @@ func (s *statsReporter) callback(_, _ time.Time, virtual, _ map[netlogtype.Conne s.Broadcast() } -// reportLoop programs the source (tailnet.Conn) to send it stats via the -// callback, then reports them to the dest. +// reportLoop reports collected stats to the server. +// +// The connstats callback is already installed by newStatsReporter; +// reportLoop only replaces it if the server returns a different interval. // // It's intended to be called within the larger retry loop that establishes a // connection to the agent API, then passes that connection to go routines like @@ -80,8 +91,11 @@ func (s *statsReporter) reportLoop(ctx context.Context, dest statsDest) error { if err != nil { return xerrors.Errorf("initial update: %w", err) } - s.lastInterval = resp.ReportInterval.AsDuration() - s.source.SetConnStatsCallback(s.lastInterval, maxConns, s.callback) + interval := resp.ReportInterval.AsDuration() + if interval != s.lastInterval { + s.lastInterval = interval + s.source.SetConnStatsCallback(s.lastInterval, maxConns, s.callback) + } // use a separate goroutine to monitor the context so that we notice immediately, rather than // waiting for the next callback (which might never come if we are closing!) diff --git a/agent/stats_internal_test.go b/agent/stats_internal_test.go index e35fa9d3e2..f0854659fc 100644 --- a/agent/stats_internal_test.go +++ b/agent/stats_internal_test.go @@ -23,7 +23,9 @@ func TestStatsReporter(t *testing.T) { fSource := newFakeNetworkStatsSource(ctx, t) fCollector := newFakeCollector(t) fDest := newFakeStatsDest() - uut := newStatsReporter(logger, fSource, fCollector) + uut := newStatsReporter(logger, fSource, fCollector, DefaultStatsReportInterval) + + _ = testutil.TryReceive(ctx, t, fSource.period) // drain construction-time install loopErr := make(chan error, 1) loopCtx, loopCancel := context.WithCancel(ctx) @@ -157,7 +159,7 @@ func newFakeNetworkStatsSource(ctx context.Context, t testing.TB) *fakeNetworkSt f := &fakeNetworkStatsSource{ ctx: ctx, t: t, - period: make(chan time.Duration), + period: make(chan time.Duration, 1), } return f }