chore: extract testutil.FakeSink for slog test assertions (#23208)

Follow-up to [review comment on
#23025](https://github.com/coder/coder/pull/23025#discussion_r2930309487)
from @mafredri.

Extracts the repeated `logSink` / `fakeSink` test pattern into a shared
`testutil.FakeSink` and migrates all existing call sites.

> 🤖 This PR was created with the help of Coder Agents, and will be
reviewed by my human. 🧑‍💻

---------

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Cian Johnston
2026-03-18 17:02:38 +00:00
committed by GitHub
parent 2577d16af2
commit 65b7658568
6 changed files with 314 additions and 137 deletions
+76
View File
@@ -0,0 +1,76 @@
package testutil
import (
"context"
"sync"
"testing"
"cdr.dev/slog/v3"
)
// FakeSink is a thread-safe slog.Sink that captures log entries so
// tests can assert on what was logged. It requires a testing.TB
// as it is only meant for use in tests.
type FakeSink struct {
mu sync.RWMutex
entries []slog.SinkEntry
}
// NewFakeSink returns a FakeSink ready for use.
func NewFakeSink(_ testing.TB) *FakeSink {
return &FakeSink{}
}
// LogEntry implements slog.Sink. It appends the entry to the
// internal slice.
func (s *FakeSink) LogEntry(_ context.Context, e slog.SinkEntry) {
s.mu.Lock()
defer s.mu.Unlock()
s.entries = append(s.entries, e)
}
// Sync implements slog.Sink.
func (*FakeSink) Sync() {}
// Entries returns a copy of the captured entries. If filters are
// provided, only entries matching ALL filters are returned. This
// lets callers compose simple predicates instead of needing
// dedicated methods for each field.
func (s *FakeSink) Entries(filters ...func(slog.SinkEntry) bool) []slog.SinkEntry {
s.mu.RLock()
cpy := make([]slog.SinkEntry, len(s.entries))
copy(cpy, s.entries)
s.mu.RUnlock()
filtered := make([]slog.SinkEntry, 0)
for _, e := range cpy {
if !matchAll(e, filters) {
continue
}
filtered = append(filtered, e)
}
return filtered
}
// Logger returns a slog.Logger backed by this sink at the given
// level. If no level is provided it defaults to LevelDebug, which
// captures everything. If more than one level is provided, the
// first one wins.
func (s *FakeSink) Logger(level ...slog.Level) slog.Logger {
l := slog.LevelDebug
if len(level) > 0 {
l = level[0]
}
return slog.Make(s).Leveled(l)
}
func matchAll(e slog.SinkEntry, filters []func(slog.SinkEntry) bool) bool {
for _, f := range filters {
if f == nil {
continue
}
if !f(e) {
return false
}
}
return true
}
+170
View File
@@ -0,0 +1,170 @@
package testutil_test
import (
"strings"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/testutil"
)
func TestFakeSink(t *testing.T) {
t.Parallel()
t.Run("BasicCapture", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
sink := testutil.NewFakeSink(t)
logger := sink.Logger()
logger.Debug(ctx, "first test message")
logger.Debug(ctx, "second test message")
logger.Debug(ctx, "third test message")
entries := sink.Entries()
require.Len(t, entries, 3)
})
t.Run("FilterByLevel", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
sink := testutil.NewFakeSink(t)
logger := sink.Logger()
logger.Debug(ctx, "debug level message")
logger.Info(ctx, "info level message")
logger.Error(ctx, "error level message")
errorOnly := sink.Entries(func(e slog.SinkEntry) bool {
return e.Level == slog.LevelError
})
require.Len(t, errorOnly, 1)
assert.Equal(t, "error level message", errorOnly[0].Message)
})
t.Run("MultipleFiltersAND", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
sink := testutil.NewFakeSink(t)
logger := sink.Logger()
logger.Info(ctx, "hello world filter test")
logger.Info(ctx, "goodbye world filter")
logger.Error(ctx, "hello error filter test")
byLevel := func(e slog.SinkEntry) bool {
return e.Level == slog.LevelInfo
}
byMessage := func(e slog.SinkEntry) bool {
return strings.Contains(e.Message, "hello")
}
matched := sink.Entries(byLevel, byMessage)
require.Len(t, matched, 1)
assert.Equal(t, "hello world filter test", matched[0].Message)
})
t.Run("NilFilterSkipped", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
sink := testutil.NewFakeSink(t)
logger := sink.Logger()
logger.Info(ctx, "nil filter test msg")
// A nil filter should be harmlessly skipped.
entries := sink.Entries(nil)
require.Len(t, entries, 1)
})
t.Run("NoFilters", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
sink := testutil.NewFakeSink(t)
logger := sink.Logger()
logger.Debug(ctx, "no filter debug msg")
logger.Info(ctx, "no filter info msg")
entries := sink.Entries()
require.Len(t, entries, 2)
})
t.Run("EmptySink", func(t *testing.T) {
t.Parallel()
sink := testutil.NewFakeSink(t)
entries := sink.Entries()
assert.Empty(t, entries)
})
t.Run("ThreadSafety", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
sink := testutil.NewFakeSink(t)
logger := sink.Logger()
const goroutines = 10
const entriesPerGoroutine = 100
var wg sync.WaitGroup
wg.Add(goroutines)
for range goroutines {
go func() {
defer wg.Done()
for range entriesPerGoroutine {
logger.Debug(ctx, "concurrent log entry")
}
}()
}
wg.Wait()
entries := sink.Entries()
require.Len(t, entries, goroutines*entriesPerGoroutine)
})
t.Run("LoggerConvenience", func(t *testing.T) {
t.Parallel()
t.Run("DefaultDebug", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
sink := testutil.NewFakeSink(t)
logger := sink.Logger()
logger.Debug(ctx, "captured at debug level")
entries := sink.Entries()
require.Len(t, entries, 1)
assert.Equal(t, slog.LevelDebug, entries[0].Level)
})
t.Run("RespectsLevel", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
sink := testutil.NewFakeSink(t)
logger := sink.Logger(slog.LevelInfo)
// Debug should be filtered out by the logger
// because the level is set to Info.
logger.Debug(ctx, "filtered out by level")
logger.Info(ctx, "kept by info level")
entries := sink.Entries()
require.Len(t, entries, 1)
assert.Equal(t, slog.LevelInfo, entries[0].Level)
})
})
}