From b199eb1c389ba713d1358bb62d4ffb87698c2a70 Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Tue, 9 Dec 2025 22:05:12 +1100 Subject: [PATCH] fix: allow stops and deletes after breaching AI limit (#21186) Fixes a bug a customer encountered once they breached their limit. Adds a test. --- coderd/wsbuilder/wsbuilder.go | 6 +-- coderd/wsbuilder/wsbuilder_test.go | 10 ++-- enterprise/coderd/coderd.go | 42 ++++++++++------ enterprise/coderd/coderd_test.go | 78 ++++++++++++++++++++++++++++-- 4 files changed, 109 insertions(+), 27 deletions(-) diff --git a/coderd/wsbuilder/wsbuilder.go b/coderd/wsbuilder/wsbuilder.go index 6aef8c2c2a..7d388966c9 100644 --- a/coderd/wsbuilder/wsbuilder.go +++ b/coderd/wsbuilder/wsbuilder.go @@ -93,7 +93,7 @@ type Builder struct { } type UsageChecker interface { - CheckBuildUsage(ctx context.Context, store database.Store, templateVersion *database.TemplateVersion) (UsageCheckResponse, error) + CheckBuildUsage(ctx context.Context, store database.Store, templateVersion *database.TemplateVersion, transition database.WorkspaceTransition) (UsageCheckResponse, error) } type UsageCheckResponse struct { @@ -105,7 +105,7 @@ type NoopUsageChecker struct{} var _ UsageChecker = NoopUsageChecker{} -func (NoopUsageChecker) CheckBuildUsage(_ context.Context, _ database.Store, _ *database.TemplateVersion) (UsageCheckResponse, error) { +func (NoopUsageChecker) CheckBuildUsage(_ context.Context, _ database.Store, _ *database.TemplateVersion, _ database.WorkspaceTransition) (UsageCheckResponse, error) { return UsageCheckResponse{ Permitted: true, }, nil @@ -1307,7 +1307,7 @@ func (b *Builder) checkUsage() error { return BuildError{http.StatusInternalServerError, "Failed to fetch template version", err} } - resp, err := b.usageChecker.CheckBuildUsage(b.ctx, b.store, templateVersion) + resp, err := b.usageChecker.CheckBuildUsage(b.ctx, b.store, templateVersion, b.trans) if err != nil { return BuildError{http.StatusInternalServerError, "Failed to check build usage", err} } diff --git a/coderd/wsbuilder/wsbuilder_test.go b/coderd/wsbuilder/wsbuilder_test.go index 3a8921dd6d..c3b4fe723c 100644 --- a/coderd/wsbuilder/wsbuilder_test.go +++ b/coderd/wsbuilder/wsbuilder_test.go @@ -1049,7 +1049,7 @@ func TestWorkspaceBuildUsageChecker(t *testing.T) { var calls int64 fakeUsageChecker := &fakeUsageChecker{ - checkBuildUsageFunc: func(_ context.Context, _ database.Store, templateVersion *database.TemplateVersion) (wsbuilder.UsageCheckResponse, error) { + checkBuildUsageFunc: func(_ context.Context, _ database.Store, templateVersion *database.TemplateVersion, _ database.WorkspaceTransition) (wsbuilder.UsageCheckResponse, error) { atomic.AddInt64(&calls, 1) return wsbuilder.UsageCheckResponse{Permitted: true}, nil }, @@ -1126,7 +1126,7 @@ func TestWorkspaceBuildUsageChecker(t *testing.T) { var calls int64 fakeUsageChecker := &fakeUsageChecker{ - checkBuildUsageFunc: func(_ context.Context, _ database.Store, templateVersion *database.TemplateVersion) (wsbuilder.UsageCheckResponse, error) { + checkBuildUsageFunc: func(_ context.Context, _ database.Store, templateVersion *database.TemplateVersion, _ database.WorkspaceTransition) (wsbuilder.UsageCheckResponse, error) { atomic.AddInt64(&calls, 1) return c.response, c.responseErr }, @@ -1577,11 +1577,11 @@ func expectFindMatchingPresetID(id uuid.UUID, err error) func(mTx *dbmock.MockSt } type fakeUsageChecker struct { - checkBuildUsageFunc func(ctx context.Context, store database.Store, templateVersion *database.TemplateVersion) (wsbuilder.UsageCheckResponse, error) + checkBuildUsageFunc func(ctx context.Context, store database.Store, templateVersion *database.TemplateVersion, transition database.WorkspaceTransition) (wsbuilder.UsageCheckResponse, error) } -func (f *fakeUsageChecker) CheckBuildUsage(ctx context.Context, store database.Store, templateVersion *database.TemplateVersion) (wsbuilder.UsageCheckResponse, error) { - return f.checkBuildUsageFunc(ctx, store, templateVersion) +func (f *fakeUsageChecker) CheckBuildUsage(ctx context.Context, store database.Store, templateVersion *database.TemplateVersion, transition database.WorkspaceTransition) (wsbuilder.UsageCheckResponse, error) { + return f.checkBuildUsageFunc(ctx, store, templateVersion, transition) } func withNoTask(mTx *dbmock.MockStore) { diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index 3875c83797..bf1a5acf53 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -971,7 +971,7 @@ func (api *API) updateEntitlements(ctx context.Context) error { var _ wsbuilder.UsageChecker = &API{} -func (api *API) CheckBuildUsage(ctx context.Context, store database.Store, templateVersion *database.TemplateVersion) (wsbuilder.UsageCheckResponse, error) { +func (api *API) CheckBuildUsage(ctx context.Context, store database.Store, templateVersion *database.TemplateVersion, transition database.WorkspaceTransition) (wsbuilder.UsageCheckResponse, error) { // If the template version has an external agent, we need to check that the // license is entitled to this feature. if templateVersion.HasExternalAgent.Valid && templateVersion.HasExternalAgent.Bool { @@ -984,16 +984,31 @@ func (api *API) CheckBuildUsage(ctx context.Context, store database.Store, templ } } - // If the template version doesn't have an AI task, we don't need to check - // usage. - if !templateVersion.HasAITask.Valid || !templateVersion.HasAITask.Bool { - return wsbuilder.UsageCheckResponse{ - Permitted: true, - }, nil + resp, err := api.checkAIBuildUsage(ctx, store, templateVersion, transition) + if err != nil { + return wsbuilder.UsageCheckResponse{}, err + } + if !resp.Permitted { + return resp, nil } - // When unlicensed, we need to check that we haven't breached the managed agent - // limit. + return wsbuilder.UsageCheckResponse{Permitted: true}, nil +} + +// checkAIBuildUsage validates AI-related usage constraints. It is a no-op +// unless the transition is "start" and the template version has an AI task. +func (api *API) checkAIBuildUsage(ctx context.Context, store database.Store, templateVersion *database.TemplateVersion, transition database.WorkspaceTransition) (wsbuilder.UsageCheckResponse, error) { + // Only check AI usage rules for start transitions. + if transition != database.WorkspaceTransitionStart { + return wsbuilder.UsageCheckResponse{Permitted: true}, nil + } + + // If the template version doesn't have an AI task, we don't need to check usage. + if !templateVersion.HasAITask.Valid || !templateVersion.HasAITask.Bool { + return wsbuilder.UsageCheckResponse{Permitted: true}, nil + } + + // When licensed, ensure we haven't breached the managed agent limit. // Unlicensed deployments are allowed to use unlimited managed agents. if api.Entitlements.HasLicense() { managedAgentLimit, ok := api.Entitlements.Feature(codersdk.FeatureManagedAgentLimit) @@ -1004,8 +1019,9 @@ func (api *API) CheckBuildUsage(ctx context.Context, store database.Store, templ }, nil } - // This check is intentionally not committed to the database. It's fine if - // it's not 100% accurate or allows for minor breaches due to build races. + // This check is intentionally not committed to the database. It's fine + // if it's not 100% accurate or allows for minor breaches due to build + // races. // nolint:gocritic // Requires permission to read all usage events. managedAgentCount, err := store.GetTotalUsageDCManagedAgentsV1(agpldbauthz.AsSystemRestricted(ctx), database.GetTotalUsageDCManagedAgentsV1Params{ StartDate: managedAgentLimit.UsagePeriod.Start, @@ -1023,9 +1039,7 @@ func (api *API) CheckBuildUsage(ctx context.Context, store database.Store, templ } } - return wsbuilder.UsageCheckResponse{ - Permitted: true, - }, nil + return wsbuilder.UsageCheckResponse{Permitted: true}, nil } // getProxyDERPStartingRegionID returns the starting region ID that should be diff --git a/enterprise/coderd/coderd_test.go b/enterprise/coderd/coderd_test.go index c3e6e1579f..0e1078128a 100644 --- a/enterprise/coderd/coderd_test.go +++ b/enterprise/coderd/coderd_test.go @@ -3,6 +3,7 @@ package coderd_test import ( "bytes" "context" + "database/sql" "encoding/json" "fmt" "io" @@ -21,6 +22,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/goleak" + "go.uber.org/mock/gomock" "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" @@ -39,13 +41,16 @@ import ( "github.com/coder/retry" "github.com/coder/serpent" + agplcoderd "github.com/coder/coder/v2/coderd" agplaudit "github.com/coder/coder/v2/coderd/audit" "github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/database/dbfake" + "github.com/coder/coder/v2/coderd/database/dbmock" "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/database/dbtime" + "github.com/coder/coder/v2/coderd/entitlements" "github.com/coder/coder/v2/coderd/rbac" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/workspacesdk" @@ -635,18 +640,18 @@ func TestManagedAgentLimit(t *testing.T) { }) // Get entitlements to check that the license is a-ok. - entitlements, err := cli.Entitlements(ctx) //nolint:gocritic // we're not testing authz on the entitlements endpoint, so using owner is fine + sdkEntitlements, err := cli.Entitlements(ctx) //nolint:gocritic // we're not testing authz on the entitlements endpoint, so using owner is fine require.NoError(t, err) - require.True(t, entitlements.HasLicense) - agentLimit := entitlements.Features[codersdk.FeatureManagedAgentLimit] + require.True(t, sdkEntitlements.HasLicense) + agentLimit := sdkEntitlements.Features[codersdk.FeatureManagedAgentLimit] require.True(t, agentLimit.Enabled) require.NotNil(t, agentLimit.Limit) require.EqualValues(t, 1, *agentLimit.Limit) require.NotNil(t, agentLimit.SoftLimit) require.EqualValues(t, 1, *agentLimit.SoftLimit) - require.Empty(t, entitlements.Errors) + require.Empty(t, sdkEntitlements.Errors) // There should be a warning since we're really close to our agent limit. - require.Equal(t, entitlements.Warnings[0], "You are approaching the managed agent limit in your license. Please refer to the Deployment Licenses page for more information.") + require.Equal(t, sdkEntitlements.Warnings[0], "You are approaching the managed agent limit in your license. Please refer to the Deployment Licenses page for more information.") // Create a fake provision response that claims there are agents in the // template and every built workspace. @@ -723,6 +728,69 @@ func TestManagedAgentLimit(t *testing.T) { coderdtest.AwaitWorkspaceBuildJobCompleted(t, cli, workspace.LatestBuild.ID) } +func TestCheckBuildUsage_SkipsAIForNonStartTransitions(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Prepare entitlements with a managed agent limit to enforce. + entSet := entitlements.New() + entSet.Modify(func(e *codersdk.Entitlements) { + e.HasLicense = true + limit := int64(1) + issuedAt := time.Now().Add(-2 * time.Hour) + start := time.Now().Add(-time.Hour) + end := time.Now().Add(time.Hour) + e.Features[codersdk.FeatureManagedAgentLimit] = codersdk.Feature{ + Enabled: true, + Limit: &limit, + UsagePeriod: &codersdk.UsagePeriod{IssuedAt: issuedAt, Start: start, End: end}, + } + }) + + // Enterprise API instance with entitlements injected. + agpl := &agplcoderd.API{ + Options: &agplcoderd.Options{ + Entitlements: entSet, + }, + } + eapi := &coderd.API{ + AGPL: agpl, + Options: &coderd.Options{Options: agpl.Options}, + } + + // Template version that has an AI task. + tv := &database.TemplateVersion{ + HasAITask: sql.NullBool{Valid: true, Bool: true}, + HasExternalAgent: sql.NullBool{Valid: true, Bool: false}, + } + + // Mock DB: expect exactly one count call for the "start" transition. + mDB := dbmock.NewMockStore(ctrl) + mDB.EXPECT(). + GetTotalUsageDCManagedAgentsV1(gomock.Any(), gomock.Any()). + Times(1). + Return(int64(1), nil) // equal to limit -> should breach + + ctx := context.Background() + + // Start transition: should be not permitted due to limit breach. + startResp, err := eapi.CheckBuildUsage(ctx, mDB, tv, database.WorkspaceTransitionStart) + require.NoError(t, err) + require.False(t, startResp.Permitted) + require.Contains(t, startResp.Message, "breached the managed agent limit") + + // Stop transition: should be permitted and must not trigger additional DB calls. + stopResp, err := eapi.CheckBuildUsage(ctx, mDB, tv, database.WorkspaceTransitionStop) + require.NoError(t, err) + require.True(t, stopResp.Permitted) + + // Delete transition: should be permitted and must not trigger additional DB calls. + deleteResp, err := eapi.CheckBuildUsage(ctx, mDB, tv, database.WorkspaceTransitionDelete) + require.NoError(t, err) + require.True(t, deleteResp.Permitted) +} + // testDBAuthzRole returns a context with a subject that has a role // with permissions required for test setup. func testDBAuthzRole(ctx context.Context) context.Context {