diff --git a/coderd/agentapi/api.go b/coderd/agentapi/api.go index 0253b27b9d..6907dcad75 100644 --- a/coderd/agentapi/api.go +++ b/coderd/agentapi/api.go @@ -17,6 +17,7 @@ import ( "cdr.dev/slog/v3" agentproto "github.com/coder/coder/v2/agent/proto" + "github.com/coder/coder/v2/coderd/agentapi/metadatabatcher" "github.com/coder/coder/v2/coderd/agentapi/resourcesmonitor" "github.com/coder/coder/v2/coderd/appearance" "github.com/coder/coder/v2/coderd/connectionlog" @@ -81,6 +82,7 @@ type Options struct { DerpMapFn func() *tailcfg.DERPMap TailnetCoordinator *atomic.Pointer[tailnet.Coordinator] StatsReporter *workspacestats.Reporter + MetadataBatcher *metadatabatcher.Batcher AppearanceFetcher *atomic.Pointer[appearance.Fetcher] PublishWorkspaceUpdateFn func(ctx context.Context, userID uuid.UUID, event wspubsub.WorkspaceEvent) PublishWorkspaceAgentLogsUpdateFn func(ctx context.Context, workspaceAgentID uuid.UUID, msg agentsdk.LogsNotifyMessage) @@ -179,8 +181,8 @@ func New(opts Options, workspace database.Workspace) *API { AgentFn: api.agent, Workspace: api.cachedWorkspaceFields, Database: opts.Database, - Pubsub: opts.Pubsub, Log: opts.Log, + Batcher: opts.MetadataBatcher, } api.LogsAPI = &LogsAPI{ diff --git a/coderd/agentapi/metadata.go b/coderd/agentapi/metadata.go index ca2708092c..67482c0317 100644 --- a/coderd/agentapi/metadata.go +++ b/coderd/agentapi/metadata.go @@ -2,27 +2,25 @@ package agentapi import ( "context" - "encoding/json" "fmt" "time" - "github.com/google/uuid" "golang.org/x/xerrors" "cdr.dev/slog/v3" agentproto "github.com/coder/coder/v2/agent/proto" + "github.com/coder/coder/v2/coderd/agentapi/metadatabatcher" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/database/dbtime" - "github.com/coder/coder/v2/coderd/database/pubsub" ) type MetadataAPI struct { AgentFn func(context.Context) (database.WorkspaceAgent, error) Workspace *CachedWorkspaceFields Database database.Store - Pubsub pubsub.Pubsub Log slog.Logger + Batcher *metadatabatcher.Batcher TimeNowFn func() time.Time // defaults to dbtime.Now() } @@ -122,21 +120,10 @@ func (a *MetadataAPI) BatchUpdateMetadata(ctx context.Context, req *agentproto.B ) } - err = a.Database.UpdateWorkspaceAgentMetadata(rbacCtx, dbUpdate) + // Use batcher to batch metadata updates. + err = a.Batcher.Add(workspaceAgent.ID, dbUpdate.Key, dbUpdate.Value, dbUpdate.Error, dbUpdate.CollectedAt) if err != nil { - return nil, xerrors.Errorf("update workspace agent metadata in database: %w", err) - } - - payload, err := json.Marshal(WorkspaceAgentMetadataChannelPayload{ - CollectedAt: collectedAt, - Keys: dbUpdate.Key, - }) - if err != nil { - return nil, xerrors.Errorf("marshal workspace agent metadata channel payload: %w", err) - } - err = a.Pubsub.Publish(WatchWorkspaceAgentMetadataChannel(workspaceAgent.ID), payload) - if err != nil { - return nil, xerrors.Errorf("publish workspace agent metadata: %w", err) + return nil, xerrors.Errorf("add metadata to batcher: %w", err) } // If the metadata keys were too large, we return an error so the agent can @@ -154,12 +141,3 @@ func ellipse(v string, n int) string { } return v } - -type WorkspaceAgentMetadataChannelPayload struct { - CollectedAt time.Time `json:"collected_at"` - Keys []string `json:"keys"` -} - -func WatchWorkspaceAgentMetadataChannel(id uuid.UUID) string { - return "workspace_agent_metadata:" + id.String() -} diff --git a/coderd/agentapi/metadata_test.go b/coderd/agentapi/metadata_test.go index 866b2a8bf2..ba5621e855 100644 --- a/coderd/agentapi/metadata_test.go +++ b/coderd/agentapi/metadata_test.go @@ -2,44 +2,26 @@ package agentapi_test import ( "context" - "database/sql" - "encoding/json" - "sync/atomic" "testing" "time" "github.com/google/uuid" "github.com/prometheus/client_golang/prometheus" + prom_testutil "github.com/prometheus/client_golang/prometheus/testutil" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" "google.golang.org/protobuf/types/known/timestamppb" agentproto "github.com/coder/coder/v2/agent/proto" "github.com/coder/coder/v2/coderd/agentapi" + "github.com/coder/coder/v2/coderd/agentapi/metadatabatcher" "github.com/coder/coder/v2/coderd/database" - "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/database/dbmock" "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/database/pubsub" - "github.com/coder/coder/v2/coderd/rbac" - "github.com/coder/coder/v2/coderd/rbac/policy" "github.com/coder/coder/v2/testutil" - "github.com/coder/quartz" ) -type fakePublisher struct { - // Nil pointer to pass interface check. - pubsub.Pubsub - publishes [][]byte -} - -var _ pubsub.Pubsub = &fakePublisher{} - -func (f *fakePublisher) Publish(_ string, message []byte) error { - f.publishes = append(f.publishes, message) - return nil -} - func TestBatchUpdateMetadata(t *testing.T) { t.Parallel() @@ -50,8 +32,12 @@ func TestBatchUpdateMetadata(t *testing.T) { t.Run("OK", func(t *testing.T) { t.Parallel() - dbM := dbmock.NewMockStore(gomock.NewController(t)) - pub := &fakePublisher{} + ctx := testutil.Context(t, testutil.WaitShort) + + ctrl := gomock.NewController(t) + store := dbmock.NewMockStore(ctrl) + ps := pubsub.NewInMemory() + reg := prometheus.NewRegistry() now := dbtime.Now() req := &agentproto.BatchUpdateMetadataRequest{ @@ -76,24 +62,30 @@ func TestBatchUpdateMetadata(t *testing.T) { }, }, } + batchSize := len(req.Metadata) + // This test sends 2 metadata entries. With batch size 2, we expect + // exactly 1 capacity flush. + store.EXPECT(). + BatchUpdateWorkspaceAgentMetadata(gomock.Any(), gomock.Any()). + Return(nil). + Times(1) - dbM.EXPECT().UpdateWorkspaceAgentMetadata(gomock.Any(), database.UpdateWorkspaceAgentMetadataParams{ - WorkspaceAgentID: agent.ID, - Key: []string{req.Metadata[0].Key, req.Metadata[1].Key}, - Value: []string{req.Metadata[0].Result.Value, req.Metadata[1].Result.Value}, - Error: []string{req.Metadata[0].Result.Error, req.Metadata[1].Result.Error}, - // The value from the agent is ignored. - CollectedAt: []time.Time{now, now}, - }).Return(nil) + // Create a real batcher for the test with batch size matching the number + // of metadata entries to trigger exactly one capacity flush. + batcher, err := metadatabatcher.NewBatcher(ctx, reg, store, ps, + metadatabatcher.WithLogger(testutil.Logger(t)), + metadatabatcher.WithBatchSize(batchSize), + ) + require.NoError(t, err) + t.Cleanup(batcher.Close) api := &agentapi.MetadataAPI{ AgentFn: func(context.Context) (database.WorkspaceAgent, error) { return agent, nil }, Workspace: &agentapi.CachedWorkspaceFields{}, - Database: dbM, - Pubsub: pub, Log: testutil.Logger(t), + Batcher: batcher, TimeNowFn: func() time.Time { return now }, @@ -103,27 +95,33 @@ func TestBatchUpdateMetadata(t *testing.T) { require.NoError(t, err) require.Equal(t, &agentproto.BatchUpdateMetadataResponse{}, resp) - require.Equal(t, 1, len(pub.publishes)) - var gotEvent agentapi.WorkspaceAgentMetadataChannelPayload - require.NoError(t, json.Unmarshal(pub.publishes[0], &gotEvent)) - require.Equal(t, agentapi.WorkspaceAgentMetadataChannelPayload{ - CollectedAt: now, - Keys: []string{req.Metadata[0].Key, req.Metadata[1].Key}, - }, gotEvent) + // Wait for the capacity flush to complete before test ends. + testutil.Eventually(ctx, t, func(ctx context.Context) bool { + return prom_testutil.ToFloat64(batcher.Metrics.MetadataTotal) == 2.0 + }, testutil.IntervalFast) }) t.Run("ExceededLength", func(t *testing.T) { t.Parallel() - dbM := dbmock.NewMockStore(gomock.NewController(t)) - pub := pubsub.NewInMemory() + ctx := testutil.Context(t, testutil.WaitShort) + ctrl := gomock.NewController(t) + store := dbmock.NewMockStore(ctrl) + ps := pubsub.NewInMemory() + reg := prometheus.NewRegistry() + // This test sends 4 metadata entries with some exceeding length limits. We set the batchers batch size so that + // we can reliably ensure a batch is sent within the WaitShort time period. + store.EXPECT(). + BatchUpdateWorkspaceAgentMetadata(gomock.Any(), gomock.Any()). + Return(nil). + Times(1) + + now := dbtime.Now() almostLongValue := "" for i := 0; i < 2048; i++ { almostLongValue += "a" } - - now := dbtime.Now() req := &agentproto.BatchUpdateMetadataRequest{ Metadata: []*agentproto.Metadata{ { @@ -152,34 +150,21 @@ func TestBatchUpdateMetadata(t *testing.T) { }, }, } - - dbM.EXPECT().UpdateWorkspaceAgentMetadata(gomock.Any(), database.UpdateWorkspaceAgentMetadataParams{ - WorkspaceAgentID: agent.ID, - Key: []string{req.Metadata[0].Key, req.Metadata[1].Key, req.Metadata[2].Key, req.Metadata[3].Key}, - Value: []string{ - almostLongValue, - almostLongValue, // truncated - "", - "", - }, - Error: []string{ - "", - "value of 2049 bytes exceeded 2048 bytes", - almostLongValue, - "error of 2049 bytes exceeded 2048 bytes", // replaced - }, - // The value from the agent is ignored. - CollectedAt: []time.Time{now, now, now, now}, - }).Return(nil) + batchSize := len(req.Metadata) + batcher, err := metadatabatcher.NewBatcher(ctx, reg, store, ps, + metadatabatcher.WithLogger(testutil.Logger(t)), + metadatabatcher.WithBatchSize(batchSize), + ) + require.NoError(t, err) + t.Cleanup(batcher.Close) api := &agentapi.MetadataAPI{ AgentFn: func(context.Context) (database.WorkspaceAgent, error) { return agent, nil }, Workspace: &agentapi.CachedWorkspaceFields{}, - Database: dbM, - Pubsub: pub, Log: testutil.Logger(t), + Batcher: batcher, TimeNowFn: func() time.Time { return now }, @@ -188,13 +173,21 @@ func TestBatchUpdateMetadata(t *testing.T) { resp, err := api.BatchUpdateMetadata(context.Background(), req) require.NoError(t, err) require.Equal(t, &agentproto.BatchUpdateMetadataResponse{}, resp) + // Wait for the capacity flush to complete before test ends. + testutil.Eventually(ctx, t, func(ctx context.Context) bool { + return prom_testutil.ToFloat64(batcher.Metrics.MetadataTotal) == 4.0 + }, testutil.IntervalFast) }) t.Run("KeysTooLong", func(t *testing.T) { t.Parallel() - dbM := dbmock.NewMockStore(gomock.NewController(t)) - pub := pubsub.NewInMemory() + ctx := testutil.Context(t, testutil.WaitShort) + + ctrl := gomock.NewController(t) + store := dbmock.NewMockStore(ctrl) + ps := pubsub.NewInMemory() + reg := prometheus.NewRegistry() now := dbtime.Now() req := &agentproto.BatchUpdateMetadataRequest{ @@ -231,595 +224,40 @@ func TestBatchUpdateMetadata(t *testing.T) { }, }, } + batchSize := len(req.Metadata) - dbM.EXPECT().UpdateWorkspaceAgentMetadata(gomock.Any(), database.UpdateWorkspaceAgentMetadataParams{ - WorkspaceAgentID: agent.ID, - // No key 4. - Key: []string{req.Metadata[0].Key, req.Metadata[1].Key, req.Metadata[2].Key}, - Value: []string{req.Metadata[0].Result.Value, req.Metadata[1].Result.Value, req.Metadata[2].Result.Value}, - Error: []string{req.Metadata[0].Result.Error, req.Metadata[1].Result.Error, req.Metadata[2].Result.Error}, - // The value from the agent is ignored. - CollectedAt: []time.Time{now, now, now}, - }).Return(nil) + // This test sends 4 metadata entries but rejects the last one due to excessive key length. + // We set the batchers batch size so that we can reliably ensure a batch is sent within the WaitShort time period. + store.EXPECT(). + BatchUpdateWorkspaceAgentMetadata(gomock.Any(), gomock.Any()). + Return(nil). + Times(1) + + batcher, err := metadatabatcher.NewBatcher(ctx, reg, store, ps, + metadatabatcher.WithLogger(testutil.Logger(t)), + metadatabatcher.WithBatchSize(batchSize-1), // one of the keys will be rejected + ) + require.NoError(t, err) + t.Cleanup(batcher.Close) api := &agentapi.MetadataAPI{ AgentFn: func(context.Context) (database.WorkspaceAgent, error) { return agent, nil }, Workspace: &agentapi.CachedWorkspaceFields{}, - Database: dbM, - Pubsub: pub, Log: testutil.Logger(t), + Batcher: batcher, TimeNowFn: func() time.Time { return now }, } - // Watch the pubsub for events. - var ( - eventCount int64 - gotEvent agentapi.WorkspaceAgentMetadataChannelPayload - ) - cancel, err := pub.Subscribe(agentapi.WatchWorkspaceAgentMetadataChannel(agent.ID), func(ctx context.Context, message []byte) { - if atomic.AddInt64(&eventCount, 1) > 1 { - return - } - require.NoError(t, json.Unmarshal(message, &gotEvent)) - }) - require.NoError(t, err) - defer cancel() - resp, err := api.BatchUpdateMetadata(context.Background(), req) + // Should return error because keys are too long. require.Error(t, err) - require.Equal(t, "metadata keys of 6145 bytes exceeded 6144 bytes", err.Error()) require.Nil(t, resp) - - require.Equal(t, int64(1), atomic.LoadInt64(&eventCount)) - require.Equal(t, agentapi.WorkspaceAgentMetadataChannelPayload{ - CollectedAt: now, - // No key 4. - Keys: []string{req.Metadata[0].Key, req.Metadata[1].Key, req.Metadata[2].Key}, - }, gotEvent) - }) - - // Test RBAC fast path with valid RBAC object - should NOT call GetWorkspaceByAgentID - // This test verifies that when a valid RBAC object is present in context, the dbauthz layer - // uses the fast path and skips the GetWorkspaceByAgentID database call. - t.Run("WorkspaceCached_SkipsDBCall", func(t *testing.T) { - t.Parallel() - - var ( - ctrl = gomock.NewController(t) - dbM = dbmock.NewMockStore(ctrl) - pub = &fakePublisher{} - now = dbtime.Now() - // Set up consistent IDs that represent a valid workspace->agent relationship - workspaceID = uuid.MustParse("12345678-1234-1234-1234-123456789012") - templateID = uuid.MustParse("aaaabbbb-cccc-dddd-eeee-ffffffff0000") - ownerID = uuid.MustParse("87654321-4321-4321-4321-210987654321") - orgID = uuid.MustParse("11111111-1111-1111-1111-111111111111") - agentID = uuid.MustParse("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa") - ) - - agent := database.WorkspaceAgent{ - ID: agentID, - // In a real scenario, this agent would belong to a resource in the workspace above - } - - req := &agentproto.BatchUpdateMetadataRequest{ - Metadata: []*agentproto.Metadata{ - { - Key: "test_key", - Result: &agentproto.WorkspaceAgentMetadata_Result{ - CollectedAt: timestamppb.New(now.Add(-time.Second)), - Age: 1, - Value: "test_value", - }, - }, - }, - } - - // Expect UpdateWorkspaceAgentMetadata to be called - dbM.EXPECT().UpdateWorkspaceAgentMetadata(gomock.Any(), database.UpdateWorkspaceAgentMetadataParams{ - WorkspaceAgentID: agent.ID, - Key: []string{"test_key"}, - Value: []string{"test_value"}, - Error: []string{""}, - CollectedAt: []time.Time{now}, - }).Return(nil) - - // DO NOT expect GetWorkspaceByAgentID - the fast path should skip this call - // If GetWorkspaceByAgentID is called, the test will fail with "unexpected call" - - // dbauthz will call Wrappers() to check for wrapped databases - dbM.EXPECT().Wrappers().Return([]string{}).AnyTimes() - - // Set up dbauthz to test the actual authorization layer - auth := rbac.NewStrictCachingAuthorizer(prometheus.NewRegistry()) - accessControlStore := &atomic.Pointer[dbauthz.AccessControlStore]{} - var acs dbauthz.AccessControlStore = dbauthz.AGPLTemplateAccessControlStore{} - accessControlStore.Store(&acs) - - api := &agentapi.MetadataAPI{ - AgentFn: func(_ context.Context) (database.WorkspaceAgent, error) { - return agent, nil - }, - Workspace: &agentapi.CachedWorkspaceFields{}, - Database: dbauthz.New(dbM, auth, testutil.Logger(t), accessControlStore), - Pubsub: pub, - Log: testutil.Logger(t), - TimeNowFn: func() time.Time { - return now - }, - } - - api.Workspace.UpdateValues(database.Workspace{ - ID: workspaceID, - OwnerID: ownerID, - OrganizationID: orgID, - }) - - // Create roles with workspace permissions - userRoles := rbac.Roles([]rbac.Role{ - { - Identifier: rbac.RoleMember(), - User: []rbac.Permission{ - { - Negate: false, - ResourceType: rbac.ResourceWorkspace.Type, - Action: policy.WildcardSymbol, - }, - }, - ByOrgID: map[string]rbac.OrgPermissions{ - orgID.String(): { - Member: []rbac.Permission{ - { - Negate: false, - ResourceType: rbac.ResourceWorkspace.Type, - Action: policy.WildcardSymbol, - }, - }, - }, - }, - }, - }) - - agentScope := rbac.WorkspaceAgentScope(rbac.WorkspaceAgentScopeParams{ - WorkspaceID: workspaceID, - OwnerID: ownerID, - TemplateID: templateID, - VersionID: uuid.New(), - }) - - ctx := dbauthz.As(context.Background(), rbac.Subject{ - Type: rbac.SubjectTypeUser, - FriendlyName: "testuser", - Email: "testuser@example.com", - ID: ownerID.String(), - Roles: userRoles, - Groups: []string{orgID.String()}, - Scope: agentScope, - }.WithCachedASTValue()) - - resp, err := api.BatchUpdateMetadata(ctx, req) - require.NoError(t, err) - require.NotNil(t, resp) - }) - // Test RBAC slow path - invalid RBAC object should fall back to GetWorkspaceByAgentID - // This test verifies that when the RBAC object has invalid IDs (nil UUIDs), the dbauthz layer - // falls back to the slow path and calls GetWorkspaceByAgentID. - t.Run("InvalidWorkspaceCached_RequiresDBCall", func(t *testing.T) { - t.Parallel() - - var ( - ctrl = gomock.NewController(t) - dbM = dbmock.NewMockStore(ctrl) - pub = &fakePublisher{} - now = dbtime.Now() - workspaceID = uuid.MustParse("12345678-1234-1234-1234-123456789012") - templateID = uuid.MustParse("aaaabbbb-cccc-dddd-eeee-ffffffff0000") - ownerID = uuid.MustParse("87654321-4321-4321-4321-210987654321") - orgID = uuid.MustParse("11111111-1111-1111-1111-111111111111") - agentID = uuid.MustParse("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb") - ) - - agent := database.WorkspaceAgent{ - ID: agentID, - } - - req := &agentproto.BatchUpdateMetadataRequest{ - Metadata: []*agentproto.Metadata{ - { - Key: "test_key", - Result: &agentproto.WorkspaceAgentMetadata_Result{ - CollectedAt: timestamppb.New(now.Add(-time.Second)), - Age: 1, - Value: "test_value", - }, - }, - }, - } - - // EXPECT GetWorkspaceByAgentID to be called because the RBAC fast path validation fails - dbM.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agentID).Return(database.Workspace{ - ID: workspaceID, - OwnerID: ownerID, - OrganizationID: orgID, - }, nil) - - // Expect UpdateWorkspaceAgentMetadata to be called after authorization - dbM.EXPECT().UpdateWorkspaceAgentMetadata(gomock.Any(), database.UpdateWorkspaceAgentMetadataParams{ - WorkspaceAgentID: agent.ID, - Key: []string{"test_key"}, - Value: []string{"test_value"}, - Error: []string{""}, - CollectedAt: []time.Time{now}, - }).Return(nil) - - // dbauthz will call Wrappers() to check for wrapped databases - dbM.EXPECT().Wrappers().Return([]string{}).AnyTimes() - - // Set up dbauthz to test the actual authorization layer - auth := rbac.NewStrictCachingAuthorizer(prometheus.NewRegistry()) - accessControlStore := &atomic.Pointer[dbauthz.AccessControlStore]{} - var acs dbauthz.AccessControlStore = dbauthz.AGPLTemplateAccessControlStore{} - accessControlStore.Store(&acs) - - api := &agentapi.MetadataAPI{ - AgentFn: func(_ context.Context) (database.WorkspaceAgent, error) { - return agent, nil - }, - - Workspace: &agentapi.CachedWorkspaceFields{}, - Database: dbauthz.New(dbM, auth, testutil.Logger(t), accessControlStore), - Pubsub: pub, - Log: testutil.Logger(t), - TimeNowFn: func() time.Time { - return now - }, - } - - // Create an invalid RBAC object with nil UUIDs for owner/org - // This will fail dbauthz fast path validation and trigger GetWorkspaceByAgentID - api.Workspace.UpdateValues(database.Workspace{ - ID: uuid.MustParse("cccccccc-cccc-cccc-cccc-cccccccccccc"), - OwnerID: uuid.Nil, // Invalid: fails dbauthz fast path validation - OrganizationID: uuid.Nil, // Invalid: fails dbauthz fast path validation - }) - - // Create roles with workspace permissions - userRoles := rbac.Roles([]rbac.Role{ - { - Identifier: rbac.RoleMember(), - User: []rbac.Permission{ - { - Negate: false, - ResourceType: rbac.ResourceWorkspace.Type, - Action: policy.WildcardSymbol, - }, - }, - ByOrgID: map[string]rbac.OrgPermissions{ - orgID.String(): { - Member: []rbac.Permission{ - { - Negate: false, - ResourceType: rbac.ResourceWorkspace.Type, - Action: policy.WildcardSymbol, - }, - }, - }, - }, - }, - }) - - agentScope := rbac.WorkspaceAgentScope(rbac.WorkspaceAgentScopeParams{ - WorkspaceID: workspaceID, - OwnerID: ownerID, - TemplateID: templateID, - VersionID: uuid.New(), - }) - - ctx := dbauthz.As(context.Background(), rbac.Subject{ - Type: rbac.SubjectTypeUser, - FriendlyName: "testuser", - Email: "testuser@example.com", - ID: ownerID.String(), - Roles: userRoles, - Groups: []string{orgID.String()}, - Scope: agentScope, - }.WithCachedASTValue()) - - resp, err := api.BatchUpdateMetadata(ctx, req) - require.NoError(t, err) - require.NotNil(t, resp) - }) - - // Test RBAC slow path - no RBAC object in context - // This test verifies that when no RBAC object is present in context, the dbauthz layer - // falls back to the slow path and calls GetWorkspaceByAgentID. - t.Run("WorkspaceNotCached_RequiresDBCall", func(t *testing.T) { - t.Parallel() - - var ( - ctrl = gomock.NewController(t) - dbM = dbmock.NewMockStore(ctrl) - pub = &fakePublisher{} - now = dbtime.Now() - workspaceID = uuid.MustParse("12345678-1234-1234-1234-123456789012") - templateID = uuid.MustParse("aaaabbbb-cccc-dddd-eeee-ffffffff0000") - ownerID = uuid.MustParse("87654321-4321-4321-4321-210987654321") - orgID = uuid.MustParse("11111111-1111-1111-1111-111111111111") - agentID = uuid.MustParse("dddddddd-dddd-dddd-dddd-dddddddddddd") - ) - - agent := database.WorkspaceAgent{ - ID: agentID, - } - - req := &agentproto.BatchUpdateMetadataRequest{ - Metadata: []*agentproto.Metadata{ - { - Key: "test_key", - Result: &agentproto.WorkspaceAgentMetadata_Result{ - CollectedAt: timestamppb.New(now.Add(-time.Second)), - Age: 1, - Value: "test_value", - }, - }, - }, - } - - // EXPECT GetWorkspaceByAgentID to be called because no RBAC object is in context - dbM.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agentID).Return(database.Workspace{ - ID: workspaceID, - OwnerID: ownerID, - OrganizationID: orgID, - }, nil) - - // Expect UpdateWorkspaceAgentMetadata to be called after authorization - dbM.EXPECT().UpdateWorkspaceAgentMetadata(gomock.Any(), database.UpdateWorkspaceAgentMetadataParams{ - WorkspaceAgentID: agent.ID, - Key: []string{"test_key"}, - Value: []string{"test_value"}, - Error: []string{""}, - CollectedAt: []time.Time{now}, - }).Return(nil) - - // dbauthz will call Wrappers() to check for wrapped databases - dbM.EXPECT().Wrappers().Return([]string{}).AnyTimes() - - // Set up dbauthz to test the actual authorization layer - auth := rbac.NewStrictCachingAuthorizer(prometheus.NewRegistry()) - accessControlStore := &atomic.Pointer[dbauthz.AccessControlStore]{} - var acs dbauthz.AccessControlStore = dbauthz.AGPLTemplateAccessControlStore{} - accessControlStore.Store(&acs) - - api := &agentapi.MetadataAPI{ - AgentFn: func(_ context.Context) (database.WorkspaceAgent, error) { - return agent, nil - }, - Workspace: &agentapi.CachedWorkspaceFields{}, - Database: dbauthz.New(dbM, auth, testutil.Logger(t), accessControlStore), - Pubsub: pub, - Log: testutil.Logger(t), - TimeNowFn: func() time.Time { - return now - }, - } - - // Create roles with workspace permissions - userRoles := rbac.Roles([]rbac.Role{ - { - Identifier: rbac.RoleMember(), - User: []rbac.Permission{ - { - Negate: false, - ResourceType: rbac.ResourceWorkspace.Type, - Action: policy.WildcardSymbol, - }, - }, - ByOrgID: map[string]rbac.OrgPermissions{ - orgID.String(): { - Member: []rbac.Permission{ - { - Negate: false, - ResourceType: rbac.ResourceWorkspace.Type, - Action: policy.WildcardSymbol, - }, - }, - }, - }, - }, - }) - - agentScope := rbac.WorkspaceAgentScope(rbac.WorkspaceAgentScopeParams{ - WorkspaceID: workspaceID, - OwnerID: ownerID, - TemplateID: templateID, - VersionID: uuid.New(), - }) - - ctx := dbauthz.As(context.Background(), rbac.Subject{ - Type: rbac.SubjectTypeUser, - FriendlyName: "testuser", - Email: "testuser@example.com", - ID: ownerID.String(), - Roles: userRoles, - Groups: []string{orgID.String()}, - Scope: agentScope, - }.WithCachedASTValue()) - - resp, err := api.BatchUpdateMetadata(ctx, req) - require.NoError(t, err) - require.NotNil(t, resp) - }) - - // Test cache refresh - AutostartSchedule updated - // This test verifies that the cache refresh mechanism actually calls GetWorkspaceByID - // and updates the cached workspace fields when the workspace is modified (e.g., autostart schedule changes). - t.Run("CacheRefreshed_AutostartScheduleUpdated", func(t *testing.T) { - t.Parallel() - - var ( - ctrl = gomock.NewController(t) - dbM = dbmock.NewMockStore(ctrl) - pub = &fakePublisher{} - now = dbtime.Now() - mClock = quartz.NewMock(t) - tickerTrap = mClock.Trap().TickerFunc("cache_refresh") - - workspaceID = uuid.MustParse("12345678-1234-1234-1234-123456789012") - ownerID = uuid.MustParse("87654321-4321-4321-4321-210987654321") - orgID = uuid.MustParse("11111111-1111-1111-1111-111111111111") - templateID = uuid.MustParse("aaaabbbb-cccc-dddd-eeee-ffffffff0000") - agentID = uuid.MustParse("ffffffff-ffff-ffff-ffff-ffffffffffff") - ) - - agent := database.WorkspaceAgent{ - ID: agentID, - } - - // Initial workspace - has Monday-Friday 9am autostart - initialWorkspace := database.Workspace{ - ID: workspaceID, - OwnerID: ownerID, - OrganizationID: orgID, - TemplateID: templateID, - Name: "my-workspace", - OwnerUsername: "testuser", - TemplateName: "test-template", - AutostartSchedule: sql.NullString{Valid: true, String: "CRON_TZ=UTC 0 9 * * 1-5"}, - } - - // Updated workspace - user changed autostart to 5pm and renamed workspace - updatedWorkspace := database.Workspace{ - ID: workspaceID, - OwnerID: ownerID, - OrganizationID: orgID, - TemplateID: templateID, - Name: "my-workspace-renamed", // Changed! - OwnerUsername: "testuser", - TemplateName: "test-template", - AutostartSchedule: sql.NullString{Valid: true, String: "CRON_TZ=UTC 0 17 * * 1-5"}, // Changed! - DormantAt: sql.NullTime{}, - } - - req := &agentproto.BatchUpdateMetadataRequest{ - Metadata: []*agentproto.Metadata{ - { - Key: "test_key", - Result: &agentproto.WorkspaceAgentMetadata_Result{ - CollectedAt: timestamppb.New(now.Add(-time.Second)), - Age: 1, - Value: "test_value", - }, - }, - }, - } - - // EXPECT GetWorkspaceByID to be called during cache refresh - // This is the key assertion - proves the refresh mechanism is working - dbM.EXPECT().GetWorkspaceByID(gomock.Any(), workspaceID).Return(updatedWorkspace, nil) - - // API needs to fetch the agent when calling metadata update - dbM.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).Return(agent, nil) - - // After refresh, metadata update should work with updated cache - dbM.EXPECT().UpdateWorkspaceAgentMetadata(gomock.Any(), gomock.Any()).DoAndReturn( - func(ctx context.Context, params database.UpdateWorkspaceAgentMetadataParams) error { - require.Equal(t, agent.ID, params.WorkspaceAgentID) - require.Equal(t, []string{"test_key"}, params.Key) - require.Equal(t, []string{"test_value"}, params.Value) - require.Equal(t, []string{""}, params.Error) - require.Len(t, params.CollectedAt, 1) - return nil - }, - ).AnyTimes() - - // May call GetWorkspaceByAgentID if slow path is used before refresh - dbM.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agentID).Return(updatedWorkspace, nil).AnyTimes() - - // dbauthz will call Wrappers() - dbM.EXPECT().Wrappers().Return([]string{}).AnyTimes() - - // Set up dbauthz - auth := rbac.NewStrictCachingAuthorizer(prometheus.NewRegistry()) - accessControlStore := &atomic.Pointer[dbauthz.AccessControlStore]{} - var acs dbauthz.AccessControlStore = dbauthz.AGPLTemplateAccessControlStore{} - accessControlStore.Store(&acs) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - // Create roles with workspace permissions - userRoles := rbac.Roles([]rbac.Role{ - { - Identifier: rbac.RoleMember(), - User: []rbac.Permission{ - { - Negate: false, - ResourceType: rbac.ResourceWorkspace.Type, - Action: policy.WildcardSymbol, - }, - }, - ByOrgID: map[string]rbac.OrgPermissions{ - orgID.String(): { - Member: []rbac.Permission{ - { - Negate: false, - ResourceType: rbac.ResourceWorkspace.Type, - Action: policy.WildcardSymbol, - }, - }, - }, - }, - }, - }) - - agentScope := rbac.WorkspaceAgentScope(rbac.WorkspaceAgentScopeParams{ - WorkspaceID: workspaceID, - OwnerID: ownerID, - TemplateID: templateID, - VersionID: uuid.New(), - }) - - ctxWithActor := dbauthz.As(ctx, rbac.Subject{ - Type: rbac.SubjectTypeUser, - FriendlyName: "testuser", - Email: "testuser@example.com", - ID: ownerID.String(), - Roles: userRoles, - Groups: []string{orgID.String()}, - Scope: agentScope, - }.WithCachedASTValue()) - - // Create full API with cached workspace fields (initial state) - api := agentapi.New(agentapi.Options{ - AuthenticatedCtx: ctxWithActor, - AgentID: agentID, - WorkspaceID: workspaceID, - OwnerID: ownerID, - OrganizationID: orgID, - Database: dbauthz.New(dbM, auth, testutil.Logger(t), accessControlStore), - Log: testutil.Logger(t), - Clock: mClock, - Pubsub: pub, - }, initialWorkspace) // Cache is initialized with 9am schedule and "my-workspace" name - - // Wait for ticker to be set up and release it so it can fire - tickerTrap.MustWait(ctx).MustRelease(ctx) - tickerTrap.Close() - - // Advance clock to trigger cache refresh and wait for it to complete - _, aw := mClock.AdvanceNext() - aw.MustWait(ctx) - - // At this point, GetWorkspaceByID should have been called and cache updated - // The cache now has the 5pm schedule and "my-workspace-renamed" name - - // Now call metadata update to verify the refreshed cache works - resp, err := api.MetadataAPI.BatchUpdateMetadata(ctxWithActor, req) - require.NoError(t, err) - require.NotNil(t, resp) + testutil.Eventually(ctx, t, func(ctx context.Context) bool { + return prom_testutil.ToFloat64(batcher.Metrics.MetadataTotal) == 3.0 + }, testutil.IntervalFast) }) } diff --git a/coderd/agentapi/metadatabatcher/agentid_chunks.go b/coderd/agentapi/metadatabatcher/agentid_chunks.go new file mode 100644 index 0000000000..a0932118fa --- /dev/null +++ b/coderd/agentapi/metadatabatcher/agentid_chunks.go @@ -0,0 +1,59 @@ +package metadatabatcher + +import ( + "encoding/base64" + + "github.com/google/uuid" + "golang.org/x/xerrors" +) + +const ( + // uuidBase64Size is the size of a base64-encoded UUID without padding (22 characters). + UUIDBase64Size = 22 + + // maxAgentIDsPerChunk is the maximum number of agent IDs that can fit in a + // single pubsub message. PostgreSQL NOTIFY has an 8KB limit. + // With base64 encoding, each UUID is 22 characters, so we can fit + // ~363 agent IDs per chunk (8000 / 22 = 363.6). + maxAgentIDsPerChunk = maxPubsubPayloadSize / UUIDBase64Size +) + +func EncodeAgentID(agentID uuid.UUID, dst []byte) error { + // Encode UUID bytes to base64 without padding (RawStdEncoding). + // This produces exactly 22 characters per UUID. + reqLen := base64.RawStdEncoding.EncodedLen(len(agentID)) + if len(dst) < reqLen { + return xerrors.Errorf("destination byte slice was too small %d, required %d", len(dst), reqLen) + } + base64.RawStdEncoding.Encode(dst, agentID[:]) + return nil +} + +// EncodeAgentIDChunks encodes agent IDs into chunks that fit within the +// PostgreSQL NOTIFY 8KB payload size limit. Each UUID is base64-encoded +// (without padding) and concatenated into a single byte slice per chunk. +func EncodeAgentIDChunks(agentIDs []uuid.UUID) ([][]byte, error) { + chunks := make([][]byte, 0, (len(agentIDs)+maxAgentIDsPerChunk-1)/maxAgentIDsPerChunk) + + for i := 0; i < len(agentIDs); i += maxAgentIDsPerChunk { + end := i + maxAgentIDsPerChunk + if end > len(agentIDs) { + end = len(agentIDs) + } + + chunk := agentIDs[i:end] + + // Build payload by base64-encoding each UUID (without padding) and + // concatenating them. This is UTF-8 safe for PostgreSQL NOTIFY. + payload := make([]byte, len(chunk)*UUIDBase64Size) + for i, agentID := range chunk { + err := EncodeAgentID(agentID, payload[i*UUIDBase64Size:(i+1)*UUIDBase64Size]) + if err != nil { + return nil, err + } + } + chunks = append(chunks, payload) + } + + return chunks, nil +} diff --git a/coderd/agentapi/metadatabatcher/agentid_chunks_test.go b/coderd/agentapi/metadatabatcher/agentid_chunks_test.go new file mode 100644 index 0000000000..68119dd08b --- /dev/null +++ b/coderd/agentapi/metadatabatcher/agentid_chunks_test.go @@ -0,0 +1,122 @@ +package metadatabatcher_test + +import ( + "encoding/base64" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/coderd/agentapi/metadatabatcher" +) + +func TestEncodeDecodeRoundTrip(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + agentIDs []uuid.UUID + }{ + { + name: "Empty", + agentIDs: []uuid.UUID{}, + }, + { + name: "Single", + agentIDs: []uuid.UUID{uuid.New()}, + }, + { + name: "Multiple", + agentIDs: []uuid.UUID{ + uuid.New(), + uuid.New(), + uuid.New(), + }, + }, + { + name: "Exactly 363 (one chunk)", + agentIDs: func() []uuid.UUID { + ids := make([]uuid.UUID, 363) + for i := range ids { + ids[i] = uuid.New() + } + return ids + }(), + }, + { + name: "364 (two chunks)", + agentIDs: func() []uuid.UUID { + ids := make([]uuid.UUID, 364) + for i := range ids { + ids[i] = uuid.New() + } + return ids + }(), + }, + { + name: "600 (multiple chunks)", + agentIDs: func() []uuid.UUID { + ids := make([]uuid.UUID, 600) + for i := range ids { + ids[i] = uuid.New() + } + return ids + }(), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + // Encode the agent IDs into chunks. + chunks, err := metadatabatcher.EncodeAgentIDChunks(tt.agentIDs) + require.NoError(t, err) + + // Decode all chunks and collect the agent IDs. + var decoded []uuid.UUID + for _, chunk := range chunks { + for i := 0; i < len(chunk); i += metadatabatcher.UUIDBase64Size { + var u uuid.UUID + _, err := base64.RawStdEncoding.Decode(u[:], chunk[i:i+metadatabatcher.UUIDBase64Size]) + require.NoError(t, err) + decoded = append(decoded, u) + } + } + + // Verify we got the same agent IDs back. + if len(tt.agentIDs) == 0 { + require.Empty(t, decoded) + } else { + require.Equal(t, tt.agentIDs, decoded) + } + }) + } +} + +// TestEncodeAgentIDChunks_PGPubsubSize ensures that each pubsub message generated via EncodeAgentIDChunks fits within +// the max allowed 8kb by Postgres. +func TestEncodeAgentIDChunks_PGPubsubSize(t *testing.T) { + t.Parallel() + + // Create 600 agents (should split into 2 chunks: 363 + 237). + agentIDs := make([]uuid.UUID, 600) + for i := range agentIDs { + agentIDs[i] = uuid.New() + } + + chunks, err := metadatabatcher.EncodeAgentIDChunks(agentIDs) + require.NoError(t, err) + require.Len(t, chunks, 2) + + // First chunk should have 363 IDs (363 * 22 = 7986 bytes). + require.Equal(t, 363*22, len(chunks[0])) + + // Second chunk should have 237 IDs (237 * 22 = 5214 bytes). + require.Equal(t, 237*22, len(chunks[1])) + + // Each chunk should be under 8KB. + for i, chunk := range chunks { + require.LessOrEqual(t, len(chunk), 8000, "chunk %d exceeds 8KB limit", i) + } +} diff --git a/coderd/agentapi/metadatabatcher/metadata_batcher.go b/coderd/agentapi/metadatabatcher/metadata_batcher.go new file mode 100644 index 0000000000..c5322d5976 --- /dev/null +++ b/coderd/agentapi/metadatabatcher/metadata_batcher.go @@ -0,0 +1,398 @@ +package metadatabatcher + +import ( + "context" + "sync/atomic" + "time" + + "github.com/google/uuid" + "github.com/prometheus/client_golang/prometheus" + "golang.org/x/xerrors" + + "cdr.dev/slog/v3" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbauthz" + "github.com/coder/coder/v2/coderd/database/pubsub" + "github.com/coder/quartz" +) + +const ( + // defaultMetadataBatchSize is the maximum number of metadata entries + // (key-value pairs across all agents) to batch before forcing a flush. + // With typical agents having 5-15 metadata keys, this accommodates + // 30-100 agents per batch. + defaultMetadataBatchSize = 500 + + // defaultChannelBufferMultiplier is the multiplier for the channel buffer size + // relative to the batch size. A 5x multiplier provides significant headroom + // for bursts while the batch is being flushed. + defaultChannelBufferMultiplier = 5 + + // defaultMetadataFlushInterval is how frequently to flush batched metadata + // updates to the database and pubsub. 5 seconds provides a good balance + // between reducing database load and maintaining reasonable UI update + // latency. + defaultMetadataFlushInterval = 5 * time.Second + + // maxPubsubPayloadSize is the maximum size of a single pubsub message. + // PostgreSQL NOTIFY has an 8KB limit for the payload. + maxPubsubPayloadSize = 8000 // Leave some headroom below 8192 bytes + + // Timeout to use for the context created when flushing the final batch due to the top level context being 'Done' + finalFlushTimeout = 15 * time.Second + + // Channel to publish batch metadata updates to, each update contains a list of all Agent IDs that have an update in + // the most recent batch + MetadataBatchPubsubChannel = "workspace_agent_metadata_batch" + + // flush reasons + flushCapacity = "capacity" + flushTicker = "scheduled" + flushExit = "shutdown" +) + +// compositeKey uniquely identifies a metadata entry by agent ID and key name. +type compositeKey struct { + agentID uuid.UUID + key string +} + +// value holds a single metadata key-value pair with its error state +// and collection timestamp. +type value struct { + v string + error string + collectedAt time.Time +} + +// update represents a single metadata update to be batched. +type update struct { + compositeKey + value +} + +// Batcher holds a buffer of agent metadata updates and periodically +// flushes them to the database and pubsub. This reduces database write +// frequency and pubsub publish rate. +type Batcher struct { + store database.Store + ps pubsub.Pubsub + log slog.Logger + + // updateCh is the buffered channel that receives metadata updates from Add() calls. + updateCh chan update + + // batch holds the current batch being accumulated. For updates with the same composite key the most recent value wins. + batch map[compositeKey]value + currentBatchLen atomic.Int64 + maxBatchSize int + + clock quartz.Clock + timer *quartz.Timer + interval time.Duration + // Used to only log at warn level for dropped keys infrequently, as it could be noisy in failure scenarios. + warnTicker *quartz.Ticker + + // ctx is the context for the batcher. Used to check if shutdown has begun. + ctx context.Context + cancel context.CancelFunc + done chan struct{} + + // Metrics collects Prometheus metrics for the batcher. + Metrics Metrics +} + +// Option is a functional option for configuring a Batcher. +type Option func(b *Batcher) + +func WithBatchSize(size int) Option { + return func(b *Batcher) { + b.maxBatchSize = size + } +} + +func WithInterval(d time.Duration) Option { + return func(b *Batcher) { + b.interval = d + } +} + +func WithLogger(log slog.Logger) Option { + return func(b *Batcher) { + b.log = log + } +} + +func WithClock(clock quartz.Clock) Option { + return func(b *Batcher) { + b.clock = clock + } +} + +// NewBatcher creates a new Batcher and starts it. Here ctx controls the lifetime of the batcher, canceling it will +// result in the Batcher exiting it's processing routine (run). +func NewBatcher(ctx context.Context, reg prometheus.Registerer, store database.Store, ps pubsub.Pubsub, opts ...Option) (*Batcher, error) { + b := &Batcher{ + store: store, + ps: ps, + Metrics: NewMetrics(), + done: make(chan struct{}), + log: slog.Logger{}, + clock: quartz.NewReal(), + } + + for _, opt := range opts { + opt(b) + } + + b.Metrics.register(reg) + + if b.interval == 0 { + b.interval = defaultMetadataFlushInterval + } + + if b.maxBatchSize == 0 { + b.maxBatchSize = defaultMetadataBatchSize + } + + // Create warn ticker after options are applied so it uses the correct clock. + b.warnTicker = b.clock.NewTicker(10 * time.Second) + + if b.timer == nil { + b.timer = b.clock.NewTimer(b.interval) + } + + // Create buffered channel with 5x batch size capacity + channelSize := b.maxBatchSize * defaultChannelBufferMultiplier + b.updateCh = make(chan update, channelSize) + + // Initialize batch map + b.batch = make(map[compositeKey]value) + + b.ctx, b.cancel = context.WithCancel(ctx) + go func() { + b.run(b.ctx) + close(b.done) + }() + + return b, nil +} + +func (b *Batcher) Close() { + b.cancel() + if b.timer != nil { + b.timer.Stop() + } + // Wait for the run function to end, it may be sending one last batch. + <-b.done +} + +// Add adds metadata updates for an agent to the batcher by writing to a +// buffered channel. If the channel is full, updates are dropped. Updates +// to the same metadata key for the same agent are deduplicated in the batch, +// keeping only the value with the most recent collectedAt timestamp. +func (b *Batcher) Add(agentID uuid.UUID, keys []string, values []string, errors []string, collectedAt []time.Time) error { + if !(len(keys) == len(values) && len(values) == len(errors) && len(errors) == len(collectedAt)) { + return xerrors.Errorf("invalid Add call, all inputs must have the same number of items; keys: %d, values: %d, errors: %d, collectedAt: %d", len(keys), len(values), len(errors), len(collectedAt)) + } + + // Write each update to the channel. If the channel is full, drop the update. + var u update + droppedCount := 0 + for i := range keys { + u.agentID = agentID + u.key = keys[i] + u.v = values[i] + u.error = errors[i] + u.collectedAt = collectedAt[i] + + select { + case b.updateCh <- u: + // Successfully queued + default: + // Channel is full, drop this update + droppedCount++ + } + } + + // Log dropped keys if any were dropped. + if droppedCount > 0 { + msg := "metadata channel at capacity, dropped updates" + fields := []slog.Field{ + slog.F("agent_id", agentID), + slog.F("channel_size", cap(b.updateCh)), + slog.F("dropped_count", droppedCount), + } + select { + case <-b.warnTicker.C: + b.log.Warn(context.Background(), msg, fields...) + default: + b.log.Debug(context.Background(), msg, fields...) + } + + b.Metrics.DroppedKeysTotal.Add(float64(droppedCount)) + } + + return nil +} + +// processUpdate adds a metadata update to the batch with deduplication based on timestamp. +func (b *Batcher) processUpdate(update update) { + ck := compositeKey{ + agentID: update.agentID, + key: update.key, + } + + // Check if key already exists and only update if new value is newer. + existing, exists := b.batch[ck] + if exists && update.collectedAt.Before(existing.collectedAt) { + return + } + + b.batch[ck] = value{ + v: update.v, + error: update.error, + collectedAt: update.collectedAt, + } + if !exists { + b.currentBatchLen.Add(1) + } +} + +// run runs the batcher loop, reading from the update channel and flushing +// periodically or when the batch reaches capacity. +func (b *Batcher) run(ctx context.Context) { + // nolint:gocritic // This is only ever used for one thing - updating agent metadata. + authCtx := dbauthz.AsSystemRestricted(ctx) + for { + select { + case update := <-b.updateCh: + b.processUpdate(update) + + // Check if batch has reached capacity + if int(b.currentBatchLen.Load()) >= b.maxBatchSize { + b.flush(authCtx, flushCapacity) + // Reset timer so the next scheduled flush is interval duration + // from now, not from when it was originally scheduled. + b.timer.Reset(b.interval, "metadataBatcher", "capacityFlush") + } + + case <-b.timer.C: + b.flush(authCtx, flushTicker) + // Reset timer to schedule the next flush. + b.timer.Reset(b.interval, "metadataBatcher", "scheduledFlush") + + case <-ctx.Done(): + b.log.Debug(ctx, "context done, flushing before exit") + + // We must create a new context here as the parent context is done. + ctxTimeout, cancel := context.WithTimeout(context.Background(), finalFlushTimeout) + defer cancel() //nolint:revive // We're returning, defer is fine. + + // nolint:gocritic // This is only ever used for one thing - updating agent metadata. + b.flush(dbauthz.AsSystemRestricted(ctxTimeout), flushExit) + return + } + } +} + +// flush flushes the current batch to the database and pubsub. +func (b *Batcher) flush(ctx context.Context, reason string) { + count := len(b.batch) + + if count == 0 { + return + } + + start := b.clock.Now() + b.log.Debug(ctx, "flushing metadata batch", + slog.F("reason", reason), + slog.F("count", count), + ) + + // Convert batch map to parallel arrays for the batch query. + // Also build map of agent IDs for per-agent metrics and pubsub. + var ( + agentIDs = make([]uuid.UUID, 0, count) + keys = make([]string, 0, count) + values = make([]string, 0, count) + errors = make([]string, 0, count) + collectedAt = make([]time.Time, 0, count) + agentKeys = make(map[uuid.UUID]int) // Track keys per agent for metrics + ) + + for ck, mv := range b.batch { + agentIDs = append(agentIDs, ck.agentID) + keys = append(keys, ck.key) + values = append(values, mv.v) + errors = append(errors, mv.error) + collectedAt = append(collectedAt, mv.collectedAt) + agentKeys[ck.agentID]++ + } + + // Batch has been processed into slices for our DB request, so we can clear it. + // It's safe to clear before we know whether the flush is successful as agent metadata is not critical, and therefore + // we do not retry failed flushes and losing a batch of metadata is okay. + b.batch = make(map[compositeKey]value) + b.currentBatchLen.Store(0) + + // Record per-agent utilization metrics. + for _, keyCount := range agentKeys { + b.Metrics.BatchUtilization.Observe(float64(keyCount)) + } + + // Update the database with all metadata updates in a single query. + err := b.store.BatchUpdateWorkspaceAgentMetadata(ctx, database.BatchUpdateWorkspaceAgentMetadataParams{ + WorkspaceAgentID: agentIDs, + Key: keys, + Value: values, + Error: errors, + CollectedAt: collectedAt, + }) + elapsed := b.clock.Since(start) + + if err != nil { + if database.IsQueryCanceledError(err) { + b.log.Debug(ctx, "query canceled, skipping update of workspace agent metadata", slog.F("elapsed", elapsed)) + return + } + b.log.Error(ctx, "error updating workspace agent metadata", slog.Error(err), slog.F("elapsed", elapsed)) + return + } + + // Build list of unique agent IDs for pubsub notification. + uniqueAgentIDs := make([]uuid.UUID, 0, len(agentKeys)) + for agentID := range agentKeys { + uniqueAgentIDs = append(uniqueAgentIDs, agentID) + } + + // Encode agent IDs into chunks and publish them. + chunks, err := EncodeAgentIDChunks(uniqueAgentIDs) + if err != nil { + b.log.Error(ctx, "Agent ID chunk encoding for pubsub failed", + slog.Error(err)) + } + for _, chunk := range chunks { + if err := b.ps.Publish(MetadataBatchPubsubChannel, chunk); err != nil { + b.log.Error(ctx, "failed to publish workspace agent metadata batch", + slog.Error(err), + slog.F("chunk_size", len(chunk)/UUIDBase64Size), + slog.F("payload_size", len(chunk)), + ) + b.Metrics.PublishErrors.Inc() + } + } + + // Record successful batch size and flush duration after successful send/publish. + b.Metrics.BatchSize.Observe(float64(count)) + b.Metrics.MetadataTotal.Add(float64(count)) + b.Metrics.BatchesTotal.WithLabelValues(reason).Inc() + b.Metrics.FlushDuration.WithLabelValues(reason).Observe(time.Since(start).Seconds()) + + elapsed = time.Since(start) + b.log.Debug(ctx, "flush complete", + slog.F("count", count), + slog.F("elapsed", elapsed), + slog.F("reason", reason), + ) +} diff --git a/coderd/agentapi/metadatabatcher/metadata_batcher_internal_test.go b/coderd/agentapi/metadatabatcher/metadata_batcher_internal_test.go new file mode 100644 index 0000000000..cc27da299f --- /dev/null +++ b/coderd/agentapi/metadatabatcher/metadata_batcher_internal_test.go @@ -0,0 +1,1008 @@ +package metadatabatcher + +import ( + "context" + "encoding/base64" + "fmt" + "sync" + "testing" + "time" + + "github.com/google/uuid" + "github.com/prometheus/client_golang/prometheus" + prom_testutil "github.com/prometheus/client_golang/prometheus/testutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "cdr.dev/slog/v3" + "cdr.dev/slog/v3/sloggers/slogtest" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbmock" + "github.com/coder/coder/v2/coderd/database/pubsub/psmock" + "github.com/coder/coder/v2/testutil" + "github.com/coder/quartz" +) + +// ============================================================================ +// Custom gomock matchers for metadata batcher testing +// ============================================================================ + +// metadataParamsMatcher validates BatchUpdateWorkspaceAgentMetadataParams by checking all fields match expected values. +type metadataParamsMatcher struct { + expectedAgentIDs []uuid.UUID + expectedKeys []string + expectedValues []string + expectedErrors []string + expectedTimes []time.Time +} + +func (m metadataParamsMatcher) Matches(x interface{}) bool { + params, ok := x.(database.BatchUpdateWorkspaceAgentMetadataParams) + if !ok { + return false + } + + // All arrays must have the same length. + expectedLen := len(m.expectedKeys) + if len(params.WorkspaceAgentID) != expectedLen || + len(params.Key) != expectedLen || + len(params.Value) != expectedLen || + len(params.Error) != expectedLen || + len(params.CollectedAt) != expectedLen { + return false + } + + // Check each field matches expected values. We create a map of expected entries and verify all actual entries match. + expectedEntries := make(map[string]bool) + for i := 0; i < len(m.expectedKeys); i++ { + key := fmt.Sprintf("%s|%s|%s|%s|%s", + m.expectedAgentIDs[i].String(), + m.expectedKeys[i], + m.expectedValues[i], + m.expectedErrors[i], + m.expectedTimes[i].Format(time.RFC3339Nano)) + expectedEntries[key] = false // not yet found + } + + // Check all actual entries are expected. + for i := 0; i < len(params.Key); i++ { + key := fmt.Sprintf("%s|%s|%s|%s|%s", + params.WorkspaceAgentID[i].String(), + params.Key[i], + params.Value[i], + params.Error[i], + params.CollectedAt[i].Format(time.RFC3339Nano)) + + if _, exists := expectedEntries[key]; !exists { + return false + } + expectedEntries[key] = true + } + + // Check all expected entries were found. + for _, found := range expectedEntries { + if !found { + return false + } + } + + return true +} + +func (m metadataParamsMatcher) String() string { + return fmt.Sprintf("metadata params with %d entries (agents: %v, keys: %v)", + len(m.expectedKeys), m.expectedAgentIDs, m.expectedKeys) +} + +// matchMetadata creates a matcher that checks all values in the metadata params. +func matchMetadata(agentIDs []uuid.UUID, keys, values, errors []string, times []time.Time) gomock.Matcher { + return metadataParamsMatcher{ + expectedAgentIDs: agentIDs, + expectedKeys: keys, + expectedValues: values, + expectedErrors: errors, + expectedTimes: times, + } +} + +// pubsubCapture captures and decodes pubsub publish calls to accumulate agent IDs. +type pubsubCapture struct { + t *testing.T + mu sync.Mutex + + agentIDs map[uuid.UUID]struct{} +} + +func newPubsubCapture(t *testing.T) *pubsubCapture { + return &pubsubCapture{ + agentIDs: make(map[uuid.UUID]struct{}), + t: t, + } +} + +func (c *pubsubCapture) capture(event string, message []byte) { + c.mu.Lock() + defer c.mu.Unlock() + + // Verify correct event. + assert.Equal(c.t, event, MetadataBatchPubsubChannel) + + // Decode base64-encoded agent IDs from payload. + assert.Equal(c.t, len(message)%UUIDBase64Size, 0) + + numAgents := len(message) / UUIDBase64Size + for i := 0; i < numAgents; i++ { + start := i * UUIDBase64Size + end := start + UUIDBase64Size + encoded := message[start:end] + + var uuidBytes [16]byte + n, err := base64.RawStdEncoding.Decode(uuidBytes[:], encoded) + assert.NoError(c.t, err) + assert.Equal(c.t, n, 16) + + agentID, err := uuid.FromBytes(uuidBytes[:]) + assert.NoError(c.t, err) + + c.agentIDs[agentID] = struct{}{} + } +} + +func (c *pubsubCapture) requireContainsAll(expected []uuid.UUID) { + c.mu.Lock() + defer c.mu.Unlock() + + // Check we don't have extra IDs. + require.Equal(c.t, len(expected), len(c.agentIDs), "unexpected number of agent IDs in pubsub messages") + + // Check all expected IDs are present. + for _, expectedID := range expected { + _, ok := c.agentIDs[expectedID] + require.True(c.t, ok, "expected agent ID %s not found in pubsub messages", expectedID) + } +} + +func (c *pubsubCapture) count() int { + c.mu.Lock() + defer c.mu.Unlock() + return len(c.agentIDs) +} + +func (c *pubsubCapture) clear() { + c.mu.Lock() + defer c.mu.Unlock() + c.agentIDs = make(map[uuid.UUID]struct{}) +} + +func TestMetadataBatcher(t *testing.T) { + t.Parallel() + + // Given: a fresh batcher with no data + ctx := testutil.Context(t, testutil.WaitShort) + log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + ctrl := gomock.NewController(t) + store := dbmock.NewMockStore(ctrl) + ps := psmock.NewMockPubsub(ctrl) + clock := quartz.NewMock(t) + + // Trap timer reset calls so we can wait for them to complete. + resetTrap := clock.Trap().TimerReset("metadataBatcher", "scheduledFlush") + defer resetTrap.Close() + capacityResetTrap := clock.Trap().TimerReset("metadataBatcher", "capacityFlush") + defer capacityResetTrap.Close() + + // Generate mock agent IDs. + agent1 := uuid.New() + agent2 := uuid.New() + + // Create a single pubsub capture to reuse across all flushes. + psCap := newPubsubCapture(t) + + // --- FLUSH 1: Empty flush (no calls expected) --- + // No expectations set - if DB query called, test will fail. + reg := prometheus.NewRegistry() + b, err := NewBatcher(ctx, reg, store, ps, + WithLogger(log), + WithClock(clock), + ) + require.NoError(t, err) + t.Cleanup(b.Close) + + // Given: no metadata updates are added + // When: it becomes time to flush + // Then: no metadata should be updated (no DB call) + clock.Advance(defaultMetadataFlushInterval).MustWait(ctx) + resetTrap.MustWait(ctx).MustRelease(ctx) // Wait for timer reset after flush + t.Log("flush 1 completed (expected 0 entries)") + require.Equal(t, float64(0), prom_testutil.ToFloat64(b.Metrics.BatchesTotal.WithLabelValues(flushTicker))) + + // --- FLUSH 2: Single agent with 2 metadata entries --- + t2 := clock.Now() + + // Expect exactly 1 database call with exact values. + store.EXPECT(). + BatchUpdateWorkspaceAgentMetadata( + gomock.Any(), + matchMetadata( + []uuid.UUID{agent1, agent1}, + []string{"key1", "key2"}, + []string{"value1", "value2"}, + []string{"", ""}, + []time.Time{t2, t2}, + ), + ). + Return(nil). + Times(1) + + // Expect exactly 1 pubsub publish with correct event and agent IDs. + ps.EXPECT(). + Publish(gomock.Any(), gomock.Any()). + Do(psCap.capture). + Return(nil). + Times(1) + + // Given: a single metadata update is added for agent1 + t.Log("adding metadata for 1 agent") + + // Capture dropped count before adding. + droppedBefore := prom_testutil.ToFloat64(b.Metrics.DroppedKeysTotal) + + require.NoError(t, b.Add(agent1, []string{"key1", "key2"}, []string{"value1", "value2"}, []string{"", ""}, []time.Time{t2, t2})) + + // Wait for the channel to be processed and verify nothing was dropped. + testutil.Eventually(ctx, t, func(ctx context.Context) bool { + channelEmpty := len(b.updateCh) == 0 + nothingDropped := prom_testutil.ToFloat64(b.Metrics.DroppedKeysTotal) == droppedBefore + batchHasExpected := int(b.currentBatchLen.Load()) == 2 + return channelEmpty && nothingDropped && batchHasExpected + }, testutil.IntervalFast) + + // When: it becomes time to flush + clock.Advance(defaultMetadataFlushInterval).MustWait(ctx) + resetTrap.MustWait(ctx).MustRelease(ctx) // Wait for timer reset after flush + t.Log("flush 2 completed (expected 2 entries)") + testutil.Eventually(ctx, t, func(ctx context.Context) bool { + val := prom_testutil.ToFloat64(b.Metrics.BatchesTotal.WithLabelValues(flushTicker)) + totalMeta := prom_testutil.ToFloat64(b.Metrics.MetadataTotal) + return float64(1) == val && totalMeta >= float64(2) + }, testutil.IntervalFast) + require.Equal(t, float64(2), prom_testutil.ToFloat64(b.Metrics.MetadataTotal)) + + // Wait for pubsub capture to complete and verify all agent IDs were published. + testutil.Eventually(ctx, t, func(ctx context.Context) bool { + return psCap.count() == 1 + }, testutil.IntervalFast) + psCap.requireContainsAll([]uuid.UUID{agent1}) + + // --- FLUSH 3: Multiple agents with 5 total metadata entries --- + t3 := clock.Now() + + // Clear pubsub capture for the next flush. + psCap.clear() + + // Expect exactly 1 database call with exact values for both agents. + store.EXPECT(). + BatchUpdateWorkspaceAgentMetadata( + gomock.Any(), + matchMetadata( + []uuid.UUID{agent1, agent1, agent1, agent2, agent2}, + []string{"key1", "key2", "key3", "key1", "key2"}, + []string{"new_value1", "new_value2", "new_value3", "agent2_value1", "agent2_value2"}, + []string{"", "", "", "", ""}, + []time.Time{t3, t3, t3, t3, t3}, + ), + ). + Return(nil). + Times(1) + + // Expect exactly 1 pubsub publish with both agent IDs. + ps.EXPECT(). + Publish(gomock.Any(), gomock.Any()). + Do(psCap.capture). + Return(nil). + Times(1) + + // Given: metadata updates are added for multiple agents + t.Log("adding metadata for 2 agents") + + // Capture dropped count before any adds. + droppedBefore = prom_testutil.ToFloat64(b.Metrics.DroppedKeysTotal) + + require.NoError(t, b.Add(agent1, []string{"key1", "key2", "key3"}, []string{"new_value1", "new_value2", "new_value3"}, []string{"", "", ""}, []time.Time{t3, t3, t3})) + require.NoError(t, b.Add(agent2, []string{"key1", "key2"}, []string{"agent2_value1", "agent2_value2"}, []string{"", ""}, []time.Time{t3, t3})) + + // Wait for all channel messages to be processed into the batch. + testutil.Eventually(ctx, t, func(ctx context.Context) bool { + channelEmpty := len(b.updateCh) == 0 + nothingDropped := prom_testutil.ToFloat64(b.Metrics.DroppedKeysTotal) == droppedBefore + batchHasExpected := int(b.currentBatchLen.Load()) == 5 + return channelEmpty && nothingDropped && batchHasExpected + }, testutil.IntervalFast) + + // When: it becomes time to flush + clock.Advance(defaultMetadataFlushInterval).MustWait(ctx) + resetTrap.MustWait(ctx).MustRelease(ctx) // Wait for timer reset after flush + t.Log("flush 3 completed (expected 5 new entries)") + testutil.Eventually(ctx, t, func(ctx context.Context) bool { + val := prom_testutil.ToFloat64(b.Metrics.BatchesTotal.WithLabelValues(flushTicker)) + totalMeta := prom_testutil.ToFloat64(b.Metrics.MetadataTotal) + return float64(2) == val && totalMeta >= float64(7) + }, testutil.IntervalFast) + require.Equal(t, float64(7), prom_testutil.ToFloat64(b.Metrics.MetadataTotal)) + + // Wait for pubsub capture to complete and verify all agent IDs were published. + testutil.Eventually(ctx, t, func(ctx context.Context) bool { + return psCap.count() == 2 + }, testutil.IntervalFast) + psCap.requireContainsAll([]uuid.UUID{agent1, agent2}) + + // --- FLUSH 4: Capacity flush with defaultMetadataBatchSize entries --- + t4 := clock.Now() + numAgents := defaultMetadataBatchSize + + // Clear pubsub capture for the next flush. + psCap.clear() + + // Pre-generate all agent IDs so we can assert on exact values. + agentIDs := make([]uuid.UUID, numAgents) + for i := 0; i < numAgents; i++ { + agentIDs[i] = uuid.New() + } + + // Build expected values for database assertion. + expectedKeys := make([]string, numAgents) + expectedValues := make([]string, numAgents) + expectedErrors := make([]string, numAgents) + expectedTimes := make([]time.Time, numAgents) + for i := 0; i < numAgents; i++ { + expectedKeys[i] = "key1" + expectedValues[i] = "bulk_value" + expectedErrors[i] = "" + expectedTimes[i] = t4 + } + + // Assert on exact database values. + store.EXPECT(). + BatchUpdateWorkspaceAgentMetadata( + gomock.Any(), + matchMetadata(agentIDs, expectedKeys, expectedValues, expectedErrors, expectedTimes), + ). + Return(nil). + Times(1) + + // Pubsub will be called with chunking. + // With 500 agents, we expect exactly 2 pubsub calls due to chunking (363 + 137). + ps.EXPECT(). + Publish(gomock.Any(), gomock.Any()). + Do(psCap.capture). + Return(nil). + Times(2) + + // Add metadata updates using the pre-generated agent IDs. + done := make(chan struct{}) + + go func() { + defer close(done) + t.Logf("adding metadata for %d agents", numAgents) + for i := 0; i < numAgents; i++ { + require.NoError(t, b.Add(agentIDs[i], []string{"key1"}, []string{"bulk_value"}, []string{""}, []time.Time{t4})) + } + }() + + // Wait for all updates to be added + <-done + capacityResetTrap.MustWait(ctx).MustRelease(ctx) // Wait for timer reset after capacity flush + t.Log("flush 4 completed (capacity flush, expected", defaultMetadataBatchSize, "entries)") + testutil.Eventually(ctx, t, func(ctx context.Context) bool { + return float64(1) == prom_testutil.ToFloat64(b.Metrics.BatchesTotal.WithLabelValues(flushCapacity)) + }, testutil.IntervalFast) + require.Equal(t, float64(507), prom_testutil.ToFloat64(b.Metrics.MetadataTotal)) + + // Wait for pubsub capture to complete and verify all agent IDs were published (across all chunks). + testutil.Eventually(ctx, t, func(ctx context.Context) bool { + return psCap.count() == numAgents + }, testutil.IntervalFast) + psCap.requireContainsAll(agentIDs) +} + +func TestMetadataBatcher_DropsWhenFull(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + + ctrl := gomock.NewController(t) + store := dbmock.NewMockStore(ctrl) + ps := psmock.NewMockPubsub(ctrl) + clock := quartz.NewMock(t) + + reg := prometheus.NewRegistry() + // Batch size of 2 means channel capacity = 10 (2 * 5) + b, err := NewBatcher(ctx, reg, store, ps, + WithLogger(log), + WithBatchSize(2), + WithClock(clock), + ) + require.NoError(t, err) + t.Cleanup(b.Close) + + t1 := clock.Now() + + // Channels to control when the store call blocks/unblocks + flushStarted := make(chan struct{}) + unblockFlush := make(chan struct{}) + + pubsubCap := newPubsubCapture(t) + + // Make the first store call block until we signal. After unblocking, + // the 10 queued entries will trigger 5 more capacity flushes (10/2 = 5). + // Total expected flushes: 1 (initial) + 5 (queued) = 6 + firstCall := true + store.EXPECT(). + BatchUpdateWorkspaceAgentMetadata(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, params database.BatchUpdateWorkspaceAgentMetadataParams) error { + if firstCall { + firstCall = false + close(flushStarted) // Signal that first flush has started + <-unblockFlush // Wait for signal to continue + } + return nil + }). + Times(6) + + ps.EXPECT(). + Publish(gomock.Any(), gomock.Any()). + Do(pubsubCap.capture). + Return(nil). + Times(6) + + // Add 2 entries - this will trigger capacity flush (batch size = 2) that blocks + agent1 := uuid.New() + agent2 := uuid.New() + require.NoError(t, b.Add(agent1, []string{"key1"}, []string{"value1"}, []string{""}, []time.Time{t1})) + require.NoError(t, b.Add(agent2, []string{"key1"}, []string{"value2"}, []string{""}, []time.Time{t1})) + + // Wait for flush to start and block in the store call + <-flushStarted + + // Now the flush is blocked. Channel capacity is 10. + // Fill the channel with 10 entries + droppedBefore := prom_testutil.ToFloat64(b.Metrics.DroppedKeysTotal) + + for i := 0; i < 10; i++ { + agent := uuid.New() + require.NoError(t, b.Add(agent, []string{"key1"}, []string{fmt.Sprintf("value%d", i)}, []string{""}, []time.Time{t1})) + } + + // Channel should now be full. Next add should drop. + agentDropped := uuid.New() + require.NoError(t, b.Add(agentDropped, []string{"key1"}, []string{"dropped"}, []string{""}, []time.Time{t1})) + + // Verify that 1 key was dropped + testutil.Eventually(ctx, t, func(ctx context.Context) bool { + dropped := prom_testutil.ToFloat64(b.Metrics.DroppedKeysTotal) + return dropped == droppedBefore+1 + }, testutil.IntervalFast) + + // Unblock the flush + close(unblockFlush) + + // Wait for all queued entries to be processed (channel should be empty) + testutil.Eventually(ctx, t, func(ctx context.Context) bool { + return len(b.updateCh) == 0 + }, testutil.IntervalFast) + + // Verify final state: 1 key was dropped, 12 metadata sent in 6 capacity batches + require.Equal(t, droppedBefore+1, prom_testutil.ToFloat64(b.Metrics.DroppedKeysTotal)) + require.Equal(t, float64(12), prom_testutil.ToFloat64(b.Metrics.MetadataTotal)) + require.Equal(t, float64(6), prom_testutil.ToFloat64(b.Metrics.BatchesTotal.WithLabelValues(flushCapacity))) +} + +// TestMetadataBatcher_Deduplication executes two Add calls, the second with a later timestamp than the first, to check +// that existing keys within a batch have their values updated. +func TestMetadataBatcher_Deduplication(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + + // First Add call + add1Keys []string + add1Values []string + + // Second Add call + add2Keys []string + add2Values []string + + // Expected result after deduplication + wantKeys []string + wantValues []string + }{ + { + name: "same key updated twice keeps newest", + + add1Keys: []string{"key1"}, + add1Values: []string{"first_value"}, + + add2Keys: []string{"key1"}, + add2Values: []string{"second_value"}, + + wantKeys: []string{"key1"}, + wantValues: []string{"second_value"}, + }, + { + name: "mixed keys with partial overlap", + + add1Keys: []string{"key1", "key2"}, + add1Values: []string{"value1", "value2"}, + + add2Keys: []string{"key1", "key3"}, + add2Values: []string{"new_value1", "value3"}, + + wantKeys: []string{"key1", "key2", "key3"}, + wantValues: []string{"new_value1", "value2", "value3"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + ctrl := gomock.NewController(t) + store := dbmock.NewMockStore(ctrl) + ps := psmock.NewMockPubsub(ctrl) + clock := quartz.NewMock(t) + + agent := uuid.New() + + reg := prometheus.NewRegistry() + b, err := NewBatcher(ctx, reg, store, ps, + WithLogger(log), + WithClock(clock), + ) + require.NoError(t, err) + t.Cleanup(b.Close) + + // Set up timestamps - t2 is 1ms after t1 + t1 := clock.Now() + t2 := t1.Add(time.Millisecond) + + // Create time slices for add1 (all t1) and add2 (all t2) + add1Times := make([]time.Time, len(tt.add1Keys)) + for i := range add1Times { + add1Times[i] = t1 + } + add2Times := make([]time.Time, len(tt.add2Keys)) + for i := range add2Times { + add2Times[i] = t2 + } + + // Build expected times based on which add they came from. + // If a key appears in add2, it gets t2 (newer), otherwise t1. + expectedTimes := make([]time.Time, len(tt.wantKeys)) + for i, wantKey := range tt.wantKeys { + // Check if key appears in add2 (newer) + foundInAdd2 := false + for _, add2Key := range tt.add2Keys { + if add2Key == wantKey { + expectedTimes[i] = t2 + foundInAdd2 = true + break + } + } + if !foundInAdd2 { + // Must be from add1 + expectedTimes[i] = t1 + } + } + + // Set up mock expectations + psCap := newPubsubCapture(t) + + // Build expected errors (all empty) and agent IDs (all same agent) + expectedErrors := make([]string, len(tt.wantKeys)) + for i := range expectedErrors { + expectedErrors[i] = "" + } + expectedAgents := make([]uuid.UUID, len(tt.wantKeys)) + for i := range expectedAgents { + expectedAgents[i] = agent + } + + store.EXPECT(). + BatchUpdateWorkspaceAgentMetadata( + gomock.Any(), + matchMetadata( + expectedAgents, + tt.wantKeys, + tt.wantValues, + expectedErrors, + expectedTimes, + ), + ). + Return(nil). + Times(1) + + ps.EXPECT(). + Publish(gomock.Any(), gomock.Any()). + Do(psCap.capture). + Return(nil). + Times(1) + + // Perform the adds + droppedBefore := prom_testutil.ToFloat64(b.Metrics.DroppedKeysTotal) + + // First add with all empty error strings + add1Errors := make([]string, len(tt.add1Keys)) + require.NoError(t, b.Add(agent, tt.add1Keys, tt.add1Values, add1Errors, add1Times)) + + // Second add with all empty error strings + add2Errors := make([]string, len(tt.add2Keys)) + require.NoError(t, b.Add(agent, tt.add2Keys, tt.add2Values, add2Errors, add2Times)) + + // Wait for all channel messages to be processed into the batch + testutil.Eventually(ctx, t, func(ctx context.Context) bool { + channelEmpty := len(b.updateCh) == 0 + nothingDropped := prom_testutil.ToFloat64(b.Metrics.DroppedKeysTotal) == droppedBefore + batchHasExpected := int(b.currentBatchLen.Load()) == len(tt.wantKeys) + return channelEmpty && nothingDropped && batchHasExpected + }, testutil.IntervalFast) + + // Trigger scheduled flush + clock.Advance(defaultMetadataFlushInterval).MustWait(ctx) + + // Verify flush occurred with correct number of entries + testutil.Eventually(ctx, t, func(ctx context.Context) bool { + return float64(1) == prom_testutil.ToFloat64(b.Metrics.BatchesTotal.WithLabelValues(flushTicker)) + }, testutil.IntervalFast) + require.Equal(t, float64(len(tt.wantKeys)), prom_testutil.ToFloat64(b.Metrics.MetadataTotal)) + + // Verify pubsub published the agent ID + testutil.Eventually(ctx, t, func(ctx context.Context) bool { + return psCap.count() == 1 + }, testutil.IntervalFast) + psCap.requireContainsAll([]uuid.UUID{agent}) + }) + } +} + +func TestMetadataBatcher_TimestampOrdering(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + ctrl := gomock.NewController(t) + store := dbmock.NewMockStore(ctrl) + ps := psmock.NewMockPubsub(ctrl) + clock := quartz.NewMock(t) + + reg := prometheus.NewRegistry() + b, err := NewBatcher(ctx, reg, store, ps, + WithLogger(log), + WithClock(clock), + ) + require.NoError(t, err) + t.Cleanup(b.Close) + + // Generate mock agent ID. + agent := uuid.New() + + t1 := clock.Now() + t2 := t1.Add(time.Second) + t3 := t2.Add(time.Second) + + // Set up pubsub capture for the flush. + psCap := newPubsubCapture(t) + + // Expect the store to be called with only the newest timestamp. + store.EXPECT(). + BatchUpdateWorkspaceAgentMetadata( + gomock.Any(), + matchMetadata( + []uuid.UUID{agent}, + []string{"key1"}, + []string{"newest_value"}, + []string{""}, + []time.Time{t3}, + ), + ). + Return(nil). + Times(1) + + // Expect pubsub publish to be called when flush happens. + ps.EXPECT(). + Publish(gomock.Any(), gomock.Any()). + Do(psCap.capture). + Return(nil). + Times(1) + + // Add update with t2 timestamp + // Capture dropped count before any adds. + droppedBefore := prom_testutil.ToFloat64(b.Metrics.DroppedKeysTotal) + + require.NoError(t, b.Add(agent, []string{"key1"}, []string{"newer_value"}, []string{""}, []time.Time{t2})) + + // Try to add older update with t1 timestamp - should be ignored + require.NoError(t, b.Add(agent, []string{"key1"}, []string{"older_value"}, []string{""}, []time.Time{t1})) + + // Add even newer update with t3 timestamp - should overwrite + require.NoError(t, b.Add(agent, []string{"key1"}, []string{"newest_value"}, []string{""}, []time.Time{t3})) + + // Wait for all channel messages to be processed by the run() goroutine into the batch. + testutil.Eventually(ctx, t, func(ctx context.Context) bool { + channelEmpty := len(b.updateCh) == 0 + nothingDropped := prom_testutil.ToFloat64(b.Metrics.DroppedKeysTotal) == droppedBefore + batchHasExpected := int(b.currentBatchLen.Load()) == 1 + return channelEmpty && nothingDropped && batchHasExpected + }, testutil.IntervalFast) + + // Flush and verify entry was sent. + // Advance the full flush interval from when the batcher was created. + clock.Advance(defaultMetadataFlushInterval).MustWait(ctx) + + // Wait for pubsub capture to complete and verify all agent IDs were published. + testutil.Eventually(ctx, t, func(ctx context.Context) bool { + return psCap.count() == 1 + }, testutil.IntervalFast) + psCap.requireContainsAll([]uuid.UUID{agent}) + + // Verify only 1 entry was flushed (newest timestamp wins) + testutil.Eventually(ctx, t, func(ctx context.Context) bool { + return float64(1) == prom_testutil.ToFloat64(b.Metrics.BatchesTotal.WithLabelValues(flushTicker)) + }, testutil.IntervalFast) + require.Equal(t, float64(1), prom_testutil.ToFloat64(b.Metrics.MetadataTotal)) +} + +func TestMetadataBatcher_PubsubChunking(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + ctrl := gomock.NewController(t) + store := dbmock.NewMockStore(ctrl) + ps := psmock.NewMockPubsub(ctrl) + clock := quartz.NewMock(t) + + reg := prometheus.NewRegistry() + b, err := NewBatcher(ctx, reg, store, ps, + WithLogger(log), + WithClock(clock), + ) + require.NoError(t, err) + t.Cleanup(b.Close) + + t1 := clock.Now() + + // Create enough agents to exceed maxAgentIDsPerChunk. + // With base64 encoding, each UUID is 22 characters, so we can fit + // ~363 agent IDs per chunk (8000 / 22 = 363.6). + // Let's create 600 agents to force chunking into 2 messages. + numAgents := 600 + agents := make([]uuid.UUID, numAgents) + expectedKeys := make([]string, numAgents) + expectedValues := make([]string, numAgents) + expectedErrors := make([]string, numAgents) + expectedTimes := make([]time.Time, numAgents) + + for i := 0; i < numAgents; i++ { + agents[i] = uuid.New() + expectedKeys[i] = "key1" + expectedValues[i] = "value1" + expectedErrors[i] = "" + expectedTimes[i] = t1 + } + + // Set up pubsub capture for the flush. + psCap := newPubsubCapture(t) + + // With 600 agents and default batch size of 500: + // - First flush at 500 agents (capacity): 2 pubsub chunks (363 + 137) + // - Second flush at 100 agents (scheduled): 1 pubsub chunk + // Total: 3 publishes, 2 store calls + + // Expect the store to be called twice - once for first 500, once for remaining 100. + store.EXPECT(). + BatchUpdateWorkspaceAgentMetadata( + gomock.Any(), + matchMetadata( + agents[:500], + expectedKeys[:500], + expectedValues[:500], + expectedErrors[:500], + expectedTimes[:500], + ), + ). + Return(nil). + Times(1) + + store.EXPECT(). + BatchUpdateWorkspaceAgentMetadata( + gomock.Any(), + matchMetadata( + agents[500:], + expectedKeys[500:], + expectedValues[500:], + expectedErrors[500:], + expectedTimes[500:], + ), + ). + Return(nil). + Times(1) + + // Expect pubsub publish to be called when flush happens. + // With base64 encoding, each UUID is 22 characters. + // With 8KB limit, we can fit ~363 agents per chunk (8000 / 22 = 363.6). + // With 600 agents and batch size of 500: + // - First flush at 500 agents: 2 chunks (363 + 137) + // - Second flush at 100 agents: 1 chunk + // Total: 3 publishes + ps.EXPECT(). + Publish(gomock.Any(), gomock.Any()). + Do(psCap.capture). + Return(nil). + Times(3) + + // Add first 499 metadata updates (just under the capacity threshold of 500) + // Capture dropped count before any adds. + droppedBefore := prom_testutil.ToFloat64(b.Metrics.DroppedKeysTotal) + + for i := 0; i < 499; i++ { + require.NoError(t, b.Add(agents[i], []string{"key1"}, []string{"value1"}, []string{""}, []time.Time{t1})) + } + + // Wait for all channel messages to be processed into the batch. + // Batch should have 499 entries, no capacity flush yet. + testutil.Eventually(ctx, t, func(ctx context.Context) bool { + channelEmpty := len(b.updateCh) == 0 + nothingDropped := prom_testutil.ToFloat64(b.Metrics.DroppedKeysTotal) == droppedBefore + batchHasExpected := int(b.currentBatchLen.Load()) == 499 + return channelEmpty && nothingDropped && batchHasExpected + }, testutil.IntervalFast) + + // Add next 101 metadata updates (will trigger capacity flush at 500) + for i := 499; i < numAgents; i++ { + require.NoError(t, b.Add(agents[i], []string{"key1"}, []string{"value1"}, []string{""}, []time.Time{t1})) + } + + // Wait for all channel messages to be processed. The 500th entry should have + // triggered an automatic capacity flush, leaving 100 entries in the batch. + testutil.Eventually(ctx, t, func(ctx context.Context) bool { + channelEmpty := len(b.updateCh) == 0 + nothingDropped := prom_testutil.ToFloat64(b.Metrics.DroppedKeysTotal) == droppedBefore + batchHasExpected := int(b.currentBatchLen.Load()) == 100 + return channelEmpty && nothingDropped && batchHasExpected + }, testutil.IntervalFast) + + // Verify capacity flush metrics and total metadata count. + testutil.Eventually(ctx, t, func(ctx context.Context) bool { + capacity := prom_testutil.ToFloat64(b.Metrics.BatchesTotal.WithLabelValues(flushCapacity)) + totalMeta := prom_testutil.ToFloat64(b.Metrics.MetadataTotal) + // Should have 1 capacity flush (500 entries) so far + return capacity == float64(1) && totalMeta == float64(500) + }, testutil.IntervalFast) + + // Flush remaining entries and verify all updates were processed + clock.Advance(defaultMetadataFlushInterval).MustWait(ctx) + + // Wait for pubsub capture to complete and verify all agent IDs were published. + testutil.Eventually(ctx, t, func(ctx context.Context) bool { + return psCap.count() == numAgents + }, testutil.IntervalFast) + psCap.requireContainsAll(agents) + + // Verify that all metadata was flushed successfully. + // We should have 1 capacity flush (500 entries) and 1 scheduled flush (100 entries). + testutil.Eventually(ctx, t, func(ctx context.Context) bool { + capacity := prom_testutil.ToFloat64(b.Metrics.BatchesTotal.WithLabelValues(flushCapacity)) + scheduled := prom_testutil.ToFloat64(b.Metrics.BatchesTotal.WithLabelValues(flushTicker)) + totalMeta := prom_testutil.ToFloat64(b.Metrics.MetadataTotal) + // Check that we've had 1 capacity flush and 1 scheduled flush + return capacity == float64(1) && scheduled == float64(1) && totalMeta == float64(600) + }, testutil.IntervalFast) + require.Equal(t, float64(numAgents), prom_testutil.ToFloat64(b.Metrics.MetadataTotal)) +} + +func TestMetadataBatcher_ConcurrentAddsToSameAgent(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + ctrl := gomock.NewController(t) + store := dbmock.NewMockStore(ctrl) + ps := psmock.NewMockPubsub(ctrl) + clock := quartz.NewMock(t) + + reg := prometheus.NewRegistry() + b, err := NewBatcher(ctx, reg, store, ps, + WithLogger(log), + WithClock(clock), + ) + require.NoError(t, err) + t.Cleanup(b.Close) + + // Single agent, multiple goroutines updating same keys concurrently + agentID := uuid.New() + numGoroutines := 20 + timestamps := make([]time.Time, numGoroutines) + initialTS := clock.Now() + for i := 0; i < numGoroutines; i++ { + timestamps[i] = initialTS.Add(time.Duration(i) * time.Millisecond) + } + + // The latest timestamp will have the final values, since deduplication keeps the newest value for each key. + latestTimestamp := timestamps[numGoroutines-1] + latestValue := fmt.Sprintf("value_from_goroutine_%d", numGoroutines-1) + + // Set up pubsub capture for the flush. + psCap := newPubsubCapture(t) + + // Expect the store to be called with exactly 3 keys (after deduplication). + // The values should be from the goroutine with the latest timestamp. + store.EXPECT(). + BatchUpdateWorkspaceAgentMetadata( + gomock.Any(), + matchMetadata( + []uuid.UUID{agentID, agentID, agentID}, + []string{"key1", "key2", "key3"}, + []string{latestValue, latestValue, latestValue}, + []string{"", "", ""}, + []time.Time{latestTimestamp, latestTimestamp, latestTimestamp}, + ), + ). + Return(nil). + Times(1) + + ps.EXPECT(). + Publish(gomock.Any(), gomock.Any()). + Do(psCap.capture). + Return(nil). + Times(1) + + var wg sync.WaitGroup + wg.Add(numGoroutines) + + // Capture dropped count before any adds. + droppedBefore := prom_testutil.ToFloat64(b.Metrics.DroppedKeysTotal) + + // Each goroutine updates the same set of keys with different values + for i := 0; i < numGoroutines; i++ { + go func(routineNum int) { + defer wg.Done() + timestamp := timestamps[routineNum] + value := fmt.Sprintf("value_from_goroutine_%d", routineNum) + _ = b.Add(agentID, []string{"key1", "key2", "key3"}, + []string{value, value, value}, + []string{"", "", ""}, + []time.Time{timestamp, timestamp, timestamp}) + }(i) + } + + wg.Wait() + + // Wait for all channel messages to be processed by the run() goroutine into the batch. + testutil.Eventually(ctx, t, func(ctx context.Context) bool { + channelEmpty := len(b.updateCh) == 0 + nothingDropped := prom_testutil.ToFloat64(b.Metrics.DroppedKeysTotal) == droppedBefore + batchHasExpected := int(b.currentBatchLen.Load()) == 3 + return channelEmpty && nothingDropped && batchHasExpected + }, testutil.IntervalFast) + + // Flush and check that we have exactly 3 keys (deduplication worked). + // Advance the full flush interval from when the batcher was created. + clock.Advance(defaultMetadataFlushInterval).MustWait(ctx) + + // Wait for pubsub capture to complete and verify all agent IDs were published. + testutil.Eventually(ctx, t, func(ctx context.Context) bool { + return psCap.count() == 1 + }, testutil.IntervalFast) + psCap.requireContainsAll([]uuid.UUID{agentID}) + + // Verify exactly 3 unique keys were flushed + testutil.Eventually(ctx, t, func(ctx context.Context) bool { + return float64(1) == prom_testutil.ToFloat64(b.Metrics.BatchesTotal.WithLabelValues(flushTicker)) + }, testutil.IntervalFast) + require.Equal(t, float64(3), prom_testutil.ToFloat64(b.Metrics.MetadataTotal)) +} diff --git a/coderd/agentapi/metadatabatcher/metadata_batcher_metrics.go b/coderd/agentapi/metadatabatcher/metadata_batcher_metrics.go new file mode 100644 index 0000000000..b559069c75 --- /dev/null +++ b/coderd/agentapi/metadatabatcher/metadata_batcher_metrics.go @@ -0,0 +1,95 @@ +package metadatabatcher + +import ( + "github.com/prometheus/client_golang/prometheus" +) + +type Metrics struct { + BatchUtilization prometheus.Histogram + FlushDuration *prometheus.HistogramVec + BatchSize prometheus.Histogram + BatchesTotal *prometheus.CounterVec + DroppedKeysTotal prometheus.Counter + MetadataTotal prometheus.Counter + PublishErrors prometheus.Counter +} + +func NewMetrics() Metrics { + return Metrics{ + BatchUtilization: prometheus.NewHistogram(prometheus.HistogramOpts{ + Namespace: "coderd", + Subsystem: "agentapi", + Name: "metadata_batch_utilization", + Help: "Number of metadata keys per agent in each batch, updated before flushes.", + Buckets: []float64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20, 40, 80, 160}, + }), + + BatchSize: prometheus.NewHistogram(prometheus.HistogramOpts{ + Namespace: "coderd", + Subsystem: "agentapi", + Name: "metadata_batch_size", + Help: "Total number of metadata entries in each batch, updated before flushes.", + Buckets: []float64{10, 25, 50, 100, 150, 200, 250, 300, 350, 400, 450, 500}, + }), + + FlushDuration: prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: "coderd", + Subsystem: "agentapi", + Name: "metadata_flush_duration_seconds", + Help: "Time taken to flush metadata batch to database and pubsub.", + Buckets: []float64{0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0}, + }, []string{"reason"}), + + BatchesTotal: prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: "coderd", + Subsystem: "agentapi", + Name: "metadata_batches_total", + Help: "Total number of metadata batches flushed.", + }, []string{"reason"}), + + DroppedKeysTotal: prometheus.NewCounter(prometheus.CounterOpts{ + Namespace: "coderd", + Subsystem: "agentapi", + Name: "metadata_dropped_keys_total", + Help: "Total number of metadata keys dropped due to capacity limits.", + }), + + MetadataTotal: prometheus.NewCounter(prometheus.CounterOpts{ + Namespace: "coderd", + Subsystem: "agentapi", + Name: "metadata_flushed_total", + Help: "Total number of unique metadatas flushed.", + }), + + PublishErrors: prometheus.NewCounter(prometheus.CounterOpts{ + Namespace: "coderd", + Subsystem: "agentapi", + Name: "metadata_publish_errors_total", + Help: "Total number of metadata batch pubsub publish calls that have resulted in an error.", + }), + } +} + +func (m Metrics) Collectors() []prometheus.Collector { + return []prometheus.Collector{ + m.BatchUtilization, + m.BatchSize, + m.FlushDuration, + m.BatchesTotal, + m.DroppedKeysTotal, + m.MetadataTotal, + m.PublishErrors, + } +} + +func (m Metrics) register(reg prometheus.Registerer) { + if reg != nil { + reg.MustRegister(m.BatchUtilization) + reg.MustRegister(m.BatchSize) + reg.MustRegister(m.FlushDuration) + reg.MustRegister(m.DroppedKeysTotal) + reg.MustRegister(m.BatchesTotal) + reg.MustRegister(m.MetadataTotal) + reg.MustRegister(m.PublishErrors) + } +} diff --git a/coderd/coderd.go b/coderd/coderd.go index 645830b7b5..b53f78e56b 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -44,6 +44,7 @@ import ( "cdr.dev/slog/v3" agentproto "github.com/coder/coder/v2/agent/proto" "github.com/coder/coder/v2/buildinfo" + "github.com/coder/coder/v2/coderd/agentapi/metadatabatcher" _ "github.com/coder/coder/v2/coderd/apidoc" // Used for swagger docs. "github.com/coder/coder/v2/coderd/appearance" "github.com/coder/coder/v2/coderd/audit" @@ -241,6 +242,8 @@ type Options struct { UpdateAgentMetrics func(ctx context.Context, labels prometheusmetrics.AgentMetricLabels, metrics []*agentproto.Stats_Metric) StatsBatcher workspacestats.Batcher + MetadataBatcherOptions []metadatabatcher.Option + ProvisionerdServerMetrics *provisionerdserver.Metrics // WorkspaceAppAuditSessionTimeout allows changing the timeout for audit @@ -786,6 +789,23 @@ func New(options *Options) *API { AppStatBatchSize: workspaceapps.DefaultStatsDBReporterBatchSize, DisableDatabaseInserts: !options.DeploymentValues.StatsCollection.UsageStats.Enable.Value(), }) + + // Initialize the metadata batcher for batching agent metadata updates. + batcherOpts := []metadatabatcher.Option{ + metadatabatcher.WithLogger(options.Logger.Named("metadata_batcher")), + } + batcherOpts = append(batcherOpts, options.MetadataBatcherOptions...) + api.metadataBatcher, err = metadatabatcher.NewBatcher( + api.ctx, + options.PrometheusRegistry, + options.Database, + options.Pubsub, + batcherOpts..., + ) + if err != nil { + api.Logger.Fatal(context.Background(), "failed to initialize metadata batcher", slog.Error(err)) + } + workspaceAppsLogger := options.Logger.Named("workspaceapps") if options.WorkspaceAppsStatsCollectorOptions.Logger == nil { named := workspaceAppsLogger.Named("stats_collector") @@ -1865,7 +1885,8 @@ type API struct { healthCheckCache atomic.Pointer[healthsdk.HealthcheckReport] healthCheckProgress healthcheck.Progress - statsReporter *workspacestats.Reporter + statsReporter *workspacestats.Reporter + metadataBatcher *metadatabatcher.Batcher Acquirer *provisionerdserver.Acquirer // dbRolluper rolls up template usage stats from raw agent and app @@ -1917,6 +1938,9 @@ func (api *API) Close() error { _ = (*coordinator).Close() } _ = api.statsReporter.Close() + if api.metadataBatcher != nil { + api.metadataBatcher.Close() + } _ = api.NetworkTelemetryBatcher.Close() _ = api.OIDCConvertKeyCache.Close() _ = api.AppSigningKeyCache.Close() diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index 5014d3c383..bb4d687db1 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -55,6 +55,7 @@ import ( "cdr.dev/slog/v3/sloggers/slogtest" "github.com/coder/coder/v2/archive" "github.com/coder/coder/v2/coderd" + "github.com/coder/coder/v2/coderd/agentapi/metadatabatcher" "github.com/coder/coder/v2/coderd/audit" "github.com/coder/coder/v2/coderd/autobuild" "github.com/coder/coder/v2/coderd/awsidentity" @@ -171,8 +172,9 @@ type Options struct { SwaggerEndpoint bool // Logger should only be overridden if you expect errors // as part of your test. - Logger *slog.Logger - StatsBatcher workspacestats.Batcher + Logger *slog.Logger + StatsBatcher workspacestats.Batcher + MetadataBatcherOptions []metadatabatcher.Option WebpushDispatcher webpush.Dispatcher WorkspaceAppsStatsCollectorOptions workspaceapps.StatsCollectorOptions @@ -598,6 +600,7 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can HealthcheckTimeout: options.HealthcheckTimeout, HealthcheckRefresh: options.HealthcheckRefresh, StatsBatcher: options.StatsBatcher, + MetadataBatcherOptions: options.MetadataBatcherOptions, WorkspaceAppsStatsCollectorOptions: options.WorkspaceAppsStatsCollectorOptions, AllowWorkspaceRenames: options.AllowWorkspaceRenames, NewTicker: options.NewTicker, diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 53cff0567b..2beda99c47 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -1458,6 +1458,15 @@ func (q *querier) ArchiveUnusedTemplateVersions(ctx context.Context, arg databas return q.db.ArchiveUnusedTemplateVersions(ctx, arg) } +func (q *querier) BatchUpdateWorkspaceAgentMetadata(ctx context.Context, arg database.BatchUpdateWorkspaceAgentMetadataParams) error { + // Could be any workspace agent and checking auth to each workspace agent is overkill for + // the purpose of this function. + if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceWorkspace.All()); err != nil { + return err + } + return q.db.BatchUpdateWorkspaceAgentMetadata(ctx, arg) +} + func (q *querier) BatchUpdateWorkspaceLastUsedAt(ctx context.Context, arg database.BatchUpdateWorkspaceLastUsedAtParams) error { // Could be any workspace and checking auth to each workspace is overkill for // the purpose of this function. diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index c6ef2d2490..127aa63181 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -1862,6 +1862,18 @@ func (s *MethodTestSuite) TestWorkspace() { dbm.EXPECT().GetWorkspaceAgentMetadata(gomock.Any(), arg).Return([]database.WorkspaceAgentMetadatum{dt}, nil).AnyTimes() check.Args(arg).Asserts(w, policy.ActionRead).Returns([]database.WorkspaceAgentMetadatum{dt}) })) + s.Run("BatchUpdateWorkspaceAgentMetadata", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + agt := testutil.Fake(s.T(), faker, database.WorkspaceAgent{}) + arg := database.BatchUpdateWorkspaceAgentMetadataParams{ + WorkspaceAgentID: []uuid.UUID{agt.ID}, + Key: []string{"key1"}, + Value: []string{"value1"}, + Error: []string{""}, + CollectedAt: []time.Time{dbtime.Now()}, + } + dbm.EXPECT().BatchUpdateWorkspaceAgentMetadata(gomock.Any(), arg).Return(nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceWorkspace.All(), policy.ActionUpdate).Returns() + })) s.Run("GetWorkspaceAgentByInstanceID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { w := testutil.Fake(s.T(), faker, database.Workspace{}) agt := testutil.Fake(s.T(), faker, database.WorkspaceAgent{}) diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index 38e0a78beb..1682f6f2a5 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -152,6 +152,13 @@ func (m queryMetricsStore) ArchiveUnusedTemplateVersions(ctx context.Context, ar return r0, r1 } +func (m queryMetricsStore) BatchUpdateWorkspaceAgentMetadata(ctx context.Context, arg database.BatchUpdateWorkspaceAgentMetadataParams) error { + start := time.Now() + r0 := m.s.BatchUpdateWorkspaceAgentMetadata(ctx, arg) + m.queryLatencies.WithLabelValues("BatchUpdateWorkspaceAgentMetadata").Observe(time.Since(start).Seconds()) + return r0 +} + func (m queryMetricsStore) BatchUpdateWorkspaceLastUsedAt(ctx context.Context, arg database.BatchUpdateWorkspaceLastUsedAtParams) error { start := time.Now() r0 := m.s.BatchUpdateWorkspaceLastUsedAt(ctx, arg) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index f57a6afb73..fe057ea74d 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -132,6 +132,20 @@ func (mr *MockStoreMockRecorder) ArchiveUnusedTemplateVersions(ctx, arg any) *go return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ArchiveUnusedTemplateVersions", reflect.TypeOf((*MockStore)(nil).ArchiveUnusedTemplateVersions), ctx, arg) } +// BatchUpdateWorkspaceAgentMetadata mocks base method. +func (m *MockStore) BatchUpdateWorkspaceAgentMetadata(ctx context.Context, arg database.BatchUpdateWorkspaceAgentMetadataParams) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BatchUpdateWorkspaceAgentMetadata", ctx, arg) + ret0, _ := ret[0].(error) + return ret0 +} + +// BatchUpdateWorkspaceAgentMetadata indicates an expected call of BatchUpdateWorkspaceAgentMetadata. +func (mr *MockStoreMockRecorder) BatchUpdateWorkspaceAgentMetadata(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchUpdateWorkspaceAgentMetadata", reflect.TypeOf((*MockStore)(nil).BatchUpdateWorkspaceAgentMetadata), ctx, arg) +} + // BatchUpdateWorkspaceLastUsedAt mocks base method. func (m *MockStore) BatchUpdateWorkspaceLastUsedAt(ctx context.Context, arg database.BatchUpdateWorkspaceLastUsedAtParams) error { m.ctrl.T.Helper() diff --git a/coderd/database/querier.go b/coderd/database/querier.go index a91410676e..de31cd410a 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -56,6 +56,7 @@ type sqlcQuerier interface { // Only unused template versions will be archived, which are any versions not // referenced by the latest build of a workspace. ArchiveUnusedTemplateVersions(ctx context.Context, arg ArchiveUnusedTemplateVersionsParams) ([]uuid.UUID, error) + BatchUpdateWorkspaceAgentMetadata(ctx context.Context, arg BatchUpdateWorkspaceAgentMetadataParams) error BatchUpdateWorkspaceLastUsedAt(ctx context.Context, arg BatchUpdateWorkspaceLastUsedAtParams) error BatchUpdateWorkspaceNextStartAt(ctx context.Context, arg BatchUpdateWorkspaceNextStartAtParams) error BulkMarkNotificationMessagesFailed(ctx context.Context, arg BulkMarkNotificationMessagesFailedParams) (int64, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 680016f189..864cd971b4 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -17954,6 +17954,47 @@ func (q *sqlQuerier) UpdateVolumeResourceMonitor(ctx context.Context, arg Update return err } +const batchUpdateWorkspaceAgentMetadata = `-- name: BatchUpdateWorkspaceAgentMetadata :exec +WITH metadata AS ( + SELECT + unnest($1::uuid[]) AS workspace_agent_id, + unnest($2::text[]) AS key, + unnest($3::text[]) AS value, + unnest($4::text[]) AS error, + unnest($5::timestamptz[]) AS collected_at +) +UPDATE + workspace_agent_metadata wam +SET + value = m.value, + error = m.error, + collected_at = m.collected_at +FROM + metadata m +WHERE + wam.workspace_agent_id = m.workspace_agent_id + AND wam.key = m.key +` + +type BatchUpdateWorkspaceAgentMetadataParams struct { + WorkspaceAgentID []uuid.UUID `db:"workspace_agent_id" json:"workspace_agent_id"` + Key []string `db:"key" json:"key"` + Value []string `db:"value" json:"value"` + Error []string `db:"error" json:"error"` + CollectedAt []time.Time `db:"collected_at" json:"collected_at"` +} + +func (q *sqlQuerier) BatchUpdateWorkspaceAgentMetadata(ctx context.Context, arg BatchUpdateWorkspaceAgentMetadataParams) error { + _, err := q.db.ExecContext(ctx, batchUpdateWorkspaceAgentMetadata, + pq.Array(arg.WorkspaceAgentID), + pq.Array(arg.Key), + pq.Array(arg.Value), + pq.Array(arg.Error), + pq.Array(arg.CollectedAt), + ) + return err +} + const deleteOldWorkspaceAgentLogs = `-- name: DeleteOldWorkspaceAgentLogs :execrows WITH latest_builds AS ( diff --git a/coderd/database/queries/workspaceagents.sql b/coderd/database/queries/workspaceagents.sql index 3e52e6269f..d4dfa9a7a0 100644 --- a/coderd/database/queries/workspaceagents.sql +++ b/coderd/database/queries/workspaceagents.sql @@ -142,6 +142,27 @@ WHERE wam.workspace_agent_id = $1 AND wam.key = m.key; +-- name: BatchUpdateWorkspaceAgentMetadata :exec +WITH metadata AS ( + SELECT + unnest(sqlc.arg('workspace_agent_id')::uuid[]) AS workspace_agent_id, + unnest(sqlc.arg('key')::text[]) AS key, + unnest(sqlc.arg('value')::text[]) AS value, + unnest(sqlc.arg('error')::text[]) AS error, + unnest(sqlc.arg('collected_at')::timestamptz[]) AS collected_at +) +UPDATE + workspace_agent_metadata wam +SET + value = m.value, + error = m.error, + collected_at = m.collected_at +FROM + metadata m +WHERE + wam.workspace_agent_id = m.workspace_agent_id + AND wam.key = m.key; + -- name: GetWorkspaceAgentMetadata :many SELECT * diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index 17a831188e..68835c19c5 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -1,6 +1,7 @@ package coderd import ( + "bytes" "context" "database/sql" "encoding/json" @@ -24,7 +25,7 @@ import ( "tailscale.com/tailcfg" "cdr.dev/slog/v3" - "github.com/coder/coder/v2/coderd/agentapi" + "github.com/coder/coder/v2/coderd/agentapi/metadatabatcher" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/db2sdk" "github.com/coder/coder/v2/coderd/database/dbauthz" @@ -1701,6 +1702,12 @@ func (api *API) watchWorkspaceAgentMetadata( r = r.WithContext(ctx) // Rewire context for SSE cancellation. waws := httpmw.WorkspaceAgentAndWorkspaceParam(r) + agentIDEncoded := make([]byte, metadatabatcher.UUIDBase64Size) + err := metadatabatcher.EncodeAgentID(waws.WorkspaceAgent.ID, agentIDEncoded) + if err != nil { + httpapi.InternalServerError(rw, err) + return + } log := api.Logger.Named("workspace_metadata_watcher").With( slog.F("workspace_agent_id", waws.WorkspaceAgent.ID), slog.F("workspace_id", waws.WorkspaceTable.ID), @@ -1708,34 +1715,49 @@ func (api *API) watchWorkspaceAgentMetadata( // Send metadata on updates, we must ensure subscription before sending // initial metadata to guarantee that events in-between are not missed. - update := make(chan agentapi.WorkspaceAgentMetadataChannelPayload, 1) - cancelSub, err := api.Pubsub.Subscribe(agentapi.WatchWorkspaceAgentMetadataChannel(waws.WorkspaceAgent.ID), func(_ context.Context, byt []byte) { + // The channel carries no data - it's just a signal to fetch all metadata. + update := make(chan struct{}, 1) + + // Subscribe to the global batched metadata channel. + // The batcher publishes only to this channel to achieve O(1) NOTIFY scaling. + cancelBatchSub, err := api.Pubsub.Subscribe(metadatabatcher.MetadataBatchPubsubChannel, func(_ context.Context, byt []byte) { if ctx.Err() != nil { return } - var payload agentapi.WorkspaceAgentMetadataChannelPayload - err := json.Unmarshal(byt, &payload) - if err != nil { - log.Error(ctx, "failed to unmarshal pubsub message", slog.Error(err)) + if len(byt)%metadatabatcher.UUIDBase64Size != 0 { + log.Error(ctx, "invalid batched pubsub message, pubsub message length was not a multiple of encoded agent UUID length", slog.Error(err)) return } - log.Debug(ctx, "received metadata update", slog.F("payload", payload)) + // Compare each encoded agentID to our encoded agent ID. + for i := 0; i < len(byt); i += metadatabatcher.UUIDBase64Size { + if !bytes.Equal(byt[i:i+metadatabatcher.UUIDBase64Size], agentIDEncoded) { + continue + } - select { - case prev := <-update: - payload.Keys = appendUnique(prev.Keys, payload.Keys) - default: + log.Debug(ctx, "received metadata update from batch channel", + slog.F("agent_id", waws.WorkspaceAgent.ID), + slog.F("batch_size", len(byt)/metadatabatcher.UUIDBase64Size), + ) + + // Signal to re-fetch all metadata for this agent. + // Batch notifications don't include which keys changed, so we + // always fetch all keys for this agent. + // Attempt to read from the channel first so that we do not block on the write. + select { + case <-update: + default: + } + update <- struct{}{} + break } - // This can never block since we pop and merge beforehand. - update <- payload }) if err != nil { httpapi.InternalServerError(rw, err) return } - defer cancelSub() + defer cancelBatchSub() // We always use the original Request context because it contains // the RBAC actor. @@ -1819,10 +1841,11 @@ func (api *API) watchWorkspaceAgentMetadata( select { case <-ctx.Done(): return - case payload := <-update: + case <-update: + // Batch notification received - fetch all metadata for this agent. md, err := api.Database.GetWorkspaceAgentMetadata(ctx, database.GetWorkspaceAgentMetadataParams{ WorkspaceAgentID: waws.WorkspaceAgent.ID, - Keys: payload.Keys, + Keys: nil, // nil means fetch all keys }) if err != nil { if !database.IsQueryCanceledError(err) { @@ -1843,9 +1866,7 @@ func (api *API) watchWorkspaceAgentMetadata( // We want to block here to avoid constantly pinging the // database when the metadata isn't being processed. case fetchedMetadata <- md: - log.Debug(ctx, "fetched metadata update for keys", - slog.F("keys", payload.Keys), - slog.F("num", len(md))) + log.Debug(ctx, "fetched all metadata after batch update", slog.F("num", len(md))) } } } @@ -1881,21 +1902,6 @@ func (api *API) watchWorkspaceAgentMetadata( } } -// appendUnique is like append and adds elements from src to dst, -// skipping any elements that already exist in dst. -func appendUnique[T comparable](dst, src []T) []T { - exists := make(map[T]struct{}, len(dst)) - for _, key := range dst { - exists[key] = struct{}{} - } - for _, key := range src { - if _, ok := exists[key]; !ok { - dst = append(dst, key) - } - } - return dst -} - func convertWorkspaceAgentMetadata(db []database.WorkspaceAgentMetadatum) []codersdk.WorkspaceAgentMetadata { // Sort the input database slice by DisplayOrder and then by Key before processing sort.Slice(db, func(i, j int) bool { diff --git a/coderd/workspaceagents_test.go b/coderd/workspaceagents_test.go index f055ff7c75..3373e2b32b 100644 --- a/coderd/workspaceagents_test.go +++ b/coderd/workspaceagents_test.go @@ -35,6 +35,7 @@ import ( "github.com/coder/coder/v2/agent/agentcontainers/watcher" "github.com/coder/coder/v2/agent/agenttest" agentproto "github.com/coder/coder/v2/agent/proto" + "github.com/coder/coder/v2/coderd/agentapi/metadatabatcher" "github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/coderd/coderdtest/oidctest" "github.com/coder/coder/v2/coderd/database" @@ -2105,7 +2106,11 @@ func TestWorkspaceAgent_LifecycleState(t *testing.T) { func TestWorkspaceAgent_Metadata(t *testing.T) { t.Parallel() - client, db := coderdtest.NewWithDatabase(t, nil) + client, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{ + MetadataBatcherOptions: []metadatabatcher.Option{ + metadatabatcher.WithInterval(100 * time.Millisecond), + }, + }) user := coderdtest.CreateFirstUser(t, client) r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ OrganizationID: user.OrganizationID, @@ -2232,7 +2237,7 @@ func TestWorkspaceAgent_Metadata(t *testing.T) { update = recvUpdate() require.Len(t, update, 3) - check(wantMetadata1, update[0], false) + check(wantMetadata1, update[0], true) // The second metadata result is not yet posted. require.Zero(t, update[1].Result.CollectedAt) diff --git a/coderd/workspaceagentsrpc.go b/coderd/workspaceagentsrpc.go index 79273ae35f..b4e9cc7650 100644 --- a/coderd/workspaceagentsrpc.go +++ b/coderd/workspaceagentsrpc.go @@ -144,6 +144,7 @@ func (api *API) workspaceAgentRPC(rw http.ResponseWriter, r *http.Request) { TailnetCoordinator: &api.TailnetCoordinator, AppearanceFetcher: &api.AppearanceFetcher, StatsReporter: api.statsReporter, + MetadataBatcher: api.metadataBatcher, PublishWorkspaceUpdateFn: api.publishWorkspaceUpdate, PublishWorkspaceAgentLogsUpdateFn: api.publishWorkspaceAgentLogsUpdate, NetworkTelemetryHandler: api.NetworkTelemetryBatcher.Handler,