Files
coder/coderd/agentapi/metadatabatcher/metadata_batcher_internal_test.go
T
Callum Styan e195856c43 perf: reduce pg_notify call volume by batching together agent metadata updates (#21330)
---------

Signed-off-by: Callum Styan <callumstyan@gmail.com>
Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
2026-01-22 22:47:49 -08:00

1009 lines
33 KiB
Go

package metadatabatcher
import (
"context"
"encoding/base64"
"fmt"
"sync"
"testing"
"time"
"github.com/google/uuid"
"github.com/prometheus/client_golang/prometheus"
prom_testutil "github.com/prometheus/client_golang/prometheus/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/coderd/database/pubsub/psmock"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
)
// ============================================================================
// Custom gomock matchers for metadata batcher testing
// ============================================================================
// metadataParamsMatcher validates BatchUpdateWorkspaceAgentMetadataParams by checking all fields match expected values.
type metadataParamsMatcher struct {
expectedAgentIDs []uuid.UUID
expectedKeys []string
expectedValues []string
expectedErrors []string
expectedTimes []time.Time
}
func (m metadataParamsMatcher) Matches(x interface{}) bool {
params, ok := x.(database.BatchUpdateWorkspaceAgentMetadataParams)
if !ok {
return false
}
// All arrays must have the same length.
expectedLen := len(m.expectedKeys)
if len(params.WorkspaceAgentID) != expectedLen ||
len(params.Key) != expectedLen ||
len(params.Value) != expectedLen ||
len(params.Error) != expectedLen ||
len(params.CollectedAt) != expectedLen {
return false
}
// Check each field matches expected values. We create a map of expected entries and verify all actual entries match.
expectedEntries := make(map[string]bool)
for i := 0; i < len(m.expectedKeys); i++ {
key := fmt.Sprintf("%s|%s|%s|%s|%s",
m.expectedAgentIDs[i].String(),
m.expectedKeys[i],
m.expectedValues[i],
m.expectedErrors[i],
m.expectedTimes[i].Format(time.RFC3339Nano))
expectedEntries[key] = false // not yet found
}
// Check all actual entries are expected.
for i := 0; i < len(params.Key); i++ {
key := fmt.Sprintf("%s|%s|%s|%s|%s",
params.WorkspaceAgentID[i].String(),
params.Key[i],
params.Value[i],
params.Error[i],
params.CollectedAt[i].Format(time.RFC3339Nano))
if _, exists := expectedEntries[key]; !exists {
return false
}
expectedEntries[key] = true
}
// Check all expected entries were found.
for _, found := range expectedEntries {
if !found {
return false
}
}
return true
}
func (m metadataParamsMatcher) String() string {
return fmt.Sprintf("metadata params with %d entries (agents: %v, keys: %v)",
len(m.expectedKeys), m.expectedAgentIDs, m.expectedKeys)
}
// matchMetadata creates a matcher that checks all values in the metadata params.
func matchMetadata(agentIDs []uuid.UUID, keys, values, errors []string, times []time.Time) gomock.Matcher {
return metadataParamsMatcher{
expectedAgentIDs: agentIDs,
expectedKeys: keys,
expectedValues: values,
expectedErrors: errors,
expectedTimes: times,
}
}
// pubsubCapture captures and decodes pubsub publish calls to accumulate agent IDs.
type pubsubCapture struct {
t *testing.T
mu sync.Mutex
agentIDs map[uuid.UUID]struct{}
}
func newPubsubCapture(t *testing.T) *pubsubCapture {
return &pubsubCapture{
agentIDs: make(map[uuid.UUID]struct{}),
t: t,
}
}
func (c *pubsubCapture) capture(event string, message []byte) {
c.mu.Lock()
defer c.mu.Unlock()
// Verify correct event.
assert.Equal(c.t, event, MetadataBatchPubsubChannel)
// Decode base64-encoded agent IDs from payload.
assert.Equal(c.t, len(message)%UUIDBase64Size, 0)
numAgents := len(message) / UUIDBase64Size
for i := 0; i < numAgents; i++ {
start := i * UUIDBase64Size
end := start + UUIDBase64Size
encoded := message[start:end]
var uuidBytes [16]byte
n, err := base64.RawStdEncoding.Decode(uuidBytes[:], encoded)
assert.NoError(c.t, err)
assert.Equal(c.t, n, 16)
agentID, err := uuid.FromBytes(uuidBytes[:])
assert.NoError(c.t, err)
c.agentIDs[agentID] = struct{}{}
}
}
func (c *pubsubCapture) requireContainsAll(expected []uuid.UUID) {
c.mu.Lock()
defer c.mu.Unlock()
// Check we don't have extra IDs.
require.Equal(c.t, len(expected), len(c.agentIDs), "unexpected number of agent IDs in pubsub messages")
// Check all expected IDs are present.
for _, expectedID := range expected {
_, ok := c.agentIDs[expectedID]
require.True(c.t, ok, "expected agent ID %s not found in pubsub messages", expectedID)
}
}
func (c *pubsubCapture) count() int {
c.mu.Lock()
defer c.mu.Unlock()
return len(c.agentIDs)
}
func (c *pubsubCapture) clear() {
c.mu.Lock()
defer c.mu.Unlock()
c.agentIDs = make(map[uuid.UUID]struct{})
}
func TestMetadataBatcher(t *testing.T) {
t.Parallel()
// Given: a fresh batcher with no data
ctx := testutil.Context(t, testutil.WaitShort)
log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
ctrl := gomock.NewController(t)
store := dbmock.NewMockStore(ctrl)
ps := psmock.NewMockPubsub(ctrl)
clock := quartz.NewMock(t)
// Trap timer reset calls so we can wait for them to complete.
resetTrap := clock.Trap().TimerReset("metadataBatcher", "scheduledFlush")
defer resetTrap.Close()
capacityResetTrap := clock.Trap().TimerReset("metadataBatcher", "capacityFlush")
defer capacityResetTrap.Close()
// Generate mock agent IDs.
agent1 := uuid.New()
agent2 := uuid.New()
// Create a single pubsub capture to reuse across all flushes.
psCap := newPubsubCapture(t)
// --- FLUSH 1: Empty flush (no calls expected) ---
// No expectations set - if DB query called, test will fail.
reg := prometheus.NewRegistry()
b, err := NewBatcher(ctx, reg, store, ps,
WithLogger(log),
WithClock(clock),
)
require.NoError(t, err)
t.Cleanup(b.Close)
// Given: no metadata updates are added
// When: it becomes time to flush
// Then: no metadata should be updated (no DB call)
clock.Advance(defaultMetadataFlushInterval).MustWait(ctx)
resetTrap.MustWait(ctx).MustRelease(ctx) // Wait for timer reset after flush
t.Log("flush 1 completed (expected 0 entries)")
require.Equal(t, float64(0), prom_testutil.ToFloat64(b.Metrics.BatchesTotal.WithLabelValues(flushTicker)))
// --- FLUSH 2: Single agent with 2 metadata entries ---
t2 := clock.Now()
// Expect exactly 1 database call with exact values.
store.EXPECT().
BatchUpdateWorkspaceAgentMetadata(
gomock.Any(),
matchMetadata(
[]uuid.UUID{agent1, agent1},
[]string{"key1", "key2"},
[]string{"value1", "value2"},
[]string{"", ""},
[]time.Time{t2, t2},
),
).
Return(nil).
Times(1)
// Expect exactly 1 pubsub publish with correct event and agent IDs.
ps.EXPECT().
Publish(gomock.Any(), gomock.Any()).
Do(psCap.capture).
Return(nil).
Times(1)
// Given: a single metadata update is added for agent1
t.Log("adding metadata for 1 agent")
// Capture dropped count before adding.
droppedBefore := prom_testutil.ToFloat64(b.Metrics.DroppedKeysTotal)
require.NoError(t, b.Add(agent1, []string{"key1", "key2"}, []string{"value1", "value2"}, []string{"", ""}, []time.Time{t2, t2}))
// Wait for the channel to be processed and verify nothing was dropped.
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
channelEmpty := len(b.updateCh) == 0
nothingDropped := prom_testutil.ToFloat64(b.Metrics.DroppedKeysTotal) == droppedBefore
batchHasExpected := int(b.currentBatchLen.Load()) == 2
return channelEmpty && nothingDropped && batchHasExpected
}, testutil.IntervalFast)
// When: it becomes time to flush
clock.Advance(defaultMetadataFlushInterval).MustWait(ctx)
resetTrap.MustWait(ctx).MustRelease(ctx) // Wait for timer reset after flush
t.Log("flush 2 completed (expected 2 entries)")
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
val := prom_testutil.ToFloat64(b.Metrics.BatchesTotal.WithLabelValues(flushTicker))
totalMeta := prom_testutil.ToFloat64(b.Metrics.MetadataTotal)
return float64(1) == val && totalMeta >= float64(2)
}, testutil.IntervalFast)
require.Equal(t, float64(2), prom_testutil.ToFloat64(b.Metrics.MetadataTotal))
// Wait for pubsub capture to complete and verify all agent IDs were published.
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
return psCap.count() == 1
}, testutil.IntervalFast)
psCap.requireContainsAll([]uuid.UUID{agent1})
// --- FLUSH 3: Multiple agents with 5 total metadata entries ---
t3 := clock.Now()
// Clear pubsub capture for the next flush.
psCap.clear()
// Expect exactly 1 database call with exact values for both agents.
store.EXPECT().
BatchUpdateWorkspaceAgentMetadata(
gomock.Any(),
matchMetadata(
[]uuid.UUID{agent1, agent1, agent1, agent2, agent2},
[]string{"key1", "key2", "key3", "key1", "key2"},
[]string{"new_value1", "new_value2", "new_value3", "agent2_value1", "agent2_value2"},
[]string{"", "", "", "", ""},
[]time.Time{t3, t3, t3, t3, t3},
),
).
Return(nil).
Times(1)
// Expect exactly 1 pubsub publish with both agent IDs.
ps.EXPECT().
Publish(gomock.Any(), gomock.Any()).
Do(psCap.capture).
Return(nil).
Times(1)
// Given: metadata updates are added for multiple agents
t.Log("adding metadata for 2 agents")
// Capture dropped count before any adds.
droppedBefore = prom_testutil.ToFloat64(b.Metrics.DroppedKeysTotal)
require.NoError(t, b.Add(agent1, []string{"key1", "key2", "key3"}, []string{"new_value1", "new_value2", "new_value3"}, []string{"", "", ""}, []time.Time{t3, t3, t3}))
require.NoError(t, b.Add(agent2, []string{"key1", "key2"}, []string{"agent2_value1", "agent2_value2"}, []string{"", ""}, []time.Time{t3, t3}))
// Wait for all channel messages to be processed into the batch.
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
channelEmpty := len(b.updateCh) == 0
nothingDropped := prom_testutil.ToFloat64(b.Metrics.DroppedKeysTotal) == droppedBefore
batchHasExpected := int(b.currentBatchLen.Load()) == 5
return channelEmpty && nothingDropped && batchHasExpected
}, testutil.IntervalFast)
// When: it becomes time to flush
clock.Advance(defaultMetadataFlushInterval).MustWait(ctx)
resetTrap.MustWait(ctx).MustRelease(ctx) // Wait for timer reset after flush
t.Log("flush 3 completed (expected 5 new entries)")
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
val := prom_testutil.ToFloat64(b.Metrics.BatchesTotal.WithLabelValues(flushTicker))
totalMeta := prom_testutil.ToFloat64(b.Metrics.MetadataTotal)
return float64(2) == val && totalMeta >= float64(7)
}, testutil.IntervalFast)
require.Equal(t, float64(7), prom_testutil.ToFloat64(b.Metrics.MetadataTotal))
// Wait for pubsub capture to complete and verify all agent IDs were published.
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
return psCap.count() == 2
}, testutil.IntervalFast)
psCap.requireContainsAll([]uuid.UUID{agent1, agent2})
// --- FLUSH 4: Capacity flush with defaultMetadataBatchSize entries ---
t4 := clock.Now()
numAgents := defaultMetadataBatchSize
// Clear pubsub capture for the next flush.
psCap.clear()
// Pre-generate all agent IDs so we can assert on exact values.
agentIDs := make([]uuid.UUID, numAgents)
for i := 0; i < numAgents; i++ {
agentIDs[i] = uuid.New()
}
// Build expected values for database assertion.
expectedKeys := make([]string, numAgents)
expectedValues := make([]string, numAgents)
expectedErrors := make([]string, numAgents)
expectedTimes := make([]time.Time, numAgents)
for i := 0; i < numAgents; i++ {
expectedKeys[i] = "key1"
expectedValues[i] = "bulk_value"
expectedErrors[i] = ""
expectedTimes[i] = t4
}
// Assert on exact database values.
store.EXPECT().
BatchUpdateWorkspaceAgentMetadata(
gomock.Any(),
matchMetadata(agentIDs, expectedKeys, expectedValues, expectedErrors, expectedTimes),
).
Return(nil).
Times(1)
// Pubsub will be called with chunking.
// With 500 agents, we expect exactly 2 pubsub calls due to chunking (363 + 137).
ps.EXPECT().
Publish(gomock.Any(), gomock.Any()).
Do(psCap.capture).
Return(nil).
Times(2)
// Add metadata updates using the pre-generated agent IDs.
done := make(chan struct{})
go func() {
defer close(done)
t.Logf("adding metadata for %d agents", numAgents)
for i := 0; i < numAgents; i++ {
require.NoError(t, b.Add(agentIDs[i], []string{"key1"}, []string{"bulk_value"}, []string{""}, []time.Time{t4}))
}
}()
// Wait for all updates to be added
<-done
capacityResetTrap.MustWait(ctx).MustRelease(ctx) // Wait for timer reset after capacity flush
t.Log("flush 4 completed (capacity flush, expected", defaultMetadataBatchSize, "entries)")
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
return float64(1) == prom_testutil.ToFloat64(b.Metrics.BatchesTotal.WithLabelValues(flushCapacity))
}, testutil.IntervalFast)
require.Equal(t, float64(507), prom_testutil.ToFloat64(b.Metrics.MetadataTotal))
// Wait for pubsub capture to complete and verify all agent IDs were published (across all chunks).
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
return psCap.count() == numAgents
}, testutil.IntervalFast)
psCap.requireContainsAll(agentIDs)
}
func TestMetadataBatcher_DropsWhenFull(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
ctrl := gomock.NewController(t)
store := dbmock.NewMockStore(ctrl)
ps := psmock.NewMockPubsub(ctrl)
clock := quartz.NewMock(t)
reg := prometheus.NewRegistry()
// Batch size of 2 means channel capacity = 10 (2 * 5)
b, err := NewBatcher(ctx, reg, store, ps,
WithLogger(log),
WithBatchSize(2),
WithClock(clock),
)
require.NoError(t, err)
t.Cleanup(b.Close)
t1 := clock.Now()
// Channels to control when the store call blocks/unblocks
flushStarted := make(chan struct{})
unblockFlush := make(chan struct{})
pubsubCap := newPubsubCapture(t)
// Make the first store call block until we signal. After unblocking,
// the 10 queued entries will trigger 5 more capacity flushes (10/2 = 5).
// Total expected flushes: 1 (initial) + 5 (queued) = 6
firstCall := true
store.EXPECT().
BatchUpdateWorkspaceAgentMetadata(gomock.Any(), gomock.Any()).
DoAndReturn(func(ctx context.Context, params database.BatchUpdateWorkspaceAgentMetadataParams) error {
if firstCall {
firstCall = false
close(flushStarted) // Signal that first flush has started
<-unblockFlush // Wait for signal to continue
}
return nil
}).
Times(6)
ps.EXPECT().
Publish(gomock.Any(), gomock.Any()).
Do(pubsubCap.capture).
Return(nil).
Times(6)
// Add 2 entries - this will trigger capacity flush (batch size = 2) that blocks
agent1 := uuid.New()
agent2 := uuid.New()
require.NoError(t, b.Add(agent1, []string{"key1"}, []string{"value1"}, []string{""}, []time.Time{t1}))
require.NoError(t, b.Add(agent2, []string{"key1"}, []string{"value2"}, []string{""}, []time.Time{t1}))
// Wait for flush to start and block in the store call
<-flushStarted
// Now the flush is blocked. Channel capacity is 10.
// Fill the channel with 10 entries
droppedBefore := prom_testutil.ToFloat64(b.Metrics.DroppedKeysTotal)
for i := 0; i < 10; i++ {
agent := uuid.New()
require.NoError(t, b.Add(agent, []string{"key1"}, []string{fmt.Sprintf("value%d", i)}, []string{""}, []time.Time{t1}))
}
// Channel should now be full. Next add should drop.
agentDropped := uuid.New()
require.NoError(t, b.Add(agentDropped, []string{"key1"}, []string{"dropped"}, []string{""}, []time.Time{t1}))
// Verify that 1 key was dropped
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
dropped := prom_testutil.ToFloat64(b.Metrics.DroppedKeysTotal)
return dropped == droppedBefore+1
}, testutil.IntervalFast)
// Unblock the flush
close(unblockFlush)
// Wait for all queued entries to be processed (channel should be empty)
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
return len(b.updateCh) == 0
}, testutil.IntervalFast)
// Verify final state: 1 key was dropped, 12 metadata sent in 6 capacity batches
require.Equal(t, droppedBefore+1, prom_testutil.ToFloat64(b.Metrics.DroppedKeysTotal))
require.Equal(t, float64(12), prom_testutil.ToFloat64(b.Metrics.MetadataTotal))
require.Equal(t, float64(6), prom_testutil.ToFloat64(b.Metrics.BatchesTotal.WithLabelValues(flushCapacity)))
}
// TestMetadataBatcher_Deduplication executes two Add calls, the second with a later timestamp than the first, to check
// that existing keys within a batch have their values updated.
func TestMetadataBatcher_Deduplication(t *testing.T) {
t.Parallel()
tests := []struct {
name string
// First Add call
add1Keys []string
add1Values []string
// Second Add call
add2Keys []string
add2Values []string
// Expected result after deduplication
wantKeys []string
wantValues []string
}{
{
name: "same key updated twice keeps newest",
add1Keys: []string{"key1"},
add1Values: []string{"first_value"},
add2Keys: []string{"key1"},
add2Values: []string{"second_value"},
wantKeys: []string{"key1"},
wantValues: []string{"second_value"},
},
{
name: "mixed keys with partial overlap",
add1Keys: []string{"key1", "key2"},
add1Values: []string{"value1", "value2"},
add2Keys: []string{"key1", "key3"},
add2Values: []string{"new_value1", "value3"},
wantKeys: []string{"key1", "key2", "key3"},
wantValues: []string{"new_value1", "value2", "value3"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
ctrl := gomock.NewController(t)
store := dbmock.NewMockStore(ctrl)
ps := psmock.NewMockPubsub(ctrl)
clock := quartz.NewMock(t)
agent := uuid.New()
reg := prometheus.NewRegistry()
b, err := NewBatcher(ctx, reg, store, ps,
WithLogger(log),
WithClock(clock),
)
require.NoError(t, err)
t.Cleanup(b.Close)
// Set up timestamps - t2 is 1ms after t1
t1 := clock.Now()
t2 := t1.Add(time.Millisecond)
// Create time slices for add1 (all t1) and add2 (all t2)
add1Times := make([]time.Time, len(tt.add1Keys))
for i := range add1Times {
add1Times[i] = t1
}
add2Times := make([]time.Time, len(tt.add2Keys))
for i := range add2Times {
add2Times[i] = t2
}
// Build expected times based on which add they came from.
// If a key appears in add2, it gets t2 (newer), otherwise t1.
expectedTimes := make([]time.Time, len(tt.wantKeys))
for i, wantKey := range tt.wantKeys {
// Check if key appears in add2 (newer)
foundInAdd2 := false
for _, add2Key := range tt.add2Keys {
if add2Key == wantKey {
expectedTimes[i] = t2
foundInAdd2 = true
break
}
}
if !foundInAdd2 {
// Must be from add1
expectedTimes[i] = t1
}
}
// Set up mock expectations
psCap := newPubsubCapture(t)
// Build expected errors (all empty) and agent IDs (all same agent)
expectedErrors := make([]string, len(tt.wantKeys))
for i := range expectedErrors {
expectedErrors[i] = ""
}
expectedAgents := make([]uuid.UUID, len(tt.wantKeys))
for i := range expectedAgents {
expectedAgents[i] = agent
}
store.EXPECT().
BatchUpdateWorkspaceAgentMetadata(
gomock.Any(),
matchMetadata(
expectedAgents,
tt.wantKeys,
tt.wantValues,
expectedErrors,
expectedTimes,
),
).
Return(nil).
Times(1)
ps.EXPECT().
Publish(gomock.Any(), gomock.Any()).
Do(psCap.capture).
Return(nil).
Times(1)
// Perform the adds
droppedBefore := prom_testutil.ToFloat64(b.Metrics.DroppedKeysTotal)
// First add with all empty error strings
add1Errors := make([]string, len(tt.add1Keys))
require.NoError(t, b.Add(agent, tt.add1Keys, tt.add1Values, add1Errors, add1Times))
// Second add with all empty error strings
add2Errors := make([]string, len(tt.add2Keys))
require.NoError(t, b.Add(agent, tt.add2Keys, tt.add2Values, add2Errors, add2Times))
// Wait for all channel messages to be processed into the batch
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
channelEmpty := len(b.updateCh) == 0
nothingDropped := prom_testutil.ToFloat64(b.Metrics.DroppedKeysTotal) == droppedBefore
batchHasExpected := int(b.currentBatchLen.Load()) == len(tt.wantKeys)
return channelEmpty && nothingDropped && batchHasExpected
}, testutil.IntervalFast)
// Trigger scheduled flush
clock.Advance(defaultMetadataFlushInterval).MustWait(ctx)
// Verify flush occurred with correct number of entries
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
return float64(1) == prom_testutil.ToFloat64(b.Metrics.BatchesTotal.WithLabelValues(flushTicker))
}, testutil.IntervalFast)
require.Equal(t, float64(len(tt.wantKeys)), prom_testutil.ToFloat64(b.Metrics.MetadataTotal))
// Verify pubsub published the agent ID
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
return psCap.count() == 1
}, testutil.IntervalFast)
psCap.requireContainsAll([]uuid.UUID{agent})
})
}
}
func TestMetadataBatcher_TimestampOrdering(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
ctrl := gomock.NewController(t)
store := dbmock.NewMockStore(ctrl)
ps := psmock.NewMockPubsub(ctrl)
clock := quartz.NewMock(t)
reg := prometheus.NewRegistry()
b, err := NewBatcher(ctx, reg, store, ps,
WithLogger(log),
WithClock(clock),
)
require.NoError(t, err)
t.Cleanup(b.Close)
// Generate mock agent ID.
agent := uuid.New()
t1 := clock.Now()
t2 := t1.Add(time.Second)
t3 := t2.Add(time.Second)
// Set up pubsub capture for the flush.
psCap := newPubsubCapture(t)
// Expect the store to be called with only the newest timestamp.
store.EXPECT().
BatchUpdateWorkspaceAgentMetadata(
gomock.Any(),
matchMetadata(
[]uuid.UUID{agent},
[]string{"key1"},
[]string{"newest_value"},
[]string{""},
[]time.Time{t3},
),
).
Return(nil).
Times(1)
// Expect pubsub publish to be called when flush happens.
ps.EXPECT().
Publish(gomock.Any(), gomock.Any()).
Do(psCap.capture).
Return(nil).
Times(1)
// Add update with t2 timestamp
// Capture dropped count before any adds.
droppedBefore := prom_testutil.ToFloat64(b.Metrics.DroppedKeysTotal)
require.NoError(t, b.Add(agent, []string{"key1"}, []string{"newer_value"}, []string{""}, []time.Time{t2}))
// Try to add older update with t1 timestamp - should be ignored
require.NoError(t, b.Add(agent, []string{"key1"}, []string{"older_value"}, []string{""}, []time.Time{t1}))
// Add even newer update with t3 timestamp - should overwrite
require.NoError(t, b.Add(agent, []string{"key1"}, []string{"newest_value"}, []string{""}, []time.Time{t3}))
// Wait for all channel messages to be processed by the run() goroutine into the batch.
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
channelEmpty := len(b.updateCh) == 0
nothingDropped := prom_testutil.ToFloat64(b.Metrics.DroppedKeysTotal) == droppedBefore
batchHasExpected := int(b.currentBatchLen.Load()) == 1
return channelEmpty && nothingDropped && batchHasExpected
}, testutil.IntervalFast)
// Flush and verify entry was sent.
// Advance the full flush interval from when the batcher was created.
clock.Advance(defaultMetadataFlushInterval).MustWait(ctx)
// Wait for pubsub capture to complete and verify all agent IDs were published.
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
return psCap.count() == 1
}, testutil.IntervalFast)
psCap.requireContainsAll([]uuid.UUID{agent})
// Verify only 1 entry was flushed (newest timestamp wins)
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
return float64(1) == prom_testutil.ToFloat64(b.Metrics.BatchesTotal.WithLabelValues(flushTicker))
}, testutil.IntervalFast)
require.Equal(t, float64(1), prom_testutil.ToFloat64(b.Metrics.MetadataTotal))
}
func TestMetadataBatcher_PubsubChunking(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
ctrl := gomock.NewController(t)
store := dbmock.NewMockStore(ctrl)
ps := psmock.NewMockPubsub(ctrl)
clock := quartz.NewMock(t)
reg := prometheus.NewRegistry()
b, err := NewBatcher(ctx, reg, store, ps,
WithLogger(log),
WithClock(clock),
)
require.NoError(t, err)
t.Cleanup(b.Close)
t1 := clock.Now()
// Create enough agents to exceed maxAgentIDsPerChunk.
// With base64 encoding, each UUID is 22 characters, so we can fit
// ~363 agent IDs per chunk (8000 / 22 = 363.6).
// Let's create 600 agents to force chunking into 2 messages.
numAgents := 600
agents := make([]uuid.UUID, numAgents)
expectedKeys := make([]string, numAgents)
expectedValues := make([]string, numAgents)
expectedErrors := make([]string, numAgents)
expectedTimes := make([]time.Time, numAgents)
for i := 0; i < numAgents; i++ {
agents[i] = uuid.New()
expectedKeys[i] = "key1"
expectedValues[i] = "value1"
expectedErrors[i] = ""
expectedTimes[i] = t1
}
// Set up pubsub capture for the flush.
psCap := newPubsubCapture(t)
// With 600 agents and default batch size of 500:
// - First flush at 500 agents (capacity): 2 pubsub chunks (363 + 137)
// - Second flush at 100 agents (scheduled): 1 pubsub chunk
// Total: 3 publishes, 2 store calls
// Expect the store to be called twice - once for first 500, once for remaining 100.
store.EXPECT().
BatchUpdateWorkspaceAgentMetadata(
gomock.Any(),
matchMetadata(
agents[:500],
expectedKeys[:500],
expectedValues[:500],
expectedErrors[:500],
expectedTimes[:500],
),
).
Return(nil).
Times(1)
store.EXPECT().
BatchUpdateWorkspaceAgentMetadata(
gomock.Any(),
matchMetadata(
agents[500:],
expectedKeys[500:],
expectedValues[500:],
expectedErrors[500:],
expectedTimes[500:],
),
).
Return(nil).
Times(1)
// Expect pubsub publish to be called when flush happens.
// With base64 encoding, each UUID is 22 characters.
// With 8KB limit, we can fit ~363 agents per chunk (8000 / 22 = 363.6).
// With 600 agents and batch size of 500:
// - First flush at 500 agents: 2 chunks (363 + 137)
// - Second flush at 100 agents: 1 chunk
// Total: 3 publishes
ps.EXPECT().
Publish(gomock.Any(), gomock.Any()).
Do(psCap.capture).
Return(nil).
Times(3)
// Add first 499 metadata updates (just under the capacity threshold of 500)
// Capture dropped count before any adds.
droppedBefore := prom_testutil.ToFloat64(b.Metrics.DroppedKeysTotal)
for i := 0; i < 499; i++ {
require.NoError(t, b.Add(agents[i], []string{"key1"}, []string{"value1"}, []string{""}, []time.Time{t1}))
}
// Wait for all channel messages to be processed into the batch.
// Batch should have 499 entries, no capacity flush yet.
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
channelEmpty := len(b.updateCh) == 0
nothingDropped := prom_testutil.ToFloat64(b.Metrics.DroppedKeysTotal) == droppedBefore
batchHasExpected := int(b.currentBatchLen.Load()) == 499
return channelEmpty && nothingDropped && batchHasExpected
}, testutil.IntervalFast)
// Add next 101 metadata updates (will trigger capacity flush at 500)
for i := 499; i < numAgents; i++ {
require.NoError(t, b.Add(agents[i], []string{"key1"}, []string{"value1"}, []string{""}, []time.Time{t1}))
}
// Wait for all channel messages to be processed. The 500th entry should have
// triggered an automatic capacity flush, leaving 100 entries in the batch.
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
channelEmpty := len(b.updateCh) == 0
nothingDropped := prom_testutil.ToFloat64(b.Metrics.DroppedKeysTotal) == droppedBefore
batchHasExpected := int(b.currentBatchLen.Load()) == 100
return channelEmpty && nothingDropped && batchHasExpected
}, testutil.IntervalFast)
// Verify capacity flush metrics and total metadata count.
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
capacity := prom_testutil.ToFloat64(b.Metrics.BatchesTotal.WithLabelValues(flushCapacity))
totalMeta := prom_testutil.ToFloat64(b.Metrics.MetadataTotal)
// Should have 1 capacity flush (500 entries) so far
return capacity == float64(1) && totalMeta == float64(500)
}, testutil.IntervalFast)
// Flush remaining entries and verify all updates were processed
clock.Advance(defaultMetadataFlushInterval).MustWait(ctx)
// Wait for pubsub capture to complete and verify all agent IDs were published.
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
return psCap.count() == numAgents
}, testutil.IntervalFast)
psCap.requireContainsAll(agents)
// Verify that all metadata was flushed successfully.
// We should have 1 capacity flush (500 entries) and 1 scheduled flush (100 entries).
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
capacity := prom_testutil.ToFloat64(b.Metrics.BatchesTotal.WithLabelValues(flushCapacity))
scheduled := prom_testutil.ToFloat64(b.Metrics.BatchesTotal.WithLabelValues(flushTicker))
totalMeta := prom_testutil.ToFloat64(b.Metrics.MetadataTotal)
// Check that we've had 1 capacity flush and 1 scheduled flush
return capacity == float64(1) && scheduled == float64(1) && totalMeta == float64(600)
}, testutil.IntervalFast)
require.Equal(t, float64(numAgents), prom_testutil.ToFloat64(b.Metrics.MetadataTotal))
}
func TestMetadataBatcher_ConcurrentAddsToSameAgent(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
ctrl := gomock.NewController(t)
store := dbmock.NewMockStore(ctrl)
ps := psmock.NewMockPubsub(ctrl)
clock := quartz.NewMock(t)
reg := prometheus.NewRegistry()
b, err := NewBatcher(ctx, reg, store, ps,
WithLogger(log),
WithClock(clock),
)
require.NoError(t, err)
t.Cleanup(b.Close)
// Single agent, multiple goroutines updating same keys concurrently
agentID := uuid.New()
numGoroutines := 20
timestamps := make([]time.Time, numGoroutines)
initialTS := clock.Now()
for i := 0; i < numGoroutines; i++ {
timestamps[i] = initialTS.Add(time.Duration(i) * time.Millisecond)
}
// The latest timestamp will have the final values, since deduplication keeps the newest value for each key.
latestTimestamp := timestamps[numGoroutines-1]
latestValue := fmt.Sprintf("value_from_goroutine_%d", numGoroutines-1)
// Set up pubsub capture for the flush.
psCap := newPubsubCapture(t)
// Expect the store to be called with exactly 3 keys (after deduplication).
// The values should be from the goroutine with the latest timestamp.
store.EXPECT().
BatchUpdateWorkspaceAgentMetadata(
gomock.Any(),
matchMetadata(
[]uuid.UUID{agentID, agentID, agentID},
[]string{"key1", "key2", "key3"},
[]string{latestValue, latestValue, latestValue},
[]string{"", "", ""},
[]time.Time{latestTimestamp, latestTimestamp, latestTimestamp},
),
).
Return(nil).
Times(1)
ps.EXPECT().
Publish(gomock.Any(), gomock.Any()).
Do(psCap.capture).
Return(nil).
Times(1)
var wg sync.WaitGroup
wg.Add(numGoroutines)
// Capture dropped count before any adds.
droppedBefore := prom_testutil.ToFloat64(b.Metrics.DroppedKeysTotal)
// Each goroutine updates the same set of keys with different values
for i := 0; i < numGoroutines; i++ {
go func(routineNum int) {
defer wg.Done()
timestamp := timestamps[routineNum]
value := fmt.Sprintf("value_from_goroutine_%d", routineNum)
_ = b.Add(agentID, []string{"key1", "key2", "key3"},
[]string{value, value, value},
[]string{"", "", ""},
[]time.Time{timestamp, timestamp, timestamp})
}(i)
}
wg.Wait()
// Wait for all channel messages to be processed by the run() goroutine into the batch.
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
channelEmpty := len(b.updateCh) == 0
nothingDropped := prom_testutil.ToFloat64(b.Metrics.DroppedKeysTotal) == droppedBefore
batchHasExpected := int(b.currentBatchLen.Load()) == 3
return channelEmpty && nothingDropped && batchHasExpected
}, testutil.IntervalFast)
// Flush and check that we have exactly 3 keys (deduplication worked).
// Advance the full flush interval from when the batcher was created.
clock.Advance(defaultMetadataFlushInterval).MustWait(ctx)
// Wait for pubsub capture to complete and verify all agent IDs were published.
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
return psCap.count() == 1
}, testutil.IntervalFast)
psCap.requireContainsAll([]uuid.UUID{agentID})
// Verify exactly 3 unique keys were flushed
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
return float64(1) == prom_testutil.ToFloat64(b.Metrics.BatchesTotal.WithLabelValues(flushTicker))
}, testutil.IntervalFast)
require.Equal(t, float64(3), prom_testutil.ToFloat64(b.Metrics.MetadataTotal))
}