mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
5861e516b9
Refactors our use of `slogtest` to instantiate a "standard logger" across most of our tests. This standard logger incorporates https://github.com/coder/slog/pull/217 to also ignore database query canceled errors by default, which are a source of low-severity flakes. Any test that has set non-default `slogtest.Options` is left alone. In particular, `coderdtest` defaults to ignoring all errors. We might consider revisiting that decision now that we have better tools to target the really common flaky Error logs on shutdown.
271 lines
7.3 KiB
Go
271 lines
7.3 KiB
Go
package agent
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"io"
|
|
"net/netip"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/require"
|
|
"google.golang.org/protobuf/types/known/durationpb"
|
|
"tailscale.com/types/ipproto"
|
|
|
|
"tailscale.com/types/netlogtype"
|
|
|
|
"cdr.dev/slog"
|
|
"cdr.dev/slog/sloggers/slogjson"
|
|
"github.com/coder/coder/v2/agent/proto"
|
|
"github.com/coder/coder/v2/testutil"
|
|
)
|
|
|
|
func TestStatsReporter(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
logger := testutil.Logger(t)
|
|
fSource := newFakeNetworkStatsSource(ctx, t)
|
|
fCollector := newFakeCollector(t)
|
|
fDest := newFakeStatsDest()
|
|
uut := newStatsReporter(logger, fSource, fCollector)
|
|
|
|
loopErr := make(chan error, 1)
|
|
loopCtx, loopCancel := context.WithCancel(ctx)
|
|
go func() {
|
|
err := uut.reportLoop(loopCtx, fDest)
|
|
loopErr <- err
|
|
}()
|
|
|
|
// initial request to get duration
|
|
req := testutil.RequireRecvCtx(ctx, t, fDest.reqs)
|
|
require.NotNil(t, req)
|
|
require.Nil(t, req.Stats)
|
|
interval := time.Second * 34
|
|
testutil.RequireSendCtx(ctx, t, fDest.resps, &proto.UpdateStatsResponse{ReportInterval: durationpb.New(interval)})
|
|
|
|
// call to source to set the callback and interval
|
|
gotInterval := testutil.RequireRecvCtx(ctx, t, fSource.period)
|
|
require.Equal(t, interval, gotInterval)
|
|
|
|
// callback returning netstats
|
|
netStats := map[netlogtype.Connection]netlogtype.Counts{
|
|
{
|
|
Proto: ipproto.TCP,
|
|
Src: netip.MustParseAddrPort("192.168.1.33:4887"),
|
|
Dst: netip.MustParseAddrPort("192.168.2.99:9999"),
|
|
}: {
|
|
TxPackets: 22,
|
|
TxBytes: 23,
|
|
RxPackets: 24,
|
|
RxBytes: 25,
|
|
},
|
|
}
|
|
fSource.callback(time.Now(), time.Now(), netStats, nil)
|
|
|
|
// collector called to complete the stats
|
|
gotNetStats := testutil.RequireRecvCtx(ctx, t, fCollector.calls)
|
|
require.Equal(t, netStats, gotNetStats)
|
|
|
|
// while we are collecting the stats, send in two new netStats to simulate
|
|
// what happens if we don't keep up. Only the latest should be kept.
|
|
netStats0 := map[netlogtype.Connection]netlogtype.Counts{
|
|
{
|
|
Proto: ipproto.TCP,
|
|
Src: netip.MustParseAddrPort("192.168.1.33:4887"),
|
|
Dst: netip.MustParseAddrPort("192.168.2.99:9999"),
|
|
}: {
|
|
TxPackets: 10,
|
|
TxBytes: 10,
|
|
RxPackets: 10,
|
|
RxBytes: 10,
|
|
},
|
|
}
|
|
fSource.callback(time.Now(), time.Now(), netStats0, nil)
|
|
netStats1 := map[netlogtype.Connection]netlogtype.Counts{
|
|
{
|
|
Proto: ipproto.TCP,
|
|
Src: netip.MustParseAddrPort("192.168.1.33:4887"),
|
|
Dst: netip.MustParseAddrPort("192.168.2.99:9999"),
|
|
}: {
|
|
TxPackets: 11,
|
|
TxBytes: 11,
|
|
RxPackets: 11,
|
|
RxBytes: 11,
|
|
},
|
|
}
|
|
fSource.callback(time.Now(), time.Now(), netStats1, nil)
|
|
|
|
// complete first collection
|
|
stats := &proto.Stats{SessionCountJetbrains: 55}
|
|
testutil.RequireSendCtx(ctx, t, fCollector.stats, stats)
|
|
|
|
// destination called to report the first stats
|
|
update := testutil.RequireRecvCtx(ctx, t, fDest.reqs)
|
|
require.NotNil(t, update)
|
|
require.Equal(t, stats, update.Stats)
|
|
testutil.RequireSendCtx(ctx, t, fDest.resps, &proto.UpdateStatsResponse{ReportInterval: durationpb.New(interval)})
|
|
|
|
// second update -- only netStats1 is reported
|
|
gotNetStats = testutil.RequireRecvCtx(ctx, t, fCollector.calls)
|
|
require.Equal(t, netStats1, gotNetStats)
|
|
stats = &proto.Stats{SessionCountJetbrains: 66}
|
|
testutil.RequireSendCtx(ctx, t, fCollector.stats, stats)
|
|
update = testutil.RequireRecvCtx(ctx, t, fDest.reqs)
|
|
require.NotNil(t, update)
|
|
require.Equal(t, stats, update.Stats)
|
|
interval2 := 27 * time.Second
|
|
testutil.RequireSendCtx(ctx, t, fDest.resps, &proto.UpdateStatsResponse{ReportInterval: durationpb.New(interval2)})
|
|
|
|
// set the new interval
|
|
gotInterval = testutil.RequireRecvCtx(ctx, t, fSource.period)
|
|
require.Equal(t, interval2, gotInterval)
|
|
|
|
loopCancel()
|
|
err := testutil.RequireRecvCtx(ctx, t, loopErr)
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
type fakeNetworkStatsSource struct {
|
|
sync.Mutex
|
|
ctx context.Context
|
|
t testing.TB
|
|
callback func(start, end time.Time, virtual, physical map[netlogtype.Connection]netlogtype.Counts)
|
|
period chan time.Duration
|
|
}
|
|
|
|
func (f *fakeNetworkStatsSource) SetConnStatsCallback(maxPeriod time.Duration, _ int, dump func(start time.Time, end time.Time, virtual map[netlogtype.Connection]netlogtype.Counts, physical map[netlogtype.Connection]netlogtype.Counts)) {
|
|
f.Lock()
|
|
defer f.Unlock()
|
|
f.callback = dump
|
|
select {
|
|
case <-f.ctx.Done():
|
|
f.t.Error("timeout")
|
|
case f.period <- maxPeriod:
|
|
// OK
|
|
}
|
|
}
|
|
|
|
func newFakeNetworkStatsSource(ctx context.Context, t testing.TB) *fakeNetworkStatsSource {
|
|
f := &fakeNetworkStatsSource{
|
|
ctx: ctx,
|
|
t: t,
|
|
period: make(chan time.Duration),
|
|
}
|
|
return f
|
|
}
|
|
|
|
type fakeCollector struct {
|
|
t testing.TB
|
|
calls chan map[netlogtype.Connection]netlogtype.Counts
|
|
stats chan *proto.Stats
|
|
}
|
|
|
|
func (f *fakeCollector) Collect(ctx context.Context, networkStats map[netlogtype.Connection]netlogtype.Counts) *proto.Stats {
|
|
select {
|
|
case <-ctx.Done():
|
|
f.t.Error("timeout on collect")
|
|
return nil
|
|
case f.calls <- networkStats:
|
|
// ok
|
|
}
|
|
select {
|
|
case <-ctx.Done():
|
|
f.t.Error("timeout on collect")
|
|
return nil
|
|
case s := <-f.stats:
|
|
return s
|
|
}
|
|
}
|
|
|
|
func newFakeCollector(t testing.TB) *fakeCollector {
|
|
return &fakeCollector{
|
|
t: t,
|
|
calls: make(chan map[netlogtype.Connection]netlogtype.Counts),
|
|
stats: make(chan *proto.Stats),
|
|
}
|
|
}
|
|
|
|
type fakeStatsDest struct {
|
|
reqs chan *proto.UpdateStatsRequest
|
|
resps chan *proto.UpdateStatsResponse
|
|
}
|
|
|
|
func (f *fakeStatsDest) UpdateStats(ctx context.Context, req *proto.UpdateStatsRequest) (*proto.UpdateStatsResponse, error) {
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
case f.reqs <- req:
|
|
// OK
|
|
}
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
case resp := <-f.resps:
|
|
return resp, nil
|
|
}
|
|
}
|
|
|
|
func newFakeStatsDest() *fakeStatsDest {
|
|
return &fakeStatsDest{
|
|
reqs: make(chan *proto.UpdateStatsRequest),
|
|
resps: make(chan *proto.UpdateStatsResponse),
|
|
}
|
|
}
|
|
|
|
func Test_logDebouncer(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var (
|
|
buf bytes.Buffer
|
|
logger = slog.Make(slogjson.Sink(&buf))
|
|
ctx = context.Background()
|
|
)
|
|
|
|
debouncer := &logDebouncer{
|
|
logger: logger,
|
|
messages: map[string]time.Time{},
|
|
interval: time.Minute,
|
|
}
|
|
|
|
fields := map[string]interface{}{
|
|
"field_1": float64(1),
|
|
"field_2": "2",
|
|
}
|
|
|
|
debouncer.Error(ctx, "my message", "field_1", 1, "field_2", "2")
|
|
debouncer.Warn(ctx, "another message", "field_1", 1, "field_2", "2")
|
|
// Shouldn't log this.
|
|
debouncer.Warn(ctx, "another message", "field_1", 1, "field_2", "2")
|
|
|
|
require.Len(t, debouncer.messages, 2)
|
|
|
|
type entry struct {
|
|
Msg string `json:"msg"`
|
|
Level string `json:"level"`
|
|
Fields map[string]interface{} `json:"fields"`
|
|
}
|
|
|
|
assertLog := func(msg string, level string, fields map[string]interface{}) {
|
|
line, err := buf.ReadString('\n')
|
|
require.NoError(t, err)
|
|
|
|
var e entry
|
|
err = json.Unmarshal([]byte(line), &e)
|
|
require.NoError(t, err)
|
|
require.Equal(t, msg, e.Msg)
|
|
require.Equal(t, level, e.Level)
|
|
require.Equal(t, fields, e.Fields)
|
|
}
|
|
assertLog("my message", "ERROR", fields)
|
|
assertLog("another message", "WARN", fields)
|
|
|
|
debouncer.messages["another message"] = time.Now().Add(-2 * time.Minute)
|
|
debouncer.Warn(ctx, "another message", "field_1", 1, "field_2", "2")
|
|
assertLog("another message", "WARN", fields)
|
|
// Assert nothing else was written.
|
|
_, err := buf.ReadString('\n')
|
|
require.ErrorIs(t, err, io.EOF)
|
|
}
|