diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index bb4d687db1..ed56cbb39a 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -83,6 +83,7 @@ import ( "github.com/coder/coder/v2/coderd/schedule" "github.com/coder/coder/v2/coderd/telemetry" "github.com/coder/coder/v2/coderd/updatecheck" + "github.com/coder/coder/v2/coderd/usage" "github.com/coder/coder/v2/coderd/util/namesgenerator" "github.com/coder/coder/v2/coderd/util/ptr" "github.com/coder/coder/v2/coderd/webpush" @@ -190,6 +191,7 @@ type Options struct { TelemetryReporter telemetry.Reporter ProvisionerdServerMetrics *provisionerdserver.Metrics + UsageInserter usage.Inserter } // New constructs a codersdk client connected to an in-memory API instance. @@ -270,6 +272,11 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can } } + var usageInserter *atomic.Pointer[usage.Inserter] + if options.UsageInserter != nil { + usageInserter = &atomic.Pointer[usage.Inserter]{} + usageInserter.Store(&options.UsageInserter) + } if options.Database == nil { options.Database, options.Pubsub = dbtestutil.NewDB(t) } @@ -563,6 +570,7 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can Database: options.Database, Pubsub: options.Pubsub, ExternalAuthConfigs: options.ExternalAuthConfigs, + UsageInserter: usageInserter, Auditor: options.Auditor, ConnectionLogger: options.ConnectionLogger, diff --git a/coderd/coderdtest/usage.go b/coderd/coderdtest/usage.go new file mode 100644 index 0000000000..4da724b177 --- /dev/null +++ b/coderd/coderdtest/usage.go @@ -0,0 +1,44 @@ +package coderdtest + +import ( + "context" + "sync" + + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/usage" + "github.com/coder/coder/v2/coderd/usage/usagetypes" +) + +var _ usage.Inserter = (*UsageInserter)(nil) + +type UsageInserter struct { + sync.Mutex + events []usagetypes.DiscreteEvent +} + +func NewUsageInserter() *UsageInserter { + return &UsageInserter{ + events: []usagetypes.DiscreteEvent{}, + } +} + +func (u *UsageInserter) InsertDiscreteUsageEvent(_ context.Context, _ database.Store, event usagetypes.DiscreteEvent) error { + u.Lock() + defer u.Unlock() + u.events = append(u.events, event) + return nil +} + +func (u *UsageInserter) GetEvents() []usagetypes.DiscreteEvent { + u.Lock() + defer u.Unlock() + eventsCopy := make([]usagetypes.DiscreteEvent, len(u.events)) + copy(eventsCopy, u.events) + return eventsCopy +} + +func (u *UsageInserter) Reset() { + u.Lock() + defer u.Unlock() + u.events = []usagetypes.DiscreteEvent{} +} diff --git a/coderd/wsbuilder/wsbuilder.go b/coderd/wsbuilder/wsbuilder.go index d7ace7a4e2..e870c89ee3 100644 --- a/coderd/wsbuilder/wsbuilder.go +++ b/coderd/wsbuilder/wsbuilder.go @@ -6,7 +6,6 @@ import ( "context" "database/sql" "encoding/json" - "errors" "fmt" "net/http" "time" @@ -85,13 +84,15 @@ type Builder struct { templateVersionPresetParameterValues *[]database.TemplateVersionPresetParameter parameterRender dynamicparameters.Renderer workspaceTags *map[string]string + task *database.Task + hasTask *bool // A workspace without a task will have a nil `task` and false `hasTask`. prebuiltWorkspaceBuildStage sdkproto.PrebuiltWorkspaceBuildStage verifyNoLegacyParametersOnce bool } type UsageChecker interface { - CheckBuildUsage(ctx context.Context, store database.Store, templateVersion *database.TemplateVersion, transition database.WorkspaceTransition) (UsageCheckResponse, error) + CheckBuildUsage(ctx context.Context, store database.Store, templateVersion *database.TemplateVersion, task *database.Task, transition database.WorkspaceTransition) (UsageCheckResponse, error) } type UsageCheckResponse struct { @@ -103,7 +104,7 @@ type NoopUsageChecker struct{} var _ UsageChecker = NoopUsageChecker{} -func (NoopUsageChecker) CheckBuildUsage(_ context.Context, _ database.Store, _ *database.TemplateVersion, _ database.WorkspaceTransition) (UsageCheckResponse, error) { +func (NoopUsageChecker) CheckBuildUsage(_ context.Context, _ database.Store, _ *database.TemplateVersion, _ *database.Task, _ database.WorkspaceTransition) (UsageCheckResponse, error) { return UsageCheckResponse{ Permitted: true, }, nil @@ -487,8 +488,12 @@ func (b *Builder) buildTx(authFunc func(action policy.Action, object rbac.Object return BuildError{code, "insert workspace build", err} } + task, err := b.getWorkspaceTask() + if err != nil { + return BuildError{http.StatusInternalServerError, "get task by workspace id", err} + } // If this is a task workspace, link it to the latest workspace build. - if task, err := store.GetTaskByWorkspaceID(b.ctx, b.workspace.ID); err == nil { + if task != nil { _, err = store.UpsertTaskWorkspaceApp(b.ctx, database.UpsertTaskWorkspaceAppParams{ TaskID: task.ID, WorkspaceBuildNumber: buildNum, @@ -498,8 +503,6 @@ func (b *Builder) buildTx(authFunc func(action policy.Action, object rbac.Object if err != nil { return BuildError{http.StatusInternalServerError, "upsert task workspace app", err} } - } else if !errors.Is(err, sql.ErrNoRows) { - return BuildError{http.StatusInternalServerError, "get task by workspace id", err} } err = store.InsertWorkspaceBuildParameters(b.ctx, database.InsertWorkspaceBuildParametersParams{ @@ -632,6 +635,27 @@ func (b *Builder) getTemplateVersionID() (uuid.UUID, error) { return bld.TemplateVersionID, nil } +// getWorkspaceTask returns the task associated with the workspace, if any. +// If no task exists, it returns (nil, nil). +func (b *Builder) getWorkspaceTask() (*database.Task, error) { + if b.hasTask != nil { + return b.task, nil + } + t, err := b.store.GetTaskByWorkspaceID(b.ctx, b.workspace.ID) + if err != nil { + if xerrors.Is(err, sql.ErrNoRows) { + b.hasTask = ptr.Ref(false) + //nolint:nilnil // No task exists. + return nil, nil + } + return nil, xerrors.Errorf("get task: %w", err) + } + + b.task = &t + b.hasTask = ptr.Ref(true) + return b.task, nil +} + func (b *Builder) getTemplateTerraformValues() (*database.TemplateVersionTerraformValue, error) { if b.terraformValues != nil { return b.terraformValues, nil @@ -1313,7 +1337,12 @@ func (b *Builder) checkUsage() error { return BuildError{http.StatusInternalServerError, "Failed to fetch template version", err} } - resp, err := b.usageChecker.CheckBuildUsage(b.ctx, b.store, templateVersion, b.trans) + task, err := b.getWorkspaceTask() + if err != nil { + return BuildError{http.StatusInternalServerError, "Failed to fetch workspace task", err} + } + + resp, err := b.usageChecker.CheckBuildUsage(b.ctx, b.store, templateVersion, task, b.trans) if err != nil { return BuildError{http.StatusInternalServerError, "Failed to check build usage", err} } diff --git a/coderd/wsbuilder/wsbuilder_test.go b/coderd/wsbuilder/wsbuilder_test.go index c3b4fe723c..38f88f7508 100644 --- a/coderd/wsbuilder/wsbuilder_test.go +++ b/coderd/wsbuilder/wsbuilder_test.go @@ -570,6 +570,7 @@ func TestWorkspaceBuildWithRichParameters(t *testing.T) { mDB := expectDB(t, // Inputs withTemplate, + withNoTask, withInactiveVersionNoParams(), withLastBuildFound, withTemplateVersionVariables(inactiveVersionID, nil), @@ -605,6 +606,7 @@ func TestWorkspaceBuildWithRichParameters(t *testing.T) { withTemplate, withInactiveVersion(richParameters), withLastBuildFound, + withNoTask, withTemplateVersionVariables(inactiveVersionID, nil), withRichParameters(initialBuildParameters), withParameterSchemas(inactiveJobID, nil), @@ -1049,7 +1051,7 @@ func TestWorkspaceBuildUsageChecker(t *testing.T) { var calls int64 fakeUsageChecker := &fakeUsageChecker{ - checkBuildUsageFunc: func(_ context.Context, _ database.Store, templateVersion *database.TemplateVersion, _ database.WorkspaceTransition) (wsbuilder.UsageCheckResponse, error) { + checkBuildUsageFunc: func(_ context.Context, _ database.Store, _ *database.TemplateVersion, _ *database.Task, _ database.WorkspaceTransition) (wsbuilder.UsageCheckResponse, error) { atomic.AddInt64(&calls, 1) return wsbuilder.UsageCheckResponse{Permitted: true}, nil }, @@ -1126,7 +1128,7 @@ func TestWorkspaceBuildUsageChecker(t *testing.T) { var calls int64 fakeUsageChecker := &fakeUsageChecker{ - checkBuildUsageFunc: func(_ context.Context, _ database.Store, templateVersion *database.TemplateVersion, _ database.WorkspaceTransition) (wsbuilder.UsageCheckResponse, error) { + checkBuildUsageFunc: func(_ context.Context, _ database.Store, _ *database.TemplateVersion, _ *database.Task, _ database.WorkspaceTransition) (wsbuilder.UsageCheckResponse, error) { atomic.AddInt64(&calls, 1) return c.response, c.responseErr }, @@ -1134,6 +1136,7 @@ func TestWorkspaceBuildUsageChecker(t *testing.T) { mDB := expectDB(t, withTemplate, + withNoTask, withInactiveVersionNoParams(), ) fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) @@ -1577,11 +1580,11 @@ func expectFindMatchingPresetID(id uuid.UUID, err error) func(mTx *dbmock.MockSt } type fakeUsageChecker struct { - checkBuildUsageFunc func(ctx context.Context, store database.Store, templateVersion *database.TemplateVersion, transition database.WorkspaceTransition) (wsbuilder.UsageCheckResponse, error) + checkBuildUsageFunc func(ctx context.Context, store database.Store, templateVersion *database.TemplateVersion, task *database.Task, transition database.WorkspaceTransition) (wsbuilder.UsageCheckResponse, error) } -func (f *fakeUsageChecker) CheckBuildUsage(ctx context.Context, store database.Store, templateVersion *database.TemplateVersion, transition database.WorkspaceTransition) (wsbuilder.UsageCheckResponse, error) { - return f.checkBuildUsageFunc(ctx, store, templateVersion, transition) +func (f *fakeUsageChecker) CheckBuildUsage(ctx context.Context, store database.Store, templateVersion *database.TemplateVersion, task *database.Task, transition database.WorkspaceTransition) (wsbuilder.UsageCheckResponse, error) { + return f.checkBuildUsageFunc(ctx, store, templateVersion, task, transition) } func withNoTask(mTx *dbmock.MockStore) { diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index 75875b0766..9ad5369666 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -975,7 +975,7 @@ func (api *API) updateEntitlements(ctx context.Context) error { var _ wsbuilder.UsageChecker = &API{} -func (api *API) CheckBuildUsage(ctx context.Context, store database.Store, templateVersion *database.TemplateVersion, transition database.WorkspaceTransition) (wsbuilder.UsageCheckResponse, error) { +func (api *API) CheckBuildUsage(ctx context.Context, store database.Store, templateVersion *database.TemplateVersion, task *database.Task, transition database.WorkspaceTransition) (wsbuilder.UsageCheckResponse, error) { // If the template version has an external agent, we need to check that the // license is entitled to this feature. if templateVersion.HasExternalAgent.Valid && templateVersion.HasExternalAgent.Bool { @@ -988,7 +988,7 @@ func (api *API) CheckBuildUsage(ctx context.Context, store database.Store, templ } } - resp, err := api.checkAIBuildUsage(ctx, store, templateVersion, transition) + resp, err := api.checkAIBuildUsage(ctx, store, task, transition) if err != nil { return wsbuilder.UsageCheckResponse{}, err } @@ -1001,14 +1001,14 @@ func (api *API) CheckBuildUsage(ctx context.Context, store database.Store, templ // checkAIBuildUsage validates AI-related usage constraints. It is a no-op // unless the transition is "start" and the template version has an AI task. -func (api *API) checkAIBuildUsage(ctx context.Context, store database.Store, templateVersion *database.TemplateVersion, transition database.WorkspaceTransition) (wsbuilder.UsageCheckResponse, error) { +func (api *API) checkAIBuildUsage(ctx context.Context, store database.Store, task *database.Task, transition database.WorkspaceTransition) (wsbuilder.UsageCheckResponse, error) { // Only check AI usage rules for start transitions. if transition != database.WorkspaceTransitionStart { return wsbuilder.UsageCheckResponse{Permitted: true}, nil } // If the template version doesn't have an AI task, we don't need to check usage. - if !templateVersion.HasAITask.Valid || !templateVersion.HasAITask.Bool { + if task == nil { return wsbuilder.UsageCheckResponse{Permitted: true}, nil } diff --git a/enterprise/coderd/coderd_test.go b/enterprise/coderd/coderd_test.go index 19f81bd51d..c4ab994dbf 100644 --- a/enterprise/coderd/coderd_test.go +++ b/enterprise/coderd/coderd_test.go @@ -772,6 +772,10 @@ func TestCheckBuildUsage_SkipsAIForNonStartTransitions(t *testing.T) { HasExternalAgent: sql.NullBool{Valid: true, Bool: false}, } + task := &database.Task{ + TemplateVersionID: tv.ID, + } + // Mock DB: expect exactly one count call for the "start" transition. mDB := dbmock.NewMockStore(ctrl) mDB.EXPECT(). @@ -782,18 +786,18 @@ func TestCheckBuildUsage_SkipsAIForNonStartTransitions(t *testing.T) { ctx := context.Background() // Start transition: should be not permitted due to limit breach. - startResp, err := eapi.CheckBuildUsage(ctx, mDB, tv, database.WorkspaceTransitionStart) + startResp, err := eapi.CheckBuildUsage(ctx, mDB, tv, task, database.WorkspaceTransitionStart) require.NoError(t, err) require.False(t, startResp.Permitted) require.Contains(t, startResp.Message, "breached the managed agent limit") // Stop transition: should be permitted and must not trigger additional DB calls. - stopResp, err := eapi.CheckBuildUsage(ctx, mDB, tv, database.WorkspaceTransitionStop) + stopResp, err := eapi.CheckBuildUsage(ctx, mDB, tv, task, database.WorkspaceTransitionStop) require.NoError(t, err) require.True(t, stopResp.Permitted) // Delete transition: should be permitted and must not trigger additional DB calls. - deleteResp, err := eapi.CheckBuildUsage(ctx, mDB, tv, database.WorkspaceTransitionDelete) + deleteResp, err := eapi.CheckBuildUsage(ctx, mDB, tv, task, database.WorkspaceTransitionDelete) require.NoError(t, err) require.True(t, deleteResp.Permitted) } diff --git a/enterprise/coderd/workspaces_test.go b/enterprise/coderd/workspaces_test.go index fd4f1d3934..78ba5e4656 100644 --- a/enterprise/coderd/workspaces_test.go +++ b/enterprise/coderd/workspaces_test.go @@ -4705,3 +4705,121 @@ func TestWorkspacesSharedWith(t *testing.T) { assert.Equal(t, "/emojis/1f60d.png", groupActor.AvatarURL) }) } + +//nolint:tparallel,paralleltest // Sub tests need to run sequentially. +func TestWorkspaceAITask(t *testing.T) { + t.Parallel() + + usage := coderdtest.NewUsageInserter() + owner, _, first := coderdenttest.NewWithDatabase(t, &coderdenttest.Options{ + Options: &coderdtest.Options{ + UsageInserter: usage, + IncludeProvisionerDaemon: true, + }, + LicenseOptions: (&coderdenttest.LicenseOptions{ + Features: license.Features{ + codersdk.FeatureTemplateRBAC: 1, + }, + }).ManagedAgentLimit(10, 20), + }) + + client, _ := coderdtest.CreateAnotherUser(t, owner, first.OrganizationID, + rbac.RoleTemplateAdmin(), rbac.RoleUserAdmin()) + + graphWithTask := []*proto.Response{{ + Type: &proto.Response_Graph{ + Graph: &proto.GraphComplete{ + Error: "", + Timings: nil, + Resources: nil, + Parameters: nil, + ExternalAuthProviders: nil, + Presets: nil, + HasAiTasks: true, + AiTasks: []*proto.AITask{ + { + Id: "test", + SidebarApp: nil, + AppId: "test", + }, + }, + HasExternalAgents: false, + }, + }, + }} + planWithTask := []*proto.Response{{ + Type: &proto.Response_Plan{ + Plan: &proto.PlanComplete{ + Plan: []byte("{}"), + AiTaskCount: 1, + }, + }, + }} + + t.Run("CreateWorkspaceWithTaskNormally", func(t *testing.T) { + // Creating a workspace that has agentic tasks, but is not launced via task + // should not count towards the usage. + t.Cleanup(usage.Reset) + version := coderdtest.CreateTemplateVersion(t, client, first.OrganizationID, &echo.Responses{ + Parse: echo.ParseComplete, + ProvisionInit: echo.InitComplete, + ProvisionPlan: planWithTask, + ProvisionApply: echo.ApplyComplete, + ProvisionGraph: graphWithTask, + }) + _ = coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID) + template := coderdtest.CreateTemplate(t, client, first.OrganizationID, version.ID) + wrk := coderdtest.CreateWorkspace(t, client, template.ID) + build := coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, wrk.LatestBuild.ID) + require.Equal(t, codersdk.WorkspaceStatusRunning, build.Status) + require.Len(t, usage.GetEvents(), 0) + }) + + t.Run("CreateTaskWorkspace", func(t *testing.T) { + ctx := testutil.Context(t, testutil.WaitMedium) + t.Cleanup(usage.Reset) + version := coderdtest.CreateTemplateVersion(t, client, first.OrganizationID, &echo.Responses{ + Parse: echo.ParseComplete, + ProvisionInit: echo.InitComplete, + ProvisionPlan: planWithTask, + ProvisionApply: echo.ApplyComplete, + ProvisionGraph: graphWithTask, + }) + _ = coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID) + template := coderdtest.CreateTemplate(t, client, first.OrganizationID, version.ID) + + task, err := client.CreateTask(ctx, codersdk.Me, codersdk.CreateTaskRequest{ + TemplateVersionID: template.ActiveVersionID, + Name: "istask", + }) + require.NoError(t, err) + + wrk, err := client.Workspace(ctx, task.WorkspaceID.UUID) + require.NoError(t, err) + + build := coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, wrk.LatestBuild.ID) + require.Equal(t, codersdk.WorkspaceStatusRunning, build.Status) + require.Len(t, usage.GetEvents(), 1) + + usage.Reset() // Clean slate for easy checks + // Stopping the workspace should not create additional usage. + build, err = client.CreateWorkspaceBuild(ctx, wrk.ID, codersdk.CreateWorkspaceBuildRequest{ + TemplateVersionID: wrk.LatestBuild.TemplateVersionID, + Transition: codersdk.WorkspaceTransitionStop, + }) + require.NoError(t, err) + coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, build.ID) + require.Len(t, usage.GetEvents(), 0) + + usage.Reset() // Clean slate for easy checks + // Starting the workspace manually **WILL** create usage, as it's + // still a task workspace. + build, err = client.CreateWorkspaceBuild(ctx, wrk.ID, codersdk.CreateWorkspaceBuildRequest{ + TemplateVersionID: wrk.LatestBuild.TemplateVersionID, + Transition: codersdk.WorkspaceTransitionStart, + }) + require.NoError(t, err) + coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, build.ID) + require.Len(t, usage.GetEvents(), 1) + }) +}